You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by se...@apache.org on 2016/09/30 12:47:55 UTC

[05/10] flink git commit: [FLINK-4379] [checkpoints] Introduce rescalable operator state

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index 9adaa86..c39e436 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -20,7 +20,9 @@ package org.apache.flink.runtime.checkpoint;
 
 import com.google.common.collect.Iterables;
 import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.fs.Path;
 import org.apache.flink.runtime.checkpoint.savepoint.HeapSavepointStore;
 import org.apache.flink.runtime.checkpoint.stats.DisabledCheckpointStatsTracker;
 import org.apache.flink.runtime.execution.ExecutionState;
@@ -34,21 +36,21 @@ import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.NotifyCheckpointComplete;
 import org.apache.flink.runtime.messages.checkpoint.TriggerCheckpoint;
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.filesystem.FileStateHandle;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.apache.flink.util.InstantiationUtil;
 import org.apache.flink.util.Preconditions;
-
 import org.junit.Assert;
 import org.junit.Test;
-
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
-
 import scala.concurrent.ExecutionContext;
 import scala.concurrent.Future;
 
@@ -56,6 +58,8 @@ import java.io.IOException;
 import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
@@ -1459,7 +1463,7 @@ public class CheckpointCoordinatorTest {
 					maxConcurrentAttempts,
 					new ExecutionVertex[] { triggerVertex },
 					new ExecutionVertex[] { ackVertex },
-					new ExecutionVertex[] { commitVertex }, 
+					new ExecutionVertex[] { commitVertex },
 					new StandaloneCheckpointIDCounter(),
 					new StandaloneCompletedCheckpointStore(2, cl),
 					new HeapSavepointStore(),
@@ -1531,7 +1535,7 @@ public class CheckpointCoordinatorTest {
 					maxConcurrentAttempts, // max two concurrent checkpoints
 					new ExecutionVertex[] { triggerVertex },
 					new ExecutionVertex[] { ackVertex },
-					new ExecutionVertex[] { commitVertex }, 
+					new ExecutionVertex[] { commitVertex },
 					new StandaloneCheckpointIDCounter(),
 					new StandaloneCompletedCheckpointStore(2, cl),
 					new HeapSavepointStore(),
@@ -1791,29 +1795,29 @@ public class CheckpointCoordinatorTest {
 
 		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
 			ChainedStateHandle<StreamStateHandle> nonPartitionedState = generateStateForVertex(jobVertexID1, index);
+			ChainedStateHandle<OperatorStateHandle> partitionableState = generateChainedPartitionableStateHandle(jobVertexID1, index, 2, 8);
 			List<KeyGroupsStateHandle> partitionedKeyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index));
 
+			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(nonPartitionedState, partitionableState, partitionedKeyGroupState);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
-				jid,
-				jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
-				checkpointId,
-				nonPartitionedState,
-				partitionedKeyGroupState);
+					jid,
+					jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+					checkpointId,
+					checkpointStateHandles);
 
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
 
-
 		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
 			ChainedStateHandle<StreamStateHandle> nonPartitionedState = generateStateForVertex(jobVertexID2, index);
+			ChainedStateHandle<OperatorStateHandle> partitionableState = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8);
 			List<KeyGroupsStateHandle> partitionedKeyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index));
-
+			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(nonPartitionedState, partitionableState, partitionedKeyGroupState);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
-				jid,
-				jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
-				checkpointId,
-				nonPartitionedState,
-				partitionedKeyGroupState);
+					jid,
+					jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+					checkpointId,
+					checkpointStateHandles);
 
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
@@ -1895,13 +1899,12 @@ public class CheckpointCoordinatorTest {
 		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
 			ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID1, index);
 			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index));
-
+			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(valueSizeTuple, null, keyGroupState);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
-				jid,
-				jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
-				checkpointId,
-				valueSizeTuple,
-				keyGroupState);
+					jid,
+					jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+					checkpointId,
+					checkpointStateHandles);
 
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
@@ -1910,13 +1913,12 @@ public class CheckpointCoordinatorTest {
 		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
 			ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID2, index);
 			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index));
-
+			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(valueSizeTuple, null, keyGroupState);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
-				jid,
-				jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
-				checkpointId,
-				valueSizeTuple,
-				keyGroupState);
+					jid,
+					jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+					checkpointId,
+					checkpointStateHandles);
 
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
@@ -2014,12 +2016,12 @@ public class CheckpointCoordinatorTest {
 			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(
 					jobVertexID1, keyGroupPartitions1.get(index));
 
+			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(valueSizeTuple, null, keyGroupState);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
-				jid,
-				jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
-				checkpointId,
-				valueSizeTuple,
-				keyGroupState);
+					jid,
+					jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+					checkpointId,
+					checkpointStateHandles);
 
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
@@ -2031,12 +2033,12 @@ public class CheckpointCoordinatorTest {
 			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(
 					jobVertexID2, keyGroupPartitions2.get(index));
 
+			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(state, null, keyGroupState);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
 					checkpointId,
-					state,
-					keyGroupState);
+					checkpointStateHandles);
 
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
@@ -2132,28 +2134,32 @@ public class CheckpointCoordinatorTest {
 
 		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
 			ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID1, index);
+			ChainedStateHandle<OperatorStateHandle> partitionableState = generateChainedPartitionableStateHandle(jobVertexID1, index, 2, 8);
 			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index));
 
+
+			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(valueSizeTuple, partitionableState, keyGroupState);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
 					checkpointId,
-					valueSizeTuple,
-					keyGroupState);
+					checkpointStateHandles);
 
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
 
 
+		final List<ChainedStateHandle<OperatorStateHandle>> originalPartitionableStates = new ArrayList<>(jobVertex2.getParallelism());
 		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
 			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index));
-
+			ChainedStateHandle<OperatorStateHandle> partitionableState = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8);
+			originalPartitionableStates.add(partitionableState);
+			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(null, partitionableState, keyGroupState);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
 					checkpointId,
-					null,
-					keyGroupState);
+					checkpointStateHandles);
 
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
@@ -2185,22 +2191,49 @@ public class CheckpointCoordinatorTest {
 
 		// verify the restored state
 		verifiyStateRestore(jobVertexID1, newJobVertex1, keyGroupPartitions1);
-
+		List<List<Collection<OperatorStateHandle>>> actualPartitionableStates = new ArrayList<>(newJobVertex2.getParallelism());
 		for (int i = 0; i < newJobVertex2.getParallelism(); i++) {
 			List<KeyGroupsStateHandle> originalKeyGroupState = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i));
 
 			ChainedStateHandle<StreamStateHandle> operatorState = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getChainedStateHandle();
+			List<Collection<OperatorStateHandle>> partitionableState = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getChainedPartitionableStateHandle();
 			List<KeyGroupsStateHandle> keyGroupState = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getKeyGroupsStateHandles();
 
