You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ch...@apache.org on 2017/04/28 19:55:37 UTC

[5/6] flink git commit: [Flink-5892] Restore state on operator level

[Flink-5892] Restore state on operator level


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/f7980a7e
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/f7980a7e
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/f7980a7e

Branch: refs/heads/master
Commit: f7980a7e29457753eb3c5b975f3bb4b59d2014f8
Parents: 8045fab
Author: zentol <ch...@apache.org>
Authored: Fri Apr 28 19:40:20 2017 +0200
Committer: zentol <ch...@apache.org>
Committed: Fri Apr 28 20:11:35 2017 +0200

----------------------------------------------------------------------
 docs/ops/upgrading.md                           |   4 +-
 docs/setup/savepoints.md                        |   5 -
 .../checkpoint/savepoint/SavepointV0.java       |   6 +
 .../checkpoint/CheckpointCoordinator.java       |   9 +-
 .../runtime/checkpoint/CompletedCheckpoint.java |  57 +-
 .../flink/runtime/checkpoint/OperatorState.java | 183 ++++++
 .../checkpoint/OperatorSubtaskState.java        | 229 +++++++
 .../runtime/checkpoint/PendingCheckpoint.java   |  87 +--
 .../checkpoint/StateAssignmentOperation.java    | 611 ++++++++++++-------
 .../checkpoint/StateAssignmentOperationV2.java  | 458 --------------
 .../flink/runtime/checkpoint/TaskState.java     |   4 +
 .../runtime/checkpoint/savepoint/Savepoint.java |  12 +
 .../checkpoint/savepoint/SavepointLoader.java   |  54 +-
 .../checkpoint/savepoint/SavepointV1.java       |   6 +
 .../savepoint/SavepointV1Serializer.java        |   4 +-
 .../checkpoint/savepoint/SavepointV2.java       | 155 ++++-
 .../savepoint/SavepointV2Serializer.java        | 114 ++--
 .../executiongraph/ExecutionJobVertex.java      |  63 ++
 .../runtime/jobgraph/InputFormatVertex.java     |   4 +-
 .../flink/runtime/jobgraph/JobVertex.java       |  25 +-
 .../flink/runtime/jobgraph/OperatorID.java      |  45 ++
 ...tCoordinatorExternalizedCheckpointsTest.java |   5 +-
 .../CheckpointCoordinatorFailureTest.java       |  35 +-
 .../CheckpointCoordinatorMasterHooksTest.java   |   5 +-
 .../checkpoint/CheckpointCoordinatorTest.java   | 236 +++++--
 .../checkpoint/CheckpointStateRestoreTest.java  |  55 +-
 .../CompletedCheckpointStoreTest.java           |  59 +-
 .../checkpoint/CompletedCheckpointTest.java     |  47 +-
 .../checkpoint/PendingCheckpointTest.java       |  24 +-
 .../StandaloneCompletedCheckpointStoreTest.java |  10 +-
 ...ZooKeeperCompletedCheckpointStoreITCase.java |  14 +-
 .../ZooKeeperCompletedCheckpointStoreTest.java  |   2 +-
 .../savepoint/CheckpointTestUtils.java          |  97 ++-
 .../savepoint/SavepointLoaderTest.java          |  24 +-
 .../savepoint/SavepointStoreTest.java           |  23 +-
 .../savepoint/SavepointV2SerializerTest.java    |  22 +-
 .../checkpoint/savepoint/SavepointV2Test.java   |  14 +-
 .../executiongraph/LegacyJobVertexIdTest.java   |   4 +-
 .../RecoverableCompletedCheckpointStore.java    |   2 +-
 .../api/graph/StreamGraphHasherV1.java          |  13 -
 .../api/graph/StreamGraphHasherV2.java          |  13 -
 .../api/graph/StreamGraphUserHashHasher.java    |   9 -
 .../api/graph/StreamingJobGraphGenerator.java   |  46 +-
 .../StreamingJobGraphGeneratorNodeHashTest.java |  18 +-
 .../test/checkpointing/SavepointITCase.java     |  54 +-
 .../AbstractOperatorRestoreTestBase.java        | 261 ++++++++
 .../state/operator/restore/ExecutionMode.java   |  31 +
 .../restore/keyed/KeyedComplexChainTest.java    |  61 ++
 .../state/operator/restore/keyed/KeyedJob.java  | 243 ++++++++
 ...AbstractNonKeyedOperatorRestoreTestBase.java |  59 ++
 .../restore/unkeyed/ChainBreakTest.java         |  55 ++
 .../unkeyed/ChainLengthDecreaseTest.java        |  51 ++
 .../unkeyed/ChainLengthIncreaseTest.java        |  56 ++
 .../restore/unkeyed/ChainOrderTest.java         |  54 ++
 .../restore/unkeyed/ChainUnionTest.java         |  53 ++
 .../operator/restore/unkeyed/NonKeyedJob.java   | 198 ++++++
 .../operatorstate/complexKeyed/_metadata        | Bin 0 -> 137490 bytes
 .../resources/operatorstate/nonKeyed/_metadata  | Bin 0 -> 3212 bytes
 58 files changed, 2971 insertions(+), 1117 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/docs/ops/upgrading.md
----------------------------------------------------------------------
diff --git a/docs/ops/upgrading.md b/docs/ops/upgrading.md
index 194d8af..7259a6b 100644
--- a/docs/ops/upgrading.md
+++ b/docs/ops/upgrading.md
@@ -73,6 +73,8 @@ val mappedEvents: DataStream[(Int, Long)] = events
 
 **Note:** Since the operator IDs stored in a savepoint and IDs of operators in the application to start must be equal, it is highly recommended to assign unique IDs to all operators of an application that might be upgraded in the future. This advice applies to all operators, i.e., operators with and without explicitly declared operator state, because some operators have internal state that is not visible to the user. Upgrading an application without assigned operator IDs is significantly more difficult and may only be possible via a low-level workaround using the `setUidHash()` method.
 
+**Important:** As of 1.3.0 this also applies to operators that are part of a chain.
+
 By default all state stored in a savepoint must be matched to the operators of a starting application. However, users can explicitly agree to skip (and thereby discard) state that cannot be matched to an operator when starting a application from a savepoint. Stateful operators for which no state is found in the savepoint are initialized with their default state.
 
 ### Stateful Operators and User Functions
@@ -105,7 +107,7 @@ When upgrading an application by changing its topology, a few things need to be
 * **Adding a stateful operator:** The state of the operator will be initialized with the default state unless it takes over the state of another operator.
 * **Removing a stateful operator:** The state of the removed operator is lost unless another operator takes it over. When starting the upgraded application, you have to explicitly agree to discard the state.
 * **Changing of input and output types of operators:** When adding a new operator before or behind an operator with internal state, you have to ensure that the input or output type of the stateful operator is not modified to preserve the data type of the internal operator state (see above for details).
