You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by al...@apache.org on 2016/08/31 17:28:42 UTC

[24/27] flink git commit: [FLINK-4380] Add tests for new Key-Group/Max-Parallelism

[FLINK-4380] Add tests for new Key-Group/Max-Parallelism

This tests the rescaling features in CheckpointCoordinator and
SavepointCoordinator.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/516ad011
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/516ad011
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/516ad011

Branch: refs/heads/master
Commit: 516ad011865ca5beece273ca9b985e2861b3435a
Parents: 847ead0
Author: Till Rohrmann <tr...@apache.org>
Authored: Thu Aug 11 12:14:18 2016 +0200
Committer: Aljoscha Krettek <al...@gmail.com>
Committed: Wed Aug 31 19:10:01 2016 +0200

----------------------------------------------------------------------
 .../checkpoint/CheckpointCoordinatorTest.java   | 733 ++++++++++++++++++-
 .../runtime/tasks/OneInputStreamTaskTest.java   | 280 ++++++-
 .../test/checkpointing/RescalingITCase.java     |   1 -
 3 files changed, 1007 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/516ad011/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index 50330fa..495dced 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -18,28 +18,45 @@
 
 package org.apache.flink.runtime.checkpoint;
 
+import com.google.common.collect.Iterables;
 import org.apache.flink.api.common.JobID;
+import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.runtime.checkpoint.savepoint.HeapSavepointStore;
 import org.apache.flink.runtime.checkpoint.stats.DisabledCheckpointStatsTracker;
 import org.apache.flink.runtime.execution.ExecutionState;
 import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.NotifyCheckpointComplete;
 import org.apache.flink.runtime.messages.checkpoint.TriggerCheckpoint;
+import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
+import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
+import org.apache.flink.util.InstantiationUtil;
+import org.apache.flink.util.Preconditions;
+import org.junit.Assert;
 import org.junit.Test;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
+import scala.concurrent.ExecutionContext;
 import scala.concurrent.Future;
 
+import java.io.IOException;
 import java.io.Serializable;
 import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.Random;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 
@@ -47,12 +64,14 @@ import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotEquals;
 import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