+			actualPartitionableStates.add(partitionableState);
 			assertNull(operatorState);
-			comparePartitionedState(originalKeyGroupState, keyGroupState);
+			compareKeyPartitionedState(originalKeyGroupState, keyGroupState);
 		}
+		comparePartitionableState(originalPartitionableStates, actualPartitionableStates);
 	}
 
 	// ------------------------------------------------------------------------
 	//  Utilities
 	// ------------------------------------------------------------------------
 
+	static void sendAckMessageToCoordinator(
+			CheckpointCoordinator coord,
+			long checkpointId, JobID jid,
+			ExecutionJobVertex jobVertex,
+			JobVertexID jobVertexID,
+			List<KeyGroupRange> keyGroupPartitions) throws Exception {
+
+		for (int index = 0; index < jobVertex.getParallelism(); index++) {
+			ChainedStateHandle<StreamStateHandle> state = generateStateForVertex(jobVertexID, index);
+			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(
+					jobVertexID,
+					keyGroupPartitions.get(index));
+
+			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(state, null, keyGroupState);
+			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
+					jid,
+					jobVertex.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+					checkpointId,
+					checkpointStateHandles);
+
+			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
+		}
+	}
+
 	public static List<KeyGroupsStateHandle> generateKeyGroupState(
 			JobVertexID jobVertexID,
 			KeyGroupRange keyGroupPartition) throws IOException {
@@ -2217,23 +2250,45 @@ public class CheckpointCoordinatorTest {
 		return generateKeyGroupState(keyGroupPartition, testStatesLists);
 	}
 
-	public static List<KeyGroupsStateHandle> generateKeyGroupState(KeyGroupRange keyGroupRange, List< ? extends Serializable> states) throws IOException {
+	public static List<KeyGroupsStateHandle> generateKeyGroupState(
+			KeyGroupRange keyGroupRange,
+			List<? extends Serializable> states) throws IOException {
+
 		Preconditions.checkArgument(keyGroupRange.getNumberOfKeyGroups() == states.size());
 
-		long[] offsets = new long[keyGroupRange.getNumberOfKeyGroups()];
-		List<byte[]> serializedGroupValues = new ArrayList<>(offsets.length);
+		Tuple2<byte[], List<long[]>> serializedDataWithOffsets =
+				serializeTogetherAndTrackOffsets(Collections.<List<? extends Serializable>>singletonList(states));
+
+		KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(keyGroupRange, serializedDataWithOffsets.f1.get(0));
+
+		ByteStreamStateHandle allSerializedStatesHandle = new ByteStreamStateHandle(
+				serializedDataWithOffsets.f0);
+		KeyGroupsStateHandle keyGroupsStateHandle = new KeyGroupsStateHandle(
+				keyGroupRangeOffsets,
+				allSerializedStatesHandle);
+		List<KeyGroupsStateHandle> keyGroupsStateHandleList = new ArrayList<>();
+		keyGroupsStateHandleList.add(keyGroupsStateHandle);
+		return keyGroupsStateHandleList;
+	}
+
+	public static Tuple2<byte[], List<long[]>> serializeTogetherAndTrackOffsets(
+			List<List<? extends Serializable>> serializables) throws IOException {
 
-		KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(keyGroupRange, offsets);
+		List<long[]> offsets = new ArrayList<>(serializables.size());
+		List<byte[]> serializedGroupValues = new ArrayList<>();
 
 		int runningGroupsOffset = 0;
-		// generate test state for all keygroups
-		int idx = 0;
-		for (int keyGroup : keyGroupRange) {
-			keyGroupRangeOffsets.setKeyGroupOffset(keyGroup,runningGroupsOffset);
-			byte[] serializedValue = InstantiationUtil.serializeObject(states.get(idx));
-			runningGroupsOffset += serializedValue.length;
-			serializedGroupValues.add(serializedValue);
-			++idx;
+		for(List<? extends Serializable> list : serializables) {
+
+			long[] currentOffsets = new long[list.size()];
+			offsets.add(currentOffsets);
+
+			for (int i = 0; i < list.size(); ++i) {
+				currentOffsets[i] = runningGroupsOffset;
+				byte[] serializedValue = InstantiationUtil.serializeObject(list.get(i));
+				serializedGroupValues.add(serializedValue);
+				runningGroupsOffset += serializedValue.length;
+			}
 		}
 
 		//write all generated values in a single byte array, which is index by groupOffsetsInFinalByteArray
@@ -2248,15 +2303,7 @@ public class CheckpointCoordinatorTest {
 					serializedGroupValue.length);
 			runningGroupsOffset += serializedGroupValue.length;
 		}
-
-		ByteStreamStateHandle allSerializedStatesHandle = new ByteStreamStateHandle(
-				allSerializedValuesConcatenated);
-		KeyGroupsStateHandle keyGroupsStateHandle = new KeyGroupsStateHandle(
-				keyGroupRangeOffsets,
-				allSerializedStatesHandle);
-		List<KeyGroupsStateHandle> keyGroupsStateHandleList = new ArrayList<>();
-		keyGroupsStateHandleList.add(keyGroupsStateHandle);
-		return keyGroupsStateHandleList;
+		return new Tuple2<>(allSerializedValuesConcatenated, offsets);
 	}
 
 	public static ChainedStateHandle<StreamStateHandle> generateStateForVertex(
@@ -2273,6 +2320,55 @@ public class CheckpointCoordinatorTest {
 		return ChainedStateHandle.wrapSingleHandle(ByteStreamStateHandle.fromSerializable(value));
 	}
 
+	public static ChainedStateHandle<OperatorStateHandle> generateChainedPartitionableStateHandle(
+			JobVertexID jobVertexID,
+			int index,
+			int namedStates,
+			int partitionsPerState) throws IOException {
+
+		Map<String, List<? extends Serializable>> statesListsMap = new HashMap<>(namedStates);
+
+		for (int i = 0; i < namedStates; ++i) {
+			List<Integer> testStatesLists = new ArrayList<>(partitionsPerState);
+			// generate state
+			Random random = new Random(jobVertexID.hashCode() * index + i * namedStates);
+			for (int j = 0; j < partitionsPerState; ++j) {
+				int simulatedStateValue = random.nextInt();
+				testStatesLists.add(simulatedStateValue);
+			}
+			statesListsMap.put("state-" + i, testStatesLists);
+		}
+
+		return generateChainedPartitionableStateHandle(statesListsMap);
+	}
+
+	public static ChainedStateHandle<OperatorStateHandle> generateChainedPartitionableStateHandle(
+			Map<String, List<? extends Serializable>> states) throws IOException {
+
+		List<List<? extends Serializable>> namedStateSerializables = new ArrayList<>(states.size());
+
+		for (Map.Entry<String, List<? extends Serializable>> entry : states.entrySet()) {
+			namedStateSerializables.add(entry.getValue());
+		}
+
+		Tuple2<byte[], List<long[]>> serializationWithOffsets = serializeTogetherAndTrackOffsets(namedStateSerializables);
+
+		Map<String, long[]> offsetsMap = new HashMap<>(states.size());
+
+		int idx = 0;
+		for (Map.Entry<String, List<? extends Serializable>> entry : states.entrySet()) {
+			offsetsMap.put(entry.getKey(), serializationWithOffsets.f1.get(idx));
+			++idx;
+		}
+
+		ByteStreamStateHandle streamStateHandle = new ByteStreamStateHandle(
+				serializationWithOffsets.f0);
+
+		OperatorStateHandle operatorStateHandle =
+				new OperatorStateHandle(streamStateHandle, offsetsMap);
+		return ChainedStateHandle.wrapSingleHandle(operatorStateHandle);
+	}
+
 	public static ExecutionJobVertex mockExecutionJobVertex(
 		JobVertexID jobVertexID,
 		int parallelism,
@@ -2348,16 +2444,24 @@ public class CheckpointCoordinatorTest {
 					getTaskVertices()[i].getCurrentExecutionAttempt().getChainedStateHandle();
 			assertEquals(expectNonPartitionedState.get(0), actualNonPartitionedState.get(0));
 
+			ChainedStateHandle<OperatorStateHandle> expectedPartitionableState =
+					generateChainedPartitionableStateHandle(jobVertexID, i, 2, 8);
+
+			List<Collection<OperatorStateHandle>> actualPartitionableState = executionJobVertex.
+					getTaskVertices()[i].getCurrentExecutionAttempt().getChainedPartitionableStateHandle();
+
+			assertEquals(expectedPartitionableState.get(0), actualPartitionableState.get(0).iterator().next());
+
 			List<KeyGroupsStateHandle> expectPartitionedKeyGroupState = generateKeyGroupState(
 					jobVertexID,
 					keyGroupPartitions.get(i));
 			List<KeyGroupsStateHandle> actualPartitionedKeyGroupState = executionJobVertex.
 					getTaskVertices()[i].getCurrentExecutionAttempt().getKeyGroupsStateHandles();
-			comparePartitionedState(expectPartitionedKeyGroupState, actualPartitionedKeyGroupState);
+			compareKeyPartitionedState(expectPartitionedKeyGroupState, actualPartitionedKeyGroupState);
 		}
 	}
 
-	public static void comparePartitionedState(
+	public static void compareKeyPartitionedState(
 			List<KeyGroupsStateHandle> expectPartitionedKeyGroupState,
 			List<KeyGroupsStateHandle> actualPartitionedKeyGroupState) throws Exception {
 
@@ -2370,22 +2474,68 @@ public class CheckpointCoordinatorTest {
 
 		assertEquals(expectedTotalKeyGroups, actualTotalKeyGroups);
 
-		FSDataInputStream inputStream = expectedHeadOpKeyGroupStateHandle.getStateHandle().openInputStream();
-		for(int groupId : expectedHeadOpKeyGroupStateHandle.keyGroups()) {
-			long offset = expectedHeadOpKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
-			inputStream.seek(offset);
-			int expectedKeyGroupState = InstantiationUtil.deserializeObject(inputStream);
-			for(KeyGroupsStateHandle oneActualKeyGroupStateHandle : actualPartitionedKeyGroupState) {
-				if (oneActualKeyGroupStateHandle.containsKeyGroup(groupId)) {
-					long actualOffset = oneActualKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
-					FSDataInputStream actualInputStream = oneActualKeyGroupStateHandle.getStateHandle().openInputStream();
-					actualInputStream.seek(actualOffset);
-					int actualGroupState = InstantiationUtil.deserializeObject(actualInputStream);
-
-					assertEquals(expectedKeyGroupState, actualGroupState);
+		try (FSDataInputStream inputStream = expectedHeadOpKeyGroupStateHandle.getStateHandle().openInputStream()) {
+			for (int groupId : expectedHeadOpKeyGroupStateHandle.keyGroups()) {
+				long offset = expectedHeadOpKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
+				inputStream.seek(offset);
+				int expectedKeyGroupState = InstantiationUtil.deserializeObject(inputStream);
+				for (KeyGroupsStateHandle oneActualKeyGroupStateHandle : actualPartitionedKeyGroupState) {
+					if (oneActualKeyGroupStateHandle.containsKeyGroup(groupId)) {
+						long actualOffset = oneActualKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
+						try (FSDataInputStream actualInputStream =
+								     oneActualKeyGroupStateHandle.getStateHandle().openInputStream()) {
+							actualInputStream.seek(actualOffset);
+							int actualGroupState = InstantiationUtil.deserializeObject(actualInputStream);
+							assertEquals(expectedKeyGroupState, actualGroupState);
+						}
+					}
+				}
+			}
+		}
+	}
+
+	public static void comparePartitionableState(
+			List<ChainedStateHandle<OperatorStateHandle>> expected,
+			List<List<Collection<OperatorStateHandle>>> actual) throws Exception {
+
+		List<String> expectedResult = new ArrayList<>();
+		for (ChainedStateHandle<OperatorStateHandle> chainedStateHandle : expected) {
+			for (int i = 0; i < chainedStateHandle.getLength(); ++i) {
+				OperatorStateHandle operatorStateHandle = chainedStateHandle.get(i);
+				try (FSDataInputStream in = operatorStateHandle.openInputStream()) {
+					for (Map.Entry<String, long[]> entry : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) {
+						for (long offset : entry.getValue()) {
+							in.seek(offset);
+							Integer state = InstantiationUtil.deserializeObject(in);
+							expectedResult.add(i + " : " + entry.getKey() + " : " + state);
+						}
+					}
 				}
 			}
 		}
+		Collections.sort(expectedResult);
+
+		List<String> actualResult = new ArrayList<>();
+		for (List<Collection<OperatorStateHandle>> collectionList : actual) {
+			if (collectionList != null) {
+				for (int i = 0; i < collectionList.size(); ++i) {
+					Collection<OperatorStateHandle> stateHandles = collectionList.get(i);
+					for (OperatorStateHandle operatorStateHandle : stateHandles) {
+						try (FSDataInputStream in = operatorStateHandle.openInputStream()) {
+							for (Map.Entry<String, long[]> entry : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) {
+								for (long offset : entry.getValue()) {
+									in.seek(offset);
+									Integer state = InstantiationUtil.deserializeObject(in);
+									actualResult.add(i + " : " + entry.getKey() + " : " + state);
+								}
+							}
+						}
+					}
+				}
+			}
+		}
+		Collections.sort(actualResult);
+		Assert.assertEquals(expectedResult, actualResult);
 	}
 
 	@Test
@@ -2415,4 +2565,117 @@ public class CheckpointCoordinatorTest {
 		}
 	}
 
+
+	@Test
+	public void testPartitionableStateRepartitioning() {
+		Random r = new Random(42);
+
+		for (int run = 0; run < 10000; ++run) {
+			int oldParallelism = 1 + r.nextInt(9);
+			int newParallelism = 1 + r.nextInt(9);
+
+			int numNamedStates = 1 + r.nextInt(9);
+			int maxPartitionsPerState = 1 + r.nextInt(9);
+
+			doTestPartitionableStateRepartitioning(
+					r, oldParallelism, newParallelism, numNamedStates, maxPartitionsPerState);
+		}
+	}
+
+	private void doTestPartitionableStateRepartitioning(
+			Random r, int oldParallelism, int newParallelism, int numNamedStates, int maxPartitionsPerState) {
+
+		List<OperatorStateHandle> previousParallelOpInstanceStates = new ArrayList<>(oldParallelism);
+
+		for (int i = 0; i < oldParallelism; ++i) {
+			Path fakePath = new Path("/fake-" + i);
+			Map<String, long[]> namedStatesToOffsets = new HashMap<>();
+			int off = 0;
+			for (int s = 0; s < numNamedStates; ++s) {
+				long[] offs = new long[1 + r.nextInt(maxPartitionsPerState)];
+				if (offs.length > 0) {
+					for (int o = 0; o < offs.length; ++o) {
+						offs[o] = off;
+						++off;
+					}
+					namedStatesToOffsets.put("State-" + s, offs);
+				}
+			}
+
+			previousParallelOpInstanceStates.add(
+					new OperatorStateHandle(new FileStateHandle(fakePath, -1), namedStatesToOffsets));
+		}
+
+		Map<StreamStateHandle, Map<String, List<Long>>> expected = new HashMap<>();
+
+		int expectedTotalPartitions = 0;
+		for (OperatorStateHandle psh : previousParallelOpInstanceStates) {
+			Map<String, long[]> offsMap = psh.getStateNameToPartitionOffsets();
+			Map<String, List<Long>> offsMapWithList = new HashMap<>(offsMap.size());
+			for (Map.Entry<String, long[]> e : offsMap.entrySet()) {
+				long[] offs = e.getValue();
+				expectedTotalPartitions += offs.length;
+				List<Long> offsList = new ArrayList<>(offs.length);
+				for (int i = 0; i < offs.length; ++i) {
+					offsList.add(i, offs[i]);
+				}
+				offsMapWithList.put(e.getKey(), offsList);
+			}
+			expected.put(psh.getDelegateStateHandle(), offsMapWithList);
+		}
+
+		OperatorStateRepartitioner repartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE;
+
+		List<Collection<OperatorStateHandle>> pshs =
+				repartitioner.repartitionState(previousParallelOpInstanceStates, newParallelism);
+
+		Map<StreamStateHandle, Map<String, List<Long>>> actual = new HashMap<>();
+
+		int minCount = Integer.MAX_VALUE;
+		int maxCount = 0;
+		int actualTotalPartitions = 0;
+		for (int p = 0; p < newParallelism; ++p) {
+			int partitionCount = 0;
+
+			Collection<OperatorStateHandle> pshc = pshs.get(p);
+			for (OperatorStateHandle sh : pshc) {
+				for (Map.Entry<String, long[]> namedState : sh.getStateNameToPartitionOffsets().entrySet()) {
+
+					Map<String, List<Long>> x = actual.get(sh.getDelegateStateHandle());
+					if (x == null) {
+						x = new HashMap<>();
+						actual.put(sh.getDelegateStateHandle(), x);
+					}
+
+					List<Long> actualOffs = x.get(namedState.getKey());
+					if (actualOffs == null) {
+						actualOffs = new ArrayList<>();
+						x.put(namedState.getKey(), actualOffs);
+					}
+					long[] add = namedState.getValue();
+					for (int i = 0; i < add.length; ++i) {
+						actualOffs.add(add[i]);
+					}
+
+					partitionCount += namedState.getValue().length;
+				}
+			}
+
+			minCount = Math.min(minCount, partitionCount);
+			maxCount = Math.max(maxCount, partitionCount);
+			actualTotalPartitions += partitionCount;
+		}
+
+		for (Map<String, List<Long>> v : actual.values()) {
+			for (List<Long> l : v.values()) {
+				Collections.sort(l);
+			}
+		}
+
+		int maxLoadDiff = maxCount - minCount;
+		Assert.assertTrue("Difference in partition load is > 1 : " + maxLoadDiff, maxLoadDiff <= 1);
+		Assert.assertEquals(expectedTotalPartitions, actualTotalPartitions);
+		Assert.assertEquals(expected, actual);
+	}
+
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
index a4896aa..bb78b6a 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
@@ -29,14 +29,18 @@ import org.apache.flink.runtime.executiongraph.ExecutionVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.util.SerializableObject;
-
+import org.hamcrest.BaseMatcher;
+import org.hamcrest.Description;
 import org.junit.Test;
 import org.mockito.Mockito;
 
+import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
@@ -112,9 +116,11 @@ public class CheckpointStateRestoreTest {
 			PendingCheckpoint pending = coord.getPendingCheckpoints().values().iterator().next();
 			final long checkpointId = pending.getCheckpointId();
 
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, serializedState, serializedKeyGroupStates));
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointId, serializedState, serializedKeyGroupStates));
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, serializedState, serializedKeyGroupStates));
+			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(serializedState, null, serializedKeyGroupStates);
+
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, checkpointStateHandles));
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointId, checkpointStateHandles));
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, checkpointStateHandles));
 			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec1.getAttemptId(), checkpointId));
 			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec2.getAttemptId(), checkpointId));
 
