You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by pn...@apache.org on 2019/10/28 08:22:31 UTC

[flink] 01/08: [hotfix] Extract utils of CheckpointCoordinatorTest into a separate utils class and correct codestyle

This is an automated email from the ASF dual-hosted git repository.

pnowojski pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 126cff6ab4e433a69c774e126b2e85335e30ee7a
Author: ifndef-SleePy <mm...@gmail.com>
AuthorDate: Wed Sep 18 17:30:04 2019 +0800

    [hotfix] Extract utils of CheckpointCoordinatorTest into a separate utils class and correct codestyle
---
 .../CheckpointCoordinatorFailureTest.java          |   5 +-
 .../CheckpointCoordinatorMasterHooksTest.java      |   4 +-
 .../checkpoint/CheckpointCoordinatorTest.java      | 517 +++------------------
 .../CheckpointCoordinatorTestingUtils.java         | 472 +++++++++++++++++++
 .../checkpoint/CheckpointStateRestoreTest.java     |   4 +-
 .../runtime/messages/CheckpointMessagesTest.java   |  13 +-
 6 files changed, 542 insertions(+), 473 deletions(-)

diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
index da181ba..b6b7930 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
@@ -49,6 +49,9 @@ import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
+/**
+ * Tests for failure of checkpoint coordinator.
+ */
 @RunWith(PowerMockRunner.class)
 @PrepareForTest(PendingCheckpoint.class)
 public class CheckpointCoordinatorFailureTest extends TestLogger {
@@ -62,7 +65,7 @@ public class CheckpointCoordinatorFailureTest extends TestLogger {
 		JobID jid = new JobID();
 
 		final ExecutionAttemptID executionAttemptId = new ExecutionAttemptID();
-		final ExecutionVertex vertex = CheckpointCoordinatorTest.mockExecutionVertex(executionAttemptId);
+		final ExecutionVertex vertex = CheckpointCoordinatorTestingUtils.mockExecutionVertex(executionAttemptId);
 
 		final long triggerTimestamp = 1L;
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java
index 48b6583..8453cba 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java
@@ -48,7 +48,7 @@ import java.util.List;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.Executor;
 
-import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTest.mockExecutionVertex;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockExecutionVertex;
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
@@ -353,7 +353,7 @@ public class CheckpointCoordinatorMasterHooksTest {
 	// ------------------------------------------------------------------------
 
 	/**
-	 * This test makes sure that the checkpoint is already registered by the time
+	 * This test makes sure that the checkpoint is already registered by the time.
 	 * that the hooks are called
 	 */
 	@Test
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index 98d80fd..034fac4 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -19,9 +19,7 @@
 package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.api.common.JobID;
-import org.apache.flink.api.common.time.Time;
 import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.fs.Path;
 import org.apache.flink.runtime.concurrent.Executors;
 import org.apache.flink.runtime.execution.ExecutionState;
@@ -39,7 +37,6 @@ import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.IncrementalRemoteKeyedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
-import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
@@ -55,8 +52,6 @@ import org.apache.flink.runtime.state.testutils.TestCompletedCheckpointStorageLo
 import org.apache.flink.runtime.testutils.CommonTestUtils;
 import org.apache.flink.runtime.testutils.RecoverableCompletedCheckpointStore;
 import org.apache.flink.util.ExceptionUtils;
-import org.apache.flink.util.InstantiationUtil;
-import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.SerializableObject;
 import org.apache.flink.util.TestLogger;
 
@@ -76,7 +71,6 @@ import org.mockito.stubbing.Answer;
 import org.mockito.verification.VerificationMode;
 
 import java.io.IOException;
-import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
@@ -91,11 +85,20 @@ import java.util.UUID;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
-import java.util.concurrent.Executor;
 import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.compareKeyedState;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.comparePartitionableState;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.generateChainedPartitionableStateHandle;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.generateKeyGroupState;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.generatePartitionableStateHandle;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockExecution;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockExecutionJobVertex;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockExecutionVertex;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockSubtaskState;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.verifyStateRestore;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotEquals;
@@ -439,7 +442,6 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			assertFalse(checkpoint.isDiscarded());
 			assertFalse(checkpoint.isFullyAcknowledged());
 
-
 			// decline checkpoint from the other task, this should cancel the checkpoint
 			// and trigger a new one
 			coord.receiveDeclineMessage(new DeclineCheckpoint(jid, attemptID1, checkpointId), TASK_MANAGER_LOCATION_INFO);
@@ -923,19 +925,19 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			OperatorID opID2 = OperatorID.fromJobVertexID(ackVertex2.getJobvertexId());
 			OperatorID opID3 = OperatorID.fromJobVertexID(ackVertex3.getJobvertexId());
 
-			TaskStateSnapshot taskOperatorSubtaskStates1_1 = spy(new TaskStateSnapshot());
-			TaskStateSnapshot taskOperatorSubtaskStates1_2 = spy(new TaskStateSnapshot());
-			TaskStateSnapshot taskOperatorSubtaskStates1_3 = spy(new TaskStateSnapshot());
+			TaskStateSnapshot taskOperatorSubtaskStates11 = spy(new TaskStateSnapshot());
+			TaskStateSnapshot taskOperatorSubtaskStates12 = spy(new TaskStateSnapshot());
+			TaskStateSnapshot taskOperatorSubtaskStates13 = spy(new TaskStateSnapshot());
 
-			OperatorSubtaskState subtaskState1_1 = mock(OperatorSubtaskState.class);
-			OperatorSubtaskState subtaskState1_2 = mock(OperatorSubtaskState.class);
-			OperatorSubtaskState subtaskState1_3 = mock(OperatorSubtaskState.class);
-			taskOperatorSubtaskStates1_1.putSubtaskStateByOperatorID(opID1, subtaskState1_1);
-			taskOperatorSubtaskStates1_2.putSubtaskStateByOperatorID(opID2, subtaskState1_2);
-			taskOperatorSubtaskStates1_3.putSubtaskStateByOperatorID(opID3, subtaskState1_3);
+			OperatorSubtaskState subtaskState11 = mock(OperatorSubtaskState.class);
+			OperatorSubtaskState subtaskState12 = mock(OperatorSubtaskState.class);
+			OperatorSubtaskState subtaskState13 = mock(OperatorSubtaskState.class);
+			taskOperatorSubtaskStates11.putSubtaskStateByOperatorID(opID1, subtaskState11);
+			taskOperatorSubtaskStates12.putSubtaskStateByOperatorID(opID2, subtaskState12);
+			taskOperatorSubtaskStates13.putSubtaskStateByOperatorID(opID3, subtaskState13);
 
 			// acknowledge one of the three tasks
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates1_2), TASK_MANAGER_LOCATION_INFO);
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates12), TASK_MANAGER_LOCATION_INFO);
 
 			// start the second checkpoint
 			// trigger the first checkpoint. this should succeed
@@ -953,17 +955,17 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			}
 			long checkpointId2 = pending2.getCheckpointId();
 
