You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by al...@apache.org on 2016/10/20 14:15:23 UTC

[5/8] flink git commit: [FLINK-4844] Partitionable Raw Keyed/Operator State

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index bbe10d1..fd425f3 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -38,13 +38,13 @@ import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.NotifyCheckpointComplete;
 import org.apache.flink.runtime.messages.checkpoint.TriggerCheckpoint;
 import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.state.filesystem.FileStateHandle;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.apache.flink.runtime.testutils.CommonTestUtils;
@@ -1847,15 +1847,15 @@ public class CheckpointCoordinatorTest {
 		long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet());
 		CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L);
 
-		List<KeyGroupRange> keyGroupPartitions1 = CheckpointCoordinator.createKeyGroupPartitions(maxParallelism1, parallelism1);
-		List<KeyGroupRange> keyGroupPartitions2 = CheckpointCoordinator.createKeyGroupPartitions(maxParallelism2, parallelism2);
+		List<KeyGroupRange> keyGroupPartitions1 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1);
+		List<KeyGroupRange> keyGroupPartitions2 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2);
 
 		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
 			ChainedStateHandle<StreamStateHandle> nonPartitionedState = generateStateForVertex(jobVertexID1, index);
-			ChainedStateHandle<OperatorStateHandle> partitionableState = generateChainedPartitionableStateHandle(jobVertexID1, index, 2, 8);
-			List<KeyGroupsStateHandle> partitionedKeyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index));
+			ChainedStateHandle<OperatorStateHandle> partitionableState = generateChainedPartitionableStateHandle(jobVertexID1, index, 2, 8, false);
+			KeyGroupsStateHandle partitionedKeyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), false);
 
-			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(nonPartitionedState, partitionableState, partitionedKeyGroupState);
+			SubtaskState checkpointStateHandles = new SubtaskState(nonPartitionedState, partitionableState, null, partitionedKeyGroupState, null, 0);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
@@ -1867,9 +1867,9 @@ public class CheckpointCoordinatorTest {
 
 		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
 			ChainedStateHandle<StreamStateHandle> nonPartitionedState = generateStateForVertex(jobVertexID2, index);
-			ChainedStateHandle<OperatorStateHandle> partitionableState = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8);
-			List<KeyGroupsStateHandle> partitionedKeyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index));
-			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(nonPartitionedState, partitionableState, partitionedKeyGroupState);
+			ChainedStateHandle<OperatorStateHandle> partitionableState = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8, false);
+			KeyGroupsStateHandle partitionedKeyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false);
+			SubtaskState checkpointStateHandles = new SubtaskState(nonPartitionedState, partitionableState, null, partitionedKeyGroupState, null, 0);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
@@ -1952,13 +1952,13 @@ public class CheckpointCoordinatorTest {
 		long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet());
 		CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L);
 
-		List<KeyGroupRange> keyGroupPartitions1 = CheckpointCoordinator.createKeyGroupPartitions(maxParallelism1, parallelism1);
-		List<KeyGroupRange> keyGroupPartitions2 = CheckpointCoordinator.createKeyGroupPartitions(maxParallelism2, parallelism2);
+		List<KeyGroupRange> keyGroupPartitions1 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1);
+		List<KeyGroupRange> keyGroupPartitions2 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2);
 
 		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
 			ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID1, index);
-			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index));
-			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(valueSizeTuple, null, keyGroupState);
+			KeyGroupsStateHandle keyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), false);
+			SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, null, null, keyGroupState, null, 0);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
@@ -1971,8 +1971,8 @@ public class CheckpointCoordinatorTest {
 
 		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
 			ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID2, index);
-			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index));
-			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(valueSizeTuple, null, keyGroupState);
+			KeyGroupsStateHandle keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false);
+			SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, null, null, keyGroupState, null, 0);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
@@ -2067,17 +2067,17 @@ public class CheckpointCoordinatorTest {
 		long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet());
 		CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L);
 
-		List<KeyGroupRange> keyGroupPartitions1 = 
-				CheckpointCoordinator.createKeyGroupPartitions(maxParallelism1, parallelism1);
-		List<KeyGroupRange> keyGroupPartitions2 = 
-				CheckpointCoordinator.createKeyGroupPartitions(maxParallelism2, parallelism2);
+		List<KeyGroupRange> keyGroupPartitions1 =
+				StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1);
+		List<KeyGroupRange> keyGroupPartitions2 =
+				StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2);
 
 		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
 			ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID1, index);
-			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(
-					jobVertexID1, keyGroupPartitions1.get(index));
+			KeyGroupsStateHandle keyGroupState = generateKeyGroupState(
+					jobVertexID1, keyGroupPartitions1.get(index), false);
 
-			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(valueSizeTuple, null, keyGroupState);
+			SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, null, null, keyGroupState, null, 0);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
@@ -2091,10 +2091,10 @@ public class CheckpointCoordinatorTest {
 		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
 
 			ChainedStateHandle<StreamStateHandle> state = generateStateForVertex(jobVertexID2, index);
-			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(
-					jobVertexID2, keyGroupPartitions2.get(index));
+			KeyGroupsStateHandle keyGroupState = generateKeyGroupState(
+					jobVertexID2, keyGroupPartitions2.get(index), false);
 
-			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(state, null, keyGroupState);
+			SubtaskState checkpointStateHandles = new SubtaskState(state, null, null, keyGroupState, null, 0);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
@@ -2132,24 +2132,36 @@ public class CheckpointCoordinatorTest {
 			"non-partitioned state changed.");
 	}
 
+	@Test
+	public void testRestoreLatestCheckpointedStateScaleIn() throws Exception {
+		testRestoreLatestCheckpointedStateWithChangingParallelism(false);
+	}
+
+	@Test
+	public void testRestoreLatestCheckpointedStateScaleOut() throws Exception {
+		testRestoreLatestCheckpointedStateWithChangingParallelism(false);
+	}
+
 	/**
 	 * Tests the checkpoint restoration with changing parallelism of job vertex with partitioned
 	 * state.
 	 *
 	 * @throws Exception
 	 */
-	@Test
-	public void testRestoreLatestCheckpointedStateWithChangingParallelism() throws Exception {
+	private void testRestoreLatestCheckpointedStateWithChangingParallelism(boolean scaleOut) throws Exception {
 		final JobID jid = new JobID();
 		final long timestamp = System.currentTimeMillis();
 
 		final JobVertexID jobVertexID1 = new JobVertexID();
 		final JobVertexID jobVertexID2 = new JobVertexID();
 		int parallelism1 = 3;
-		int parallelism2 = 2;
+		int parallelism2 = scaleOut ? 2 : 13;
+
 		int maxParallelism1 = 42;
 		int maxParallelism2 = 13;
 
+		int newParallelism2 = scaleOut ? 13 : 2;
+
 		final ExecutionJobVertex jobVertex1 = mockExecutionJobVertex(
 				jobVertexID1,
 				parallelism1,
@@ -2190,18 +2202,20 @@ public class CheckpointCoordinatorTest {
 		long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet());
 		CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L);
 
-		List<KeyGroupRange> keyGroupPartitions1 = 
-				CheckpointCoordinator.createKeyGroupPartitions(maxParallelism1, parallelism1);
-		List<KeyGroupRange> keyGroupPartitions2 = 
-				CheckpointCoordinator.createKeyGroupPartitions(maxParallelism2, parallelism2);
+		List<KeyGroupRange> keyGroupPartitions1 =
+				StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1);
+		List<KeyGroupRange> keyGroupPartitions2 =
+				StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2);
 
+		//vertex 1
 		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
 			ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID1, index);