@@ -125,11 +131,27 @@ public class CheckpointStateRestoreTest {
 			coord.restoreLatestCheckpointedState(map, true, false);
 
 			// verify that each stateful vertex got the state
-			verify(statefulExec1, times(1)).setInitialState(Mockito.eq(serializedState), Mockito.<List<KeyGroupsStateHandle>>any());
-			verify(statefulExec2, times(1)).setInitialState(Mockito.eq(serializedState), Mockito.<List<KeyGroupsStateHandle>>any());
-			verify(statefulExec3, times(1)).setInitialState(Mockito.eq(serializedState), Mockito.<List<KeyGroupsStateHandle>>any());
-			verify(statelessExec1, times(0)).setInitialState(Mockito.<ChainedStateHandle<StreamStateHandle>>any(), Mockito.<List<KeyGroupsStateHandle>>any());
-			verify(statelessExec2, times(0)).setInitialState(Mockito.<ChainedStateHandle<StreamStateHandle>>any(), Mockito.<List<KeyGroupsStateHandle>>any());
+
+			BaseMatcher<CheckpointStateHandles> matcher = new BaseMatcher<CheckpointStateHandles>() {
+				@Override
+				public boolean matches(Object o) {
+					if (o instanceof CheckpointStateHandles) {
+						return ((CheckpointStateHandles) o).getNonPartitionedStateHandles().equals(serializedState);
+					}
+					return false;
+				}
+
+				@Override
+				public void describeTo(Description description) {
+					description.appendValue(serializedState);
+				}
+			};
+
+			verify(statefulExec1, times(1)).setInitialState(Mockito.argThat(matcher), Mockito.<List<Collection<OperatorStateHandle>>>any());
+			verify(statefulExec2, times(1)).setInitialState(Mockito.argThat(matcher), Mockito.<List<Collection<OperatorStateHandle>>>any());
+			verify(statefulExec3, times(1)).setInitialState(Mockito.argThat(matcher), Mockito.<List<Collection<OperatorStateHandle>>>any());
+			verify(statelessExec1, times(0)).setInitialState(Mockito.<CheckpointStateHandles>any(), Mockito.<List<Collection<OperatorStateHandle>>>any());
+			verify(statelessExec2, times(0)).setInitialState(Mockito.<CheckpointStateHandles>any(), Mockito.<List<Collection<OperatorStateHandle>>>any());
 		}
 		catch (Exception e) {
 			e.printStackTrace();
@@ -193,9 +215,11 @@ public class CheckpointStateRestoreTest {
 			final long checkpointId = pending.getCheckpointId();
 
 			// the difference to the test "testSetState" is that one stateful subtask does not report state
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, serializedState, serializedKeyGroupStates));
+			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(serializedState, null, serializedKeyGroupStates);
+
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, checkpointStateHandles));
 			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointId));
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, serializedState, serializedKeyGroupStates));
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, checkpointStateHandles));
 			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec1.getAttemptId(), checkpointId));
 			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec2.getAttemptId(), checkpointId));
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
index 6182ffd..289f5c3 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
@@ -197,7 +197,7 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 		JobVertexID jvid = new JobVertexID();
 
 		Map<JobVertexID, TaskState> taskGroupStates = new HashMap<>();