@@ -124,7 +143,7 @@ public class CheckpointCoordinatorTest {
 			final ExecutionAttemptID triggerAttemptID1 = new ExecutionAttemptID();
 			final ExecutionAttemptID triggerAttemptID2 = new ExecutionAttemptID();
 			ExecutionVertex triggerVertex1 = mockExecutionVertex(triggerAttemptID1);
-			ExecutionVertex triggerVertex2 = mockExecutionVertex(triggerAttemptID2, ExecutionState.FINISHED);
+			ExecutionVertex triggerVertex2 = mockExecutionVertex(triggerAttemptID2, new JobVertexID(), 1, 1, ExecutionState.FINISHED);
 
 			// create some mock Execution vertices that need to ack the checkpoint
 			final ExecutionAttemptID ackAttemptID1 = new ExecutionAttemptID();
@@ -1529,7 +1548,7 @@ public class CheckpointCoordinatorTest {
 			coord.startCheckpointScheduler();
 
 			// after a while, there should be exactly as many checkpoints
-			// as concurrently permitted 
+			// as concurrently permitted
 			long now = System.currentTimeMillis();
 			long timeout = now + 60000;
 			long minDuration = now + 100;
@@ -1622,7 +1641,7 @@ public class CheckpointCoordinatorTest {
 			}
 			while (System.currentTimeMillis() < timeout && 
 					coord.getNumberOfPendingCheckpoints() == 0);
-			
+
 			assertTrue(coord.getNumberOfPendingCheckpoints() > 0);
 		}
 		catch (Exception e) {
@@ -1738,4 +1757,712 @@ public class CheckpointCoordinatorTest {
 
 		return vertex;
 	}
+/**
+	 * Tests that the checkpointed partitioned and non-partitioned state is assigned properly to
+	 * the {@link Execution} upon recovery.
+	 *
+	 * @throws Exception
+	 */
+	@Test
+	public void testRestoreLatestCheckpointedState() throws Exception {
+		final JobID jid = new JobID();
+		final long timestamp = System.currentTimeMillis();
+
+		final JobVertexID jobVertexID1 = new JobVertexID();
+		final JobVertexID jobVertexID2 = new JobVertexID();
+		int parallelism1 = 3;
+		int parallelism2 = 2;
+		int maxParallelism1 = 42;
+		int maxParallelism2 = 13;
+
+		final ExecutionJobVertex jobVertex1 = mockExecutionJobVertex(
+			jobVertexID1,
+			parallelism1,
+			maxParallelism1);
+		final ExecutionJobVertex jobVertex2 = mockExecutionJobVertex(
+			jobVertexID2,
+			parallelism2,
+			maxParallelism2);
+
+		List<ExecutionVertex> allExecutionVertices = new ArrayList<>(parallelism1 + parallelism2);
+
+		allExecutionVertices.addAll(Arrays.asList(jobVertex1.getTaskVertices()));
+		allExecutionVertices.addAll(Arrays.asList(jobVertex2.getTaskVertices()));
+
+		ExecutionVertex[] arrayExecutionVertices = allExecutionVertices.toArray(new ExecutionVertex[0]);
+
+		// set up the coordinator and validate the initial state
+		CheckpointCoordinator coord = new CheckpointCoordinator(
+			jid,
+			600000,
+			600000,
+				0,
+				Integer.MAX_VALUE,
+			arrayExecutionVertices,
+			arrayExecutionVertices,
+			arrayExecutionVertices,
+			cl,
+			new StandaloneCheckpointIDCounter(),
+			new StandaloneCompletedCheckpointStore(1, cl),
+			new HeapSavepointStore(),
+			new DisabledCheckpointStatsTracker());
+
+		// trigger the checkpoint
+		coord.triggerCheckpoint(timestamp);
+
+		assertTrue(coord.getPendingCheckpoints().keySet().size() == 1);
+		long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet());
+
+		List<KeyGroupRange> keyGroupPartitions1 = coord.createKeyGroupPartitions(maxParallelism1, parallelism1);
+		List<KeyGroupRange> keyGroupPartitions2 = coord.createKeyGroupPartitions(maxParallelism2, parallelism2);
+
+		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
+			ChainedStateHandle<StreamStateHandle> nonPartitionedState = generateStateForVertex(jobVertexID1, index);
+			List<KeyGroupsStateHandle> partitionedKeyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index));
+
+			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
+				jid,
+				jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+				checkpointId,
+				nonPartitionedState,
+				partitionedKeyGroupState);
+
+			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
+		}
+
+
+		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
+			ChainedStateHandle<StreamStateHandle> nonPartitionedState = generateStateForVertex(jobVertexID2, index);
+			List<KeyGroupsStateHandle> partitionedKeyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index));
+
+			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
+				jid,
+				jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+				checkpointId,
+				nonPartitionedState,
+				partitionedKeyGroupState);
+
+			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
+		}
+
+		List<CompletedCheckpoint> completedCheckpoints = coord.getSuccessfulCheckpoints();
+
+		assertEquals(1, completedCheckpoints.size());
+
+		Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>();
+
+		tasks.put(jobVertexID1, jobVertex1);
+		tasks.put(jobVertexID2, jobVertex2);
+
+		coord.restoreLatestCheckpointedState(tasks, true, true);
+
+		// verify the restored state
+		verifiyStateRestore(jobVertexID1, jobVertex1, keyGroupPartitions1);
+		verifiyStateRestore(jobVertexID2, jobVertex2, keyGroupPartitions2);
+	}
+
+	/**
+	 * Tests that the checkpoint restoration fails if the max parallelism of the job vertices has
+	 * changed.
+	 *
+	 * @throws Exception
+	 */
+	@Test(expected=IllegalStateException.class)
+	public void testRestoreLatestCheckpointFailureWhenMaxParallelismChanges() throws Exception {
+		final JobID jid = new JobID();
+		final long timestamp = System.currentTimeMillis();
+
+		final JobVertexID jobVertexID1 = new JobVertexID();
+		final JobVertexID jobVertexID2 = new JobVertexID();
+		int parallelism1 = 3;
+		int parallelism2 = 2;
+		int maxParallelism1 = 42;
+		int maxParallelism2 = 13;
+
+		final ExecutionJobVertex jobVertex1 = mockExecutionJobVertex(
+			jobVertexID1,
+			parallelism1,
+			maxParallelism1);
+		final ExecutionJobVertex jobVertex2 = mockExecutionJobVertex(
+			jobVertexID2,
+			parallelism2,
+			maxParallelism2);
+
+		List<ExecutionVertex> allExecutionVertices = new ArrayList<>(parallelism1 + parallelism2);
+
+		allExecutionVertices.addAll(Arrays.asList(jobVertex1.getTaskVertices()));
+		allExecutionVertices.addAll(Arrays.asList(jobVertex2.getTaskVertices()));
+
+		ExecutionVertex[] arrayExecutionVertices = allExecutionVertices.toArray(new ExecutionVertex[0]);
+
+		// set up the coordinator and validate the initial state
+		CheckpointCoordinator coord = new CheckpointCoordinator(
+			jid,
+			600000,
+			600000,
+			0,
+			Integer.MAX_VALUE,
+			arrayExecutionVertices,
+			arrayExecutionVertices,
+			arrayExecutionVertices,
+			cl,
+			new StandaloneCheckpointIDCounter(),
+			new StandaloneCompletedCheckpointStore(1, cl),
+			new HeapSavepointStore(),
+			new DisabledCheckpointStatsTracker());
+
+		// trigger the checkpoint
+		coord.triggerCheckpoint(timestamp);
+
+		assertTrue(coord.getPendingCheckpoints().keySet().size() == 1);
+		long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet());
+
+		List<KeyGroupRange> keyGroupPartitions1 = coord.createKeyGroupPartitions(maxParallelism1, parallelism1);
+		List<KeyGroupRange> keyGroupPartitions2 = coord.createKeyGroupPartitions(maxParallelism2, parallelism2);
+
+		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
+			ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID1, index);
+			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index));
+
+			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
+				jid,
+				jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+				checkpointId,
+				valueSizeTuple,
+				keyGroupState);
+
+			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
+		}
+
+
+		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
+			ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID2, index);
+			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index));
+
+			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
+				jid,
+				jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+				checkpointId,
+				valueSizeTuple,
+				keyGroupState);
+
+			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
+		}
+
+		List<CompletedCheckpoint> completedCheckpoints = coord.getSuccessfulCheckpoints();
+
+		assertEquals(1, completedCheckpoints.size());
+
+		Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>();
+
+		int newMaxParallelism1 = 20;
+		int newMaxParallelism2 = 42;
+
+		final ExecutionJobVertex newJobVertex1 = mockExecutionJobVertex(
+			jobVertexID1,
+			parallelism1,
+			newMaxParallelism1);
+
+		final ExecutionJobVertex newJobVertex2 = mockExecutionJobVertex(
+			jobVertexID2,
+			parallelism2,
+			newMaxParallelism2);
+
+		tasks.put(jobVertexID1, newJobVertex1);
+		tasks.put(jobVertexID2, newJobVertex2);
+
+		coord.restoreLatestCheckpointedState(tasks, true, true);
+
+		fail("The restoration should have failed because the max parallelism changed.");
+	}
+
+	/**
+	 * Tests that the checkpoint restoration fails if the parallelism of a job vertices with
+	 * non-partitioned state has changed.
+	 *
+	 * @throws Exception
+	 */
+	@Test(expected=IllegalStateException.class)
+	public void testRestoreLatestCheckpointFailureWhenParallelismChanges() throws Exception {
+		final JobID jid = new JobID();
+		final long timestamp = System.currentTimeMillis();
+
+		final JobVertexID jobVertexID1 = new JobVertexID();
+		final JobVertexID jobVertexID2 = new JobVertexID();
+		int parallelism1 = 3;
+		int parallelism2 = 2;
+		int maxParallelism1 = 42;
+		int maxParallelism2 = 13;
+
+		final ExecutionJobVertex jobVertex1 = mockExecutionJobVertex(
+			jobVertexID1,
+			parallelism1,
+			maxParallelism1);
+		final ExecutionJobVertex jobVertex2 = mockExecutionJobVertex(
+			jobVertexID2,
+			parallelism2,
+			maxParallelism2);
+
+		List<ExecutionVertex> allExecutionVertices = new ArrayList<>(parallelism1 + parallelism2);
+
+		allExecutionVertices.addAll(Arrays.asList(jobVertex1.getTaskVertices()));
+		allExecutionVertices.addAll(Arrays.asList(jobVertex2.getTaskVertices()));
+
+		ExecutionVertex[] arrayExecutionVertices = allExecutionVertices.toArray(new ExecutionVertex[0]);
+
+		// set up the coordinator and validate the initial state
+		CheckpointCoordinator coord = new CheckpointCoordinator(
+			jid,
+			600000,
+			600000,
+			0,
+			Integer.MAX_VALUE,
+			arrayExecutionVertices,
+			arrayExecutionVertices,
+			arrayExecutionVertices,
+			cl,
+			new StandaloneCheckpointIDCounter(),
+			new StandaloneCompletedCheckpointStore(1, cl),
+			new HeapSavepointStore(),
+			new DisabledCheckpointStatsTracker());
+
+		// trigger the checkpoint
+		coord.triggerCheckpoint(timestamp);
+
+		assertTrue(coord.getPendingCheckpoints().keySet().size() == 1);
+		long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet());
+
+		List<KeyGroupRange> keyGroupPartitions1 = coord.createKeyGroupPartitions(maxParallelism1, parallelism1);
+		List<KeyGroupRange> keyGroupPartitions2 = coord.createKeyGroupPartitions(maxParallelism2, parallelism2);
+
+		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
+			ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID1, index);
+			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(
+					jobVertexID1, keyGroupPartitions1.get(index));
+
+			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
+				jid,
+				jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+				checkpointId,
+				valueSizeTuple,
+				keyGroupState);
+
+			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
+		}
+
+
+		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
+
+			ChainedStateHandle<StreamStateHandle> state = generateStateForVertex(jobVertexID2, index);
+			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(
+					jobVertexID2, keyGroupPartitions2.get(index));
+
+			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
+					jid,
+					jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+					checkpointId,
+					state,
+					keyGroupState);
+
+			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
+		}
+
+		List<CompletedCheckpoint> completedCheckpoints = coord.getSuccessfulCheckpoints();
+
+		assertEquals(1, completedCheckpoints.size());
+
+		Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>();
+
+		int newParallelism1 = 4;
+		int newParallelism2 = 3;
+
+		final ExecutionJobVertex newJobVertex1 = mockExecutionJobVertex(
+			jobVertexID1,
+			newParallelism1,
+			maxParallelism1);
+
+		final ExecutionJobVertex newJobVertex2 = mockExecutionJobVertex(
+			jobVertexID2,
+			newParallelism2,
+			maxParallelism2);
+
+		tasks.put(jobVertexID1, newJobVertex1);
+		tasks.put(jobVertexID2, newJobVertex2);
+
+		coord.restoreLatestCheckpointedState(tasks, true, true);
+
+		fail("The restoration should have failed because the parallelism of an vertex with " +
+			"non-partitioned state changed.");
+	}
+
+	/**
+	 * Tests the checkpoint restoration with changing parallelism of job vertex with partitioned
+	 * state.
+	 *
+	 * @throws Exception
+	 */
+	@Test
+	public void testRestoreLatestCheckpointedStateWithChangingParallelism() throws Exception {
+		final JobID jid = new JobID();
+		final long timestamp = System.currentTimeMillis();
+
+		final JobVertexID jobVertexID1 = new JobVertexID();
+		final JobVertexID jobVertexID2 = new JobVertexID();
+		int parallelism1 = 3;
+		int parallelism2 = 2;
+		int maxParallelism1 = 42;
+		int maxParallelism2 = 13;
+
+		final ExecutionJobVertex jobVertex1 = mockExecutionJobVertex(
+				jobVertexID1,
+				parallelism1,
+				maxParallelism1);
+		final ExecutionJobVertex jobVertex2 = mockExecutionJobVertex(
+				jobVertexID2,
+				parallelism2,
+				maxParallelism2);
+
+		List<ExecutionVertex> allExecutionVertices = new ArrayList<>(parallelism1 + parallelism2);
+
+		allExecutionVertices.addAll(Arrays.asList(jobVertex1.getTaskVertices()));
+		allExecutionVertices.addAll(Arrays.asList(jobVertex2.getTaskVertices()));
+
+		ExecutionVertex[] arrayExecutionVertices = allExecutionVertices.toArray(new ExecutionVertex[0]);
+
+		// set up the coordinator and validate the initial state
+		CheckpointCoordinator coord = new CheckpointCoordinator(
+				jid,
+				600000,
+				600000,
+				0,
+				Integer.MAX_VALUE,
+				arrayExecutionVertices,
+				arrayExecutionVertices,
+				arrayExecutionVertices,
+				cl,
+				new StandaloneCheckpointIDCounter(),
+				new StandaloneCompletedCheckpointStore(1, cl),
+				new HeapSavepointStore(),
+				new DisabledCheckpointStatsTracker());
+
+		// trigger the checkpoint
+		coord.triggerCheckpoint(timestamp);
+
+		assertTrue(coord.getPendingCheckpoints().keySet().size() == 1);
+		long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet());
+
+		List<KeyGroupRange> keyGroupPartitions1 = coord.createKeyGroupPartitions(maxParallelism1, parallelism1);
+		List<KeyGroupRange> keyGroupPartitions2 = coord.createKeyGroupPartitions(maxParallelism2, parallelism2);
+
+		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
+			ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID1, index);
+			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index));
+
+			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
+					jid,
+					jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+					checkpointId,
+					valueSizeTuple,
+					keyGroupState);
+
+			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
+		}
+
+
+		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
+			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index));
+
+			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
+					jid,
+					jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+					checkpointId,
+					null,
+					keyGroupState);
+
+			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
+		}
+
+		List<CompletedCheckpoint> completedCheckpoints = coord.getSuccessfulCheckpoints();
+
+		assertEquals(1, completedCheckpoints.size());
+
+		Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>();
+
+		int newParallelism2 = 13;
+
+		List<KeyGroupRange> newKeyGroupPartitions2 = coord.createKeyGroupPartitions(maxParallelism2, newParallelism2);
+
+		final ExecutionJobVertex newJobVertex1 = mockExecutionJobVertex(
+				jobVertexID1,
+				parallelism1,
+				maxParallelism1);
+
+		final ExecutionJobVertex newJobVertex2 = mockExecutionJobVertex(
+				jobVertexID2,
+				newParallelism2,
+				maxParallelism2);
+
+		tasks.put(jobVertexID1, newJobVertex1);
+		tasks.put(jobVertexID2, newJobVertex2);
+		coord.restoreLatestCheckpointedState(tasks, true, true);
+
+		// verify the restored state
+		verifiyStateRestore(jobVertexID1, newJobVertex1, keyGroupPartitions1);
+
+		for (int i = 0; i < newJobVertex2.getParallelism(); i++) {
+			List<KeyGroupsStateHandle> originalKeyGroupState = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i));
+
+			ChainedStateHandle<StreamStateHandle> operatorState = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getChainedStateHandle();
+			List<KeyGroupsStateHandle> keyGroupState = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getKeyGroupsStateHandles();
+
+			assertNull(operatorState);
+			comparePartitionedState(originalKeyGroupState, keyGroupState);
+		}
+	}
+
+	// ------------------------------------------------------------------------
+	//  Utilities
+	// ------------------------------------------------------------------------
+
+	static void sendAckMessageToCoordinator(
+			CheckpointCoordinator coord,
+			long checkpointId, JobID jid,
+			ExecutionJobVertex jobVertex,
+			JobVertexID jobVertexID,
+			List<KeyGroupRange> keyGroupPartitions) throws Exception {
+
+		for (int index = 0; index < jobVertex.getParallelism(); index++) {
+			ChainedStateHandle<StreamStateHandle> state = generateStateForVertex(jobVertexID, index);
+			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(
+					jobVertexID,
+					keyGroupPartitions.get(index));
+
+			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
+					jid,
+					jobVertex.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+					checkpointId,
+					state,
+					keyGroupState);
+
+			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
+		}
+	}
+
+	public static List<KeyGroupsStateHandle> generateKeyGroupState(
+			JobVertexID jobVertexID,
+			KeyGroupRange keyGroupPartition) throws IOException {
+
+		KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(keyGroupPartition);
+		List<Integer> testStatesLists = new ArrayList<>(keyGroupPartition.getNumberOfKeyGroups());
+		int runningGroupsOffset = 0;
+		// generate state for one keygroup
+		for (int keyGroupIndex : keyGroupPartition) {
+			Random random = new Random(jobVertexID.hashCode() + keyGroupIndex);
+			int simulatedStateValue = random.nextInt();
+			testStatesLists.add(simulatedStateValue);
+		}
+
+		return generateKeyGroupState(keyGroupPartition, testStatesLists);
+	}
+
+	public static List<KeyGroupsStateHandle> generateKeyGroupState(KeyGroupRange keyGroupRange, List< ? extends Serializable> states) throws IOException {
+		Preconditions.checkArgument(keyGroupRange.getNumberOfKeyGroups() == states.size());
+
+		long[] offsets = new long[keyGroupRange.getNumberOfKeyGroups()];
+		List<byte[]> serializedGroupValues = new ArrayList<>(offsets.length);
+
+		KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(keyGroupRange, offsets);
+
+		int runningGroupsOffset = 0;
+		// generate test state for all keygroups
+		int idx = 0;
+		for (int keyGroup : keyGroupRange) {
+			keyGroupRangeOffsets.setKeyGroupOffset(keyGroup,runningGroupsOffset);
+			byte[] serializedValue = InstantiationUtil.serializeObject(states.get(idx));
+			runningGroupsOffset += serializedValue.length;
+			serializedGroupValues.add(serializedValue);
+			++idx;
+		}
+
+		//write all generated values in a single byte array, which is index by groupOffsetsInFinalByteArray
+		byte[] allSerializedValuesConcatenated = new byte[runningGroupsOffset];
+		runningGroupsOffset = 0;
+		byte[] old = null;
+		for(byte[] serializedGroupValue : serializedGroupValues) {
+			System.arraycopy(
+					serializedGroupValue,
+					0,
+					allSerializedValuesConcatenated,
+					runningGroupsOffset,
+					serializedGroupValue.length);
+			runningGroupsOffset += serializedGroupValue.length;
+			old = serializedGroupValue;
+		}
+
+		ByteStreamStateHandle allSerializedStatesHandle = new ByteStreamStateHandle(
+				allSerializedValuesConcatenated);
+		KeyGroupsStateHandle keyGroupsStateHandle = new KeyGroupsStateHandle(
+				keyGroupRangeOffsets,
+				allSerializedStatesHandle);
+		List<KeyGroupsStateHandle> keyGroupsStateHandleList = new ArrayList<>();
+		keyGroupsStateHandleList.add(keyGroupsStateHandle);
+		return keyGroupsStateHandleList;
+	}
+
+	public static ChainedStateHandle<StreamStateHandle> generateStateForVertex(
+			JobVertexID jobVertexID,
+			int index) throws IOException {
+
+		Random random = new Random(jobVertexID.hashCode() + index);
+		int value = random.nextInt();
+		return generateChainedStateHandle(value);
+	}
+
+	public static ChainedStateHandle<StreamStateHandle> generateChainedStateHandle(
+			Serializable value) throws IOException {
+		return ChainedStateHandle.wrapSingleHandle(ByteStreamStateHandle.fromSerializable(value));
+	}
+
+	public static ExecutionJobVertex mockExecutionJobVertex(
+		JobVertexID jobVertexID,
+		int parallelism,
+		int maxParallelism) {
+		final ExecutionJobVertex executionJobVertex = mock(ExecutionJobVertex.class);
+
+		ExecutionVertex[] executionVertices = new ExecutionVertex[parallelism];
+
+		for (int i = 0; i < parallelism; i++) {
+			executionVertices[i] = mockExecutionVertex(
+				new ExecutionAttemptID(),
+				jobVertexID,
+				parallelism,
+				maxParallelism,
+				ExecutionState.RUNNING);
+
+			when(executionVertices[i].getParallelSubtaskIndex()).thenReturn(i);
+		}
+
+		when(executionJobVertex.getJobVertexId()).thenReturn(jobVertexID);
+		when(executionJobVertex.getTaskVertices()).thenReturn(executionVertices);
+		when(executionJobVertex.getParallelism()).thenReturn(parallelism);
+		when(executionJobVertex.getMaxParallelism()).thenReturn(maxParallelism);
+
+		return executionJobVertex;
+	}
+
+	private static ExecutionVertex mockExecutionVertex(ExecutionAttemptID attemptID) {
+		return mockExecutionVertex(
+			attemptID,
+			new JobVertexID(),
+			1,
+			1,
+			ExecutionState.RUNNING);
+	}
+
+	private static ExecutionVertex mockExecutionVertex(
+		ExecutionAttemptID attemptID,
+		JobVertexID jobVertexID,
+		int parallelism,
+		int maxParallelism,
+		ExecutionState state,
+		ExecutionState ... successiveStates) {
+
+		ExecutionVertex vertex = mock(ExecutionVertex.class);
+
+		final Execution exec = spy(new Execution(
+			mock(ExecutionContext.class),
+			vertex,
+			1,
+			1L,
+			null
+		));
+		when(exec.getAttemptId()).thenReturn(attemptID);
+		when(exec.getState()).thenReturn(state, successiveStates);
+
+		when(vertex.getJobvertexId()).thenReturn(jobVertexID);
+		when(vertex.getCurrentExecutionAttempt()).thenReturn(exec);
+		when(vertex.getTotalNumberOfParallelSubtasks()).thenReturn(parallelism);
+		when(vertex.getMaxParallelism()).thenReturn(maxParallelism);
+
+		return vertex;
+	}
+
+	public static void verifiyStateRestore(
+			JobVertexID jobVertexID, ExecutionJobVertex executionJobVertex,
+			List<KeyGroupRange> keyGroupPartitions) throws Exception {
+
+		for (int i = 0; i < executionJobVertex.getParallelism(); i++) {
+
+			ChainedStateHandle<StreamStateHandle> expectNonPartitionedState = generateStateForVertex(jobVertexID, i);
+			ChainedStateHandle<StreamStateHandle> actualNonPartitionedState = executionJobVertex.
+					getTaskVertices()[i].getCurrentExecutionAttempt().getChainedStateHandle();
+			assertEquals(expectNonPartitionedState.get(0), actualNonPartitionedState.get(0));
+
+			List<KeyGroupsStateHandle> expectPartitionedKeyGroupState = generateKeyGroupState(
+					jobVertexID,
+					keyGroupPartitions.get(i));
+			List<KeyGroupsStateHandle> actualPartitionedKeyGroupState = executionJobVertex.
+					getTaskVertices()[i].getCurrentExecutionAttempt().getKeyGroupsStateHandles();
+			comparePartitionedState(expectPartitionedKeyGroupState, actualPartitionedKeyGroupState);
+		}
+	}
+
+	public static void comparePartitionedState(
+			List<KeyGroupsStateHandle> expectPartitionedKeyGroupState,
+			List<KeyGroupsStateHandle> actualPartitionedKeyGroupState) throws Exception {
+
+		KeyGroupsStateHandle expectedHeadOpKeyGroupStateHandle = expectPartitionedKeyGroupState.get(0);
+		int expectedTotalKeyGroups = expectedHeadOpKeyGroupStateHandle.getNumberOfKeyGroups();
+		int actualTotalKeyGroups = 0;
+		for(KeyGroupsStateHandle keyGroupsStateHandle: actualPartitionedKeyGroupState) {
+			actualTotalKeyGroups += keyGroupsStateHandle.getNumberOfKeyGroups();
+		}
+
+		assertEquals(expectedTotalKeyGroups, actualTotalKeyGroups);
+
+		FSDataInputStream inputStream = expectedHeadOpKeyGroupStateHandle.getStateHandle().openInputStream();
+		for(int groupId : expectedHeadOpKeyGroupStateHandle.keyGroups()) {
+			long offset = expectedHeadOpKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
+			inputStream.seek(offset);
+			int expectedKeyGroupState = InstantiationUtil.deserializeObject(inputStream);
+			for(KeyGroupsStateHandle oneActualKeyGroupStateHandle : actualPartitionedKeyGroupState) {
+				if (oneActualKeyGroupStateHandle.containsKeyGroup(groupId)) {
+					long actualOffset = oneActualKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
+					FSDataInputStream actualInputStream = oneActualKeyGroupStateHandle.getStateHandle().openInputStream();
+					actualInputStream.seek(actualOffset);
+					int actualGroupState = InstantiationUtil.deserializeObject(actualInputStream);
+
+					assertEquals(expectedKeyGroupState, actualGroupState);
+				}
+			}
+		}
+	}
+
+	@Test
+	public void testCreateKeyGroupPartitions() {
+		testCreateKeyGroupPartitions(1, 1);
+		testCreateKeyGroupPartitions(13, 1);
+		testCreateKeyGroupPartitions(13, 2);
+		testCreateKeyGroupPartitions(Short.MAX_VALUE, 1);
+		testCreateKeyGroupPartitions(Short.MAX_VALUE, 13);
+		testCreateKeyGroupPartitions(Short.MAX_VALUE, Short.MAX_VALUE);
+
+		Random r = new Random(1234);
+		for (int k = 0; k < 1000; ++k) {
+			int maxParallelism = 1 + r.nextInt(Short.MAX_VALUE - 1);
+			int parallelism = 1 + r.nextInt(maxParallelism);
+			testCreateKeyGroupPartitions(maxParallelism, parallelism);
+		}
+	}
+
+	private void testCreateKeyGroupPartitions(int maxParallelism, int parallelism) {
+		List<KeyGroupRange> ranges = CheckpointCoordinator.createKeyGroupPartitions(maxParallelism, parallelism);
+		for (int i = 0; i < maxParallelism; ++i) {
+			KeyGroupRange range = ranges.get(KeyGroupRange.computeOperatorIndexForKeyGroup(maxParallelism, parallelism, i));
+			if (!range.contains(i)) {
+				Assert.fail("Could not find expected key-group " + i + " in range " + range);
+			}
+		}
+	}
+
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/516ad011/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
index 5fcc59e..f757943 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
@@ -18,26 +18,54 @@
 package org.apache.flink.streaming.runtime.tasks;
 
 
