You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by sr...@apache.org on 2017/05/05 14:30:28 UTC

[2/2] flink git commit: [FLINK-6364] [checkpoint] Incremental checkpointing in RocksDBKeyedStateBackend

[FLINK-6364] [checkpoint] Incremental checkpointing in RocksDBKeyedStateBackend


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/6e94cf19
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/6e94cf19
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/6e94cf19

Branch: refs/heads/master
Commit: 6e94cf19736b3b3751abe55cf0f3ce4aa740ef96
Parents: 5795ebe
Author: xiaogang.sxg <xi...@alibaba-inc.com>
Authored: Sat Apr 29 23:44:36 2017 +0800
Committer: Stefan Richter <s....@data-artisans.com>
Committed: Fri May 5 16:30:06 2017 +0200

----------------------------------------------------------------------
 .../RocksDBIncrementalKeyedStateHandle.java     | 248 +++++++
 .../state/RocksDBKeyedStateBackend.java         | 711 ++++++++++++++++++-
 .../streaming/state/RocksDBStateBackend.java    |  60 +-
 .../state/RocksDBAggregatingStateTest.java      |   6 +-
 .../state/RocksDBAsyncSnapshotTest.java         |  19 +-
 .../streaming/state/RocksDBListStateTest.java   |   6 +-
 .../state/RocksDBReducingStateTest.java         |   6 +-
 .../state/RocksDBStateBackendTest.java          |  41 +-
 .../flink/runtime/checkpoint/SubtaskState.java  |  16 +-
 .../state/AbstractKeyedStateBackend.java        |  10 +
 .../runtime/state/KeyGroupsStateHandle.java     |  10 +
 .../flink/runtime/state/KeyedStateHandle.java   |   2 +-
 .../apache/flink/runtime/state/StateUtil.java   |   9 +
 .../state/heap/HeapKeyedStateBackend.java       |   9 +
 .../runtime/state/StateBackendTestBase.java     |  15 +-
 .../api/operators/AbstractStreamOperator.java   |   6 +-
 .../streaming/runtime/tasks/StreamTask.java     |   7 +-
 .../KeyedOneInputStreamOperatorTestHarness.java |   4 +-
 .../PartitionedStateCheckpointingITCase.java    |  45 ++
 .../KVStateRequestSerializerRocksDBTest.java    |   8 +-
 20 files changed, 1163 insertions(+), 75 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/6e94cf19/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalKeyedStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalKeyedStateHandle.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalKeyedStateHandle.java
new file mode 100644
index 0000000..5ac9e46
--- /dev/null
+++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalKeyedStateHandle.java
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.contrib.streaming.state;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.runtime.state.CompositeStateHandle;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyedStateHandle;
+import org.apache.flink.runtime.state.SharedStateHandle;
+import org.apache.flink.runtime.state.SharedStateRegistry;
+import org.apache.flink.runtime.state.StateUtil;
+import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.util.Preconditions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Map;
+
+/**
+ * The handle to states in incremental snapshots taken by {@link RocksDBKeyedStateBackend}.
+ *
+ * The states contained in an incremental snapshot include
+ * <ul>
+ * <li> New SST state which includes the sst files produced since the last completed
+ *   checkpoint. These files can be referenced by succeeding checkpoints if the
+ *   checkpoint succeeds to complete. </li>
+ * <li> Old SST state which includes the sst files materialized in previous
+ *   checkpoints. </li>
+ * <li> MISC state which include the other files in the RocksDB instance, e.g. the
+ *   LOG and MANIFEST files. These files are mutable, hence cannot be shared by
+ *   other checkpoints. </li>
+ * <li> Meta state which includes the information of existing states. </li>
+ * </ul>
+ */
+public class RocksDBIncrementalKeyedStateHandle implements KeyedStateHandle, CompositeStateHandle {
+
+	private static final Logger LOG = LoggerFactory.getLogger(RocksDBIncrementalKeyedStateHandle.class);
+
+	private static final long serialVersionUID = -8328808513197388231L;
+
+	private final JobID jobId;
+
+	private final String operatorIdentifier;
+
+	private final KeyGroupRange keyGroupRange;
+
+	private final long checkpointId;
+
+	private final Map<String, StreamStateHandle> newSstFiles;
+
+	private final Map<String, StreamStateHandle> oldSstFiles;
+
+	private final Map<String, StreamStateHandle> miscFiles;
+
+	private final StreamStateHandle metaStateHandle;
+
+	/**
+	 * True if the state handle has already registered shared states.
+	 *
+	 * Once the shared states are registered, it's the {@link SharedStateRegistry}'s
+	 * responsibility to maintain the shared states. But in the cases where the
+	 * state handle is discarded before performing the registration, the handle
+	 * should delete all the shared states created by it.
+	 */
+	private boolean registered;
+
+	RocksDBIncrementalKeyedStateHandle(
+			JobID jobId,
+			String operatorIdentifier,
+			KeyGroupRange keyGroupRange,
+			long checkpointId,
+			Map<String, StreamStateHandle> newSstFiles,
+			Map<String, StreamStateHandle> oldSstFiles,
+			Map<String, StreamStateHandle> miscFiles,
+			StreamStateHandle metaStateHandle) {
+
+		this.jobId = Preconditions.checkNotNull(jobId);
+		this.operatorIdentifier = Preconditions.checkNotNull(operatorIdentifier);
+		this.keyGroupRange = Preconditions.checkNotNull(keyGroupRange);
+		this.checkpointId = checkpointId;
+		this.newSstFiles = Preconditions.checkNotNull(newSstFiles);
+		this.oldSstFiles = Preconditions.checkNotNull(oldSstFiles);
+		this.miscFiles = Preconditions.checkNotNull(miscFiles);
+		this.metaStateHandle = Preconditions.checkNotNull(metaStateHandle);
+		this.registered = false;
+	}
+
+	@Override
+	public KeyGroupRange getKeyGroupRange() {
+		return keyGroupRange;
+	}
+
+	long getCheckpointId() {
+		return checkpointId;
+	}
+
+	Map<String, StreamStateHandle> getNewSstFiles() {
+		return newSstFiles;
+	}
+
+	Map<String, StreamStateHandle> getOldSstFiles() {
+		return oldSstFiles;
+	}
+
+	Map<String, StreamStateHandle> getMiscFiles() {
+		return miscFiles;
+	}
+
+	StreamStateHandle getMetaStateHandle() {
+		return metaStateHandle;
+	}
+
+	@Override
+	public KeyedStateHandle getIntersection(KeyGroupRange keyGroupRange) {
+		if (this.keyGroupRange.getIntersection(keyGroupRange) != KeyGroupRange.EMPTY_KEY_GROUP_RANGE) {
+			return this;
+		} else {
+			return null;
+		}
+	}
+
+	@Override
+	public void discardState() throws Exception {
+
+		try {
+			metaStateHandle.discardState();
+		} catch (Exception e) {
+			LOG.warn("Could not properly discard meta data.", e);
+		}
+
+		try {
+			StateUtil.bestEffortDiscardAllStateObjects(miscFiles.values());
+		} catch (Exception e) {
+			LOG.warn("Could not properly discard misc file states.", e);
+		}
+
+		if (!registered) {
+			try {
+				StateUtil.bestEffortDiscardAllStateObjects(newSstFiles.values());
+			} catch (Exception e) {
+				LOG.warn("Could not properly discard new sst file states.", e);
+			}
+		}
+	}
+
+	@Override
+	public long getStateSize() {
+		long size = StateUtil.getStateSize(metaStateHandle);
+
+		for (StreamStateHandle newSstFileHandle : newSstFiles.values()) {
+			size += newSstFileHandle.getStateSize();
+		}
+
+		for (StreamStateHandle oldSstFileHandle : oldSstFiles.values()) {
+			size += oldSstFileHandle.getStateSize();
+		}
+
+		for (StreamStateHandle miscFileHandle : miscFiles.values()) {
+			size += miscFileHandle.getStateSize();
+		}
+
+		return size;
+	}
+
+	@Override
+	public void registerSharedStates(SharedStateRegistry stateRegistry) {
+		Preconditions.checkState(!registered, "The state handle has already registered its shared states.");
+
+		for (Map.Entry<String, StreamStateHandle> newSstFileEntry : newSstFiles.entrySet()) {
+			SstFileStateHandle stateHandle = new SstFileStateHandle(newSstFileEntry.getKey(), newSstFileEntry.getValue());
+
+			int referenceCount = stateRegistry.register(stateHandle);
+			Preconditions.checkState(referenceCount == 1);
+		}
+
+		for (Map.Entry<String, StreamStateHandle> oldSstFileEntry : oldSstFiles.entrySet()) {
+			SstFileStateHandle stateHandle = new SstFileStateHandle(oldSstFileEntry.getKey(), oldSstFileEntry.getValue());
+
+			int referenceCount = stateRegistry.register(stateHandle);
+			Preconditions.checkState(referenceCount > 1);
+		}
+
+		registered = true;
+	}
+
+	@Override
+	public void unregisterSharedStates(SharedStateRegistry stateRegistry) {
+		Preconditions.checkState(registered, "The state handle has not registered its shared states yet.");
+
+		for (Map.Entry<String, StreamStateHandle> newSstFileEntry : newSstFiles.entrySet()) {
+			stateRegistry.unregister(new SstFileStateHandle(newSstFileEntry.getKey(), newSstFileEntry.getValue()));
+		}
+
+		for (Map.Entry<String, StreamStateHandle> oldSstFileEntry : oldSstFiles.entrySet()) {
+			stateRegistry.unregister(new SstFileStateHandle(oldSstFileEntry.getKey(), oldSstFileEntry.getValue()));
+		}
+
+		registered = false;
+	}
+
+	private class SstFileStateHandle implements SharedStateHandle {
+
+		private static final long serialVersionUID = 9092049285789170669L;
+
+		private final String fileName;
+
+		private final StreamStateHandle delegateStateHandle;
+
+		private SstFileStateHandle(
+				String fileName,
+				StreamStateHandle delegateStateHandle) {
+			this.fileName = fileName;
+			this.delegateStateHandle = delegateStateHandle;
+		}
+
+		@Override
+		public String getRegistrationKey() {
+			return jobId + "-" + operatorIdentifier + "-" + keyGroupRange + "-" + fileName;
+		}
+
+		@Override
+		public void discardState() throws Exception {
+			delegateStateHandle.discardState();
+		}
+
+		@Override
+		public long getStateSize() {
+			return delegateStateHandle.getStateSize();
+		}
+	}
+}
+

