You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by sr...@apache.org on 2017/03/28 18:05:52 UTC

[2/2] flink git commit: [FLINK-6034] [checkpoints] Introduce KeyedStateHandle abstraction for the snapshots in keyed streams

[FLINK-6034] [checkpoints] Introduce KeyedStateHandle abstraction for the snapshots in keyed streams


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/cd552741
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/cd552741
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/cd552741

Branch: refs/heads/master
Commit: cd5527417a1cae57073a8855c6c3b88c88c780aa
Parents: 89866a5
Author: xiaogang.sxg <xi...@alibaba-inc.com>
Authored: Thu Mar 23 23:32:15 2017 +0800
Committer: Stefan Richter <s....@data-artisans.com>
Committed: Tue Mar 28 20:05:28 2017 +0200

----------------------------------------------------------------------
 .../state/RocksDBKeyedStateBackend.java         | 46 ++++++++++----
 .../state/RocksDBAsyncSnapshotTest.java         |  3 +-
 .../state/RocksDBStateBackendTest.java          | 21 ++++---
 .../cep/operator/CEPMigration12to13Test.java    | 14 ++---
 .../apache/flink/migration/MigrationUtil.java   | 10 +--
 .../checkpoint/StateAssignmentOperation.java    | 41 ++++++------
 .../flink/runtime/checkpoint/SubtaskState.java  | 14 ++---
 .../savepoint/SavepointV1Serializer.java        | 42 +++++++------
 .../state/AbstractKeyedStateBackend.java        |  2 +-
 .../runtime/state/KeyGroupsStateHandle.java     | 39 ++++--------
 .../flink/runtime/state/KeyedStateHandle.java   | 40 ++++++++++++
 .../state/StateInitializationContextImpl.java   | 28 ++++++++-
 .../StateSnapshotContextSynchronousImpl.java    | 12 ++--
 .../flink/runtime/state/TaskStateHandles.java   | 16 ++---
 .../state/heap/HeapKeyedStateBackend.java       | 46 ++++++++++----
 .../checkpoint/CheckpointCoordinatorTest.java   | 29 +++++----
 .../checkpoint/CheckpointStateRestoreTest.java  |  3 +-
 .../savepoint/MigrationV0ToV1Test.java          | 14 ++++-
 .../KeyedStateCheckpointOutputStreamTest.java   |  4 +-
 .../runtime/state/StateBackendTestBase.java     | 66 ++++++++++----------
 ...pKeyedStateBackendSnapshotMigrationTest.java |  3 +-
 .../api/operators/AbstractStreamOperator.java   |  7 ++-
 .../api/operators/OperatorSnapshotResult.java   | 18 +++---
 .../runtime/tasks/OperatorStateHandles.java     | 14 ++---
 .../streaming/runtime/tasks/StreamTask.java     | 14 ++---
 .../operators/AbstractStreamOperatorTest.java   | 10 +--
 .../operators/OperatorSnapshotResultTest.java   | 10 +--
 .../StateInitializationContextImplTest.java     |  9 +--
 .../tasks/InterruptSensitiveRestoreTest.java    | 17 ++---
 .../streaming/runtime/tasks/StreamTaskTest.java | 14 ++---
 .../util/AbstractStreamOperatorTestHarness.java | 25 ++++----
 .../KeyedOneInputStreamOperatorTestHarness.java | 17 ++---
 .../KeyedTwoInputStreamOperatorTestHarness.java |  3 +-
 33 files changed, 389 insertions(+), 262 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index 2ce527f..0407070 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -40,6 +40,7 @@ import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.migration.MigrationNamespaceSerializerProxy;
 import org.apache.flink.migration.MigrationUtil;
 import org.apache.flink.migration.contrib.streaming.state.RocksDBStateBackend;
+import org.apache.flink.migration.state.MigrationKeyGroupStateHandle;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.io.async.AbstractAsyncIOCallable;
 import org.apache.flink.runtime.io.async.AsyncStoppableTaskWithCallback;
@@ -52,6 +53,7 @@ import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedBackendSerializationProxy;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.RegisteredBackendStateMetaInfo;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.internal.InternalAggregatingState;
@@ -257,7 +259,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	 * @throws Exception
 	 */
 	@Override