-			TaskStateSnapshot taskOperatorSubtaskStates2_1 = spy(new TaskStateSnapshot());
-			TaskStateSnapshot taskOperatorSubtaskStates2_2 = spy(new TaskStateSnapshot());
-			TaskStateSnapshot taskOperatorSubtaskStates2_3 = spy(new TaskStateSnapshot());
+			TaskStateSnapshot taskOperatorSubtaskStates21 = spy(new TaskStateSnapshot());
+			TaskStateSnapshot taskOperatorSubtaskStates22 = spy(new TaskStateSnapshot());
+			TaskStateSnapshot taskOperatorSubtaskStates23 = spy(new TaskStateSnapshot());
 
-			OperatorSubtaskState subtaskState2_1 = mock(OperatorSubtaskState.class);
-			OperatorSubtaskState subtaskState2_2 = mock(OperatorSubtaskState.class);
-			OperatorSubtaskState subtaskState2_3 = mock(OperatorSubtaskState.class);
+			OperatorSubtaskState subtaskState21 = mock(OperatorSubtaskState.class);
+			OperatorSubtaskState subtaskState22 = mock(OperatorSubtaskState.class);
+			OperatorSubtaskState subtaskState23 = mock(OperatorSubtaskState.class);
 
-			taskOperatorSubtaskStates2_1.putSubtaskStateByOperatorID(opID1, subtaskState2_1);
-			taskOperatorSubtaskStates2_2.putSubtaskStateByOperatorID(opID2, subtaskState2_2);
-			taskOperatorSubtaskStates2_3.putSubtaskStateByOperatorID(opID3, subtaskState2_3);
+			taskOperatorSubtaskStates21.putSubtaskStateByOperatorID(opID1, subtaskState21);
+			taskOperatorSubtaskStates22.putSubtaskStateByOperatorID(opID2, subtaskState22);
+			taskOperatorSubtaskStates23.putSubtaskStateByOperatorID(opID3, subtaskState23);
 
 			// trigger messages should have been sent
 			verify(triggerVertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId2), eq(timestamp2), any(CheckpointOptions.class));
@@ -972,13 +974,13 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			// we acknowledge one more task from the first checkpoint and the second
 			// checkpoint completely. The second checkpoint should then subsume the first checkpoint
 
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates2_3), TASK_MANAGER_LOCATION_INFO);
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates23), TASK_MANAGER_LOCATION_INFO);
 
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates2_1), TASK_MANAGER_LOCATION_INFO);
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates21), TASK_MANAGER_LOCATION_INFO);
 
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates1_1), TASK_MANAGER_LOCATION_INFO);
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates11), TASK_MANAGER_LOCATION_INFO);
 
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates2_2), TASK_MANAGER_LOCATION_INFO);
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates22), TASK_MANAGER_LOCATION_INFO);
 
 			// now, the second checkpoint should be confirmed, and the first discarded
 			// actually both pending checkpoints are discarded, and the second has been transformed
@@ -990,13 +992,13 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			assertEquals(1, coord.getNumberOfRetainedSuccessfulCheckpoints());
 
 			// validate that all received subtask states in the first checkpoint have been discarded
-			verify(subtaskState1_1, times(1)).discardState();
-			verify(subtaskState1_2, times(1)).discardState();
+			verify(subtaskState11, times(1)).discardState();
+			verify(subtaskState12, times(1)).discardState();
 
 			// validate that all subtask states in the second checkpoint are not discarded
-			verify(subtaskState2_1, never()).discardState();
-			verify(subtaskState2_2, never()).discardState();
-			verify(subtaskState2_3, never()).discardState();
+			verify(subtaskState21, never()).discardState();
+			verify(subtaskState22, never()).discardState();
+			verify(subtaskState23, never()).discardState();
 
 			// validate the committed checkpoints
 			List<CompletedCheckpoint> scs = coord.getSuccessfulCheckpoints();
@@ -1010,15 +1012,15 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			verify(commitVertex.getCurrentExecutionAttempt(), times(1)).notifyCheckpointComplete(eq(checkpointId2), eq(timestamp2));
 
 			// send the last remaining ack for the first checkpoint. This should not do anything
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates1_3), TASK_MANAGER_LOCATION_INFO);
-			verify(subtaskState1_3, times(1)).discardState();
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates13), TASK_MANAGER_LOCATION_INFO);
+			verify(subtaskState13, times(1)).discardState();
 
 			coord.shutdown(JobStatus.FINISHED);
 
 			// validate that the states in the second checkpoint have been discarded
-			verify(subtaskState2_1, times(1)).discardState();
-			verify(subtaskState2_2, times(1)).discardState();
-			verify(subtaskState2_3, times(1)).discardState();
+			verify(subtaskState21, times(1)).discardState();
+			verify(subtaskState22, times(1)).discardState();
+			verify(subtaskState23, times(1)).discardState();
 
 		}
 		catch (Exception e) {
@@ -1377,7 +1379,6 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
 			coord.stopCheckpointScheduler();
 
-
 			// for 400 ms, no further calls may come.
 			// there may be the case that one trigger was fired and about to
 			// acquire the lock, such that after cancelling it will still do
@@ -1385,7 +1386,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			int numCallsSoFar = numCalls.get();
 			Thread.sleep(400);
 			assertTrue(numCallsSoFar == numCalls.get() ||
-					numCallsSoFar+1 == numCalls.get());
+					numCallsSoFar + 1 == numCalls.get());
 
 			// start another sequence of periodic scheduling
 			numCalls.set(0);
@@ -2200,7 +2201,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 	 *
 	 * @throws Exception
 	 */
-	@Test(expected=IllegalStateException.class)
+	@Test(expected = IllegalStateException.class)
 	public void testRestoreLatestCheckpointFailureWhenMaxParallelismChanges() throws Exception {
 		final JobID jid = new JobID();
 		final long timestamp = System.currentTimeMillis();
@@ -2275,7 +2276,6 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint, TASK_MANAGER_LOCATION_INFO);
 		}
 
-
 		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
 			KeyGroupsStateHandle keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false);
 			OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(null, null, keyGroupState, null);
@@ -2391,9 +2391,9 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
 			long checkpointId = checkpointIDCounter.getLast();
 
-			KeyGroupRange keyGroupRange = KeyGroupRange.of(0,0);
+			KeyGroupRange keyGroupRange = KeyGroupRange.of(0, 0);
 			List<SerializableObject> testStates = Collections.singletonList(new SerializableObject());
-			KeyedStateHandle serializedKeyGroupStates = CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, testStates);
+			KeyedStateHandle serializedKeyGroupStates = generateKeyGroupState(keyGroupRange, testStates);
 
 			TaskStateSnapshot subtaskStatesForCheckpoint = new TaskStateSnapshot();
 
@@ -2416,10 +2416,9 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			timestamp = System.currentTimeMillis();
 			CompletableFuture<CompletedCheckpoint> savepointFuture = coord.triggerSavepoint(timestamp, savepointDir);
 