+import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.common.functions.RichMapFunction;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
+import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
+import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.streaming.api.graph.StreamConfig;
+import org.apache.flink.streaming.api.graph.StreamEdge;
+import org.apache.flink.streaming.api.graph.StreamNode;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.api.operators.StreamMap;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.util.TestHarnessUtil;
+import org.apache.flink.util.InstantiationUtil;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.TestLogger;
 import org.junit.Assert;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.powermock.core.classloader.annotations.PowerMockIgnore;
 import org.powermock.core.classloader.annotations.PrepareForTest;
 import org.powermock.modules.junit4.PowerMockRunner;
+import scala.concurrent.duration.Deadline;
+import scala.concurrent.duration.FiniteDuration;
 
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
+import java.util.Random;
 import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.TimeUnit;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
 
 /**
  * Tests for {@link OneInputStreamTask}.
@@ -51,7 +79,7 @@ import java.util.concurrent.ConcurrentLinkedQueue;
 @RunWith(PowerMockRunner.class)
 @PrepareForTest({ResultPartitionWriter.class})
 @PowerMockIgnore({"javax.management.*", "com.sun.jndi.*"})
-public class OneInputStreamTaskTest {
+public class OneInputStreamTaskTest extends TestLogger {
 
 	/**
 	 * This test verifies that open() and close() are correctly called. This test also verifies
@@ -82,7 +110,7 @@ public class OneInputStreamTaskTest {
 
 		testHarness.waitForTaskCompletion();
 
-		Assert.assertTrue("RichFunction methods where not called.", TestOpenCloseMapFunction.closeCalled);
+		assertTrue("RichFunction methods where not called.", TestOpenCloseMapFunction.closeCalled);
 
 		TestHarnessUtil.assertOutputEquals("Output was not correct.",
 				expectedOutput,
@@ -165,7 +193,7 @@ public class OneInputStreamTaskTest {
 		testHarness.waitForTaskCompletion();
 
 		List<String> resultElements = TestHarnessUtil.getRawElementsFromOutput(testHarness.getOutput());
-		Assert.assertEquals(2, resultElements.size());
+		assertEquals(2, resultElements.size());
 	}
 
 	/**
@@ -293,6 +321,252 @@ public class OneInputStreamTaskTest {
 		TestHarnessUtil.assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput());
 	}
 
+	/**
+	 * Tests that the stream operator can snapshot and restore the operator state of chained
+	 * operators
+	 */
+	@Test
+	public void testSnapshottingAndRestoring() throws Exception {
+		final Deadline deadline = new FiniteDuration(2, TimeUnit.MINUTES).fromNow();
+		final OneInputStreamTask<String, String> streamTask = new OneInputStreamTask<String, String>();
+		final OneInputStreamTaskTestHarness<String, String> testHarness = new OneInputStreamTaskTestHarness<String, String>(streamTask, BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO);
+		IdentityKeySelector<String> keySelector = new IdentityKeySelector<>();
+		testHarness.configureForKeyedStream(keySelector, BasicTypeInfo.STRING_TYPE_INFO);
+
+		long checkpointId = 1L;
+		long checkpointTimestamp = 1L;
+		long recoveryTimestamp = 3L;
+		long seed = 2L;
+		int numberChainedTasks = 11;
+
+		StreamConfig streamConfig = testHarness.getStreamConfig();
+
+		configureChainedTestingStreamOperator(streamConfig, numberChainedTasks, seed, recoveryTimestamp);
+
+		AcknowledgeStreamMockEnvironment env = new AcknowledgeStreamMockEnvironment(
+			testHarness.jobConfig,
+			testHarness.taskConfig,
+			testHarness.executionConfig,
+			testHarness.memorySize,
+			new MockInputSplitProvider(),
+			testHarness.bufferSize);
+
+		// reset number of restore calls
+		TestingStreamOperator.numberRestoreCalls = 0;
+
+		testHarness.invoke(env);
+		testHarness.waitForTaskRunning(deadline.timeLeft().toMillis());
+
+		streamTask.triggerCheckpoint(checkpointId, checkpointTimestamp);
+
+		testHarness.endInput();
+		testHarness.waitForTaskCompletion(deadline.timeLeft().toMillis());
+
+		// since no state was set, there shouldn't be restore calls
+		assertEquals(0, TestingStreamOperator.numberRestoreCalls);
+
+		assertEquals(checkpointId, env.getCheckpointId());
+
+		final OneInputStreamTask<String, String> restoredTask = new OneInputStreamTask<String, String>();
+		restoredTask.setInitialState(env.getState(), env.getKeyGroupStates());
+
+		final OneInputStreamTaskTestHarness<String, String> restoredTaskHarness = new OneInputStreamTaskTestHarness<String, String>(restoredTask, BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO);
+		restoredTaskHarness.configureForKeyedStream(keySelector, BasicTypeInfo.STRING_TYPE_INFO);
+
+		StreamConfig restoredTaskStreamConfig = restoredTaskHarness.getStreamConfig();
+
+		configureChainedTestingStreamOperator(restoredTaskStreamConfig, numberChainedTasks, seed, recoveryTimestamp);
+
+		TestingStreamOperator.numberRestoreCalls = 0;
+
+		restoredTaskHarness.invoke();
+		restoredTaskHarness.endInput();
+		restoredTaskHarness.waitForTaskCompletion(deadline.timeLeft().toMillis());
+
+		// restore of every chained operator should have been called
+		assertEquals(numberChainedTasks, TestingStreamOperator.numberRestoreCalls);
+
+		TestingStreamOperator.numberRestoreCalls = 0;
+	}
+
+	//==============================================================================================
+	// Utility functions and classes
+	//==============================================================================================
+
+	private void configureChainedTestingStreamOperator(
+		StreamConfig streamConfig,
+		int numberChainedTasks,
+		long seed,
+		long recoveryTimestamp) {
+
+		Preconditions.checkArgument(numberChainedTasks >= 1, "The operator chain must at least " +
+			"contain one operator.");
+
+		Random random = new Random(seed);
+
+		TestingStreamOperator<Integer, Integer> previousOperator = new TestingStreamOperator<>(random.nextLong(), recoveryTimestamp);
+		streamConfig.setStreamOperator(previousOperator);
+
+		// create the chain of operators
+		Map<Integer, StreamConfig> chainedTaskConfigs = new HashMap<>(numberChainedTasks - 1);
+		List<StreamEdge> outputEdges = new ArrayList<>(numberChainedTasks - 1);
+
+		for (int chainedIndex = 1; chainedIndex < numberChainedTasks; chainedIndex++) {
+			TestingStreamOperator<Integer, Integer> chainedOperator = new TestingStreamOperator<>(random.nextLong(), recoveryTimestamp);
+			StreamConfig chainedConfig = new StreamConfig(new Configuration());
+			chainedConfig.setStreamOperator(chainedOperator);
+			chainedTaskConfigs.put(chainedIndex, chainedConfig);
+
+			StreamEdge outputEdge = new StreamEdge(
+				new StreamNode(
+					null,
+					chainedIndex - 1,
+					null,
+					null,
+					null,
+					null,
+					null
+				),
+				new StreamNode(
+					null,
+					chainedIndex,
+					null,
+					null,
+					null,
+					null,
+					null
+				),
+				0,
+				Collections.<String>emptyList(),
+				null
+			);
+
+			outputEdges.add(outputEdge);
+		}
+
+		streamConfig.setChainedOutputs(outputEdges);
+		streamConfig.setTransitiveChainedTaskConfigs(chainedTaskConfigs);
+	}
+
+	private static class IdentityKeySelector<IN> implements KeySelector<IN, IN> {
+
+		private static final long serialVersionUID = -3555913664416688425L;
+
+		@Override
+		public IN getKey(IN value) throws Exception {
+			return value;
+		}
+	}
+
+	private static class AcknowledgeStreamMockEnvironment extends StreamMockEnvironment {
+		private long checkpointId;
+		private ChainedStateHandle<StreamStateHandle> state;
+		private List<KeyGroupsStateHandle> keyGroupStates;
+
+		public long getCheckpointId() {
+			return checkpointId;
+		}
+
+		public ChainedStateHandle<StreamStateHandle> getState() {
+			return state;
+		}
+
+		List<KeyGroupsStateHandle> getKeyGroupStates() {
+			List<KeyGroupsStateHandle> result = new ArrayList<>();
+			for (int i = 0; i < keyGroupStates.size(); i++) {
+				if (keyGroupStates.get(i) != null) {
+					result.add(keyGroupStates.get(i));
+				}
+			}
+			return result;
+		}
+
+		AcknowledgeStreamMockEnvironment(Configuration jobConfig, Configuration taskConfig,
+		                                 ExecutionConfig executionConfig, long memorySize,
+		                                 MockInputSplitProvider inputSplitProvider, int bufferSize) {
+			super(jobConfig, taskConfig, executionConfig, memorySize, inputSplitProvider, bufferSize);
+		}
+
+
+		@Override
+		public void acknowledgeCheckpoint(long checkpointId, ChainedStateHandle<StreamStateHandle> state,
+		                                  List<KeyGroupsStateHandle> keyGroupStates) {
+			this.checkpointId = checkpointId;
+			this.state = state;
+			this.keyGroupStates = keyGroupStates;
+		}
+	}
+
+	private static class TestingStreamOperator<IN, OUT>
+			extends AbstractStreamOperator<OUT> implements OneInputStreamOperator<IN, OUT> {
+
+		private static final long serialVersionUID = 774614855940397174L;
+
+		public static int numberRestoreCalls = 0;
+
+		private final long seed;
+		private final long recoveryTimestamp;
+
+		private transient Random random;
+
+		TestingStreamOperator(long seed, long recoveryTimestamp) {
+			this.seed = seed;
+			this.recoveryTimestamp = recoveryTimestamp;
+		}
+
+		@Override
+		public void processElement(StreamRecord<IN> element) throws Exception {
+
+		}
+
+		@Override
+		public void processWatermark(Watermark mark) throws Exception {
+
+		}
+
+		@Override
+		public void snapshotState(FSDataOutputStream out, long checkpointId, long timestamp) throws Exception {
+			if (random == null) {
+				random = new Random(seed);
+			}
+
+			Serializable functionState = generateFunctionState();
+			Integer operatorState = generateOperatorState();
+
+			InstantiationUtil.serializeObject(out, functionState);
+			InstantiationUtil.serializeObject(out, operatorState);
+		}
+
+		@Override
+		public void restoreState(FSDataInputStream in) throws Exception {
+			numberRestoreCalls++;
+
+			if (random == null) {
+				random = new Random(seed);
+			}
+
+			assertEquals(this.recoveryTimestamp, recoveryTimestamp);
+
+			assertNotNull(in);
+
+			Serializable functionState= InstantiationUtil.deserializeObject(in);
+			Integer operatorState= InstantiationUtil.deserializeObject(in);
+
+			assertEquals(random.nextInt(), functionState);
+			assertEquals(random.nextInt(), (int) operatorState);
+		}
+
+
+		private Serializable generateFunctionState() {
+			return random.nextInt();
+		}
+
+		private Integer generateOperatorState() {
+			return random.nextInt();
+		}
+	}
+
+
 	// This must only be used in one test, otherwise the static fields will be changed
 	// by several tests concurrently
 	private static class TestOpenCloseMapFunction extends RichMapFunction<String, String> {

http://git-wip-us.apache.org/repos/asf/flink/blob/516ad011/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
index 8d1baeb..39f3086 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
@@ -352,7 +352,6 @@ public class RescalingITCase extends TestLogger {
 			for (int key = 0; key < numberKeys; key++) {
 				int keyGroupIndex = keyGroupAssigner.getKeyGroupIndex(key);
 
-//				expectedResult.add(Tuple2.of(keyGroupIndex % parallelism, numberElements * key));
 				expectedResult.add(Tuple2.of(KeyGroupRange.computeOperatorIndexForKeyGroup(maxParallelism, parallelism, keyGroupIndex) , numberElements * key));
 			}