You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by sr...@apache.org on 2018/08/21 12:31:35 UTC

[flink] branch master updated (07bb90d -> f803280)

This is an automated email from the ASF dual-hosted git repository.

srichter pushed a change to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git.


    from 07bb90d  [FLINK-10181][rest][docs] Add anchor links to rest requests
     new aba02eb  [FLINK-10042][state] (part 1) Extract snapshot algorithms from inner classes of RocksDBKeyedStateBackend into full classes
     new f803280  [FLINK-10042][state] (part 2) Refactoring of snapshot algorithms for better abstraction and cleaner resource management

The 2 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 .../async/AbstractAsyncCallableWithResources.java  |  194 ----
 .../flink/runtime/io/async/AsyncDoneCallback.java  |   33 -
 .../flink/runtime/io/async/AsyncStoppable.java     |   45 -
 .../io/async/AsyncStoppableTaskWithCallback.java   |   59 --
 .../io/async/StoppableCallbackCallable.java        |   30 -
 .../runtime/state/AbstractSnapshotStrategy.java    |   79 ++
 .../flink/runtime/state/AsyncSnapshotCallable.java |  190 ++++
 .../runtime/state/DefaultOperatorStateBackend.java |  369 +++----
 .../flink/runtime/state/SnapshotStrategy.java      |   12 +-
 .../apache/flink/runtime/state/Snapshotable.java   |   27 +-
 .../runtime/state/heap/HeapKeyedStateBackend.java  |  140 +--
 .../runtime/state/AsyncSnapshotCallableTest.java   |  326 ++++++
 .../runtime/state/OperatorStateBackendTest.java    |    4 +-
 .../flink/runtime/state/StateBackendTestBase.java  |    6 +-
 .../state/ttl/mock/MockKeyedStateBackend.java      |    5 +-
 .../streaming/state/RocksDBKeyedStateBackend.java  | 1098 +++-----------------
 .../snapshot/RocksDBSnapshotStrategyBase.java      |  141 +++
 .../state/snapshot/RocksFullSnapshotStrategy.java  |  421 ++++++++
 .../snapshot/RocksIncrementalSnapshotStrategy.java |  534 ++++++++++
 .../state/snapshot/RocksSnapshotUtil.java          |   37 +-
 .../streaming/state/RocksDBAsyncSnapshotTest.java  |   27 +-
 .../streaming/state/RocksDBStateBackendTest.java   |    1 +
 .../flink/streaming/runtime/tasks/StreamTask.java  |    4 +-
 .../tasks/TaskCheckpointingBehaviourTest.java      |   11 +-
 .../apache/flink/core/testutils/OneShotLatch.java  |   18 +-
 25 files changed, 2138 insertions(+), 1673 deletions(-)
 delete mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/io/async/AbstractAsyncCallableWithResources.java
 delete mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/io/async/AsyncDoneCallback.java
 delete mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/io/async/AsyncStoppable.java
 delete mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/io/async/AsyncStoppableTaskWithCallback.java
 delete mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/io/async/StoppableCallbackCallable.java
 create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractSnapshotStrategy.java
 create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/state/AsyncSnapshotCallable.java
 create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/state/AsyncSnapshotCallableTest.java
 create mode 100644 flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksDBSnapshotStrategyBase.java
 create mode 100644 flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksFullSnapshotStrategy.java
 create mode 100644 flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksIncrementalSnapshotStrategy.java
 copy flink-core/src/main/java/org/apache/flink/util/CollectionUtil.java => flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksSnapshotUtil.java (52%)


[flink] 01/02: [FLINK-10042][state] (part 1) Extract snapshot algorithms from inner classes of RocksDBKeyedStateBackend into full classes

Posted by sr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

srichter pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit aba02eb3fcc4472c3d5f5a0f527960d79c659c31
Author: Stefan Richter <s....@data-artisans.com>
AuthorDate: Tue Aug 7 15:57:27 2018 +0200

    [FLINK-10042][state] (part 1) Extract snapshot algorithms from inner classes of RocksDBKeyedStateBackend into full classes
---
 .../flink/runtime/state/SnapshotStrategy.java      |    3 +-
 .../runtime/state/heap/HeapKeyedStateBackend.java  |    5 +
 .../streaming/state/RocksDBKeyedStateBackend.java  | 1071 ++------------------
 .../state/snapshot/RocksFullSnapshotStrategy.java  |  478 +++++++++
 .../snapshot/RocksIncrementalSnapshotStrategy.java |  578 +++++++++++
 .../state/snapshot/RocksSnapshotUtil.java          |   51 +
 .../state/snapshot/SnapshotStrategyBase.java       |   90 ++
 .../streaming/state/RocksDBAsyncSnapshotTest.java  |   27 +-
 .../streaming/state/RocksDBStateBackendTest.java   |    1 +
 9 files changed, 1317 insertions(+), 987 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SnapshotStrategy.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SnapshotStrategy.java
index 9139fa7..3ad68af 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SnapshotStrategy.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SnapshotStrategy.java
@@ -28,8 +28,7 @@ import java.util.concurrent.RunnableFuture;
  *
  * @param <S> type of the returned state object that represents the result of the snapshot operation.
  */
