You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by sr...@apache.org on 2018/08/21 12:31:37 UTC

[flink] 02/02: [FLINK-10042][state] (part 2) Refactoring of snapshot algorithms for better abstraction and cleaner resource management

This is an automated email from the ASF dual-hosted git repository.

srichter pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit f803280bb933d968976e79b9efb5953bed308d96
Author: Stefan Richter <s....@data-artisans.com>
AuthorDate: Thu Aug 9 22:23:42 2018 +0200

    [FLINK-10042][state] (part 2) Refactoring of snapshot algorithms for better abstraction and cleaner resource management
    
    This closes #6556.
---
 .../async/AbstractAsyncCallableWithResources.java  | 194 --------
 .../flink/runtime/io/async/AsyncDoneCallback.java  |  33 --
 .../flink/runtime/io/async/AsyncStoppable.java     |  45 --
 .../io/async/AsyncStoppableTaskWithCallback.java   |  59 ---
 .../io/async/StoppableCallbackCallable.java        |  30 --
 .../runtime/state/AbstractSnapshotStrategy.java    |  79 +++
 .../flink/runtime/state/AsyncSnapshotCallable.java | 190 +++++++
 .../runtime/state/DefaultOperatorStateBackend.java | 369 +++++++-------
 .../flink/runtime/state/SnapshotStrategy.java      |  13 +-
 .../apache/flink/runtime/state/Snapshotable.java   |  27 +-
 .../runtime/state/heap/HeapKeyedStateBackend.java  | 145 +++---
 .../runtime/state/AsyncSnapshotCallableTest.java   | 326 ++++++++++++
 .../runtime/state/OperatorStateBackendTest.java    |   4 +-
 .../flink/runtime/state/StateBackendTestBase.java  |   6 +-
 .../state/ttl/mock/MockKeyedStateBackend.java      |   5 +-
 .../streaming/state/RocksDBKeyedStateBackend.java  |  53 +-
 ...yBase.java => RocksDBSnapshotStrategyBase.java} |  57 ++-
 .../state/snapshot/RocksFullSnapshotStrategy.java  | 255 ++++------
 .../snapshot/RocksIncrementalSnapshotStrategy.java | 552 ++++++++++-----------
 .../flink/streaming/runtime/tasks/StreamTask.java  |   4 +-
 .../tasks/TaskCheckpointingBehaviourTest.java      |  11 +-
 .../apache/flink/core/testutils/OneShotLatch.java  |  18 +-
 22 files changed, 1329 insertions(+), 1146 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/async/AbstractAsyncCallableWithResources.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/async/AbstractAsyncCallableWithResources.java
