You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tr...@apache.org on 2016/04/26 12:23:42 UTC

[2/2] flink git commit: [FLINK-3756] [state] Add state hierarchy to CheckpointCoordinator

[FLINK-3756] [state] Add state hierarchy to CheckpointCoordinator

This commit introduces a state hierarchy for the StateForTask objects kept
at the CheckpointCoordinator. Task states are now grouped together if they
belong to the same ExecutionJobVertex. The StateForTask objects are now
stored in so called StateForTaskGroup objects. The StateForTaskGroup object
can also store the key group state handles associated to a ExecutionJobVertex.

Adapt restore methods of CheckpointCoordinator and SavepointCoordinator

Add state size computation

Add comments to createKeyGroupPartitions; Add more information to StateForTask.toString

Rename StateForTaskGroup -> TaskState, StateForTask -> SubtaskState, KvStateForTasks -> KeyGroupState

This closes #1883.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/0cf04108
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/0cf04108
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/0cf04108

Branch: refs/heads/master
Commit: 0cf04108f70375d41ebb7c39629db3a081bd2876
Parents: 7abf8ef
Author: Till Rohrmann <tr...@apache.org>
Authored: Fri Apr 8 15:20:06 2016 +0200
Committer: Till Rohrmann <tr...@apache.org>
Committed: Tue Apr 26 12:22:53 2016 +0200

----------------------------------------------------------------------
 .../checkpoint/CheckpointCoordinator.java       | 115 +++++---
 .../runtime/checkpoint/CompletedCheckpoint.java |  65 ++++-
 .../flink/runtime/checkpoint/KeyGroupState.java |  91 +++++++
 .../runtime/checkpoint/PendingCheckpoint.java   |  81 ++++--
 .../checkpoint/SavepointCoordinator.java        |  67 +++--
 .../flink/runtime/checkpoint/StateForTask.java  | 141 ----------
 .../flink/runtime/checkpoint/SubtaskState.java  | 120 +++++++++
 .../flink/runtime/checkpoint/TaskState.java     | 171 ++++++++++++
 .../stats/SimpleCheckpointStatsTracker.java     |  33 ++-
 .../deployment/TaskDeploymentDescriptor.java    |  22 +-
 .../flink/runtime/executiongraph/Execution.java |  17 +-
 .../runtime/executiongraph/ExecutionGraph.java  |   3 +
 .../runtime/executiongraph/ExecutionVertex.java |  25 +-
 .../flink/runtime/jobmanager/JobManager.scala   |   9 +
 .../checkpoint/CheckpointCoordinatorTest.java   | 261 ++++++++++++-------
 .../checkpoint/CheckpointStateRestoreTest.java  |  85 +++---
 .../CompletedCheckpointStoreTest.java           |  34 +--
 ...ExecutionGraphCheckpointCoordinatorTest.java |   1 +
 .../checkpoint/SavepointCoordinatorTest.java    |  55 ++--
 .../stats/SimpleCheckpointStatsTrackerTest.java |  91 +++----
 .../test/checkpointing/SavepointITCase.java     |  52 ++--
 21 files changed, 1034 insertions(+), 505 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/0cf04108/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
index 8a856bf..3b6f764 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
@@ -37,15 +37,20 @@ import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.NotifyCheckpointComplete;
 import org.apache.flink.runtime.messages.checkpoint.TriggerCheckpoint;
+import org.apache.flink.runtime.state.StateHandle;
+import org.apache.flink.util.SerializedValue;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.util.ArrayDeque;
+import java.util.ArrayList;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.Iterator;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.Timer;
 import java.util.TimerTask;
 import java.util.UUID;