-@FunctionalInterface
-public interface SnapshotStrategy<S extends StateObject> {
+public interface SnapshotStrategy<S extends StateObject> extends CheckpointListener {
 
 	/**
 	 * Operation that writes a snapshot into a stream that is provided by the given {@link CheckpointStreamFactory} and
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
index bc1e0f5..0e2f16c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
@@ -882,6 +882,11 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 				}
 			}
 		}
+
+		@Override
+		public void notifyCheckpointComplete(long checkpointId) throws Exception {
+			// nothing to do.
+		}
 	}
 
 	private interface StateFactory {
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index c159976..87c7e55 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -35,9 +35,8 @@ import org.apache.flink.api.common.typeutils.base.array.BytePrimitiveArraySerial
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.contrib.streaming.state.iterator.RocksStateKeysIterator;
-import org.apache.flink.contrib.streaming.state.iterator.RocksStatesPerKeyGroupMergeIterator;
-import org.apache.flink.contrib.streaming.state.iterator.RocksTransformingIteratorWrapper;
-import org.apache.flink.core.fs.CloseableRegistry;
+import org.apache.flink.contrib.streaming.state.snapshot.RocksFullSnapshotStrategy;
+import org.apache.flink.contrib.streaming.state.snapshot.RocksIncrementalSnapshotStrategy;
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.core.fs.FileStatus;
@@ -47,32 +46,22 @@ import org.apache.flink.core.memory.ByteArrayDataInputView;
 import org.apache.flink.core.memory.ByteArrayDataOutputView;
 import org.apache.flink.core.memory.DataInputView;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
-import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
-import org.apache.flink.runtime.checkpoint.CheckpointType;
-import org.apache.flink.runtime.io.async.AbstractAsyncCallableWithResources;
-import org.apache.flink.runtime.io.async.AsyncStoppableTaskWithCallback;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
-import org.apache.flink.runtime.state.CheckpointStreamWithResultProvider;
-import org.apache.flink.runtime.state.CheckpointedStateScope;
 import org.apache.flink.runtime.state.DirectoryStateHandle;
-import org.apache.flink.runtime.state.DoneFuture;
 import org.apache.flink.runtime.state.IncrementalKeyedStateHandle;
 import org.apache.flink.runtime.state.IncrementalLocalKeyedStateHandle;
 import org.apache.flink.runtime.state.KeyExtractorFunction;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.Keyed;
 import org.apache.flink.runtime.state.KeyedBackendSerializationProxy;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.LocalRecoveryConfig;
-import org.apache.flink.runtime.state.LocalRecoveryDirectoryProvider;
-import org.apache.flink.runtime.state.PlaceholderStreamStateHandle;
 import org.apache.flink.runtime.state.PriorityComparable;
 import org.apache.flink.runtime.state.PriorityComparator;
 import org.apache.flink.runtime.state.PriorityQueueSetFactory;
@@ -80,14 +69,11 @@ import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
 import org.apache.flink.runtime.state.RegisteredPriorityQueueStateBackendMetaInfo;
 import org.apache.flink.runtime.state.RegisteredStateMetaInfoBase;
 import org.apache.flink.runtime.state.SnappyStreamCompressionDecorator;
-import org.apache.flink.runtime.state.SnapshotDirectory;
 import org.apache.flink.runtime.state.SnapshotResult;
 import org.apache.flink.runtime.state.SnapshotStrategy;
 import org.apache.flink.runtime.state.StateHandleID;
-import org.apache.flink.runtime.state.StateObject;
 import org.apache.flink.runtime.state.StateSnapshotTransformer;
 import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
-import org.apache.flink.runtime.state.StateUtil;
 import org.apache.flink.runtime.state.StreamCompressionDecorator;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.UncompressedStreamCompressionDecorator;
@@ -96,25 +82,19 @@ import org.apache.flink.runtime.state.heap.HeapPriorityQueueSetFactory;
 import org.apache.flink.runtime.state.heap.KeyGroupPartitionedPriorityQueue;
 import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
 import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
-import org.apache.flink.util.ExceptionUtils;
 import org.apache.flink.util.FileUtils;
 import org.apache.flink.util.FlinkRuntimeException;
 import org.apache.flink.util.IOUtils;
 import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.ResourceGuard;
 import org.apache.flink.util.StateMigrationException;
-import org.apache.flink.util.function.SupplierWithException;
 
-import org.rocksdb.Checkpoint;
 import org.rocksdb.ColumnFamilyDescriptor;
 import org.rocksdb.ColumnFamilyHandle;
 import org.rocksdb.ColumnFamilyOptions;
 import org.rocksdb.DBOptions;
-import org.rocksdb.ReadOptions;
 import org.rocksdb.RocksDB;
 import org.rocksdb.RocksDBException;
-import org.rocksdb.RocksIterator;
-import org.rocksdb.Snapshot;
 import org.rocksdb.WriteOptions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -125,7 +105,6 @@ import javax.annotation.Nullable;
 import java.io.File;
 import java.io.IOException;
 import java.io.InputStream;
-import java.io.OutputStream;
 import java.nio.file.Files;
 import java.nio.file.StandardCopyOption;
 import java.util.ArrayList;
@@ -144,12 +123,16 @@ import java.util.Spliterator;
 import java.util.Spliterators;
 import java.util.TreeMap;
 import java.util.UUID;
-import java.util.concurrent.FutureTask;
 import java.util.concurrent.RunnableFuture;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
 import java.util.stream.StreamSupport;
 
+import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUtil.END_OF_KEY_GROUP_MARK;
+import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUtil.SST_FILE_SUFFIX;
+import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUtil.clearMetaDataFollowsFlag;
+import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUtil.hasMetaDataFollowsFlag;
+
 /**
  * An {@link AbstractKeyedStateBackend} that stores its state in {@code RocksDB} and serializes state to
  * streams provided by a {@link org.apache.flink.runtime.state.CheckpointStreamFactory} upon
@@ -167,9 +150,6 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	/** The name of the merge operator in RocksDB. Do not change except you know exactly what you do. */
 	public static final String MERGE_OPERATOR_NAME = "stringappendtest";
 
-	/** File suffix of sstable files. */
-	private static final String SST_FILE_SUFFIX = ".sst";
-
 	private static final Map<Class<? extends StateDescriptor>, StateFactory> STATE_FACTORIES =
 		Stream.of(
 			Tuple2.of(ValueStateDescriptor.class, (StateFactory) RocksDBValueState::create),
@@ -230,7 +210,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	 * Information about the k/v states as we create them. This is used to retrieve the
 	 * column family that is used for a state and also for sanity checks when restoring.
 	 */
-	private final Map<String, Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase>> kvStateInformation;
+	private final LinkedHashMap<String, Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase>> kvStateInformation;
 
 	/**
 	 * Map of state names to their corresponding restored state meta info.
@@ -246,20 +226,11 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	/** True if incremental checkpointing is enabled. */
 	private final boolean enableIncrementalCheckpointing;
 
-	/** The state handle ids of all sst files materialized in snapshots for previous checkpoints. */
-	private final SortedMap<Long, Set<StateHandleID>> materializedSstFiles;
-
-	/** The identifier of the last completed checkpoint. */
-	private long lastCompletedCheckpointId = -1L;
-
-	/** Unique ID of this backend. */
-	private UUID backendUID;
-
 	/** The configuration of local recovery. */
 	private final LocalRecoveryConfig localRecoveryConfig;
 
 	/** The snapshot strategy, e.g., if we use full or incremental checkpoints, local state, and so on. */
-	private final SnapshotStrategy<SnapshotResult<KeyedStateHandle>> snapshotStrategy;
+	private SnapshotStrategy<SnapshotResult<KeyedStateHandle>> snapshotStrategy;
 
 	/** Factory for priority queue state. */
 	private final PriorityQueueSetFactory priorityQueueFactory;
@@ -314,12 +285,6 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			RocksDBKeySerializationUtils.computeRequiredBytesInKeyGroupPrefix(getNumberOfKeyGroups());
 		this.kvStateInformation = new LinkedHashMap<>();
 		this.restoredKvStateMetaInfos = new HashMap<>();
-		this.materializedSstFiles = new TreeMap<>();
-		this.backendUID = UUID.randomUUID();
-
-		this.snapshotStrategy = enableIncrementalCheckpointing ?
-			new IncrementalSnapshotStrategy() :
-			new FullSnapshotStrategy();
 
 		this.writeOptions = new WriteOptions().setDisableWAL(true);
 
@@ -333,8 +298,6 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			default:
 				throw new IllegalArgumentException("Unknown priority queue state type: " + priorityQueueStateType);
 		}
-
-		LOG.debug("Setting initial keyed backend uid for operator {} to {}.", this.operatorIdentifier, this.backendUID);
 	}
 
 	private static void checkAndCreateDirectory(File directory) throws IOException {
@@ -508,41 +471,83 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		restoredKvStateMetaInfos.clear();
 
 		try {
+			RocksDBIncrementalRestoreOperation<K> incrementalRestoreOperation = null;
 			if (restoreState == null || restoreState.isEmpty()) {
 				createDB();
 			} else {
 				KeyedStateHandle firstStateHandle = restoreState.iterator().next();
 				if (firstStateHandle instanceof IncrementalKeyedStateHandle
 					|| firstStateHandle instanceof IncrementalLocalKeyedStateHandle) {
-					RocksDBIncrementalRestoreOperation<K> restoreOperation = new RocksDBIncrementalRestoreOperation<>(this);
-					restoreOperation.restore(restoreState);
+					incrementalRestoreOperation = new RocksDBIncrementalRestoreOperation<>(this);
+					incrementalRestoreOperation.restore(restoreState);
 				} else {
-					RocksDBFullRestoreOperation<K> restoreOperation = new RocksDBFullRestoreOperation<>(this);
-					restoreOperation.doRestore(restoreState);
+					RocksDBFullRestoreOperation<K> fullRestoreOperation = new RocksDBFullRestoreOperation<>(this);
+					fullRestoreOperation.doRestore(restoreState);
 				}
 			}
+
+			initializeSnapshotStrategy(incrementalRestoreOperation);
 		} catch (Exception ex) {
 			dispose();
 			throw ex;
 		}
 	}
 
-	@Override
-	public void notifyCheckpointComplete(long completedCheckpointId) {
-
-		if (!enableIncrementalCheckpointing) {
-			return;
-		}
-
-		synchronized (materializedSstFiles) {
-
-			if (completedCheckpointId < lastCompletedCheckpointId) {
-				return;
+	@VisibleForTesting
+	void initializeSnapshotStrategy(
+		@Nullable RocksDBIncrementalRestoreOperation<K> incrementalRestoreOperation) {
+
+		final RocksFullSnapshotStrategy<K> fullSnapshotStrategy =
+			new RocksFullSnapshotStrategy<>(
+				db,
+				rocksDBResourceGuard,
+				keySerializer,
+				kvStateInformation,
+				keyGroupRange,
+				keyGroupPrefixBytes,
+				localRecoveryConfig,
+				cancelStreamRegistry,
+				keyGroupCompressionDecorator);
+
+		if (enableIncrementalCheckpointing) {
+			final UUID backendUID;
+			final SortedMap<Long, Set<StateHandleID>> materializedSstFiles;
+			final long lastCompletedCheckpointId;
+
+			if (incrementalRestoreOperation == null) {
+				backendUID = UUID.randomUUID();
+				materializedSstFiles = new TreeMap<>();
+				lastCompletedCheckpointId = -1L;
+			} else {
+				backendUID = Preconditions.checkNotNull(incrementalRestoreOperation.getRestoredBackendUID());
+				materializedSstFiles = Preconditions.checkNotNull(incrementalRestoreOperation.getRestoredSstFiles());
+				lastCompletedCheckpointId = incrementalRestoreOperation.getLastCompletedCheckpointId();
+				Preconditions.checkState(lastCompletedCheckpointId >= 0L);
 			}
+			// TODO eventually we might want to separate savepoint and snapshot strategy, i.e. having 2 strategies.
+			this.snapshotStrategy = new RocksIncrementalSnapshotStrategy<>(
+				db,
+				rocksDBResourceGuard,
+				keySerializer,
+				kvStateInformation,
+				keyGroupRange,
+				keyGroupPrefixBytes,
+				localRecoveryConfig,
+				cancelStreamRegistry,
+				instanceBasePath,
+				backendUID,
+				materializedSstFiles,
+				lastCompletedCheckpointId,
+				fullSnapshotStrategy);
+		} else {
+			this.snapshotStrategy = fullSnapshotStrategy;
+		}
+	}
 
-			materializedSstFiles.keySet().removeIf(checkpointId -> checkpointId < completedCheckpointId);
-
-			lastCompletedCheckpointId = completedCheckpointId;
+	@Override
+	public void notifyCheckpointComplete(long completedCheckpointId) throws Exception {
+		if (snapshotStrategy != null) {
+			snapshotStrategy.notifyCheckpointComplete(completedCheckpointId);
 		}
 	}
 
@@ -656,10 +661,6 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 		/**
 		 * Restore the KV-state / ColumnFamily meta data for all key-groups referenced by the current state handle.
-		 *
-		 * @throws IOException
-		 * @throws ClassNotFoundException
-		 * @throws RocksDBException
 		 */
 		private void restoreKVStateMetaData() throws IOException, StateMigrationException, RocksDBException {
 
@@ -724,9 +725,6 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 		/**
 		 * Restore the KV-state / ColumnFamily data for all key-groups referenced by the current state handle.
-		 *
-		 * @throws IOException
-		 * @throws RocksDBException
 		 */
 		private void restoreKVStateData() throws IOException, RocksDBException {
 			//for all key-groups in the current state handle...
@@ -752,14 +750,14 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 							while (keyGroupHasMoreKeys) {
 								byte[] key = BytePrimitiveArraySerializer.INSTANCE.deserialize(compressedKgInputView);
 								byte[] value = BytePrimitiveArraySerializer.INSTANCE.deserialize(compressedKgInputView);
-								if (RocksDBFullSnapshotOperation.hasMetaDataFollowsFlag(key)) {
+								if (hasMetaDataFollowsFlag(key)) {
 									//clear the signal bit in the key to make it ready for insertion again
-									RocksDBFullSnapshotOperation.clearMetaDataFollowsFlag(key);
+									clearMetaDataFollowsFlag(key);
 									writeBatchWrapper.put(handle, key, value);
 									//TODO this could be aware of keyGroupPrefixBytes and write only one byte if possible
-									kvStateId = RocksDBFullSnapshotOperation.END_OF_KEY_GROUP_MARK
+									kvStateId = END_OF_KEY_GROUP_MARK
 										& compressedKgInputView.readShort();
-									if (RocksDBFullSnapshotOperation.END_OF_KEY_GROUP_MARK == kvStateId) {
+									if (END_OF_KEY_GROUP_MARK == kvStateId) {
 										keyGroupHasMoreKeys = false;
 									} else {
 										handle = currentStateHandleKVStateColumnFamilies.get(kvStateId);
@@ -781,9 +779,26 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	private static class RocksDBIncrementalRestoreOperation<T> {
 
 		private final RocksDBKeyedStateBackend<T> stateBackend;
+		private final SortedMap<Long, Set<StateHandleID>> restoredSstFiles;
+		private UUID restoredBackendUID;
+		private long lastCompletedCheckpointId;
 
 		private RocksDBIncrementalRestoreOperation(RocksDBKeyedStateBackend<T> stateBackend) {
+
 			this.stateBackend = stateBackend;
+			this.restoredSstFiles = new TreeMap<>();
+		}
+
+		SortedMap<Long, Set<StateHandleID>> getRestoredSstFiles() {
+			return restoredSstFiles;
+		}
+
+		UUID getRestoredBackendUID() {
+			return restoredBackendUID;
+		}
+
+		long getLastCompletedCheckpointId() {
+			return lastCompletedCheckpointId;
 		}
 
 		/**
@@ -872,6 +887,8 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		 */
 		void restoreWithRescaling(Collection<KeyedStateHandle> restoreStateHandles) throws Exception {
 
+			this.restoredBackendUID = UUID.randomUUID();
+
 			initTargetDB(restoreStateHandles, stateBackend.keyGroupRange);
 
 			byte[] startKeyGroupPrefixBytes = new byte[stateBackend.keyGroupPrefixBytes];
@@ -949,6 +966,8 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			@Nonnull
 			private final List<StateMetaInfoSnapshot> stateMetaInfoSnapshots;
 
+			private
+
 			RestoredDBInstance(
 				@Nonnull RocksDB db,
 				@Nonnull List<ColumnFamilyHandle> columnFamilyHandles,
@@ -1113,10 +1132,10 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			List<ColumnFamilyDescriptor> columnFamilyDescriptors,
 			List<StateMetaInfoSnapshot> stateMetaInfoSnapshots) throws Exception {
 			// pick up again the old backend id, so the we can reference existing state
-			stateBackend.backendUID = restoreStateHandle.getBackendIdentifier();
+			this.restoredBackendUID = restoreStateHandle.getBackendIdentifier();
 
 			LOG.debug("Restoring keyed backend uid in operator {} from incremental snapshot to {}.",
-				stateBackend.operatorIdentifier, stateBackend.backendUID);
+				stateBackend.operatorIdentifier, this.restoredBackendUID);
 
 			// create hard links in the instance directory
 			if (!stateBackend.instanceRocksDBPath.mkdirs()) {
@@ -1150,13 +1169,11 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			}
 
 			// use the restore sst files as the base for succeeding checkpoints
-			synchronized (stateBackend.materializedSstFiles) {
-				stateBackend.materializedSstFiles.put(
+				restoredSstFiles.put(
 					restoreStateHandle.getCheckpointId(),
 					restoreStateHandle.getSharedStateHandleIDs());
-			}
 
-			stateBackend.lastCompletedCheckpointId = restoreStateHandle.getCheckpointId();
+			lastCompletedCheckpointId = restoreStateHandle.getCheckpointId();
 		}
 
 		/**
@@ -1447,881 +1464,6 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		return count;
 	}
 
-	private class FullSnapshotStrategy implements SnapshotStrategy<SnapshotResult<KeyedStateHandle>> {
-
-		@Override
-		public RunnableFuture<SnapshotResult<KeyedStateHandle>> performSnapshot(
-			long checkpointId,
-			long timestamp,
-			CheckpointStreamFactory primaryStreamFactory,
-			CheckpointOptions checkpointOptions) throws Exception {
-
-			long startTime = System.currentTimeMillis();
-
-			if (kvStateInformation.isEmpty()) {
-				if (LOG.isDebugEnabled()) {
-					LOG.debug("Asynchronous RocksDB snapshot performed on empty keyed state at {}. Returning null.",
-						timestamp);
-				}
-
-				return DoneFuture.of(SnapshotResult.empty());
-			}
-
-			final SupplierWithException<CheckpointStreamWithResultProvider, Exception> supplier =
-
-				localRecoveryConfig.isLocalRecoveryEnabled() &&
-					(CheckpointType.SAVEPOINT != checkpointOptions.getCheckpointType()) ?
-
-					() -> CheckpointStreamWithResultProvider.createDuplicatingStream(
-						checkpointId,
-						CheckpointedStateScope.EXCLUSIVE,
-						primaryStreamFactory,
-						localRecoveryConfig.getLocalStateDirectoryProvider()) :
-
-					() -> CheckpointStreamWithResultProvider.createSimpleStream(
-						CheckpointedStateScope.EXCLUSIVE,
-						primaryStreamFactory);
-
-			final CloseableRegistry snapshotCloseableRegistry = new CloseableRegistry();
-
-			final RocksDBFullSnapshotOperation<K> snapshotOperation =
-				new RocksDBFullSnapshotOperation<>(
-					RocksDBKeyedStateBackend.this,
-					supplier,
-					snapshotCloseableRegistry);
-
-			snapshotOperation.takeDBSnapShot();
-
-			// implementation of the async IO operation, based on FutureTask
-			AbstractAsyncCallableWithResources<SnapshotResult<KeyedStateHandle>> ioCallable =
-				new AbstractAsyncCallableWithResources<SnapshotResult<KeyedStateHandle>>() {
-
-					@Override
-					protected void acquireResources() throws Exception {
-						cancelStreamRegistry.registerCloseable(snapshotCloseableRegistry);
-						snapshotOperation.openCheckpointStream();
-					}
-
-					@Override
-					protected void releaseResources() throws Exception {
-						closeLocalRegistry();
-						releaseSnapshotOperationResources();
-					}
-
-					private void releaseSnapshotOperationResources() {
-						// hold the db lock while operation on the db to guard us against async db disposal
-						snapshotOperation.releaseSnapshotResources();
-					}
-
-					@Override
-					protected void stopOperation() throws Exception {
-						closeLocalRegistry();
-					}
-
-					private void closeLocalRegistry() {
-						if (cancelStreamRegistry.unregisterCloseable(snapshotCloseableRegistry)) {
-							try {
-								snapshotCloseableRegistry.close();
-							} catch (Exception ex) {
-								LOG.warn("Error closing local registry", ex);
-							}
-						}
-					}
-
-					@Nonnull
-					@Override
-					public SnapshotResult<KeyedStateHandle> performOperation() throws Exception {
-						long startTime = System.currentTimeMillis();
-
-						if (isStopped()) {
-							throw new IOException("RocksDB closed.");
-						}
-
-						snapshotOperation.writeDBSnapshot();
-
-						LOG.debug("Asynchronous RocksDB snapshot ({}, asynchronous part) in thread {} took {} ms.",
-							primaryStreamFactory, Thread.currentThread(), (System.currentTimeMillis() - startTime));
-
-						return snapshotOperation.getSnapshotResultStateHandle();
-					}
-				};
-
-			LOG.debug("Asynchronous RocksDB snapshot ({}, synchronous part) in thread {} took {} ms.",
-				primaryStreamFactory, Thread.currentThread(), (System.currentTimeMillis() - startTime));
-			return AsyncStoppableTaskWithCallback.from(ioCallable);
-		}
-	}
-
-	private class IncrementalSnapshotStrategy implements SnapshotStrategy<SnapshotResult<KeyedStateHandle>> {
-
-		private final SnapshotStrategy<SnapshotResult<KeyedStateHandle>> savepointDelegate;
-
-		public IncrementalSnapshotStrategy() {
-			this.savepointDelegate = new FullSnapshotStrategy();
-		}
-
-		@Override
-		public RunnableFuture<SnapshotResult<KeyedStateHandle>> performSnapshot(
-			long checkpointId,
-			long checkpointTimestamp,
-			CheckpointStreamFactory checkpointStreamFactory,
-			CheckpointOptions checkpointOptions) throws Exception {
-
-			// for savepoints, we delegate to the full snapshot strategy because savepoints are always self-contained.
-			if (CheckpointType.SAVEPOINT == checkpointOptions.getCheckpointType()) {
-				return savepointDelegate.performSnapshot(
-					checkpointId,
-					checkpointTimestamp,
-					checkpointStreamFactory,
-					checkpointOptions);
-			}
-
-			if (db == null) {
-				throw new IOException("RocksDB closed.");
-			}
-
-			if (kvStateInformation.isEmpty()) {
-				if (LOG.isDebugEnabled()) {
-					LOG.debug("Asynchronous RocksDB snapshot performed on empty keyed state at {}. Returning null.", checkpointTimestamp);
-				}
-				return DoneFuture.of(SnapshotResult.empty());
-			}
-
-			SnapshotDirectory snapshotDirectory;
-
-			if (localRecoveryConfig.isLocalRecoveryEnabled()) {
-				// create a "permanent" snapshot directory for local recovery.
-				LocalRecoveryDirectoryProvider directoryProvider = localRecoveryConfig.getLocalStateDirectoryProvider();
-				File directory = directoryProvider.subtaskSpecificCheckpointDirectory(checkpointId);
-
-				if (directory.exists()) {
-					FileUtils.deleteDirectory(directory);
-				}
-
-				if (!directory.mkdirs()) {
-					throw new IOException("Local state base directory for checkpoint " + checkpointId +
-						" already exists: " + directory);
-				}
-
-				// introduces an extra directory because RocksDB wants a non-existing directory for native checkpoints.
-				File rdbSnapshotDir = new File(directory, "rocks_db");
-				Path path = new Path(rdbSnapshotDir.toURI());
-				// create a "permanent" snapshot directory because local recovery is active.
-				snapshotDirectory = SnapshotDirectory.permanent(path);
-			} else {
-				// create a "temporary" snapshot directory because local recovery is inactive.
-				Path path = new Path(instanceBasePath.getAbsolutePath(), "chk-" + checkpointId);
-				snapshotDirectory = SnapshotDirectory.temporary(path);
-			}
-
-			final RocksDBIncrementalSnapshotOperation<K> snapshotOperation =
-				new RocksDBIncrementalSnapshotOperation<>(
-					RocksDBKeyedStateBackend.this,
-					checkpointStreamFactory,
-					snapshotDirectory,
-					checkpointId);
-
-			try {
-				snapshotOperation.takeSnapshot();
-			} catch (Exception e) {
-				snapshotOperation.stop();
-				snapshotOperation.releaseResources(true);
-				throw e;
-			}
-
-			return new FutureTask<SnapshotResult<KeyedStateHandle>>(
-				snapshotOperation::runSnapshot
-			) {
-				@Override
-				public boolean cancel(boolean mayInterruptIfRunning) {
-					snapshotOperation.stop();
-					return super.cancel(mayInterruptIfRunning);
-				}
-
-				@Override
-				protected void done() {
-					snapshotOperation.releaseResources(isCancelled());
-				}
-			};
-		}
-	}
-
-	/**
-	 * Encapsulates the process to perform a full snapshot of a RocksDBKeyedStateBackend.
-	 */
-	@VisibleForTesting
-	static class RocksDBFullSnapshotOperation<K>
-		extends AbstractAsyncCallableWithResources<SnapshotResult<KeyedStateHandle>> {
-
-		static final int FIRST_BIT_IN_BYTE_MASK = 0x80;
-		static final int END_OF_KEY_GROUP_MARK = 0xFFFF;
-
-		private final RocksDBKeyedStateBackend<K> stateBackend;
-		private final KeyGroupRangeOffsets keyGroupRangeOffsets;
-		private final SupplierWithException<CheckpointStreamWithResultProvider, Exception> checkpointStreamSupplier;
-		private final CloseableRegistry snapshotCloseableRegistry;
-		private final ResourceGuard.Lease dbLease;
-
-		private Snapshot snapshot;
-		private ReadOptions readOptions;
-
-		/**
-		 * The state meta data.
-		 */
-		private List<StateMetaInfoSnapshot> stateMetaInfoSnapshots;
-
-		/**
-		 * The copied column handle.
-		 */
-		private List<Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase>> copiedMeta;
-
-		private List<Tuple2<RocksIteratorWrapper, Integer>> kvStateIterators;
-
-		private CheckpointStreamWithResultProvider checkpointStreamWithResultProvider;
-		private DataOutputView outputView;
-
-		RocksDBFullSnapshotOperation(
-			RocksDBKeyedStateBackend<K> stateBackend,
-			SupplierWithException<CheckpointStreamWithResultProvider, Exception> checkpointStreamSupplier,
-			CloseableRegistry registry) throws IOException {
-
-			this.stateBackend = stateBackend;
-			this.checkpointStreamSupplier = checkpointStreamSupplier;
-			this.keyGroupRangeOffsets = new KeyGroupRangeOffsets(stateBackend.keyGroupRange);
-			this.snapshotCloseableRegistry = registry;
-			this.dbLease = this.stateBackend.rocksDBResourceGuard.acquireResource();
-		}
-
-		/**
-		 * 1) Create a snapshot object from RocksDB.
-		 *
-		 */
-		public void takeDBSnapShot() {
-			Preconditions.checkArgument(snapshot == null, "Only one ongoing snapshot allowed!");
-
-			this.stateMetaInfoSnapshots = new ArrayList<>(stateBackend.kvStateInformation.size());
-
-			this.copiedMeta = new ArrayList<>(stateBackend.kvStateInformation.size());
-
-			for (Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> tuple2 :
-				stateBackend.kvStateInformation.values()) {
-				// snapshot meta info
-				this.stateMetaInfoSnapshots.add(tuple2.f1.snapshot());
-				this.copiedMeta.add(tuple2);
-			}
-			this.snapshot = stateBackend.db.getSnapshot();
-		}
-
-		/**
-		 * 2) Open CheckpointStateOutputStream through the checkpointStreamFactory into which we will write.
-		 *
-		 * @throws Exception
-		 */
-		public void openCheckpointStream() throws Exception {
-			Preconditions.checkArgument(checkpointStreamWithResultProvider == null,
-				"Output stream for snapshot is already set.");
-
-			checkpointStreamWithResultProvider = checkpointStreamSupplier.get();
-			snapshotCloseableRegistry.registerCloseable(checkpointStreamWithResultProvider);
-			outputView = new DataOutputViewStreamWrapper(
-				checkpointStreamWithResultProvider.getCheckpointOutputStream());
-		}
-
-		/**
-		 * 3) Write the actual data from RocksDB from the time we took the snapshot object in (1).
-		 *
-		 * @throws IOException
-		 */
-		public void writeDBSnapshot() throws IOException, InterruptedException, RocksDBException {
-
-			if (null == snapshot) {
-				throw new IOException("No snapshot available. Might be released due to cancellation.");
-			}
-
-			Preconditions.checkNotNull(checkpointStreamWithResultProvider, "No output stream to write snapshot.");
-			writeKVStateMetaData();
-			writeKVStateData();
-		}
-
-		/**
-		 * 4) Returns a snapshot result for the completed snapshot.
-		 *
-		 * @return snapshot result for the completed snapshot.
-		 */
-		@Nonnull
-		public SnapshotResult<KeyedStateHandle> getSnapshotResultStateHandle() throws IOException {
-
-			if (snapshotCloseableRegistry.unregisterCloseable(checkpointStreamWithResultProvider)) {
-
-				SnapshotResult<StreamStateHandle> res =
-					checkpointStreamWithResultProvider.closeAndFinalizeCheckpointStreamResult();
-				checkpointStreamWithResultProvider = null;
-				return CheckpointStreamWithResultProvider.toKeyedStateHandleSnapshotResult(res, keyGroupRangeOffsets);
-			}
-
-			return SnapshotResult.empty();
-		}
-
-		/**
-		 * 5) Release the snapshot object for RocksDB and clean up.
-		 */
-		public void releaseSnapshotResources() {
-
-			checkpointStreamWithResultProvider = null;
-
-			if (null != kvStateIterators) {
-				for (Tuple2<RocksIteratorWrapper, Integer> kvStateIterator : kvStateIterators) {
-					IOUtils.closeQuietly(kvStateIterator.f0);
-				}
-				kvStateIterators = null;
-			}
-
-			if (null != snapshot) {
-				if (null != stateBackend.db) {
-					stateBackend.db.releaseSnapshot(snapshot);
-				}
-				IOUtils.closeQuietly(snapshot);
-				snapshot = null;
-			}
-
-			if (null != readOptions) {
-				IOUtils.closeQuietly(readOptions);
-				readOptions = null;
-			}
-
-			this.dbLease.close();
-		}
-
-		private void writeKVStateMetaData() throws IOException {
-
-			this.kvStateIterators = new ArrayList<>(copiedMeta.size());
-
-			int kvStateId = 0;
-
-			//retrieve iterator for this k/v states
-			readOptions = new ReadOptions();
-			readOptions.setSnapshot(snapshot);
-
-			for (Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> tuple2 : copiedMeta) {
-				RocksIteratorWrapper rocksIteratorWrapper =
-					getRocksIterator(stateBackend.db, tuple2.f0, tuple2.f1, readOptions);
-				kvStateIterators.add(new Tuple2<>(rocksIteratorWrapper, kvStateId));
-				++kvStateId;
-			}
-
-			KeyedBackendSerializationProxy<K> serializationProxy =
-				new KeyedBackendSerializationProxy<>(
-					// TODO: this code assumes that writing a serializer is threadsafe, we should support to
-					// get a serialized form already at state registration time in the future
-					stateBackend.getKeySerializer(),
-					stateMetaInfoSnapshots,
-					!Objects.equals(
-						UncompressedStreamCompressionDecorator.INSTANCE,
-						stateBackend.keyGroupCompressionDecorator));
-
-			serializationProxy.write(outputView);
-		}
-
-		private void writeKVStateData() throws IOException, InterruptedException {
-			byte[] previousKey = null;
-			byte[] previousValue = null;
-			DataOutputView kgOutView = null;
-			OutputStream kgOutStream = null;
-			CheckpointStreamFactory.CheckpointStateOutputStream checkpointOutputStream =
-				checkpointStreamWithResultProvider.getCheckpointOutputStream();
-
-			try {
-				// Here we transfer ownership of RocksIterators to the RocksStatesPerKeyGroupMergeIterator
-				try (RocksStatesPerKeyGroupMergeIterator mergeIterator = new RocksStatesPerKeyGroupMergeIterator(
-					kvStateIterators, stateBackend.keyGroupPrefixBytes)) {
-
-					// handover complete, null out to prevent double close
-					kvStateIterators = null;
-
-					//preamble: setup with first key-group as our lookahead
-					if (mergeIterator.isValid()) {
-						//begin first key-group by recording the offset
-						keyGroupRangeOffsets.setKeyGroupOffset(
-							mergeIterator.keyGroup(),
-							checkpointOutputStream.getPos());
-						//write the k/v-state id as metadata
-						kgOutStream = stateBackend.keyGroupCompressionDecorator.
-							decorateWithCompression(checkpointOutputStream);
-						kgOutView = new DataOutputViewStreamWrapper(kgOutStream);
-						//TODO this could be aware of keyGroupPrefixBytes and write only one byte if possible
-						kgOutView.writeShort(mergeIterator.kvStateId());
-						previousKey = mergeIterator.key();
-						previousValue = mergeIterator.value();
-						mergeIterator.next();
-					}
-
-					//main loop: write k/v pairs ordered by (key-group, kv-state), thereby tracking key-group offsets.
-					while (mergeIterator.isValid()) {
-
-						assert (!hasMetaDataFollowsFlag(previousKey));
-
-						//set signal in first key byte that meta data will follow in the stream after this k/v pair
-						if (mergeIterator.isNewKeyGroup() || mergeIterator.isNewKeyValueState()) {
-
-							//be cooperative and check for interruption from time to time in the hot loop
-							checkInterrupted();
-
-							setMetaDataFollowsFlagInKey(previousKey);
-						}
-
-						writeKeyValuePair(previousKey, previousValue, kgOutView);
-
-						//write meta data if we have to
-						if (mergeIterator.isNewKeyGroup()) {
-							//TODO this could be aware of keyGroupPrefixBytes and write only one byte if possible
-							kgOutView.writeShort(END_OF_KEY_GROUP_MARK);
-							// this will just close the outer stream
-							kgOutStream.close();
-							//begin new key-group
-							keyGroupRangeOffsets.setKeyGroupOffset(
-								mergeIterator.keyGroup(),
-								checkpointOutputStream.getPos());
-							//write the kev-state
-							//TODO this could be aware of keyGroupPrefixBytes and write only one byte if possible
-							kgOutStream = stateBackend.keyGroupCompressionDecorator.
-								decorateWithCompression(checkpointOutputStream);
-							kgOutView = new DataOutputViewStreamWrapper(kgOutStream);
-							kgOutView.writeShort(mergeIterator.kvStateId());
-						} else if (mergeIterator.isNewKeyValueState()) {
-							//write the k/v-state
-							//TODO this could be aware of keyGroupPrefixBytes and write only one byte if possible
-							kgOutView.writeShort(mergeIterator.kvStateId());
-						}
-
-						//request next k/v pair
-						previousKey = mergeIterator.key();
-						previousValue = mergeIterator.value();
-						mergeIterator.next();
-					}
-				}
-
-				//epilogue: write last key-group
-				if (previousKey != null) {
-					assert (!hasMetaDataFollowsFlag(previousKey));
-					setMetaDataFollowsFlagInKey(previousKey);
-					writeKeyValuePair(previousKey, previousValue, kgOutView);
-					//TODO this could be aware of keyGroupPrefixBytes and write only one byte if possible
-					kgOutView.writeShort(END_OF_KEY_GROUP_MARK);
-					// this will just close the outer stream
-					kgOutStream.close();
-					kgOutStream = null;
-				}
-
-			} finally {
-				// this will just close the outer stream
-				IOUtils.closeQuietly(kgOutStream);
-			}
-		}
-
-		private void writeKeyValuePair(byte[] key, byte[] value, DataOutputView out) throws IOException {
-			BytePrimitiveArraySerializer.INSTANCE.serialize(key, out);
-			BytePrimitiveArraySerializer.INSTANCE.serialize(value, out);
-		}
-
-		static void setMetaDataFollowsFlagInKey(byte[] key) {
-			key[0] |= FIRST_BIT_IN_BYTE_MASK;
-		}
-
-		static void clearMetaDataFollowsFlag(byte[] key) {
-			key[0] &= (~RocksDBFullSnapshotOperation.FIRST_BIT_IN_BYTE_MASK);
-		}
-
-		static boolean hasMetaDataFollowsFlag(byte[] key) {
-			return 0 != (key[0] & RocksDBFullSnapshotOperation.FIRST_BIT_IN_BYTE_MASK);
-		}
-
-		private static void checkInterrupted() throws InterruptedException {
-			if (Thread.currentThread().isInterrupted()) {
-				throw new InterruptedException("RocksDB snapshot interrupted.");
-			}
-		}
-
-		@Override
-		protected void acquireResources() throws Exception {
-			stateBackend.cancelStreamRegistry.registerCloseable(snapshotCloseableRegistry);
-			openCheckpointStream();
-		}
-
-		@Override
-		protected void releaseResources() {
-			closeLocalRegistry();
-			releaseSnapshotOperationResources();
-		}
-
-		private void releaseSnapshotOperationResources() {
-			// hold the db lock while operation on the db to guard us against async db disposal
-			releaseSnapshotResources();
-		}
-
-		@Override
-		protected void stopOperation() {
-			closeLocalRegistry();
-		}
-
-		private void closeLocalRegistry() {
-			if (stateBackend.cancelStreamRegistry.unregisterCloseable(snapshotCloseableRegistry)) {
-				try {
-					snapshotCloseableRegistry.close();
-				} catch (Exception ex) {
-					LOG.warn("Error closing local registry", ex);
-				}
-			}
-		}
-
-		@Nonnull
-		@Override
-		public SnapshotResult<KeyedStateHandle> performOperation() throws Exception {
-			long startTime = System.currentTimeMillis();
-
-			if (isStopped()) {
-				throw new IOException("RocksDB closed.");
-			}
-
-			writeDBSnapshot();
-
-			LOG.debug("Asynchronous RocksDB snapshot ({}, asynchronous part) in thread {} took {} ms.",
-				checkpointStreamSupplier, Thread.currentThread(), (System.currentTimeMillis() - startTime));
-
-			return getSnapshotResultStateHandle();
-		}
-	}
-
-	/**
-	 * Encapsulates the process to perform an incremental snapshot of a RocksDBKeyedStateBackend.
-	 */
-	private static final class RocksDBIncrementalSnapshotOperation<K> {
-
-		/** The backend which we snapshot. */
-		private final RocksDBKeyedStateBackend<K> stateBackend;
-
-		/** Stream factory that creates the outpus streams to DFS. */
-		private final CheckpointStreamFactory checkpointStreamFactory;
-
-		/** Id for the current checkpoint. */
-		private final long checkpointId;
-
-		/** All sst files that were part of the last previously completed checkpoint. */
-		private Set<StateHandleID> baseSstFiles;
-
-		/** The state meta data. */
-		private final List<StateMetaInfoSnapshot> stateMetaInfoSnapshots = new ArrayList<>();
-
-		/** Local directory for the RocksDB native backup. */
-		private SnapshotDirectory localBackupDirectory;
-
-		// Registry for all opened i/o streams
-		private final CloseableRegistry closeableRegistry = new CloseableRegistry();
-
-		// new sst files since the last completed checkpoint
-		private final Map<StateHandleID, StreamStateHandle> sstFiles = new HashMap<>();
-
-		// handles to the misc files in the current snapshot
-		private final Map<StateHandleID, StreamStateHandle> miscFiles = new HashMap<>();
-
-		// This lease protects from concurrent disposal of the native rocksdb instance.
-		private final ResourceGuard.Lease dbLease;
-
-		private SnapshotResult<StreamStateHandle> metaStateHandle = null;
-
-		private RocksDBIncrementalSnapshotOperation(
-			RocksDBKeyedStateBackend<K> stateBackend,
-			CheckpointStreamFactory checkpointStreamFactory,
-			SnapshotDirectory localBackupDirectory,
-			long checkpointId) throws IOException {
-
-			this.stateBackend = stateBackend;
-			this.checkpointStreamFactory = checkpointStreamFactory;
-			this.checkpointId = checkpointId;
-			this.dbLease = this.stateBackend.rocksDBResourceGuard.acquireResource();
-			this.localBackupDirectory = localBackupDirectory;
-		}
-
-		private StreamStateHandle materializeStateData(Path filePath) throws Exception {
-			FSDataInputStream inputStream = null;
-			CheckpointStreamFactory.CheckpointStateOutputStream outputStream = null;
-
-			try {
-				final byte[] buffer = new byte[8 * 1024];
-
-				FileSystem backupFileSystem = localBackupDirectory.getFileSystem();
-				inputStream = backupFileSystem.open(filePath);
-				closeableRegistry.registerCloseable(inputStream);
-
-				outputStream = checkpointStreamFactory
-					.createCheckpointStateOutputStream(CheckpointedStateScope.SHARED);
-				closeableRegistry.registerCloseable(outputStream);
-
-				while (true) {
-					int numBytes = inputStream.read(buffer);
-
-					if (numBytes == -1) {
-						break;
-					}
-
-					outputStream.write(buffer, 0, numBytes);
-				}
-
-				StreamStateHandle result = null;
-				if (closeableRegistry.unregisterCloseable(outputStream)) {
-					result = outputStream.closeAndGetHandle();
-					outputStream = null;
-				}
-				return result;
-
-			} finally {
-
-				if (closeableRegistry.unregisterCloseable(inputStream)) {
-					inputStream.close();
-				}
-
-				if (closeableRegistry.unregisterCloseable(outputStream)) {
-					outputStream.close();
-				}
-			}
-		}
-
-		@Nonnull
-		private SnapshotResult<StreamStateHandle> materializeMetaData() throws Exception {
-
-			LocalRecoveryConfig localRecoveryConfig = stateBackend.localRecoveryConfig;
-
-			CheckpointStreamWithResultProvider streamWithResultProvider =
-
-				localRecoveryConfig.isLocalRecoveryEnabled() ?
-
-					CheckpointStreamWithResultProvider.createDuplicatingStream(
-						checkpointId,
-						CheckpointedStateScope.EXCLUSIVE,
-						checkpointStreamFactory,
-						localRecoveryConfig.getLocalStateDirectoryProvider()) :
-
-					CheckpointStreamWithResultProvider.createSimpleStream(
-						CheckpointedStateScope.EXCLUSIVE,
-						checkpointStreamFactory);
-
-			try {
-				closeableRegistry.registerCloseable(streamWithResultProvider);
-
-				//no need for compression scheme support because sst-files are already compressed
-				KeyedBackendSerializationProxy<K> serializationProxy =
-					new KeyedBackendSerializationProxy<>(
-						stateBackend.keySerializer,
-						stateMetaInfoSnapshots,
-						false);
-
-				DataOutputView out =
-					new DataOutputViewStreamWrapper(streamWithResultProvider.getCheckpointOutputStream());
-
-				serializationProxy.write(out);
-
-				if (closeableRegistry.unregisterCloseable(streamWithResultProvider)) {
-					SnapshotResult<StreamStateHandle> result =
-						streamWithResultProvider.closeAndFinalizeCheckpointStreamResult();
-					streamWithResultProvider = null;
-					return result;
-				} else {
-					throw new IOException("Stream already closed and cannot return a handle.");
-				}
-			} finally {
-				if (streamWithResultProvider != null) {
-					if (closeableRegistry.unregisterCloseable(streamWithResultProvider)) {
-						IOUtils.closeQuietly(streamWithResultProvider);
-					}
-				}
-			}
-		}
-
-		void takeSnapshot() throws Exception {
-
-			final long lastCompletedCheckpoint;
-
-			// use the last completed checkpoint as the comparison base.
-			synchronized (stateBackend.materializedSstFiles) {
-				lastCompletedCheckpoint = stateBackend.lastCompletedCheckpointId;
-				baseSstFiles = stateBackend.materializedSstFiles.get(lastCompletedCheckpoint);
-			}
-
-			LOG.trace("Taking incremental snapshot for checkpoint {}. Snapshot is based on last completed checkpoint {} " +
-				"assuming the following (shared) files as base: {}.", checkpointId, lastCompletedCheckpoint, baseSstFiles);
-
-			// save meta data
-			for (Map.Entry<String, Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase>> stateMetaInfoEntry
-				: stateBackend.kvStateInformation.entrySet()) {
-				stateMetaInfoSnapshots.add(stateMetaInfoEntry.getValue().f1.snapshot());
-			}
-
-			LOG.trace("Local RocksDB checkpoint goes to backup path {}.", localBackupDirectory);
-
-			if (localBackupDirectory.exists()) {
-				throw new IllegalStateException("Unexpected existence of the backup directory.");
-			}
-
-			// create hard links of living files in the snapshot path
-			try (Checkpoint checkpoint = Checkpoint.create(stateBackend.db)) {
-				checkpoint.createCheckpoint(localBackupDirectory.getDirectory().getPath());
-			}
-		}
-
-		@Nonnull
-		SnapshotResult<KeyedStateHandle> runSnapshot() throws Exception {
-
-			stateBackend.cancelStreamRegistry.registerCloseable(closeableRegistry);
-
-			// write meta data
-			metaStateHandle = materializeMetaData();
-
-			// sanity checks - they should never fail
-			Preconditions.checkNotNull(metaStateHandle,
-				"Metadata was not properly created.");
-			Preconditions.checkNotNull(metaStateHandle.getJobManagerOwnedSnapshot(),
-				"Metadata for job manager was not properly created.");
-
-			// write state data
-			Preconditions.checkState(localBackupDirectory.exists());
-
-			FileStatus[] fileStatuses = localBackupDirectory.listStatus();
-			if (fileStatuses != null) {
-				for (FileStatus fileStatus : fileStatuses) {
-					final Path filePath = fileStatus.getPath();
-					final String fileName = filePath.getName();
-					final StateHandleID stateHandleID = new StateHandleID(fileName);
-
-					if (fileName.endsWith(SST_FILE_SUFFIX)) {
-						final boolean existsAlready =
-							baseSstFiles != null && baseSstFiles.contains(stateHandleID);
-
-						if (existsAlready) {
-							// we introduce a placeholder state handle, that is replaced with the
-							// original from the shared state registry (created from a previous checkpoint)
-							sstFiles.put(
-								stateHandleID,
-								new PlaceholderStreamStateHandle());
-						} else {
-							sstFiles.put(stateHandleID, materializeStateData(filePath));
-						}
-					} else {
-						StreamStateHandle fileHandle = materializeStateData(filePath);
-						miscFiles.put(stateHandleID, fileHandle);
-					}
-				}
-			}
-
-			synchronized (stateBackend.materializedSstFiles) {
-				stateBackend.materializedSstFiles.put(checkpointId, sstFiles.keySet());
-			}
-
-			IncrementalKeyedStateHandle jmIncrementalKeyedStateHandle = new IncrementalKeyedStateHandle(
-				stateBackend.backendUID,
-				stateBackend.keyGroupRange,
-				checkpointId,
-				sstFiles,
-				miscFiles,
-				metaStateHandle.getJobManagerOwnedSnapshot());
-
-			StreamStateHandle taskLocalSnapshotMetaDataStateHandle = metaStateHandle.getTaskLocalSnapshot();
-			DirectoryStateHandle directoryStateHandle = null;
-
-			try {
-
-				directoryStateHandle = localBackupDirectory.completeSnapshotAndGetHandle();
-			} catch (IOException ex) {
-
-				Exception collector = ex;
-
-				try {
-					taskLocalSnapshotMetaDataStateHandle.discardState();
-				} catch (Exception discardEx) {
-					collector = ExceptionUtils.firstOrSuppressed(discardEx, collector);
-				}
-
-				LOG.warn("Problem with local state snapshot.", collector);
-			}
-
-			if (directoryStateHandle != null && taskLocalSnapshotMetaDataStateHandle != null) {
-
-				IncrementalLocalKeyedStateHandle localDirKeyedStateHandle =
-					new IncrementalLocalKeyedStateHandle(
-						stateBackend.backendUID,
-						checkpointId,
-						directoryStateHandle,
-						stateBackend.keyGroupRange,
-						taskLocalSnapshotMetaDataStateHandle,
-						sstFiles.keySet());
-				return SnapshotResult.withLocalState(jmIncrementalKeyedStateHandle, localDirKeyedStateHandle);
-			} else {
-				return SnapshotResult.of(jmIncrementalKeyedStateHandle);
-			}
-		}
-
-		void stop() {
-
-			if (stateBackend.cancelStreamRegistry.unregisterCloseable(closeableRegistry)) {
-				try {
-					closeableRegistry.close();
-				} catch (IOException e) {
-					LOG.warn("Could not properly close io streams.", e);
-				}
-			}
-		}
-
-		void releaseResources(boolean canceled) {
-
-			dbLease.close();
-
-			if (stateBackend.cancelStreamRegistry.unregisterCloseable(closeableRegistry)) {
-				try {
-					closeableRegistry.close();
-				} catch (IOException e) {
-					LOG.warn("Exception on closing registry.", e);
-				}
-			}
-
-			try {
-				if (localBackupDirectory.exists()) {
-					LOG.trace("Running cleanup for local RocksDB backup directory {}.", localBackupDirectory);
-					boolean cleanupOk = localBackupDirectory.cleanup();
-
-					if (!cleanupOk) {
-						LOG.debug("Could not properly cleanup local RocksDB backup directory.");
-					}
-				}
-			} catch (IOException e) {
-				LOG.warn("Could not properly cleanup local RocksDB backup directory.", e);
-			}
-
-			if (canceled) {
-				Collection<StateObject> statesToDiscard =
-					new ArrayList<>(1 + miscFiles.size() + sstFiles.size());
-
-				statesToDiscard.add(metaStateHandle);
-				statesToDiscard.addAll(miscFiles.values());
-				statesToDiscard.addAll(sstFiles.values());
-
-				try {
-					StateUtil.bestEffortDiscardAllStateObjects(statesToDiscard);
-				} catch (Exception e) {
-					LOG.warn("Could not properly discard states.", e);
-				}
-
-				if (localBackupDirectory.isSnapshotCompleted()) {
-					try {
-						DirectoryStateHandle directoryStateHandle = localBackupDirectory.completeSnapshotAndGetHandle();
-						if (directoryStateHandle != null) {
-							directoryStateHandle.discardState();
-						}
-					} catch (Exception e) {
-						LOG.warn("Could not properly discard local state.", e);
-					}
-				}
-			}
-		}
-	}
-
 	public static RocksIteratorWrapper getRocksIterator(RocksDB db) {
 		return new RocksIteratorWrapper(db.newIterator());
 	}
@@ -2332,23 +1474,6 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		return new RocksIteratorWrapper(db.newIterator(columnFamilyHandle));
 	}
 
-	@SuppressWarnings("unchecked")
-	private static RocksIteratorWrapper getRocksIterator(
-		RocksDB db,
-		ColumnFamilyHandle columnFamilyHandle,
-		RegisteredStateMetaInfoBase metaInfo,
-		ReadOptions readOptions) {
-		StateSnapshotTransformer<byte[]> stateSnapshotTransformer = null;
-		if (metaInfo instanceof RegisteredKeyValueStateBackendMetaInfo) {
-			stateSnapshotTransformer = (StateSnapshotTransformer<byte[]>)
-				((RegisteredKeyValueStateBackendMetaInfo<?, ?>) metaInfo).getSnapshotTransformer();
-		}
-		RocksIterator rocksIterator = db.newIterator(columnFamilyHandle, readOptions);
-		return stateSnapshotTransformer == null ?
-			new RocksIteratorWrapper(rocksIterator) :
-			new RocksTransformingIteratorWrapper(rocksIterator, stateSnapshotTransformer);
-	}
-
 	/**
 	 * Encapsulates the logic and resources in connection with creating priority queue state structures.
 	 */
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksFullSnapshotStrategy.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksFullSnapshotStrategy.java
new file mode 100644
index 0000000..0cc9729
--- /dev/null
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksFullSnapshotStrategy.java
@@ -0,0 +1,478 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.contrib.streaming.state.snapshot;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.array.BytePrimitiveArraySerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.contrib.streaming.state.RocksIteratorWrapper;
+import org.apache.flink.contrib.streaming.state.iterator.RocksStatesPerKeyGroupMergeIterator;
+import org.apache.flink.contrib.streaming.state.iterator.RocksTransformingIteratorWrapper;
+import org.apache.flink.core.fs.CloseableRegistry;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.CheckpointType;
+import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.runtime.state.CheckpointStreamWithResultProvider;
+import org.apache.flink.runtime.state.CheckpointedStateScope;
+import org.apache.flink.runtime.state.DoneFuture;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
+import org.apache.flink.runtime.state.KeyedBackendSerializationProxy;
+import org.apache.flink.runtime.state.KeyedStateHandle;
+import org.apache.flink.runtime.state.LocalRecoveryConfig;
+import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
+import org.apache.flink.runtime.state.RegisteredStateMetaInfoBase;
+import org.apache.flink.runtime.state.SnapshotResult;
+import org.apache.flink.runtime.state.StateSnapshotTransformer;
+import org.apache.flink.runtime.state.StreamCompressionDecorator;
+import org.apache.flink.runtime.state.UncompressedStreamCompressionDecorator;
+import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
+import org.apache.flink.util.IOUtils;
+import org.apache.flink.util.ResourceGuard;
+import org.apache.flink.util.function.SupplierWithException;
+
+import org.rocksdb.ColumnFamilyHandle;
+import org.rocksdb.ReadOptions;
+import org.rocksdb.RocksDB;
+import org.rocksdb.RocksIterator;
+import org.rocksdb.Snapshot;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Objects;
+import java.util.concurrent.Callable;
+import java.util.concurrent.CancellationException;
+import java.util.concurrent.FutureTask;
+import java.util.concurrent.RunnableFuture;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUtil.END_OF_KEY_GROUP_MARK;
+import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUtil.hasMetaDataFollowsFlag;
+import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUtil.setMetaDataFollowsFlagInKey;
+
+/**
+ * Snapshot strategy to create full snapshots of
+ * {@link org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend}. Iterates and writes all states from a
+ * RocksDB snapshot of the column families.
+ *
+ * @param <K> type of the backend keys.
+ */
+public class RocksFullSnapshotStrategy<K> extends SnapshotStrategyBase<K> {
+
+	private static final Logger LOG = LoggerFactory.getLogger(RocksFullSnapshotStrategy.class);
+
+	/** This decorator is used to apply compression per key-group for the written snapshot data. */
+	@Nonnull
+	private final StreamCompressionDecorator keyGroupCompressionDecorator;
+
+	public RocksFullSnapshotStrategy(
+		@Nonnull RocksDB db,
+		@Nonnull ResourceGuard rocksDBResourceGuard,
+		@Nonnull TypeSerializer<K> keySerializer,
+		@Nonnull LinkedHashMap<String, Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase>> kvStateInformation,
+		@Nonnull KeyGroupRange keyGroupRange,
+		@Nonnegative int keyGroupPrefixBytes,
+		@Nonnull LocalRecoveryConfig localRecoveryConfig,
+		@Nonnull CloseableRegistry cancelStreamRegistry,
+		@Nonnull StreamCompressionDecorator keyGroupCompressionDecorator) {
+		super(
+			db,
+			rocksDBResourceGuard,
+			keySerializer,
+			kvStateInformation,
+			keyGroupRange,
+			keyGroupPrefixBytes,
+			localRecoveryConfig,
+			cancelStreamRegistry);
+
+		this.keyGroupCompressionDecorator = keyGroupCompressionDecorator;
+	}
+
+	@Override
+	public RunnableFuture<SnapshotResult<KeyedStateHandle>> performSnapshot(
+		long checkpointId,
+		long timestamp,
+		CheckpointStreamFactory primaryStreamFactory,
+		CheckpointOptions checkpointOptions) throws Exception {
+
+		long startTime = System.currentTimeMillis();
+
+		if (kvStateInformation.isEmpty()) {
+			if (LOG.isDebugEnabled()) {
+				LOG.debug("Asynchronous RocksDB snapshot performed on empty keyed state at {}. Returning null.",
+					timestamp);
+			}
+
+			return DoneFuture.of(SnapshotResult.empty());
+		}
+
+		final SupplierWithException<CheckpointStreamWithResultProvider, Exception> supplier =
+
+			localRecoveryConfig.isLocalRecoveryEnabled() &&
+				(CheckpointType.SAVEPOINT != checkpointOptions.getCheckpointType()) ?
+
+				() -> CheckpointStreamWithResultProvider.createDuplicatingStream(
+					checkpointId,
+					CheckpointedStateScope.EXCLUSIVE,
+					primaryStreamFactory,
+					localRecoveryConfig.getLocalStateDirectoryProvider()) :
+
+				() -> CheckpointStreamWithResultProvider.createSimpleStream(
+					CheckpointedStateScope.EXCLUSIVE,
+					primaryStreamFactory);
+
+		final CloseableRegistry snapshotCloseableRegistry = new CloseableRegistry();
+
+		final RocksDBFullSnapshotCallable snapshotOperation =
+			new RocksDBFullSnapshotCallable(supplier, snapshotCloseableRegistry);
+
+		return new SnapshotTask(snapshotOperation);
+	}
+
+	@Override
+	public void notifyCheckpointComplete(long checkpointId) {
+		// nothing to do.
+	}
+
+	/**
+	 * Wrapping task to run a {@link RocksDBFullSnapshotCallable} and delegate cancellation.
+	 */
+	private class SnapshotTask extends FutureTask<SnapshotResult<KeyedStateHandle>> {
+
+		/** Reference to the callable for cancellation. */
+		@Nonnull
+		private final AutoCloseable callableClose;
+
+		SnapshotTask(@Nonnull RocksDBFullSnapshotCallable callable) {
+			super(callable);
+			this.callableClose = callable;
+		}
+
+		@Override
+		public boolean cancel(boolean mayInterruptIfRunning) {
+			IOUtils.closeQuietly(callableClose);
+			return super.cancel(mayInterruptIfRunning);
+		}
+	}
+
+	/**
+	 * Encapsulates the process to perform a full snapshot of a RocksDBKeyedStateBackend.
+	 */
+	@VisibleForTesting
+	private class RocksDBFullSnapshotCallable implements Callable<SnapshotResult<KeyedStateHandle>>, AutoCloseable {
+
+		@Nonnull
+		private final KeyGroupRangeOffsets keyGroupRangeOffsets;
+
+		@Nonnull
+		private final SupplierWithException<CheckpointStreamWithResultProvider, Exception> checkpointStreamSupplier;
+
+		@Nonnull
+		private final CloseableRegistry snapshotCloseableRegistry;
+
+		@Nonnull
+		private final ResourceGuard.Lease dbLease;
+
+		@Nonnull
+		private final Snapshot snapshot;
+
+		@Nonnull
+		private final ReadOptions readOptions;
+
+		/**
+		 * The state meta data.
+		 */
+		@Nonnull
+		private List<StateMetaInfoSnapshot> stateMetaInfoSnapshots;
+
+		/**
+		 * The copied column handle.
+		 */
+		@Nonnull
+		private List<Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase>> metaDataCopy;
+
+		private final AtomicBoolean ownedForCleanup;
+
+		RocksDBFullSnapshotCallable(
+			@Nonnull SupplierWithException<CheckpointStreamWithResultProvider, Exception> checkpointStreamSupplier,
+			@Nonnull CloseableRegistry registry) throws IOException {
+
+			this.ownedForCleanup = new AtomicBoolean(false);
+			this.checkpointStreamSupplier = checkpointStreamSupplier;
+			this.keyGroupRangeOffsets = new KeyGroupRangeOffsets(keyGroupRange);
+			this.snapshotCloseableRegistry = registry;
+
+			this.stateMetaInfoSnapshots = new ArrayList<>(kvStateInformation.size());
+			this.metaDataCopy = new ArrayList<>(kvStateInformation.size());
+			for (Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> tuple2 : kvStateInformation.values()) {
+				// snapshot meta info
+				this.stateMetaInfoSnapshots.add(tuple2.f1.snapshot());
+				this.metaDataCopy.add(tuple2);
+			}
+
+			this.dbLease = rocksDBResourceGuard.acquireResource();
+
+			this.readOptions = new ReadOptions();
+			this.snapshot = db.getSnapshot();
+			this.readOptions.setSnapshot(snapshot);
+		}
+
+		@Override
+		public SnapshotResult<KeyedStateHandle> call() throws Exception {
+
+			if (!ownedForCleanup.compareAndSet(false, true)) {
+				throw new CancellationException("Snapshot task was already cancelled, stopping execution.");
+			}
+
+			final long startTime = System.currentTimeMillis();
+			final List<Tuple2<RocksIteratorWrapper, Integer>> kvStateIterators = new ArrayList<>(metaDataCopy.size());
+
+			try {
+
+				cancelStreamRegistry.registerCloseable(snapshotCloseableRegistry);
+
+				final CheckpointStreamWithResultProvider checkpointStreamWithResultProvider = checkpointStreamSupplier.get();
+				snapshotCloseableRegistry.registerCloseable(checkpointStreamWithResultProvider);
+
+				final DataOutputView outputView =
+					new DataOutputViewStreamWrapper(checkpointStreamWithResultProvider.getCheckpointOutputStream());
+
+				writeKVStateMetaData(kvStateIterators, outputView);
+				writeKVStateData(kvStateIterators, checkpointStreamWithResultProvider);
+
+				final SnapshotResult<KeyedStateHandle> snapshotResult =
+					createStateHandlesFromStreamProvider(checkpointStreamWithResultProvider);
+
+				LOG.info("Asynchronous RocksDB snapshot ({}, asynchronous part) in thread {} took {} ms.",
+					checkpointStreamSupplier, Thread.currentThread(), (System.currentTimeMillis() - startTime));
+
+				return snapshotResult;
+
+			} finally {
+
+				for (Tuple2<RocksIteratorWrapper, Integer> kvStateIterator : kvStateIterators) {
+					IOUtils.closeQuietly(kvStateIterator.f0);
+				}
+
+				cleanupSynchronousStepResources();
+			}
+		}
+
+		private void cleanupSynchronousStepResources() {
+			IOUtils.closeQuietly(readOptions);
+
+			db.releaseSnapshot(snapshot);
+			IOUtils.closeQuietly(snapshot);
+
+			IOUtils.closeQuietly(dbLease);
+
+			if (cancelStreamRegistry.unregisterCloseable(snapshotCloseableRegistry)) {
+				try {
+					snapshotCloseableRegistry.close();
+				} catch (Exception ex) {
+					LOG.warn("Error closing local registry", ex);
+				}
+			}
+		}
+
+		private SnapshotResult<KeyedStateHandle> createStateHandlesFromStreamProvider(
+			CheckpointStreamWithResultProvider checkpointStreamWithResultProvider) throws IOException {
+			if (snapshotCloseableRegistry.unregisterCloseable(checkpointStreamWithResultProvider)) {
+				return CheckpointStreamWithResultProvider.toKeyedStateHandleSnapshotResult(
+					checkpointStreamWithResultProvider.closeAndFinalizeCheckpointStreamResult(),
+					keyGroupRangeOffsets);
+			} else {
+				throw new IOException("Snapshot was already closed before completion.");
+			}
+		}
+
+		private void writeKVStateMetaData(
+			final List<Tuple2<RocksIteratorWrapper, Integer>> kvStateIterators,
+			final DataOutputView outputView) throws IOException {
+
+			int kvStateId = 0;
+
+			for (Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> tuple2 : metaDataCopy) {
+
+				RocksIteratorWrapper rocksIteratorWrapper =
+					getRocksIterator(db, tuple2.f0, tuple2.f1, readOptions);
+
+				kvStateIterators.add(Tuple2.of(rocksIteratorWrapper, kvStateId));
+				++kvStateId;
+			}
+
+			KeyedBackendSerializationProxy<K> serializationProxy =
+				new KeyedBackendSerializationProxy<>(
+					// TODO: this code assumes that writing a serializer is threadsafe, we should support to
+					// get a serialized form already at state registration time in the future
+					keySerializer,
+					stateMetaInfoSnapshots,
+					!Objects.equals(
+						UncompressedStreamCompressionDecorator.INSTANCE,
+						keyGroupCompressionDecorator));
+
+			serializationProxy.write(outputView);
+		}
+
+		private void writeKVStateData(
+			final List<Tuple2<RocksIteratorWrapper, Integer>> kvStateIterators,
+			final CheckpointStreamWithResultProvider checkpointStreamWithResultProvider) throws IOException, InterruptedException {
+
+			byte[] previousKey = null;
+			byte[] previousValue = null;
+			DataOutputView kgOutView = null;
+			OutputStream kgOutStream = null;
+			CheckpointStreamFactory.CheckpointStateOutputStream checkpointOutputStream =
+				checkpointStreamWithResultProvider.getCheckpointOutputStream();
+
+			try {
+				// Here we transfer ownership of RocksIterators to the RocksStatesPerKeyGroupMergeIterator
+				try (RocksStatesPerKeyGroupMergeIterator mergeIterator = new RocksStatesPerKeyGroupMergeIterator(
+					kvStateIterators, keyGroupPrefixBytes)) {
+
+					//preamble: setup with first key-group as our lookahead
+					if (mergeIterator.isValid()) {
+						//begin first key-group by recording the offset
+						keyGroupRangeOffsets.setKeyGroupOffset(
+							mergeIterator.keyGroup(),
+							checkpointOutputStream.getPos());
+						//write the k/v-state id as metadata
+						kgOutStream = keyGroupCompressionDecorator.decorateWithCompression(checkpointOutputStream);
+						kgOutView = new DataOutputViewStreamWrapper(kgOutStream);
+						//TODO this could be aware of keyGroupPrefixBytes and write only one byte if possible
+						kgOutView.writeShort(mergeIterator.kvStateId());
+						previousKey = mergeIterator.key();
+						previousValue = mergeIterator.value();
+						mergeIterator.next();
+					}
+
+					//main loop: write k/v pairs ordered by (key-group, kv-state), thereby tracking key-group offsets.
+					while (mergeIterator.isValid()) {
+
+						assert (!hasMetaDataFollowsFlag(previousKey));
+
+						//set signal in first key byte that meta data will follow in the stream after this k/v pair
+						if (mergeIterator.isNewKeyGroup() || mergeIterator.isNewKeyValueState()) {
+
+							//be cooperative and check for interruption from time to time in the hot loop
+							checkInterrupted();
+
+							setMetaDataFollowsFlagInKey(previousKey);
+						}
+
+						writeKeyValuePair(previousKey, previousValue, kgOutView);
+
+						//write meta data if we have to
+						if (mergeIterator.isNewKeyGroup()) {
+							//TODO this could be aware of keyGroupPrefixBytes and write only one byte if possible
+							kgOutView.writeShort(END_OF_KEY_GROUP_MARK);
+							// this will just close the outer stream
+							kgOutStream.close();
+							//begin new key-group
+							keyGroupRangeOffsets.setKeyGroupOffset(
+								mergeIterator.keyGroup(),
+								checkpointOutputStream.getPos());
+							//write the kev-state
+							//TODO this could be aware of keyGroupPrefixBytes and write only one byte if possible
+							kgOutStream = keyGroupCompressionDecorator.decorateWithCompression(checkpointOutputStream);
+							kgOutView = new DataOutputViewStreamWrapper(kgOutStream);
+							kgOutView.writeShort(mergeIterator.kvStateId());
+						} else if (mergeIterator.isNewKeyValueState()) {
+							//write the k/v-state
+							//TODO this could be aware of keyGroupPrefixBytes and write only one byte if possible
+							kgOutView.writeShort(mergeIterator.kvStateId());
+						}
+
+						//request next k/v pair
+						previousKey = mergeIterator.key();
+						previousValue = mergeIterator.value();
+						mergeIterator.next();
+					}
+				}
+
+				//epilogue: write last key-group
+				if (previousKey != null) {
+					assert (!hasMetaDataFollowsFlag(previousKey));
+					setMetaDataFollowsFlagInKey(previousKey);
+					writeKeyValuePair(previousKey, previousValue, kgOutView);
+					//TODO this could be aware of keyGroupPrefixBytes and write only one byte if possible
+					kgOutView.writeShort(END_OF_KEY_GROUP_MARK);
+					// this will just close the outer stream
+					kgOutStream.close();
+					kgOutStream = null;
+				}
+
+			} finally {
+				// this will just close the outer stream
+				IOUtils.closeQuietly(kgOutStream);
+			}
+		}
+
+		private void writeKeyValuePair(byte[] key, byte[] value, DataOutputView out) throws IOException {
+			BytePrimitiveArraySerializer.INSTANCE.serialize(key, out);
+			BytePrimitiveArraySerializer.INSTANCE.serialize(value, out);
+		}
+
+		private void checkInterrupted() throws InterruptedException {
+			if (Thread.currentThread().isInterrupted()) {
+				throw new InterruptedException("RocksDB snapshot interrupted.");
+			}
+		}
+
+		@Override
+		public void close() throws Exception {
+
+			if (ownedForCleanup.compareAndSet(false, true)) {
+				cleanupSynchronousStepResources();
+			}
+
+			if (cancelStreamRegistry.unregisterCloseable(snapshotCloseableRegistry)) {
+				snapshotCloseableRegistry.close();
+			}
+		}
+	}
+
+	@SuppressWarnings("unchecked")
+	private static RocksIteratorWrapper getRocksIterator(
+		RocksDB db,
+		ColumnFamilyHandle columnFamilyHandle,
+		RegisteredStateMetaInfoBase metaInfo,
+		ReadOptions readOptions) {
+		StateSnapshotTransformer<byte[]> stateSnapshotTransformer = null;
+		if (metaInfo instanceof RegisteredKeyValueStateBackendMetaInfo) {
+			stateSnapshotTransformer = (StateSnapshotTransformer<byte[]>)
+				((RegisteredKeyValueStateBackendMetaInfo<?, ?>) metaInfo).getSnapshotTransformer();
+		}
+		RocksIterator rocksIterator = db.newIterator(columnFamilyHandle, readOptions);
+		return stateSnapshotTransformer == null ?
+			new RocksIteratorWrapper(rocksIterator) :
+			new RocksTransformingIteratorWrapper(rocksIterator, stateSnapshotTransformer);
+	}
+}
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksIncrementalSnapshotStrategy.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksIncrementalSnapshotStrategy.java
new file mode 100644
index 0000000..3487fe6
--- /dev/null
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksIncrementalSnapshotStrategy.java
@@ -0,0 +1,578 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.contrib.streaming.state.snapshot;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.fs.CloseableRegistry;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.fs.FileStatus;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.CheckpointType;
+import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.runtime.state.CheckpointStreamWithResultProvider;
+import org.apache.flink.runtime.state.CheckpointedStateScope;
+import org.apache.flink.runtime.state.DirectoryStateHandle;
+import org.apache.flink.runtime.state.DoneFuture;
+import org.apache.flink.runtime.state.IncrementalKeyedStateHandle;
+import org.apache.flink.runtime.state.IncrementalLocalKeyedStateHandle;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyedBackendSerializationProxy;
+import org.apache.flink.runtime.state.KeyedStateHandle;
+import org.apache.flink.runtime.state.LocalRecoveryConfig;
+import org.apache.flink.runtime.state.LocalRecoveryDirectoryProvider;
+import org.apache.flink.runtime.state.PlaceholderStreamStateHandle;
+import org.apache.flink.runtime.state.RegisteredStateMetaInfoBase;
+import org.apache.flink.runtime.state.SnapshotDirectory;
+import org.apache.flink.runtime.state.SnapshotResult;
+import org.apache.flink.runtime.state.SnapshotStrategy;
+import org.apache.flink.runtime.state.StateHandleID;
+import org.apache.flink.runtime.state.StateObject;
+import org.apache.flink.runtime.state.StateUtil;
+import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
+import org.apache.flink.util.ExceptionUtils;
+import org.apache.flink.util.FileUtils;
+import org.apache.flink.util.IOUtils;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.ResourceGuard;
+
+import org.rocksdb.Checkpoint;
+import org.rocksdb.ColumnFamilyHandle;
+import org.rocksdb.RocksDB;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.SortedMap;
+import java.util.UUID;
+import java.util.concurrent.FutureTask;
+import java.util.concurrent.RunnableFuture;
+
+import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUtil.SST_FILE_SUFFIX;
+
+/**
+ * Snapshot strategy for {@link org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend} that is based
+ * on RocksDB's native checkpoints and creates incremental snapshots.
+ *
+ * @param <K> type of the backend keys.
+ */
+public class RocksIncrementalSnapshotStrategy<K> extends SnapshotStrategyBase<K> {
+
+	private static final Logger LOG = LoggerFactory.getLogger(RocksIncrementalSnapshotStrategy.class);
+
+	/** Base path of the RocksDB instance. */
+	@Nonnull
+	private final File instanceBasePath;
+
+	/** The state handle ids of all sst files materialized in snapshots for previous checkpoints. */
+	@Nonnull
+	private final UUID backendUID;
+
+	/** Stores the materialized sstable files from all snapshots that build the incremental history. */
+	@Nonnull
+	private final SortedMap<Long, Set<StateHandleID>> materializedSstFiles;
+
+	/** The identifier of the last completed checkpoint. */
+	private long lastCompletedCheckpointId;
+
+	/** We delegate snapshots that are for savepoints to this. */
+	@Nonnull
+	private final SnapshotStrategy<SnapshotResult<KeyedStateHandle>> savepointDelegate;
+
+	public RocksIncrementalSnapshotStrategy(
+		@Nonnull RocksDB db,
+		@Nonnull ResourceGuard rocksDBResourceGuard,
+		@Nonnull TypeSerializer<K> keySerializer,
+		@Nonnull LinkedHashMap<String, Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase>> kvStateInformation,
+		@Nonnull KeyGroupRange keyGroupRange,
+		@Nonnegative int keyGroupPrefixBytes,
+		@Nonnull LocalRecoveryConfig localRecoveryConfig,
+		@Nonnull CloseableRegistry cancelStreamRegistry,
+		@Nonnull File instanceBasePath,
+		@Nonnull UUID backendUID,
+		@Nonnull SortedMap<Long, Set<StateHandleID>> materializedSstFiles,
+		long lastCompletedCheckpointId,
+		@Nonnull SnapshotStrategy<SnapshotResult<KeyedStateHandle>> savepointDelegate) {
+
+		super(
+			db,
+			rocksDBResourceGuard,
+			keySerializer,
+			kvStateInformation,
+			keyGroupRange,
+			keyGroupPrefixBytes,
+			localRecoveryConfig,
+			cancelStreamRegistry);
+
+		this.instanceBasePath = instanceBasePath;
+		this.backendUID = backendUID;
+		this.materializedSstFiles = materializedSstFiles;
+		this.lastCompletedCheckpointId = lastCompletedCheckpointId;
+		this.savepointDelegate = savepointDelegate;
+	}
+
+	@Override
+	public RunnableFuture<SnapshotResult<KeyedStateHandle>> performSnapshot(
+		long checkpointId,
+		long checkpointTimestamp,
+		CheckpointStreamFactory checkpointStreamFactory,
+		CheckpointOptions checkpointOptions) throws Exception {
+
+		// for savepoints, we delegate to the full snapshot strategy because savepoints are always self-contained.
+		if (CheckpointType.SAVEPOINT == checkpointOptions.getCheckpointType()) {
+			return savepointDelegate.performSnapshot(
+				checkpointId,
+				checkpointTimestamp,
+				checkpointStreamFactory,
+				checkpointOptions);
+		}
+
+		if (kvStateInformation.isEmpty()) {
+			if (LOG.isDebugEnabled()) {
+				LOG.debug("Asynchronous RocksDB snapshot performed on empty keyed state at {}. Returning null.", checkpointTimestamp);
+			}
+			return DoneFuture.of(SnapshotResult.empty());
+		}
+
+		SnapshotDirectory snapshotDirectory;
+
+		if (localRecoveryConfig.isLocalRecoveryEnabled()) {
+			// create a "permanent" snapshot directory for local recovery.
+			LocalRecoveryDirectoryProvider directoryProvider = localRecoveryConfig.getLocalStateDirectoryProvider();
+			File directory = directoryProvider.subtaskSpecificCheckpointDirectory(checkpointId);
+
+			if (directory.exists()) {
+				FileUtils.deleteDirectory(directory);
+			}
+
+			if (!directory.mkdirs()) {
+				throw new IOException("Local state base directory for checkpoint " + checkpointId +
+					" already exists: " + directory);
+			}
+
+			// introduces an extra directory because RocksDB wants a non-existing directory for native checkpoints.
+			File rdbSnapshotDir = new File(directory, "rocks_db");
+			Path path = new Path(rdbSnapshotDir.toURI());
+			// create a "permanent" snapshot directory because local recovery is active.
+			snapshotDirectory = SnapshotDirectory.permanent(path);
+		} else {
+			// create a "temporary" snapshot directory because local recovery is inactive.
+			Path path = new Path(instanceBasePath.getAbsolutePath(), "chk-" + checkpointId);
+			snapshotDirectory = SnapshotDirectory.temporary(path);
+		}
+
+		final RocksDBIncrementalSnapshotOperation snapshotOperation =
+			new RocksDBIncrementalSnapshotOperation(
+				checkpointStreamFactory,
+				snapshotDirectory,
+				checkpointId);
+
+		try {
+			snapshotOperation.takeSnapshot();
+		} catch (Exception e) {
+			snapshotOperation.stop();
+			snapshotOperation.releaseResources(true);
+			throw e;
+		}
+
+		return new FutureTask<SnapshotResult<KeyedStateHandle>>(
+			snapshotOperation::runSnapshot
+		) {
+			@Override
+			public boolean cancel(boolean mayInterruptIfRunning) {
+				snapshotOperation.stop();
+				return super.cancel(mayInterruptIfRunning);
+			}
+
+			@Override
+			protected void done() {
+				snapshotOperation.releaseResources(isCancelled());
+			}
+		};
+	}
+
+	@Override
+	public void notifyCheckpointComplete(long completedCheckpointId) {
+		synchronized (materializedSstFiles) {
+
+			if (completedCheckpointId < lastCompletedCheckpointId) {
+				return;
+			}
+
+			materializedSstFiles.keySet().removeIf(checkpointId -> checkpointId < completedCheckpointId);
+
+			lastCompletedCheckpointId = completedCheckpointId;
+		}
+	}
+
+	/**
+	 * Encapsulates the process to perform an incremental snapshot of a RocksDBKeyedStateBackend.
+	 */
+	private final class RocksDBIncrementalSnapshotOperation {
+
+		/**
+		 * Stream factory that creates the outpus streams to DFS.
+		 */
+		private final CheckpointStreamFactory checkpointStreamFactory;
+
+		/**
+		 * Id for the current checkpoint.
+		 */
+		private final long checkpointId;
+
+		/**
+		 * All sst files that were part of the last previously completed checkpoint.
+		 */
+		private Set<StateHandleID> baseSstFiles;
+
+		/**
+		 * The state meta data.
+		 */
+		private final List<StateMetaInfoSnapshot> stateMetaInfoSnapshots;
+
+		/**
+		 * Local directory for the RocksDB native backup.
+		 */
+		private SnapshotDirectory localBackupDirectory;
+
+		// Registry for all opened i/o streams
+		private final CloseableRegistry closeableRegistry;
+
+		// new sst files since the last completed checkpoint
+		private final Map<StateHandleID, StreamStateHandle> sstFiles;
+
+		// handles to the misc files in the current snapshot
+		private final Map<StateHandleID, StreamStateHandle> miscFiles;
+
+		// This lease protects from concurrent disposal of the native rocksdb instance.
+		private final ResourceGuard.Lease dbLease;
+
+		private SnapshotResult<StreamStateHandle> metaStateHandle;
+
+		private RocksDBIncrementalSnapshotOperation(
+			CheckpointStreamFactory checkpointStreamFactory,
+			SnapshotDirectory localBackupDirectory,
+			long checkpointId) throws IOException {
+
+			this.checkpointStreamFactory = checkpointStreamFactory;
+			this.checkpointId = checkpointId;
+			this.localBackupDirectory = localBackupDirectory;
+			this.stateMetaInfoSnapshots = new ArrayList<>();
+			this.closeableRegistry = new CloseableRegistry();
+			this.sstFiles = new HashMap<>();
+			this.miscFiles = new HashMap<>();
+			this.metaStateHandle = null;
+			this.dbLease = rocksDBResourceGuard.acquireResource();
+		}
+
+		private StreamStateHandle materializeStateData(Path filePath) throws Exception {
+			FSDataInputStream inputStream = null;
+			CheckpointStreamFactory.CheckpointStateOutputStream outputStream = null;
+
+			try {
+				final byte[] buffer = new byte[8 * 1024];
+
+				FileSystem backupFileSystem = localBackupDirectory.getFileSystem();
+				inputStream = backupFileSystem.open(filePath);
+				closeableRegistry.registerCloseable(inputStream);
+
+				outputStream = checkpointStreamFactory
+					.createCheckpointStateOutputStream(CheckpointedStateScope.SHARED);
+				closeableRegistry.registerCloseable(outputStream);
+
+				while (true) {
+					int numBytes = inputStream.read(buffer);
+
+					if (numBytes == -1) {
+						break;
+					}
+
+					outputStream.write(buffer, 0, numBytes);
+				}
+
+				StreamStateHandle result = null;
+				if (closeableRegistry.unregisterCloseable(outputStream)) {
+					result = outputStream.closeAndGetHandle();
+					outputStream = null;
+				}
+				return result;
+
+			} finally {
+
+				if (closeableRegistry.unregisterCloseable(inputStream)) {
+					inputStream.close();
+				}
+
+				if (closeableRegistry.unregisterCloseable(outputStream)) {
+					outputStream.close();
+				}
+			}
+		}
+
+		@Nonnull
+		private SnapshotResult<StreamStateHandle> materializeMetaData() throws Exception {
+
+			CheckpointStreamWithResultProvider streamWithResultProvider =
+
+				localRecoveryConfig.isLocalRecoveryEnabled() ?
+
+					CheckpointStreamWithResultProvider.createDuplicatingStream(
+						checkpointId,
+						CheckpointedStateScope.EXCLUSIVE,
+						checkpointStreamFactory,
+						localRecoveryConfig.getLocalStateDirectoryProvider()) :
+
+					CheckpointStreamWithResultProvider.createSimpleStream(
+						CheckpointedStateScope.EXCLUSIVE,
+						checkpointStreamFactory);
+
+			try {
+				closeableRegistry.registerCloseable(streamWithResultProvider);
+
+				//no need for compression scheme support because sst-files are already compressed
+				KeyedBackendSerializationProxy<K> serializationProxy =
+					new KeyedBackendSerializationProxy<>(
+						keySerializer,
+						stateMetaInfoSnapshots,
+						false);
+
+				DataOutputView out =
+					new DataOutputViewStreamWrapper(streamWithResultProvider.getCheckpointOutputStream());
+
+				serializationProxy.write(out);
+
+				if (closeableRegistry.unregisterCloseable(streamWithResultProvider)) {
+					SnapshotResult<StreamStateHandle> result =
+						streamWithResultProvider.closeAndFinalizeCheckpointStreamResult();
+					streamWithResultProvider = null;
+					return result;
+				} else {
+					throw new IOException("Stream already closed and cannot return a handle.");
+				}
+			} finally {
+				if (streamWithResultProvider != null) {
+					if (closeableRegistry.unregisterCloseable(streamWithResultProvider)) {
+						IOUtils.closeQuietly(streamWithResultProvider);
+					}
+				}
+			}
+		}
+
+		void takeSnapshot() throws Exception {
+
+			final long lastCompletedCheckpoint;
+
+			// use the last completed checkpoint as the comparison base.
+			synchronized (materializedSstFiles) {
+				lastCompletedCheckpoint = lastCompletedCheckpointId;
+				baseSstFiles = materializedSstFiles.get(lastCompletedCheckpoint);
+			}
+
+			LOG.trace("Taking incremental snapshot for checkpoint {}. Snapshot is based on last completed checkpoint {} " +
+				"assuming the following (shared) files as base: {}.", checkpointId, lastCompletedCheckpoint, baseSstFiles);
+
+			// save meta data
+			for (Map.Entry<String, Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase>> stateMetaInfoEntry
+				: kvStateInformation.entrySet()) {
+				stateMetaInfoSnapshots.add(stateMetaInfoEntry.getValue().f1.snapshot());
+			}
+
+			LOG.trace("Local RocksDB checkpoint goes to backup path {}.", localBackupDirectory);
+
+			if (localBackupDirectory.exists()) {
+				throw new IllegalStateException("Unexpected existence of the backup directory.");
+			}
+
+			// create hard links of living files in the snapshot path
+			try (Checkpoint checkpoint = Checkpoint.create(db)) {
+				checkpoint.createCheckpoint(localBackupDirectory.getDirectory().getPath());
+			}
+		}
+
+		@Nonnull
+		SnapshotResult<KeyedStateHandle> runSnapshot() throws Exception {
+
+			cancelStreamRegistry.registerCloseable(closeableRegistry);
+
+			// write meta data
+			metaStateHandle = materializeMetaData();
+
+			// sanity checks - they should never fail
+			Preconditions.checkNotNull(metaStateHandle,
+				"Metadata was not properly created.");
+			Preconditions.checkNotNull(metaStateHandle.getJobManagerOwnedSnapshot(),
+				"Metadata for job manager was not properly created.");
+
+			// write state data
+			Preconditions.checkState(localBackupDirectory.exists());
+
+			FileStatus[] fileStatuses = localBackupDirectory.listStatus();
+			if (fileStatuses != null) {
+				for (FileStatus fileStatus : fileStatuses) {
+					final Path filePath = fileStatus.getPath();
+					final String fileName = filePath.getName();
+					final StateHandleID stateHandleID = new StateHandleID(fileName);
+
+					if (fileName.endsWith(SST_FILE_SUFFIX)) {
+						final boolean existsAlready =
+							baseSstFiles != null && baseSstFiles.contains(stateHandleID);
+
+						if (existsAlready) {
+							// we introduce a placeholder state handle, that is replaced with the
+							// original from the shared state registry (created from a previous checkpoint)
+							sstFiles.put(
+								stateHandleID,
+								new PlaceholderStreamStateHandle());
+						} else {
+							sstFiles.put(stateHandleID, materializeStateData(filePath));
+						}
+					} else {
+						StreamStateHandle fileHandle = materializeStateData(filePath);
+						miscFiles.put(stateHandleID, fileHandle);
+					}
+				}
+			}
+
+			synchronized (materializedSstFiles) {
+				materializedSstFiles.put(checkpointId, sstFiles.keySet());
+			}
+
+			IncrementalKeyedStateHandle jmIncrementalKeyedStateHandle = new IncrementalKeyedStateHandle(
+				backendUID,
+				keyGroupRange,
+				checkpointId,
+				sstFiles,
+				miscFiles,
+				metaStateHandle.getJobManagerOwnedSnapshot());
+
+			StreamStateHandle taskLocalSnapshotMetaDataStateHandle = metaStateHandle.getTaskLocalSnapshot();
+			DirectoryStateHandle directoryStateHandle = null;
+
+			try {
+
+				directoryStateHandle = localBackupDirectory.completeSnapshotAndGetHandle();
+			} catch (IOException ex) {
+
+				Exception collector = ex;
+
+				try {
+					taskLocalSnapshotMetaDataStateHandle.discardState();
+				} catch (Exception discardEx) {
+					collector = ExceptionUtils.firstOrSuppressed(discardEx, collector);
+				}
+
+				LOG.warn("Problem with local state snapshot.", collector);
+			}
+
+			if (directoryStateHandle != null && taskLocalSnapshotMetaDataStateHandle != null) {
+
+				IncrementalLocalKeyedStateHandle localDirKeyedStateHandle =
+					new IncrementalLocalKeyedStateHandle(
+						backendUID,
+						checkpointId,
+						directoryStateHandle,
+						keyGroupRange,
+						taskLocalSnapshotMetaDataStateHandle,
+						sstFiles.keySet());
+				return SnapshotResult.withLocalState(jmIncrementalKeyedStateHandle, localDirKeyedStateHandle);
+			} else {
+				return SnapshotResult.of(jmIncrementalKeyedStateHandle);
+			}
+		}
+
+		void stop() {
+
+			if (cancelStreamRegistry.unregisterCloseable(closeableRegistry)) {
+				try {
+					closeableRegistry.close();
+				} catch (IOException e) {
+					LOG.warn("Could not properly close io streams.", e);
+				}
+			}
+		}
+
+		void releaseResources(boolean canceled) {
+
+			dbLease.close();
+
+			if (cancelStreamRegistry.unregisterCloseable(closeableRegistry)) {
+				try {
+					closeableRegistry.close();
+				} catch (IOException e) {
+					LOG.warn("Exception on closing registry.", e);
+				}
+			}
+
+			try {
+				if (localBackupDirectory.exists()) {
+					LOG.trace("Running cleanup for local RocksDB backup directory {}.", localBackupDirectory);
+					boolean cleanupOk = localBackupDirectory.cleanup();
+
+					if (!cleanupOk) {
+						LOG.debug("Could not properly cleanup local RocksDB backup directory.");
+					}
+				}
+			} catch (IOException e) {
+				LOG.warn("Could not properly cleanup local RocksDB backup directory.", e);
+			}
+
+			if (canceled) {
+				Collection<StateObject> statesToDiscard =
+					new ArrayList<>(1 + miscFiles.size() + sstFiles.size());
+
+				statesToDiscard.add(metaStateHandle);
+				statesToDiscard.addAll(miscFiles.values());
+				statesToDiscard.addAll(sstFiles.values());
+
+				try {
+					StateUtil.bestEffortDiscardAllStateObjects(statesToDiscard);
+				} catch (Exception e) {
+					LOG.warn("Could not properly discard states.", e);
+				}
+
+				if (localBackupDirectory.isSnapshotCompleted()) {
+					try {
+						DirectoryStateHandle directoryStateHandle = localBackupDirectory.completeSnapshotAndGetHandle();
+						if (directoryStateHandle != null) {
+							directoryStateHandle.discardState();
+						}
+					} catch (Exception e) {
+						LOG.warn("Could not properly discard local state.", e);
+					}
+				}
+			}
+		}
+	}
+}
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksSnapshotUtil.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksSnapshotUtil.java
new file mode 100644
index 0000000..bf2bbdb
--- /dev/null
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksSnapshotUtil.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.contrib.streaming.state.snapshot;
+
+/**
+ * Utility methods and constants around RocksDB creating and restoring snapshots for
+ * {@link org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend}.
+ */
+public class RocksSnapshotUtil {
+
+	/**
+	 * File suffix of sstable files.
+	 */
+	public static final String SST_FILE_SUFFIX = ".sst";
+
+	public static final int FIRST_BIT_IN_BYTE_MASK = 0x80;
+
+	public static final int END_OF_KEY_GROUP_MARK = 0xFFFF;
+
+	public static void setMetaDataFollowsFlagInKey(byte[] key) {
+		key[0] |= FIRST_BIT_IN_BYTE_MASK;
+	}
+
+	public static void clearMetaDataFollowsFlag(byte[] key) {
+		key[0] &= (~FIRST_BIT_IN_BYTE_MASK);
+	}
+
+	public static boolean hasMetaDataFollowsFlag(byte[] key) {
+		return 0 != (key[0] & FIRST_BIT_IN_BYTE_MASK);
+	}
+
+	private RocksSnapshotUtil() {
+		throw new AssertionError();
+	}
+}
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/SnapshotStrategyBase.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/SnapshotStrategyBase.java
new file mode 100644
index 0000000..efebe8c
--- /dev/null
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/SnapshotStrategyBase.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.contrib.streaming.state.snapshot;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.fs.CloseableRegistry;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyedStateHandle;
+import org.apache.flink.runtime.state.LocalRecoveryConfig;
+import org.apache.flink.runtime.state.RegisteredStateMetaInfoBase;
+import org.apache.flink.runtime.state.SnapshotResult;
+import org.apache.flink.runtime.state.SnapshotStrategy;
+import org.apache.flink.util.ResourceGuard;
+
+import org.rocksdb.ColumnFamilyHandle;
+import org.rocksdb.RocksDB;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+import java.util.LinkedHashMap;
+
+/**
+ * Base class for {@link SnapshotStrategy} implementations on RocksDB.
+ *
+ * @param <K> type of the backend keys.
+ */
+public abstract class SnapshotStrategyBase<K> implements SnapshotStrategy<SnapshotResult<KeyedStateHandle>> {
+
+	@Nonnull
+	protected final RocksDB db;
+
+	@Nonnull
+	protected final ResourceGuard rocksDBResourceGuard;
+
+	@Nonnull
+	protected final TypeSerializer<K> keySerializer;
+
+	@Nonnull
+	protected final LinkedHashMap<String, Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase>> kvStateInformation;
+
+	@Nonnull
+	protected final KeyGroupRange keyGroupRange;
+
+	@Nonnegative
+	protected final int keyGroupPrefixBytes;
+
+	@Nonnull
+	protected final LocalRecoveryConfig localRecoveryConfig;
+
+	@Nonnull
+	protected final CloseableRegistry cancelStreamRegistry;
+
+	public SnapshotStrategyBase(
+		@Nonnull RocksDB db,
+		@Nonnull ResourceGuard rocksDBResourceGuard,
+		@Nonnull TypeSerializer<K> keySerializer,
+		@Nonnull LinkedHashMap<String, Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase>> kvStateInformation,
+		@Nonnull KeyGroupRange keyGroupRange,
+		@Nonnegative int keyGroupPrefixBytes,
+		@Nonnull LocalRecoveryConfig localRecoveryConfig,
+		@Nonnull CloseableRegistry cancelStreamRegistry) {
+
+		this.db = db;
+		this.rocksDBResourceGuard = rocksDBResourceGuard;
+		this.keySerializer = keySerializer;
+		this.kvStateInformation = kvStateInformation;
+		this.keyGroupRange = keyGroupRange;
+		this.keyGroupPrefixBytes = keyGroupPrefixBytes;
+		this.localRecoveryConfig = localRecoveryConfig;
+		this.cancelStreamRegistry = cancelStreamRegistry;
+	}
+}
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
index e344638..c872553 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
@@ -91,6 +91,11 @@ import java.util.concurrent.ExecutorService;
 import java.util.concurrent.RunnableFuture;
 import java.util.concurrent.TimeUnit;
 
+import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUtil.END_OF_KEY_GROUP_MARK;
+import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUtil.FIRST_BIT_IN_BYTE_MASK;
+import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUtil.clearMetaDataFollowsFlag;
+import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUtil.hasMetaDataFollowsFlag;
+import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUtil.setMetaDataFollowsFlagInKey;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 import static org.mockito.Mockito.spy;
@@ -425,21 +430,19 @@ public class RocksDBAsyncSnapshotTest extends TestLogger {
 	@Test
 	public void testConsistentSnapshotSerializationFlagsAndMasks() {
 
-		Assert.assertEquals(0xFFFF, RocksDBKeyedStateBackend.RocksDBFullSnapshotOperation.END_OF_KEY_GROUP_MARK);
-		Assert.assertEquals(0x80, RocksDBKeyedStateBackend.RocksDBFullSnapshotOperation.FIRST_BIT_IN_BYTE_MASK);
+		Assert.assertEquals(0xFFFF, END_OF_KEY_GROUP_MARK);
+		Assert.assertEquals(0x80, FIRST_BIT_IN_BYTE_MASK);
 
 		byte[] expectedKey = new byte[] {42, 42};
 		byte[] modKey = expectedKey.clone();
 
-		Assert.assertFalse(
-			RocksDBKeyedStateBackend.RocksDBFullSnapshotOperation.hasMetaDataFollowsFlag(modKey));
+		Assert.assertFalse(hasMetaDataFollowsFlag(modKey));
 
-		RocksDBKeyedStateBackend.RocksDBFullSnapshotOperation.setMetaDataFollowsFlagInKey(modKey);
-		Assert.assertTrue(RocksDBKeyedStateBackend.RocksDBFullSnapshotOperation.hasMetaDataFollowsFlag(modKey));
+		setMetaDataFollowsFlagInKey(modKey);
+		Assert.assertTrue(hasMetaDataFollowsFlag(modKey));
 
-		RocksDBKeyedStateBackend.RocksDBFullSnapshotOperation.clearMetaDataFollowsFlag(modKey);
-		Assert.assertFalse(
-			RocksDBKeyedStateBackend.RocksDBFullSnapshotOperation.hasMetaDataFollowsFlag(modKey));
+		clearMetaDataFollowsFlag(modKey);
+		Assert.assertFalse(hasMetaDataFollowsFlag(modKey));
 
 		Assert.assertTrue(Arrays.equals(expectedKey, modKey));
 	}
@@ -504,12 +507,12 @@ public class RocksDBAsyncSnapshotTest extends TestLogger {
 
 		@Nullable
 		@Override
-		public StreamStateHandle closeAndGetHandle() throws IOException {
+		public StreamStateHandle closeAndGetHandle() {
 			throw new UnsupportedOperationException();
 		}
 
 		@Override
-		public long getPos() throws IOException {
+		public long getPos() {
 			throw new UnsupportedOperationException();
 		}
 
@@ -529,7 +532,7 @@ public class RocksDBAsyncSnapshotTest extends TestLogger {
 		}
 
 		@Override
-		public void close() throws IOException {
+		public void close() {
 			throw new UnsupportedOperationException();
 		}
 	}
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
index 0ea0d3f..4916251 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
@@ -191,6 +191,7 @@ public class RocksDBStateBackendTest extends StateBackendTestBase<RocksDBStateBa
 		allCreatedCloseables = new ArrayList<>();
 
 		keyedStateBackend.db = spy(keyedStateBackend.db);
+		keyedStateBackend.initializeSnapshotStrategy(null);
 
 		doAnswer(new Answer<Object>() {
 


[flink] 02/02: [FLINK-10042][state] (part 2) Refactoring of snapshot algorithms for better abstraction and cleaner resource management

Posted by sr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

srichter pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit f803280bb933d968976e79b9efb5953bed308d96
Author: Stefan Richter <s....@data-artisans.com>
AuthorDate: Thu Aug 9 22:23:42 2018 +0200

    [FLINK-10042][state] (part 2) Refactoring of snapshot algorithms for better abstraction and cleaner resource management
    
    This closes #6556.
---
 .../async/AbstractAsyncCallableWithResources.java  | 194 --------
 .../flink/runtime/io/async/AsyncDoneCallback.java  |  33 --
 .../flink/runtime/io/async/AsyncStoppable.java     |  45 --
 .../io/async/AsyncStoppableTaskWithCallback.java   |  59 ---
 .../io/async/StoppableCallbackCallable.java        |  30 --
 .../runtime/state/AbstractSnapshotStrategy.java    |  79 +++
 .../flink/runtime/state/AsyncSnapshotCallable.java | 190 +++++++
 .../runtime/state/DefaultOperatorStateBackend.java | 369 +++++++-------
 .../flink/runtime/state/SnapshotStrategy.java      |  13 +-
 .../apache/flink/runtime/state/Snapshotable.java   |  27 +-
 .../runtime/state/heap/HeapKeyedStateBackend.java  | 145 +++---
 .../runtime/state/AsyncSnapshotCallableTest.java   | 326 ++++++++++++
 .../runtime/state/OperatorStateBackendTest.java    |   4 +-
 .../flink/runtime/state/StateBackendTestBase.java  |   6 +-
 .../state/ttl/mock/MockKeyedStateBackend.java      |   5 +-
 .../streaming/state/RocksDBKeyedStateBackend.java  |  53 +-
 ...yBase.java => RocksDBSnapshotStrategyBase.java} |  57 ++-
 .../state/snapshot/RocksFullSnapshotStrategy.java  | 255 ++++------
 .../snapshot/RocksIncrementalSnapshotStrategy.java | 552 ++++++++++-----------
 .../flink/streaming/runtime/tasks/StreamTask.java  |   4 +-
 .../tasks/TaskCheckpointingBehaviourTest.java      |  11 +-
 .../apache/flink/core/testutils/OneShotLatch.java  |  18 +-
 22 files changed, 1329 insertions(+), 1146 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/async/AbstractAsyncCallableWithResources.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/async/AbstractAsyncCallableWithResources.java
deleted file mode 100644
index bc0116c..0000000
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/async/AbstractAsyncCallableWithResources.java
+++ /dev/null
@@ -1,194 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.runtime.io.async;
-
-import org.apache.flink.util.ExceptionUtils;
-
-import java.io.IOException;
-
-/**
- * This abstract class encapsulates the lifecycle and execution strategy for asynchronous operations that use resources.
- *
- * @param <V> return type of the asynchronous call.
- */
-public abstract class AbstractAsyncCallableWithResources<V> implements StoppableCallbackCallable<V> {
-
-	/** Tracks if the stop method was called on this object. */
-	private volatile boolean stopped;
-
-	/** Tracks if call method was executed (only before stop calls). */
-	private volatile boolean called;
-
-	/** Stores a collected exception if there was one during stop. */
-	private volatile Exception stopException;
-
-	public AbstractAsyncCallableWithResources() {
-		this.stopped = false;
-		this.called = false;
-	}
-
-	/**
-	 * This method implements the strategy for the actual IO operation:
-	 * <p>
-	 * 1) Acquire resources asynchronously and atomically w.r.t stopping.
-	 * 2) Performs the operation
-	 * 3) Releases resources.
-	 *
-	 * @return Result of the IO operation, e.g. a deserialized object.
-	 * @throws Exception exception that happened during the call.
-	 */
-	@Override
-	public final V call() throws Exception {
-
-		V result = null;
-		Exception collectedException = null;
-
-		try {
-			synchronized (this) {
-
-				if (stopped) {
-					throw new IOException("Task was already stopped.");
-				}
-
-				called = true;
-				// Get resources in async part, atomically w.r.t. stopping.
-				acquireResources();
-			}
-
-			// The main work is performed here.
-			result = performOperation();
-
-		} catch (Exception ex) {
-			collectedException = ex;
-		} finally {
-
-			try {
-				// Cleanup
-				releaseResources();
-			} catch (Exception relEx) {
-				collectedException = ExceptionUtils.firstOrSuppressed(relEx, collectedException);
-			}
-
-			if (collectedException != null) {
-				throw collectedException;
-			}
-		}
-
-		return result;
-	}
-
-	/**
-	 * Open the IO Handle (e.g. a stream) on which the operation will be performed.
-	 *
-	 * @return the opened IO handle that implements #Closeable
-	 * @throws Exception if there was a problem in acquiring.
-	 */
-	protected abstract void acquireResources() throws Exception;
-
-	/**
-	 * Implements the actual operation.
-	 *
-	 * @return Result of the operation
-	 * @throws Exception if there was a problem in executing the operation.
-	 */
-	protected abstract V performOperation() throws Exception;
-
-	/**
-	 * Releases resources acquired by this object.
-	 *
-	 * @throws Exception if there was a problem in releasing resources.
-	 */
-	protected abstract void releaseResources() throws Exception;
-
-	/**
-	 * This method implements how the operation is stopped. Usually this involves interrupting or closing some
-	 * resources like streams to return from blocking calls.
-	 *
-	 * @throws Exception on problems during the stopping.
-	 */
-	protected abstract void stopOperation() throws Exception;
-
-	/**
-	 * Stops the I/O operation by closing the I/O handle. If an exception is thrown on close, it can be accessed via
-	 * #getStopException().
-	 */
-	@Override
-	public final void stop() {
-
-		synchronized (this) {
-
-			// Make sure that call can not enter execution from here.
-			if (stopped) {
-				return;
-			} else {
-				stopped = true;
-			}
-		}
-
-		if (called) {
-			// Async call is executing -> attempt to stop it and releaseResources() will happen inside the async method.
-			try {
-				stopOperation();
-			} catch (Exception stpEx) {
-				this.stopException = stpEx;
-			}
-		} else {
-			// Async call was not executed, so we also need to releaseResources() here.
-			try {
-				releaseResources();
-			} catch (Exception relEx) {
-				stopException = relEx;
-			}
-		}
-	}
-
-	/**
-	 * Optional callback that subclasses can implement. This is called when the callable method completed, e.g. because
-	 * it finished or was stopped.
-	 */
-	@Override
-	public void done(boolean canceled) {
-		//optional callback hook
-	}
-
-	/**
-	 * True once the async method was called.
-	 */
-	public boolean isCalled() {
-		return called;
-	}
-
-	/**
-	 * Check if the IO operation is stopped
-	 *
-	 * @return true if stop() was called
-	 */
-	@Override
-	public boolean isStopped() {
-		return stopped;
-	}
-
-	/**
-	 * Returns a potential exception that might have been observed while stopping the operation.
-	 */
-	@Override
-	public Exception getStopException() {
-		return stopException;
-	}
-}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/async/AsyncDoneCallback.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/async/AsyncDoneCallback.java
deleted file mode 100644
index dcc5525..0000000
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/async/AsyncDoneCallback.java
+++ /dev/null
@@ -1,33 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.runtime.io.async;
-
-/**
- * Callback for an asynchronous operation that is called on termination
- */
-public interface AsyncDoneCallback {
-
-	/**
-	 * the callback
-	 *
-	 * @param canceled true if the callback is done, but was canceled
-	 */
-	void done(boolean canceled);
-
-}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/async/AsyncStoppable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/async/AsyncStoppable.java
deleted file mode 100644
index 8698600..0000000
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/async/AsyncStoppable.java
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.runtime.io.async;
-
-/**
- * An asynchronous operation that can be stopped.
- */
-public interface AsyncStoppable {
-
-	/**
-	 * Stop the operation
-	 */
-	void stop();
-
-	/**
-	 * Check whether the operation is stopped
-	 *
-	 * @return true iff operation is stopped
-	 */
-	boolean isStopped();
-
-	/**
-	 * Delivers Exception that might happen during {@link #stop()}
-	 *
-	 * @return Exception that can happen during stop
-	 */
-	Exception getStopException();
-
-}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/async/AsyncStoppableTaskWithCallback.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/async/AsyncStoppableTaskWithCallback.java
deleted file mode 100644
index a30c607..0000000
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/async/AsyncStoppableTaskWithCallback.java
+++ /dev/null
@@ -1,59 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.runtime.io.async;
-
-import org.apache.flink.util.Preconditions;
-
-import java.util.concurrent.FutureTask;
-
-/**
- * @param <V> return type of the callable function
- */
-public class AsyncStoppableTaskWithCallback<V> extends FutureTask<V> {
-
-	protected final StoppableCallbackCallable<V> stoppableCallbackCallable;
-
-	public AsyncStoppableTaskWithCallback(StoppableCallbackCallable<V> callable) {
-		super(Preconditions.checkNotNull(callable));
-		this.stoppableCallbackCallable = callable;
-	}
-
-	@Override
-	public boolean cancel(boolean mayInterruptIfRunning) {
-		final boolean cancel = super.cancel(mayInterruptIfRunning);
-		if (cancel) {
-			stoppableCallbackCallable.stop();
-			// this is where we report done() for the cancel case, after calling stop().
-			stoppableCallbackCallable.done(true);
-		}
-		return cancel;
-	}
-
-	@Override
-	protected void done() {
-		// we suppress forwarding if we have not been canceled, because the cancel case will call to this method separately.
-		if (!isCancelled()) {
-			stoppableCallbackCallable.done(false);
-		}
-	}
-
-	public static <V> AsyncStoppableTaskWithCallback<V> from(StoppableCallbackCallable<V> callable) {
-		return new AsyncStoppableTaskWithCallback<>(callable);
-	}
-}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/async/StoppableCallbackCallable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/async/StoppableCallbackCallable.java
deleted file mode 100644
index d459316..0000000
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/async/StoppableCallbackCallable.java
+++ /dev/null
@@ -1,30 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.runtime.io.async;
-
-import java.util.concurrent.Callable;
-
-/**
- * A {@link Callable} that can be stopped and offers a callback on termination.
- *
- * @param <V> return value of the call operation.
- */
-public interface StoppableCallbackCallable<V> extends Callable<V>, AsyncStoppable, AsyncDoneCallback {
-
-}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractSnapshotStrategy.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractSnapshotStrategy.java
new file mode 100644
index 0000000..e0debe5
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractSnapshotStrategy.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nonnull;
+
+/**
+ * Abstract base class for implementing {@link SnapshotStrategy}, that gives a consistent logging across state backends.
+ *
+ * @param <T> type of the snapshot result.
+ */
+public abstract class AbstractSnapshotStrategy<T extends StateObject> implements SnapshotStrategy<SnapshotResult<T>> {
+
+	private static final Logger LOG = LoggerFactory.getLogger(AbstractSnapshotStrategy.class);
+
+	private static final String LOG_SYNC_COMPLETED_TEMPLATE = "{} ({}, synchronous part) in thread {} took {} ms.";
+	private static final String LOG_ASYNC_COMPLETED_TEMPLATE = "{} ({}, asynchronous part) in thread {} took {} ms.";
+
+	/** Descriptive name of the snapshot strategy that will appear in the log outputs and {@link #toString()}. */
+	@Nonnull
+	protected final String description;
+
+	protected AbstractSnapshotStrategy(@Nonnull String description) {
+		this.description = description;
+	}
+
+	/**
+	 * Logs the duration of the synchronous snapshot part from the given start time.
+	 */
+	public void logSyncCompleted(@Nonnull Object checkpointOutDescription, long startTime) {
+		logCompletedInternal(LOG_SYNC_COMPLETED_TEMPLATE, checkpointOutDescription, startTime);
+	}
+
+	/**
+	 * Logs the duration of the asynchronous snapshot part from the given start time.
+	 */
+	public void logAsyncCompleted(@Nonnull Object checkpointOutDescription, long startTime) {
+		logCompletedInternal(LOG_ASYNC_COMPLETED_TEMPLATE, checkpointOutDescription, startTime);
+	}
+
+	private void logCompletedInternal(
+		@Nonnull String template,
+		@Nonnull Object checkpointOutDescription,
+		long startTime) {
+
+		long duration = (System.currentTimeMillis() - startTime);
+
+		LOG.debug(
+			template,
+			description,
+			checkpointOutDescription,
+			Thread.currentThread(),
+			duration);
+	}
+
+	@Override
+	public String toString() {
+		return "SnapshotStrategy {" + description + "}";
+	}
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AsyncSnapshotCallable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AsyncSnapshotCallable.java
new file mode 100644
index 0000000..2c1a0be
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AsyncSnapshotCallable.java
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.core.fs.CloseableRegistry;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.concurrent.Callable;
+import java.util.concurrent.CancellationException;
+import java.util.concurrent.FutureTask;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+/**
+ * Base class that outlines the strategy for asynchronous snapshots. Implementations of this class are typically
+ * instantiated with resources that have been created in the synchronous part of a snapshot. Then, the implementation
+ * of {@link #callInternal()} is invoked in the asynchronous part. All resources created by this methods should
+ * be released by the end of the method. If the created resources are {@link Closeable} objects and can block in calls
+ * (e.g. in/output streams), they should be registered with the snapshot's {@link CloseableRegistry} so that the can
+ * be closed and unblocked on cancellation. After {@link #callInternal()} ended, {@link #logAsyncSnapshotComplete(long)}
+ * is called. In that method, implementations can emit log statements about the duration. At the very end, this class
+ * calls {@link #cleanupProvidedResources()}. The implementation of this method should release all provided resources
+ * that have been passed into the snapshot from the synchronous part of the snapshot.
+ *
+ * @param <T> type of the result.
+ */
+public abstract class AsyncSnapshotCallable<T> implements Callable<T> {
+
+	/** Message for the {@link CancellationException}. */
+	private static final String CANCELLATION_EXCEPTION_MSG = "Async snapshot was cancelled.";
+
+	private static final Logger LOG = LoggerFactory.getLogger(AsyncSnapshotCallable.class);
+
+	/** This is used to atomically claim ownership for the resource cleanup. */
+	@Nonnull
+	private final AtomicBoolean resourceCleanupOwnershipTaken;
+
+	/** Registers streams that can block in I/O during snapshot. Forwards close from taskCancelCloseableRegistry. */
+	@Nonnull
+	private final CloseableRegistry snapshotCloseableRegistry;
+
+	protected AsyncSnapshotCallable() {
+		this.snapshotCloseableRegistry = new CloseableRegistry();
+		this.resourceCleanupOwnershipTaken = new AtomicBoolean(false);
+	}
+
+	@Override
+	public T call() throws Exception {
+		final long startTime = System.currentTimeMillis();
+
+		if (resourceCleanupOwnershipTaken.compareAndSet(false, true)) {
+			try {
+				T result = callInternal();
+				logAsyncSnapshotComplete(startTime);
+				return result;
+			} catch (Exception ex) {
+				if (!snapshotCloseableRegistry.isClosed()) {
+					throw ex;
+				}
+			} finally {
+				closeSnapshotIO();
+				cleanup();
+			}
+		}
+
+		throw new CancellationException(CANCELLATION_EXCEPTION_MSG);
+	}
+
+	@VisibleForTesting
+	protected void cancel() {
+		closeSnapshotIO();
+		if (resourceCleanupOwnershipTaken.compareAndSet(false, true)) {
+			cleanup();
+		}
+	}
+
+	/**
+	 * Creates a future task from this and registers it with the given {@link CloseableRegistry}. The task is
+	 * unregistered again in {@link FutureTask#done()}.
+	 */
+	public AsyncSnapshotTask toAsyncSnapshotFutureTask(@Nonnull CloseableRegistry taskRegistry) throws IOException {
+		return new AsyncSnapshotTask(taskRegistry);
+	}
+
+	/**
+	 * {@link FutureTask} that wraps a {@link AsyncSnapshotCallable} and connects it with cancellation and closing.
+	 */
+	public class AsyncSnapshotTask extends FutureTask<T> {
+
+		@Nonnull
+		private final CloseableRegistry taskRegistry;
+
+		@Nonnull
+		private final Closeable cancelOnClose;
+
+		private AsyncSnapshotTask(@Nonnull CloseableRegistry taskRegistry) throws IOException {
+			super(AsyncSnapshotCallable.this);
+			this.cancelOnClose = () -> cancel(true);
+			this.taskRegistry = taskRegistry;
+			taskRegistry.registerCloseable(cancelOnClose);
+		}
+
+		@Override
+		public boolean cancel(boolean mayInterruptIfRunning) {
+			boolean result = super.cancel(mayInterruptIfRunning);
+			if (mayInterruptIfRunning) {
+				AsyncSnapshotCallable.this.cancel();
+			}
+			return result;
+		}
+
+		@Override
+		protected void done() {
+			super.done();
+			taskRegistry.unregisterCloseable(cancelOnClose);
+		}
+	}
+
+	/**
+	 * This method implements the (async) snapshot logic. Resources aquired within this method should be released at
+	 * the end of the method.
+	 */
+	protected abstract T callInternal() throws Exception;
+
+	/**
+	 * This method implements the cleanup of resources that have been passed in (from the sync part). Called after the
+	 * end of {@link #callInternal()}.
+	 */
+	protected abstract void cleanupProvidedResources();
+
+	/**
+	 * This method is invoked after completion of the snapshot and can be overridden to output a logging about the
+	 * duration of the async part.
+	 */
+	protected void logAsyncSnapshotComplete(long startTime) {
+
+	}
+
+	/**
+	 * Registers the {@link Closeable} with the snapshot's {@link CloseableRegistry}, so that it will be closed on
+	 * {@link #cancel()} and becomes unblocked. If the registry is already closed, the arguments is closed and an
+	 * {@link IOException} is emitted.
+	 */
+	protected void registerCloseableForCancellation(@Nullable Closeable toRegister) throws IOException {
+		snapshotCloseableRegistry.registerCloseable(toRegister);
+	}
+
+	/**
+	 * Unregisters the given argument from the snapshot's {@link CloseableRegistry} and returns <code>true</code> iff
+	 * the argument was registered before the call.
+	 */
+	protected boolean unregisterCloseableFromCancellation(@Nullable Closeable toUnregister) {
+		return snapshotCloseableRegistry.unregisterCloseable(toUnregister);
+	}
+
+	private void cleanup() {
+		cleanupProvidedResources();
+	}
+
+	private void closeSnapshotIO() {
+		try {
+			snapshotCloseableRegistry.close();
+		} catch (IOException e) {
+			LOG.warn("Could not properly close incremental snapshot streams.", e);
+		}
+	}
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
index d9fc41e..eae5a3b 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
@@ -36,8 +36,6 @@ import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
-import org.apache.flink.runtime.io.async.AbstractAsyncCallableWithResources;
-import org.apache.flink.runtime.io.async.AsyncStoppableTaskWithCallback;
 import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
 import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.StateMigrationException;
@@ -56,6 +54,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.FutureTask;
 import java.util.concurrent.RunnableFuture;
 
 /**
@@ -133,6 +132,8 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 
 	private final Map<String, BackendWritableBroadcastState<?, ?>> accessedBroadcastStatesByName;
 
+	private final AbstractSnapshotStrategy<OperatorStateHandle> snapshotStrategy;
+
 	public DefaultOperatorStateBackend(
 		ClassLoader userClassLoader,
 		ExecutionConfig executionConfig,
@@ -149,6 +150,7 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 		this.accessedBroadcastStatesByName = new HashMap<>();
 		this.restoredOperatorStateMetaInfos = new HashMap<>();
 		this.restoredBroadcastStateMetaInfos = new HashMap<>();
+		this.snapshotStrategy = new DefaultOperatorStateBackendSnapshotStrategy();
 	}
 
 	public ExecutionConfig getExecutionConfig() {
@@ -307,179 +309,6 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 	//  Snapshot and restore
 	// -------------------------------------------------------------------------------------------
 
-	@Override
-	public RunnableFuture<SnapshotResult<OperatorStateHandle>> snapshot(
-			final long checkpointId,
-			final long timestamp,
-			final CheckpointStreamFactory streamFactory,
-			final CheckpointOptions checkpointOptions) throws Exception {
-
-		final long syncStartTime = System.currentTimeMillis();
-
-		if (registeredOperatorStates.isEmpty() && registeredBroadcastStates.isEmpty()) {
-			return DoneFuture.of(SnapshotResult.empty());
-		}
-
-		final Map<String, PartitionableListState<?>> registeredOperatorStatesDeepCopies =
-				new HashMap<>(registeredOperatorStates.size());
-		final Map<String, BackendWritableBroadcastState<?, ?>> registeredBroadcastStatesDeepCopies =
-				new HashMap<>(registeredBroadcastStates.size());
-
-		ClassLoader snapshotClassLoader = Thread.currentThread().getContextClassLoader();
-		Thread.currentThread().setContextClassLoader(userClassloader);
-		try {
-			// eagerly create deep copies of the list and the broadcast states (if any)
-			// in the synchronous phase, so that we can use them in the async writing.
-
-			if (!registeredOperatorStates.isEmpty()) {
-				for (Map.Entry<String, PartitionableListState<?>> entry : registeredOperatorStates.entrySet()) {
-					PartitionableListState<?> listState = entry.getValue();
-					if (null != listState) {
-						listState = listState.deepCopy();
-					}
-					registeredOperatorStatesDeepCopies.put(entry.getKey(), listState);
-				}
-			}
-
-			if (!registeredBroadcastStates.isEmpty()) {
-				for (Map.Entry<String, BackendWritableBroadcastState<?, ?>> entry : registeredBroadcastStates.entrySet()) {
-					BackendWritableBroadcastState<?, ?> broadcastState = entry.getValue();
-					if (null != broadcastState) {
-						broadcastState = broadcastState.deepCopy();
-					}
-					registeredBroadcastStatesDeepCopies.put(entry.getKey(), broadcastState);
-				}
-			}
-		} finally {
-			Thread.currentThread().setContextClassLoader(snapshotClassLoader);
-		}
-
-		// implementation of the async IO operation, based on FutureTask
-		final AbstractAsyncCallableWithResources<SnapshotResult<OperatorStateHandle>> ioCallable =
-			new AbstractAsyncCallableWithResources<SnapshotResult<OperatorStateHandle>>() {
-
-				CheckpointStreamFactory.CheckpointStateOutputStream out = null;
-
-				@Override
-				protected void acquireResources() throws Exception {
-					openOutStream();
-				}
-
-				@Override
-				protected void releaseResources() {
-					closeOutStream();
-				}
-
-				@Override
-				protected void stopOperation() {
-					closeOutStream();
-				}
-
-				private void openOutStream() throws Exception {
-					out = streamFactory.createCheckpointStateOutputStream(CheckpointedStateScope.EXCLUSIVE);
-					closeStreamOnCancelRegistry.registerCloseable(out);
-				}
-
-				private void closeOutStream() {
-					if (closeStreamOnCancelRegistry.unregisterCloseable(out)) {
-						IOUtils.closeQuietly(out);
-					}
-				}
-
-				@Nonnull
-				@Override
-				public SnapshotResult<OperatorStateHandle> performOperation() throws Exception {
-					long asyncStartTime = System.currentTimeMillis();
-
-					CheckpointStreamFactory.CheckpointStateOutputStream localOut = this.out;
-
-					// get the registered operator state infos ...
-					List<StateMetaInfoSnapshot> operatorMetaInfoSnapshots =
-						new ArrayList<>(registeredOperatorStatesDeepCopies.size());
-
-					for (Map.Entry<String, PartitionableListState<?>> entry : registeredOperatorStatesDeepCopies.entrySet()) {
-						operatorMetaInfoSnapshots.add(entry.getValue().getStateMetaInfo().snapshot());
-					}
-
-					// ... get the registered broadcast operator state infos ...
-					List<StateMetaInfoSnapshot> broadcastMetaInfoSnapshots =
-							new ArrayList<>(registeredBroadcastStatesDeepCopies.size());
-
-					for (Map.Entry<String, BackendWritableBroadcastState<?, ?>> entry : registeredBroadcastStatesDeepCopies.entrySet()) {
-						broadcastMetaInfoSnapshots.add(entry.getValue().getStateMetaInfo().snapshot());
-					}
-
-					// ... write them all in the checkpoint stream ...
-					DataOutputView dov = new DataOutputViewStreamWrapper(localOut);
-
-					OperatorBackendSerializationProxy backendSerializationProxy =
-						new OperatorBackendSerializationProxy(operatorMetaInfoSnapshots, broadcastMetaInfoSnapshots);
-
-					backendSerializationProxy.write(dov);
-
-					// ... and then go for the states ...
-
-					// we put BOTH normal and broadcast state metadata here
-					final Map<String, OperatorStateHandle.StateMetaInfo> writtenStatesMetaData =
-							new HashMap<>(registeredOperatorStatesDeepCopies.size() + registeredBroadcastStatesDeepCopies.size());
-
-					for (Map.Entry<String, PartitionableListState<?>> entry :
-							registeredOperatorStatesDeepCopies.entrySet()) {
-
-						PartitionableListState<?> value = entry.getValue();
-						long[] partitionOffsets = value.write(localOut);
-						OperatorStateHandle.Mode mode = value.getStateMetaInfo().getAssignmentMode();
-						writtenStatesMetaData.put(
-							entry.getKey(),
-							new OperatorStateHandle.StateMetaInfo(partitionOffsets, mode));
-					}
-
-					// ... and the broadcast states themselves ...
-					for (Map.Entry<String, BackendWritableBroadcastState<?, ?>> entry :
-							registeredBroadcastStatesDeepCopies.entrySet()) {
-
-						BackendWritableBroadcastState<?, ?> value = entry.getValue();
-						long[] partitionOffsets = {value.write(localOut)};
-						OperatorStateHandle.Mode mode = value.getStateMetaInfo().getAssignmentMode();
-						writtenStatesMetaData.put(
-								entry.getKey(),
-								new OperatorStateHandle.StateMetaInfo(partitionOffsets, mode));
-					}
-
-					// ... and, finally, create the state handle.
-					OperatorStateHandle retValue = null;
-
-					if (closeStreamOnCancelRegistry.unregisterCloseable(out)) {
-
-						StreamStateHandle stateHandle = out.closeAndGetHandle();
-
-						if (stateHandle != null) {
-							retValue = new OperatorStreamStateHandle(writtenStatesMetaData, stateHandle);
-						}
-					}
-
-					if (asynchronousSnapshots) {
-						LOG.debug("DefaultOperatorStateBackend snapshot ({}, asynchronous part) in thread {} took {} ms.",
-							streamFactory, Thread.currentThread(), (System.currentTimeMillis() - asyncStartTime));
-					}
-
-					return SnapshotResult.of(retValue);
-				}
-			};
-
-		AsyncStoppableTaskWithCallback<SnapshotResult<OperatorStateHandle>> task =
-			AsyncStoppableTaskWithCallback.from(ioCallable);
-
-		if (!asynchronousSnapshots) {
-			task.run();
-		}
-
-		LOG.debug("DefaultOperatorStateBackend snapshot ({}, synchronous part) in thread {} took {} ms.",
-				streamFactory, Thread.currentThread(), (System.currentTimeMillis() - syncStartTime));
-
-		return task;
-	}
-
 	public void restore(Collection<OperatorStateHandle> restoreSnapshots) throws Exception {
 
 		if (null == restoreSnapshots || restoreSnapshots.isEmpty()) {
@@ -513,8 +342,7 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 					final RegisteredOperatorStateBackendMetaInfo<?> restoredMetaInfo =
 						new RegisteredOperatorStateBackendMetaInfo<>(restoredSnapshot);
 
-					if (restoredMetaInfo.getPartitionStateSerializer() == null ||
-						restoredMetaInfo.getPartitionStateSerializer() instanceof UnloadableDummyTypeSerializer) {
+					if (restoredMetaInfo.getPartitionStateSerializer() instanceof UnloadableDummyTypeSerializer) {
 
 						// must fail now if the previous serializer cannot be restored because there is no serializer
 						// capable of reading previous state
@@ -549,8 +377,7 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 					final RegisteredBroadcastStateBackendMetaInfo<?, ?> restoredMetaInfo =
 						new RegisteredBroadcastStateBackendMetaInfo<>(restoredSnapshot);
 
-					if (restoredMetaInfo.getKeySerializer() == null || restoredMetaInfo.getValueSerializer() == null ||
-						restoredMetaInfo.getKeySerializer() instanceof UnloadableDummyTypeSerializer ||
+					if (restoredMetaInfo.getKeySerializer() instanceof UnloadableDummyTypeSerializer ||
 						restoredMetaInfo.getValueSerializer() instanceof UnloadableDummyTypeSerializer) {
 
 						// must fail now if the previous serializer cannot be restored because there is no serializer
@@ -603,6 +430,23 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 		}
 	}
 
+	@Nonnull
+	@Override
+	public RunnableFuture<SnapshotResult<OperatorStateHandle>> snapshot(
+		long checkpointId,
+		long timestamp,
+		@Nonnull CheckpointStreamFactory streamFactory,
+		@Nonnull CheckpointOptions checkpointOptions) throws Exception {
+
+		long syncStartTime = System.currentTimeMillis();
+
+		RunnableFuture<SnapshotResult<OperatorStateHandle>> snapshotRunner =
+			snapshotStrategy.snapshot(checkpointId, timestamp, streamFactory, checkpointOptions);
+
+		snapshotStrategy.logSyncCompleted(streamFactory, syncStartTime);
+		return snapshotRunner;
+	}
+
 	/**
 	 * Implementation of operator list state.
 	 *
@@ -695,14 +539,14 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 		}
 
 		@Override
-		public void update(List<S> values) throws Exception {
+		public void update(List<S> values) {
 			internalList.clear();
 
 			addAll(values);
 		}
 
 		@Override
-		public void addAll(List<S> values) throws Exception {
+		public void addAll(List<S> values) {
 			if (values != null && !values.isEmpty()) {
 				internalList.addAll(values);
 			}
@@ -848,4 +692,167 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 				"Was [" + actualMode + "], " +
 				"registered with [" + expectedMode + "].");
 	}
+
+	/**
+	 * Snapshot strategy for this backend.
+	 */
+	private class DefaultOperatorStateBackendSnapshotStrategy extends AbstractSnapshotStrategy<OperatorStateHandle> {
+
+		protected DefaultOperatorStateBackendSnapshotStrategy() {
+			super("DefaultOperatorStateBackend snapshot");
+		}
+
+		@Nonnull
+		@Override
+		public RunnableFuture<SnapshotResult<OperatorStateHandle>> snapshot(
+			final long checkpointId,
+			final long timestamp,
+			@Nonnull final CheckpointStreamFactory streamFactory,
+			@Nonnull final CheckpointOptions checkpointOptions) throws IOException {
+
+			if (registeredOperatorStates.isEmpty() && registeredBroadcastStates.isEmpty()) {
+				return DoneFuture.of(SnapshotResult.empty());
+			}
+
+			final Map<String, PartitionableListState<?>> registeredOperatorStatesDeepCopies =
+				new HashMap<>(registeredOperatorStates.size());
+			final Map<String, BackendWritableBroadcastState<?, ?>> registeredBroadcastStatesDeepCopies =
+				new HashMap<>(registeredBroadcastStates.size());
+
+			ClassLoader snapshotClassLoader = Thread.currentThread().getContextClassLoader();
+			Thread.currentThread().setContextClassLoader(userClassloader);
+			try {
+				// eagerly create deep copies of the list and the broadcast states (if any)
+				// in the synchronous phase, so that we can use them in the async writing.
+
+				if (!registeredOperatorStates.isEmpty()) {
+					for (Map.Entry<String, PartitionableListState<?>> entry : registeredOperatorStates.entrySet()) {
+						PartitionableListState<?> listState = entry.getValue();
+						if (null != listState) {
+							listState = listState.deepCopy();
+						}
+						registeredOperatorStatesDeepCopies.put(entry.getKey(), listState);
+					}
+				}
+
+				if (!registeredBroadcastStates.isEmpty()) {
+					for (Map.Entry<String, BackendWritableBroadcastState<?, ?>> entry : registeredBroadcastStates.entrySet()) {
+						BackendWritableBroadcastState<?, ?> broadcastState = entry.getValue();
+						if (null != broadcastState) {
+							broadcastState = broadcastState.deepCopy();
+						}
+						registeredBroadcastStatesDeepCopies.put(entry.getKey(), broadcastState);
+					}
+				}
+			} finally {
+				Thread.currentThread().setContextClassLoader(snapshotClassLoader);
+			}
+
+			AsyncSnapshotCallable<SnapshotResult<OperatorStateHandle>> snapshotCallable =
+				new AsyncSnapshotCallable<SnapshotResult<OperatorStateHandle>>() {
+
+					@Override
+					protected SnapshotResult<OperatorStateHandle> callInternal() throws Exception {
+
+						CheckpointStreamFactory.CheckpointStateOutputStream localOut =
+							streamFactory.createCheckpointStateOutputStream(CheckpointedStateScope.EXCLUSIVE);
+						registerCloseableForCancellation(localOut);
+
+						// get the registered operator state infos ...
+						List<StateMetaInfoSnapshot> operatorMetaInfoSnapshots =
+							new ArrayList<>(registeredOperatorStatesDeepCopies.size());
+
+						for (Map.Entry<String, PartitionableListState<?>> entry :
+							registeredOperatorStatesDeepCopies.entrySet()) {
+							operatorMetaInfoSnapshots.add(entry.getValue().getStateMetaInfo().snapshot());
+						}
+
+						// ... get the registered broadcast operator state infos ...
+						List<StateMetaInfoSnapshot> broadcastMetaInfoSnapshots =
+							new ArrayList<>(registeredBroadcastStatesDeepCopies.size());
+
+						for (Map.Entry<String, BackendWritableBroadcastState<?, ?>> entry :
+							registeredBroadcastStatesDeepCopies.entrySet()) {
+							broadcastMetaInfoSnapshots.add(entry.getValue().getStateMetaInfo().snapshot());
+						}
+
+						// ... write them all in the checkpoint stream ...
+						DataOutputView dov = new DataOutputViewStreamWrapper(localOut);
+
+						OperatorBackendSerializationProxy backendSerializationProxy =
+							new OperatorBackendSerializationProxy(operatorMetaInfoSnapshots, broadcastMetaInfoSnapshots);
+
+						backendSerializationProxy.write(dov);
+
+						// ... and then go for the states ...
+
+						// we put BOTH normal and broadcast state metadata here
+						int initialMapCapacity =
+							registeredOperatorStatesDeepCopies.size() + registeredBroadcastStatesDeepCopies.size();
+						final Map<String, OperatorStateHandle.StateMetaInfo> writtenStatesMetaData =
+							new HashMap<>(initialMapCapacity);
+
+						for (Map.Entry<String, PartitionableListState<?>> entry :
+							registeredOperatorStatesDeepCopies.entrySet()) {
+
+							PartitionableListState<?> value = entry.getValue();
+							long[] partitionOffsets = value.write(localOut);
+							OperatorStateHandle.Mode mode = value.getStateMetaInfo().getAssignmentMode();
+							writtenStatesMetaData.put(
+								entry.getKey(),
+								new OperatorStateHandle.StateMetaInfo(partitionOffsets, mode));
+						}
+
+						// ... and the broadcast states themselves ...
+						for (Map.Entry<String, BackendWritableBroadcastState<?, ?>> entry :
+							registeredBroadcastStatesDeepCopies.entrySet()) {
+
+							BackendWritableBroadcastState<?, ?> value = entry.getValue();
+							long[] partitionOffsets = {value.write(localOut)};
+							OperatorStateHandle.Mode mode = value.getStateMetaInfo().getAssignmentMode();
+							writtenStatesMetaData.put(
+								entry.getKey(),
+								new OperatorStateHandle.StateMetaInfo(partitionOffsets, mode));
+						}
+
+						// ... and, finally, create the state handle.
+						OperatorStateHandle retValue = null;
+
+						if (unregisterCloseableFromCancellation(localOut)) {
+
+							StreamStateHandle stateHandle = localOut.closeAndGetHandle();
+
+							if (stateHandle != null) {
+								retValue = new OperatorStreamStateHandle(writtenStatesMetaData, stateHandle);
+							}
+
+							return SnapshotResult.of(retValue);
+						} else {
+							throw new IOException("Stream was already unregistered.");
+						}
+					}
+
+					@Override
+					protected void cleanupProvidedResources() {
+						// nothing to do
+					}
+
+					@Override
+					protected void logAsyncSnapshotComplete(long startTime) {
+						if (asynchronousSnapshots) {
+							logAsyncCompleted(streamFactory, startTime);
+						}
+					}
+				};
+
+			final FutureTask<SnapshotResult<OperatorStateHandle>> task =
+				snapshotCallable.toAsyncSnapshotFutureTask(closeStreamOnCancelRegistry);
+
+			if (!asynchronousSnapshots) {
+				task.run();
+			}
+
+			return task;
+		}
+	}
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SnapshotStrategy.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SnapshotStrategy.java
index 3ad68af..53c8663 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SnapshotStrategy.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SnapshotStrategy.java
@@ -18,8 +18,11 @@
 
 package org.apache.flink.runtime.state;
 
+import org.apache.flink.annotation.Internal;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 
+import javax.annotation.Nonnull;
+
 import java.util.concurrent.RunnableFuture;
 
 /**
@@ -28,7 +31,8 @@ import java.util.concurrent.RunnableFuture;
  *
  * @param <S> type of the returned state object that represents the result of the snapshot operation.
  */
-public interface SnapshotStrategy<S extends StateObject> extends CheckpointListener {
+@Internal
+public interface SnapshotStrategy<S extends StateObject> {
 
 	/**
 	 * Operation that writes a snapshot into a stream that is provided by the given {@link CheckpointStreamFactory} and
@@ -42,9 +46,10 @@ public interface SnapshotStrategy<S extends StateObject> extends CheckpointListe
 	 * @param checkpointOptions Options for how to perform this checkpoint.
 	 * @return A runnable future that will yield a {@link StateObject}.
 	 */
-	RunnableFuture<S> performSnapshot(
+	@Nonnull
+	RunnableFuture<S> snapshot(
 		long checkpointId,
 		long timestamp,
-		CheckpointStreamFactory streamFactory,
-		CheckpointOptions checkpointOptions) throws Exception;
+		@Nonnull CheckpointStreamFactory streamFactory,
+		@Nonnull CheckpointOptions checkpointOptions) throws Exception;
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/Snapshotable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/Snapshotable.java
index 733339f..1677855 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/Snapshotable.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/Snapshotable.java
@@ -18,9 +18,9 @@
 
 package org.apache.flink.runtime.state;
 
-import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.annotation.Internal;
 
-import java.util.concurrent.RunnableFuture;
+import javax.annotation.Nullable;
 
 /**
  * Interface for operators that can perform snapshots of their state.
@@ -28,25 +28,8 @@ import java.util.concurrent.RunnableFuture;
  * @param <S> Generic type of the state object that is created as handle to snapshots.
  * @param <R> Generic type of the state object that used in restore.
  */
-public interface Snapshotable<S extends StateObject, R> {
-
-	/**
-	 * Operation that writes a snapshot into a stream that is provided by the given {@link CheckpointStreamFactory} and
-	 * returns a @{@link RunnableFuture} that gives a state handle to the snapshot. It is up to the implementation if
-	 * the operation is performed synchronous or asynchronous. In the later case, the returned Runnable must be executed
-	 * first before obtaining the handle.
-	 *
-	 * @param checkpointId  The ID of the checkpoint.
-	 * @param timestamp     The timestamp of the checkpoint.
-	 * @param streamFactory The factory that we can use for writing our state to streams.
-	 * @param checkpointOptions Options for how to perform this checkpoint.
-	 * @return A runnable future that will yield a {@link StateObject}.
-	 */
-	RunnableFuture<S> snapshot(
-			long checkpointId,
-			long timestamp,
-			CheckpointStreamFactory streamFactory,
-			CheckpointOptions checkpointOptions) throws Exception;
+@Internal
+public interface Snapshotable<S extends StateObject, R> extends SnapshotStrategy<S> {
 
 	/**
 	 * Restores state that was previously snapshotted from the provided parameters. Typically the parameters are state
@@ -54,5 +37,5 @@ public interface Snapshotable<S extends StateObject, R> {
 	 *
 	 * @param state the old state to restore.
 	 */
-	void restore(R state) throws Exception;
+	void restore(@Nullable R state) throws Exception;
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
index 0e2f16c..05070f9 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
@@ -37,10 +37,10 @@ import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
-import org.apache.flink.runtime.io.async.AbstractAsyncCallableWithResources;
-import org.apache.flink.runtime.io.async.AsyncStoppableTaskWithCallback;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
+import org.apache.flink.runtime.state.AbstractSnapshotStrategy;
+import org.apache.flink.runtime.state.AsyncSnapshotCallable;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.CheckpointStreamWithResultProvider;
 import org.apache.flink.runtime.state.CheckpointedStateScope;
@@ -60,11 +60,10 @@ import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
 import org.apache.flink.runtime.state.RegisteredPriorityQueueStateBackendMetaInfo;
 import org.apache.flink.runtime.state.SnappyStreamCompressionDecorator;
 import org.apache.flink.runtime.state.SnapshotResult;
-import org.apache.flink.runtime.state.SnapshotStrategy;
 import org.apache.flink.runtime.state.StateSnapshot;
-import org.apache.flink.runtime.state.StateSnapshotTransformer;
 import org.apache.flink.runtime.state.StateSnapshotKeyGroupReader;
 import org.apache.flink.runtime.state.StateSnapshotRestore;
+import org.apache.flink.runtime.state.StateSnapshotTransformer;
 import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
 import org.apache.flink.runtime.state.StreamCompressionDecorator;
 import org.apache.flink.runtime.state.StreamStateHandle;
@@ -92,6 +91,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Optional;
+import java.util.concurrent.FutureTask;
 import java.util.concurrent.RunnableFuture;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
@@ -344,15 +344,22 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		}
 	}
 
+	@Nonnull
 	@Override
 	@SuppressWarnings("unchecked")
-	public  RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshot(
-			final long checkpointId,
-			final long timestamp,
-			final CheckpointStreamFactory streamFactory,
-			CheckpointOptions checkpointOptions) {
+	public RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshot(
+		final long checkpointId,
+		final long timestamp,
+		@Nonnull final CheckpointStreamFactory streamFactory,
+		@Nonnull CheckpointOptions checkpointOptions) throws IOException {
 
-		return snapshotStrategy.performSnapshot(checkpointId, timestamp, streamFactory, checkpointOptions);
+		long startTime = System.currentTimeMillis();
+
+		final RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshotRunner =
+			snapshotStrategy.snapshot(checkpointId, timestamp, streamFactory, checkpointOptions);
+
+		snapshotStrategy.logSyncCompleted(streamFactory, startTime);
+		return snapshotRunner;
 	}
 
 	@SuppressWarnings("deprecation")
@@ -630,9 +637,6 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 		}
 
-		default void logOperationCompleted(CheckpointStreamFactory streamFactory, long startTime) {
-
-		}
 
 		boolean isAsynchronous();
 
@@ -642,12 +646,6 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	private class AsyncSnapshotStrategySynchronicityBehavior implements SnapshotStrategySynchronicityBehavior<K> {
 
 		@Override
-		public void logOperationCompleted(CheckpointStreamFactory streamFactory, long startTime) {
-			LOG.debug("Heap backend snapshot ({}, asynchronous part) in thread {} took {} ms.",
-				streamFactory, Thread.currentThread(), (System.currentTimeMillis() - startTime));
-		}
-
-		@Override
 		public boolean isAsynchronous() {
 			return true;
 		}
@@ -682,28 +680,28 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	 * the concrete strategies. Subclasses must be threadsafe.
 	 */
 	private class HeapSnapshotStrategy
-		implements SnapshotStrategy<SnapshotResult<KeyedStateHandle>>, SnapshotStrategySynchronicityBehavior<K> {
+		extends AbstractSnapshotStrategy<KeyedStateHandle> implements SnapshotStrategySynchronicityBehavior<K> {
 
 		private final SnapshotStrategySynchronicityBehavior<K> snapshotStrategySynchronicityTrait;
 
 		HeapSnapshotStrategy(
 			SnapshotStrategySynchronicityBehavior<K> snapshotStrategySynchronicityTrait) {
+			super("Heap backend snapshot");
 			this.snapshotStrategySynchronicityTrait = snapshotStrategySynchronicityTrait;
 		}
 
+		@Nonnull
 		@Override
-		public RunnableFuture<SnapshotResult<KeyedStateHandle>> performSnapshot(
+		public RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshot(
 			long checkpointId,
 			long timestamp,
-			CheckpointStreamFactory primaryStreamFactory,
-			CheckpointOptions checkpointOptions) {
+			@Nonnull CheckpointStreamFactory primaryStreamFactory,
+			@Nonnull CheckpointOptions checkpointOptions) throws IOException {
 
 			if (!hasRegisteredState()) {
 				return DoneFuture.of(SnapshotResult.empty());
 			}
 
-			long syncStartTime = System.currentTimeMillis();
-
 			int numStates = registeredKVStates.size() + registeredPQStates.size();
 
 			Preconditions.checkState(numStates <= Short.MAX_VALUE,
@@ -754,53 +752,23 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 			//--------------------------------------------------- this becomes the end of sync part
 
-			// implementation of the async IO operation, based on FutureTask
-			final AbstractAsyncCallableWithResources<SnapshotResult<KeyedStateHandle>> ioCallable =
-				new AbstractAsyncCallableWithResources<SnapshotResult<KeyedStateHandle>>() {
-
-					CheckpointStreamWithResultProvider streamAndResultExtractor = null;
-
-					@Override
-					protected void acquireResources() throws Exception {
-						streamAndResultExtractor = checkpointStreamSupplier.get();
-						cancelStreamRegistry.registerCloseable(streamAndResultExtractor);
-					}
-
+			final AsyncSnapshotCallable<SnapshotResult<KeyedStateHandle>> asyncSnapshotCallable =
+				new AsyncSnapshotCallable<SnapshotResult<KeyedStateHandle>>() {
 					@Override
-					protected void releaseResources() {
+					protected SnapshotResult<KeyedStateHandle> callInternal() throws Exception {
 
-						unregisterAndCloseStreamAndResultExtractor();
+						final CheckpointStreamWithResultProvider streamWithResultProvider =
+							checkpointStreamSupplier.get();
 
-						for (StateSnapshot tableSnapshot : cowStateStableSnapshots.values()) {
-							tableSnapshot.release();
-						}
-					}
+						registerCloseableForCancellation(streamWithResultProvider);
 
-					@Override
-					protected void stopOperation() {
-						unregisterAndCloseStreamAndResultExtractor();
-					}
+						final CheckpointStreamFactory.CheckpointStateOutputStream localStream =
+							streamWithResultProvider.getCheckpointOutputStream();
 
-					private void unregisterAndCloseStreamAndResultExtractor() {
-						if (cancelStreamRegistry.unregisterCloseable(streamAndResultExtractor)) {
-							IOUtils.closeQuietly(streamAndResultExtractor);
-							streamAndResultExtractor = null;
-						}
-					}
-
-					@Nonnull
-					@Override
-					protected SnapshotResult<KeyedStateHandle> performOperation() throws Exception {
-
-						long startTime = System.currentTimeMillis();
-
-						CheckpointStreamFactory.CheckpointStateOutputStream localStream =
-							this.streamAndResultExtractor.getCheckpointOutputStream();
-
-						DataOutputViewStreamWrapper outView = new DataOutputViewStreamWrapper(localStream);
+						final DataOutputViewStreamWrapper outView = new DataOutputViewStreamWrapper(localStream);
 						serializationProxy.write(outView);
 
-						long[] keyGroupRangeOffsets = new long[keyGroupRange.getNumberOfKeyGroups()];
+						final long[] keyGroupRangeOffsets = new long[keyGroupRange.getNumberOfKeyGroups()];
 
 						for (int keyGroupPos = 0; keyGroupPos < keyGroupRange.getNumberOfKeyGroups(); ++keyGroupPos) {
 							int keyGroupId = keyGroupRange.getKeyGroupId(keyGroupPos);
@@ -812,35 +780,46 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 								StateSnapshot.StateKeyGroupWriter partitionedSnapshot =
 
 									stateSnapshot.getValue().getKeyGroupWriter();
-								try (OutputStream kgCompressionOut = keyGroupCompressionDecorator.decorateWithCompression(localStream)) {
-									DataOutputViewStreamWrapper kgCompressionView = new DataOutputViewStreamWrapper(kgCompressionOut);
+								try (
+									OutputStream kgCompressionOut =
+										keyGroupCompressionDecorator.decorateWithCompression(localStream)) {
+									DataOutputViewStreamWrapper kgCompressionView =
+										new DataOutputViewStreamWrapper(kgCompressionOut);
 									kgCompressionView.writeShort(stateNamesToId.get(stateSnapshot.getKey()));
 									partitionedSnapshot.writeStateInKeyGroup(kgCompressionView, keyGroupId);
 								} // this will just close the outer compression stream
 							}
 						}
 
-						if (cancelStreamRegistry.unregisterCloseable(streamAndResultExtractor)) {
+						if (unregisterCloseableFromCancellation(streamWithResultProvider)) {
 							KeyGroupRangeOffsets kgOffs = new KeyGroupRangeOffsets(keyGroupRange, keyGroupRangeOffsets);
 							SnapshotResult<StreamStateHandle> result =
-								streamAndResultExtractor.closeAndFinalizeCheckpointStreamResult();
-							streamAndResultExtractor = null;
-							logOperationCompleted(primaryStreamFactory, startTime);
+								streamWithResultProvider.closeAndFinalizeCheckpointStreamResult();
 							return CheckpointStreamWithResultProvider.toKeyedStateHandleSnapshotResult(result, kgOffs);
+						} else {
+							throw new IOException("Stream already unregistered.");
 						}
+					}
 
-						return SnapshotResult.empty();
+					@Override
+					protected void cleanupProvidedResources() {
+						for (StateSnapshot tableSnapshot : cowStateStableSnapshots.values()) {
+							tableSnapshot.release();
+						}
 					}
-				};
 
-			AsyncStoppableTaskWithCallback<SnapshotResult<KeyedStateHandle>> task =
-				AsyncStoppableTaskWithCallback.from(ioCallable);
+					@Override
+					protected void logAsyncSnapshotComplete(long startTime) {
+						if (snapshotStrategySynchronicityTrait.isAsynchronous()) {
+							logAsyncCompleted(primaryStreamFactory, startTime);
+						}
+					}
+				};
 
+			final FutureTask<SnapshotResult<KeyedStateHandle>> task =
+				asyncSnapshotCallable.toAsyncSnapshotFutureTask(cancelStreamRegistry);
 			finalizeSnapshotBeforeReturnHook(task);
 
-			LOG.debug("Heap backend snapshot (" + primaryStreamFactory + ", synchronous part) in thread " +
-				Thread.currentThread() + " took " + (System.currentTimeMillis() - syncStartTime) + " ms.");
-
 			return task;
 		}
 
@@ -850,11 +829,6 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		}
 
 		@Override
-		public void logOperationCompleted(CheckpointStreamFactory streamFactory, long startTime) {
-			snapshotStrategySynchronicityTrait.logOperationCompleted(streamFactory, startTime);
-		}
-
-		@Override
 		public boolean isAsynchronous() {
 			return snapshotStrategySynchronicityTrait.isAsynchronous();
 		}
@@ -882,11 +856,6 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 				}
 			}
 		}
-
-		@Override
-		public void notifyCheckpointComplete(long checkpointId) throws Exception {
-			// nothing to do.
-		}
 	}
 
 	private interface StateFactory {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/AsyncSnapshotCallableTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/AsyncSnapshotCallableTest.java
new file mode 100644
index 0000000..304a495
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/AsyncSnapshotCallableTest.java
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.core.fs.CloseableRegistry;
+import org.apache.flink.core.testutils.OneShotLatch;
+import org.apache.flink.util.Preconditions;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import javax.annotation.Nonnull;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.CancellationException;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.FutureTask;
+
+/**
+ * Tests for {@link AsyncSnapshotCallable}.
+ */
+public class AsyncSnapshotCallableTest {
+
+	private static final String METHOD_CALL = "callInternal";
+	private static final String METHOD_LOG = "logAsyncSnapshotComplete";
+	private static final String METHOD_CLEANUP = "cleanupProvidedResources";
+	private static final String METHOD_CANCEL = "cancel";
+	private static final String SUCCESS = "Success!";
+
+	private CloseableRegistry ownerRegistry;
+	private TestBlockingCloseable testProvidedResource;
+	private TestBlockingCloseable testBlocker;
+	private TestAsyncSnapshotCallable testAsyncSnapshotCallable;
+	private FutureTask<String> task;
+
+	@Before
+	public void setup() throws IOException {
+		ownerRegistry = new CloseableRegistry();
+		testProvidedResource = new TestBlockingCloseable();
+		testBlocker = new TestBlockingCloseable();
+		testAsyncSnapshotCallable = new TestAsyncSnapshotCallable(testProvidedResource, testBlocker);
+		task = testAsyncSnapshotCallable.toAsyncSnapshotFutureTask(ownerRegistry);
+		Assert.assertEquals(1, ownerRegistry.getNumberOfRegisteredCloseables());
+	}
+
+	@After
+	public void finalChecks() {
+		Assert.assertTrue(testProvidedResource.isClosed());
+		Assert.assertEquals(0, ownerRegistry.getNumberOfRegisteredCloseables());
+	}
+
+	@Test
+	public void testNormalRun() throws Exception {
+
+		Thread runner = startTask(task);
+
+		while (testBlocker.getWaitersCount() < 1) {
+			Thread.sleep(1L);
+		}
+
+		testBlocker.unblockSuccessfully();
+
+		runner.join();
+
+		Assert.assertEquals(SUCCESS, task.get());
+		Assert.assertEquals(
+			Arrays.asList(METHOD_CALL, METHOD_LOG, METHOD_CLEANUP),
+			testAsyncSnapshotCallable.getInvocationOrder());
+
+		Assert.assertTrue(testBlocker.isClosed());
+	}
+
+	@Test
+	public void testExceptionRun() throws Exception {
+
+		testBlocker.introduceException();
+		Thread runner = startTask(task);
+
+		while (testBlocker.getWaitersCount() < 1) {
+			Thread.sleep(1L);
+		}
+
+		testBlocker.unblockSuccessfully();
+		try {
+			task.get();
+			Assert.fail();
+		} catch (ExecutionException ee) {
+			Assert.assertEquals(IOException.class, ee.getCause().getClass());
+		}
+
+		runner.join();
+
+		Assert.assertEquals(
+			Arrays.asList(METHOD_CALL, METHOD_CLEANUP),
+			testAsyncSnapshotCallable.getInvocationOrder());
+
+		Assert.assertTrue(testBlocker.isClosed());
+	}
+
+	@Test
+	public void testCancelRun() throws Exception {
+
+		Thread runner = startTask(task);
+
+		while (testBlocker.getWaitersCount() < 1) {
+			Thread.sleep(1L);
+		}
+
+		task.cancel(true);
+		testBlocker.unblockExceptionally();
+
+		try {
+			task.get();
+			Assert.fail();
+		} catch (CancellationException ignored) {
+		}
+
+		runner.join();
+
+		Assert.assertEquals(
+			Arrays.asList(METHOD_CALL, METHOD_CANCEL, METHOD_CLEANUP),
+			testAsyncSnapshotCallable.getInvocationOrder());
+		Assert.assertTrue(testProvidedResource.isClosed());
+		Assert.assertTrue(testBlocker.isClosed());
+	}
+
+	@Test
+	public void testCloseRun() throws Exception {
+
+		Thread runner = startTask(task);
+
+		while (testBlocker.getWaitersCount() < 1) {
+			Thread.sleep(1L);
+		}
+
+		ownerRegistry.close();
+
+		try {
+			task.get();
+			Assert.fail();
+		} catch (CancellationException ignored) {
+		}
+
+		runner.join();
+
+		Assert.assertEquals(
+			Arrays.asList(METHOD_CALL, METHOD_CANCEL, METHOD_CLEANUP),
+			testAsyncSnapshotCallable.getInvocationOrder());
+		Assert.assertTrue(testBlocker.isClosed());
+	}
+
+	@Test
+	public void testCancelBeforeRun() throws Exception {
+
+		task.cancel(true);
+
+		Thread runner = startTask(task);
+
+		try {
+			task.get();
+			Assert.fail();
+		} catch (CancellationException ignored) {
+		}
+
+		runner.join();
+
+		Assert.assertEquals(
+			Arrays.asList(METHOD_CANCEL, METHOD_CLEANUP),
+			testAsyncSnapshotCallable.getInvocationOrder());
+
+		Assert.assertTrue(testProvidedResource.isClosed());
+	}
+
+	private Thread startTask(Runnable task)  {
+		Thread runner = new Thread(task);
+		runner.start();
+		return runner;
+	}
+
+	/**
+	 * Test implementation of {@link AsyncSnapshotCallable}.
+	 */
+	private static class TestAsyncSnapshotCallable extends AsyncSnapshotCallable<String> {
+
+		@Nonnull
+		private final TestBlockingCloseable providedResource;
+		@Nonnull
+		private final TestBlockingCloseable blockingResource;
+		@Nonnull
+		private final List<String> invocationOrder;
+
+		TestAsyncSnapshotCallable(
+			@Nonnull TestBlockingCloseable providedResource,
+			@Nonnull TestBlockingCloseable blockingResource) {
+
+			this.providedResource = providedResource;
+			this.blockingResource = blockingResource;
+			this.invocationOrder = new ArrayList<>();
+		}
+
+		@Override
+		protected String callInternal() throws Exception {
+
+			addInvocation(METHOD_CALL);
+			registerCloseableForCancellation(blockingResource);
+			try {
+				blockingResource.simulateBlockingOperation();
+			} finally {
+				if (unregisterCloseableFromCancellation(blockingResource)) {
+					blockingResource.close();
+				}
+			}
+
+			return SUCCESS;
+		}
+
+		@Override
+		protected void cleanupProvidedResources() {
+			addInvocation(METHOD_CLEANUP);
+			providedResource.close();
+		}
+
+		@Override
+		protected void logAsyncSnapshotComplete(long startTime) {
+			invocationOrder.add(METHOD_LOG);
+		}
+
+		@Override
+		protected void cancel() {
+			addInvocation(METHOD_CANCEL);
+			super.cancel();
+		}
+
+		@Nonnull
+		public List<String> getInvocationOrder() {
+			synchronized (invocationOrder) {
+				return new ArrayList<>(invocationOrder);
+			}
+		}
+
+		private void addInvocation(@Nonnull String invocation) {
+			synchronized (invocationOrder) {
+				invocationOrder.add(invocation);
+			}
+		}
+	}
+
+	/**
+	 * Mix of a {@link Closeable} and and some {@link OneShotLatch} functionality for testing.
+	 */
+	private static class TestBlockingCloseable implements Closeable {
+
+		private final OneShotLatch blockerLatch = new OneShotLatch();
+		private boolean closed = false;
+		private boolean unblocked = false;
+		private boolean exceptionally = false;
+
+		public void simulateBlockingOperation() throws IOException {
+			while (!unblocked) {
+				try {
+					blockerLatch.await();
+				} catch (InterruptedException e) {
+					blockerLatch.reset();
+				}
+			}
+			if (exceptionally) {
+				throw new IOException("Closed in block");
+			}
+		}
+
+		@Override
+		public void close() {
+			Preconditions.checkState(!closed);
+			this.closed = true;
+			unblockExceptionally();
+		}
+
+		public boolean isClosed() {
+			return closed;
+		}
+
+		public void unblockExceptionally() {
+			introduceException();
+			unblock();
+		}
+
+		public void unblockSuccessfully() {
+			unblock();
+		}
+
+		private void unblock() {
+			this.unblocked = true;
+			blockerLatch.trigger();
+		}
+
+		public void introduceException() {
+			this.exceptionally = true;
+		}
+
+		public int getWaitersCount() {
+			return blockerLatch.getWaitersCount();
+		}
+	}
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
index d8918e7..b5988f3 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
@@ -55,7 +55,6 @@ import java.util.HashMap;
 import java.util.Iterator;
 import java.util.Map;
 import java.util.concurrent.CancellationException;
-import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.FutureTask;
@@ -790,8 +789,7 @@ public class OperatorStateBackendTest {
 		try {
 			runnableFuture.get(60, TimeUnit.SECONDS);
 			Assert.fail();
-		} catch (ExecutionException eex) {
-			Assert.assertTrue(eex.getCause() instanceof IOException);
+		} catch (CancellationException expected) {
 		}
 	}
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
index 059a706..649c6d0 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
@@ -110,9 +110,9 @@ import java.util.PrimitiveIterator;
 import java.util.Random;
 import java.util.Timer;
 import java.util.TimerTask;
+import java.util.concurrent.CancellationException;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
@@ -189,7 +189,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 				numberOfKeyGroups,
 				keyGroupRange,
 				env.getTaskKvStateRegistry(),
-			    TtlTimeProvider.DEFAULT);
+				TtlTimeProvider.DEFAULT);
 
 		backend.restore(null);
 
@@ -4015,7 +4015,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			try {
 				snapshot.get();
 				fail("Close was not propagated.");
-			} catch (ExecutionException ex) {
+			} catch (CancellationException ex) {
 				//ignore
 			}
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
index 0b5931c..ccfafec 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
@@ -170,12 +170,13 @@ public class MockKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			.map(Map.Entry::getKey);
 	}
 
+	@Nonnull
 	@Override
 	public RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshot(
 		long checkpointId,
 		long timestamp,
-		CheckpointStreamFactory streamFactory,
-		CheckpointOptions checkpointOptions) {
+		@Nonnull CheckpointStreamFactory streamFactory,
+		@Nonnull CheckpointOptions checkpointOptions) {
 		return new FutureTask<>(() ->
 			SnapshotResult.of(new MockKeyedStateHandle<>(copy(stateValues, stateSnapshotFilters))));
 	}
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index 87c7e55..60baaed 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -35,6 +35,7 @@ import org.apache.flink.api.common.typeutils.base.array.BytePrimitiveArraySerial
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.contrib.streaming.state.iterator.RocksStateKeysIterator;
+import org.apache.flink.contrib.streaming.state.snapshot.RocksDBSnapshotStrategyBase;
 import org.apache.flink.contrib.streaming.state.snapshot.RocksFullSnapshotStrategy;
 import org.apache.flink.contrib.streaming.state.snapshot.RocksIncrementalSnapshotStrategy;
 import org.apache.flink.core.fs.FSDataInputStream;
@@ -46,8 +47,8 @@ import org.apache.flink.core.memory.ByteArrayDataInputView;
 import org.apache.flink.core.memory.ByteArrayDataOutputView;
 import org.apache.flink.core.memory.DataInputView;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
-import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.CheckpointType;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
@@ -70,7 +71,6 @@ import org.apache.flink.runtime.state.RegisteredPriorityQueueStateBackendMetaInf
 import org.apache.flink.runtime.state.RegisteredStateMetaInfoBase;
 import org.apache.flink.runtime.state.SnappyStreamCompressionDecorator;
 import org.apache.flink.runtime.state.SnapshotResult;
-import org.apache.flink.runtime.state.SnapshotStrategy;
 import org.apache.flink.runtime.state.StateHandleID;
 import org.apache.flink.runtime.state.StateSnapshotTransformer;
 import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
@@ -207,7 +207,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	private final WriteOptions writeOptions;
 
 	/**
-	 * Information about the k/v states as we create them. This is used to retrieve the
+	 * Information about the k/v states, maintained in the order as we create them. This is used to retrieve the
 	 * column family that is used for a state and also for sanity checks when restoring.
 	 */
 	private final LinkedHashMap<String, Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase>> kvStateInformation;
@@ -229,8 +229,11 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	/** The configuration of local recovery. */
 	private final LocalRecoveryConfig localRecoveryConfig;
 
-	/** The snapshot strategy, e.g., if we use full or incremental checkpoints, local state, and so on. */
-	private SnapshotStrategy<SnapshotResult<KeyedStateHandle>> snapshotStrategy;
+	/** The checkpoint snapshot strategy, e.g., if we use full or incremental checkpoints, local state, and so on. */
+	private RocksDBSnapshotStrategyBase<K> checkpointSnapshotStrategy;
+
+	/** The savepoint snapshot strategy. */
+	private RocksDBSnapshotStrategyBase<K> savepointSnapshotStrategy;
 
 	/** Factory for priority queue state. */
 	private final PriorityQueueSetFactory priorityQueueFactory;
@@ -444,17 +447,29 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	 * @return Future to the state handle of the snapshot data.
 	 * @throws Exception indicating a problem in the synchronous part of the checkpoint.
 	 */
+	@Nonnull
 	@Override
 	public RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshot(
 		final long checkpointId,
 		final long timestamp,
-		final CheckpointStreamFactory streamFactory,
-		CheckpointOptions checkpointOptions) throws Exception {
+		@Nonnull final CheckpointStreamFactory streamFactory,
+		@Nonnull CheckpointOptions checkpointOptions) throws Exception {
+
+		long startTime = System.currentTimeMillis();
 
 		// flush everything into db before taking a snapshot
 		writeBatchWrapper.flush();
 
-		return snapshotStrategy.performSnapshot(checkpointId, timestamp, streamFactory, checkpointOptions);
+		RocksDBSnapshotStrategyBase<K> chosenSnapshotStrategy =
+			CheckpointType.SAVEPOINT == checkpointOptions.getCheckpointType() ?
+				savepointSnapshotStrategy : checkpointSnapshotStrategy;
+
+		RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshotRunner =
+			chosenSnapshotStrategy.snapshot(checkpointId, timestamp, streamFactory, checkpointOptions);
+
+		chosenSnapshotStrategy.logSyncCompleted(streamFactory, startTime);
+
+		return snapshotRunner;
 	}
 
 	@Override
@@ -497,7 +512,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	void initializeSnapshotStrategy(
 		@Nullable RocksDBIncrementalRestoreOperation<K> incrementalRestoreOperation) {
 
-		final RocksFullSnapshotStrategy<K> fullSnapshotStrategy =
+		this.savepointSnapshotStrategy =
 			new RocksFullSnapshotStrategy<>(
 				db,
 				rocksDBResourceGuard,
@@ -525,7 +540,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 				Preconditions.checkState(lastCompletedCheckpointId >= 0L);
 			}
 			// TODO eventually we might want to separate savepoint and snapshot strategy, i.e. having 2 strategies.
-			this.snapshotStrategy = new RocksIncrementalSnapshotStrategy<>(
+			this.checkpointSnapshotStrategy = new RocksIncrementalSnapshotStrategy<>(
 				db,
 				rocksDBResourceGuard,
 				keySerializer,
@@ -537,17 +552,21 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 				instanceBasePath,
 				backendUID,
 				materializedSstFiles,
-				lastCompletedCheckpointId,
-				fullSnapshotStrategy);
+				lastCompletedCheckpointId);
 		} else {
-			this.snapshotStrategy = fullSnapshotStrategy;
+			this.checkpointSnapshotStrategy = savepointSnapshotStrategy;
 		}
 	}
 
 	@Override
 	public void notifyCheckpointComplete(long completedCheckpointId) throws Exception {
-		if (snapshotStrategy != null) {
-			snapshotStrategy.notifyCheckpointComplete(completedCheckpointId);
+
+		if (checkpointSnapshotStrategy != null) {
+			checkpointSnapshotStrategy.notifyCheckpointComplete(completedCheckpointId);
+		}
+
+		if (savepointSnapshotStrategy != null) {
+			savepointSnapshotStrategy.notifyCheckpointComplete(completedCheckpointId);
 		}
 	}
 
@@ -966,9 +985,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			@Nonnull
 			private final List<StateMetaInfoSnapshot> stateMetaInfoSnapshots;
 
-			private
-
-			RestoredDBInstance(
+			private RestoredDBInstance(
 				@Nonnull RocksDB db,
 				@Nonnull List<ColumnFamilyHandle> columnFamilyHandles,
 				@Nonnull List<ColumnFamilyDescriptor> columnFamilyDescriptors,
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/SnapshotStrategyBase.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksDBSnapshotStrategyBase.java
similarity index 57%
rename from flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/SnapshotStrategyBase.java
rename to flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksDBSnapshotStrategyBase.java
index efebe8c..fffd98d 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/SnapshotStrategyBase.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksDBSnapshotStrategyBase.java
@@ -21,6 +21,11 @@ package org.apache.flink.contrib.streaming.state.snapshot;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.core.fs.CloseableRegistry;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.state.AbstractSnapshotStrategy;
+import org.apache.flink.runtime.state.CheckpointListener;
+import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.runtime.state.DoneFuture;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.LocalRecoveryConfig;
@@ -31,44 +36,60 @@ import org.apache.flink.util.ResourceGuard;
 
 import org.rocksdb.ColumnFamilyHandle;
 import org.rocksdb.RocksDB;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import javax.annotation.Nonnegative;
 import javax.annotation.Nonnull;
 
 import java.util.LinkedHashMap;
+import java.util.concurrent.RunnableFuture;
 
 /**
- * Base class for {@link SnapshotStrategy} implementations on RocksDB.
+ * Abstract base class for {@link SnapshotStrategy} implementations for RocksDB state backend.
  *
  * @param <K> type of the backend keys.
  */
-public abstract class SnapshotStrategyBase<K> implements SnapshotStrategy<SnapshotResult<KeyedStateHandle>> {
+public abstract class RocksDBSnapshotStrategyBase<K>
+	extends AbstractSnapshotStrategy<KeyedStateHandle>
+	implements CheckpointListener {
 
+	private static final Logger LOG = LoggerFactory.getLogger(RocksDBSnapshotStrategyBase.class);
+
+	/** RocksDB instance from the backend. */
 	@Nonnull
 	protected final RocksDB db;
 
+	/** Resource guard for the RocksDB instance. */
 	@Nonnull
 	protected final ResourceGuard rocksDBResourceGuard;
 
+	/** The key serializer of the backend. */
 	@Nonnull
 	protected final TypeSerializer<K> keySerializer;
 
+	/** Key/Value state meta info from the backend. */
 	@Nonnull
 	protected final LinkedHashMap<String, Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase>> kvStateInformation;
 
+	/** The key-group range for the task. */
 	@Nonnull
 	protected final KeyGroupRange keyGroupRange;
 
+	/** Number of bytes in the key-group prefix. */
 	@Nonnegative
 	protected final int keyGroupPrefixBytes;
 
+	/** The configuration for local recovery. */
 	@Nonnull
 	protected final LocalRecoveryConfig localRecoveryConfig;
 
+	/** A {@link CloseableRegistry} that will be closed when the task is cancelled. */
 	@Nonnull
 	protected final CloseableRegistry cancelStreamRegistry;
 
-	public SnapshotStrategyBase(
+	public RocksDBSnapshotStrategyBase(
+		@Nonnull String description,
 		@Nonnull RocksDB db,
 		@Nonnull ResourceGuard rocksDBResourceGuard,
 		@Nonnull TypeSerializer<K> keySerializer,
@@ -78,6 +99,7 @@ public abstract class SnapshotStrategyBase<K> implements SnapshotStrategy<Snapsh
 		@Nonnull LocalRecoveryConfig localRecoveryConfig,
 		@Nonnull CloseableRegistry cancelStreamRegistry) {
 
+		super(description);
 		this.db = db;
 		this.rocksDBResourceGuard = rocksDBResourceGuard;
 		this.keySerializer = keySerializer;
@@ -87,4 +109,33 @@ public abstract class SnapshotStrategyBase<K> implements SnapshotStrategy<Snapsh
 		this.localRecoveryConfig = localRecoveryConfig;
 		this.cancelStreamRegistry = cancelStreamRegistry;
 	}
+
+	@Nonnull
+	@Override
+	public final RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshot(
+		long checkpointId,
+		long timestamp,
+		@Nonnull CheckpointStreamFactory streamFactory,
+		@Nonnull CheckpointOptions checkpointOptions) throws Exception {
+
+		if (kvStateInformation.isEmpty()) {
+			if (LOG.isDebugEnabled()) {
+				LOG.debug("Asynchronous RocksDB snapshot performed on empty keyed state at {}. Returning null.",
+					timestamp);
+			}
+			return DoneFuture.of(SnapshotResult.empty());
+		} else {
+			return doSnapshot(checkpointId, timestamp, streamFactory, checkpointOptions);
+		}
+	}
+
+	/**
+	 * This method implements the concrete snapshot logic for a non-empty state.
+	 */
+	@Nonnull
+	protected abstract RunnableFuture<SnapshotResult<KeyedStateHandle>> doSnapshot(
+		long checkpointId,
+		long timestamp,
+		CheckpointStreamFactory streamFactory,
+		CheckpointOptions checkpointOptions) throws Exception;
 }
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksFullSnapshotStrategy.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksFullSnapshotStrategy.java
index 0cc9729..0aa091e 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksFullSnapshotStrategy.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksFullSnapshotStrategy.java
@@ -30,10 +30,10 @@ import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.CheckpointType;
+import org.apache.flink.runtime.state.AsyncSnapshotCallable;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.CheckpointStreamWithResultProvider;
 import org.apache.flink.runtime.state.CheckpointedStateScope;
-import org.apache.flink.runtime.state.DoneFuture;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyedBackendSerializationProxy;
@@ -55,8 +55,6 @@ import org.rocksdb.ReadOptions;
 import org.rocksdb.RocksDB;
 import org.rocksdb.RocksIterator;
 import org.rocksdb.Snapshot;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 
 import javax.annotation.Nonnegative;
 import javax.annotation.Nonnull;
@@ -67,11 +65,7 @@ import java.util.ArrayList;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Objects;
-import java.util.concurrent.Callable;
-import java.util.concurrent.CancellationException;
-import java.util.concurrent.FutureTask;
 import java.util.concurrent.RunnableFuture;
-import java.util.concurrent.atomic.AtomicBoolean;
 
 import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUtil.END_OF_KEY_GROUP_MARK;
 import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUtil.hasMetaDataFollowsFlag;
@@ -84,9 +78,9 @@ import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUti
  *
  * @param <K> type of the backend keys.
  */
-public class RocksFullSnapshotStrategy<K> extends SnapshotStrategyBase<K> {
+public class RocksFullSnapshotStrategy<K> extends RocksDBSnapshotStrategyBase<K> {
 
-	private static final Logger LOG = LoggerFactory.getLogger(RocksFullSnapshotStrategy.class);
+	private static final String DESCRIPTION = "Asynchronous incremental RocksDB snapshot";
 
 	/** This decorator is used to apply compression per key-group for the written snapshot data. */
 	@Nonnull
@@ -103,6 +97,7 @@ public class RocksFullSnapshotStrategy<K> extends SnapshotStrategyBase<K> {
 		@Nonnull CloseableRegistry cancelStreamRegistry,
 		@Nonnull StreamCompressionDecorator keyGroupCompressionDecorator) {
 		super(
+			DESCRIPTION,
 			db,
 			rocksDBResourceGuard,
 			keySerializer,
@@ -115,45 +110,40 @@ public class RocksFullSnapshotStrategy<K> extends SnapshotStrategyBase<K> {
 		this.keyGroupCompressionDecorator = keyGroupCompressionDecorator;
 	}
 
+	@Nonnull
 	@Override
-	public RunnableFuture<SnapshotResult<KeyedStateHandle>> performSnapshot(
+	public RunnableFuture<SnapshotResult<KeyedStateHandle>> doSnapshot(
 		long checkpointId,
 		long timestamp,
-		CheckpointStreamFactory primaryStreamFactory,
-		CheckpointOptions checkpointOptions) throws Exception {
+		@Nonnull CheckpointStreamFactory primaryStreamFactory,
+		@Nonnull CheckpointOptions checkpointOptions) throws Exception {
 
-		long startTime = System.currentTimeMillis();
+		final SupplierWithException<CheckpointStreamWithResultProvider, Exception> checkpointStreamSupplier =
+			createCheckpointStreamSupplier(checkpointId, primaryStreamFactory, checkpointOptions);
 
-		if (kvStateInformation.isEmpty()) {
-			if (LOG.isDebugEnabled()) {
-				LOG.debug("Asynchronous RocksDB snapshot performed on empty keyed state at {}. Returning null.",
-					timestamp);
-			}
+		final List<StateMetaInfoSnapshot> stateMetaInfoSnapshots = new ArrayList<>(kvStateInformation.size());
+		final List<Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase>> metaDataCopy =
+			new ArrayList<>(kvStateInformation.size());
 
-			return DoneFuture.of(SnapshotResult.empty());
+		for (Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> tuple2 : kvStateInformation.values()) {
+			// snapshot meta info
+			stateMetaInfoSnapshots.add(tuple2.f1.snapshot());
+			metaDataCopy.add(tuple2);
 		}
 
-		final SupplierWithException<CheckpointStreamWithResultProvider, Exception> supplier =
-
-			localRecoveryConfig.isLocalRecoveryEnabled() &&
-				(CheckpointType.SAVEPOINT != checkpointOptions.getCheckpointType()) ?
-
-				() -> CheckpointStreamWithResultProvider.createDuplicatingStream(
-					checkpointId,
-					CheckpointedStateScope.EXCLUSIVE,
-					primaryStreamFactory,
-					localRecoveryConfig.getLocalStateDirectoryProvider()) :
+		final ResourceGuard.Lease lease = rocksDBResourceGuard.acquireResource();
+		final Snapshot snapshot = db.getSnapshot();
 
-				() -> CheckpointStreamWithResultProvider.createSimpleStream(
-					CheckpointedStateScope.EXCLUSIVE,
-					primaryStreamFactory);
+		final SnapshotAsynchronousPartCallable asyncSnapshotCallable =
+			new SnapshotAsynchronousPartCallable(
+				checkpointStreamSupplier,
+				lease,
+				snapshot,
+				stateMetaInfoSnapshots,
+				metaDataCopy,
+				primaryStreamFactory.toString());
 
-		final CloseableRegistry snapshotCloseableRegistry = new CloseableRegistry();
-
-		final RocksDBFullSnapshotCallable snapshotOperation =
-			new RocksDBFullSnapshotCallable(supplier, snapshotCloseableRegistry);
-
-		return new SnapshotTask(snapshotOperation);
+		return asyncSnapshotCallable.toAsyncSnapshotFutureTask(cancelStreamRegistry);
 	}
 
 	@Override
@@ -161,160 +151,124 @@ public class RocksFullSnapshotStrategy<K> extends SnapshotStrategyBase<K> {
 		// nothing to do.
 	}
 
-	/**
-	 * Wrapping task to run a {@link RocksDBFullSnapshotCallable} and delegate cancellation.
-	 */
-	private class SnapshotTask extends FutureTask<SnapshotResult<KeyedStateHandle>> {
+	private SupplierWithException<CheckpointStreamWithResultProvider, Exception> createCheckpointStreamSupplier(
+		long checkpointId,
+		CheckpointStreamFactory primaryStreamFactory,
+		CheckpointOptions checkpointOptions) {
 
-		/** Reference to the callable for cancellation. */
-		@Nonnull
-		private final AutoCloseable callableClose;
+		return localRecoveryConfig.isLocalRecoveryEnabled() &&
+			(CheckpointType.SAVEPOINT != checkpointOptions.getCheckpointType()) ?
 
-		SnapshotTask(@Nonnull RocksDBFullSnapshotCallable callable) {
-			super(callable);
-			this.callableClose = callable;
-		}
+			() -> CheckpointStreamWithResultProvider.createDuplicatingStream(
+				checkpointId,
+				CheckpointedStateScope.EXCLUSIVE,
+				primaryStreamFactory,
+				localRecoveryConfig.getLocalStateDirectoryProvider()) :
 
-		@Override
-		public boolean cancel(boolean mayInterruptIfRunning) {
-			IOUtils.closeQuietly(callableClose);
-			return super.cancel(mayInterruptIfRunning);
-		}
+			() -> CheckpointStreamWithResultProvider.createSimpleStream(
+				CheckpointedStateScope.EXCLUSIVE,
+				primaryStreamFactory);
 	}
 
 	/**
 	 * Encapsulates the process to perform a full snapshot of a RocksDBKeyedStateBackend.
 	 */
 	@VisibleForTesting
-	private class RocksDBFullSnapshotCallable implements Callable<SnapshotResult<KeyedStateHandle>>, AutoCloseable {
-
-		@Nonnull
-		private final KeyGroupRangeOffsets keyGroupRangeOffsets;
+	private class SnapshotAsynchronousPartCallable extends AsyncSnapshotCallable<SnapshotResult<KeyedStateHandle>> {
 
+		/** Supplier for the stream into which we write the snapshot. */
 		@Nonnull
 		private final SupplierWithException<CheckpointStreamWithResultProvider, Exception> checkpointStreamSupplier;
 
-		@Nonnull
-		private final CloseableRegistry snapshotCloseableRegistry;
-
+		/** This lease protects the native RocksDB resources. */
 		@Nonnull
 		private final ResourceGuard.Lease dbLease;
 
+		/** RocksDB snapshot. */
 		@Nonnull
 		private final Snapshot snapshot;
 
 		@Nonnull
-		private final ReadOptions readOptions;
-
-		/**
-		 * The state meta data.
-		 */
-		@Nonnull
 		private List<StateMetaInfoSnapshot> stateMetaInfoSnapshots;
 
-		/**
-		 * The copied column handle.
-		 */
 		@Nonnull
 		private List<Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase>> metaDataCopy;
 
-		private final AtomicBoolean ownedForCleanup;
+		@Nonnull
+		private final String logPathString;
 
-		RocksDBFullSnapshotCallable(
+		SnapshotAsynchronousPartCallable(
 			@Nonnull SupplierWithException<CheckpointStreamWithResultProvider, Exception> checkpointStreamSupplier,
-			@Nonnull CloseableRegistry registry) throws IOException {
+			@Nonnull ResourceGuard.Lease dbLease,
+			@Nonnull Snapshot snapshot,
+			@Nonnull List<StateMetaInfoSnapshot> stateMetaInfoSnapshots,
+			@Nonnull List<Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase>> metaDataCopy,
+			@Nonnull String logPathString) {
 
-			this.ownedForCleanup = new AtomicBoolean(false);
 			this.checkpointStreamSupplier = checkpointStreamSupplier;
-			this.keyGroupRangeOffsets = new KeyGroupRangeOffsets(keyGroupRange);
-			this.snapshotCloseableRegistry = registry;
-
-			this.stateMetaInfoSnapshots = new ArrayList<>(kvStateInformation.size());
-			this.metaDataCopy = new ArrayList<>(kvStateInformation.size());
-			for (Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> tuple2 : kvStateInformation.values()) {
-				// snapshot meta info
-				this.stateMetaInfoSnapshots.add(tuple2.f1.snapshot());
-				this.metaDataCopy.add(tuple2);
-			}
-
-			this.dbLease = rocksDBResourceGuard.acquireResource();
-
-			this.readOptions = new ReadOptions();
-			this.snapshot = db.getSnapshot();
-			this.readOptions.setSnapshot(snapshot);
+			this.dbLease = dbLease;
+			this.snapshot = snapshot;
+			this.stateMetaInfoSnapshots = stateMetaInfoSnapshots;
+			this.metaDataCopy = metaDataCopy;
+			this.logPathString = logPathString;
 		}
 
 		@Override
-		public SnapshotResult<KeyedStateHandle> call() throws Exception {
-
-			if (!ownedForCleanup.compareAndSet(false, true)) {
-				throw new CancellationException("Snapshot task was already cancelled, stopping execution.");
-			}
+		protected SnapshotResult<KeyedStateHandle> callInternal() throws Exception {
+			final KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(keyGroupRange);
+			final CheckpointStreamWithResultProvider checkpointStreamWithResultProvider =
+				checkpointStreamSupplier.get();
 
-			final long startTime = System.currentTimeMillis();
-			final List<Tuple2<RocksIteratorWrapper, Integer>> kvStateIterators = new ArrayList<>(metaDataCopy.size());
+			registerCloseableForCancellation(checkpointStreamWithResultProvider);
+			writeSnapshotToOutputStream(checkpointStreamWithResultProvider, keyGroupRangeOffsets);
 
-			try {
-
-				cancelStreamRegistry.registerCloseable(snapshotCloseableRegistry);
-
-				final CheckpointStreamWithResultProvider checkpointStreamWithResultProvider = checkpointStreamSupplier.get();
-				snapshotCloseableRegistry.registerCloseable(checkpointStreamWithResultProvider);
-
-				final DataOutputView outputView =
-					new DataOutputViewStreamWrapper(checkpointStreamWithResultProvider.getCheckpointOutputStream());
-
-				writeKVStateMetaData(kvStateIterators, outputView);
-				writeKVStateData(kvStateIterators, checkpointStreamWithResultProvider);
+			if (unregisterCloseableFromCancellation(checkpointStreamWithResultProvider)) {
+				return CheckpointStreamWithResultProvider.toKeyedStateHandleSnapshotResult(
+					checkpointStreamWithResultProvider.closeAndFinalizeCheckpointStreamResult(),
+					keyGroupRangeOffsets);
+			} else {
+				throw new IOException("Stream is already unregistered/closed.");
+			}
+		}
 
-				final SnapshotResult<KeyedStateHandle> snapshotResult =
-					createStateHandlesFromStreamProvider(checkpointStreamWithResultProvider);
+		@Override
+		protected void cleanupProvidedResources() {
+			db.releaseSnapshot(snapshot);
+			IOUtils.closeQuietly(snapshot);
+			IOUtils.closeQuietly(dbLease);
+		}
 
-				LOG.info("Asynchronous RocksDB snapshot ({}, asynchronous part) in thread {} took {} ms.",
-					checkpointStreamSupplier, Thread.currentThread(), (System.currentTimeMillis() - startTime));
+		@Override
+		protected void logAsyncSnapshotComplete(long startTime) {
+			logAsyncCompleted(logPathString, startTime);
+		}
 
-				return snapshotResult;
+		private void writeSnapshotToOutputStream(
+			@Nonnull CheckpointStreamWithResultProvider checkpointStreamWithResultProvider,
+			@Nonnull KeyGroupRangeOffsets keyGroupRangeOffsets) throws IOException, InterruptedException {
 
+			final List<Tuple2<RocksIteratorWrapper, Integer>> kvStateIterators =
+				new ArrayList<>(metaDataCopy.size());
+			final DataOutputView outputView =
+				new DataOutputViewStreamWrapper(checkpointStreamWithResultProvider.getCheckpointOutputStream());
+			final ReadOptions readOptions = new ReadOptions();
+			try {
+				readOptions.setSnapshot(snapshot);
+				writeKVStateMetaData(kvStateIterators, readOptions, outputView);
+				writeKVStateData(kvStateIterators, checkpointStreamWithResultProvider, keyGroupRangeOffsets);
 			} finally {
 
 				for (Tuple2<RocksIteratorWrapper, Integer> kvStateIterator : kvStateIterators) {
 					IOUtils.closeQuietly(kvStateIterator.f0);
 				}
 
-				cleanupSynchronousStepResources();
-			}
-		}
-
-		private void cleanupSynchronousStepResources() {
-			IOUtils.closeQuietly(readOptions);
-
-			db.releaseSnapshot(snapshot);
-			IOUtils.closeQuietly(snapshot);
-
-			IOUtils.closeQuietly(dbLease);
-
-			if (cancelStreamRegistry.unregisterCloseable(snapshotCloseableRegistry)) {
-				try {
-					snapshotCloseableRegistry.close();
-				} catch (Exception ex) {
-					LOG.warn("Error closing local registry", ex);
-				}
-			}
-		}
-
-		private SnapshotResult<KeyedStateHandle> createStateHandlesFromStreamProvider(
-			CheckpointStreamWithResultProvider checkpointStreamWithResultProvider) throws IOException {
-			if (snapshotCloseableRegistry.unregisterCloseable(checkpointStreamWithResultProvider)) {
-				return CheckpointStreamWithResultProvider.toKeyedStateHandleSnapshotResult(
-					checkpointStreamWithResultProvider.closeAndFinalizeCheckpointStreamResult(),
-					keyGroupRangeOffsets);
-			} else {
-				throw new IOException("Snapshot was already closed before completion.");
+				IOUtils.closeQuietly(readOptions);
 			}
 		}
 
 		private void writeKVStateMetaData(
 			final List<Tuple2<RocksIteratorWrapper, Integer>> kvStateIterators,
+			final ReadOptions readOptions,
 			final DataOutputView outputView) throws IOException {
 
 			int kvStateId = 0;
@@ -343,7 +297,8 @@ public class RocksFullSnapshotStrategy<K> extends SnapshotStrategyBase<K> {
 
 		private void writeKVStateData(
 			final List<Tuple2<RocksIteratorWrapper, Integer>> kvStateIterators,
-			final CheckpointStreamWithResultProvider checkpointStreamWithResultProvider) throws IOException, InterruptedException {
+			final CheckpointStreamWithResultProvider checkpointStreamWithResultProvider,
+			final KeyGroupRangeOffsets keyGroupRangeOffsets) throws IOException, InterruptedException {
 
 			byte[] previousKey = null;
 			byte[] previousValue = null;
@@ -445,18 +400,6 @@ public class RocksFullSnapshotStrategy<K> extends SnapshotStrategyBase<K> {
 				throw new InterruptedException("RocksDB snapshot interrupted.");
 			}
 		}
-
-		@Override
-		public void close() throws Exception {
-
-			if (ownedForCleanup.compareAndSet(false, true)) {
-				cleanupSynchronousStepResources();
-			}
-
-			if (cancelStreamRegistry.unregisterCloseable(snapshotCloseableRegistry)) {
-				snapshotCloseableRegistry.close();
-			}
-		}
 	}
 
 	@SuppressWarnings("unchecked")
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksIncrementalSnapshotStrategy.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksIncrementalSnapshotStrategy.java
index 3487fe6..8117031 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksIncrementalSnapshotStrategy.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksIncrementalSnapshotStrategy.java
@@ -28,12 +28,11 @@ import org.apache.flink.core.fs.Path;
 import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
-import org.apache.flink.runtime.checkpoint.CheckpointType;
+import org.apache.flink.runtime.state.AsyncSnapshotCallable;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.CheckpointStreamWithResultProvider;
 import org.apache.flink.runtime.state.CheckpointedStateScope;
 import org.apache.flink.runtime.state.DirectoryStateHandle;
-import org.apache.flink.runtime.state.DoneFuture;
 import org.apache.flink.runtime.state.IncrementalKeyedStateHandle;
 import org.apache.flink.runtime.state.IncrementalLocalKeyedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupRange;
@@ -45,7 +44,6 @@ import org.apache.flink.runtime.state.PlaceholderStreamStateHandle;
 import org.apache.flink.runtime.state.RegisteredStateMetaInfoBase;
 import org.apache.flink.runtime.state.SnapshotDirectory;
 import org.apache.flink.runtime.state.SnapshotResult;
-import org.apache.flink.runtime.state.SnapshotStrategy;
 import org.apache.flink.runtime.state.StateHandleID;
 import org.apache.flink.runtime.state.StateObject;
 import org.apache.flink.runtime.state.StateUtil;
@@ -65,11 +63,11 @@ import org.slf4j.LoggerFactory;
 
 import javax.annotation.Nonnegative;
 import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
 
 import java.io.File;
 import java.io.IOException;
 import java.util.ArrayList;
-import java.util.Collection;
 import java.util.HashMap;
 import java.util.LinkedHashMap;
 import java.util.List;
@@ -77,7 +75,6 @@ import java.util.Map;
 import java.util.Set;
 import java.util.SortedMap;
 import java.util.UUID;
-import java.util.concurrent.FutureTask;
 import java.util.concurrent.RunnableFuture;
 
 import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUtil.SST_FILE_SUFFIX;
@@ -88,10 +85,12 @@ import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUti
  *
  * @param <K> type of the backend keys.
  */
-public class RocksIncrementalSnapshotStrategy<K> extends SnapshotStrategyBase<K> {
+public class RocksIncrementalSnapshotStrategy<K> extends RocksDBSnapshotStrategyBase<K> {
 
 	private static final Logger LOG = LoggerFactory.getLogger(RocksIncrementalSnapshotStrategy.class);
 
+	private static final String DESCRIPTION = "Asynchronous incremental RocksDB snapshot";
+
 	/** Base path of the RocksDB instance. */
 	@Nonnull
 	private final File instanceBasePath;
@@ -107,10 +106,6 @@ public class RocksIncrementalSnapshotStrategy<K> extends SnapshotStrategyBase<K>
 	/** The identifier of the last completed checkpoint. */
 	private long lastCompletedCheckpointId;
 
-	/** We delegate snapshots that are for savepoints to this. */
-	@Nonnull
-	private final SnapshotStrategy<SnapshotResult<KeyedStateHandle>> savepointDelegate;
-
 	public RocksIncrementalSnapshotStrategy(
 		@Nonnull RocksDB db,
 		@Nonnull ResourceGuard rocksDBResourceGuard,
@@ -123,10 +118,10 @@ public class RocksIncrementalSnapshotStrategy<K> extends SnapshotStrategyBase<K>
 		@Nonnull File instanceBasePath,
 		@Nonnull UUID backendUID,
 		@Nonnull SortedMap<Long, Set<StateHandleID>> materializedSstFiles,
-		long lastCompletedCheckpointId,
-		@Nonnull SnapshotStrategy<SnapshotResult<KeyedStateHandle>> savepointDelegate) {
+		long lastCompletedCheckpointId) {
 
 		super(
+			DESCRIPTION,
 			db,
 			rocksDBResourceGuard,
 			keySerializer,
@@ -140,33 +135,47 @@ public class RocksIncrementalSnapshotStrategy<K> extends SnapshotStrategyBase<K>
 		this.backendUID = backendUID;
 		this.materializedSstFiles = materializedSstFiles;
 		this.lastCompletedCheckpointId = lastCompletedCheckpointId;
-		this.savepointDelegate = savepointDelegate;
 	}
 
+	@Nonnull
 	@Override
-	public RunnableFuture<SnapshotResult<KeyedStateHandle>> performSnapshot(
+	protected RunnableFuture<SnapshotResult<KeyedStateHandle>> doSnapshot(
 		long checkpointId,
 		long checkpointTimestamp,
-		CheckpointStreamFactory checkpointStreamFactory,
-		CheckpointOptions checkpointOptions) throws Exception {
+		@Nonnull CheckpointStreamFactory checkpointStreamFactory,
+		@Nonnull CheckpointOptions checkpointOptions) throws Exception {
 
-		// for savepoints, we delegate to the full snapshot strategy because savepoints are always self-contained.
-		if (CheckpointType.SAVEPOINT == checkpointOptions.getCheckpointType()) {
-			return savepointDelegate.performSnapshot(
+		final SnapshotDirectory snapshotDirectory = prepareLocalSnapshotDirectory(checkpointId);
+		LOG.trace("Local RocksDB checkpoint goes to backup path {}.", snapshotDirectory);
+
+		final List<StateMetaInfoSnapshot> stateMetaInfoSnapshots = new ArrayList<>(kvStateInformation.size());
+		final Set<StateHandleID> baseSstFiles = snapshotMetaData(checkpointId, stateMetaInfoSnapshots);
+
+		takeDBNativeCheckpoint(snapshotDirectory);
+
+		final RocksDBIncrementalSnapshotOperation snapshotOperation =
+			new RocksDBIncrementalSnapshotOperation(
 				checkpointId,
-				checkpointTimestamp,
 				checkpointStreamFactory,
-				checkpointOptions);
-		}
+				snapshotDirectory,
+				baseSstFiles,
+				stateMetaInfoSnapshots);
 
-		if (kvStateInformation.isEmpty()) {
-			if (LOG.isDebugEnabled()) {
-				LOG.debug("Asynchronous RocksDB snapshot performed on empty keyed state at {}. Returning null.", checkpointTimestamp);
+		return snapshotOperation.toAsyncSnapshotFutureTask(cancelStreamRegistry);
+	}
+
+	@Override
+	public void notifyCheckpointComplete(long completedCheckpointId) {
+		synchronized (materializedSstFiles) {
+			if (completedCheckpointId > lastCompletedCheckpointId) {
+				materializedSstFiles.keySet().removeIf(checkpointId -> checkpointId < completedCheckpointId);
+				lastCompletedCheckpointId = completedCheckpointId;
 			}
-			return DoneFuture.of(SnapshotResult.empty());
 		}
+	}
 
-		SnapshotDirectory snapshotDirectory;
+	@Nonnull
+	private SnapshotDirectory prepareLocalSnapshotDirectory(long checkpointId) throws IOException {
 
 		if (localRecoveryConfig.isLocalRecoveryEnabled()) {
 			// create a "permanent" snapshot directory for local recovery.
@@ -186,254 +195,217 @@ public class RocksIncrementalSnapshotStrategy<K> extends SnapshotStrategyBase<K>
 			File rdbSnapshotDir = new File(directory, "rocks_db");
 			Path path = new Path(rdbSnapshotDir.toURI());
 			// create a "permanent" snapshot directory because local recovery is active.
-			snapshotDirectory = SnapshotDirectory.permanent(path);
+			try {
+				return SnapshotDirectory.permanent(path);
+			} catch (IOException ex) {
+				try {
+					FileUtils.deleteDirectory(directory);
+				} catch (IOException delEx) {
+					ex = ExceptionUtils.firstOrSuppressed(delEx, ex);
+				}
+				throw ex;
+			}
 		} else {
 			// create a "temporary" snapshot directory because local recovery is inactive.
 			Path path = new Path(instanceBasePath.getAbsolutePath(), "chk-" + checkpointId);
-			snapshotDirectory = SnapshotDirectory.temporary(path);
-		}
-
-		final RocksDBIncrementalSnapshotOperation snapshotOperation =
-			new RocksDBIncrementalSnapshotOperation(
-				checkpointStreamFactory,
-				snapshotDirectory,
-				checkpointId);
-
-		try {
-			snapshotOperation.takeSnapshot();
-		} catch (Exception e) {
-			snapshotOperation.stop();
-			snapshotOperation.releaseResources(true);
-			throw e;
+			return SnapshotDirectory.temporary(path);
 		}
+	}
 
-		return new FutureTask<SnapshotResult<KeyedStateHandle>>(
-			snapshotOperation::runSnapshot
-		) {
-			@Override
-			public boolean cancel(boolean mayInterruptIfRunning) {
-				snapshotOperation.stop();
-				return super.cancel(mayInterruptIfRunning);
-			}
+	private Set<StateHandleID> snapshotMetaData(
+		long checkpointId,
+		@Nonnull List<StateMetaInfoSnapshot> stateMetaInfoSnapshots) {
 
-			@Override
-			protected void done() {
-				snapshotOperation.releaseResources(isCancelled());
-			}
-		};
-	}
+		final long lastCompletedCheckpoint;
+		final Set<StateHandleID> baseSstFiles;
 
-	@Override
-	public void notifyCheckpointComplete(long completedCheckpointId) {
+		// use the last completed checkpoint as the comparison base.
 		synchronized (materializedSstFiles) {
+			lastCompletedCheckpoint = lastCompletedCheckpointId;
+			baseSstFiles = materializedSstFiles.get(lastCompletedCheckpoint);
+		}
+		LOG.trace("Taking incremental snapshot for checkpoint {}. Snapshot is based on last completed checkpoint {} " +
+			"assuming the following (shared) files as base: {}.", checkpointId, lastCompletedCheckpoint, baseSstFiles);
 
-			if (completedCheckpointId < lastCompletedCheckpointId) {
-				return;
-			}
-
-			materializedSstFiles.keySet().removeIf(checkpointId -> checkpointId < completedCheckpointId);
+		// snapshot meta data to save
+		for (Map.Entry<String, Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase>> stateMetaInfoEntry
+			: kvStateInformation.entrySet()) {
+			stateMetaInfoSnapshots.add(stateMetaInfoEntry.getValue().f1.snapshot());
+		}
+		return baseSstFiles;
+	}
 
-			lastCompletedCheckpointId = completedCheckpointId;
+	private void takeDBNativeCheckpoint(@Nonnull SnapshotDirectory outputDirectory) throws Exception {
+		// create hard links of living files in the output path
+		try (
+			ResourceGuard.Lease ignored = rocksDBResourceGuard.acquireResource();
+			Checkpoint checkpoint = Checkpoint.create(db)) {
+			checkpoint.createCheckpoint(outputDirectory.getDirectory().getPath());
+		} catch (Exception ex) {
+			try {
+				outputDirectory.cleanup();
+			} catch (IOException cleanupEx) {
+				ex = ExceptionUtils.firstOrSuppressed(cleanupEx, ex);
+			}
+			throw ex;
 		}
 	}
 
 	/**
 	 * Encapsulates the process to perform an incremental snapshot of a RocksDBKeyedStateBackend.
 	 */
-	private final class RocksDBIncrementalSnapshotOperation {
+	private final class RocksDBIncrementalSnapshotOperation
+		extends AsyncSnapshotCallable<SnapshotResult<KeyedStateHandle>> {
 
-		/**
-		 * Stream factory that creates the outpus streams to DFS.
-		 */
-		private final CheckpointStreamFactory checkpointStreamFactory;
+		private static final int READ_BUFFER_SIZE = 16 * 1024;
 
-		/**
-		 * Id for the current checkpoint.
-		 */
+		/** Id for the current checkpoint. */
 		private final long checkpointId;
 
-		/**
-		 * All sst files that were part of the last previously completed checkpoint.
-		 */
-		private Set<StateHandleID> baseSstFiles;
+		/** Stream factory that creates the output streams to DFS. */
+		@Nonnull
+		private final CheckpointStreamFactory checkpointStreamFactory;
 
-		/**
-		 * The state meta data.
-		 */
+		/** The state meta data. */
+		@Nonnull
 		private final List<StateMetaInfoSnapshot> stateMetaInfoSnapshots;
 
-		/**
-		 * Local directory for the RocksDB native backup.
-		 */
-		private SnapshotDirectory localBackupDirectory;
-
-		// Registry for all opened i/o streams
-		private final CloseableRegistry closeableRegistry;
-
-		// new sst files since the last completed checkpoint
-		private final Map<StateHandleID, StreamStateHandle> sstFiles;
-
-		// handles to the misc files in the current snapshot
-		private final Map<StateHandleID, StreamStateHandle> miscFiles;
-
-		// This lease protects from concurrent disposal of the native rocksdb instance.
-		private final ResourceGuard.Lease dbLease;
+		/** Local directory for the RocksDB native backup. */
+		@Nonnull
+		private final SnapshotDirectory localBackupDirectory;
 
-		private SnapshotResult<StreamStateHandle> metaStateHandle;
+		/** All sst files that were part of the last previously completed checkpoint. */
+		@Nullable
+		private final Set<StateHandleID> baseSstFiles;
 
 		private RocksDBIncrementalSnapshotOperation(
-			CheckpointStreamFactory checkpointStreamFactory,
-			SnapshotDirectory localBackupDirectory,
-			long checkpointId) throws IOException {
+			long checkpointId,
+			@Nonnull CheckpointStreamFactory checkpointStreamFactory,
+			@Nonnull SnapshotDirectory localBackupDirectory,
+			@Nullable Set<StateHandleID> baseSstFiles,
+			@Nonnull List<StateMetaInfoSnapshot> stateMetaInfoSnapshots) {
 
 			this.checkpointStreamFactory = checkpointStreamFactory;
+			this.baseSstFiles = baseSstFiles;
 			this.checkpointId = checkpointId;
 			this.localBackupDirectory = localBackupDirectory;
-			this.stateMetaInfoSnapshots = new ArrayList<>();
-			this.closeableRegistry = new CloseableRegistry();
-			this.sstFiles = new HashMap<>();
-			this.miscFiles = new HashMap<>();
-			this.metaStateHandle = null;
-			this.dbLease = rocksDBResourceGuard.acquireResource();
+			this.stateMetaInfoSnapshots = stateMetaInfoSnapshots;
 		}
 
-		private StreamStateHandle materializeStateData(Path filePath) throws Exception {
-			FSDataInputStream inputStream = null;
-			CheckpointStreamFactory.CheckpointStateOutputStream outputStream = null;
+		@Override
+		protected SnapshotResult<KeyedStateHandle> callInternal() throws Exception {
 
-			try {
-				final byte[] buffer = new byte[8 * 1024];
+			boolean completed = false;
 
-				FileSystem backupFileSystem = localBackupDirectory.getFileSystem();
-				inputStream = backupFileSystem.open(filePath);
-				closeableRegistry.registerCloseable(inputStream);
+			// Handle to the meta data file
+			SnapshotResult<StreamStateHandle> metaStateHandle = null;
+			// Handles to new sst files since the last completed checkpoint will go here
+			final Map<StateHandleID, StreamStateHandle> sstFiles = new HashMap<>();
+			// Handles to the misc files in the current snapshot will go here
+			final Map<StateHandleID, StreamStateHandle> miscFiles = new HashMap<>();
 
-				outputStream = checkpointStreamFactory
-					.createCheckpointStateOutputStream(CheckpointedStateScope.SHARED);
-				closeableRegistry.registerCloseable(outputStream);
+			try {
 
-				while (true) {
-					int numBytes = inputStream.read(buffer);
+				metaStateHandle = materializeMetaData();
 
-					if (numBytes == -1) {
-						break;
-					}
+				// Sanity checks - they should never fail
+				Preconditions.checkNotNull(metaStateHandle, "Metadata was not properly created.");
+				Preconditions.checkNotNull(metaStateHandle.getJobManagerOwnedSnapshot(),
+					"Metadata for job manager was not properly created.");
 
-					outputStream.write(buffer, 0, numBytes);
-				}
+				uploadSstFiles(sstFiles, miscFiles);
 
-				StreamStateHandle result = null;
-				if (closeableRegistry.unregisterCloseable(outputStream)) {
-					result = outputStream.closeAndGetHandle();
-					outputStream = null;
+				synchronized (materializedSstFiles) {
+					materializedSstFiles.put(checkpointId, sstFiles.keySet());
 				}
-				return result;
 
-			} finally {
-
-				if (closeableRegistry.unregisterCloseable(inputStream)) {
-					inputStream.close();
+				final IncrementalKeyedStateHandle jmIncrementalKeyedStateHandle =
+					new IncrementalKeyedStateHandle(
+						backendUID,
+						keyGroupRange,
+						checkpointId,
+						sstFiles,
+						miscFiles,
+						metaStateHandle.getJobManagerOwnedSnapshot());
+
+				final DirectoryStateHandle directoryStateHandle = localBackupDirectory.completeSnapshotAndGetHandle();
+				final SnapshotResult<KeyedStateHandle> snapshotResult;
+				if (directoryStateHandle != null && metaStateHandle.getTaskLocalSnapshot() != null) {
+
+					IncrementalLocalKeyedStateHandle localDirKeyedStateHandle =
+						new IncrementalLocalKeyedStateHandle(
+							backendUID,
+							checkpointId,
+							directoryStateHandle,
+							keyGroupRange,
+							metaStateHandle.getTaskLocalSnapshot(),
+							sstFiles.keySet());
+
+					snapshotResult = SnapshotResult.withLocalState(jmIncrementalKeyedStateHandle, localDirKeyedStateHandle);
+				} else {
+					snapshotResult = SnapshotResult.of(jmIncrementalKeyedStateHandle);
 				}
 
-				if (closeableRegistry.unregisterCloseable(outputStream)) {
-					outputStream.close();
+				completed = true;
+
+				return snapshotResult;
+			} finally {
+				if (!completed) {
+					final List<StateObject> statesToDiscard =
+						new ArrayList<>(1 + miscFiles.size() + sstFiles.size());
+					statesToDiscard.add(metaStateHandle);
+					statesToDiscard.addAll(miscFiles.values());
+					statesToDiscard.addAll(sstFiles.values());
+					cleanupIncompleteSnapshot(statesToDiscard);
 				}
 			}
 		}
 
-		@Nonnull
-		private SnapshotResult<StreamStateHandle> materializeMetaData() throws Exception {
-
-			CheckpointStreamWithResultProvider streamWithResultProvider =
-
-				localRecoveryConfig.isLocalRecoveryEnabled() ?
-
-					CheckpointStreamWithResultProvider.createDuplicatingStream(
-						checkpointId,
-						CheckpointedStateScope.EXCLUSIVE,
-						checkpointStreamFactory,
-						localRecoveryConfig.getLocalStateDirectoryProvider()) :
-
-					CheckpointStreamWithResultProvider.createSimpleStream(
-						CheckpointedStateScope.EXCLUSIVE,
-						checkpointStreamFactory);
-
+		@Override
+		protected void cleanupProvidedResources() {
 			try {
-				closeableRegistry.registerCloseable(streamWithResultProvider);
-
-				//no need for compression scheme support because sst-files are already compressed
-				KeyedBackendSerializationProxy<K> serializationProxy =
-					new KeyedBackendSerializationProxy<>(
-						keySerializer,
-						stateMetaInfoSnapshots,
-						false);
-
-				DataOutputView out =
-					new DataOutputViewStreamWrapper(streamWithResultProvider.getCheckpointOutputStream());
-
-				serializationProxy.write(out);
+				if (localBackupDirectory.exists()) {
+					LOG.trace("Running cleanup for local RocksDB backup directory {}.", localBackupDirectory);
+					boolean cleanupOk = localBackupDirectory.cleanup();
 
-				if (closeableRegistry.unregisterCloseable(streamWithResultProvider)) {
-					SnapshotResult<StreamStateHandle> result =
-						streamWithResultProvider.closeAndFinalizeCheckpointStreamResult();
-					streamWithResultProvider = null;
-					return result;
-				} else {
-					throw new IOException("Stream already closed and cannot return a handle.");
-				}
-			} finally {
-				if (streamWithResultProvider != null) {
-					if (closeableRegistry.unregisterCloseable(streamWithResultProvider)) {
-						IOUtils.closeQuietly(streamWithResultProvider);
+					if (!cleanupOk) {
+						LOG.debug("Could not properly cleanup local RocksDB backup directory.");
 					}
 				}
+			} catch (IOException e) {
+				LOG.warn("Could not properly cleanup local RocksDB backup directory.", e);
 			}
 		}
 
-		void takeSnapshot() throws Exception {
-
-			final long lastCompletedCheckpoint;
-
-			// use the last completed checkpoint as the comparison base.
-			synchronized (materializedSstFiles) {
-				lastCompletedCheckpoint = lastCompletedCheckpointId;
-				baseSstFiles = materializedSstFiles.get(lastCompletedCheckpoint);
-			}
-
-			LOG.trace("Taking incremental snapshot for checkpoint {}. Snapshot is based on last completed checkpoint {} " +
-				"assuming the following (shared) files as base: {}.", checkpointId, lastCompletedCheckpoint, baseSstFiles);
-
-			// save meta data
-			for (Map.Entry<String, Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase>> stateMetaInfoEntry
-				: kvStateInformation.entrySet()) {
-				stateMetaInfoSnapshots.add(stateMetaInfoEntry.getValue().f1.snapshot());
-			}
+		@Override
+		protected void logAsyncSnapshotComplete(long startTime) {
+			logAsyncCompleted(checkpointStreamFactory, startTime);
+		}
 
-			LOG.trace("Local RocksDB checkpoint goes to backup path {}.", localBackupDirectory);
+		private void cleanupIncompleteSnapshot(@Nonnull List<StateObject> statesToDiscard) {
 
-			if (localBackupDirectory.exists()) {
-				throw new IllegalStateException("Unexpected existence of the backup directory.");
+			try {
+				StateUtil.bestEffortDiscardAllStateObjects(statesToDiscard);
+			} catch (Exception e) {
+				LOG.warn("Could not properly discard states.", e);
 			}
 
-			// create hard links of living files in the snapshot path
-			try (Checkpoint checkpoint = Checkpoint.create(db)) {
-				checkpoint.createCheckpoint(localBackupDirectory.getDirectory().getPath());
+			if (localBackupDirectory.isSnapshotCompleted()) {
+				try {
+					DirectoryStateHandle directoryStateHandle =
+						localBackupDirectory.completeSnapshotAndGetHandle();
+					if (directoryStateHandle != null) {
+						directoryStateHandle.discardState();
+					}
+				} catch (Exception e) {
+					LOG.warn("Could not properly discard local state.", e);
+				}
 			}
 		}
 
-		@Nonnull
-		SnapshotResult<KeyedStateHandle> runSnapshot() throws Exception {
-
-			cancelStreamRegistry.registerCloseable(closeableRegistry);
-
-			// write meta data
-			metaStateHandle = materializeMetaData();
-
-			// sanity checks - they should never fail
-			Preconditions.checkNotNull(metaStateHandle,
-				"Metadata was not properly created.");
-			Preconditions.checkNotNull(metaStateHandle.getJobManagerOwnedSnapshot(),
-				"Metadata for job manager was not properly created.");
+		private void uploadSstFiles(
+			@Nonnull Map<StateHandleID, StreamStateHandle> sstFiles,
+			@Nonnull Map<StateHandleID, StreamStateHandle> miscFiles) throws Exception {
 
 			// write state data
 			Preconditions.checkState(localBackupDirectory.exists());
@@ -456,120 +428,104 @@ public class RocksIncrementalSnapshotStrategy<K> extends SnapshotStrategyBase<K>
 								stateHandleID,
 								new PlaceholderStreamStateHandle());
 						} else {
-							sstFiles.put(stateHandleID, materializeStateData(filePath));
+							sstFiles.put(stateHandleID, uploadLocalFileToCheckpointFs(filePath));
 						}
 					} else {
-						StreamStateHandle fileHandle = materializeStateData(filePath);
+						StreamStateHandle fileHandle = uploadLocalFileToCheckpointFs(filePath);
 						miscFiles.put(stateHandleID, fileHandle);
 					}
 				}
 			}
+		}
 
-			synchronized (materializedSstFiles) {
-				materializedSstFiles.put(checkpointId, sstFiles.keySet());
-			}
+		private StreamStateHandle uploadLocalFileToCheckpointFs(Path filePath) throws Exception {
+			FSDataInputStream inputStream = null;
+			CheckpointStreamFactory.CheckpointStateOutputStream outputStream = null;
 
-			IncrementalKeyedStateHandle jmIncrementalKeyedStateHandle = new IncrementalKeyedStateHandle(
-				backendUID,
-				keyGroupRange,
-				checkpointId,
-				sstFiles,
-				miscFiles,
-				metaStateHandle.getJobManagerOwnedSnapshot());
+			try {
+				final byte[] buffer = new byte[READ_BUFFER_SIZE];
 
-			StreamStateHandle taskLocalSnapshotMetaDataStateHandle = metaStateHandle.getTaskLocalSnapshot();
-			DirectoryStateHandle directoryStateHandle = null;
+				FileSystem backupFileSystem = localBackupDirectory.getFileSystem();
+				inputStream = backupFileSystem.open(filePath);
+				registerCloseableForCancellation(inputStream);
 
-			try {
+				outputStream = checkpointStreamFactory
+					.createCheckpointStateOutputStream(CheckpointedStateScope.SHARED);
+				registerCloseableForCancellation(outputStream);
 
-				directoryStateHandle = localBackupDirectory.completeSnapshotAndGetHandle();
-			} catch (IOException ex) {
+				while (true) {
+					int numBytes = inputStream.read(buffer);
 
-				Exception collector = ex;
+					if (numBytes == -1) {
+						break;
+					}
 
-				try {
-					taskLocalSnapshotMetaDataStateHandle.discardState();
-				} catch (Exception discardEx) {
-					collector = ExceptionUtils.firstOrSuppressed(discardEx, collector);
+					outputStream.write(buffer, 0, numBytes);
 				}
 
-				LOG.warn("Problem with local state snapshot.", collector);
-			}
-
-			if (directoryStateHandle != null && taskLocalSnapshotMetaDataStateHandle != null) {
+				StreamStateHandle result = null;
+				if (unregisterCloseableFromCancellation(outputStream)) {
+					result = outputStream.closeAndGetHandle();
+					outputStream = null;
+				}
+				return result;
 
-				IncrementalLocalKeyedStateHandle localDirKeyedStateHandle =
-					new IncrementalLocalKeyedStateHandle(
-						backendUID,
-						checkpointId,
-						directoryStateHandle,
-						keyGroupRange,
-						taskLocalSnapshotMetaDataStateHandle,
-						sstFiles.keySet());
-				return SnapshotResult.withLocalState(jmIncrementalKeyedStateHandle, localDirKeyedStateHandle);
-			} else {
-				return SnapshotResult.of(jmIncrementalKeyedStateHandle);
-			}
-		}
+			} finally {
 
-		void stop() {
+				if (unregisterCloseableFromCancellation(inputStream)) {
+					IOUtils.closeQuietly(inputStream);
+				}
 
-			if (cancelStreamRegistry.unregisterCloseable(closeableRegistry)) {
-				try {
-					closeableRegistry.close();
-				} catch (IOException e) {
-					LOG.warn("Could not properly close io streams.", e);
+				if (unregisterCloseableFromCancellation(outputStream)) {
+					IOUtils.closeQuietly(outputStream);
 				}
 			}
 		}
 
-		void releaseResources(boolean canceled) {
+		@Nonnull
+		private SnapshotResult<StreamStateHandle> materializeMetaData() throws Exception {
 
-			dbLease.close();
+			CheckpointStreamWithResultProvider streamWithResultProvider =
 
-			if (cancelStreamRegistry.unregisterCloseable(closeableRegistry)) {
-				try {
-					closeableRegistry.close();
-				} catch (IOException e) {
-					LOG.warn("Exception on closing registry.", e);
-				}
-			}
+				localRecoveryConfig.isLocalRecoveryEnabled() ?
 
-			try {
-				if (localBackupDirectory.exists()) {
-					LOG.trace("Running cleanup for local RocksDB backup directory {}.", localBackupDirectory);
-					boolean cleanupOk = localBackupDirectory.cleanup();
+					CheckpointStreamWithResultProvider.createDuplicatingStream(
+						checkpointId,
+						CheckpointedStateScope.EXCLUSIVE,
+						checkpointStreamFactory,
+						localRecoveryConfig.getLocalStateDirectoryProvider()) :
 
-					if (!cleanupOk) {
-						LOG.debug("Could not properly cleanup local RocksDB backup directory.");
-					}
-				}
-			} catch (IOException e) {
-				LOG.warn("Could not properly cleanup local RocksDB backup directory.", e);
-			}
+					CheckpointStreamWithResultProvider.createSimpleStream(
+						CheckpointedStateScope.EXCLUSIVE,
+						checkpointStreamFactory);
 
-			if (canceled) {
-				Collection<StateObject> statesToDiscard =
-					new ArrayList<>(1 + miscFiles.size() + sstFiles.size());
+			registerCloseableForCancellation(streamWithResultProvider);
 
-				statesToDiscard.add(metaStateHandle);
-				statesToDiscard.addAll(miscFiles.values());
-				statesToDiscard.addAll(sstFiles.values());
+			try {
+				//no need for compression scheme support because sst-files are already compressed
+				KeyedBackendSerializationProxy<K> serializationProxy =
+					new KeyedBackendSerializationProxy<>(
+						keySerializer,
+						stateMetaInfoSnapshots,
+						false);
 
-				try {
-					StateUtil.bestEffortDiscardAllStateObjects(statesToDiscard);
-				} catch (Exception e) {
-					LOG.warn("Could not properly discard states.", e);
-				}
+				DataOutputView out =
+					new DataOutputViewStreamWrapper(streamWithResultProvider.getCheckpointOutputStream());
 
-				if (localBackupDirectory.isSnapshotCompleted()) {
-					try {
-						DirectoryStateHandle directoryStateHandle = localBackupDirectory.completeSnapshotAndGetHandle();
-						if (directoryStateHandle != null) {
-							directoryStateHandle.discardState();
-						}
-					} catch (Exception e) {
-						LOG.warn("Could not properly discard local state.", e);
+				serializationProxy.write(out);
+
+				if (unregisterCloseableFromCancellation(streamWithResultProvider)) {
+					SnapshotResult<StreamStateHandle> result =
+						streamWithResultProvider.closeAndFinalizeCheckpointStreamResult();
+					streamWithResultProvider = null;
+					return result;
+				} else {
+					throw new IOException("Stream already closed and cannot return a handle.");
+				}
+			} finally {
+				if (streamWithResultProvider != null) {
+					if (unregisterCloseableFromCancellation(streamWithResultProvider)) {
+						IOUtils.closeQuietly(streamWithResultProvider);
 					}
 				}
 			}
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index db504d5..9ee8892 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -1076,7 +1076,7 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 				owner.asyncOperationsThreadPool.submit(asyncCheckpointRunnable);
 
 				if (LOG.isDebugEnabled()) {
-					LOG.debug("{} - finished synchronous part of checkpoint {}." +
+					LOG.debug("{} - finished synchronous part of checkpoint {}. " +
 							"Alignment duration: {} ms, snapshot duration {} ms",
 						owner.getName(), checkpointMetaData.getCheckpointId(),
 						checkpointMetrics.getAlignmentDurationNanos() / 1_000_000,
@@ -1095,7 +1095,7 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 				}
 
 				if (LOG.isDebugEnabled()) {
-					LOG.debug("{} - did NOT finish synchronous part of checkpoint {}." +
+					LOG.debug("{} - did NOT finish synchronous part of checkpoint {}. " +
 							"Alignment duration: {} ms, snapshot duration {} ms",
 						owner.getName(), checkpointMetaData.getCheckpointId(),
 						checkpointMetrics.getAlignmentDurationNanos() / 1_000_000,
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java
index d8f577d..cd8a4fa 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java
@@ -82,6 +82,7 @@ import org.apache.flink.util.TestLogger;
 import org.junit.Assert;
 import org.junit.Test;
 
+import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
 
 import java.io.IOException;
@@ -305,12 +306,13 @@ public class TaskCheckpointingBehaviourTest extends TestLogger {
 				env.getUserClassLoader(),
 				env.getExecutionConfig(),
 				true) {
+				@Nonnull
 				@Override
 				public RunnableFuture<SnapshotResult<OperatorStateHandle>> snapshot(
 					long checkpointId,
 					long timestamp,
-					CheckpointStreamFactory streamFactory,
-					CheckpointOptions checkpointOptions) throws Exception {
+					@Nonnull CheckpointStreamFactory streamFactory,
+					@Nonnull CheckpointOptions checkpointOptions) throws Exception {
 
 					throw new Exception("Sync part snapshot exception.");
 				}
@@ -334,12 +336,13 @@ public class TaskCheckpointingBehaviourTest extends TestLogger {
 				env.getUserClassLoader(),
 				env.getExecutionConfig(),
 				true) {
+				@Nonnull
 				@Override
 				public RunnableFuture<SnapshotResult<OperatorStateHandle>> snapshot(
 					long checkpointId,
 					long timestamp,
-					CheckpointStreamFactory streamFactory,
-					CheckpointOptions checkpointOptions) throws Exception {
+					@Nonnull CheckpointStreamFactory streamFactory,
+					@Nonnull CheckpointOptions checkpointOptions) throws Exception {
 
 					return new FutureTask<>(() -> {
 						throw new Exception("Async part snapshot exception.");
diff --git a/flink-test-utils-parent/flink-test-utils-junit/src/main/java/org/apache/flink/core/testutils/OneShotLatch.java b/flink-test-utils-parent/flink-test-utils-junit/src/main/java/org/apache/flink/core/testutils/OneShotLatch.java
index 7fed5eb..bef23bb 100644
--- a/flink-test-utils-parent/flink-test-utils-junit/src/main/java/org/apache/flink/core/testutils/OneShotLatch.java
+++ b/flink-test-utils-parent/flink-test-utils-junit/src/main/java/org/apache/flink/core/testutils/OneShotLatch.java
@@ -18,6 +18,9 @@
 
 package org.apache.flink.core.testutils;
 
+import java.util.Collections;
+import java.util.IdentityHashMap;
+import java.util.Set;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 
@@ -31,6 +34,7 @@ import java.util.concurrent.TimeoutException;
 public final class OneShotLatch {
 
 	private final Object lock = new Object();
+	private final Set<Thread> waitersSet = Collections.newSetFromMap(new IdentityHashMap<>());
 
 	private volatile boolean triggered;
 
@@ -53,7 +57,13 @@ public final class OneShotLatch {
 	public void await() throws InterruptedException {
 		synchronized (lock) {
 			while (!triggered) {
-				lock.wait();
+				Thread thread = Thread.currentThread();
+				try {
+					waitersSet.add(thread);
+					lock.wait();
+				} finally {
+					waitersSet.remove(thread);
+				}
 			}
 		}
 	}
@@ -108,6 +118,12 @@ public final class OneShotLatch {
 		return triggered;
 	}
 
+	public int getWaitersCount() {
+		synchronized (lock) {
+			return waitersSet.size();
+		}
+	}
+
 	/**
 	 * Resets the latch so that {@link #isTriggered()} returns false.
 	 */