deleted file mode 100644
index bc0116c..0000000
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/async/AbstractAsyncCallableWithResources.java
+++ /dev/null
@@ -1,194 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.runtime.io.async;
-
-import org.apache.flink.util.ExceptionUtils;
-
-import java.io.IOException;
-
-/**
- * This abstract class encapsulates the lifecycle and execution strategy for asynchronous operations that use resources.
- *
- * @param <V> return type of the asynchronous call.
- */
-public abstract class AbstractAsyncCallableWithResources<V> implements StoppableCallbackCallable<V> {
-
-	/** Tracks if the stop method was called on this object. */
-	private volatile boolean stopped;
-
-	/** Tracks if call method was executed (only before stop calls). */
-	private volatile boolean called;
-
-	/** Stores a collected exception if there was one during stop. */
-	private volatile Exception stopException;
-
-	public AbstractAsyncCallableWithResources() {
-		this.stopped = false;
-		this.called = false;
-	}
-
-	/**
-	 * This method implements the strategy for the actual IO operation:
-	 * <p>
-	 * 1) Acquire resources asynchronously and atomically w.r.t stopping.
-	 * 2) Performs the operation
-	 * 3) Releases resources.
-	 *
-	 * @return Result of the IO operation, e.g. a deserialized object.
-	 * @throws Exception exception that happened during the call.
-	 */
-	@Override
-	public final V call() throws Exception {
-
-		V result = null;
-		Exception collectedException = null;
-
-		try {
-			synchronized (this) {
-
-				if (stopped) {
-					throw new IOException("Task was already stopped.");
-				}
-
-				called = true;
-				// Get resources in async part, atomically w.r.t. stopping.
-				acquireResources();
-			}
-
-			// The main work is performed here.
-			result = performOperation();
-
-		} catch (Exception ex) {
-			collectedException = ex;
-		} finally {
-
-			try {
-				// Cleanup
-				releaseResources();
-			} catch (Exception relEx) {
-				collectedException = ExceptionUtils.firstOrSuppressed(relEx, collectedException);
-			}
-
-			if (collectedException != null) {
-				throw collectedException;
-			}
-		}
-
-		return result;
-	}
-
-	/**
-	 * Open the IO Handle (e.g. a stream) on which the operation will be performed.
-	 *
-	 * @return the opened IO handle that implements #Closeable
-	 * @throws Exception if there was a problem in acquiring.
-	 */
-	protected abstract void acquireResources() throws Exception;
-
-	/**
-	 * Implements the actual operation.
-	 *
-	 * @return Result of the operation
-	 * @throws Exception if there was a problem in executing the operation.
-	 */
-	protected abstract V performOperation() throws Exception;
-
-	/**
-	 * Releases resources acquired by this object.
-	 *
-	 * @throws Exception if there was a problem in releasing resources.
-	 */
-	protected abstract void releaseResources() throws Exception;
-
-	/**
-	 * This method implements how the operation is stopped. Usually this involves interrupting or closing some
-	 * resources like streams to return from blocking calls.
-	 *
-	 * @throws Exception on problems during the stopping.
-	 */
-	protected abstract void stopOperation() throws Exception;
-
-	/**
-	 * Stops the I/O operation by closing the I/O handle. If an exception is thrown on close, it can be accessed via
-	 * #getStopException().
-	 */
-	@Override
-	public final void stop() {
-
-		synchronized (this) {
-
-			// Make sure that call can not enter execution from here.
-			if (stopped) {
-				return;
-			} else {
-				stopped = true;
-			}
-		}
-
-		if (called) {
-			// Async call is executing -> attempt to stop it and releaseResources() will happen inside the async method.
-			try {
-				stopOperation();
-			} catch (Exception stpEx) {
-				this.stopException = stpEx;
-			}
-		} else {
-			// Async call was not executed, so we also need to releaseResources() here.
-			try {
-				releaseResources();
-			} catch (Exception relEx) {
-				stopException = relEx;
-			}
-		}
-	}
-
-	/**
-	 * Optional callback that subclasses can implement. This is called when the callable method completed, e.g. because
-	 * it finished or was stopped.
-	 */
-	@Override
-	public void done(boolean canceled) {
-		//optional callback hook
-	}
-
-	/**
-	 * True once the async method was called.
-	 */
-	public boolean isCalled() {
-		return called;
-	}
-
-	/**
-	 * Check if the IO operation is stopped
-	 *
-	 * @return true if stop() was called
-	 */
-	@Override
-	public boolean isStopped() {
-		return stopped;
-	}
-
-	/**
-	 * Returns a potential exception that might have been observed while stopping the operation.
-	 */
-	@Override
-	public Exception getStopException() {
-		return stopException;
-	}
-}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/async/AsyncDoneCallback.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/async/AsyncDoneCallback.java
deleted file mode 100644
index dcc5525..0000000
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/async/AsyncDoneCallback.java
+++ /dev/null
@@ -1,33 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.runtime.io.async;
-
-/**
- * Callback for an asynchronous operation that is called on termination
- */
-public interface AsyncDoneCallback {
-
-	/**
-	 * the callback
-	 *
-	 * @param canceled true if the callback is done, but was canceled
-	 */
-	void done(boolean canceled);
-
-}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/async/AsyncStoppable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/async/AsyncStoppable.java
deleted file mode 100644
index 8698600..0000000
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/async/AsyncStoppable.java
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.runtime.io.async;
-
-/**
- * An asynchronous operation that can be stopped.
- */
-public interface AsyncStoppable {
-
-	/**
-	 * Stop the operation
-	 */
-	void stop();
-
-	/**
-	 * Check whether the operation is stopped
-	 *
-	 * @return true iff operation is stopped
-	 */
-	boolean isStopped();
-
-	/**
-	 * Delivers Exception that might happen during {@link #stop()}
-	 *
-	 * @return Exception that can happen during stop
-	 */
-	Exception getStopException();
-
-}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/async/AsyncStoppableTaskWithCallback.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/async/AsyncStoppableTaskWithCallback.java
deleted file mode 100644
index a30c607..0000000
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/async/AsyncStoppableTaskWithCallback.java
+++ /dev/null
@@ -1,59 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.runtime.io.async;
-
-import org.apache.flink.util.Preconditions;
-
-import java.util.concurrent.FutureTask;
-
-/**
- * @param <V> return type of the callable function
- */
-public class AsyncStoppableTaskWithCallback<V> extends FutureTask<V> {
-
-	protected final StoppableCallbackCallable<V> stoppableCallbackCallable;
-
-	public AsyncStoppableTaskWithCallback(StoppableCallbackCallable<V> callable) {
-		super(Preconditions.checkNotNull(callable));
-		this.stoppableCallbackCallable = callable;
-	}
-
-	@Override
-	public boolean cancel(boolean mayInterruptIfRunning) {
-		final boolean cancel = super.cancel(mayInterruptIfRunning);
-		if (cancel) {
-			stoppableCallbackCallable.stop();
-			// this is where we report done() for the cancel case, after calling stop().
-			stoppableCallbackCallable.done(true);
-		}
-		return cancel;
-	}
-
-	@Override
-	protected void done() {
-		// we suppress forwarding if we have not been canceled, because the cancel case will call to this method separately.
-		if (!isCancelled()) {
-			stoppableCallbackCallable.done(false);
-		}
-	}
-
-	public static <V> AsyncStoppableTaskWithCallback<V> from(StoppableCallbackCallable<V> callable) {
-		return new AsyncStoppableTaskWithCallback<>(callable);
-	}
-}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/async/StoppableCallbackCallable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/async/StoppableCallbackCallable.java
deleted file mode 100644
index d459316..0000000
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/async/StoppableCallbackCallable.java
+++ /dev/null
@@ -1,30 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.runtime.io.async;
-
-import java.util.concurrent.Callable;
-
-/**
- * A {@link Callable} that can be stopped and offers a callback on termination.
- *
- * @param <V> return value of the call operation.
- */
-public interface StoppableCallbackCallable<V> extends Callable<V>, AsyncStoppable, AsyncDoneCallback {
-
-}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractSnapshotStrategy.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractSnapshotStrategy.java
new file mode 100644
index 0000000..e0debe5
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractSnapshotStrategy.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nonnull;
+
+/**
+ * Abstract base class for implementing {@link SnapshotStrategy}, that gives a consistent logging across state backends.
+ *
+ * @param <T> type of the snapshot result.
+ */
+public abstract class AbstractSnapshotStrategy<T extends StateObject> implements SnapshotStrategy<SnapshotResult<T>> {
+
+	private static final Logger LOG = LoggerFactory.getLogger(AbstractSnapshotStrategy.class);
+
+	private static final String LOG_SYNC_COMPLETED_TEMPLATE = "{} ({}, synchronous part) in thread {} took {} ms.";
+	private static final String LOG_ASYNC_COMPLETED_TEMPLATE = "{} ({}, asynchronous part) in thread {} took {} ms.";
+
+	/** Descriptive name of the snapshot strategy that will appear in the log outputs and {@link #toString()}. */
+	@Nonnull
+	protected final String description;
+
+	protected AbstractSnapshotStrategy(@Nonnull String description) {
+		this.description = description;
+	}
+
+	/**
+	 * Logs the duration of the synchronous snapshot part from the given start time.
+	 */
+	public void logSyncCompleted(@Nonnull Object checkpointOutDescription, long startTime) {
+		logCompletedInternal(LOG_SYNC_COMPLETED_TEMPLATE, checkpointOutDescription, startTime);
+	}
+
+	/**
+	 * Logs the duration of the asynchronous snapshot part from the given start time.
+	 */
+	public void logAsyncCompleted(@Nonnull Object checkpointOutDescription, long startTime) {
+		logCompletedInternal(LOG_ASYNC_COMPLETED_TEMPLATE, checkpointOutDescription, startTime);
+	}
+
+	private void logCompletedInternal(
+		@Nonnull String template,
+		@Nonnull Object checkpointOutDescription,
+		long startTime) {
+
+		long duration = (System.currentTimeMillis() - startTime);
+
+		LOG.debug(
+			template,
+			description,
+			checkpointOutDescription,
+			Thread.currentThread(),
+			duration);
+	}
+
+	@Override
+	public String toString() {
+		return "SnapshotStrategy {" + description + "}";
+	}
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AsyncSnapshotCallable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AsyncSnapshotCallable.java
new file mode 100644
index 0000000..2c1a0be
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AsyncSnapshotCallable.java
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.core.fs.CloseableRegistry;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.concurrent.Callable;
+import java.util.concurrent.CancellationException;
+import java.util.concurrent.FutureTask;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+/**
+ * Base class that outlines the strategy for asynchronous snapshots. Implementations of this class are typically
+ * instantiated with resources that have been created in the synchronous part of a snapshot. Then, the implementation
+ * of {@link #callInternal()} is invoked in the asynchronous part. All resources created by this methods should
+ * be released by the end of the method. If the created resources are {@link Closeable} objects and can block in calls
+ * (e.g. in/output streams), they should be registered with the snapshot's {@link CloseableRegistry} so that the can
+ * be closed and unblocked on cancellation. After {@link #callInternal()} ended, {@link #logAsyncSnapshotComplete(long)}
+ * is called. In that method, implementations can emit log statements about the duration. At the very end, this class
+ * calls {@link #cleanupProvidedResources()}. The implementation of this method should release all provided resources
+ * that have been passed into the snapshot from the synchronous part of the snapshot.
+ *
+ * @param <T> type of the result.
+ */
+public abstract class AsyncSnapshotCallable<T> implements Callable<T> {
+
+	/** Message for the {@link CancellationException}. */
+	private static final String CANCELLATION_EXCEPTION_MSG = "Async snapshot was cancelled.";
+
+	private static final Logger LOG = LoggerFactory.getLogger(AsyncSnapshotCallable.class);
+
+	/** This is used to atomically claim ownership for the resource cleanup. */
+	@Nonnull
+	private final AtomicBoolean resourceCleanupOwnershipTaken;
+
+	/** Registers streams that can block in I/O during snapshot. Forwards close from taskCancelCloseableRegistry. */
+	@Nonnull
+	private final CloseableRegistry snapshotCloseableRegistry;
+
+	protected AsyncSnapshotCallable() {
+		this.snapshotCloseableRegistry = new CloseableRegistry();
+		this.resourceCleanupOwnershipTaken = new AtomicBoolean(false);
+	}
+
+	@Override
+	public T call() throws Exception {
+		final long startTime = System.currentTimeMillis();
+
+		if (resourceCleanupOwnershipTaken.compareAndSet(false, true)) {
+			try {
+				T result = callInternal();
+				logAsyncSnapshotComplete(startTime);
+				return result;
+			} catch (Exception ex) {
+				if (!snapshotCloseableRegistry.isClosed()) {
+					throw ex;
+				}
+			} finally {
+				closeSnapshotIO();
+				cleanup();
+			}
+		}
+
+		throw new CancellationException(CANCELLATION_EXCEPTION_MSG);
+	}
+
+	@VisibleForTesting
+	protected void cancel() {
+		closeSnapshotIO();
+		if (resourceCleanupOwnershipTaken.compareAndSet(false, true)) {
+			cleanup();
+		}
+	}
+
+	/**
+	 * Creates a future task from this and registers it with the given {@link CloseableRegistry}. The task is
+	 * unregistered again in {@link FutureTask#done()}.
+	 */
+	public AsyncSnapshotTask toAsyncSnapshotFutureTask(@Nonnull CloseableRegistry taskRegistry) throws IOException {
+		return new AsyncSnapshotTask(taskRegistry);
+	}
+
+	/**
+	 * {@link FutureTask} that wraps a {@link AsyncSnapshotCallable} and connects it with cancellation and closing.
+	 */
+	public class AsyncSnapshotTask extends FutureTask<T> {
+
+		@Nonnull
+		private final CloseableRegistry taskRegistry;
+
+		@Nonnull
+		private final Closeable cancelOnClose;
+
+		private AsyncSnapshotTask(@Nonnull CloseableRegistry taskRegistry) throws IOException {
+			super(AsyncSnapshotCallable.this);
+			this.cancelOnClose = () -> cancel(true);
+			this.taskRegistry = taskRegistry;
+			taskRegistry.registerCloseable(cancelOnClose);
+		}
+
+		@Override
+		public boolean cancel(boolean mayInterruptIfRunning) {
+			boolean result = super.cancel(mayInterruptIfRunning);
+			if (mayInterruptIfRunning) {
+				AsyncSnapshotCallable.this.cancel();
+			}
+			return result;
+		}
+
+		@Override
+		protected void done() {
+			super.done();
+			taskRegistry.unregisterCloseable(cancelOnClose);
+		}
+	}
+
+	/**
+	 * This method implements the (async) snapshot logic. Resources aquired within this method should be released at
+	 * the end of the method.
+	 */
+	protected abstract T callInternal() throws Exception;
+
+	/**
+	 * This method implements the cleanup of resources that have been passed in (from the sync part). Called after the
+	 * end of {@link #callInternal()}.
+	 */
+	protected abstract void cleanupProvidedResources();
+
+	/**
+	 * This method is invoked after completion of the snapshot and can be overridden to output a logging about the
+	 * duration of the async part.
+	 */
+	protected void logAsyncSnapshotComplete(long startTime) {
+
+	}
+
+	/**
+	 * Registers the {@link Closeable} with the snapshot's {@link CloseableRegistry}, so that it will be closed on
+	 * {@link #cancel()} and becomes unblocked. If the registry is already closed, the arguments is closed and an
+	 * {@link IOException} is emitted.
+	 */
+	protected void registerCloseableForCancellation(@Nullable Closeable toRegister) throws IOException {
+		snapshotCloseableRegistry.registerCloseable(toRegister);
+	}
+
+	/**
+	 * Unregisters the given argument from the snapshot's {@link CloseableRegistry} and returns <code>true</code> iff
+	 * the argument was registered before the call.
+	 */
+	protected boolean unregisterCloseableFromCancellation(@Nullable Closeable toUnregister) {
+		return snapshotCloseableRegistry.unregisterCloseable(toUnregister);
+	}
+
+	private void cleanup() {
+		cleanupProvidedResources();
+	}
+
+	private void closeSnapshotIO() {
+		try {
+			snapshotCloseableRegistry.close();
+		} catch (IOException e) {
+			LOG.warn("Could not properly close incremental snapshot streams.", e);
+		}
+	}
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
index d9fc41e..eae5a3b 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
@@ -36,8 +36,6 @@ import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
-import org.apache.flink.runtime.io.async.AbstractAsyncCallableWithResources;
-import org.apache.flink.runtime.io.async.AsyncStoppableTaskWithCallback;
 import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
 import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.StateMigrationException;
@@ -56,6 +54,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.FutureTask;
 import java.util.concurrent.RunnableFuture;
 
 /**
@@ -133,6 +132,8 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 
 	private final Map<String, BackendWritableBroadcastState<?, ?>> accessedBroadcastStatesByName;
 
+	private final AbstractSnapshotStrategy<OperatorStateHandle> snapshotStrategy;
+
 	public DefaultOperatorStateBackend(
 		ClassLoader userClassLoader,
 		ExecutionConfig executionConfig,
@@ -149,6 +150,7 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 		this.accessedBroadcastStatesByName = new HashMap<>();
 		this.restoredOperatorStateMetaInfos = new HashMap<>();
 		this.restoredBroadcastStateMetaInfos = new HashMap<>();
+		this.snapshotStrategy = new DefaultOperatorStateBackendSnapshotStrategy();
 	}
 
 	public ExecutionConfig getExecutionConfig() {
@@ -307,179 +309,6 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 	//  Snapshot and restore
 	// -------------------------------------------------------------------------------------------
 
-	@Override
-	public RunnableFuture<SnapshotResult<OperatorStateHandle>> snapshot(
-			final long checkpointId,
-			final long timestamp,
-			final CheckpointStreamFactory streamFactory,
-			final CheckpointOptions checkpointOptions) throws Exception {
-
-		final long syncStartTime = System.currentTimeMillis();
-
-		if (registeredOperatorStates.isEmpty() && registeredBroadcastStates.isEmpty()) {
-			return DoneFuture.of(SnapshotResult.empty());
-		}
-
-		final Map<String, PartitionableListState<?>> registeredOperatorStatesDeepCopies =
-				new HashMap<>(registeredOperatorStates.size());
-		final Map<String, BackendWritableBroadcastState<?, ?>> registeredBroadcastStatesDeepCopies =
-				new HashMap<>(registeredBroadcastStates.size());
-
-		ClassLoader snapshotClassLoader = Thread.currentThread().getContextClassLoader();
-		Thread.currentThread().setContextClassLoader(userClassloader);
-		try {
-			// eagerly create deep copies of the list and the broadcast states (if any)
-			// in the synchronous phase, so that we can use them in the async writing.
-
-			if (!registeredOperatorStates.isEmpty()) {
-				for (Map.Entry<String, PartitionableListState<?>> entry : registeredOperatorStates.entrySet()) {
-					PartitionableListState<?> listState = entry.getValue();
-					if (null != listState) {
-						listState = listState.deepCopy();
-					}
-					registeredOperatorStatesDeepCopies.put(entry.getKey(), listState);
-				}
-			}
-
-			if (!registeredBroadcastStates.isEmpty()) {
-				for (Map.Entry<String, BackendWritableBroadcastState<?, ?>> entry : registeredBroadcastStates.entrySet()) {
-					BackendWritableBroadcastState<?, ?> broadcastState = entry.getValue();
-					if (null != broadcastState) {
-						broadcastState = broadcastState.deepCopy();
-					}
-					registeredBroadcastStatesDeepCopies.put(entry.getKey(), broadcastState);
-				}
-			}
-		} finally {
-			Thread.currentThread().setContextClassLoader(snapshotClassLoader);
-		}
-
-		// implementation of the async IO operation, based on FutureTask
-		final AbstractAsyncCallableWithResources<SnapshotResult<OperatorStateHandle>> ioCallable =
-			new AbstractAsyncCallableWithResources<SnapshotResult<OperatorStateHandle>>() {
-
-				CheckpointStreamFactory.CheckpointStateOutputStream out = null;
-
-				@Override
-				protected void acquireResources() throws Exception {
-					openOutStream();
-				}
-
-				@Override
-				protected void releaseResources() {
-					closeOutStream();
-				}
-
-				@Override
-				protected void stopOperation() {
-					closeOutStream();
-				}
-
-				private void openOutStream() throws Exception {
-					out = streamFactory.createCheckpointStateOutputStream(CheckpointedStateScope.EXCLUSIVE);
-					closeStreamOnCancelRegistry.registerCloseable(out);
-				}
-
-				private void closeOutStream() {
-					if (closeStreamOnCancelRegistry.unregisterCloseable(out)) {
-						IOUtils.closeQuietly(out);
-					}
-				}
-
-				@Nonnull
-				@Override
-				public SnapshotResult<OperatorStateHandle> performOperation() throws Exception {
-					long asyncStartTime = System.currentTimeMillis();
-
-					CheckpointStreamFactory.CheckpointStateOutputStream localOut = this.out;
-
-					// get the registered operator state infos ...
-					List<StateMetaInfoSnapshot> operatorMetaInfoSnapshots =
-						new ArrayList<>(registeredOperatorStatesDeepCopies.size());
-
-					for (Map.Entry<String, PartitionableListState<?>> entry : registeredOperatorStatesDeepCopies.entrySet()) {
-						operatorMetaInfoSnapshots.add(entry.getValue().getStateMetaInfo().snapshot());
-					}
-
-					// ... get the registered broadcast operator state infos ...
-					List<StateMetaInfoSnapshot> broadcastMetaInfoSnapshots =
-							new ArrayList<>(registeredBroadcastStatesDeepCopies.size());
-
-					for (Map.Entry<String, BackendWritableBroadcastState<?, ?>> entry : registeredBroadcastStatesDeepCopies.entrySet()) {
-						broadcastMetaInfoSnapshots.add(entry.getValue().getStateMetaInfo().snapshot());
-					}
-
-					// ... write them all in the checkpoint stream ...
-					DataOutputView dov = new DataOutputViewStreamWrapper(localOut);
-
-					OperatorBackendSerializationProxy backendSerializationProxy =
-						new OperatorBackendSerializationProxy(operatorMetaInfoSnapshots, broadcastMetaInfoSnapshots);
-
-					backendSerializationProxy.write(dov);
-
-					// ... and then go for the states ...
-
-					// we put BOTH normal and broadcast state metadata here
-					final Map<String, OperatorStateHandle.StateMetaInfo> writtenStatesMetaData =
-							new HashMap<>(registeredOperatorStatesDeepCopies.size() + registeredBroadcastStatesDeepCopies.size());
-
-					for (Map.Entry<String, PartitionableListState<?>> entry :
-							registeredOperatorStatesDeepCopies.entrySet()) {
-
-						PartitionableListState<?> value = entry.getValue();
-						long[] partitionOffsets = value.write(localOut);
-						OperatorStateHandle.Mode mode = value.getStateMetaInfo().getAssignmentMode();
-						writtenStatesMetaData.put(
-							entry.getKey(),
-							new OperatorStateHandle.StateMetaInfo(partitionOffsets, mode));
-					}
-
-					// ... and the broadcast states themselves ...
-					for (Map.Entry<String, BackendWritableBroadcastState<?, ?>> entry :
-							registeredBroadcastStatesDeepCopies.entrySet()) {
-
-						BackendWritableBroadcastState<?, ?> value = entry.getValue();
-						long[] partitionOffsets = {value.write(localOut)};
-						OperatorStateHandle.Mode mode = value.getStateMetaInfo().getAssignmentMode();
-						writtenStatesMetaData.put(
-								entry.getKey(),
-								new OperatorStateHandle.StateMetaInfo(partitionOffsets, mode));
-					}
-
-					// ... and, finally, create the state handle.
-					OperatorStateHandle retValue = null;
-
-					if (closeStreamOnCancelRegistry.unregisterCloseable(out)) {
-
-						StreamStateHandle stateHandle = out.closeAndGetHandle();
-
-						if (stateHandle != null) {
-							retValue = new OperatorStreamStateHandle(writtenStatesMetaData, stateHandle);
-						}
-					}
-
-					if (asynchronousSnapshots) {
-						LOG.debug("DefaultOperatorStateBackend snapshot ({}, asynchronous part) in thread {} took {} ms.",
-							streamFactory, Thread.currentThread(), (System.currentTimeMillis() - asyncStartTime));
-					}
-
-					return SnapshotResult.of(retValue);
-				}
-			};
-
-		AsyncStoppableTaskWithCallback<SnapshotResult<OperatorStateHandle>> task =
-			AsyncStoppableTaskWithCallback.from(ioCallable);
-
-		if (!asynchronousSnapshots) {
-			task.run();
-		}
-
-		LOG.debug("DefaultOperatorStateBackend snapshot ({}, synchronous part) in thread {} took {} ms.",
-				streamFactory, Thread.currentThread(), (System.currentTimeMillis() - syncStartTime));
-
-		return task;
-	}
-
 	public void restore(Collection<OperatorStateHandle> restoreSnapshots) throws Exception {
 
 		if (null == restoreSnapshots || restoreSnapshots.isEmpty()) {
@@ -513,8 +342,7 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 					final RegisteredOperatorStateBackendMetaInfo<?> restoredMetaInfo =
 						new RegisteredOperatorStateBackendMetaInfo<>(restoredSnapshot);
 
-					if (restoredMetaInfo.getPartitionStateSerializer() == null ||
-						restoredMetaInfo.getPartitionStateSerializer() instanceof UnloadableDummyTypeSerializer) {
+					if (restoredMetaInfo.getPartitionStateSerializer() instanceof UnloadableDummyTypeSerializer) {
 
 						// must fail now if the previous serializer cannot be restored because there is no serializer
 						// capable of reading previous state
@@ -549,8 +377,7 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 					final RegisteredBroadcastStateBackendMetaInfo<?, ?> restoredMetaInfo =
 						new RegisteredBroadcastStateBackendMetaInfo<>(restoredSnapshot);
 
-					if (restoredMetaInfo.getKeySerializer() == null || restoredMetaInfo.getValueSerializer() == null ||
-						restoredMetaInfo.getKeySerializer() instanceof UnloadableDummyTypeSerializer ||
+					if (restoredMetaInfo.getKeySerializer() instanceof UnloadableDummyTypeSerializer ||
 						restoredMetaInfo.getValueSerializer() instanceof UnloadableDummyTypeSerializer) {
 
 						// must fail now if the previous serializer cannot be restored because there is no serializer
@@ -603,6 +430,23 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 		}
 	}
 
+	@Nonnull
+	@Override
+	public RunnableFuture<SnapshotResult<OperatorStateHandle>> snapshot(
+		long checkpointId,
+		long timestamp,
+		@Nonnull CheckpointStreamFactory streamFactory,
+		@Nonnull CheckpointOptions checkpointOptions) throws Exception {
+
+		long syncStartTime = System.currentTimeMillis();
+
+		RunnableFuture<SnapshotResult<OperatorStateHandle>> snapshotRunner =
+			snapshotStrategy.snapshot(checkpointId, timestamp, streamFactory, checkpointOptions);
+
+		snapshotStrategy.logSyncCompleted(streamFactory, syncStartTime);
+		return snapshotRunner;
+	}
+
 	/**
 	 * Implementation of operator list state.
 	 *
@@ -695,14 +539,14 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 		}
 
 		@Override
-		public void update(List<S> values) throws Exception {
+		public void update(List<S> values) {
 			internalList.clear();
 
 			addAll(values);
 		}
 
 		@Override
-		public void addAll(List<S> values) throws Exception {
+		public void addAll(List<S> values) {
 			if (values != null && !values.isEmpty()) {
 				internalList.addAll(values);
 			}
@@ -848,4 +692,167 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 				"Was [" + actualMode + "], " +
 				"registered with [" + expectedMode + "].");
 	}
+
+	/**
+	 * Snapshot strategy for this backend.
+	 */
+	private class DefaultOperatorStateBackendSnapshotStrategy extends AbstractSnapshotStrategy<OperatorStateHandle> {
+
+		protected DefaultOperatorStateBackendSnapshotStrategy() {
+			super("DefaultOperatorStateBackend snapshot");
+		}
+
+		@Nonnull
+		@Override
+		public RunnableFuture<SnapshotResult<OperatorStateHandle>> snapshot(
+			final long checkpointId,
+			final long timestamp,
+			@Nonnull final CheckpointStreamFactory streamFactory,
+			@Nonnull final CheckpointOptions checkpointOptions) throws IOException {
+
+			if (registeredOperatorStates.isEmpty() && registeredBroadcastStates.isEmpty()) {
+				return DoneFuture.of(SnapshotResult.empty());
+			}
+
+			final Map<String, PartitionableListState<?>> registeredOperatorStatesDeepCopies =
+				new HashMap<>(registeredOperatorStates.size());
+			final Map<String, BackendWritableBroadcastState<?, ?>> registeredBroadcastStatesDeepCopies =
+				new HashMap<>(registeredBroadcastStates.size());
+
+			ClassLoader snapshotClassLoader = Thread.currentThread().getContextClassLoader();
+			Thread.currentThread().setContextClassLoader(userClassloader);
+			try {
+				// eagerly create deep copies of the list and the broadcast states (if any)
+				// in the synchronous phase, so that we can use them in the async writing.
+
+				if (!registeredOperatorStates.isEmpty()) {
+					for (Map.Entry<String, PartitionableListState<?>> entry : registeredOperatorStates.entrySet()) {
+						PartitionableListState<?> listState = entry.getValue();
+						if (null != listState) {
+							listState = listState.deepCopy();
+						}
+						registeredOperatorStatesDeepCopies.put(entry.getKey(), listState);
+					}
+				}
+
+				if (!registeredBroadcastStates.isEmpty()) {
+					for (Map.Entry<String, BackendWritableBroadcastState<?, ?>> entry : registeredBroadcastStates.entrySet()) {
+						BackendWritableBroadcastState<?, ?> broadcastState = entry.getValue();
+						if (null != broadcastState) {
+							broadcastState = broadcastState.deepCopy();
+						}
+						registeredBroadcastStatesDeepCopies.put(entry.getKey(), broadcastState);
+					}
+				}
+			} finally {
+				Thread.currentThread().setContextClassLoader(snapshotClassLoader);
+			}
+
+			AsyncSnapshotCallable<SnapshotResult<OperatorStateHandle>> snapshotCallable =
+				new AsyncSnapshotCallable<SnapshotResult<OperatorStateHandle>>() {
+
+					@Override
+					protected SnapshotResult<OperatorStateHandle> callInternal() throws Exception {
+
+						CheckpointStreamFactory.CheckpointStateOutputStream localOut =
+							streamFactory.createCheckpointStateOutputStream(CheckpointedStateScope.EXCLUSIVE);
+						registerCloseableForCancellation(localOut);
+
+						// get the registered operator state infos ...
+						List<StateMetaInfoSnapshot> operatorMetaInfoSnapshots =
+							new ArrayList<>(registeredOperatorStatesDeepCopies.size());
+
+						for (Map.Entry<String, PartitionableListState<?>> entry :
+							registeredOperatorStatesDeepCopies.entrySet()) {
+							operatorMetaInfoSnapshots.add(entry.getValue().getStateMetaInfo().snapshot());
+						}
+
+						// ... get the registered broadcast operator state infos ...
+						List<StateMetaInfoSnapshot> broadcastMetaInfoSnapshots =
+							new ArrayList<>(registeredBroadcastStatesDeepCopies.size());
+
+						for (Map.Entry<String, BackendWritableBroadcastState<?, ?>> entry :
+							registeredBroadcastStatesDeepCopies.entrySet()) {
+							broadcastMetaInfoSnapshots.add(entry.getValue().getStateMetaInfo().snapshot());
+						}
+
+						// ... write them all in the checkpoint stream ...
+						DataOutputView dov = new DataOutputViewStreamWrapper(localOut);
+
+						OperatorBackendSerializationProxy backendSerializationProxy =
+							new OperatorBackendSerializationProxy(operatorMetaInfoSnapshots, broadcastMetaInfoSnapshots);
+
+						backendSerializationProxy.write(dov);
+
+						// ... and then go for the states ...
+
+						// we put BOTH normal and broadcast state metadata here
+						int initialMapCapacity =
+							registeredOperatorStatesDeepCopies.size() + registeredBroadcastStatesDeepCopies.size();
+						final Map<String, OperatorStateHandle.StateMetaInfo> writtenStatesMetaData =
+							new HashMap<>(initialMapCapacity);
+
+						for (Map.Entry<String, PartitionableListState<?>> entry :
+							registeredOperatorStatesDeepCopies.entrySet()) {
+
+							PartitionableListState<?> value = entry.getValue();
+							long[] partitionOffsets = value.write(localOut);
+							OperatorStateHandle.Mode mode = value.getStateMetaInfo().getAssignmentMode();
+							writtenStatesMetaData.put(
+								entry.getKey(),
+								new OperatorStateHandle.StateMetaInfo(partitionOffsets, mode));
+						}
+
+						// ... and the broadcast states themselves ...
+						for (Map.Entry<String, BackendWritableBroadcastState<?, ?>> entry :
+							registeredBroadcastStatesDeepCopies.entrySet()) {
+
+							BackendWritableBroadcastState<?, ?> value = entry.getValue();
+							long[] partitionOffsets = {value.write(localOut)};
+							OperatorStateHandle.Mode mode = value.getStateMetaInfo().getAssignmentMode();
+							writtenStatesMetaData.put(
+								entry.getKey(),
+								new OperatorStateHandle.StateMetaInfo(partitionOffsets, mode));
+						}
+
+						// ... and, finally, create the state handle.
+						OperatorStateHandle retValue = null;
+
+						if (unregisterCloseableFromCancellation(localOut)) {
+
+							StreamStateHandle stateHandle = localOut.closeAndGetHandle();
+
+							if (stateHandle != null) {
+								retValue = new OperatorStreamStateHandle(writtenStatesMetaData, stateHandle);
+							}
+
+							return SnapshotResult.of(retValue);
+						} else {
+							throw new IOException("Stream was already unregistered.");
+						}
+					}
+
+					@Override
+					protected void cleanupProvidedResources() {
+						// nothing to do
+					}
+
+					@Override
+					protected void logAsyncSnapshotComplete(long startTime) {
+						if (asynchronousSnapshots) {
+							logAsyncCompleted(streamFactory, startTime);
+						}
+					}
+				};
+
+			final FutureTask<SnapshotResult<OperatorStateHandle>> task =
+				snapshotCallable.toAsyncSnapshotFutureTask(closeStreamOnCancelRegistry);
+
+			if (!asynchronousSnapshots) {
+				task.run();
+			}
+
+			return task;
+		}
+	}
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SnapshotStrategy.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SnapshotStrategy.java
index 3ad68af..53c8663 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SnapshotStrategy.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SnapshotStrategy.java
@@ -18,8 +18,11 @@
 
 package org.apache.flink.runtime.state;
 
+import org.apache.flink.annotation.Internal;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 
+import javax.annotation.Nonnull;
+
 import java.util.concurrent.RunnableFuture;
 
 /**
@@ -28,7 +31,8 @@ import java.util.concurrent.RunnableFuture;
  *
  * @param <S> type of the returned state object that represents the result of the snapshot operation.
  */
-public interface SnapshotStrategy<S extends StateObject> extends CheckpointListener {
+@Internal
+public interface SnapshotStrategy<S extends StateObject> {
 
 	/**
 	 * Operation that writes a snapshot into a stream that is provided by the given {@link CheckpointStreamFactory} and
@@ -42,9 +46,10 @@ public interface SnapshotStrategy<S extends StateObject> extends CheckpointListe
 	 * @param checkpointOptions Options for how to perform this checkpoint.
 	 * @return A runnable future that will yield a {@link StateObject}.
 	 */
-	RunnableFuture<S> performSnapshot(
+	@Nonnull
+	RunnableFuture<S> snapshot(
 		long checkpointId,
 		long timestamp,
-		CheckpointStreamFactory streamFactory,
-		CheckpointOptions checkpointOptions) throws Exception;
+		@Nonnull CheckpointStreamFactory streamFactory,
+		@Nonnull CheckpointOptions checkpointOptions) throws Exception;
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/Snapshotable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/Snapshotable.java
index 733339f..1677855 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/Snapshotable.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/Snapshotable.java
@@ -18,9 +18,9 @@
 
 package org.apache.flink.runtime.state;
 
-import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.annotation.Internal;
 
-import java.util.concurrent.RunnableFuture;
+import javax.annotation.Nullable;
 
 /**
  * Interface for operators that can perform snapshots of their state.
@@ -28,25 +28,8 @@ import java.util.concurrent.RunnableFuture;
  * @param <S> Generic type of the state object that is created as handle to snapshots.
  * @param <R> Generic type of the state object that used in restore.
  */
-public interface Snapshotable<S extends StateObject, R> {
-
-	/**
-	 * Operation that writes a snapshot into a stream that is provided by the given {@link CheckpointStreamFactory} and
-	 * returns a @{@link RunnableFuture} that gives a state handle to the snapshot. It is up to the implementation if
-	 * the operation is performed synchronous or asynchronous. In the later case, the returned Runnable must be executed
-	 * first before obtaining the handle.
-	 *
-	 * @param checkpointId  The ID of the checkpoint.
-	 * @param timestamp     The timestamp of the checkpoint.
-	 * @param streamFactory The factory that we can use for writing our state to streams.
-	 * @param checkpointOptions Options for how to perform this checkpoint.
-	 * @return A runnable future that will yield a {@link StateObject}.
-	 */
-	RunnableFuture<S> snapshot(
-			long checkpointId,
-			long timestamp,
-			CheckpointStreamFactory streamFactory,
-			CheckpointOptions checkpointOptions) throws Exception;
+@Internal
+public interface Snapshotable<S extends StateObject, R> extends SnapshotStrategy<S> {
 
 	/**
 	 * Restores state that was previously snapshotted from the provided parameters. Typically the parameters are state
@@ -54,5 +37,5 @@ public interface Snapshotable<S extends StateObject, R> {
 	 *
 	 * @param state the old state to restore.
 	 */
-	void restore(R state) throws Exception;
+	void restore(@Nullable R state) throws Exception;
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
index 0e2f16c..05070f9 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
@@ -37,10 +37,10 @@ import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
-import org.apache.flink.runtime.io.async.AbstractAsyncCallableWithResources;
-import org.apache.flink.runtime.io.async.AsyncStoppableTaskWithCallback;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
+import org.apache.flink.runtime.state.AbstractSnapshotStrategy;
+import org.apache.flink.runtime.state.AsyncSnapshotCallable;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.CheckpointStreamWithResultProvider;
 import org.apache.flink.runtime.state.CheckpointedStateScope;
@@ -60,11 +60,10 @@ import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
 import org.apache.flink.runtime.state.RegisteredPriorityQueueStateBackendMetaInfo;
 import org.apache.flink.runtime.state.SnappyStreamCompressionDecorator;
 import org.apache.flink.runtime.state.SnapshotResult;
-import org.apache.flink.runtime.state.SnapshotStrategy;
 import org.apache.flink.runtime.state.StateSnapshot;
-import org.apache.flink.runtime.state.StateSnapshotTransformer;
 import org.apache.flink.runtime.state.StateSnapshotKeyGroupReader;
 import org.apache.flink.runtime.state.StateSnapshotRestore;
+import org.apache.flink.runtime.state.StateSnapshotTransformer;
 import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
 import org.apache.flink.runtime.state.StreamCompressionDecorator;
 import org.apache.flink.runtime.state.StreamStateHandle;
@@ -92,6 +91,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Optional;
+import java.util.concurrent.FutureTask;
 import java.util.concurrent.RunnableFuture;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
@@ -344,15 +344,22 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		}
 	}
 
+	@Nonnull
 	@Override
 	@SuppressWarnings("unchecked")
-	public  RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshot(
-			final long checkpointId,
-			final long timestamp,
-			final CheckpointStreamFactory streamFactory,
-			CheckpointOptions checkpointOptions) {
+	public RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshot(
+		final long checkpointId,
+		final long timestamp,
+		@Nonnull final CheckpointStreamFactory streamFactory,
+		@Nonnull CheckpointOptions checkpointOptions) throws IOException {
 
-		return snapshotStrategy.performSnapshot(checkpointId, timestamp, streamFactory, checkpointOptions);
+		long startTime = System.currentTimeMillis();
+
+		final RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshotRunner =
+			snapshotStrategy.snapshot(checkpointId, timestamp, streamFactory, checkpointOptions);
+
+		snapshotStrategy.logSyncCompleted(streamFactory, startTime);
+		return snapshotRunner;
 	}
 
 	@SuppressWarnings("deprecation")
@@ -630,9 +637,6 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 		}
 
-		default void logOperationCompleted(CheckpointStreamFactory streamFactory, long startTime) {
-
-		}
 
 		boolean isAsynchronous();
 
@@ -642,12 +646,6 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	private class AsyncSnapshotStrategySynchronicityBehavior implements SnapshotStrategySynchronicityBehavior<K> {
 
 		@Override
-		public void logOperationCompleted(CheckpointStreamFactory streamFactory, long startTime) {
-			LOG.debug("Heap backend snapshot ({}, asynchronous part) in thread {} took {} ms.",
-				streamFactory, Thread.currentThread(), (System.currentTimeMillis() - startTime));
-		}
-
-		@Override
 		public boolean isAsynchronous() {
 			return true;
 		}
@@ -682,28 +680,28 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	 * the concrete strategies. Subclasses must be threadsafe.
 	 */
 	private class HeapSnapshotStrategy
-		implements SnapshotStrategy<SnapshotResult<KeyedStateHandle>>, SnapshotStrategySynchronicityBehavior<K> {
+		extends AbstractSnapshotStrategy<KeyedStateHandle> implements SnapshotStrategySynchronicityBehavior<K> {
 
 		private final SnapshotStrategySynchronicityBehavior<K> snapshotStrategySynchronicityTrait;
 
 		HeapSnapshotStrategy(
 			SnapshotStrategySynchronicityBehavior<K> snapshotStrategySynchronicityTrait) {
+			super("Heap backend snapshot");
 			this.snapshotStrategySynchronicityTrait = snapshotStrategySynchronicityTrait;
 		}
 
+		@Nonnull
 		@Override
-		public RunnableFuture<SnapshotResult<KeyedStateHandle>> performSnapshot(
+		public RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshot(
 			long checkpointId,
 			long timestamp,
-			CheckpointStreamFactory primaryStreamFactory,
-			CheckpointOptions checkpointOptions) {
+			@Nonnull CheckpointStreamFactory primaryStreamFactory,
+			@Nonnull CheckpointOptions checkpointOptions) throws IOException {
 
 			if (!hasRegisteredState()) {
 				return DoneFuture.of(SnapshotResult.empty());
 			}
 
-			long syncStartTime = System.currentTimeMillis();
-
 			int numStates = registeredKVStates.size() + registeredPQStates.size();
 
 			Preconditions.checkState(numStates <= Short.MAX_VALUE,
@@ -754,53 +752,23 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 			//--------------------------------------------------- this becomes the end of sync part
 
-			// implementation of the async IO operation, based on FutureTask
-			final AbstractAsyncCallableWithResources<SnapshotResult<KeyedStateHandle>> ioCallable =
-				new AbstractAsyncCallableWithResources<SnapshotResult<KeyedStateHandle>>() {
-
-					CheckpointStreamWithResultProvider streamAndResultExtractor = null;
-
-					@Override
-					protected void acquireResources() throws Exception {
-						streamAndResultExtractor = checkpointStreamSupplier.get();
-						cancelStreamRegistry.registerCloseable(streamAndResultExtractor);
-					}
-
+			final AsyncSnapshotCallable<SnapshotResult<KeyedStateHandle>> asyncSnapshotCallable =
+				new AsyncSnapshotCallable<SnapshotResult<KeyedStateHandle>>() {
 					@Override
-					protected void releaseResources() {
+					protected SnapshotResult<KeyedStateHandle> callInternal() throws Exception {
 
-						unregisterAndCloseStreamAndResultExtractor();
+						final CheckpointStreamWithResultProvider streamWithResultProvider =
+							checkpointStreamSupplier.get();
 
-						for (StateSnapshot tableSnapshot : cowStateStableSnapshots.values()) {
-							tableSnapshot.release();
-						}
-					}
+						registerCloseableForCancellation(streamWithResultProvider);
 
-					@Override
-					protected void stopOperation() {
-						unregisterAndCloseStreamAndResultExtractor();
-					}
+						final CheckpointStreamFactory.CheckpointStateOutputStream localStream =
+							streamWithResultProvider.getCheckpointOutputStream();
 
-					private void unregisterAndCloseStreamAndResultExtractor() {
-						if (cancelStreamRegistry.unregisterCloseable(streamAndResultExtractor)) {
-							IOUtils.closeQuietly(streamAndResultExtractor);
-							streamAndResultExtractor = null;
-						}
-					}
-
-					@Nonnull
-					@Override
-					protected SnapshotResult<KeyedStateHandle> performOperation() throws Exception {
-
-						long startTime = System.currentTimeMillis();
-
-						CheckpointStreamFactory.CheckpointStateOutputStream localStream =
-							this.streamAndResultExtractor.getCheckpointOutputStream();
-
-						DataOutputViewStreamWrapper outView = new DataOutputViewStreamWrapper(localStream);
+						final DataOutputViewStreamWrapper outView = new DataOutputViewStreamWrapper(localStream);
 						serializationProxy.write(outView);
 
-						long[] keyGroupRangeOffsets = new long[keyGroupRange.getNumberOfKeyGroups()];
+						final long[] keyGroupRangeOffsets = new long[keyGroupRange.getNumberOfKeyGroups()];
 
 						for (int keyGroupPos = 0; keyGroupPos < keyGroupRange.getNumberOfKeyGroups(); ++keyGroupPos) {
 							int keyGroupId = keyGroupRange.getKeyGroupId(keyGroupPos);
@@ -812,35 +780,46 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 								StateSnapshot.StateKeyGroupWriter partitionedSnapshot =
 
 									stateSnapshot.getValue().getKeyGroupWriter();
-								try (OutputStream kgCompressionOut = keyGroupCompressionDecorator.decorateWithCompression(localStream)) {
-									DataOutputViewStreamWrapper kgCompressionView = new DataOutputViewStreamWrapper(kgCompressionOut);
+								try (
+									OutputStream kgCompressionOut =
+										keyGroupCompressionDecorator.decorateWithCompression(localStream)) {
+									DataOutputViewStreamWrapper kgCompressionView =
+										new DataOutputViewStreamWrapper(kgCompressionOut);
 									kgCompressionView.writeShort(stateNamesToId.get(stateSnapshot.getKey()));
 									partitionedSnapshot.writeStateInKeyGroup(kgCompressionView, keyGroupId);
 								} // this will just close the outer compression stream
 							}
 						}
 
-						if (cancelStreamRegistry.unregisterCloseable(streamAndResultExtractor)) {
+						if (unregisterCloseableFromCancellation(streamWithResultProvider)) {
 							KeyGroupRangeOffsets kgOffs = new KeyGroupRangeOffsets(keyGroupRange, keyGroupRangeOffsets);
 							SnapshotResult<StreamStateHandle> result =
-								streamAndResultExtractor.closeAndFinalizeCheckpointStreamResult();
-							streamAndResultExtractor = null;
-							logOperationCompleted(primaryStreamFactory, startTime);
+								streamWithResultProvider.closeAndFinalizeCheckpointStreamResult();
 							return CheckpointStreamWithResultProvider.toKeyedStateHandleSnapshotResult(result, kgOffs);
+						} else {
+							throw new IOException("Stream already unregistered.");
 						}
+					}
 
-						return SnapshotResult.empty();
+					@Override
+					protected void cleanupProvidedResources() {
+						for (StateSnapshot tableSnapshot : cowStateStableSnapshots.values()) {
+							tableSnapshot.release();
+						}
 					}
-				};
 
-			AsyncStoppableTaskWithCallback<SnapshotResult<KeyedStateHandle>> task =
-				AsyncStoppableTaskWithCallback.from(ioCallable);
+					@Override
+					protected void logAsyncSnapshotComplete(long startTime) {
+						if (snapshotStrategySynchronicityTrait.isAsynchronous()) {
+							logAsyncCompleted(primaryStreamFactory, startTime);
+						}
+					}
+				};
 
+			final FutureTask<SnapshotResult<KeyedStateHandle>> task =
+				asyncSnapshotCallable.toAsyncSnapshotFutureTask(cancelStreamRegistry);
 			finalizeSnapshotBeforeReturnHook(task);
 
-			LOG.debug("Heap backend snapshot (" + primaryStreamFactory + ", synchronous part) in thread " +
-				Thread.currentThread() + " took " + (System.currentTimeMillis() - syncStartTime) + " ms.");
-
 			return task;
 		}
 
@@ -850,11 +829,6 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		}
 
 		@Override
-		public void logOperationCompleted(CheckpointStreamFactory streamFactory, long startTime) {
-			snapshotStrategySynchronicityTrait.logOperationCompleted(streamFactory, startTime);
-		}
-
-		@Override
 		public boolean isAsynchronous() {
 			return snapshotStrategySynchronicityTrait.isAsynchronous();
 		}
@@ -882,11 +856,6 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 				}
 			}
 		}
-
-		@Override
-		public void notifyCheckpointComplete(long checkpointId) throws Exception {
-			// nothing to do.
-		}
 	}
 
 	private interface StateFactory {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/AsyncSnapshotCallableTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/AsyncSnapshotCallableTest.java
new file mode 100644
index 0000000..304a495
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/AsyncSnapshotCallableTest.java
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.core.fs.CloseableRegistry;
+import org.apache.flink.core.testutils.OneShotLatch;
+import org.apache.flink.util.Preconditions;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import javax.annotation.Nonnull;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.CancellationException;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.FutureTask;
+
+/**
+ * Tests for {@link AsyncSnapshotCallable}.
+ */
+public class AsyncSnapshotCallableTest {
+
+	private static final String METHOD_CALL = "callInternal";
+	private static final String METHOD_LOG = "logAsyncSnapshotComplete";
+	private static final String METHOD_CLEANUP = "cleanupProvidedResources";
+	private static final String METHOD_CANCEL = "cancel";
+	private static final String SUCCESS = "Success!";
+
+	private CloseableRegistry ownerRegistry;
+	private TestBlockingCloseable testProvidedResource;
+	private TestBlockingCloseable testBlocker;
+	private TestAsyncSnapshotCallable testAsyncSnapshotCallable;
+	private FutureTask<String> task;
+
+	@Before
+	public void setup() throws IOException {
+		ownerRegistry = new CloseableRegistry();
+		testProvidedResource = new TestBlockingCloseable();
+		testBlocker = new TestBlockingCloseable();
+		testAsyncSnapshotCallable = new TestAsyncSnapshotCallable(testProvidedResource, testBlocker);
+		task = testAsyncSnapshotCallable.toAsyncSnapshotFutureTask(ownerRegistry);
+		Assert.assertEquals(1, ownerRegistry.getNumberOfRegisteredCloseables());
+	}
+
+	@After
+	public void finalChecks() {
+		Assert.assertTrue(testProvidedResource.isClosed());
+		Assert.assertEquals(0, ownerRegistry.getNumberOfRegisteredCloseables());
+	}
+
+	@Test
+	public void testNormalRun() throws Exception {
+
+		Thread runner = startTask(task);
+
+		while (testBlocker.getWaitersCount() < 1) {
+			Thread.sleep(1L);
+		}
+
+		testBlocker.unblockSuccessfully();
+
+		runner.join();
+
+		Assert.assertEquals(SUCCESS, task.get());
+		Assert.assertEquals(
+			Arrays.asList(METHOD_CALL, METHOD_LOG, METHOD_CLEANUP),
+			testAsyncSnapshotCallable.getInvocationOrder());
+
+		Assert.assertTrue(testBlocker.isClosed());
+	}
+
+	@Test
+	public void testExceptionRun() throws Exception {
+
+		testBlocker.introduceException();
+		Thread runner = startTask(task);
+
+		while (testBlocker.getWaitersCount() < 1) {
+			Thread.sleep(1L);
+		}
+
+		testBlocker.unblockSuccessfully();
+		try {
+			task.get();
+			Assert.fail();
+		} catch (ExecutionException ee) {
+			Assert.assertEquals(IOException.class, ee.getCause().getClass());
+		}
+
+		runner.join();
+
+		Assert.assertEquals(
+			Arrays.asList(METHOD_CALL, METHOD_CLEANUP),
+			testAsyncSnapshotCallable.getInvocationOrder());
+
+		Assert.assertTrue(testBlocker.isClosed());
+	}
+
+	@Test
+	public void testCancelRun() throws Exception {
+
+		Thread runner = startTask(task);
+
+		while (testBlocker.getWaitersCount() < 1) {
+			Thread.sleep(1L);
+		}
+
+		task.cancel(true);
+		testBlocker.unblockExceptionally();
+
+		try {
+			task.get();
+			Assert.fail();
+		} catch (CancellationException ignored) {
+		}
+
+		runner.join();
+
+		Assert.assertEquals(
+			Arrays.asList(METHOD_CALL, METHOD_CANCEL, METHOD_CLEANUP),
+			testAsyncSnapshotCallable.getInvocationOrder());
+		Assert.assertTrue(testProvidedResource.isClosed());
+		Assert.assertTrue(testBlocker.isClosed());
+	}
+
+	@Test
+	public void testCloseRun() throws Exception {
+
+		Thread runner = startTask(task);
+
+		while (testBlocker.getWaitersCount() < 1) {
+			Thread.sleep(1L);
+		}
+
+		ownerRegistry.close();
+
+		try {
+			task.get();
+			Assert.fail();
+		} catch (CancellationException ignored) {
+		}
+
+		runner.join();
+
+		Assert.assertEquals(
+			Arrays.asList(METHOD_CALL, METHOD_CANCEL, METHOD_CLEANUP),
+			testAsyncSnapshotCallable.getInvocationOrder());
+		Assert.assertTrue(testBlocker.isClosed());
+	}
+
+	@Test
+	public void testCancelBeforeRun() throws Exception {
+
+		task.cancel(true);
+
+		Thread runner = startTask(task);
+
+		try {
+			task.get();
+			Assert.fail();
+		} catch (CancellationException ignored) {
+		}
+
+		runner.join();
+
+		Assert.assertEquals(
+			Arrays.asList(METHOD_CANCEL, METHOD_CLEANUP),
+			testAsyncSnapshotCallable.getInvocationOrder());
+
+		Assert.assertTrue(testProvidedResource.isClosed());
+	}
+
+	private Thread startTask(Runnable task)  {
+		Thread runner = new Thread(task);
+		runner.start();
+		return runner;
+	}
+
+	/**
+	 * Test implementation of {@link AsyncSnapshotCallable}.
+	 */
+	private static class TestAsyncSnapshotCallable extends AsyncSnapshotCallable<String> {
+
+		@Nonnull
+		private final TestBlockingCloseable providedResource;
+		@Nonnull
+		private final TestBlockingCloseable blockingResource;
+		@Nonnull
+		private final List<String> invocationOrder;
+
+		TestAsyncSnapshotCallable(
+			@Nonnull TestBlockingCloseable providedResource,
+			@Nonnull TestBlockingCloseable blockingResource) {
+
+			this.providedResource = providedResource;
+			this.blockingResource = blockingResource;
+			this.invocationOrder = new ArrayList<>();
+		}
+
+		@Override
+		protected String callInternal() throws Exception {
+
+			addInvocation(METHOD_CALL);
+			registerCloseableForCancellation(blockingResource);
+			try {
+				blockingResource.simulateBlockingOperation();
+			} finally {
+				if (unregisterCloseableFromCancellation(blockingResource)) {
+					blockingResource.close();
+				}
+			}
+
+			return SUCCESS;
+		}
+
+		@Override
+		protected void cleanupProvidedResources() {
+			addInvocation(METHOD_CLEANUP);
+			providedResource.close();
+		}
+
+		@Override
+		protected void logAsyncSnapshotComplete(long startTime) {
+			invocationOrder.add(METHOD_LOG);
+		}
+
+		@Override
+		protected void cancel() {
+			addInvocation(METHOD_CANCEL);
+			super.cancel();
+		}
+
+		@Nonnull
+		public List<String> getInvocationOrder() {
+			synchronized (invocationOrder) {
+				return new ArrayList<>(invocationOrder);
+			}
+		}
+
+		private void addInvocation(@Nonnull String invocation) {
+			synchronized (invocationOrder) {
+				invocationOrder.add(invocation);
+			}
+		}
+	}
+
+	/**
+	 * Mix of a {@link Closeable} and and some {@link OneShotLatch} functionality for testing.
+	 */
+	private static class TestBlockingCloseable implements Closeable {
+
+		private final OneShotLatch blockerLatch = new OneShotLatch();
+		private boolean closed = false;
+		private boolean unblocked = false;
+		private boolean exceptionally = false;
+
+		public void simulateBlockingOperation() throws IOException {
+			while (!unblocked) {
+				try {
+					blockerLatch.await();
+				} catch (InterruptedException e) {
+					blockerLatch.reset();
+				}
+			}
+			if (exceptionally) {
+				throw new IOException("Closed in block");
+			}
+		}
+
+		@Override
+		public void close() {
+			Preconditions.checkState(!closed);
+			this.closed = true;
+			unblockExceptionally();
+		}
+
+		public boolean isClosed() {
+			return closed;
+		}
+
+		public void unblockExceptionally() {
+			introduceException();
+			unblock();
+		}
+
+		public void unblockSuccessfully() {
+			unblock();
+		}
+
+		private void unblock() {
+			this.unblocked = true;
+			blockerLatch.trigger();
+		}
+
+		public void introduceException() {
+			this.exceptionally = true;
+		}
+
+		public int getWaitersCount() {
+			return blockerLatch.getWaitersCount();
+		}
+	}
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
index d8918e7..b5988f3 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
@@ -55,7 +55,6 @@ import java.util.HashMap;
 import java.util.Iterator;
 import java.util.Map;
 import java.util.concurrent.CancellationException;
-import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.FutureTask;
@@ -790,8 +789,7 @@ public class OperatorStateBackendTest {
 		try {
 			runnableFuture.get(60, TimeUnit.SECONDS);
 			Assert.fail();
-		} catch (ExecutionException eex) {
-			Assert.assertTrue(eex.getCause() instanceof IOException);
+		} catch (CancellationException expected) {
 		}
 	}
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
index 059a706..649c6d0 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
@@ -110,9 +110,9 @@ import java.util.PrimitiveIterator;
 import java.util.Random;
 import java.util.Timer;
 import java.util.TimerTask;
+import java.util.concurrent.CancellationException;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
@@ -189,7 +189,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 				numberOfKeyGroups,
 				keyGroupRange,
 				env.getTaskKvStateRegistry(),
-			    TtlTimeProvider.DEFAULT);
+				TtlTimeProvider.DEFAULT);
 
 		backend.restore(null);
 
@@ -4015,7 +4015,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			try {
 				snapshot.get();
 				fail("Close was not propagated.");
-			} catch (ExecutionException ex) {
+			} catch (CancellationException ex) {
 				//ignore
 			}
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
index 0b5931c..ccfafec 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
@@ -170,12 +170,13 @@ public class MockKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			.map(Map.Entry::getKey);
 	}
 
+	@Nonnull
 	@Override
 	public RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshot(
 		long checkpointId,
 		long timestamp,
-		CheckpointStreamFactory streamFactory,
-		CheckpointOptions checkpointOptions) {
+		@Nonnull CheckpointStreamFactory streamFactory,
+		@Nonnull CheckpointOptions checkpointOptions) {
 		return new FutureTask<>(() ->
 			SnapshotResult.of(new MockKeyedStateHandle<>(copy(stateValues, stateSnapshotFilters))));
 	}
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index 87c7e55..60baaed 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -35,6 +35,7 @@ import org.apache.flink.api.common.typeutils.base.array.BytePrimitiveArraySerial
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.contrib.streaming.state.iterator.RocksStateKeysIterator;
+import org.apache.flink.contrib.streaming.state.snapshot.RocksDBSnapshotStrategyBase;
 import org.apache.flink.contrib.streaming.state.snapshot.RocksFullSnapshotStrategy;
 import org.apache.flink.contrib.streaming.state.snapshot.RocksIncrementalSnapshotStrategy;
 import org.apache.flink.core.fs.FSDataInputStream;
@@ -46,8 +47,8 @@ import org.apache.flink.core.memory.ByteArrayDataInputView;
 import org.apache.flink.core.memory.ByteArrayDataOutputView;
 import org.apache.flink.core.memory.DataInputView;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
-import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.CheckpointType;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
@@ -70,7 +71,6 @@ import org.apache.flink.runtime.state.RegisteredPriorityQueueStateBackendMetaInf
 import org.apache.flink.runtime.state.RegisteredStateMetaInfoBase;
 import org.apache.flink.runtime.state.SnappyStreamCompressionDecorator;
 import org.apache.flink.runtime.state.SnapshotResult;
-import org.apache.flink.runtime.state.SnapshotStrategy;
 import org.apache.flink.runtime.state.StateHandleID;
 import org.apache.flink.runtime.state.StateSnapshotTransformer;
 import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
@@ -207,7 +207,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	private final WriteOptions writeOptions;
 
 	/**
-	 * Information about the k/v states as we create them. This is used to retrieve the
+	 * Information about the k/v states, maintained in the order as we create them. This is used to retrieve the
 	 * column family that is used for a state and also for sanity checks when restoring.
 	 */
 	private final LinkedHashMap<String, Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase>> kvStateInformation;
@@ -229,8 +229,11 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	/** The configuration of local recovery. */
 	private final LocalRecoveryConfig localRecoveryConfig;
 
-	/** The snapshot strategy, e.g., if we use full or incremental checkpoints, local state, and so on. */
-	private SnapshotStrategy<SnapshotResult<KeyedStateHandle>> snapshotStrategy;
+	/** The checkpoint snapshot strategy, e.g., if we use full or incremental checkpoints, local state, and so on. */
+	private RocksDBSnapshotStrategyBase<K> checkpointSnapshotStrategy;
+
+	/** The savepoint snapshot strategy. */
+	private RocksDBSnapshotStrategyBase<K> savepointSnapshotStrategy;
 
 	/** Factory for priority queue state. */
 	private final PriorityQueueSetFactory priorityQueueFactory;
@@ -444,17 +447,29 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	 * @return Future to the state handle of the snapshot data.
 	 * @throws Exception indicating a problem in the synchronous part of the checkpoint.
 	 */
+	@Nonnull
 	@Override
 	public RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshot(
 		final long checkpointId,
 		final long timestamp,
-		final CheckpointStreamFactory streamFactory,
-		CheckpointOptions checkpointOptions) throws Exception {
+		@Nonnull final CheckpointStreamFactory streamFactory,
+		@Nonnull CheckpointOptions checkpointOptions) throws Exception {
+
+		long startTime = System.currentTimeMillis();
 
 		// flush everything into db before taking a snapshot
 		writeBatchWrapper.flush();
 
-		return snapshotStrategy.performSnapshot(checkpointId, timestamp, streamFactory, checkpointOptions);
+		RocksDBSnapshotStrategyBase<K> chosenSnapshotStrategy =
+			CheckpointType.SAVEPOINT == checkpointOptions.getCheckpointType() ?
+				savepointSnapshotStrategy : checkpointSnapshotStrategy;
+
+		RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshotRunner =
+			chosenSnapshotStrategy.snapshot(checkpointId, timestamp, streamFactory, checkpointOptions);
+
+		chosenSnapshotStrategy.logSyncCompleted(streamFactory, startTime);
+
+		return snapshotRunner;
 	}
 
 	@Override
@@ -497,7 +512,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	void initializeSnapshotStrategy(
 		@Nullable RocksDBIncrementalRestoreOperation<K> incrementalRestoreOperation) {
 
-		final RocksFullSnapshotStrategy<K> fullSnapshotStrategy =
+		this.savepointSnapshotStrategy =
 			new RocksFullSnapshotStrategy<>(
 				db,
 				rocksDBResourceGuard,
@@ -525,7 +540,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 				Preconditions.checkState(lastCompletedCheckpointId >= 0L);
 			}
 			// TODO eventually we might want to separate savepoint and snapshot strategy, i.e. having 2 strategies.
-			this.snapshotStrategy = new RocksIncrementalSnapshotStrategy<>(
+			this.checkpointSnapshotStrategy = new RocksIncrementalSnapshotStrategy<>(
 				db,
 				rocksDBResourceGuard,
 				keySerializer,
@@ -537,17 +552,21 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 				instanceBasePath,
 				backendUID,
 				materializedSstFiles,
-				lastCompletedCheckpointId,
-				fullSnapshotStrategy);
+				lastCompletedCheckpointId);
 		} else {
-			this.snapshotStrategy = fullSnapshotStrategy;
+			this.checkpointSnapshotStrategy = savepointSnapshotStrategy;
 		}
 	}
 
 	@Override
 	public void notifyCheckpointComplete(long completedCheckpointId) throws Exception {
-		if (snapshotStrategy != null) {
-			snapshotStrategy.notifyCheckpointComplete(completedCheckpointId);
+
+		if (checkpointSnapshotStrategy != null) {
+			checkpointSnapshotStrategy.notifyCheckpointComplete(completedCheckpointId);
+		}
+
+		if (savepointSnapshotStrategy != null) {
+			savepointSnapshotStrategy.notifyCheckpointComplete(completedCheckpointId);
 		}
 	}
 
@@ -966,9 +985,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			@Nonnull
 			private final List<StateMetaInfoSnapshot> stateMetaInfoSnapshots;
 
-			private
-
-			RestoredDBInstance(
+			private RestoredDBInstance(
 				@Nonnull RocksDB db,
 				@Nonnull List<ColumnFamilyHandle> columnFamilyHandles,
 				@Nonnull List<ColumnFamilyDescriptor> columnFamilyDescriptors,
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/SnapshotStrategyBase.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksDBSnapshotStrategyBase.java
similarity index 57%
rename from flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/SnapshotStrategyBase.java
rename to flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksDBSnapshotStrategyBase.java
index efebe8c..fffd98d 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/SnapshotStrategyBase.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksDBSnapshotStrategyBase.java
@@ -21,6 +21,11 @@ package org.apache.flink.contrib.streaming.state.snapshot;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.core.fs.CloseableRegistry;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.state.AbstractSnapshotStrategy;
+import org.apache.flink.runtime.state.CheckpointListener;
+import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.runtime.state.DoneFuture;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.LocalRecoveryConfig;
@@ -31,44 +36,60 @@ import org.apache.flink.util.ResourceGuard;
 
 import org.rocksdb.ColumnFamilyHandle;
 import org.rocksdb.RocksDB;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import javax.annotation.Nonnegative;
 import javax.annotation.Nonnull;
 
 import java.util.LinkedHashMap;
+import java.util.concurrent.RunnableFuture;
 
 /**
- * Base class for {@link SnapshotStrategy} implementations on RocksDB.
+ * Abstract base class for {@link SnapshotStrategy} implementations for RocksDB state backend.
  *
  * @param <K> type of the backend keys.
  */
-public abstract class SnapshotStrategyBase<K> implements SnapshotStrategy<SnapshotResult<KeyedStateHandle>> {
+public abstract class RocksDBSnapshotStrategyBase<K>
+	extends AbstractSnapshotStrategy<KeyedStateHandle>
+	implements CheckpointListener {
 
+	private static final Logger LOG = LoggerFactory.getLogger(RocksDBSnapshotStrategyBase.class);
+
+	/** RocksDB instance from the backend. */
 	@Nonnull
 	protected final RocksDB db;
 
+	/** Resource guard for the RocksDB instance. */
 	@Nonnull
 	protected final ResourceGuard rocksDBResourceGuard;
 
+	/** The key serializer of the backend. */
 	@Nonnull
 	protected final TypeSerializer<K> keySerializer;
 
+	/** Key/Value state meta info from the backend. */
 	@Nonnull
 	protected final LinkedHashMap<String, Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase>> kvStateInformation;
 
+	/** The key-group range for the task. */
 	@Nonnull
 	protected final KeyGroupRange keyGroupRange;
 
+	/** Number of bytes in the key-group prefix. */
 	@Nonnegative
 	protected final int keyGroupPrefixBytes;
 
+	/** The configuration for local recovery. */
 	@Nonnull
 	protected final LocalRecoveryConfig localRecoveryConfig;
 
+	/** A {@link CloseableRegistry} that will be closed when the task is cancelled. */
 	@Nonnull
 	protected final CloseableRegistry cancelStreamRegistry;
 
-	public SnapshotStrategyBase(
+	public RocksDBSnapshotStrategyBase(
+		@Nonnull String description,
 		@Nonnull RocksDB db,
 		@Nonnull ResourceGuard rocksDBResourceGuard,
 		@Nonnull TypeSerializer<K> keySerializer,
@@ -78,6 +99,7 @@ public abstract class SnapshotStrategyBase<K> implements SnapshotStrategy<Snapsh
 		@Nonnull LocalRecoveryConfig localRecoveryConfig,
 		@Nonnull CloseableRegistry cancelStreamRegistry) {
 
+		super(description);
 		this.db = db;
 		this.rocksDBResourceGuard = rocksDBResourceGuard;
 		this.keySerializer = keySerializer;
@@ -87,4 +109,33 @@ public abstract class SnapshotStrategyBase<K> implements SnapshotStrategy<Snapsh
 		this.localRecoveryConfig = localRecoveryConfig;
 		this.cancelStreamRegistry = cancelStreamRegistry;
 	}
+
+	@Nonnull
+	@Override
+	public final RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshot(
+		long checkpointId,
+		long timestamp,
+		@Nonnull CheckpointStreamFactory streamFactory,
+		@Nonnull CheckpointOptions checkpointOptions) throws Exception {
+
+		if (kvStateInformation.isEmpty()) {
+			if (LOG.isDebugEnabled()) {
+				LOG.debug("Asynchronous RocksDB snapshot performed on empty keyed state at {}. Returning null.",
+					timestamp);
+			}
+			return DoneFuture.of(SnapshotResult.empty());
+		} else {
+			return doSnapshot(checkpointId, timestamp, streamFactory, checkpointOptions);
+		}
+	}
+
+	/**
+	 * This method implements the concrete snapshot logic for a non-empty state.
+	 */
+	@Nonnull
+	protected abstract RunnableFuture<SnapshotResult<KeyedStateHandle>> doSnapshot(
+		long checkpointId,
+		long timestamp,
+		CheckpointStreamFactory streamFactory,
+		CheckpointOptions checkpointOptions) throws Exception;
 }
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksFullSnapshotStrategy.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksFullSnapshotStrategy.java
index 0cc9729..0aa091e 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksFullSnapshotStrategy.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksFullSnapshotStrategy.java
@@ -30,10 +30,10 @@ import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.CheckpointType;
+import org.apache.flink.runtime.state.AsyncSnapshotCallable;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.CheckpointStreamWithResultProvider;
 import org.apache.flink.runtime.state.CheckpointedStateScope;
-import org.apache.flink.runtime.state.DoneFuture;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyedBackendSerializationProxy;
@@ -55,8 +55,6 @@ import org.rocksdb.ReadOptions;
 import org.rocksdb.RocksDB;
 import org.rocksdb.RocksIterator;
 import org.rocksdb.Snapshot;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 
 import javax.annotation.Nonnegative;
 import javax.annotation.Nonnull;
@@ -67,11 +65,7 @@ import java.util.ArrayList;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Objects;
-import java.util.concurrent.Callable;
-import java.util.concurrent.CancellationException;
-import java.util.concurrent.FutureTask;
 import java.util.concurrent.RunnableFuture;
-import java.util.concurrent.atomic.AtomicBoolean;
 
 import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUtil.END_OF_KEY_GROUP_MARK;
 import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUtil.hasMetaDataFollowsFlag;
@@ -84,9 +78,9 @@ import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUti
  *
  * @param <K> type of the backend keys.
  */
-public class RocksFullSnapshotStrategy<K> extends SnapshotStrategyBase<K> {
+public class RocksFullSnapshotStrategy<K> extends RocksDBSnapshotStrategyBase<K> {
 
-	private static final Logger LOG = LoggerFactory.getLogger(RocksFullSnapshotStrategy.class);
+	private static final String DESCRIPTION = "Asynchronous incremental RocksDB snapshot";
 
 	/** This decorator is used to apply compression per key-group for the written snapshot data. */
 	@Nonnull
@@ -103,6 +97,7 @@ public class RocksFullSnapshotStrategy<K> extends SnapshotStrategyBase<K> {
 		@Nonnull CloseableRegistry cancelStreamRegistry,
 		@Nonnull StreamCompressionDecorator keyGroupCompressionDecorator) {
 		super(
+			DESCRIPTION,
 			db,
 			rocksDBResourceGuard,
 			keySerializer,
@@ -115,45 +110,40 @@ public class RocksFullSnapshotStrategy<K> extends SnapshotStrategyBase<K> {
 		this.keyGroupCompressionDecorator = keyGroupCompressionDecorator;
 	}
 
+	@Nonnull
 	@Override
-	public RunnableFuture<SnapshotResult<KeyedStateHandle>> performSnapshot(
+	public RunnableFuture<SnapshotResult<KeyedStateHandle>> doSnapshot(
 		long checkpointId,
 		long timestamp,
-		CheckpointStreamFactory primaryStreamFactory,
-		CheckpointOptions checkpointOptions) throws Exception {
+		@Nonnull CheckpointStreamFactory primaryStreamFactory,
+		@Nonnull CheckpointOptions checkpointOptions) throws Exception {
 
-		long startTime = System.currentTimeMillis();
+		final SupplierWithException<CheckpointStreamWithResultProvider, Exception> checkpointStreamSupplier =
+			createCheckpointStreamSupplier(checkpointId, primaryStreamFactory, checkpointOptions);
 
-		if (kvStateInformation.isEmpty()) {
-			if (LOG.isDebugEnabled()) {
-				LOG.debug("Asynchronous RocksDB snapshot performed on empty keyed state at {}. Returning null.",
-					timestamp);
-			}
+		final List<StateMetaInfoSnapshot> stateMetaInfoSnapshots = new ArrayList<>(kvStateInformation.size());
+		final List<Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase>> metaDataCopy =
+			new ArrayList<>(kvStateInformation.size());
 
-			return DoneFuture.of(SnapshotResult.empty());
+		for (Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> tuple2 : kvStateInformation.values()) {
+			// snapshot meta info
+			stateMetaInfoSnapshots.add(tuple2.f1.snapshot());
+			metaDataCopy.add(tuple2);
 		}
 
-		final SupplierWithException<CheckpointStreamWithResultProvider, Exception> supplier =
-
-			localRecoveryConfig.isLocalRecoveryEnabled() &&
-				(CheckpointType.SAVEPOINT != checkpointOptions.getCheckpointType()) ?
-
-				() -> CheckpointStreamWithResultProvider.createDuplicatingStream(
-					checkpointId,
-					CheckpointedStateScope.EXCLUSIVE,
-					primaryStreamFactory,
-					localRecoveryConfig.getLocalStateDirectoryProvider()) :
+		final ResourceGuard.Lease lease = rocksDBResourceGuard.acquireResource();
+		final Snapshot snapshot = db.getSnapshot();
 
-				() -> CheckpointStreamWithResultProvider.createSimpleStream(
-					CheckpointedStateScope.EXCLUSIVE,
-					primaryStreamFactory);
+		final SnapshotAsynchronousPartCallable asyncSnapshotCallable =
+			new SnapshotAsynchronousPartCallable(
+				checkpointStreamSupplier,
+				lease,
+				snapshot,
+				stateMetaInfoSnapshots,
+				metaDataCopy,
+				primaryStreamFactory.toString());
 
-		final CloseableRegistry snapshotCloseableRegistry = new CloseableRegistry();
-
-		final RocksDBFullSnapshotCallable snapshotOperation =
-			new RocksDBFullSnapshotCallable(supplier, snapshotCloseableRegistry);
-
-		return new SnapshotTask(snapshotOperation);
+		return asyncSnapshotCallable.toAsyncSnapshotFutureTask(cancelStreamRegistry);
 	}
 
 	@Override
@@ -161,160 +151,124 @@ public class RocksFullSnapshotStrategy<K> extends SnapshotStrategyBase<K> {
 		// nothing to do.
 	}
 
-	/**
-	 * Wrapping task to run a {@link RocksDBFullSnapshotCallable} and delegate cancellation.
-	 */
-	private class SnapshotTask extends FutureTask<SnapshotResult<KeyedStateHandle>> {
+	private SupplierWithException<CheckpointStreamWithResultProvider, Exception> createCheckpointStreamSupplier(
+		long checkpointId,
+		CheckpointStreamFactory primaryStreamFactory,
+		CheckpointOptions checkpointOptions) {
 
-		/** Reference to the callable for cancellation. */
-		@Nonnull
-		private final AutoCloseable callableClose;
+		return localRecoveryConfig.isLocalRecoveryEnabled() &&
+			(CheckpointType.SAVEPOINT != checkpointOptions.getCheckpointType()) ?
 
-		SnapshotTask(@Nonnull RocksDBFullSnapshotCallable callable) {
-			super(callable);
-			this.callableClose = callable;
-		}
+			() -> CheckpointStreamWithResultProvider.createDuplicatingStream(
+				checkpointId,
+				CheckpointedStateScope.EXCLUSIVE,
+				primaryStreamFactory,
+				localRecoveryConfig.getLocalStateDirectoryProvider()) :
 
-		@Override
-		public boolean cancel(boolean mayInterruptIfRunning) {
-			IOUtils.closeQuietly(callableClose);
-			return super.cancel(mayInterruptIfRunning);
-		}
+			() -> CheckpointStreamWithResultProvider.createSimpleStream(
+				CheckpointedStateScope.EXCLUSIVE,
+				primaryStreamFactory);
 	}
 
 	/**
 	 * Encapsulates the process to perform a full snapshot of a RocksDBKeyedStateBackend.
 	 */
 	@VisibleForTesting
-	private class RocksDBFullSnapshotCallable implements Callable<SnapshotResult<KeyedStateHandle>>, AutoCloseable {
-
-		@Nonnull
-		private final KeyGroupRangeOffsets keyGroupRangeOffsets;
+	private class SnapshotAsynchronousPartCallable extends AsyncSnapshotCallable<SnapshotResult<KeyedStateHandle>> {
 
+		/** Supplier for the stream into which we write the snapshot. */
 		@Nonnull
 		private final SupplierWithException<CheckpointStreamWithResultProvider, Exception> checkpointStreamSupplier;
 
-		@Nonnull
-		private final CloseableRegistry snapshotCloseableRegistry;
-
+		/** This lease protects the native RocksDB resources. */
 		@Nonnull
 		private final ResourceGuard.Lease dbLease;
 
+		/** RocksDB snapshot. */
 		@Nonnull
 		private final Snapshot snapshot;
 
 		@Nonnull
-		private final ReadOptions readOptions;
-
-		/**
-		 * The state meta data.
-		 */
-		@Nonnull
 		private List<StateMetaInfoSnapshot> stateMetaInfoSnapshots;
 
-		/**
-		 * The copied column handle.
-		 */
 		@Nonnull
 		private List<Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase>> metaDataCopy;
 
-		private final AtomicBoolean ownedForCleanup;
+		@Nonnull
+		private final String logPathString;
 
-		RocksDBFullSnapshotCallable(
+		SnapshotAsynchronousPartCallable(
 			@Nonnull SupplierWithException<CheckpointStreamWithResultProvider, Exception> checkpointStreamSupplier,
-			@Nonnull CloseableRegistry registry) throws IOException {
+			@Nonnull ResourceGuard.Lease dbLease,
+			@Nonnull Snapshot snapshot,
+			@Nonnull List<StateMetaInfoSnapshot> stateMetaInfoSnapshots,
+			@Nonnull List<Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase>> metaDataCopy,
+			@Nonnull String logPathString) {
 
-			this.ownedForCleanup = new AtomicBoolean(false);
 			this.checkpointStreamSupplier = checkpointStreamSupplier;
-			this.keyGroupRangeOffsets = new KeyGroupRangeOffsets(keyGroupRange);
-			this.snapshotCloseableRegistry = registry;
-
-			this.stateMetaInfoSnapshots = new ArrayList<>(kvStateInformation.size());
-			this.metaDataCopy = new ArrayList<>(kvStateInformation.size());
-			for (Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> tuple2 : kvStateInformation.values()) {
-				// snapshot meta info
-				this.stateMetaInfoSnapshots.add(tuple2.f1.snapshot());
-				this.metaDataCopy.add(tuple2);
-			}
-
-			this.dbLease = rocksDBResourceGuard.acquireResource();
-
-			this.readOptions = new ReadOptions();
-			this.snapshot = db.getSnapshot();
-			this.readOptions.setSnapshot(snapshot);
+			this.dbLease = dbLease;
+			this.snapshot = snapshot;
+			this.stateMetaInfoSnapshots = stateMetaInfoSnapshots;
+			this.metaDataCopy = metaDataCopy;
+			this.logPathString = logPathString;
 		}
 
 		@Override
-		public SnapshotResult<KeyedStateHandle> call() throws Exception {
-
-			if (!ownedForCleanup.compareAndSet(false, true)) {
-				throw new CancellationException("Snapshot task was already cancelled, stopping execution.");
-			}
+		protected SnapshotResult<KeyedStateHandle> callInternal() throws Exception {
+			final KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(keyGroupRange);
+			final CheckpointStreamWithResultProvider checkpointStreamWithResultProvider =
+				checkpointStreamSupplier.get();
 
-			final long startTime = System.currentTimeMillis();
-			final List<Tuple2<RocksIteratorWrapper, Integer>> kvStateIterators = new ArrayList<>(metaDataCopy.size());
+			registerCloseableForCancellation(checkpointStreamWithResultProvider);
+			writeSnapshotToOutputStream(checkpointStreamWithResultProvider, keyGroupRangeOffsets);
 
-			try {
-
-				cancelStreamRegistry.registerCloseable(snapshotCloseableRegistry);
-
-				final CheckpointStreamWithResultProvider checkpointStreamWithResultProvider = checkpointStreamSupplier.get();
-				snapshotCloseableRegistry.registerCloseable(checkpointStreamWithResultProvider);
-
-				final DataOutputView outputView =
-					new DataOutputViewStreamWrapper(checkpointStreamWithResultProvider.getCheckpointOutputStream());
-
-				writeKVStateMetaData(kvStateIterators, outputView);
-				writeKVStateData(kvStateIterators, checkpointStreamWithResultProvider);
+			if (unregisterCloseableFromCancellation(checkpointStreamWithResultProvider)) {
+				return CheckpointStreamWithResultProvider.toKeyedStateHandleSnapshotResult(
+					checkpointStreamWithResultProvider.closeAndFinalizeCheckpointStreamResult(),
+					keyGroupRangeOffsets);
+			} else {
+				throw new IOException("Stream is already unregistered/closed.");
+			}
+		}
 
-				final SnapshotResult<KeyedStateHandle> snapshotResult =
-					createStateHandlesFromStreamProvider(checkpointStreamWithResultProvider);
+		@Override
+		protected void cleanupProvidedResources() {
+			db.releaseSnapshot(snapshot);
+			IOUtils.closeQuietly(snapshot);
+			IOUtils.closeQuietly(dbLease);
+		}
 
-				LOG.info("Asynchronous RocksDB snapshot ({}, asynchronous part) in thread {} took {} ms.",
-					checkpointStreamSupplier, Thread.currentThread(), (System.currentTimeMillis() - startTime));
+		@Override
+		protected void logAsyncSnapshotComplete(long startTime) {
+			logAsyncCompleted(logPathString, startTime);
+		}
 
-				return snapshotResult;
+		private void writeSnapshotToOutputStream(
+			@Nonnull CheckpointStreamWithResultProvider checkpointStreamWithResultProvider,
+			@Nonnull KeyGroupRangeOffsets keyGroupRangeOffsets) throws IOException, InterruptedException {
 
+			final List<Tuple2<RocksIteratorWrapper, Integer>> kvStateIterators =
+				new ArrayList<>(metaDataCopy.size());
+			final DataOutputView outputView =
+				new DataOutputViewStreamWrapper(checkpointStreamWithResultProvider.getCheckpointOutputStream());
+			final ReadOptions readOptions = new ReadOptions();
+			try {
+				readOptions.setSnapshot(snapshot);
+				writeKVStateMetaData(kvStateIterators, readOptions, outputView);
+				writeKVStateData(kvStateIterators, checkpointStreamWithResultProvider, keyGroupRangeOffsets);
 			} finally {
 
 				for (Tuple2<RocksIteratorWrapper, Integer> kvStateIterator : kvStateIterators) {
 					IOUtils.closeQuietly(kvStateIterator.f0);
 				}
 
-				cleanupSynchronousStepResources();
-			}
-		}
-
-		private void cleanupSynchronousStepResources() {
-			IOUtils.closeQuietly(readOptions);
-
-			db.releaseSnapshot(snapshot);
-			IOUtils.closeQuietly(snapshot);
-
-			IOUtils.closeQuietly(dbLease);
-
-			if (cancelStreamRegistry.unregisterCloseable(snapshotCloseableRegistry)) {
-				try {
-					snapshotCloseableRegistry.close();
-				} catch (Exception ex) {
-					LOG.warn("Error closing local registry", ex);
-				}
-			}
-		}
-
-		private SnapshotResult<KeyedStateHandle> createStateHandlesFromStreamProvider(
-			CheckpointStreamWithResultProvider checkpointStreamWithResultProvider) throws IOException {
-			if (snapshotCloseableRegistry.unregisterCloseable(checkpointStreamWithResultProvider)) {
-				return CheckpointStreamWithResultProvider.toKeyedStateHandleSnapshotResult(
-					checkpointStreamWithResultProvider.closeAndFinalizeCheckpointStreamResult(),
-					keyGroupRangeOffsets);
-			} else {
-				throw new IOException("Snapshot was already closed before completion.");
+				IOUtils.closeQuietly(readOptions);
 			}
 		}
 
 		private void writeKVStateMetaData(
 			final List<Tuple2<RocksIteratorWrapper, Integer>> kvStateIterators,
+			final ReadOptions readOptions,
 			final DataOutputView outputView) throws IOException {
 
 			int kvStateId = 0;
@@ -343,7 +297,8 @@ public class RocksFullSnapshotStrategy<K> extends SnapshotStrategyBase<K> {
 
 		private void writeKVStateData(
 			final List<Tuple2<RocksIteratorWrapper, Integer>> kvStateIterators,
-			final CheckpointStreamWithResultProvider checkpointStreamWithResultProvider) throws IOException, InterruptedException {
+			final CheckpointStreamWithResultProvider checkpointStreamWithResultProvider,
+			final KeyGroupRangeOffsets keyGroupRangeOffsets) throws IOException, InterruptedException {
 
 			byte[] previousKey = null;
 			byte[] previousValue = null;
@@ -445,18 +400,6 @@ public class RocksFullSnapshotStrategy<K> extends SnapshotStrategyBase<K> {
 				throw new InterruptedException("RocksDB snapshot interrupted.");
 			}
 		}
-
-		@Override
-		public void close() throws Exception {
-
-			if (ownedForCleanup.compareAndSet(false, true)) {
-				cleanupSynchronousStepResources();
-			}
-
-			if (cancelStreamRegistry.unregisterCloseable(snapshotCloseableRegistry)) {
-				snapshotCloseableRegistry.close();
-			}
-		}
 	}
 
 	@SuppressWarnings("unchecked")
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksIncrementalSnapshotStrategy.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksIncrementalSnapshotStrategy.java
index 3487fe6..8117031 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksIncrementalSnapshotStrategy.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksIncrementalSnapshotStrategy.java
@@ -28,12 +28,11 @@ import org.apache.flink.core.fs.Path;
 import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
-import org.apache.flink.runtime.checkpoint.CheckpointType;
+import org.apache.flink.runtime.state.AsyncSnapshotCallable;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.CheckpointStreamWithResultProvider;
 import org.apache.flink.runtime.state.CheckpointedStateScope;
 import org.apache.flink.runtime.state.DirectoryStateHandle;
-import org.apache.flink.runtime.state.DoneFuture;
 import org.apache.flink.runtime.state.IncrementalKeyedStateHandle;
 import org.apache.flink.runtime.state.IncrementalLocalKeyedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupRange;
@@ -45,7 +44,6 @@ import org.apache.flink.runtime.state.PlaceholderStreamStateHandle;
 import org.apache.flink.runtime.state.RegisteredStateMetaInfoBase;
 import org.apache.flink.runtime.state.SnapshotDirectory;
 import org.apache.flink.runtime.state.SnapshotResult;
-import org.apache.flink.runtime.state.SnapshotStrategy;
 import org.apache.flink.runtime.state.StateHandleID;
 import org.apache.flink.runtime.state.StateObject;
 import org.apache.flink.runtime.state.StateUtil;
@@ -65,11 +63,11 @@ import org.slf4j.LoggerFactory;
 
 import javax.annotation.Nonnegative;
 import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
 
 import java.io.File;
 import java.io.IOException;
 import java.util.ArrayList;
-import java.util.Collection;
 import java.util.HashMap;
 import java.util.LinkedHashMap;
 import java.util.List;
@@ -77,7 +75,6 @@ import java.util.Map;
 import java.util.Set;
 import java.util.SortedMap;
 import java.util.UUID;
-import java.util.concurrent.FutureTask;
 import java.util.concurrent.RunnableFuture;
 
 import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUtil.SST_FILE_SUFFIX;
@@ -88,10 +85,12 @@ import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUti
  *
  * @param <K> type of the backend keys.
  */
-public class RocksIncrementalSnapshotStrategy<K> extends SnapshotStrategyBase<K> {
+public class RocksIncrementalSnapshotStrategy<K> extends RocksDBSnapshotStrategyBase<K> {
 
 	private static final Logger LOG = LoggerFactory.getLogger(RocksIncrementalSnapshotStrategy.class);
 
+	private static final String DESCRIPTION = "Asynchronous incremental RocksDB snapshot";
+
 	/** Base path of the RocksDB instance. */
 	@Nonnull
 	private final File instanceBasePath;
@@ -107,10 +106,6 @@ public class RocksIncrementalSnapshotStrategy<K> extends SnapshotStrategyBase<K>
 	/** The identifier of the last completed checkpoint. */
 	private long lastCompletedCheckpointId;
 
-	/** We delegate snapshots that are for savepoints to this. */
-	@Nonnull
-	private final SnapshotStrategy<SnapshotResult<KeyedStateHandle>> savepointDelegate;
-
 	public RocksIncrementalSnapshotStrategy(
 		@Nonnull RocksDB db,
 		@Nonnull ResourceGuard rocksDBResourceGuard,
@@ -123,10 +118,10 @@ public class RocksIncrementalSnapshotStrategy<K> extends SnapshotStrategyBase<K>
 		@Nonnull File instanceBasePath,
 		@Nonnull UUID backendUID,
 		@Nonnull SortedMap<Long, Set<StateHandleID>> materializedSstFiles,
-		long lastCompletedCheckpointId,
-		@Nonnull SnapshotStrategy<SnapshotResult<KeyedStateHandle>> savepointDelegate) {
+		long lastCompletedCheckpointId) {
 
 		super(
+			DESCRIPTION,
 			db,
 			rocksDBResourceGuard,
 			keySerializer,
@@ -140,33 +135,47 @@ public class RocksIncrementalSnapshotStrategy<K> extends SnapshotStrategyBase<K>
 		this.backendUID = backendUID;
 		this.materializedSstFiles = materializedSstFiles;
 		this.lastCompletedCheckpointId = lastCompletedCheckpointId;
-		this.savepointDelegate = savepointDelegate;
 	}
 
+	@Nonnull
 	@Override
-	public RunnableFuture<SnapshotResult<KeyedStateHandle>> performSnapshot(
+	protected RunnableFuture<SnapshotResult<KeyedStateHandle>> doSnapshot(
 		long checkpointId,
 		long checkpointTimestamp,
-		CheckpointStreamFactory checkpointStreamFactory,
-		CheckpointOptions checkpointOptions) throws Exception {
+		@Nonnull CheckpointStreamFactory checkpointStreamFactory,
+		@Nonnull CheckpointOptions checkpointOptions) throws Exception {
 
-		// for savepoints, we delegate to the full snapshot strategy because savepoints are always self-contained.
-		if (CheckpointType.SAVEPOINT == checkpointOptions.getCheckpointType()) {
-			return savepointDelegate.performSnapshot(
+		final SnapshotDirectory snapshotDirectory = prepareLocalSnapshotDirectory(checkpointId);
+		LOG.trace("Local RocksDB checkpoint goes to backup path {}.", snapshotDirectory);
+
+		final List<StateMetaInfoSnapshot> stateMetaInfoSnapshots = new ArrayList<>(kvStateInformation.size());
+		final Set<StateHandleID> baseSstFiles = snapshotMetaData(checkpointId, stateMetaInfoSnapshots);
+
+		takeDBNativeCheckpoint(snapshotDirectory);
+
+		final RocksDBIncrementalSnapshotOperation snapshotOperation =
+			new RocksDBIncrementalSnapshotOperation(
 				checkpointId,
-				checkpointTimestamp,
 				checkpointStreamFactory,
-				checkpointOptions);
-		}
+				snapshotDirectory,
+				baseSstFiles,
+				stateMetaInfoSnapshots);
 
-		if (kvStateInformation.isEmpty()) {
-			if (LOG.isDebugEnabled()) {
-				LOG.debug("Asynchronous RocksDB snapshot performed on empty keyed state at {}. Returning null.", checkpointTimestamp);
+		return snapshotOperation.toAsyncSnapshotFutureTask(cancelStreamRegistry);
+	}
+
+	@Override
+	public void notifyCheckpointComplete(long completedCheckpointId) {
+		synchronized (materializedSstFiles) {
+			if (completedCheckpointId > lastCompletedCheckpointId) {
+				materializedSstFiles.keySet().removeIf(checkpointId -> checkpointId < completedCheckpointId);
+				lastCompletedCheckpointId = completedCheckpointId;
 			}
-			return DoneFuture.of(SnapshotResult.empty());
 		}
+	}
 
-		SnapshotDirectory snapshotDirectory;
+	@Nonnull
+	private SnapshotDirectory prepareLocalSnapshotDirectory(long checkpointId) throws IOException {
 
 		if (localRecoveryConfig.isLocalRecoveryEnabled()) {
 			// create a "permanent" snapshot directory for local recovery.
@@ -186,254 +195,217 @@ public class RocksIncrementalSnapshotStrategy<K> extends SnapshotStrategyBase<K>
 			File rdbSnapshotDir = new File(directory, "rocks_db");
 			Path path = new Path(rdbSnapshotDir.toURI());
 			// create a "permanent" snapshot directory because local recovery is active.
-			snapshotDirectory = SnapshotDirectory.permanent(path);
+			try {
+				return SnapshotDirectory.permanent(path);
+			} catch (IOException ex) {
+				try {
+					FileUtils.deleteDirectory(directory);
+				} catch (IOException delEx) {
+					ex = ExceptionUtils.firstOrSuppressed(delEx, ex);
+				}
+				throw ex;
+			}
 		} else {
 			// create a "temporary" snapshot directory because local recovery is inactive.
 			Path path = new Path(instanceBasePath.getAbsolutePath(), "chk-" + checkpointId);
-			snapshotDirectory = SnapshotDirectory.temporary(path);
-		}
-
-		final RocksDBIncrementalSnapshotOperation snapshotOperation =
-			new RocksDBIncrementalSnapshotOperation(
-				checkpointStreamFactory,
-				snapshotDirectory,
-				checkpointId);
-
-		try {
-			snapshotOperation.takeSnapshot();
-		} catch (Exception e) {
-			snapshotOperation.stop();
-			snapshotOperation.releaseResources(true);
-			throw e;
+			return SnapshotDirectory.temporary(path);
 		}
+	}
 
-		return new FutureTask<SnapshotResult<KeyedStateHandle>>(
-			snapshotOperation::runSnapshot
-		) {
-			@Override
-			public boolean cancel(boolean mayInterruptIfRunning) {
-				snapshotOperation.stop();
-				return super.cancel(mayInterruptIfRunning);
-			}
+	private Set<StateHandleID> snapshotMetaData(
+		long checkpointId,
+		@Nonnull List<StateMetaInfoSnapshot> stateMetaInfoSnapshots) {
 
-			@Override
-			protected void done() {
-				snapshotOperation.releaseResources(isCancelled());
-			}
-		};
-	}
+		final long lastCompletedCheckpoint;
+		final Set<StateHandleID> baseSstFiles;
 
-	@Override
-	public void notifyCheckpointComplete(long completedCheckpointId) {
+		// use the last completed checkpoint as the comparison base.
 		synchronized (materializedSstFiles) {
+			lastCompletedCheckpoint = lastCompletedCheckpointId;
+			baseSstFiles = materializedSstFiles.get(lastCompletedCheckpoint);
+		}
+		LOG.trace("Taking incremental snapshot for checkpoint {}. Snapshot is based on last completed checkpoint {} " +
+			"assuming the following (shared) files as base: {}.", checkpointId, lastCompletedCheckpoint, baseSstFiles);
 
-			if (completedCheckpointId < lastCompletedCheckpointId) {
-				return;
-			}
-
-			materializedSstFiles.keySet().removeIf(checkpointId -> checkpointId < completedCheckpointId);
+		// snapshot meta data to save
+		for (Map.Entry<String, Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase>> stateMetaInfoEntry
+			: kvStateInformation.entrySet()) {
+			stateMetaInfoSnapshots.add(stateMetaInfoEntry.getValue().f1.snapshot());
+		}
+		return baseSstFiles;
+	}
 
-			lastCompletedCheckpointId = completedCheckpointId;
+	private void takeDBNativeCheckpoint(@Nonnull SnapshotDirectory outputDirectory) throws Exception {
+		// create hard links of living files in the output path
+		try (
+			ResourceGuard.Lease ignored = rocksDBResourceGuard.acquireResource();
+			Checkpoint checkpoint = Checkpoint.create(db)) {
+			checkpoint.createCheckpoint(outputDirectory.getDirectory().getPath());
+		} catch (Exception ex) {
+			try {
+				outputDirectory.cleanup();
+			} catch (IOException cleanupEx) {
+				ex = ExceptionUtils.firstOrSuppressed(cleanupEx, ex);
+			}
+			throw ex;
 		}
 	}
 
 	/**
 	 * Encapsulates the process to perform an incremental snapshot of a RocksDBKeyedStateBackend.
 	 */
-	private final class RocksDBIncrementalSnapshotOperation {
+	private final class RocksDBIncrementalSnapshotOperation
+		extends AsyncSnapshotCallable<SnapshotResult<KeyedStateHandle>> {
 
-		/**
-		 * Stream factory that creates the outpus streams to DFS.
-		 */
-		private final CheckpointStreamFactory checkpointStreamFactory;
+		private static final int READ_BUFFER_SIZE = 16 * 1024;
 
-		/**
-		 * Id for the current checkpoint.
-		 */
+		/** Id for the current checkpoint. */
 		private final long checkpointId;
 
-		/**
-		 * All sst files that were part of the last previously completed checkpoint.
-		 */
-		private Set<StateHandleID> baseSstFiles;
+		/** Stream factory that creates the output streams to DFS. */
+		@Nonnull
+		private final CheckpointStreamFactory checkpointStreamFactory;
 
-		/**
-		 * The state meta data.
-		 */
+		/** The state meta data. */
+		@Nonnull
 		private final List<StateMetaInfoSnapshot> stateMetaInfoSnapshots;
 
-		/**
-		 * Local directory for the RocksDB native backup.
-		 */
-		private SnapshotDirectory localBackupDirectory;
-
-		// Registry for all opened i/o streams
-		private final CloseableRegistry closeableRegistry;
-
-		// new sst files since the last completed checkpoint
-		private final Map<StateHandleID, StreamStateHandle> sstFiles;
-
-		// handles to the misc files in the current snapshot
-		private final Map<StateHandleID, StreamStateHandle> miscFiles;
-
-		// This lease protects from concurrent disposal of the native rocksdb instance.
-		private final ResourceGuard.Lease dbLease;
+		/** Local directory for the RocksDB native backup. */
+		@Nonnull
+		private final SnapshotDirectory localBackupDirectory;
 
-		private SnapshotResult<StreamStateHandle> metaStateHandle;
+		/** All sst files that were part of the last previously completed checkpoint. */
+		@Nullable
+		private final Set<StateHandleID> baseSstFiles;
 
 		private RocksDBIncrementalSnapshotOperation(
-			CheckpointStreamFactory checkpointStreamFactory,
-			SnapshotDirectory localBackupDirectory,
-			long checkpointId) throws IOException {
+			long checkpointId,
+			@Nonnull CheckpointStreamFactory checkpointStreamFactory,
+			@Nonnull SnapshotDirectory localBackupDirectory,
+			@Nullable Set<StateHandleID> baseSstFiles,
+			@Nonnull List<StateMetaInfoSnapshot> stateMetaInfoSnapshots) {
 
 			this.checkpointStreamFactory = checkpointStreamFactory;
+			this.baseSstFiles = baseSstFiles;
 			this.checkpointId = checkpointId;
 			this.localBackupDirectory = localBackupDirectory;
-			this.stateMetaInfoSnapshots = new ArrayList<>();
-			this.closeableRegistry = new CloseableRegistry();
-			this.sstFiles = new HashMap<>();
-			this.miscFiles = new HashMap<>();
-			this.metaStateHandle = null;
-			this.dbLease = rocksDBResourceGuard.acquireResource();
+			this.stateMetaInfoSnapshots = stateMetaInfoSnapshots;
 		}
 
-		private StreamStateHandle materializeStateData(Path filePath) throws Exception {
-			FSDataInputStream inputStream = null;
-			CheckpointStreamFactory.CheckpointStateOutputStream outputStream = null;
+		@Override
+		protected SnapshotResult<KeyedStateHandle> callInternal() throws Exception {
 
-			try {
-				final byte[] buffer = new byte[8 * 1024];
+			boolean completed = false;
 
-				FileSystem backupFileSystem = localBackupDirectory.getFileSystem();
-				inputStream = backupFileSystem.open(filePath);
-				closeableRegistry.registerCloseable(inputStream);
+			// Handle to the meta data file
+			SnapshotResult<StreamStateHandle> metaStateHandle = null;
+			// Handles to new sst files since the last completed checkpoint will go here
+			final Map<StateHandleID, StreamStateHandle> sstFiles = new HashMap<>();
+			// Handles to the misc files in the current snapshot will go here
+			final Map<StateHandleID, StreamStateHandle> miscFiles = new HashMap<>();
 
-				outputStream = checkpointStreamFactory
-					.createCheckpointStateOutputStream(CheckpointedStateScope.SHARED);
-				closeableRegistry.registerCloseable(outputStream);
+			try {
 
-				while (true) {
-					int numBytes = inputStream.read(buffer);
+				metaStateHandle = materializeMetaData();
 
-					if (numBytes == -1) {
-						break;
-					}
+				// Sanity checks - they should never fail
+				Preconditions.checkNotNull(metaStateHandle, "Metadata was not properly created.");
+				Preconditions.checkNotNull(metaStateHandle.getJobManagerOwnedSnapshot(),
+					"Metadata for job manager was not properly created.");
 
-					outputStream.write(buffer, 0, numBytes);
-				}
+				uploadSstFiles(sstFiles, miscFiles);
 
-				StreamStateHandle result = null;
-				if (closeableRegistry.unregisterCloseable(outputStream)) {
-					result = outputStream.closeAndGetHandle();
-					outputStream = null;
+				synchronized (materializedSstFiles) {
+					materializedSstFiles.put(checkpointId, sstFiles.keySet());
 				}
-				return result;
 
-			} finally {
-
-				if (closeableRegistry.unregisterCloseable(inputStream)) {
-					inputStream.close();
+				final IncrementalKeyedStateHandle jmIncrementalKeyedStateHandle =
+					new IncrementalKeyedStateHandle(
+						backendUID,
+						keyGroupRange,
+						checkpointId,
+						sstFiles,
+						miscFiles,
+						metaStateHandle.getJobManagerOwnedSnapshot());
+
+				final DirectoryStateHandle directoryStateHandle = localBackupDirectory.completeSnapshotAndGetHandle();
+				final SnapshotResult<KeyedStateHandle> snapshotResult;
+				if (directoryStateHandle != null && metaStateHandle.getTaskLocalSnapshot() != null) {
+
+					IncrementalLocalKeyedStateHandle localDirKeyedStateHandle =
+						new IncrementalLocalKeyedStateHandle(
+							backendUID,
+							checkpointId,
+							directoryStateHandle,
+							keyGroupRange,
+							metaStateHandle.getTaskLocalSnapshot(),
+							sstFiles.keySet());
+
+					snapshotResult = SnapshotResult.withLocalState(jmIncrementalKeyedStateHandle, localDirKeyedStateHandle);
+				} else {
+					snapshotResult = SnapshotResult.of(jmIncrementalKeyedStateHandle);
 				}
 
-				if (closeableRegistry.unregisterCloseable(outputStream)) {
-					outputStream.close();
+				completed = true;
+
+				return snapshotResult;
+			} finally {
+				if (!completed) {
+					final List<StateObject> statesToDiscard =
+						new ArrayList<>(1 + miscFiles.size() + sstFiles.size());
+					statesToDiscard.add(metaStateHandle);
+					statesToDiscard.addAll(miscFiles.values());
+					statesToDiscard.addAll(sstFiles.values());
+					cleanupIncompleteSnapshot(statesToDiscard);
 				}
 			}
 		}
 
-		@Nonnull
-		private SnapshotResult<StreamStateHandle> materializeMetaData() throws Exception {
-
-			CheckpointStreamWithResultProvider streamWithResultProvider =
-
-				localRecoveryConfig.isLocalRecoveryEnabled() ?
-
-					CheckpointStreamWithResultProvider.createDuplicatingStream(
-						checkpointId,
-						CheckpointedStateScope.EXCLUSIVE,
-						checkpointStreamFactory,
-						localRecoveryConfig.getLocalStateDirectoryProvider()) :
-
-					CheckpointStreamWithResultProvider.createSimpleStream(
-						CheckpointedStateScope.EXCLUSIVE,
-						checkpointStreamFactory);
-
+		@Override
+		protected void cleanupProvidedResources() {
 			try {
-				closeableRegistry.registerCloseable(streamWithResultProvider);
-
-				//no need for compression scheme support because sst-files are already compressed
-				KeyedBackendSerializationProxy<K> serializationProxy =
-					new KeyedBackendSerializationProxy<>(
-						keySerializer,
-						stateMetaInfoSnapshots,
-						false);
-
-				DataOutputView out =
-					new DataOutputViewStreamWrapper(streamWithResultProvider.getCheckpointOutputStream());
-
-				serializationProxy.write(out);
+				if (localBackupDirectory.exists()) {
+					LOG.trace("Running cleanup for local RocksDB backup directory {}.", localBackupDirectory);
+					boolean cleanupOk = localBackupDirectory.cleanup();
 
-				if (closeableRegistry.unregisterCloseable(streamWithResultProvider)) {
-					SnapshotResult<StreamStateHandle> result =
-						streamWithResultProvider.closeAndFinalizeCheckpointStreamResult();
-					streamWithResultProvider = null;
-					return result;
-				} else {
-					throw new IOException("Stream already closed and cannot return a handle.");
-				}
-			} finally {
-				if (streamWithResultProvider != null) {
-					if (closeableRegistry.unregisterCloseable(streamWithResultProvider)) {
-						IOUtils.closeQuietly(streamWithResultProvider);
+					if (!cleanupOk) {
+						LOG.debug("Could not properly cleanup local RocksDB backup directory.");
 					}
 				}
+			} catch (IOException e) {
+				LOG.warn("Could not properly cleanup local RocksDB backup directory.", e);
 			}
 		}
 
-		void takeSnapshot() throws Exception {
-
-			final long lastCompletedCheckpoint;
-
-			// use the last completed checkpoint as the comparison base.
-			synchronized (materializedSstFiles) {
-				lastCompletedCheckpoint = lastCompletedCheckpointId;
-				baseSstFiles = materializedSstFiles.get(lastCompletedCheckpoint);
-			}
-
-			LOG.trace("Taking incremental snapshot for checkpoint {}. Snapshot is based on last completed checkpoint {} " +
-				"assuming the following (shared) files as base: {}.", checkpointId, lastCompletedCheckpoint, baseSstFiles);
-
-			// save meta data
-			for (Map.Entry<String, Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase>> stateMetaInfoEntry
-				: kvStateInformation.entrySet()) {
-				stateMetaInfoSnapshots.add(stateMetaInfoEntry.getValue().f1.snapshot());
-			}
+		@Override
+		protected void logAsyncSnapshotComplete(long startTime) {
+			logAsyncCompleted(checkpointStreamFactory, startTime);
+		}
 
-			LOG.trace("Local RocksDB checkpoint goes to backup path {}.", localBackupDirectory);
+		private void cleanupIncompleteSnapshot(@Nonnull List<StateObject> statesToDiscard) {
 
-			if (localBackupDirectory.exists()) {
-				throw new IllegalStateException("Unexpected existence of the backup directory.");
+			try {
+				StateUtil.bestEffortDiscardAllStateObjects(statesToDiscard);
+			} catch (Exception e) {
+				LOG.warn("Could not properly discard states.", e);
 			}
 
-			// create hard links of living files in the snapshot path
-			try (Checkpoint checkpoint = Checkpoint.create(db)) {
-				checkpoint.createCheckpoint(localBackupDirectory.getDirectory().getPath());
+			if (localBackupDirectory.isSnapshotCompleted()) {
+				try {
+					DirectoryStateHandle directoryStateHandle =
+						localBackupDirectory.completeSnapshotAndGetHandle();
+					if (directoryStateHandle != null) {
+						directoryStateHandle.discardState();
+					}
+				} catch (Exception e) {
+					LOG.warn("Could not properly discard local state.", e);
+				}
 			}
 		}
 
-		@Nonnull
-		SnapshotResult<KeyedStateHandle> runSnapshot() throws Exception {
-
-			cancelStreamRegistry.registerCloseable(closeableRegistry);
-
-			// write meta data
-			metaStateHandle = materializeMetaData();
-
-			// sanity checks - they should never fail
-			Preconditions.checkNotNull(metaStateHandle,
-				"Metadata was not properly created.");
-			Preconditions.checkNotNull(metaStateHandle.getJobManagerOwnedSnapshot(),
-				"Metadata for job manager was not properly created.");
+		private void uploadSstFiles(
+			@Nonnull Map<StateHandleID, StreamStateHandle> sstFiles,
+			@Nonnull Map<StateHandleID, StreamStateHandle> miscFiles) throws Exception {
 
 			// write state data
 			Preconditions.checkState(localBackupDirectory.exists());
@@ -456,120 +428,104 @@ public class RocksIncrementalSnapshotStrategy<K> extends SnapshotStrategyBase<K>
 								stateHandleID,
 								new PlaceholderStreamStateHandle());
 						} else {
-							sstFiles.put(stateHandleID, materializeStateData(filePath));
+							sstFiles.put(stateHandleID, uploadLocalFileToCheckpointFs(filePath));
 						}
 					} else {
-						StreamStateHandle fileHandle = materializeStateData(filePath);
+						StreamStateHandle fileHandle = uploadLocalFileToCheckpointFs(filePath);
 						miscFiles.put(stateHandleID, fileHandle);
 					}
 				}
 			}
+		}
 
-			synchronized (materializedSstFiles) {
-				materializedSstFiles.put(checkpointId, sstFiles.keySet());
-			}
+		private StreamStateHandle uploadLocalFileToCheckpointFs(Path filePath) throws Exception {
+			FSDataInputStream inputStream = null;
+			CheckpointStreamFactory.CheckpointStateOutputStream outputStream = null;
 
-			IncrementalKeyedStateHandle jmIncrementalKeyedStateHandle = new IncrementalKeyedStateHandle(
-				backendUID,
-				keyGroupRange,
-				checkpointId,
-				sstFiles,
-				miscFiles,
-				metaStateHandle.getJobManagerOwnedSnapshot());
+			try {
+				final byte[] buffer = new byte[READ_BUFFER_SIZE];
 
-			StreamStateHandle taskLocalSnapshotMetaDataStateHandle = metaStateHandle.getTaskLocalSnapshot();
-			DirectoryStateHandle directoryStateHandle = null;
+				FileSystem backupFileSystem = localBackupDirectory.getFileSystem();
+				inputStream = backupFileSystem.open(filePath);
+				registerCloseableForCancellation(inputStream);
 
-			try {
+				outputStream = checkpointStreamFactory
+					.createCheckpointStateOutputStream(CheckpointedStateScope.SHARED);
+				registerCloseableForCancellation(outputStream);
 
-				directoryStateHandle = localBackupDirectory.completeSnapshotAndGetHandle();
-			} catch (IOException ex) {
+				while (true) {
+					int numBytes = inputStream.read(buffer);
 
-				Exception collector = ex;
+					if (numBytes == -1) {
+						break;
+					}
 
-				try {
-					taskLocalSnapshotMetaDataStateHandle.discardState();
-				} catch (Exception discardEx) {
-					collector = ExceptionUtils.firstOrSuppressed(discardEx, collector);
+					outputStream.write(buffer, 0, numBytes);
 				}
 
-				LOG.warn("Problem with local state snapshot.", collector);
-			}
-
-			if (directoryStateHandle != null && taskLocalSnapshotMetaDataStateHandle != null) {
+				StreamStateHandle result = null;
+				if (unregisterCloseableFromCancellation(outputStream)) {
+					result = outputStream.closeAndGetHandle();
+					outputStream = null;
+				}
+				return result;
 
-				IncrementalLocalKeyedStateHandle localDirKeyedStateHandle =
-					new IncrementalLocalKeyedStateHandle(
-						backendUID,
-						checkpointId,
-						directoryStateHandle,
-						keyGroupRange,
-						taskLocalSnapshotMetaDataStateHandle,
-						sstFiles.keySet());
-				return SnapshotResult.withLocalState(jmIncrementalKeyedStateHandle, localDirKeyedStateHandle);
-			} else {
-				return SnapshotResult.of(jmIncrementalKeyedStateHandle);
-			}
-		}
+			} finally {
 
-		void stop() {
+				if (unregisterCloseableFromCancellation(inputStream)) {
+					IOUtils.closeQuietly(inputStream);
+				}
 
-			if (cancelStreamRegistry.unregisterCloseable(closeableRegistry)) {
-				try {
-					closeableRegistry.close();
-				} catch (IOException e) {
-					LOG.warn("Could not properly close io streams.", e);
+				if (unregisterCloseableFromCancellation(outputStream)) {
+					IOUtils.closeQuietly(outputStream);
 				}
 			}
 		}
 
-		void releaseResources(boolean canceled) {
+		@Nonnull
+		private SnapshotResult<StreamStateHandle> materializeMetaData() throws Exception {
 
-			dbLease.close();
+			CheckpointStreamWithResultProvider streamWithResultProvider =
 
-			if (cancelStreamRegistry.unregisterCloseable(closeableRegistry)) {
-				try {
-					closeableRegistry.close();
-				} catch (IOException e) {
-					LOG.warn("Exception on closing registry.", e);
-				}
-			}
+				localRecoveryConfig.isLocalRecoveryEnabled() ?
 
-			try {
-				if (localBackupDirectory.exists()) {
-					LOG.trace("Running cleanup for local RocksDB backup directory {}.", localBackupDirectory);
-					boolean cleanupOk = localBackupDirectory.cleanup();
+					CheckpointStreamWithResultProvider.createDuplicatingStream(
+						checkpointId,
+						CheckpointedStateScope.EXCLUSIVE,
+						checkpointStreamFactory,
+						localRecoveryConfig.getLocalStateDirectoryProvider()) :
 
-					if (!cleanupOk) {
-						LOG.debug("Could not properly cleanup local RocksDB backup directory.");
-					}
-				}
-			} catch (IOException e) {
-				LOG.warn("Could not properly cleanup local RocksDB backup directory.", e);
-			}
+					CheckpointStreamWithResultProvider.createSimpleStream(
+						CheckpointedStateScope.EXCLUSIVE,
+						checkpointStreamFactory);
 
-			if (canceled) {
-				Collection<StateObject> statesToDiscard =
-					new ArrayList<>(1 + miscFiles.size() + sstFiles.size());
+			registerCloseableForCancellation(streamWithResultProvider);
 
-				statesToDiscard.add(metaStateHandle);
-				statesToDiscard.addAll(miscFiles.values());
-				statesToDiscard.addAll(sstFiles.values());
+			try {
+				//no need for compression scheme support because sst-files are already compressed
+				KeyedBackendSerializationProxy<K> serializationProxy =
+					new KeyedBackendSerializationProxy<>(
+						keySerializer,
+						stateMetaInfoSnapshots,
+						false);
 
-				try {
-					StateUtil.bestEffortDiscardAllStateObjects(statesToDiscard);
-				} catch (Exception e) {
-					LOG.warn("Could not properly discard states.", e);
-				}
+				DataOutputView out =
+					new DataOutputViewStreamWrapper(streamWithResultProvider.getCheckpointOutputStream());
 
-				if (localBackupDirectory.isSnapshotCompleted()) {
-					try {
-						DirectoryStateHandle directoryStateHandle = localBackupDirectory.completeSnapshotAndGetHandle();
-						if (directoryStateHandle != null) {
-							directoryStateHandle.discardState();
-						}
-					} catch (Exception e) {
-						LOG.warn("Could not properly discard local state.", e);
+				serializationProxy.write(out);
+
+				if (unregisterCloseableFromCancellation(streamWithResultProvider)) {
+					SnapshotResult<StreamStateHandle> result =
+						streamWithResultProvider.closeAndFinalizeCheckpointStreamResult();
+					streamWithResultProvider = null;
+					return result;
+				} else {
+					throw new IOException("Stream already closed and cannot return a handle.");
+				}
+			} finally {
+				if (streamWithResultProvider != null) {
+					if (unregisterCloseableFromCancellation(streamWithResultProvider)) {
+						IOUtils.closeQuietly(streamWithResultProvider);
 					}
 				}
 			}
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index db504d5..9ee8892 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -1076,7 +1076,7 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 				owner.asyncOperationsThreadPool.submit(asyncCheckpointRunnable);
 
 				if (LOG.isDebugEnabled()) {
-					LOG.debug("{} - finished synchronous part of checkpoint {}." +
+					LOG.debug("{} - finished synchronous part of checkpoint {}. " +
 							"Alignment duration: {} ms, snapshot duration {} ms",
 						owner.getName(), checkpointMetaData.getCheckpointId(),
 						checkpointMetrics.getAlignmentDurationNanos() / 1_000_000,
@@ -1095,7 +1095,7 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 				}
 
 				if (LOG.isDebugEnabled()) {
-					LOG.debug("{} - did NOT finish synchronous part of checkpoint {}." +
+					LOG.debug("{} - did NOT finish synchronous part of checkpoint {}. " +
 							"Alignment duration: {} ms, snapshot duration {} ms",
 						owner.getName(), checkpointMetaData.getCheckpointId(),
 						checkpointMetrics.getAlignmentDurationNanos() / 1_000_000,
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java
index d8f577d..cd8a4fa 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java
@@ -82,6 +82,7 @@ import org.apache.flink.util.TestLogger;
 import org.junit.Assert;
 import org.junit.Test;
 
+import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
 
 import java.io.IOException;
@@ -305,12 +306,13 @@ public class TaskCheckpointingBehaviourTest extends TestLogger {
 				env.getUserClassLoader(),
 				env.getExecutionConfig(),
 				true) {
+				@Nonnull
 				@Override
 				public RunnableFuture<SnapshotResult<OperatorStateHandle>> snapshot(
 					long checkpointId,
 					long timestamp,
-					CheckpointStreamFactory streamFactory,
-					CheckpointOptions checkpointOptions) throws Exception {
+					@Nonnull CheckpointStreamFactory streamFactory,
+					@Nonnull CheckpointOptions checkpointOptions) throws Exception {
 
 					throw new Exception("Sync part snapshot exception.");
 				}
@@ -334,12 +336,13 @@ public class TaskCheckpointingBehaviourTest extends TestLogger {
 				env.getUserClassLoader(),
 				env.getExecutionConfig(),
 				true) {
+				@Nonnull
 				@Override
 				public RunnableFuture<SnapshotResult<OperatorStateHandle>> snapshot(
 					long checkpointId,
 					long timestamp,
-					CheckpointStreamFactory streamFactory,
-					CheckpointOptions checkpointOptions) throws Exception {
+					@Nonnull CheckpointStreamFactory streamFactory,
+					@Nonnull CheckpointOptions checkpointOptions) throws Exception {
 
 					return new FutureTask<>(() -> {
 						throw new Exception("Async part snapshot exception.");
diff --git a/flink-test-utils-parent/flink-test-utils-junit/src/main/java/org/apache/flink/core/testutils/OneShotLatch.java b/flink-test-utils-parent/flink-test-utils-junit/src/main/java/org/apache/flink/core/testutils/OneShotLatch.java
index 7fed5eb..bef23bb 100644
--- a/flink-test-utils-parent/flink-test-utils-junit/src/main/java/org/apache/flink/core/testutils/OneShotLatch.java
+++ b/flink-test-utils-parent/flink-test-utils-junit/src/main/java/org/apache/flink/core/testutils/OneShotLatch.java
@@ -18,6 +18,9 @@
 
 package org.apache.flink.core.testutils;
 
+import java.util.Collections;
+import java.util.IdentityHashMap;
+import java.util.Set;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 
@@ -31,6 +34,7 @@ import java.util.concurrent.TimeoutException;
 public final class OneShotLatch {
 
 	private final Object lock = new Object();
+	private final Set<Thread> waitersSet = Collections.newSetFromMap(new IdentityHashMap<>());
 
 	private volatile boolean triggered;
 
@@ -53,7 +57,13 @@ public final class OneShotLatch {
 	public void await() throws InterruptedException {
 		synchronized (lock) {
 			while (!triggered) {
-				lock.wait();
+				Thread thread = Thread.currentThread();
+				try {
+					waitersSet.add(thread);
+					lock.wait();
+				} finally {
+					waitersSet.remove(thread);
+				}
 			}
 		}
 	}
@@ -108,6 +118,12 @@ public final class OneShotLatch {
 		return triggered;
 	}
 
+	public int getWaitersCount() {
+		synchronized (lock) {
+			return waitersSet.size();
+		}
+	}
+
 	/**
 	 * Resets the latch so that {@link #isTriggered()} returns false.
 	 */