-			ChainedStateHandle<OperatorStateHandle> partitionableState = generateChainedPartitionableStateHandle(jobVertexID1, index, 2, 8);
-			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index));
+			ChainedStateHandle<OperatorStateHandle> opStateBackend = generateChainedPartitionableStateHandle(jobVertexID1, index, 2, 8, false);
+			KeyGroupsStateHandle keyedStateBackend = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), false);
+			KeyGroupsStateHandle keyedStateRaw = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), true);
 
 
-			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(valueSizeTuple, partitionableState, keyGroupState);
+			SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, opStateBackend, null, keyedStateBackend, keyedStateRaw , 0);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
@@ -2211,13 +2225,19 @@ public class CheckpointCoordinatorTest {
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
 
-
-		final List<ChainedStateHandle<OperatorStateHandle>> originalPartitionableStates = new ArrayList<>(jobVertex2.getParallelism());
+		//vertex 2
+		final List<ChainedStateHandle<OperatorStateHandle>> expectedOpStatesBackend = new ArrayList<>(jobVertex2.getParallelism());
+		final List<ChainedStateHandle<OperatorStateHandle>> expectedOpStatesRaw = new ArrayList<>(jobVertex2.getParallelism());
 		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
-			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index));
-			ChainedStateHandle<OperatorStateHandle> partitionableState = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8);
-			originalPartitionableStates.add(partitionableState);
-			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(null, partitionableState, keyGroupState);
+			KeyGroupsStateHandle keyedStateBackend = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false);
+			KeyGroupsStateHandle keyedStateRaw = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), true);
+			ChainedStateHandle<OperatorStateHandle> opStateBackend = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8, false);
+			ChainedStateHandle<OperatorStateHandle> opStateRaw = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8, true);
+			expectedOpStatesBackend.add(opStateBackend);
+			expectedOpStatesRaw.add(opStateRaw);
+			SubtaskState checkpointStateHandles =
+					new SubtaskState(new ChainedStateHandle<>(
+							Collections.<StreamStateHandle>singletonList(null)), opStateBackend, opStateRaw, keyedStateBackend, keyedStateRaw, 0);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
@@ -2233,16 +2253,15 @@ public class CheckpointCoordinatorTest {
 
 		Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>();
 
-		int newParallelism2 = 13;
-
-		List<KeyGroupRange> newKeyGroupPartitions2 = 
-				CheckpointCoordinator.createKeyGroupPartitions(maxParallelism2, newParallelism2);
+		List<KeyGroupRange> newKeyGroupPartitions2 =
+				StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, newParallelism2);
 
 		final ExecutionJobVertex newJobVertex1 = mockExecutionJobVertex(
 				jobVertexID1,
 				parallelism1,
 				maxParallelism1);
 
+		// rescale vertex 2
 		final ExecutionJobVertex newJobVertex2 = mockExecutionJobVertex(
 				jobVertexID2,
 				newParallelism2,
@@ -2254,19 +2273,28 @@ public class CheckpointCoordinatorTest {
 
 		// verify the restored state
 		verifiyStateRestore(jobVertexID1, newJobVertex1, keyGroupPartitions1);
-		List<List<Collection<OperatorStateHandle>>> actualPartitionableStates = new ArrayList<>(newJobVertex2.getParallelism());
+		List<List<Collection<OperatorStateHandle>>> actualOpStatesBackend = new ArrayList<>(newJobVertex2.getParallelism());
+		List<List<Collection<OperatorStateHandle>>> actualOpStatesRaw = new ArrayList<>(newJobVertex2.getParallelism());
 		for (int i = 0; i < newJobVertex2.getParallelism(); i++) {
-			List<KeyGroupsStateHandle> originalKeyGroupState = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i));
+			KeyGroupsStateHandle originalKeyedStateBackend = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), false);
+			KeyGroupsStateHandle originalKeyedStateRaw = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), true);
+
+			TaskStateHandles taskStateHandles = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles();
 
-			ChainedStateHandle<StreamStateHandle> operatorState = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getChainedStateHandle();
-			List<Collection<OperatorStateHandle>> partitionableState = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getChainedPartitionableStateHandle();
-			List<KeyGroupsStateHandle> keyGroupState = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getKeyGroupsStateHandles();
+			ChainedStateHandle<StreamStateHandle> operatorState = taskStateHandles.getLegacyOperatorState();
+			List<Collection<OperatorStateHandle>> opStateBackend = taskStateHandles.getManagedOperatorState();
+			List<Collection<OperatorStateHandle>> opStateRaw = taskStateHandles.getRawOperatorState();
+			Collection<KeyGroupsStateHandle> keyGroupStateBackend = taskStateHandles.getManagedKeyedState();
+			Collection<KeyGroupsStateHandle> keyGroupStateRaw = taskStateHandles.getRawKeyedState();
 
-			actualPartitionableStates.add(partitionableState);
+			actualOpStatesBackend.add(opStateBackend);
+			actualOpStatesRaw.add(opStateRaw);
 			assertNull(operatorState);
-			compareKeyPartitionedState(originalKeyGroupState, keyGroupState);
+			compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyGroupStateBackend);
+			compareKeyedState(Collections.singletonList(originalKeyedStateRaw), keyGroupStateRaw);
 		}