-	public RunnableFuture<KeyGroupsStateHandle> snapshot(
+	public RunnableFuture<KeyedStateHandle> snapshot(
 			final long checkpointId,
 			final long timestamp,
 			final CheckpointStreamFactory streamFactory,
@@ -286,8 +288,8 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		}
 
 		// implementation of the async IO operation, based on FutureTask
-		AbstractAsyncIOCallable<KeyGroupsStateHandle, CheckpointStreamFactory.CheckpointStateOutputStream> ioCallable =
-				new AbstractAsyncIOCallable<KeyGroupsStateHandle, CheckpointStreamFactory.CheckpointStateOutputStream>() {
+		AbstractAsyncIOCallable<KeyedStateHandle, CheckpointStreamFactory.CheckpointStateOutputStream> ioCallable =
+				new AbstractAsyncIOCallable<KeyedStateHandle, CheckpointStreamFactory.CheckpointStateOutputStream>() {
 
 					@Override
 					public CheckpointStreamFactory.CheckpointStateOutputStream openIOHandle() throws Exception {
@@ -620,7 +622,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	}
 
 	@Override
-	public void restore(Collection<KeyGroupsStateHandle> restoreState) throws Exception {
+	public void restore(Collection<KeyedStateHandle> restoreState) throws Exception {
 		LOG.info("Initializing RocksDB keyed state backend from snapshot.");
 
 		if (LOG.isDebugEnabled()) {
@@ -669,17 +671,23 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		/**
 		 * Restores all key-groups data that is referenced by the passed state handles.
 		 *
-		 * @param keyGroupsStateHandles List of all key groups state handles that shall be restored.
+		 * @param keyedStateHandles List of all key groups state handles that shall be restored.
 		 * @throws IOException
 		 * @throws ClassNotFoundException
 		 * @throws RocksDBException
 		 */
-		public void doRestore(Collection<KeyGroupsStateHandle> keyGroupsStateHandles)
+		public void doRestore(Collection<KeyedStateHandle> keyedStateHandles)
 				throws IOException, ClassNotFoundException, RocksDBException {
 
-			for (KeyGroupsStateHandle keyGroupsStateHandle : keyGroupsStateHandles) {
-				if (keyGroupsStateHandle != null) {
-					this.currentKeyGroupsStateHandle = keyGroupsStateHandle;
+			for (KeyedStateHandle keyedStateHandle : keyedStateHandles) {
+				if (keyedStateHandle != null) {
+
+					if (!(keyedStateHandle instanceof KeyGroupsStateHandle)) {
+						throw new IllegalStateException("Unexpected state handle type, " +
+								"expected: " + KeyGroupsStateHandle.class +
+								", but found: " + keyedStateHandle.getClass());
+					}
+					this.currentKeyGroupsStateHandle = (KeyGroupsStateHandle) keyedStateHandle;
 					restoreKeyGroupsInStateHandle();
 				}
 			}
@@ -761,6 +769,12 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		private void restoreKVStateData() throws IOException, RocksDBException {
 			//for all key-groups in the current state handle...
 			for (Tuple2<Integer, Long> keyGroupOffset : currentKeyGroupsStateHandle.getGroupRangeOffsets()) {
+				int keyGroup = keyGroupOffset.f0;
+
+				// Check that restored key groups all belong to the backend
+				Preconditions.checkState(rocksDBKeyedStateBackend.getKeyGroupRange().contains(keyGroup),
+					"The key group must belong to the backend");
+
 				long offset = keyGroupOffset.f1;
 				//not empty key-group?
 				if (0L != offset) {
@@ -1143,15 +1157,25 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	 * For backwards compatibility, remove again later!
 	 */
 	@Deprecated
-	private void restoreOldSavepointKeyedState(Collection<KeyGroupsStateHandle> restoreState) throws Exception {
+	private void restoreOldSavepointKeyedState(Collection<KeyedStateHandle> restoreState) throws Exception {
 
 		if (restoreState.isEmpty()) {
 			return;
 		}
 
 		Preconditions.checkState(1 == restoreState.size(), "Only one element expected here.");
+
+		KeyedStateHandle keyedStateHandle = restoreState.iterator().next();
+		if (!(keyedStateHandle instanceof MigrationKeyGroupStateHandle)) {
+			throw new IllegalStateException("Unexpected state handle type, " +
+					"expected: " + MigrationKeyGroupStateHandle.class +
+					", but found: " + keyedStateHandle.getClass());
+		}
+
+		MigrationKeyGroupStateHandle keyGroupStateHandle = (MigrationKeyGroupStateHandle) keyedStateHandle;
+
 		HashMap<String, RocksDBStateBackend.FinalFullyAsyncSnapshot> namedStates;
-		try (FSDataInputStream inputStream = restoreState.iterator().next().openInputStream()) {
+		try (FSDataInputStream inputStream = keyGroupStateHandle.openInputStream()) {
 			namedStates = InstantiationUtil.deserializeObject(inputStream, userCodeClassLoader);
 		}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
index 90de7a6..ffe2ce2 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
@@ -42,6 +42,7 @@ import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory;
@@ -343,7 +344,7 @@ public class RocksDBAsyncSnapshotTest {
 			StringSerializer.INSTANCE,
 			new ValueStateDescriptor<>("foobar", String.class));
 
-		RunnableFuture<KeyGroupsStateHandle> snapshotFuture = keyedStateBackend.snapshot(
+		RunnableFuture<KeyedStateHandle> snapshotFuture = keyedStateBackend.snapshot(
 			checkpointId, timestamp, checkpointStreamFactory, CheckpointOptions.forFullCheckpoint());
 
 		try {

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
index 708613b..d95a9b4 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
@@ -33,6 +33,7 @@ import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.StateBackendTestBase;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
@@ -172,7 +173,7 @@ public class RocksDBStateBackendTest extends StateBackendTestBase<RocksDBStateBa
 	@Test
 	public void testRunningSnapshotAfterBackendClosed() throws Exception {
 		setupRocksKeyedStateBackend();
-		RunnableFuture<KeyGroupsStateHandle> snapshot = keyedStateBackend.snapshot(0L, 0L, testStreamFactory,
+		RunnableFuture<KeyedStateHandle> snapshot = keyedStateBackend.snapshot(0L, 0L, testStreamFactory,
 			CheckpointOptions.forFullCheckpoint());
 
 		RocksDB spyDB = keyedStateBackend.db;
@@ -210,7 +211,7 @@ public class RocksDBStateBackendTest extends StateBackendTestBase<RocksDBStateBa
 	@Test
 	public void testReleasingSnapshotAfterBackendClosed() throws Exception {
 		setupRocksKeyedStateBackend();
-		RunnableFuture<KeyGroupsStateHandle> snapshot = keyedStateBackend.snapshot(0L, 0L, testStreamFactory,
+		RunnableFuture<KeyedStateHandle> snapshot = keyedStateBackend.snapshot(0L, 0L, testStreamFactory,
 			CheckpointOptions.forFullCheckpoint());
 
 		RocksDB spyDB = keyedStateBackend.db;
@@ -239,7 +240,7 @@ public class RocksDBStateBackendTest extends StateBackendTestBase<RocksDBStateBa
 	@Test
 	public void testDismissingSnapshot() throws Exception {
 		setupRocksKeyedStateBackend();
-		RunnableFuture<KeyGroupsStateHandle> snapshot = keyedStateBackend.snapshot(0L, 0L, testStreamFactory, CheckpointOptions.forFullCheckpoint());
+		RunnableFuture<KeyedStateHandle> snapshot = keyedStateBackend.snapshot(0L, 0L, testStreamFactory, CheckpointOptions.forFullCheckpoint());
 		snapshot.cancel(true);
 		verifyRocksObjectsReleased();
 	}
@@ -247,7 +248,7 @@ public class RocksDBStateBackendTest extends StateBackendTestBase<RocksDBStateBa
 	@Test
 	public void testDismissingSnapshotNotRunnable() throws Exception {
 		setupRocksKeyedStateBackend();
-		RunnableFuture<KeyGroupsStateHandle> snapshot = keyedStateBackend.snapshot(0L, 0L, testStreamFactory, CheckpointOptions.forFullCheckpoint());
+		RunnableFuture<KeyedStateHandle> snapshot = keyedStateBackend.snapshot(0L, 0L, testStreamFactory, CheckpointOptions.forFullCheckpoint());
 		snapshot.cancel(true);
 		Thread asyncSnapshotThread = new Thread(snapshot);
 		asyncSnapshotThread.start();
@@ -264,7 +265,7 @@ public class RocksDBStateBackendTest extends StateBackendTestBase<RocksDBStateBa
 	@Test
 	public void testCompletingSnapshot() throws Exception {
 		setupRocksKeyedStateBackend();
-		RunnableFuture<KeyGroupsStateHandle> snapshot = keyedStateBackend.snapshot(0L, 0L, testStreamFactory, CheckpointOptions.forFullCheckpoint());
+		RunnableFuture<KeyedStateHandle> snapshot = keyedStateBackend.snapshot(0L, 0L, testStreamFactory, CheckpointOptions.forFullCheckpoint());
 		Thread asyncSnapshotThread = new Thread(snapshot);
 		asyncSnapshotThread.start();
 		waiter.await(); // wait for snapshot to run
@@ -272,10 +273,10 @@ public class RocksDBStateBackendTest extends StateBackendTestBase<RocksDBStateBa
 		runStateUpdates();
 		blocker.trigger(); // allow checkpointing to start writing
 		waiter.await(); // wait for snapshot stream writing to run
-		KeyGroupsStateHandle keyGroupsStateHandle = snapshot.get();
-		assertNotNull(keyGroupsStateHandle);
-		assertTrue(keyGroupsStateHandle.getStateSize() > 0);
-		assertEquals(2, keyGroupsStateHandle.getNumberOfKeyGroups());
+		KeyedStateHandle keyedStateHandle = snapshot.get();
+		assertNotNull(keyedStateHandle);
+		assertTrue(keyedStateHandle.getStateSize() > 0);
+		assertEquals(2, keyedStateHandle.getKeyGroupRange().getNumberOfKeyGroups());
 		assertTrue(testStreamFactory.getLastCreatedStream().isClosed());
 		asyncSnapshotThread.join();
 		verifyRocksObjectsReleased();
@@ -284,7 +285,7 @@ public class RocksDBStateBackendTest extends StateBackendTestBase<RocksDBStateBa
 	@Test
 	public void testCancelRunningSnapshot() throws Exception {
 		setupRocksKeyedStateBackend();
-		RunnableFuture<KeyGroupsStateHandle> snapshot = keyedStateBackend.snapshot(0L, 0L, testStreamFactory, CheckpointOptions.forFullCheckpoint());
+		RunnableFuture<KeyedStateHandle> snapshot = keyedStateBackend.snapshot(0L, 0L, testStreamFactory, CheckpointOptions.forFullCheckpoint());
 		Thread asyncSnapshotThread = new Thread(snapshot);
 		asyncSnapshotThread.start();
 		waiter.await(); // wait for snapshot to run

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPMigration12to13Test.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPMigration12to13Test.java b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPMigration12to13Test.java
index f230bbc..dbe4230 100644
--- a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPMigration12to13Test.java
+++ b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPMigration12to13Test.java
@@ -26,7 +26,7 @@ import org.apache.flink.cep.nfa.NFA;
 import org.apache.flink.cep.nfa.compiler.NFACompiler;
 import org.apache.flink.cep.pattern.Pattern;
 import org.apache.flink.cep.pattern.conditions.SimpleCondition;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.streaming.api.watermark.Watermark;
@@ -128,8 +128,8 @@ public class CEPMigration12to13Test {
 		final OperatorStateHandles snapshot = new OperatorStateHandles(
 			(int) ois.readObject(),
 			(StreamStateHandle) ois.readObject(),
-			(Collection<KeyGroupsStateHandle>) ois.readObject(),
-			(Collection<KeyGroupsStateHandle>) ois.readObject(),
+			(Collection<KeyedStateHandle>) ois.readObject(),
+			(Collection<KeyedStateHandle>) ois.readObject(),
 			(Collection<OperatorStateHandle>) ois.readObject(),
 			(Collection<OperatorStateHandle>) ois.readObject()
 		);
@@ -243,8 +243,8 @@ public class CEPMigration12to13Test {
 		final OperatorStateHandles snapshot = new OperatorStateHandles(
 			(int) ois.readObject(),
 			(StreamStateHandle) ois.readObject(),
-			(Collection<KeyGroupsStateHandle>) ois.readObject(),
-			(Collection<KeyGroupsStateHandle>) ois.readObject(),
+			(Collection<KeyedStateHandle>) ois.readObject(),
+			(Collection<KeyedStateHandle>) ois.readObject(),
 			(Collection<OperatorStateHandle>) ois.readObject(),
 			(Collection<OperatorStateHandle>) ois.readObject()
 		);
@@ -363,8 +363,8 @@ public class CEPMigration12to13Test {
 		final OperatorStateHandles snapshot = new OperatorStateHandles(
 			(int) ois.readObject(),
 			(StreamStateHandle) ois.readObject(),
-			(Collection<KeyGroupsStateHandle>) ois.readObject(),
-			(Collection<KeyGroupsStateHandle>) ois.readObject(),
+			(Collection<KeyedStateHandle>) ois.readObject(),
+			(Collection<KeyedStateHandle>) ois.readObject(),
 			(Collection<OperatorStateHandle>) ois.readObject(),
 			(Collection<OperatorStateHandle>) ois.readObject()
 		);

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/migration/MigrationUtil.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/MigrationUtil.java b/flink-runtime/src/main/java/org/apache/flink/migration/MigrationUtil.java
index 9427f72..a4e3a2e 100644
--- a/flink-runtime/src/main/java/org/apache/flink/migration/MigrationUtil.java
+++ b/flink-runtime/src/main/java/org/apache/flink/migration/MigrationUtil.java
@@ -19,17 +19,17 @@
 package org.apache.flink.migration;
 
 import org.apache.flink.migration.state.MigrationKeyGroupStateHandle;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 
 import java.util.Collection;
 
 public class MigrationUtil {
 
 	@SuppressWarnings("deprecation")
-	public static boolean isOldSavepointKeyedState(Collection<KeyGroupsStateHandle> keyGroupsStateHandles) {
-		return (keyGroupsStateHandles != null)
-				&& (keyGroupsStateHandles.size() == 1)
-				&& (keyGroupsStateHandles.iterator().next() instanceof MigrationKeyGroupStateHandle);
+	public static boolean isOldSavepointKeyedState(Collection<KeyedStateHandle> keyedStateHandles) {
+		return (keyedStateHandles != null)
+				&& (keyedStateHandles.size() == 1)
+				&& (keyedStateHandles.iterator().next() instanceof MigrationKeyGroupStateHandle);
 	}
 
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
index 3fda430..ac70e1a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
@@ -25,6 +25,7 @@ import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.TaskStateHandles;
@@ -160,8 +161,8 @@ public class StateAssignmentOperation {
 		@SuppressWarnings("unchecked")
 		List<OperatorStateHandle>[] parallelOpStatesStream = new List[chainLength];
 
-		List<KeyGroupsStateHandle> parallelKeyedStatesBackend = new ArrayList<>(oldParallelism);
-		List<KeyGroupsStateHandle> parallelKeyedStateStream = new ArrayList<>(oldParallelism);
+		List<KeyedStateHandle> parallelKeyedStatesBackend = new ArrayList<>(oldParallelism);
+		List<KeyedStateHandle> parallelKeyedStateStream = new ArrayList<>(oldParallelism);
 
 		for (int p = 0; p < oldParallelism; ++p) {
 			SubtaskState subtaskState = taskState.getState(p);
@@ -173,12 +174,12 @@ public class StateAssignmentOperation {
 				collectParallelStatesByChainOperator(
 						parallelOpStatesStream, subtaskState.getRawOperatorState());
 
-				KeyGroupsStateHandle keyedStateBackend = subtaskState.getManagedKeyedState();
+				KeyedStateHandle keyedStateBackend = subtaskState.getManagedKeyedState();
 				if (null != keyedStateBackend) {
 					parallelKeyedStatesBackend.add(keyedStateBackend);
 				}
 
-				KeyGroupsStateHandle keyedStateStream = subtaskState.getRawKeyedState();
+				KeyedStateHandle keyedStateStream = subtaskState.getRawKeyedState();
 				if (null != keyedStateStream) {
 					parallelKeyedStateStream.add(keyedStateStream);
 				}
@@ -252,13 +253,13 @@ public class StateAssignmentOperation {
 					.getTaskVertices()[subTaskIdx]
 					.getCurrentExecutionAttempt();
 
-			List<KeyGroupsStateHandle> newKeyedStatesBackend;
-			List<KeyGroupsStateHandle> newKeyedStateStream;
+			List<KeyedStateHandle> newKeyedStatesBackend;
+			List<KeyedStateHandle> newKeyedStateStream;
 			if (oldParallelism == newParallelism) {
 				SubtaskState subtaskState = taskState.getState(subTaskIdx);
 				if (subtaskState != null) {
-					KeyGroupsStateHandle oldKeyedStatesBackend = subtaskState.getManagedKeyedState();
-					KeyGroupsStateHandle oldKeyedStatesStream = subtaskState.getRawKeyedState();
+					KeyedStateHandle oldKeyedStatesBackend = subtaskState.getManagedKeyedState();
+					KeyedStateHandle oldKeyedStatesStream = subtaskState.getRawKeyedState();
 					newKeyedStatesBackend = oldKeyedStatesBackend != null ? Collections.singletonList(
 							oldKeyedStatesBackend) : null;
 					newKeyedStateStream = oldKeyedStatesStream != null ? Collections.singletonList(
@@ -269,8 +270,8 @@ public class StateAssignmentOperation {
 				}
 			} else {
 				KeyGroupRange subtaskKeyGroupIds = keyGroupPartitions.get(subTaskIdx);
-				newKeyedStatesBackend = getKeyGroupsStateHandles(parallelKeyedStatesBackend, subtaskKeyGroupIds);
-				newKeyedStateStream = getKeyGroupsStateHandles(parallelKeyedStateStream, subtaskKeyGroupIds);
+				newKeyedStatesBackend = getKeyedStateHandles(parallelKeyedStatesBackend, subtaskKeyGroupIds);
+				newKeyedStateStream = getKeyedStateHandles(parallelKeyedStateStream, subtaskKeyGroupIds);
 			}
 
 			TaskStateHandles taskStateHandles = new TaskStateHandles(
@@ -290,19 +291,21 @@ public class StateAssignmentOperation {
 	 * <p>
 	 * <p>This is publicly visible to be used in tests.
 	 */
-	public static List<KeyGroupsStateHandle> getKeyGroupsStateHandles(
-			Collection<KeyGroupsStateHandle> allKeyGroupsHandles,
-			KeyGroupRange subtaskKeyGroupIds) {
+	public static List<KeyedStateHandle> getKeyedStateHandles(
+			Collection<? extends KeyedStateHandle> keyedStateHandles,
+			KeyGroupRange subtaskKeyGroupRange) {
 
-		List<KeyGroupsStateHandle> subtaskKeyGroupStates = new ArrayList<>();
+		List<KeyedStateHandle> subtaskKeyedStateHandles = new ArrayList<>();
 
-		for (KeyGroupsStateHandle storedKeyGroup : allKeyGroupsHandles) {
-			KeyGroupsStateHandle intersection = storedKeyGroup.getKeyGroupIntersection(subtaskKeyGroupIds);
-			if (intersection.getNumberOfKeyGroups() > 0) {
-				subtaskKeyGroupStates.add(intersection);
+		for (KeyedStateHandle keyedStateHandle : keyedStateHandles) {
+			KeyedStateHandle intersectedKeyedStateHandle = keyedStateHandle.getIntersection(subtaskKeyGroupRange);
+
+			if (intersectedKeyedStateHandle != null) {
+				subtaskKeyedStateHandles.add(intersectedKeyedStateHandle);
 			}
 		}
-		return subtaskKeyGroupStates;
+
+		return subtaskKeyedStateHandles;
 	}
 
 	/**

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
index 1393e32..9e195b1 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
@@ -19,7 +19,7 @@
 package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateObject;
 import org.apache.flink.runtime.state.StateUtil;
@@ -56,12 +56,12 @@ public class SubtaskState implements StateObject {
 	/**
 	 * Snapshot from {@link org.apache.flink.runtime.state.KeyedStateBackend}.
 	 */
-	private final KeyGroupsStateHandle managedKeyedState;
+	private final KeyedStateHandle managedKeyedState;
 
 	/**
 	 * Snapshot written using {@link org.apache.flink.runtime.state.KeyedStateCheckpointOutputStream}.
 	 */
-	private final KeyGroupsStateHandle rawKeyedState;
+	private final KeyedStateHandle rawKeyedState;
 
 	/**
 	 * The state size. This is also part of the deserialized state handle.
@@ -74,8 +74,8 @@ public class SubtaskState implements StateObject {
 			ChainedStateHandle<StreamStateHandle> legacyOperatorState,
 			ChainedStateHandle<OperatorStateHandle> managedOperatorState,
 			ChainedStateHandle<OperatorStateHandle> rawOperatorState,
-			KeyGroupsStateHandle managedKeyedState,
-			KeyGroupsStateHandle rawKeyedState) {
+			KeyedStateHandle managedKeyedState,
+			KeyedStateHandle rawKeyedState) {
 
 		this.legacyOperatorState = checkNotNull(legacyOperatorState, "State");
 		this.managedOperatorState = managedOperatorState;
@@ -114,11 +114,11 @@ public class SubtaskState implements StateObject {
 		return rawOperatorState;
 	}
 
-	public KeyGroupsStateHandle getManagedKeyedState() {
+	public KeyedStateHandle getManagedKeyedState() {
 		return managedKeyedState;
 	}
 
-	public KeyGroupsStateHandle getRawKeyedState() {
+	public KeyedStateHandle getRawKeyedState() {
 		return rawKeyedState;
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
index ba1949a..44461d8 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
@@ -26,6 +26,7 @@ import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.filesystem.FileStateHandle;
@@ -155,11 +156,11 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV1> {
 			serializeOperatorStateHandle(stateHandle, dos);
 		}
 
-		KeyGroupsStateHandle keyedStateBackend = subtaskState.getManagedKeyedState();
-		serializeKeyGroupStateHandle(keyedStateBackend, dos);
+		KeyedStateHandle keyedStateBackend = subtaskState.getManagedKeyedState();
+		serializeKeyedStateHandle(keyedStateBackend, dos);
 
-		KeyGroupsStateHandle keyedStateStream = subtaskState.getRawKeyedState();
-		serializeKeyGroupStateHandle(keyedStateStream, dos);
+		KeyedStateHandle keyedStateStream = subtaskState.getRawKeyedState();
+		serializeKeyedStateHandle(keyedStateStream, dos);
 	}
 
 	private static SubtaskState deserializeSubtaskState(DataInputStream dis) throws IOException {
@@ -188,9 +189,9 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV1> {
 			operatorStateStream.add(streamStateHandle);
 		}
 
-		KeyGroupsStateHandle keyedStateBackend = deserializeKeyGroupStateHandle(dis);
+		KeyedStateHandle keyedStateBackend = deserializeKeyedStateHandle(dis);
 
-		KeyGroupsStateHandle keyedStateStream = deserializeKeyGroupStateHandle(dis);
+		KeyedStateHandle keyedStateStream = deserializeKeyedStateHandle(dis);
 
 		ChainedStateHandle<StreamStateHandle> nonPartitionableStateChain =
 				new ChainedStateHandle<>(nonPartitionableState);
@@ -209,23 +210,27 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV1> {
 				keyedStateStream);
 	}
 
-	private static void serializeKeyGroupStateHandle(
-			KeyGroupsStateHandle stateHandle, DataOutputStream dos) throws IOException {
+	private static void serializeKeyedStateHandle(
+			KeyedStateHandle stateHandle, DataOutputStream dos) throws IOException {
+
+		if (stateHandle == null) {
+			dos.writeByte(NULL_HANDLE);
+		} else if (stateHandle instanceof KeyGroupsStateHandle) {
+			KeyGroupsStateHandle keyGroupsStateHandle = (KeyGroupsStateHandle) stateHandle;
 
-		if (stateHandle != null) {
 			dos.writeByte(KEY_GROUPS_HANDLE);
-			dos.writeInt(stateHandle.getGroupRangeOffsets().getKeyGroupRange().getStartKeyGroup());
-			dos.writeInt(stateHandle.getNumberOfKeyGroups());
-			for (int keyGroup : stateHandle.keyGroups()) {
-				dos.writeLong(stateHandle.getOffsetForKeyGroup(keyGroup));
+			dos.writeInt(keyGroupsStateHandle.getKeyGroupRange().getStartKeyGroup());
+			dos.writeInt(keyGroupsStateHandle.getKeyGroupRange().getNumberOfKeyGroups());
+			for (int keyGroup : keyGroupsStateHandle.getKeyGroupRange()) {
+				dos.writeLong(keyGroupsStateHandle.getOffsetForKeyGroup(keyGroup));
 			}
-			serializeStreamStateHandle(stateHandle.getDelegateStateHandle(), dos);
+			serializeStreamStateHandle(keyGroupsStateHandle.getDelegateStateHandle(), dos);
 		} else {
-			dos.writeByte(NULL_HANDLE);
+			throw new IllegalStateException("Unknown KeyedStateHandle type: " + stateHandle.getClass());
 		}
 	}
 
-	private static KeyGroupsStateHandle deserializeKeyGroupStateHandle(DataInputStream dis) throws IOException {
+	private static KeyedStateHandle deserializeKeyedStateHandle(DataInputStream dis) throws IOException {
 		final int type = dis.readByte();
 		if (NULL_HANDLE == type) {
 			return null;
@@ -237,11 +242,12 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV1> {
 			for (int i = 0; i < numKeyGroups; ++i) {
 				offsets[i] = dis.readLong();
 			}
-			KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(keyGroupRange, offsets);
+			KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(
+				keyGroupRange, offsets);
 			StreamStateHandle stateHandle = deserializeStreamStateHandle(dis);
 			return new KeyGroupsStateHandle(keyGroupRangeOffsets, stateHandle);
 		} else {
-			throw new IllegalStateException("Reading invalid KeyGroupsStateHandle, type: " + type);
+			throw new IllegalStateException("Reading invalid KeyedStateHandle, type: " + type);
 		}
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
index e6e7b23..e86f1f8 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
@@ -61,7 +61,7 @@ import static org.apache.flink.util.Preconditions.checkNotNull;
  * @param <K> Type of the key by which state is keyed.
  */
 public abstract class AbstractKeyedStateBackend<K>
-		implements KeyedStateBackend<K>, Snapshotable<KeyGroupsStateHandle>, Closeable {
+		implements KeyedStateBackend<K>, Snapshotable<KeyedStateHandle>, Closeable {
 
 	/** {@link TypeSerializer} for our key. */
 	protected final TypeSerializer<K> keySerializer;

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java
index b454e42..bad7fd4 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java
@@ -29,7 +29,7 @@ import java.io.IOException;
  * consists of a range of key group snapshots. A key group is subset of the available
  * key space. The key groups are identified by their key group indices.
  */
-public class KeyGroupsStateHandle implements StreamStateHandle {
+public class KeyGroupsStateHandle implements StreamStateHandle, KeyedStateHandle {
 
 	private static final long serialVersionUID = -8070326169926626355L;
 
@@ -54,20 +54,18 @@ public class KeyGroupsStateHandle implements StreamStateHandle {
 
 	/**
 	 *
-	 * @return iterable over the key-group range for the key-group state referenced by this handle
+	 * @return the internal key-group range to offsets metadata
 	 */
-	public Iterable<Integer> keyGroups() {
-		return groupRangeOffsets.getKeyGroupRange();
+	public KeyGroupRangeOffsets getGroupRangeOffsets() {
+		return groupRangeOffsets;
 	}
 
-
 	/**
 	 *
-	 * @param keyGroupId the id of a key-group
-	 * @return true if the provided key-group id is contained in the key-group range of this handle
+	 * @return The handle to the actual states
 	 */
-	public boolean containsKeyGroup(int keyGroupId) {
-		return groupRangeOffsets.getKeyGroupRange().contains(keyGroupId);
+	public StreamStateHandle getDelegateStateHandle() {
+		return stateHandle;
 	}
 
 	/**
@@ -85,24 +83,13 @@ public class KeyGroupsStateHandle implements StreamStateHandle {
 	 * @return key-group state over a range that is the intersection between this handle's key-group range and the
 	 *          provided key-group range.
 	 */
-	public KeyGroupsStateHandle getKeyGroupIntersection(KeyGroupRange keyGroupRange) {
+	public KeyGroupsStateHandle getIntersection(KeyGroupRange keyGroupRange) {
 		return new KeyGroupsStateHandle(groupRangeOffsets.getIntersection(keyGroupRange), stateHandle);
 	}
 
-	/**
-	 *
-	 * @return the internal key-group range to offsets metadata
-	 */
-	public KeyGroupRangeOffsets getGroupRangeOffsets() {
-		return groupRangeOffsets;
-	}
-
-	/**
-	 *
-	 * @return number of key-groups in the key-group range of this handle
-	 */
-	public int getNumberOfKeyGroups() {
-		return groupRangeOffsets.getKeyGroupRange().getNumberOfKeyGroups();
+	@Override
+	public KeyGroupRange getKeyGroupRange() {
+		return groupRangeOffsets.getKeyGroupRange();
 	}
 
 	@Override
@@ -120,10 +107,6 @@ public class KeyGroupsStateHandle implements StreamStateHandle {
 		return stateHandle.openInputStream();
 	}
 
-	public StreamStateHandle getDelegateStateHandle() {
-		return stateHandle;
-	}
-
 	@Override
 	public boolean equals(Object o) {
 		if (this == o) {

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateHandle.java
new file mode 100644
index 0000000..dc9c97d
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateHandle.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+/**
+ * Base for the handles of the checkpointed states in keyed streams. When
+ * recovering from failures, the handle will be passed to all tasks whose key
+ * group ranges overlap with it.
+ */
+public interface KeyedStateHandle extends StateObject {
+
+	/**
+	 * Returns the range of the key groups contained in the state.
+	 */
+	KeyGroupRange getKeyGroupRange();
+
+	/**
+	 * Returns a state over a range that is the intersection between this
+	 * handle's key-group range and the provided key-group range.
+	 *
+	 * @param keyGroupRange The key group range to intersect with
+	 */
+	KeyedStateHandle getIntersection(KeyGroupRange keyGroupRange);
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
index 886d214..d82af72 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
@@ -27,9 +27,11 @@ import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.util.Preconditions;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.Iterator;
+import java.util.List;
 import java.util.NoSuchElementException;
 
 /**
@@ -55,7 +57,7 @@ public class StateInitializationContextImpl implements StateInitializationContex
 			boolean restored,
 			OperatorStateStore operatorStateStore,
 			KeyedStateStore keyedStateStore,
-			Collection<KeyGroupsStateHandle> keyGroupsStateHandles,
+			Collection<KeyedStateHandle> keyedStateHandles,
 			Collection<OperatorStateHandle> operatorStateHandles,
 			CloseableRegistry closableRegistry) {
 
@@ -64,7 +66,7 @@ public class StateInitializationContextImpl implements StateInitializationContex
 		this.operatorStateStore = operatorStateStore;
 		this.keyedStateStore = keyedStateStore;
 		this.operatorStateHandles = operatorStateHandles;
-		this.keyGroupsStateHandles = keyGroupsStateHandles;
+		this.keyGroupsStateHandles = transform(keyedStateHandles);
 
 		this.keyedStateIterable = keyGroupsStateHandles == null ?
 				null
@@ -136,6 +138,26 @@ public class StateInitializationContextImpl implements StateInitializationContex
 		IOUtils.closeQuietly(closableRegistry);
 	}
 
+	private static Collection<KeyGroupsStateHandle> transform(Collection<KeyedStateHandle> keyedStateHandles) {
+		if (keyedStateHandles == null) {
+			return null;
+		}
+
+		List<KeyGroupsStateHandle> keyGroupsStateHandles = new ArrayList<>();
+
+		for (KeyedStateHandle keyedStateHandle : keyedStateHandles) {
+			if (! (keyedStateHandle instanceof KeyGroupsStateHandle)) {
+				throw new IllegalStateException("Unexpected state handle type, " +
+					"expected: " + KeyGroupsStateHandle.class +
+					", but found: " + keyedStateHandle.getClass() + ".");
+			}
+
+			keyGroupsStateHandles.add((KeyGroupsStateHandle) keyedStateHandle);
+		}
+
+		return keyGroupsStateHandles;
+	}
+
 	private static class KeyGroupStreamIterator
 			extends AbstractStateStreamIterator<KeyGroupStatePartitionStreamProvider, KeyGroupsStateHandle> {
 
@@ -159,7 +181,7 @@ public class StateInitializationContextImpl implements StateInitializationContex
 
 			while (stateHandleIterator.hasNext()) {
 				currentStateHandle = stateHandleIterator.next();
-				if (currentStateHandle.getNumberOfKeyGroups() > 0) {
+				if (currentStateHandle.getKeyGroupRange().getNumberOfKeyGroups() > 0) {
 					currentOffsetsIterator = currentStateHandle.getGroupRangeOffsets().iterator();
 
 					return true;

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContextSynchronousImpl.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContextSynchronousImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContextSynchronousImpl.java
index 96edccb..5db0138 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContextSynchronousImpl.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContextSynchronousImpl.java
@@ -109,15 +109,17 @@ public class StateSnapshotContextSynchronousImpl implements StateSnapshotContext
 		return operatorStateCheckpointOutputStream;
 	}
 
-	public RunnableFuture<KeyGroupsStateHandle> getKeyedStateStreamFuture() throws IOException {
-		return closeAndUnregisterStreamToObtainStateHandle(keyedStateCheckpointOutputStream);
+	public RunnableFuture<KeyedStateHandle> getKeyedStateStreamFuture() throws IOException {
+		KeyGroupsStateHandle keyGroupsStateHandle = closeAndUnregisterStreamToObtainStateHandle(keyedStateCheckpointOutputStream);
+		return new DoneFuture<KeyedStateHandle>(keyGroupsStateHandle);
 	}
 
 	public RunnableFuture<OperatorStateHandle> getOperatorStateStreamFuture() throws IOException {
-		return closeAndUnregisterStreamToObtainStateHandle(operatorStateCheckpointOutputStream);
+		OperatorStateHandle operatorStateHandle = closeAndUnregisterStreamToObtainStateHandle(operatorStateCheckpointOutputStream);
+		return new DoneFuture<>(operatorStateHandle);
 	}
 
-	private <T extends StreamStateHandle> RunnableFuture<T> closeAndUnregisterStreamToObtainStateHandle(
+	private <T extends StreamStateHandle> T closeAndUnregisterStreamToObtainStateHandle(
 			NonClosingCheckpointOutputStream<T> stream) throws IOException {
 		if (null == stream) {
 			return null;
@@ -126,7 +128,7 @@ public class StateSnapshotContextSynchronousImpl implements StateSnapshotContext
 		closableRegistry.unregisterClosable(stream.getDelegate());
 
 		// for now we only support synchronous writing
-		return new DoneFuture<>(stream.closeAndGetHandle());
+		return stream.closeAndGetHandle();
 	}
 
 	private <T extends StreamStateHandle> void closeAndUnregisterStream(NonClosingCheckpointOutputStream<T> stream) throws IOException {

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java
index 417a9dd..450413a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java
@@ -40,10 +40,10 @@ public class TaskStateHandles implements Serializable {
 	private final ChainedStateHandle<StreamStateHandle> legacyOperatorState;
 
 	/** Collection of handles which represent the managed keyed state of the head operator */
-	private final Collection<KeyGroupsStateHandle> managedKeyedState;
+	private final Collection<KeyedStateHandle> managedKeyedState;
 
 	/** Collection of handles which represent the raw/streamed keyed state of the head operator */
-	private final Collection<KeyGroupsStateHandle> rawKeyedState;
+	private final Collection<KeyedStateHandle> rawKeyedState;
 
 	/** Outer list represents the operator chain, each collection holds handles for managed state of a single operator */
 	private final List<Collection<OperatorStateHandle>> managedOperatorState;
@@ -67,8 +67,8 @@ public class TaskStateHandles implements Serializable {
 			ChainedStateHandle<StreamStateHandle> legacyOperatorState,
 			List<Collection<OperatorStateHandle>> managedOperatorState,
 			List<Collection<OperatorStateHandle>> rawOperatorState,
-			Collection<KeyGroupsStateHandle> managedKeyedState,
-			Collection<KeyGroupsStateHandle> rawKeyedState) {
+			Collection<KeyedStateHandle> managedKeyedState,
+			Collection<KeyedStateHandle> rawKeyedState) {
 
 		this.legacyOperatorState = legacyOperatorState;
 		this.managedKeyedState = managedKeyedState;
@@ -82,11 +82,11 @@ public class TaskStateHandles implements Serializable {
 		return legacyOperatorState;
 	}
 
-	public Collection<KeyGroupsStateHandle> getManagedKeyedState() {
+	public Collection<KeyedStateHandle> getManagedKeyedState() {
 		return managedKeyedState;
 	}
 
-	public Collection<KeyGroupsStateHandle> getRawKeyedState() {
+	public Collection<KeyedStateHandle> getRawKeyedState() {
 		return rawKeyedState;
 	}
 
@@ -110,8 +110,8 @@ public class TaskStateHandles implements Serializable {
 		return out;
 	}
 
-	private static List<KeyGroupsStateHandle> transform(KeyGroupsStateHandle in) {
-		return in == null ? Collections.<KeyGroupsStateHandle>emptyList() : Collections.singletonList(in);
+	private static <T> List<T> transform(T in) {
+		return in == null ? Collections.<T>emptyList() : Collections.singletonList(in);
 	}
 
 	@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
index 46ec5c2..a332d7d 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
@@ -39,6 +39,7 @@ import org.apache.flink.migration.runtime.state.KvStateSnapshot;
 import org.apache.flink.migration.runtime.state.memory.MigrationRestoreSnapshot;
 import org.apache.flink.runtime.io.async.AbstractAsyncIOCallable;
 import org.apache.flink.runtime.io.async.AsyncStoppableTaskWithCallback;
+import org.apache.flink.migration.state.MigrationKeyGroupStateHandle;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
@@ -50,6 +51,7 @@ import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedBackendSerializationProxy;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.RegisteredBackendStateMetaInfo;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.internal.InternalAggregatingState;
@@ -223,7 +225,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 	@Override
 	@SuppressWarnings("unchecked")
-	public  RunnableFuture<KeyGroupsStateHandle> snapshot(
+	public  RunnableFuture<KeyedStateHandle> snapshot(
 			final long checkpointId,
 			final long timestamp,
 			final CheckpointStreamFactory streamFactory,
@@ -267,8 +269,8 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		//--------------------------------------------------- this becomes the end of sync part
 
 		// implementation of the async IO operation, based on FutureTask
-		final AbstractAsyncIOCallable<KeyGroupsStateHandle, CheckpointStreamFactory.CheckpointStateOutputStream> ioCallable =
-				new AbstractAsyncIOCallable<KeyGroupsStateHandle, CheckpointStreamFactory.CheckpointStateOutputStream>() {
+		final AbstractAsyncIOCallable<KeyedStateHandle, CheckpointStreamFactory.CheckpointStateOutputStream> ioCallable =
+				new AbstractAsyncIOCallable<KeyedStateHandle, CheckpointStreamFactory.CheckpointStateOutputStream>() {
 
 					AtomicBoolean open = new AtomicBoolean(false);
 
@@ -340,7 +342,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 					}
 				};
 
-		AsyncStoppableTaskWithCallback<KeyGroupsStateHandle> task = AsyncStoppableTaskWithCallback.from(ioCallable);
+		AsyncStoppableTaskWithCallback<KeyedStateHandle> task = AsyncStoppableTaskWithCallback.from(ioCallable);
 
 		if (!asynchronousSnapshots) {
 			task.run();
@@ -354,7 +356,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 	@SuppressWarnings("deprecation")
 	@Override
-	public void restore(Collection<KeyGroupsStateHandle> restoredState) throws Exception {
+	public void restore(Collection<KeyedStateHandle> restoredState) throws Exception {
 		LOG.info("Initializing heap keyed state backend from snapshot.");
 
 		if (LOG.isDebugEnabled()) {
@@ -369,19 +371,26 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	}
 
 	@SuppressWarnings({"unchecked"})
-	private void restorePartitionedState(Collection<KeyGroupsStateHandle> state) throws Exception {
+	private void restorePartitionedState(Collection<KeyedStateHandle> state) throws Exception {
 
 		final Map<Integer, String> kvStatesById = new HashMap<>();
 		int numRegisteredKvStates = 0;
 		stateTables.clear();
 
-		for (KeyGroupsStateHandle keyGroupsHandle : state) {
+		for (KeyedStateHandle keyedStateHandle : state) {
 
-			if (keyGroupsHandle == null) {
+			if (keyedStateHandle == null) {
 				continue;
 			}
 
-			FSDataInputStream fsDataInputStream = keyGroupsHandle.openInputStream();
+			if (!(keyedStateHandle instanceof KeyGroupsStateHandle)) {
+				throw new IllegalStateException("Unexpected state handle type, " +
+						"expected: " + KeyGroupsStateHandle.class +
+						", but found: " + keyedStateHandle.getClass());
+			}
+
+			KeyGroupsStateHandle keyGroupsStateHandle = (KeyGroupsStateHandle) keyedStateHandle;
+			FSDataInputStream fsDataInputStream = keyGroupsStateHandle.openInputStream();
 			cancelStreamRegistry.registerClosable(fsDataInputStream);
 
 			try {
@@ -412,9 +421,13 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 					}
 				}
 
-				for (Tuple2<Integer, Long> groupOffset : keyGroupsHandle.getGroupRangeOffsets()) {
+				for (Tuple2<Integer, Long> groupOffset : keyGroupsStateHandle.getGroupRangeOffsets()) {
 					int keyGroupIndex = groupOffset.f0;
 					long offset = groupOffset.f1;
+
+					// Check that restored key groups all belong to the backend.
+					Preconditions.checkState(keyGroupRange.contains(keyGroupIndex), "The key group must belong to the backend.");
+
 					fsDataInputStream.seek(offset);
 
 					int writtenKeyGroupIndex = inView.readInt();
@@ -449,7 +462,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	@SuppressWarnings({"unchecked", "rawtypes", "DeprecatedIsStillUsed"})
 	@Deprecated
 	private void restoreOldSavepointKeyedState(
-			Collection<KeyGroupsStateHandle> stateHandles) throws IOException, ClassNotFoundException {
+			Collection<KeyedStateHandle> stateHandles) throws IOException, ClassNotFoundException {
 
 		if (stateHandles.isEmpty()) {
 			return;
@@ -457,8 +470,17 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 		Preconditions.checkState(1 == stateHandles.size(), "Only one element expected here.");
 
+		KeyedStateHandle keyedStateHandle = stateHandles.iterator().next();
+		if (!(keyedStateHandle instanceof MigrationKeyGroupStateHandle)) {
+			throw new IllegalStateException("Unexpected state handle type, " +
+					"expected: " + MigrationKeyGroupStateHandle.class +
+					", but found " + keyedStateHandle.getClass());
+		}
+
+		MigrationKeyGroupStateHandle keyGroupStateHandle = (MigrationKeyGroupStateHandle) keyedStateHandle;
+
 		HashMap<String, KvStateSnapshot<K, ?, ?, ?>> namedStates;
-		try (FSDataInputStream inputStream = stateHandles.iterator().next().openInputStream()) {
+		try (FSDataInputStream inputStream = keyGroupStateHandle.openInputStream()) {
 			namedStates = InstantiationUtil.deserializeObject(inputStream, userCodeClassLoader);
 		}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index d8bba59..117c70d 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -41,6 +41,7 @@ import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.TaskStateHandles;
@@ -2346,13 +2347,13 @@ public class CheckpointCoordinatorTest {
 			ChainedStateHandle<StreamStateHandle> operatorState = taskStateHandles.getLegacyOperatorState();
 			List<Collection<OperatorStateHandle>> opStateBackend = taskStateHandles.getManagedOperatorState();
 			List<Collection<OperatorStateHandle>> opStateRaw = taskStateHandles.getRawOperatorState();
-			Collection<KeyGroupsStateHandle> keyGroupStateBackend = taskStateHandles.getManagedKeyedState();
-			Collection<KeyGroupsStateHandle> keyGroupStateRaw = taskStateHandles.getRawKeyedState();
+			Collection<KeyedStateHandle> keyedStateBackend = taskStateHandles.getManagedKeyedState();
+			Collection<KeyedStateHandle> keyGroupStateRaw = taskStateHandles.getRawKeyedState();
 
 			actualOpStatesBackend.add(opStateBackend);
 			actualOpStatesRaw.add(opStateRaw);
 			assertNull(operatorState);
-			compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyGroupStateBackend);
+			compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyedStateBackend);
 			compareKeyedState(Collections.singletonList(originalKeyedStateRaw), keyGroupStateRaw);
 		}
 		comparePartitionableState(expectedOpStatesBackend, actualOpStatesBackend);
@@ -2690,32 +2691,38 @@ public class CheckpointCoordinatorTest {
 
 			KeyGroupsStateHandle expectPartitionedKeyGroupState = generateKeyGroupState(
 					jobVertexID, keyGroupPartitions.get(i), false);
-			Collection<KeyGroupsStateHandle> actualPartitionedKeyGroupState = taskStateHandles.getManagedKeyedState();
+			Collection<KeyedStateHandle> actualPartitionedKeyGroupState = taskStateHandles.getManagedKeyedState();
 			compareKeyedState(Collections.singletonList(expectPartitionedKeyGroupState), actualPartitionedKeyGroupState);
 		}
 	}
 
 	public static void compareKeyedState(
 			Collection<KeyGroupsStateHandle> expectPartitionedKeyGroupState,
-			Collection<KeyGroupsStateHandle> actualPartitionedKeyGroupState) throws Exception {
+			Collection<? extends KeyedStateHandle> actualPartitionedKeyGroupState) throws Exception {
 
 		KeyGroupsStateHandle expectedHeadOpKeyGroupStateHandle = expectPartitionedKeyGroupState.iterator().next();
-		int expectedTotalKeyGroups = expectedHeadOpKeyGroupStateHandle.getNumberOfKeyGroups();
+		int expectedTotalKeyGroups = expectedHeadOpKeyGroupStateHandle.getKeyGroupRange().getNumberOfKeyGroups();
 		int actualTotalKeyGroups = 0;
-		for(KeyGroupsStateHandle keyGroupsStateHandle: actualPartitionedKeyGroupState) {
-			actualTotalKeyGroups += keyGroupsStateHandle.getNumberOfKeyGroups();
+		for(KeyedStateHandle keyedStateHandle: actualPartitionedKeyGroupState) {
+			assertTrue(keyedStateHandle instanceof KeyGroupsStateHandle);
+
+			actualTotalKeyGroups += keyedStateHandle.getKeyGroupRange().getNumberOfKeyGroups();
 		}
 
 		assertEquals(expectedTotalKeyGroups, actualTotalKeyGroups);
 
 		try (FSDataInputStream inputStream = expectedHeadOpKeyGroupStateHandle.openInputStream()) {
-			for (int groupId : expectedHeadOpKeyGroupStateHandle.keyGroups()) {
+			for (int groupId : expectedHeadOpKeyGroupStateHandle.getKeyGroupRange()) {
 				long offset = expectedHeadOpKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
 				inputStream.seek(offset);
 				int expectedKeyGroupState =
 						InstantiationUtil.deserializeObject(inputStream, Thread.currentThread().getContextClassLoader());
-				for (KeyGroupsStateHandle oneActualKeyGroupStateHandle : actualPartitionedKeyGroupState) {
-					if (oneActualKeyGroupStateHandle.containsKeyGroup(groupId)) {
+				for (KeyedStateHandle oneActualKeyedStateHandle : actualPartitionedKeyGroupState) {
+
+					assertTrue(oneActualKeyedStateHandle instanceof KeyGroupsStateHandle);
+
+					KeyGroupsStateHandle oneActualKeyGroupStateHandle = (KeyGroupsStateHandle) oneActualKeyedStateHandle;
+					if (oneActualKeyGroupStateHandle.getKeyGroupRange().contains(groupId)) {
 						long actualOffset = oneActualKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
 						try (FSDataInputStream actualInputStream = oneActualKeyGroupStateHandle.openInputStream()) {
 							actualInputStream.seek(actualOffset);

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
index 18b07eb..7e0a7c1 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
@@ -31,6 +31,7 @@ import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.TaskStateHandles;
@@ -68,7 +69,7 @@ public class CheckpointStateRestoreTest {
 			final ChainedStateHandle<StreamStateHandle> serializedState = CheckpointCoordinatorTest.generateChainedStateHandle(new SerializableObject());
 			KeyGroupRange keyGroupRange = KeyGroupRange.of(0,0);
 			List<SerializableObject> testStates = Collections.singletonList(new SerializableObject());
-			final KeyGroupsStateHandle serializedKeyGroupStates = CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, testStates);
+			final KeyedStateHandle serializedKeyGroupStates = CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, testStates);
 
 			final JobID jid = new JobID();
 			final JobVertexID statefulId = new JobVertexID();

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/MigrationV0ToV1Test.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/MigrationV0ToV1Test.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/MigrationV0ToV1Test.java
index 6ab8620..1ecb2e3 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/MigrationV0ToV1Test.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/MigrationV0ToV1Test.java
@@ -37,6 +37,7 @@ import org.apache.flink.runtime.checkpoint.TaskState;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
@@ -58,6 +59,7 @@ import java.util.concurrent.ThreadLocalRandom;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
 
 @SuppressWarnings("deprecation")
 public class MigrationV0ToV1Test {
@@ -154,9 +156,15 @@ public class MigrationV0ToV1Test {
 					}
 
 					//check keyed state
-					KeyGroupsStateHandle keyGroupsStateHandle = subtaskState.getManagedKeyedState();
+					KeyedStateHandle keyedStateHandle = subtaskState.getManagedKeyedState();
+
 					if (t % 3 != 0) {
-						assertEquals(1, keyGroupsStateHandle.getNumberOfKeyGroups());
+
+						assertTrue(keyedStateHandle instanceof KeyGroupsStateHandle);
+
+						KeyGroupsStateHandle keyGroupsStateHandle = (KeyGroupsStateHandle) keyedStateHandle;
+
+						assertEquals(1, keyGroupsStateHandle.getKeyGroupRange().getNumberOfKeyGroups());
 						assertEquals(p, keyGroupsStateHandle.getGroupRangeOffsets().getKeyGroupRange().getStartKeyGroup());
 
 						ByteStreamStateHandle stateHandle =
@@ -172,7 +180,7 @@ public class MigrationV0ToV1Test {
 							assertEquals(p, data[1]);
 						}
 					} else {
-						assertEquals(null, keyGroupsStateHandle);
+						assertEquals(null, keyedStateHandle);
 					}
 				}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStreamTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStreamTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStreamTest.java
index 0c4ed74..cee0b02 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStreamTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStreamTest.java
@@ -135,7 +135,7 @@ public class KeyedStateCheckpointOutputStreamTest {
 		int count = 0;
 		try (FSDataInputStream in = fullHandle.openInputStream()) {
 			DataInputView div = new DataInputViewStreamWrapper(in);
-			for (int kg : fullHandle.keyGroups()) {
+			for (int kg : fullHandle.getKeyGroupRange()) {
 				long off = fullHandle.getOffsetForKeyGroup(kg);
 				if (off >= 0) {
 					in.seek(off);
@@ -152,7 +152,7 @@ public class KeyedStateCheckpointOutputStreamTest {
 		int count = 0;
 		try (FSDataInputStream in = fullHandle.openInputStream()) {
 			DataInputView div = new DataInputViewStreamWrapper(in);
-			for (int kg : fullHandle.keyGroups()) {
+			for (int kg : fullHandle.getKeyGroupRange()) {
 				long off = fullHandle.getOffsetForKeyGroup(kg);
 				in.seek(off);
 				Assert.assertEquals(kg, div.readInt());

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
index 22bb715..ccc1eae 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
@@ -143,13 +143,13 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 				env.getTaskKvStateRegistry());
 	}
 
-	protected <K> AbstractKeyedStateBackend<K> restoreKeyedBackend(TypeSerializer<K> keySerializer, KeyGroupsStateHandle state) throws Exception {
+	protected <K> AbstractKeyedStateBackend<K> restoreKeyedBackend(TypeSerializer<K> keySerializer, KeyedStateHandle state) throws Exception {
 		return restoreKeyedBackend(keySerializer, state, new DummyEnvironment("test", 1, 0));
 	}
 
 	protected <K> AbstractKeyedStateBackend<K> restoreKeyedBackend(
 			TypeSerializer<K> keySerializer,
-			KeyGroupsStateHandle state,
+			KeyedStateHandle state,
 			Environment env) throws Exception {
 		return restoreKeyedBackend(
 				keySerializer,
@@ -163,7 +163,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			TypeSerializer<K> keySerializer,
 			int numberOfKeyGroups,
 			KeyGroupRange keyGroupRange,
-			List<KeyGroupsStateHandle> state,
+			List<KeyedStateHandle> state,
 			Environment env) throws Exception {
 
 		AbstractKeyedStateBackend<K> backend = getStateBackend().createKeyedStateBackend(
@@ -436,7 +436,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		backend.setCurrentKey(2);
 		state.update(new TestPojo("u2", 2));
 
-		KeyGroupsStateHandle snapshot = runSnapshot(backend.snapshot(
+		KeyedStateHandle snapshot = runSnapshot(backend.snapshot(
 				682375462378L,
 				2,
 				streamFactory,
@@ -497,7 +497,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		backend.setCurrentKey(2);
 		state.update(new TestPojo("u2", 2));
 
-		KeyGroupsStateHandle snapshot = runSnapshot(backend.snapshot(
+		KeyedStateHandle snapshot = runSnapshot(backend.snapshot(
 				682375462378L,
 				2,
 				streamFactory,
@@ -524,7 +524,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		// update to test state backends that eagerly serialize, such as RocksDB
 		state.update(new TestPojo("u1", 11));
 
-		KeyGroupsStateHandle snapshot2 = runSnapshot(backend.snapshot(
+		KeyedStateHandle snapshot2 = runSnapshot(backend.snapshot(
 				682375462378L,
 				2,
 				streamFactory,
@@ -585,7 +585,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		backend.setCurrentKey(2);
 		state.update(new TestPojo("u2", 2));
 
-		KeyGroupsStateHandle snapshot = runSnapshot(backend.snapshot(
+		KeyedStateHandle snapshot = runSnapshot(backend.snapshot(
 				682375462378L,
 				2,
 				streamFactory,
@@ -611,7 +611,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		// update to test state backends that eagerly serialize, such as RocksDB
 		state.update(new TestPojo("u1", 11));
 
-		KeyGroupsStateHandle snapshot2 = runSnapshot(backend.snapshot(
+		KeyedStateHandle snapshot2 = runSnapshot(backend.snapshot(
 				682375462378L,
 				2,
 				streamFactory,
@@ -670,7 +670,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		assertEquals("1", getSerializedValue(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
 		// draw a snapshot
-		KeyGroupsStateHandle snapshot1 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
+		KeyedStateHandle snapshot1 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 		// make some more modifications
 		backend.setCurrentKey(1);
@@ -681,7 +681,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		state.update("u3");
 
 		// draw another snapshot
-		KeyGroupsStateHandle snapshot2 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forFullCheckpoint()));
+		KeyedStateHandle snapshot2 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 		// validate the original state
 		backend.setCurrentKey(1);
@@ -880,7 +880,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		assertEquals(13, (int) state2.value());
 
 		// draw a snapshot
-		KeyGroupsStateHandle snapshot1 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
+		KeyedStateHandle snapshot1 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 		backend.dispose();
 		backend = restoreKeyedBackend(
@@ -952,7 +952,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		assertEquals(42L, (long) state.value());
 
 		// draw a snapshot
-		KeyGroupsStateHandle snapshot1 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
+		KeyedStateHandle snapshot1 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 		backend.dispose();
 		backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot1);
@@ -997,7 +997,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		assertEquals("1", joiner.join(getSerializedList(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)));
 
 		// draw a snapshot
-		KeyGroupsStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
+		KeyedStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 		// make some more modifications
 		backend.setCurrentKey(1);
@@ -1008,7 +1008,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		state.add("u3");
 
 		// draw another snapshot
-		KeyGroupsStateHandle snapshot2 = runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forFullCheckpoint()));
+		KeyedStateHandle snapshot2 = runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 		// validate the original state
 		backend.setCurrentKey(1);
@@ -1091,7 +1091,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		assertEquals("1", getSerializedValue(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
 		// draw a snapshot
-		KeyGroupsStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
+		KeyedStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 		// make some more modifications
 		backend.setCurrentKey(1);
@@ -1102,7 +1102,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		state.add("u3");
 
 		// draw another snapshot
-		KeyGroupsStateHandle snapshot2 = runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forFullCheckpoint()));
+		KeyedStateHandle snapshot2 = runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 		// validate the original state
 		backend.setCurrentKey(1);
@@ -1188,7 +1188,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		assertEquals("Fold-Initial:,1", getSerializedValue(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
 		// draw a snapshot
-		KeyGroupsStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
+		KeyedStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 		// make some more modifications
 		backend.setCurrentKey(1);
@@ -1200,7 +1200,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		state.add(103);
 
 		// draw another snapshot
-		KeyGroupsStateHandle snapshot2 = runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forFullCheckpoint()));
+		KeyedStateHandle snapshot2 = runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 		// validate the original state
 		backend.setCurrentKey(1);
@@ -1287,7 +1287,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 				getSerializedMap(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, userValueSerializer));
 
 		// draw a snapshot
-		KeyGroupsStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
+		KeyedStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 		// make some more modifications
 		backend.setCurrentKey(1);
@@ -1299,7 +1299,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		state.putAll(new HashMap<Integer, String>() {{ put(1031, "1031"); put(1032, "1032"); }});
 
 		// draw another snapshot
-		KeyGroupsStateHandle snapshot2 = runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forFullCheckpoint()));
+		KeyedStateHandle snapshot2 = runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 		// validate the original state
 		backend.setCurrentKey(1);
@@ -1606,13 +1606,13 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		state.update("ShouldBeInSecondHalf");
 
 
-		KeyGroupsStateHandle snapshot = FutureUtil.runIfNotDoneAndGet(backend.snapshot(0, 0, streamFactory, CheckpointOptions.forFullCheckpoint()));
+		KeyedStateHandle snapshot = FutureUtil.runIfNotDoneAndGet(backend.snapshot(0, 0, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
-		List<KeyGroupsStateHandle> firstHalfKeyGroupStates = StateAssignmentOperation.getKeyGroupsStateHandles(
+		List<KeyedStateHandle> firstHalfKeyGroupStates = StateAssignmentOperation.getKeyedStateHandles(
 				Collections.singletonList(snapshot),
 				KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(MAX_PARALLELISM, 2, 0));
 
-		List<KeyGroupsStateHandle> secondHalfKeyGroupStates = StateAssignmentOperation.getKeyGroupsStateHandles(
+		List<KeyedStateHandle> secondHalfKeyGroupStates = StateAssignmentOperation.getKeyedStateHandles(
 				Collections.singletonList(snapshot),
 				KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(MAX_PARALLELISM, 2, 1));
 
@@ -1672,7 +1672,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			state.update("2");
 
 			// draw a snapshot
-			KeyGroupsStateHandle snapshot1 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
+			KeyedStateHandle snapshot1 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 			backend.dispose();
 			// restore the first snapshot and validate it
@@ -1723,7 +1723,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			state.add("2");
 
 			// draw a snapshot
-			KeyGroupsStateHandle snapshot1 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
+			KeyedStateHandle snapshot1 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 			backend.dispose();
 			// restore the first snapshot and validate it
@@ -1776,7 +1776,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			state.add("2");
 
 			// draw a snapshot
-			KeyGroupsStateHandle snapshot1 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
+			KeyedStateHandle snapshot1 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 			backend.dispose();
 			// restore the first snapshot and validate it
@@ -1827,7 +1827,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			state.put("2", "Second");
 
 			// draw a snapshot
-			KeyGroupsStateHandle snapshot1 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
+			KeyedStateHandle snapshot1 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 			backend.dispose();
 			// restore the first snapshot and validate it
@@ -2093,7 +2093,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 				eq(env.getJobID()), eq(env.getJobVertexId()), eq(expectedKeyGroupRange), eq("banana"), any(KvStateID.class));
 
 
-		KeyGroupsStateHandle snapshot = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forFullCheckpoint()));
+		KeyedStateHandle snapshot = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 		backend.dispose();
 
@@ -2124,7 +2124,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			ListStateDescriptor<String> kvId = new ListStateDescriptor<>("id", String.class);
 
 			// draw a snapshot
-			KeyGroupsStateHandle snapshot =
+			KeyedStateHandle snapshot =
 					FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462379L, 1, streamFactory, CheckpointOptions.forFullCheckpoint()));
 			assertNull(snapshot);
 			backend.dispose();
@@ -2152,7 +2152,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		streamFactory.setWaiterLatch(waiter);
 
 		AbstractKeyedStateBackend<Integer> backend = null;
-		KeyGroupsStateHandle stateHandle = null;
+		KeyedStateHandle stateHandle = null;
 
 		try {
 			backend = createKeyedBackend(IntSerializer.INSTANCE);
@@ -2167,7 +2167,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 				valueState.update(i);
 			}
 
-			RunnableFuture<KeyGroupsStateHandle> snapshot =
+			RunnableFuture<KeyedStateHandle> snapshot =
 					backend.snapshot(0L, 0L, streamFactory, CheckpointOptions.forFullCheckpoint());
 			Thread runner = new Thread(snapshot);
 			runner.start();
@@ -2249,7 +2249,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 				valueState.update(i);
 			}
 
-			RunnableFuture<KeyGroupsStateHandle> snapshot =
+			RunnableFuture<KeyedStateHandle> snapshot =
 					backend.snapshot(0L, 0L, streamFactory, CheckpointOptions.forFullCheckpoint());
 
 			Thread runner = new Thread(snapshot);
@@ -2367,7 +2367,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		}
 	}
 
-	private KeyGroupsStateHandle runSnapshot(RunnableFuture<KeyGroupsStateHandle> snapshotRunnableFuture) throws Exception {
+	private KeyedStateHandle runSnapshot(RunnableFuture<KeyedStateHandle> snapshotRunnableFuture) throws Exception {
 		if(!snapshotRunnableFuture.isDone()) {
 			Thread runner = new Thread(snapshotRunnableFuture);
 			runner.start();

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendSnapshotMigrationTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendSnapshotMigrationTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendSnapshotMigrationTest.java
index da0666a..3754d63 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendSnapshotMigrationTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendSnapshotMigrationTest.java
@@ -22,6 +22,7 @@ import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.typeutils.base.IntSerializer;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.internal.InternalListState;
 import org.apache.flink.util.InstantiationUtil;
 import org.apache.flink.util.Preconditions;
@@ -63,7 +64,7 @@ public class HeapKeyedStateBackendSnapshotMigrationTest extends HeapStateBackend
 			try (BufferedInputStream bis = new BufferedInputStream((new FileInputStream(resource.getFile())))) {
 				stateHandle = InstantiationUtil.deserializeObject(bis, Thread.currentThread().getContextClassLoader());
 			}
-			keyedBackend.restore(Collections.singleton(stateHandle));
+			keyedBackend.restore(Collections.<KeyedStateHandle>singleton(stateHandle));
 			final ListStateDescriptor<Long> stateDescr = new ListStateDescriptor<>("my-state", Long.class);
 			stateDescr.initializeSerializerUnlessSet(new ExecutionConfig());
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
index e40a59b..a6a89b5 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
@@ -18,7 +18,6 @@
 
 package org.apache.flink.streaming.api.operators;
 
-import java.io.IOException;
 import org.apache.commons.io.IOUtils;
 import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
 import org.apache.flink.annotation.PublicEvolving;
@@ -47,9 +46,9 @@ import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.KeyGroupStatePartitionStreamProvider;
 import org.apache.flink.runtime.state.KeyGroupsList;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.KeyedStateCheckpointOutputStream;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateBackend;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateInitializationContext;
@@ -70,6 +69,7 @@ import org.apache.flink.util.OutputTag;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.io.IOException;
 import java.io.Serializable;
 import java.util.Collection;
 import java.util.ConcurrentModificationException;
@@ -198,7 +198,7 @@ public abstract class AbstractStreamOperator<OUT>
 	@Override
 	public final void initializeState(OperatorStateHandles stateHandles) throws Exception {
 
-		Collection<KeyGroupsStateHandle> keyedStateHandlesRaw = null;
+		Collection<KeyedStateHandle> keyedStateHandlesRaw = null;
 		Collection<OperatorStateHandle> operatorStateHandlesRaw = null;
 		Collection<OperatorStateHandle> operatorStateHandlesBackend = null;
 
@@ -473,6 +473,7 @@ public abstract class AbstractStreamOperator<OUT>
 			// and then initialize the timer services
 			for (KeyGroupStatePartitionStreamProvider streamProvider : context.getRawKeyedStateInputs()) {
 				int keyGroupIdx = streamProvider.getKeyGroupId();
+
 				checkArgument(localKeyGroupRange.contains(keyGroupIdx),
 					"Key Group " + keyGroupIdx + " does not belong to the local range.");