You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by sr...@apache.org on 2017/08/15 12:57:09 UTC

[1/7] flink git commit: [FLINK-7268] Add delaying executor in *EventTimeWindowCheckpointingITCase

Repository: flink
Updated Branches:
  refs/heads/master 3b0321aee -> d29bed383


[FLINK-7268] Add delaying executor in *EventTimeWindowCheckpointingITCase

This helps tease out races, for example the recently discovered one in
cleanup of incremental state handles at the SharedStateRegistry.

(cherry picked from commit d7683cc)


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/d29bed38
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/d29bed38
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/d29bed38

Branch: refs/heads/master
Commit: d29bed38311f7a01d2241fbf8fa26eac7f012f53
Parents: 91a4b27
Author: Aljoscha Krettek <al...@gmail.com>
Authored: Fri Jul 28 15:01:35 2017 +0200
Committer: Stefan Richter <s....@data-artisans.com>
Committed: Tue Aug 15 14:56:54 2017 +0200

----------------------------------------------------------------------
 ...bstractEventTimeWindowCheckpointingITCase.java | 18 +++++++++++++++++-
 1 file changed, 17 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/d29bed38/flink-tests/src/test/java/org/apache/flink/test/checkpointing/AbstractEventTimeWindowCheckpointingITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/AbstractEventTimeWindowCheckpointingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/AbstractEventTimeWindowCheckpointingITCase.java
index c525a37..4d5fa71 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/AbstractEventTimeWindowCheckpointingITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/AbstractEventTimeWindowCheckpointingITCase.java
@@ -31,6 +31,8 @@ import org.apache.flink.configuration.HighAvailabilityOptions;
 import org.apache.flink.configuration.TaskManagerOptions;
 import org.apache.flink.contrib.streaming.state.RocksDBStateBackend;
 import org.apache.flink.core.fs.Path;
+import org.apache.flink.runtime.highavailability.HighAvailabilityServices;
+import org.apache.flink.runtime.highavailability.HighAvailabilityServicesUtils;
 import org.apache.flink.runtime.minicluster.LocalFlinkMiniCluster;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.CheckpointListener;
@@ -62,6 +64,9 @@ import java.io.IOException;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
+import java.util.concurrent.Executor;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
 
 import static java.util.concurrent.TimeUnit.MILLISECONDS;
 import static org.apache.flink.test.checkpointing.AbstractEventTimeWindowCheckpointingITCase.StateBackendEnum.ROCKSDB_INCREMENTAL_ZK;
@@ -139,7 +144,18 @@ public abstract class AbstractEventTimeWindowCheckpointingITCase extends TestLog
 			config.setString(HighAvailabilityOptions.HA_STORAGE_PATH, haDir.toURI().toString());
 		}
 
-		cluster = new LocalFlinkMiniCluster(config, false);
+		// purposefully delay in the executor to tease out races
+		final ScheduledExecutorService executor = Executors.newScheduledThreadPool(10);
+		HighAvailabilityServices haServices = HighAvailabilityServicesUtils.createAvailableOrEmbeddedServices(
+			config,
+			new Executor() {
+				@Override
+				public void execute(Runnable command) {
+					executor.schedule(command, 500, MILLISECONDS);
+				}
+			});
+
+		cluster = new LocalFlinkMiniCluster(config, haServices, false);
 		cluster.start();
 
 		env = new TestStreamEnvironment(cluster, PARALLELISM);


[6/7] flink git commit: [FLINK-7213] Introduce state management by OperatorID in TaskManager

Posted by sr...@apache.org.
http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
index 344b340..88b95f5 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
@@ -23,14 +23,15 @@ import org.apache.flink.runtime.concurrent.Executors;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
 import org.apache.flink.runtime.jobgraph.JobStatus;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