@@ -146,12 +151,15 @@ public class CheckpointCoordinator {
 	/** Helper for tracking checkpoint statistics  */
 	private final CheckpointStatsTracker statsTracker;
 
+	protected final int numberKeyGroups;
+
 	// --------------------------------------------------------------------------------------------
 
 	public CheckpointCoordinator(
 			JobID job,
 			long baseInterval,
 			long checkpointTimeout,
+			int numberKeyGroups,
 			ExecutionVertex[] tasksToTrigger,
 			ExecutionVertex[] tasksToWaitFor,
 			ExecutionVertex[] tasksToCommitTo,
@@ -160,7 +168,7 @@ public class CheckpointCoordinator {
 			CompletedCheckpointStore completedCheckpointStore,
 			RecoveryMode recoveryMode) throws Exception {
 
-		this(job, baseInterval, checkpointTimeout, 0L, Integer.MAX_VALUE,
+		this(job, baseInterval, checkpointTimeout, 0L, Integer.MAX_VALUE, numberKeyGroups,
 				tasksToTrigger, tasksToWaitFor, tasksToCommitTo,
 				userClassLoader, checkpointIDCounter, completedCheckpointStore, recoveryMode,
 				new DisabledCheckpointStatsTracker());
@@ -172,6 +180,7 @@ public class CheckpointCoordinator {
 			long checkpointTimeout,
 			long minPauseBetweenCheckpoints,
 			int maxConcurrentCheckpointAttempts,
+			int numberKeyGroups,
 			ExecutionVertex[] tasksToTrigger,
 			ExecutionVertex[] tasksToWaitFor,
 			ExecutionVertex[] tasksToCommitTo,
@@ -238,6 +247,8 @@ public class CheckpointCoordinator {
 		else {
 			this.shutdownHook = null;
 		}
+
+		this.numberKeyGroups = numberKeyGroups;
 	}
 
 	// --------------------------------------------------------------------------------------------
@@ -539,7 +550,6 @@ public class CheckpointCoordinator {
 
 		final long checkpointId = message.getCheckpointId();
 
-		CompletedCheckpoint completed = null;
 		PendingCheckpoint checkpoint;
 
 		// Flag indicating whether the ack message was for a known pending
@@ -640,7 +650,11 @@ public class CheckpointCoordinator {
 			if (checkpoint != null && !checkpoint.isDiscarded()) {
 				isPendingCheckpoint = true;
 
-				if (checkpoint.acknowledgeTask(message.getTaskExecutionId(), message.getState(), message.getStateSize())) {
+				if (checkpoint.acknowledgeTask(
+					message.getTaskExecutionId(),
+					message.getState(),
+					message.getStateSize(),
+					null)) { // TODO: Give KV-state to the acknowledgeTask method
 					if (checkpoint.isFullyAcknowledged()) {
 						completed = checkpoint.toCompletedCheckpoint();
 
@@ -648,7 +662,15 @@ public class CheckpointCoordinator {
 
 						LOG.info("Completed checkpoint " + checkpointId + " (in " +
 								completed.getDuration() + " ms)");
-						LOG.debug(completed.getStates().toString());
+
+						if (LOG.isDebugEnabled()) {
+							StringBuilder builder = new StringBuilder();
+							for (Map.Entry<JobVertexID, TaskState> entry: completed.getTaskStates().entrySet()) {
+								builder.append("JobVertexID: ").append(entry.getKey()).append(" {").append(entry.getValue()).append("}");
+							}
+
+							LOG.debug(builder.toString());
+						}
 
 						pendingCheckpoints.remove(checkpointId);
 						rememberRecentCheckpointId(checkpointId);
@@ -781,36 +803,46 @@ public class CheckpointCoordinator {
 
 			long recoveryTimestamp = System.currentTimeMillis();
 
-			if (allOrNothingState) {
-				Map<ExecutionJobVertex, Integer> stateCounts = new HashMap<ExecutionJobVertex, Integer>();
+			for (Map.Entry<JobVertexID, TaskState> taskGroupStateEntry: latest.getTaskStates().entrySet()) {
+				TaskState taskState = taskGroupStateEntry.getValue();
+				ExecutionJobVertex executionJobVertex = tasks.get(taskGroupStateEntry.getKey());
+
+				if (executionJobVertex != null) {
+					// check that we only restore the state if the parallelism has not been changed
+					if (taskState.getParallelism() != executionJobVertex.getParallelism()) {
+						throw new RuntimeException("Cannot restore the latest checkpoint because " +
+							"the parallelism changed. The operator" + executionJobVertex.getJobVertexId() +
+							" has parallelism " + executionJobVertex.getParallelism() + " whereas the corresponding" +
+							"state object has a parallelism of " + taskState.getParallelism());
+					}
+
+					int counter = 0;
+
+					List<Set<Integer>> keyGroupPartitions = createKeyGroupPartitions(numberKeyGroups, executionJobVertex.getParallelism());
 
-				for (StateForTask state : latest.getStates()) {
-					ExecutionJobVertex vertex = tasks.get(state.getOperatorId());
-					Execution exec = vertex.getTaskVertices()[state.getSubtask()].getCurrentExecutionAttempt();
-					exec.setInitialState(state.getState(), recoveryTimestamp);
+					for (int i = 0; i < executionJobVertex.getParallelism(); i++) {
+						SubtaskState subtaskState = taskState.getState(i);
+						SerializedValue<StateHandle<?>> state = null;
+
+						if (subtaskState != null) {
+							// count the number of executions for which we set a state
+							counter++;
+							state = subtaskState.getState();
+						}
 
-					Integer count = stateCounts.get(vertex);
-					if (count != null) {
-						stateCounts.put(vertex, count+1);
-					} else {
-						stateCounts.put(vertex, 1);
+						Map<Integer, SerializedValue<StateHandle<?>>> kvStateForTaskMap = taskState.getUnwrappedKvStates(keyGroupPartitions.get(i));
+
+						Execution currentExecutionAttempt = executionJobVertex.getTaskVertices()[i].getCurrentExecutionAttempt();
+						currentExecutionAttempt.setInitialState(state, kvStateForTaskMap, recoveryTimestamp);
 					}
-				}
 
-				// validate that either all task vertices have state, or none
-				for (Map.Entry<ExecutionJobVertex, Integer> entry : stateCounts.entrySet()) {
-					ExecutionJobVertex vertex = entry.getKey();
-					if (entry.getValue() != vertex.getParallelism()) {
-						throw new IllegalStateException(
-								"The checkpoint contained state only for a subset of tasks for vertex " + vertex);
+					if (allOrNothingState && counter > 0 && counter < executionJobVertex.getParallelism()) {
+						throw new IllegalStateException("The checkpoint contained state only for " +
+							"a subset of tasks for vertex " + executionJobVertex);
 					}
-				}
-			}
-			else {
-				for (StateForTask state : latest.getStates()) {
-					ExecutionJobVertex vertex = tasks.get(state.getOperatorId());
-					Execution exec = vertex.getTaskVertices()[state.getSubtask()].getCurrentExecutionAttempt();
-					exec.setInitialState(state.getState(), recoveryTimestamp);
+				} else {
+					throw new IllegalStateException("There is no execution job vertex for the job" +
+						" vertex ID " + taskGroupStateEntry.getKey());
 				}
 			}
 
@@ -818,6 +850,31 @@ public class CheckpointCoordinator {
 		}
 	}
 
+	/**
+	 * Groups the available set of key groups into key group partitions. A key group partition is
+	 * the set of key groups which is assigned to the same task. Each set of the returned list
+	 * constitutes a key group partition.
+	 *
+	 * @param numberKeyGroups Number of available key groups (indexed from 0 to numberKeyGroups - 1)
+	 * @param parallelism Parallelism to generate the key group partitioning for
+	 * @return List of key group partitions
+	 */
+	protected List<Set<Integer>> createKeyGroupPartitions(int numberKeyGroups, int parallelism) {
+		ArrayList<Set<Integer>> result = new ArrayList<>(parallelism);
+
+		for (int p = 0; p < parallelism; p++) {
+			HashSet<Integer> keyGroupPartition = new HashSet<>();
+
+			for (int k = p; k < numberKeyGroups; k += parallelism) {
+				keyGroupPartition.add(k);
+			}
+
+			result.add(keyGroupPartition);
+		}
+
+		return result;
+	}
+
 	// --------------------------------------------------------------------------------------------
 	//  Accessors
 	// --------------------------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/0cf04108/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
index bce2ffd..f899ae1 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
@@ -19,10 +19,12 @@
 package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.api.common.JobID;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.util.Preconditions;
 
 import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.List;
+import java.util.Map;
+import java.util.Objects;
 
 /**
  * A successful checkpoint describes a checkpoint after all required tasks acknowledged it (with their state)
@@ -42,20 +44,21 @@ public class CompletedCheckpoint implements Serializable {
 	/** The duration of the checkpoint (completion timestamp - trigger timestamp). */
 	private final long duration;
 
-	private final ArrayList<StateForTask> states;
+	/** States of the different task groups belonging to this checkpoint */
+	private final Map<JobVertexID, TaskState> taskStates;
 
 	public CompletedCheckpoint(
-			JobID job,
-			long checkpointID,
-			long timestamp,
-			long completionTimestamp,
-			ArrayList<StateForTask> states) {
+		JobID job,
+		long checkpointID,
+		long timestamp,
+		long completionTimestamp,
+		Map<JobVertexID, TaskState> taskStates) {
 
 		this.job = job;
 		this.checkpointID = checkpointID;
 		this.timestamp = timestamp;
 		this.duration = completionTimestamp - timestamp;
-		this.states = states;
+		this.taskStates = Preconditions.checkNotNull(taskStates);
 	}
 
 	public JobID getJobId() {
@@ -74,20 +77,56 @@ public class CompletedCheckpoint implements Serializable {
 		return duration;
 	}
 
-	public List<StateForTask> getStates() {
-		return states;
+	public long getStateSize() {
+		long result = 0L;
+
+		for (TaskState taskState : taskStates.values()) {
+			result  += taskState.getStateSize();
+		}
+
+		return result;
+	}
+
+	public Map<JobVertexID, TaskState> getTaskStates() {
+		return taskStates;
+	}
+
+	public TaskState getTaskState(JobVertexID jobVertexID) {
+		return taskStates.get(jobVertexID);
 	}
 
 	// --------------------------------------------------------------------------------------------
 	
 	public void discard(ClassLoader userClassLoader) {
-		for(StateForTask state: states){
+		for (TaskState state: taskStates.values()) {
 			state.discard(userClassLoader);
 		}
-		states.clear();
+
+		taskStates.clear();
 	}
 
 	// --------------------------------------------------------------------------------------------
+
+	@Override
+	public boolean equals(Object obj) {
+		if (obj instanceof CompletedCheckpoint) {
+			CompletedCheckpoint other = (CompletedCheckpoint) obj;
+
+			return job.equals(other.job) && checkpointID == other.checkpointID &&
+				timestamp == other.timestamp && duration == other.duration &&
+				taskStates.equals(other.taskStates);
+		} else {
+			return false;
+		}
+	}
+
+	@Override
+	public int hashCode() {
+		return (int) (this.checkpointID ^ this.checkpointID >>> 32) +
+			31 * ((int) (this.timestamp ^ this.timestamp >>> 32) +
+				31 * ((int) (this.duration ^ this.duration >>> 32) +
+					31 * Objects.hash(job, taskStates)));
+	}
 	
 	@Override
 	public String toString() {

http://git-wip-us.apache.org/repos/asf/flink/blob/0cf04108/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/KeyGroupState.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/KeyGroupState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/KeyGroupState.java
new file mode 100644
index 0000000..f510151
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/KeyGroupState.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.runtime.state.StateHandle;
+import org.apache.flink.util.SerializedValue;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.Serializable;
+
+/**
+ * Simple container class which contains the serialized state handle for a key group.
+ *
+ * The key group state handle is kept in serialized form because it can contain user code classes
+ * which might not be available on the JobManager.
+ */
+public class KeyGroupState implements Serializable {
+	private static final long serialVersionUID = -5926696455438467634L;
+
+	private static final Logger LOG = LoggerFactory.getLogger(KeyGroupState.class);
+
+	private final SerializedValue<StateHandle<?>> keyGroupState;
+
+	private final long stateSize;
+
+	private final long duration;
+
+	public KeyGroupState(SerializedValue<StateHandle<?>> keyGroupState, long stateSize, long duration) {
+		this.keyGroupState = keyGroupState;
+
+		this.stateSize = stateSize;
+
+		this.duration = duration;
+	}
+
+	public SerializedValue<StateHandle<?>> getKeyGroupState() {
+		return keyGroupState;
+	}
+
+	public long getDuration() {
+		return duration;
+	}
+
+	public long getStateSize() {
+		return stateSize;
+	}
+
+	public void discard(ClassLoader classLoader) {
+		try {
+			keyGroupState.deserializeValue(classLoader).discardState();
+		} catch (Exception e) {
+			LOG.warn("Failed to discard checkpoint state: " + this, e);
+		}
+	}
+
+	@Override
+	public boolean equals(Object obj) {
+		if (obj instanceof KeyGroupState) {
+			KeyGroupState other = (KeyGroupState) obj;
+
+			return keyGroupState.equals(other.keyGroupState) && stateSize == other.stateSize &&
+				duration == other.duration;
+		} else {
+			return false;
+		}
+	}
+
+	@Override
+	public int hashCode() {
+		return (int) (this.stateSize ^ this.stateSize >>> 32) +
+			31 * ((int) (this.duration ^ this.duration >>> 32) +
+				31 * keyGroupState.hashCode());
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/0cf04108/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
index 2aa3469..850440e 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
@@ -18,13 +18,13 @@
 
 package org.apache.flink.runtime.checkpoint;
 
-import java.util.ArrayList;
-import java.util.List;
+import java.util.HashMap;
 import java.util.Map;
 
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.state.StateHandle;
 import org.apache.flink.util.SerializedValue;
 
@@ -45,9 +45,9 @@ public class PendingCheckpoint {
 	private final long checkpointId;
 	
 	private final long checkpointTimestamp;
-	
-	private final List<StateForTask> collectedStates;
-	
+
+	private final Map<JobVertexID, TaskState> taskStates;
+
 	private final Map<ExecutionAttemptID, ExecutionVertex> notYetAcknowledgedTasks;
 	
 	private int numAcknowledgedTasks;
@@ -71,7 +71,7 @@ public class PendingCheckpoint {
 		this.checkpointTimestamp = checkpointTimestamp;
 		
 		this.notYetAcknowledgedTasks = verticesToConfirm;
-		this.collectedStates = new ArrayList<StateForTask>(notYetAcknowledgedTasks.size());
+		this.taskStates = new HashMap<>();
 	}
 	
 	// --------------------------------------------------------------------------------------------
@@ -96,7 +96,11 @@ public class PendingCheckpoint {
 	public int getNumberOfAcknowledgedTasks() {
 		return numAcknowledgedTasks;
 	}
-	
+
+	public Map<JobVertexID, TaskState> getTaskStates() {
+		return taskStates;
+	}
+
 	public boolean isFullyAcknowledged() {
 		return this.notYetAcknowledgedTasks.isEmpty() && !discarded;
 	}
@@ -105,18 +109,18 @@ public class PendingCheckpoint {
 		return discarded;
 	}
 	
-	public List<StateForTask> getCollectedStates() {
-		return collectedStates;
-	}
-	
 	public CompletedCheckpoint toCompletedCheckpoint() {
 		synchronized (lock) {
 			if (discarded) {
 				throw new IllegalStateException("pending checkpoint is discarded");
 			}
 			if (notYetAcknowledgedTasks.isEmpty()) {
-				CompletedCheckpoint completed =  new CompletedCheckpoint(jobId, checkpointId,
-						checkpointTimestamp, System.currentTimeMillis(), new ArrayList<StateForTask>(collectedStates));
+				CompletedCheckpoint completed =  new CompletedCheckpoint(
+					jobId,
+					checkpointId,
+					checkpointTimestamp,
+					System.currentTimeMillis(),
+					new HashMap<>(taskStates));
 				dispose(null, false);
 				
 				return completed;
@@ -130,7 +134,8 @@ public class PendingCheckpoint {
 	public boolean acknowledgeTask(
 			ExecutionAttemptID attemptID,
 			SerializedValue<StateHandle<?>> state,
-			long stateSize) {
+			long stateSize,
+			Map<Integer, SerializedValue<StateHandle<?>>> kvState) {
 
 		synchronized (lock) {
 			if (discarded) {
@@ -139,13 +144,43 @@ public class PendingCheckpoint {
 			
 			ExecutionVertex vertex = notYetAcknowledgedTasks.remove(attemptID);
 			if (vertex != null) {
-				if (state != null) {
-					collectedStates.add(new StateForTask(
-							state,
-							stateSize,
-							vertex.getJobvertexId(),
+				if (state != null || kvState != null) {
+
+					JobVertexID jobVertexID = vertex.getJobvertexId();
+
+					TaskState taskState;
+
+					if (taskStates.containsKey(jobVertexID)) {
+						taskState = taskStates.get(jobVertexID);
+					} else {
+						taskState = new TaskState(jobVertexID, vertex.getTotalNumberOfParallelSubtasks());
+						taskStates.put(jobVertexID, taskState);
+					}
+
+					long timestamp = System.currentTimeMillis() - checkpointTimestamp;
+
+					if (state != null) {
+						taskState.putState(
 							vertex.getParallelSubtaskIndex(),
-							System.currentTimeMillis() - checkpointTimestamp));
+							new SubtaskState(
+								state,
+								stateSize,
+								timestamp
+							)
+						);
+					}
+
+					if (kvState != null) {
+						for (Map.Entry<Integer, SerializedValue<StateHandle<?>>> entry : kvState.entrySet()) {
+							taskState.putKvState(
+								entry.getKey(),
+								new KeyGroupState(
+									entry.getValue(),
+									0L,
+									timestamp
+								));
+						}
+					}
 				}
 				numAcknowledgedTasks++;
 				return true;
@@ -168,11 +203,11 @@ public class PendingCheckpoint {
 			discarded = true;
 			numAcknowledgedTasks = -1;
 			if (releaseState) {
-				for (StateForTask state : collectedStates) {
-					state.discard(userClassLoader);
+				for (TaskState taskState : taskStates.values()) {
+					taskState.discard(userClassLoader);
 				}
 			}
-			collectedStates.clear();
+			taskStates.clear();
 			notYetAcknowledgedTasks.clear();
 		}
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/0cf04108/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SavepointCoordinator.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SavepointCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SavepointCoordinator.java
index 5008932..034eefe 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SavepointCoordinator.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SavepointCoordinator.java
@@ -30,6 +30,8 @@ import org.apache.flink.runtime.instance.ActorGateway;
 import org.apache.flink.runtime.instance.AkkaActorGateway;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobmanager.RecoveryMode;
+import org.apache.flink.runtime.state.StateHandle;
+import org.apache.flink.util.SerializedValue;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import scala.concurrent.Future;
@@ -38,6 +40,7 @@ import scala.concurrent.Promise;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.UUID;
 import java.util.concurrent.ConcurrentHashMap;
 
@@ -80,6 +83,7 @@ public class SavepointCoordinator extends CheckpointCoordinator {
 			JobID jobId,
 			long baseInterval,
 			long checkpointTimeout,
+			int numberKeyGroups,
 			ExecutionVertex[] tasksToTrigger,
 			ExecutionVertex[] tasksToWaitFor,
 			ExecutionVertex[] tasksToCommitTo,
@@ -93,6 +97,7 @@ public class SavepointCoordinator extends CheckpointCoordinator {
 				checkpointTimeout,
 				0L,
 				Integer.MAX_VALUE,
+				numberKeyGroups,
 				tasksToTrigger,
 				tasksToWaitFor,
 				tasksToCommitTo,
@@ -193,35 +198,55 @@ public class SavepointCoordinator extends CheckpointCoordinator {
 
 			// Set the initial state of all tasks
 			LOG.debug("Rolling back individual operators.");
-			for (StateForTask state : checkpoint.getStates()) {
-				LOG.debug("Rolling back subtask {} of operator {}.",
-						state.getSubtask(), state.getOperatorId());
 
-				ExecutionJobVertex vertex = tasks.get(state.getOperatorId());
+			for (Map.Entry<JobVertexID, TaskState> taskStateEntry: checkpoint.getTaskStates().entrySet()) {
+				TaskState taskState = taskStateEntry.getValue();
+				ExecutionJobVertex executionJobVertex = tasks.get(taskStateEntry.getKey());
+
+				if (executionJobVertex != null) {
+					if (executionJobVertex.getParallelism() != taskState.getParallelism()) {
+						String msg = String.format("Failed to rollback to savepoint %s. " +
+								"Parallelism mismatch between savepoint state and new program. " +
+								"Cannot map operator %s with parallelism %d to new program with " +
+								"parallelism %d. This indicates that the program has been changed " +
+								"in a non-compatible way after the savepoint.",
+							checkpoint,
+							taskStateEntry.getKey(),
+							taskState.getParallelism(),
+							executionJobVertex.getParallelism());
+
+						throw new IllegalStateException(msg);
+					}
+
+					List<Set<Integer>> keyGroupPartitions = createKeyGroupPartitions(
+						numberKeyGroups,
+						executionJobVertex.getParallelism());
+
+					for (int i = 0; i < executionJobVertex.getTaskVertices().length; i++) {
+						SubtaskState subtaskState = taskState.getState(i);
+						SerializedValue<StateHandle<?>> state = null;
+
+						if (subtaskState != null) {
+							state = subtaskState.getState();
+						}
 
-				if (vertex == null) {
+						Map<Integer, SerializedValue<StateHandle<?>>> kvStateForTaskMap = taskState
+							.getUnwrappedKvStates(keyGroupPartitions.get(i));
+
+						Execution currentExecutionAttempt = executionJobVertex
+							.getTaskVertices()[i]
+							.getCurrentExecutionAttempt();
+
+						currentExecutionAttempt.setInitialState(state, kvStateForTaskMap, recoveryTimestamp);
+					}
+				} else {
 					String msg = String.format("Failed to rollback to savepoint %s. " +
 							"Cannot map old state for task %s to the new program. " +
 							"This indicates that the program has been changed in a " +
 							"non-compatible way  after the savepoint.", checkpoint,
-							state.getOperatorId());
+						taskStateEntry.getKey());
 					throw new IllegalStateException(msg);
 				}
-
-				if (state.getSubtask() >= vertex.getParallelism()) {
-					String msg = String.format("Failed to rollback to savepoint %s. " +
-							"Parallelism mismatch between savepoint state and new program. " +
-							"Cannot map subtask %d of operator %s to new program with " +
-							"parallelism %d. This indicates that the program has been changed " +
-							"in a non-compatible way after the savepoint.", checkpoint,
-							state.getSubtask(), state.getOperatorId(), vertex.getParallelism());
-					throw new IllegalStateException(msg);
-				}
-
-				Execution exec = vertex.getTaskVertices()[state.getSubtask()]
-						.getCurrentExecutionAttempt();
-
-				exec.setInitialState(state.getState(), recoveryTimestamp);
 			}
 
 			// Reset the checkpoint ID counter

http://git-wip-us.apache.org/repos/asf/flink/blob/0cf04108/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateForTask.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateForTask.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateForTask.java
deleted file mode 100644
index 57fdccd..0000000
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateForTask.java
+++ /dev/null
@@ -1,141 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.runtime.checkpoint;
-
-import org.apache.flink.runtime.jobgraph.JobVertexID;
-import org.apache.flink.runtime.state.StateHandle;
-import org.apache.flink.util.SerializedValue;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.io.Serializable;
-
-import static com.google.common.base.Preconditions.checkArgument;
-import static com.google.common.base.Preconditions.checkNotNull;
-
-/**
- * Simple bean to describe the state belonging to a parallel operator.
- * Since we hold the state across execution attempts, we identify a task by its
- * JobVertexId and subtask index.
- * 
- * The state itself is kept in serialized from, since the checkpoint coordinator itself
- * is never looking at it anyways and only sends it back out in case of a recovery.
- * Furthermore, the state may involve user-defined classes that are not accessible without
- * the respective classloader.
- */
-public class StateForTask implements Serializable {
-
-	private static final long serialVersionUID = -2394696997971923995L;
-
-	private static final Logger LOG = LoggerFactory.getLogger(StateForTask.class);
-
-	/** The state of the parallel operator */
-	private final SerializedValue<StateHandle<?>> state;
-
-	/**
-	 * The state size. This is also part of the deserialized state handle.
-	 * We store it here in order to not deserialize the state handle when
-	 * gathering stats.
-	 */
-	private final long stateSize;
-
-	/** The vertex id of the parallel operator */
-	private final JobVertexID operatorId;
-	
-	/** The index of the parallel subtask */
-	private final int subtask;
-
-	/** The duration of the acknowledged (ack timestamp - trigger timestamp). */
-	private final long duration;
-	
-	public StateForTask(
-			SerializedValue<StateHandle<?>> state,
-			long stateSize,
-			JobVertexID operatorId,
-			int subtask,
-			long duration) {
-
-		this.state = checkNotNull(state, "State");
-		// Sanity check and don't fail checkpoint because of this.
-		this.stateSize = stateSize >= 0 ? stateSize : 0;
-		this.operatorId = checkNotNull(operatorId, "Operator ID");
-
-		checkArgument(subtask >= 0, "Negative subtask index");
-		this.subtask = subtask;
-
-		this.duration = duration;
-	}
-
-	// --------------------------------------------------------------------------------------------
-	
-	public SerializedValue<StateHandle<?>> getState() {
-		return state;
-	}
-
-	public long getStateSize() {
-		return stateSize;
-	}
-
-	public JobVertexID getOperatorId() {
-		return operatorId;
-	}
-
-	public int getSubtask() {
-		return subtask;
-	}
-
-	public long getDuration() {
-		return duration;
-	}
-
-	public void discard(ClassLoader userClassLoader) {
-		try {
-			state.deserializeValue(userClassLoader).discardState();
-		} catch (Exception e) {
-			LOG.warn("Failed to discard checkpoint state: " + this, e);
-		}
-	}
-
-	// --------------------------------------------------------------------------------------------
-
-	@Override
-	public boolean equals(Object o) {
-		if (this == o) {
-			return true;
-		}
-		else if (o instanceof StateForTask) {
-			StateForTask that = (StateForTask) o;
-			return this.subtask == that.subtask && this.operatorId.equals(that.operatorId)
-					&& this.state.equals(that.state);
-		}
-		else {
-			return false;
-		}
-	}
-
-	@Override
-	public int hashCode() {
-		return state.hashCode() + 31 * operatorId.hashCode() + 43 * subtask;
-	}
-
-	@Override
-	public String toString() {
-		return String.format("StateForTask %s-%d : %s", operatorId, subtask, state);
-	}
-}

http://git-wip-us.apache.org/repos/asf/flink/blob/0cf04108/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
new file mode 100644
index 0000000..2ad83b8
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.runtime.state.StateHandle;
+import org.apache.flink.util.SerializedValue;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.Serializable;
+
+import static com.google.common.base.Preconditions.checkNotNull;
+
+/**
+ * Simple bean to describe the state belonging to a parallel operator. It is part of the
+ * {@link TaskState}.
+ * 
+ * The state itself is kept in serialized form, since the checkpoint coordinator itself
+ * is never looking at it anyways and only sends it back out in case of a recovery.
+ * Furthermore, the state may involve user-defined classes that are not accessible without
+ * the respective classloader.
+ */
+public class SubtaskState implements Serializable {
+
+	private static final long serialVersionUID = -2394696997971923995L;
+
+	private static final Logger LOG = LoggerFactory.getLogger(SubtaskState.class);
+
+	/** The state of the parallel operator */
+	private final SerializedValue<StateHandle<?>> state;
+
+	/**
+	 * The state size. This is also part of the deserialized state handle.
+	 * We store it here in order to not deserialize the state handle when
+	 * gathering stats.
+	 */
+	private final long stateSize;
+
+	/** The duration of the acknowledged (ack timestamp - trigger timestamp). */
+	private final long duration;
+	
+	public SubtaskState(
+			SerializedValue<StateHandle<?>> state,
+			long stateSize,
+			long duration) {
+
+		this.state = checkNotNull(state, "State");
+		// Sanity check and don't fail checkpoint because of this.
+		this.stateSize = stateSize >= 0 ? stateSize : 0;
+
+		this.duration = duration;
+	}
+
+	// --------------------------------------------------------------------------------------------
+	
+	public SerializedValue<StateHandle<?>> getState() {
+		return state;
+	}
+
+	public long getStateSize() {
+		return stateSize;
+	}
+
+	public long getDuration() {
+		return duration;
+	}
+
+	public void discard(ClassLoader userClassLoader) {
+		try {
+			state.deserializeValue(userClassLoader).discardState();
+		} catch (Exception e) {
+			LOG.warn("Failed to discard checkpoint state: " + this, e);
+		}
+	}
+
+	// --------------------------------------------------------------------------------------------
+
+	@Override
+	public boolean equals(Object o) {
+		if (this == o) {
+			return true;
+		}
+		else if (o instanceof SubtaskState) {
+			SubtaskState that = (SubtaskState) o;
+			return this.state.equals(that.state) && stateSize == that.stateSize &&
+				duration == that.duration;
+		}
+		else {
+			return false;
+		}
+	}
+
+	@Override
+	public int hashCode() {
+		return (int) (this.stateSize ^ this.stateSize >>> 32) +
+			31 * ((int) (this.duration ^ this.duration >>> 32) +
+				31 * state.hashCode());
+	}
+
+	@Override
+	public String toString() {
+		return String.format("StateForTask(Size: %d, Duration: %d, State: %s)", stateSize, duration, state);
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/0cf04108/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
new file mode 100644
index 0000000..2d57021
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.state.StateHandle;
+import org.apache.flink.util.SerializedValue;
+
+import java.io.Serializable;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+
+/**
+ * Simple container class which contains the task state and key-value state handles for the sub
+ * tasks of a {@link org.apache.flink.runtime.jobgraph.JobVertex}.
+ *
+ * This class basically groups all tasks and key groups belonging to the same job vertex together.
+ */
+public class TaskState implements Serializable {
+
+	private static final long serialVersionUID = -4845578005863201810L;
+
+	private final JobVertexID jobVertexID;
+
+	/** Map of task states which can be accessed by their sub task index */
+	private final Map<Integer, SubtaskState> subtaskStates;
+
+	/** Map of key-value states which can be accessed by their key group index */
+	private final Map<Integer, KeyGroupState> kvStates;
+
+	/** Parallelism of the operator when it was checkpointed */
+	private final int parallelism;
+
+	public TaskState(JobVertexID jobVertexID, int parallelism) {
+		this.jobVertexID = jobVertexID;
+
+		this.subtaskStates = new HashMap<>(parallelism);
+
+		this.kvStates = new HashMap<>();
+
+		this.parallelism = parallelism;
+	}
+
+	public JobVertexID getJobVertexID() {
+		return jobVertexID;
+	}
+
+	public void putState(int subtaskIndex, SubtaskState subtaskState) {
+		if (subtaskIndex < 0 || subtaskIndex >= parallelism) {
+			throw new IndexOutOfBoundsException("The given sub task index " + subtaskIndex +
+				" exceeds the maximum number of sub tasks " + subtaskStates.size());
+		} else {
+			subtaskStates.put(subtaskIndex, subtaskState);
+		}
+	}
+
+	public SubtaskState getState(int subtaskIndex) {
+		if (subtaskIndex < 0 || subtaskIndex >= parallelism) {
+			throw new IndexOutOfBoundsException("The given sub task index " + subtaskIndex +
+				" exceeds the maximum number of sub tasks " + subtaskStates.size());
+		} else {
+			return subtaskStates.get(subtaskIndex);
+		}
+	}
+
+	public Collection<SubtaskState> getStates() {
+		return subtaskStates.values();
+	}
+
+	public long getStateSize() {
+		long result = 0L;
+
+		for (SubtaskState subtaskState : subtaskStates.values()) {
+			result += subtaskState.getStateSize();
+		}
+
+		for (KeyGroupState keyGroupState : kvStates.values()) {
+			result += keyGroupState.getStateSize();
+		}
+
+		return result;
+	}
+
+	public int getNumberCollectedStates() {
+		return subtaskStates.size();
+	}
+
+	public int getParallelism() {
+		return parallelism;
+	}
+
+	public void putKvState(int keyGroupId, KeyGroupState keyGroupState) {
+		kvStates.put(keyGroupId, keyGroupState);
+	}
+
+	public KeyGroupState getKvState(int keyGroupId) {
+		return kvStates.get(keyGroupId);
+	}
+
+	/**
+	 * Retrieve the set of key-value state key groups specified by the given key group partition set.
+	 * The key groups are returned as a map where the key group index maps to the serialized state
+	 * handle of the key group.
+	 *
+	 * @param keyGroupPartition Set of key group indices
+	 * @return Map of serialized key group state handles indexed by their key group index.
+	 */
+	public Map<Integer, SerializedValue<StateHandle<?>>> getUnwrappedKvStates(Set<Integer> keyGroupPartition) {
+		HashMap<Integer, SerializedValue<StateHandle<?>>> result = new HashMap<>(keyGroupPartition.size());
+
+		for (Integer keyGroupId : keyGroupPartition) {
+			KeyGroupState keyGroupState = kvStates.get(keyGroupId);
+
+			if (keyGroupState != null) {
+				result.put(keyGroupId, kvStates.get(keyGroupId).getKeyGroupState());
+			}
+		}
+
+		return result;
+	}
+
+	public int getNumberCollectedKvStates() {
+		return kvStates.size();
+	}
+
+	public void discard(ClassLoader classLoader) {
+		for (SubtaskState subtaskState : subtaskStates.values()) {
+			subtaskState.discard(classLoader);
+		}
+
+		for (KeyGroupState keyGroupState : kvStates.values()) {
+			keyGroupState.discard(classLoader);
+		}
+	}
+
+	@Override
+	public boolean equals(Object obj) {
+		if (obj instanceof TaskState) {
+			TaskState other = (TaskState) obj;
+
+			return jobVertexID.equals(other.jobVertexID) && parallelism == other.parallelism &&
+				subtaskStates.equals(other.subtaskStates) && kvStates.equals(other.kvStates);
+		} else {
+			return false;
+		}
+	}
+
+	@Override
+	public int hashCode() {
+		return parallelism + 31 * Objects.hash(jobVertexID, subtaskStates, kvStates);
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/0cf04108/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTracker.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTracker.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTracker.java
index aea18e9..a3d9a0e 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTracker.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTracker.java
@@ -19,8 +19,9 @@
 package org.apache.flink.runtime.checkpoint.stats;
 
 import org.apache.flink.runtime.checkpoint.CompletedCheckpoint;
-import org.apache.flink.runtime.checkpoint.StateForTask;
+import org.apache.flink.runtime.checkpoint.SubtaskState;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
+import org.apache.flink.runtime.checkpoint.TaskState;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import scala.Option;
 
@@ -133,29 +134,27 @@ public class SimpleCheckpointStatsTracker implements CheckpointStatsTracker {
 		}
 
 		synchronized (statsLock) {
-			long overallStateSize = 0;
+			long overallStateSize = checkpoint.getStateSize();
 
 			// Operator stats
 			Map<JobVertexID, long[][]> statsForSubTasks = new HashMap<>();
 
-			for (StateForTask state : checkpoint.getStates()) {
-				// Job-level checkpoint size is sum of all state sizes
-				overallStateSize += state.getStateSize();
+			for (Map.Entry<JobVertexID, TaskState> taskStateEntry: checkpoint.getTaskStates().entrySet()) {
+				JobVertexID jobVertexID = taskStateEntry.getKey();
+				TaskState taskState = taskStateEntry.getValue();
 
-				// Subtask stats
-				JobVertexID opId = state.getOperatorId();
-				long[][] statsPerSubtask = statsForSubTasks.get(opId);
+				int parallelism = taskParallelism.get(jobVertexID);
+				long[][] statsPerSubtask = new long[parallelism][2];
 
-				if (statsPerSubtask == null) {
-					int parallelism = taskParallelism.get(opId);
-					statsPerSubtask = new long[parallelism][2];
-					statsForSubTasks.put(opId, statsPerSubtask);
-				}
+				statsForSubTasks.put(jobVertexID, statsPerSubtask);
+
+				for (int i = 0; i < Math.min(taskState.getParallelism(), parallelism); i++) {
+					SubtaskState subtaskState = taskState.getState(i);
 
-				int subTaskIndex = state.getSubtask();
-				if (subTaskIndex < statsPerSubtask.length) {
-					statsPerSubtask[subTaskIndex][0] = state.getDuration();
-					statsPerSubtask[subTaskIndex][1] = state.getStateSize();
+					if (subtaskState != null) {
+						statsPerSubtask[i][0] = subtaskState.getDuration();
+						statsPerSubtask[i][1] = subtaskState.getStateSize();
+					}
 				}
 			}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/0cf04108/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java
index 73c37b2..948f6af 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java
@@ -160,9 +160,25 @@ public final class TaskDeploymentDescriptor implements Serializable {
 		List<URL> requiredClasspaths,
 		int targetSlotNumber) {
 
-		this(jobID, vertexID, executionId, executionConfig, taskName, indexInSubtaskGroup,
-				numberOfSubtasks, attemptNumber, jobConfiguration, taskConfiguration, invokableClassName,
-				producedPartitions, inputGates, requiredJarFiles, requiredClasspaths, targetSlotNumber, null, -1);
+		this(
+			jobID,
+			vertexID,
+			executionId,
+			executionConfig,
+			taskName,
+			indexInSubtaskGroup,
+			numberOfSubtasks,
+			attemptNumber,
+			jobConfiguration,
+			taskConfiguration,
+			invokableClassName,
+			producedPartitions,
+			inputGates,
+			requiredJarFiles,
+			requiredClasspaths,
+			targetSlotNumber,
+			null,
+			-1);
 	}
 
 	/**

http://git-wip-us.apache.org/repos/asf/flink/blob/0cf04108/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
index 6d5832b..a14248e 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
@@ -137,6 +137,8 @@ public class Execution implements Serializable {
 	private volatile InstanceConnectionInfo assignedResourceLocation; // for the archived execution
 
 	private SerializedValue<StateHandle<?>> operatorState;
+
+	private Map<Integer, SerializedValue<StateHandle<?>>> operatorKvState;
 	
 	private long recoveryTimestamp;
 
@@ -235,11 +237,16 @@ public class Execution implements Serializable {
 		partialInputChannelDeploymentDescriptors = null;
 	}
 
-	public void setInitialState(SerializedValue<StateHandle<?>> initialState, long recoveryTimestamp) {
+	public void setInitialState(
+		SerializedValue<StateHandle<?>> initialState,
+		Map<Integer, SerializedValue<StateHandle<?>>> initialKvState,
+		long recoveryTimestamp) {
+
 		if (state != ExecutionState.CREATED) {
 			throw new IllegalArgumentException("Can only assign operator state when execution attempt is in CREATED");
 		}
 		this.operatorState = initialState;
+		this.operatorKvState = initialKvState;
 		this.recoveryTimestamp = recoveryTimestamp;
 	}
 
@@ -364,7 +371,13 @@ public class Execution implements Serializable {
 						attemptNumber, slot.getInstance().getInstanceConnectionInfo().getHostname()));
 			}
 
-			final TaskDeploymentDescriptor deployment = vertex.createDeploymentDescriptor(attemptId, slot, operatorState, recoveryTimestamp, attemptNumber);
+			final TaskDeploymentDescriptor deployment = vertex.createDeploymentDescriptor(
+				attemptId,
+				slot,
+				operatorState,
+				operatorKvState,
+				recoveryTimestamp,
+				attemptNumber);
 
 			// register this execution at the execution graph, to receive call backs
 			vertex.getExecutionGraph().registerExecution(this);

http://git-wip-us.apache.org/repos/asf/flink/blob/0cf04108/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
index 2ad7832..9ee8ee5 100755
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
@@ -345,6 +345,7 @@ public class ExecutionGraph implements Serializable {
 			long checkpointTimeout,
 			long minPauseBetweenCheckpoints,
 			int maxConcurrentCheckpoints,
+			int numberKeyGroups,
 			List<ExecutionJobVertex> verticesToTrigger,
 			List<ExecutionJobVertex> verticesToWaitFor,
 			List<ExecutionJobVertex> verticesToCommitTo,
@@ -380,6 +381,7 @@ public class ExecutionGraph implements Serializable {
 				checkpointTimeout,
 				minPauseBetweenCheckpoints,
 				maxConcurrentCheckpoints,
+				numberKeyGroups,
 				tasksToTrigger,
 				tasksToWaitFor,
 				tasksToCommitTo,
@@ -399,6 +401,7 @@ public class ExecutionGraph implements Serializable {
 				jobID,
 				interval,
 				checkpointTimeout,
+				numberKeyGroups,
 				tasksToTrigger,
 				tasksToWaitFor,
 				tasksToCommitTo,

http://git-wip-us.apache.org/repos/asf/flink/blob/0cf04108/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
index 80430bc..4d27423 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
@@ -637,6 +637,7 @@ public class ExecutionVertex implements Serializable {
 			ExecutionAttemptID executionId,
 			SimpleSlot targetSlot,
 			SerializedValue<StateHandle<?>> operatorState,
+			Map<Integer, SerializedValue<StateHandle<?>>> operatorKvState,
 			long recoveryTimestamp,
 			int attemptNumber) {
 
@@ -670,11 +671,25 @@ public class ExecutionVertex implements Serializable {
 		List<BlobKey> jarFiles = getExecutionGraph().getRequiredJarFiles();
 		List<URL> classpaths = getExecutionGraph().getRequiredClasspaths();
 
-		return new TaskDeploymentDescriptor(getJobId(), getJobvertexId(), executionId, config, getTaskName(),
-				subTaskIndex, getTotalNumberOfParallelSubtasks(), attemptNumber, getExecutionGraph().getJobConfiguration(),
-				jobVertex.getJobVertex().getConfiguration(), jobVertex.getJobVertex().getInvokableClassName(),
-				producedPartitions, consumedPartitions, jarFiles, classpaths, targetSlot.getRoot().getSlotNumber(),
-				operatorState, recoveryTimestamp);
+		return new TaskDeploymentDescriptor(
+			getJobId(),
+			getJobvertexId(),
+			executionId,
+			config,
+			getTaskName(),
+			subTaskIndex,
+			getTotalNumberOfParallelSubtasks(),
+			attemptNumber,
+			getExecutionGraph().getJobConfiguration(),
+			jobVertex.getJobVertex().getConfiguration(),
+			jobVertex.getJobVertex().getInvokableClassName(),
+			producedPartitions,
+			consumedPartitions,
+			jarFiles,
+			classpaths,
+			targetSlot.getRoot().getSlotNumber(),
+			operatorState,
+			recoveryTimestamp);
 	}
 
 	// --------------------------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/0cf04108/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala b/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala
index a8f1a0a..9f0482e 100644
--- a/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala
+++ b/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala
@@ -1196,11 +1196,20 @@ class JobManager(
               new SimpleCheckpointStatsTracker(historySize, ackVertices)
             }
 
+          val jobParallelism = jobGraph.getExecutionConfig.getParallelism()
+
+          val parallelism = if (jobParallelism == ExecutionConfig.PARALLELISM_AUTO_MAX) {
+            numSlots
+          } else {
+            jobGraph.getExecutionConfig.getParallelism
+          }
+
           executionGraph.enableSnapshotCheckpointing(
             snapshotSettings.getCheckpointInterval,
             snapshotSettings.getCheckpointTimeout,
             snapshotSettings.getMinPauseBetweenCheckpoints,
             snapshotSettings.getMaxConcurrentCheckpoints,
+            parallelism,
             triggerVertices,
             ackVertices,
             confirmVertices,

http://git-wip-us.apache.org/repos/asf/flink/blob/0cf04108/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index 266cf52..62af42b 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -80,12 +80,17 @@ public class CheckpointCoordinatorTest {
 
 			// set up the coordinator and validate the initial state
 			CheckpointCoordinator coord = new CheckpointCoordinator(
-					jid, 600000, 600000,
-					new ExecutionVertex[] { triggerVertex1, triggerVertex2 },
-					new ExecutionVertex[] { ackVertex1, ackVertex2 },
-					new ExecutionVertex[] {}, cl,
-					new StandaloneCheckpointIDCounter(),
-					new StandaloneCompletedCheckpointStore(1, cl), RecoveryMode.STANDALONE);
+				jid,
+				600000,
+				600000,
+				42,
+				new ExecutionVertex[] { triggerVertex1, triggerVertex2 },
+				new ExecutionVertex[] { ackVertex1, ackVertex2 },
+				new ExecutionVertex[] {},
+				cl,
+				new StandaloneCheckpointIDCounter(),
+				new StandaloneCompletedCheckpointStore(1, cl),
+				RecoveryMode.STANDALONE);
 
 			// nothing should be happening
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
@@ -126,12 +131,17 @@ public class CheckpointCoordinatorTest {
 
 			// set up the coordinator and validate the initial state
 			CheckpointCoordinator coord = new CheckpointCoordinator(
-					jid, 600000, 600000,
-					new ExecutionVertex[] { triggerVertex1, triggerVertex2 },
-					new ExecutionVertex[] { ackVertex1, ackVertex2 },
-					new ExecutionVertex[] {}, cl,
-					new StandaloneCheckpointIDCounter(),
-					new StandaloneCompletedCheckpointStore(1, cl), RecoveryMode.STANDALONE);
+				jid,
+				600000,
+				600000,
+				42,
+				new ExecutionVertex[] { triggerVertex1, triggerVertex2 },
+				new ExecutionVertex[] { ackVertex1, ackVertex2 },
+				new ExecutionVertex[] {},
+				cl,
+				new StandaloneCheckpointIDCounter(),
+				new StandaloneCompletedCheckpointStore(1, cl),
+				RecoveryMode.STANDALONE);
 
 			// nothing should be happening
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
@@ -170,11 +180,17 @@ public class CheckpointCoordinatorTest {
 
 			// set up the coordinator and validate the initial state
 			CheckpointCoordinator coord = new CheckpointCoordinator(
-					jid, 600000, 600000,
-					new ExecutionVertex[] { triggerVertex1, triggerVertex2 },
-					new ExecutionVertex[] { ackVertex1, ackVertex2 },
-					new ExecutionVertex[] {}, cl, new StandaloneCheckpointIDCounter(), new
-					StandaloneCompletedCheckpointStore(1, cl), RecoveryMode.STANDALONE);
+				jid,
+				600000,
+				600000,
+				42,
+				new ExecutionVertex[] { triggerVertex1, triggerVertex2 },
+				new ExecutionVertex[] { ackVertex1, ackVertex2 },
+				new ExecutionVertex[] {},
+				cl,
+				new StandaloneCheckpointIDCounter(),
+				new StandaloneCompletedCheckpointStore(1, cl),
+				RecoveryMode.STANDALONE);
 
 			// nothing should be happening
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
@@ -214,13 +230,17 @@ public class CheckpointCoordinatorTest {
 
 			// set up the coordinator and validate the initial state
 			CheckpointCoordinator coord = new CheckpointCoordinator(
-				jid, 600000, 600000,
+				jid,
+				600000,
+				600000,
+				42,
 				new ExecutionVertex[] { vertex1, vertex2 },
 				new ExecutionVertex[] { vertex1, vertex2 },
 				new ExecutionVertex[] { vertex1, vertex2 },
 				cl,
 				new StandaloneCheckpointIDCounter(),
-				new StandaloneCompletedCheckpointStore(1, cl), RecoveryMode.STANDALONE);
+				new StandaloneCompletedCheckpointStore(1, cl),
+				RecoveryMode.STANDALONE);
 
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
@@ -241,7 +261,7 @@ public class CheckpointCoordinatorTest {
 			assertEquals(jid, checkpoint.getJobId());
 			assertEquals(2, checkpoint.getNumberOfNonAcknowledgedTasks());
 			assertEquals(0, checkpoint.getNumberOfAcknowledgedTasks());
-			assertEquals(0, checkpoint.getCollectedStates().size());
+			assertEquals(0, checkpoint.getTaskStates().size());
 			assertFalse(checkpoint.isDiscarded());
 			assertFalse(checkpoint.isFullyAcknowledged());
 
@@ -283,7 +303,7 @@ public class CheckpointCoordinatorTest {
 			assertEquals(jid, checkpointNew.getJobId());
 			assertEquals(2, checkpointNew.getNumberOfNonAcknowledgedTasks());
 			assertEquals(0, checkpointNew.getNumberOfAcknowledgedTasks());
-			assertEquals(0, checkpointNew.getCollectedStates().size());
+			assertEquals(0, checkpointNew.getTaskStates().size());
 			assertFalse(checkpointNew.isDiscarded());
 			assertFalse(checkpointNew.isFullyAcknowledged());
 			assertNotEquals(checkpoint.getCheckpointId(), checkpointNew.getCheckpointId());
@@ -333,13 +353,17 @@ public class CheckpointCoordinatorTest {
 
 			// set up the coordinator and validate the initial state
 			CheckpointCoordinator coord = new CheckpointCoordinator(
-				jid, 600000, 600000,
+				jid,
+				600000,
+				600000,
+				42,
 				new ExecutionVertex[] { vertex1, vertex2 },
 				new ExecutionVertex[] { vertex1, vertex2 },
 				new ExecutionVertex[] { vertex1, vertex2 },
 				cl,
 				new StandaloneCheckpointIDCounter(),
-				new StandaloneCompletedCheckpointStore(1, cl), RecoveryMode.STANDALONE);
+				new StandaloneCompletedCheckpointStore(1, cl),
+				RecoveryMode.STANDALONE);
 
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
@@ -366,7 +390,7 @@ public class CheckpointCoordinatorTest {
 			assertEquals(jid, checkpoint1.getJobId());
 			assertEquals(2, checkpoint1.getNumberOfNonAcknowledgedTasks());
 			assertEquals(0, checkpoint1.getNumberOfAcknowledgedTasks());
-			assertEquals(0, checkpoint1.getCollectedStates().size());
+			assertEquals(0, checkpoint1.getTaskStates().size());
 			assertFalse(checkpoint1.isDiscarded());
 			assertFalse(checkpoint1.isFullyAcknowledged());
 
@@ -376,7 +400,7 @@ public class CheckpointCoordinatorTest {
 			assertEquals(jid, checkpoint2.getJobId());
 			assertEquals(2, checkpoint2.getNumberOfNonAcknowledgedTasks());
 			assertEquals(0, checkpoint2.getNumberOfAcknowledgedTasks());
-			assertEquals(0, checkpoint2.getCollectedStates().size());
+			assertEquals(0, checkpoint2.getTaskStates().size());
 			assertFalse(checkpoint2.isDiscarded());
 			assertFalse(checkpoint2.isFullyAcknowledged());
 
@@ -415,7 +439,7 @@ public class CheckpointCoordinatorTest {
 			assertEquals(jid, checkpointNew.getJobId());
 			assertEquals(2, checkpointNew.getNumberOfNonAcknowledgedTasks());
 			assertEquals(0, checkpointNew.getNumberOfAcknowledgedTasks());
-			assertEquals(0, checkpointNew.getCollectedStates().size());
+			assertEquals(0, checkpointNew.getTaskStates().size());
 			assertFalse(checkpointNew.isDiscarded());
 			assertFalse(checkpointNew.isFullyAcknowledged());
 			assertNotEquals(checkpoint1.getCheckpointId(), checkpointNew.getCheckpointId());
@@ -448,12 +472,17 @@ public class CheckpointCoordinatorTest {
 
 			// set up the coordinator and validate the initial state
 			CheckpointCoordinator coord = new CheckpointCoordinator(
-					jid, 600000, 600000,
-					new ExecutionVertex[] { vertex1, vertex2 },
-					new ExecutionVertex[] { vertex1, vertex2 },
-					new ExecutionVertex[] { vertex1, vertex2 }, cl,
-					new StandaloneCheckpointIDCounter(),
-					new StandaloneCompletedCheckpointStore(1, cl), RecoveryMode.STANDALONE);
+				jid,
+				600000,
+				600000,
+				42,
+				new ExecutionVertex[] { vertex1, vertex2 },
+				new ExecutionVertex[] { vertex1, vertex2 },
+				new ExecutionVertex[] { vertex1, vertex2 },
+				cl,
+				new StandaloneCheckpointIDCounter(),
+				new StandaloneCompletedCheckpointStore(1, cl),
+				RecoveryMode.STANDALONE);
 
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
@@ -474,7 +503,7 @@ public class CheckpointCoordinatorTest {
 			assertEquals(jid, checkpoint.getJobId());
 			assertEquals(2, checkpoint.getNumberOfNonAcknowledgedTasks());
 			assertEquals(0, checkpoint.getNumberOfAcknowledgedTasks());
-			assertEquals(0, checkpoint.getCollectedStates().size());
+			assertEquals(0, checkpoint.getTaskStates().size());
 			assertFalse(checkpoint.isDiscarded());
 			assertFalse(checkpoint.isFullyAcknowledged());
 
@@ -521,7 +550,7 @@ public class CheckpointCoordinatorTest {
 			assertEquals(jid, success.getJobId());
 			assertEquals(timestamp, success.getTimestamp());
 			assertEquals(checkpoint.getCheckpointId(), success.getCheckpointID());
-			assertTrue(success.getStates().isEmpty());
+			assertTrue(success.getTaskStates().isEmpty());
 
 			// ---------------
 			// trigger another checkpoint and see that this one replaces the other checkpoint
@@ -540,7 +569,7 @@ public class CheckpointCoordinatorTest {
 			assertEquals(jid, successNew.getJobId());
 			assertEquals(timestampNew, successNew.getTimestamp());
 			assertEquals(checkpointIdNew, successNew.getCheckpointID());
-			assertTrue(successNew.getStates().isEmpty());
+			assertTrue(successNew.getTaskStates().isEmpty());
 
 			// validate that the relevant tasks got a confirmation message
 			{
@@ -592,12 +621,17 @@ public class CheckpointCoordinatorTest {
 
 			// set up the coordinator and validate the initial state
 			CheckpointCoordinator coord = new CheckpointCoordinator(
-					jid, 600000, 600000,
-					new ExecutionVertex[] { triggerVertex1, triggerVertex2 },
-					new ExecutionVertex[] { ackVertex1, ackVertex2, ackVertex3 },
-					new ExecutionVertex[] { commitVertex }, cl,
-					new StandaloneCheckpointIDCounter(),
-					new StandaloneCompletedCheckpointStore(2, cl), RecoveryMode.STANDALONE);
+				jid,
+				600000,
+				600000,
+				42,
+				new ExecutionVertex[] { triggerVertex1, triggerVertex2 },
+				new ExecutionVertex[] { ackVertex1, ackVertex2, ackVertex3 },
+				new ExecutionVertex[] { commitVertex },
+				cl,
+				new StandaloneCheckpointIDCounter(),
+				new StandaloneCompletedCheckpointStore(2, cl),
+				RecoveryMode.STANDALONE);
 
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
@@ -677,13 +711,13 @@ public class CheckpointCoordinatorTest {
 			assertEquals(checkpointId1, sc1.getCheckpointID());
 			assertEquals(timestamp1, sc1.getTimestamp());
 			assertEquals(jid, sc1.getJobId());
-			assertTrue(sc1.getStates().isEmpty());
+			assertTrue(sc1.getTaskStates().isEmpty());
 
 			CompletedCheckpoint sc2 = scs.get(1);
 			assertEquals(checkpointId2, sc2.getCheckpointID());
 			assertEquals(timestamp2, sc2.getTimestamp());
 			assertEquals(jid, sc2.getJobId());
-			assertTrue(sc2.getStates().isEmpty());
+			assertTrue(sc2.getTaskStates().isEmpty());
 
 			coord.shutdown();
 		}
@@ -721,12 +755,17 @@ public class CheckpointCoordinatorTest {
 
 			// set up the coordinator and validate the initial state
 			CheckpointCoordinator coord = new CheckpointCoordinator(
-					jid, 600000, 600000,
-					new ExecutionVertex[] { triggerVertex1, triggerVertex2 },
-					new ExecutionVertex[] { ackVertex1, ackVertex2, ackVertex3 },
-					new ExecutionVertex[] { commitVertex }, cl,
-					new StandaloneCheckpointIDCounter(),
-					new StandaloneCompletedCheckpointStore(10, cl), RecoveryMode.STANDALONE);
+				jid,
+				600000,
+				600000,
+				42,
+				new ExecutionVertex[] { triggerVertex1, triggerVertex2 },
+				new ExecutionVertex[] { ackVertex1, ackVertex2, ackVertex3 },
+				new ExecutionVertex[] { commitVertex },
+				cl,
+				new StandaloneCheckpointIDCounter(),
+				new StandaloneCompletedCheckpointStore(10, cl),
+				RecoveryMode.STANDALONE);
 
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
@@ -793,7 +832,7 @@ public class CheckpointCoordinatorTest {
 			assertEquals(checkpointId2, success.getCheckpointID());
 			assertEquals(timestamp2, success.getTimestamp());
 			assertEquals(jid, success.getJobId());
-			assertTrue(success.getStates().isEmpty());
+			assertTrue(success.getTaskStates().isEmpty());
 
 			// the first confirm message should be out
 			verify(commitVertex, times(1)).sendMessageToCurrentExecution(
@@ -836,12 +875,17 @@ public class CheckpointCoordinatorTest {
 			// the timeout for the checkpoint is a 200 milliseconds
 
 			CheckpointCoordinator coord = new CheckpointCoordinator(
-					jid, 600000, 200,
-					new ExecutionVertex[] { triggerVertex },
-					new ExecutionVertex[] { ackVertex1, ackVertex2 },
-					new ExecutionVertex[] { commitVertex }, cl,
-					new StandaloneCheckpointIDCounter(),
-					new StandaloneCompletedCheckpointStore(2, cl), RecoveryMode.STANDALONE);
+				jid,
+				600000,
+				200,
+				42,
+				new ExecutionVertex[] { triggerVertex },
+				new ExecutionVertex[] { ackVertex1, ackVertex2 },
+				new ExecutionVertex[] { commitVertex },
+				cl,
+				new StandaloneCheckpointIDCounter(),
+				new StandaloneCompletedCheckpointStore(2, cl),
+				RecoveryMode.STANDALONE);
 
 			// trigger a checkpoint, partially acknowledged
 			assertTrue(coord.triggerCheckpoint(timestamp));
@@ -898,11 +942,17 @@ public class CheckpointCoordinatorTest {
 			ExecutionVertex commitVertex = mockExecutionVertex(commitAttemptID);
 
 			CheckpointCoordinator coord = new CheckpointCoordinator(
-					jid, 200000, 200000,
-					new ExecutionVertex[] { triggerVertex },
-					new ExecutionVertex[] { ackVertex1, ackVertex2 },
-					new ExecutionVertex[] { commitVertex }, cl, new StandaloneCheckpointIDCounter
-					(), new StandaloneCompletedCheckpointStore(2, cl), RecoveryMode.STANDALONE);
+				jid,
+				200000,
+				200000,
+				42,
+				new ExecutionVertex[] { triggerVertex },
+				new ExecutionVertex[] { ackVertex1, ackVertex2 },
+				new ExecutionVertex[] { commitVertex },
+				cl,
+				new StandaloneCheckpointIDCounter(),
+				new StandaloneCompletedCheckpointStore(2, cl),
+				RecoveryMode.STANDALONE);
 
 			assertTrue(coord.triggerCheckpoint(timestamp));
 
@@ -970,13 +1020,17 @@ public class CheckpointCoordinatorTest {
 			}).when(triggerVertex).sendMessageToCurrentExecution(any(Serializable.class), any(ExecutionAttemptID.class));
 			
 			CheckpointCoordinator coord = new CheckpointCoordinator(
-					jid,
-					10,		// periodic interval is 10 ms
-					200000,	// timeout is very long (200 s)
-					new ExecutionVertex[] { triggerVertex },
-					new ExecutionVertex[] { ackVertex },
-					new ExecutionVertex[] { commitVertex }, cl, new StandaloneCheckpointIDCounter
-					(), new StandaloneCompletedCheckpointStore(2, cl), RecoveryMode.STANDALONE);
+				jid,
+				10,		// periodic interval is 10 ms
+				200000,	// timeout is very long (200 s)
+				42,
+				new ExecutionVertex[] { triggerVertex },
+				new ExecutionVertex[] { ackVertex },
+				new ExecutionVertex[] { commitVertex },
+				cl,
+				new StandaloneCheckpointIDCounter(),
+				new StandaloneCompletedCheckpointStore(2, cl),
+				RecoveryMode.STANDALONE);
 
 			
 			coord.startCheckpointScheduler();
@@ -1061,10 +1115,14 @@ public class CheckpointCoordinatorTest {
 				200000,	// timeout is very long (200 s)
 				500,	// 500ms delay between checkpoints
 				10,
+				42,
 				new ExecutionVertex[] { vertex1 },
 				new ExecutionVertex[] { vertex1 },
-				new ExecutionVertex[] { vertex1 }, cl, new StandaloneCheckpointIDCounter
-				(), new StandaloneCompletedCheckpointStore(2, cl), RecoveryMode.STANDALONE,
+				new ExecutionVertex[] { vertex1 },
+				cl,
+				new StandaloneCheckpointIDCounter(),
+				new StandaloneCompletedCheckpointStore(2, cl),
+				RecoveryMode.STANDALONE,
 				new DisabledCheckpointStatsTracker());
 
 			coord.startCheckpointScheduler();
@@ -1149,16 +1207,17 @@ public class CheckpointCoordinatorTest {
 			}).when(triggerVertex).sendMessageToCurrentExecution(any(Serializable.class), any(ExecutionAttemptID.class));
 
 			CheckpointCoordinator coord = new CheckpointCoordinator(
-					jid,
-					10,		// periodic interval is 10 ms
-					200000,	// timeout is very long (200 s)
-					0L,		// no extra delay
-					maxConcurrentAttempts,
-					new ExecutionVertex[] { triggerVertex },
-					new ExecutionVertex[] { ackVertex },
-					new ExecutionVertex[] { commitVertex }, cl, new StandaloneCheckpointIDCounter
-					(), new StandaloneCompletedCheckpointStore(2, cl), RecoveryMode.STANDALONE,
-					new DisabledCheckpointStatsTracker());
+				jid,
+				10,		// periodic interval is 10 ms
+				200000,	// timeout is very long (200 s)
+				0L,		// no extra delay
+				maxConcurrentAttempts,
+				42,
+				new ExecutionVertex[] { triggerVertex },
+				new ExecutionVertex[] { ackVertex },
+				new ExecutionVertex[] { commitVertex }, cl, new StandaloneCheckpointIDCounter
+				(), new StandaloneCompletedCheckpointStore(2, cl), RecoveryMode.STANDALONE,
+				new DisabledCheckpointStatsTracker());
 
 			coord.startCheckpointScheduler();
 
@@ -1219,16 +1278,17 @@ public class CheckpointCoordinatorTest {
 			ExecutionVertex commitVertex = mockExecutionVertex(commitAttemptID);
 
 			CheckpointCoordinator coord = new CheckpointCoordinator(
-					jid,
-					10,		// periodic interval is 10 ms
-					200000,	// timeout is very long (200 s)
-					0L,		// no extra delay
-					maxConcurrentAttempts, // max two concurrent checkpoints
-					new ExecutionVertex[] { triggerVertex },
-					new ExecutionVertex[] { ackVertex },
-					new ExecutionVertex[] { commitVertex }, cl, new StandaloneCheckpointIDCounter
-					(), new StandaloneCompletedCheckpointStore(2, cl), RecoveryMode.STANDALONE,
-					new DisabledCheckpointStatsTracker());
+				jid,
+				10,		// periodic interval is 10 ms
+				200000,	// timeout is very long (200 s)
+				0L,		// no extra delay
+				maxConcurrentAttempts, // max two concurrent checkpoints
+				42,
+				new ExecutionVertex[] { triggerVertex },
+				new ExecutionVertex[] { ackVertex },
+				new ExecutionVertex[] { commitVertex }, cl, new StandaloneCheckpointIDCounter
+				(), new StandaloneCompletedCheckpointStore(2, cl), RecoveryMode.STANDALONE,
+				new DisabledCheckpointStatsTracker());
 
 			coord.startCheckpointScheduler();
 
@@ -1298,16 +1358,17 @@ public class CheckpointCoordinatorTest {
 					});
 			
 			CheckpointCoordinator coord = new CheckpointCoordinator(
-					jid,
-					10,		// periodic interval is 10 ms
-					200000,	// timeout is very long (200 s)
-					0L,		// no extra delay
-					2, // max two concurrent checkpoints
-					new ExecutionVertex[] { triggerVertex },
-					new ExecutionVertex[] { ackVertex },
-					new ExecutionVertex[] { commitVertex }, cl, new StandaloneCheckpointIDCounter(),
-					new StandaloneCompletedCheckpointStore(2, cl), RecoveryMode.STANDALONE,
-					new DisabledCheckpointStatsTracker());
+				jid,
+				10,		// periodic interval is 10 ms
+				200000,	// timeout is very long (200 s)
+				0L,		// no extra delay
+				2, // max two concurrent checkpoints
+				42,
+				new ExecutionVertex[] { triggerVertex },
+				new ExecutionVertex[] { ackVertex },
+				new ExecutionVertex[] { commitVertex }, cl, new StandaloneCheckpointIDCounter(),
+				new StandaloneCompletedCheckpointStore(2, cl), RecoveryMode.STANDALONE,
+				new DisabledCheckpointStatsTracker());
 			
 			coord.startCheckpointScheduler();
 

http://git-wip-us.apache.org/repos/asf/flink/blob/0cf04108/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
index c35fd22..68cd145 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
@@ -64,11 +64,11 @@ public class CheckpointStateRestoreTest {
 			Execution statelessExec1 = mockExecution();
 			Execution statelessExec2 = mockExecution();
 
-			ExecutionVertex stateful1 = mockExecutionVertex(statefulExec1, statefulId, 0);
-			ExecutionVertex stateful2 = mockExecutionVertex(statefulExec2, statefulId, 1);
-			ExecutionVertex stateful3 = mockExecutionVertex(statefulExec3, statefulId, 2);
-			ExecutionVertex stateless1 = mockExecutionVertex(statelessExec1, statelessId, 0);
-			ExecutionVertex stateless2 = mockExecutionVertex(statelessExec2, statelessId, 1);
+			ExecutionVertex stateful1 = mockExecutionVertex(statefulExec1, statefulId, 0, 3);
+			ExecutionVertex stateful2 = mockExecutionVertex(statefulExec2, statefulId, 1, 3);
+			ExecutionVertex stateful3 = mockExecutionVertex(statefulExec3, statefulId, 2, 3);
+			ExecutionVertex stateless1 = mockExecutionVertex(statelessExec1, statelessId, 0, 2);
+			ExecutionVertex stateless2 = mockExecutionVertex(statelessExec2, statelessId, 1, 2);
 
 			ExecutionJobVertex stateful = mockExecutionJobVertex(statefulId,
 					new ExecutionVertex[] { stateful1, stateful2, stateful3 });
@@ -80,12 +80,18 @@ public class CheckpointStateRestoreTest {
 			map.put(statelessId, stateless);
 
 
-			CheckpointCoordinator coord = new CheckpointCoordinator(jid, 200000L, 200000L,
-					new ExecutionVertex[] { stateful1, stateful2, stateful3, stateless1, stateless2 },
-					new ExecutionVertex[] { stateful1, stateful2, stateful3, stateless1, stateless2 },
-					new ExecutionVertex[0], cl,
-					new StandaloneCheckpointIDCounter(),
-					new StandaloneCompletedCheckpointStore(1, cl), RecoveryMode.STANDALONE);
+			CheckpointCoordinator coord = new CheckpointCoordinator(
+				jid,
+				200000L,
+				200000L,
+				42,
+				new ExecutionVertex[] { stateful1, stateful2, stateful3, stateless1, stateless2 },
+				new ExecutionVertex[] { stateful1, stateful2, stateful3, stateless1, stateless2 },
+				new ExecutionVertex[0],
+				cl,
+				new StandaloneCheckpointIDCounter(),
+				new StandaloneCompletedCheckpointStore(1, cl),
+				RecoveryMode.STANDALONE);
 
 			// create ourselves a checkpoint with state
 			final long timestamp = 34623786L;
@@ -107,11 +113,11 @@ public class CheckpointStateRestoreTest {
 			coord.restoreLatestCheckpointedState(map, true, false);
 
 			// verify that each stateful vertex got the state
-			verify(statefulExec1, times(1)).setInitialState(Mockito.eq(serializedState), Mockito.anyLong());
-			verify(statefulExec2, times(1)).setInitialState(Mockito.eq(serializedState), Mockito.anyLong());
-			verify(statefulExec3, times(1)).setInitialState(Mockito.eq(serializedState), Mockito.anyLong());
-			verify(statelessExec1, times(0)).setInitialState(Mockito.<SerializedValue<StateHandle<?>>>any(), Mockito.anyLong());
-			verify(statelessExec2, times(0)).setInitialState(Mockito.<SerializedValue<StateHandle<?>>>any(), Mockito.anyLong());
+			verify(statefulExec1, times(1)).setInitialState(Mockito.eq(serializedState), Mockito.<Map<Integer, SerializedValue<StateHandle<?>>>>any(), Mockito.anyLong());
+			verify(statefulExec2, times(1)).setInitialState(Mockito.eq(serializedState), Mockito.<Map<Integer, SerializedValue<StateHandle<?>>>>any(), Mockito.anyLong());
+			verify(statefulExec3, times(1)).setInitialState(Mockito.eq(serializedState), Mockito.<Map<Integer, SerializedValue<StateHandle<?>>>>any(), Mockito.anyLong());
+			verify(statelessExec1, times(0)).setInitialState(Mockito.<SerializedValue<StateHandle<?>>>any(), Mockito.<Map<Integer, SerializedValue<StateHandle<?>>>>any(), Mockito.anyLong());
+			verify(statelessExec2, times(0)).setInitialState(Mockito.<SerializedValue<StateHandle<?>>>any(), Mockito.<Map<Integer, SerializedValue<StateHandle<?>>>>any(), Mockito.anyLong());
 		}
 		catch (Exception e) {
 			e.printStackTrace();
@@ -135,11 +141,11 @@ public class CheckpointStateRestoreTest {
 			Execution statelessExec1 = mockExecution();
 			Execution statelessExec2 = mockExecution();
 
-			ExecutionVertex stateful1 = mockExecutionVertex(statefulExec1, statefulId, 0);
-			ExecutionVertex stateful2 = mockExecutionVertex(statefulExec2, statefulId, 1);
-			ExecutionVertex stateful3 = mockExecutionVertex(statefulExec3, statefulId, 2);
-			ExecutionVertex stateless1 = mockExecutionVertex(statelessExec1, statelessId, 0);
-			ExecutionVertex stateless2 = mockExecutionVertex(statelessExec2, statelessId, 1);
+			ExecutionVertex stateful1 = mockExecutionVertex(statefulExec1, statefulId, 0, 3);
+			ExecutionVertex stateful2 = mockExecutionVertex(statefulExec2, statefulId, 1, 3);
+			ExecutionVertex stateful3 = mockExecutionVertex(statefulExec3, statefulId, 2, 3);
+			ExecutionVertex stateless1 = mockExecutionVertex(statelessExec1, statelessId, 0, 2);
+			ExecutionVertex stateless2 = mockExecutionVertex(statelessExec2, statelessId, 1, 2);
 
 			ExecutionJobVertex stateful = mockExecutionJobVertex(statefulId,
 					new ExecutionVertex[] { stateful1, stateful2, stateful3 });
@@ -151,12 +157,18 @@ public class CheckpointStateRestoreTest {
 			map.put(statelessId, stateless);
 
 
-			CheckpointCoordinator coord = new CheckpointCoordinator(jid, 200000L, 200000L,
-					new ExecutionVertex[] { stateful1, stateful2, stateful3, stateless1, stateless2 },
-					new ExecutionVertex[] { stateful1, stateful2, stateful3, stateless1, stateless2 },
-					new ExecutionVertex[0], cl,
-					new StandaloneCheckpointIDCounter(),
-					new StandaloneCompletedCheckpointStore(1, cl), RecoveryMode.STANDALONE);
+			CheckpointCoordinator coord = new CheckpointCoordinator(
+				jid,
+				200000L,
+				200000L,
+				42,
+				new ExecutionVertex[] { stateful1, stateful2, stateful3, stateless1, stateless2 },
+				new ExecutionVertex[] { stateful1, stateful2, stateful3, stateless1, stateless2 },
+				new ExecutionVertex[0],
+				cl,
+				new StandaloneCheckpointIDCounter(),
+				new StandaloneCompletedCheckpointStore(1, cl),
+				RecoveryMode.STANDALONE);
 
 			// create ourselves a checkpoint with state
 			final long timestamp = 34623786L;
@@ -193,12 +205,16 @@ public class CheckpointStateRestoreTest {
 	@Test
 	public void testNoCheckpointAvailable() {
 		try {
-			CheckpointCoordinator coord = new CheckpointCoordinator(new JobID(), 200000L, 200000L,
-					new ExecutionVertex[] { mock(ExecutionVertex.class) },
-					new ExecutionVertex[] { mock(ExecutionVertex.class) },
-					new ExecutionVertex[0], cl,
-					new StandaloneCheckpointIDCounter(),
-					new StandaloneCompletedCheckpointStore(1, cl), RecoveryMode.STANDALONE);
+			CheckpointCoordinator coord = new CheckpointCoordinator(
+				new JobID(),
+				200000L,
+				200000L,
+				42,
+				new ExecutionVertex[] { mock(ExecutionVertex.class) },
+				new ExecutionVertex[] { mock(ExecutionVertex.class) },
+				new ExecutionVertex[0], cl,
+				new StandaloneCheckpointIDCounter(),
+				new StandaloneCompletedCheckpointStore(1, cl), RecoveryMode.STANDALONE);
 
 			try {
 				coord.restoreLatestCheckpointedState(new HashMap<JobVertexID, ExecutionJobVertex>(), true, false);
@@ -227,11 +243,12 @@ public class CheckpointStateRestoreTest {
 		return mock;
 	}
 
-	private ExecutionVertex mockExecutionVertex(Execution execution, JobVertexID vertexId, int subtask) {
+	private ExecutionVertex mockExecutionVertex(Execution execution, JobVertexID vertexId, int subtask, int parallelism) {
 		ExecutionVertex mock = mock(ExecutionVertex.class);
 		when(mock.getJobvertexId()).thenReturn(vertexId);
 		when(mock.getParallelSubtaskIndex()).thenReturn(subtask);
 		when(mock.getCurrentExecutionAttempt()).thenReturn(execution);
+		when(mock.getTotalNumberOfParallelSubtasks()).thenReturn(parallelism);
 		return mock;
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/0cf04108/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
index 33026cc..0276a1f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
@@ -27,8 +27,9 @@ import org.apache.flink.util.TestLogger;
 import org.junit.Test;
 
 import java.io.IOException;
-import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.concurrent.CountDownLatch;
 
 import static org.junit.Assert.assertEquals;
@@ -197,31 +198,22 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 
 		JobVertexID jvid = new JobVertexID();
 
-		ArrayList<StateForTask> taskStates = new ArrayList<>();
+		Map<JobVertexID, TaskState> taskGroupStates = new HashMap<>();
+		TaskState taskState = new TaskState(jvid, numberOfStates);
+		taskGroupStates.put(jvid, taskState);
 
 		for (int i = 0; i < numberOfStates; i++) {
 			SerializedValue<StateHandle<?>> stateHandle = new SerializedValue<StateHandle<?>>(
 					new CheckpointMessagesTest.MyHandle());
 
-			taskStates.add(new StateForTask(stateHandle, 0, jvid, i, 0));
+			taskState.putState(i, new SubtaskState(stateHandle, 0, 0));
 		}
 
-		return new TestCheckpoint(new JobID(), id, 0, taskStates);
+		return new TestCheckpoint(new JobID(), id, 0, taskGroupStates);
 	}
 
 	private void verifyCheckpoint(CompletedCheckpoint expected, CompletedCheckpoint actual) {
-		assertEquals(expected.getJobId(), actual.getJobId());
-		assertEquals(expected.getCheckpointID(), actual.getCheckpointID());
-		assertEquals(expected.getTimestamp(), actual.getTimestamp());
-
-		List<StateForTask> expectedStates = expected.getStates();
-		List<StateForTask> actualStates = actual.getStates();
-
-		assertEquals(expectedStates.size(), actualStates.size());
-
-		for (int i = 0; i < expectedStates.size(); i++) {
-			assertEquals(expectedStates.get(i), actualStates.get(i));
-		}
+		assertEquals(expected, actual);
 	}
 
 	/**
@@ -241,12 +233,12 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 		private transient ClassLoader discardClassLoader;
 
 		public TestCheckpoint(
-				JobID jobId,
-				long checkpointId,
-				long timestamp,
-				ArrayList<StateForTask> states) {
+			JobID jobId,
+			long checkpointId,
+			long timestamp,
+			Map<JobVertexID, TaskState> taskGroupStates) {
 
-			super(jobId, checkpointId, timestamp, Long.MAX_VALUE, states);
+			super(jobId, checkpointId, timestamp, Long.MAX_VALUE, taskGroupStates);
 		}
 
 		@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/0cf04108/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java
index d1c74d6..965556f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java
@@ -65,6 +65,7 @@ public class ExecutionGraphCheckpointCoordinatorTest {
 					100,
 					100,
 					1,
+					42,
 					Collections.<ExecutionJobVertex>emptyList(),
 					Collections.<ExecutionJobVertex>emptyList(),
 					Collections.<ExecutionJobVertex>emptyList(),