http://git-wip-us.apache.org/repos/asf/flink/blob/6e94cf19/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index 199a5a4..ee5f956 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -31,7 +31,12 @@ import org.apache.flink.api.common.typeutils.base.array.BytePrimitiveArraySerial
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.typeutils.runtime.DataInputViewStream;
 import org.apache.flink.configuration.ConfigConstants;
+import org.apache.flink.core.fs.CloseableRegistry;
 import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.fs.FileStatus;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
 import org.apache.flink.core.memory.ByteArrayInputStreamWithPos;
 import org.apache.flink.core.memory.DataInputView;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
@@ -55,6 +60,8 @@ import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedBackendSerializationProxy;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.RegisteredBackendStateMetaInfo;
+import org.apache.flink.runtime.state.StateObject;
+import org.apache.flink.runtime.state.StateUtil;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.internal.InternalAggregatingState;
 import org.apache.flink.runtime.state.internal.InternalFoldingState;
@@ -67,6 +74,7 @@ import org.apache.flink.util.FileUtils;
 import org.apache.flink.util.IOUtils;
 import org.apache.flink.util.InstantiationUtil;
 import org.apache.flink.util.Preconditions;
+import org.rocksdb.Checkpoint;
 import org.rocksdb.ColumnFamilyDescriptor;
 import org.rocksdb.ColumnFamilyHandle;
 import org.rocksdb.ColumnFamilyOptions;
@@ -83,13 +91,20 @@ import java.io.EOFException;
 import java.io.File;
 import java.io.IOException;
 import java.io.ObjectInputStream;
+import java.nio.file.Files;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Comparator;
 import java.util.HashMap;
+import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.PriorityQueue;
+import java.util.SortedMap;
+import java.util.TreeMap;
+import java.util.UUID;
+import java.util.concurrent.Callable;
+import java.util.concurrent.FutureTask;
 import java.util.concurrent.RunnableFuture;
 
 /**
@@ -102,6 +117,10 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 	private static final Logger LOG = LoggerFactory.getLogger(RocksDBKeyedStateBackend.class);
 
+	private final JobID jobId;
+
+	private final String operatorIdentifier;
+
 	/** The column family options from the options factory */
 	private final ColumnFamilyOptions columnOptions;
 
@@ -137,6 +156,17 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	/** Number of bytes required to prefix the key groups. */
 	private final int keyGroupPrefixBytes;
 
+	/** True if incremental checkpointing is enabled */
+	private final boolean enableIncrementalCheckpointing;
+
+	/** The sst files materialized in pending checkpoints */
+	private final SortedMap<Long, Map<String, StreamStateHandle>> materializedSstFiles = new TreeMap<>();
+
+	/** The identifier of the last completed checkpoint */
+	private long lastCompletedCheckpointId = -1;
+
+	private static final String SST_FILE_SUFFIX = ".sst";
+
 	public RocksDBKeyedStateBackend(
 			JobID jobId,
 			String operatorIdentifier,
@@ -148,10 +178,17 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			TypeSerializer<K> keySerializer,
 			int numberOfKeyGroups,
 			KeyGroupRange keyGroupRange,
-			ExecutionConfig executionConfig
+			ExecutionConfig executionConfig,
+			boolean enableIncrementalCheckpointing
 	) throws IOException {
 
 		super(kvStateRegistry, keySerializer, userCodeClassLoader, numberOfKeyGroups, keyGroupRange, executionConfig);
+
+		this.jobId = Preconditions.checkNotNull(jobId);
+		this.operatorIdentifier = Preconditions.checkNotNull(operatorIdentifier);
+
+		this.enableIncrementalCheckpointing = enableIncrementalCheckpointing;
+
 		this.columnOptions = Preconditions.checkNotNull(columnFamilyOptions);
 		this.dbOptions = Preconditions.checkNotNull(dbOptions);
 
@@ -174,21 +211,6 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			throw new IOException("Error cleaning RocksDB data directory.", e);
 		}
 
-		List<ColumnFamilyDescriptor> columnFamilyDescriptors = new ArrayList<>(1);
-		// RocksDB seems to need this...
-		columnFamilyDescriptors.add(new ColumnFamilyDescriptor("default".getBytes(ConfigConstants.DEFAULT_CHARSET)));
-		List<ColumnFamilyHandle> columnFamilyHandles = new ArrayList<>(1);
-		try {
-
-			db = RocksDB.open(
-					Preconditions.checkNotNull(dbOptions),
-					instanceRocksDBPath.getAbsolutePath(),
-					columnFamilyDescriptors,
-					columnFamilyHandles);
-
-		} catch (RocksDBException e) {
-			throw new IOException("Error while opening RocksDB instance.", e);
-		}
 		keyGroupPrefixBytes = getNumberOfKeyGroups() > (Byte.MAX_VALUE + 1) ? 2 : 1;
 		kvStateInformation = new HashMap<>();
 	}