-import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.util.TestLogger;
+
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.powermock.core.classloader.annotations.PrepareForTest;
@@ -42,8 +43,8 @@ import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
-import static org.mockito.Matchers.anyInt;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
@@ -89,29 +90,26 @@ public class CheckpointCoordinatorFailureTest extends TestLogger {
 		assertFalse(pendingCheckpoint.isDiscarded());
 
 		final long checkpointId = coord.getPendingCheckpoints().keySet().iterator().next();
-		
-		SubtaskState subtaskState = mock(SubtaskState.class);
+
 
 		StreamStateHandle legacyHandle = mock(StreamStateHandle.class);
-		ChainedStateHandle<StreamStateHandle> chainedLegacyHandle = mock(ChainedStateHandle.class);
-		when(chainedLegacyHandle.get(anyInt())).thenReturn(legacyHandle);
-		when(subtaskState.getLegacyOperatorState()).thenReturn(chainedLegacyHandle);
+		KeyedStateHandle managedKeyedHandle = mock(KeyedStateHandle.class);
+		KeyedStateHandle rawKeyedHandle = mock(KeyedStateHandle.class);
+		OperatorStateHandle managedOpHandle = mock(OperatorStateHandle.class);
+		OperatorStateHandle rawOpHandle = mock(OperatorStateHandle.class);
 
-		OperatorStateHandle managedHandle = mock(OperatorStateHandle.class);
-		ChainedStateHandle<OperatorStateHandle> chainedManagedHandle = mock(ChainedStateHandle.class);
-		when(chainedManagedHandle.get(anyInt())).thenReturn(managedHandle);
-		when(subtaskState.getManagedOperatorState()).thenReturn(chainedManagedHandle);
+		final OperatorSubtaskState operatorSubtaskState = spy(new OperatorSubtaskState(
+			legacyHandle,
+			managedOpHandle,
+			rawOpHandle,
+			managedKeyedHandle,
+			rawKeyedHandle));
 
-		OperatorStateHandle rawHandle = mock(OperatorStateHandle.class);
-		ChainedStateHandle<OperatorStateHandle> chainedRawHandle = mock(ChainedStateHandle.class);
-		when(chainedRawHandle.get(anyInt())).thenReturn(rawHandle);
-		when(subtaskState.getRawOperatorState()).thenReturn(chainedRawHandle);
+		TaskStateSnapshot subtaskState = spy(new TaskStateSnapshot());
+		subtaskState.putSubtaskStateByOperatorID(new OperatorID(), operatorSubtaskState);
+
+		when(subtaskState.getSubtaskStateByOperatorID(OperatorID.fromJobVertexID(vertex.getJobvertexId()))).thenReturn(operatorSubtaskState);
 
-		KeyedStateHandle managedKeyedHandle = mock(KeyedStateHandle.class);
-		when(subtaskState.getRawKeyedState()).thenReturn(managedKeyedHandle);
-		KeyedStateHandle managedRawHandle = mock(KeyedStateHandle.class);
-		when(subtaskState.getManagedKeyedState()).thenReturn(managedRawHandle);
-		
 		AcknowledgeCheckpoint acknowledgeMessage = new AcknowledgeCheckpoint(jid, executionAttemptId, checkpointId, new CheckpointMetrics(), subtaskState);
 		
 		try {
@@ -126,11 +124,12 @@ public class CheckpointCoordinatorFailureTest extends TestLogger {
 		assertTrue(pendingCheckpoint.isDiscarded());
 
 		// make sure that the subtask state has been discarded after we could not complete it.
-		verify(subtaskState.getLegacyOperatorState().get(0)).discardState();
-		verify(subtaskState.getManagedOperatorState().get(0)).discardState();
-		verify(subtaskState.getRawOperatorState().get(0)).discardState();
-		verify(subtaskState.getManagedKeyedState()).discardState();
-		verify(subtaskState.getRawKeyedState()).discardState();
+		verify(operatorSubtaskState).discardState();
+		verify(operatorSubtaskState.getLegacyOperatorState()).discardState();
+		verify(operatorSubtaskState.getManagedOperatorState().iterator().next()).discardState();
+		verify(operatorSubtaskState.getRawOperatorState().iterator().next()).discardState();
+		verify(operatorSubtaskState.getManagedKeyedState().iterator().next()).discardState();
+		verify(operatorSubtaskState.getRawKeyedState().iterator().next()).discardState();
 	}
 
 	private static final class FailingCompletedCheckpointStore implements CompletedCheckpointStore {

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index cb92df6..d9af879 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -44,7 +44,6 @@ import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.StreamStateHandle;
-import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.state.filesystem.FileStateHandle;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.apache.flink.runtime.testutils.CommonTestUtils;
@@ -93,7 +92,6 @@ import static org.junit.Assert.fail;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyLong;
 import static org.mockito.Mockito.doAnswer;
-import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
@@ -102,7 +100,6 @@ import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
-import static org.mockito.Mockito.withSettings;
 
 /**
  * Tests for the checkpoint coordinator.
@@ -555,31 +552,29 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			assertFalse(checkpoint.isDiscarded());
 			assertFalse(checkpoint.isFullyAcknowledged());
 
-			OperatorID opID1 = OperatorID.fromJobVertexID(vertex1.getJobvertexId());
-			OperatorID opID2 = OperatorID.fromJobVertexID(vertex2.getJobvertexId());
-
-			Map<OperatorID, OperatorState> operatorStates = checkpoint.getOperatorStates();
-
-			operatorStates.put(opID1, new SpyInjectingOperatorState(
-				opID1, vertex1.getTotalNumberOfParallelSubtasks(), vertex1.getMaxParallelism()));
-			operatorStates.put(opID2, new SpyInjectingOperatorState(
-				opID2, vertex2.getTotalNumberOfParallelSubtasks(), vertex2.getMaxParallelism()));
-
 			// check that the vertices received the trigger checkpoint message
 			{
 				verify(vertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId), eq(timestamp), any(CheckpointOptions.class));
 				verify(vertex2.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId), eq(timestamp), any(CheckpointOptions.class));
 			}
 
+			OperatorID opID1 = OperatorID.fromJobVertexID(vertex1.getJobvertexId());
+			OperatorID opID2 = OperatorID.fromJobVertexID(vertex2.getJobvertexId());
+			TaskStateSnapshot taskOperatorSubtaskStates1 = mock(TaskStateSnapshot.class);
+			TaskStateSnapshot taskOperatorSubtaskStates2 = mock(TaskStateSnapshot.class);
+			OperatorSubtaskState subtaskState1 = mock(OperatorSubtaskState.class);
+			OperatorSubtaskState subtaskState2 = mock(OperatorSubtaskState.class);
+			when(taskOperatorSubtaskStates1.getSubtaskStateByOperatorID(opID1)).thenReturn(subtaskState1);
+			when(taskOperatorSubtaskStates2.getSubtaskStateByOperatorID(opID2)).thenReturn(subtaskState2);
+
 			// acknowledge from one of the tasks
-			AcknowledgeCheckpoint acknowledgeCheckpoint1 = new AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class));
+			AcknowledgeCheckpoint acknowledgeCheckpoint1 = new AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), taskOperatorSubtaskStates2);
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint1);
-			OperatorSubtaskState subtaskState2 = operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex());
 			assertEquals(1, checkpoint.getNumberOfAcknowledgedTasks());
 			assertEquals(1, checkpoint.getNumberOfNonAcknowledgedTasks());
 			assertFalse(checkpoint.isDiscarded());
 			assertFalse(checkpoint.isFullyAcknowledged());
-			verify(subtaskState2, never()).registerSharedStates(any(SharedStateRegistry.class));
+			verify(taskOperatorSubtaskStates2, never()).registerSharedStates(any(SharedStateRegistry.class));
 
 			// acknowledge the same task again (should not matter)
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint1);
@@ -588,8 +583,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			verify(subtaskState2, never()).registerSharedStates(any(SharedStateRegistry.class));
 
 			// acknowledge the other task.
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class)));
-			OperatorSubtaskState subtaskState1 = operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex());
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), taskOperatorSubtaskStates1));
 
 			// the checkpoint is internally converted to a successful checkpoint and the
 			// pending checkpoint object is disposed
@@ -628,9 +622,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
 			long checkpointIdNew = coord.getPendingCheckpoints().entrySet().iterator().next().getKey();
 			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointIdNew));
-			subtaskState1 = operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex());
 			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID2, checkpointIdNew));
-			subtaskState2 = operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex());
 
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
 			assertEquals(1, coord.getNumberOfRetainedSuccessfulCheckpoints());
@@ -852,18 +844,20 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			OperatorID opID2 = OperatorID.fromJobVertexID(ackVertex2.getJobvertexId());
 			OperatorID opID3 = OperatorID.fromJobVertexID(ackVertex3.getJobvertexId());
 
-			Map<OperatorID, OperatorState> operatorStates1 = pending1.getOperatorStates();
+			TaskStateSnapshot taskOperatorSubtaskStates1_1 = spy(new TaskStateSnapshot());
+			TaskStateSnapshot taskOperatorSubtaskStates1_2 = spy(new TaskStateSnapshot());
+			TaskStateSnapshot taskOperatorSubtaskStates1_3 = spy(new TaskStateSnapshot());
 
-			operatorStates1.put(opID1, new SpyInjectingOperatorState(
-				opID1, ackVertex1.getTotalNumberOfParallelSubtasks(), ackVertex1.getMaxParallelism()));
-			operatorStates1.put(opID2, new SpyInjectingOperatorState(
-				opID2, ackVertex2.getTotalNumberOfParallelSubtasks(), ackVertex2.getMaxParallelism()));
-			operatorStates1.put(opID3, new SpyInjectingOperatorState(
-				opID3, ackVertex3.getTotalNumberOfParallelSubtasks(), ackVertex3.getMaxParallelism()));
+			OperatorSubtaskState subtaskState1_1 = mock(OperatorSubtaskState.class);
+			OperatorSubtaskState subtaskState1_2 = mock(OperatorSubtaskState.class);
+			OperatorSubtaskState subtaskState1_3 = mock(OperatorSubtaskState.class);
+			taskOperatorSubtaskStates1_1.putSubtaskStateByOperatorID(opID1, subtaskState1_1);
+			taskOperatorSubtaskStates1_2.putSubtaskStateByOperatorID(opID2, subtaskState1_2);
+			taskOperatorSubtaskStates1_3.putSubtaskStateByOperatorID(opID3, subtaskState1_3);
 
 			// acknowledge one of the three tasks
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1, new CheckpointMetrics(), mock(SubtaskState.class)));
-			OperatorSubtaskState subtaskState1_2 = operatorStates1.get(opID2).getState(ackVertex2.getParallelSubtaskIndex());
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates1_2));
+
 			// start the second checkpoint
 			// trigger the first checkpoint. this should succeed
 			assertTrue(coord.triggerCheckpoint(timestamp2, false));
@@ -880,14 +874,17 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			}
 			long checkpointId2 = pending2.getCheckpointId();
 
-			Map<OperatorID, OperatorState> operatorStates2 = pending2.getOperatorStates();
+			TaskStateSnapshot taskOperatorSubtaskStates2_1 = spy(new TaskStateSnapshot());
+			TaskStateSnapshot taskOperatorSubtaskStates2_2 = spy(new TaskStateSnapshot());
+			TaskStateSnapshot taskOperatorSubtaskStates2_3 = spy(new TaskStateSnapshot());
+
+			OperatorSubtaskState subtaskState2_1 = mock(OperatorSubtaskState.class);
+			OperatorSubtaskState subtaskState2_2 = mock(OperatorSubtaskState.class);
+			OperatorSubtaskState subtaskState2_3 = mock(OperatorSubtaskState.class);
 
-			operatorStates2.put(opID1, new SpyInjectingOperatorState(
-				opID1, ackVertex1.getTotalNumberOfParallelSubtasks(), ackVertex1.getMaxParallelism()));
-			operatorStates2.put(opID2, new SpyInjectingOperatorState(
-				opID2, ackVertex2.getTotalNumberOfParallelSubtasks(), ackVertex2.getMaxParallelism()));
-			operatorStates2.put(opID3, new SpyInjectingOperatorState(
-				opID3, ackVertex3.getTotalNumberOfParallelSubtasks(), ackVertex3.getMaxParallelism()));
+			taskOperatorSubtaskStates2_1.putSubtaskStateByOperatorID(opID1, subtaskState2_1);
+			taskOperatorSubtaskStates2_2.putSubtaskStateByOperatorID(opID2, subtaskState2_2);
+			taskOperatorSubtaskStates2_3.putSubtaskStateByOperatorID(opID3, subtaskState2_3);
 
 			// trigger messages should have been sent
 			verify(triggerVertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId2), eq(timestamp2), any(CheckpointOptions.class));
@@ -896,17 +893,13 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			// we acknowledge one more task from the first checkpoint and the second
 			// checkpoint completely. The second checkpoint should then subsume the first checkpoint
 
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId2, new CheckpointMetrics(), mock(SubtaskState.class)));
-			OperatorSubtaskState subtaskState2_3 = operatorStates2.get(opID3).getState(ackVertex3.getParallelSubtaskIndex());
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates2_3));
 
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId2, new CheckpointMetrics(), mock(SubtaskState.class)));
-			OperatorSubtaskState subtaskState2_1 = operatorStates2.get(opID1).getState(ackVertex1.getParallelSubtaskIndex());
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates2_1));
 
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId1, new CheckpointMetrics(), mock(SubtaskState.class)));
-			OperatorSubtaskState subtaskState1_1 = operatorStates1.get(opID1).getState(ackVertex1.getParallelSubtaskIndex());
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates1_1));
 
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId2, new CheckpointMetrics(), mock(SubtaskState.class)));
-			OperatorSubtaskState subtaskState2_2 = operatorStates2.get(opID2).getState(ackVertex2.getParallelSubtaskIndex());
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates2_2));
 
 			// now, the second checkpoint should be confirmed, and the first discarded
 			// actually both pending checkpoints are discarded, and the second has been transformed
@@ -938,8 +931,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			verify(commitVertex.getCurrentExecutionAttempt(), times(1)).notifyCheckpointComplete(eq(checkpointId2), eq(timestamp2));
 
 			// send the last remaining ack for the first checkpoint. This should not do anything
-			SubtaskState subtaskState1_3 = mock(SubtaskState.class);
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId1, new CheckpointMetrics(), subtaskState1_3));
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates1_3));
 			verify(subtaskState1_3, times(1)).discardState();
 
 			coord.shutdown(JobStatus.FINISHED);
@@ -1005,13 +997,11 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
 			OperatorID opID1 = OperatorID.fromJobVertexID(ackVertex1.getJobvertexId());
 
-			Map<OperatorID, OperatorState> operatorStates = checkpoint.getOperatorStates();
+			TaskStateSnapshot taskOperatorSubtaskStates1 = spy(new TaskStateSnapshot());
+			OperatorSubtaskState subtaskState1 = mock(OperatorSubtaskState.class);
+			taskOperatorSubtaskStates1.putSubtaskStateByOperatorID(opID1, subtaskState1);
 
-			operatorStates.put(opID1, new SpyInjectingOperatorState(
-				opID1, ackVertex1.getTotalNumberOfParallelSubtasks(), ackVertex1.getMaxParallelism()));
-
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpoint.getCheckpointId(), new CheckpointMetrics(), mock(SubtaskState.class)));
-			OperatorSubtaskState subtaskState = operatorStates.get(opID1).getState(ackVertex1.getParallelSubtaskIndex());
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpoint.getCheckpointId(), new CheckpointMetrics(), taskOperatorSubtaskStates1));
 
 			// wait until the checkpoint must have expired.
 			// we check every 250 msecs conservatively for 5 seconds
@@ -1029,7 +1019,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
 
 			// validate that the received states have been discarded
-			verify(subtaskState, times(1)).discardState();
+			verify(subtaskState1, times(1)).discardState();
 
 			// no confirm message must have been sent
 			verify(commitVertex.getCurrentExecutionAttempt(), times(0)).notifyCheckpointComplete(anyLong(), anyLong());
@@ -1147,26 +1137,18 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		long checkpointId = pendingCheckpoint.getCheckpointId();
 
 		OperatorID opIDtrigger = OperatorID.fromJobVertexID(triggerVertex.getJobvertexId());
-		OperatorID opID1 = OperatorID.fromJobVertexID(ackVertex1.getJobvertexId());
-		OperatorID opID2 = OperatorID.fromJobVertexID(ackVertex2.getJobvertexId());
-
-		Map<OperatorID, OperatorState> operatorStates = pendingCheckpoint.getOperatorStates();
 
-		operatorStates.put(opIDtrigger, new SpyInjectingOperatorState(
-			opIDtrigger, triggerVertex.getTotalNumberOfParallelSubtasks(), triggerVertex.getMaxParallelism()));
-		operatorStates.put(opID1, new SpyInjectingOperatorState(
-			opID1, ackVertex1.getTotalNumberOfParallelSubtasks(), ackVertex1.getMaxParallelism()));
-		operatorStates.put(opID2, new SpyInjectingOperatorState(
-			opID2, ackVertex2.getTotalNumberOfParallelSubtasks(), ackVertex2.getMaxParallelism()));
+		TaskStateSnapshot taskOperatorSubtaskStatesTrigger = spy(new TaskStateSnapshot());
+		OperatorSubtaskState subtaskStateTrigger = mock(OperatorSubtaskState.class);
+		taskOperatorSubtaskStatesTrigger.putSubtaskStateByOperatorID(opIDtrigger, subtaskStateTrigger);
 
 		// acknowledge the first trigger vertex
-		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class)));
-		OperatorSubtaskState storedTriggerSubtaskState = operatorStates.get(opIDtrigger).getState(triggerVertex.getParallelSubtaskIndex());
+		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new CheckpointMetrics(), taskOperatorSubtaskStatesTrigger));
 
 		// verify that the subtask state has not been discarded
-		verify(storedTriggerSubtaskState, never()).discardState();
+		verify(subtaskStateTrigger, never()).discardState();
 
-		SubtaskState unknownSubtaskState = mock(SubtaskState.class);
+		TaskStateSnapshot unknownSubtaskState = mock(TaskStateSnapshot.class);
 
 		// receive an acknowledge message for an unknown vertex
 		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), unknownSubtaskState));
@@ -1174,7 +1156,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		// we should discard acknowledge messages from an unknown vertex belonging to our job
 		verify(unknownSubtaskState, times(1)).discardState();
 
-		SubtaskState differentJobSubtaskState = mock(SubtaskState.class);
+		TaskStateSnapshot differentJobSubtaskState = mock(TaskStateSnapshot.class);
 
 		// receive an acknowledge message from an unknown job
 		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(new JobID(), new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), differentJobSubtaskState));
@@ -1183,22 +1165,22 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		verify(differentJobSubtaskState, never()).discardState();
 
 		// duplicate acknowledge message for the trigger vertex
-		SubtaskState triggerSubtaskState = mock(SubtaskState.class);
+		TaskStateSnapshot triggerSubtaskState = mock(TaskStateSnapshot.class);
 		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new CheckpointMetrics(), triggerSubtaskState));
 
 		// duplicate acknowledge messages for a known vertex should not trigger discarding the state
 		verify(triggerSubtaskState, never()).discardState();
 
 		// let the checkpoint fail at the first ack vertex
-		reset(storedTriggerSubtaskState);
+		reset(subtaskStateTrigger);
 		coord.receiveDeclineMessage(new DeclineCheckpoint(jobId, ackAttemptId1, checkpointId));
 
 		assertTrue(pendingCheckpoint.isDiscarded());
 
 		// check that we've cleaned up the already acknowledged state
-		verify(storedTriggerSubtaskState, times(1)).discardState();
+		verify(subtaskStateTrigger, times(1)).discardState();
 
-		SubtaskState ackSubtaskState = mock(SubtaskState.class);
+		TaskStateSnapshot ackSubtaskState = mock(TaskStateSnapshot.class);
 
 		// late acknowledge message from the second ack vertex
 		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, ackAttemptId2, checkpointId, new CheckpointMetrics(), ackSubtaskState));
@@ -1213,7 +1195,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		// we should not interfere with different jobs
 		verify(differentJobSubtaskState, never()).discardState();
 
-		SubtaskState unknownSubtaskState2 = mock(SubtaskState.class);
+		TaskStateSnapshot unknownSubtaskState2 = mock(TaskStateSnapshot.class);
 
 		// receive an acknowledge message for an unknown vertex
 		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), unknownSubtaskState2));
@@ -1470,18 +1452,16 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
 		OperatorID opID1 = OperatorID.fromJobVertexID(vertex1.getJobvertexId());
 		OperatorID opID2 = OperatorID.fromJobVertexID(vertex2.getJobvertexId());
-
-		Map<OperatorID, OperatorState> operatorStates = pending.getOperatorStates();
-
-		operatorStates.put(opID1, new SpyInjectingOperatorState(
-			opID1, vertex1.getTotalNumberOfParallelSubtasks(), vertex1.getMaxParallelism()));
-		operatorStates.put(opID2, new SpyInjectingOperatorState(
-			opID2, vertex2.getTotalNumberOfParallelSubtasks(), vertex1.getMaxParallelism()));
+		TaskStateSnapshot taskOperatorSubtaskStates1 = mock(TaskStateSnapshot.class);
+		TaskStateSnapshot taskOperatorSubtaskStates2 = mock(TaskStateSnapshot.class);
+		OperatorSubtaskState subtaskState1 = mock(OperatorSubtaskState.class);
+		OperatorSubtaskState subtaskState2 = mock(OperatorSubtaskState.class);
+		when(taskOperatorSubtaskStates1.getSubtaskStateByOperatorID(opID1)).thenReturn(subtaskState1);
+		when(taskOperatorSubtaskStates2.getSubtaskStateByOperatorID(opID2)).thenReturn(subtaskState2);
 
 		// acknowledge from one of the tasks
-		AcknowledgeCheckpoint acknowledgeCheckpoint2 = new AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class));
+		AcknowledgeCheckpoint acknowledgeCheckpoint2 = new AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), taskOperatorSubtaskStates2);
 		coord.receiveAcknowledgeMessage(acknowledgeCheckpoint2);
-		OperatorSubtaskState subtaskState2 = operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex());
 		assertEquals(1, pending.getNumberOfAcknowledgedTasks());
 		assertEquals(1, pending.getNumberOfNonAcknowledgedTasks());
 		assertFalse(pending.isDiscarded());
@@ -1495,8 +1475,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		assertFalse(savepointFuture.isDone());
 
 		// acknowledge the other task.
-		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class)));
-		OperatorSubtaskState subtaskState1 = operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex());
+		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), taskOperatorSubtaskStates1));
 
 		// the checkpoint is internally converted to a successful checkpoint and the
 		// pending checkpoint object is disposed
@@ -1536,9 +1515,6 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointIdNew));
 		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID2, checkpointIdNew));
 
-		subtaskState1 = operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex());
-		subtaskState2 = operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex());
-
 		assertEquals(0, coord.getNumberOfPendingCheckpoints());
 		assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
 
@@ -2037,20 +2013,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		List<KeyGroupRange> keyGroupPartitions1 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1);
 		List<KeyGroupRange> keyGroupPartitions2 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2);
 
-		PendingCheckpoint pending = coord.getPendingCheckpoints().get(checkpointId);
-
-		OperatorID opID1 = OperatorID.fromJobVertexID(jobVertexID1);
-		OperatorID opID2 = OperatorID.fromJobVertexID(jobVertexID2);
-
-		Map<OperatorID, OperatorState> operatorStates = pending.getOperatorStates();
-
-		operatorStates.put(opID1, new SpyInjectingOperatorState(
-			opID1, jobVertex1.getParallelism(), jobVertex1.getMaxParallelism()));
-		operatorStates.put(opID2, new SpyInjectingOperatorState(
-			opID2, jobVertex2.getParallelism(), jobVertex2.getMaxParallelism()));
-
 		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
-			SubtaskState subtaskState = mockSubtaskState(jobVertexID1, index, keyGroupPartitions1.get(index));
+			TaskStateSnapshot subtaskState = mockSubtaskState(jobVertexID1, index, keyGroupPartitions1.get(index));
 
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
@@ -2063,7 +2027,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		}
 
 		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
-			SubtaskState subtaskState = mockSubtaskState(jobVertexID2, index, keyGroupPartitions2.get(index));
+			TaskStateSnapshot subtaskState = mockSubtaskState(jobVertexID2, index, keyGroupPartitions2.get(index));
 
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
@@ -2165,30 +2129,34 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		List<KeyGroupRange> keyGroupPartitions2 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2);
 
 		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
-			ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID1, index);
+			StreamStateHandle valueSizeTuple = generateStateForVertex(jobVertexID1, index);
 			KeyGroupsStateHandle keyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), false);
-			SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, null, null, keyGroupState, null);
+			OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(valueSizeTuple, null, null, keyGroupState, null);
+			TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot();
+			taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID1), operatorSubtaskState);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
 					checkpointId,
 					new CheckpointMetrics(),
-					checkpointStateHandles);
+				taskOperatorSubtaskStates);
 
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
 
 
 		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
-			ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID2, index);
+			StreamStateHandle valueSizeTuple = generateStateForVertex(jobVertexID2, index);
 			KeyGroupsStateHandle keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false);
-			SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, null, null, keyGroupState, null);
+			OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(valueSizeTuple, null, null, keyGroupState, null);
+			TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot();
+			taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID2), operatorSubtaskState);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
 					checkpointId,
 					new CheckpointMetrics(),
-					checkpointStateHandles);
+					taskOperatorSubtaskStates);
 
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
@@ -2284,17 +2252,20 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2);
 
 		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
-			ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID1, index);
+			StreamStateHandle valueSizeTuple = generateStateForVertex(jobVertexID1, index);
 			KeyGroupsStateHandle keyGroupState = generateKeyGroupState(
 					jobVertexID1, keyGroupPartitions1.get(index), false);
 
-			SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, null, null, keyGroupState, null);
+			OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(valueSizeTuple, null, null, keyGroupState, null);
+			TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot();
+			taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID1), operatorSubtaskState);
+
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
 					checkpointId,
 					new CheckpointMetrics(),
-					checkpointStateHandles);
+					taskOperatorSubtaskStates);
 
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
@@ -2302,17 +2273,19 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
 		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
 
-			ChainedStateHandle<StreamStateHandle> state = generateStateForVertex(jobVertexID2, index);
+			StreamStateHandle state = generateStateForVertex(jobVertexID2, index);
 			KeyGroupsStateHandle keyGroupState = generateKeyGroupState(
 					jobVertexID2, keyGroupPartitions2.get(index), false);
 
-			SubtaskState checkpointStateHandles = new SubtaskState(state, null, null, keyGroupState, null);
+			OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(state, null, null, keyGroupState, null);
+			TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot();
+			taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID2), operatorSubtaskState);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
 					checkpointId,
 					new CheckpointMetrics(),
-					checkpointStateHandles);
+					taskOperatorSubtaskStates);
 
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
@@ -2438,18 +2411,21 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
 		//vertex 1
 		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
-			ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID1, index);
-			ChainedStateHandle<OperatorStateHandle> opStateBackend = generateChainedPartitionableStateHandle(jobVertexID1, index, 2, 8, false);
+			StreamStateHandle valueSizeTuple = generateStateForVertex(jobVertexID1, index);
+			OperatorStateHandle opStateBackend = generatePartitionableStateHandle(jobVertexID1, index, 2, 8, false);
 			KeyGroupsStateHandle keyedStateBackend = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), false);
 			KeyGroupsStateHandle keyedStateRaw = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), true);
 
-			SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, opStateBackend, null, keyedStateBackend, keyedStateRaw);
+			OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(valueSizeTuple, opStateBackend, null, keyedStateBackend, keyedStateRaw);
+			TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot();
+			taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID1), operatorSubtaskState);
+
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
 					checkpointId,
 					new CheckpointMetrics(),
-					checkpointStateHandles);
+					taskOperatorSubtaskStates);
 
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
@@ -2460,19 +2436,21 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
 			KeyGroupsStateHandle keyedStateBackend = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false);
 			KeyGroupsStateHandle keyedStateRaw = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), true);
-			ChainedStateHandle<OperatorStateHandle> opStateBackend = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8, false);
-			ChainedStateHandle<OperatorStateHandle> opStateRaw = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8, true);
-			expectedOpStatesBackend.add(opStateBackend);
-			expectedOpStatesRaw.add(opStateRaw);
-			SubtaskState checkpointStateHandles =
-					new SubtaskState(new ChainedStateHandle<>(
-							Collections.<StreamStateHandle>singletonList(null)), opStateBackend, opStateRaw, keyedStateBackend, keyedStateRaw);
+			OperatorStateHandle opStateBackend = generatePartitionableStateHandle(jobVertexID2, index, 2, 8, false);
+			OperatorStateHandle opStateRaw = generatePartitionableStateHandle(jobVertexID2, index, 2, 8, true);
+			expectedOpStatesBackend.add(new ChainedStateHandle<>(Collections.singletonList(opStateBackend)));
+			expectedOpStatesRaw.add(new ChainedStateHandle<>(Collections.singletonList(opStateRaw)));
+
+			OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(null, opStateBackend, opStateRaw, keyedStateBackend, keyedStateRaw);
+			TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot();
+			taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID2), operatorSubtaskState);
+
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
 					checkpointId,
 					new CheckpointMetrics(),
-					checkpointStateHandles);
+					taskOperatorSubtaskStates);
 
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
@@ -2506,27 +2484,37 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		List<List<Collection<OperatorStateHandle>>> actualOpStatesBackend = new ArrayList<>(newJobVertex2.getParallelism());
 		List<List<Collection<OperatorStateHandle>>> actualOpStatesRaw = new ArrayList<>(newJobVertex2.getParallelism());
 		for (int i = 0; i < newJobVertex2.getParallelism(); i++) {
-			KeyGroupsStateHandle originalKeyedStateBackend = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), false);
-			KeyGroupsStateHandle originalKeyedStateRaw = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), true);
 
-			TaskStateHandles taskStateHandles = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles();
+			List<OperatorID> operatorIDs = newJobVertex2.getOperatorIDs();
 
-			ChainedStateHandle<StreamStateHandle> operatorState = taskStateHandles.getLegacyOperatorState();
-			List<Collection<OperatorStateHandle>> opStateBackend = taskStateHandles.getManagedOperatorState();
-			List<Collection<OperatorStateHandle>> opStateRaw = taskStateHandles.getRawOperatorState();
-			Collection<KeyedStateHandle> keyedStateBackend = taskStateHandles.getManagedKeyedState();
-			Collection<KeyedStateHandle> keyGroupStateRaw = taskStateHandles.getRawKeyedState();
+			KeyGroupsStateHandle originalKeyedStateBackend = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), false);
+			KeyGroupsStateHandle originalKeyedStateRaw = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), true);
 
-			actualOpStatesBackend.add(opStateBackend);
-			actualOpStatesRaw.add(opStateRaw);
-			// the 'non partition state' is not null because it is recombined.
-			assertNotNull(operatorState);
-			for (int index = 0; index < operatorState.getLength(); index++) {
-				assertNull(operatorState.get(index));
+			TaskStateSnapshot taskStateHandles = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateSnapshot();
+
+			final int headOpIndex = operatorIDs.size() - 1;
+			List<Collection<OperatorStateHandle>> allParallelManagedOpStates = new ArrayList<>(operatorIDs.size());
+			List<Collection<OperatorStateHandle>> allParallelRawOpStates = new ArrayList<>(operatorIDs.size());
+
+			for (int idx = 0; idx < operatorIDs.size(); ++idx) {
+				OperatorID operatorID = operatorIDs.get(idx);
+				OperatorSubtaskState opState = taskStateHandles.getSubtaskStateByOperatorID(operatorID);
+				Assert.assertNull(opState.getLegacyOperatorState());
+				Collection<OperatorStateHandle> opStateBackend = opState.getManagedOperatorState();
+				Collection<OperatorStateHandle> opStateRaw = opState.getRawOperatorState();
+				allParallelManagedOpStates.add(opStateBackend);
+				allParallelRawOpStates.add(opStateRaw);
+				if (idx == headOpIndex) {
+					Collection<KeyedStateHandle> keyedStateBackend = opState.getManagedKeyedState();
+					Collection<KeyedStateHandle> keyGroupStateRaw = opState.getRawKeyedState();
+					compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyedStateBackend);
+					compareKeyedState(Collections.singletonList(originalKeyedStateRaw), keyGroupStateRaw);
+				}
 			}
-			compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyedStateBackend);
-			compareKeyedState(Collections.singletonList(originalKeyedStateRaw), keyGroupStateRaw);
+			actualOpStatesBackend.add(allParallelManagedOpStates);
+			actualOpStatesRaw.add(allParallelRawOpStates);
 		}
+
 		comparePartitionableState(expectedOpStatesBackend, actualOpStatesBackend);
 		comparePartitionableState(expectedOpStatesRaw, actualOpStatesRaw);
 	}
@@ -2578,14 +2566,11 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			operatorStates.put(id.f1, taskState);
 			for (int index = 0; index < taskState.getParallelism(); index++) {
 				StreamStateHandle subNonPartitionedState = 
-					generateStateForVertex(id.f0, index)
-						.get(0);
+					generateStateForVertex(id.f0, index);
 				OperatorStateHandle subManagedOperatorState =
-					generateChainedPartitionableStateHandle(id.f0, index, 2, 8, false)
-						.get(0);
+					generatePartitionableStateHandle(id.f0, index, 2, 8, false);
 				OperatorStateHandle subRawOperatorState =
-					generateChainedPartitionableStateHandle(id.f0, index, 2, 8, true)
-						.get(0);
+					generatePartitionableStateHandle(id.f0, index, 2, 8, true);
 
 				OperatorSubtaskState subtaskState = new OperatorSubtaskState(subNonPartitionedState,
 					subManagedOperatorState,
@@ -2707,57 +2692,75 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
 		for (int i = 0; i < newJobVertex1.getParallelism(); i++) {
 
-			TaskStateHandles taskStateHandles = newJobVertex1.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles();
-			ChainedStateHandle<StreamStateHandle> actualSubNonPartitionedState = taskStateHandles.getLegacyOperatorState();
-			List<Collection<OperatorStateHandle>> actualSubManagedOperatorState = taskStateHandles.getManagedOperatorState();
-			List<Collection<OperatorStateHandle>> actualSubRawOperatorState = taskStateHandles.getRawOperatorState();
+			final List<OperatorID> operatorIds = newJobVertex1.getOperatorIDs();
 
-			assertNull(taskStateHandles.getManagedKeyedState());
-			assertNull(taskStateHandles.getRawKeyedState());
+			TaskStateSnapshot stateSnapshot = newJobVertex1.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateSnapshot();
+
+			OperatorSubtaskState headOpState = stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIds.size() - 1));
+			assertTrue(headOpState.getManagedKeyedState().isEmpty());
+			assertTrue(headOpState.getRawKeyedState().isEmpty());
 
 			// operator5
 			{
 				int operatorIndexInChain = 2;
-				assertNull(actualSubNonPartitionedState.get(operatorIndexInChain));
-				assertNull(actualSubManagedOperatorState.get(operatorIndexInChain));
-				assertNull(actualSubRawOperatorState.get(operatorIndexInChain));
+				OperatorSubtaskState opState =
+					stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain));
+
+				assertNull(opState.getLegacyOperatorState());
+				assertTrue(opState.getManagedOperatorState().isEmpty());
+				assertTrue(opState.getRawOperatorState().isEmpty());
 			}
 			// operator1
 			{
 				int operatorIndexInChain = 1;
-				ChainedStateHandle<StreamStateHandle> expectSubNonPartitionedState = generateStateForVertex(id1.f0, i);
-				ChainedStateHandle<OperatorStateHandle> expectedManagedOpState = generateChainedPartitionableStateHandle(
+				OperatorSubtaskState opState =
+					stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain));
+
+				StreamStateHandle expectSubNonPartitionedState = generateStateForVertex(id1.f0, i);
+				OperatorStateHandle expectedManagedOpState = generatePartitionableStateHandle(
 					id1.f0, i, 2, 8, false);
-				ChainedStateHandle<OperatorStateHandle> expectedRawOpState = generateChainedPartitionableStateHandle(
+				OperatorStateHandle expectedRawOpState = generatePartitionableStateHandle(
 					id1.f0, i, 2, 8, true);
 
 				assertTrue(CommonTestUtils.isSteamContentEqual(
-					expectSubNonPartitionedState.get(0).openInputStream(),
-					actualSubNonPartitionedState.get(operatorIndexInChain).openInputStream()));
-
-				assertTrue(CommonTestUtils.isSteamContentEqual(expectedManagedOpState.get(0).openInputStream(),
-					actualSubManagedOperatorState.get(operatorIndexInChain).iterator().next().openInputStream()));
-
-				assertTrue(CommonTestUtils.isSteamContentEqual(expectedRawOpState.get(0).openInputStream(),
-					actualSubRawOperatorState.get(operatorIndexInChain).iterator().next().openInputStream()));
+					expectSubNonPartitionedState.openInputStream(),
+					opState.getLegacyOperatorState().openInputStream()));
+
+				Collection<OperatorStateHandle> managedOperatorState = opState.getManagedOperatorState();
+				assertEquals(1, managedOperatorState.size());
+				assertTrue(CommonTestUtils.isSteamContentEqual(expectedManagedOpState.openInputStream(),
+					managedOperatorState.iterator().next().openInputStream()));
+
+				Collection<OperatorStateHandle> rawOperatorState = opState.getRawOperatorState();
+				assertEquals(1, rawOperatorState.size());
+				assertTrue(CommonTestUtils.isSteamContentEqual(expectedRawOpState.openInputStream(),
+					rawOperatorState.iterator().next().openInputStream()));
 			}
 			// operator2
 			{
 				int operatorIndexInChain = 0;
-				ChainedStateHandle<StreamStateHandle> expectSubNonPartitionedState = generateStateForVertex(id2.f0, i);
-				ChainedStateHandle<OperatorStateHandle> expectedManagedOpState = generateChainedPartitionableStateHandle(
+				OperatorSubtaskState opState =
+					stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain));
+
+				StreamStateHandle expectSubNonPartitionedState = generateStateForVertex(id2.f0, i);
+				OperatorStateHandle expectedManagedOpState = generatePartitionableStateHandle(
 					id2.f0, i, 2, 8, false);
-				ChainedStateHandle<OperatorStateHandle> expectedRawOpState = generateChainedPartitionableStateHandle(
+				OperatorStateHandle expectedRawOpState = generatePartitionableStateHandle(
 					id2.f0, i, 2, 8, true);
 
-				assertTrue(CommonTestUtils.isSteamContentEqual(expectSubNonPartitionedState.get(0).openInputStream(),
-					actualSubNonPartitionedState.get(operatorIndexInChain).openInputStream()));
-
-				assertTrue(CommonTestUtils.isSteamContentEqual(expectedManagedOpState.get(0).openInputStream(),
-					actualSubManagedOperatorState.get(operatorIndexInChain).iterator().next().openInputStream()));
-
-				assertTrue(CommonTestUtils.isSteamContentEqual(expectedRawOpState.get(0).openInputStream(),
-					actualSubRawOperatorState.get(operatorIndexInChain).iterator().next().openInputStream()));
+				assertTrue(CommonTestUtils.isSteamContentEqual(
+					expectSubNonPartitionedState.openInputStream(),
+					opState.getLegacyOperatorState().openInputStream()));
+
+				Collection<OperatorStateHandle> managedOperatorState = opState.getManagedOperatorState();
+				assertEquals(1, managedOperatorState.size());
+				assertTrue(CommonTestUtils.isSteamContentEqual(expectedManagedOpState.openInputStream(),
+					managedOperatorState.iterator().next().openInputStream()));
+
+				Collection<OperatorStateHandle> rawOperatorState = opState.getRawOperatorState();
+				assertEquals(1, rawOperatorState.size());
+				assertTrue(CommonTestUtils.isSteamContentEqual(expectedRawOpState.openInputStream(),
+					rawOperatorState.iterator().next().openInputStream()));
 			}
 		}
 
@@ -2765,38 +2768,48 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		List<List<Collection<OperatorStateHandle>>> actualRawOperatorStates = new ArrayList<>(newJobVertex2.getParallelism());
 
 		for (int i = 0; i < newJobVertex2.getParallelism(); i++) {
-			TaskStateHandles taskStateHandles = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles();
+
+			final List<OperatorID> operatorIds = newJobVertex2.getOperatorIDs();
+
+			TaskStateSnapshot stateSnapshot = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateSnapshot();
 
 			// operator 3
 			{
 				int operatorIndexInChain = 1;
+				OperatorSubtaskState opState =
+					stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain));
+
 				List<Collection<OperatorStateHandle>> actualSubManagedOperatorState = new ArrayList<>(1);
-				actualSubManagedOperatorState.add(taskStateHandles.getManagedOperatorState().get(operatorIndexInChain));
+				actualSubManagedOperatorState.add(opState.getManagedOperatorState());
 
 				List<Collection<OperatorStateHandle>> actualSubRawOperatorState = new ArrayList<>(1);
-				actualSubRawOperatorState.add(taskStateHandles.getRawOperatorState().get(operatorIndexInChain));
+				actualSubRawOperatorState.add(opState.getRawOperatorState());
 
 				actualManagedOperatorStates.add(actualSubManagedOperatorState);
 				actualRawOperatorStates.add(actualSubRawOperatorState);
 
-				assertNull(taskStateHandles.getLegacyOperatorState().get(operatorIndexInChain));
+				assertNull(opState.getLegacyOperatorState());
 			}
 
 			// operator 6
 			{
 				int operatorIndexInChain = 0;
-				assertNull(taskStateHandles.getManagedOperatorState().get(operatorIndexInChain));
-				assertNull(taskStateHandles.getRawOperatorState().get(operatorIndexInChain));
-				assertNull(taskStateHandles.getLegacyOperatorState().get(operatorIndexInChain));
+				OperatorSubtaskState opState =
+					stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain));
+				assertNull(opState.getLegacyOperatorState());
+				assertTrue(opState.getManagedOperatorState().isEmpty());
+				assertTrue(opState.getRawOperatorState().isEmpty());
 
 			}
 
 			KeyGroupsStateHandle originalKeyedStateBackend = generateKeyGroupState(id3.f0, newKeyGroupPartitions2.get(i), false);
 			KeyGroupsStateHandle originalKeyedStateRaw = generateKeyGroupState(id3.f0, newKeyGroupPartitions2.get(i), true);
 
+			OperatorSubtaskState headOpState =
+				stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIds.size() - 1));
 
-			Collection<KeyedStateHandle> keyedStateBackend = taskStateHandles.getManagedKeyedState();
-			Collection<KeyedStateHandle> keyGroupStateRaw = taskStateHandles.getRawKeyedState();
+			Collection<KeyedStateHandle> keyedStateBackend = headOpState.getManagedKeyedState();
+			Collection<KeyedStateHandle> keyGroupStateRaw = headOpState.getRawKeyedState();
 
 
 			compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyedStateBackend);
@@ -2974,19 +2987,50 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		return new Tuple2<>(allSerializedValuesConcatenated, offsets);
 	}
 
-	public static ChainedStateHandle<StreamStateHandle> generateStateForVertex(
+	public static StreamStateHandle generateStateForVertex(
 			JobVertexID jobVertexID,
 			int index) throws IOException {
 
 		Random random = new Random(jobVertexID.hashCode() + index);
 		int value = random.nextInt();
-		return generateChainedStateHandle(value);
+		return generateStreamStateHandle(value);
+	}
+
+	public static StreamStateHandle generateStreamStateHandle(Serializable value) throws IOException {
+		return TestByteStreamStateHandleDeepCompare.fromSerializable(String.valueOf(UUID.randomUUID()), value);
 	}
 
 	public static ChainedStateHandle<StreamStateHandle> generateChainedStateHandle(
 			Serializable value) throws IOException {
 		return ChainedStateHandle.wrapSingleHandle(
-				TestByteStreamStateHandleDeepCompare.fromSerializable(String.valueOf(UUID.randomUUID()), value));
+				generateStreamStateHandle(value));
+	}
+
+	public static OperatorStateHandle generatePartitionableStateHandle(
+		JobVertexID jobVertexID,
+		int index,
+		int namedStates,
+		int partitionsPerState,
+		boolean rawState) throws IOException {
+
+		Map<String, List<? extends Serializable>> statesListsMap = new HashMap<>(namedStates);
+
+		for (int i = 0; i < namedStates; ++i) {
+			List<Integer> testStatesLists = new ArrayList<>(partitionsPerState);
+			// generate state
+			int seed = jobVertexID.hashCode() * index + i * namedStates;
+			if (rawState) {
+				seed = (seed + 1) * 31;
+			}
+			Random random = new Random(seed);
+			for (int j = 0; j < partitionsPerState; ++j) {
+				int simulatedStateValue = random.nextInt();
+				testStatesLists.add(simulatedStateValue);
+			}
+			statesListsMap.put("state-" + i, testStatesLists);
+		}
+
+		return generatePartitionableStateHandle(statesListsMap);
 	}
 
 	public static ChainedStateHandle<OperatorStateHandle> generateChainedPartitionableStateHandle(
@@ -3013,11 +3057,11 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			statesListsMap.put("state-" + i, testStatesLists);
 		}
 
-		return generateChainedPartitionableStateHandle(statesListsMap);
+		return ChainedStateHandle.wrapSingleHandle(generatePartitionableStateHandle(statesListsMap));
 	}
 
-	private static ChainedStateHandle<OperatorStateHandle> generateChainedPartitionableStateHandle(
-			Map<String, List<? extends Serializable>> states) throws IOException {
+	private static OperatorStateHandle generatePartitionableStateHandle(
+		Map<String, List<? extends Serializable>> states) throws IOException {
 
 		List<List<? extends Serializable>> namedStateSerializables = new ArrayList<>(states.size());
 
@@ -3032,20 +3076,18 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		int idx = 0;
 		for (Map.Entry<String, List<? extends Serializable>> entry : states.entrySet()) {
 			offsetsMap.put(
-					entry.getKey(),
-					new OperatorStateHandle.StateMetaInfo(
-							serializationWithOffsets.f1.get(idx),
-							OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
+				entry.getKey(),
+				new OperatorStateHandle.StateMetaInfo(
+					serializationWithOffsets.f1.get(idx),
+					OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
 			++idx;
 		}
 
 		ByteStreamStateHandle streamStateHandle = new TestByteStreamStateHandleDeepCompare(
-				String.valueOf(UUID.randomUUID()),
-				serializationWithOffsets.f0);
+			String.valueOf(UUID.randomUUID()),
+			serializationWithOffsets.f0);
 
-		OperatorStateHandle operatorStateHandle =
-				new OperatorStateHandle(offsetsMap, streamStateHandle);
-		return ChainedStateHandle.wrapSingleHandle(operatorStateHandle);
+		return new OperatorStateHandle(offsetsMap, streamStateHandle);
 	}
 
 	static ExecutionJobVertex mockExecutionJobVertex(
@@ -3139,24 +3181,23 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		return vertex;
 	}
 
-	static SubtaskState mockSubtaskState(
+	static TaskStateSnapshot mockSubtaskState(
 		JobVertexID jobVertexID,
 		int index,
 		KeyGroupRange keyGroupRange) throws IOException {
 
-		ChainedStateHandle<StreamStateHandle> nonPartitionedState = generateStateForVertex(jobVertexID, index);
-		ChainedStateHandle<OperatorStateHandle> partitionableState = generateChainedPartitionableStateHandle(jobVertexID, index, 2, 8, false);
+		StreamStateHandle nonPartitionedState = generateStateForVertex(jobVertexID, index);
+		OperatorStateHandle partitionableState = generatePartitionableStateHandle(jobVertexID, index, 2, 8, false);
 		KeyGroupsStateHandle partitionedKeyGroupState = generateKeyGroupState(jobVertexID, keyGroupRange, false);
 
-		SubtaskState subtaskState = mock(SubtaskState.class, withSettings().serializable());
+		TaskStateSnapshot subtaskStates = spy(new TaskStateSnapshot());
+		OperatorSubtaskState subtaskState = spy(new OperatorSubtaskState(
+			nonPartitionedState, partitionableState, null, partitionedKeyGroupState, null)
+		);
 
-		doReturn(nonPartitionedState).when(subtaskState).getLegacyOperatorState();
-		doReturn(partitionableState).when(subtaskState).getManagedOperatorState();
-		doReturn(null).when(subtaskState).getRawOperatorState();
-		doReturn(partitionedKeyGroupState).when(subtaskState).getManagedKeyedState();
-		doReturn(null).when(subtaskState).getRawKeyedState();
+		subtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID), subtaskState);
 
-		return subtaskState;
+		return subtaskStates;
 	}
 
 	public static void verifyStateRestore(
@@ -3165,27 +3206,27 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
 		for (int i = 0; i < executionJobVertex.getParallelism(); i++) {
 
-			TaskStateHandles taskStateHandles = executionJobVertex.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles();
+			final List<OperatorID> operatorIds = executionJobVertex.getOperatorIDs();
 
-			ChainedStateHandle<StreamStateHandle> expectNonPartitionedState = generateStateForVertex(jobVertexID, i);
-			ChainedStateHandle<StreamStateHandle> actualNonPartitionedState = taskStateHandles.getLegacyOperatorState();
+			TaskStateSnapshot stateSnapshot = executionJobVertex.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateSnapshot();
+
+			OperatorSubtaskState operatorState = stateSnapshot.getSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID));
+
+			StreamStateHandle expectNonPartitionedState = generateStateForVertex(jobVertexID, i);
 			assertTrue(CommonTestUtils.isSteamContentEqual(
-					expectNonPartitionedState.get(0).openInputStream(),
-					actualNonPartitionedState.get(0).openInputStream()));
+					expectNonPartitionedState.openInputStream(),
+				operatorState.getLegacyOperatorState().openInputStream()));
 
 			ChainedStateHandle<OperatorStateHandle> expectedOpStateBackend =
 					generateChainedPartitionableStateHandle(jobVertexID, i, 2, 8, false);
 
-			List<Collection<OperatorStateHandle>> actualPartitionableState = taskStateHandles.getManagedOperatorState();
-
 			assertTrue(CommonTestUtils.isSteamContentEqual(
 					expectedOpStateBackend.get(0).openInputStream(),
-					actualPartitionableState.get(0).iterator().next().openInputStream()));
+					operatorState.getManagedOperatorState().iterator().next().openInputStream()));
 
 			KeyGroupsStateHandle expectPartitionedKeyGroupState = generateKeyGroupState(
 					jobVertexID, keyGroupPartitions.get(i), false);
-			Collection<KeyedStateHandle> actualPartitionedKeyGroupState = taskStateHandles.getManagedKeyedState();
-			compareKeyedState(Collections.singletonList(expectPartitionedKeyGroupState), actualPartitionedKeyGroupState);
+			compareKeyedState(Collections.singletonList(expectPartitionedKeyGroupState), operatorState.getManagedKeyedState());
 		}
 	}
 
@@ -3632,17 +3673,4 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			"The latest completed (proper) checkpoint should have been added to the completed checkpoint store.",
 			completedCheckpointStore.getLatestCheckpoint().getCheckpointID() == checkpointIDCounter.getLast());
 	}
-
-	private static final class SpyInjectingOperatorState extends OperatorState {
-
-		private static final long serialVersionUID = -4004437428483663815L;
-
-		public SpyInjectingOperatorState(OperatorID taskID, int parallelism, int maxParallelism) {
-			super(taskID, parallelism, maxParallelism);
-		}
-
-		public void putState(int subtaskIndex, OperatorSubtaskState subtaskState) {
-			super.putState(subtaskIndex, spy(subtaskState));
-		}
-	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
index 7d24568..6ce071b 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
@@ -34,18 +34,18 @@ import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
-import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.util.SerializableObject;
+
 import org.hamcrest.BaseMatcher;
 import org.hamcrest.Description;
 import org.junit.Test;
 import org.mockito.Mockito;
 
-import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.fail;
@@ -118,10 +118,20 @@ public class CheckpointStateRestoreTest {
 			PendingCheckpoint pending = coord.getPendingCheckpoints().values().iterator().next();
 			final long checkpointId = pending.getCheckpointId();
 
-			SubtaskState checkpointStateHandles = new SubtaskState(serializedState, null, null, serializedKeyGroupStates, null);
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, new CheckpointMetrics(), checkpointStateHandles));
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointId, new CheckpointMetrics(), checkpointStateHandles));
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, new CheckpointMetrics(), checkpointStateHandles));
+			final TaskStateSnapshot subtaskStates = new TaskStateSnapshot();
+
+			subtaskStates.putSubtaskStateByOperatorID(
+				OperatorID.fromJobVertexID(statefulId),
+				new OperatorSubtaskState(
+					serializedState.get(0),
+					Collections.<OperatorStateHandle>emptyList(),
+					Collections.<OperatorStateHandle>emptyList(),
+					Collections.singletonList(serializedKeyGroupStates),
+					Collections.<KeyedStateHandle>emptyList()));
+
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, new CheckpointMetrics(), subtaskStates));
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointId, new CheckpointMetrics(), subtaskStates));
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, new CheckpointMetrics(), subtaskStates));
 			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec1.getAttemptId(), checkpointId));
 			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec2.getAttemptId(), checkpointId));
 
@@ -133,33 +143,26 @@ public class CheckpointStateRestoreTest {
 
 			// verify that each stateful vertex got the state
 
-			final TaskStateHandles taskStateHandles = new TaskStateHandles(
-					serializedState,
-					Collections.<Collection<OperatorStateHandle>>singletonList(null),
-					Collections.<Collection<OperatorStateHandle>>singletonList(null),
-					Collections.singletonList(serializedKeyGroupStates),
-					null);
-
-			BaseMatcher<TaskStateHandles> matcher = new BaseMatcher<TaskStateHandles>() {
+			BaseMatcher<TaskStateSnapshot> matcher = new BaseMatcher<TaskStateSnapshot>() {
 				@Override
 				public boolean matches(Object o) {
-					if (o instanceof TaskStateHandles) {
-						return o.equals(taskStateHandles);
+					if (o instanceof TaskStateSnapshot) {
+						return Objects.equals(o, subtaskStates);
 					}
 					return false;
 				}
 
 				@Override
 				public void describeTo(Description description) {
-					description.appendValue(taskStateHandles);
+					description.appendValue(subtaskStates);
 				}
 			};
 
 			verify(statefulExec1, times(1)).setInitialState(Mockito.argThat(matcher));
 			verify(statefulExec2, times(1)).setInitialState(Mockito.argThat(matcher));
 			verify(statefulExec3, times(1)).setInitialState(Mockito.argThat(matcher));
-			verify(statelessExec1, times(0)).setInitialState(Mockito.<TaskStateHandles>any());
-			verify(statelessExec2, times(0)).setInitialState(Mockito.<TaskStateHandles>any());
+			verify(statelessExec1, times(0)).setInitialState(Mockito.<TaskStateSnapshot>any());
+			verify(statelessExec2, times(0)).setInitialState(Mockito.<TaskStateSnapshot>any());
 		}
 		catch (Exception e) {
 			e.printStackTrace();
@@ -250,9 +253,9 @@ public class CheckpointStateRestoreTest {
 		Map<OperatorID, OperatorState> checkpointTaskStates = new HashMap<>();
 		{
 			OperatorState taskState = new OperatorState(operatorId1, 3, 3);
-			taskState.putState(0, new OperatorSubtaskState(serializedState, null, null, null, null));
-			taskState.putState(1, new OperatorSubtaskState(serializedState, null, null, null, null));
-			taskState.putState(2, new OperatorSubtaskState(serializedState, null, null, null, null));
+			taskState.putState(0, new OperatorSubtaskState(serializedState));
+			taskState.putState(1, new OperatorSubtaskState(serializedState));
+			taskState.putState(2, new OperatorSubtaskState(serializedState));
 
 			checkpointTaskStates.put(operatorId1, taskState);
 		}
@@ -279,7 +282,7 @@ public class CheckpointStateRestoreTest {
 		// There is no task for this
 		{
 			OperatorState taskState = new OperatorState(newOperatorID, 1, 1);
-			taskState.putState(0, new OperatorSubtaskState(serializedState, null, null, null, null));
+			taskState.putState(0, new OperatorSubtaskState(serializedState));
 
 			checkpointTaskStates.put(newOperatorID, taskState);
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
index 1fe4e65..320dc2d 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
@@ -331,7 +331,7 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 		boolean discarded;
 
 		public TestOperatorSubtaskState() {
-			super(null, null, null, null, null);
+			super();
 			this.registered = false;
 			this.discarded = false;
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
index 7d103d0..7ebb49a 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
@@ -324,7 +324,7 @@ public class PendingCheckpointTest {
 	@Test
 	public void testNonNullSubtaskStateLeadsToStatefulTask() throws Exception {
 		PendingCheckpoint pending = createPendingCheckpoint(CheckpointProperties.forStandardCheckpoint(), null);
-		pending.acknowledgeTask(ATTEMPT_ID, mock(SubtaskState.class), mock(CheckpointMetrics.class));
+		pending.acknowledgeTask(ATTEMPT_ID, mock(TaskStateSnapshot.class), mock(CheckpointMetrics.class));
 		Assert.assertFalse(pending.getOperatorStates().isEmpty());
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java
index 36c9cad..9ed4851 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java
@@ -23,6 +23,7 @@ import org.apache.flink.api.common.JobID;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.core.testutils.CommonTestUtils;
 import org.apache.flink.runtime.blob.BlobKey;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.executiongraph.JobInformation;
@@ -30,7 +31,6 @@ import org.apache.flink.runtime.executiongraph.TaskInformation;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.runtime.operators.BatchTask;
-import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.util.SerializedValue;
 
 import org.junit.Test;
@@ -73,7 +73,7 @@ public class TaskDeploymentDescriptorTest {
 			final SerializedValue<TaskInformation> serializedJobVertexInformation = new SerializedValue<>(new TaskInformation(
 				vertexID, taskName, currentNumberOfSubtasks, numberOfKeyGroups, invokableClass.getName(), taskConfiguration));
 			final int targetSlotNumber = 47;
-			final TaskStateHandles taskStateHandles = new TaskStateHandles();
+			final TaskStateSnapshot taskStateHandles = new TaskStateSnapshot();
 
 			final TaskDeploymentDescriptor orig = new TaskDeploymentDescriptor(
 				serializedJobInformation,

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java
index 0eed90d..c9b7a40 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java
@@ -23,6 +23,7 @@ import org.apache.flink.api.common.time.Time;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.metrics.groups.UnregisteredMetricsGroup;
 import org.apache.flink.runtime.checkpoint.StandaloneCheckpointRecoveryFactory;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.clusterframework.types.ResourceID;
 import org.apache.flink.runtime.clusterframework.types.ResourceProfile;
@@ -38,7 +39,6 @@ import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobmanager.slots.AllocatedSlot;
 import org.apache.flink.runtime.jobmanager.slots.SlotOwner;
 import org.apache.flink.runtime.jobmanager.slots.TaskManagerGateway;
-import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.taskmanager.TaskManagerLocation;
 import org.apache.flink.runtime.testingUtils.TestingUtils;
 import org.apache.flink.runtime.testtasks.NoOpInvokable;
@@ -51,8 +51,10 @@ import java.net.InetAddress;
 import java.util.Iterator;
 import java.util.concurrent.TimeUnit;
 
-import static org.mockito.Mockito.*;
-import static org.junit.Assert.*;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
 
 /**
  * Tests that the execution vertex handles locality preferences well.
@@ -169,7 +171,7 @@ public class ExecutionVertexLocalityTest extends TestLogger {
 
 			// target state
 			ExecutionVertex target = graph.getAllVertices().get(targetVertexId).getTaskVertices()[i];
-			target.getCurrentExecutionAttempt().setInitialState(mock(TaskStateHandles.class));
+			target.getCurrentExecutionAttempt().setInitialState(mock(TaskStateSnapshot.class));
 		}
 
 		// validate that the target vertices have the state's location as the location preference

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
index a63b02d..23f0a38 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
@@ -18,16 +18,6 @@
 
 package org.apache.flink.runtime.jobmanager;
 
-import akka.actor.ActorRef;
-import akka.actor.ActorSystem;
-import akka.actor.Identify;
-import akka.actor.PoisonPill;
-import akka.actor.Props;
-import akka.japi.pf.FI;
-import akka.japi.pf.ReceiveBuilder;
-import akka.pattern.Patterns;
-import akka.testkit.CallingThreadDispatcher;
-import akka.testkit.JavaTestKit;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.configuration.Configuration;
@@ -44,8 +34,9 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.CheckpointRecoveryFactory;
 import org.apache.flink.runtime.checkpoint.CompletedCheckpointStore;
+import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
 import org.apache.flink.runtime.checkpoint.StandaloneCheckpointIDCounter;
-import org.apache.flink.runtime.checkpoint.SubtaskState;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.clusterframework.types.ResourceID;
 import org.apache.flink.runtime.execution.librarycache.BlobLibraryCacheManager;
 import org.apache.flink.runtime.executiongraph.restart.FixedDelayRestartStrategy;
@@ -59,6 +50,7 @@ import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.jobgraph.JobStatus;
 import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
 import org.apache.flink.runtime.jobgraph.tasks.JobCheckpointingSettings;
@@ -69,9 +61,6 @@ import org.apache.flink.runtime.leaderelection.TestingLeaderElectionService;
 import org.apache.flink.runtime.leaderelection.TestingLeaderRetrievalService;
 import org.apache.flink.runtime.messages.JobManagerMessages;
 import org.apache.flink.runtime.metrics.MetricRegistry;
-import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.StreamStateHandle;
-import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.apache.flink.runtime.taskmanager.TaskManager;
 import org.apache.flink.runtime.testingUtils.TestingJobManager;
@@ -83,23 +72,24 @@ import org.apache.flink.runtime.testingUtils.TestingUtils;
 import org.apache.flink.runtime.testutils.RecoverableCompletedCheckpointStore;
 import org.apache.flink.runtime.util.TestByteStreamStateHandleDeepCompare;
 import org.apache.flink.util.InstantiationUtil;
-
 import org.apache.flink.util.TestLogger;
+
+import akka.actor.ActorRef;
+import akka.actor.ActorSystem;
+import akka.actor.Identify;
+import akka.actor.PoisonPill;
+import akka.actor.Props;
+import akka.japi.pf.FI;
+import akka.japi.pf.ReceiveBuilder;
+import akka.pattern.Patterns;
+import akka.testkit.CallingThreadDispatcher;
+import akka.testkit.JavaTestKit;
 import org.junit.AfterClass;
 import org.junit.BeforeClass;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.TemporaryFolder;
 
-import scala.Int;
-import scala.Option;
-import scala.PartialFunction;
-import scala.concurrent.Await;
-import scala.concurrent.Future;
-import scala.concurrent.duration.Deadline;
-import scala.concurrent.duration.FiniteDuration;
-import scala.runtime.BoxedUnit;
-
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
@@ -113,6 +103,15 @@ import java.util.concurrent.Executor;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeUnit;
 
+import scala.Int;
+import scala.Option;
+import scala.PartialFunction;
+import scala.concurrent.Await;
+import scala.concurrent.Future;
+import scala.concurrent.duration.Deadline;
+import scala.concurrent.duration.FiniteDuration;
+import scala.runtime.BoxedUnit;
+
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertThat;
@@ -552,10 +551,10 @@ public class JobManagerHARecoveryTest extends TestLogger {
 
 		@Override
 		public void setInitialState(
-				TaskStateHandles taskStateHandles) throws Exception {
+			TaskStateSnapshot taskStateHandles) throws Exception {
 			int subtaskIndex = getIndexInSubtaskGroup();
 			if (subtaskIndex < recoveredStates.length) {
-				try (FSDataInputStream in = taskStateHandles.getLegacyOperatorState().get(0).openInputStream()) {
+				try (FSDataInputStream in = taskStateHandles.getSubtaskStateMappings().iterator().next().getValue().getLegacyOperatorState().openInputStream()) {
 					recoveredStates[subtaskIndex] = InstantiationUtil.deserializeObject(in, getUserCodeClassLoader());
 				}
 			}
@@ -567,10 +566,11 @@ public class JobManagerHARecoveryTest extends TestLogger {
 					String.valueOf(UUID.randomUUID()),
 					InstantiationUtil.serializeObject(checkpointMetaData.getCheckpointId()));
 
-			ChainedStateHandle<StreamStateHandle> chainedStateHandle =
-					new ChainedStateHandle<StreamStateHandle>(Collections.singletonList(byteStreamStateHandle));
-			SubtaskState checkpointStateHandles =
-					new SubtaskState(chainedStateHandle, null, null, null, null);
+			TaskStateSnapshot checkpointStateHandles = new TaskStateSnapshot();
+			checkpointStateHandles.putSubtaskStateByOperatorID(
+				OperatorID.fromJobVertexID(getEnvironment().getJobVertexId()),
+				new OperatorSubtaskState(byteStreamStateHandle)
+			);
 
 			getEnvironment().acknowledgeCheckpoint(
 					checkpointMetaData.getCheckpointId(),

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
index bc420cc..d022cdc 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
@@ -24,14 +24,17 @@ import org.apache.flink.core.testutils.CommonTestUtils;
 import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTest;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
-import org.apache.flink.runtime.checkpoint.SubtaskState;
+import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.NotifyCheckpointComplete;
 import org.apache.flink.runtime.messages.checkpoint.TriggerCheckpoint;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.StreamStateHandle;
+
 import org.junit.Test;
 
 import java.io.IOException;
@@ -68,13 +71,17 @@ public class CheckpointMessagesTest {
 
 			KeyGroupRange keyGroupRange = KeyGroupRange.of(42,42);
 
-			SubtaskState checkpointStateHandles =
-					new SubtaskState(
-							CheckpointCoordinatorTest.generateChainedStateHandle(new MyHandle()),
-							CheckpointCoordinatorTest.generateChainedPartitionableStateHandle(new JobVertexID(), 0, 2, 8, false),
-							null,
-							CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, Collections.singletonList(new MyHandle())),
-							null);
+			TaskStateSnapshot checkpointStateHandles = new TaskStateSnapshot();
+			checkpointStateHandles.putSubtaskStateByOperatorID(
+				new OperatorID(),
+				new OperatorSubtaskState(
+					CheckpointCoordinatorTest.generateStreamStateHandle(new MyHandle()),
+					CheckpointCoordinatorTest.generatePartitionableStateHandle(new JobVertexID(), 0, 2, 8, false),
+					null,
+					CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, Collections.singletonList(new MyHandle())),
+					null
+				)
+			);
 
 			AcknowledgeCheckpoint withState = new AcknowledgeCheckpoint(
 					new JobID(),

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
index 851fa96..8ed06b2 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
@@ -26,7 +26,7 @@ import org.apache.flink.core.fs.Path;
 import org.apache.flink.runtime.accumulators.AccumulatorRegistry;
 import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
-import org.apache.flink.runtime.checkpoint.SubtaskState;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.io.disk.iomanager.IOManager;
@@ -156,7 +156,7 @@ public class DummyEnvironment implements Environment {
 	}
 
 	@Override
-	public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, SubtaskState subtaskState) {
+	public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, TaskStateSnapshot subtaskState) {
 	}
 
 	@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
index 4f0242e..7514cc4 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
@@ -27,7 +27,7 @@ import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.runtime.accumulators.AccumulatorRegistry;
 import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
-import org.apache.flink.runtime.checkpoint.SubtaskState;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.io.disk.iomanager.IOManager;
@@ -50,8 +50,8 @@ import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
 import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo;
 import org.apache.flink.types.Record;
 import org.apache.flink.util.MutableObjectIterator;
-
 import org.apache.flink.util.Preconditions;
+
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
 
@@ -354,7 +354,7 @@ public class MockEnvironment implements Environment {
 	}
 
 	@Override
-	public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, SubtaskState subtaskState) {
+	public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, TaskStateSnapshot subtaskState) {
 		throw new UnsupportedOperationException();
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
index c6d2fec..085a386 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
@@ -27,6 +27,7 @@ import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
@@ -49,7 +50,6 @@ import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup;
 import org.apache.flink.runtime.metrics.groups.TaskMetricGroup;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
-import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo;
 import org.apache.flink.util.SerializedValue;
 
@@ -187,7 +187,7 @@ public class TaskAsyncCallTest {
 			Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
 			Collections.<InputGateDeploymentDescriptor>emptyList(),
 			0,
-			new TaskStateHandles(),
+			new TaskStateSnapshot(),
 			mock(MemoryManager.class),
 			mock(IOManager.class),
 			networkEnvironment,
@@ -228,7 +228,7 @@ public class TaskAsyncCallTest {
 		}
 
 		@Override
-		public void setInitialState(TaskStateHandles taskStateHandles) throws Exception {}
+		public void setInitialState(TaskStateSnapshot taskStateHandles) throws Exception {}
 
 		@Override
 		public boolean triggerCheckpoint(CheckpointMetaData checkpointMetaData, CheckpointOptions checkpointOptions) {


[7/7] flink git commit: [FLINK-7213] Introduce state management by OperatorID in TaskManager

Posted by sr...@apache.org.
[FLINK-7213] Introduce state management by OperatorID in TaskManager


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/b71154a7
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/b71154a7
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/b71154a7

Branch: refs/heads/master
Commit: b71154a734ea9f4489dffe1be6761efbb90cff41
Parents: 3b0321a
Author: Stefan Richter <s....@data-artisans.com>
Authored: Mon Jun 26 18:07:59 2017 +0200
Committer: Stefan Richter <s....@data-artisans.com>
Committed: Tue Aug 15 14:56:54 2017 +0200

----------------------------------------------------------------------
 .../state/RocksDBAsyncSnapshotTest.java         |  21 +-
 .../checkpoint/CheckpointCoordinator.java       |   7 +-
 .../CheckpointCoordinatorGateway.java           |   2 +-
 .../flink/runtime/checkpoint/OperatorState.java |   4 +-
 .../checkpoint/OperatorSubtaskState.java        | 224 ++++++---
 .../runtime/checkpoint/PendingCheckpoint.java   |  54 +-
 .../RoundRobinOperatorStateRepartitioner.java   |   4 +
 .../checkpoint/StateAssignmentOperation.java    | 177 ++++---
 .../runtime/checkpoint/TaskStateSnapshot.java   | 139 ++++++
 .../savepoint/SavepointV2Serializer.java        |  20 +-
 .../deployment/TaskDeploymentDescriptor.java    |   8 +-
 .../flink/runtime/execution/Environment.java    |   9 +-
 .../flink/runtime/executiongraph/Execution.java |   8 +-
 .../runtime/executiongraph/ExecutionVertex.java |   8 +-
 .../runtime/jobgraph/tasks/StatefulTask.java    |   8 +-
 .../flink/runtime/jobmaster/JobMaster.java      |   5 +-
 .../checkpoint/AcknowledgeCheckpoint.java       |   8 +-
 .../state/StateInitializationContextImpl.java   |  11 +-
 .../flink/runtime/state/TaskStateHandles.java   | 172 -------
 .../rpc/RpcCheckpointResponder.java             |   4 +-
 .../ActorGatewayCheckpointResponder.java        |   4 +-
 .../taskmanager/CheckpointResponder.java        |   4 +-
 .../runtime/taskmanager/RuntimeEnvironment.java |   4 +-
 .../apache/flink/runtime/taskmanager/Task.java  |   8 +-
 .../CheckpointCoordinatorFailureTest.java       |  49 +-
 .../checkpoint/CheckpointCoordinatorTest.java   | 498 ++++++++++---------
 .../checkpoint/CheckpointStateRestoreTest.java  |  49 +-
 .../CompletedCheckpointStoreTest.java           |   2 +-
 .../checkpoint/PendingCheckpointTest.java       |   2 +-
 .../TaskDeploymentDescriptorTest.java           |   4 +-
 .../ExecutionVertexLocalityTest.java            |  10 +-
 .../jobmanager/JobManagerHARecoveryTest.java    |  60 +--
 .../messages/CheckpointMessagesTest.java        |  23 +-
 .../operators/testutils/DummyEnvironment.java   |   4 +-
 .../operators/testutils/MockEnvironment.java    |   6 +-
 .../runtime/taskmanager/TaskAsyncCallTest.java  |   6 +-
 .../flink/runtime/taskmanager/TaskStopTest.java |  26 +-
 .../runtime/util/JvmExitOnFatalErrorTest.java   |   7 +-
 .../flink/streaming/api/graph/StreamConfig.java |  13 +-
 .../api/graph/StreamingJobGraphGenerator.java   |  11 +-
 .../api/operators/AbstractStreamOperator.java   |  19 +-
 .../streaming/api/operators/StreamOperator.java |   6 +-
 .../runtime/tasks/OperatorStateHandles.java     |  19 -
 .../streaming/runtime/tasks/StreamTask.java     | 196 +++-----
 .../AbstractUdfStreamOperatorLifecycleTest.java |   5 +-
 .../operators/async/AsyncWaitOperatorTest.java  |  16 +-
 .../streaming/runtime/io/BarrierBufferTest.java |   4 +-
 .../runtime/io/BarrierTrackerTest.java          |   4 +-
 .../runtime/operators/StreamTaskTimerTest.java  |   2 +
 .../TestProcessingTimeServiceTest.java          |   2 +
 .../runtime/tasks/BlockingCheckpointsTest.java  |   2 +
 .../tasks/InterruptSensitiveRestoreTest.java    |  55 +-
 .../runtime/tasks/OneInputStreamTaskTest.java   |  34 +-
 .../SourceExternalCheckpointTriggerTest.java    |   2 +
 .../runtime/tasks/SourceStreamTaskTest.java     |   3 +
 .../runtime/tasks/StreamMockEnvironment.java    |   4 +-
 .../StreamTaskCancellationBarrierTest.java      |   3 +
 .../tasks/StreamTaskTerminationTest.java        |   2 +
 .../streaming/runtime/tasks/StreamTaskTest.java |  79 ++-
 .../runtime/tasks/StreamTaskTestHarness.java    |   2 +
 .../runtime/tasks/TwoInputStreamTaskTest.java   |   5 +
 .../util/AbstractStreamOperatorTestHarness.java |  22 +-
 .../test/checkpointing/SavepointITCase.java     |   2 +-
 63 files changed, 1185 insertions(+), 986 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
index d2edf0e..c752e53 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
@@ -32,8 +32,10 @@ import org.apache.flink.core.testutils.OneShotLatch;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
-import org.apache.flink.runtime.checkpoint.SubtaskState;
+import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.execution.Environment;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
@@ -74,6 +76,7 @@ import java.io.IOException;
 import java.lang.reflect.Field;
 import java.net.URI;
 import java.util.Arrays;
+import java.util.Map;
 import java.util.UUID;
 import java.util.concurrent.CancellationException;
 import java.util.concurrent.ExecutionException;
@@ -81,7 +84,7 @@ import java.util.concurrent.ExecutorService;
 import java.util.concurrent.RunnableFuture;
 import java.util.concurrent.TimeUnit;
 
-import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyInt;
@@ -137,6 +140,7 @@ public class RocksDBAsyncSnapshotTest {
 		streamConfig.setStateBackend(backend);
 
 		streamConfig.setStreamOperator(new AsyncCheckpointOperator());
+		streamConfig.setOperatorID(new OperatorID());
 
 		final OneShotLatch delayCheckpointLatch = new OneShotLatch();
 		final OneShotLatch ensureCheckpointLatch = new OneShotLatch();
@@ -152,7 +156,7 @@ public class RocksDBAsyncSnapshotTest {
 			public void acknowledgeCheckpoint(
 					long checkpointId,
 					CheckpointMetrics checkpointMetrics,
-					SubtaskState checkpointStateHandles) {
+					TaskStateSnapshot checkpointStateHandles) {
 
 				super.acknowledgeCheckpoint(checkpointId, checkpointMetrics);
 
@@ -164,8 +168,16 @@ public class RocksDBAsyncSnapshotTest {
 					throw new RuntimeException(e);
 				}
 
+				boolean hasManagedKeyedState = false;
+				for (Map.Entry<OperatorID, OperatorSubtaskState> entry : checkpointStateHandles.getSubtaskStateMappings()) {
+					OperatorSubtaskState state = entry.getValue();
+					if (state != null) {
+						hasManagedKeyedState |= state.getManagedKeyedState() != null;
+					}
+				}
+
 				// should be one k/v state
-				assertNotNull(checkpointStateHandles.getManagedKeyedState());
+				assertTrue(hasManagedKeyedState);
 
 				// we now know that the checkpoint went through
 				ensureCheckpointLatch.trigger();
@@ -241,6 +253,7 @@ public class RocksDBAsyncSnapshotTest {
 		streamConfig.setStateBackend(backend);
 
 		streamConfig.setStreamOperator(new AsyncCheckpointOperator());
+		streamConfig.setOperatorID(new OperatorID());
 
 		StreamMockEnvironment mockEnv = new StreamMockEnvironment(
 				testHarness.jobConfig,

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
index 6f41867..0b64a73 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
@@ -40,7 +40,6 @@ import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint;
 import org.apache.flink.runtime.state.SharedStateRegistry;
-import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.taskmanager.DispatcherThreadFactory;
 import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.StringUtils;
@@ -1016,7 +1015,7 @@ public class CheckpointCoordinator {
 	 * Restores the latest checkpointed state.
 	 *
 	 * @param tasks Map of job vertices to restore. State for these vertices is
-	 * restored via {@link Execution#setInitialState(TaskStateHandles)}.
+	 * restored via {@link Execution#setInitialState(TaskStateSnapshot)}.
 	 * @param errorIfNoCheckpoint Fail if no completed checkpoint is available to
 	 * restore from.
 	 * @param allowNonRestoredState Allow checkpoint state that cannot be mapped
@@ -1102,7 +1101,7 @@ public class CheckpointCoordinator {
 	 *                         mapped to any job vertex in tasks.
 	 * @param tasks            Map of job vertices to restore. State for these 
 	 *                         vertices is restored via 
-	 *                         {@link Execution#setInitialState(TaskStateHandles)}.
+	 *                         {@link Execution#setInitialState(TaskStateSnapshot)}.
 	 * @param userClassLoader  The class loader to resolve serialized classes in 
 	 *                         legacy savepoint versions. 
 	 */
@@ -1256,7 +1255,7 @@ public class CheckpointCoordinator {
 			final JobID jobId,
 			final ExecutionAttemptID executionAttemptID,
 			final long checkpointId,
-			final SubtaskState subtaskState) {
+			final TaskStateSnapshot subtaskState) {
 
 		if (subtaskState != null) {
 			executor.execute(new Runnable() {

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorGateway.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorGateway.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorGateway.java
index 43d66ee..22244f6 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorGateway.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorGateway.java
@@ -29,7 +29,7 @@ public interface CheckpointCoordinatorGateway extends RpcGateway {
 			final ExecutionAttemptID executionAttemptID,
 			final long checkpointId,
 			final CheckpointMetrics checkpointMetrics,
-			final SubtaskState subtaskState);
+			final TaskStateSnapshot subtaskState);
 
 	void declineCheckpoint(
 			JobID jobID,

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorState.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorState.java
index b153028..145ff6a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorState.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorState.java
@@ -30,8 +30,8 @@ import java.util.Map;
 import java.util.Objects;
 
 /**
- * Simple container class which contains the raw/managed/legacy operator state and key-group state handles for the sub
- * tasks of an operator.
+ * Simple container class which contains the raw/managed/legacy operator state and key-group state handles from all sub
+ * tasks of an operator and therefore represents the complete state of a logical operator.
  */
 public class OperatorState implements CompositeStateHandle {
 

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java
index e2ae632..296b5ab 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.checkpoint;
 
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.runtime.state.CompositeStateHandle;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
@@ -25,13 +26,35 @@ import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.StateObject;
 import org.apache.flink.runtime.state.StateUtil;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.util.Preconditions;
+
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.util.Arrays;
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
 
 /**
- * Container for the state of one parallel subtask of an operator. This is part of the {@link OperatorState}.
+ * This class encapsulates the state for one parallel instance of an operator. The complete state of a (logical)
+ * operator (e.g. a flatmap operator) consists of the union of all {@link OperatorSubtaskState}s from all
+ * parallel tasks that physically execute parallelized, physical instances of the operator.
+ *
+ * <p>The full state of the logical operator is represented by {@link OperatorState} which consists of
+ * {@link OperatorSubtaskState}s.
+ *
+ * <p>Typically, we expect all collections in this class to be of size 0 or 1, because there is up to one state handle
+ * produced per state type (e.g. managed-keyed, raw-operator, ...). In particular, this holds when taking a snapshot.
+ * The purpose of having the state handles in collections is that this class is also reused in restoring state.
+ * Under normal circumstances, the expected size of each collection is still 0 or 1, except for scale-down. In
+ * scale-down, one operator subtask can become responsible for the state of multiple previous subtasks. The collections
+ * can then store all the state handles that are relevant to build up the new subtask state.
+ *
+ * <p>There is no collection for legacy state because it is not rescalable.
  */
 public class OperatorSubtaskState implements CompositeStateHandle {
 
@@ -46,27 +69,32 @@ public class OperatorSubtaskState implements CompositeStateHandle {
 	 * Can be removed when we remove the APIs for non-repartitionable operator state.
 	 */
 	@Deprecated
+	@Nullable
 	private final StreamStateHandle legacyOperatorState;
 
 	/**
 	 * Snapshot from the {@link org.apache.flink.runtime.state.OperatorStateBackend}.
 	 */
-	private final OperatorStateHandle managedOperatorState;
+	@Nonnull
+	private final Collection<OperatorStateHandle> managedOperatorState;
 
 	/**
 	 * Snapshot written using {@link org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream}.
 	 */
-	private final OperatorStateHandle rawOperatorState;
+	@Nonnull
+	private final Collection<OperatorStateHandle> rawOperatorState;
 
 	/**
 	 * Snapshot from {@link org.apache.flink.runtime.state.KeyedStateBackend}.
 	 */
-	private final KeyedStateHandle managedKeyedState;
+	@Nonnull
+	private final Collection<KeyedStateHandle> managedKeyedState;
 
 	/**
 	 * Snapshot written using {@link org.apache.flink.runtime.state.KeyedStateCheckpointOutputStream}.
 	 */
-	private final KeyedStateHandle rawKeyedState;
+	@Nonnull
+	private final Collection<KeyedStateHandle> rawKeyedState;
 
 	/**
 	 * The state size. This is also part of the deserialized state handle.
@@ -75,31 +103,79 @@ public class OperatorSubtaskState implements CompositeStateHandle {
 	 */
 	private final long stateSize;
 
+	@VisibleForTesting
+	public OperatorSubtaskState(StreamStateHandle legacyOperatorState) {
+
+		this(legacyOperatorState,
+			Collections.<OperatorStateHandle>emptyList(),
+			Collections.<OperatorStateHandle>emptyList(),
+			Collections.<KeyedStateHandle>emptyList(),
+			Collections.<KeyedStateHandle>emptyList());
+	}
+
+	/**
+	 * Empty state.
+	 */
+	public OperatorSubtaskState() {
+		this(null);
+	}
+
 	public OperatorSubtaskState(
 		StreamStateHandle legacyOperatorState,
-		OperatorStateHandle managedOperatorState,
-		OperatorStateHandle rawOperatorState,
-		KeyedStateHandle managedKeyedState,
-		KeyedStateHandle rawKeyedState) {
+		Collection<OperatorStateHandle> managedOperatorState,
+		Collection<OperatorStateHandle> rawOperatorState,
+		Collection<KeyedStateHandle> managedKeyedState,
+		Collection<KeyedStateHandle> rawKeyedState) {
 
 		this.legacyOperatorState = legacyOperatorState;
-		this.managedOperatorState = managedOperatorState;
-		this.rawOperatorState = rawOperatorState;
-		this.managedKeyedState = managedKeyedState;
-		this.rawKeyedState = rawKeyedState;
+		this.managedOperatorState = Preconditions.checkNotNull(managedOperatorState);
+		this.rawOperatorState = Preconditions.checkNotNull(rawOperatorState);
+		this.managedKeyedState = Preconditions.checkNotNull(managedKeyedState);
+		this.rawKeyedState = Preconditions.checkNotNull(rawKeyedState);
 
 		try {
 			long calculateStateSize = getSizeNullSafe(legacyOperatorState);
-			calculateStateSize += getSizeNullSafe(managedOperatorState);
-			calculateStateSize += getSizeNullSafe(rawOperatorState);
-			calculateStateSize += getSizeNullSafe(managedKeyedState);
-			calculateStateSize += getSizeNullSafe(rawKeyedState);
+			calculateStateSize += sumAllSizes(managedOperatorState);
+			calculateStateSize += sumAllSizes(rawOperatorState);
+			calculateStateSize += sumAllSizes(managedKeyedState);
+			calculateStateSize += sumAllSizes(rawKeyedState);
 			stateSize = calculateStateSize;
 		} catch (Exception e) {
 			throw new RuntimeException("Failed to get state size.", e);
 		}
 	}
 
+	/**
+	 * For convenience because the size of the collections is typically 0 or 1. Null values are translated into empty
+	 * Collections (except for legacy state).
+	 */
+	public OperatorSubtaskState(
+		StreamStateHandle legacyOperatorState,
+		OperatorStateHandle managedOperatorState,
+		OperatorStateHandle rawOperatorState,
+		KeyedStateHandle managedKeyedState,
+		KeyedStateHandle rawKeyedState) {
+
+		this(legacyOperatorState,
+			singletonOrEmptyOnNull(managedOperatorState),
+			singletonOrEmptyOnNull(rawOperatorState),
+			singletonOrEmptyOnNull(managedKeyedState),
+			singletonOrEmptyOnNull(rawKeyedState));
+	}
+
+	private static <T> Collection<T> singletonOrEmptyOnNull(T element) {
+		return element != null ? Collections.singletonList(element) : Collections.<T>emptyList();
+	}
+
+	private static long sumAllSizes(Collection<? extends StateObject> stateObject) throws Exception {
+		long size = 0L;
+		for (StateObject object : stateObject) {
+			size += getSizeNullSafe(object);
+		}
+
+		return size;
+	}
+
 	private static long getSizeNullSafe(StateObject stateObject) throws Exception {
 		return stateObject != null ? stateObject.getStateSize() : 0L;
 	}
@@ -111,36 +187,58 @@ public class OperatorSubtaskState implements CompositeStateHandle {
 	 * Can be removed when we remove the APIs for non-repartitionable operator state.
 	 */
 	@Deprecated
+	@Nullable
 	public StreamStateHandle getLegacyOperatorState() {
 		return legacyOperatorState;
 	}
 
-	public OperatorStateHandle getManagedOperatorState() {
+	/**
+	 * Returns a handle to the managed operator state.
+	 */
+	@Nonnull
+	public Collection<OperatorStateHandle> getManagedOperatorState() {
 		return managedOperatorState;
 	}
 
-	public OperatorStateHandle getRawOperatorState() {
+	/**
+	 * Returns a handle to the raw operator state.
+	 */
+	@Nonnull
+	public Collection<OperatorStateHandle> getRawOperatorState() {
 		return rawOperatorState;
 	}
 
-	public KeyedStateHandle getManagedKeyedState() {
+	/**
+	 * Returns a handle to the managed keyed state.
+	 */
+	@Nonnull
+	public Collection<KeyedStateHandle> getManagedKeyedState() {
 		return managedKeyedState;
 	}
 
-	public KeyedStateHandle getRawKeyedState() {
+	/**
+	 * Returns a handle to the raw keyed state.
+	 */
+	@Nonnull
+	public Collection<KeyedStateHandle> getRawKeyedState() {
 		return rawKeyedState;
 	}
 
 	@Override
 	public void discardState() {
 		try {
-			StateUtil.bestEffortDiscardAllStateObjects(
-				Arrays.asList(
-					legacyOperatorState,
-					managedOperatorState,
-					rawOperatorState,
-					managedKeyedState,
-					rawKeyedState));
+			List<StateObject> toDispose =
+				new ArrayList<>(1 +
+					managedOperatorState.size() +
+					rawOperatorState.size() +
+					managedKeyedState.size() +
+					rawKeyedState.size());
+			toDispose.add(legacyOperatorState);
+			toDispose.addAll(managedOperatorState);
+			toDispose.addAll(rawOperatorState);
+			toDispose.addAll(managedKeyedState);
+			toDispose.addAll(rawKeyedState);
+			StateUtil.bestEffortDiscardAllStateObjects(toDispose);
 		} catch (Exception e) {
 			LOG.warn("Error while discarding operator states.", e);
 		}
@@ -148,12 +246,17 @@ public class OperatorSubtaskState implements CompositeStateHandle {
 
 	@Override
 	public void registerSharedStates(SharedStateRegistry sharedStateRegistry) {
-		if (managedKeyedState != null) {
-			managedKeyedState.registerSharedStates(sharedStateRegistry);
-		}
+		registerSharedState(sharedStateRegistry, managedKeyedState);
+		registerSharedState(sharedStateRegistry, rawKeyedState);
+	}
 
-		if (rawKeyedState != null) {
-			rawKeyedState.registerSharedStates(sharedStateRegistry);
+	private static void registerSharedState(
+		SharedStateRegistry sharedStateRegistry,
+		Iterable<KeyedStateHandle> stateHandles) {
+		for (KeyedStateHandle stateHandle : stateHandles) {
+			if (stateHandle != null) {
+				stateHandle.registerSharedStates(sharedStateRegistry);
+			}
 		}
 	}
 
@@ -175,44 +278,32 @@ public class OperatorSubtaskState implements CompositeStateHandle {
 
 		OperatorSubtaskState that = (OperatorSubtaskState) o;
 
-		if (stateSize != that.stateSize) {
+		if (getStateSize() != that.getStateSize()) {
 			return false;
 		}
-
-		if (legacyOperatorState != null ?
-			!legacyOperatorState.equals(that.legacyOperatorState)
-			: that.legacyOperatorState != null) {
+		if (getLegacyOperatorState() != null ? !getLegacyOperatorState().equals(that.getLegacyOperatorState()) : that.getLegacyOperatorState() != null) {
 			return false;
 		}
-		if (managedOperatorState != null ?
-			!managedOperatorState.equals(that.managedOperatorState)
-			: that.managedOperatorState != null) {
+		if (!getManagedOperatorState().equals(that.getManagedOperatorState())) {
 			return false;
 		}
-		if (rawOperatorState != null ?
-			!rawOperatorState.equals(that.rawOperatorState)
-			: that.rawOperatorState != null) {
+		if (!getRawOperatorState().equals(that.getRawOperatorState())) {
 			return false;
 		}
-		if (managedKeyedState != null ?
-			!managedKeyedState.equals(that.managedKeyedState)
-			: that.managedKeyedState != null) {
+		if (!getManagedKeyedState().equals(that.getManagedKeyedState())) {
 			return false;
 		}
-		return rawKeyedState != null ?
-			rawKeyedState.equals(that.rawKeyedState)
-			: that.rawKeyedState == null;
-
+		return getRawKeyedState().equals(that.getRawKeyedState());
 	}
 
 	@Override
 	public int hashCode() {
-		int result = legacyOperatorState != null ? legacyOperatorState.hashCode() : 0;
-		result = 31 * result + (managedOperatorState != null ? managedOperatorState.hashCode() : 0);
-		result = 31 * result + (rawOperatorState != null ? rawOperatorState.hashCode() : 0);
-		result = 31 * result + (managedKeyedState != null ? managedKeyedState.hashCode() : 0);
-		result = 31 * result + (rawKeyedState != null ? rawKeyedState.hashCode() : 0);
-		result = 31 * result + (int) (stateSize ^ (stateSize >>> 32));
+		int result = getLegacyOperatorState() != null ? getLegacyOperatorState().hashCode() : 0;
+		result = 31 * result + getManagedOperatorState().hashCode();
+		result = 31 * result + getRawOperatorState().hashCode();
+		result = 31 * result + getManagedKeyedState().hashCode();
+		result = 31 * result + getRawKeyedState().hashCode();
+		result = 31 * result + (int) (getStateSize() ^ (getStateSize() >>> 32));
 		return result;
 	}
 
@@ -227,4 +318,21 @@ public class OperatorSubtaskState implements CompositeStateHandle {
 			", stateSize=" + stateSize +
 			'}';
 	}
+
+	public boolean hasState() {
+		return legacyOperatorState != null
+			|| hasState(managedOperatorState)
+			|| hasState(rawOperatorState)
+			|| hasState(managedKeyedState)
+			|| hasState(rawKeyedState);
+	}
+
+	private boolean hasState(Iterable<? extends StateObject> states) {
+		for (StateObject state : states) {
+			if (state != null) {
+				return true;
+			}
+		}
+		return false;
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
index 3472fc2..16231dd 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
@@ -25,19 +25,18 @@ import org.apache.flink.runtime.checkpoint.savepoint.SavepointV2;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
 import org.apache.flink.runtime.jobgraph.OperatorID;
-import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.KeyedStateHandle;
-import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateUtil;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.filesystem.FileStateHandle;
 import org.apache.flink.util.ExceptionUtils;
 import org.apache.flink.util.Preconditions;
+
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import javax.annotation.Nullable;
 import javax.annotation.concurrent.GuardedBy;
+
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.HashMap;
@@ -353,13 +352,13 @@ public class PendingCheckpoint {
 	 * Acknowledges the task with the given execution attempt id and the given subtask state.
 	 *
 	 * @param executionAttemptId of the acknowledged task
-	 * @param subtaskState of the acknowledged task
+	 * @param operatorSubtaskStates of the acknowledged task
 	 * @param metrics Checkpoint metrics for the stats
 	 * @return TaskAcknowledgeResult of the operation
 	 */
 	public TaskAcknowledgeResult acknowledgeTask(
 			ExecutionAttemptID executionAttemptId,
-			SubtaskState subtaskState,
+			TaskStateSnapshot operatorSubtaskStates,
 			CheckpointMetrics metrics) {
 
 		synchronized (lock) {
@@ -383,21 +382,19 @@ public class PendingCheckpoint {
 			int subtaskIndex = vertex.getParallelSubtaskIndex();
 			long ackTimestamp = System.currentTimeMillis();
 
-			long stateSize = 0;
-			if (subtaskState != null) {
-				stateSize = subtaskState.getStateSize();
-
-				@SuppressWarnings("deprecation")
-				ChainedStateHandle<StreamStateHandle> nonPartitionedState =
-					subtaskState.getLegacyOperatorState();
-				ChainedStateHandle<OperatorStateHandle> partitioneableState =
-					subtaskState.getManagedOperatorState();
-				ChainedStateHandle<OperatorStateHandle> rawOperatorState =
-					subtaskState.getRawOperatorState();
-
-				// break task state apart into separate operator states
-				for (int x = 0; x < operatorIDs.size(); x++) {
-					OperatorID operatorID = operatorIDs.get(x);
+			long stateSize = 0L;
+
+			if (operatorSubtaskStates != null) {
+				for (OperatorID operatorID : operatorIDs) {
+
+					OperatorSubtaskState operatorSubtaskState =
+						operatorSubtaskStates.getSubtaskStateByOperatorID(operatorID);
+
+					// if no real operatorSubtaskState was reported, we insert an empty state
+					if (operatorSubtaskState == null) {
+						operatorSubtaskState = new OperatorSubtaskState();
+					}
+
 					OperatorState operatorState = operatorStates.get(operatorID);
 
 					if (operatorState == null) {
@@ -408,23 +405,8 @@ public class PendingCheckpoint {
 						operatorStates.put(operatorID, operatorState);
 					}
 
-					KeyedStateHandle managedKeyedState = null;
-					KeyedStateHandle rawKeyedState = null;
-
-					// only the head operator retains the keyed state
-					if (x == operatorIDs.size() - 1) {
-						managedKeyedState = subtaskState.getManagedKeyedState();
-						rawKeyedState = subtaskState.getRawKeyedState();
-					}
-
-					OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(
-							nonPartitionedState != null ? nonPartitionedState.get(x) : null,
-							partitioneableState != null ? partitioneableState.get(x) : null,
-							rawOperatorState != null ? rawOperatorState.get(x) : null,
-							managedKeyedState,
-							rawKeyedState);
-
 					operatorState.putState(subtaskIndex, operatorSubtaskState);
+					stateSize += operatorSubtaskState.getStateSize();
 				}
 			}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java
index 046096f..4513ef8 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java
@@ -89,6 +89,10 @@ public class RoundRobinOperatorStateRepartitioner implements OperatorStateRepart
 
 		for (OperatorStateHandle psh : previousParallelSubtaskStates) {
 
+			if (psh == null) {
+				continue;
+			}
+
 			for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> e :
 					psh.getStateNameToPartitionOffsets().entrySet()) {
 				OperatorStateHandle.StateMetaInfo metaInfo = e.getValue();

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
index 5712ea1..b69285e 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
@@ -23,15 +23,14 @@ import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.OperatorID;
-import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
-import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.util.Preconditions;
+
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -185,7 +184,8 @@ public class StateAssignmentOperation {
 					subNonPartitionableState);
 
 				// PartitionedState
-				reAssignSubPartitionableState(newManagedOperatorStates,
+				reAssignSubPartitionableState(
+					newManagedOperatorStates,
 					newRawOperatorStates,
 					subTaskIndex,
 					operatorIndex,
@@ -193,36 +193,57 @@ public class StateAssignmentOperation {
 					subRawOperatorState);
 
 				// KeyedState
-				if (operatorIndex == operatorIDs.size() - 1) {
-					subKeyedState = reAssignSubKeyedStates(operatorState,
+				if (isHeadOperator(operatorIndex, operatorIDs)) {
+					subKeyedState = reAssignSubKeyedStates(
+						operatorState,
 						keyGroupPartitions,
 						subTaskIndex,
 						newParallelism,
 						oldParallelism);
-
 				}
 			}
 
-
 			// check if a stateless task
 			if (!allElementsAreNull(subNonPartitionableState) ||
 				!allElementsAreNull(subManagedOperatorState) ||
 				!allElementsAreNull(subRawOperatorState) ||
 				subKeyedState != null) {
 
-				TaskStateHandles taskStateHandles = new TaskStateHandles(
+				TaskStateSnapshot taskState = new TaskStateSnapshot();
 
-					new ChainedStateHandle<>(subNonPartitionableState),
-					subManagedOperatorState,
-					subRawOperatorState,
-					subKeyedState != null ? subKeyedState.f0 : null,
-					subKeyedState != null ? subKeyedState.f1 : null);
+				for (int i = 0; i < operatorIDs.size(); ++i) {
+
+					OperatorID operatorID = operatorIDs.get(i);
+
+					Collection<KeyedStateHandle> rawKeyed = Collections.emptyList();
+					Collection<KeyedStateHandle> managedKeyed = Collections.emptyList();
+
+					// keyed state case
+					if (subKeyedState != null) {
+						managedKeyed = subKeyedState.f0;
+						rawKeyed = subKeyedState.f1;
+					}
+
+					OperatorSubtaskState operatorSubtaskState =
+						new OperatorSubtaskState(
+							subNonPartitionableState.get(i),
+							subManagedOperatorState.get(i),
+							subRawOperatorState.get(i),
+							managedKeyed,
+							rawKeyed
+						);
+
+					taskState.putSubtaskStateByOperatorID(operatorID, operatorSubtaskState);
+				}
 
-				currentExecutionAttempt.setInitialState(taskStateHandles);
+				currentExecutionAttempt.setInitialState(taskState);
 			}
 		}
 	}
 
+	private static boolean isHeadOperator(int opIdx, List<OperatorID> operatorIDs) {
+		return opIdx == operatorIDs.size() - 1;
+	}
 
 	public void checkParallelismPreconditions(List<OperatorState> operatorStates, ExecutionJobVertex executionJobVertex) {
 
@@ -239,18 +260,18 @@ public class StateAssignmentOperation {
 			List<Collection<OperatorStateHandle>> subManagedOperatorState,
 			List<Collection<OperatorStateHandle>> subRawOperatorState) {
 
-		if (newMangedOperatorStates.get(operatorIndex) != null) {
-			subManagedOperatorState.add(newMangedOperatorStates.get(operatorIndex).get(subTaskIndex));
+		if (newMangedOperatorStates.get(operatorIndex) != null && !newMangedOperatorStates.get(operatorIndex).isEmpty()) {
+			Collection<OperatorStateHandle> operatorStateHandles = newMangedOperatorStates.get(operatorIndex).get(subTaskIndex);
+			subManagedOperatorState.add(operatorStateHandles != null ? operatorStateHandles : Collections.<OperatorStateHandle>emptyList());
 		} else {
-			subManagedOperatorState.add(null);
+			subManagedOperatorState.add(Collections.<OperatorStateHandle>emptyList());
 		}
-		if (newRawOperatorStates.get(operatorIndex) != null) {
-			subRawOperatorState.add(newRawOperatorStates.get(operatorIndex).get(subTaskIndex));
+		if (newRawOperatorStates.get(operatorIndex) != null && !newRawOperatorStates.get(operatorIndex).isEmpty()) {
+			Collection<OperatorStateHandle> operatorStateHandles = newRawOperatorStates.get(operatorIndex).get(subTaskIndex);
+			subRawOperatorState.add(operatorStateHandles != null ? operatorStateHandles : Collections.<OperatorStateHandle>emptyList());
 		} else {
-			subRawOperatorState.add(null);
+			subRawOperatorState.add(Collections.<OperatorStateHandle>emptyList());
 		}
-
-
 	}
 
 	private Tuple2<Collection<KeyedStateHandle>, Collection<KeyedStateHandle>> reAssignSubKeyedStates(
@@ -265,24 +286,22 @@ public class StateAssignmentOperation {
 
 		if (newParallelism == oldParallelism) {
 			if (operatorState.getState(subTaskIndex) != null) {
-				KeyedStateHandle oldSubManagedKeyedState = operatorState.getState(subTaskIndex).getManagedKeyedState();
-				KeyedStateHandle oldSubRawKeyedState = operatorState.getState(subTaskIndex).getRawKeyedState();
-				subManagedKeyedState = oldSubManagedKeyedState != null ? Collections.singletonList(
-					oldSubManagedKeyedState) : null;
-				subRawKeyedState = oldSubRawKeyedState != null ? Collections.singletonList(
-					oldSubRawKeyedState) : null;
+				subManagedKeyedState = operatorState.getState(subTaskIndex).getManagedKeyedState();
+				subRawKeyedState = operatorState.getState(subTaskIndex).getRawKeyedState();
 			} else {
-				subManagedKeyedState = null;
-				subRawKeyedState = null;
+				subManagedKeyedState = Collections.emptyList();
+				subRawKeyedState = Collections.emptyList();
 			}
 		} else {
 			subManagedKeyedState = getManagedKeyedStateHandles(operatorState, keyGroupPartitions.get(subTaskIndex));
 			subRawKeyedState = getRawKeyedStateHandles(operatorState, keyGroupPartitions.get(subTaskIndex));
 		}
-		if (subManagedKeyedState == null && subRawKeyedState == null) {
+
+		if (subManagedKeyedState.isEmpty() && subRawKeyedState.isEmpty()) {
 			return null;
+		} else {
+			return new Tuple2<>(subManagedKeyedState, subRawKeyedState);
 		}
-		return new Tuple2<>(subManagedKeyedState, subRawKeyedState);
 	}
 
 
@@ -318,7 +337,7 @@ public class StateAssignmentOperation {
 			List<List<Collection<OperatorStateHandle>>> newManagedOperatorStates,
 			List<List<Collection<OperatorStateHandle>>> newRawOperatorStates) {
 
-		//collect the old partitionalbe state
+		//collect the old partitionable state
 		List<List<OperatorStateHandle>> oldManagedOperatorStates = new ArrayList<>();
 		List<List<OperatorStateHandle>> oldRawOperatorStates = new ArrayList<>();
 
@@ -351,19 +370,16 @@ public class StateAssignmentOperation {
 			for (int i = 0; i < operatorState.getParallelism(); i++) {
 				OperatorSubtaskState operatorSubtaskState = operatorState.getState(i);
 				if (operatorSubtaskState != null) {
-					if (operatorSubtaskState.getManagedOperatorState() != null) {
-						if (managedOperatorState == null) {
-							managedOperatorState = new ArrayList<>();
-						}
-						managedOperatorState.add(operatorSubtaskState.getManagedOperatorState());
+
+					if (managedOperatorState == null) {
+						managedOperatorState = new ArrayList<>();
 					}
+					managedOperatorState.addAll(operatorSubtaskState.getManagedOperatorState());
 
-					if (operatorSubtaskState.getRawOperatorState() != null) {
-						if (rawOperatorState == null) {
-							rawOperatorState = new ArrayList<>();
-						}
-						rawOperatorState.add(operatorSubtaskState.getRawOperatorState());
+					if (rawOperatorState == null) {
+						rawOperatorState = new ArrayList<>();
 					}
+					rawOperatorState.addAll(operatorSubtaskState.getRawOperatorState());
 				}
 
 			}
@@ -382,21 +398,19 @@ public class StateAssignmentOperation {
 	 * @return all managedKeyedStateHandles which have intersection with given KeyGroupRange
 	 */
 	public static List<KeyedStateHandle> getManagedKeyedStateHandles(
-			OperatorState operatorState,
-			KeyGroupRange subtaskKeyGroupRange) {
+		OperatorState operatorState,
+		KeyGroupRange subtaskKeyGroupRange) {
 
-		List<KeyedStateHandle> subtaskKeyedStateHandles = null;
+		List<KeyedStateHandle> subtaskKeyedStateHandles = new ArrayList<>();
 
 		for (int i = 0; i < operatorState.getParallelism(); i++) {
-			if (operatorState.getState(i) != null && operatorState.getState(i).getManagedKeyedState() != null) {
-				KeyedStateHandle intersectedKeyedStateHandle = operatorState.getState(i).getManagedKeyedState().getIntersection(subtaskKeyGroupRange);
+			if (operatorState.getState(i) != null) {
 
-				if (intersectedKeyedStateHandle != null) {
-					if (subtaskKeyedStateHandles == null) {
-						subtaskKeyedStateHandles = new ArrayList<>();
-					}
-					subtaskKeyedStateHandles.add(intersectedKeyedStateHandle);
-				}
+				Collection<KeyedStateHandle> keyedStateHandles = operatorState.getState(i).getManagedKeyedState();
+				extractIntersectingState(
+					keyedStateHandles,
+					subtaskKeyGroupRange,
+					subtaskKeyedStateHandles);
 			}
 		}
 
@@ -415,22 +429,40 @@ public class StateAssignmentOperation {
 		OperatorState operatorState,
 		KeyGroupRange subtaskKeyGroupRange) {
 
-		List<KeyedStateHandle> subtaskKeyedStateHandles = null;
+		List<KeyedStateHandle> extractedKeyedStateHandles = new ArrayList<>();
 
 		for (int i = 0; i < operatorState.getParallelism(); i++) {
-			if (operatorState.getState(i) != null && operatorState.getState(i).getRawKeyedState() != null) {
-				KeyedStateHandle intersectedKeyedStateHandle = operatorState.getState(i).getRawKeyedState().getIntersection(subtaskKeyGroupRange);
+			if (operatorState.getState(i) != null) {
+				Collection<KeyedStateHandle> rawKeyedState = operatorState.getState(i).getRawKeyedState();
+				extractIntersectingState(
+					rawKeyedState,
+					subtaskKeyGroupRange,
+					extractedKeyedStateHandles);
+			}
+		}
+
+		return extractedKeyedStateHandles;
+	}
+
+	/**
+	 * Extracts certain key group ranges from the given state handles and adds them to the collector.
+	 */
+	private static void extractIntersectingState(
+		Collection<KeyedStateHandle> originalSubtaskStateHandles,
+		KeyGroupRange rangeToExtract,
+		List<KeyedStateHandle> extractedStateCollector) {
+
+		for (KeyedStateHandle keyedStateHandle : originalSubtaskStateHandles) {
+
+			if (keyedStateHandle != null) {
+
+				KeyedStateHandle intersectedKeyedStateHandle = keyedStateHandle.getIntersection(rangeToExtract);
 
 				if (intersectedKeyedStateHandle != null) {
-					if (subtaskKeyedStateHandles == null) {
-						subtaskKeyedStateHandles = new ArrayList<>();
-					}
-					subtaskKeyedStateHandles.add(intersectedKeyedStateHandle);
+					extractedStateCollector.add(intersectedKeyedStateHandle);
 				}
 			}
 		}
-
-		return subtaskKeyedStateHandles;
 	}
 
 	/**
@@ -554,7 +586,7 @@ public class StateAssignmentOperation {
 			int newParallelism) {
 
 		if (chainOpParallelStates == null) {
-			return null;
+			return Collections.emptyList();
 		}
 
 		//We only redistribute if the parallelism of the operator changed from previous executions
@@ -567,20 +599,23 @@ public class StateAssignmentOperation {
 			List<Collection<OperatorStateHandle>> repackStream = new ArrayList<>(newParallelism);
 			for (OperatorStateHandle operatorStateHandle : chainOpParallelStates) {
 
-				Map<String, OperatorStateHandle.StateMetaInfo> partitionOffsets =
+				if (operatorStateHandle != null) {
+					Map<String, OperatorStateHandle.StateMetaInfo> partitionOffsets =
 						operatorStateHandle.getStateNameToPartitionOffsets();
 
-				for (OperatorStateHandle.StateMetaInfo metaInfo : partitionOffsets.values()) {
 
-					// if we find any broadcast state, we cannot take the shortcut and need to go through repartitioning
-					if (OperatorStateHandle.Mode.BROADCAST.equals(metaInfo.getDistributionMode())) {
-						return opStateRepartitioner.repartitionState(
+					for (OperatorStateHandle.StateMetaInfo metaInfo : partitionOffsets.values()) {
+
+						// if we find any broadcast state, we cannot take the shortcut and need to go through repartitioning
+						if (OperatorStateHandle.Mode.BROADCAST.equals(metaInfo.getDistributionMode())) {
+							return opStateRepartitioner.repartitionState(
 								chainOpParallelStates,
 								newParallelism);
+						}
 					}
-				}
 
-				repackStream.add(Collections.singletonList(operatorStateHandle));
+					repackStream.add(Collections.singletonList(operatorStateHandle));
+				}
 			}
 			return repackStream;
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateSnapshot.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateSnapshot.java
new file mode 100644
index 0000000..c416f3f
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateSnapshot.java
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.state.CompositeStateHandle;
+import org.apache.flink.runtime.state.SharedStateRegistry;
+import org.apache.flink.runtime.state.StateUtil;
+import org.apache.flink.util.Preconditions;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * This class encapsulates state handles to the snapshots of all operator instances executed within one task. A task
+ * can run multiple operator instances as a result of operator chaining, and all operator instances from the chain can
+ * register their state under their operator id. Each operator instance is a physical execution responsible for
+ * processing a partition of the data that goes through a logical operator. This partitioning happens to parallelize
+ * execution of logical operators, e.g. distributing a map function.
+ *
+ * <p>One instance of this class contains the information that one task will send to acknowledge a checkpoint request by
+ * the checkpoint coordinator. Tasks run operator instances in parallel, so the union of all
+ * {@link TaskStateSnapshot} that are collected by the checkpoint coordinator from all tasks represent the whole
+ * state of a job at the time of the checkpoint.
+ *
+ * <p>This class should be called TaskState once the old class with this name that we keep for backwards
+ * compatibility goes away.
+ */
+public class TaskStateSnapshot implements CompositeStateHandle {
+
+	private static final long serialVersionUID = 1L;
+
+	/** Mapping from an operator id to the state of one subtask of this operator */
+	private final Map<OperatorID, OperatorSubtaskState> subtaskStatesByOperatorID;
+
+	public TaskStateSnapshot() {
+		this(10);
+	}
+
+	public TaskStateSnapshot(int size) {
+		this(new HashMap<OperatorID, OperatorSubtaskState>(size));
+	}
+
+	public TaskStateSnapshot(Map<OperatorID, OperatorSubtaskState> subtaskStatesByOperatorID) {
+		this.subtaskStatesByOperatorID = Preconditions.checkNotNull(subtaskStatesByOperatorID);
+	}
+
+	/**
+	 * Returns the subtask state for the given operator id (or null if not contained).
+	 */
+	public OperatorSubtaskState getSubtaskStateByOperatorID(OperatorID operatorID) {
+		return subtaskStatesByOperatorID.get(operatorID);
+	}
+
+	/**
+	 * Maps the given operator id to the given subtask state. Returns the subtask state of a previous mapping, if such
+	 * a mapping existed or null otherwise.
+	 */
+	public OperatorSubtaskState putSubtaskStateByOperatorID(OperatorID operatorID, OperatorSubtaskState state) {
+		return subtaskStatesByOperatorID.put(operatorID, Preconditions.checkNotNull(state));
+	}
+
+	/**
+	 * Returns the set of all mappings from operator id to the corresponding subtask state.
+	 */
+	public Set<Map.Entry<OperatorID, OperatorSubtaskState>> getSubtaskStateMappings() {
+		return subtaskStatesByOperatorID.entrySet();
+	}
+
+	@Override
+	public void discardState() throws Exception {
+		StateUtil.bestEffortDiscardAllStateObjects(subtaskStatesByOperatorID.values());
+	}
+
+	@Override
+	public long getStateSize() {
+		long size = 0L;
+
+		for (OperatorSubtaskState subtaskState : subtaskStatesByOperatorID.values()) {
+			if (subtaskState != null) {
+				size += subtaskState.getStateSize();
+			}
+		}
+
+		return size;
+	}
+
+	@Override
+	public void registerSharedStates(SharedStateRegistry stateRegistry) {
+		for (OperatorSubtaskState operatorSubtaskState : subtaskStatesByOperatorID.values()) {
+			if (operatorSubtaskState != null) {
+				operatorSubtaskState.registerSharedStates(stateRegistry);
+			}
+		}
+	}
+
+	@Override
+	public boolean equals(Object o) {
+		if (this == o) {
+			return true;
+		}
+		if (o == null || getClass() != o.getClass()) {
+			return false;
+		}
+
+		TaskStateSnapshot that = (TaskStateSnapshot) o;
+
+		return subtaskStatesByOperatorID.equals(that.subtaskStatesByOperatorID);
+	}
+
+	@Override
+	public int hashCode() {
+		return subtaskStatesByOperatorID.hashCode();
+	}
+
+	@Override
+	public String toString() {
+		return "TaskOperatorSubtaskStates{" +
+			"subtaskStatesByOperatorID=" + subtaskStatesByOperatorID +
+			'}';
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Serializer.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Serializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Serializer.java
index 4cbbfcf..15628a0 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Serializer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Serializer.java
@@ -240,6 +240,18 @@ class SavepointV2Serializer implements SavepointSerializer<SavepointV2> {
 	//  task state (de)serialization methods
 	// ------------------------------------------------------------------------
 
+	private static <T> T extractSingleton(Collection<T> collection) {
+		if (collection == null || collection.isEmpty()) {
+			return null;
+		}
+
+		if (collection.size() == 1) {
+			return collection.iterator().next();
+		} else {
+			throw new IllegalStateException("Expected singleton collection, but found size: " + collection.size());
+		}
+	}
+
 	private static void serializeSubtaskState(OperatorSubtaskState subtaskState, DataOutputStream dos) throws IOException {
 
 		dos.writeLong(-1);
@@ -252,7 +264,7 @@ class SavepointV2Serializer implements SavepointSerializer<SavepointV2> {
 			serializeStreamStateHandle(nonPartitionableState, dos);
 		}
 
-		OperatorStateHandle operatorStateBackend = subtaskState.getManagedOperatorState();
+		OperatorStateHandle operatorStateBackend = extractSingleton(subtaskState.getManagedOperatorState());
 
 		len = operatorStateBackend != null ? 1 : 0;
 		dos.writeInt(len);
@@ -260,7 +272,7 @@ class SavepointV2Serializer implements SavepointSerializer<SavepointV2> {
 			serializeOperatorStateHandle(operatorStateBackend, dos);
 		}
 
-		OperatorStateHandle operatorStateFromStream = subtaskState.getRawOperatorState();
+		OperatorStateHandle operatorStateFromStream = extractSingleton(subtaskState.getRawOperatorState());
 
 		len = operatorStateFromStream != null ? 1 : 0;
 		dos.writeInt(len);
@@ -268,10 +280,10 @@ class SavepointV2Serializer implements SavepointSerializer<SavepointV2> {
 			serializeOperatorStateHandle(operatorStateFromStream, dos);
 		}
 
-		KeyedStateHandle keyedStateBackend = subtaskState.getManagedKeyedState();
+		KeyedStateHandle keyedStateBackend = extractSingleton(subtaskState.getManagedKeyedState());
 		serializeKeyedStateHandle(keyedStateBackend, dos);
 
-		KeyedStateHandle keyedStateStream = subtaskState.getRawKeyedState();
+		KeyedStateHandle keyedStateStream = extractSingleton(subtaskState.getRawKeyedState());
 		serializeKeyedStateHandle(keyedStateStream, dos);
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java
index 0578b78..1fa5eb5 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java
@@ -18,11 +18,11 @@
 
 package org.apache.flink.runtime.deployment;
 
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.executiongraph.JobInformation;
 import org.apache.flink.runtime.executiongraph.TaskInformation;
-import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.SerializedValue;
 
@@ -64,7 +64,7 @@ public final class TaskDeploymentDescriptor implements Serializable {
 	private final int targetSlotNumber;
 
 	/** State handles for the sub task. */
-	private final TaskStateHandles taskStateHandles;
+	private final TaskStateSnapshot taskStateHandles;
 
 	public TaskDeploymentDescriptor(
 			SerializedValue<JobInformation> serializedJobInformation,
@@ -74,7 +74,7 @@ public final class TaskDeploymentDescriptor implements Serializable {
 			int subtaskIndex,
 			int attemptNumber,
 			int targetSlotNumber,
-			TaskStateHandles taskStateHandles,
+			TaskStateSnapshot taskStateHandles,
 			Collection<ResultPartitionDeploymentDescriptor> resultPartitionDeploymentDescriptors,
 			Collection<InputGateDeploymentDescriptor> inputGateDeploymentDescriptors) {
 
@@ -153,7 +153,7 @@ public final class TaskDeploymentDescriptor implements Serializable {
 		return inputGates;
 	}
 
-	public TaskStateHandles getTaskStateHandles() {
+	public TaskStateSnapshot getTaskStateHandles() {
 		return taskStateHandles;
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java
index 9e9f7c4..203ee85 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java
@@ -18,8 +18,6 @@
 
 package org.apache.flink.runtime.execution;
 
-import java.util.Map;
-import java.util.concurrent.Future;
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.TaskInfo;
@@ -28,7 +26,7 @@ import org.apache.flink.core.fs.Path;
 import org.apache.flink.runtime.accumulators.AccumulatorRegistry;
 import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
-import org.apache.flink.runtime.checkpoint.SubtaskState;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.io.disk.iomanager.IOManager;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
@@ -41,6 +39,9 @@ import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.internal.InternalKvState;
 import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
 
+import java.util.Map;
+import java.util.concurrent.Future;
+
 /**
  * The Environment gives the code executed in a task access to the task's properties
  * (such as name, parallelism), the configurations, the data stream readers and writers,
@@ -175,7 +176,7 @@ public interface Environment {
 	 * @param checkpointMetrics metrics for this checkpoint
 	 * @param subtaskState All state handles for the checkpointed state
 	 */
-	void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, SubtaskState subtaskState);
+	void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, TaskStateSnapshot subtaskState);
 
 	/**
 	 * Declines a checkpoint. This tells the checkpoint coordinator that this task will

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
index bd5bc7f..2074820 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
@@ -24,6 +24,7 @@ import org.apache.flink.api.common.time.Time;
 import org.apache.flink.runtime.JobException;
 import org.apache.flink.runtime.accumulators.StringifiedAccumulatorResult;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.clusterframework.types.ResourceID;
 import org.apache.flink.runtime.concurrent.FutureUtils;
 import org.apache.flink.runtime.deployment.InputChannelDeploymentDescriptor;
@@ -41,7 +42,6 @@ import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup;
 import org.apache.flink.runtime.jobmanager.slots.TaskManagerGateway;
 import org.apache.flink.runtime.messages.Acknowledge;
 import org.apache.flink.runtime.messages.StackTraceSampleResponse;
-import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.taskmanager.TaskManagerLocation;
 import org.apache.flink.util.ExceptionUtils;
 
@@ -133,7 +133,7 @@ public class Execution implements AccessExecution, Archiveable<ArchivedExecution
 	private volatile Throwable failureCause;          // once assigned, never changes
 
 	/** The handle to the state that the task gets on restore */
-	private volatile TaskStateHandles taskState;
+	private volatile TaskStateSnapshot taskState;
 
 	// ------------------------ Accumulators & Metrics ------------------------
 
@@ -253,7 +253,7 @@ public class Execution implements AccessExecution, Archiveable<ArchivedExecution
 		return state.isTerminal();
 	}
 
-	public TaskStateHandles getTaskStateHandles() {
+	public TaskStateSnapshot getTaskStateSnapshot() {
 		return taskState;
 	}
 
@@ -263,7 +263,7 @@ public class Execution implements AccessExecution, Archiveable<ArchivedExecution
 	 *
 	 * @param checkpointStateHandles all checkpointed operator state
 	 */
-	public void setInitialState(TaskStateHandles checkpointStateHandles) {
+	public void setInitialState(TaskStateSnapshot checkpointStateHandles) {
 		checkState(state == CREATED, "Can only assign operator state when execution attempt is in CREATED");
 		this.taskState = checkpointStateHandles;
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
index 0ff71e7..9aac133 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
@@ -22,7 +22,9 @@ import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.Archiveable;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.time.Time;
+import org.apache.flink.configuration.JobManagerOptions;
 import org.apache.flink.runtime.JobException;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.deployment.InputChannelDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.PartialInputChannelDeploymentDescriptor;
@@ -38,11 +40,9 @@ import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
 import org.apache.flink.runtime.jobgraph.JobEdge;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
-import org.apache.flink.configuration.JobManagerOptions;
 import org.apache.flink.runtime.jobmanager.scheduler.CoLocationConstraint;
 import org.apache.flink.runtime.jobmanager.scheduler.CoLocationGroup;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
-import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.taskmanager.TaskManagerLocation;
 import org.apache.flink.runtime.util.EvictingBoundedList;
 import org.apache.flink.util.ExceptionUtils;
@@ -457,7 +457,7 @@ public class ExecutionVertex implements AccessExecutionVertex, Archiveable<Archi
 	 */
 	public Iterable<TaskManagerLocation> getPreferredLocationsBasedOnState() {
 		TaskManagerLocation priorLocation;
-		if (currentExecution.getTaskStateHandles() != null && (priorLocation = getLatestPriorLocation()) != null) {
+		if (currentExecution.getTaskStateSnapshot() != null && (priorLocation = getLatestPriorLocation()) != null) {
 			return Collections.singleton(priorLocation);
 		}
 		else {
@@ -719,7 +719,7 @@ public class ExecutionVertex implements AccessExecutionVertex, Archiveable<Archi
 	TaskDeploymentDescriptor createDeploymentDescriptor(
 			ExecutionAttemptID executionId,
 			SimpleSlot targetSlot,
-			TaskStateHandles taskStateHandles,
+			TaskStateSnapshot taskStateHandles,
 			int attemptNumber) throws ExecutionGraphException {
 		
 		// Produced intermediate results

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java
index 0930011..00db01f 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java
@@ -21,7 +21,7 @@ package org.apache.flink.runtime.jobgraph.tasks;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
-import org.apache.flink.runtime.state.TaskStateHandles;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 
 /**
  * This interface must be implemented by any invokable that has recoverable state and participates
@@ -35,7 +35,7 @@ public interface StatefulTask {
 	 *
 	 * @param taskStateHandles All state handle for the task.
 	 */
-	void setInitialState(TaskStateHandles taskStateHandles) throws Exception;
+	void setInitialState(TaskStateSnapshot taskStateHandles) throws Exception;
 
 	/**
 	 * This method is called to trigger a checkpoint, asynchronously by the checkpoint
@@ -43,8 +43,8 @@ public interface StatefulTask {
 	 * 
 	 * <p>This method is called for tasks that start the checkpoints by injecting the initial barriers,
 	 * i.e., the source tasks. In contrast, checkpoints on downstream operators, which are the result of
-	 * receiving checkpoint barriers, invoke the {@link #triggerCheckpointOnBarrier(CheckpointMetaData, CheckpointMetrics)}
-	 * method.
+	 * receiving checkpoint barriers, invoke the
+	 * {@link #triggerCheckpointOnBarrier(CheckpointMetaData, CheckpointOptions, CheckpointMetrics)} method.
 	 *
 	 * @param checkpointMetaData Meta data for about this checkpoint
 	 * @param checkpointOptions Options for performing this checkpoint

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobMaster.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobMaster.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobMaster.java
index 31036f6..25df19b 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobMaster.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobMaster.java
@@ -31,7 +31,7 @@ import org.apache.flink.metrics.groups.UnregisteredMetricsGroup;
 import org.apache.flink.runtime.checkpoint.CheckpointCoordinator;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointRecoveryFactory;
-import org.apache.flink.runtime.checkpoint.SubtaskState;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.client.JobExecutionException;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.clusterframework.types.ResourceID;
@@ -96,6 +96,7 @@ import org.apache.flink.util.Preconditions;
 import org.slf4j.Logger;
 
 import javax.annotation.Nullable;
+
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collection;
@@ -586,7 +587,7 @@ public class JobMaster extends RpcEndpoint implements JobMasterGateway {
 			final ExecutionAttemptID executionAttemptID,
 			final long checkpointId,
 			final CheckpointMetrics checkpointMetrics,
-			final SubtaskState checkpointState) {
+			final TaskStateSnapshot checkpointState) {
 
 		final CheckpointCoordinator checkpointCoordinator = executionGraph.getCheckpointCoordinator();
 		final AcknowledgeCheckpoint ackMessage = 

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/main/java/org/apache/flink/runtime/messages/checkpoint/AcknowledgeCheckpoint.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/messages/checkpoint/AcknowledgeCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/messages/checkpoint/AcknowledgeCheckpoint.java
index 9721c2c..65e3019 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/messages/checkpoint/AcknowledgeCheckpoint.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/messages/checkpoint/AcknowledgeCheckpoint.java
@@ -21,7 +21,7 @@ package org.apache.flink.runtime.messages.checkpoint;
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
-import org.apache.flink.runtime.checkpoint.SubtaskState;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 
 /**
@@ -36,7 +36,7 @@ public class AcknowledgeCheckpoint extends AbstractCheckpointMessage implements
 
 	private static final long serialVersionUID = -7606214777192401493L;
 
-	private final SubtaskState subtaskState;
+	private final TaskStateSnapshot subtaskState;
 
 	private final CheckpointMetrics checkpointMetrics;
 
@@ -47,7 +47,7 @@ public class AcknowledgeCheckpoint extends AbstractCheckpointMessage implements
 			ExecutionAttemptID taskExecutionId,
 			long checkpointId,
 			CheckpointMetrics checkpointMetrics,
-			SubtaskState subtaskState) {
+			TaskStateSnapshot subtaskState) {
 
 		super(job, taskExecutionId, checkpointId);
 
@@ -64,7 +64,7 @@ public class AcknowledgeCheckpoint extends AbstractCheckpointMessage implements
 	//  properties
 	// ------------------------------------------------------------------------
 
-	public SubtaskState getSubtaskState() {
+	public TaskStateSnapshot getSubtaskState() {
 		return subtaskState;
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
index d82af72..031d7c7 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
@@ -18,7 +18,6 @@
 
 package org.apache.flink.runtime.state;
 
-import org.apache.commons.io.IOUtils;
 import org.apache.flink.api.common.state.KeyedStateStore;
 import org.apache.flink.api.common.state.OperatorStateStore;
 import org.apache.flink.api.java.tuple.Tuple2;
@@ -26,6 +25,8 @@ import org.apache.flink.core.fs.CloseableRegistry;
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.util.Preconditions;
 
+import org.apache.commons.io.IOUtils;
+
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collection;
@@ -139,6 +140,7 @@ public class StateInitializationContextImpl implements StateInitializationContex
 	}
 
 	private static Collection<KeyGroupsStateHandle> transform(Collection<KeyedStateHandle> keyedStateHandles) {
+
 		if (keyedStateHandles == null) {
 			return null;
 		}
@@ -146,13 +148,14 @@ public class StateInitializationContextImpl implements StateInitializationContex
 		List<KeyGroupsStateHandle> keyGroupsStateHandles = new ArrayList<>();
 
 		for (KeyedStateHandle keyedStateHandle : keyedStateHandles) {
-			if (! (keyedStateHandle instanceof KeyGroupsStateHandle)) {
+
+			if (keyedStateHandle instanceof KeyGroupsStateHandle) {
+				keyGroupsStateHandles.add((KeyGroupsStateHandle) keyedStateHandle);
+			} else if (keyedStateHandle != null) {
 				throw new IllegalStateException("Unexpected state handle type, " +
 					"expected: " + KeyGroupsStateHandle.class +
 					", but found: " + keyedStateHandle.getClass() + ".");
 			}
-
-			keyGroupsStateHandles.add((KeyGroupsStateHandle) keyedStateHandle);
 		}
 
 		return keyGroupsStateHandles;

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java
deleted file mode 100644
index 2fde548..0000000
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java
+++ /dev/null
@@ -1,172 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.runtime.state;
-
-import org.apache.flink.runtime.checkpoint.SubtaskState;
-
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.List;
-
-/**
- * This class encapsulates all state handles for a task.
- */
-public class TaskStateHandles implements Serializable {
-
-	public static final TaskStateHandles EMPTY = new TaskStateHandles();
-
-	private static final long serialVersionUID = 267686583583579359L;
-
-	/**
-	 * State handle with the (non-partitionable) legacy operator state
-	 *
-	 * @deprecated Non-repartitionable operator state that has been deprecated.
-	 * Can be removed when we remove the APIs for non-repartitionable operator state.
-	 */
-	@Deprecated
-	private final ChainedStateHandle<StreamStateHandle> legacyOperatorState;
-
-	/** Collection of handles which represent the managed keyed state of the head operator */
-	private final Collection<KeyedStateHandle> managedKeyedState;
-
-	/** Collection of handles which represent the raw/streamed keyed state of the head operator */
-	private final Collection<KeyedStateHandle> rawKeyedState;
-
-	/** Outer list represents the operator chain, each collection holds handles for managed state of a single operator */
-	private final List<Collection<OperatorStateHandle>> managedOperatorState;
-
-	/** Outer list represents the operator chain, each collection holds handles for raw/streamed state of a single operator */
-	private final List<Collection<OperatorStateHandle>> rawOperatorState;
-
-	public TaskStateHandles() {
-		this(null, null, null, null, null);
-	}
-
-	public TaskStateHandles(SubtaskState checkpointStateHandles) {
-		this(checkpointStateHandles.getLegacyOperatorState(),
-				transform(checkpointStateHandles.getManagedOperatorState()),
-				transform(checkpointStateHandles.getRawOperatorState()),
-				transform(checkpointStateHandles.getManagedKeyedState()),
-				transform(checkpointStateHandles.getRawKeyedState()));
-	}
-
-	public TaskStateHandles(
-			ChainedStateHandle<StreamStateHandle> legacyOperatorState,
-			List<Collection<OperatorStateHandle>> managedOperatorState,
-			List<Collection<OperatorStateHandle>> rawOperatorState,
-			Collection<KeyedStateHandle> managedKeyedState,
-			Collection<KeyedStateHandle> rawKeyedState) {
-
-		this.legacyOperatorState = legacyOperatorState;
-		this.managedKeyedState = managedKeyedState;
-		this.rawKeyedState = rawKeyedState;
-		this.managedOperatorState = managedOperatorState;
-		this.rawOperatorState = rawOperatorState;
-	}
-
-	/**
-	 * @deprecated Non-repartitionable operator state that has been deprecated.
-	 * Can be removed when we remove the APIs for non-repartitionable operator state.
-	 */
-	@Deprecated
-	public ChainedStateHandle<StreamStateHandle> getLegacyOperatorState() {
-		return legacyOperatorState;
-	}
-
-	public Collection<KeyedStateHandle> getManagedKeyedState() {
-		return managedKeyedState;
-	}
-
-	public Collection<KeyedStateHandle> getRawKeyedState() {
-		return rawKeyedState;
-	}
-
-	public List<Collection<OperatorStateHandle>> getRawOperatorState() {
-		return rawOperatorState;
-	}
-
-	public List<Collection<OperatorStateHandle>> getManagedOperatorState() {
-		return managedOperatorState;
-	}
-
-	private static List<Collection<OperatorStateHandle>> transform(ChainedStateHandle<OperatorStateHandle> in) {
-		if (null == in) {
-			return Collections.emptyList();
-		}
-		List<Collection<OperatorStateHandle>> out = new ArrayList<>(in.getLength());
-		for (int i = 0; i < in.getLength(); ++i) {
-			OperatorStateHandle osh = in.get(i);
-			out.add(osh != null ? Collections.singletonList(osh) : null);
-		}
-		return out;
-	}
-
-	private static <T> List<T> transform(T in) {
-		return in == null ? Collections.<T>emptyList() : Collections.singletonList(in);
-	}
-
-	@Override
-	public boolean equals(Object o) {
-		if (this == o) {
-			return true;
-		}
-		if (o == null || getClass() != o.getClass()) {
-			return false;
-		}
-
-		TaskStateHandles that = (TaskStateHandles) o;
-
-		if (legacyOperatorState != null ?
-				!legacyOperatorState.equals(that.legacyOperatorState)
-				: that.legacyOperatorState != null) {
-			return false;
-		}
-		if (managedKeyedState != null ?
-				!managedKeyedState.equals(that.managedKeyedState)
-				: that.managedKeyedState != null) {
-			return false;
-		}
-		if (rawKeyedState != null ?
-				!rawKeyedState.equals(that.rawKeyedState)
-				: that.rawKeyedState != null) {
-			return false;
-		}
-
-		if (rawOperatorState != null ?
-				!rawOperatorState.equals(that.rawOperatorState)
-				: that.rawOperatorState != null) {
-			return false;
-		}
-		return managedOperatorState != null ?
-				managedOperatorState.equals(that.managedOperatorState)
-				: that.managedOperatorState == null;
-	}
-
-	@Override
-	public int hashCode() {
-		int result = legacyOperatorState != null ? legacyOperatorState.hashCode() : 0;
-		result = 31 * result + (managedKeyedState != null ? managedKeyedState.hashCode() : 0);
-		result = 31 * result + (rawKeyedState != null ? rawKeyedState.hashCode() : 0);
-		result = 31 * result + (managedOperatorState != null ? managedOperatorState.hashCode() : 0);
-		result = 31 * result + (rawOperatorState != null ? rawOperatorState.hashCode() : 0);
-		return result;
-	}
-}

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/rpc/RpcCheckpointResponder.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/rpc/RpcCheckpointResponder.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/rpc/RpcCheckpointResponder.java
index bf60161..aba8bda 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/rpc/RpcCheckpointResponder.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/rpc/RpcCheckpointResponder.java
@@ -21,7 +21,7 @@ package org.apache.flink.runtime.taskexecutor.rpc;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorGateway;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
-import org.apache.flink.runtime.checkpoint.SubtaskState;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.taskmanager.CheckpointResponder;
 import org.apache.flink.util.Preconditions;
@@ -40,7 +40,7 @@ public class RpcCheckpointResponder implements CheckpointResponder {
 			ExecutionAttemptID executionAttemptID,
 			long checkpointId,
 			CheckpointMetrics checkpointMetrics,
-			SubtaskState subtaskState) {
+			TaskStateSnapshot subtaskState) {
 
 		checkpointCoordinatorGateway.acknowledgeCheckpoint(
 			jobID,

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java
index ad0df71..e9f600d 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java
@@ -20,7 +20,7 @@ package org.apache.flink.runtime.taskmanager;
 
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
-import org.apache.flink.runtime.checkpoint.SubtaskState;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.instance.ActorGateway;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
@@ -44,7 +44,7 @@ public class ActorGatewayCheckpointResponder implements CheckpointResponder {
 			ExecutionAttemptID executionAttemptID,
 			long checkpointId,
 			CheckpointMetrics checkpointMetrics,
-			SubtaskState checkpointStateHandles) {
+			TaskStateSnapshot checkpointStateHandles) {
 
 		AcknowledgeCheckpoint message = new AcknowledgeCheckpoint(
 				jobID, executionAttemptID, checkpointId, checkpointMetrics,

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java
index cc66a3f..b3584a6 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java
@@ -20,7 +20,7 @@ package org.apache.flink.runtime.taskmanager;
 
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
-import org.apache.flink.runtime.checkpoint.SubtaskState;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 
 /**
@@ -47,7 +47,7 @@ public interface CheckpointResponder {
 		ExecutionAttemptID executionAttemptID,
 		long checkpointId,
 		CheckpointMetrics checkpointMetrics,
-		SubtaskState subtaskState);
+		TaskStateSnapshot subtaskState);
 
 	/**
 	 * Declines the given checkpoint.

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java
index 788a590..92b5886 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java
@@ -26,7 +26,7 @@ import org.apache.flink.core.fs.Path;
 import org.apache.flink.runtime.accumulators.AccumulatorRegistry;
 import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
-import org.apache.flink.runtime.checkpoint.SubtaskState;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.io.disk.iomanager.IOManager;
@@ -245,7 +245,7 @@ public class RuntimeEnvironment implements Environment {
 	public void acknowledgeCheckpoint(
 			long checkpointId,
 			CheckpointMetrics checkpointMetrics,
-			SubtaskState checkpointStateHandles) {
+			TaskStateSnapshot checkpointStateHandles) {
 
 		checkpointResponder.acknowledgeCheckpoint(
 				jobId, executionId, checkpointId, checkpointMetrics,

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
index 596d365..04cb990 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
@@ -34,6 +34,7 @@ import org.apache.flink.runtime.blob.BlobKey;
 import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.checkpoint.decline.CheckpointDeclineTaskNotCheckpointingException;
 import org.apache.flink.runtime.checkpoint.decline.CheckpointDeclineTaskNotReadyException;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
@@ -67,16 +68,17 @@ import org.apache.flink.runtime.jobmanager.PartitionProducerDisposedException;
 import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.metrics.groups.TaskMetricGroup;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
-import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.util.ExceptionUtils;
 import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.SerializedValue;
 import org.apache.flink.util.WrappingRuntimeException;
+
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
+
 import java.io.IOException;
 import java.net.URL;
 import java.util.Collection;
@@ -250,7 +252,7 @@ public class Task implements Runnable, TaskActions {
 	 * The handles to the states that the task was initialized with. Will be set
 	 * to null after the initialization, to be memory friendly.
 	 */
-	private volatile TaskStateHandles taskStateHandles;
+	private volatile TaskStateSnapshot taskStateHandles;
 
 	/** Initialized from the Flink configuration. May also be set at the ExecutionConfig */
 	private long taskCancellationInterval;
@@ -272,7 +274,7 @@ public class Task implements Runnable, TaskActions {
 		Collection<ResultPartitionDeploymentDescriptor> resultPartitionDeploymentDescriptors,
 		Collection<InputGateDeploymentDescriptor> inputGateDeploymentDescriptors,
 		int targetSlotNumber,
-		TaskStateHandles taskStateHandles,
+		TaskStateSnapshot taskStateHandles,
 		MemoryManager memManager,
 		IOManager ioManager,
 		NetworkEnvironment networkEnvironment,


[4/7] flink git commit: [FLINK-7213] Introduce state management by OperatorID in TaskManager

Posted by sr...@apache.org.
http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
index 923b912..09e9a1b 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
@@ -32,7 +32,9 @@ import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
 import org.apache.flink.runtime.checkpoint.SubtaskState;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
@@ -49,6 +51,7 @@ import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
 import org.apache.flink.runtime.memory.MemoryManager;
@@ -56,7 +59,6 @@ import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
-import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.DoneFuture;
 import org.apache.flink.runtime.state.KeyGroupRange;
@@ -65,7 +67,6 @@ import org.apache.flink.runtime.state.OperatorStateBackend;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateBackendFactory;
 import org.apache.flink.runtime.state.StreamStateHandle;
-import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.taskmanager.CheckpointResponder;
 import org.apache.flink.runtime.taskmanager.Task;
 import org.apache.flink.runtime.taskmanager.TaskExecutionState;
@@ -128,6 +129,7 @@ import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyCollectionOf;
 import static org.mockito.Matchers.anyLong;
 import static org.mockito.Matchers.anyString;
 import static org.mockito.Matchers.eq;
@@ -158,6 +160,7 @@ public class StreamTaskTest extends TestLogger {
 	public void testEarlyCanceling() throws Exception {
 		Deadline deadline = new FiniteDuration(2, TimeUnit.MINUTES).fromNow();
 		StreamConfig cfg = new StreamConfig(new Configuration());
+		cfg.setOperatorID(new OperatorID(4711L, 42L));
 		cfg.setStreamOperator(new SlowlyDeserializingOperator());
 		cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
 
@@ -203,6 +206,7 @@ public class StreamTaskTest extends TestLogger {
 		taskManagerConfig.setString(CoreOptions.STATE_BACKEND, MockStateBackend.class.getName());
 
 		StreamConfig cfg = new StreamConfig(new Configuration());
+		cfg.setOperatorID(new OperatorID(4711L, 42L));
 		cfg.setStreamOperator(new StreamSource<>(new MockSourceFunction()));
 		cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
 
@@ -227,6 +231,7 @@ public class StreamTaskTest extends TestLogger {
 		taskManagerConfig.setString(CoreOptions.STATE_BACKEND, MockStateBackend.class.getName());
 
 		StreamConfig cfg = new StreamConfig(new Configuration());
+		cfg.setOperatorID(new OperatorID(4711L, 42L));
 		cfg.setStreamOperator(new StreamSource<>(new MockSourceFunction()));
 		cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
 
@@ -324,6 +329,13 @@ public class StreamTaskTest extends TestLogger {
 		when(streamOperator2.snapshotLegacyOperatorState(anyLong(), anyLong(), any(CheckpointOptions.class))).thenReturn(streamStateHandle2);
 		when(streamOperator3.snapshotLegacyOperatorState(anyLong(), anyLong(), any(CheckpointOptions.class))).thenReturn(streamStateHandle3);
 
+		OperatorID operatorID1 = new OperatorID();
+		OperatorID operatorID2 = new OperatorID();
+		OperatorID operatorID3 = new OperatorID();
+		when(streamOperator1.getOperatorID()).thenReturn(operatorID1);
+		when(streamOperator2.getOperatorID()).thenReturn(operatorID2);
+		when(streamOperator3.getOperatorID()).thenReturn(operatorID3);
+
 		// set up the task
 
 		StreamOperator<?>[] streamOperators = {streamOperator1, streamOperator2, streamOperator3};
@@ -399,6 +411,13 @@ public class StreamTaskTest extends TestLogger {
 		when(streamOperator2.snapshotLegacyOperatorState(anyLong(), anyLong(), any(CheckpointOptions.class))).thenReturn(streamStateHandle2);
 		when(streamOperator3.snapshotLegacyOperatorState(anyLong(), anyLong(), any(CheckpointOptions.class))).thenReturn(streamStateHandle3);
 
+		OperatorID operatorID1 = new OperatorID();
+		OperatorID operatorID2 = new OperatorID();
+		OperatorID operatorID3 = new OperatorID();
+		when(streamOperator1.getOperatorID()).thenReturn(operatorID1);
+		when(streamOperator2.getOperatorID()).thenReturn(operatorID2);
+		when(streamOperator3.getOperatorID()).thenReturn(operatorID3);
+
 		StreamOperator<?>[] streamOperators = {streamOperator1, streamOperator2, streamOperator3};
 
 		OperatorChain<Void, AbstractStreamOperator<Void>> operatorChain = mock(OperatorChain.class);
@@ -455,7 +474,7 @@ public class StreamTaskTest extends TestLogger {
 
 				return null;
 			}
-		}).when(mockEnvironment).acknowledgeCheckpoint(anyLong(), any(CheckpointMetrics.class), any(SubtaskState.class));
+		}).when(mockEnvironment).acknowledgeCheckpoint(anyLong(), any(CheckpointMetrics.class), any(TaskStateSnapshot.class));
 
 		StreamTask<?, AbstractStreamOperator<?>> streamTask = mock(StreamTask.class, Mockito.CALLS_REAL_METHODS);
 		CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, timestamp);
@@ -505,18 +524,19 @@ public class StreamTaskTest extends TestLogger {
 
 		acknowledgeCheckpointLatch.await();
 
-		ArgumentCaptor<SubtaskState> subtaskStateCaptor = ArgumentCaptor.forClass(SubtaskState.class);
+		ArgumentCaptor<TaskStateSnapshot> subtaskStateCaptor = ArgumentCaptor.forClass(TaskStateSnapshot.class);
 
 		// check that the checkpoint has been completed
 		verify(mockEnvironment).acknowledgeCheckpoint(eq(checkpointId), any(CheckpointMetrics.class), subtaskStateCaptor.capture());
 
-		SubtaskState subtaskState = subtaskStateCaptor.getValue();
+		TaskStateSnapshot subtaskStates = subtaskStateCaptor.getValue();
+		OperatorSubtaskState subtaskState = subtaskStates.getSubtaskStateMappings().iterator().next().getValue();
 
 		// check that the subtask state contains the expected state handles
-		assertEquals(managedKeyedStateHandle, subtaskState.getManagedKeyedState());
-		assertEquals(rawKeyedStateHandle, subtaskState.getRawKeyedState());
-		assertEquals(new ChainedStateHandle<>(Collections.singletonList(managedOperatorStateHandle)), subtaskState.getManagedOperatorState());
-		assertEquals(new ChainedStateHandle<>(Collections.singletonList(rawOperatorStateHandle)), subtaskState.getRawOperatorState());
+		assertEquals(Collections.singletonList(managedKeyedStateHandle), subtaskState.getManagedKeyedState());
+		assertEquals(Collections.singletonList(rawKeyedStateHandle), subtaskState.getRawKeyedState());
+		assertEquals(Collections.singletonList(managedOperatorStateHandle), subtaskState.getManagedOperatorState());
+		assertEquals(Collections.singletonList(rawOperatorStateHandle), subtaskState.getRawOperatorState());
 
 		// check that the state handles have not been discarded
 		verify(managedKeyedStateHandle, never()).discardState();
@@ -558,18 +578,26 @@ public class StreamTaskTest extends TestLogger {
 		Environment mockEnvironment = mock(Environment.class);
 		when(mockEnvironment.getTaskInfo()).thenReturn(mockTaskInfo);
 
-		whenNew(SubtaskState.class).withAnyArguments().thenAnswer(new Answer<SubtaskState>() {
-			@Override
-			public SubtaskState answer(InvocationOnMock invocation) throws Throwable {
+		whenNew(OperatorSubtaskState.class).
+			withArguments(
+				any(StreamStateHandle.class),
+				anyCollectionOf(OperatorStateHandle.class),
+				anyCollectionOf(OperatorStateHandle.class),
+				anyCollectionOf(KeyedStateHandle.class),
+				anyCollectionOf(KeyedStateHandle.class)).
+			thenAnswer(new Answer<OperatorSubtaskState>() {
+				@Override
+			public OperatorSubtaskState answer(InvocationOnMock invocation) throws Throwable {
 				createSubtask.trigger();
 				completeSubtask.await();
-
-				return new SubtaskState(
-					(ChainedStateHandle<StreamStateHandle>) invocation.getArguments()[0],
-					(ChainedStateHandle<OperatorStateHandle>) invocation.getArguments()[1],
-					(ChainedStateHandle<OperatorStateHandle>) invocation.getArguments()[2],
-					(KeyedStateHandle) invocation.getArguments()[3],
-					(KeyedStateHandle) invocation.getArguments()[4]);
+				Object[] arguments = invocation.getArguments();
+				return new OperatorSubtaskState(
+					(StreamStateHandle) arguments[0],
+					(OperatorStateHandle) arguments[1],
+					(OperatorStateHandle) arguments[2],
+					(KeyedStateHandle) arguments[3],
+					(KeyedStateHandle) arguments[4]
+				);
 			}
 		});
 
@@ -577,7 +605,9 @@ public class StreamTaskTest extends TestLogger {
 		CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, timestamp);
 		streamTask.setEnvironment(mockEnvironment);
 
-		StreamOperator<?> streamOperator = mock(StreamOperator.class, withSettings().extraInterfaces(StreamCheckpointedOperator.class));
+		final StreamOperator<?> streamOperator = mock(StreamOperator.class, withSettings().extraInterfaces(StreamCheckpointedOperator.class));
+		final OperatorID operatorID = new OperatorID();
+		when(streamOperator.getOperatorID()).thenReturn(operatorID);
 
 		KeyedStateHandle managedKeyedStateHandle = mock(KeyedStateHandle.class);
 		KeyedStateHandle rawKeyedStateHandle = mock(KeyedStateHandle.class);
@@ -636,7 +666,7 @@ public class StreamTaskTest extends TestLogger {
 		}
 
 		// check that the checkpoint has not been acknowledged
-		verify(mockEnvironment, never()).acknowledgeCheckpoint(eq(checkpointId), any(CheckpointMetrics.class), any(SubtaskState.class));
+		verify(mockEnvironment, never()).acknowledgeCheckpoint(eq(checkpointId), any(CheckpointMetrics.class), any(TaskStateSnapshot.class));
 
 		// check that the state handles have been discarded
 		verify(managedKeyedStateHandle).discardState();
@@ -676,7 +706,7 @@ public class StreamTaskTest extends TestLogger {
 				checkpointCompletedLatch.trigger();
 				return null;
 			}
-		}).when(mockEnvironment).acknowledgeCheckpoint(anyLong(), any(CheckpointMetrics.class), any(SubtaskState.class));
+		}).when(mockEnvironment).acknowledgeCheckpoint(anyLong(), any(CheckpointMetrics.class), any(TaskStateSnapshot.class));
 
 		when(mockEnvironment.getTaskInfo()).thenReturn(mockTaskInfo);
 
@@ -688,6 +718,9 @@ public class StreamTaskTest extends TestLogger {
 		StreamOperator<?> statelessOperator =
 				mock(StreamOperator.class, withSettings().extraInterfaces(StreamCheckpointedOperator.class));
 
+		final OperatorID operatorID = new OperatorID();
+		when(statelessOperator.getOperatorID()).thenReturn(operatorID);
+
 		// mock the returned empty snapshot result (all state handles are null)
 		OperatorSnapshotResult statelessOperatorSnapshotResult = new OperatorSnapshotResult();
 		when(statelessOperator.snapshotState(anyLong(), anyLong(), any(CheckpointOptions.class)))
@@ -803,7 +836,7 @@ public class StreamTaskTest extends TestLogger {
 			Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
 			Collections.<InputGateDeploymentDescriptor>emptyList(),
 			0,
-			new TaskStateHandles(),
+			new TaskStateSnapshot(),
 			mock(MemoryManager.class),
 			mock(IOManager.class),
 			network,

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
index a02fe4e..19d48e1 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
@@ -24,6 +24,7 @@ import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.event.AbstractEvent;
 import org.apache.flink.runtime.io.network.partition.consumer.StreamTestSingleInputGate;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
@@ -142,6 +143,7 @@ public class StreamTaskTestHarness<OUT> {
 		streamConfig.setNumberOfOutputs(1);
 		streamConfig.setTypeSerializerOut(outputSerializer);
 		streamConfig.setVertexID(0);
+		streamConfig.setOperatorID(new OperatorID(4711L, 123L));
 
 		StreamOperator<OUT> dummyOperator = new AbstractStreamOperator<OUT>() {
 			private static final long serialVersionUID = 1L;

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java
index 66531ac..d785c0d 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java
@@ -23,6 +23,7 @@ import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.streaming.api.functions.co.CoMapFunction;
 import org.apache.flink.streaming.api.functions.co.RichCoMapFunction;
 import org.apache.flink.streaming.api.graph.StreamConfig;
@@ -64,6 +65,7 @@ public class TwoInputStreamTaskTest {
 		StreamConfig streamConfig = testHarness.getStreamConfig();
 		CoStreamMap<String, Integer, String> coMapOperator = new CoStreamMap<String, Integer, String>(new TestOpenCloseMapFunction());
 		streamConfig.setStreamOperator(coMapOperator);
+		streamConfig.setOperatorID(new OperatorID());
 
 		long initialTime = 0L;
 		ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<Object>();
@@ -110,6 +112,7 @@ public class TwoInputStreamTaskTest {
 		StreamConfig streamConfig = testHarness.getStreamConfig();
 		CoStreamMap<String, Integer, String> coMapOperator = new CoStreamMap<String, Integer, String>(new IdentityMap());
 		streamConfig.setStreamOperator(coMapOperator);
+		streamConfig.setOperatorID(new OperatorID());
 
 		ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<Object>();
 		long initialTime = 0L;
@@ -216,6 +219,7 @@ public class TwoInputStreamTaskTest {
 		StreamConfig streamConfig = testHarness.getStreamConfig();
 		CoStreamMap<String, Integer, String> coMapOperator = new CoStreamMap<String, Integer, String>(new IdentityMap());
 		streamConfig.setStreamOperator(coMapOperator);
+		streamConfig.setOperatorID(new OperatorID());
 
 		ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<Object>();
 		long initialTime = 0L;
@@ -296,6 +300,7 @@ public class TwoInputStreamTaskTest {
 		StreamConfig streamConfig = testHarness.getStreamConfig();
 		CoStreamMap<String, Integer, String> coMapOperator = new CoStreamMap<String, Integer, String>(new IdentityMap());
 		streamConfig.setStreamOperator(coMapOperator);
+		streamConfig.setOperatorID(new OperatorID());
 
 		ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<Object>();
 		long initialTime = 0L;

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
index 47e8726..15802353 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
@@ -32,9 +32,11 @@ import org.apache.flink.migration.streaming.runtime.tasks.StreamTaskState;
 import org.apache.flink.migration.util.MigrationInstantiationUtil;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.OperatorStateRepartitioner;
+import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
 import org.apache.flink.runtime.checkpoint.RoundRobinOperatorStateRepartitioner;
 import org.apache.flink.runtime.checkpoint.StateAssignmentOperation;
 import org.apache.flink.runtime.execution.Environment;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.operators.testutils.MockEnvironment;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
@@ -154,6 +156,7 @@ public class AbstractStreamOperatorTestHarness<OUT> {
 		Configuration underlyingConfig = environment.getTaskConfiguration();
 		this.config = new StreamConfig(underlyingConfig);
 		this.config.setCheckpointingEnabled(true);
+		this.config.setOperatorID(new OperatorID());
 		this.executionConfig = environment.getExecutionConfig();
 		this.closableRegistry = new CloseableRegistry();
 		this.checkpointLock = new Object();
@@ -336,7 +339,7 @@ public class AbstractStreamOperatorTestHarness<OUT> {
 	}
 
 	/**
-	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#initializeState(OperatorStateHandles)}.
+	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#initializeState(OperatorSubtaskState)}.
 	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#setup(StreamTask, StreamConfig, Output)}
 	 * if it was not called before.
 	 *
@@ -393,13 +396,12 @@ public class AbstractStreamOperatorTestHarness<OUT> {
 					rawOperatorState,
 					numSubtasks).get(subtaskIndex);
 
-			OperatorStateHandles massagedOperatorStateHandles = new OperatorStateHandles(
-					0,
-					operatorStateHandles.getLegacyOperatorState(),
-					localManagedKeyGroupState,
-					localRawKeyGroupState,
-					localManagedOperatorState,
-					localRawOperatorState);
+			OperatorSubtaskState massagedOperatorStateHandles = new OperatorSubtaskState(
+				operatorStateHandles.getLegacyOperatorState(),
+				nullToEmptyCollection(localManagedOperatorState),
+				nullToEmptyCollection(localRawOperatorState),
+				nullToEmptyCollection(localManagedKeyGroupState),
+				nullToEmptyCollection(localRawKeyGroupState));
 
 			operator.initializeState(massagedOperatorStateHandles);
 		} else {
@@ -408,6 +410,10 @@ public class AbstractStreamOperatorTestHarness<OUT> {
 		initializeCalled = true;
 	}
 
+	private static <T> Collection<T> nullToEmptyCollection(Collection<T> collection) {
+		return collection != null ? collection : Collections.<T>emptyList();
+	}
+
 	/**
 	 * Takes the different {@link OperatorStateHandles} created by calling {@link #snapshot(long, long)}
 	 * on different instances of {@link AbstractStreamOperatorTestHarness} (each one representing one subtask)

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java
index bf1bb1b..cc23545 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java
@@ -354,7 +354,7 @@ public class SavepointITCase extends TestLogger {
 
 					errMsg = "Initial operator state mismatch.";
 					assertEquals(errMsg, subtaskState.getLegacyOperatorState(),
-						tdd.getTaskStateHandles().getLegacyOperatorState().get(chainIndexAndJobVertex.f0));
+						tdd.getTaskStateHandles().getSubtaskStateByOperatorID(operatorState.getOperatorID()).getLegacyOperatorState());
 				}
 			}
 


[2/7] flink git commit: [FLINK-7268] [checkpoints] Scope SharedStateRegistry objects per (re)start

Posted by sr...@apache.org.
http://git-wip-us.apache.org/repos/asf/flink/blob/91a4b276/flink-tests/src/test/java/org/apache/flink/test/checkpointing/HAIncrementalRocksDbBackendEventTimeWindowCheckpointingITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/HAIncrementalRocksDbBackendEventTimeWindowCheckpointingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/HAIncrementalRocksDbBackendEventTimeWindowCheckpointingITCase.java
new file mode 100644
index 0000000..394815f
--- /dev/null
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/HAIncrementalRocksDbBackendEventTimeWindowCheckpointingITCase.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.test.checkpointing;
+
+/**
+ * Integration tests for incremental RocksDB backend.
+ */
+public class HAIncrementalRocksDbBackendEventTimeWindowCheckpointingITCase extends AbstractEventTimeWindowCheckpointingITCase {
+
+	public HAIncrementalRocksDbBackendEventTimeWindowCheckpointingITCase() {
+		super(StateBackendEnum.ROCKSDB_INCREMENTAL_ZK);
+	}
+
+	@Override
+	protected int numElementsPerKey() {
+		return 3000;
+	}
+
+	@Override
+	protected int windowSize() {
+		return 1000;
+	}
+
+	@Override
+	protected int windowSlide() {
+		return 100;
+	}
+
+	@Override
+	protected int numKeys() {
+		return 100;
+	}
+}


[3/7] flink git commit: [FLINK-7268] [checkpoints] Scope SharedStateRegistry objects per (re)start

Posted by sr...@apache.org.
[FLINK-7268] [checkpoints] Scope SharedStateRegistry objects per (re)start


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/91a4b276
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/91a4b276
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/91a4b276

Branch: refs/heads/master
Commit: 91a4b276171afb760bfff9ccf30593e648e91dfb
Parents: b71154a
Author: Stefan Richter <s....@data-artisans.com>
Authored: Tue Jul 25 12:04:16 2017 +0200
Committer: Stefan Richter <s....@data-artisans.com>
Committed: Tue Aug 15 14:56:54 2017 +0200

----------------------------------------------------------------------
 .../state/RocksDBKeyedStateBackend.java         |  13 +-
 .../checkpoint/CheckpointCoordinator.java       |  32 +-
 .../runtime/checkpoint/CompletedCheckpoint.java |   2 +
 .../checkpoint/CompletedCheckpointStore.java    |   5 +-
 .../StandaloneCompletedCheckpointStore.java     |   4 +-
 .../ZooKeeperCompletedCheckpointStore.java      |  12 +-
 .../runtime/executiongraph/ExecutionGraph.java  |   6 +-
 .../executiongraph/ExecutionJobVertex.java      |   2 +-
 .../flink/runtime/jobmaster/JobMaster.java      |   2 +-
 .../state/IncrementalKeyedStateHandle.java      |  68 ++--
 .../runtime/state/KeyGroupsStateHandle.java     |   2 +-
 .../runtime/state/MultiStreamStateHandle.java   |  10 +-
 .../runtime/state/SharedStateRegistry.java      |  52 ++-
 .../state/SharedStateRegistryFactory.java       |  35 ++
 .../state/memory/ByteStreamStateHandle.java     |   1 +
 ...tCoordinatorExternalizedCheckpointsTest.java |  22 +-
 .../CheckpointCoordinatorFailureTest.java       |   7 +-
 .../CheckpointCoordinatorMasterHooksTest.java   |   7 +-
 .../checkpoint/CheckpointCoordinatorTest.java   | 341 +++++++++++++++++--
 .../checkpoint/CheckpointStateRestoreTest.java  |  10 +-
 ...ZooKeeperCompletedCheckpointStoreITCase.java |  25 +-
 .../ZooKeeperCompletedCheckpointStoreTest.java  |   7 +-
 .../state/IncrementalKeyedStateHandleTest.java  |  75 +++-
 .../RecoverableCompletedCheckpointStore.java    |  33 +-
 .../streaming/runtime/tasks/StreamTask.java     |  21 +-
 ...tractEventTimeWindowCheckpointingITCase.java |  76 +++--
 ...ckendEventTimeWindowCheckpointingITCase.java |  49 +++
 27 files changed, 743 insertions(+), 176 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/91a4b276/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index bba5b55..756cfdd 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -253,7 +253,10 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		this.restoredKvStateMetaInfos = new HashMap<>();
 		this.materializedSstFiles = new TreeMap<>();
 		this.backendUID = UUID.randomUUID();
-		LOG.debug("Setting initial keyed backend uid for operator {} to {}.", this.operatorIdentifier, this.backendUID);
+
+		LOG.debug("Setting initial backend ID in RocksDBKeyedStateBackend for operator {} to {}.",
+			this.operatorIdentifier,
+			this.backendUID);
 	}
 
 	/**
@@ -883,11 +886,17 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		void takeSnapshot() throws Exception {
 			assert (Thread.holdsLock(stateBackend.asyncSnapshotLock));
 
+			final long lastCompletedCheckpoint;
+
 			// use the last completed checkpoint as the comparison base.
 			synchronized (stateBackend.materializedSstFiles) {
-				baseSstFiles = stateBackend.materializedSstFiles.get(stateBackend.lastCompletedCheckpointId);
+				lastCompletedCheckpoint = stateBackend.lastCompletedCheckpointId;
+				baseSstFiles = stateBackend.materializedSstFiles.get(lastCompletedCheckpoint);
 			}
 
+			LOG.trace("Taking incremental snapshot for checkpoint {}. Snapshot is based on last completed checkpoint {} " +
+				"assuming the following (shared) files as base: {}.", checkpointId, lastCompletedCheckpoint, baseSstFiles);
+
 			// save meta data
 			for (Map.Entry<String, Tuple2<ColumnFamilyHandle, RegisteredKeyedBackendStateMetaInfo<?, ?>>> stateMetaInfoEntry
 					: stateBackend.kvStateInformation.entrySet()) {

http://git-wip-us.apache.org/repos/asf/flink/blob/91a4b276/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
index 0b64a73..c98d3aa 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
@@ -40,6 +40,7 @@ import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint;
 import org.apache.flink.runtime.state.SharedStateRegistry;
+import org.apache.flink.runtime.state.SharedStateRegistryFactory;
 import org.apache.flink.runtime.taskmanager.DispatcherThreadFactory;
 import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.StringUtils;
@@ -174,8 +175,11 @@ public class CheckpointCoordinator {
 	@Nullable
 	private CheckpointStatsTracker statsTracker;
 
+	/** A factory for SharedStateRegistry objects */
+	private final SharedStateRegistryFactory sharedStateRegistryFactory;
+
 	/** Registry that tracks state which is shared across (incremental) checkpoints */
-	private final SharedStateRegistry sharedStateRegistry;
+	private SharedStateRegistry sharedStateRegistry;
 
 	// --------------------------------------------------------------------------------------------
 
@@ -192,7 +196,8 @@ public class CheckpointCoordinator {
 			CheckpointIDCounter checkpointIDCounter,
 			CompletedCheckpointStore completedCheckpointStore,
 			@Nullable String checkpointDirectory,
-			Executor executor) {
+			Executor executor,
+			SharedStateRegistryFactory sharedStateRegistryFactory) {
 
 		// sanity checks
 		checkArgument(baseInterval > 0, "Checkpoint timeout must be larger than zero");
@@ -230,7 +235,8 @@ public class CheckpointCoordinator {
 		this.completedCheckpointStore = checkNotNull(completedCheckpointStore);
 		this.checkpointDirectory = checkpointDirectory;
 		this.executor = checkNotNull(executor);
-		this.sharedStateRegistry = new SharedStateRegistry(executor);
+		this.sharedStateRegistryFactory = checkNotNull(sharedStateRegistryFactory);
+		this.sharedStateRegistry = sharedStateRegistryFactory.create(executor);
 
 		this.recentPendingCheckpoints = new ArrayDeque<>(NUM_GHOST_CHECKPOINT_IDS);
 		this.masterHooks = new HashMap<>();
@@ -1043,10 +1049,23 @@ public class CheckpointCoordinator {
 				throw new IllegalStateException("CheckpointCoordinator is shut down");
 			}
 
-			// Recover the checkpoints
-			completedCheckpointStore.recover(sharedStateRegistry);
+			// We create a new shared state registry object, so that all pending async disposal requests from previous
+			// runs will go against the old object (were they can do no harm).
+			// This must happen under the checkpoint lock.
+			sharedStateRegistry.close();
+			sharedStateRegistry = sharedStateRegistryFactory.create(executor);
+
+			// Recover the checkpoints, TODO this could be done only when there is a new leader, not on each recovery
+			completedCheckpointStore.recover();
+
+			// Now, we re-register all (shared) states from the checkpoint store with the new registry
+			for (CompletedCheckpoint completedCheckpoint : completedCheckpointStore.getAllCheckpoints()) {
+				completedCheckpoint.registerSharedStatesAfterRestored(sharedStateRegistry);
+			}
+
+			LOG.debug("Status of the shared state registry after restore: {}.", sharedStateRegistry);
 
-			// restore from the latest checkpoint
+			// Restore from the latest checkpoint
 			CompletedCheckpoint latest = completedCheckpointStore.getLatestCheckpoint();
 
 			if (latest == null) {
@@ -1120,7 +1139,6 @@ public class CheckpointCoordinator {
 		CompletedCheckpoint savepoint = SavepointLoader.loadAndValidateSavepoint(
 				job, tasks, savepointPath, userClassLoader, allowNonRestored);
 
-		savepoint.registerSharedStatesAfterRestored(sharedStateRegistry);
 		completedCheckpointStore.addCheckpoint(savepoint);
 		
 		// Reset the checkpoint ID counter

http://git-wip-us.apache.org/repos/asf/flink/blob/91a4b276/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
index 7c3edee..d3f61e4 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
@@ -209,6 +209,8 @@ public class CompletedCheckpoint implements Serializable {
 
 	private void doDiscard() throws Exception {
 
+		LOG.trace("Executing discard procedure for {}.", this);
+
 		try {
 			// collect exceptions and continue cleanup
 			Exception exception = null;

http://git-wip-us.apache.org/repos/asf/flink/blob/91a4b276/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java
index 45d407e..82193b5 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java
@@ -19,7 +19,6 @@
 package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.runtime.jobgraph.JobStatus;
-import org.apache.flink.runtime.state.SharedStateRegistry;
 
 import java.util.List;
 
@@ -33,10 +32,8 @@ public interface CompletedCheckpointStore {
 	 *
 	 * <p>After a call to this method, {@link #getLatestCheckpoint()} returns the latest
 	 * available checkpoint.
-	 *
-	 * @param sharedStateRegistry the shared state registry to register recovered states.
 	 */
-	void recover(SharedStateRegistry sharedStateRegistry) throws Exception;
+	void recover() throws Exception;
 
 	/**
 	 * Adds a {@link CompletedCheckpoint} instance to the list of completed checkpoints.

http://git-wip-us.apache.org/repos/asf/flink/blob/91a4b276/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java
index fbb0198..63e7468 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java
@@ -20,7 +20,7 @@ package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.runtime.jobgraph.JobStatus;
 import org.apache.flink.runtime.jobmanager.HighAvailabilityMode;
-import org.apache.flink.runtime.state.SharedStateRegistry;
+
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -57,7 +57,7 @@ public class StandaloneCompletedCheckpointStore implements CompletedCheckpointSt
 	}
 
 	@Override
-	public void recover(SharedStateRegistry sharedStateRegistry) throws Exception {
+	public void recover() throws Exception {
 		// Nothing to do
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/91a4b276/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java
index c4cb6bc..88dd0d4 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java
@@ -18,20 +18,21 @@
 
 package org.apache.flink.runtime.checkpoint;
 
-import org.apache.curator.framework.CuratorFramework;
-import org.apache.curator.utils.ZKPaths;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.runtime.jobgraph.JobStatus;
 import org.apache.flink.runtime.jobmanager.HighAvailabilityMode;
 import org.apache.flink.runtime.state.RetrievableStateHandle;
-import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.zookeeper.RetrievableStateStorageHelper;
 import org.apache.flink.runtime.zookeeper.ZooKeeperStateHandleStore;
 import org.apache.flink.util.FlinkException;
+
+import org.apache.curator.framework.CuratorFramework;
+import org.apache.curator.utils.ZKPaths;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import javax.annotation.Nullable;
+
 import java.io.IOException;
 import java.util.ArrayDeque;
 import java.util.ArrayList;
@@ -138,14 +139,13 @@ public class ZooKeeperCompletedCheckpointStore implements CompletedCheckpointSto
 	 * that the history of checkpoints is consistent.
 	 */
 	@Override
-	public void recover(SharedStateRegistry sharedStateRegistry) throws Exception {
+	public void recover() throws Exception {
 		LOG.info("Recovering checkpoints from ZooKeeper.");
 
 		// Clear local handles in order to prevent duplicates on
 		// recovery. The local handles should reflect the state
 		// of ZooKeeper.
 		completedCheckpoints.clear();
-		sharedStateRegistry.clear();
 
 		// Get all there is first
 		List<Tuple2<RetrievableStateHandle<CompletedCheckpoint>, String>> initialCheckpoints;
@@ -170,8 +170,6 @@ public class ZooKeeperCompletedCheckpointStore implements CompletedCheckpointSto
 			try {
 				completedCheckpoint = retrieveCompletedCheckpoint(checkpointStateHandle);
 				if (completedCheckpoint != null) {
-					// Re-register all shared states in the checkpoint.
-					completedCheckpoint.registerSharedStatesAfterRestored(sharedStateRegistry);
 					completedCheckpoints.add(completedCheckpoint);
 				}
 			} catch (Exception e) {

http://git-wip-us.apache.org/repos/asf/flink/blob/91a4b276/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
index 139f484..2e5f3d1 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
@@ -61,6 +61,7 @@ import org.apache.flink.runtime.jobgraph.tasks.JobCheckpointingSettings;
 import org.apache.flink.runtime.jobmanager.scheduler.CoLocationGroup;
 import org.apache.flink.runtime.jobmanager.scheduler.NoResourceAvailableException;
 import org.apache.flink.runtime.query.KvStateLocationRegistry;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.StateBackend;
 import org.apache.flink.runtime.taskmanager.TaskExecutionState;
 import org.apache.flink.util.SerializedThrowable;
@@ -69,8 +70,8 @@ import org.apache.flink.util.FlinkException;
 import org.apache.flink.util.FlinkRuntimeException;
 import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.SerializedValue;
-
 import org.apache.flink.util.StringUtils;
+
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -456,7 +457,8 @@ public class ExecutionGraph implements AccessExecutionGraph, Archiveable<Archive
 			checkpointIDCounter,
 			checkpointStore,
 			checkpointDir,
-			ioExecutor);
+			ioExecutor,
+			SharedStateRegistry.DEFAULT_FACTORY);
 
 		// register the master hooks on the checkpoint coordinator
 		for (MasterTriggerRestoreHook<?> hook : masterHooks) {

http://git-wip-us.apache.org/repos/asf/flink/blob/91a4b276/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertex.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertex.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertex.java
index 5ee7a9f..e6d49d2 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertex.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertex.java
@@ -26,6 +26,7 @@ import org.apache.flink.api.common.accumulators.Accumulator;
 import org.apache.flink.api.common.accumulators.AccumulatorHelper;
 import org.apache.flink.api.common.time.Time;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.JobManagerOptions;
 import org.apache.flink.core.io.InputSplit;
 import org.apache.flink.core.io.InputSplitAssigner;
 import org.apache.flink.core.io.InputSplitSource;
@@ -39,7 +40,6 @@ import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 import org.apache.flink.runtime.jobgraph.JobEdge;
 import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
-import org.apache.flink.configuration.JobManagerOptions;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobmanager.scheduler.CoLocationGroup;
 import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup;

http://git-wip-us.apache.org/repos/asf/flink/blob/91a4b276/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobMaster.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobMaster.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobMaster.java
index 25df19b..d6019db 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobMaster.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobmaster/JobMaster.java
@@ -89,9 +89,9 @@ import org.apache.flink.runtime.taskexecutor.TaskExecutorGateway;
 import org.apache.flink.runtime.taskexecutor.slot.SlotOffer;
 import org.apache.flink.runtime.taskmanager.TaskExecutionState;
 import org.apache.flink.runtime.taskmanager.TaskManagerLocation;
-import org.apache.flink.util.SerializedThrowable;
 import org.apache.flink.util.InstantiationUtil;
 import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.SerializedThrowable;
 
 import org.slf4j.Logger;
 

http://git-wip-us.apache.org/repos/asf/flink/blob/91a4b276/flink-runtime/src/main/java/org/apache/flink/runtime/state/IncrementalKeyedStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/IncrementalKeyedStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/IncrementalKeyedStateHandle.java
index 0085890..0268b10 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/IncrementalKeyedStateHandle.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/IncrementalKeyedStateHandle.java
@@ -65,27 +65,27 @@ public class IncrementalKeyedStateHandle implements KeyedStateHandle {
 	private final UUID backendIdentifier;
 
 	/**
-	 * The key-group range covered by this state handle
+	 * The key-group range covered by this state handle.
 	 */
 	private final KeyGroupRange keyGroupRange;
 
 	/**
-	 * The checkpoint Id
+	 * The checkpoint Id.
 	 */
 	private final long checkpointId;
 
 	/**
-	 * Shared state in the incremental checkpoint. This i
+	 * Shared state in the incremental checkpoint.
 	 */
 	private final Map<StateHandleID, StreamStateHandle> sharedState;
 
 	/**
-	 * Private state in the incremental checkpoint
+	 * Private state in the incremental checkpoint.
 	 */
 	private final Map<StateHandleID, StreamStateHandle> privateState;
 
 	/**
-	 * Primary meta data state of the incremental checkpoint
+	 * Primary meta data state of the incremental checkpoint.
 	 */
 	private final StreamStateHandle metaStateHandle;
 
@@ -143,16 +143,21 @@ public class IncrementalKeyedStateHandle implements KeyedStateHandle {
 
 	@Override
 	public KeyedStateHandle getIntersection(KeyGroupRange keyGroupRange) {
-		if (this.keyGroupRange.getIntersection(keyGroupRange) != KeyGroupRange.EMPTY_KEY_GROUP_RANGE) {
-			return this;
-		} else {
-			return null;
-		}
+		return KeyGroupRange.EMPTY_KEY_GROUP_RANGE.equals(this.keyGroupRange.getIntersection(keyGroupRange)) ?
+			null : this;
 	}
 
 	@Override
 	public void discardState() throws Exception {
 
+		SharedStateRegistry registry = this.sharedStateRegistry;
+		final boolean isRegistered = (registry != null);
+
+		LOG.trace("Discarding IncrementalKeyedStateHandle (registered = {}) for checkpoint {} from backend with id {}.",
+			isRegistered,
+			checkpointId,
+			backendIdentifier);
+
 		try {
 			metaStateHandle.discardState();
 		} catch (Exception e) {
@@ -168,19 +173,20 @@ public class IncrementalKeyedStateHandle implements KeyedStateHandle {
 		// If this was not registered, we can delete the shared state. We can simply apply this
 		// to all handles, because all handles that have not been created for the first time for this
 		// are only placeholders at this point (disposing them is a NOP).
-		if (sharedStateRegistry == null) {
-			try {
-				StateUtil.bestEffortDiscardAllStateObjects(sharedState.values());
-			} catch (Exception e) {
-				LOG.warn("Could not properly discard new sst file states.", e);
-			}
-		} else {
+		if (isRegistered) {
 			// If this was registered, we only unregister all our referenced shared states
 			// from the registry.
 			for (StateHandleID stateHandleID : sharedState.keySet()) {
-				sharedStateRegistry.unregisterReference(
+				registry.unregisterReference(
 					createSharedStateRegistryKeyFromFileName(stateHandleID));
 			}
+		} else {
+			// Otherwise, we assume to own those handles and dispose them directly.
+			try {
+				StateUtil.bestEffortDiscardAllStateObjects(sharedState.values());
+			} catch (Exception e) {
+				LOG.warn("Could not properly discard new sst file states.", e);
+			}
 		}
 	}
 
@@ -202,10 +208,21 @@ public class IncrementalKeyedStateHandle implements KeyedStateHandle {
 	@Override
 	public void registerSharedStates(SharedStateRegistry stateRegistry) {
 
-		Preconditions.checkState(sharedStateRegistry == null, "The state handle has already registered its shared states.");
+		// This is a quick check to avoid that we register twice with the same registry. However, the code allows to
+		// register again with a different registry. The implication is that ownership is transferred to this new
+		// registry. This should only happen in case of a restart, when the CheckpointCoordinator creates a new
+		// SharedStateRegistry for the current attempt and the old registry becomes meaningless. We also assume that
+		// an old registry object from a previous run is due to be GCed and will never be used for registration again.
+		Preconditions.checkState(
+			sharedStateRegistry != stateRegistry,
+			"The state handle has already registered its shared states to the given registry.");
 
 		sharedStateRegistry = Preconditions.checkNotNull(stateRegistry);
 
+		LOG.trace("Registering IncrementalKeyedStateHandle for checkpoint {} from backend with id {}.",
+			checkpointId,
+			backendIdentifier);
+
 		for (Map.Entry<StateHandleID, StreamStateHandle> sharedStateHandle : sharedState.entrySet()) {
 			SharedStateRegistryKey registryKey =
 				createSharedStateRegistryKeyFromFileName(sharedStateHandle.getKey());
@@ -284,5 +301,18 @@ public class IncrementalKeyedStateHandle implements KeyedStateHandle {
 		result = 31 * result + getMetaStateHandle().hashCode();
 		return result;
 	}
+
+	@Override
+	public String toString() {
+		return "IncrementalKeyedStateHandle{" +
+			"backendIdentifier=" + backendIdentifier +
+			", keyGroupRange=" + keyGroupRange +
+			", checkpointId=" + checkpointId +
+			", sharedState=" + sharedState +
+			", privateState=" + privateState +
+			", metaStateHandle=" + metaStateHandle +
+			", registered=" + (sharedStateRegistry != null) +
+			'}';
+	}
 }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/91a4b276/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java
index 8e38ad4..8092f6c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java
@@ -141,7 +141,7 @@ public class KeyGroupsStateHandle implements StreamStateHandle, KeyedStateHandle
 	public String toString() {
 		return "KeyGroupsStateHandle{" +
 				"groupRangeOffsets=" + groupRangeOffsets +
-				", data=" + stateHandle +
+				", stateHandle=" + stateHandle +
 				'}';
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/91a4b276/flink-runtime/src/main/java/org/apache/flink/runtime/state/MultiStreamStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/MultiStreamStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/MultiStreamStateHandle.java
index b95dace..1960c1c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/MultiStreamStateHandle.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/MultiStreamStateHandle.java
@@ -38,7 +38,7 @@ public class MultiStreamStateHandle implements StreamStateHandle {
 	private final List<StreamStateHandle> stateHandles;
 	private final long stateSize;
 
-	public MultiStreamStateHandle(List<StreamStateHandle> stateHandles) throws IOException {
+	public MultiStreamStateHandle(List<StreamStateHandle> stateHandles) {
 		this.stateHandles = Preconditions.checkNotNull(stateHandles);
 		long calculateSize = 0L;
 		for(StreamStateHandle stateHandle : stateHandles) {
@@ -62,6 +62,14 @@ public class MultiStreamStateHandle implements StreamStateHandle {
 		return stateSize;
 	}
 
+	@Override
+	public String toString() {
+		return "MultiStreamStateHandle{" +
+			"stateHandles=" + stateHandles +
+			", stateSize=" + stateSize +
+			'}';
+	}
+
 	static final class MultiFSDataInputStream extends AbstractMultiFSDataInputStream {
 
 		private final TreeMap<Long, StreamStateHandle> stateHandleMap;

http://git-wip-us.apache.org/repos/asf/flink/blob/91a4b276/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java
index e0ca873..347f30c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java
@@ -38,13 +38,24 @@ import java.util.concurrent.Executor;
  * maintain the reference count of {@link StreamStateHandle}s by a key that (logically) identifies
  * them.
  */
-public class SharedStateRegistry {
+public class SharedStateRegistry implements AutoCloseable {
 
 	private static final Logger LOG = LoggerFactory.getLogger(SharedStateRegistry.class);
 
+	/** A singleton object for the default implementation of a {@link SharedStateRegistryFactory} */
+	public static final SharedStateRegistryFactory DEFAULT_FACTORY = new SharedStateRegistryFactory() {
+		@Override
+		public SharedStateRegistry create(Executor deleteExecutor) {
+			return new SharedStateRegistry(deleteExecutor);
+		}
+	};
+
 	/** All registered state objects by an artificial key */
 	private final Map<SharedStateRegistryKey, SharedStateRegistry.SharedStateEntry> registeredStates;
 
+	/** This flag indicates whether or not the registry is open or if close() was called */
+	private boolean open;
+
 	/** Executor for async state deletion */
 	private final Executor asyncDisposalExecutor;
 
@@ -56,6 +67,7 @@ public class SharedStateRegistry {
 	public SharedStateRegistry(Executor asyncDisposalExecutor) {
 		this.registeredStates = new HashMap<>();
 		this.asyncDisposalExecutor = Preconditions.checkNotNull(asyncDisposalExecutor);
+		this.open = true;
 	}
 
 	/**
@@ -82,6 +94,9 @@ public class SharedStateRegistry {
 		SharedStateRegistry.SharedStateEntry entry;
 
 		synchronized (registeredStates) {
+
+			Preconditions.checkState(open, "Attempt to register state to closed SharedStateRegistry.");
+
 			entry = registeredStates.get(registrationKey);
 
 			if (entry == null) {
@@ -96,6 +111,11 @@ public class SharedStateRegistry {
 				// delete if this is a real duplicate
 				if (!Objects.equals(state, entry.stateHandle)) {
 					scheduledStateDeletion = state;
+					LOG.trace("Identified duplicate state registration under key {}. New state {} was determined to " +
+							"be an unnecessary copy of existing state {} and will be dropped.",
+						registrationKey,
+						state,
+						entry.stateHandle);
 				}
 				entry.increaseReferenceCount();
 			}
@@ -112,7 +132,8 @@ public class SharedStateRegistry {
 	 *
 	 * @param registrationKey the shared state for which we release a reference.
 	 * @return the result of the request, consisting of the reference count after this operation
-	 * and the state handle, or null if the state handle was deleted through this request.
+	 * and the state handle, or null if the state handle was deleted through this request. Returns null if the registry
+	 * was previously closed.
 	 */
 	public Result unregisterReference(SharedStateRegistryKey registrationKey) {
 
@@ -123,6 +144,7 @@ public class SharedStateRegistry {
 		SharedStateRegistry.SharedStateEntry entry;
 
 		synchronized (registeredStates) {
+
 			entry = registeredStates.get(registrationKey);
 
 			Preconditions.checkState(entry != null,
@@ -164,10 +186,18 @@ public class SharedStateRegistry {
 		}
 	}
 
+	@Override
+	public String toString() {
+		synchronized (registeredStates) {
+			return "SharedStateRegistry{" +
+				"registeredStates=" + registeredStates +
+				'}';
+		}
+	}
+
 	private void scheduleAsyncDelete(StreamStateHandle streamStateHandle) {
 		// We do the small optimization to not issue discards for placeholders, which are NOPs.
 		if (streamStateHandle != null && !isPlaceholder(streamStateHandle)) {
-
 			LOG.trace("Scheduled delete of state handle {}.", streamStateHandle);
 			asyncDisposalExecutor.execute(
 				new SharedStateRegistry.AsyncDisposalRunnable(streamStateHandle));
@@ -178,6 +208,13 @@ public class SharedStateRegistry {
 		return stateHandle instanceof PlaceholderStreamStateHandle;
 	}
 
+	@Override
+	public void close() {
+		synchronized (registeredStates) {
+			open = false;
+		}
+	}
+
 	/**
 	 * An entry in the registry, tracking the handle and the corresponding reference count.
 	 */
@@ -279,13 +316,4 @@ public class SharedStateRegistry {
 			}
 		}
 	}
-
-	/**
-	 * Clears the registry.
-	 */
-	public void clear() {
-		synchronized (registeredStates) {
-			registeredStates.clear();
-		}
-	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/91a4b276/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistryFactory.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistryFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistryFactory.java
new file mode 100644
index 0000000..05c9825
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistryFactory.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import java.util.concurrent.Executor;
+
+/**
+ * Simple factory to produce {@link SharedStateRegistry} objects.
+ */
+public interface SharedStateRegistryFactory {
+
+	/**
+	 * Factory method for {@link SharedStateRegistry}.
+	 *
+	 * @param deleteExecutor executor used to run (async) deletes.
+	 * @return a SharedStateRegistry object
+	 */
+	SharedStateRegistry create(Executor deleteExecutor);
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/91a4b276/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java
index 9ba9d35..3a43d4f 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java
@@ -95,6 +95,7 @@ public class ByteStreamStateHandle implements StreamStateHandle {
 	public String toString() {
 		return "ByteStreamStateHandle{" +
 			"handleName='" + handleName + '\'' +
+			", dataBytes=" + data.length +
 			'}';
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/91a4b276/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorExternalizedCheckpointsTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorExternalizedCheckpointsTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorExternalizedCheckpointsTest.java
index d293eea..edc29fe 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorExternalizedCheckpointsTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorExternalizedCheckpointsTest.java
@@ -18,14 +18,6 @@
 
 package org.apache.flink.runtime.checkpoint;
 
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertTrue;
-
-import java.io.File;
-import java.io.IOException;
-import java.util.HashMap;
-import java.util.Map;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.runtime.checkpoint.savepoint.SavepointLoader;
 import org.apache.flink.runtime.concurrent.Executors;
@@ -37,11 +29,22 @@ import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.filesystem.FileStateHandle;
+
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.TemporaryFolder;
 
+import java.io.File;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
 /**
  * CheckpointCoordinator tests for externalized checkpoints.
  *
@@ -91,7 +94,8 @@ public class CheckpointCoordinatorExternalizedCheckpointsTest {
 			new StandaloneCheckpointIDCounter(),
 			new StandaloneCompletedCheckpointStore(1),
 			checkpointDir.getAbsolutePath(),
-			Executors.directExecutor());
+			Executors.directExecutor(),
+			SharedStateRegistry.DEFAULT_FACTORY);
 
 		assertEquals(0, coord.getNumberOfPendingCheckpoints());
 		assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());

http://git-wip-us.apache.org/repos/asf/flink/blob/91a4b276/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
index 88b95f5..26db772 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
@@ -79,7 +79,8 @@ public class CheckpointCoordinatorFailureTest extends TestLogger {
 			new StandaloneCheckpointIDCounter(),
 			new FailingCompletedCheckpointStore(),
 			null,
-			Executors.directExecutor());
+			Executors.directExecutor(),
+			SharedStateRegistry.DEFAULT_FACTORY);
 
 		coord.triggerCheckpoint(triggerTimestamp, false);
 
@@ -111,7 +112,7 @@ public class CheckpointCoordinatorFailureTest extends TestLogger {
 		when(subtaskState.getSubtaskStateByOperatorID(OperatorID.fromJobVertexID(vertex.getJobvertexId()))).thenReturn(operatorSubtaskState);
 
 		AcknowledgeCheckpoint acknowledgeMessage = new AcknowledgeCheckpoint(jid, executionAttemptId, checkpointId, new CheckpointMetrics(), subtaskState);
-		
+
 		try {
 			coord.receiveAcknowledgeMessage(acknowledgeMessage);
 			fail("Expected a checkpoint exception because the completed checkpoint store could not " +
@@ -135,7 +136,7 @@ public class CheckpointCoordinatorFailureTest extends TestLogger {
 	private static final class FailingCompletedCheckpointStore implements CompletedCheckpointStore {
 
 		@Override
-		public void recover(SharedStateRegistry sharedStateRegistry) throws Exception {
+		public void recover() throws Exception {
 			throw new UnsupportedOperationException("Not implemented.");
 		}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/91a4b276/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java
index e23f6a2..2f860e0 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java
@@ -28,9 +28,9 @@ import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 
 import org.junit.Test;
-
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
 
@@ -46,14 +46,12 @@ import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.Executor;
 
 import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTest.mockExecutionVertex;
-
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
-
 import static org.mockito.Matchers.eq;
 import static org.mockito.Matchers.isNull;
 import static org.mockito.Mockito.any;
@@ -404,7 +402,8 @@ public class CheckpointCoordinatorMasterHooksTest {
 				new StandaloneCheckpointIDCounter(),
 				new StandaloneCompletedCheckpointStore(10),
 				null,
-				Executors.directExecutor());
+				Executors.directExecutor(),
+				SharedStateRegistry.DEFAULT_FACTORY);
 	}
 
 	private static <T> T mockGeneric(Class<?> clazz) {

http://git-wip-us.apache.org/repos/asf/flink/blob/91a4b276/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index d9af879..45cbbc3 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -36,32 +36,36 @@ import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint;
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.IncrementalKeyedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.PlaceholderStreamStateHandle;
 import org.apache.flink.runtime.state.SharedStateRegistry;
+import org.apache.flink.runtime.state.SharedStateRegistryFactory;
+import org.apache.flink.runtime.state.StateHandleID;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.filesystem.FileStateHandle;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.apache.flink.runtime.testutils.CommonTestUtils;
 import org.apache.flink.runtime.testutils.RecoverableCompletedCheckpointStore;
 import org.apache.flink.runtime.util.TestByteStreamStateHandleDeepCompare;
+import org.apache.flink.shaded.guava18.com.google.common.collect.Iterables;
+import org.apache.flink.shaded.guava18.com.google.common.collect.Lists;
 import org.apache.flink.util.InstantiationUtil;
 import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.TestLogger;
 
-import org.apache.flink.shaded.guava18.com.google.common.collect.Iterables;
-import org.apache.flink.shaded.guava18.com.google.common.collect.Lists;
-
 import org.junit.Assert;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.TemporaryFolder;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
+import org.mockito.verification.VerificationMode;
 
 import java.io.IOException;
 import java.io.Serializable;
@@ -139,7 +143,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCheckpointIDCounter(),
 				new StandaloneCompletedCheckpointStore(1),
 				null,
-				Executors.directExecutor());
+				Executors.directExecutor(),
+				SharedStateRegistry.DEFAULT_FACTORY);
 
 			// nothing should be happening
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
@@ -199,7 +204,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCheckpointIDCounter(),
 				new StandaloneCompletedCheckpointStore(1),
 				null,
-				Executors.directExecutor());
+				Executors.directExecutor(),
+				SharedStateRegistry.DEFAULT_FACTORY);
 
 			// nothing should be happening
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
@@ -250,7 +256,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCheckpointIDCounter(),
 				new StandaloneCompletedCheckpointStore(1),
 				null,
-				Executors.directExecutor());
+				Executors.directExecutor(),
+				SharedStateRegistry.DEFAULT_FACTORY);
 
 			// nothing should be happening
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
@@ -302,7 +309,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCheckpointIDCounter(),
 				new StandaloneCompletedCheckpointStore(1),
 				null,
-				Executors.directExecutor());
+				Executors.directExecutor(),
+				SharedStateRegistry.DEFAULT_FACTORY);
 
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
@@ -406,7 +414,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCheckpointIDCounter(),
 				new StandaloneCompletedCheckpointStore(1),
 				null,
-				Executors.directExecutor());
+				Executors.directExecutor(),
+				SharedStateRegistry.DEFAULT_FACTORY);
 
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
@@ -525,7 +534,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCheckpointIDCounter(),
 				new StandaloneCompletedCheckpointStore(1),
 				null,
-				Executors.directExecutor());
+				Executors.directExecutor(),
+				SharedStateRegistry.DEFAULT_FACTORY);
 
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
@@ -692,7 +702,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCheckpointIDCounter(),
 				new StandaloneCompletedCheckpointStore(2),
 				null,
-				Executors.directExecutor());
+				Executors.directExecutor(),
+				SharedStateRegistry.DEFAULT_FACTORY);
 
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
@@ -822,7 +833,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCheckpointIDCounter(),
 				new StandaloneCompletedCheckpointStore(10),
 				null,
-				Executors.directExecutor());
+				Executors.directExecutor(),
+				SharedStateRegistry.DEFAULT_FACTORY);
 
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
@@ -986,7 +998,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCheckpointIDCounter(),
 				new StandaloneCompletedCheckpointStore(2),
 				null,
-				Executors.directExecutor());
+				Executors.directExecutor(),
+				SharedStateRegistry.DEFAULT_FACTORY);
 
 			// trigger a checkpoint, partially acknowledged
 			assertTrue(coord.triggerCheckpoint(timestamp, false));
@@ -1063,7 +1076,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCheckpointIDCounter(),
 				new StandaloneCompletedCheckpointStore(2),
 				null,
-				Executors.directExecutor());
+				Executors.directExecutor(),
+				SharedStateRegistry.DEFAULT_FACTORY);
 
 			assertTrue(coord.triggerCheckpoint(timestamp, false));
 
@@ -1126,7 +1140,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			new StandaloneCheckpointIDCounter(),
 			new StandaloneCompletedCheckpointStore(1),
 			null,
-			Executors.directExecutor());
+			Executors.directExecutor(),
+			SharedStateRegistry.DEFAULT_FACTORY);
 
 		assertTrue(coord.triggerCheckpoint(timestamp, false));
 
@@ -1258,7 +1273,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCheckpointIDCounter(),
 				new StandaloneCompletedCheckpointStore(2),
 				null,
-				Executors.directExecutor());
+				Executors.directExecutor(),
+				SharedStateRegistry.DEFAULT_FACTORY);
 
 
 			coord.startCheckpointScheduler();
@@ -1350,7 +1366,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCheckpointIDCounter(),
 				new StandaloneCompletedCheckpointStore(2),
 				"dummy-path",
-				Executors.directExecutor());
+				Executors.directExecutor(),
+				SharedStateRegistry.DEFAULT_FACTORY);
 
 		try {
 			coord.startCheckpointScheduler();
@@ -1423,7 +1440,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			new StandaloneCheckpointIDCounter(),
 			new StandaloneCompletedCheckpointStore(1),
 			null,
-			Executors.directExecutor());
+			Executors.directExecutor(),
+			SharedStateRegistry.DEFAULT_FACTORY);
 
 		assertEquals(0, coord.getNumberOfPendingCheckpoints());
 		assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
@@ -1574,7 +1592,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			counter,
 			new StandaloneCompletedCheckpointStore(10),
 			null,
-			Executors.directExecutor());
+			Executors.directExecutor(),
+			SharedStateRegistry.DEFAULT_FACTORY);
 
 		String savepointDir = tmpFolder.newFolder().getAbsolutePath();
 
@@ -1680,7 +1699,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCheckpointIDCounter(),
 				new StandaloneCompletedCheckpointStore(2),
 				null,
-				Executors.directExecutor());
+				Executors.directExecutor(),
+				SharedStateRegistry.DEFAULT_FACTORY);
 
 			coord.startCheckpointScheduler();
 
@@ -1753,7 +1773,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCheckpointIDCounter(),
 				new StandaloneCompletedCheckpointStore(2),
 				null,
-				Executors.directExecutor());
+				Executors.directExecutor(),
+				SharedStateRegistry.DEFAULT_FACTORY);
 
 			coord.startCheckpointScheduler();
 
@@ -1835,7 +1856,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCheckpointIDCounter(),
 				new StandaloneCompletedCheckpointStore(2),
 				null,
-				Executors.directExecutor());
+				Executors.directExecutor(),
+				SharedStateRegistry.DEFAULT_FACTORY);
 
 			coord.startCheckpointScheduler();
 
@@ -1887,7 +1909,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			checkpointIDCounter,
 			new StandaloneCompletedCheckpointStore(2),
 			null,
-			Executors.directExecutor());
+			Executors.directExecutor(),
+			SharedStateRegistry.DEFAULT_FACTORY);
 
 		List<CompletableFuture<CompletedCheckpoint>> savepointFutures = new ArrayList<>();
 
@@ -1940,7 +1963,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			new StandaloneCheckpointIDCounter(),
 			new StandaloneCompletedCheckpointStore(2),
 			null,
-			Executors.directExecutor());
+			Executors.directExecutor(),
+			SharedStateRegistry.DEFAULT_FACTORY);
 
 		String savepointDir = tmpFolder.newFolder().getAbsolutePath();
 
@@ -2002,7 +2026,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			new StandaloneCheckpointIDCounter(),
 			store,
 			null,
-			Executors.directExecutor());
+			Executors.directExecutor(),
+			SharedStateRegistry.DEFAULT_FACTORY);
 
 		// trigger the checkpoint
 		coord.triggerCheckpoint(timestamp, false);
@@ -2116,7 +2141,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			new StandaloneCheckpointIDCounter(),
 			new StandaloneCompletedCheckpointStore(1),
 			null,
-			Executors.directExecutor());
+			Executors.directExecutor(),
+			SharedStateRegistry.DEFAULT_FACTORY);
 
 		// trigger the checkpoint
 		coord.triggerCheckpoint(timestamp, false);
@@ -2237,7 +2263,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			new StandaloneCheckpointIDCounter(),
 			new StandaloneCompletedCheckpointStore(1),
 			null,
-			Executors.directExecutor());
+			Executors.directExecutor(),
+			SharedStateRegistry.DEFAULT_FACTORY);
 
 		// trigger the checkpoint
 		coord.triggerCheckpoint(timestamp, false);
@@ -2395,7 +2422,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			new StandaloneCheckpointIDCounter(),
 			new StandaloneCompletedCheckpointStore(1),
 			null,
-			Executors.directExecutor());
+			Executors.directExecutor(),
+			SharedStateRegistry.DEFAULT_FACTORY);
 
 		// trigger the checkpoint
 		coord.triggerCheckpoint(timestamp, false);
@@ -2686,7 +2714,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			new StandaloneCheckpointIDCounter(),
 			standaloneCompletedCheckpointStore,
 			null,
-			Executors.directExecutor());
+			Executors.directExecutor(),
+			SharedStateRegistry.DEFAULT_FACTORY);
 
 		coord.restoreLatestCheckpointedState(tasks, false, true);
 
@@ -2847,7 +2876,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCheckpointIDCounter(),
 				new StandaloneCompletedCheckpointStore(1),
 				"fake-directory",
-				Executors.directExecutor());
+				Executors.directExecutor(),
+				SharedStateRegistry.DEFAULT_FACTORY);
 
 			assertTrue(coord.triggerCheckpoint(timestamp, false));
 
@@ -3351,7 +3381,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			new StandaloneCheckpointIDCounter(),
 			new StandaloneCompletedCheckpointStore(1),
 			null,
-			Executors.directExecutor());
+			Executors.directExecutor(),
+			SharedStateRegistry.DEFAULT_FACTORY);
 
 		// Periodic
 		CheckpointTriggerResult triggerResult = coord.triggerCheckpoint(
@@ -3529,7 +3560,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			new StandaloneCheckpointIDCounter(),
 			new StandaloneCompletedCheckpointStore(1),
 			null,
-			Executors.directExecutor());
+			Executors.directExecutor(),
+			SharedStateRegistry.DEFAULT_FACTORY);
 
 		CheckpointStatsTracker tracker = mock(CheckpointStatsTracker.class);
 		coord.setCheckpointStatsTracker(tracker);
@@ -3567,7 +3599,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			new StandaloneCheckpointIDCounter(),
 			store,
 			null,
-			Executors.directExecutor());
+			Executors.directExecutor(),
+			SharedStateRegistry.DEFAULT_FACTORY);
 
 		store.addCheckpoint(new CompletedCheckpoint(
 			new JobID(),
@@ -3623,7 +3656,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			checkpointIDCounter,
 			completedCheckpointStore,
 			null,
-			Executors.directExecutor());
+			Executors.directExecutor(),
+			SharedStateRegistry.DEFAULT_FACTORY);
 
 		// trigger a first checkpoint
 		assertTrue(
@@ -3673,4 +3707,245 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			"The latest completed (proper) checkpoint should have been added to the completed checkpoint store.",
 			completedCheckpointStore.getLatestCheckpoint().getCheckpointID() == checkpointIDCounter.getLast());
 	}
+
+	@Test
+	public void testSharedStateRegistrationOnRestore() throws Exception {
+
+		final JobID jid = new JobID();
+		final long timestamp = System.currentTimeMillis();
+
+		final JobVertexID jobVertexID1 = new JobVertexID();
+
+		int parallelism1 = 2;
+		int maxParallelism1 = 4;
+
+		final ExecutionJobVertex jobVertex1 = mockExecutionJobVertex(
+			jobVertexID1,
+			parallelism1,
+			maxParallelism1);
+
+		List<ExecutionVertex> allExecutionVertices = new ArrayList<>(parallelism1);
+
+		allExecutionVertices.addAll(Arrays.asList(jobVertex1.getTaskVertices()));
+
+		ExecutionVertex[] arrayExecutionVertices =
+			allExecutionVertices.toArray(new ExecutionVertex[allExecutionVertices.size()]);
+
+		RecoverableCompletedCheckpointStore store = new RecoverableCompletedCheckpointStore(10);
+
+		final List<SharedStateRegistry> createdSharedStateRegistries = new ArrayList<>(2);
+
+		// set up the coordinator and validate the initial state
+		CheckpointCoordinator coord = new CheckpointCoordinator(
+			jid,
+			600000,
+			600000,
+			0,
+			Integer.MAX_VALUE,
+			ExternalizedCheckpointSettings.none(),
+			arrayExecutionVertices,
+			arrayExecutionVertices,
+			arrayExecutionVertices,
+			new StandaloneCheckpointIDCounter(),
+			store,
+			null,
+			Executors.directExecutor(),
+			new SharedStateRegistryFactory() {
+				@Override
+				public SharedStateRegistry create(Executor deleteExecutor) {
+					SharedStateRegistry instance = new SharedStateRegistry(deleteExecutor);
+					createdSharedStateRegistries.add(instance);
+					return instance;
+				}
+			});
+
+		final int numCheckpoints = 3;
+
+		List<KeyGroupRange> keyGroupPartitions1 =
+			StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1);
+
+		for (int i = 0; i < numCheckpoints; ++i) {
+			performIncrementalCheckpoint(jid, coord, jobVertex1, keyGroupPartitions1, timestamp + i, i);
+		}
+
+		List<CompletedCheckpoint> completedCheckpoints = coord.getSuccessfulCheckpoints();
+		assertEquals(numCheckpoints, completedCheckpoints.size());
+
+		int sharedHandleCount = 0;
+
+		List<Map<StateHandleID, StreamStateHandle>> sharedHandlesByCheckpoint = new ArrayList<>(numCheckpoints);
+
+		for (int i = 0; i < numCheckpoints; ++i) {
+			sharedHandlesByCheckpoint.add(new HashMap<StateHandleID, StreamStateHandle>(2));
+		}
+
+		int cp = 0;
+		for (CompletedCheckpoint completedCheckpoint : completedCheckpoints) {
+			for (OperatorState taskState : completedCheckpoint.getOperatorStates().values()) {
+				for (OperatorSubtaskState subtaskState : taskState.getStates()) {
+					for (KeyedStateHandle keyedStateHandle : subtaskState.getManagedKeyedState()) {
+						// test we are once registered with the current registry
+						verify(keyedStateHandle, times(1)).registerSharedStates(createdSharedStateRegistries.get(0));
+						IncrementalKeyedStateHandle incrementalKeyedStateHandle = (IncrementalKeyedStateHandle) keyedStateHandle;
+
+						sharedHandlesByCheckpoint.get(cp).putAll(incrementalKeyedStateHandle.getSharedState());
+
+						for (StreamStateHandle streamStateHandle : incrementalKeyedStateHandle.getSharedState().values()) {
+							assertTrue(!(streamStateHandle instanceof PlaceholderStreamStateHandle));
+							verify(streamStateHandle, never()).discardState();
+							++sharedHandleCount;
+						}
+
+						for (StreamStateHandle streamStateHandle : incrementalKeyedStateHandle.getPrivateState().values()) {
+							verify(streamStateHandle, never()).discardState();
+						}
+
+						verify(incrementalKeyedStateHandle.getMetaStateHandle(), never()).discardState();
+					}
+
+					verify(subtaskState, never()).discardState();
+				}
+			}
+			++cp;
+		}
+
+		// 2 (parallelism) x (1 (CP0) + 2 (CP1) + 2 (CP2)) = 10
+		assertEquals(10, sharedHandleCount);
+
+		// discard CP0
+		store.removeOldestCheckpoint();
+
+		// we expect no shared state was discarded because the state of CP0 is still referenced by CP1
+		for (Map<StateHandleID, StreamStateHandle> cpList : sharedHandlesByCheckpoint) {
+			for (StreamStateHandle streamStateHandle : cpList.values()) {
+				verify(streamStateHandle, never()).discardState();
+			}
+		}
+
+		// shutdown the store
+		store.shutdown(JobStatus.SUSPENDED);
+
+		// restore the store
+		Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>();
+		tasks.put(jobVertexID1, jobVertex1);
+		coord.restoreLatestCheckpointedState(tasks, true, false);
+
+		// validate that all shared states are registered again after the recovery.
+		cp = 0;
+		for (CompletedCheckpoint completedCheckpoint : completedCheckpoints) {
+			for (OperatorState taskState : completedCheckpoint.getOperatorStates().values()) {
+				for (OperatorSubtaskState subtaskState : taskState.getStates()) {
+					for (KeyedStateHandle keyedStateHandle : subtaskState.getManagedKeyedState()) {
+						VerificationMode verificationMode;
+						// test we are once registered with the new registry
+						if (cp > 0) {
+							verificationMode = times(1);
+						} else {
+							verificationMode = never();
+						}
+
+						//check that all are registered with the new registry
+						verify(keyedStateHandle, verificationMode).registerSharedStates(createdSharedStateRegistries.get(1));
+					}
+				}
+			}
+			++cp;
+		}
+
+		// discard CP1
+		store.removeOldestCheckpoint();
+
+		// we expect that all shared state from CP0 is no longer referenced and discarded. CP2 is still live and also
+		// references the state from CP1, so we expect they are not discarded.
+		for (Map<StateHandleID, StreamStateHandle> cpList : sharedHandlesByCheckpoint) {
+			for (Map.Entry<StateHandleID, StreamStateHandle> entry : cpList.entrySet()) {
+				String key = entry.getKey().getKeyString();
+				int belongToCP = Integer.parseInt(String.valueOf(key.charAt(key.length() - 1)));
+				if (belongToCP == 0) {
+					verify(entry.getValue(), times(1)).discardState();
+				} else {
+					verify(entry.getValue(), never()).discardState();
+				}
+			}
+		}
+
+		// discard CP2
+		store.removeOldestCheckpoint();
+
+		// we expect all shared state was discarded now, because all CPs are
+		for (Map<StateHandleID, StreamStateHandle> cpList : sharedHandlesByCheckpoint) {
+			for (StreamStateHandle streamStateHandle : cpList.values()) {
+				verify(streamStateHandle, times(1)).discardState();
+			}
+		}
+	}
+
+	private void performIncrementalCheckpoint(
+		JobID jid,
+		CheckpointCoordinator coord,
+		ExecutionJobVertex jobVertex1,
+		List<KeyGroupRange> keyGroupPartitions1,
+		long timestamp,
+		int cpSequenceNumber) throws Exception {
+
+		// trigger the checkpoint
+		coord.triggerCheckpoint(timestamp, false);
+
+		assertTrue(coord.getPendingCheckpoints().keySet().size() == 1);
+		long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet());
+
+		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
+
+			KeyGroupRange keyGroupRange = keyGroupPartitions1.get(index);
+
+			Map<StateHandleID, StreamStateHandle> privateState = new HashMap<>();
+			privateState.put(
+				new StateHandleID("private-1"),
+				spy(new ByteStreamStateHandle("private-1", new byte[]{'p'})));
+
+			Map<StateHandleID, StreamStateHandle> sharedState = new HashMap<>();
+
+			// let all but the first CP overlap by one shared state.
+			if (cpSequenceNumber > 0) {
+				sharedState.put(
+					new StateHandleID("shared-" + (cpSequenceNumber - 1)),
+					spy(new PlaceholderStreamStateHandle()));
+			}
+
+			sharedState.put(
+				new StateHandleID("shared-" + cpSequenceNumber),
+				spy(new ByteStreamStateHandle("shared-" + cpSequenceNumber + "-" + keyGroupRange, new byte[]{'s'})));
+
+			IncrementalKeyedStateHandle managedState =
+				spy(new IncrementalKeyedStateHandle(
+					new UUID(42L, 42L),
+					keyGroupRange,
+					checkpointId,
+					sharedState,
+					privateState,
+					spy(new ByteStreamStateHandle("meta", new byte[]{'m'}))));
+
+			OperatorSubtaskState operatorSubtaskState =
+				spy(new OperatorSubtaskState(null,
+					Collections.<OperatorStateHandle>emptyList(),
+					Collections.<OperatorStateHandle>emptyList(),
+					Collections.<KeyedStateHandle>singletonList(managedState),
+					Collections.<KeyedStateHandle>emptyList()));
+
+			Map<OperatorID, OperatorSubtaskState> opStates = new HashMap<>();
+
+			opStates.put(jobVertex1.getOperatorIDs().get(0), operatorSubtaskState);
+
+			TaskStateSnapshot taskStateSnapshot = new TaskStateSnapshot(opStates);
+
+			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
+				jid,
+				jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+				checkpointId,
+				new CheckpointMetrics(),
+				taskStateSnapshot);
+
+			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
+		}
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/91a4b276/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
index 6ce071b..791bffa 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
@@ -33,6 +33,7 @@ import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.util.SerializableObject;
 
@@ -109,7 +110,8 @@ public class CheckpointStateRestoreTest {
 				new StandaloneCheckpointIDCounter(),
 				new StandaloneCompletedCheckpointStore(1),
 				null,
-				Executors.directExecutor());
+				Executors.directExecutor(),
+				SharedStateRegistry.DEFAULT_FACTORY);
 
 			// create ourselves a checkpoint with state
 			final long timestamp = 34623786L;
@@ -186,7 +188,8 @@ public class CheckpointStateRestoreTest {
 				new StandaloneCheckpointIDCounter(),
 				new StandaloneCompletedCheckpointStore(1),
 				null,
-				Executors.directExecutor());
+				Executors.directExecutor(),
+				SharedStateRegistry.DEFAULT_FACTORY);
 
 			try {
 				coord.restoreLatestCheckpointedState(new HashMap<JobVertexID, ExecutionJobVertex>(), true, false);
@@ -243,7 +246,8 @@ public class CheckpointStateRestoreTest {
 			new StandaloneCheckpointIDCounter(),
 			new StandaloneCompletedCheckpointStore(1),
 			null,
-			Executors.directExecutor());
+			Executors.directExecutor(),
+			SharedStateRegistry.DEFAULT_FACTORY);
 
 		StreamStateHandle serializedState = CheckpointCoordinatorTest
 				.generateChainedStateHandle(new SerializableObject())

http://git-wip-us.apache.org/repos/asf/flink/blob/91a4b276/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
index 77423c2..dc2b11e 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
@@ -18,13 +18,14 @@
 
 package org.apache.flink.runtime.checkpoint;
 
-import org.apache.curator.framework.CuratorFramework;
 import org.apache.flink.runtime.concurrent.Executors;
 import org.apache.flink.runtime.jobgraph.JobStatus;
 import org.apache.flink.runtime.state.RetrievableStateHandle;
 import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.zookeeper.RetrievableStateStorageHelper;
 import org.apache.flink.runtime.zookeeper.ZooKeeperTestEnvironment;
+
+import org.apache.curator.framework.CuratorFramework;
 import org.apache.zookeeper.data.Stat;
 import org.junit.AfterClass;
 import org.junit.Before;
@@ -106,8 +107,9 @@ public class ZooKeeperCompletedCheckpointStoreITCase extends CompletedCheckpoint
 		assertEquals(3, checkpoints.getNumberOfRetainedCheckpoints());
 
 		// Recover
-		sharedStateRegistry.clear();
-		checkpoints.recover(sharedStateRegistry);
+		sharedStateRegistry.close();
+		sharedStateRegistry = new SharedStateRegistry();
+		checkpoints.recover();
 
 		assertEquals(3, ZOOKEEPER.getClient().getChildren().forPath(CHECKPOINT_PATH).size());
 		assertEquals(3, checkpoints.getNumberOfRetainedCheckpoints());
@@ -148,8 +150,8 @@ public class ZooKeeperCompletedCheckpointStoreITCase extends CompletedCheckpoint
 		assertEquals(0, store.getNumberOfRetainedCheckpoints());
 		assertNull(client.checkExists().forPath(CHECKPOINT_PATH + ZooKeeperCompletedCheckpointStore.checkpointIdToPath(checkpoint.getCheckpointID())));
 
-		sharedStateRegistry.clear();
-		store.recover(sharedStateRegistry);
+		sharedStateRegistry.close();
+		store.recover();
 
 		assertEquals(0, store.getNumberOfRetainedCheckpoints());
 	}
@@ -182,8 +184,8 @@ public class ZooKeeperCompletedCheckpointStoreITCase extends CompletedCheckpoint
 		assertEquals("The checkpoint node should not be locked.", 0, stat.getNumChildren());
 
 		// Recover again
-		sharedStateRegistry.clear();
-		store.recover(sharedStateRegistry);
+		sharedStateRegistry.close();
+		store.recover();
 
 		CompletedCheckpoint recovered = store.getLatestCheckpoint();
 		assertEquals(checkpoint, recovered);
@@ -209,8 +211,8 @@ public class ZooKeeperCompletedCheckpointStoreITCase extends CompletedCheckpoint
 			checkpointStore.addCheckpoint(checkpoint);
 		}
 
-		sharedStateRegistry.clear();
-		checkpointStore.recover(sharedStateRegistry);
+		sharedStateRegistry.close();
+		checkpointStore.recover();
 
 		CompletedCheckpoint latestCheckpoint = checkpointStore.getLatestCheckpoint();
 
@@ -239,8 +241,9 @@ public class ZooKeeperCompletedCheckpointStoreITCase extends CompletedCheckpoint
 		zkCheckpointStore1.addCheckpoint(completedCheckpoint);
 
 		// recover the checkpoint by a different checkpoint store
-		sharedStateRegistry.clear();
-		zkCheckpointStore2.recover(sharedStateRegistry);
+		sharedStateRegistry.close();
+		sharedStateRegistry = new SharedStateRegistry();
+		zkCheckpointStore2.recover();
 
 		CompletedCheckpoint recoveredCheckpoint = zkCheckpointStore2.getLatestCheckpoint();
 		assertTrue(recoveredCheckpoint instanceof TestCompletedCheckpoint);

http://git-wip-us.apache.org/repos/asf/flink/blob/91a4b276/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java
index 91bab85..3171f1f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java
@@ -52,7 +52,6 @@ import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.Matchers.any;
-import static org.mockito.Matchers.anyCollection;
 import static org.mockito.Matchers.anyString;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
@@ -162,11 +161,7 @@ public class ZooKeeperCompletedCheckpointStoreTest extends TestLogger {
 			stateStorage,
 			Executors.directExecutor());
 
-		SharedStateRegistry sharedStateRegistry = spy(new SharedStateRegistry());
-		zooKeeperCompletedCheckpointStore.recover(sharedStateRegistry);
-
-		verify(retrievableStateHandle1.retrieveState(), times(1)).registerSharedStatesAfterRestored(sharedStateRegistry);
-		verify(retrievableStateHandle2.retrieveState(), times(1)).registerSharedStatesAfterRestored(sharedStateRegistry);
+		zooKeeperCompletedCheckpointStore.recover();
 
 		CompletedCheckpoint latestCompletedCheckpoint = zooKeeperCompletedCheckpointStore.getLatestCheckpoint();
 

http://git-wip-us.apache.org/repos/asf/flink/blob/91a4b276/flink-runtime/src/test/java/org/apache/flink/runtime/state/IncrementalKeyedStateHandleTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/IncrementalKeyedStateHandleTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/IncrementalKeyedStateHandleTest.java
index c1b3ccd..9f6f88e 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/IncrementalKeyedStateHandleTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/IncrementalKeyedStateHandleTest.java
@@ -19,12 +19,15 @@
 package org.apache.flink.runtime.state;
 
 import org.apache.flink.runtime.checkpoint.savepoint.CheckpointTestUtils;
+
 import org.junit.Test;
 
 import java.util.Map;
 import java.util.Random;
 import java.util.UUID;
 
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.powermock.api.mockito.PowerMockito.spy;
@@ -59,8 +62,6 @@ public class IncrementalKeyedStateHandleTest {
 	@Test
 	public void testSharedStateDeRegistration() throws Exception {
 
-		Random rnd = new Random(42);
-
 		SharedStateRegistry registry = spy(new SharedStateRegistry());
 
 		// Create two state handles with overlapping shared state
@@ -186,6 +187,76 @@ public class IncrementalKeyedStateHandleTest {
 		verify(stateHandle2.getMetaStateHandle(), times(1)).discardState();
 	}
 
+	/**
+	 * This tests that re-registration of shared state with another registry works as expected. This simulates a
+	 * recovery from a checkpoint, when the checkpoint coordinator creates a new shared state registry and re-registers
+	 * all live checkpoint states.
+	 */
+	@Test
+	public void testSharedStateReRegistration() throws Exception {
+
+		SharedStateRegistry stateRegistryA = spy(new SharedStateRegistry());
+
+		IncrementalKeyedStateHandle stateHandleX = create(new Random(1));
+		IncrementalKeyedStateHandle stateHandleY = create(new Random(2));
+		IncrementalKeyedStateHandle stateHandleZ = create(new Random(3));
+
+		// Now we register first time ...
+		stateHandleX.registerSharedStates(stateRegistryA);
+		stateHandleY.registerSharedStates(stateRegistryA);
+		stateHandleZ.registerSharedStates(stateRegistryA);
+
+		try {
+			// Second attempt should fail
+			stateHandleX.registerSharedStates(stateRegistryA);
+			fail("Should not be able to register twice with the same registry.");
+		} catch (IllegalStateException ignore) {
+		}
+
+		// Everything should be discarded for this handle
+		stateHandleZ.discardState();
+		verify(stateHandleZ.getMetaStateHandle(), times(1)).discardState();
+		for (StreamStateHandle stateHandle : stateHandleZ.getSharedState().values()) {
+			verify(stateHandle, times(1)).discardState();
+		}
+
+		// Close the first registry
+		stateRegistryA.close();
+
+		// Attempt to register to closed registry should trigger exception
+		try {
+			create(new Random(4)).registerSharedStates(stateRegistryA);
+			fail("Should not be able to register new state to closed registry.");
+		} catch (IllegalStateException ignore) {
+		}
+
+		// All state should still get discarded
+		stateHandleY.discardState();
+		verify(stateHandleY.getMetaStateHandle(), times(1)).discardState();
+		for (StreamStateHandle stateHandle : stateHandleY.getSharedState().values()) {
+			verify(stateHandle, times(1)).discardState();
+		}
+
+		// This should still be unaffected
+		verify(stateHandleX.getMetaStateHandle(), never()).discardState();
+		for (StreamStateHandle stateHandle : stateHandleX.getSharedState().values()) {
+			verify(stateHandle, never()).discardState();
+		}
+
+		// We re-register the handle with a new registry
+		SharedStateRegistry sharedStateRegistryB = spy(new SharedStateRegistry());
+		stateHandleX.registerSharedStates(sharedStateRegistryB);
+		stateHandleX.discardState();
+
+		// Should be completely discarded because it is tracked through the new registry
+		verify(stateHandleX.getMetaStateHandle(), times(1)).discardState();
+		for (StreamStateHandle stateHandle : stateHandleX.getSharedState().values()) {
+			verify(stateHandle, times(1)).discardState();
+		}
+
+		sharedStateRegistryB.close();
+	}
+
 	private static IncrementalKeyedStateHandle create(Random rnd) {
 		return new IncrementalKeyedStateHandle(
 			UUID.nameUUIDFromBytes("test".getBytes()),

http://git-wip-us.apache.org/repos/asf/flink/blob/91a4b276/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java
index a0c4412..037ecd1 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java
@@ -21,7 +21,8 @@ package org.apache.flink.runtime.testutils;
 import org.apache.flink.runtime.checkpoint.CompletedCheckpoint;
 import org.apache.flink.runtime.checkpoint.CompletedCheckpointStore;
 import org.apache.flink.runtime.jobgraph.JobStatus;
-import org.apache.flink.runtime.state.SharedStateRegistry;
+import org.apache.flink.util.Preconditions;
+
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -41,14 +42,21 @@ public class RecoverableCompletedCheckpointStore implements CompletedCheckpointS
 
 	private final ArrayDeque<CompletedCheckpoint> suspended = new ArrayDeque<>(2);
 
+	private final int maxRetainedCheckpoints;
+
+	public RecoverableCompletedCheckpointStore() {
+		this(1);
+	}
+
+	public RecoverableCompletedCheckpointStore(int maxRetainedCheckpoints) {
+		Preconditions.checkArgument(maxRetainedCheckpoints > 0);
+		this.maxRetainedCheckpoints = maxRetainedCheckpoints;
+	}
+
 	@Override
-	public void recover(SharedStateRegistry sharedStateRegistry) throws Exception {
+	public void recover() throws Exception {
 		checkpoints.addAll(suspended);
 		suspended.clear();
-
-		for (CompletedCheckpoint checkpoint : checkpoints) {
-			checkpoint.registerSharedStatesAfterRestored(sharedStateRegistry);
-		}
 	}
 
 	@Override
@@ -56,13 +64,16 @@ public class RecoverableCompletedCheckpointStore implements CompletedCheckpointS
 
 		checkpoints.addLast(checkpoint);
 
-
-		if (checkpoints.size() > 1) {
-			CompletedCheckpoint checkpointToSubsume = checkpoints.removeFirst();
-			checkpointToSubsume.discardOnSubsume();
+		if (checkpoints.size() > maxRetainedCheckpoints) {
+			removeOldestCheckpoint();
 		}
 	}
 
+	public void removeOldestCheckpoint() throws Exception {
+		CompletedCheckpoint checkpointToSubsume = checkpoints.removeFirst();
+		checkpointToSubsume.discardOnSubsume();
+	}
+
 	@Override
 	public CompletedCheckpoint getLatestCheckpoint() throws Exception {
 		return checkpoints.isEmpty() ? null : checkpoints.getLast();
@@ -96,7 +107,7 @@ public class RecoverableCompletedCheckpointStore implements CompletedCheckpointS
 
 	@Override
 	public int getMaxNumberOfRetainedCheckpoints() {
-		return 1;
+		return maxRetainedCheckpoints;
 	}
 
 	@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/91a4b276/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index cb8639b..1ba5fb1 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -18,6 +18,7 @@
 package org.apache.flink.streaming.runtime.tasks;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.TaskInfo;
 import org.apache.flink.api.common.accumulators.Accumulator;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.core.fs.CloseableRegistry;
@@ -779,9 +780,11 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 	}
 
 	private String createOperatorIdentifier(StreamOperator<?> operator, int vertexId) {
+
+		TaskInfo taskInfo = getEnvironment().getTaskInfo();
 		return operator.getClass().getSimpleName() +
-				"_" + vertexId +
-				"_" + getEnvironment().getTaskInfo().getIndexOfThisSubtask();
+			"_" + operator.getOperatorID() +
+			"_(" + taskInfo.getIndexOfThisSubtask() + "/" + taskInfo.getNumberOfParallelSubtasks() + ")";
 	}
 
 	/**
@@ -892,18 +895,22 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 				if (asyncCheckpointState.compareAndSet(CheckpointingOperation.AsynCheckpointState.RUNNING,
 						CheckpointingOperation.AsynCheckpointState.COMPLETED)) {
 
+					TaskStateSnapshot acknowledgedState = hasState ? taskOperatorSubtaskStates : null;
+
 					// we signal stateless tasks by reporting null, so that there are no attempts to assign empty state
 					// to stateless tasks on restore. This enables simple job modifications that only concern
 					// stateless without the need to assign them uids to match their (always empty) states.
 					owner.getEnvironment().acknowledgeCheckpoint(
 						checkpointMetaData.getCheckpointId(),
 						checkpointMetrics,
-						hasState ? taskOperatorSubtaskStates : null);
+						acknowledgedState);
+
+					LOG.debug("{} - finished asynchronous part of checkpoint {}. Asynchronous duration: {} ms",
+						owner.getName(), checkpointMetaData.getCheckpointId(), asyncDurationMillis);
+
+					LOG.trace("{} - reported the following states in snapshot for checkpoint {}: {}.",
+						owner.getName(), checkpointMetaData.getCheckpointId(), acknowledgedState);
 
-					if (LOG.isDebugEnabled()) {
-						LOG.debug("{} - finished asynchronous part of checkpoint {}. Asynchronous duration: {} ms",
-							owner.getName(), checkpointMetaData.getCheckpointId(), asyncDurationMillis);
-					}
 				} else {
 					LOG.debug("{} - asynchronous part of checkpoint {} could not be completed because it was closed before.",
 						owner.getName(),

http://git-wip-us.apache.org/repos/asf/flink/blob/91a4b276/flink-tests/src/test/java/org/apache/flink/test/checkpointing/AbstractEventTimeWindowCheckpointingITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/AbstractEventTimeWindowCheckpointingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/AbstractEventTimeWindowCheckpointingITCase.java
index 22ed847..c525a37 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/AbstractEventTimeWindowCheckpointingITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/AbstractEventTimeWindowCheckpointingITCase.java
@@ -27,6 +27,7 @@ import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.tuple.Tuple4;
 import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.HighAvailabilityOptions;
 import org.apache.flink.configuration.TaskManagerOptions;
 import org.apache.flink.contrib.streaming.state.RocksDBStateBackend;
 import org.apache.flink.core.fs.Path;
@@ -48,21 +49,22 @@ import org.apache.flink.test.util.SuccessException;
 import org.apache.flink.util.Collector;
 import org.apache.flink.util.TestLogger;
 
+import org.apache.curator.test.TestingServer;
 import org.junit.After;
-import org.junit.AfterClass;
 import org.junit.Before;
-import org.junit.BeforeClass;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.TemporaryFolder;
 import org.junit.rules.TestName;
 
+import java.io.File;
 import java.io.IOException;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 
 import static java.util.concurrent.TimeUnit.MILLISECONDS;
+import static org.apache.flink.test.checkpointing.AbstractEventTimeWindowCheckpointingITCase.StateBackendEnum.ROCKSDB_INCREMENTAL_ZK;
 import static org.apache.flink.test.util.TestUtils.tryExecute;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
@@ -87,6 +89,8 @@ public abstract class AbstractEventTimeWindowCheckpointingITCase extends TestLog
 
 	private static TestStreamEnvironment env;
 
+	private static TestingServer zkServer;
+
 	@Rule
 	public TemporaryFolder tempFolder = new TemporaryFolder();
 
@@ -101,11 +105,27 @@ public abstract class AbstractEventTimeWindowCheckpointingITCase extends TestLog
 	}
 
 	enum StateBackendEnum {
-		MEM, FILE, ROCKSDB_FULLY_ASYNC, ROCKSDB_INCREMENTAL, MEM_ASYNC, FILE_ASYNC
+		MEM, FILE, ROCKSDB_FULLY_ASYNC, ROCKSDB_INCREMENTAL, ROCKSDB_INCREMENTAL_ZK, MEM_ASYNC, FILE_ASYNC
 	}
 
-	@BeforeClass
-	public static void startTestCluster() {
+	@Before
+	public void startTestCluster() throws Exception {
+
+		// print a message when starting a test method to avoid Travis' <tt>"Maven produced no
+		// output for xxx seconds."</tt> messages
+		System.out.println(
+			"Starting " + getClass().getCanonicalName() + "#" + name.getMethodName() + ".");
+
+		// Testing HA Scenario / ZKCompletedCheckpointStore with incremental checkpoints
+		if (ROCKSDB_INCREMENTAL_ZK.equals(stateBackendEnum)) {
+			zkServer = new TestingServer();
+			zkServer.start();
+		}
+
+		TemporaryFolder temporaryFolder = new TemporaryFolder();
+		temporaryFolder.create();
+		final File haDir = temporaryFolder.newFolder();
+
 		Configuration config = new Configuration();
 		config.setInteger(ConfigConstants.LOCAL_NUMBER_TASK_MANAGER, 2);
 		config.setInteger(ConfigConstants.TASK_MANAGER_NUM_TASK_SLOTS, PARALLELISM / 2);
@@ -113,28 +133,18 @@ public abstract class AbstractEventTimeWindowCheckpointingITCase extends TestLog
 		// the default network buffers size (10% of heap max =~ 150MB) seems to much for this test case
 		config.setLong(TaskManagerOptions.NETWORK_BUFFERS_MEMORY_MAX, 80L << 20); // 80 MB
 
+		if (zkServer != null) {
+			config.setString(HighAvailabilityOptions.HA_MODE, "ZOOKEEPER");
+			config.setString(HighAvailabilityOptions.HA_ZOOKEEPER_QUORUM, zkServer.getConnectString());
+			config.setString(HighAvailabilityOptions.HA_STORAGE_PATH, haDir.toURI().toString());
+		}
+
 		cluster = new LocalFlinkMiniCluster(config, false);
 		cluster.start();
 
 		env = new TestStreamEnvironment(cluster, PARALLELISM);
 		env.getConfig().setUseSnapshotCompression(true);
-	}
-
-	@AfterClass
-	public static void stopTestCluster() {
-		if (cluster != null) {
-			cluster.stop();
-		}
-	}
-
-	@Before
-	public void beforeTest() throws IOException {
-		// print a message when starting a test method to avoid Travis' <tt>"Maven produced no
-		// output for xxx seconds."</tt> messages
-		System.out.println(
-			"Starting " + getClass().getCanonicalName() + "#" + name.getMethodName() + ".");
 
-		// init state back-end
 		switch (stateBackendEnum) {
 			case MEM:
 				this.stateBackend = new MemoryStateBackend(MAX_MEM_STATE_SIZE, false);
@@ -159,7 +169,8 @@ public abstract class AbstractEventTimeWindowCheckpointingITCase extends TestLog
 				this.stateBackend = rdb;
 				break;
 			}
-			case ROCKSDB_INCREMENTAL: {
+			case ROCKSDB_INCREMENTAL:
+			case ROCKSDB_INCREMENTAL_ZK: {
 				String rocksDb = tempFolder.newFolder().getAbsolutePath();
 				String backups = tempFolder.newFolder().getAbsolutePath();
 				// we use the fs backend with small threshold here to test the behaviour with file
@@ -173,16 +184,25 @@ public abstract class AbstractEventTimeWindowCheckpointingITCase extends TestLog
 				this.stateBackend = rdb;
 				break;
 			}
-
+			default:
+				throw new IllegalStateException("No backend selected.");
 		}
 	}
 
-	/**
-	 * Prints a message when finishing a test method to avoid Travis' <tt>"Maven produced no output
-	 * for xxx seconds."</tt> messages.
-	 */
 	@After
-	public void afterTest() {
+	public void stopTestCluster() throws IOException {
+		if (cluster != null) {
+			cluster.stop();
+			cluster = null;
+		}
+
+		if (zkServer != null) {
+			zkServer.stop();
+			zkServer = null;
+		}
+
+		//Prints a message when finishing a test method to avoid Travis' <tt>"Maven produced no output
+		// for xxx seconds."</tt> messages.
 		System.out.println(
 			"Finished " + getClass().getCanonicalName() + "#" + name.getMethodName() + ".");
 	}


[5/7] flink git commit: [FLINK-7213] Introduce state management by OperatorID in TaskManager

Posted by sr...@apache.org.
http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskStopTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskStopTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskStopTest.java
index 40678de..1ebd4ad 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskStopTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskStopTest.java
@@ -20,39 +20,41 @@ package org.apache.flink.runtime.taskmanager;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.TaskInfo;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
-import org.apache.flink.runtime.executiongraph.JobInformation;
-import org.apache.flink.runtime.executiongraph.TaskInformation;
-import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker;
-import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier;
-import org.apache.flink.runtime.jobgraph.JobVertexID;
-import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
-import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup;
-import org.apache.flink.runtime.metrics.groups.TaskMetricGroup;
-import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor;
 import org.apache.flink.runtime.execution.ExecutionState;
 import org.apache.flink.runtime.execution.librarycache.LibraryCacheManager;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.executiongraph.JobInformation;
+import org.apache.flink.runtime.executiongraph.TaskInformation;
 import org.apache.flink.runtime.filecache.FileCache;
 import org.apache.flink.runtime.io.disk.iomanager.IOManager;
 import org.apache.flink.runtime.io.network.NetworkEnvironment;
+import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
+import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
 import org.apache.flink.runtime.jobgraph.tasks.StoppableTask;
 import org.apache.flink.runtime.memory.MemoryManager;
-import org.apache.flink.runtime.state.TaskStateHandles;
+import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup;
+import org.apache.flink.runtime.metrics.groups.TaskMetricGroup;
+
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.powermock.core.classloader.annotations.PrepareForTest;
 import org.powermock.modules.junit4.PowerMockRunner;
-import scala.concurrent.duration.FiniteDuration;
 
 import java.lang.reflect.Field;
 import java.util.Collections;
 import java.util.concurrent.Executor;
 
+import scala.concurrent.duration.FiniteDuration;
+
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
@@ -88,7 +90,7 @@ public class TaskStopTest {
 			Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
 			Collections.<InputGateDeploymentDescriptor>emptyList(),
 			0,
-			mock(TaskStateHandles.class),
+			mock(TaskStateSnapshot.class),
 			mock(MemoryManager.class),
 			mock(IOManager.class),
 			mock(NetworkEnvironment.class),

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java
index f262bf2..c1df5a3 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java
@@ -27,7 +27,7 @@ import org.apache.flink.core.testutils.CommonTestUtils;
 import org.apache.flink.runtime.blob.BlobKey;
 import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
-import org.apache.flink.runtime.checkpoint.SubtaskState;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
@@ -70,7 +70,8 @@ import java.util.concurrent.Executor;
 import java.util.concurrent.Executors;
 
 import static org.junit.Assume.assumeTrue;
-import static org.mockito.Mockito.*;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 /**
  * Test that verifies the behavior of blocking shutdown hooks and of the
@@ -232,7 +233,7 @@ public class JvmExitOnFatalErrorTest {
 		private static final class NoOpCheckpointResponder implements CheckpointResponder {
 
 			@Override
-			public void acknowledgeCheckpoint(JobID j, ExecutionAttemptID e, long i, CheckpointMetrics c, SubtaskState s) {}
+			public void acknowledgeCheckpoint(JobID j, ExecutionAttemptID e, long i, CheckpointMetrics c, TaskStateSnapshot s) {}
 
 			@Override
 			public void declineCheckpoint(JobID j, ExecutionAttemptID e, long l, Throwable t) {}

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamConfig.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamConfig.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamConfig.java
index 77caa34..13100db 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamConfig.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamConfig.java
@@ -21,6 +21,7 @@ import org.apache.flink.annotation.Internal;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.operators.util.CorruptConfigurationException;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.util.ClassLoaderUtil;
@@ -76,6 +77,7 @@ public class StreamConfig implements Serializable {
 	private static final String OUT_STREAM_EDGES = "outStreamEdges";
 	private static final String IN_STREAM_EDGES = "inStreamEdges";
 	private static final String OPERATOR_NAME = "operatorName";
+	private static final String OPERATOR_ID = "operatorID";
 	private static final String CHAIN_END = "chainEnd";
 
 	private static final String CHECKPOINTING_ENABLED = "checkpointing";
@@ -213,7 +215,7 @@ public class StreamConfig implements Serializable {
 		}
 	}
 
-	public <T> T getStreamOperator(ClassLoader cl) {
+	public <T extends StreamOperator<?>> T getStreamOperator(ClassLoader cl) {
 		try {
 			return InstantiationUtil.readObjectFromConfig(this.config, SERIALIZEDUDF, cl);
 		}
@@ -411,6 +413,15 @@ public class StreamConfig implements Serializable {
 		}
 	}
 
+	public void setOperatorID(OperatorID operatorID) {
+		this.config.setBytes(OPERATOR_ID, operatorID.getBytes());
+	}
+
+	public OperatorID getOperatorID() {
+		byte[] operatorIDBytes = config.getBytes(OPERATOR_ID, null);
+		return new OperatorID(Preconditions.checkNotNull(operatorIDBytes));
+	}
+
 	public void setOperatorName(String name) {
 		this.config.setString(OPERATOR_NAME, name);
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
index e70962b..abaa74e 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
@@ -246,7 +246,9 @@ public class StreamingJobGraphGenerator {
 				operatorHashes = new ArrayList<>();
 				chainedOperatorHashes.put(startNodeId, operatorHashes);
 			}
-			operatorHashes.add(new Tuple2<>(hashes.get(currentNodeId), legacyHashes.get(1).get(currentNodeId)));
+
+			byte[] primaryHashBytes = hashes.get(currentNodeId);
+			operatorHashes.add(new Tuple2<>(primaryHashBytes, legacyHashes.get(1).get(currentNodeId)));
 
 			chainedNames.put(currentNodeId, createChainedName(currentNodeId, chainableOutputs));
 			chainedMinResources.put(currentNodeId, createChainedMinResources(currentNodeId, chainableOutputs));
@@ -280,13 +282,16 @@ public class StreamingJobGraphGenerator {
 					chainedConfigs.put(startNodeId, new HashMap<Integer, StreamConfig>());
 				}
 				config.setChainIndex(chainIndex);
-				config.setOperatorName(streamGraph.getStreamNode(currentNodeId).getOperatorName());
+				StreamNode node = streamGraph.getStreamNode(currentNodeId);
+				config.setOperatorName(node.getOperatorName());
 				chainedConfigs.get(startNodeId).put(currentNodeId, config);
 			}
+
+			config.setOperatorID(new OperatorID(primaryHashBytes));
+
 			if (chainableOutputs.isEmpty()) {
 				config.setChainEnd();
 			}
-
 			return transitiveOutEdges;
 
 		} else {

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
index d711518..a72b9fe 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
@@ -36,6 +36,8 @@ import org.apache.flink.metrics.Gauge;
 import org.apache.flink.metrics.MetricGroup;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions.CheckpointType;
+import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.metrics.groups.OperatorMetricGroup;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
@@ -60,7 +62,6 @@ import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
-import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
 import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
 import org.apache.flink.util.OutputTag;
@@ -179,7 +180,6 @@ public abstract class AbstractStreamOperator<OUT>
 	public void setup(StreamTask<?, ?> containingTask, StreamConfig config, Output<StreamRecord<OUT>> output) {
 		this.container = containingTask;
 		this.config = config;
-
 		this.metrics = container.getEnvironment().getMetricGroup().addOperator(config.getOperatorName());
 		this.output = new CountingOutput(output, ((OperatorMetricGroup) this.metrics).getIOMetricGroup().getNumRecordsOutCounter());
 		if (config.isChainStart()) {
@@ -208,13 +208,13 @@ public abstract class AbstractStreamOperator<OUT>
 	}
 
 	@Override
-	public final void initializeState(OperatorStateHandles stateHandles) throws Exception {
+	public final void initializeState(OperatorSubtaskState stateHandles) throws Exception {
 
 		Collection<KeyedStateHandle> keyedStateHandlesRaw = null;
 		Collection<OperatorStateHandle> operatorStateHandlesRaw = null;
 		Collection<OperatorStateHandle> operatorStateHandlesBackend = null;
 
-		boolean restoring = null != stateHandles;
+		boolean restoring = (null != stateHandles);
 
 		initKeyedState(); //TODO we should move the actual initialization of this from StreamTask to this class
 
@@ -266,13 +266,13 @@ public abstract class AbstractStreamOperator<OUT>
 	 * Can be removed when we remove the APIs for non-repartitionable operator state.
 	 */
 	@Deprecated
-	private void restoreStreamCheckpointed(OperatorStateHandles stateHandles) throws Exception {
+	private void restoreStreamCheckpointed(OperatorSubtaskState stateHandles) throws Exception {
 		StreamStateHandle state = stateHandles.getLegacyOperatorState();
 		if (null != state) {
 			if (this instanceof CheckpointedRestoringOperator) {
 
-				LOG.debug("Restore state of task {} in chain ({}).",
-						stateHandles.getOperatorChainIndex(), getContainingTask().getName());
+				LOG.debug("Restore state of task {} in operator with id ({}).",
+					getContainingTask().getName(), getOperatorID());
 
 				FSDataInputStream is = state.openInputStream();
 				try {
@@ -973,6 +973,11 @@ public abstract class AbstractStreamOperator<OUT>
 		}
 	}
 
+	@Override
+	public OperatorID getOperatorID() {
+		return config.getOperatorID();
+	}
+
 	@VisibleForTesting
 	public int numProcessingTimeTimers() {
 		return timeServiceManager == null ? 0 :

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java
index 61578b2..9d5e02b 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java
@@ -20,10 +20,11 @@ package org.apache.flink.streaming.api.operators;
 import org.apache.flink.annotation.PublicEvolving;
 import org.apache.flink.metrics.MetricGroup;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
-import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
 
 import java.io.Serializable;
@@ -123,7 +124,7 @@ public interface StreamOperator<OUT> extends Serializable {
 	 *
 	 * @param stateHandles state handles to the operator state.
 	 */
-	void initializeState(OperatorStateHandles stateHandles) throws Exception;
+	void initializeState(OperatorSubtaskState stateHandles) throws Exception;
 
 	/**
 	 * Called when the checkpoint with the given ID is completed and acknowledged on the JobManager.
@@ -149,4 +150,5 @@ public interface StreamOperator<OUT> extends Serializable {
 
 	MetricGroup getMetricGroup();
 
+	OperatorID getOperatorID();
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorStateHandles.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorStateHandles.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorStateHandles.java
index 1a79f54..4914075 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorStateHandles.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorStateHandles.java
@@ -20,13 +20,10 @@ package org.apache.flink.streaming.runtime.tasks;
 
 import org.apache.flink.annotation.Internal;
 import org.apache.flink.annotation.VisibleForTesting;
-import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
-import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.util.CollectionUtil;
-import org.apache.flink.util.Preconditions;
 
 import java.util.Collection;
 import java.util.List;
@@ -63,22 +60,6 @@ public class OperatorStateHandles {
 		this.rawOperatorState = rawOperatorState;
 	}
 
-	public OperatorStateHandles(TaskStateHandles taskStateHandles, int operatorChainIndex) {
-		Preconditions.checkNotNull(taskStateHandles);
-
-		this.operatorChainIndex = operatorChainIndex;
-
-		ChainedStateHandle<StreamStateHandle> legacyState = taskStateHandles.getLegacyOperatorState();
-		this.legacyOperatorState = ChainedStateHandle.isNullOrEmpty(legacyState) ?
-				null : legacyState.get(operatorChainIndex);
-
-		this.rawKeyedState = taskStateHandles.getRawKeyedState();
-		this.managedKeyedState = taskStateHandles.getManagedKeyedState();
-
-		this.managedOperatorState = getSafeItemAtIndexOrNull(taskStateHandles.getManagedOperatorState(), operatorChainIndex);
-		this.rawOperatorState = getSafeItemAtIndexOrNull(taskStateHandles.getRawOperatorState(), operatorChainIndex);
-	}
-
 	public StreamStateHandle getLegacyOperatorState() {
 		return legacyOperatorState;
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index c35a6dc..cb8639b 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -25,17 +25,18 @@ import org.apache.flink.core.fs.FileSystemSafetyNet;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
-import org.apache.flink.runtime.checkpoint.SubtaskState;
+import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.execution.CancelTaskException;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker;
 import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.runtime.jobgraph.tasks.StatefulTask;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
-import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyedStateHandle;
@@ -44,7 +45,6 @@ import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateBackend;
 import org.apache.flink.runtime.state.StateUtil;
 import org.apache.flink.runtime.state.StreamStateHandle;
-import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.taskmanager.DispatcherThreadFactory;
 import org.apache.flink.streaming.api.TimeCharacteristic;
 import org.apache.flink.streaming.api.graph.StreamConfig;
@@ -54,7 +54,6 @@ import org.apache.flink.streaming.api.operators.StreamOperator;
 import org.apache.flink.streaming.runtime.io.RecordWriterOutput;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.streamstatus.StreamStatusMaintainer;
-import org.apache.flink.util.CollectionUtil;
 import org.apache.flink.util.ExceptionUtils;
 import org.apache.flink.util.FutureUtil;
 import org.apache.flink.util.Preconditions;
@@ -64,13 +63,11 @@ import org.slf4j.LoggerFactory;
 
 import java.io.Closeable;
 import java.io.IOException;
-import java.util.ArrayList;
 import java.util.Collection;
-import java.util.List;
+import java.util.HashMap;
 import java.util.Map;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
-import java.util.concurrent.RunnableFuture;
 import java.util.concurrent.ThreadFactory;
 import java.util.concurrent.atomic.AtomicReference;
 
@@ -158,7 +155,7 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 	/** The map of user-defined accumulators of this task. */
 	private Map<String, Accumulator<?, ?>> accumulatorMap;
 
-	private TaskStateHandles restoreStateHandles;
+	private TaskStateSnapshot taskStateSnapshot;
 
 	/** The currently active background materialization threads. */
 	private final CloseableRegistry cancelables = new CloseableRegistry();
@@ -508,8 +505,8 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 	// ------------------------------------------------------------------------
 
 	@Override
-	public void setInitialState(TaskStateHandles taskStateHandles) {
-		this.restoreStateHandles = taskStateHandles;
+	public void setInitialState(TaskStateSnapshot taskStateHandles) {
+		this.taskStateSnapshot = taskStateHandles;
 	}
 
 	@Override
@@ -658,12 +655,11 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 
 	private void initializeState() throws Exception {
 
-		boolean restored = null != restoreStateHandles;
+		boolean restored = null != taskStateSnapshot;
 
 		if (restored) {
-			checkRestorePreconditions(operatorChain.getChainLength());
 			initializeOperators(true);
-			restoreStateHandles = null; // free for GC
+			taskStateSnapshot = null; // free for GC
 		} else {
 			initializeOperators(false);
 		}
@@ -674,8 +670,8 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 		for (int chainIdx = 0; chainIdx < allOperators.length; ++chainIdx) {
 			StreamOperator<?> operator = allOperators[chainIdx];
 			if (null != operator) {
-				if (restored && restoreStateHandles != null) {
-					operator.initializeState(new OperatorStateHandles(restoreStateHandles, chainIdx));
+				if (restored && taskStateSnapshot != null) {
+					operator.initializeState(taskStateSnapshot.getSubtaskStateByOperatorID(operator.getOperatorID()));
 				} else {
 					operator.initializeState(null);
 				}
@@ -683,26 +679,6 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 		}
 	}
 
-	private void checkRestorePreconditions(int operatorChainLength) {
-
-		ChainedStateHandle<StreamStateHandle> nonPartitionableOperatorStates =
-				restoreStateHandles.getLegacyOperatorState();
-		List<Collection<OperatorStateHandle>> operatorStates =
-				restoreStateHandles.getManagedOperatorState();
-
-		if (nonPartitionableOperatorStates != null) {
-			Preconditions.checkState(nonPartitionableOperatorStates.getLength() == operatorChainLength,
-					"Invalid Invalid number of operator states. Found :" + nonPartitionableOperatorStates.getLength()
-							+ ". Expected: " + operatorChainLength);
-		}
-
-		if (!CollectionUtil.isNullOrEmpty(operatorStates)) {
-			Preconditions.checkArgument(operatorStates.size() == operatorChainLength,
-					"Invalid number of operator states. Found :" + operatorStates.size() +
-							". Expected: " + operatorChainLength);
-		}
-	}
-
 	// ------------------------------------------------------------------------
 	//  State backend
 	// ------------------------------------------------------------------------
@@ -768,8 +744,13 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 		cancelables.registerClosable(keyedStateBackend);
 
 		// restore if we have some old state
-		Collection<KeyedStateHandle> restoreKeyedStateHandles =
-			restoreStateHandles == null ? null : restoreStateHandles.getManagedKeyedState();
+		Collection<KeyedStateHandle> restoreKeyedStateHandles = null;
+
+		if (taskStateSnapshot != null) {
+			OperatorSubtaskState stateByOperatorID =
+				taskStateSnapshot.getSubtaskStateByOperatorID(headOperator.getOperatorID());
+			restoreKeyedStateHandles = stateByOperatorID != null ? stateByOperatorID.getManagedKeyedState() : null;
+		}
 
 		keyedStateBackend.restore(restoreKeyedStateHandles);
 
@@ -850,12 +831,9 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 
 		private final StreamTask<?, ?> owner;
 
-		private final List<OperatorSnapshotResult> snapshotInProgressList;
-
-		private RunnableFuture<KeyedStateHandle> futureKeyedBackendStateHandles;
-		private RunnableFuture<KeyedStateHandle> futureKeyedStreamStateHandles;
+		private final Map<OperatorID, OperatorSnapshotResult> operatorSnapshotsInProgress;
 
-		private List<StreamStateHandle> nonPartitionedStateHandles;
+		private Map<OperatorID, StreamStateHandle> nonPartitionedStateHandles;
 
 		private final CheckpointMetaData checkpointMetaData;
 		private final CheckpointMetrics checkpointMetrics;
@@ -867,81 +845,60 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 
 		AsyncCheckpointRunnable(
 				StreamTask<?, ?> owner,
-				List<StreamStateHandle> nonPartitionedStateHandles,
-				List<OperatorSnapshotResult> snapshotInProgressList,
+				Map<OperatorID, StreamStateHandle> nonPartitionedStateHandles,
+				Map<OperatorID, OperatorSnapshotResult> operatorSnapshotsInProgress,
 				CheckpointMetaData checkpointMetaData,
 				CheckpointMetrics checkpointMetrics,
 				long asyncStartNanos) {
 
 			this.owner = Preconditions.checkNotNull(owner);
-			this.snapshotInProgressList = Preconditions.checkNotNull(snapshotInProgressList);
+			this.operatorSnapshotsInProgress = Preconditions.checkNotNull(operatorSnapshotsInProgress);
 			this.checkpointMetaData = Preconditions.checkNotNull(checkpointMetaData);
 			this.checkpointMetrics = Preconditions.checkNotNull(checkpointMetrics);
 			this.nonPartitionedStateHandles = nonPartitionedStateHandles;
 			this.asyncStartNanos = asyncStartNanos;
-
-			if (!snapshotInProgressList.isEmpty()) {
-				// TODO Currently only the head operator of a chain can have keyed state, so simply access it directly.
-				int headIndex = snapshotInProgressList.size() - 1;
-				OperatorSnapshotResult snapshotInProgress = snapshotInProgressList.get(headIndex);
-				if (null != snapshotInProgress) {
-					this.futureKeyedBackendStateHandles = snapshotInProgress.getKeyedStateManagedFuture();
-					this.futureKeyedStreamStateHandles = snapshotInProgress.getKeyedStateRawFuture();
-				}
-			}
 		}
 
 		@Override
 		public void run() {
 			FileSystemSafetyNet.initializeSafetyNetForThread();
 			try {
-				// Keyed state handle future, currently only one (the head) operator can have this
-				KeyedStateHandle keyedStateHandleBackend = FutureUtil.runIfNotDoneAndGet(futureKeyedBackendStateHandles);
-				KeyedStateHandle keyedStateHandleStream = FutureUtil.runIfNotDoneAndGet(futureKeyedStreamStateHandles);
-
-				List<OperatorStateHandle> operatorStatesBackend = new ArrayList<>(snapshotInProgressList.size());
-				List<OperatorStateHandle> operatorStatesStream = new ArrayList<>(snapshotInProgressList.size());
-
-				for (OperatorSnapshotResult snapshotInProgress : snapshotInProgressList) {
-					if (null != snapshotInProgress) {
-						operatorStatesBackend.add(
-								FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getOperatorStateManagedFuture()));
-						operatorStatesStream.add(
-								FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getOperatorStateRawFuture()));
-					} else {
-						operatorStatesBackend.add(null);
-						operatorStatesStream.add(null);
-					}
-				}
+				boolean hasState = false;
+				final TaskStateSnapshot taskOperatorSubtaskStates =
+					new TaskStateSnapshot(operatorSnapshotsInProgress.size());
 
-				final long asyncEndNanos = System.nanoTime();
-				final long asyncDurationMillis = (asyncEndNanos - asyncStartNanos) / 1_000_000;
+				for (Map.Entry<OperatorID, OperatorSnapshotResult> entry : operatorSnapshotsInProgress.entrySet()) {
 
-				checkpointMetrics.setAsyncDurationMillis(asyncDurationMillis);
+					OperatorID operatorID = entry.getKey();
+					OperatorSnapshotResult snapshotInProgress = entry.getValue();
 
-				ChainedStateHandle<StreamStateHandle> chainedNonPartitionedOperatorsState =
-						new ChainedStateHandle<>(nonPartitionedStateHandles);
+					OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(
+						nonPartitionedStateHandles.get(operatorID),
+						FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getOperatorStateManagedFuture()),
+						FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getOperatorStateRawFuture()),
+						FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getKeyedStateManagedFuture()),
+						FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getKeyedStateRawFuture())
+					);
 
-				ChainedStateHandle<OperatorStateHandle> chainedOperatorStateBackend =
-						new ChainedStateHandle<>(operatorStatesBackend);
+					hasState |= operatorSubtaskState.hasState();
+					taskOperatorSubtaskStates.putSubtaskStateByOperatorID(operatorID, operatorSubtaskState);
+				}
 
-				ChainedStateHandle<OperatorStateHandle> chainedOperatorStateStream =
-						new ChainedStateHandle<>(operatorStatesStream);
+				final long asyncEndNanos = System.nanoTime();
+				final long asyncDurationMillis = (asyncEndNanos - asyncStartNanos) / 1_000_000;
 
-				SubtaskState subtaskState = createSubtaskStateFromSnapshotStateHandles(
-						chainedNonPartitionedOperatorsState,
-						chainedOperatorStateBackend,
-						chainedOperatorStateStream,
-						keyedStateHandleBackend,
-						keyedStateHandleStream);
+				checkpointMetrics.setAsyncDurationMillis(asyncDurationMillis);
 
 				if (asyncCheckpointState.compareAndSet(CheckpointingOperation.AsynCheckpointState.RUNNING,
 						CheckpointingOperation.AsynCheckpointState.COMPLETED)) {
 
+					// we signal stateless tasks by reporting null, so that there are no attempts to assign empty state
+					// to stateless tasks on restore. This enables simple job modifications that only concern
+					// stateless without the need to assign them uids to match their (always empty) states.
 					owner.getEnvironment().acknowledgeCheckpoint(
 						checkpointMetaData.getCheckpointId(),
 						checkpointMetrics,
-						subtaskState);
+						hasState ? taskOperatorSubtaskStates : null);
 
 					if (LOG.isDebugEnabled()) {
 						LOG.debug("{} - finished asynchronous part of checkpoint {}. Asynchronous duration: {} ms",
@@ -988,38 +945,13 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 			}
 		}
 
-		private SubtaskState createSubtaskStateFromSnapshotStateHandles(
-				ChainedStateHandle<StreamStateHandle> chainedNonPartitionedOperatorsState,
-				ChainedStateHandle<OperatorStateHandle> chainedOperatorStateBackend,
-				ChainedStateHandle<OperatorStateHandle> chainedOperatorStateStream,
-				KeyedStateHandle keyedStateHandleBackend,
-				KeyedStateHandle keyedStateHandleStream) {
-
-			boolean hasAnyState = keyedStateHandleBackend != null
-					|| keyedStateHandleStream != null
-					|| !chainedOperatorStateBackend.isEmpty()
-					|| !chainedOperatorStateStream.isEmpty()
-					|| !chainedNonPartitionedOperatorsState.isEmpty();
-
-			// we signal a stateless task by reporting null, so that there are no attempts to assign empty state to
-			// stateless tasks on restore. This allows for simple job modifications that only concern stateless without
-			// the need to assign them uids to match their (always empty) states.
-			return hasAnyState ? new SubtaskState(
-					chainedNonPartitionedOperatorsState,
-					chainedOperatorStateBackend,
-					chainedOperatorStateStream,
-					keyedStateHandleBackend,
-					keyedStateHandleStream)
-					: null;
-		}
-
 		private void cleanup() throws Exception {
 			if (asyncCheckpointState.compareAndSet(CheckpointingOperation.AsynCheckpointState.RUNNING, CheckpointingOperation.AsynCheckpointState.DISCARDED)) {
 				LOG.debug("Cleanup AsyncCheckpointRunnable for checkpoint {} of {}.", checkpointMetaData.getCheckpointId(), owner.getName());
 				Exception exception = null;
 
 				// clean up ongoing operator snapshot results and non partitioned state handles
-				for (OperatorSnapshotResult operatorSnapshotResult : snapshotInProgressList) {
+				for (OperatorSnapshotResult operatorSnapshotResult : operatorSnapshotsInProgress.values()) {
 					if (operatorSnapshotResult != null) {
 						try {
 							operatorSnapshotResult.cancel();
@@ -1031,7 +963,7 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 
 				// discard non partitioned state handles
 				try {
-					StateUtil.bestEffortDiscardAllStateObjects(nonPartitionedStateHandles);
+					StateUtil.bestEffortDiscardAllStateObjects(nonPartitionedStateHandles.values());
 				} catch (Exception discardException) {
 					exception = ExceptionUtils.firstOrSuppressed(discardException, exception);
 				}
@@ -1069,8 +1001,8 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 
 		// ------------------------
 
-		private final List<StreamStateHandle> nonPartitionedStates;
-		private final List<OperatorSnapshotResult> snapshotInProgressList;
+		private final Map<OperatorID, StreamStateHandle> nonPartitionedStates;
+		private final Map<OperatorID, OperatorSnapshotResult> operatorSnapshotsInProgress;
 
 		public CheckpointingOperation(
 				StreamTask<?, ?> owner,
@@ -1083,8 +1015,8 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 			this.checkpointOptions = Preconditions.checkNotNull(checkpointOptions);
 			this.checkpointMetrics = Preconditions.checkNotNull(checkpointMetrics);
 			this.allOperators = owner.operatorChain.getAllOperators();
-			this.nonPartitionedStates = new ArrayList<>(allOperators.length);
-			this.snapshotInProgressList = new ArrayList<>(allOperators.length);
+			this.nonPartitionedStates = new HashMap<>(allOperators.length);
+			this.operatorSnapshotsInProgress = new HashMap<>(allOperators.length);
 		}
 
 		public void executeCheckpointing() throws Exception {
@@ -1119,7 +1051,7 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 			} finally {
 				if (failed) {
 					// Cleanup to release resources
-					for (OperatorSnapshotResult operatorSnapshotResult : snapshotInProgressList) {
+					for (OperatorSnapshotResult operatorSnapshotResult : operatorSnapshotsInProgress.values()) {
 						if (null != operatorSnapshotResult) {
 							try {
 								operatorSnapshotResult.cancel();
@@ -1130,7 +1062,7 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 					}
 
 					// Cleanup non partitioned state handles
-					for (StreamStateHandle nonPartitionedState : nonPartitionedStates) {
+					for (StreamStateHandle nonPartitionedState : nonPartitionedStates.values()) {
 						if (nonPartitionedState != null) {
 							try {
 								nonPartitionedState.discardState();
@@ -1156,21 +1088,19 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 		private void checkpointStreamOperator(StreamOperator<?> op) throws Exception {
 			if (null != op) {
 				// first call the legacy checkpoint code paths
-				nonPartitionedStates.add(op.snapshotLegacyOperatorState(
-						checkpointMetaData.getCheckpointId(),
-						checkpointMetaData.getTimestamp(),
-						checkpointOptions));
+				StreamStateHandle legacyOperatorState = op.snapshotLegacyOperatorState(
+					checkpointMetaData.getCheckpointId(),
+					checkpointMetaData.getTimestamp(),
+					checkpointOptions);
+
+				OperatorID operatorID = op.getOperatorID();
+				nonPartitionedStates.put(operatorID, legacyOperatorState);
 
 				OperatorSnapshotResult snapshotInProgress = op.snapshotState(
 						checkpointMetaData.getCheckpointId(),
 						checkpointMetaData.getTimestamp(),
 						checkpointOptions);
-
-				snapshotInProgressList.add(snapshotInProgress);
-			} else {
-				nonPartitionedStates.add(null);
-				OperatorSnapshotResult emptySnapshotInProgress = new OperatorSnapshotResult();
-				snapshotInProgressList.add(emptySnapshotInProgress);
+				operatorSnapshotsInProgress.put(operatorID, snapshotInProgress);
 			}
 		}
 
@@ -1179,7 +1109,7 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 			AsyncCheckpointRunnable asyncCheckpointRunnable = new AsyncCheckpointRunnable(
 					owner,
 					nonPartitionedStates,
-					snapshotInProgressList,
+					operatorSnapshotsInProgress,
 					checkpointMetaData,
 					checkpointMetrics,
 					startAsyncPartNano);

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorLifecycleTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorLifecycleTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorLifecycleTest.java
index e8b4c9e..ff5f589 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorLifecycleTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorLifecycleTest.java
@@ -25,6 +25,7 @@ import org.apache.flink.core.testutils.OneShotLatch;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.execution.ExecutionState;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.state.StateInitializationContext;
 import org.apache.flink.runtime.state.StateSnapshotContext;
 import org.apache.flink.runtime.state.StreamStateHandle;
@@ -84,7 +85,7 @@ public class AbstractUdfStreamOperatorLifecycleTest {
 			"UDF::close");
 
 	private static final String ALL_METHODS_STREAM_OPERATOR = "[close[], dispose[], getChainingStrategy[], " +
-			"getMetricGroup[], initializeState[class org.apache.flink.streaming.runtime.tasks.OperatorStateHandles], " +
+			"getMetricGroup[], getOperatorID[], initializeState[class org.apache.flink.runtime.checkpoint.OperatorSubtaskState], " +
 			"notifyOfCompletedCheckpoint[long], open[], setChainingStrategy[class " +
 			"org.apache.flink.streaming.api.operators.ChainingStrategy], setKeyContextElement1[class " +
 			"org.apache.flink.streaming.runtime.streamrecord.StreamRecord], " +
@@ -132,6 +133,7 @@ public class AbstractUdfStreamOperatorLifecycleTest {
 		MockSourceFunction srcFun = new MockSourceFunction();
 
 		cfg.setStreamOperator(new LifecycleTrackingStreamSource(srcFun, true));
+		cfg.setOperatorID(new OperatorID());
 		cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
 
 		Task task = StreamTaskTest.createTask(SourceStreamTask.class, cfg, taskManagerConfig);
@@ -154,6 +156,7 @@ public class AbstractUdfStreamOperatorLifecycleTest {
 		StreamConfig cfg = new StreamConfig(new Configuration());
 		MockSourceFunction srcFun = new MockSourceFunction();
 		cfg.setStreamOperator(new LifecycleTrackingStreamSource(srcFun, false));
+		cfg.setOperatorID(new OperatorID());
 		cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
 
 		Task task = StreamTaskTest.createTask(SourceStreamTask.class, cfg, taskManagerConfig);

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java
index f9a1cd0..1dd99fe 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java
@@ -29,15 +29,15 @@ import org.apache.flink.core.testutils.OneShotLatch;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
-import org.apache.flink.runtime.checkpoint.SubtaskState;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.jobgraph.JobVertex;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.metrics.groups.TaskMetricGroup;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
 import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup;
-import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
 import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo;
 import org.apache.flink.streaming.api.datastream.AsyncDataStream;
@@ -500,7 +500,9 @@ public class AsyncWaitOperatorTest extends TestLogger {
 			AsyncDataStream.OutputMode.ORDERED);
 
 		final StreamConfig streamConfig = testHarness.getStreamConfig();
+		OperatorID operatorID = new OperatorID(42L, 4711L);
 		streamConfig.setStreamOperator(operator);
+		streamConfig.setOperatorID(operatorID);
 
 		final AcknowledgeStreamMockEnvironment env = new AcknowledgeStreamMockEnvironment(
 				testHarness.jobConfig,
@@ -540,7 +542,8 @@ public class AsyncWaitOperatorTest extends TestLogger {
 
 		// set the operator state from previous attempt into the restored one
 		final OneInputStreamTask<Integer, Integer> restoredTask = new OneInputStreamTask<>();
-		restoredTask.setInitialState(new TaskStateHandles(env.getCheckpointStateHandles()));
+		TaskStateSnapshot subtaskStates = env.getCheckpointStateHandles();
+		restoredTask.setInitialState(subtaskStates);
 
 		final OneInputStreamTaskTestHarness<Integer, Integer> restoredTaskHarness =
 				new OneInputStreamTaskTestHarness<>(restoredTask, BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO);
@@ -553,6 +556,7 @@ public class AsyncWaitOperatorTest extends TestLogger {
 			AsyncDataStream.OutputMode.ORDERED);
 
 		restoredTaskHarness.getStreamConfig().setStreamOperator(restoredOperator);
+		restoredTaskHarness.getStreamConfig().setOperatorID(operatorID);
 
 		restoredTaskHarness.invoke();
 		restoredTaskHarness.waitForTaskRunning();
@@ -595,7 +599,7 @@ public class AsyncWaitOperatorTest extends TestLogger {
 
 	private static class AcknowledgeStreamMockEnvironment extends StreamMockEnvironment {
 		private volatile long checkpointId;
-		private volatile SubtaskState checkpointStateHandles;
+		private volatile TaskStateSnapshot checkpointStateHandles;
 
 		private final OneShotLatch checkpointLatch = new OneShotLatch();
 
@@ -614,7 +618,7 @@ public class AsyncWaitOperatorTest extends TestLogger {
 		public void acknowledgeCheckpoint(
 				long checkpointId,
 				CheckpointMetrics checkpointMetrics,
-				SubtaskState checkpointStateHandles) {
+				TaskStateSnapshot checkpointStateHandles) {
 
 			this.checkpointId = checkpointId;
 			this.checkpointStateHandles = checkpointStateHandles;
@@ -625,7 +629,7 @@ public class AsyncWaitOperatorTest extends TestLogger {
 			return checkpointLatch;
 		}
 
-		public SubtaskState getCheckpointStateHandles() {
+		public TaskStateSnapshot getCheckpointStateHandles() {
 			return checkpointStateHandles;
 		}
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java
index c2cf7f3..491b23d 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java
@@ -23,6 +23,7 @@ import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.checkpoint.decline.CheckpointDeclineOnCancellationBarrierException;
 import org.apache.flink.runtime.checkpoint.decline.CheckpointDeclineSubsumedException;
 import org.apache.flink.runtime.io.disk.iomanager.IOManager;
@@ -34,7 +35,6 @@ import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
 import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
 import org.apache.flink.runtime.jobgraph.tasks.StatefulTask;
-import org.apache.flink.runtime.state.TaskStateHandles;
 
 import org.hamcrest.BaseMatcher;
 import org.hamcrest.Description;
@@ -1484,7 +1484,7 @@ public class BarrierBufferTest {
 		}
 
 		@Override
-		public void setInitialState(TaskStateHandles taskStateHandles) throws Exception {
+		public void setInitialState(TaskStateSnapshot taskStateHandles) throws Exception {
 			throw new UnsupportedOperationException("should never be called");
 		}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java
index 847db5c..cde9010 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java
@@ -22,13 +22,13 @@ import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
 import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
 import org.apache.flink.runtime.jobgraph.tasks.StatefulTask;
-import org.apache.flink.runtime.state.TaskStateHandles;
 
 import org.junit.Test;
 
@@ -498,7 +498,7 @@ public class BarrierTrackerTest {
 		}
 
 		@Override
-		public void setInitialState(TaskStateHandles taskStateHandles) throws Exception {
+		public void setInitialState(TaskStateSnapshot taskStateHandles) throws Exception {
 			throw new UnsupportedOperationException("should never be called");
 		}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamTaskTimerTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamTaskTimerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamTaskTimerTest.java
index 6e3be03..65e59f8 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamTaskTimerTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamTaskTimerTest.java
@@ -20,6 +20,7 @@ package org.apache.flink.streaming.runtime.operators;
 
 import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.StreamMap;
 import org.apache.flink.streaming.runtime.tasks.OneInputStreamTask;
@@ -53,6 +54,7 @@ public class StreamTaskTimerTest {
 
 		StreamMap<String, String> mapOperator = new StreamMap<>(new DummyMapFunction<String>());
 		streamConfig.setStreamOperator(mapOperator);
+		streamConfig.setOperatorID(new OperatorID());
 
 		testHarness.invoke();
 		testHarness.waitForTaskRunning();

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/TestProcessingTimeServiceTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/TestProcessingTimeServiceTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/TestProcessingTimeServiceTest.java
index 675ffa3..d621b0b 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/TestProcessingTimeServiceTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/TestProcessingTimeServiceTest.java
@@ -19,6 +19,7 @@
 package org.apache.flink.streaming.runtime.operators;
 
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.StreamMap;
 import org.apache.flink.streaming.runtime.tasks.AsyncExceptionHandler;
@@ -53,6 +54,7 @@ public class TestProcessingTimeServiceTest {
 
 		StreamMap<String, String> mapOperator = new StreamMap<>(new StreamTaskTimerTest.DummyMapFunction<String>());
 		streamConfig.setStreamOperator(mapOperator);
+		streamConfig.setOperatorID(new OperatorID());
 
 		testHarness.invoke();
 

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/BlockingCheckpointsTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/BlockingCheckpointsTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/BlockingCheckpointsTest.java
index 51328ab..3b8178b 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/BlockingCheckpointsTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/BlockingCheckpointsTest.java
@@ -45,6 +45,7 @@ import org.apache.flink.runtime.io.network.NetworkEnvironment;
 import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
 import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup;
@@ -93,6 +94,7 @@ public class BlockingCheckpointsTest {
 		Configuration taskConfig = new Configuration();
 		StreamConfig cfg = new StreamConfig(taskConfig);
 		cfg.setStreamOperator(new TestOperator());
+		cfg.setOperatorID(new OperatorID());
 		cfg.setStateBackend(new LockingStreamStateBackend());
 
 		Task task = createTask(taskConfig);

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
index 25b504b..82e4f31 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
@@ -26,6 +26,8 @@ import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.testutils.OneShotLatch;
 import org.apache.flink.runtime.blob.BlobKey;
 import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
+import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
@@ -40,11 +42,11 @@ import org.apache.flink.runtime.io.network.NetworkEnvironment;
 import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
 import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
-import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.DefaultOperatorStateBackend;
 import org.apache.flink.runtime.state.FunctionInitializationContext;
 import org.apache.flink.runtime.state.FunctionSnapshotContext;
@@ -55,7 +57,6 @@ import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateInitializationContext;
 import org.apache.flink.runtime.state.StreamStateHandle;
-import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.taskmanager.CheckpointResponder;
 import org.apache.flink.runtime.taskmanager.Task;
 import org.apache.flink.runtime.taskmanager.TaskManagerActions;
@@ -135,18 +136,18 @@ public class InterruptSensitiveRestoreTest {
 
 		IN_RESTORE_LATCH.reset();
 		Configuration taskConfig = new Configuration();
-		StreamConfig cfg = new StreamConfig(taskConfig);
-		cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
+		StreamConfig streamConfig = new StreamConfig(taskConfig);
+		streamConfig.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
 		switch (mode) {
 			case OPERATOR_MANAGED:
 			case OPERATOR_RAW:
 			case KEYED_MANAGED:
 			case KEYED_RAW:
-				cfg.setStateKeySerializer(IntSerializer.INSTANCE);
-				cfg.setStreamOperator(new StreamSource<>(new TestSource()));
+				streamConfig.setStateKeySerializer(IntSerializer.INSTANCE);
+				streamConfig.setStreamOperator(new StreamSource<>(new TestSource()));
 				break;
 			case LEGACY:
-				cfg.setStreamOperator(new StreamSource<>(new TestSourceLegacy()));
+				streamConfig.setStreamOperator(new StreamSource<>(new TestSourceLegacy()));
 				break;
 			default:
 				throw new IllegalArgumentException();
@@ -154,7 +155,7 @@ public class InterruptSensitiveRestoreTest {
 
 		StreamStateHandle lockingHandle = new InterruptLockingStateHandle();
 
-		Task task = createTask(taskConfig, lockingHandle, mode);
+		Task task = createTask(streamConfig, taskConfig, lockingHandle, mode);
 
 		// start the task and wait until it is in "restore"
 		task.startTaskThread();
@@ -178,19 +179,20 @@ public class InterruptSensitiveRestoreTest {
 	// ------------------------------------------------------------------------
 
 	private static Task createTask(
-			Configuration taskConfig,
-			StreamStateHandle state,
-			int mode) throws IOException {
+		StreamConfig streamConfig,
+		Configuration taskConfig,
+		StreamStateHandle state,
+		int mode) throws IOException {
 
 		NetworkEnvironment networkEnvironment = mock(NetworkEnvironment.class);
 		when(networkEnvironment.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class)))
 				.thenReturn(mock(TaskKvStateRegistry.class));
 
-		ChainedStateHandle<StreamStateHandle> operatorState = null;
-		List<KeyedStateHandle> keyedStateFromBackend = Collections.emptyList();
-		List<KeyedStateHandle> keyedStateFromStream = Collections.emptyList();
-		List<Collection<OperatorStateHandle>> operatorStateBackend = Collections.emptyList();
-		List<Collection<OperatorStateHandle>> operatorStateStream = Collections.emptyList();
+		StreamStateHandle operatorState = null;
+		Collection<KeyedStateHandle> keyedStateFromBackend = Collections.emptyList();
+		Collection<KeyedStateHandle> keyedStateFromStream = Collections.emptyList();
+		Collection<OperatorStateHandle> operatorStateBackend = Collections.emptyList();
+		Collection<OperatorStateHandle> operatorStateStream = Collections.emptyList();
 
 		Map<String, OperatorStateHandle.StateMetaInfo> operatorStateMetadata = new HashMap<>(1);
 		OperatorStateHandle.StateMetaInfo metaInfo =
@@ -207,10 +209,10 @@ public class InterruptSensitiveRestoreTest {
 
 		switch (mode) {
 			case OPERATOR_MANAGED:
-				operatorStateBackend = Collections.singletonList(operatorStateHandles);
+				operatorStateBackend = operatorStateHandles;
 				break;
 			case OPERATOR_RAW:
-				operatorStateStream = Collections.singletonList(operatorStateHandles);
+				operatorStateStream = operatorStateHandles;
 				break;
 			case KEYED_MANAGED:
 				keyedStateFromBackend = keyedStateHandles;
@@ -219,29 +221,35 @@ public class InterruptSensitiveRestoreTest {
 				keyedStateFromStream = keyedStateHandles;
 				break;
 			case LEGACY:
-				operatorState = new ChainedStateHandle<>(Collections.singletonList(state));
+				operatorState = state;
 				break;
 			default:
 				throw new IllegalArgumentException();
 		}
 
-		TaskStateHandles taskStateHandles = new TaskStateHandles(
+		OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(
 			operatorState,
 			operatorStateBackend,
 			operatorStateStream,
 			keyedStateFromBackend,
 			keyedStateFromStream);
 
+		JobVertexID jobVertexID = new JobVertexID();
+		OperatorID operatorID = OperatorID.fromJobVertexID(jobVertexID);
+		streamConfig.setOperatorID(operatorID);
+
+		TaskStateSnapshot stateSnapshot = new TaskStateSnapshot();
+		stateSnapshot.putSubtaskStateByOperatorID(operatorID, operatorSubtaskState);
 		JobInformation jobInformation = new JobInformation(
 			new JobID(),
 			"test job name",
 			new SerializedValue<>(new ExecutionConfig()),
-			new Configuration(),
+			taskConfig,
 			Collections.<BlobKey>emptyList(),
 			Collections.<URL>emptyList());
 
 		TaskInformation taskInformation = new TaskInformation(
-			new JobVertexID(),
+			jobVertexID,
 			"test task name",
 			1,
 			1,
@@ -258,7 +266,7 @@ public class InterruptSensitiveRestoreTest {
 			Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
 			Collections.<InputGateDeploymentDescriptor>emptyList(),
 			0,
-			taskStateHandles,
+			stateSnapshot,
 			mock(MemoryManager.class),
 			mock(IOManager.class),
 			networkEnvironment,
@@ -273,7 +281,6 @@ public class InterruptSensitiveRestoreTest {
 			mock(ResultPartitionConsumableNotifier.class),
 			mock(PartitionProducerStateChecker.class),
 			mock(Executor.class));
-
 	}
 
 	// ------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
index f7987a1..3190620 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
@@ -34,13 +34,13 @@ import org.apache.flink.core.testutils.OneShotLatch;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
-import org.apache.flink.runtime.checkpoint.SubtaskState;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
 import org.apache.flink.runtime.state.StateInitializationContext;
 import org.apache.flink.runtime.state.StateSnapshotContext;
-import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.streaming.api.collector.selector.OutputSelector;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.graph.StreamEdge;
@@ -109,6 +109,7 @@ public class OneInputStreamTaskTest extends TestLogger {
 		StreamConfig streamConfig = testHarness.getStreamConfig();
 		StreamMap<String, String> mapOperator = new StreamMap<String, String>(new TestOpenCloseMapFunction());
 		streamConfig.setStreamOperator(mapOperator);
+		streamConfig.setOperatorID(new OperatorID());
 
 		long initialTime = 0L;
 		ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<Object>();
@@ -151,6 +152,7 @@ public class OneInputStreamTaskTest extends TestLogger {
 		StreamConfig streamConfig = testHarness.getStreamConfig();
 		StreamMap<String, String> mapOperator = new StreamMap<String, String>(new IdentityMap());
 		streamConfig.setStreamOperator(mapOperator);
+		streamConfig.setOperatorID(new OperatorID());
 
 		ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<Object>();
 		long initialTime = 0L;
@@ -261,15 +263,21 @@ public class OneInputStreamTaskTest extends TestLogger {
 		// ------------------ setup the chain ------------------
 
 		TriggerableFailOnWatermarkTestOperator headOperator = new TriggerableFailOnWatermarkTestOperator();
+		OperatorID headOperatorId = new OperatorID();
+
 		StreamConfig headOperatorConfig = testHarness.getStreamConfig();
 
 		WatermarkGeneratingTestOperator watermarkOperator = new WatermarkGeneratingTestOperator();
+		OperatorID watermarkOperatorId = new OperatorID();
+
 		StreamConfig watermarkOperatorConfig = new StreamConfig(new Configuration());
 
 		TriggerableFailOnWatermarkTestOperator tailOperator = new TriggerableFailOnWatermarkTestOperator();
+		OperatorID tailOperatorId = new OperatorID();
 		StreamConfig tailOperatorConfig = new StreamConfig(new Configuration());
 
 		headOperatorConfig.setStreamOperator(headOperator);
+		headOperatorConfig.setOperatorID(headOperatorId);
 		headOperatorConfig.setChainStart();
 		headOperatorConfig.setChainIndex(0);
 		headOperatorConfig.setChainedOutputs(Collections.singletonList(new StreamEdge(
@@ -282,6 +290,7 @@ public class OneInputStreamTaskTest extends TestLogger {
 		)));
 
 		watermarkOperatorConfig.setStreamOperator(watermarkOperator);
+		watermarkOperatorConfig.setOperatorID(watermarkOperatorId);
 		watermarkOperatorConfig.setTypeSerializerIn1(StringSerializer.INSTANCE);
 		watermarkOperatorConfig.setChainIndex(1);
 		watermarkOperatorConfig.setChainedOutputs(Collections.singletonList(new StreamEdge(
@@ -303,6 +312,7 @@ public class OneInputStreamTaskTest extends TestLogger {
 			null));
 
 		tailOperatorConfig.setStreamOperator(tailOperator);
+		tailOperatorConfig.setOperatorID(tailOperatorId);
 		tailOperatorConfig.setTypeSerializerIn1(StringSerializer.INSTANCE);
 		tailOperatorConfig.setBufferTimeout(0);
 		tailOperatorConfig.setChainIndex(2);
@@ -412,6 +422,7 @@ public class OneInputStreamTaskTest extends TestLogger {
 		StreamConfig streamConfig = testHarness.getStreamConfig();
 		StreamMap<String, String> mapOperator = new StreamMap<String, String>(new IdentityMap());
 		streamConfig.setStreamOperator(mapOperator);
+		streamConfig.setOperatorID(new OperatorID());
 
 		ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<Object>();
 		long initialTime = 0L;
@@ -471,6 +482,7 @@ public class OneInputStreamTaskTest extends TestLogger {
 		StreamConfig streamConfig = testHarness.getStreamConfig();
 		StreamMap<String, String> mapOperator = new StreamMap<String, String>(new IdentityMap());
 		streamConfig.setStreamOperator(mapOperator);
+		streamConfig.setOperatorID(new OperatorID());
 
 		ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<Object>();
 		long initialTime = 0L;
@@ -580,15 +592,20 @@ public class OneInputStreamTaskTest extends TestLogger {
 		testHarness.waitForTaskCompletion(deadline.timeLeft().toMillis());
 
 		final OneInputStreamTask<String, String> restoredTask = new OneInputStreamTask<String, String>();
-		restoredTask.setInitialState(new TaskStateHandles(env.getCheckpointStateHandles()));
 
-		final OneInputStreamTaskTestHarness<String, String> restoredTaskHarness = new OneInputStreamTaskTestHarness<String, String>(restoredTask, BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO);
+		final OneInputStreamTaskTestHarness<String, String> restoredTaskHarness =
+			new OneInputStreamTaskTestHarness<String, String>(restoredTask, BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO);
 		restoredTaskHarness.configureForKeyedStream(keySelector, BasicTypeInfo.STRING_TYPE_INFO);
 
 		StreamConfig restoredTaskStreamConfig = restoredTaskHarness.getStreamConfig();
 
 		configureChainedTestingStreamOperator(restoredTaskStreamConfig, numberChainedTasks, seed, recoveryTimestamp);
 
+		TaskStateSnapshot stateHandles = env.getCheckpointStateHandles();
+		Assert.assertEquals(numberChainedTasks, stateHandles.getSubtaskStateMappings().size());
+
+		restoredTask.setInitialState(stateHandles);
+
 		TestingStreamOperator.numberRestoreCalls = 0;
 
 		restoredTaskHarness.invoke();
@@ -601,6 +618,7 @@ public class OneInputStreamTaskTest extends TestLogger {
 		TestingStreamOperator.numberRestoreCalls = 0;
 	}
 
+
 	//==============================================================================================
 	// Utility functions and classes
 	//==============================================================================================
@@ -618,6 +636,7 @@ public class OneInputStreamTaskTest extends TestLogger {
 
 		TestingStreamOperator<Integer, Integer> previousOperator = new TestingStreamOperator<>(random.nextLong(), recoveryTimestamp);
 		streamConfig.setStreamOperator(previousOperator);
+		streamConfig.setOperatorID(new OperatorID(0L, 0L));
 
 		// create the chain of operators
 		Map<Integer, StreamConfig> chainedTaskConfigs = new HashMap<>(numberChainedTasks - 1);
@@ -627,6 +646,7 @@ public class OneInputStreamTaskTest extends TestLogger {
 			TestingStreamOperator<Integer, Integer> chainedOperator = new TestingStreamOperator<>(random.nextLong(), recoveryTimestamp);
 			StreamConfig chainedConfig = new StreamConfig(new Configuration());
 			chainedConfig.setStreamOperator(chainedOperator);
+			chainedConfig.setOperatorID(new OperatorID(0L, chainedIndex));
 			chainedTaskConfigs.put(chainedIndex, chainedConfig);
 
 			StreamEdge outputEdge = new StreamEdge(
@@ -673,7 +693,7 @@ public class OneInputStreamTaskTest extends TestLogger {
 
 	private static class AcknowledgeStreamMockEnvironment extends StreamMockEnvironment {
 		private volatile long checkpointId;
-		private volatile SubtaskState checkpointStateHandles;
+		private volatile TaskStateSnapshot checkpointStateHandles;
 
 		private final OneShotLatch checkpointLatch = new OneShotLatch();
 
@@ -692,7 +712,7 @@ public class OneInputStreamTaskTest extends TestLogger {
 		public void acknowledgeCheckpoint(
 				long checkpointId,
 				CheckpointMetrics checkpointMetrics,
-				SubtaskState checkpointStateHandles) {
+				TaskStateSnapshot checkpointStateHandles) {
 
 			this.checkpointId = checkpointId;
 			this.checkpointStateHandles = checkpointStateHandles;
@@ -703,7 +723,7 @@ public class OneInputStreamTaskTest extends TestLogger {
 			return checkpointLatch;
 		}
 
-		public SubtaskState getCheckpointStateHandles() {
+		public TaskStateSnapshot getCheckpointStateHandles() {
 			return checkpointStateHandles;
 		}
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceExternalCheckpointTriggerTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceExternalCheckpointTriggerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceExternalCheckpointTriggerTest.java
index 47a5350..b3b0a9f 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceExternalCheckpointTriggerTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceExternalCheckpointTriggerTest.java
@@ -25,6 +25,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.streaming.api.checkpoint.ExternallyInducedSource;
 import org.apache.flink.streaming.api.functions.source.ParallelSourceFunction;
 import org.apache.flink.streaming.api.graph.StreamConfig;
@@ -64,6 +65,7 @@ public class SourceExternalCheckpointTriggerTest {
 		StreamConfig streamConfig = testHarness.getStreamConfig();
 		StreamSource<Long, ?> sourceOperator = new StreamSource<>(source);
 		streamConfig.setStreamOperator(sourceOperator);
+		streamConfig.setOperatorID(new OperatorID());
 
 		// this starts the source thread
 		testHarness.invoke();

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskTest.java
index 27818bc..8867632 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskTest.java
@@ -24,6 +24,7 @@ import org.apache.flink.api.java.typeutils.TupleTypeInfo;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
 import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
@@ -63,6 +64,7 @@ public class SourceStreamTaskTest {
 		StreamConfig streamConfig = testHarness.getStreamConfig();
 		StreamSource<String, ?> sourceOperator = new StreamSource<>(new OpenCloseTestSource());
 		streamConfig.setStreamOperator(sourceOperator);
+		streamConfig.setOperatorID(new OperatorID());
 
 		testHarness.invoke();
 		testHarness.waitForTaskCompletion();
@@ -106,6 +108,7 @@ public class SourceStreamTaskTest {
 			StreamConfig streamConfig = testHarness.getStreamConfig();
 			StreamSource<Tuple2<Long, Integer>, ?> sourceOperator = new StreamSource<>(new MockSource(numElements, sourceCheckpointDelay, sourceReadDelay));
 			streamConfig.setStreamOperator(sourceOperator);
+			streamConfig.setOperatorID(new OperatorID());
 
 			// prepare the
 

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
index 5b995c6..231f59e 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
@@ -28,7 +28,7 @@ import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.runtime.accumulators.AccumulatorRegistry;
 import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
-import org.apache.flink.runtime.checkpoint.SubtaskState;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.event.AbstractEvent;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
@@ -333,7 +333,7 @@ public class StreamMockEnvironment implements Environment {
 	}
 
 	@Override
-	public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, SubtaskState subtaskState) {
+	public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, TaskStateSnapshot subtaskState) {
 	}
 
 	@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskCancellationBarrierTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskCancellationBarrierTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskCancellationBarrierTest.java
index 6e3c299..36bdc05 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskCancellationBarrierTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskCancellationBarrierTest.java
@@ -24,6 +24,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.decline.CheckpointDeclineOnCancellationBarrierException;
 import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.streaming.api.functions.co.CoMapFunction;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
@@ -91,6 +92,7 @@ public class StreamTaskCancellationBarrierTest {
 		StreamConfig streamConfig = testHarness.getStreamConfig();
 		StreamMap<String, String> mapOperator = new StreamMap<>(new IdentityMap());
 		streamConfig.setStreamOperator(mapOperator);
+		streamConfig.setOperatorID(new OperatorID());
 
 		StreamMockEnvironment environment = spy(testHarness.createEnvironment());
 
@@ -135,6 +137,7 @@ public class StreamTaskCancellationBarrierTest {
 		StreamConfig streamConfig = testHarness.getStreamConfig();
 		CoStreamMap<String, String, String> op = new CoStreamMap<>(new UnionCoMap());
 		streamConfig.setStreamOperator(op);
+		streamConfig.setOperatorID(new OperatorID());
 
 		StreamMockEnvironment environment = spy(testHarness.createEnvironment());
 

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java
index 4f2135d..702d833 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java
@@ -42,6 +42,7 @@ import org.apache.flink.runtime.io.network.NetworkEnvironment;
 import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
 import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup;
@@ -107,6 +108,7 @@ public class StreamTaskTerminationTest extends TestLogger {
 		final AbstractStateBackend blockingStateBackend = new BlockingStateBackend();
 
 		streamConfig.setStreamOperator(noOpStreamOperator);
+		streamConfig.setOperatorID(new OperatorID());
 		streamConfig.setStateBackend(blockingStateBackend);
 
 		final long checkpointId = 0L;