-		TaskState taskState = new TaskState(jvid, numberOfStates, numberOfStates);
+		TaskState taskState = new TaskState(jvid, numberOfStates, numberOfStates, 1);
 		taskGroupStates.put(jvid, taskState);
 
 		for (int i = 0; i < numberOfStates; i++) {

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
index fd4e02d..b8126e9 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
@@ -106,7 +106,7 @@ public class PendingCheckpointTest {
 		PendingCheckpoint pending = createPendingCheckpoint();
 		PendingCheckpointTest.setTaskState(pending, state);
 
-		pending.acknowledgeTask(ATTEMPT_ID, null, null);
+		pending.acknowledgeTask(ATTEMPT_ID, null);
 
 		CompletedCheckpoint checkpoint = pending.finalizeCheckpoint();
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingSavepointTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingSavepointTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingSavepointTest.java
index 7258545..3701359 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingSavepointTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingSavepointTest.java
@@ -117,7 +117,7 @@ public class PendingSavepointTest {
 
 		Future<String> future = pending.getCompletionFuture();
 
-		pending.acknowledgeTask(ATTEMPT_ID, null, null);
+		pending.acknowledgeTask(ATTEMPT_ID, null);
 
 		CompletedCheckpoint checkpoint = pending.finalizeCheckpoint();
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
index 6a8d072..9fbe574 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
@@ -186,10 +186,5 @@ public class ZooKeeperCompletedCheckpointStoreITCase extends CompletedCheckpoint
 		public long getStateSize() throws IOException {
 			return 0;
 		}
-
-		@Override
-		public void close() throws IOException {
-			
-		}
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
index ef10032..c82be18 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
@@ -24,6 +24,7 @@ import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.junit.Test;
@@ -32,7 +33,9 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.concurrent.ThreadLocalRandom;
 
 import static org.junit.Assert.assertEquals;
@@ -67,17 +70,30 @@ public class SavepointV1Test {
 		List<TaskState> taskStates = new ArrayList<>(numTaskStates);
 
 		for (int i = 0; i < numTaskStates; i++) {
-			TaskState taskState = new TaskState(new JobVertexID(), numSubtaskStates, numSubtaskStates);
+			TaskState taskState = new TaskState(new JobVertexID(), numSubtaskStates, numSubtaskStates, 1);
 			for (int j = 0; j < numSubtaskStates; j++) {
 				StreamStateHandle stateHandle = new ByteStreamStateHandle("Hello".getBytes());
 				taskState.putState(i, new SubtaskState(
 						new ChainedStateHandle<>(Collections.singletonList(stateHandle)), 0));
+
+				stateHandle = new ByteStreamStateHandle("Beautiful".getBytes());
+				Map<String, long[]> offsetsMap = new HashMap<>();
+				offsetsMap.put("A", new long[]{0, 10, 20});
+				offsetsMap.put("B", new long[]{30, 40, 50});
+
+				OperatorStateHandle operatorStateHandle =
+						new OperatorStateHandle(stateHandle, offsetsMap);
+
+				taskState.putPartitionableState(
+						i,
+						new ChainedStateHandle<OperatorStateHandle>(
+								Collections.singletonList(operatorStateHandle)));
 			}
 
 			taskState.putKeyedState(
 					0,
 					new KeyGroupsStateHandle(
-							new KeyGroupRangeOffsets(1,1, new long[] {42}), new ByteStreamStateHandle("Hello".getBytes())));
+							new KeyGroupRangeOffsets(1,1, new long[] {42}), new ByteStreamStateHandle("World".getBytes())));
 
 			taskStates.add(taskState);
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java
index 504143b..1e95732 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java
@@ -319,7 +319,7 @@ public class SimpleCheckpointStatsTrackerTest {
 				JobVertexID operatorId = operatorIds[operatorIndex];
 				int parallelism = operatorParallelism[operatorIndex];
 
-				TaskState taskState = new TaskState(operatorId, parallelism, maxParallelism);
+				TaskState taskState = new TaskState(operatorId, parallelism, maxParallelism, 1);
 
 				taskGroupStates.put(operatorId, taskState);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
index ef8e3bd..9b12cac 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
@@ -26,6 +26,7 @@ import akka.testkit.JavaTestKit;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.runtime.akka.AkkaUtils;
 import org.apache.flink.runtime.akka.ListeningBehaviour;
 import org.apache.flink.runtime.blob.BlobServer;
@@ -54,9 +55,11 @@ import org.apache.flink.runtime.leaderelection.TestingLeaderRetrievalService;
 import org.apache.flink.runtime.leaderretrieval.LeaderRetrievalService;
 import org.apache.flink.runtime.messages.JobManagerMessages;
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.RetrievableStreamStateHandle;
+import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.apache.flink.runtime.taskmanager.TaskManager;
 import org.apache.flink.runtime.testingUtils.TestingJobManager;
@@ -80,6 +83,7 @@ import scala.concurrent.duration.FiniteDuration;
 
 import java.util.ArrayDeque;
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
@@ -441,10 +445,15 @@ public class JobManagerHARecoveryTest {
 		private int completedCheckpoints = 0;
 
 		@Override
-		public void setInitialState(ChainedStateHandle<StreamStateHandle> chainedState, List<KeyGroupsStateHandle> keyGroupsState) throws Exception {
+		public void setInitialState(
+			ChainedStateHandle<StreamStateHandle> chainedState,
+			List<KeyGroupsStateHandle> keyGroupsState,
+			List<Collection<OperatorStateHandle>> partitionableOperatorState) throws Exception {
 			int subtaskIndex = getIndexInSubtaskGroup();
 			if (subtaskIndex < recoveredStates.length) {
-				recoveredStates[subtaskIndex] = InstantiationUtil.deserializeObject(chainedState.get(0).openInputStream());
+				try (FSDataInputStream in = chainedState.get(0).openInputStream()) {
+					recoveredStates[subtaskIndex] = InstantiationUtil.deserializeObject(in);
+				}
 			}
 		}
 
@@ -456,11 +465,12 @@ public class JobManagerHARecoveryTest {
 
 				RetrievableStreamStateHandle<Long> state = new RetrievableStreamStateHandle<Long>(byteStreamStateHandle);
 				ChainedStateHandle<StreamStateHandle> chainedStateHandle = new ChainedStateHandle<StreamStateHandle>(Collections.singletonList(state));
+				CheckpointStateHandles checkpointStateHandles =
+						new CheckpointStateHandles(chainedStateHandle, null, Collections.<KeyGroupsStateHandle>emptyList());
 
 				getEnvironment().acknowledgeCheckpoint(
 						checkpointId,
-						chainedStateHandle,
-						Collections.<KeyGroupsStateHandle>emptyList(),
+						checkpointStateHandles,
 						0L, 0L, 0L, 0L);
 				return true;
 			} catch (Exception ex) {

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
index 6a6ac64..4873335 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
@@ -23,11 +23,12 @@ import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.testutils.CommonTestUtils;
 import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTest;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.NotifyCheckpointComplete;
 import org.apache.flink.runtime.messages.checkpoint.TriggerCheckpoint;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.StateObject;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.junit.Test;
 
@@ -65,13 +66,17 @@ public class CheckpointMessagesTest {
 
 			KeyGroupRange keyGroupRange = KeyGroupRange.of(42,42);
 
+			CheckpointStateHandles checkpointStateHandles =
+					new CheckpointStateHandles(
+							CheckpointCoordinatorTest.generateChainedStateHandle(new MyHandle()),
+							CheckpointCoordinatorTest.generateChainedPartitionableStateHandle(new JobVertexID(), 0, 2, 8),
+							CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, Collections.singletonList(new MyHandle())));
+
 			AcknowledgeCheckpoint withState = new AcknowledgeCheckpoint(
 					new JobID(),
 					new ExecutionAttemptID(),
 					87658976143L,
-					CheckpointCoordinatorTest.generateChainedStateHandle(new MyHandle()),
-					CheckpointCoordinatorTest.generateKeyGroupState(
-							keyGroupRange, Collections.singletonList(new MyHandle())));
+					checkpointStateHandles);
 
 			testSerializabilityEqualsHashCode(noState);
 			testSerializabilityEqualsHashCode(withState);
@@ -83,7 +88,6 @@ public class CheckpointMessagesTest {
 
 	private static void testSerializabilityEqualsHashCode(Serializable o) throws IOException {
 		Object copy = CommonTestUtils.createCopySerializable(o);
-		System.out.println(o.getClass() +" "+copy.getClass());
 		assertEquals(o, copy);
 		assertEquals(o.hashCode(), copy.hashCode());
 		assertNotNull(o.toString());
@@ -117,9 +121,6 @@ public class CheckpointMessagesTest {
 		}
 
 		@Override
-		public void close() throws IOException {}
-
-		@Override
 		public FSDataInputStream openInputStream() throws IOException {
 			return null;
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
index a857d1b..c855230 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
@@ -37,6 +37,7 @@ import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.query.KvStateRegistry;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
@@ -162,7 +163,7 @@ public class DummyEnvironment implements Environment {
 	@Override
 	public void acknowledgeCheckpoint(
 			long checkpointId,
-			ChainedStateHandle<StreamStateHandle> chainedStateHandle, List<KeyGroupsStateHandle> keyGroupStateHandles,
+			CheckpointStateHandles checkpointStateHandles,
 			long synchronousDurationMillis, long asynchronousDurationMillis,
 			long bytesBufferedInAlignment, long alignmentDurationNanos) {
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
index 75e88eb..c3ed6c0 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
@@ -46,6 +46,7 @@ import org.apache.flink.runtime.metrics.groups.TaskMetricGroup;
 import org.apache.flink.runtime.query.KvStateRegistry;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 
@@ -323,7 +324,7 @@ public class MockEnvironment implements Environment {
 	@Override
 	public void acknowledgeCheckpoint(
 			long checkpointId,
-			ChainedStateHandle<StreamStateHandle> chainedStateHandle, List<KeyGroupsStateHandle> keyGroupStateHandles,
+			CheckpointStateHandles checkpointStateHandles,
 			long synchronousDurationMillis, long asynchronousDurationMillis,
 			long bytesBufferedInAlignment, long alignmentDurationNanos) {
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/query/QueryableStateClientTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/query/QueryableStateClientTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/query/QueryableStateClientTest.java
index 1039568..4279635 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/query/QueryableStateClientTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/query/QueryableStateClientTest.java
@@ -32,8 +32,8 @@ import org.apache.flink.runtime.query.netty.KvStateClient;
 import org.apache.flink.runtime.query.netty.KvStateServer;
 import org.apache.flink.runtime.query.netty.UnknownKvStateID;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.runtime.state.heap.HeapValueState;
@@ -246,7 +246,7 @@ public class QueryableStateClientTest {
 		MemoryStateBackend backend = new MemoryStateBackend();
 		DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0);
 
-		KeyedStateBackend<Integer> keyedStateBackend = backend.createKeyedStateBackend(dummyEnv,
+		AbstractKeyedStateBackend<Integer> keyedStateBackend = backend.createKeyedStateBackend(dummyEnv,
 				new JobID(),
 				"test_op",
 				IntSerializer.INSTANCE,

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientTest.java
index c8fb4bb..0db8b31 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientTest.java
@@ -41,9 +41,9 @@ import org.apache.flink.runtime.query.KvStateServerAddress;
 import org.apache.flink.runtime.query.netty.message.KvStateRequest;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestType;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.KvState;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
@@ -538,7 +538,8 @@ public class KvStateClientTest {
 		KvStateRegistry dummyRegistry = new KvStateRegistry();
 		DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0);
 		dummyEnv.setKvStateRegistry(dummyRegistry);
-		KeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend(
+
+		AbstractKeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend(
 				dummyEnv,
 				new JobID(),
 				"test_op",

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerHandlerTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerHandlerTest.java
index 7e6d713..ed4a822 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerHandlerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerHandlerTest.java
@@ -38,6 +38,7 @@ import org.apache.flink.runtime.query.netty.message.KvStateRequestFailure;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestResult;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestType;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyedStateBackend;
@@ -92,7 +93,7 @@ public class KvStateServerHandlerTest {
 		AbstractStateBackend abstractBackend = new MemoryStateBackend();
 		DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0);
 		dummyEnv.setKvStateRegistry(registry);
-		KeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend(
+		AbstractKeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend(
 				dummyEnv,
 				new JobID(),
 				"test_op",
@@ -490,7 +491,7 @@ public class KvStateServerHandlerTest {
 		AbstractStateBackend abstractBackend = new MemoryStateBackend();
 		DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0);
 		dummyEnv.setKvStateRegistry(registry);
-		KeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend(
+		AbstractKeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend(
 				dummyEnv,
 				new JobID(),
 				"test_op",
@@ -586,7 +587,7 @@ public class KvStateServerHandlerTest {
 		AbstractStateBackend abstractBackend = new MemoryStateBackend();
 		DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0);
 		dummyEnv.setKvStateRegistry(registry);
-		KeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend(
+		AbstractKeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend(
 				dummyEnv,
 				new JobID(),
 				"test_op",

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java
index e92fb10..b1c4a9f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java
@@ -41,9 +41,9 @@ import org.apache.flink.runtime.query.KvStateServerAddress;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestResult;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestType;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
@@ -91,7 +91,7 @@ public class KvStateServerTest {
 			AbstractStateBackend abstractBackend = new MemoryStateBackend();
 			DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0);
 			dummyEnv.setKvStateRegistry(registry);
-			KeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend(
+			AbstractKeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend(
 					dummyEnv,
 					new JobID(),
 					"test_op",

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/state/AbstractCloseableHandleTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/AbstractCloseableHandleTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/AbstractCloseableHandleTest.java
deleted file mode 100644
index e613105..0000000
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/AbstractCloseableHandleTest.java
+++ /dev/null
@@ -1,97 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.runtime.state;
-
-import org.junit.Test;
-
-import java.io.Closeable;
-import java.io.IOException;
-
-import static org.junit.Assert.*;
-import static org.mockito.Mockito.*;
-
-public class AbstractCloseableHandleTest {
-
-	@Test
-	public void testRegisterThenClose() throws Exception {
-		Closeable closeable = mock(Closeable.class);
-
-		AbstractCloseableHandle handle = new CloseableHandle();
-		assertFalse(handle.isClosed());
-
-		// no immediate closing
-		handle.registerCloseable(closeable);
-		verify(closeable, times(0)).close();
-		assertFalse(handle.isClosed());
-
-		// close forwarded once
-		handle.close();
-		verify(closeable, times(1)).close();
-		assertTrue(handle.isClosed());
-
-		// no repeated closing
-		handle.close();
-		verify(closeable, times(1)).close();
-		assertTrue(handle.isClosed());
-	}
-
-	@Test
-	public void testCloseThenRegister() throws Exception {
-		Closeable closeable = mock(Closeable.class);
-
-		AbstractCloseableHandle handle = new CloseableHandle();
-		assertFalse(handle.isClosed());
-
-		// close the handle before setting the closeable
-		handle.close();
-		assertTrue(handle.isClosed());
-
-		// immediate closing
-		try {
-			handle.registerCloseable(closeable);
-			fail("this should throw an excepion");
-		} catch (IOException e) {
-			// expected
-			assertTrue(e.getMessage().contains("closed"));
-		}
-
-		// should still have called "close" on the Closeable
-		verify(closeable, times(1)).close();
-		assertTrue(handle.isClosed());
-
-		// no repeated closing
-		handle.close();
-		verify(closeable, times(1)).close();
-		assertTrue(handle.isClosed());
-	}
-
-	// ------------------------------------------------------------------------
-
-	private static final class CloseableHandle extends AbstractCloseableHandle {
-		private static final long serialVersionUID = 1L;
-
-		@Override
-		public void discardState() {}
-
-		@Override
-		public long getStateSize() {
-			return 0;
-		}
-	}
-}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java
index bc0b9c3..0b04ebc 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java
@@ -20,16 +20,11 @@ package org.apache.flink.runtime.state;
 
 import org.apache.commons.io.FileUtils;
 import org.apache.flink.api.common.JobID;
-import org.apache.flink.api.common.typeutils.base.IntSerializer;
-import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.core.fs.Path;
 import org.apache.flink.core.testutils.CommonTestUtils;
-import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
-
 import org.apache.flink.runtime.state.filesystem.FileStateHandle;
 import org.apache.flink.runtime.state.filesystem.FsStateBackend;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
-
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.TemporaryFolder;
@@ -39,9 +34,12 @@ import java.io.IOException;
 import java.io.InputStream;
 import java.net.URI;
 import java.util.Random;
-import java.util.UUID;
 
-import static org.junit.Assert.*;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 
 public class FileStateBackendTest extends StateBackendTestBase<FsStateBackend> {
 
@@ -188,18 +186,21 @@ public class FileStateBackendTest extends StateBackendTestBase<FsStateBackend> {
 	}
 
 	private static void validateBytesInStream(InputStream is, byte[] data) throws IOException {
-		byte[] holder = new byte[data.length];
+		try {
+			byte[] holder = new byte[data.length];
 
-		int pos = 0;
-		int read;
-		while (pos < holder.length && (read = is.read(holder, pos, holder.length - pos)) != -1) {
-			pos += read;
-		}
+			int pos = 0;
+			int read;
+			while (pos < holder.length && (read = is.read(holder, pos, holder.length - pos)) != -1) {
+				pos += read;
+			}
 
-		assertEquals("not enough data", holder.length, pos);
-		assertEquals("too much data", -1, is.read());
-		assertArrayEquals("wrong data", data, holder);
-		is.close();
+			assertEquals("not enough data", holder.length, pos);
+			assertEquals("too much data", -1, is.read());
+			assertArrayEquals("wrong data", data, holder);
+		} finally {
+			is.close();
+		}
 	}
 
 	@Test

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
index 944938b..ac6adff 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
@@ -19,8 +19,6 @@
 package org.apache.flink.runtime.state;
 
 import org.apache.flink.api.common.JobID;
-import org.apache.flink.api.common.typeutils.base.IntSerializer;
-import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.junit.Test;
 
@@ -29,7 +27,10 @@ import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
 import java.util.HashMap;
 
-import static org.junit.Assert.*;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 
 /**
  * Tests for the {@link org.apache.flink.runtime.state.memory.MemoryStateBackend}.
@@ -105,10 +106,10 @@ public class MemoryStateBackendTest extends StateBackendTestBase<MemoryStateBack
 
 			assertNotNull(handle);
 
-			ObjectInputStream ois = new ObjectInputStream(handle.openInputStream());
-			assertEquals(state, ois.readObject());
-			assertTrue(ois.available() <= 0);
-			ois.close();
+			try (ObjectInputStream ois = new ObjectInputStream(handle.openInputStream())) {
+				assertEquals(state, ois.readObject());
+				assertTrue(ois.available() <= 0);
+			}
 		}
 		catch (Exception e) {
 			e.printStackTrace();

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
new file mode 100644
index 0000000..56c8987
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.java.typeutils.runtime.JavaSerializer;
+import org.apache.flink.runtime.state.memory.MemoryStateBackend;
+import org.junit.Test;
+
+import java.io.Serializable;
+import java.util.Collections;
+import java.util.Iterator;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+public class OperatorStateBackendTest {
+
+	AbstractStateBackend abstractStateBackend = new MemoryStateBackend(1024);
+
+	private OperatorStateBackend createNewOperatorStateBackend() throws Exception {
+		return abstractStateBackend.createOperatorStateBackend(null, "test-operator");
+	}
+
+	@Test
+	public void testCreateNew() throws Exception {
+		OperatorStateBackend operatorStateBackend = createNewOperatorStateBackend();
+		assertNotNull(operatorStateBackend);
+		assertTrue(operatorStateBackend.getRegisteredStateNames().isEmpty());
+	}
+
+	@Test
+	public void testRegisterStates() throws Exception {
+		OperatorStateBackend operatorStateBackend = createNewOperatorStateBackend();
+		ListStateDescriptor<Serializable> stateDescriptor1 = new ListStateDescriptor<>("test1", new JavaSerializer<>());
+		ListStateDescriptor<Serializable> stateDescriptor2 = new ListStateDescriptor<>("test2", new JavaSerializer<>());
+		ListState<Serializable> listState1 = operatorStateBackend.getPartitionableState(stateDescriptor1);
+		assertNotNull(listState1);
+		assertEquals(1, operatorStateBackend.getRegisteredStateNames().size());
+		Iterator<Serializable> it = listState1.get().iterator();
+		assertTrue(!it.hasNext());
+		listState1.add(42);
+		listState1.add(4711);
+
+		it = listState1.get().iterator();
+		assertEquals(42, it.next());
+		assertEquals(4711, it.next());
+		assertTrue(!it.hasNext());
+
+		ListState<Serializable> listState2 = operatorStateBackend.getPartitionableState(stateDescriptor2);
+		assertNotNull(listState2);
+		assertEquals(2, operatorStateBackend.getRegisteredStateNames().size());
+		assertTrue(!it.hasNext());
+		listState2.add(7);
+		listState2.add(13);
+		listState2.add(23);
+
+		it = listState2.get().iterator();
+		assertEquals(7, it.next());
+		assertEquals(13, it.next());
+		assertEquals(23, it.next());
+		assertTrue(!it.hasNext());
+
+		ListState<Serializable> listState1b = operatorStateBackend.getPartitionableState(stateDescriptor1);
+		assertNotNull(listState1b);
+		listState1b.add(123);
+		it = listState1b.get().iterator();
+		assertEquals(42, it.next());
+		assertEquals(4711, it.next());
+		assertEquals(123, it.next());
+		assertTrue(!it.hasNext());
+
+		it = listState1.get().iterator();
+		assertEquals(42, it.next());
+		assertEquals(4711, it.next());
+		assertEquals(123, it.next());
+		assertTrue(!it.hasNext());
+
+		it = listState1b.get().iterator();
+		assertEquals(42, it.next());
+		assertEquals(4711, it.next());
+		assertEquals(123, it.next());
+		assertTrue(!it.hasNext());
+	}
+
+	@Test
+	public void testSnapshotRestore() throws Exception {
+		OperatorStateBackend operatorStateBackend = createNewOperatorStateBackend();
+		ListStateDescriptor<Serializable> stateDescriptor1 = new ListStateDescriptor<>("test1", new JavaSerializer<>());
+		ListStateDescriptor<Serializable> stateDescriptor2 = new ListStateDescriptor<>("test2", new JavaSerializer<>());
+		ListState<Serializable> listState1 = operatorStateBackend.getPartitionableState(stateDescriptor1);
+		ListState<Serializable> listState2 = operatorStateBackend.getPartitionableState(stateDescriptor2);
+
+		listState1.add(42);
+		listState1.add(4711);
+
+		listState2.add(7);
+		listState2.add(13);
+		listState2.add(23);
+
+		CheckpointStreamFactory streamFactory = abstractStateBackend.createStreamFactory(new JobID(), "testOperator");
+		OperatorStateHandle stateHandle = operatorStateBackend.snapshot(1, 1, streamFactory).get();
+
+		try {
+
+			operatorStateBackend.dispose();
+
+			operatorStateBackend = abstractStateBackend.
+					restoreOperatorStateBackend(null, "testOperator", Collections.singletonList(stateHandle));
+
+			assertEquals(0, operatorStateBackend.getRegisteredStateNames().size());
+
+			listState1 = operatorStateBackend.getPartitionableState(stateDescriptor1);
+			listState2 = operatorStateBackend.getPartitionableState(stateDescriptor2);
+
+			assertEquals(2, operatorStateBackend.getRegisteredStateNames().size());
+
+
+			Iterator<Serializable> it = listState1.get().iterator();
+			assertEquals(42, it.next());
+			assertEquals(4711, it.next());
+			assertTrue(!it.hasNext());
+
+			it = listState2.get().iterator();
+			assertEquals(7, it.next());
+			assertEquals(13, it.next());
+			assertEquals(23, it.next());
+			assertTrue(!it.hasNext());
+
+			operatorStateBackend.dispose();
+		} finally {
+
+			stateHandle.discardState();
+		}
+	}
+
+}
\ No newline at end of file