-* **Changing operator chaining:** Operators can be chained together for improved performance. However, chaining can limit the ability of an application to be upgraded if a chain contains a stateful operator that is not the first operator of the chain. In such a case, it is not possible to break the chain such that the stateful operator is moved out of the chain. It is also not possible to append or inject an existing stateful operator into a chain. The chaining behavior can be changed by modifying the parallelism of a chained operator or by adding or removing explicit operator chaining instructions. 
+* **Changing operator chaining:** Operators can be chained together for improved performance. When restoring from a savepoint taken since 1.3.0 it is possible to modify chains while preversing state consistency. It is possible a break the chain such that a stateful operator is moved out of the chain. It is also possible to append or inject a new or existing stateful operator into a chain, or to modify the operator order within a chain. However, when upgrading a savepoint to 1.3.0 it is paramount that the topology did not change in regards to chaining. All operators that are part of a chain should be assigned an ID as described in the [Matching Operator State](#Matching Operator State) section above.
 
 ## Upgrading the Flink Framework Version
 

http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/docs/setup/savepoints.md
----------------------------------------------------------------------
diff --git a/docs/setup/savepoints.md b/docs/setup/savepoints.md
index 4bdc43f..eada9b4 100644
--- a/docs/setup/savepoints.md
+++ b/docs/setup/savepoints.md
@@ -185,8 +185,3 @@ If you did not assign IDs, the auto generated IDs of the stateful operators will
 If the savepoint was triggered with Flink >= 1.2.0 and using no deprecated state API like `Checkpointed`, you can simply restore the program from a savepoint and specify a new parallelism.
 
 If you are resuming from a savepoint triggered with Flink < 1.2.0 or using now deprecated APIs you first have to migrate your job and savepoint to Flink 1.2.0 before being able to change the parallelism. See the [upgrading jobs and Flink versions guide]({{ site.baseurl }}/ops/upgrading.html).
-
-## Current limitations
-
-- **Chaining**: Chained operators are identified by the ID of the first task. It's not possible to manually assign an ID to an intermediate chained task, e.g. in the chain `[  a -> b -> c ]` only **a** can have its ID assigned manually, but not **b** or **c**. To work around this, you can [manually define the task chains](index.html#task-chaining-and-resource-groups). If you rely on the automatic ID assignment, a change in the chaining behaviour will also change the IDs.
-

http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/savepoint/SavepointV0.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/savepoint/SavepointV0.java b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/savepoint/SavepointV0.java
index f3ec1cf..7888d2f 100644
--- a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/savepoint/SavepointV0.java
+++ b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/savepoint/SavepointV0.java
@@ -20,6 +20,7 @@ package org.apache.flink.migration.runtime.checkpoint.savepoint;
 
 import org.apache.flink.migration.runtime.checkpoint.TaskState;
 import org.apache.flink.runtime.checkpoint.MasterState;
+import org.apache.flink.runtime.checkpoint.OperatorState;
 import org.apache.flink.runtime.checkpoint.savepoint.Savepoint;
 import org.apache.flink.util.Preconditions;
 
@@ -72,6 +73,11 @@ public class SavepointV0 implements Savepoint {
 	}
 
 	@Override
+	public Collection<OperatorState> getOperatorStates() {
+		return null;
+	}
+
+	@Override
 	public void dispose() throws Exception {
 		//NOP
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
index fb6cc72..96add06 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
@@ -36,6 +36,7 @@ import org.apache.flink.runtime.executiongraph.ExecutionVertex;
 import org.apache.flink.runtime.executiongraph.JobStatusListener;
 import org.apache.flink.runtime.jobgraph.JobStatus;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint;
@@ -892,7 +893,7 @@ public class CheckpointCoordinator {
 		if (LOG.isDebugEnabled()) {
 			StringBuilder builder = new StringBuilder();
 			builder.append("Checkpoint state: ");
-			for (TaskState state : completedCheckpoint.getTaskStates().values()) {
+			for (OperatorState state : completedCheckpoint.getOperatorStates().values()) {
 				builder.append(state);
 				builder.append(", ");
 			}
@@ -1017,11 +1018,11 @@ public class CheckpointCoordinator {
 			LOG.info("Restoring from latest valid checkpoint: {}.", latest);
 
 			// re-assign the task states
-
-			final Map<JobVertexID, TaskState> taskStates = latest.getTaskStates();
+			final Map<OperatorID, OperatorState> operatorStates = latest.getOperatorStates();
 
 			StateAssignmentOperation stateAssignmentOperation =
-					new StateAssignmentOperation(LOG, tasks, taskStates, allowNonRestoredState);
+					new StateAssignmentOperation(tasks, operatorStates, allowNonRestoredState);
+
 			stateAssignmentOperation.assignStates();
 
 			// call master hooks for restore

http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
index bb49b45..1ab5b41 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
@@ -18,10 +18,9 @@
 
 package org.apache.flink.runtime.checkpoint;
 
-import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.runtime.jobgraph.JobStatus;
-import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.StateUtil;
 import org.apache.flink.runtime.state.StreamStateHandle;
@@ -94,8 +93,8 @@ public class CompletedCheckpoint implements Serializable {
 	/** The duration of the checkpoint (completion timestamp - trigger timestamp). */
 	private final long duration;
 
-	/** States of the different task groups belonging to this checkpoint */
-	private final HashMap<JobVertexID, TaskState> taskStates;
+	/** States of the different operator groups belonging to this checkpoint */
+	private final Map<OperatorID, OperatorState> operatorStates;
 
 	/** Properties for this checkpoint. */
 	private final CheckpointProperties props;
@@ -117,38 +116,12 @@ public class CompletedCheckpoint implements Serializable {
 
 	// ------------------------------------------------------------------------
 
-	@VisibleForTesting
-	CompletedCheckpoint(
-			JobID job,
-			long checkpointID,
-			long timestamp,
-			long completionTimestamp,
-			Map<JobVertexID, TaskState> taskStates) {
-
-		this(job, checkpointID, timestamp, completionTimestamp, taskStates,
-				Collections.<MasterState>emptyList(),
-				CheckpointProperties.forStandardCheckpoint());
-	}
-
-	public CompletedCheckpoint(
-			JobID job,
-			long checkpointID,
-			long timestamp,
-			long completionTimestamp,
-			Map<JobVertexID, TaskState> taskStates,
-			@Nullable Collection<MasterState> masterHookStates,
-			CheckpointProperties props) {
-
-		this(job, checkpointID, timestamp, completionTimestamp, taskStates, 
-				masterHookStates, props, null, null);
-	}
-
 	public CompletedCheckpoint(
 			JobID job,
 			long checkpointID,
 			long timestamp,
 			long completionTimestamp,
-			Map<JobVertexID, TaskState> taskStates,
+			Map<OperatorID, OperatorState> operatorStates,
 			@Nullable Collection<MasterState> masterHookStates,
 			CheckpointProperties props,
 			@Nullable StreamStateHandle externalizedMetadata,
@@ -171,7 +144,7 @@ public class CompletedCheckpoint implements Serializable {
 
 		// we create copies here, to make sure we have no shared mutable
 		// data structure with the "outside world"
-		this.taskStates = new HashMap<>(checkNotNull(taskStates));
+		this.operatorStates = new HashMap<>(checkNotNull(operatorStates));
 		this.masterHookStates = masterHookStates == null || masterHookStates.isEmpty() ?
 				Collections.<MasterState>emptyList() :
 				new ArrayList<>(masterHookStates);
@@ -239,19 +212,15 @@ public class CompletedCheckpoint implements Serializable {
 	public long getStateSize() {
 		long result = 0L;
 
-		for (TaskState taskState : taskStates.values()) {
-			result += taskState.getStateSize();
+		for (OperatorState operatorState : operatorStates.values()) {
+			result += operatorState.getStateSize();
 		}
 
 		return result;
 	}
 
-	public Map<JobVertexID, TaskState> getTaskStates() {
-		return Collections.unmodifiableMap(taskStates);
-	}
-
-	public TaskState getTaskState(JobVertexID jobVertexID) {
-		return taskStates.get(jobVertexID);
+	public Map<OperatorID, OperatorState> getOperatorStates() {
+		return operatorStates;
 	}
 
 	public Collection<MasterState> getMasterHookStates() {
@@ -288,7 +257,7 @@ public class CompletedCheckpoint implements Serializable {
 	 * @param sharedStateRegistry The registry where shared states are registered
 	 */
 	public void registerSharedStates(SharedStateRegistry sharedStateRegistry) {
-		sharedStateRegistry.registerAll(taskStates.values());
+		sharedStateRegistry.registerAll(operatorStates.values());
 	}
 
 	// --------------------------------------------------------------------------------------------
@@ -338,7 +307,7 @@ public class CompletedCheckpoint implements Serializable {
 		protected void doDiscardPrivateState() {
 			// discard private state objects
 			try {
-				StateUtil.bestEffortDiscardAllStateObjects(taskStates.values());
+				StateUtil.bestEffortDiscardAllStateObjects(operatorStates.values());
 			} catch (Exception e) {
 				storedException = ExceptionUtils.firstOrSuppressed(e, storedException);
 			}
@@ -353,7 +322,7 @@ public class CompletedCheckpoint implements Serializable {
 		}
 
 		protected void clearTaskStatesAndNotifyDiscardCompleted() {
-			taskStates.clear();
+			operatorStates.clear();
 			// to be null-pointer safe, copy reference to stack
 			CompletedCheckpointStats.DiscardCallback discardCallback =
 				CompletedCheckpoint.this.discardCallback;
@@ -392,7 +361,7 @@ public class CompletedCheckpoint implements Serializable {
 
 		@Override
 		protected void doDiscardSharedState() {
-			sharedStateRegistry.unregisterAll(taskStates.values());
+			sharedStateRegistry.unregisterAll(operatorStates.values());
 		}
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorState.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorState.java
new file mode 100644
index 0000000..aa676e7
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorState.java
@@ -0,0 +1,183 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.state.CompositeStateHandle;
+import org.apache.flink.runtime.state.SharedStateRegistry;
+import org.apache.flink.util.Preconditions;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+
+/**
+ * Simple container class which contains the raw/managed/legacy operator state and key-group state handles for the sub
+ * tasks of an operator.
+ */
+public class OperatorState implements CompositeStateHandle {
+
+	private static final long serialVersionUID = -4845578005863201810L;
+
+	/** id of the operator */
+	private final OperatorID operatorID;
+
+	/** handles to non-partitioned states, subtaskindex -> subtaskstate */
+	private final Map<Integer, OperatorSubtaskState> operatorSubtaskStates;
+
+	/** parallelism of the operator when it was checkpointed */
+	private final int parallelism;
+
+	/** maximum parallelism of the operator when the job was first created */
+	private final int maxParallelism;
+
+	public OperatorState(OperatorID operatorID, int parallelism, int maxParallelism) {
+		Preconditions.checkArgument(
+			parallelism <= maxParallelism,
+			"Parallelism " + parallelism + " is not smaller or equal to max parallelism " + maxParallelism + ".");
+
+		this.operatorID = operatorID;
+
+		this.operatorSubtaskStates = new HashMap<>(parallelism);
+
+		this.parallelism = parallelism;
+		this.maxParallelism = maxParallelism;
+	}
+
+	public OperatorID getOperatorID() {
+		return operatorID;
+	}
+
+	public void putState(int subtaskIndex, OperatorSubtaskState subtaskState) {
+		Preconditions.checkNotNull(subtaskState);
+
+		if (subtaskIndex < 0 || subtaskIndex >= parallelism) {
+			throw new IndexOutOfBoundsException("The given sub task index " + subtaskIndex +
+				" exceeds the maximum number of sub tasks " + operatorSubtaskStates.size());
+		} else {
+			operatorSubtaskStates.put(subtaskIndex, subtaskState);
+		}
+	}
+
+	public OperatorSubtaskState getState(int subtaskIndex) {
+		if (subtaskIndex < 0 || subtaskIndex >= parallelism) {
+			throw new IndexOutOfBoundsException("The given sub task index " + subtaskIndex +
+				" exceeds the maximum number of sub tasks " + operatorSubtaskStates.size());
+		} else {
+			return operatorSubtaskStates.get(subtaskIndex);
+		}
+	}
+
+	public Collection<OperatorSubtaskState> getStates() {
+		return operatorSubtaskStates.values();
+	}
+
+	public int getNumberCollectedStates() {
+		return operatorSubtaskStates.size();
+	}
+
+	public int getParallelism() {
+		return parallelism;
+	}
+
+	public int getMaxParallelism() {
+		return maxParallelism;
+	}
+
+	public boolean hasNonPartitionedState() {
+		for (OperatorSubtaskState sts : operatorSubtaskStates.values()) {
+			if (sts != null && sts.getLegacyOperatorState() != null) {
+				return true;
+			}
+		}
+		return false;
+	}
+
+	@Override
+	public void discardState() throws Exception {
+		for (OperatorSubtaskState operatorSubtaskState : operatorSubtaskStates.values()) {
+			operatorSubtaskState.discardState();
+		}
+	}
+
+	@Override
+	public void registerSharedStates(SharedStateRegistry sharedStateRegistry) {
+		for (OperatorSubtaskState operatorSubtaskState : operatorSubtaskStates.values()) {
+			operatorSubtaskState.registerSharedStates(sharedStateRegistry);
+		}
+	}
+
+	@Override
+	public void unregisterSharedStates(SharedStateRegistry sharedStateRegistry) {
+		for (OperatorSubtaskState operatorSubtaskState : operatorSubtaskStates.values()) {
+			operatorSubtaskState.unregisterSharedStates(sharedStateRegistry);
+		}
+	}
+
+	@Override
+	public long getStateSize() {
+		long result = 0L;
+
+		for (int i = 0; i < parallelism; i++) {
+			OperatorSubtaskState operatorSubtaskState = operatorSubtaskStates.get(i);
+			if (operatorSubtaskState != null) {
+				result += operatorSubtaskState.getStateSize();
+			}
+		}
+
+		return result;
+	}
+
+	@Override
+	public boolean equals(Object obj) {
+		if (obj instanceof OperatorState) {
+			OperatorState other = (OperatorState) obj;
+
+			return operatorID.equals(other.operatorID)
+				&& parallelism == other.parallelism
+				&& operatorSubtaskStates.equals(other.operatorSubtaskStates);
+		} else {
+			return false;
+		}
+	}
+
+	@Override
+	public int hashCode() {
+		return parallelism + 31 * Objects.hash(operatorID, operatorSubtaskStates);
+	}
+
+	public Map<Integer, OperatorSubtaskState> getSubtaskStates() {
+		return Collections.unmodifiableMap(operatorSubtaskStates);
+	}
+
+	@Override
+	public String toString() {
+		// KvStates are always null in 1.1. Don't print this as it might
+		// confuse users that don't care about how we store it internally.
+		return "OperatorState(" +
+			"operatorID: " + operatorID +
+			", parallelism: " + parallelism +
+			", maxParallelism: " + maxParallelism +
+			", sub task states: " + operatorSubtaskStates.size() +
+			", total size (bytes): " + getStateSize() +
+			')';
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java
new file mode 100644
index 0000000..863816a
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.runtime.state.CompositeStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.SharedStateRegistry;
+import org.apache.flink.runtime.state.StateObject;
+import org.apache.flink.runtime.state.StateUtil;
+import org.apache.flink.runtime.state.StreamStateHandle;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Arrays;
+
+/**
+ * Container for the state of one parallel subtask of an operator. This is part of the {@link OperatorState}.
+ */
+public class OperatorSubtaskState implements CompositeStateHandle {
+
+	private static final Logger LOG = LoggerFactory.getLogger(OperatorSubtaskState.class);
+
+	private static final long serialVersionUID = -2394696997971923995L;
+
+	/**
+	 * Legacy (non-repartitionable) operator state.
+	 *
+	 * @deprecated Non-repartitionable operator state that has been deprecated.
+	 * Can be removed when we remove the APIs for non-repartitionable operator state.
+	 */
+	@Deprecated
+	private final StreamStateHandle legacyOperatorState;
+
+	/**
+	 * Snapshot from the {@link org.apache.flink.runtime.state.OperatorStateBackend}.
+	 */
+	private final OperatorStateHandle managedOperatorState;
+
+	/**
+	 * Snapshot written using {@link org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream}.
+	 */
+	private final OperatorStateHandle rawOperatorState;
+
+	/**
+	 * Snapshot from {@link org.apache.flink.runtime.state.KeyedStateBackend}.
+	 */
+	private final KeyedStateHandle managedKeyedState;
+
+	/**
+	 * Snapshot written using {@link org.apache.flink.runtime.state.KeyedStateCheckpointOutputStream}.
+	 */
+	private final KeyedStateHandle rawKeyedState;
+
+	/**
+	 * The state size. This is also part of the deserialized state handle.
+	 * We store it here in order to not deserialize the state handle when
+	 * gathering stats.
+	 */
+	private final long stateSize;
+
+	public OperatorSubtaskState(
+		StreamStateHandle legacyOperatorState,
+		OperatorStateHandle managedOperatorState,
+		OperatorStateHandle rawOperatorState,
+		KeyedStateHandle managedKeyedState,
+		KeyedStateHandle rawKeyedState) {
+
+		this.legacyOperatorState = legacyOperatorState;
+		this.managedOperatorState = managedOperatorState;
+		this.rawOperatorState = rawOperatorState;
+		this.managedKeyedState = managedKeyedState;
+		this.rawKeyedState = rawKeyedState;
+
+		try {
+			long calculateStateSize = getSizeNullSafe(legacyOperatorState);
+			calculateStateSize += getSizeNullSafe(managedOperatorState);
+			calculateStateSize += getSizeNullSafe(rawOperatorState);
+			calculateStateSize += getSizeNullSafe(managedKeyedState);
+			calculateStateSize += getSizeNullSafe(rawKeyedState);
+			stateSize = calculateStateSize;
+		} catch (Exception e) {
+			throw new RuntimeException("Failed to get state size.", e);
+		}
+	}
+
+	private static long getSizeNullSafe(StateObject stateObject) throws Exception {
+		return stateObject != null ? stateObject.getStateSize() : 0L;
+	}
+
+	// --------------------------------------------------------------------------------------------
+
+	/**
+	 * @deprecated Non-repartitionable operator state that has been deprecated.
+	 * Can be removed when we remove the APIs for non-repartitionable operator state.
+	 */
+	@Deprecated
+	public StreamStateHandle getLegacyOperatorState() {
+		return legacyOperatorState;
+	}
+
+	public OperatorStateHandle getManagedOperatorState() {
+		return managedOperatorState;
+	}
+
+	public OperatorStateHandle getRawOperatorState() {
+		return rawOperatorState;
+	}
+
+	public KeyedStateHandle getManagedKeyedState() {
+		return managedKeyedState;
+	}
+
+	public KeyedStateHandle getRawKeyedState() {
+		return rawKeyedState;
+	}
+
+	@Override
+	public void discardState() {
+		try {
+			StateUtil.bestEffortDiscardAllStateObjects(
+				Arrays.asList(
+					legacyOperatorState,
+					managedOperatorState,
+					rawOperatorState,
+					managedKeyedState,
+					rawKeyedState));
+		} catch (Exception e) {
+			LOG.warn("Error while discarding operator states.", e);
+		}
+	}
+
+	@Override
+	public void registerSharedStates(SharedStateRegistry sharedStateRegistry) {
+		// No shared states
+	}
+
+	@Override
+	public void unregisterSharedStates(SharedStateRegistry sharedStateRegistry) {
+		// No shared states
+	}
+
+	@Override
+	public long getStateSize() {
+		return stateSize;
+	}
+
+	// --------------------------------------------------------------------------------------------
+
+	@Override
+	public boolean equals(Object o) {
+		if (this == o) {
+			return true;
+		}
+		if (o == null || getClass() != o.getClass()) {
+			return false;
+		}
+
+		OperatorSubtaskState that = (OperatorSubtaskState) o;
+
+		if (stateSize != that.stateSize) {
+			return false;
+		}
+
+		if (legacyOperatorState != null ?
+			!legacyOperatorState.equals(that.legacyOperatorState)
+			: that.legacyOperatorState != null) {
+			return false;
+		}
+		if (managedOperatorState != null ?
+			!managedOperatorState.equals(that.managedOperatorState)
+			: that.managedOperatorState != null) {
+			return false;
+		}
+		if (rawOperatorState != null ?
+			!rawOperatorState.equals(that.rawOperatorState)
+			: that.rawOperatorState != null) {
+			return false;
+		}
+		if (managedKeyedState != null ?
+			!managedKeyedState.equals(that.managedKeyedState)
+			: that.managedKeyedState != null) {
+			return false;
+		}
+		return rawKeyedState != null ?
+			rawKeyedState.equals(that.rawKeyedState)
+			: that.rawKeyedState == null;
+
+	}
+
+	@Override
+	public int hashCode() {
+		int result = legacyOperatorState != null ? legacyOperatorState.hashCode() : 0;
+		result = 31 * result + (managedOperatorState != null ? managedOperatorState.hashCode() : 0);
+		result = 31 * result + (rawOperatorState != null ? rawOperatorState.hashCode() : 0);
+		result = 31 * result + (managedKeyedState != null ? managedKeyedState.hashCode() : 0);
+		result = 31 * result + (rawKeyedState != null ? rawKeyedState.hashCode() : 0);
+		result = 31 * result + (int) (stateSize ^ (stateSize >>> 32));
+		return result;
+	}
+
+	@Override
+	public String toString() {
+		return "SubtaskState{" +
+			"legacyState=" + legacyOperatorState +
+			", operatorStateFromBackend=" + managedOperatorState +
+			", operatorStateFromStream=" + rawOperatorState +
+			", keyedStateFromBackend=" + managedKeyedState +
+			", keyedStateFromStream=" + rawKeyedState +
+			", stateSize=" + stateSize +
+			'}';
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
index cc3dce2..370032a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
@@ -26,8 +26,9 @@ import org.apache.flink.runtime.concurrent.Future;
 import org.apache.flink.runtime.concurrent.impl.FlinkCompletableFuture;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
-import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateUtil;
 import org.apache.flink.runtime.state.StreamStateHandle;
@@ -87,7 +88,7 @@ public class PendingCheckpoint {
 
 	private final long checkpointTimestamp;
 
-	private final Map<JobVertexID, TaskState> taskStates;
+	private final Map<OperatorID, OperatorState> operatorStates;
 
 	private final Map<ExecutionAttemptID, ExecutionVertex> notYetAcknowledgedTasks;
 
@@ -146,7 +147,7 @@ public class PendingCheckpoint {
 		this.targetDirectory = targetDirectory;
 		this.executor = Preconditions.checkNotNull(executor);
 
-		this.taskStates = new HashMap<>();
+		this.operatorStates = new HashMap<>();
 		this.masterState = new ArrayList<>();
 		this.acknowledgedTasks = new HashSet<>(verticesToConfirm.size());
 		this.onCompletionPromise = new FlinkCompletableFuture<>();
@@ -178,8 +179,8 @@ public class PendingCheckpoint {
 		return numAcknowledgedTasks;
 	}
 
-	public Map<JobVertexID, TaskState> getTaskStates() {
-		return taskStates;
+	public Map<OperatorID, OperatorState> getOperatorStates() {
+		return operatorStates;
 	}
 
 	public boolean isFullyAcknowledged() {
@@ -261,7 +262,7 @@ public class PendingCheckpoint {
 			// make sure we fulfill the promise with an exception if something fails
 			try {
 				// externalize the metadata
-				final Savepoint savepoint = new SavepointV2(checkpointId, taskStates.values());
+				final Savepoint savepoint = new SavepointV2(checkpointId, operatorStates.values(), masterState);
 
 				// TEMP FIX - The savepoint store is strictly typed to file systems currently
 				//            but the checkpoints think more generic. we need to work with file handles
@@ -326,7 +327,7 @@ public class PendingCheckpoint {
 				checkpointId,
 				checkpointTimestamp,
 				System.currentTimeMillis(),
-				taskStates,
+				operatorStates,
 				masterState,
 				props,
 				externalMetadata,
@@ -380,41 +381,53 @@ public class PendingCheckpoint {
 				acknowledgedTasks.add(executionAttemptId);
 			}
 
-			JobVertexID jobVertexID = vertex.getJobvertexId();
+			List<OperatorID> operatorIDs = vertex.getJobVertex().getOperatorIDs();
 			int subtaskIndex = vertex.getParallelSubtaskIndex();
 			long ackTimestamp = System.currentTimeMillis();
 
 			long stateSize = 0;
-			if (null != subtaskState) {
-				TaskState taskState = taskStates.get(jobVertexID);
-
-				if (null == taskState) {
-					@SuppressWarnings("deprecation")
-					ChainedStateHandle<StreamStateHandle> nonPartitionedState = 
-							subtaskState.getLegacyOperatorState();
-					ChainedStateHandle<OperatorStateHandle> partitioneableState =
-							subtaskState.getManagedOperatorState();
-					//TODO this should go away when we remove chained state, assigning state to operators directly instead
-					int chainLength;
-					if (nonPartitionedState != null) {
-						chainLength = nonPartitionedState.getLength();
-					} else if (partitioneableState != null) {
-						chainLength = partitioneableState.getLength();
-					} else {
-						chainLength = 1;
-					}
+			if (subtaskState != null) {
+				stateSize = subtaskState.getStateSize();
 
-					taskState = new TaskState(
-							jobVertexID,
+				@SuppressWarnings("deprecation")
+				ChainedStateHandle<StreamStateHandle> nonPartitionedState =
+					subtaskState.getLegacyOperatorState();
+				ChainedStateHandle<OperatorStateHandle> partitioneableState =
+					subtaskState.getManagedOperatorState();
+				ChainedStateHandle<OperatorStateHandle> rawOperatorState =
+					subtaskState.getRawOperatorState();
+
+				// break task state apart into separate operator states
+				for (int x = 0; x < operatorIDs.size(); x++) {
+					OperatorID operatorID = operatorIDs.get(x);
+					OperatorState operatorState = operatorStates.get(operatorID);
+
+					if (operatorState == null) {
+						operatorState = new OperatorState(
+							operatorID,
 							vertex.getTotalNumberOfParallelSubtasks(),
-							vertex.getMaxParallelism(),
-							chainLength);
+							vertex.getMaxParallelism());
+						operatorStates.put(operatorID, operatorState);
+					}
 
-					taskStates.put(jobVertexID, taskState);
-				}
+					KeyedStateHandle managedKeyedState = null;
+					KeyedStateHandle rawKeyedState = null;
 
-				taskState.putState(subtaskIndex, subtaskState);
-				stateSize = subtaskState.getStateSize();
+					// only the head operator retains the keyed state
+					if (x == operatorIDs.size() - 1) {
+						managedKeyedState = subtaskState.getManagedKeyedState();
+						rawKeyedState = subtaskState.getRawKeyedState();
+					}
+
+					OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(
+							nonPartitionedState != null ? nonPartitionedState.get(x) : null,
+							partitioneableState != null ? partitioneableState.get(x) : null,
+							rawOperatorState != null ? rawOperatorState.get(x) : null,
+							managedKeyedState,
+							rawKeyedState);
+
+					operatorState.putState(subtaskIndex, operatorSubtaskState);
+				}
 			}
 
 			++numAcknowledgedTasks;
@@ -435,7 +448,7 @@ public class PendingCheckpoint {
 					metrics.getBytesBufferedInAlignment(),
 					alignmentDurationMillis);
 
-				statsCallback.reportSubtaskStats(jobVertexID, subtaskStateStats);
+				statsCallback.reportSubtaskStats(vertex.getJobvertexId(), subtaskStateStats);
 			}
 
 			return TaskAcknowledgeResult.SUCCESS;
@@ -530,12 +543,12 @@ public class PendingCheckpoint {
 							// discard the private states.
 							// unregistered shared states are still considered private at this point.
 							try {
-								StateUtil.bestEffortDiscardAllStateObjects(taskStates.values());
+								StateUtil.bestEffortDiscardAllStateObjects(operatorStates.values());
 							} catch (Throwable t) {
 								LOG.warn("Could not properly dispose the private states in the pending checkpoint {} of job {}.",
 									checkpointId, jobId, t);
 							} finally {
-								taskStates.clear();
+								operatorStates.clear();
 							}
 						}
 					});

http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
index ac70e1a..1042d5a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
@@ -18,9 +18,11 @@
 
 package org.apache.flink.runtime.checkpoint;
 
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
@@ -31,277 +33,400 @@ import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.util.Preconditions;
 import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 
 /**
  * This class encapsulates the operation of assigning restored state when restoring from a checkpoint.
  */
 public class StateAssignmentOperation {
 
-	private final Logger logger;
+	private static final Logger LOG = LoggerFactory.getLogger(StateAssignmentOperation.class);
+
 	private final Map<JobVertexID, ExecutionJobVertex> tasks;
-	private final Map<JobVertexID, TaskState> taskStates;
+	private final Map<OperatorID, OperatorState> operatorStates;
 	private final boolean allowNonRestoredState;
 
 	public StateAssignmentOperation(
-			Logger logger,
 			Map<JobVertexID, ExecutionJobVertex> tasks,
-			Map<JobVertexID, TaskState> taskStates,
+			Map<OperatorID, OperatorState> operatorStates,
 			boolean allowNonRestoredState) {
 
-		this.logger = Preconditions.checkNotNull(logger);
 		this.tasks = Preconditions.checkNotNull(tasks);
-		this.taskStates = Preconditions.checkNotNull(taskStates);
+		this.operatorStates = Preconditions.checkNotNull(operatorStates);
 		this.allowNonRestoredState = allowNonRestoredState;
 	}
 
 	public boolean assignStates() throws Exception {
-
-		// this tracks if we find missing node hash ids and already use secondary mappings
-		boolean expandedToLegacyIds = false;
-
+		Map<OperatorID, OperatorState> localOperators = new HashMap<>(operatorStates);
 		Map<JobVertexID, ExecutionJobVertex> localTasks = this.tasks;
 
-		for (Map.Entry<JobVertexID, TaskState> taskGroupStateEntry : taskStates.entrySet()) {
-
-			TaskState taskState = taskGroupStateEntry.getValue();
-
-			//----------------------------------------find vertex for state---------------------------------------------
-
-			ExecutionJobVertex executionJobVertex = localTasks.get(taskGroupStateEntry.getKey());
-
-			// on the first time we can not find the execution job vertex for an id, we also consider alternative ids,
-			// for example as generated from older flink versions, to provide backwards compatibility.
-			if (executionJobVertex == null && !expandedToLegacyIds) {
-				localTasks = ExecutionJobVertex.includeLegacyJobVertexIDs(localTasks);
-				executionJobVertex = localTasks.get(taskGroupStateEntry.getKey());
-				expandedToLegacyIds = true;
-				logger.info("Could not find ExecutionJobVertex. Including legacy JobVertexIDs in search.");
-			}
-
-			if (executionJobVertex == null) {
-				if (allowNonRestoredState) {
-					logger.info("Skipped checkpoint state for operator {}.", taskState.getJobVertexID());
-					continue;
+		checkStateMappingCompleteness(allowNonRestoredState, operatorStates, tasks);
+
+		for (Map.Entry<JobVertexID, ExecutionJobVertex> task : localTasks.entrySet()) {
+			final ExecutionJobVertex executionJobVertex = task.getValue();
+
+			// find the states of all operators belonging to this task
+			List<OperatorID> operatorIDs = executionJobVertex.getOperatorIDs();
+			List<OperatorID> altOperatorIDs = executionJobVertex.getUserDefinedOperatorIDs();
+			List<OperatorState> operatorStates = new ArrayList<>();
+			boolean statelessTask = true;
+			for (int x = 0; x < operatorIDs.size(); x++) {
+				OperatorID operatorID = altOperatorIDs.get(x) == null
+					? operatorIDs.get(x)
+					: altOperatorIDs.get(x);
+
+				OperatorState operatorState = localOperators.remove(operatorID);
+				if (operatorState == null) {
+					operatorState = new OperatorState(
+						operatorID,
+						executionJobVertex.getParallelism(),
+						executionJobVertex.getMaxParallelism());
 				} else {
-					throw new IllegalStateException("There is no execution job vertex for the job" +
-							" vertex ID " + taskGroupStateEntry.getKey());
+					statelessTask = false;
 				}
+				operatorStates.add(operatorState);
+			}
+			if (statelessTask) { // skip tasks where no operator has any state
+				continue;
 			}
 
-			checkParallelismPreconditions(taskState, executionJobVertex);
-
-			assignTaskStatesToOperatorInstances(taskState, executionJobVertex);
+			assignAttemptState(task.getValue(), operatorStates);
 		}
 
 		return true;
 	}
 
-	private void checkParallelismPreconditions(TaskState taskState, ExecutionJobVertex executionJobVertex) {
-		//----------------------------------------max parallelism preconditions-------------------------------------
+	private void assignAttemptState(ExecutionJobVertex executionJobVertex, List<OperatorState> operatorStates) {
 
-		// check that the number of key groups have not changed or if we need to override it to satisfy the restored state
-		if (taskState.getMaxParallelism() != executionJobVertex.getMaxParallelism()) {
+		List<OperatorID> operatorIDs = executionJobVertex.getOperatorIDs();
 
-			if (!executionJobVertex.isMaxParallelismConfigured()) {
-				// if the max parallelism was not explicitly specified by the user, we derive it from the state
+		//1. first compute the new parallelism
+		checkParallelismPreconditions(operatorStates, executionJobVertex);
+
+		int newParallelism = executionJobVertex.getParallelism();
+
+		List<KeyGroupRange> keyGroupPartitions = createKeyGroupPartitions(
+			executionJobVertex.getMaxParallelism(),
+			newParallelism);
+
+		//2. Redistribute the operator state.
+		/**
+		 *
+		 * Redistribute ManagedOperatorStates and RawOperatorStates from old parallelism to new parallelism.
+		 *
+		 * The old ManagedOperatorStates with old parallelism 3:
+		 *
+		 * 		parallelism0 parallelism1 parallelism2
+		 * op0   states0,0    state0,1	   state0,2
+		 * op1
+		 * op2   states2,0    state2,1	   state1,2
+		 * op3   states3,0    state3,1     state3,2
+		 *
+		 * The new ManagedOperatorStates with new parallelism 4:
+		 *
+		 * 		parallelism0 parallelism1 parallelism2 parallelism3
+		 * op0   state0,0	  state0,1 	   state0,2		state0,3
+		 * op1
+		 * op2   state2,0	  state2,1 	   state2,2		state2,3
+		 * op3   state3,0	  state3,1 	   state3,2		state3,3
+		 */
+		List<List<Collection<OperatorStateHandle>>> newManagedOperatorStates = new ArrayList<>();
+		List<List<Collection<OperatorStateHandle>>> newRawOperatorStates = new ArrayList<>();
+
+		reDistributePartitionableStates(operatorStates, newParallelism, newManagedOperatorStates, newRawOperatorStates);
+
+
+		//3. Compute TaskStateHandles of every subTask in the executionJobVertex
+		/**
+		 *  An executionJobVertex's all state handles needed to restore are something like a matrix
+		 *
+		 * 		parallelism0 parallelism1 parallelism2 parallelism3
+		 * op0   sh(0,0)     sh(0,1)       sh(0,2)	    sh(0,3)
+		 * op1   sh(1,0)	 sh(1,1)	   sh(1,2)	    sh(1,3)
+		 * op2   sh(2,0)	 sh(2,1)	   sh(2,2)		sh(2,3)
+		 * op3   sh(3,0)	 sh(3,1)	   sh(3,2)		sh(3,3)
+		 *
+		 * we will compute the state handles column by column.
+		 *
+		 */
+		for (int subTaskIndex = 0; subTaskIndex < newParallelism; subTaskIndex++) {
+
+			Execution currentExecutionAttempt = executionJobVertex.getTaskVertices()[subTaskIndex]
+				.getCurrentExecutionAttempt();
+
+			List<StreamStateHandle> subNonPartitionableState = new ArrayList<>();
+
+			Tuple2<Collection<KeyedStateHandle>, Collection<KeyedStateHandle>> subKeyedState = null;
+
+			List<Collection<OperatorStateHandle>> subManagedOperatorState = new ArrayList<>();
+			List<Collection<OperatorStateHandle>> subRawOperatorState = new ArrayList<>();
+
+
+			for (int operatorIndex = 0; operatorIndex < operatorIDs.size(); operatorIndex++) {
+				OperatorState operatorState = operatorStates.get(operatorIndex);
+				int oldParallelism = operatorState.getParallelism();
+
+				// NonPartitioned State
+
+				reAssignSubNonPartitionedStates(
+					operatorState,
+					subTaskIndex,
+					newParallelism,
+					oldParallelism,
+					subNonPartitionableState);
+
+				// PartitionedState
+				reAssignSubPartitionableState(newManagedOperatorStates,
+					newRawOperatorStates,
+					subTaskIndex,
+					operatorIndex,
+					subManagedOperatorState,
+					subRawOperatorState);
+
+				// KeyedState
+				if (operatorIndex == operatorIDs.size() - 1) {
+					subKeyedState = reAssignSubKeyedStates(operatorState,
+						keyGroupPartitions,
+						subTaskIndex,
+						newParallelism,
+						oldParallelism);
 
-				if (logger.isDebugEnabled()) {
-					logger.debug("Overriding maximum parallelism for JobVertex " + executionJobVertex.getJobVertexId()
-							+ " from " + executionJobVertex.getMaxParallelism() + " to " + taskState.getMaxParallelism());
 				}
+			}
 
-				executionJobVertex.setMaxParallelism(taskState.getMaxParallelism());
-			} else {
-				// if the max parallelism was explicitly specified, we complain on mismatch
-				throw new IllegalStateException("The maximum parallelism (" +
-						taskState.getMaxParallelism() + ") with which the latest " +
-						"checkpoint of the execution job vertex " + executionJobVertex +
-						" has been taken and the current maximum parallelism (" +
-						executionJobVertex.getMaxParallelism() + ") changed. This " +
-						"is currently not supported.");
+
+			// check if a stateless task
+			if (!allElementsAreNull(subNonPartitionableState) ||
+				!allElementsAreNull(subManagedOperatorState) ||
+				!allElementsAreNull(subRawOperatorState) ||
+				subKeyedState != null) {
+
+				TaskStateHandles taskStateHandles = new TaskStateHandles(
+
+					new ChainedStateHandle<>(subNonPartitionableState),
+					subManagedOperatorState,
+					subRawOperatorState,
+					subKeyedState != null ? subKeyedState.f0 : null,
+					subKeyedState != null ? subKeyedState.f1 : null);
+
+				currentExecutionAttempt.setInitialState(taskStateHandles);
 			}
 		}
+	}
 
-		//----------------------------------------parallelism preconditions-----------------------------------------
 
-		final int oldParallelism = taskState.getParallelism();
-		final int newParallelism = executionJobVertex.getParallelism();
+	public void checkParallelismPreconditions(List<OperatorState> operatorStates, ExecutionJobVertex executionJobVertex) {
 
-		if (taskState.hasNonPartitionedState() && (oldParallelism != newParallelism)) {
-			throw new IllegalStateException("Cannot restore the latest checkpoint because " +
-					"the operator " + executionJobVertex.getJobVertexId() + " has non-partitioned " +
-					"state and its parallelism changed. The operator " + executionJobVertex.getJobVertexId() +
-					" has parallelism " + newParallelism + " whereas the corresponding " +
-					"state object has a parallelism of " + oldParallelism);
+		for (OperatorState operatorState : operatorStates) {
+			checkParallelismPreconditions(operatorState, executionJobVertex);
 		}
 	}
 
-	private static void assignTaskStatesToOperatorInstances(
-			TaskState taskState, ExecutionJobVertex executionJobVertex) {
 
-		final int oldParallelism = taskState.getParallelism();
-		final int newParallelism = executionJobVertex.getParallelism();
+	private void reAssignSubPartitionableState(
+			List<List<Collection<OperatorStateHandle>>> newMangedOperatorStates,
+			List<List<Collection<OperatorStateHandle>>> newRawOperatorStates,
+			int subTaskIndex, int operatorIndex,
+			List<Collection<OperatorStateHandle>> subManagedOperatorState,
+			List<Collection<OperatorStateHandle>> subRawOperatorState) {
 
-		List<KeyGroupRange> keyGroupPartitions = createKeyGroupPartitions(
-				executionJobVertex.getMaxParallelism(),
-				newParallelism);
-
-		final int chainLength = taskState.getChainLength();
+		if (newMangedOperatorStates.get(operatorIndex) != null) {
+			subManagedOperatorState.add(newMangedOperatorStates.get(operatorIndex).get(subTaskIndex));
+		} else {
+			subManagedOperatorState.add(null);
+		}
+		if (newRawOperatorStates.get(operatorIndex) != null) {
+			subRawOperatorState.add(newRawOperatorStates.get(operatorIndex).get(subTaskIndex));
+		} else {
+			subRawOperatorState.add(null);
+		}
 
-		// operator chain idx -> list of the stored op states from all parallel instances for this chain idx
-		@SuppressWarnings("unchecked")
-		List<OperatorStateHandle>[] parallelOpStatesBackend = new List[chainLength];
-		@SuppressWarnings("unchecked")
-		List<OperatorStateHandle>[] parallelOpStatesStream = new List[chainLength];
 
-		List<KeyedStateHandle> parallelKeyedStatesBackend = new ArrayList<>(oldParallelism);
-		List<KeyedStateHandle> parallelKeyedStateStream = new ArrayList<>(oldParallelism);
+	}
 
-		for (int p = 0; p < oldParallelism; ++p) {
-			SubtaskState subtaskState = taskState.getState(p);
+	private Tuple2<Collection<KeyedStateHandle>, Collection<KeyedStateHandle>> reAssignSubKeyedStates(
+			OperatorState operatorState,
+			List<KeyGroupRange> keyGroupPartitions,
+			int subTaskIndex,
+			int newParallelism,
+			int oldParallelism) {
+
+		Collection<KeyedStateHandle> subManagedKeyedState;
+		Collection<KeyedStateHandle> subRawKeyedState;
+
+		if (newParallelism == oldParallelism) {
+			if (operatorState.getState(subTaskIndex) != null) {
+				KeyedStateHandle oldSubManagedKeyedState = operatorState.getState(subTaskIndex).getManagedKeyedState();
+				KeyedStateHandle oldSubRawKeyedState = operatorState.getState(subTaskIndex).getRawKeyedState();
+				subManagedKeyedState = oldSubManagedKeyedState != null ? Collections.singletonList(
+					oldSubManagedKeyedState) : null;
+				subRawKeyedState = oldSubRawKeyedState != null ? Collections.singletonList(
+					oldSubRawKeyedState) : null;
+			} else {
+				subManagedKeyedState = null;
+				subRawKeyedState = null;
+			}
+		} else {
+			subManagedKeyedState = getManagedKeyedStateHandles(operatorState, keyGroupPartitions.get(subTaskIndex));
+			subRawKeyedState = getRawKeyedStateHandles(operatorState, keyGroupPartitions.get(subTaskIndex));
+		}
+		if (subManagedKeyedState == null && subRawKeyedState == null) {
+			return null;
+		}
+		return new Tuple2<>(subManagedKeyedState, subRawKeyedState);
+	}
 
-			if (null != subtaskState) {
-				collectParallelStatesByChainOperator(
-						parallelOpStatesBackend, subtaskState.getManagedOperatorState());
 
-				collectParallelStatesByChainOperator(
-						parallelOpStatesStream, subtaskState.getRawOperatorState());
+	private <X> boolean allElementsAreNull(List<X> nonPartitionableStates) {
+		for (Object streamStateHandle : nonPartitionableStates) {
+			if (streamStateHandle != null) {
+				return false;
+			}
+		}
+		return true;
+	}
 
-				KeyedStateHandle keyedStateBackend = subtaskState.getManagedKeyedState();
-				if (null != keyedStateBackend) {
-					parallelKeyedStatesBackend.add(keyedStateBackend);
-				}
 
-				KeyedStateHandle keyedStateStream = subtaskState.getRawKeyedState();
-				if (null != keyedStateStream) {
-					parallelKeyedStateStream.add(keyedStateStream);
-				}
+	private void reAssignSubNonPartitionedStates(
+			OperatorState operatorState,
+			int subTaskIndex,
+			int newParallelism,
+			int oldParallelism,
+		List<StreamStateHandle> subNonPartitionableState) {
+		if (oldParallelism == newParallelism) {
+			if (operatorState.getState(subTaskIndex) != null) {
+				subNonPartitionableState.add(operatorState.getState(subTaskIndex).getLegacyOperatorState());
+			} else {
+				subNonPartitionableState.add(null);
 			}
+		} else {
+			subNonPartitionableState.add(null);
 		}
+	}
 
-		// operator chain index -> lists with collected states (one collection for each parallel subtasks)
-		@SuppressWarnings("unchecked")
-		List<Collection<OperatorStateHandle>>[] partitionedParallelStatesBackend = new List[chainLength];
+	private void reDistributePartitionableStates(
+			List<OperatorState> operatorStates, int newParallelism,
+			List<List<Collection<OperatorStateHandle>>> newManagedOperatorStates,
+			List<List<Collection<OperatorStateHandle>>> newRawOperatorStates) {
 
-		@SuppressWarnings("unchecked")
-		List<Collection<OperatorStateHandle>>[] partitionedParallelStatesStream = new List[chainLength];
+		//collect the old partitionalbe state
+		List<List<OperatorStateHandle>> oldManagedOperatorStates = new ArrayList<>();
+		List<List<OperatorStateHandle>> oldRawOperatorStates = new ArrayList<>();
 
-		//TODO here we can employ different redistribution strategies for state, e.g. union state.
-		// For now we only offer round robin as the default.
-		OperatorStateRepartitioner opStateRepartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE;
+		collectPartionableStates(operatorStates, oldManagedOperatorStates, oldRawOperatorStates);
 
-		for (int chainIdx = 0; chainIdx < chainLength; ++chainIdx) {
 
-			List<OperatorStateHandle> chainOpParallelStatesBackend = parallelOpStatesBackend[chainIdx];
-			List<OperatorStateHandle> chainOpParallelStatesStream = parallelOpStatesStream[chainIdx];
+		//redistribute
+		OperatorStateRepartitioner opStateRepartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE;
 
-			partitionedParallelStatesBackend[chainIdx] = applyRepartitioner(
-					opStateRepartitioner,
-					chainOpParallelStatesBackend,
-					oldParallelism,
-					newParallelism);
+		for (int operatorIndex = 0; operatorIndex < operatorStates.size(); operatorIndex++) {
+			int oldParallelism = operatorStates.get(operatorIndex).getParallelism();
+			newManagedOperatorStates.add(applyRepartitioner(opStateRepartitioner,
+				oldManagedOperatorStates.get(operatorIndex), oldParallelism, newParallelism));
+			newRawOperatorStates.add(applyRepartitioner(opStateRepartitioner,
+				oldRawOperatorStates.get(operatorIndex), oldParallelism, newParallelism));
 
-			partitionedParallelStatesStream[chainIdx] = applyRepartitioner(
-					opStateRepartitioner,
-					chainOpParallelStatesStream,
-					oldParallelism,
-					newParallelism);
 		}
+	}
 
-		for (int subTaskIdx = 0; subTaskIdx < newParallelism; ++subTaskIdx) {
-			// non-partitioned state
-			ChainedStateHandle<StreamStateHandle> nonPartitionableState = null;
-
-			if (oldParallelism == newParallelism) {
-				if (taskState.getState(subTaskIdx) != null) {
-					nonPartitionableState = taskState.getState(subTaskIdx).getLegacyOperatorState();
-				}
-			}
 
-			// partitionable state
-			@SuppressWarnings("unchecked")
-			Collection<OperatorStateHandle>[] iab = new Collection[chainLength];
-			@SuppressWarnings("unchecked")
-			Collection<OperatorStateHandle>[] ias = new Collection[chainLength];
-			List<Collection<OperatorStateHandle>> operatorStateFromBackend = Arrays.asList(iab);
-			List<Collection<OperatorStateHandle>> operatorStateFromStream = Arrays.asList(ias);
+	private void collectPartionableStates(
+			List<OperatorState> operatorStates,
+			List<List<OperatorStateHandle>> managedOperatorStates,
+			List<List<OperatorStateHandle>> rawOperatorStates) {
 
-			for (int chainIdx = 0; chainIdx < partitionedParallelStatesBackend.length; ++chainIdx) {
-				List<Collection<OperatorStateHandle>> redistributedOpStateBackend =
-						partitionedParallelStatesBackend[chainIdx];
+		for (OperatorState operatorState : operatorStates) {
+			List<OperatorStateHandle> managedOperatorState = null;
+			List<OperatorStateHandle> rawOperatorState = null;
 
-				List<Collection<OperatorStateHandle>> redistributedOpStateStream =
-						partitionedParallelStatesStream[chainIdx];
+			for (int i = 0; i < operatorState.getParallelism(); i++) {
+				OperatorSubtaskState operatorSubtaskState = operatorState.getState(i);
+				if (operatorSubtaskState != null) {
+					if (operatorSubtaskState.getManagedOperatorState() != null) {
+						if (managedOperatorState == null) {
+							managedOperatorState = new ArrayList<>();
+						}
+						managedOperatorState.add(operatorSubtaskState.getManagedOperatorState());
+					}
 
-				if (redistributedOpStateBackend != null) {
-					operatorStateFromBackend.set(chainIdx, redistributedOpStateBackend.get(subTaskIdx));
+					if (operatorSubtaskState.getRawOperatorState() != null) {
+						if (rawOperatorState == null) {
+							rawOperatorState = new ArrayList<>();
+						}
+						rawOperatorState.add(operatorSubtaskState.getRawOperatorState());
+					}
 				}
 
-				if (redistributedOpStateStream != null) {
-					operatorStateFromStream.set(chainIdx, redistributedOpStateStream.get(subTaskIdx));
-				}
 			}
+			managedOperatorStates.add(managedOperatorState);
+			rawOperatorStates.add(rawOperatorState);
+		}
+	}
 
-			Execution currentExecutionAttempt = executionJobVertex
-					.getTaskVertices()[subTaskIdx]
-					.getCurrentExecutionAttempt();
-
-			List<KeyedStateHandle> newKeyedStatesBackend;
-			List<KeyedStateHandle> newKeyedStateStream;
-			if (oldParallelism == newParallelism) {
-				SubtaskState subtaskState = taskState.getState(subTaskIdx);
-				if (subtaskState != null) {
-					KeyedStateHandle oldKeyedStatesBackend = subtaskState.getManagedKeyedState();
-					KeyedStateHandle oldKeyedStatesStream = subtaskState.getRawKeyedState();
-					newKeyedStatesBackend = oldKeyedStatesBackend != null ? Collections.singletonList(
-							oldKeyedStatesBackend) : null;
-					newKeyedStateStream = oldKeyedStatesStream != null ? Collections.singletonList(
-							oldKeyedStatesStream) : null;
-				} else {
-					newKeyedStatesBackend = null;
-					newKeyedStateStream = null;
-				}
-			} else {
-				KeyGroupRange subtaskKeyGroupIds = keyGroupPartitions.get(subTaskIdx);
-				newKeyedStatesBackend = getKeyedStateHandles(parallelKeyedStatesBackend, subtaskKeyGroupIds);
-				newKeyedStateStream = getKeyedStateHandles(parallelKeyedStateStream, subtaskKeyGroupIds);
-			}
 
-			TaskStateHandles taskStateHandles = new TaskStateHandles(
-					nonPartitionableState,
-					operatorStateFromBackend,
-					operatorStateFromStream,
-					newKeyedStatesBackend,
-					newKeyedStateStream);
+	/**
+	 * Collect {@link KeyGroupsStateHandle  managedKeyedStateHandles} which have intersection with given
+	 * {@link KeyGroupRange} from {@link TaskState operatorState}
+	 *
+	 * @param operatorState        all state handles of a operator
+	 * @param subtaskKeyGroupRange the KeyGroupRange of a subtask
+	 * @return all managedKeyedStateHandles which have intersection with given KeyGroupRange
+	 */
+	public static List<KeyedStateHandle> getManagedKeyedStateHandles(
+			OperatorState operatorState,
+			KeyGroupRange subtaskKeyGroupRange) {
+
+		List<KeyedStateHandle> subtaskKeyedStateHandles = null;
+
+		for (int i = 0; i < operatorState.getParallelism(); i++) {
+			if (operatorState.getState(i) != null && operatorState.getState(i).getManagedKeyedState() != null) {
+				KeyedStateHandle intersectedKeyedStateHandle = operatorState.getState(i).getManagedKeyedState().getIntersection(subtaskKeyGroupRange);
 
-			currentExecutionAttempt.setInitialState(taskStateHandles);
+				if (intersectedKeyedStateHandle != null) {
+					if (subtaskKeyedStateHandles == null) {
+						subtaskKeyedStateHandles = new ArrayList<>();
+					}
+					subtaskKeyedStateHandles.add(intersectedKeyedStateHandle);
+				}
+			}
 		}
+
+		return subtaskKeyedStateHandles;
 	}
 
 	/**
-	 * Determine the subset of {@link KeyGroupsStateHandle KeyGroupsStateHandles} with correct
-	 * key group index for the given subtask {@link KeyGroupRange}.
-	 * <p>
-	 * <p>This is publicly visible to be used in tests.
+	 * Collect {@link KeyGroupsStateHandle  rawKeyedStateHandles} which have intersection with given
+	 * {@link KeyGroupRange} from {@link TaskState operatorState}
+	 *
+	 * @param operatorState        all state handles of a operator
+	 * @param subtaskKeyGroupRange the KeyGroupRange of a subtask
+	 * @return all rawKeyedStateHandles which have intersection with given KeyGroupRange
 	 */
-	public static List<KeyedStateHandle> getKeyedStateHandles(
-			Collection<? extends KeyedStateHandle> keyedStateHandles,
-			KeyGroupRange subtaskKeyGroupRange) {
+	public static List<KeyedStateHandle> getRawKeyedStateHandles(
+		OperatorState operatorState,
+		KeyGroupRange subtaskKeyGroupRange) {
 
-		List<KeyedStateHandle> subtaskKeyedStateHandles = new ArrayList<>();
+		List<KeyedStateHandle> subtaskKeyedStateHandles = null;
 
-		for (KeyedStateHandle keyedStateHandle : keyedStateHandles) {
-			KeyedStateHandle intersectedKeyedStateHandle = keyedStateHandle.getIntersection(subtaskKeyGroupRange);
+		for (int i = 0; i < operatorState.getParallelism(); i++) {
+			if (operatorState.getState(i) != null && operatorState.getState(i).getRawKeyedState() != null) {
+				KeyedStateHandle intersectedKeyedStateHandle = operatorState.getState(i).getRawKeyedState().getIntersection(subtaskKeyGroupRange);
 
-			if (intersectedKeyedStateHandle != null) {
-				subtaskKeyedStateHandles.add(intersectedKeyedStateHandle);
+				if (intersectedKeyedStateHandle != null) {
+					if (subtaskKeyedStateHandles == null) {
+						subtaskKeyedStateHandles = new ArrayList<>();
+					}
+					subtaskKeyedStateHandles.add(intersectedKeyedStateHandle);
+				}
 			}
 		}
 
@@ -331,37 +456,90 @@ public class StateAssignmentOperation {
 	}
 
 	/**
-	 * @param chainParallelOpStates array = chain ops, array[idx] = parallel states for this chain op.
-	 * @param chainOpState the operator chain
+	 * Verifies conditions in regards to parallelism and maxParallelism that must be met when restoring state.
+	 *
+	 * @param operatorState      state to restore
+	 * @param executionJobVertex task for which the state should be restored
 	 */
-	private static void collectParallelStatesByChainOperator(
-			List<OperatorStateHandle>[] chainParallelOpStates, ChainedStateHandle<OperatorStateHandle> chainOpState) {
+	private static void checkParallelismPreconditions(OperatorState operatorState, ExecutionJobVertex executionJobVertex) {
+		//----------------------------------------max parallelism preconditions-------------------------------------
 
-		if (null != chainOpState) {
+		// check that the number of key groups have not changed or if we need to override it to satisfy the restored state
+		if (operatorState.getMaxParallelism() != executionJobVertex.getMaxParallelism()) {
 
-			int chainLength = chainOpState.getLength();
-			Preconditions.checkState(chainLength >= chainParallelOpStates.length,
-					"Found more states than operators in the chain. Chain length: " + chainLength +
-							", States: " + chainParallelOpStates.length);
+			if (!executionJobVertex.isMaxParallelismConfigured()) {
+				// if the max parallelism was not explicitly specified by the user, we derive it from the state
 
-			for (int chainIdx = 0; chainIdx < chainParallelOpStates.length; ++chainIdx) {
-				OperatorStateHandle operatorState = chainOpState.get(chainIdx);
+				LOG.debug("Overriding maximum parallelism for JobVertex {} from {} to {}",
+					executionJobVertex.getJobVertexId(), executionJobVertex.getMaxParallelism(), operatorState.getMaxParallelism());
 
-				if (null != operatorState) {
+				executionJobVertex.setMaxParallelism(operatorState.getMaxParallelism());
+			} else {
+				// if the max parallelism was explicitly specified, we complain on mismatch
+				throw new IllegalStateException("The maximum parallelism (" +
+					operatorState.getMaxParallelism() + ") with which the latest " +
+					"checkpoint of the execution job vertex " + executionJobVertex +
+					" has been taken and the current maximum parallelism (" +
+					executionJobVertex.getMaxParallelism() + ") changed. This " +
+					"is currently not supported.");
+			}
+		}
 
-					List<OperatorStateHandle> opParallelStatesForOneChainOp = chainParallelOpStates[chainIdx];
+		//----------------------------------------parallelism preconditions-----------------------------------------
 
-					if (null == opParallelStatesForOneChainOp) {
-						opParallelStatesForOneChainOp = new ArrayList<>();
-						chainParallelOpStates[chainIdx] = opParallelStatesForOneChainOp;
-					}
-					opParallelStatesForOneChainOp.add(operatorState);
+		final int oldParallelism = operatorState.getParallelism();
+		final int newParallelism = executionJobVertex.getParallelism();
+
+		if (operatorState.hasNonPartitionedState() && (oldParallelism != newParallelism)) {
+			throw new IllegalStateException("Cannot restore the latest checkpoint because " +
+				"the operator " + executionJobVertex.getJobVertexId() + " has non-partitioned " +
+				"state and its parallelism changed. The operator " + executionJobVertex.getJobVertexId() +
+				" has parallelism " + newParallelism + " whereas the corresponding " +
+				"state object has a parallelism of " + oldParallelism);
+		}
+	}
+
+	/**
+	 * Verifies that all operator states can be mapped to an execution job vertex.
+	 *
+	 * @param allowNonRestoredState if false an exception will be thrown if a state could not be mapped
+	 * @param operatorStates operator states to map
+	 * @param tasks task to map to
+	 */
+	private static void checkStateMappingCompleteness(
+			boolean allowNonRestoredState,
+			Map<OperatorID, OperatorState> operatorStates,
+			Map<JobVertexID, ExecutionJobVertex> tasks) {
+
+		Set<OperatorID> allOperatorIDs = new HashSet<>();
+		for (ExecutionJobVertex executionJobVertex : tasks.values()) {
+			allOperatorIDs.addAll(executionJobVertex.getOperatorIDs());
+		}
+		for (Map.Entry<OperatorID, OperatorState> operatorGroupStateEntry : operatorStates.entrySet()) {
+			OperatorState operatorState = operatorGroupStateEntry.getValue();
+			//----------------------------------------find operator for state---------------------------------------------
+
+			if (!allOperatorIDs.contains(operatorGroupStateEntry.getKey())) {
+				if (allowNonRestoredState) {
+					LOG.info("Skipped checkpoint state for operator {}.", operatorState.getOperatorID());
+				} else {
+					throw new IllegalStateException("There is no operator for the state " + operatorState.getOperatorID());
 				}
 			}
 		}
 	}
 
-	private static List<Collection<OperatorStateHandle>> applyRepartitioner(
+	/**
+	 * Repartitions the given operator state using the given {@link OperatorStateRepartitioner} with respect to the new
+	 * parallelism.
+	 *
+	 * @param opStateRepartitioner  partitioner to use
+	 * @param chainOpParallelStates state to repartition
+	 * @param oldParallelism        parallelism with which the state is currently partitioned
+	 * @param newParallelism        parallelism with which the state should be partitioned
+	 * @return repartitioned state
+	 */
+	public static List<Collection<OperatorStateHandle>> applyRepartitioner(
 			OperatorStateRepartitioner opStateRepartitioner,
 			List<OperatorStateHandle> chainOpParallelStates,
 			int oldParallelism,
@@ -399,4 +577,27 @@ public class StateAssignmentOperation {
 			return repackStream;
 		}
 	}
-}
\ No newline at end of file
+
+	/**
+	 * Determine the subset of {@link KeyGroupsStateHandle KeyGroupsStateHandles} with correct
+	 * key group index for the given subtask {@link KeyGroupRange}.
+	 * <p>
+	 * <p>This is publicly visible to be used in tests.
+	 */
+	public static List<KeyedStateHandle> getKeyedStateHandles(
+		Collection<? extends KeyedStateHandle> keyedStateHandles,
+		KeyGroupRange subtaskKeyGroupRange) {
+
+		List<KeyedStateHandle> subtaskKeyedStateHandles = new ArrayList<>();
+
+		for (KeyedStateHandle keyedStateHandle : keyedStateHandles) {
+			KeyedStateHandle intersectedKeyedStateHandle = keyedStateHandle.getIntersection(subtaskKeyGroupRange);
+
+			if (intersectedKeyedStateHandle != null) {
+				subtaskKeyedStateHandles.add(intersectedKeyedStateHandle);
+			}
+		}
+
+		return subtaskKeyedStateHandles;
+	}
+}