-		comparePartitionableState(originalPartitionableStates, actualPartitionableStates);
+		comparePartitionableState(expectedOpStatesBackend, actualOpStatesBackend);
+		comparePartitionableState(expectedOpStatesRaw, actualOpStatesRaw);
 	}
 
 	/**
@@ -2320,15 +2348,41 @@ public class CheckpointCoordinatorTest {
 	//  Utilities
 	// ------------------------------------------------------------------------
 
-	public static List<KeyGroupsStateHandle> generateKeyGroupState(
+	static void sendAckMessageToCoordinator(
+			CheckpointCoordinator coord,
+			long checkpointId, JobID jid,
+			ExecutionJobVertex jobVertex,
+			JobVertexID jobVertexID,
+			List<KeyGroupRange> keyGroupPartitions) throws Exception {
+
+		for (int index = 0; index < jobVertex.getParallelism(); index++) {
+			ChainedStateHandle<StreamStateHandle> state = generateStateForVertex(jobVertexID, index);
+			KeyGroupsStateHandle keyGroupState = generateKeyGroupState(
+					jobVertexID,
+					keyGroupPartitions.get(index), false);
+
+			SubtaskState checkpointStateHandles = new SubtaskState(state, null, null, keyGroupState, null, 0);
+			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
+					jid,
+					jobVertex.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+					new CheckpointMetaData(checkpointId, 0L),
+					checkpointStateHandles);
+
+			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
+		}
+	}
+
+	public static KeyGroupsStateHandle generateKeyGroupState(
 			JobVertexID jobVertexID,
-			KeyGroupRange keyGroupPartition) throws IOException {
+			KeyGroupRange keyGroupPartition, boolean rawState) throws IOException {
 
 		List<Integer> testStatesLists = new ArrayList<>(keyGroupPartition.getNumberOfKeyGroups());
 
 		// generate state for one keygroup
 		for (int keyGroupIndex : keyGroupPartition) {
-			Random random = new Random(jobVertexID.hashCode() + keyGroupIndex);
+			int vertexHash = jobVertexID.hashCode();
+			int seed = rawState ? (vertexHash * (31 + keyGroupIndex)) : (vertexHash + keyGroupIndex);
+			Random random = new Random(seed);
 			int simulatedStateValue = random.nextInt();
 			testStatesLists.add(simulatedStateValue);
 		}
@@ -2336,7 +2390,7 @@ public class CheckpointCoordinatorTest {
 		return generateKeyGroupState(keyGroupPartition, testStatesLists);
 	}
 
-	public static List<KeyGroupsStateHandle> generateKeyGroupState(
+	public static KeyGroupsStateHandle generateKeyGroupState(
 			KeyGroupRange keyGroupRange,
 			List<? extends Serializable> states) throws IOException {
 
@@ -2353,9 +2407,7 @@ public class CheckpointCoordinatorTest {
 		KeyGroupsStateHandle keyGroupsStateHandle = new KeyGroupsStateHandle(
 				keyGroupRangeOffsets,
 				allSerializedStatesHandle);
-		List<KeyGroupsStateHandle> keyGroupsStateHandleList = new ArrayList<>();
-		keyGroupsStateHandleList.add(keyGroupsStateHandle);
-		return keyGroupsStateHandleList;
+		return keyGroupsStateHandle;
 	}
 
 	public static Tuple2<byte[], List<long[]>> serializeTogetherAndTrackOffsets(
@@ -2412,14 +2464,19 @@ public class CheckpointCoordinatorTest {
 			JobVertexID jobVertexID,
 			int index,
 			int namedStates,
-			int partitionsPerState) throws IOException {
+			int partitionsPerState,
+			boolean rawState) throws IOException {
 
 		Map<String, List<? extends Serializable>> statesListsMap = new HashMap<>(namedStates);
 
 		for (int i = 0; i < namedStates; ++i) {
 			List<Integer> testStatesLists = new ArrayList<>(partitionsPerState);
 			// generate state
-			Random random = new Random(jobVertexID.hashCode() * index + i * namedStates);
+			int seed = jobVertexID.hashCode() * index + i * namedStates;
+			if (rawState) {
+				seed = (seed + 1) * 31;
+			}
+			Random random = new Random(seed);
 			for (int j = 0; j < partitionsPerState; ++j) {
 				int simulatedStateValue = random.nextInt();
 				testStatesLists.add(simulatedStateValue);
@@ -2454,7 +2511,7 @@ public class CheckpointCoordinatorTest {
 				serializationWithOffsets.f0);
 
 		OperatorStateHandle operatorStateHandle =
-				new OperatorStateHandle(streamStateHandle, offsetsMap);
+				new OperatorStateHandle(offsetsMap, streamStateHandle);
 		return ChainedStateHandle.wrapSingleHandle(operatorStateHandle);
 	}
 
@@ -2528,37 +2585,35 @@ public class CheckpointCoordinatorTest {
 
 		for (int i = 0; i < executionJobVertex.getParallelism(); i++) {
 
+			TaskStateHandles taskStateHandles = executionJobVertex.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles();
+
 			ChainedStateHandle<StreamStateHandle> expectNonPartitionedState = generateStateForVertex(jobVertexID, i);
-			ChainedStateHandle<StreamStateHandle> actualNonPartitionedState = executionJobVertex.
-					getTaskVertices()[i].getCurrentExecutionAttempt().getChainedStateHandle();
+			ChainedStateHandle<StreamStateHandle> actualNonPartitionedState = taskStateHandles.getLegacyOperatorState();
 			assertTrue(CommonTestUtils.isSteamContentEqual(
 					expectNonPartitionedState.get(0).openInputStream(),
 					actualNonPartitionedState.get(0).openInputStream()));
 
-			ChainedStateHandle<OperatorStateHandle> expectedPartitionableState =
-					generateChainedPartitionableStateHandle(jobVertexID, i, 2, 8);
+			ChainedStateHandle<OperatorStateHandle> expectedOpStateBackend =
+					generateChainedPartitionableStateHandle(jobVertexID, i, 2, 8, false);
 
-			List<Collection<OperatorStateHandle>> actualPartitionableState = executionJobVertex.
-					getTaskVertices()[i].getCurrentExecutionAttempt().getChainedPartitionableStateHandle();
+			List<Collection<OperatorStateHandle>> actualPartitionableState = taskStateHandles.getManagedOperatorState();
 
 			assertTrue(CommonTestUtils.isSteamContentEqual(
-					expectedPartitionableState.get(0).openInputStream(),
+					expectedOpStateBackend.get(0).openInputStream(),
 					actualPartitionableState.get(0).iterator().next().openInputStream()));
 
-			List<KeyGroupsStateHandle> expectPartitionedKeyGroupState = generateKeyGroupState(
-					jobVertexID,
-					keyGroupPartitions.get(i));
-			List<KeyGroupsStateHandle> actualPartitionedKeyGroupState = executionJobVertex.
-					getTaskVertices()[i].getCurrentExecutionAttempt().getKeyGroupsStateHandles();
-			compareKeyPartitionedState(expectPartitionedKeyGroupState, actualPartitionedKeyGroupState);
+			KeyGroupsStateHandle expectPartitionedKeyGroupState = generateKeyGroupState(
+					jobVertexID, keyGroupPartitions.get(i), false);
+			Collection<KeyGroupsStateHandle> actualPartitionedKeyGroupState = taskStateHandles.getManagedKeyedState();
+			compareKeyedState(Collections.singletonList(expectPartitionedKeyGroupState), actualPartitionedKeyGroupState);
 		}
 	}
 
-	public static void compareKeyPartitionedState(
-			List<KeyGroupsStateHandle> expectPartitionedKeyGroupState,
-			List<KeyGroupsStateHandle> actualPartitionedKeyGroupState) throws Exception {
+	public static void compareKeyedState(
+			Collection<KeyGroupsStateHandle> expectPartitionedKeyGroupState,
+			Collection<KeyGroupsStateHandle> actualPartitionedKeyGroupState) throws Exception {
 
-		KeyGroupsStateHandle expectedHeadOpKeyGroupStateHandle = expectPartitionedKeyGroupState.get(0);
+		KeyGroupsStateHandle expectedHeadOpKeyGroupStateHandle = expectPartitionedKeyGroupState.iterator().next();
 		int expectedTotalKeyGroups = expectedHeadOpKeyGroupStateHandle.getNumberOfKeyGroups();
 		int actualTotalKeyGroups = 0;
 		for(KeyGroupsStateHandle keyGroupsStateHandle: actualPartitionedKeyGroupState) {
@@ -2576,13 +2631,10 @@ public class CheckpointCoordinatorTest {
 				for (KeyGroupsStateHandle oneActualKeyGroupStateHandle : actualPartitionedKeyGroupState) {
 					if (oneActualKeyGroupStateHandle.containsKeyGroup(groupId)) {
 						long actualOffset = oneActualKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
-						try (FSDataInputStream actualInputStream =
-								     oneActualKeyGroupStateHandle.openInputStream()) {
+						try (FSDataInputStream actualInputStream = oneActualKeyGroupStateHandle.openInputStream()) {
 							actualInputStream.seek(actualOffset);
-
 							int actualGroupState = InstantiationUtil.
 									deserializeObject(actualInputStream, Thread.currentThread().getContextClassLoader());
-
 							assertEquals(expectedKeyGroupState, actualGroupState);
 						}
 					}
@@ -2599,16 +2651,7 @@ public class CheckpointCoordinatorTest {
 		for (ChainedStateHandle<OperatorStateHandle> chainedStateHandle : expected) {
 			for (int i = 0; i < chainedStateHandle.getLength(); ++i) {
 				OperatorStateHandle operatorStateHandle = chainedStateHandle.get(i);
-				try (FSDataInputStream in = operatorStateHandle.openInputStream()) {
-					for (Map.Entry<String, long[]> entry : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) {
-						for (long offset : entry.getValue()) {
-							in.seek(offset);
-							Integer state = InstantiationUtil.
-									deserializeObject(in, Thread.currentThread().getContextClassLoader());
-							expectedResult.add(i + " : " + entry.getKey() + " : " + state);
-						}
-					}
-				}
+				collectResult(i, operatorStateHandle, expectedResult);
 			}
 		}
 		Collections.sort(expectedResult);
@@ -2618,25 +2661,32 @@ public class CheckpointCoordinatorTest {
 			if (collectionList != null) {
 				for (int i = 0; i < collectionList.size(); ++i) {
 					Collection<OperatorStateHandle> stateHandles = collectionList.get(i);
+					Assert.assertNotNull(stateHandles);
 					for (OperatorStateHandle operatorStateHandle : stateHandles) {
-						try (FSDataInputStream in = operatorStateHandle.openInputStream()) {
-							for (Map.Entry<String, long[]> entry : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) {
-								for (long offset : entry.getValue()) {
-									in.seek(offset);
-									Integer state = InstantiationUtil.
-											deserializeObject(in, Thread.currentThread().getContextClassLoader());
-									actualResult.add(i + " : " + entry.getKey() + " : " + state);
-								}
-							}
-						}
+						collectResult(i, operatorStateHandle, actualResult);
 					}
 				}
 			}
 		}
+
 		Collections.sort(actualResult);
 		Assert.assertEquals(expectedResult, actualResult);
 	}
 
+	private static void collectResult(int opIdx, OperatorStateHandle operatorStateHandle, List<String> resultCollector) throws Exception {
+		try (FSDataInputStream in = operatorStateHandle.openInputStream()) {
+			for (Map.Entry<String, long[]> entry : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) {
+				for (long offset : entry.getValue()) {
+					in.seek(offset);
+					Integer state = InstantiationUtil.
+							deserializeObject(in, Thread.currentThread().getContextClassLoader());
+					resultCollector.add(opIdx + " : " + entry.getKey() + " : " + state);
+				}
+			}
+		}
+	}
+
+
 	@Test
 	public void testCreateKeyGroupPartitions() {
 		testCreateKeyGroupPartitions(1, 1);
@@ -2697,7 +2747,7 @@ public class CheckpointCoordinatorTest {
 	}
 
 	private void testCreateKeyGroupPartitions(int maxParallelism, int parallelism) {
-		List<KeyGroupRange> ranges = CheckpointCoordinator.createKeyGroupPartitions(maxParallelism, parallelism);
+		List<KeyGroupRange> ranges = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism, parallelism);
 		for (int i = 0; i < maxParallelism; ++i) {
 			KeyGroupRange range = ranges.get(KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(maxParallelism, parallelism, i));
 			if (!range.contains(i)) {
@@ -2743,7 +2793,7 @@ public class CheckpointCoordinatorTest {
 			}
 
 			previousParallelOpInstanceStates.add(
-					new OperatorStateHandle(new FileStateHandle(fakePath, -1), namedStatesToOffsets));
+					new OperatorStateHandle(namedStatesToOffsets, new FileStateHandle(fakePath, -1)));
 		}
 
 		Map<StreamStateHandle, Map<String, List<Long>>> expected = new HashMap<>();

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
index 950526c..359262f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
@@ -29,17 +29,18 @@ import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.util.SerializableObject;
 import org.hamcrest.BaseMatcher;
 import org.hamcrest.Description;
 import org.junit.Test;
 import org.mockito.Mockito;
 
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
@@ -65,7 +66,7 @@ public class CheckpointStateRestoreTest {
 			final ChainedStateHandle<StreamStateHandle> serializedState = CheckpointCoordinatorTest.generateChainedStateHandle(new SerializableObject());
 			KeyGroupRange keyGroupRange = KeyGroupRange.of(0,0);
 			List<SerializableObject> testStates = Collections.singletonList(new SerializableObject());
-			final List<KeyGroupsStateHandle> serializedKeyGroupStates = CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, testStates);
+			final KeyGroupsStateHandle serializedKeyGroupStates = CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, testStates);
 
 			final JobID jid = new JobID();
 			final JobVertexID statefulId = new JobVertexID();
@@ -115,7 +116,7 @@ public class CheckpointStateRestoreTest {
 			PendingCheckpoint pending = coord.getPendingCheckpoints().values().iterator().next();
 			final long checkpointId = pending.getCheckpointId();
 
-			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(serializedState, null, serializedKeyGroupStates);
+			SubtaskState checkpointStateHandles = new SubtaskState(serializedState, null, null, serializedKeyGroupStates, null, 0L);
 			CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L);
 			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointMetaData, checkpointStateHandles));
 			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointMetaData, checkpointStateHandles));
@@ -131,26 +132,33 @@ public class CheckpointStateRestoreTest {
 
 			// verify that each stateful vertex got the state
 
-			BaseMatcher<CheckpointStateHandles> matcher = new BaseMatcher<CheckpointStateHandles>() {
+			final TaskStateHandles taskStateHandles = new TaskStateHandles(
+					serializedState,
+					Collections.<Collection<OperatorStateHandle>>singletonList(null),
+					Collections.<Collection<OperatorStateHandle>>singletonList(null),
+					Collections.singletonList(serializedKeyGroupStates),
+					null);
+
+			BaseMatcher<TaskStateHandles> matcher = new BaseMatcher<TaskStateHandles>() {
 				@Override
 				public boolean matches(Object o) {
-					if (o instanceof CheckpointStateHandles) {
-						return ((CheckpointStateHandles) o).getNonPartitionedStateHandles().equals(serializedState);
+					if (o instanceof TaskStateHandles) {
+						return o.equals(taskStateHandles);
 					}
 					return false;
 				}
 
 				@Override
 				public void describeTo(Description description) {
-					description.appendValue(serializedState);
+					description.appendValue(taskStateHandles);
 				}
 			};
 
-			verify(statefulExec1, times(1)).setInitialState(Mockito.argThat(matcher), Mockito.<List<Collection<OperatorStateHandle>>>any());
-			verify(statefulExec2, times(1)).setInitialState(Mockito.argThat(matcher), Mockito.<List<Collection<OperatorStateHandle>>>any());
-			verify(statefulExec3, times(1)).setInitialState(Mockito.argThat(matcher), Mockito.<List<Collection<OperatorStateHandle>>>any());
-			verify(statelessExec1, times(0)).setInitialState(Mockito.<CheckpointStateHandles>any(), Mockito.<List<Collection<OperatorStateHandle>>>any());
-			verify(statelessExec2, times(0)).setInitialState(Mockito.<CheckpointStateHandles>any(), Mockito.<List<Collection<OperatorStateHandle>>>any());
+			verify(statefulExec1, times(1)).setInitialState(Mockito.argThat(matcher));
+			verify(statefulExec2, times(1)).setInitialState(Mockito.argThat(matcher));
+			verify(statefulExec3, times(1)).setInitialState(Mockito.argThat(matcher));
+			verify(statelessExec1, times(0)).setInitialState(Mockito.<TaskStateHandles>any());
+			verify(statelessExec2, times(0)).setInitialState(Mockito.<TaskStateHandles>any());
 		}
 		catch (Exception e) {
 			e.printStackTrace();
@@ -164,7 +172,7 @@ public class CheckpointStateRestoreTest {
 			final ChainedStateHandle<StreamStateHandle> serializedState = CheckpointCoordinatorTest.generateChainedStateHandle(new SerializableObject());
 			KeyGroupRange keyGroupRange = KeyGroupRange.of(0,0);
 			List<SerializableObject> testStates = Collections.singletonList(new SerializableObject());
-			final List<KeyGroupsStateHandle> serializedKeyGroupStates = CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, testStates);
+			final KeyGroupsStateHandle serializedKeyGroupStates = CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, testStates);
 
 			final JobID jid = new JobID();
 			final JobVertexID statefulId = new JobVertexID();
@@ -215,7 +223,8 @@ public class CheckpointStateRestoreTest {
 			final long checkpointId = pending.getCheckpointId();
 
 			// the difference to the test "testSetState" is that one stateful subtask does not report state
-			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(serializedState, null, serializedKeyGroupStates);
+			SubtaskState checkpointStateHandles =
+					new SubtaskState(serializedState, null, null, serializedKeyGroupStates, null, 0L);
 
 			CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
index baa0e08..6b0d3f8 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
@@ -206,7 +206,7 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 			ChainedStateHandle<StreamStateHandle> stateHandle = CheckpointCoordinatorTest.generateChainedStateHandle(
 					new CheckpointMessagesTest.MyHandle());
 
-			taskState.putState(i, new SubtaskState(stateHandle, 0));
+			taskState.putState(i, new SubtaskState(stateHandle, null, null, null, null, 0L));
 		}
 
 		return new TestCompletedCheckpoint(new JobID(), id, 0, taskGroupStates, props);

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1SerializerTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1SerializerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1SerializerTest.java
index bad836b..508a69d 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1SerializerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1SerializerTest.java
@@ -19,11 +19,13 @@
 package org.apache.flink.runtime.checkpoint.savepoint;
 
 import org.apache.commons.io.output.ByteArrayOutputStream;
+import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.junit.Test;
 
 import java.io.ByteArrayInputStream;
+import java.util.Random;
 
 import static org.junit.Assert.assertEquals;
 
@@ -34,19 +36,23 @@ public class SavepointV1SerializerTest {
 	 */
 	@Test
 	public void testSerializeDeserializeV1() throws Exception {
-		SavepointV1 expected = new SavepointV1(123123, SavepointV1Test.createTaskStates(8, 32));
+		Random r = new Random(42);
+		for (int i = 0; i < 100; ++i) {
+			SavepointV1 expected =
+					new SavepointV1(i+ 123123, SavepointV1Test.createTaskStates(1 + r.nextInt(64), 1 + r.nextInt(64)));
 
-		SavepointV1Serializer serializer = SavepointV1Serializer.INSTANCE;
+			SavepointV1Serializer serializer = SavepointV1Serializer.INSTANCE;
 
-		// Serialize
-		ByteArrayOutputStream baos = new ByteArrayOutputStream();
-		serializer.serialize(expected, new DataOutputViewStreamWrapper(baos));
-		byte[] bytes = baos.toByteArray();
+			// Serialize
+			ByteArrayOutputStreamWithPos baos = new ByteArrayOutputStreamWithPos();
+			serializer.serialize(expected, new DataOutputViewStreamWrapper(baos));
+			byte[] bytes = baos.toByteArray();
 
-		// Deserialize
-		ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
-		Savepoint actual = serializer.deserialize(new DataInputViewStreamWrapper(bais));
+			// Deserialize
+			ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
+			Savepoint actual = serializer.deserialize(new DataInputViewStreamWrapper(bais));
 
-		assertEquals(expected, actual);
+			assertEquals(expected, actual);
+		}
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
index e38e5fb..1ae74ff 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
@@ -32,10 +32,10 @@ import org.junit.Test;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collection;
-import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Random;
 import java.util.concurrent.ThreadLocalRandom;
 
 import static org.junit.Assert.assertEquals;
@@ -66,35 +66,83 @@ public class SavepointV1Test {
 		assertTrue(savepoint.getTaskStates().isEmpty());
 	}
 
-	static Collection<TaskState> createTaskStates(int numTaskStates, int numSubtaskStates) throws IOException {
+	static Collection<TaskState> createTaskStates(int numTaskStates, int numSubtasksPerTask) throws IOException {
+
+		Random random = new Random(numTaskStates * 31 + numSubtasksPerTask);
+
 		List<TaskState> taskStates = new ArrayList<>(numTaskStates);
 
-		for (int i = 0; i < numTaskStates; i++) {
-			TaskState taskState = new TaskState(new JobVertexID(), numSubtaskStates, numSubtaskStates, 1);
-			for (int j = 0; j < numSubtaskStates; j++) {
-				StreamStateHandle stateHandle = new TestByteStreamStateHandleDeepCompare("a", "Hello".getBytes());
-				taskState.putState(i, new SubtaskState(
-						new ChainedStateHandle<>(Collections.singletonList(stateHandle)), 0));
-
-				stateHandle = new TestByteStreamStateHandleDeepCompare("b", "Beautiful".getBytes());
-				Map<String, long[]> offsetsMap = new HashMap<>();
-				offsetsMap.put("A", new long[]{0, 10, 20});
-				offsetsMap.put("B", new long[]{30, 40, 50});
-
-				OperatorStateHandle operatorStateHandle =
-						new OperatorStateHandle(stateHandle, offsetsMap);
-
-				taskState.putPartitionableState(
-						i,
-						new ChainedStateHandle<OperatorStateHandle>(
-								Collections.singletonList(operatorStateHandle)));
-			}
+		for (int stateIdx = 0; stateIdx < numTaskStates; ++stateIdx) {
+
+			int chainLength = 1 + random.nextInt(8);
+
+			TaskState taskState = new TaskState(new JobVertexID(), numSubtasksPerTask, 128, chainLength);
+
+			int noNonPartitionableStateAtIndex = random.nextInt(chainLength);
+			int noOperatorStateBackendAtIndex = random.nextInt(chainLength);
+			int noOperatorStateStreamAtIndex = random.nextInt(chainLength);
+
+			boolean hasKeyedBackend = random.nextInt(4) != 0;
+			boolean hasKeyedStream = random.nextInt(4) != 0;
+
+			for (int subtaskIdx = 0; subtaskIdx < numSubtasksPerTask; subtaskIdx++) {
 
-			taskState.putKeyedState(
-					0,
-					new KeyGroupsStateHandle(
+				List<StreamStateHandle> nonPartitionableStates = new ArrayList<>(chainLength);
+				List<OperatorStateHandle> operatorStatesBackend = new ArrayList<>(chainLength);
+				List<OperatorStateHandle> operatorStatesStream = new ArrayList<>(chainLength);
+
+				for (int chainIdx = 0; chainIdx < chainLength; ++chainIdx) {
+
+					StreamStateHandle nonPartitionableState =
+							new TestByteStreamStateHandleDeepCompare("a-" + chainIdx, ("Hi-" + chainIdx).getBytes());
+					StreamStateHandle operatorStateBackend =
+							new TestByteStreamStateHandleDeepCompare("b-" + chainIdx, ("Beautiful-" + chainIdx).getBytes());
+					StreamStateHandle operatorStateStream =
+							new TestByteStreamStateHandleDeepCompare("b-" + chainIdx, ("Beautiful-" + chainIdx).getBytes());
+					Map<String, long[]> offsetsMap = new HashMap<>();
+					offsetsMap.put("A", new long[]{0, 10, 20});
+					offsetsMap.put("B", new long[]{30, 40, 50});
+
+					if (chainIdx != noNonPartitionableStateAtIndex) {
+						nonPartitionableStates.add(nonPartitionableState);
+					}
+
+					if (chainIdx != noOperatorStateBackendAtIndex) {
+						OperatorStateHandle operatorStateHandleBackend =
+								new OperatorStateHandle(offsetsMap, operatorStateBackend);
+						operatorStatesBackend.add(operatorStateHandleBackend);
+					}
+
+					if (chainIdx != noOperatorStateStreamAtIndex) {
+						OperatorStateHandle operatorStateHandleStream =
+								new OperatorStateHandle(offsetsMap, operatorStateStream);
+						operatorStatesStream.add(operatorStateHandleStream);
+					}
+				}
+
+				KeyGroupsStateHandle keyedStateBackend = null;
+				KeyGroupsStateHandle keyedStateStream = null;
+
+				if (hasKeyedBackend) {
+					keyedStateBackend = new KeyGroupsStateHandle(
 							new KeyGroupRangeOffsets(1, 1, new long[]{42}),
-							new TestByteStreamStateHandleDeepCompare("c", "World".getBytes())));
+							new TestByteStreamStateHandleDeepCompare("c", "Hello".getBytes()));
+				}
+
+				if (hasKeyedStream) {
+					keyedStateStream = new KeyGroupsStateHandle(
+							new KeyGroupRangeOffsets(1, 1, new long[]{23}),
+							new TestByteStreamStateHandleDeepCompare("d", "World".getBytes()));
+				}
+
+				taskState.putState(subtaskIdx, new SubtaskState(
+						new ChainedStateHandle<>(nonPartitionableStates),
+						new ChainedStateHandle<>(operatorStatesBackend),
+						new ChainedStateHandle<>(operatorStatesStream),
+						keyedStateStream,
+						keyedStateBackend,
+						subtaskIdx * 10L));
+			}
 
 			taskStates.add(taskState);
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java
index 2dac87f..50a59a5 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java
@@ -335,8 +335,7 @@ public class SimpleCheckpointStatsTrackerTest {
 					StreamStateHandle proxy = new StateHandleProxy(new Path(), proxySize);
 
 					SubtaskState subtaskState = new SubtaskState(
-						new ChainedStateHandle<>(Collections.singletonList(proxy)),
-						duration);
+							new ChainedStateHandle<>(Collections.singletonList(proxy)), null, null, null, null, duration);
 
 					taskState.putState(subtaskIndex, subtaskState);
 				}

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
index 5ec6991..b195858 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
@@ -37,6 +37,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointRecoveryFactory;
 import org.apache.flink.runtime.checkpoint.CompletedCheckpoint;
 import org.apache.flink.runtime.checkpoint.CompletedCheckpointStore;
 import org.apache.flink.runtime.checkpoint.StandaloneCheckpointIDCounter;
+import org.apache.flink.runtime.checkpoint.SubtaskState;
 import org.apache.flink.runtime.clusterframework.types.ResourceID;
 import org.apache.flink.runtime.execution.librarycache.BlobLibraryCacheManager;
 import org.apache.flink.runtime.executiongraph.restart.FixedDelayRestartStrategy;
@@ -57,10 +58,8 @@ import org.apache.flink.runtime.leaderelection.TestingLeaderRetrievalService;
 import org.apache.flink.runtime.leaderretrieval.LeaderRetrievalService;
 import org.apache.flink.runtime.messages.JobManagerMessages;
 import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.apache.flink.runtime.taskmanager.TaskManager;
 import org.apache.flink.runtime.testingUtils.TestingJobManager;
@@ -85,7 +84,6 @@ import scala.concurrent.duration.FiniteDuration;
 
 import java.util.ArrayDeque;
 import java.util.ArrayList;
-import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
@@ -446,12 +444,10 @@ public class JobManagerHARecoveryTest {
 
 		@Override
 		public void setInitialState(
-				ChainedStateHandle<StreamStateHandle> chainedState,
-				List<KeyGroupsStateHandle> keyGroupsState,
-				List<Collection<OperatorStateHandle>> partitionableOperatorState) throws Exception {
+				TaskStateHandles taskStateHandles) throws Exception {
 			int subtaskIndex = getIndexInSubtaskGroup();
 			if (subtaskIndex < recoveredStates.length) {
-				try (FSDataInputStream in = chainedState.get(0).openInputStream()) {
+				try (FSDataInputStream in = taskStateHandles.getLegacyOperatorState().get(0).openInputStream()) {
 					recoveredStates[subtaskIndex] = InstantiationUtil.deserializeObject(in, getUserCodeClassLoader());
 				}
 			}
@@ -466,9 +462,8 @@ public class JobManagerHARecoveryTest {
 
 				ChainedStateHandle<StreamStateHandle> chainedStateHandle =
 						new ChainedStateHandle<StreamStateHandle>(Collections.singletonList(byteStreamStateHandle));
-
-				CheckpointStateHandles checkpointStateHandles =
-						new CheckpointStateHandles(chainedStateHandle, null, Collections.<KeyGroupsStateHandle>emptyList());
+				SubtaskState checkpointStateHandles =
+						new SubtaskState(chainedStateHandle, null, null, null, null, 0L);
 
 				getEnvironment().acknowledgeCheckpoint(
 						new CheckpointMetaData(checkpointMetaData.getCheckpointId(), -1, 0L, 0L, 0L, 0L),

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
index 305625e..3521630 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
@@ -23,12 +23,12 @@ import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.testutils.CommonTestUtils;
 import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTest;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.checkpoint.SubtaskState;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.NotifyCheckpointComplete;
 import org.apache.flink.runtime.messages.checkpoint.TriggerCheckpoint;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.junit.Test;
@@ -67,11 +67,14 @@ public class CheckpointMessagesTest {
 
 			KeyGroupRange keyGroupRange = KeyGroupRange.of(42,42);
 
-			CheckpointStateHandles checkpointStateHandles =
-					new CheckpointStateHandles(
+			SubtaskState checkpointStateHandles =
+					new SubtaskState(
 							CheckpointCoordinatorTest.generateChainedStateHandle(new MyHandle()),
-							CheckpointCoordinatorTest.generateChainedPartitionableStateHandle(new JobVertexID(), 0, 2, 8),
-							CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, Collections.singletonList(new MyHandle())));
+							CheckpointCoordinatorTest.generateChainedPartitionableStateHandle(new JobVertexID(), 0, 2, 8, false),
+							null,
+							CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, Collections.singletonList(new MyHandle())),
+							null,
+							0L);
 
 			AcknowledgeCheckpoint withState = new AcknowledgeCheckpoint(
 					new JobID(),

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
index 04ba4e5..f2616b5 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
@@ -26,6 +26,7 @@ import org.apache.flink.core.fs.Path;
 import org.apache.flink.runtime.accumulators.AccumulatorRegistry;
 import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.checkpoint.SubtaskState;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.io.disk.iomanager.IOManager;
@@ -37,7 +38,6 @@ import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.metrics.groups.TaskMetricGroup;
 import org.apache.flink.runtime.query.KvStateRegistry;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
 
 import java.util.Collections;
@@ -155,8 +155,7 @@ public class DummyEnvironment implements Environment {
 	}
 
 	@Override
-	public void acknowledgeCheckpoint(
-			CheckpointMetaData checkpointMetaData, CheckpointStateHandles checkpointStateHandles) {
+	public void acknowledgeCheckpoint(CheckpointMetaData checkpointMetaData, SubtaskState subtaskState) {
 	}
 
 	@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
index eb55c4d..08b84cb 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
@@ -28,6 +28,7 @@ import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.runtime.accumulators.AccumulatorRegistry;
 import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.checkpoint.SubtaskState;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.io.disk.iomanager.IOManager;
@@ -46,7 +47,6 @@ import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.metrics.groups.TaskMetricGroup;
 import org.apache.flink.runtime.query.KvStateRegistry;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
 import org.apache.flink.types.Record;
 import org.apache.flink.util.MutableObjectIterator;
@@ -316,8 +316,7 @@ public class MockEnvironment implements Environment {
 	}
 
 	@Override
-	public void acknowledgeCheckpoint(
-			CheckpointMetaData checkpointMetaData, CheckpointStateHandles checkpointStateHandles) {
+	public void acknowledgeCheckpoint(CheckpointMetaData checkpointMetaData, SubtaskState subtaskState) {
 	}
 
 	@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupRangeOffsetTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupRangeOffsetTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupRangeOffsetTest.java
index 95564cc..fb24712 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupRangeOffsetTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupRangeOffsetTest.java
@@ -45,7 +45,7 @@ public class KeyGroupRangeOffsetTest {
 				keyGroupRangeOffsets.getKeyGroupRange()));
 
 		intersection = keyGroupRangeOffsets.getIntersection(KeyGroupRange.of(11, 13));
-		Assert.assertEquals(KeyGroupRange.EMPTY_KEY_GROUP, intersection.getKeyGroupRange());
+		Assert.assertEquals(KeyGroupRange.EMPTY_KEY_GROUP_RANGE, intersection.getKeyGroupRange());
 		Assert.assertFalse(intersection.iterator().hasNext());
 
 		intersection = keyGroupRangeOffsets.getIntersection(KeyGroupRange.of(5, 13));
@@ -129,7 +129,7 @@ public class KeyGroupRangeOffsetTest {
 			Assert.assertFalse(keyGroupRange.getKeyGroupRange().contains(startKeyGroup - 1));
 			Assert.assertFalse(keyGroupRange.getKeyGroupRange().contains(endKeyGroup + 1));
 		} else {
-			Assert.assertEquals(KeyGroupRange.EMPTY_KEY_GROUP, keyGroupRange);
+			Assert.assertEquals(KeyGroupRange.EMPTY_KEY_GROUP_RANGE, keyGroupRange);
 		}
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupRangeTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupRangeTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupRangeTest.java
index ab0c327..94350ad 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupRangeTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupRangeTest.java
@@ -37,7 +37,7 @@ public class KeyGroupRangeTest {
 		keyGroupRange1 = KeyGroupRange.of(0,5);
 		keyGroupRange2 = KeyGroupRange.of(6,10);
 		intersection =keyGroupRange1.getIntersection(keyGroupRange2);
-		Assert.assertEquals(KeyGroupRange.EMPTY_KEY_GROUP, intersection);
+		Assert.assertEquals(KeyGroupRange.EMPTY_KEY_GROUP_RANGE, intersection);
 		Assert.assertEquals(intersection, keyGroupRange2.getIntersection(keyGroupRange1));
 
 		keyGroupRange1 = KeyGroupRange.of(0, 10);
@@ -93,7 +93,7 @@ public class KeyGroupRangeTest {
 			Assert.assertFalse(keyGroupRange.contains(startKeyGroup - 1));
 			Assert.assertFalse(keyGroupRange.contains(endKeyGroup + 1));
 		} else {
-			Assert.assertEquals(KeyGroupRange.EMPTY_KEY_GROUP, keyGroupRange);
+			Assert.assertEquals(KeyGroupRange.EMPTY_KEY_GROUP_RANGE, keyGroupRange);
 		}
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStreamTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStreamTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStreamTest.java
new file mode 100644
index 0000000..0c4ed74
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStreamTest.java
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.IOException;
+
+public class KeyedStateCheckpointOutputStreamTest {
+
+	private static final int STREAM_CAPACITY = 128;
+
+	private static KeyedStateCheckpointOutputStream createStream(KeyGroupRange keyGroupRange) {
+		CheckpointStreamFactory.CheckpointStateOutputStream checkStream =
+				new TestMemoryCheckpointOutputStream(STREAM_CAPACITY);
+		return new KeyedStateCheckpointOutputStream(checkStream, keyGroupRange);
+	}
+
+	private KeyGroupsStateHandle writeAllTestKeyGroups(
+			KeyedStateCheckpointOutputStream stream, KeyGroupRange keyRange) throws Exception {
+
+		DataOutputView dov = new DataOutputViewStreamWrapper(stream);
+		for (int kg : keyRange) {
+			stream.startNewKeyGroup(kg);
+			dov.writeInt(kg);
+		}
+
+		return stream.closeAndGetHandle();
+	}
+
+	@Test
+	public void testCloseNotPropagated() throws Exception {
+		KeyedStateCheckpointOutputStream stream = createStream(new KeyGroupRange(0, 0));
+		TestMemoryCheckpointOutputStream innerStream = (TestMemoryCheckpointOutputStream) stream.getDelegate();
+		stream.close();
+		Assert.assertFalse(innerStream.isClosed());
+	}
+
+	@Test
+	public void testEmptyKeyedStream() throws Exception {
+		final KeyGroupRange keyRange = new KeyGroupRange(0, 2);
+		KeyedStateCheckpointOutputStream stream = createStream(keyRange);
+		TestMemoryCheckpointOutputStream innerStream = (TestMemoryCheckpointOutputStream) stream.getDelegate();
+		KeyGroupsStateHandle emptyHandle = stream.closeAndGetHandle();
+		Assert.assertTrue(innerStream.isClosed());
+		Assert.assertEquals(null, emptyHandle);
+	}
+
+	@Test
+	public void testWriteReadRoundtrip() throws Exception {
+		final KeyGroupRange keyRange = new KeyGroupRange(0, 2);
+		KeyedStateCheckpointOutputStream stream = createStream(keyRange);
+		KeyGroupsStateHandle fullHandle = writeAllTestKeyGroups(stream, keyRange);
+		Assert.assertNotNull(fullHandle);
+
+		verifyRead(fullHandle, keyRange);
+	}
+
+	@Test
+	public void testWriteKeyGroupTracking() throws Exception {
+		final KeyGroupRange keyRange = new KeyGroupRange(0, 2);
+		KeyedStateCheckpointOutputStream stream = createStream(keyRange);
+
+		try {
+			stream.startNewKeyGroup(4711);
+			Assert.fail();
+		} catch (IllegalArgumentException expected) {
+			// good
+		}
+
+		Assert.assertEquals(-1, stream.getCurrentKeyGroup());
+
+		DataOutputView dov = new DataOutputViewStreamWrapper(stream);
+		int previous = -1;
+		for (int kg : keyRange) {
+			Assert.assertFalse(stream.isKeyGroupAlreadyStarted(kg));
+			Assert.assertFalse(stream.isKeyGroupAlreadyFinished(kg));
+			stream.startNewKeyGroup(kg);
+			if(-1 != previous) {
+				Assert.assertTrue(stream.isKeyGroupAlreadyStarted(previous));
+				Assert.assertTrue(stream.isKeyGroupAlreadyFinished(previous));
+			}
+			Assert.assertTrue(stream.isKeyGroupAlreadyStarted(kg));
+			Assert.assertFalse(stream.isKeyGroupAlreadyFinished(kg));
+			dov.writeInt(kg);
+			previous = kg;
+		}
+
+		KeyGroupsStateHandle fullHandle = stream.closeAndGetHandle();
+
+		verifyRead(fullHandle, keyRange);
+
+		for (int kg : keyRange) {
+			try {
+				stream.startNewKeyGroup(kg);
+				Assert.fail();
+			} catch (IOException ex) {
+				// required
+			}
+		}
+	}
+
+	@Test
+	public void testReadWriteMissingKeyGroups() throws Exception {
+		final KeyGroupRange keyRange = new KeyGroupRange(0, 2);
+		KeyedStateCheckpointOutputStream stream = createStream(keyRange);
+
+		DataOutputView dov = new DataOutputViewStreamWrapper(stream);
+		stream.startNewKeyGroup(1);
+		dov.writeInt(1);
+
+		KeyGroupsStateHandle fullHandle = stream.closeAndGetHandle();
+
+		int count = 0;
+		try (FSDataInputStream in = fullHandle.openInputStream()) {
+			DataInputView div = new DataInputViewStreamWrapper(in);
+			for (int kg : fullHandle.keyGroups()) {
+				long off = fullHandle.getOffsetForKeyGroup(kg);
+				if (off >= 0) {
+					in.seek(off);
+					Assert.assertEquals(1, div.readInt());
+					++count;
+				}
+			}
+		}
+
+		Assert.assertEquals(1, count);
+	}
+
+	private static void verifyRead(KeyGroupsStateHandle fullHandle, KeyGroupRange keyRange) throws IOException {
+		int count = 0;
+		try (FSDataInputStream in = fullHandle.openInputStream()) {
+			DataInputView div = new DataInputViewStreamWrapper(in);
+			for (int kg : fullHandle.keyGroups()) {
+				long off = fullHandle.getOffsetForKeyGroup(kg);
+				in.seek(off);
+				Assert.assertEquals(kg, div.readInt());
+				++count;
+			}
+		}
+
+		Assert.assertEquals(keyRange.getNumberOfKeyGroups(), count);
+	}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateOutputCheckpointStreamTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateOutputCheckpointStreamTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateOutputCheckpointStreamTest.java
new file mode 100644
index 0000000..c6ef0f0
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateOutputCheckpointStreamTest.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.IOException;
+
+public class OperatorStateOutputCheckpointStreamTest {
+
+	private static final int STREAM_CAPACITY = 128;
+
+	private static OperatorStateCheckpointOutputStream createStream() throws IOException {
+		CheckpointStreamFactory.CheckpointStateOutputStream checkStream =
+				new TestMemoryCheckpointOutputStream(STREAM_CAPACITY);
+		return new OperatorStateCheckpointOutputStream(checkStream);
+	}
+
+	private OperatorStateHandle writeAllTestKeyGroups(
+			OperatorStateCheckpointOutputStream stream, int numPartitions) throws Exception {
+
+		DataOutputView dov = new DataOutputViewStreamWrapper(stream);
+		for (int i = 0; i < numPartitions; ++i) {
+			Assert.assertEquals(i, stream.getNumberOfPartitions());
+			stream.startNewPartition();
+			dov.writeInt(i);
+		}
+
+		return stream.closeAndGetHandle();
+	}
+
+	@Test
+	public void testCloseNotPropagated() throws Exception {
+		OperatorStateCheckpointOutputStream stream = createStream();
+		TestMemoryCheckpointOutputStream innerStream = (TestMemoryCheckpointOutputStream) stream.getDelegate();
+		stream.close();
+		Assert.assertFalse(innerStream.isClosed());
+		innerStream.close();
+	}
+
+	@Test
+	public void testEmptyOperatorStream() throws Exception {
+		OperatorStateCheckpointOutputStream stream = createStream();
+		TestMemoryCheckpointOutputStream innerStream = (TestMemoryCheckpointOutputStream) stream.getDelegate();
+		OperatorStateHandle emptyHandle = stream.closeAndGetHandle();
+		Assert.assertTrue(innerStream.isClosed());
+		Assert.assertEquals(0, stream.getNumberOfPartitions());
+		Assert.assertEquals(null, emptyHandle);
+	}
+
+	@Test
+	public void testWriteReadRoundtrip() throws Exception {
+		int numPartitions = 3;
+		OperatorStateCheckpointOutputStream stream = createStream();
+		OperatorStateHandle fullHandle = writeAllTestKeyGroups(stream, numPartitions);
+		Assert.assertNotNull(fullHandle);
+
+		verifyRead(fullHandle, numPartitions);
+	}
+
+	private static void verifyRead(OperatorStateHandle fullHandle, int numPartitions) throws IOException {
+		int count = 0;
+		try (FSDataInputStream in = fullHandle.openInputStream()) {
+			long[] offsets = fullHandle.getStateNameToPartitionOffsets().
+					get(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME);
+
+			Assert.assertNotNull(offsets);
+
+			DataInputView div = new DataInputViewStreamWrapper(in);
+			for (int i = 0; i < numPartitions; ++i) {
+				in.seek(offsets[i]);
+				Assert.assertEquals(i, div.readInt());
+				++count;
+			}
+		}
+
+		Assert.assertEquals(numPartitions, count);
+	}
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
index 2f21574..9e835ce 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
@@ -38,7 +38,7 @@ import org.apache.flink.api.common.typeutils.base.IntSerializer;
 import org.apache.flink.api.common.typeutils.base.LongSerializer;
 import org.apache.flink.api.common.typeutils.base.StringSerializer;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
-import org.apache.flink.runtime.checkpoint.CheckpointCoordinator;
+import org.apache.flink.runtime.checkpoint.StateAssignmentOperation;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
 import org.apache.flink.runtime.query.KvStateID;
@@ -707,11 +707,11 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 
 		KeyGroupsStateHandle snapshot = runSnapshot(backend.snapshot(0, 0, streamFactory));
 
-		List<KeyGroupsStateHandle> firstHalfKeyGroupStates = CheckpointCoordinator.getKeyGroupsStateHandles(
+		List<KeyGroupsStateHandle> firstHalfKeyGroupStates = StateAssignmentOperation.getKeyGroupsStateHandles(
 				Collections.singletonList(snapshot),
 				KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(MAX_PARALLELISM, 2, 0));
 
-		List<KeyGroupsStateHandle> secondHalfKeyGroupStates = CheckpointCoordinator.getKeyGroupsStateHandles(
+		List<KeyGroupsStateHandle> secondHalfKeyGroupStates = StateAssignmentOperation.getKeyGroupsStateHandles(
 				Collections.singletonList(snapshot),
 				KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(MAX_PARALLELISM, 2, 1));
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/state/TestMemoryCheckpointOutputStream.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/TestMemoryCheckpointOutputStream.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/TestMemoryCheckpointOutputStream.java
new file mode 100644
index 0000000..5accc19
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/TestMemoryCheckpointOutputStream.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory;
+
+import java.io.IOException;
+
+final class TestMemoryCheckpointOutputStream extends MemCheckpointStreamFactory.MemoryCheckpointOutputStream {
+
+	private boolean closed;
+
+	public TestMemoryCheckpointOutputStream(int maxSize) {
+		super(maxSize);
+		this.closed = false;
+	}
+
+	@Override
+	public void close() {
+		this.closed = true;
+		super.close();
+	}
+
+	public boolean isClosed() {
+		return this.closed;
+	}
+
+	@Override
+	public StreamStateHandle closeAndGetHandle() throws IOException {
+		this.closed = true;
+		return super.closeAndGetHandle();
+	}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
index e2abe88..7dd67ed 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
@@ -48,6 +48,7 @@ import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.util.SerializedValue;
 import org.junit.Before;
 import org.junit.Test;
@@ -209,9 +210,7 @@ public class TaskAsyncCallTest {
 		}
 
 		@Override
-		public void setInitialState(ChainedStateHandle<StreamStateHandle> chainedState,
-									List<KeyGroupsStateHandle> keyGroupsState,
-									List<Collection<OperatorStateHandle>> partitionableOperatorState) throws Exception {
+		public void setInitialState(TaskStateHandles taskStateHandles) throws Exception {
 
 		}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-connectors/flink-connector-filesystem/src/test/java/org/apache/flink/streaming/connectors/fs/bucketing/BucketingSinkTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-filesystem/src/test/java/org/apache/flink/streaming/connectors/fs/bucketing/BucketingSinkTest.java b/flink-streaming-connectors/flink-connector-filesystem/src/test/java/org/apache/flink/streaming/connectors/fs/bucketing/BucketingSinkTest.java
index ac1e3f0..0c0111c 100644
--- a/flink-streaming-connectors/flink-connector-filesystem/src/test/java/org/apache/flink/streaming/connectors/fs/bucketing/BucketingSinkTest.java
+++ b/flink-streaming-connectors/flink-connector-filesystem/src/test/java/org/apache/flink/streaming/connectors/fs/bucketing/BucketingSinkTest.java
@@ -137,7 +137,7 @@ public class BucketingSinkTest {
 
 		// snapshot but don't call notify to simulate a notify that never
 		// arrives, the sink should move pending files in restore() in that case
-		StreamStateHandle snapshot1 = testHarness.snapshot(0, 0);
+		StreamStateHandle snapshot1 = testHarness.snapshotLegacy(0, 0);
 
 		testHarness = createTestSink(dataDir, clock);
 		testHarness.setup();

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
index 7d6bd76..db092f0 100644
--- a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
+++ b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
@@ -19,13 +19,16 @@ package org.apache.flink.streaming.connectors.kafka;
 
 import org.apache.commons.collections.map.LinkedMap;
 import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.OperatorStateStore;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.ClosureCleaner;
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.state.CheckpointListener;
-import org.apache.flink.api.common.state.OperatorStateStore;
+import org.apache.flink.runtime.state.DefaultOperatorStateBackend;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
 import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
 import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks;
 import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
@@ -37,11 +40,9 @@ import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition
 import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionState;
 import org.apache.flink.streaming.util.serialization.KeyedDeserializationSchema;
 import org.apache.flink.util.SerializedValue;
-
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.Comparator;
@@ -97,7 +98,7 @@ public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFuncti
 	 * The assigner is kept in serialized form, to deserialize it into multiple copies */
 	private SerializedValue<AssignerWithPunctuatedWatermarks<T>> punctuatedWatermarkAssigner;
 
-	private transient OperatorStateStore stateStore;
+	private transient ListState<Tuple2<KafkaTopicPartition, Long>> offsetsStateForCheckpoint;
 
 	// ------------------------------------------------------------------------
 	//  runtime state (used individually by each parallel subtask) 
@@ -311,33 +312,33 @@ public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFuncti
 	// ------------------------------------------------------------------------
 
 	@Override
-	public void initializeState(OperatorStateStore stateStore) throws Exception {
-
-		this.stateStore = stateStore;
+	public void initializeState(FunctionInitializationContext context) throws Exception {
 
-		ListState<Serializable> offsets =
-				stateStore.getSerializableListState(OperatorStateStore.DEFAULT_OPERATOR_STATE_NAME);
+		OperatorStateStore stateStore = context.getManagedOperatorStateStore();
+		offsetsStateForCheckpoint = stateStore.getSerializableListState(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME);
 
-		restoreToOffset = new HashMap<>();
+		if (context.isRestored()) {
+			restoreToOffset = new HashMap<>();
+			for (Tuple2<KafkaTopicPartition, Long> kafkaOffset : offsetsStateForCheckpoint.get()) {
+				restoreToOffset.put(kafkaOffset.f0, kafkaOffset.f1);
+			}
 
-		for (Serializable serializable : offsets.get()) {
-			@SuppressWarnings("unchecked")
-			Tuple2<KafkaTopicPartition, Long> kafkaOffset = (Tuple2<KafkaTopicPartition, Long>) serializable;
-			restoreToOffset.put(kafkaOffset.f0, kafkaOffset.f1);
+			LOG.info("Setting restore state in the FlinkKafkaConsumer.");
+			if (LOG.isDebugEnabled()) {
+				LOG.debug("Using the following offsets: {}", restoreToOffset);
+			}
+		} else {
+			LOG.info("No restore state for FlinkKafkaConsumer.");
 		}
-
-		LOG.info("Setting restore state in the FlinkKafkaConsumer: {}", restoreToOffset);
 	}
 
 	@Override
-	public void prepareSnapshot(long checkpointId, long timestamp) throws Exception {
+	public void snapshotState(FunctionSnapshotContext context) throws Exception {
 		if (!running) {
-			LOG.debug("storeOperatorState() called on closed source");
+			LOG.debug("snapshotState() called on closed source");
 		} else {
 
-			ListState<Serializable> listState =
-					stateStore.getSerializableListState(OperatorStateStore.DEFAULT_OPERATOR_STATE_NAME);
-			listState.clear();
+			offsetsStateForCheckpoint.clear();
 
 			final AbstractFetcher<?, ?> fetcher = this.kafkaFetcher;
 			if (fetcher == null) {
@@ -347,14 +348,16 @@ public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFuncti
 				if (restoreToOffset != null) {
 					// the map cannot be asynchronously updated, because only one checkpoint call can happen
 					// on this function at a time: either snapshotState() or notifyCheckpointComplete()
-					pendingOffsetsToCommit.put(checkpointId, restoreToOffset);
+					pendingOffsetsToCommit.put(context.getCheckpointId(), restoreToOffset);
 
 					for (Map.Entry<KafkaTopicPartition, Long> kafkaTopicPartitionLongEntry : restoreToOffset.entrySet()) {
-						listState.add(Tuple2.of(kafkaTopicPartitionLongEntry.getKey(), kafkaTopicPartitionLongEntry.getValue()));
+						offsetsStateForCheckpoint.add(
+								Tuple2.of(kafkaTopicPartitionLongEntry.getKey(), kafkaTopicPartitionLongEntry.getValue()));
 					}
 				} else if (subscribedPartitions != null) {
 					for (KafkaTopicPartition subscribedPartition : subscribedPartitions) {
-						listState.add(Tuple2.of(subscribedPartition, KafkaTopicPartitionState.OFFSET_NOT_SET));
+						offsetsStateForCheckpoint.add(
+								Tuple2.of(subscribedPartition, KafkaTopicPartitionState.OFFSET_NOT_SET));
 					}
 				}
 			} else {
@@ -362,10 +365,11 @@ public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFuncti
 
 				// the map cannot be asynchronously updated, because only one checkpoint call can happen
 				// on this function at a time: either snapshotState() or notifyCheckpointComplete()
-				pendingOffsetsToCommit.put(checkpointId, currentOffsets);
+				pendingOffsetsToCommit.put(context.getCheckpointId(), currentOffsets);
 
 				for (Map.Entry<KafkaTopicPartition, Long> kafkaTopicPartitionLongEntry : currentOffsets.entrySet()) {
-					listState.add(Tuple2.of(kafkaTopicPartitionLongEntry.getKey(), kafkaTopicPartitionLongEntry.getValue()));
+					offsetsStateForCheckpoint.add(
+							Tuple2.of(kafkaTopicPartitionLongEntry.getKey(), kafkaTopicPartitionLongEntry.getValue()));
 				}
 			}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java
index 26a695e..bede064 100644
--- a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java
+++ b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java
@@ -18,10 +18,12 @@
 package org.apache.flink.streaming.connectors.kafka;
 
 import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.api.common.state.OperatorStateStore;
 import org.apache.flink.api.java.ClosureCleaner;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.metrics.MetricGroup;
-import org.apache.flink.api.common.state.OperatorStateStore;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
 import org.apache.flink.runtime.util.SerializableObject;
 import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
 import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
@@ -330,12 +332,12 @@ public abstract class FlinkKafkaProducerBase<IN> extends RichSinkFunction<IN> im
 	protected abstract void flush();
 
 	@Override
-	public void initializeState(OperatorStateStore stateStore) throws Exception {
-		this.stateStore = stateStore;
+	public void initializeState(FunctionInitializationContext context) throws Exception {
+		this.stateStore = context.getManagedOperatorStateStore();
 	}
 
 	@Override
-	public void prepareSnapshot(long checkpointId, long timestamp) throws Exception {
+	public void snapshotState(FunctionSnapshotContext ctx) throws Exception {
 		if (flushOnCheckpoint) {
 			// flushing is activated: We need to wait until pendingRecords is 0
 			flush();