You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by sr...@apache.org on 2017/03/28 18:05:51 UTC

[1/2] flink git commit: [FLINK-6034] [checkpoints] Introduce KeyedStateHandle abstraction for the snapshots in keyed streams

Repository: flink
Updated Branches:
  refs/heads/master 89866a5ad -> cd5527417


http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/OperatorSnapshotResult.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/OperatorSnapshotResult.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/OperatorSnapshotResult.java
index b1c94cb..8aa76a5 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/OperatorSnapshotResult.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/OperatorSnapshotResult.java
@@ -18,7 +18,7 @@
 
 package org.apache.flink.streaming.api.operators;
 
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateUtil;
 import org.apache.flink.util.ExceptionUtils;
@@ -30,8 +30,8 @@ import java.util.concurrent.RunnableFuture;
  */
 public class OperatorSnapshotResult {
 
-	private RunnableFuture<KeyGroupsStateHandle> keyedStateManagedFuture;
-	private RunnableFuture<KeyGroupsStateHandle> keyedStateRawFuture;
+	private RunnableFuture<KeyedStateHandle> keyedStateManagedFuture;
+	private RunnableFuture<KeyedStateHandle> keyedStateRawFuture;
 	private RunnableFuture<OperatorStateHandle> operatorStateManagedFuture;
 	private RunnableFuture<OperatorStateHandle> operatorStateRawFuture;
 
@@ -40,8 +40,8 @@ public class OperatorSnapshotResult {
 	}
 
 	public OperatorSnapshotResult(
-			RunnableFuture<KeyGroupsStateHandle> keyedStateManagedFuture,
-			RunnableFuture<KeyGroupsStateHandle> keyedStateRawFuture,
+			RunnableFuture<KeyedStateHandle> keyedStateManagedFuture,
+			RunnableFuture<KeyedStateHandle> keyedStateRawFuture,
 			RunnableFuture<OperatorStateHandle> operatorStateManagedFuture,
 			RunnableFuture<OperatorStateHandle> operatorStateRawFuture) {
 		this.keyedStateManagedFuture = keyedStateManagedFuture;
@@ -50,19 +50,19 @@ public class OperatorSnapshotResult {
 		this.operatorStateRawFuture = operatorStateRawFuture;
 	}
 
-	public RunnableFuture<KeyGroupsStateHandle> getKeyedStateManagedFuture() {
+	public RunnableFuture<KeyedStateHandle> getKeyedStateManagedFuture() {
 		return keyedStateManagedFuture;
 	}
 
-	public void setKeyedStateManagedFuture(RunnableFuture<KeyGroupsStateHandle> keyedStateManagedFuture) {
+	public void setKeyedStateManagedFuture(RunnableFuture<KeyedStateHandle> keyedStateManagedFuture) {
 		this.keyedStateManagedFuture = keyedStateManagedFuture;
 	}
 
-	public RunnableFuture<KeyGroupsStateHandle> getKeyedStateRawFuture() {
+	public RunnableFuture<KeyedStateHandle> getKeyedStateRawFuture() {
 		return keyedStateRawFuture;
 	}
 
-	public void setKeyedStateRawFuture(RunnableFuture<KeyGroupsStateHandle> keyedStateRawFuture) {
+	public void setKeyedStateRawFuture(RunnableFuture<KeyedStateHandle> keyedStateRawFuture) {
 		this.keyedStateRawFuture = keyedStateRawFuture;
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorStateHandles.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorStateHandles.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorStateHandles.java
index 7abf8d9..30d07b7 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorStateHandles.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorStateHandles.java
@@ -21,7 +21,7 @@ package org.apache.flink.streaming.runtime.tasks;
 import org.apache.flink.annotation.Internal;
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.TaskStateHandles;
@@ -42,16 +42,16 @@ public class OperatorStateHandles {
 
 	private final StreamStateHandle legacyOperatorState;
 
-	private final Collection<KeyGroupsStateHandle> managedKeyedState;
-	private final Collection<KeyGroupsStateHandle> rawKeyedState;
+	private final Collection<KeyedStateHandle> managedKeyedState;
+	private final Collection<KeyedStateHandle> rawKeyedState;
 	private final Collection<OperatorStateHandle> managedOperatorState;
 	private final Collection<OperatorStateHandle> rawOperatorState;
 
 	public OperatorStateHandles(
 			int operatorChainIndex,
 			StreamStateHandle legacyOperatorState,
-			Collection<KeyGroupsStateHandle> managedKeyedState,
-			Collection<KeyGroupsStateHandle> rawKeyedState,
+			Collection<KeyedStateHandle> managedKeyedState,
+			Collection<KeyedStateHandle> rawKeyedState,
 			Collection<OperatorStateHandle> managedOperatorState,
 			Collection<OperatorStateHandle> rawOperatorState) {
 
@@ -83,11 +83,11 @@ public class OperatorStateHandles {
 		return legacyOperatorState;
 	}
 
-	public Collection<KeyGroupsStateHandle> getManagedKeyedState() {
+	public Collection<KeyedStateHandle> getManagedKeyedState() {
 		return managedKeyedState;
 	}
 
-	public Collection<KeyGroupsStateHandle> getRawKeyedState() {
+	public Collection<KeyedStateHandle> getRawKeyedState() {
 		return rawKeyedState;
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index 76b2b98..11e8e0d 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -37,7 +37,7 @@ import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateBackend;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateBackend;
@@ -849,8 +849,8 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 
 		private final List<OperatorSnapshotResult> snapshotInProgressList;
 
-		private RunnableFuture<KeyGroupsStateHandle> futureKeyedBackendStateHandles;
-		private RunnableFuture<KeyGroupsStateHandle> futureKeyedStreamStateHandles;
+		private RunnableFuture<KeyedStateHandle> futureKeyedBackendStateHandles;
+		private RunnableFuture<KeyedStateHandle> futureKeyedStreamStateHandles;
 
 		private List<StreamStateHandle> nonPartitionedStateHandles;
 
@@ -892,8 +892,8 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 		public void run() {
 			try {
 				// Keyed state handle future, currently only one (the head) operator can have this
-				KeyGroupsStateHandle keyedStateHandleBackend = FutureUtil.runIfNotDoneAndGet(futureKeyedBackendStateHandles);
-				KeyGroupsStateHandle keyedStateHandleStream = FutureUtil.runIfNotDoneAndGet(futureKeyedStreamStateHandles);
+				KeyedStateHandle keyedStateHandleBackend = FutureUtil.runIfNotDoneAndGet(futureKeyedBackendStateHandles);
+				KeyedStateHandle keyedStateHandleStream = FutureUtil.runIfNotDoneAndGet(futureKeyedStreamStateHandles);
 
 				List<OperatorStateHandle> operatorStatesBackend = new ArrayList<>(snapshotInProgressList.size());
 				List<OperatorStateHandle> operatorStatesStream = new ArrayList<>(snapshotInProgressList.size());
@@ -987,8 +987,8 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 				ChainedStateHandle<StreamStateHandle> chainedNonPartitionedOperatorsState,
 				ChainedStateHandle<OperatorStateHandle> chainedOperatorStateBackend,
 				ChainedStateHandle<OperatorStateHandle> chainedOperatorStateStream,
-				KeyGroupsStateHandle keyedStateHandleBackend,
-				KeyGroupsStateHandle keyedStateHandleStream) {
+				KeyedStateHandle keyedStateHandleBackend,
+				KeyedStateHandle keyedStateHandleStream) {
 
 			boolean hasAnyState = keyedStateHandleBackend != null
 					|| keyedStateHandleStream != null

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractStreamOperatorTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractStreamOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractStreamOperatorTest.java
index eeee8dc..8f42c1a 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractStreamOperatorTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractStreamOperatorTest.java
@@ -51,7 +51,7 @@ import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateBackend;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateSnapshotContextSynchronousImpl;
@@ -559,11 +559,11 @@ public class AbstractStreamOperatorTest {
 
 		final CloseableRegistry closeableRegistry = new CloseableRegistry();
 
-		RunnableFuture<KeyGroupsStateHandle> futureKeyGroupStateHandle = mock(RunnableFuture.class);
+		RunnableFuture<KeyedStateHandle> futureKeyedStateHandle = mock(RunnableFuture.class);
 		RunnableFuture<OperatorStateHandle> futureOperatorStateHandle = mock(RunnableFuture.class);
 
 		StateSnapshotContextSynchronousImpl context = mock(StateSnapshotContextSynchronousImpl.class);
-		when(context.getKeyedStateStreamFuture()).thenReturn(futureKeyGroupStateHandle);
+		when(context.getKeyedStateStreamFuture()).thenReturn(futureKeyedStateHandle);
 		when(context.getOperatorStateStreamFuture()).thenReturn(futureOperatorStateHandle);
 
 		OperatorSnapshotResult operatorSnapshotResult = spy(new OperatorSnapshotResult());
@@ -609,9 +609,9 @@ public class AbstractStreamOperatorTest {
 		verify(context).close();
 		verify(operatorSnapshotResult).cancel();
 
-		verify(futureKeyGroupStateHandle).cancel(anyBoolean());
+		verify(futureKeyedStateHandle).cancel(anyBoolean());
 		verify(futureOperatorStateHandle).cancel(anyBoolean());
-		verify(futureKeyGroupStateHandle).cancel(anyBoolean());
+		verify(futureKeyedStateHandle).cancel(anyBoolean());
 	}
 
 	@Test

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/OperatorSnapshotResultTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/OperatorSnapshotResultTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/OperatorSnapshotResultTest.java
index 490df52..f57eed1 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/OperatorSnapshotResultTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/OperatorSnapshotResultTest.java
@@ -18,7 +18,7 @@
 
 package org.apache.flink.streaming.api.operators;
 
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.util.TestLogger;
 import org.junit.Test;
@@ -41,12 +41,12 @@ public class OperatorSnapshotResultTest extends TestLogger {
 
 		operatorSnapshotResult.cancel();
 
-		KeyGroupsStateHandle keyedManagedStateHandle = mock(KeyGroupsStateHandle.class);
-		RunnableFuture<KeyGroupsStateHandle> keyedStateManagedFuture = mock(RunnableFuture.class);
+		KeyedStateHandle keyedManagedStateHandle = mock(KeyedStateHandle.class);
+		RunnableFuture<KeyedStateHandle> keyedStateManagedFuture = mock(RunnableFuture.class);
 		when(keyedStateManagedFuture.get()).thenReturn(keyedManagedStateHandle);
 
-		KeyGroupsStateHandle keyedRawStateHandle = mock(KeyGroupsStateHandle.class);
-		RunnableFuture<KeyGroupsStateHandle> keyedStateRawFuture = mock(RunnableFuture.class);
+		KeyedStateHandle keyedRawStateHandle = mock(KeyedStateHandle.class);
+		RunnableFuture<KeyedStateHandle> keyedStateRawFuture = mock(RunnableFuture.class);
 		when(keyedStateRawFuture.get()).thenReturn(keyedRawStateHandle);
 
 		OperatorStateHandle operatorManagedStateHandle = mock(OperatorStateHandle.class);

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java
index 963c42c..8e0edfc 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java
@@ -32,6 +32,7 @@ import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupStatePartitionStreamProvider;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateInitializationContextImpl;
 import org.apache.flink.runtime.state.StatePartitionStreamProvider;
@@ -75,7 +76,7 @@ public class StateInitializationContextImplTest {
 
 		ByteArrayOutputStreamWithPos out = new ByteArrayOutputStreamWithPos(64);
 
-		List<KeyGroupsStateHandle> keyGroupsStateHandles = new ArrayList<>(NUM_HANDLES);
+		List<KeyedStateHandle> keyedStateHandles = new ArrayList<>(NUM_HANDLES);
 		int prev = 0;
 		for (int i = 0; i < NUM_HANDLES; ++i) {
 			out.reset();
@@ -91,10 +92,10 @@ public class StateInitializationContextImplTest {
 				++writtenKeyGroups;
 			}
 
-			KeyGroupsStateHandle handle =
+			KeyedStateHandle handle =
 					new KeyGroupsStateHandle(offsets, new ByteStateHandleCloseChecking("kg-" + i, out.toByteArray()));
 
-			keyGroupsStateHandles.add(handle);
+			keyedStateHandles.add(handle);
 		}
 
 		List<OperatorStateHandle> operatorStateHandles = new ArrayList<>(NUM_HANDLES);
@@ -125,7 +126,7 @@ public class StateInitializationContextImplTest {
 						true,
 						stateStore,
 						mock(KeyedStateStore.class),
-						keyGroupsStateHandles,
+						keyedStateHandles,
 						operatorStateHandles,
 						closableRegistry);
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
index 58cfefd..4435247 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
@@ -51,6 +51,7 @@ import org.apache.flink.runtime.state.FunctionSnapshotContext;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateInitializationContext;
 import org.apache.flink.runtime.state.StreamStateHandle;
@@ -186,8 +187,8 @@ public class InterruptSensitiveRestoreTest {
 
 
 		ChainedStateHandle<StreamStateHandle> operatorState = null;
-		List<KeyGroupsStateHandle> keyGroupStateFromBackend = Collections.emptyList();
-		List<KeyGroupsStateHandle> keyGroupStateFromStream = Collections.emptyList();
+		List<KeyedStateHandle> keyedStateFromBackend = Collections.emptyList();
+		List<KeyedStateHandle> keyedStateFromStream = Collections.emptyList();
 		List<Collection<OperatorStateHandle>> operatorStateBackend = Collections.emptyList();
 		List<Collection<OperatorStateHandle>> operatorStateStream = Collections.emptyList();
 
@@ -201,8 +202,8 @@ public class InterruptSensitiveRestoreTest {
 		Collection<OperatorStateHandle> operatorStateHandles =
 				Collections.singletonList(new OperatorStateHandle(operatorStateMetadata, state));
 
-		List<KeyGroupsStateHandle> keyGroupsStateHandles =
-				Collections.singletonList(new KeyGroupsStateHandle(keyGroupRangeOffsets, state));
+		List<KeyedStateHandle> keyedStateHandles =
+				Collections.<KeyedStateHandle>singletonList(new KeyGroupsStateHandle(keyGroupRangeOffsets, state));
 
 		switch (mode) {
 			case OPERATOR_MANAGED:
@@ -212,10 +213,10 @@ public class InterruptSensitiveRestoreTest {
 				operatorStateStream = Collections.singletonList(operatorStateHandles);
 				break;
 			case KEYED_MANAGED:
-				keyGroupStateFromBackend = keyGroupsStateHandles;
+				keyedStateFromBackend = keyedStateHandles;
 				break;
 			case KEYED_RAW:
-				keyGroupStateFromStream = keyGroupsStateHandles;
+				keyedStateFromStream = keyedStateHandles;
 				break;
 			case LEGACY:
 				operatorState = new ChainedStateHandle<>(Collections.singletonList(state));
@@ -228,8 +229,8 @@ public class InterruptSensitiveRestoreTest {
 			operatorState,
 			operatorStateBackend,
 			operatorStateStream,
-			keyGroupStateFromBackend,
-			keyGroupStateFromStream);
+			keyedStateFromBackend,
+			keyedStateFromStream);
 
 		JobInformation jobInformation = new JobInformation(
 			new JobID(),

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
index d7e3d6c..f34522b 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
@@ -61,7 +61,7 @@ import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.DoneFuture;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateBackend;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateBackendFactory;
@@ -458,8 +458,8 @@ public class StreamTaskTest extends TestLogger {
 
 		StreamOperator<?> streamOperator = mock(StreamOperator.class, withSettings().extraInterfaces(StreamCheckpointedOperator.class));
 
-		KeyGroupsStateHandle managedKeyedStateHandle = mock(KeyGroupsStateHandle.class);
-		KeyGroupsStateHandle rawKeyedStateHandle = mock(KeyGroupsStateHandle.class);
+		KeyedStateHandle managedKeyedStateHandle = mock(KeyedStateHandle.class);
+		KeyedStateHandle rawKeyedStateHandle = mock(KeyedStateHandle.class);
 		OperatorStateHandle managedOperatorStateHandle = mock(OperatorStateHandle.class);
 		OperatorStateHandle rawOperatorStateHandle = mock(OperatorStateHandle.class);
 
@@ -563,8 +563,8 @@ public class StreamTaskTest extends TestLogger {
 					(ChainedStateHandle<StreamStateHandle>)invocation.getArguments()[0],
 					(ChainedStateHandle<OperatorStateHandle>)invocation.getArguments()[1],
 					(ChainedStateHandle<OperatorStateHandle>)invocation.getArguments()[2],
-					(KeyGroupsStateHandle)invocation.getArguments()[3],
-					(KeyGroupsStateHandle)invocation.getArguments()[4]);
+					(KeyedStateHandle)invocation.getArguments()[3],
+					(KeyedStateHandle)invocation.getArguments()[4]);
 			}
 		});
 
@@ -574,8 +574,8 @@ public class StreamTaskTest extends TestLogger {
 
 		StreamOperator<?> streamOperator = mock(StreamOperator.class, withSettings().extraInterfaces(StreamCheckpointedOperator.class));
 
-		KeyGroupsStateHandle managedKeyedStateHandle = mock(KeyGroupsStateHandle.class);
-		KeyGroupsStateHandle rawKeyedStateHandle = mock(KeyGroupsStateHandle.class);
+		KeyedStateHandle managedKeyedStateHandle = mock(KeyedStateHandle.class);
+		KeyedStateHandle rawKeyedStateHandle = mock(KeyedStateHandle.class);
 		OperatorStateHandle managedOperatorStateHandle = mock(OperatorStateHandle.class);
 		OperatorStateHandle rawOperatorStateHandle = mock(OperatorStateHandle.class);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
index 945103c..912d579 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
@@ -41,6 +41,7 @@ import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateBackend;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
@@ -318,7 +319,7 @@ public class AbstractStreamOperatorTestHarness<OUT> {
 
 		StreamStateHandle stateHandle = SavepointV0Serializer.convertOperatorAndFunctionState(state);
 
-		List<KeyGroupsStateHandle> keyGroupStatesList = new ArrayList<>();
+		List<KeyedStateHandle> keyGroupStatesList = new ArrayList<>();
 		if (state.getKvStates() != null) {
 			KeyGroupsStateHandle keyedStateHandle = SavepointV0Serializer.convertKeyedBackendState(
 					state.getKvStates(),
@@ -331,7 +332,7 @@ public class AbstractStreamOperatorTestHarness<OUT> {
 		initializeState(new OperatorStateHandles(0,
 				stateHandle,
 				keyGroupStatesList,
-				Collections.<KeyGroupsStateHandle>emptyList(),
+				Collections.<KeyedStateHandle>emptyList(),
 				Collections.<OperatorStateHandle>emptyList(),
 				Collections.<OperatorStateHandle>emptyList()));
 	}
@@ -364,16 +365,16 @@ public class AbstractStreamOperatorTestHarness<OUT> {
 			KeyGroupRange localKeyGroupRange =
 					keyGroupPartitions.get(subtaskIndex);
 
-			List<KeyGroupsStateHandle> localManagedKeyGroupState = null;
+			List<KeyedStateHandle> localManagedKeyGroupState = null;
 			if (operatorStateHandles.getManagedKeyedState() != null) {
-				localManagedKeyGroupState = StateAssignmentOperation.getKeyGroupsStateHandles(
+				localManagedKeyGroupState = StateAssignmentOperation.getKeyedStateHandles(
 						operatorStateHandles.getManagedKeyedState(),
 						localKeyGroupRange);
 			}
 
-			List<KeyGroupsStateHandle> localRawKeyGroupState = null;
+			List<KeyedStateHandle> localRawKeyGroupState = null;
 			if (operatorStateHandles.getRawKeyedState() != null) {
-				localRawKeyGroupState = StateAssignmentOperation.getKeyGroupsStateHandles(
+				localRawKeyGroupState = StateAssignmentOperation.getKeyedStateHandles(
 						operatorStateHandles.getRawKeyedState(),
 						localKeyGroupRange);
 			}
@@ -442,15 +443,15 @@ public class AbstractStreamOperatorTestHarness<OUT> {
 		List<OperatorStateHandle> mergedManagedOperatorState = new ArrayList<>(handles.length);
 		List<OperatorStateHandle> mergedRawOperatorState = new ArrayList<>(handles.length);
 
-		List<KeyGroupsStateHandle> mergedManagedKeyedState = new ArrayList<>(handles.length);
-		List<KeyGroupsStateHandle> mergedRawKeyedState = new ArrayList<>(handles.length);
+		List<KeyedStateHandle> mergedManagedKeyedState = new ArrayList<>(handles.length);
+		List<KeyedStateHandle> mergedRawKeyedState = new ArrayList<>(handles.length);
 
 		for (OperatorStateHandles handle: handles) {
 
 			Collection<OperatorStateHandle> managedOperatorState = handle.getManagedOperatorState();
 			Collection<OperatorStateHandle> rawOperatorState = handle.getRawOperatorState();
-			Collection<KeyGroupsStateHandle> managedKeyedState = handle.getManagedKeyedState();
-			Collection<KeyGroupsStateHandle> rawKeyedState = handle.getRawKeyedState();
+			Collection<KeyedStateHandle> managedKeyedState = handle.getManagedKeyedState();
+			Collection<KeyedStateHandle> rawKeyedState = handle.getRawKeyedState();
 
 			if (managedOperatorState != null) {
 				mergedManagedOperatorState.addAll(managedOperatorState);
@@ -502,8 +503,8 @@ public class AbstractStreamOperatorTestHarness<OUT> {
 			timestamp,
 			CheckpointOptions.forFullCheckpoint());
 
-		KeyGroupsStateHandle keyedManaged = FutureUtil.runIfNotDoneAndGet(operatorStateResult.getKeyedStateManagedFuture());
-		KeyGroupsStateHandle keyedRaw = FutureUtil.runIfNotDoneAndGet(operatorStateResult.getKeyedStateRawFuture());
+		KeyedStateHandle keyedManaged = FutureUtil.runIfNotDoneAndGet(operatorStateResult.getKeyedStateManagedFuture());
+		KeyedStateHandle keyedRaw = FutureUtil.runIfNotDoneAndGet(operatorStateResult.getKeyedStateRawFuture());
 
 		OperatorStateHandle opManaged = FutureUtil.runIfNotDoneAndGet(operatorStateResult.getOperatorStateManagedFuture());
 		OperatorStateHandle opRaw = FutureUtil.runIfNotDoneAndGet(operatorStateResult.getOperatorStateRawFuture());

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
index d45ae21..d9c7387 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
@@ -30,6 +30,7 @@ import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedStateBackend;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
@@ -65,7 +66,7 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 
 	// when we restore we keep the state here so that we can call restore
 	// when the operator requests the keyed state backend
-	private List<KeyGroupsStateHandle> restoredKeyedState = null;
+	private List<KeyedStateHandle> restoredKeyedState = null;
 
 	public KeyedOneInputStreamOperatorTestHarness(
 			OneInputStreamOperator<IN, OUT> operator,
@@ -144,7 +145,7 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 		}
 
 		if (keyedStateBackend != null) {
-			RunnableFuture<KeyGroupsStateHandle> keyedSnapshotRunnable = keyedStateBackend.snapshot(
+			RunnableFuture<KeyedStateHandle> keyedSnapshotRunnable = keyedStateBackend.snapshot(
 					checkpointId,
 					timestamp,
 					streamFactory,
@@ -177,14 +178,14 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 			byte keyedStatePresent = (byte) inStream.read();
 			if (keyedStatePresent == 1) {
 				ObjectInputStream ois = new ObjectInputStream(inStream);
-				this.restoredKeyedState = Collections.singletonList((KeyGroupsStateHandle) ois.readObject());
+				this.restoredKeyedState = Collections.singletonList((KeyedStateHandle) ois.readObject());
 			}
 		}
 	}
 
 
-	private static boolean hasMigrationHandles(Collection<KeyGroupsStateHandle> allKeyGroupsHandles) {
-		for (KeyGroupsStateHandle handle : allKeyGroupsHandles) {
+	private static boolean hasMigrationHandles(Collection<KeyedStateHandle> allKeyGroupsHandles) {
+		for (KeyedStateHandle handle : allKeyGroupsHandles) {
 			if (handle instanceof Migration) {
 				return true;
 			}
@@ -225,17 +226,17 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 					keyGroupPartitions.get(subtaskIndex);
 
 			restoredKeyedState = null;
-			Collection<KeyGroupsStateHandle> managedKeyedState = operatorStateHandles.getManagedKeyedState();
+			Collection<KeyedStateHandle> managedKeyedState = operatorStateHandles.getManagedKeyedState();
 			if (managedKeyedState != null) {
 
 				// if we have migration handles, don't reshuffle state and preserve
 				// the migration tag
 				if (hasMigrationHandles(managedKeyedState)) {
-					List<KeyGroupsStateHandle> result = new ArrayList<>(managedKeyedState.size());
+					List<KeyedStateHandle> result = new ArrayList<>(managedKeyedState.size());
 					result.addAll(managedKeyedState);
 					restoredKeyedState = result;
 				} else {
-					restoredKeyedState = StateAssignmentOperation.getKeyGroupsStateHandles(
+					restoredKeyedState = StateAssignmentOperation.getKeyedStateHandles(
 							managedKeyedState,
 							localKeyGroupRange);
 				}

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java
index 8e76f70..41a083a 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java
@@ -26,6 +26,7 @@ import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedStateBackend;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
 import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
 import org.mockito.invocation.InvocationOnMock;
@@ -50,7 +51,7 @@ public class KeyedTwoInputStreamOperatorTestHarness<K, IN1, IN2, OUT>
 
 	// when we restore we keep the state here so that we can call restore
 	// when the operator requests the keyed state backend
-	private Collection<KeyGroupsStateHandle> restoredKeyedState = null;
+	private Collection<KeyedStateHandle> restoredKeyedState = null;
 
 	public KeyedTwoInputStreamOperatorTestHarness(
 			TwoInputStreamOperator<IN1, IN2, OUT> operator,


[2/2] flink git commit: [FLINK-6034] [checkpoints] Introduce KeyedStateHandle abstraction for the snapshots in keyed streams

Posted by sr...@apache.org.
[FLINK-6034] [checkpoints] Introduce KeyedStateHandle abstraction for the snapshots in keyed streams


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/cd552741
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/cd552741
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/cd552741

Branch: refs/heads/master
Commit: cd5527417a1cae57073a8855c6c3b88c88c780aa
Parents: 89866a5
Author: xiaogang.sxg <xi...@alibaba-inc.com>
Authored: Thu Mar 23 23:32:15 2017 +0800
Committer: Stefan Richter <s....@data-artisans.com>
Committed: Tue Mar 28 20:05:28 2017 +0200

----------------------------------------------------------------------
 .../state/RocksDBKeyedStateBackend.java         | 46 ++++++++++----
 .../state/RocksDBAsyncSnapshotTest.java         |  3 +-
 .../state/RocksDBStateBackendTest.java          | 21 ++++---
 .../cep/operator/CEPMigration12to13Test.java    | 14 ++---
 .../apache/flink/migration/MigrationUtil.java   | 10 +--
 .../checkpoint/StateAssignmentOperation.java    | 41 ++++++------
 .../flink/runtime/checkpoint/SubtaskState.java  | 14 ++---
 .../savepoint/SavepointV1Serializer.java        | 42 +++++++------
 .../state/AbstractKeyedStateBackend.java        |  2 +-
 .../runtime/state/KeyGroupsStateHandle.java     | 39 ++++--------
 .../flink/runtime/state/KeyedStateHandle.java   | 40 ++++++++++++
 .../state/StateInitializationContextImpl.java   | 28 ++++++++-
 .../StateSnapshotContextSynchronousImpl.java    | 12 ++--
 .../flink/runtime/state/TaskStateHandles.java   | 16 ++---
 .../state/heap/HeapKeyedStateBackend.java       | 46 ++++++++++----
 .../checkpoint/CheckpointCoordinatorTest.java   | 29 +++++----
 .../checkpoint/CheckpointStateRestoreTest.java  |  3 +-
 .../savepoint/MigrationV0ToV1Test.java          | 14 ++++-
 .../KeyedStateCheckpointOutputStreamTest.java   |  4 +-
 .../runtime/state/StateBackendTestBase.java     | 66 ++++++++++----------
 ...pKeyedStateBackendSnapshotMigrationTest.java |  3 +-
 .../api/operators/AbstractStreamOperator.java   |  7 ++-
 .../api/operators/OperatorSnapshotResult.java   | 18 +++---
 .../runtime/tasks/OperatorStateHandles.java     | 14 ++---
 .../streaming/runtime/tasks/StreamTask.java     | 14 ++---
 .../operators/AbstractStreamOperatorTest.java   | 10 +--
 .../operators/OperatorSnapshotResultTest.java   | 10 +--
 .../StateInitializationContextImplTest.java     |  9 +--
 .../tasks/InterruptSensitiveRestoreTest.java    | 17 ++---
 .../streaming/runtime/tasks/StreamTaskTest.java | 14 ++---
 .../util/AbstractStreamOperatorTestHarness.java | 25 ++++----
 .../KeyedOneInputStreamOperatorTestHarness.java | 17 ++---
 .../KeyedTwoInputStreamOperatorTestHarness.java |  3 +-
 33 files changed, 389 insertions(+), 262 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index 2ce527f..0407070 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -40,6 +40,7 @@ import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.migration.MigrationNamespaceSerializerProxy;
 import org.apache.flink.migration.MigrationUtil;
 import org.apache.flink.migration.contrib.streaming.state.RocksDBStateBackend;
+import org.apache.flink.migration.state.MigrationKeyGroupStateHandle;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.io.async.AbstractAsyncIOCallable;
 import org.apache.flink.runtime.io.async.AsyncStoppableTaskWithCallback;
@@ -52,6 +53,7 @@ import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedBackendSerializationProxy;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.RegisteredBackendStateMetaInfo;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.internal.InternalAggregatingState;
@@ -257,7 +259,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	 * @throws Exception
 	 */
 	@Override
-	public RunnableFuture<KeyGroupsStateHandle> snapshot(
+	public RunnableFuture<KeyedStateHandle> snapshot(
 			final long checkpointId,
 			final long timestamp,
 			final CheckpointStreamFactory streamFactory,
@@ -286,8 +288,8 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		}
 
 		// implementation of the async IO operation, based on FutureTask
-		AbstractAsyncIOCallable<KeyGroupsStateHandle, CheckpointStreamFactory.CheckpointStateOutputStream> ioCallable =
-				new AbstractAsyncIOCallable<KeyGroupsStateHandle, CheckpointStreamFactory.CheckpointStateOutputStream>() {
+		AbstractAsyncIOCallable<KeyedStateHandle, CheckpointStreamFactory.CheckpointStateOutputStream> ioCallable =
+				new AbstractAsyncIOCallable<KeyedStateHandle, CheckpointStreamFactory.CheckpointStateOutputStream>() {
 
 					@Override
 					public CheckpointStreamFactory.CheckpointStateOutputStream openIOHandle() throws Exception {
@@ -620,7 +622,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	}
 
 	@Override
-	public void restore(Collection<KeyGroupsStateHandle> restoreState) throws Exception {
+	public void restore(Collection<KeyedStateHandle> restoreState) throws Exception {
 		LOG.info("Initializing RocksDB keyed state backend from snapshot.");
 
 		if (LOG.isDebugEnabled()) {
@@ -669,17 +671,23 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		/**
 		 * Restores all key-groups data that is referenced by the passed state handles.
 		 *
-		 * @param keyGroupsStateHandles List of all key groups state handles that shall be restored.
+		 * @param keyedStateHandles List of all key groups state handles that shall be restored.
 		 * @throws IOException
 		 * @throws ClassNotFoundException
 		 * @throws RocksDBException
 		 */
-		public void doRestore(Collection<KeyGroupsStateHandle> keyGroupsStateHandles)
+		public void doRestore(Collection<KeyedStateHandle> keyedStateHandles)
 				throws IOException, ClassNotFoundException, RocksDBException {
 
-			for (KeyGroupsStateHandle keyGroupsStateHandle : keyGroupsStateHandles) {
-				if (keyGroupsStateHandle != null) {
-					this.currentKeyGroupsStateHandle = keyGroupsStateHandle;
+			for (KeyedStateHandle keyedStateHandle : keyedStateHandles) {
+				if (keyedStateHandle != null) {
+
+					if (!(keyedStateHandle instanceof KeyGroupsStateHandle)) {
+						throw new IllegalStateException("Unexpected state handle type, " +
+								"expected: " + KeyGroupsStateHandle.class +
+								", but found: " + keyedStateHandle.getClass());
+					}
+					this.currentKeyGroupsStateHandle = (KeyGroupsStateHandle) keyedStateHandle;
 					restoreKeyGroupsInStateHandle();
 				}
 			}
@@ -761,6 +769,12 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		private void restoreKVStateData() throws IOException, RocksDBException {
 			//for all key-groups in the current state handle...
 			for (Tuple2<Integer, Long> keyGroupOffset : currentKeyGroupsStateHandle.getGroupRangeOffsets()) {
+				int keyGroup = keyGroupOffset.f0;
+
+				// Check that restored key groups all belong to the backend
+				Preconditions.checkState(rocksDBKeyedStateBackend.getKeyGroupRange().contains(keyGroup),
+					"The key group must belong to the backend");
+
 				long offset = keyGroupOffset.f1;
 				//not empty key-group?
 				if (0L != offset) {
@@ -1143,15 +1157,25 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	 * For backwards compatibility, remove again later!
 	 */
 	@Deprecated
-	private void restoreOldSavepointKeyedState(Collection<KeyGroupsStateHandle> restoreState) throws Exception {
+	private void restoreOldSavepointKeyedState(Collection<KeyedStateHandle> restoreState) throws Exception {
 
 		if (restoreState.isEmpty()) {
 			return;
 		}
 
 		Preconditions.checkState(1 == restoreState.size(), "Only one element expected here.");
+
+		KeyedStateHandle keyedStateHandle = restoreState.iterator().next();
+		if (!(keyedStateHandle instanceof MigrationKeyGroupStateHandle)) {
+			throw new IllegalStateException("Unexpected state handle type, " +
+					"expected: " + MigrationKeyGroupStateHandle.class +
+					", but found: " + keyedStateHandle.getClass());
+		}
+
+		MigrationKeyGroupStateHandle keyGroupStateHandle = (MigrationKeyGroupStateHandle) keyedStateHandle;
+
 		HashMap<String, RocksDBStateBackend.FinalFullyAsyncSnapshot> namedStates;
-		try (FSDataInputStream inputStream = restoreState.iterator().next().openInputStream()) {
+		try (FSDataInputStream inputStream = keyGroupStateHandle.openInputStream()) {
 			namedStates = InstantiationUtil.deserializeObject(inputStream, userCodeClassLoader);
 		}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
index 90de7a6..ffe2ce2 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
@@ -42,6 +42,7 @@ import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory;
@@ -343,7 +344,7 @@ public class RocksDBAsyncSnapshotTest {
 			StringSerializer.INSTANCE,
 			new ValueStateDescriptor<>("foobar", String.class));
 
-		RunnableFuture<KeyGroupsStateHandle> snapshotFuture = keyedStateBackend.snapshot(
+		RunnableFuture<KeyedStateHandle> snapshotFuture = keyedStateBackend.snapshot(
 			checkpointId, timestamp, checkpointStreamFactory, CheckpointOptions.forFullCheckpoint());
 
 		try {

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
index 708613b..d95a9b4 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
@@ -33,6 +33,7 @@ import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.StateBackendTestBase;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
@@ -172,7 +173,7 @@ public class RocksDBStateBackendTest extends StateBackendTestBase<RocksDBStateBa
 	@Test
 	public void testRunningSnapshotAfterBackendClosed() throws Exception {
 		setupRocksKeyedStateBackend();
-		RunnableFuture<KeyGroupsStateHandle> snapshot = keyedStateBackend.snapshot(0L, 0L, testStreamFactory,
+		RunnableFuture<KeyedStateHandle> snapshot = keyedStateBackend.snapshot(0L, 0L, testStreamFactory,
 			CheckpointOptions.forFullCheckpoint());
 
 		RocksDB spyDB = keyedStateBackend.db;
@@ -210,7 +211,7 @@ public class RocksDBStateBackendTest extends StateBackendTestBase<RocksDBStateBa
 	@Test
 	public void testReleasingSnapshotAfterBackendClosed() throws Exception {
 		setupRocksKeyedStateBackend();
-		RunnableFuture<KeyGroupsStateHandle> snapshot = keyedStateBackend.snapshot(0L, 0L, testStreamFactory,
+		RunnableFuture<KeyedStateHandle> snapshot = keyedStateBackend.snapshot(0L, 0L, testStreamFactory,
 			CheckpointOptions.forFullCheckpoint());
 
 		RocksDB spyDB = keyedStateBackend.db;
@@ -239,7 +240,7 @@ public class RocksDBStateBackendTest extends StateBackendTestBase<RocksDBStateBa
 	@Test
 	public void testDismissingSnapshot() throws Exception {
 		setupRocksKeyedStateBackend();
-		RunnableFuture<KeyGroupsStateHandle> snapshot = keyedStateBackend.snapshot(0L, 0L, testStreamFactory, CheckpointOptions.forFullCheckpoint());
+		RunnableFuture<KeyedStateHandle> snapshot = keyedStateBackend.snapshot(0L, 0L, testStreamFactory, CheckpointOptions.forFullCheckpoint());
 		snapshot.cancel(true);
 		verifyRocksObjectsReleased();
 	}
@@ -247,7 +248,7 @@ public class RocksDBStateBackendTest extends StateBackendTestBase<RocksDBStateBa
 	@Test
 	public void testDismissingSnapshotNotRunnable() throws Exception {
 		setupRocksKeyedStateBackend();
-		RunnableFuture<KeyGroupsStateHandle> snapshot = keyedStateBackend.snapshot(0L, 0L, testStreamFactory, CheckpointOptions.forFullCheckpoint());
+		RunnableFuture<KeyedStateHandle> snapshot = keyedStateBackend.snapshot(0L, 0L, testStreamFactory, CheckpointOptions.forFullCheckpoint());
 		snapshot.cancel(true);
 		Thread asyncSnapshotThread = new Thread(snapshot);
 		asyncSnapshotThread.start();
@@ -264,7 +265,7 @@ public class RocksDBStateBackendTest extends StateBackendTestBase<RocksDBStateBa
 	@Test
 	public void testCompletingSnapshot() throws Exception {
 		setupRocksKeyedStateBackend();
-		RunnableFuture<KeyGroupsStateHandle> snapshot = keyedStateBackend.snapshot(0L, 0L, testStreamFactory, CheckpointOptions.forFullCheckpoint());
+		RunnableFuture<KeyedStateHandle> snapshot = keyedStateBackend.snapshot(0L, 0L, testStreamFactory, CheckpointOptions.forFullCheckpoint());
 		Thread asyncSnapshotThread = new Thread(snapshot);
 		asyncSnapshotThread.start();
 		waiter.await(); // wait for snapshot to run
@@ -272,10 +273,10 @@ public class RocksDBStateBackendTest extends StateBackendTestBase<RocksDBStateBa
 		runStateUpdates();
 		blocker.trigger(); // allow checkpointing to start writing
 		waiter.await(); // wait for snapshot stream writing to run
-		KeyGroupsStateHandle keyGroupsStateHandle = snapshot.get();
-		assertNotNull(keyGroupsStateHandle);
-		assertTrue(keyGroupsStateHandle.getStateSize() > 0);
-		assertEquals(2, keyGroupsStateHandle.getNumberOfKeyGroups());
+		KeyedStateHandle keyedStateHandle = snapshot.get();
+		assertNotNull(keyedStateHandle);
+		assertTrue(keyedStateHandle.getStateSize() > 0);
+		assertEquals(2, keyedStateHandle.getKeyGroupRange().getNumberOfKeyGroups());
 		assertTrue(testStreamFactory.getLastCreatedStream().isClosed());
 		asyncSnapshotThread.join();
 		verifyRocksObjectsReleased();
@@ -284,7 +285,7 @@ public class RocksDBStateBackendTest extends StateBackendTestBase<RocksDBStateBa
 	@Test
 	public void testCancelRunningSnapshot() throws Exception {
 		setupRocksKeyedStateBackend();
-		RunnableFuture<KeyGroupsStateHandle> snapshot = keyedStateBackend.snapshot(0L, 0L, testStreamFactory, CheckpointOptions.forFullCheckpoint());
+		RunnableFuture<KeyedStateHandle> snapshot = keyedStateBackend.snapshot(0L, 0L, testStreamFactory, CheckpointOptions.forFullCheckpoint());
 		Thread asyncSnapshotThread = new Thread(snapshot);
 		asyncSnapshotThread.start();
 		waiter.await(); // wait for snapshot to run

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPMigration12to13Test.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPMigration12to13Test.java b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPMigration12to13Test.java
index f230bbc..dbe4230 100644
--- a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPMigration12to13Test.java
+++ b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPMigration12to13Test.java
@@ -26,7 +26,7 @@ import org.apache.flink.cep.nfa.NFA;
 import org.apache.flink.cep.nfa.compiler.NFACompiler;
 import org.apache.flink.cep.pattern.Pattern;
 import org.apache.flink.cep.pattern.conditions.SimpleCondition;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.streaming.api.watermark.Watermark;
@@ -128,8 +128,8 @@ public class CEPMigration12to13Test {
 		final OperatorStateHandles snapshot = new OperatorStateHandles(
 			(int) ois.readObject(),
 			(StreamStateHandle) ois.readObject(),
-			(Collection<KeyGroupsStateHandle>) ois.readObject(),
-			(Collection<KeyGroupsStateHandle>) ois.readObject(),
+			(Collection<KeyedStateHandle>) ois.readObject(),
+			(Collection<KeyedStateHandle>) ois.readObject(),
 			(Collection<OperatorStateHandle>) ois.readObject(),
 			(Collection<OperatorStateHandle>) ois.readObject()
 		);
@@ -243,8 +243,8 @@ public class CEPMigration12to13Test {
 		final OperatorStateHandles snapshot = new OperatorStateHandles(
 			(int) ois.readObject(),
 			(StreamStateHandle) ois.readObject(),
-			(Collection<KeyGroupsStateHandle>) ois.readObject(),
-			(Collection<KeyGroupsStateHandle>) ois.readObject(),
+			(Collection<KeyedStateHandle>) ois.readObject(),
+			(Collection<KeyedStateHandle>) ois.readObject(),
 			(Collection<OperatorStateHandle>) ois.readObject(),
 			(Collection<OperatorStateHandle>) ois.readObject()
 		);
@@ -363,8 +363,8 @@ public class CEPMigration12to13Test {
 		final OperatorStateHandles snapshot = new OperatorStateHandles(
 			(int) ois.readObject(),
 			(StreamStateHandle) ois.readObject(),
-			(Collection<KeyGroupsStateHandle>) ois.readObject(),
-			(Collection<KeyGroupsStateHandle>) ois.readObject(),
+			(Collection<KeyedStateHandle>) ois.readObject(),
+			(Collection<KeyedStateHandle>) ois.readObject(),
 			(Collection<OperatorStateHandle>) ois.readObject(),
 			(Collection<OperatorStateHandle>) ois.readObject()
 		);

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/migration/MigrationUtil.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/MigrationUtil.java b/flink-runtime/src/main/java/org/apache/flink/migration/MigrationUtil.java
index 9427f72..a4e3a2e 100644
--- a/flink-runtime/src/main/java/org/apache/flink/migration/MigrationUtil.java
+++ b/flink-runtime/src/main/java/org/apache/flink/migration/MigrationUtil.java
@@ -19,17 +19,17 @@
 package org.apache.flink.migration;
 
 import org.apache.flink.migration.state.MigrationKeyGroupStateHandle;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 
 import java.util.Collection;
 
 public class MigrationUtil {
 
 	@SuppressWarnings("deprecation")
-	public static boolean isOldSavepointKeyedState(Collection<KeyGroupsStateHandle> keyGroupsStateHandles) {
-		return (keyGroupsStateHandles != null)
-				&& (keyGroupsStateHandles.size() == 1)
-				&& (keyGroupsStateHandles.iterator().next() instanceof MigrationKeyGroupStateHandle);
+	public static boolean isOldSavepointKeyedState(Collection<KeyedStateHandle> keyedStateHandles) {
+		return (keyedStateHandles != null)
+				&& (keyedStateHandles.size() == 1)
+				&& (keyedStateHandles.iterator().next() instanceof MigrationKeyGroupStateHandle);
 	}
 
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
index 3fda430..ac70e1a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
@@ -25,6 +25,7 @@ import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.TaskStateHandles;
@@ -160,8 +161,8 @@ public class StateAssignmentOperation {
 		@SuppressWarnings("unchecked")
 		List<OperatorStateHandle>[] parallelOpStatesStream = new List[chainLength];
 
-		List<KeyGroupsStateHandle> parallelKeyedStatesBackend = new ArrayList<>(oldParallelism);
-		List<KeyGroupsStateHandle> parallelKeyedStateStream = new ArrayList<>(oldParallelism);
+		List<KeyedStateHandle> parallelKeyedStatesBackend = new ArrayList<>(oldParallelism);
+		List<KeyedStateHandle> parallelKeyedStateStream = new ArrayList<>(oldParallelism);
 
 		for (int p = 0; p < oldParallelism; ++p) {
 			SubtaskState subtaskState = taskState.getState(p);
@@ -173,12 +174,12 @@ public class StateAssignmentOperation {
 				collectParallelStatesByChainOperator(
 						parallelOpStatesStream, subtaskState.getRawOperatorState());
 
-				KeyGroupsStateHandle keyedStateBackend = subtaskState.getManagedKeyedState();
+				KeyedStateHandle keyedStateBackend = subtaskState.getManagedKeyedState();
 				if (null != keyedStateBackend) {
 					parallelKeyedStatesBackend.add(keyedStateBackend);
 				}
 
-				KeyGroupsStateHandle keyedStateStream = subtaskState.getRawKeyedState();
+				KeyedStateHandle keyedStateStream = subtaskState.getRawKeyedState();
 				if (null != keyedStateStream) {
 					parallelKeyedStateStream.add(keyedStateStream);
 				}
@@ -252,13 +253,13 @@ public class StateAssignmentOperation {
 					.getTaskVertices()[subTaskIdx]
 					.getCurrentExecutionAttempt();
 
-			List<KeyGroupsStateHandle> newKeyedStatesBackend;
-			List<KeyGroupsStateHandle> newKeyedStateStream;
+			List<KeyedStateHandle> newKeyedStatesBackend;
+			List<KeyedStateHandle> newKeyedStateStream;
 			if (oldParallelism == newParallelism) {
 				SubtaskState subtaskState = taskState.getState(subTaskIdx);
 				if (subtaskState != null) {
-					KeyGroupsStateHandle oldKeyedStatesBackend = subtaskState.getManagedKeyedState();
-					KeyGroupsStateHandle oldKeyedStatesStream = subtaskState.getRawKeyedState();
+					KeyedStateHandle oldKeyedStatesBackend = subtaskState.getManagedKeyedState();
+					KeyedStateHandle oldKeyedStatesStream = subtaskState.getRawKeyedState();
 					newKeyedStatesBackend = oldKeyedStatesBackend != null ? Collections.singletonList(
 							oldKeyedStatesBackend) : null;
 					newKeyedStateStream = oldKeyedStatesStream != null ? Collections.singletonList(
@@ -269,8 +270,8 @@ public class StateAssignmentOperation {
 				}
 			} else {
 				KeyGroupRange subtaskKeyGroupIds = keyGroupPartitions.get(subTaskIdx);
-				newKeyedStatesBackend = getKeyGroupsStateHandles(parallelKeyedStatesBackend, subtaskKeyGroupIds);
-				newKeyedStateStream = getKeyGroupsStateHandles(parallelKeyedStateStream, subtaskKeyGroupIds);
+				newKeyedStatesBackend = getKeyedStateHandles(parallelKeyedStatesBackend, subtaskKeyGroupIds);
+				newKeyedStateStream = getKeyedStateHandles(parallelKeyedStateStream, subtaskKeyGroupIds);
 			}
 
 			TaskStateHandles taskStateHandles = new TaskStateHandles(
@@ -290,19 +291,21 @@ public class StateAssignmentOperation {
 	 * <p>
 	 * <p>This is publicly visible to be used in tests.
 	 */
-	public static List<KeyGroupsStateHandle> getKeyGroupsStateHandles(
-			Collection<KeyGroupsStateHandle> allKeyGroupsHandles,
-			KeyGroupRange subtaskKeyGroupIds) {
+	public static List<KeyedStateHandle> getKeyedStateHandles(
+			Collection<? extends KeyedStateHandle> keyedStateHandles,
+			KeyGroupRange subtaskKeyGroupRange) {
 
-		List<KeyGroupsStateHandle> subtaskKeyGroupStates = new ArrayList<>();
+		List<KeyedStateHandle> subtaskKeyedStateHandles = new ArrayList<>();
 
-		for (KeyGroupsStateHandle storedKeyGroup : allKeyGroupsHandles) {
-			KeyGroupsStateHandle intersection = storedKeyGroup.getKeyGroupIntersection(subtaskKeyGroupIds);
-			if (intersection.getNumberOfKeyGroups() > 0) {
-				subtaskKeyGroupStates.add(intersection);
+		for (KeyedStateHandle keyedStateHandle : keyedStateHandles) {
+			KeyedStateHandle intersectedKeyedStateHandle = keyedStateHandle.getIntersection(subtaskKeyGroupRange);
+
+			if (intersectedKeyedStateHandle != null) {
+				subtaskKeyedStateHandles.add(intersectedKeyedStateHandle);
 			}
 		}
-		return subtaskKeyGroupStates;
+
+		return subtaskKeyedStateHandles;
 	}
 
 	/**

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
index 1393e32..9e195b1 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
@@ -19,7 +19,7 @@
 package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateObject;
 import org.apache.flink.runtime.state.StateUtil;
@@ -56,12 +56,12 @@ public class SubtaskState implements StateObject {
 	/**
 	 * Snapshot from {@link org.apache.flink.runtime.state.KeyedStateBackend}.
 	 */
-	private final KeyGroupsStateHandle managedKeyedState;
+	private final KeyedStateHandle managedKeyedState;
 
 	/**
 	 * Snapshot written using {@link org.apache.flink.runtime.state.KeyedStateCheckpointOutputStream}.
 	 */
-	private final KeyGroupsStateHandle rawKeyedState;
+	private final KeyedStateHandle rawKeyedState;
 
 	/**
 	 * The state size. This is also part of the deserialized state handle.
@@ -74,8 +74,8 @@ public class SubtaskState implements StateObject {
 			ChainedStateHandle<StreamStateHandle> legacyOperatorState,
 			ChainedStateHandle<OperatorStateHandle> managedOperatorState,
 			ChainedStateHandle<OperatorStateHandle> rawOperatorState,
-			KeyGroupsStateHandle managedKeyedState,
-			KeyGroupsStateHandle rawKeyedState) {
+			KeyedStateHandle managedKeyedState,
+			KeyedStateHandle rawKeyedState) {
 
 		this.legacyOperatorState = checkNotNull(legacyOperatorState, "State");
 		this.managedOperatorState = managedOperatorState;
@@ -114,11 +114,11 @@ public class SubtaskState implements StateObject {
 		return rawOperatorState;
 	}
 
-	public KeyGroupsStateHandle getManagedKeyedState() {
+	public KeyedStateHandle getManagedKeyedState() {
 		return managedKeyedState;
 	}
 
-	public KeyGroupsStateHandle getRawKeyedState() {
+	public KeyedStateHandle getRawKeyedState() {
 		return rawKeyedState;
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
index ba1949a..44461d8 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
@@ -26,6 +26,7 @@ import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.filesystem.FileStateHandle;
@@ -155,11 +156,11 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV1> {
 			serializeOperatorStateHandle(stateHandle, dos);
 		}
 
-		KeyGroupsStateHandle keyedStateBackend = subtaskState.getManagedKeyedState();
-		serializeKeyGroupStateHandle(keyedStateBackend, dos);
+		KeyedStateHandle keyedStateBackend = subtaskState.getManagedKeyedState();
+		serializeKeyedStateHandle(keyedStateBackend, dos);
 
-		KeyGroupsStateHandle keyedStateStream = subtaskState.getRawKeyedState();
-		serializeKeyGroupStateHandle(keyedStateStream, dos);
+		KeyedStateHandle keyedStateStream = subtaskState.getRawKeyedState();
+		serializeKeyedStateHandle(keyedStateStream, dos);
 	}
 
 	private static SubtaskState deserializeSubtaskState(DataInputStream dis) throws IOException {
@@ -188,9 +189,9 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV1> {
 			operatorStateStream.add(streamStateHandle);
 		}
 
-		KeyGroupsStateHandle keyedStateBackend = deserializeKeyGroupStateHandle(dis);
+		KeyedStateHandle keyedStateBackend = deserializeKeyedStateHandle(dis);
 
-		KeyGroupsStateHandle keyedStateStream = deserializeKeyGroupStateHandle(dis);
+		KeyedStateHandle keyedStateStream = deserializeKeyedStateHandle(dis);
 
 		ChainedStateHandle<StreamStateHandle> nonPartitionableStateChain =
 				new ChainedStateHandle<>(nonPartitionableState);
@@ -209,23 +210,27 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV1> {
 				keyedStateStream);
 	}
 
-	private static void serializeKeyGroupStateHandle(
-			KeyGroupsStateHandle stateHandle, DataOutputStream dos) throws IOException {
+	private static void serializeKeyedStateHandle(
+			KeyedStateHandle stateHandle, DataOutputStream dos) throws IOException {
+
+		if (stateHandle == null) {
+			dos.writeByte(NULL_HANDLE);
+		} else if (stateHandle instanceof KeyGroupsStateHandle) {
+			KeyGroupsStateHandle keyGroupsStateHandle = (KeyGroupsStateHandle) stateHandle;
 
-		if (stateHandle != null) {
 			dos.writeByte(KEY_GROUPS_HANDLE);
-			dos.writeInt(stateHandle.getGroupRangeOffsets().getKeyGroupRange().getStartKeyGroup());
-			dos.writeInt(stateHandle.getNumberOfKeyGroups());
-			for (int keyGroup : stateHandle.keyGroups()) {
-				dos.writeLong(stateHandle.getOffsetForKeyGroup(keyGroup));
+			dos.writeInt(keyGroupsStateHandle.getKeyGroupRange().getStartKeyGroup());
+			dos.writeInt(keyGroupsStateHandle.getKeyGroupRange().getNumberOfKeyGroups());
+			for (int keyGroup : keyGroupsStateHandle.getKeyGroupRange()) {
+				dos.writeLong(keyGroupsStateHandle.getOffsetForKeyGroup(keyGroup));
 			}
-			serializeStreamStateHandle(stateHandle.getDelegateStateHandle(), dos);
+			serializeStreamStateHandle(keyGroupsStateHandle.getDelegateStateHandle(), dos);
 		} else {
-			dos.writeByte(NULL_HANDLE);
+			throw new IllegalStateException("Unknown KeyedStateHandle type: " + stateHandle.getClass());
 		}
 	}
 
-	private static KeyGroupsStateHandle deserializeKeyGroupStateHandle(DataInputStream dis) throws IOException {
+	private static KeyedStateHandle deserializeKeyedStateHandle(DataInputStream dis) throws IOException {
 		final int type = dis.readByte();
 		if (NULL_HANDLE == type) {
 			return null;
@@ -237,11 +242,12 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV1> {
 			for (int i = 0; i < numKeyGroups; ++i) {
 				offsets[i] = dis.readLong();
 			}
-			KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(keyGroupRange, offsets);
+			KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(
+				keyGroupRange, offsets);
 			StreamStateHandle stateHandle = deserializeStreamStateHandle(dis);
 			return new KeyGroupsStateHandle(keyGroupRangeOffsets, stateHandle);
 		} else {
-			throw new IllegalStateException("Reading invalid KeyGroupsStateHandle, type: " + type);
+			throw new IllegalStateException("Reading invalid KeyedStateHandle, type: " + type);
 		}
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
index e6e7b23..e86f1f8 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
@@ -61,7 +61,7 @@ import static org.apache.flink.util.Preconditions.checkNotNull;
  * @param <K> Type of the key by which state is keyed.
  */
 public abstract class AbstractKeyedStateBackend<K>
-		implements KeyedStateBackend<K>, Snapshotable<KeyGroupsStateHandle>, Closeable {
+		implements KeyedStateBackend<K>, Snapshotable<KeyedStateHandle>, Closeable {
 
 	/** {@link TypeSerializer} for our key. */
 	protected final TypeSerializer<K> keySerializer;

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java
index b454e42..bad7fd4 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java
@@ -29,7 +29,7 @@ import java.io.IOException;
  * consists of a range of key group snapshots. A key group is subset of the available
  * key space. The key groups are identified by their key group indices.
  */
-public class KeyGroupsStateHandle implements StreamStateHandle {
+public class KeyGroupsStateHandle implements StreamStateHandle, KeyedStateHandle {
 
 	private static final long serialVersionUID = -8070326169926626355L;
 
@@ -54,20 +54,18 @@ public class KeyGroupsStateHandle implements StreamStateHandle {
 
 	/**
 	 *
-	 * @return iterable over the key-group range for the key-group state referenced by this handle
+	 * @return the internal key-group range to offsets metadata
 	 */
-	public Iterable<Integer> keyGroups() {
-		return groupRangeOffsets.getKeyGroupRange();
+	public KeyGroupRangeOffsets getGroupRangeOffsets() {
+		return groupRangeOffsets;
 	}
 
-
 	/**
 	 *
-	 * @param keyGroupId the id of a key-group
-	 * @return true if the provided key-group id is contained in the key-group range of this handle
+	 * @return The handle to the actual states
 	 */
-	public boolean containsKeyGroup(int keyGroupId) {
-		return groupRangeOffsets.getKeyGroupRange().contains(keyGroupId);
+	public StreamStateHandle getDelegateStateHandle() {
+		return stateHandle;
 	}
 
 	/**
@@ -85,24 +83,13 @@ public class KeyGroupsStateHandle implements StreamStateHandle {
 	 * @return key-group state over a range that is the intersection between this handle's key-group range and the
 	 *          provided key-group range.
 	 */
-	public KeyGroupsStateHandle getKeyGroupIntersection(KeyGroupRange keyGroupRange) {
+	public KeyGroupsStateHandle getIntersection(KeyGroupRange keyGroupRange) {
 		return new KeyGroupsStateHandle(groupRangeOffsets.getIntersection(keyGroupRange), stateHandle);
 	}
 
-	/**
-	 *
-	 * @return the internal key-group range to offsets metadata
-	 */
-	public KeyGroupRangeOffsets getGroupRangeOffsets() {
-		return groupRangeOffsets;
-	}
-
-	/**
-	 *
-	 * @return number of key-groups in the key-group range of this handle
-	 */
-	public int getNumberOfKeyGroups() {
-		return groupRangeOffsets.getKeyGroupRange().getNumberOfKeyGroups();
+	@Override
+	public KeyGroupRange getKeyGroupRange() {
+		return groupRangeOffsets.getKeyGroupRange();
 	}
 
 	@Override
@@ -120,10 +107,6 @@ public class KeyGroupsStateHandle implements StreamStateHandle {
 		return stateHandle.openInputStream();
 	}
 
-	public StreamStateHandle getDelegateStateHandle() {
-		return stateHandle;
-	}
-
 	@Override
 	public boolean equals(Object o) {
 		if (this == o) {

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateHandle.java
new file mode 100644
index 0000000..dc9c97d
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateHandle.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+/**
+ * Base for the handles of the checkpointed states in keyed streams. When
+ * recovering from failures, the handle will be passed to all tasks whose key
+ * group ranges overlap with it.
+ */
+public interface KeyedStateHandle extends StateObject {
+
+	/**
+	 * Returns the range of the key groups contained in the state.
+	 */
+	KeyGroupRange getKeyGroupRange();
+
+	/**
+	 * Returns a state over a range that is the intersection between this
+	 * handle's key-group range and the provided key-group range.
+	 *
+	 * @param keyGroupRange The key group range to intersect with
+	 */
+	KeyedStateHandle getIntersection(KeyGroupRange keyGroupRange);
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
index 886d214..d82af72 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
@@ -27,9 +27,11 @@ import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.util.Preconditions;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.Iterator;
+import java.util.List;
 import java.util.NoSuchElementException;
 
 /**
@@ -55,7 +57,7 @@ public class StateInitializationContextImpl implements StateInitializationContex
 			boolean restored,
 			OperatorStateStore operatorStateStore,
 			KeyedStateStore keyedStateStore,
-			Collection<KeyGroupsStateHandle> keyGroupsStateHandles,
+			Collection<KeyedStateHandle> keyedStateHandles,
 			Collection<OperatorStateHandle> operatorStateHandles,
 			CloseableRegistry closableRegistry) {
 
@@ -64,7 +66,7 @@ public class StateInitializationContextImpl implements StateInitializationContex
 		this.operatorStateStore = operatorStateStore;
 		this.keyedStateStore = keyedStateStore;
 		this.operatorStateHandles = operatorStateHandles;
-		this.keyGroupsStateHandles = keyGroupsStateHandles;
+		this.keyGroupsStateHandles = transform(keyedStateHandles);
 
 		this.keyedStateIterable = keyGroupsStateHandles == null ?
 				null
@@ -136,6 +138,26 @@ public class StateInitializationContextImpl implements StateInitializationContex
 		IOUtils.closeQuietly(closableRegistry);
 	}
 
+	private static Collection<KeyGroupsStateHandle> transform(Collection<KeyedStateHandle> keyedStateHandles) {
+		if (keyedStateHandles == null) {
+			return null;
+		}
+
+		List<KeyGroupsStateHandle> keyGroupsStateHandles = new ArrayList<>();
+
+		for (KeyedStateHandle keyedStateHandle : keyedStateHandles) {
+			if (! (keyedStateHandle instanceof KeyGroupsStateHandle)) {
+				throw new IllegalStateException("Unexpected state handle type, " +
+					"expected: " + KeyGroupsStateHandle.class +
+					", but found: " + keyedStateHandle.getClass() + ".");
+			}
+
+			keyGroupsStateHandles.add((KeyGroupsStateHandle) keyedStateHandle);
+		}
+
+		return keyGroupsStateHandles;
+	}
+
 	private static class KeyGroupStreamIterator
 			extends AbstractStateStreamIterator<KeyGroupStatePartitionStreamProvider, KeyGroupsStateHandle> {
 
@@ -159,7 +181,7 @@ public class StateInitializationContextImpl implements StateInitializationContex
 
 			while (stateHandleIterator.hasNext()) {
 				currentStateHandle = stateHandleIterator.next();
-				if (currentStateHandle.getNumberOfKeyGroups() > 0) {
+				if (currentStateHandle.getKeyGroupRange().getNumberOfKeyGroups() > 0) {
 					currentOffsetsIterator = currentStateHandle.getGroupRangeOffsets().iterator();
 
 					return true;

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContextSynchronousImpl.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContextSynchronousImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContextSynchronousImpl.java
index 96edccb..5db0138 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContextSynchronousImpl.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContextSynchronousImpl.java
@@ -109,15 +109,17 @@ public class StateSnapshotContextSynchronousImpl implements StateSnapshotContext
 		return operatorStateCheckpointOutputStream;
 	}
 
-	public RunnableFuture<KeyGroupsStateHandle> getKeyedStateStreamFuture() throws IOException {
-		return closeAndUnregisterStreamToObtainStateHandle(keyedStateCheckpointOutputStream);
+	public RunnableFuture<KeyedStateHandle> getKeyedStateStreamFuture() throws IOException {
+		KeyGroupsStateHandle keyGroupsStateHandle = closeAndUnregisterStreamToObtainStateHandle(keyedStateCheckpointOutputStream);
+		return new DoneFuture<KeyedStateHandle>(keyGroupsStateHandle);
 	}
 
 	public RunnableFuture<OperatorStateHandle> getOperatorStateStreamFuture() throws IOException {
-		return closeAndUnregisterStreamToObtainStateHandle(operatorStateCheckpointOutputStream);
+		OperatorStateHandle operatorStateHandle = closeAndUnregisterStreamToObtainStateHandle(operatorStateCheckpointOutputStream);
+		return new DoneFuture<>(operatorStateHandle);
 	}
 
-	private <T extends StreamStateHandle> RunnableFuture<T> closeAndUnregisterStreamToObtainStateHandle(
+	private <T extends StreamStateHandle> T closeAndUnregisterStreamToObtainStateHandle(
 			NonClosingCheckpointOutputStream<T> stream) throws IOException {
 		if (null == stream) {
 			return null;
@@ -126,7 +128,7 @@ public class StateSnapshotContextSynchronousImpl implements StateSnapshotContext
 		closableRegistry.unregisterClosable(stream.getDelegate());
 
 		// for now we only support synchronous writing
-		return new DoneFuture<>(stream.closeAndGetHandle());
+		return stream.closeAndGetHandle();
 	}
 
 	private <T extends StreamStateHandle> void closeAndUnregisterStream(NonClosingCheckpointOutputStream<T> stream) throws IOException {

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java
index 417a9dd..450413a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java
@@ -40,10 +40,10 @@ public class TaskStateHandles implements Serializable {
 	private final ChainedStateHandle<StreamStateHandle> legacyOperatorState;
 
 	/** Collection of handles which represent the managed keyed state of the head operator */
-	private final Collection<KeyGroupsStateHandle> managedKeyedState;
+	private final Collection<KeyedStateHandle> managedKeyedState;
 
 	/** Collection of handles which represent the raw/streamed keyed state of the head operator */
-	private final Collection<KeyGroupsStateHandle> rawKeyedState;
+	private final Collection<KeyedStateHandle> rawKeyedState;
 
 	/** Outer list represents the operator chain, each collection holds handles for managed state of a single operator */
 	private final List<Collection<OperatorStateHandle>> managedOperatorState;
@@ -67,8 +67,8 @@ public class TaskStateHandles implements Serializable {
 			ChainedStateHandle<StreamStateHandle> legacyOperatorState,
 			List<Collection<OperatorStateHandle>> managedOperatorState,
 			List<Collection<OperatorStateHandle>> rawOperatorState,
-			Collection<KeyGroupsStateHandle> managedKeyedState,
-			Collection<KeyGroupsStateHandle> rawKeyedState) {
+			Collection<KeyedStateHandle> managedKeyedState,
+			Collection<KeyedStateHandle> rawKeyedState) {
 
 		this.legacyOperatorState = legacyOperatorState;
 		this.managedKeyedState = managedKeyedState;
@@ -82,11 +82,11 @@ public class TaskStateHandles implements Serializable {
 		return legacyOperatorState;
 	}
 
-	public Collection<KeyGroupsStateHandle> getManagedKeyedState() {
+	public Collection<KeyedStateHandle> getManagedKeyedState() {
 		return managedKeyedState;
 	}
 
-	public Collection<KeyGroupsStateHandle> getRawKeyedState() {
+	public Collection<KeyedStateHandle> getRawKeyedState() {
 		return rawKeyedState;
 	}
 
@@ -110,8 +110,8 @@ public class TaskStateHandles implements Serializable {
 		return out;
 	}
 
-	private static List<KeyGroupsStateHandle> transform(KeyGroupsStateHandle in) {
-		return in == null ? Collections.<KeyGroupsStateHandle>emptyList() : Collections.singletonList(in);
+	private static <T> List<T> transform(T in) {
+		return in == null ? Collections.<T>emptyList() : Collections.singletonList(in);
 	}
 
 	@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
index 46ec5c2..a332d7d 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
@@ -39,6 +39,7 @@ import org.apache.flink.migration.runtime.state.KvStateSnapshot;
 import org.apache.flink.migration.runtime.state.memory.MigrationRestoreSnapshot;
 import org.apache.flink.runtime.io.async.AbstractAsyncIOCallable;
 import org.apache.flink.runtime.io.async.AsyncStoppableTaskWithCallback;
+import org.apache.flink.migration.state.MigrationKeyGroupStateHandle;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
@@ -50,6 +51,7 @@ import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedBackendSerializationProxy;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.RegisteredBackendStateMetaInfo;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.internal.InternalAggregatingState;
@@ -223,7 +225,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 	@Override
 	@SuppressWarnings("unchecked")
-	public  RunnableFuture<KeyGroupsStateHandle> snapshot(
+	public  RunnableFuture<KeyedStateHandle> snapshot(
 			final long checkpointId,
 			final long timestamp,
 			final CheckpointStreamFactory streamFactory,
@@ -267,8 +269,8 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		//--------------------------------------------------- this becomes the end of sync part
 
 		// implementation of the async IO operation, based on FutureTask
-		final AbstractAsyncIOCallable<KeyGroupsStateHandle, CheckpointStreamFactory.CheckpointStateOutputStream> ioCallable =
-				new AbstractAsyncIOCallable<KeyGroupsStateHandle, CheckpointStreamFactory.CheckpointStateOutputStream>() {
+		final AbstractAsyncIOCallable<KeyedStateHandle, CheckpointStreamFactory.CheckpointStateOutputStream> ioCallable =
+				new AbstractAsyncIOCallable<KeyedStateHandle, CheckpointStreamFactory.CheckpointStateOutputStream>() {
 
 					AtomicBoolean open = new AtomicBoolean(false);
 
@@ -340,7 +342,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 					}
 				};
 
-		AsyncStoppableTaskWithCallback<KeyGroupsStateHandle> task = AsyncStoppableTaskWithCallback.from(ioCallable);
+		AsyncStoppableTaskWithCallback<KeyedStateHandle> task = AsyncStoppableTaskWithCallback.from(ioCallable);
 
 		if (!asynchronousSnapshots) {
 			task.run();
@@ -354,7 +356,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 	@SuppressWarnings("deprecation")
 	@Override
-	public void restore(Collection<KeyGroupsStateHandle> restoredState) throws Exception {
+	public void restore(Collection<KeyedStateHandle> restoredState) throws Exception {
 		LOG.info("Initializing heap keyed state backend from snapshot.");
 
 		if (LOG.isDebugEnabled()) {
@@ -369,19 +371,26 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	}
 
 	@SuppressWarnings({"unchecked"})
-	private void restorePartitionedState(Collection<KeyGroupsStateHandle> state) throws Exception {
+	private void restorePartitionedState(Collection<KeyedStateHandle> state) throws Exception {
 
 		final Map<Integer, String> kvStatesById = new HashMap<>();
 		int numRegisteredKvStates = 0;
 		stateTables.clear();
 
-		for (KeyGroupsStateHandle keyGroupsHandle : state) {
+		for (KeyedStateHandle keyedStateHandle : state) {
 
-			if (keyGroupsHandle == null) {
+			if (keyedStateHandle == null) {
 				continue;
 			}
 
-			FSDataInputStream fsDataInputStream = keyGroupsHandle.openInputStream();
+			if (!(keyedStateHandle instanceof KeyGroupsStateHandle)) {
+				throw new IllegalStateException("Unexpected state handle type, " +
+						"expected: " + KeyGroupsStateHandle.class +
+						", but found: " + keyedStateHandle.getClass());
+			}
+
+			KeyGroupsStateHandle keyGroupsStateHandle = (KeyGroupsStateHandle) keyedStateHandle;
+			FSDataInputStream fsDataInputStream = keyGroupsStateHandle.openInputStream();
 			cancelStreamRegistry.registerClosable(fsDataInputStream);
 
 			try {
@@ -412,9 +421,13 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 					}
 				}
 
-				for (Tuple2<Integer, Long> groupOffset : keyGroupsHandle.getGroupRangeOffsets()) {
+				for (Tuple2<Integer, Long> groupOffset : keyGroupsStateHandle.getGroupRangeOffsets()) {
 					int keyGroupIndex = groupOffset.f0;
 					long offset = groupOffset.f1;
+
+					// Check that restored key groups all belong to the backend.
+					Preconditions.checkState(keyGroupRange.contains(keyGroupIndex), "The key group must belong to the backend.");
+
 					fsDataInputStream.seek(offset);
 
 					int writtenKeyGroupIndex = inView.readInt();
@@ -449,7 +462,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	@SuppressWarnings({"unchecked", "rawtypes", "DeprecatedIsStillUsed"})
 	@Deprecated
 	private void restoreOldSavepointKeyedState(
-			Collection<KeyGroupsStateHandle> stateHandles) throws IOException, ClassNotFoundException {
+			Collection<KeyedStateHandle> stateHandles) throws IOException, ClassNotFoundException {
 
 		if (stateHandles.isEmpty()) {
 			return;
@@ -457,8 +470,17 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 		Preconditions.checkState(1 == stateHandles.size(), "Only one element expected here.");
 
+		KeyedStateHandle keyedStateHandle = stateHandles.iterator().next();
+		if (!(keyedStateHandle instanceof MigrationKeyGroupStateHandle)) {
+			throw new IllegalStateException("Unexpected state handle type, " +
+					"expected: " + MigrationKeyGroupStateHandle.class +
+					", but found " + keyedStateHandle.getClass());
+		}
+
+		MigrationKeyGroupStateHandle keyGroupStateHandle = (MigrationKeyGroupStateHandle) keyedStateHandle;
+
 		HashMap<String, KvStateSnapshot<K, ?, ?, ?>> namedStates;
-		try (FSDataInputStream inputStream = stateHandles.iterator().next().openInputStream()) {
+		try (FSDataInputStream inputStream = keyGroupStateHandle.openInputStream()) {
 			namedStates = InstantiationUtil.deserializeObject(inputStream, userCodeClassLoader);
 		}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index d8bba59..117c70d 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -41,6 +41,7 @@ import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.TaskStateHandles;
@@ -2346,13 +2347,13 @@ public class CheckpointCoordinatorTest {
 			ChainedStateHandle<StreamStateHandle> operatorState = taskStateHandles.getLegacyOperatorState();
 			List<Collection<OperatorStateHandle>> opStateBackend = taskStateHandles.getManagedOperatorState();
 			List<Collection<OperatorStateHandle>> opStateRaw = taskStateHandles.getRawOperatorState();
-			Collection<KeyGroupsStateHandle> keyGroupStateBackend = taskStateHandles.getManagedKeyedState();
-			Collection<KeyGroupsStateHandle> keyGroupStateRaw = taskStateHandles.getRawKeyedState();
+			Collection<KeyedStateHandle> keyedStateBackend = taskStateHandles.getManagedKeyedState();
+			Collection<KeyedStateHandle> keyGroupStateRaw = taskStateHandles.getRawKeyedState();
 
 			actualOpStatesBackend.add(opStateBackend);
 			actualOpStatesRaw.add(opStateRaw);
 			assertNull(operatorState);
-			compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyGroupStateBackend);
+			compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyedStateBackend);
 			compareKeyedState(Collections.singletonList(originalKeyedStateRaw), keyGroupStateRaw);
 		}
 		comparePartitionableState(expectedOpStatesBackend, actualOpStatesBackend);
@@ -2690,32 +2691,38 @@ public class CheckpointCoordinatorTest {
 
 			KeyGroupsStateHandle expectPartitionedKeyGroupState = generateKeyGroupState(
 					jobVertexID, keyGroupPartitions.get(i), false);
-			Collection<KeyGroupsStateHandle> actualPartitionedKeyGroupState = taskStateHandles.getManagedKeyedState();
+			Collection<KeyedStateHandle> actualPartitionedKeyGroupState = taskStateHandles.getManagedKeyedState();
 			compareKeyedState(Collections.singletonList(expectPartitionedKeyGroupState), actualPartitionedKeyGroupState);
 		}
 	}
 
 	public static void compareKeyedState(
 			Collection<KeyGroupsStateHandle> expectPartitionedKeyGroupState,
-			Collection<KeyGroupsStateHandle> actualPartitionedKeyGroupState) throws Exception {
+			Collection<? extends KeyedStateHandle> actualPartitionedKeyGroupState) throws Exception {
 
 		KeyGroupsStateHandle expectedHeadOpKeyGroupStateHandle = expectPartitionedKeyGroupState.iterator().next();
-		int expectedTotalKeyGroups = expectedHeadOpKeyGroupStateHandle.getNumberOfKeyGroups();
+		int expectedTotalKeyGroups = expectedHeadOpKeyGroupStateHandle.getKeyGroupRange().getNumberOfKeyGroups();
 		int actualTotalKeyGroups = 0;
-		for(KeyGroupsStateHandle keyGroupsStateHandle: actualPartitionedKeyGroupState) {
-			actualTotalKeyGroups += keyGroupsStateHandle.getNumberOfKeyGroups();
+		for(KeyedStateHandle keyedStateHandle: actualPartitionedKeyGroupState) {
+			assertTrue(keyedStateHandle instanceof KeyGroupsStateHandle);
+
+			actualTotalKeyGroups += keyedStateHandle.getKeyGroupRange().getNumberOfKeyGroups();
 		}
 
 		assertEquals(expectedTotalKeyGroups, actualTotalKeyGroups);
 
 		try (FSDataInputStream inputStream = expectedHeadOpKeyGroupStateHandle.openInputStream()) {
-			for (int groupId : expectedHeadOpKeyGroupStateHandle.keyGroups()) {
+			for (int groupId : expectedHeadOpKeyGroupStateHandle.getKeyGroupRange()) {
 				long offset = expectedHeadOpKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
 				inputStream.seek(offset);
 				int expectedKeyGroupState =
 						InstantiationUtil.deserializeObject(inputStream, Thread.currentThread().getContextClassLoader());
-				for (KeyGroupsStateHandle oneActualKeyGroupStateHandle : actualPartitionedKeyGroupState) {
-					if (oneActualKeyGroupStateHandle.containsKeyGroup(groupId)) {
+				for (KeyedStateHandle oneActualKeyedStateHandle : actualPartitionedKeyGroupState) {
+
+					assertTrue(oneActualKeyedStateHandle instanceof KeyGroupsStateHandle);
+
+					KeyGroupsStateHandle oneActualKeyGroupStateHandle = (KeyGroupsStateHandle) oneActualKeyedStateHandle;
+					if (oneActualKeyGroupStateHandle.getKeyGroupRange().contains(groupId)) {
 						long actualOffset = oneActualKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
 						try (FSDataInputStream actualInputStream = oneActualKeyGroupStateHandle.openInputStream()) {
 							actualInputStream.seek(actualOffset);

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
index 18b07eb..7e0a7c1 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
@@ -31,6 +31,7 @@ import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.TaskStateHandles;
@@ -68,7 +69,7 @@ public class CheckpointStateRestoreTest {
 			final ChainedStateHandle<StreamStateHandle> serializedState = CheckpointCoordinatorTest.generateChainedStateHandle(new SerializableObject());
 			KeyGroupRange keyGroupRange = KeyGroupRange.of(0,0);
 			List<SerializableObject> testStates = Collections.singletonList(new SerializableObject());
-			final KeyGroupsStateHandle serializedKeyGroupStates = CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, testStates);
+			final KeyedStateHandle serializedKeyGroupStates = CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, testStates);
 
 			final JobID jid = new JobID();
 			final JobVertexID statefulId = new JobVertexID();

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/MigrationV0ToV1Test.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/MigrationV0ToV1Test.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/MigrationV0ToV1Test.java
index 6ab8620..1ecb2e3 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/MigrationV0ToV1Test.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/MigrationV0ToV1Test.java
@@ -37,6 +37,7 @@ import org.apache.flink.runtime.checkpoint.TaskState;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
@@ -58,6 +59,7 @@ import java.util.concurrent.ThreadLocalRandom;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
 
 @SuppressWarnings("deprecation")
 public class MigrationV0ToV1Test {
@@ -154,9 +156,15 @@ public class MigrationV0ToV1Test {
 					}
 
 					//check keyed state
-					KeyGroupsStateHandle keyGroupsStateHandle = subtaskState.getManagedKeyedState();
+					KeyedStateHandle keyedStateHandle = subtaskState.getManagedKeyedState();
+
 					if (t % 3 != 0) {
-						assertEquals(1, keyGroupsStateHandle.getNumberOfKeyGroups());
+
+						assertTrue(keyedStateHandle instanceof KeyGroupsStateHandle);
+
+						KeyGroupsStateHandle keyGroupsStateHandle = (KeyGroupsStateHandle) keyedStateHandle;
+
+						assertEquals(1, keyGroupsStateHandle.getKeyGroupRange().getNumberOfKeyGroups());
 						assertEquals(p, keyGroupsStateHandle.getGroupRangeOffsets().getKeyGroupRange().getStartKeyGroup());
 
 						ByteStreamStateHandle stateHandle =
@@ -172,7 +180,7 @@ public class MigrationV0ToV1Test {
 							assertEquals(p, data[1]);
 						}
 					} else {
-						assertEquals(null, keyGroupsStateHandle);
+						assertEquals(null, keyedStateHandle);
 					}
 				}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStreamTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStreamTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStreamTest.java
index 0c4ed74..cee0b02 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStreamTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStreamTest.java
@@ -135,7 +135,7 @@ public class KeyedStateCheckpointOutputStreamTest {
 		int count = 0;
 		try (FSDataInputStream in = fullHandle.openInputStream()) {
 			DataInputView div = new DataInputViewStreamWrapper(in);
-			for (int kg : fullHandle.keyGroups()) {
+			for (int kg : fullHandle.getKeyGroupRange()) {
 				long off = fullHandle.getOffsetForKeyGroup(kg);
 				if (off >= 0) {
 					in.seek(off);
@@ -152,7 +152,7 @@ public class KeyedStateCheckpointOutputStreamTest {
 		int count = 0;
 		try (FSDataInputStream in = fullHandle.openInputStream()) {
 			DataInputView div = new DataInputViewStreamWrapper(in);
-			for (int kg : fullHandle.keyGroups()) {
+			for (int kg : fullHandle.getKeyGroupRange()) {
 				long off = fullHandle.getOffsetForKeyGroup(kg);
 				in.seek(off);
 				Assert.assertEquals(kg, div.readInt());

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
index 22bb715..ccc1eae 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
@@ -143,13 +143,13 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 				env.getTaskKvStateRegistry());
 	}
 
-	protected <K> AbstractKeyedStateBackend<K> restoreKeyedBackend(TypeSerializer<K> keySerializer, KeyGroupsStateHandle state) throws Exception {
+	protected <K> AbstractKeyedStateBackend<K> restoreKeyedBackend(TypeSerializer<K> keySerializer, KeyedStateHandle state) throws Exception {
 		return restoreKeyedBackend(keySerializer, state, new DummyEnvironment("test", 1, 0));
 	}
 
 	protected <K> AbstractKeyedStateBackend<K> restoreKeyedBackend(
 			TypeSerializer<K> keySerializer,
-			KeyGroupsStateHandle state,
+			KeyedStateHandle state,
 			Environment env) throws Exception {
 		return restoreKeyedBackend(
 				keySerializer,
@@ -163,7 +163,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			TypeSerializer<K> keySerializer,
 			int numberOfKeyGroups,
 			KeyGroupRange keyGroupRange,
-			List<KeyGroupsStateHandle> state,
+			List<KeyedStateHandle> state,
 			Environment env) throws Exception {
 
 		AbstractKeyedStateBackend<K> backend = getStateBackend().createKeyedStateBackend(
@@ -436,7 +436,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		backend.setCurrentKey(2);
 		state.update(new TestPojo("u2", 2));
 
-		KeyGroupsStateHandle snapshot = runSnapshot(backend.snapshot(
+		KeyedStateHandle snapshot = runSnapshot(backend.snapshot(
 				682375462378L,
 				2,
 				streamFactory,
@@ -497,7 +497,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		backend.setCurrentKey(2);
 		state.update(new TestPojo("u2", 2));
 
-		KeyGroupsStateHandle snapshot = runSnapshot(backend.snapshot(
+		KeyedStateHandle snapshot = runSnapshot(backend.snapshot(
 				682375462378L,
 				2,
 				streamFactory,
@@ -524,7 +524,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		// update to test state backends that eagerly serialize, such as RocksDB
 		state.update(new TestPojo("u1", 11));
 
-		KeyGroupsStateHandle snapshot2 = runSnapshot(backend.snapshot(
+		KeyedStateHandle snapshot2 = runSnapshot(backend.snapshot(
 				682375462378L,
 				2,
 				streamFactory,
@@ -585,7 +585,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		backend.setCurrentKey(2);
 		state.update(new TestPojo("u2", 2));
 
-		KeyGroupsStateHandle snapshot = runSnapshot(backend.snapshot(
+		KeyedStateHandle snapshot = runSnapshot(backend.snapshot(
 				682375462378L,
 				2,
 				streamFactory,
@@ -611,7 +611,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		// update to test state backends that eagerly serialize, such as RocksDB
 		state.update(new TestPojo("u1", 11));
 
-		KeyGroupsStateHandle snapshot2 = runSnapshot(backend.snapshot(
+		KeyedStateHandle snapshot2 = runSnapshot(backend.snapshot(
 				682375462378L,
 				2,
 				streamFactory,
@@ -670,7 +670,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		assertEquals("1", getSerializedValue(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
 		// draw a snapshot
-		KeyGroupsStateHandle snapshot1 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
+		KeyedStateHandle snapshot1 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 		// make some more modifications
 		backend.setCurrentKey(1);
@@ -681,7 +681,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		state.update("u3");
 
 		// draw another snapshot
-		KeyGroupsStateHandle snapshot2 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forFullCheckpoint()));
+		KeyedStateHandle snapshot2 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 		// validate the original state
 		backend.setCurrentKey(1);
@@ -880,7 +880,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		assertEquals(13, (int) state2.value());
 
 		// draw a snapshot
-		KeyGroupsStateHandle snapshot1 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
+		KeyedStateHandle snapshot1 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 		backend.dispose();
 		backend = restoreKeyedBackend(
@@ -952,7 +952,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		assertEquals(42L, (long) state.value());
 
 		// draw a snapshot
-		KeyGroupsStateHandle snapshot1 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
+		KeyedStateHandle snapshot1 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 		backend.dispose();
 		backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot1);
@@ -997,7 +997,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		assertEquals("1", joiner.join(getSerializedList(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)));
 
 		// draw a snapshot
-		KeyGroupsStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
+		KeyedStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 		// make some more modifications
 		backend.setCurrentKey(1);
@@ -1008,7 +1008,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		state.add("u3");
 
 		// draw another snapshot
-		KeyGroupsStateHandle snapshot2 = runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forFullCheckpoint()));
+		KeyedStateHandle snapshot2 = runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 		// validate the original state
 		backend.setCurrentKey(1);
@@ -1091,7 +1091,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		assertEquals("1", getSerializedValue(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
 		// draw a snapshot
-		KeyGroupsStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
+		KeyedStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 		// make some more modifications
 		backend.setCurrentKey(1);
@@ -1102,7 +1102,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		state.add("u3");
 
 		// draw another snapshot
-		KeyGroupsStateHandle snapshot2 = runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forFullCheckpoint()));
+		KeyedStateHandle snapshot2 = runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 		// validate the original state
 		backend.setCurrentKey(1);
@@ -1188,7 +1188,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		assertEquals("Fold-Initial:,1", getSerializedValue(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
 		// draw a snapshot
-		KeyGroupsStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
+		KeyedStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 		// make some more modifications
 		backend.setCurrentKey(1);
@@ -1200,7 +1200,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		state.add(103);
 
 		// draw another snapshot
-		KeyGroupsStateHandle snapshot2 = runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forFullCheckpoint()));
+		KeyedStateHandle snapshot2 = runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 		// validate the original state
 		backend.setCurrentKey(1);
@@ -1287,7 +1287,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 				getSerializedMap(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, userValueSerializer));
 
 		// draw a snapshot
-		KeyGroupsStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
+		KeyedStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 		// make some more modifications
 		backend.setCurrentKey(1);
@@ -1299,7 +1299,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		state.putAll(new HashMap<Integer, String>() {{ put(1031, "1031"); put(1032, "1032"); }});
 
 		// draw another snapshot
-		KeyGroupsStateHandle snapshot2 = runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forFullCheckpoint()));
+		KeyedStateHandle snapshot2 = runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 		// validate the original state
 		backend.setCurrentKey(1);
@@ -1606,13 +1606,13 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		state.update("ShouldBeInSecondHalf");
 
 
-		KeyGroupsStateHandle snapshot = FutureUtil.runIfNotDoneAndGet(backend.snapshot(0, 0, streamFactory, CheckpointOptions.forFullCheckpoint()));
+		KeyedStateHandle snapshot = FutureUtil.runIfNotDoneAndGet(backend.snapshot(0, 0, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
-		List<KeyGroupsStateHandle> firstHalfKeyGroupStates = StateAssignmentOperation.getKeyGroupsStateHandles(
+		List<KeyedStateHandle> firstHalfKeyGroupStates = StateAssignmentOperation.getKeyedStateHandles(
 				Collections.singletonList(snapshot),
 				KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(MAX_PARALLELISM, 2, 0));
 
-		List<KeyGroupsStateHandle> secondHalfKeyGroupStates = StateAssignmentOperation.getKeyGroupsStateHandles(
+		List<KeyedStateHandle> secondHalfKeyGroupStates = StateAssignmentOperation.getKeyedStateHandles(
 				Collections.singletonList(snapshot),
 				KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(MAX_PARALLELISM, 2, 1));
 
@@ -1672,7 +1672,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			state.update("2");
 
 			// draw a snapshot
-			KeyGroupsStateHandle snapshot1 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
+			KeyedStateHandle snapshot1 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 			backend.dispose();
 			// restore the first snapshot and validate it
@@ -1723,7 +1723,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			state.add("2");
 
 			// draw a snapshot
-			KeyGroupsStateHandle snapshot1 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
+			KeyedStateHandle snapshot1 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 			backend.dispose();
 			// restore the first snapshot and validate it
@@ -1776,7 +1776,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			state.add("2");
 
 			// draw a snapshot
-			KeyGroupsStateHandle snapshot1 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
+			KeyedStateHandle snapshot1 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 			backend.dispose();
 			// restore the first snapshot and validate it
@@ -1827,7 +1827,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			state.put("2", "Second");
 
 			// draw a snapshot
-			KeyGroupsStateHandle snapshot1 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
+			KeyedStateHandle snapshot1 = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 			backend.dispose();
 			// restore the first snapshot and validate it
@@ -2093,7 +2093,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 				eq(env.getJobID()), eq(env.getJobVertexId()), eq(expectedKeyGroupRange), eq("banana"), any(KvStateID.class));
 
 
-		KeyGroupsStateHandle snapshot = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forFullCheckpoint()));
+		KeyedStateHandle snapshot = FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forFullCheckpoint()));
 
 		backend.dispose();
 
@@ -2124,7 +2124,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			ListStateDescriptor<String> kvId = new ListStateDescriptor<>("id", String.class);
 
 			// draw a snapshot
-			KeyGroupsStateHandle snapshot =
+			KeyedStateHandle snapshot =
 					FutureUtil.runIfNotDoneAndGet(backend.snapshot(682375462379L, 1, streamFactory, CheckpointOptions.forFullCheckpoint()));
 			assertNull(snapshot);
 			backend.dispose();
@@ -2152,7 +2152,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		streamFactory.setWaiterLatch(waiter);
 
 		AbstractKeyedStateBackend<Integer> backend = null;
-		KeyGroupsStateHandle stateHandle = null;
+		KeyedStateHandle stateHandle = null;
 
 		try {
 			backend = createKeyedBackend(IntSerializer.INSTANCE);
@@ -2167,7 +2167,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 				valueState.update(i);
 			}
 
-			RunnableFuture<KeyGroupsStateHandle> snapshot =
+			RunnableFuture<KeyedStateHandle> snapshot =
 					backend.snapshot(0L, 0L, streamFactory, CheckpointOptions.forFullCheckpoint());
 			Thread runner = new Thread(snapshot);
 			runner.start();
@@ -2249,7 +2249,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 				valueState.update(i);
 			}
 
-			RunnableFuture<KeyGroupsStateHandle> snapshot =
+			RunnableFuture<KeyedStateHandle> snapshot =
 					backend.snapshot(0L, 0L, streamFactory, CheckpointOptions.forFullCheckpoint());
 
 			Thread runner = new Thread(snapshot);
@@ -2367,7 +2367,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		}
 	}
 
-	private KeyGroupsStateHandle runSnapshot(RunnableFuture<KeyGroupsStateHandle> snapshotRunnableFuture) throws Exception {
+	private KeyedStateHandle runSnapshot(RunnableFuture<KeyedStateHandle> snapshotRunnableFuture) throws Exception {
 		if(!snapshotRunnableFuture.isDone()) {
 			Thread runner = new Thread(snapshotRunnableFuture);
 			runner.start();

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendSnapshotMigrationTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendSnapshotMigrationTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendSnapshotMigrationTest.java
index da0666a..3754d63 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendSnapshotMigrationTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendSnapshotMigrationTest.java
@@ -22,6 +22,7 @@ import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.typeutils.base.IntSerializer;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.internal.InternalListState;
 import org.apache.flink.util.InstantiationUtil;
 import org.apache.flink.util.Preconditions;
@@ -63,7 +64,7 @@ public class HeapKeyedStateBackendSnapshotMigrationTest extends HeapStateBackend
 			try (BufferedInputStream bis = new BufferedInputStream((new FileInputStream(resource.getFile())))) {
 				stateHandle = InstantiationUtil.deserializeObject(bis, Thread.currentThread().getContextClassLoader());
 			}
-			keyedBackend.restore(Collections.singleton(stateHandle));
+			keyedBackend.restore(Collections.<KeyedStateHandle>singleton(stateHandle));
 			final ListStateDescriptor<Long> stateDescr = new ListStateDescriptor<>("my-state", Long.class);
 			stateDescr.initializeSerializerUnlessSet(new ExecutionConfig());
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cd552741/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
index e40a59b..a6a89b5 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
@@ -18,7 +18,6 @@
 
 package org.apache.flink.streaming.api.operators;
 
-import java.io.IOException;
 import org.apache.commons.io.IOUtils;
 import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
 import org.apache.flink.annotation.PublicEvolving;
@@ -47,9 +46,9 @@ import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.KeyGroupStatePartitionStreamProvider;
 import org.apache.flink.runtime.state.KeyGroupsList;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.KeyedStateCheckpointOutputStream;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateBackend;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateInitializationContext;
@@ -70,6 +69,7 @@ import org.apache.flink.util.OutputTag;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.io.IOException;
 import java.io.Serializable;
 import java.util.Collection;
 import java.util.ConcurrentModificationException;
@@ -198,7 +198,7 @@ public abstract class AbstractStreamOperator<OUT>
 	@Override
 	public final void initializeState(OperatorStateHandles stateHandles) throws Exception {
 
-		Collection<KeyGroupsStateHandle> keyedStateHandlesRaw = null;
+		Collection<KeyedStateHandle> keyedStateHandlesRaw = null;
 		Collection<OperatorStateHandle> operatorStateHandlesRaw = null;
 		Collection<OperatorStateHandle> operatorStateHandlesBackend = null;
 
@@ -473,6 +473,7 @@ public abstract class AbstractStreamOperator<OUT>
 			// and then initialize the timer services
 			for (KeyGroupStatePartitionStreamProvider streamProvider : context.getRawKeyedStateInputs()) {
 				int keyGroupIdx = streamProvider.getKeyGroupId();
+
 				checkArgument(localKeyGroupRange.contains(keyGroupIdx),
 					"Key Group " + keyGroupIdx + " does not belong to the local range.");