-
 			KeyGroupRange keyGroupRangeForSavepoint = KeyGroupRange.of(1, 1);
 			List<SerializableObject> testStatesForSavepoint = Collections.singletonList(new SerializableObject());
-			KeyedStateHandle serializedKeyGroupStatesForSavepoint = CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRangeForSavepoint, testStatesForSavepoint);
+			KeyedStateHandle serializedKeyGroupStatesForSavepoint = generateKeyGroupState(keyGroupRangeForSavepoint, testStatesForSavepoint);
 
 			TaskStateSnapshot subtaskStatesForSavepoint = new TaskStateSnapshot();
 
@@ -2677,16 +2676,18 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		OperatorID operatorID = OperatorID.fromJobVertexID(jobVertexID);
 		return new Tuple2<>(jobVertexID, operatorID);
 	}
-	
+
 	/**
-	 * old topology
+	 * <p>
+	 * old topology.
 	 * [operator1,operator2] * parallelism1 -> [operator3,operator4] * parallelism2
+	 * </p>
 	 *
-	 *
+	 * <p>
 	 * new topology
 	 *
 	 * [operator5,operator1,operator3] * newParallelism1 -> [operator3, operator6] * newParallelism2
-	 *
+	 * </p>
 	 * scaleType:
 	 * 0  increase parallelism
 	 * 1  decrease parallelism
@@ -2956,7 +2957,6 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			Collection<KeyedStateHandle> keyedStateBackend = headOpState.getManagedKeyedState();
 			Collection<KeyedStateHandle> keyGroupStateRaw = headOpState.getRawKeyedState();
 
-
 			compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyedStateBackend);
 			compareKeyedState(Collections.singletonList(originalKeyedStateRaw), keyGroupStateRaw);
 		}
@@ -3019,383 +3019,6 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		}
 	}
 
-	// ------------------------------------------------------------------------
-	//  Utilities
-	// ------------------------------------------------------------------------
-
-	public static KeyGroupsStateHandle generateKeyGroupState(
-			JobVertexID jobVertexID,
-			KeyGroupRange keyGroupPartition, boolean rawState) throws IOException {
-
-		List<Integer> testStatesLists = new ArrayList<>(keyGroupPartition.getNumberOfKeyGroups());
-
-		// generate state for one keygroup
-		for (int keyGroupIndex : keyGroupPartition) {
-			int vertexHash = jobVertexID.hashCode();
-			int seed = rawState ? (vertexHash * (31 + keyGroupIndex)) : (vertexHash + keyGroupIndex);
-			Random random = new Random(seed);
-			int simulatedStateValue = random.nextInt();
-			testStatesLists.add(simulatedStateValue);
-		}
-
-		return generateKeyGroupState(keyGroupPartition, testStatesLists);
-	}
-
-	public static KeyGroupsStateHandle generateKeyGroupState(
-			KeyGroupRange keyGroupRange,
-			List<? extends Serializable> states) throws IOException {
-
-		Preconditions.checkArgument(keyGroupRange.getNumberOfKeyGroups() == states.size());
-
-		Tuple2<byte[], List<long[]>> serializedDataWithOffsets =
-				serializeTogetherAndTrackOffsets(Collections.<List<? extends Serializable>>singletonList(states));
-
-		KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(keyGroupRange, serializedDataWithOffsets.f1.get(0));
-
-		ByteStreamStateHandle allSerializedStatesHandle = new ByteStreamStateHandle(
-				String.valueOf(UUID.randomUUID()),
-				serializedDataWithOffsets.f0);
-
-		return new KeyGroupsStateHandle(keyGroupRangeOffsets, allSerializedStatesHandle);
-	}
-
-	public static Tuple2<byte[], List<long[]>> serializeTogetherAndTrackOffsets(
-			List<List<? extends Serializable>> serializables) throws IOException {
-
-		List<long[]> offsets = new ArrayList<>(serializables.size());
-		List<byte[]> serializedGroupValues = new ArrayList<>();
-
-		int runningGroupsOffset = 0;
-		for(List<? extends Serializable> list : serializables) {
-
-			long[] currentOffsets = new long[list.size()];
-			offsets.add(currentOffsets);
-
-			for (int i = 0; i < list.size(); ++i) {
-				currentOffsets[i] = runningGroupsOffset;
-				byte[] serializedValue = InstantiationUtil.serializeObject(list.get(i));
-				serializedGroupValues.add(serializedValue);
-				runningGroupsOffset += serializedValue.length;
-			}
-		}
-
-		//write all generated values in a single byte array, which is index by groupOffsetsInFinalByteArray
-		byte[] allSerializedValuesConcatenated = new byte[runningGroupsOffset];
-		runningGroupsOffset = 0;
-		for (byte[] serializedGroupValue : serializedGroupValues) {
-			System.arraycopy(
-					serializedGroupValue,
-					0,
-					allSerializedValuesConcatenated,
-					runningGroupsOffset,
-					serializedGroupValue.length);
-			runningGroupsOffset += serializedGroupValue.length;
-		}
-		return new Tuple2<>(allSerializedValuesConcatenated, offsets);
-	}
-
-	public static OperatorStateHandle generatePartitionableStateHandle(
-		JobVertexID jobVertexID,
-		int index,
-		int namedStates,
-		int partitionsPerState,
-		boolean rawState) throws IOException {
-
-		Map<String, List<? extends Serializable>> statesListsMap = new HashMap<>(namedStates);
-
-		for (int i = 0; i < namedStates; ++i) {
-			List<Integer> testStatesLists = new ArrayList<>(partitionsPerState);
-			// generate state
-			int seed = jobVertexID.hashCode() * index + i * namedStates;
-			if (rawState) {
-				seed = (seed + 1) * 31;
-			}
-			Random random = new Random(seed);
-			for (int j = 0; j < partitionsPerState; ++j) {
-				int simulatedStateValue = random.nextInt();
-				testStatesLists.add(simulatedStateValue);
-			}
-			statesListsMap.put("state-" + i, testStatesLists);
-		}
-
-		return generatePartitionableStateHandle(statesListsMap);
-	}
-
-	public static ChainedStateHandle<OperatorStateHandle> generateChainedPartitionableStateHandle(
-			JobVertexID jobVertexID,
-			int index,
-			int namedStates,
-			int partitionsPerState,
-			boolean rawState) throws IOException {
-
-		Map<String, List<? extends Serializable>> statesListsMap = new HashMap<>(namedStates);
-
-		for (int i = 0; i < namedStates; ++i) {
-			List<Integer> testStatesLists = new ArrayList<>(partitionsPerState);
-			// generate state
-			int seed = jobVertexID.hashCode() * index + i * namedStates;
-			if (rawState) {
-				seed = (seed + 1) * 31;
-			}
-			Random random = new Random(seed);
-			for (int j = 0; j < partitionsPerState; ++j) {
-				int simulatedStateValue = random.nextInt();
-				testStatesLists.add(simulatedStateValue);
-			}
-			statesListsMap.put("state-" + i, testStatesLists);
-		}
-
-		return ChainedStateHandle.wrapSingleHandle(generatePartitionableStateHandle(statesListsMap));
-	}
-
-	private static OperatorStateHandle generatePartitionableStateHandle(
-		Map<String, List<? extends Serializable>> states) throws IOException {
-
-		List<List<? extends Serializable>> namedStateSerializables = new ArrayList<>(states.size());
-
-		for (Map.Entry<String, List<? extends Serializable>> entry : states.entrySet()) {
-			namedStateSerializables.add(entry.getValue());
-		}
-
-		Tuple2<byte[], List<long[]>> serializationWithOffsets = serializeTogetherAndTrackOffsets(namedStateSerializables);
-
-		Map<String, OperatorStateHandle.StateMetaInfo> offsetsMap = new HashMap<>(states.size());
-
-		int idx = 0;
-		for (Map.Entry<String, List<? extends Serializable>> entry : states.entrySet()) {
-			offsetsMap.put(
-				entry.getKey(),
-				new OperatorStateHandle.StateMetaInfo(
-					serializationWithOffsets.f1.get(idx),
-					OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
-			++idx;
-		}
-
-		ByteStreamStateHandle streamStateHandle = new ByteStreamStateHandle(
-			String.valueOf(UUID.randomUUID()),
-			serializationWithOffsets.f0);
-
-		return new OperatorStreamStateHandle(offsetsMap, streamStateHandle);
-	}
-
-	static ExecutionJobVertex mockExecutionJobVertex(
-			JobVertexID jobVertexID,
-			int parallelism,
-			int maxParallelism) {
-
-		return mockExecutionJobVertex(
-			jobVertexID,
-			Collections.singletonList(OperatorID.fromJobVertexID(jobVertexID)),
-			parallelism,
-			maxParallelism
-		);
-	}
-
-	static ExecutionJobVertex mockExecutionJobVertex(
-		JobVertexID jobVertexID,
-		List<OperatorID> jobVertexIDs,
-		int parallelism,
-		int maxParallelism) {
-		final ExecutionJobVertex executionJobVertex = mock(ExecutionJobVertex.class);
-
-		ExecutionVertex[] executionVertices = new ExecutionVertex[parallelism];
-
-		for (int i = 0; i < parallelism; i++) {
-			executionVertices[i] = mockExecutionVertex(
-				new ExecutionAttemptID(),
-				jobVertexID,
-				jobVertexIDs,
-				parallelism,
-				maxParallelism,
-				ExecutionState.RUNNING);
-
-			when(executionVertices[i].getParallelSubtaskIndex()).thenReturn(i);
-		}
-
-		when(executionJobVertex.getJobVertexId()).thenReturn(jobVertexID);
-		when(executionJobVertex.getTaskVertices()).thenReturn(executionVertices);
-		when(executionJobVertex.getParallelism()).thenReturn(parallelism);
-		when(executionJobVertex.getMaxParallelism()).thenReturn(maxParallelism);
-		when(executionJobVertex.isMaxParallelismConfigured()).thenReturn(true);
-		when(executionJobVertex.getOperatorIDs()).thenReturn(jobVertexIDs);
-		when(executionJobVertex.getUserDefinedOperatorIDs()).thenReturn(Arrays.asList(new OperatorID[jobVertexIDs.size()]));
-
-		return executionJobVertex;
-	}
-
-	static ExecutionVertex mockExecutionVertex(ExecutionAttemptID attemptID) {
-		JobVertexID jobVertexID = new JobVertexID();
-		return mockExecutionVertex(
-			attemptID,
-			jobVertexID,
-			Collections.singletonList(OperatorID.fromJobVertexID(jobVertexID)),
-			1,
-			1,
-			ExecutionState.RUNNING);
-	}
-
-	private static ExecutionVertex mockExecutionVertex(
-		ExecutionAttemptID attemptID,
-		JobVertexID jobVertexID,
-		List<OperatorID> jobVertexIDs,
-		int parallelism,
-		int maxParallelism,
-		ExecutionState state,
-		ExecutionState ... successiveStates) {
-
-		ExecutionVertex vertex = mock(ExecutionVertex.class);
-
-		final Execution exec = spy(new Execution(
-			mock(Executor.class),
-			vertex,
-			1,
-			1L,
-			1L,
-			Time.milliseconds(500L)
-		));
-		when(exec.getAttemptId()).thenReturn(attemptID);
-		when(exec.getState()).thenReturn(state, successiveStates);
-
-		when(vertex.getJobvertexId()).thenReturn(jobVertexID);
-		when(vertex.getCurrentExecutionAttempt()).thenReturn(exec);
-		when(vertex.getTotalNumberOfParallelSubtasks()).thenReturn(parallelism);
-		when(vertex.getMaxParallelism()).thenReturn(maxParallelism);
-
-		ExecutionJobVertex jobVertex = mock(ExecutionJobVertex.class);
-		when(jobVertex.getOperatorIDs()).thenReturn(jobVertexIDs);
-		
-		when(vertex.getJobVertex()).thenReturn(jobVertex);
-
-		return vertex;
-	}
-
-	static TaskStateSnapshot mockSubtaskState(
-		JobVertexID jobVertexID,
-		int index,
-		KeyGroupRange keyGroupRange) throws IOException {
-
-		OperatorStateHandle partitionableState = generatePartitionableStateHandle(jobVertexID, index, 2, 8, false);
-		KeyGroupsStateHandle partitionedKeyGroupState = generateKeyGroupState(jobVertexID, keyGroupRange, false);
-
-		TaskStateSnapshot subtaskStates = spy(new TaskStateSnapshot());
-		OperatorSubtaskState subtaskState = spy(new OperatorSubtaskState(
-			partitionableState, null, partitionedKeyGroupState, null)
-		);
-
-		subtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID), subtaskState);
-
-		return subtaskStates;
-	}
-
-	public static void verifyStateRestore(
-			JobVertexID jobVertexID, ExecutionJobVertex executionJobVertex,
-			List<KeyGroupRange> keyGroupPartitions) throws Exception {
-
-		for (int i = 0; i < executionJobVertex.getParallelism(); i++) {
-
-			JobManagerTaskRestore taskRestore = executionJobVertex.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskRestore();
-			Assert.assertEquals(1L, taskRestore.getRestoreCheckpointId());
-			TaskStateSnapshot stateSnapshot = taskRestore.getTaskStateSnapshot();
-
-			OperatorSubtaskState operatorState = stateSnapshot.getSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID));
-
-			ChainedStateHandle<OperatorStateHandle> expectedOpStateBackend =
-					generateChainedPartitionableStateHandle(jobVertexID, i, 2, 8, false);
-
-			assertTrue(CommonTestUtils.isStreamContentEqual(
-					expectedOpStateBackend.get(0).openInputStream(),
-					operatorState.getManagedOperatorState().iterator().next().openInputStream()));
-
-			KeyGroupsStateHandle expectPartitionedKeyGroupState = generateKeyGroupState(
-					jobVertexID, keyGroupPartitions.get(i), false);
-			compareKeyedState(Collections.singletonList(expectPartitionedKeyGroupState), operatorState.getManagedKeyedState());
-		}
-	}
-
-	public static void compareKeyedState(
-			Collection<KeyGroupsStateHandle> expectPartitionedKeyGroupState,
-			Collection<? extends KeyedStateHandle> actualPartitionedKeyGroupState) throws Exception {
-
-		KeyGroupsStateHandle expectedHeadOpKeyGroupStateHandle = expectPartitionedKeyGroupState.iterator().next();
-		int expectedTotalKeyGroups = expectedHeadOpKeyGroupStateHandle.getKeyGroupRange().getNumberOfKeyGroups();
-		int actualTotalKeyGroups = 0;
-		for(KeyedStateHandle keyedStateHandle: actualPartitionedKeyGroupState) {
-			assertTrue(keyedStateHandle instanceof KeyGroupsStateHandle);
-
-			actualTotalKeyGroups += keyedStateHandle.getKeyGroupRange().getNumberOfKeyGroups();
-		}
-
-		assertEquals(expectedTotalKeyGroups, actualTotalKeyGroups);
-
-		try (FSDataInputStream inputStream = expectedHeadOpKeyGroupStateHandle.openInputStream()) {
-			for (int groupId : expectedHeadOpKeyGroupStateHandle.getKeyGroupRange()) {
-				long offset = expectedHeadOpKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
-				inputStream.seek(offset);
-				int expectedKeyGroupState =
-						InstantiationUtil.deserializeObject(inputStream, Thread.currentThread().getContextClassLoader());
-				for (KeyedStateHandle oneActualKeyedStateHandle : actualPartitionedKeyGroupState) {
-
-					assertTrue(oneActualKeyedStateHandle instanceof KeyGroupsStateHandle);
-
-					KeyGroupsStateHandle oneActualKeyGroupStateHandle = (KeyGroupsStateHandle) oneActualKeyedStateHandle;
-					if (oneActualKeyGroupStateHandle.getKeyGroupRange().contains(groupId)) {
-						long actualOffset = oneActualKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
-						try (FSDataInputStream actualInputStream = oneActualKeyGroupStateHandle.openInputStream()) {
-							actualInputStream.seek(actualOffset);
-							int actualGroupState = InstantiationUtil.
-									deserializeObject(actualInputStream, Thread.currentThread().getContextClassLoader());
-							assertEquals(expectedKeyGroupState, actualGroupState);
-						}
-					}
-				}
-			}
-		}
-	}
-
-	public static void comparePartitionableState(
-			List<ChainedStateHandle<OperatorStateHandle>> expected,
-			List<List<Collection<OperatorStateHandle>>> actual) throws Exception {
-
-		List<String> expectedResult = new ArrayList<>();
-		for (ChainedStateHandle<OperatorStateHandle> chainedStateHandle : expected) {
-			for (int i = 0; i < chainedStateHandle.getLength(); ++i) {
-				OperatorStateHandle operatorStateHandle = chainedStateHandle.get(i);
-				collectResult(i, operatorStateHandle, expectedResult);
-			}
-		}
-		Collections.sort(expectedResult);
-
-		List<String> actualResult = new ArrayList<>();
-		for (List<Collection<OperatorStateHandle>> collectionList : actual) {
-			if (collectionList != null) {
-				for (int i = 0; i < collectionList.size(); ++i) {
-					Collection<OperatorStateHandle> stateHandles = collectionList.get(i);
-					Assert.assertNotNull(stateHandles);
-					for (OperatorStateHandle operatorStateHandle : stateHandles) {
-						collectResult(i, operatorStateHandle, actualResult);
-					}
-				}
-			}
-		}
-
-		Collections.sort(actualResult);
-		Assert.assertEquals(expectedResult, actualResult);
-	}
-
-	private static void collectResult(int opIdx, OperatorStateHandle operatorStateHandle, List<String> resultCollector) throws Exception {
-		try (FSDataInputStream in = operatorStateHandle.openInputStream()) {
-			for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> entry : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) {
-				for (long offset : entry.getValue().getOffsets()) {
-					in.seek(offset);
-					Integer state = InstantiationUtil.
-							deserializeObject(in, Thread.currentThread().getContextClassLoader());
-					resultCollector.add(opIdx + " : " + entry.getKey() + " : " + state);
-				}
-			}
-		}
-	}
-
-
 	@Test
 	public void testCreateKeyGroupPartitions() {
 		testCreateKeyGroupPartitions(1, 1);
@@ -4071,36 +3694,4 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint, TASK_MANAGER_LOCATION_INFO);
 		}
 	}
-
-	private Execution mockExecution() {
-		Execution mock = mock(Execution.class);
-		when(mock.getAttemptId()).thenReturn(new ExecutionAttemptID());
-		when(mock.getState()).thenReturn(ExecutionState.RUNNING);
-		return mock;
-	}
-
-	private ExecutionVertex mockExecutionVertex(Execution execution, JobVertexID vertexId, int subtask, int parallelism) {
-		ExecutionVertex mock = mock(ExecutionVertex.class);
-		when(mock.getJobvertexId()).thenReturn(vertexId);
-		when(mock.getParallelSubtaskIndex()).thenReturn(subtask);
-		when(mock.getCurrentExecutionAttempt()).thenReturn(execution);
-		when(mock.getTotalNumberOfParallelSubtasks()).thenReturn(parallelism);
-		when(mock.getMaxParallelism()).thenReturn(parallelism);
-		return mock;
-	}
-
-	private ExecutionJobVertex mockExecutionJobVertex(JobVertexID id, ExecutionVertex[] vertices) {
-		ExecutionJobVertex vertex = mock(ExecutionJobVertex.class);
-		when(vertex.getParallelism()).thenReturn(vertices.length);
-		when(vertex.getMaxParallelism()).thenReturn(vertices.length);
-		when(vertex.getJobVertexId()).thenReturn(id);
-		when(vertex.getTaskVertices()).thenReturn(vertices);
-		when(vertex.getOperatorIDs()).thenReturn(Collections.singletonList(OperatorID.fromJobVertexID(id)));
-		when(vertex.getUserDefinedOperatorIDs()).thenReturn(Collections.<OperatorID>singletonList(null));
-
-		for (ExecutionVertex v : vertices) {
-			when(v.getJobVertex()).thenReturn(vertex);
-		}
-		return vertex;
-	}
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTestingUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTestingUtils.java
new file mode 100644
index 0000000..9578ac7
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTestingUtils.java
@@ -0,0 +1,472 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.api.common.time.Time;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.runtime.execution.ExecutionState;
+import org.apache.flink.runtime.executiongraph.Execution;
+import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
+import org.apache.flink.runtime.executiongraph.ExecutionVertex;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
+import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.OperatorStreamStateHandle;
+import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
+import org.apache.flink.runtime.testutils.CommonTestUtils;
+import org.apache.flink.util.InstantiationUtil;
+import org.apache.flink.util.Preconditions;
+
+import org.junit.Assert;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.UUID;
+import java.util.concurrent.Executor;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.when;
+
+/**
+ * Testing utils for checkpoint coordinator.
+ */
+public class CheckpointCoordinatorTestingUtils {
+
+	public static OperatorStateHandle generatePartitionableStateHandle(
+		JobVertexID jobVertexID,
+		int index,
+		int namedStates,
+		int partitionsPerState,
+		boolean rawState) throws IOException {
+
+		Map<String, List<? extends Serializable>> statesListsMap = new HashMap<>(namedStates);
+
+		for (int i = 0; i < namedStates; ++i) {
+			List<Integer> testStatesLists = new ArrayList<>(partitionsPerState);
+			// generate state
+			int seed = jobVertexID.hashCode() * index + i * namedStates;
+			if (rawState) {
+				seed = (seed + 1) * 31;
+			}
+			Random random = new Random(seed);
+			for (int j = 0; j < partitionsPerState; ++j) {
+				int simulatedStateValue = random.nextInt();
+				testStatesLists.add(simulatedStateValue);
+			}
+			statesListsMap.put("state-" + i, testStatesLists);
+		}
+
+		return generatePartitionableStateHandle(statesListsMap);
+	}
+
+	static ChainedStateHandle<OperatorStateHandle> generateChainedPartitionableStateHandle(
+		JobVertexID jobVertexID,
+		int index,
+		int namedStates,
+		int partitionsPerState,
+		boolean rawState) throws IOException {
+
+		Map<String, List<? extends Serializable>> statesListsMap = new HashMap<>(namedStates);
+
+		for (int i = 0; i < namedStates; ++i) {
+			List<Integer> testStatesLists = new ArrayList<>(partitionsPerState);
+			// generate state
+			int seed = jobVertexID.hashCode() * index + i * namedStates;
+			if (rawState) {
+				seed = (seed + 1) * 31;
+			}
+			Random random = new Random(seed);
+			for (int j = 0; j < partitionsPerState; ++j) {
+				int simulatedStateValue = random.nextInt();
+				testStatesLists.add(simulatedStateValue);
+			}
+			statesListsMap.put("state-" + i, testStatesLists);
+		}
+
+		return ChainedStateHandle.wrapSingleHandle(generatePartitionableStateHandle(statesListsMap));
+	}
+
+	static OperatorStateHandle generatePartitionableStateHandle(
+		Map<String, List<? extends Serializable>> states) throws IOException {
+
+		List<List<? extends Serializable>> namedStateSerializables = new ArrayList<>(states.size());
+
+		for (Map.Entry<String, List<? extends Serializable>> entry : states.entrySet()) {
+			namedStateSerializables.add(entry.getValue());
+		}
+
+		Tuple2<byte[], List<long[]>> serializationWithOffsets = serializeTogetherAndTrackOffsets(namedStateSerializables);
+
+		Map<String, OperatorStateHandle.StateMetaInfo> offsetsMap = new HashMap<>(states.size());
+
+		int idx = 0;
+		for (Map.Entry<String, List<? extends Serializable>> entry : states.entrySet()) {
+			offsetsMap.put(
+				entry.getKey(),
+				new OperatorStateHandle.StateMetaInfo(
+					serializationWithOffsets.f1.get(idx),
+					OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
+			++idx;
+		}
+
+		ByteStreamStateHandle streamStateHandle = new ByteStreamStateHandle(
+			String.valueOf(UUID.randomUUID()),
+			serializationWithOffsets.f0);
+
+		return new OperatorStreamStateHandle(offsetsMap, streamStateHandle);
+	}
+
+	static Tuple2<byte[], List<long[]>> serializeTogetherAndTrackOffsets(
+		List<List<? extends Serializable>> serializables) throws IOException {
+
+		List<long[]> offsets = new ArrayList<>(serializables.size());
+		List<byte[]> serializedGroupValues = new ArrayList<>();
+
+		int runningGroupsOffset = 0;
+		for (List<? extends Serializable> list : serializables) {
+
+			long[] currentOffsets = new long[list.size()];
+			offsets.add(currentOffsets);
+
+			for (int i = 0; i < list.size(); ++i) {
+				currentOffsets[i] = runningGroupsOffset;
+				byte[] serializedValue = InstantiationUtil.serializeObject(list.get(i));
+				serializedGroupValues.add(serializedValue);
+				runningGroupsOffset += serializedValue.length;
+			}
+		}
+
+		//write all generated values in a single byte array, which is index by groupOffsetsInFinalByteArray
+		byte[] allSerializedValuesConcatenated = new byte[runningGroupsOffset];
+		runningGroupsOffset = 0;
+		for (byte[] serializedGroupValue : serializedGroupValues) {
+			System.arraycopy(
+				serializedGroupValue,
+				0,
+				allSerializedValuesConcatenated,
+				runningGroupsOffset,
+				serializedGroupValue.length);
+			runningGroupsOffset += serializedGroupValue.length;
+		}
+		return new Tuple2<>(allSerializedValuesConcatenated, offsets);
+	}
+
+	public static void verifyStateRestore(
+		JobVertexID jobVertexID, ExecutionJobVertex executionJobVertex,
+		List<KeyGroupRange> keyGroupPartitions) throws Exception {
+
+		for (int i = 0; i < executionJobVertex.getParallelism(); i++) {
+
+			JobManagerTaskRestore taskRestore = executionJobVertex.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskRestore();
+			Assert.assertEquals(1L, taskRestore.getRestoreCheckpointId());
+			TaskStateSnapshot stateSnapshot = taskRestore.getTaskStateSnapshot();
+
+			OperatorSubtaskState operatorState = stateSnapshot.getSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID));
+
+			ChainedStateHandle<OperatorStateHandle> expectedOpStateBackend =
+				generateChainedPartitionableStateHandle(jobVertexID, i, 2, 8, false);
+
+			assertTrue(CommonTestUtils.isStreamContentEqual(
+				expectedOpStateBackend.get(0).openInputStream(),
+				operatorState.getManagedOperatorState().iterator().next().openInputStream()));
+
+			KeyGroupsStateHandle expectPartitionedKeyGroupState = generateKeyGroupState(
+				jobVertexID, keyGroupPartitions.get(i), false);
+			compareKeyedState(Collections.singletonList(expectPartitionedKeyGroupState), operatorState.getManagedKeyedState());
+		}
+	}
+
+	static void compareKeyedState(
+		Collection<KeyGroupsStateHandle> expectPartitionedKeyGroupState,
+		Collection<? extends KeyedStateHandle> actualPartitionedKeyGroupState) throws Exception {
+
+		KeyGroupsStateHandle expectedHeadOpKeyGroupStateHandle = expectPartitionedKeyGroupState.iterator().next();
+		int expectedTotalKeyGroups = expectedHeadOpKeyGroupStateHandle.getKeyGroupRange().getNumberOfKeyGroups();
+		int actualTotalKeyGroups = 0;
+		for (KeyedStateHandle keyedStateHandle: actualPartitionedKeyGroupState) {
+			assertTrue(keyedStateHandle instanceof KeyGroupsStateHandle);
+
+			actualTotalKeyGroups += keyedStateHandle.getKeyGroupRange().getNumberOfKeyGroups();
+		}
+
+		assertEquals(expectedTotalKeyGroups, actualTotalKeyGroups);
+
+		try (FSDataInputStream inputStream = expectedHeadOpKeyGroupStateHandle.openInputStream()) {
+			for (int groupId : expectedHeadOpKeyGroupStateHandle.getKeyGroupRange()) {
+				long offset = expectedHeadOpKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
+				inputStream.seek(offset);
+				int expectedKeyGroupState =
+					InstantiationUtil.deserializeObject(inputStream, Thread.currentThread().getContextClassLoader());
+				for (KeyedStateHandle oneActualKeyedStateHandle : actualPartitionedKeyGroupState) {
+
+					assertTrue(oneActualKeyedStateHandle instanceof KeyGroupsStateHandle);
+
+					KeyGroupsStateHandle oneActualKeyGroupStateHandle = (KeyGroupsStateHandle) oneActualKeyedStateHandle;
+					if (oneActualKeyGroupStateHandle.getKeyGroupRange().contains(groupId)) {
+						long actualOffset = oneActualKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
+						try (FSDataInputStream actualInputStream = oneActualKeyGroupStateHandle.openInputStream()) {
+							actualInputStream.seek(actualOffset);
+							int actualGroupState = InstantiationUtil.
+								deserializeObject(actualInputStream, Thread.currentThread().getContextClassLoader());
+							assertEquals(expectedKeyGroupState, actualGroupState);
+						}
+					}
+				}
+			}
+		}
+	}
+
+	static void comparePartitionableState(
+		List<ChainedStateHandle<OperatorStateHandle>> expected,
+		List<List<Collection<OperatorStateHandle>>> actual) throws Exception {
+
+		List<String> expectedResult = new ArrayList<>();
+		for (ChainedStateHandle<OperatorStateHandle> chainedStateHandle : expected) {
+			for (int i = 0; i < chainedStateHandle.getLength(); ++i) {
+				OperatorStateHandle operatorStateHandle = chainedStateHandle.get(i);
+				collectResult(i, operatorStateHandle, expectedResult);
+			}
+		}
+		Collections.sort(expectedResult);
+
+		List<String> actualResult = new ArrayList<>();
+		for (List<Collection<OperatorStateHandle>> collectionList : actual) {
+			if (collectionList != null) {
+				for (int i = 0; i < collectionList.size(); ++i) {
+					Collection<OperatorStateHandle> stateHandles = collectionList.get(i);
+					Assert.assertNotNull(stateHandles);
+					for (OperatorStateHandle operatorStateHandle : stateHandles) {
+						collectResult(i, operatorStateHandle, actualResult);
+					}
+				}
+			}
+		}
+
+		Collections.sort(actualResult);
+		Assert.assertEquals(expectedResult, actualResult);
+	}
+
+	static void collectResult(int opIdx, OperatorStateHandle operatorStateHandle, List<String> resultCollector) throws Exception {
+		try (FSDataInputStream in = operatorStateHandle.openInputStream()) {
+			for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> entry : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) {
+				for (long offset : entry.getValue().getOffsets()) {
+					in.seek(offset);
+					Integer state = InstantiationUtil.
+						deserializeObject(in, Thread.currentThread().getContextClassLoader());
+					resultCollector.add(opIdx + " : " + entry.getKey() + " : " + state);
+				}
+			}
+		}
+	}
+
+	static ExecutionJobVertex mockExecutionJobVertex(
+		JobVertexID jobVertexID,
+		int parallelism,
+		int maxParallelism) {
+
+		return mockExecutionJobVertex(
+			jobVertexID,
+			Collections.singletonList(OperatorID.fromJobVertexID(jobVertexID)),
+			parallelism,
+			maxParallelism
+		);
+	}
+
+	static ExecutionJobVertex mockExecutionJobVertex(
+		JobVertexID jobVertexID,
+		List<OperatorID> jobVertexIDs,
+		int parallelism,
+		int maxParallelism) {
+		final ExecutionJobVertex executionJobVertex = mock(ExecutionJobVertex.class);
+
+		ExecutionVertex[] executionVertices = new ExecutionVertex[parallelism];
+
+		for (int i = 0; i < parallelism; i++) {
+			executionVertices[i] = mockExecutionVertex(
+				new ExecutionAttemptID(),
+				jobVertexID,
+				jobVertexIDs,
+				parallelism,
+				maxParallelism,
+				ExecutionState.RUNNING);
+
+			when(executionVertices[i].getParallelSubtaskIndex()).thenReturn(i);
+		}
+
+		when(executionJobVertex.getJobVertexId()).thenReturn(jobVertexID);
+		when(executionJobVertex.getTaskVertices()).thenReturn(executionVertices);
+		when(executionJobVertex.getParallelism()).thenReturn(parallelism);
+		when(executionJobVertex.getMaxParallelism()).thenReturn(maxParallelism);
+		when(executionJobVertex.isMaxParallelismConfigured()).thenReturn(true);
+		when(executionJobVertex.getOperatorIDs()).thenReturn(jobVertexIDs);
+		when(executionJobVertex.getUserDefinedOperatorIDs()).thenReturn(Arrays.asList(new OperatorID[jobVertexIDs.size()]));
+
+		return executionJobVertex;
+	}
+
+	static ExecutionVertex mockExecutionVertex(ExecutionAttemptID attemptID) {
+		JobVertexID jobVertexID = new JobVertexID();
+		return mockExecutionVertex(
+			attemptID,
+			jobVertexID,
+			Collections.singletonList(OperatorID.fromJobVertexID(jobVertexID)),
+			1,
+			1,
+			ExecutionState.RUNNING);
+	}
+
+	static ExecutionVertex mockExecutionVertex(
+		ExecutionAttemptID attemptID,
+		JobVertexID jobVertexID,
+		List<OperatorID> jobVertexIDs,
+		int parallelism,
+		int maxParallelism,
+		ExecutionState state,
+		ExecutionState ... successiveStates) {
+
+		ExecutionVertex vertex = mock(ExecutionVertex.class);
+
+		final Execution exec = spy(new Execution(
+			mock(Executor.class),
+			vertex,
+			1,
+			1L,
+			1L,
+			Time.milliseconds(500L)
+		));
+		when(exec.getAttemptId()).thenReturn(attemptID);
+		when(exec.getState()).thenReturn(state, successiveStates);
+
+		when(vertex.getJobvertexId()).thenReturn(jobVertexID);
+		when(vertex.getCurrentExecutionAttempt()).thenReturn(exec);
+		when(vertex.getTotalNumberOfParallelSubtasks()).thenReturn(parallelism);
+		when(vertex.getMaxParallelism()).thenReturn(maxParallelism);
+
+		ExecutionJobVertex jobVertex = mock(ExecutionJobVertex.class);
+		when(jobVertex.getOperatorIDs()).thenReturn(jobVertexIDs);
+
+		when(vertex.getJobVertex()).thenReturn(jobVertex);
+
+		return vertex;
+	}
+
+	static TaskStateSnapshot mockSubtaskState(
+		JobVertexID jobVertexID,
+		int index,
+		KeyGroupRange keyGroupRange) throws IOException {
+
+		OperatorStateHandle partitionableState = generatePartitionableStateHandle(jobVertexID, index, 2, 8, false);
+		KeyGroupsStateHandle partitionedKeyGroupState = generateKeyGroupState(jobVertexID, keyGroupRange, false);
+
+		TaskStateSnapshot subtaskStates = spy(new TaskStateSnapshot());
+		OperatorSubtaskState subtaskState = spy(new OperatorSubtaskState(
+			partitionableState, null, partitionedKeyGroupState, null)
+		);
+
+		subtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID), subtaskState);
+
+		return subtaskStates;
+	}
+
+	public static KeyGroupsStateHandle generateKeyGroupState(
+		JobVertexID jobVertexID,
+		KeyGroupRange keyGroupPartition, boolean rawState) throws IOException {
+
+		List<Integer> testStatesLists = new ArrayList<>(keyGroupPartition.getNumberOfKeyGroups());
+
+		// generate state for one keygroup
+		for (int keyGroupIndex : keyGroupPartition) {
+			int vertexHash = jobVertexID.hashCode();
+			int seed = rawState ? (vertexHash * (31 + keyGroupIndex)) : (vertexHash + keyGroupIndex);
+			Random random = new Random(seed);
+			int simulatedStateValue = random.nextInt();
+			testStatesLists.add(simulatedStateValue);
+		}
+
+		return generateKeyGroupState(keyGroupPartition, testStatesLists);
+	}
+
+	public static KeyGroupsStateHandle generateKeyGroupState(
+		KeyGroupRange keyGroupRange,
+		List<? extends Serializable> states) throws IOException {
+
+		Preconditions.checkArgument(keyGroupRange.getNumberOfKeyGroups() == states.size());
+
+		Tuple2<byte[], List<long[]>> serializedDataWithOffsets =
+			serializeTogetherAndTrackOffsets(Collections.<List<? extends Serializable>>singletonList(states));
+
+		KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(keyGroupRange, serializedDataWithOffsets.f1.get(0));
+
+		ByteStreamStateHandle allSerializedStatesHandle = new ByteStreamStateHandle(
+			String.valueOf(UUID.randomUUID()),
+			serializedDataWithOffsets.f0);
+
+		return new KeyGroupsStateHandle(keyGroupRangeOffsets, allSerializedStatesHandle);
+	}
+
+	static Execution mockExecution() {
+		Execution mock = mock(Execution.class);
+		when(mock.getAttemptId()).thenReturn(new ExecutionAttemptID());
+		when(mock.getState()).thenReturn(ExecutionState.RUNNING);
+		return mock;
+	}
+
+	static ExecutionVertex mockExecutionVertex(Execution execution, JobVertexID vertexId, int subtask, int parallelism) {
+		ExecutionVertex mock = mock(ExecutionVertex.class);
+		when(mock.getJobvertexId()).thenReturn(vertexId);
+		when(mock.getParallelSubtaskIndex()).thenReturn(subtask);
+		when(mock.getCurrentExecutionAttempt()).thenReturn(execution);
+		when(mock.getTotalNumberOfParallelSubtasks()).thenReturn(parallelism);
+		when(mock.getMaxParallelism()).thenReturn(parallelism);
+		return mock;
+	}
+
+	static ExecutionJobVertex mockExecutionJobVertex(JobVertexID id, ExecutionVertex[] vertices) {
+		ExecutionJobVertex vertex = mock(ExecutionJobVertex.class);
+		when(vertex.getParallelism()).thenReturn(vertices.length);
+		when(vertex.getMaxParallelism()).thenReturn(vertices.length);
+		when(vertex.getJobVertexId()).thenReturn(id);
+		when(vertex.getTaskVertices()).thenReturn(vertices);
+		when(vertex.getOperatorIDs()).thenReturn(Collections.singletonList(OperatorID.fromJobVertexID(id)));
+		when(vertex.getUserDefinedOperatorIDs()).thenReturn(Collections.<OperatorID>singletonList(null));
+
+		for (ExecutionVertex v : vertices) {
+			when(v.getJobVertex()).thenReturn(vertex);
+		}
+		return vertex;
+	}
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
index af2e5a3..42849b5 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
@@ -79,9 +79,9 @@ public class CheckpointStateRestoreTest {
 	public void testSetState() {
 		try {
 
-			KeyGroupRange keyGroupRange = KeyGroupRange.of(0,0);
+			KeyGroupRange keyGroupRange = KeyGroupRange.of(0, 0);
 			List<SerializableObject> testStates = Collections.singletonList(new SerializableObject());
-			final KeyedStateHandle serializedKeyGroupStates = CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, testStates);
+			final KeyedStateHandle serializedKeyGroupStates = CheckpointCoordinatorTestingUtils.generateKeyGroupState(keyGroupRange, testStates);
 
 			final JobID jid = new JobID();
 			final JobVertexID statefulId = new JobVertexID();
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
index f18b4c8..0fb46e0 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
@@ -21,7 +21,7 @@ package org.apache.flink.runtime.messages;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.testutils.CommonTestUtils;
-import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTest;
+import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
@@ -45,6 +45,9 @@ import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.fail;
 
+/**
+ * Tests for checkpoint messages.
+ */
 public class CheckpointMessagesTest {
 
 	@Test
@@ -69,15 +72,15 @@ public class CheckpointMessagesTest {
 			AcknowledgeCheckpoint noState = new AcknowledgeCheckpoint(
 					new JobID(), new ExecutionAttemptID(), 569345L);
 
-			KeyGroupRange keyGroupRange = KeyGroupRange.of(42,42);
+			KeyGroupRange keyGroupRange = KeyGroupRange.of(42, 42);
 
 			TaskStateSnapshot checkpointStateHandles = new TaskStateSnapshot();
 			checkpointStateHandles.putSubtaskStateByOperatorID(
 				new OperatorID(),
 				new OperatorSubtaskState(
-					CheckpointCoordinatorTest.generatePartitionableStateHandle(new JobVertexID(), 0, 2, 8, false),
+					CheckpointCoordinatorTestingUtils.generatePartitionableStateHandle(new JobVertexID(), 0, 2, 8, false),
 					null,
-					CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, Collections.singletonList(new MyHandle())),
+					CheckpointCoordinatorTestingUtils.generateKeyGroupState(keyGroupRange, Collections.singletonList(new MyHandle())),
 					null
 				)
 			);
@@ -105,7 +108,7 @@ public class CheckpointMessagesTest {
 		assertNotNull(copy.toString());
 	}
 
-	public static class MyHandle implements StreamStateHandle {
+	private static class MyHandle implements StreamStateHandle {
 
 		private static final long serialVersionUID = 8128146204128728332L;