@@ -265,9 +287,71 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			final CheckpointStreamFactory streamFactory,
 			CheckpointOptions checkpointOptions) throws Exception {
 
+		if (checkpointOptions.getCheckpointType() != CheckpointOptions.CheckpointType.SAVEPOINT &&
+			enableIncrementalCheckpointing) {
+			return snapshotIncrementally(checkpointId, timestamp, streamFactory);
+		} else {
+			return snapshotFully(checkpointId, timestamp, streamFactory);
+		}
+	}
+
+	private RunnableFuture<KeyedStateHandle> snapshotIncrementally(
+			final long checkpointId,
+			final long checkpointTimestamp,
+			final CheckpointStreamFactory checkpointStreamFactory) throws Exception {
+
+		final RocksDBIncrementalSnapshotOperation snapshotOperation =
+			new RocksDBIncrementalSnapshotOperation(
+				this,
+				checkpointStreamFactory,
+				checkpointId,
+				checkpointTimestamp);
+
+		synchronized (asyncSnapshotLock) {
+			if (db == null) {
+				throw new IOException("RocksDB closed.");
+			}
+
+			if (!hasRegisteredState()) {
+				if (LOG.isDebugEnabled()) {
+					LOG.debug("Asynchronous RocksDB snapshot performed on empty keyed state at " +
+							checkpointTimestamp + " . Returning null.");
+				}
+				return DoneFuture.nullValue();
+			}
+
+			snapshotOperation.takeSnapshot();
+		}
+
+		return new FutureTask<KeyedStateHandle>(
+			new Callable<KeyedStateHandle>() {
+				@Override
+				public KeyedStateHandle call() throws Exception {
+					return snapshotOperation.materializeSnapshot();
+				}
+			}
+		) {
+			@Override
+			public boolean cancel(boolean mayInterruptIfRunning) {
+				snapshotOperation.stop();
+				return super.cancel(mayInterruptIfRunning);
+			}
+
+			@Override
+			protected void done() {
+				snapshotOperation.releaseResources(isCancelled());
+			}
+		};
+	}
+
+	private RunnableFuture<KeyedStateHandle> snapshotFully(
+			final long checkpointId,
+			final long timestamp,
+			final CheckpointStreamFactory streamFactory) throws Exception {
+
 		long startTime = System.currentTimeMillis();
 
-		final RocksDBSnapshotOperation snapshotOperation = new RocksDBSnapshotOperation(this, streamFactory);
+		final RocksDBFullSnapshotOperation snapshotOperation = new RocksDBFullSnapshotOperation(this, streamFactory);
 		// hold the db lock while operation on the db to guard us against async db disposal
 		synchronized (asyncSnapshotLock) {
 
@@ -342,7 +426,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	/**
 	 * Encapsulates the process to perform a snapshot of a RocksDBKeyedStateBackend.
 	 */
-	static final class RocksDBSnapshotOperation {
+	static final class RocksDBFullSnapshotOperation {
 
 		static final int FIRST_BIT_IN_BYTE_MASK = 0x80;
 		static final int END_OF_KEY_GROUP_MARK = 0xFFFF;
@@ -362,7 +446,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		private DataOutputView outputView;
 		private KeyGroupsStateHandle snapshotResultStateHandle;
 
-		RocksDBSnapshotOperation(
+		RocksDBFullSnapshotOperation(
 				RocksDBKeyedStateBackend<?> stateBackend,
 				CheckpointStreamFactory checkpointStreamFactory) {
 
@@ -607,11 +691,11 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		}
 
 		static void clearMetaDataFollowsFlag(byte[] key) {
-			key[0] &= (~RocksDBSnapshotOperation.FIRST_BIT_IN_BYTE_MASK);
+			key[0] &= (~RocksDBFullSnapshotOperation.FIRST_BIT_IN_BYTE_MASK);
 		}
 
 		static boolean hasMetaDataFollowsFlag(byte[] key) {
-			return 0 != (key[0] & RocksDBSnapshotOperation.FIRST_BIT_IN_BYTE_MASK);
+			return 0 != (key[0] & RocksDBFullSnapshotOperation.FIRST_BIT_IN_BYTE_MASK);
 		}
 
 		private static void checkInterrupted() throws InterruptedException {
@@ -621,6 +705,239 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		}
 	}
 
+	private static class RocksDBIncrementalSnapshotOperation {
+
+		private final RocksDBKeyedStateBackend<?> stateBackend;
+
+		private final CheckpointStreamFactory checkpointStreamFactory;
+
+		private final long checkpointId;
+
+		private final long checkpointTimestamp;
+
+		private Map<String, StreamStateHandle> baseSstFiles;
+
+		private List<KeyedBackendSerializationProxy.StateMetaInfo<?, ?>> stateMetaInfos = new ArrayList<>();
+
+		private FileSystem backupFileSystem;
+		private Path backupPath;
+
+		// Registry for all opened i/o streams
+		private CloseableRegistry closeableRegistry = new CloseableRegistry();
+
+		// new sst files since the last completed checkpoint
+		private Map<String, StreamStateHandle> newSstFiles = new HashMap<>();
+
+		// old sst files which have been materialized in previous completed checkpoints
+		private Map<String, StreamStateHandle> oldSstFiles = new HashMap<>();
+
+		// handles to the misc files in the current snapshot
+		private Map<String, StreamStateHandle> miscFiles = new HashMap<>();
+
+		private StreamStateHandle metaStateHandle = null;
+
+		private RocksDBIncrementalSnapshotOperation(
+				RocksDBKeyedStateBackend<?> stateBackend,
+				CheckpointStreamFactory checkpointStreamFactory,
+				long checkpointId,
+				long checkpointTimestamp) {
+
+			this.stateBackend = stateBackend;
+			this.checkpointStreamFactory = checkpointStreamFactory;
+			this.checkpointId = checkpointId;
+			this.checkpointTimestamp = checkpointTimestamp;
+		}
+
+		private StreamStateHandle materializeStateData(Path filePath) throws Exception {
+			FSDataInputStream inputStream = null;
+			CheckpointStreamFactory.CheckpointStateOutputStream outputStream = null;
+
+			try {
+				final byte[] buffer = new byte[1024];
+
+				FileSystem backupFileSystem = backupPath.getFileSystem();
+				inputStream = backupFileSystem.open(filePath);
+				closeableRegistry.registerClosable(inputStream);
+
+				outputStream = checkpointStreamFactory
+					.createCheckpointStateOutputStream(checkpointId, checkpointTimestamp);
+				closeableRegistry.registerClosable(outputStream);
+
+				while (true) {
+					int numBytes = inputStream.read(buffer);
+
+					if (numBytes == -1) {
+						break;
+					}
+
+					outputStream.write(buffer, 0, numBytes);
+				}
+
+				closeableRegistry.unregisterClosable(outputStream);
+				StreamStateHandle result = outputStream.closeAndGetHandle();
+				outputStream = null;
+
+				return result;
+			} finally {
+				if (inputStream != null) {
+					closeableRegistry.unregisterClosable(inputStream);
+					inputStream.close();
+				}
+
+				if (outputStream != null) {
+					closeableRegistry.unregisterClosable(outputStream);
+					outputStream.close();
+				}
+			}
+		}
+
+		private StreamStateHandle materializeMetaData() throws Exception {
+			CheckpointStreamFactory.CheckpointStateOutputStream outputStream = null;
+
+			try {
+				outputStream = checkpointStreamFactory
+					.createCheckpointStateOutputStream(checkpointId, checkpointTimestamp);
+				stateBackend.cancelStreamRegistry.registerClosable(outputStream);
+
+				KeyedBackendSerializationProxy serializationProxy =
+					new KeyedBackendSerializationProxy(stateBackend.keySerializer, stateMetaInfos);
+				DataOutputView out = new DataOutputViewStreamWrapper(outputStream);
+
+				serializationProxy.write(out);
+
+				stateBackend.cancelStreamRegistry.unregisterClosable(outputStream);
+				StreamStateHandle result = outputStream.closeAndGetHandle();
+				outputStream = null;
+
+				return result;
+			} finally {
+				if (outputStream != null) {
+					stateBackend.cancelStreamRegistry.unregisterClosable(outputStream);
+					outputStream.close();
+				}
+			}
+		}
+
+		void takeSnapshot() throws Exception {
+			// use the last completed checkpoint as the comparison base.
+			baseSstFiles = stateBackend.materializedSstFiles.get(stateBackend.lastCompletedCheckpointId);
+
+			// save meta data
+			for (Map.Entry<String, Tuple2<ColumnFamilyHandle, RegisteredBackendStateMetaInfo<?, ?>>> stateMetaInfoEntry : stateBackend.kvStateInformation.entrySet()) {
+
+				RegisteredBackendStateMetaInfo<?, ?> metaInfo = stateMetaInfoEntry.getValue().f1;
+
+				KeyedBackendSerializationProxy.StateMetaInfo<?, ?> metaInfoProxy =
+					new KeyedBackendSerializationProxy.StateMetaInfo<>(
+						metaInfo.getStateType(),
+						metaInfo.getName(),
+						metaInfo.getNamespaceSerializer(),
+						metaInfo.getStateSerializer());
+
+				stateMetaInfos.add(metaInfoProxy);
+			}
+
+			// save state data
+			backupPath = new Path(stateBackend.instanceBasePath.getAbsolutePath(), "chk-" + checkpointId);
+			backupFileSystem = backupPath.getFileSystem();
+			if (backupFileSystem.exists(backupPath)) {
+				throw new IllegalStateException("Unexpected existence of the backup directory.");
+			}
+
+			// create hard links of living files in the checkpoint path
+			Checkpoint checkpoint = Checkpoint.create(stateBackend.db);
+			checkpoint.createCheckpoint(backupPath.getPath());
+		}
+
+		KeyedStateHandle materializeSnapshot() throws Exception {
+
+			synchronized (stateBackend.asyncSnapshotLock) {
+
+				if (stateBackend.db == null) {
+					throw new IOException("RocksDB closed.");
+				}
+
+				stateBackend.cancelStreamRegistry.registerClosable(closeableRegistry);
+
+				// write meta data
+				metaStateHandle = materializeMetaData();
+
+				// write state data
+				Preconditions.checkState(backupFileSystem.exists(backupPath));
+
+				FileStatus[] fileStatuses = backupFileSystem.listStatus(backupPath);
+				if (fileStatuses != null) {
+					for (FileStatus fileStatus : fileStatuses) {
+						Path filePath = fileStatus.getPath();
+						String fileName = filePath.getName();
+
+						if (fileName.endsWith(SST_FILE_SUFFIX)) {
+							StreamStateHandle fileHandle =
+								baseSstFiles == null ? null : baseSstFiles.get(fileName);
+
+							if (fileHandle == null) {
+								fileHandle = materializeStateData(filePath);
+
+								newSstFiles.put(fileName, fileHandle);
+							} else {
+								oldSstFiles.put(fileName, fileHandle);
+							}
+						} else {
+							StreamStateHandle fileHandle = materializeStateData(filePath);
+							miscFiles.put(fileName, fileHandle);
+						}
+					}
+				}
+
+				Map<String, StreamStateHandle> sstFiles = new HashMap<>(newSstFiles.size() + oldSstFiles.size());
+				sstFiles.putAll(newSstFiles);
+				sstFiles.putAll(oldSstFiles);
+
+				stateBackend.materializedSstFiles.put(checkpointId, sstFiles);
+
+				return new RocksDBIncrementalKeyedStateHandle(stateBackend.jobId,
+					stateBackend.operatorIdentifier, stateBackend.keyGroupRange,
+					checkpointId, newSstFiles, oldSstFiles, miscFiles, metaStateHandle);
+			}
+		}
+
+		void stop() {
+			try {
+				closeableRegistry.close();
+			} catch (IOException e) {
+				LOG.warn("Could not properly close io streams.", e);
+			}
+		}
+
+		void releaseResources(boolean canceled) {
+			stateBackend.cancelStreamRegistry.unregisterClosable(closeableRegistry);
+
+			if (backupPath != null) {
+				try {
+					if (backupFileSystem.exists(backupPath)) {
+						backupFileSystem.delete(backupPath, true);
+					}
+				} catch (Exception e) {
+					LOG.warn("Could not properly delete the checkpoint directory.", e);
+				}
+			}
+
+			if (canceled) {
+				List<StateObject> statesToDiscard = new ArrayList<>();
+
+				statesToDiscard.add(metaStateHandle);
+				statesToDiscard.addAll(miscFiles.values());
+				statesToDiscard.addAll(newSstFiles.values());
+
+				try {
+					StateUtil.bestEffortDiscardAllStateObjects(statesToDiscard);
+				} catch (Exception e) {
+					LOG.warn("Could not properly discard states.", e);
+				}
+			}
+		}
+	}
+
 	@Override
 	public void restore(Collection<KeyedStateHandle> restoreState) throws Exception {
 		LOG.info("Initializing RocksDB keyed state backend from snapshot.");
@@ -630,11 +947,16 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		}
 
 		try {
-			if (MigrationUtil.isOldSavepointKeyedState(restoreState)) {
+			if (restoreState == null || restoreState.isEmpty()) {
+				createDB();
+			} else if (MigrationUtil.isOldSavepointKeyedState(restoreState)) {
 				LOG.info("Converting RocksDB state from old savepoint.");
 				restoreOldSavepointKeyedState(restoreState);
+			} else if (restoreState.iterator().next() instanceof RocksDBIncrementalKeyedStateHandle) {
+				RocksDBIncrementalRestoreOperation restoreOperation = new RocksDBIncrementalRestoreOperation(this);
+				restoreOperation.restore(restoreState);
 			} else {
-				RocksDBRestoreOperation restoreOperation = new RocksDBRestoreOperation(this);
+				RocksDBFullRestoreOperation restoreOperation = new RocksDBFullRestoreOperation(this);
 				restoreOperation.doRestore(restoreState);
 			}
 		} catch (Exception ex) {
@@ -643,10 +965,68 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		}
 	}
 
+	@Override
+	public void notifyOfCompletedCheckpoint(long completedCheckpointId) {
+		synchronized (asyncSnapshotLock) {
+			if (completedCheckpointId < lastCompletedCheckpointId) {
+				return;
+			}
+
+			Iterator<Long> materializedCheckpointIterator = materializedSstFiles.keySet().iterator();
+			while (materializedCheckpointIterator.hasNext()) {
+				long materializedCheckpointId = materializedCheckpointIterator.next();
+
+				if (materializedCheckpointId < completedCheckpointId) {
+					materializedCheckpointIterator.remove();
+				}
+			}
+
+			lastCompletedCheckpointId = completedCheckpointId;
+		}
+	}
+
+	private void createDB() throws IOException {
+		db = openDB(instanceRocksDBPath.getAbsolutePath(),
+			new ArrayList<ColumnFamilyDescriptor>(),
+			null);
+	}
+
+	private RocksDB openDB(
+			String path,
+			List<ColumnFamilyDescriptor> stateColumnFamilyDescriptors,
+			List<ColumnFamilyHandle> stateColumnFamilyHandles) throws IOException {
+
+		List<ColumnFamilyDescriptor> columnFamilyDescriptors = new ArrayList<>(stateColumnFamilyDescriptors);
+		columnFamilyDescriptors.add(
+			new ColumnFamilyDescriptor(
+				"default".getBytes(ConfigConstants.DEFAULT_CHARSET), columnOptions));
+
+		List<ColumnFamilyHandle> columnFamilyHandles = new ArrayList<>(columnFamilyDescriptors.size());
+
+		RocksDB db;
+
+		try {
+			db = RocksDB.open(
+					Preconditions.checkNotNull(dbOptions),
+					Preconditions.checkNotNull(path),
+					columnFamilyDescriptors,
+					columnFamilyHandles);
+		} catch (RocksDBException e) {
+			throw new IOException("Error while opening RocksDB instance.", e);
+		}
+
+		if (stateColumnFamilyHandles != null) {
+			stateColumnFamilyHandles.addAll(
+				columnFamilyHandles.subList(0, columnFamilyHandles.size() - 1));
+		}
+
+		return db;
+	}
+
 	/**
 	 * Encapsulates the process of restoring a RocksDBKeyedStateBackend from a snapshot.
 	 */
-	static final class RocksDBRestoreOperation {
+	static final class RocksDBFullRestoreOperation {
 
 		private final RocksDBKeyedStateBackend<?> rocksDBKeyedStateBackend;
 
@@ -664,7 +1044,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		 *
 		 * @param rocksDBKeyedStateBackend the state backend into which we restore
 		 */
-		public RocksDBRestoreOperation(RocksDBKeyedStateBackend<?> rocksDBKeyedStateBackend) {
+		public RocksDBFullRestoreOperation(RocksDBKeyedStateBackend<?> rocksDBKeyedStateBackend) {
 			this.rocksDBKeyedStateBackend = Preconditions.checkNotNull(rocksDBKeyedStateBackend);
 		}
 
@@ -679,6 +1059,8 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		public void doRestore(Collection<KeyedStateHandle> keyedStateHandles)
 				throws IOException, ClassNotFoundException, RocksDBException {
 
+			rocksDBKeyedStateBackend.createDB();
+
 			for (KeyedStateHandle keyedStateHandle : keyedStateHandles) {
 				if (keyedStateHandle != null) {
 
@@ -787,14 +1169,14 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 					while (keyGroupHasMoreKeys) {
 						byte[] key = BytePrimitiveArraySerializer.INSTANCE.deserialize(currentStateHandleInView);
 						byte[] value = BytePrimitiveArraySerializer.INSTANCE.deserialize(currentStateHandleInView);
-						if (RocksDBSnapshotOperation.hasMetaDataFollowsFlag(key)) {
+						if (RocksDBFullSnapshotOperation.hasMetaDataFollowsFlag(key)) {
 							//clear the signal bit in the key to make it ready for insertion again
-							RocksDBSnapshotOperation.clearMetaDataFollowsFlag(key);
+							RocksDBFullSnapshotOperation.clearMetaDataFollowsFlag(key);
 							rocksDBKeyedStateBackend.db.put(handle, key, value);
 							//TODO this could be aware of keyGroupPrefixBytes and write only one byte if possible
-							kvStateId = RocksDBSnapshotOperation.END_OF_KEY_GROUP_MARK
+							kvStateId = RocksDBFullSnapshotOperation.END_OF_KEY_GROUP_MARK
 									& currentStateHandleInView.readShort();
-							if (RocksDBSnapshotOperation.END_OF_KEY_GROUP_MARK == kvStateId) {
+							if (RocksDBFullSnapshotOperation.END_OF_KEY_GROUP_MARK == kvStateId) {
 								keyGroupHasMoreKeys = false;
 							} else {
 								handle = currentStateHandleKVStateColumnFamilies.get(kvStateId);
@@ -808,6 +1190,272 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		}
 	}
 
+	private static class RocksDBIncrementalRestoreOperation {
+
+		private final RocksDBKeyedStateBackend<?> stateBackend;
+
+		private RocksDBIncrementalRestoreOperation(RocksDBKeyedStateBackend<?> stateBackend) {
+			this.stateBackend = stateBackend;
+		}
+
+		private List<KeyedBackendSerializationProxy.StateMetaInfo<?, ?>> readMetaData(
+				StreamStateHandle metaStateHandle) throws Exception {
+
+			FSDataInputStream inputStream = null;
+
+			try {
+				inputStream = metaStateHandle.openInputStream();
+				stateBackend.cancelStreamRegistry.registerClosable(inputStream);
+
+				KeyedBackendSerializationProxy serializationProxy =
+					new KeyedBackendSerializationProxy(stateBackend.userCodeClassLoader);
+				DataInputView in = new DataInputViewStreamWrapper(inputStream);
+				serializationProxy.read(in);
+
+				return serializationProxy.getNamedStateSerializationProxies();
+			} finally {
+				if (inputStream != null) {
+					stateBackend.cancelStreamRegistry.unregisterClosable(inputStream);
+					inputStream.close();
+				}
+			}
+		}
+
+		private void readStateData(
+				Path restoreFilePath,
+				StreamStateHandle remoteFileHandle) throws IOException {
+
+			FileSystem restoreFileSystem = restoreFilePath.getFileSystem();
+
+			FSDataInputStream inputStream = null;
+			FSDataOutputStream outputStream = null;
+
+			try {
+				inputStream = remoteFileHandle.openInputStream();
+				stateBackend.cancelStreamRegistry.registerClosable(inputStream);
+
+				outputStream = restoreFileSystem.create(restoreFilePath, FileSystem.WriteMode.OVERWRITE);
+				stateBackend.cancelStreamRegistry.registerClosable(outputStream);
+
+				byte[] buffer = new byte[1024];
+				while (true) {
+					int numBytes = inputStream.read(buffer);
+					if (numBytes == -1) {
+						break;
+					}
+
+					outputStream.write(buffer, 0, numBytes);
+				}
+			} finally {
+				if (inputStream != null) {
+					stateBackend.cancelStreamRegistry.unregisterClosable(inputStream);
+					inputStream.close();
+				}
+
+				if (outputStream != null) {
+					stateBackend.cancelStreamRegistry.unregisterClosable(outputStream);
+					outputStream.close();
+				}
+			}
+		}
+
+		private void restoreInstance(
+				RocksDBIncrementalKeyedStateHandle restoreStateHandle,
+				boolean hasExtraKeys) throws Exception {
+
+			// read state data
+			Path restoreInstancePath = new Path(
+				stateBackend.instanceBasePath.getAbsolutePath(),
+				UUID.randomUUID().toString());
+
+			try {
+				Map<String, StreamStateHandle> newSstFiles = restoreStateHandle.getNewSstFiles();
+				for (Map.Entry<String, StreamStateHandle> newSstFileEntry : newSstFiles.entrySet()) {
+					String fileName = newSstFileEntry.getKey();
+					StreamStateHandle remoteFileHandle = newSstFileEntry.getValue();
+
+					readStateData(new Path(restoreInstancePath, fileName), remoteFileHandle);
+				}
+
+				Map<String, StreamStateHandle> oldSstFiles = restoreStateHandle.getOldSstFiles();
+				for (Map.Entry<String, StreamStateHandle> oldSstFileEntry : oldSstFiles.entrySet()) {
+					String fileName = oldSstFileEntry.getKey();
+					StreamStateHandle remoteFileHandle = oldSstFileEntry.getValue();
+
+					readStateData(new Path(restoreInstancePath, fileName), remoteFileHandle);
+				}
+
+				Map<String, StreamStateHandle> miscFiles = restoreStateHandle.getMiscFiles();
+				for (Map.Entry<String, StreamStateHandle> miscFileEntry : miscFiles.entrySet()) {
+					String fileName = miscFileEntry.getKey();
+					StreamStateHandle remoteFileHandle = miscFileEntry.getValue();
+
+					readStateData(new Path(restoreInstancePath, fileName), remoteFileHandle);
+				}
+
+				// read meta data
+				List<KeyedBackendSerializationProxy.StateMetaInfo<?, ?>> stateMetaInfoProxies =
+					readMetaData(restoreStateHandle.getMetaStateHandle());
+
+				List<ColumnFamilyDescriptor> columnFamilyDescriptors = new ArrayList<>();
+
+				for (KeyedBackendSerializationProxy.StateMetaInfo<?, ?> stateMetaInfoProxy : stateMetaInfoProxies) {
+
+					ColumnFamilyDescriptor columnFamilyDescriptor = new ColumnFamilyDescriptor(
+						stateMetaInfoProxy.getStateName().getBytes(ConfigConstants.DEFAULT_CHARSET),
+						stateBackend.columnOptions);
+
+					columnFamilyDescriptors.add(columnFamilyDescriptor);
+				}
+
+				if (hasExtraKeys) {
+
+					List<ColumnFamilyHandle> columnFamilyHandles = new ArrayList<>();
+
+					try (RocksDB restoreDb = stateBackend.openDB(
+							restoreInstancePath.getPath(),
+							columnFamilyDescriptors,
+							columnFamilyHandles)) {
+
+						for (int i = 0; i < columnFamilyHandles.size(); ++i) {
+							ColumnFamilyHandle columnFamilyHandle = columnFamilyHandles.get(i);
+							ColumnFamilyDescriptor columnFamilyDescriptor = columnFamilyDescriptors.get(i);
+							KeyedBackendSerializationProxy.StateMetaInfo<?, ?> stateMetaInfoProxy = stateMetaInfoProxies.get(i);
+
+							Tuple2<ColumnFamilyHandle, RegisteredBackendStateMetaInfo<?, ?>> registeredStateMetaInfoEntry =
+								stateBackend.kvStateInformation.get(stateMetaInfoProxy.getStateName());
+
+							if (null == registeredStateMetaInfoEntry) {
+
+								RegisteredBackendStateMetaInfo<?, ?> stateMetaInfo =
+									new RegisteredBackendStateMetaInfo<>(stateMetaInfoProxy);
+
+								registeredStateMetaInfoEntry =
+									new Tuple2<ColumnFamilyHandle, RegisteredBackendStateMetaInfo<?, ?>>(
+										stateBackend.db.createColumnFamily(columnFamilyDescriptor),
+										stateMetaInfo);
+
+								stateBackend.kvStateInformation.put(
+									stateMetaInfoProxy.getStateName(),
+									registeredStateMetaInfoEntry);
+							}
+
+							ColumnFamilyHandle targetColumnFamilyHandle = registeredStateMetaInfoEntry.f0;
+
+							try (RocksIterator iterator = restoreDb.newIterator(columnFamilyHandle)) {
+
+								int startKeyGroup = stateBackend.getKeyGroupRange().getStartKeyGroup();
+								byte[] startKeyGroupPrefixBytes = new byte[stateBackend.keyGroupPrefixBytes];
+								for (int j = 0; j < stateBackend.keyGroupPrefixBytes; ++j) {
+									startKeyGroupPrefixBytes[j] = (byte)(startKeyGroup >>> ((stateBackend.keyGroupPrefixBytes - j - 1) * Byte.SIZE));
+								}
+
+								iterator.seek(startKeyGroupPrefixBytes);
+
+								while (iterator.isValid()) {
+
+									int keyGroup = 0;
+									for (int j = 0; j < stateBackend.keyGroupPrefixBytes; ++j) {
+										keyGroup = (keyGroup << Byte.SIZE) + iterator.key()[j];
+									}
+
+									if (stateBackend.keyGroupRange.contains(keyGroup)) {
+										stateBackend.db.put(targetColumnFamilyHandle,
+											iterator.key(), iterator.value());
+									}
+
+									iterator.next();
+								}
+							}
+						}
+					}
+				} else {
+
+					// create hard links in the instance directory
+					if (!stateBackend.instanceRocksDBPath.mkdirs()) {
+						throw new IOException("Could not create RocksDB data directory.");
+					}
+
+					for (String newSstFileName : newSstFiles.keySet()) {
+						File restoreFile = new File(restoreInstancePath.getPath(), newSstFileName);
+						File targetFile = new File(stateBackend.instanceRocksDBPath, newSstFileName);
+
+						Files.createLink(targetFile.toPath(), restoreFile.toPath());
+					}
+
+					for (String oldSstFileName : oldSstFiles.keySet()) {
+						File restoreFile = new File(restoreInstancePath.getPath(), oldSstFileName);
+						File targetFile = new File(stateBackend.instanceRocksDBPath, oldSstFileName);
+
+						Files.createLink(targetFile.toPath(), restoreFile.toPath());
+					}
+
+					for (String miscFileName : miscFiles.keySet()) {
+						File restoreFile = new File(restoreInstancePath.getPath(), miscFileName);
+						File targetFile = new File(stateBackend.instanceRocksDBPath, miscFileName);
+
+						Files.createLink(targetFile.toPath(), restoreFile.toPath());
+					}
+
+					List<ColumnFamilyHandle> columnFamilyHandles = new ArrayList<>();
+					stateBackend.db = stateBackend.openDB(
+						stateBackend.instanceRocksDBPath.getAbsolutePath(),
+						columnFamilyDescriptors, columnFamilyHandles);
+
+					for (int i = 0; i < columnFamilyDescriptors.size(); ++i) {
+						KeyedBackendSerializationProxy.StateMetaInfo<?, ?> stateMetaInfoProxy = stateMetaInfoProxies.get(i);
+
+						ColumnFamilyHandle columnFamilyHandle = columnFamilyHandles.get(i);
+						RegisteredBackendStateMetaInfo<?, ?> stateMetaInfo =
+							new RegisteredBackendStateMetaInfo<>(stateMetaInfoProxy);
+
+						stateBackend.kvStateInformation.put(
+							stateMetaInfoProxy.getStateName(),
+							new Tuple2<ColumnFamilyHandle, RegisteredBackendStateMetaInfo<?, ?>>(
+								columnFamilyHandle, stateMetaInfo));
+					}
+
+
+					// use the restore sst files as the base for succeeding checkpoints
+					Map<String, StreamStateHandle> sstFiles = new HashMap<>();
+					sstFiles.putAll(newSstFiles);
+					sstFiles.putAll(oldSstFiles);
+					stateBackend.materializedSstFiles.put(restoreStateHandle.getCheckpointId(), sstFiles);
+
+					stateBackend.lastCompletedCheckpointId = restoreStateHandle.getCheckpointId();
+				}
+			} finally {
+				FileSystem restoreFileSystem = restoreInstancePath.getFileSystem();
+				if (restoreFileSystem.exists(restoreInstancePath)) {
+					restoreFileSystem.delete(restoreInstancePath, true);
+				}
+			}
+		}
+
+		void restore(Collection<KeyedStateHandle> restoreStateHandles) throws Exception {
+
+			boolean hasExtraKeys = (restoreStateHandles.size() > 1 ||
+				!restoreStateHandles.iterator().next().getKeyGroupRange().equals(stateBackend.keyGroupRange));
+
+			if (hasExtraKeys) {
+				stateBackend.createDB();
+			}
+
+			for (KeyedStateHandle rawStateHandle : restoreStateHandles) {
+
+				if (! (rawStateHandle instanceof RocksDBIncrementalKeyedStateHandle)) {
+					throw new IllegalStateException("Unexpected state handle type, " +
+						"expected " + RocksDBIncrementalKeyedStateHandle.class +
+						", but found " + rawStateHandle.getClass());
+				}
+
+				RocksDBIncrementalKeyedStateHandle keyedStateHandle = (RocksDBIncrementalKeyedStateHandle) rawStateHandle;
+
+				restoreInstance(keyedStateHandle, hasExtraKeys);
+			}
+		}
+	}
+
 	// ------------------------------------------------------------------------
 	//  State factories
 	// ------------------------------------------------------------------------
@@ -1160,10 +1808,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	 */
 	@Deprecated
 	private void restoreOldSavepointKeyedState(Collection<KeyedStateHandle> restoreState) throws Exception {
-
-		if (restoreState.isEmpty()) {
-			return;
-		}
+		createDB();
 
 		Preconditions.checkState(1 == restoreState.size(), "Only one element expected here.");
 

http://git-wip-us.apache.org/repos/asf/flink/blob/6e94cf19/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
index 80c9a29..55b8be2 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
@@ -109,6 +109,9 @@ public class RocksDBStateBackend extends AbstractStateBackend {
 	/** Whether we already lazily initialized our local storage directories. */
 	private transient boolean isInitialized = false;
 
+	/** True if incremental checkpointing is enabled */
+	private boolean enableIncrementalCheckpointing;
+
 
 	/**
 	 * Creates a new {@code RocksDBStateBackend} that stores its checkpoint data in the
@@ -123,7 +126,24 @@ public class RocksDBStateBackend extends AbstractStateBackend {
 	 * @throws IOException Thrown, if no file system can be found for the scheme in the URI.
 	 */
 	public RocksDBStateBackend(String checkpointDataUri) throws IOException {
-		this(new Path(checkpointDataUri).toUri());
+		this(new Path(checkpointDataUri).toUri(), false);
+	}
+
+	/**
+	 * Creates a new {@code RocksDBStateBackend} that stores its checkpoint data in the
+	 * file system and location defined by the given URI.
+	 *
+	 * <p>A state backend that stores checkpoints in HDFS or S3 must specify the file system
+	 * host and port in the URI, or have the Hadoop configuration that describes the file system
+	 * (host / high-availability group / possibly credentials) either referenced from the Flink
+	 * config, or included in the classpath.
+	 *
+	 * @param checkpointDataUri The URI describing the filesystem and path to the checkpoint data directory.
+	 * @param enableIncrementalCheckpointing True if incremental checkpointing is enabled.
+	 * @throws IOException Thrown, if no file system can be found for the scheme in the URI.
+	 */
+	public RocksDBStateBackend(String checkpointDataUri, boolean enableIncrementalCheckpointing) throws IOException {
+		this(new Path(checkpointDataUri).toUri(), enableIncrementalCheckpointing);
 	}
 
 	/**
@@ -139,7 +159,24 @@ public class RocksDBStateBackend extends AbstractStateBackend {
 	 * @throws IOException Thrown, if no file system can be found for the scheme in the URI.
 	 */
 	public RocksDBStateBackend(URI checkpointDataUri) throws IOException {
-		this(new FsStateBackend(checkpointDataUri));
+		this(new FsStateBackend(checkpointDataUri), false);
+	}
+
+	/**
+	 * Creates a new {@code RocksDBStateBackend} that stores its checkpoint data in the
+	 * file system and location defined by the given URI.
+	 *
+	 * <p>A state backend that stores checkpoints in HDFS or S3 must specify the file system
+	 * host and port in the URI, or have the Hadoop configuration that describes the file system
+	 * (host / high-availability group / possibly credentials) either referenced from the Flink
+	 * config, or included in the classpath.
+	 *
+	 * @param checkpointDataUri The URI describing the filesystem and path to the checkpoint data directory.
+	 * @param enableIncrementalCheckpointing True if incremental checkpointing is enabled.
+	 * @throws IOException Thrown, if no file system can be found for the scheme in the URI.
+	 */
+	public RocksDBStateBackend(URI checkpointDataUri, boolean enableIncrementalCheckpointing) throws IOException {
+		this(new FsStateBackend(checkpointDataUri), enableIncrementalCheckpointing);
 	}
 
 	/**
@@ -156,6 +193,22 @@ public class RocksDBStateBackend extends AbstractStateBackend {
 		this.checkpointStreamBackend = requireNonNull(checkpointStreamBackend);
 	}
 
+	/**
+	 * Creates a new {@code RocksDBStateBackend} that uses the given state backend to store its
+	 * checkpoint data streams. Typically, one would supply a filesystem or database state backend
+	 * here where the snapshots from RocksDB would be stored.
+	 *
+	 * <p>The snapshots of the RocksDB state will be stored using the given backend's
+	 * {@link AbstractStateBackend#createStreamFactory(JobID, String) checkpoint stream}.
+	 *
+	 * @param checkpointStreamBackend The backend to store the
+	 * @param enableIncrementalCheckpointing True if incremental checkponting is enabled
+	 */
+	public RocksDBStateBackend(AbstractStateBackend checkpointStreamBackend, boolean enableIncrementalCheckpointing) {
+		this.checkpointStreamBackend = requireNonNull(checkpointStreamBackend);
+		this.enableIncrementalCheckpointing = enableIncrementalCheckpointing;
+	}
+
 	// ------------------------------------------------------------------------
 	//  State backend methods
 	// ------------------------------------------------------------------------
@@ -260,7 +313,8 @@ public class RocksDBStateBackend extends AbstractStateBackend {
 				keySerializer,
 				numberOfKeyGroups,
 				keyGroupRange,
-				env.getExecutionConfig());
+				env.getExecutionConfig(),
+				enableIncrementalCheckpointing);
 	}
 
 	// ------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/6e94cf19/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAggregatingStateTest.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAggregatingStateTest.java b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAggregatingStateTest.java
index 983e569..1b65466 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAggregatingStateTest.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAggregatingStateTest.java
@@ -204,7 +204,7 @@ public class RocksDBAggregatingStateTest {
 	}
 
 	private static RocksDBKeyedStateBackend<String> createKeyedBackend(RocksDBStateBackend backend) throws Exception {
-		return (RocksDBKeyedStateBackend<String>) backend.createKeyedStateBackend(
+		RocksDBKeyedStateBackend<String> keyedBackend = (RocksDBKeyedStateBackend<String>) backend.createKeyedStateBackend(
 						new DummyEnvironment("TestTask", 1, 0),
 						new JobID(),
 						"test-op",
@@ -212,6 +212,10 @@ public class RocksDBAggregatingStateTest {
 						16,
 						new KeyGroupRange(2, 3),
 						mock(TaskKvStateRegistry.class));
+
+		keyedBackend.restore(null);
+
+		return keyedBackend;
 	}
 
 	//  test functions

http://git-wip-us.apache.org/repos/asf/flink/blob/6e94cf19/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
index ffe2ce2..812babb 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
@@ -41,7 +41,6 @@ import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
@@ -338,6 +337,8 @@ public class RocksDBAsyncSnapshotTest {
 			new KeyGroupRange(0, 0),
 			null);
 
+		keyedStateBackend.restore(null);
+
 		// register a state so that the state backend has to checkpoint something
 		keyedStateBackend.getPartitionedState(
 			"namespace",
@@ -360,19 +361,21 @@ public class RocksDBAsyncSnapshotTest {
 	@Test
 	public void testConsistentSnapshotSerializationFlagsAndMasks() {
 
-		Assert.assertEquals(0xFFFF, RocksDBKeyedStateBackend.RocksDBSnapshotOperation.END_OF_KEY_GROUP_MARK);
-		Assert.assertEquals(0x80, RocksDBKeyedStateBackend.RocksDBSnapshotOperation.FIRST_BIT_IN_BYTE_MASK);
+		Assert.assertEquals(0xFFFF, RocksDBKeyedStateBackend.RocksDBFullSnapshotOperation.END_OF_KEY_GROUP_MARK);
+		Assert.assertEquals(0x80, RocksDBKeyedStateBackend.RocksDBFullSnapshotOperation.FIRST_BIT_IN_BYTE_MASK);
 
 		byte[] expectedKey = new byte[] {42, 42};
 		byte[] modKey = expectedKey.clone();
 
-		Assert.assertFalse(RocksDBKeyedStateBackend.RocksDBSnapshotOperation.hasMetaDataFollowsFlag(modKey));
+		Assert.assertFalse(
+			RocksDBKeyedStateBackend.RocksDBFullSnapshotOperation.hasMetaDataFollowsFlag(modKey));
 
-		RocksDBKeyedStateBackend.RocksDBSnapshotOperation.setMetaDataFollowsFlagInKey(modKey);
-		Assert.assertTrue(RocksDBKeyedStateBackend.RocksDBSnapshotOperation.hasMetaDataFollowsFlag(modKey));
+		RocksDBKeyedStateBackend.RocksDBFullSnapshotOperation.setMetaDataFollowsFlagInKey(modKey);
+		Assert.assertTrue(RocksDBKeyedStateBackend.RocksDBFullSnapshotOperation.hasMetaDataFollowsFlag(modKey));
 
-		RocksDBKeyedStateBackend.RocksDBSnapshotOperation.clearMetaDataFollowsFlag(modKey);
-		Assert.assertFalse(RocksDBKeyedStateBackend.RocksDBSnapshotOperation.hasMetaDataFollowsFlag(modKey));
+		RocksDBKeyedStateBackend.RocksDBFullSnapshotOperation.clearMetaDataFollowsFlag(modKey);
+		Assert.assertFalse(
+			RocksDBKeyedStateBackend.RocksDBFullSnapshotOperation.hasMetaDataFollowsFlag(modKey));
 
 		Assert.assertTrue(Arrays.equals(expectedKey, modKey));
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/6e94cf19/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBListStateTest.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBListStateTest.java b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBListStateTest.java
index d8d0308..e7efcfa 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBListStateTest.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBListStateTest.java
@@ -210,7 +210,7 @@ public class RocksDBListStateTest {
 	// ------------------------------------------------------------------------
 
 	private static RocksDBKeyedStateBackend<String> createKeyedBackend(RocksDBStateBackend backend) throws Exception {
-		return (RocksDBKeyedStateBackend<String>) backend.createKeyedStateBackend(
+		RocksDBKeyedStateBackend<String> keyedBackend = (RocksDBKeyedStateBackend<String>) backend.createKeyedStateBackend(
 				new DummyEnvironment("TestTask", 1, 0),
 				new JobID(),
 				"test-op",
@@ -218,6 +218,10 @@ public class RocksDBListStateTest {
 				16,
 				new KeyGroupRange(2, 3),
 				mock(TaskKvStateRegistry.class));
+
+		keyedBackend.restore(null);
+
+		return keyedBackend;
 	}
 
 	private static <T> void validateResult(Iterable<T> values, Set<T> expected) {

http://git-wip-us.apache.org/repos/asf/flink/blob/6e94cf19/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBReducingStateTest.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBReducingStateTest.java b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBReducingStateTest.java
index fb854f2..a8b4535 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBReducingStateTest.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBReducingStateTest.java
@@ -210,7 +210,7 @@ public class RocksDBReducingStateTest {
 	// ------------------------------------------------------------------------
 
 	private static RocksDBKeyedStateBackend<String> createKeyedBackend(RocksDBStateBackend backend) throws Exception {
-		return (RocksDBKeyedStateBackend<String>) backend.createKeyedStateBackend(
+		RocksDBKeyedStateBackend<String> keyedBackend = (RocksDBKeyedStateBackend<String>) backend.createKeyedStateBackend(
 				new DummyEnvironment("TestTask", 1, 0),
 				new JobID(),
 				"test-op",
@@ -218,6 +218,10 @@ public class RocksDBReducingStateTest {
 				16,
 				new KeyGroupRange(2, 3),
 				mock(TaskKvStateRegistry.class));
+
+		keyedBackend.restore(null);
+
+		return keyedBackend;
 	}
 
 	// ------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/6e94cf19/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
index b5f18a4..fad1559 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
@@ -33,7 +33,6 @@ import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.StateBackendTestBase;
 import org.apache.flink.runtime.state.VoidNamespace;
@@ -42,8 +41,11 @@ import org.apache.flink.runtime.state.filesystem.FsStateBackend;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.TemporaryFolder;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
+import org.rocksdb.Checkpoint;
 import org.rocksdb.ColumnFamilyDescriptor;
 import org.rocksdb.ColumnFamilyHandle;
 import org.rocksdb.ReadOptions;
@@ -55,6 +57,7 @@ import org.rocksdb.Snapshot;
 import java.io.File;
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.List;
 import java.util.concurrent.RunnableFuture;
@@ -73,6 +76,7 @@ import static org.powermock.api.mockito.PowerMockito.spy;
 /**
  * Tests for the partitioned state part of {@link RocksDBStateBackend}.
  */
+@RunWith(Parameterized.class)
 public class RocksDBStateBackendTest extends StateBackendTestBase<RocksDBStateBackend> {
 
 	private OneShotLatch blocker;
@@ -83,17 +87,25 @@ public class RocksDBStateBackendTest extends StateBackendTestBase<RocksDBStateBa
 	private ValueState<Integer> testState1;
 	private ValueState<String> testState2;
 
+	@Parameterized.Parameters
+	public static Collection<Boolean> parameters() {
+		return Arrays.asList(false, true);
+	}
+
+	@Parameterized.Parameter
+	public boolean enableIncrementalCheckpointing;
+
 	@Rule
 	public TemporaryFolder tempFolder = new TemporaryFolder();
 
 	// Store it because we need it for the cleanup test.
-	private String dbPath;
+	String dbPath;
 
 	@Override
 	protected RocksDBStateBackend getStateBackend() throws IOException {
 		dbPath = tempFolder.newFolder().getAbsolutePath();
 		String checkpointPath = tempFolder.newFolder().toURI().toString();
-		RocksDBStateBackend backend = new RocksDBStateBackend(new FsStateBackend(checkpointPath));
+		RocksDBStateBackend backend = new RocksDBStateBackend(new FsStateBackend(checkpointPath), enableIncrementalCheckpointing);
 		backend.setDbStoragePath(dbPath);
 		return backend;
 	}
@@ -105,7 +117,7 @@ public class RocksDBStateBackendTest extends StateBackendTestBase<RocksDBStateBa
 		testStreamFactory = new BlockerCheckpointStreamFactory(1024 * 1024);
 		testStreamFactory.setBlockerLatch(blocker);
 		testStreamFactory.setWaiterLatch(waiter);
-		testStreamFactory.setAfterNumberInvocations(100);
+		testStreamFactory.setAfterNumberInvocations(10);
 
 		RocksDBStateBackend backend = getStateBackend();
 		Environment env = new DummyEnvironment("TestTask", 1, 0);
@@ -119,6 +131,8 @@ public class RocksDBStateBackendTest extends StateBackendTestBase<RocksDBStateBa
 				new KeyGroupRange(0, 1),
 				mock(TaskKvStateRegistry.class));
 
+		keyedStateBackend.restore(null);
+
 		testState1 = keyedStateBackend.getPartitionedState(
 				VoidNamespace.INSTANCE,
 				VoidNamespaceSerializer.INSTANCE,
@@ -178,8 +192,10 @@ public class RocksDBStateBackendTest extends StateBackendTestBase<RocksDBStateBa
 
 		RocksDB spyDB = keyedStateBackend.db;
 
-		verify(spyDB, times(1)).getSnapshot();
-		verify(spyDB, times(0)).releaseSnapshot(any(Snapshot.class));
+		if (!enableIncrementalCheckpointing) {
+			verify(spyDB, times(1)).getSnapshot();
+			verify(spyDB, times(0)).releaseSnapshot(any(Snapshot.class));
+		}
 
 		this.keyedStateBackend.dispose();
 		verify(spyDB, times(1)).close();
@@ -216,8 +232,10 @@ public class RocksDBStateBackendTest extends StateBackendTestBase<RocksDBStateBa
 
 		RocksDB spyDB = keyedStateBackend.db;
 
-		verify(spyDB, times(1)).getSnapshot();
-		verify(spyDB, times(0)).releaseSnapshot(any(Snapshot.class));
+		if (!enableIncrementalCheckpointing) {
+			verify(spyDB, times(1)).getSnapshot();
+			verify(spyDB, times(0)).releaseSnapshot(any(Snapshot.class));
+		}
 
 		this.keyedStateBackend.dispose();
 		verify(spyDB, times(1)).close();
@@ -319,7 +337,6 @@ public class RocksDBStateBackendTest extends StateBackendTestBase<RocksDBStateBa
 		backend.setCurrentKey(1);
 		state.update("Hello");
 
-
 		Collection<File> allFilesInDbDir =
 				FileUtils.listFilesAndDirs(new File(dbPath), new AcceptAllFilter(), new AcceptAllFilter());
 
@@ -356,8 +373,10 @@ public class RocksDBStateBackendTest extends StateBackendTestBase<RocksDBStateBa
 		assertNotNull(null, keyedStateBackend.db);
 		RocksDB spyDB = keyedStateBackend.db;
 
-		verify(spyDB, times(1)).getSnapshot();
-		verify(spyDB, times(1)).releaseSnapshot(any(Snapshot.class));
+		if (!enableIncrementalCheckpointing) {
+			verify(spyDB, times(1)).getSnapshot();
+			verify(spyDB, times(1)).releaseSnapshot(any(Snapshot.class));
+		}
 
 		keyedStateBackend.dispose();
 		verify(spyDB, times(1)).close();

http://git-wip-us.apache.org/repos/asf/flink/blob/6e94cf19/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
index 121ac57..a77baf3 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
@@ -152,12 +152,24 @@ public class SubtaskState implements CompositeStateHandle {
 
 	@Override
 	public void registerSharedStates(SharedStateRegistry sharedStateRegistry) {
-		// No shared states
+		if (managedKeyedState != null) {
+			managedKeyedState.registerSharedStates(sharedStateRegistry);
+		}
+
+		if (rawKeyedState != null) {
+			rawKeyedState.registerSharedStates(sharedStateRegistry);
+		}
 	}
 
 	@Override
 	public void unregisterSharedStates(SharedStateRegistry sharedStateRegistry) {
-		// No shared states
+		if (managedKeyedState != null) {
+			managedKeyedState.unregisterSharedStates(sharedStateRegistry);
+		}
+
+		if (rawKeyedState != null) {
+			rawKeyedState.unregisterSharedStates(sharedStateRegistry);
+		}
 	}
 
 	@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/6e94cf19/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
index e86f1f8..61f397c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
@@ -212,6 +212,16 @@ public abstract class AbstractKeyedStateBackend<K>
 			MapStateDescriptor<UK, UV> stateDesc) throws Exception;
 
 	/**
+	 * Called when the checkpoint with the given ID is completed and acknowledged on the JobManager.
+	 *
+	 * @param checkpointId The ID of the checkpoint that has been completed.
+	 *
+	 * @throws Exception Exceptions during checkpoint acknowledgement may be forwarded and will cause
+	 *                   the program to fail and enter recovery.
+	 */
+	public abstract void notifyOfCompletedCheckpoint(long checkpointId) throws Exception;
+
+	/**
 	 * @see KeyedStateBackend
 	 */
 	@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/6e94cf19/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java
index bad7fd4..8280460 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java
@@ -93,6 +93,16 @@ public class KeyGroupsStateHandle implements StreamStateHandle, KeyedStateHandle
 	}
 
 	@Override
+	public void registerSharedStates(SharedStateRegistry stateRegistry) {
+		// No shared states
+	}
+
+	@Override
+	public void unregisterSharedStates(SharedStateRegistry stateRegistry) {
+		// No shared states
+	}
+
+	@Override
 	public void discardState() throws Exception {
 		stateHandle.discardState();
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/6e94cf19/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateHandle.java
index dc9c97d..704ec14 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateHandle.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateHandle.java
@@ -23,7 +23,7 @@ package org.apache.flink.runtime.state;
  * recovering from failures, the handle will be passed to all tasks whose key
  * group ranges overlap with it.
  */
-public interface KeyedStateHandle extends StateObject {
+public interface KeyedStateHandle extends CompositeStateHandle {
 
 	/**
 	 * Returns the range of the key groups contained in the state.

http://git-wip-us.apache.org/repos/asf/flink/blob/6e94cf19/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtil.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtil.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtil.java
index b250831..6f231e4 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtil.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtil.java
@@ -33,6 +33,15 @@ public class StateUtil {
 	}
 
 	/**
+	 * Returns the size of a state object
+	 *
+	 * @param handle The handle to the retrieved state
+	 */
+	public static long getStateSize(StateObject handle) {
+		return handle == null ? 0 : handle.getStateSize();
+	}
+
+	/**
 	 * Iterates through the passed state handles and calls discardState() on each handle that is not null. All
 	 * occurring exceptions are suppressed and collected until the iteration is over and emitted as a single exception.
 	 *

http://git-wip-us.apache.org/repos/asf/flink/blob/6e94cf19/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
index 38817cd..ead89b3 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
@@ -328,6 +328,10 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	@SuppressWarnings("deprecation")
 	@Override
 	public void restore(Collection<KeyedStateHandle> restoredState) throws Exception {
+		if (restoredState == null || restoredState.isEmpty()) {
+			return;
+		}
+
 		LOG.info("Initializing heap keyed state backend from snapshot.");
 
 		if (LOG.isDebugEnabled()) {
@@ -426,6 +430,11 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	}
 
 	@Override
+	public void notifyOfCompletedCheckpoint(long checkpointId) {
+		//Nothing to do
+	}
+
+	@Override
 	public String toString() {
 		return "HeapKeyedStateBackend";
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/6e94cf19/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
index ccc1eae..60f9c81 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
@@ -133,7 +133,8 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			int numberOfKeyGroups,
 			KeyGroupRange keyGroupRange,
 			Environment env) throws Exception {
-		return getStateBackend().createKeyedStateBackend(
+
+		AbstractKeyedStateBackend<K> backend = getStateBackend().createKeyedStateBackend(
 				env,
 				new JobID(),
 				"test_op",
@@ -141,6 +142,10 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 				numberOfKeyGroups,
 				keyGroupRange,
 				env.getTaskKvStateRegistry());
+
+		backend.restore(null);
+
+		return backend;
 	}
 
 	protected <K> AbstractKeyedStateBackend<K> restoreKeyedBackend(TypeSerializer<K> keySerializer, KeyedStateHandle state) throws Exception {
@@ -2197,9 +2202,11 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 
 		Assert.assertNotNull(stateHandle);
 
-		backend = createKeyedBackend(IntSerializer.INSTANCE);
+		backend = null;
+
 		try {
-			backend.restore(Collections.singleton(stateHandle));
+			backend = restoreKeyedBackend(IntSerializer.INSTANCE, stateHandle);
+
 			InternalValueState<VoidNamespace, Integer> valueState = backend.createValueState(
 					VoidNamespaceSerializer.INSTANCE,
 					new ValueStateDescriptor<>("test", IntSerializer.INSTANCE));
@@ -2297,7 +2304,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 	 * Returns the value by getting the serialized value and deserializing it
 	 * if it is not null.
 	 */
-	private static <V, K, N> V getSerializedValue(
+	protected static <V, K, N> V getSerializedValue(
 			InternalKvState<N> kvState,
 			K key,
 			TypeSerializer<K> keySerializer,

http://git-wip-us.apache.org/repos/asf/flink/blob/6e94cf19/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
index 1850007..d45ad42 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
@@ -504,7 +504,11 @@ public abstract class AbstractStreamOperator<OUT>
 	}
 
 	@Override
-	public void notifyOfCompletedCheckpoint(long checkpointId) throws Exception {}
+	public void notifyOfCompletedCheckpoint(long checkpointId) throws Exception {
+		if (keyedStateBackend != null) {
+			keyedStateBackend.notifyOfCompletedCheckpoint(checkpointId);
+		}
+	}
 
 	/**
 	 * Returns a checkpoint stream factory for the provided options.

http://git-wip-us.apache.org/repos/asf/flink/blob/6e94cf19/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index 57e43de..bc66751 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -765,9 +765,10 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 		cancelables.registerClosable(keyedStateBackend);
 
 		// restore if we have some old state
-		if (null != restoreStateHandles && null != restoreStateHandles.getManagedKeyedState()) {
-			keyedStateBackend.restore(restoreStateHandles.getManagedKeyedState());
-		}
+		Collection<KeyedStateHandle> restoreKeyedStateHandles =
+			restoreStateHandles == null ? null : restoreStateHandles.getManagedKeyedState();
+
+		keyedStateBackend.restore(restoreKeyedStateHandles);
 
 		@SuppressWarnings("unchecked")
 		AbstractKeyedStateBackend<K> typedBackend = (AbstractKeyedStateBackend<K>) keyedStateBackend;

http://git-wip-us.apache.org/repos/asf/flink/blob/6e94cf19/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
index d9c7387..c6d0bce 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
@@ -116,9 +116,7 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 							keyGroupRange,
 							mockTask.getEnvironment().getTaskKvStateRegistry());
 
-					if (restoredKeyedState != null) {
-						keyedStateBackend.restore(restoredKeyedState);
-					}
+					keyedStateBackend.restore(restoredKeyedState);
 
 					return keyedStateBackend;
 				}

http://git-wip-us.apache.org/repos/asf/flink/blob/6e94cf19/flink-tests/src/test/java/org/apache/flink/test/checkpointing/PartitionedStateCheckpointingITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/PartitionedStateCheckpointingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/PartitionedStateCheckpointingITCase.java
index 4761d70..517c82b 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/PartitionedStateCheckpointingITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/PartitionedStateCheckpointingITCase.java
@@ -22,6 +22,8 @@ import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
 
 import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
@@ -35,11 +37,18 @@ import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.contrib.streaming.state.RocksDBStateBackend;
+import org.apache.flink.runtime.state.AbstractStateBackend;
+import org.apache.flink.runtime.state.filesystem.FsStateBackend;
+import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
 import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.junit.rules.TemporaryFolder;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
 
 /**
  * A simple test that runs a streaming topology with checkpointing enabled.
@@ -50,15 +59,49 @@ import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunctio
  * It is designed to check partitioned states.
  */
 @SuppressWarnings("serial")
+@RunWith(Parameterized.class)
 public class PartitionedStateCheckpointingITCase extends StreamFaultToleranceTestBase {
 
+	private static final int MAX_MEM_STATE_SIZE = 10 * 1024 * 1024;
+
 	final long NUM_STRINGS = 10_000_000L;
 	final static int NUM_KEYS = 40;
 
+	@Parameterized.Parameters
+	public static Collection<AbstractStateBackend> parameters() throws IOException {
+		TemporaryFolder tempFolder = new TemporaryFolder();
+		tempFolder.create();
+
+		MemoryStateBackend syncMemBackend = new MemoryStateBackend(MAX_MEM_STATE_SIZE, false);
+		MemoryStateBackend asyncMemBackend = new MemoryStateBackend(MAX_MEM_STATE_SIZE, true);
+
+		FsStateBackend syncFsBackend = new FsStateBackend("file://" + tempFolder.newFolder().getAbsolutePath(), false);
+		FsStateBackend asyncFsBackend = new FsStateBackend("file://" + tempFolder.newFolder().getAbsolutePath(), true);
+
+		RocksDBStateBackend fullRocksDbBackend = new RocksDBStateBackend(new MemoryStateBackend(MAX_MEM_STATE_SIZE), false);
+		fullRocksDbBackend.setDbStoragePath(tempFolder.newFolder().getAbsolutePath());
+
+		RocksDBStateBackend incRocksDbBackend = new RocksDBStateBackend(new MemoryStateBackend(MAX_MEM_STATE_SIZE), true);
+		incRocksDbBackend.setDbStoragePath(tempFolder.newFolder().getAbsolutePath());
+
+		return Arrays.asList(
+			syncMemBackend,
+			asyncMemBackend,
+			syncFsBackend,
+			asyncFsBackend,
+			fullRocksDbBackend,
+			incRocksDbBackend);
+	}
+
+	@Parameterized.Parameter
+	public AbstractStateBackend stateBackend;
+
 	@Override
 	public void testProgram(StreamExecutionEnvironment env) {
 		assertTrue("Broken test setup", (NUM_STRINGS/2) % NUM_KEYS == 0);
 
+		env.setStateBackend(stateBackend);
+
 		DataStream<Integer> stream1 = env.addSource(new IntGeneratingSourceFunction(NUM_STRINGS / 2));
 		DataStream<Integer> stream2 = env.addSource(new IntGeneratingSourceFunction(NUM_STRINGS / 2));
 
@@ -163,6 +206,7 @@ public class PartitionedStateCheckpointingITCase extends StreamFaultToleranceTes
 
 		OnceFailingPartitionedSum(long numElements) {
 			this.numElements = numElements;
+			this.hasFailed = false;
 		}
 
 		@Override
@@ -181,6 +225,7 @@ public class PartitionedStateCheckpointingITCase extends StreamFaultToleranceTes
 		@Override
 		public Tuple2<Integer, Long> map(Integer value) throws Exception {
 			count++;
+
 			if (!hasFailed && count >= failurePos) {
 				hasFailed = true;
 				throw new Exception("Test Failure");

http://git-wip-us.apache.org/repos/asf/flink/blob/6e94cf19/flink-tests/src/test/java/org/apache/flink/test/query/KVStateRequestSerializerRocksDBTest.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/query/KVStateRequestSerializerRocksDBTest.java b/flink-tests/src/test/java/org/apache/flink/test/query/KVStateRequestSerializerRocksDBTest.java
index 3c86f90..05f72c2 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/query/KVStateRequestSerializerRocksDBTest.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/query/KVStateRequestSerializerRocksDBTest.java
@@ -81,7 +81,7 @@ public final class KVStateRequestSerializerRocksDBTest {
 			super(jobId, operatorIdentifier, userCodeClassLoader,
 				instanceBasePath,
 				dbOptions, columnFamilyOptions, kvStateRegistry, keySerializer,
-				numberOfKeyGroups, keyGroupRange, executionConfig);
+				numberOfKeyGroups, keyGroupRange, executionConfig, false);
 		}
 
 		@Override
@@ -120,6 +120,7 @@ public final class KVStateRequestSerializerRocksDBTest {
 				1, new KeyGroupRange(0, 0),
 				new ExecutionConfig()
 			);
+		longHeapKeyedStateBackend.restore(null);
 		longHeapKeyedStateBackend.setCurrentKey(key);
 
 		final InternalListState<VoidNamespace, Long> listState = longHeapKeyedStateBackend
@@ -154,8 +155,9 @@ public final class KVStateRequestSerializerRocksDBTest {
 				mock(TaskKvStateRegistry.class),
 				LongSerializer.INSTANCE,
 				1, new KeyGroupRange(0, 0),
-				new ExecutionConfig()
-			);
+				new ExecutionConfig(),
+				false);
+		longHeapKeyedStateBackend.restore(null);
 		longHeapKeyedStateBackend.setCurrentKey(key);
 
 		final InternalMapState<VoidNamespace, Long, String> mapState = (InternalMapState<VoidNamespace, Long, String>)