You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by al...@apache.org on 2016/10/20 14:15:26 UTC

[8/8] flink git commit: [FLINK-4844] Partitionable Raw Keyed/Operator State

[FLINK-4844] Partitionable Raw Keyed/Operator State


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/cab9cd44
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/cab9cd44
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/cab9cd44

Branch: refs/heads/master
Commit: cab9cd44eca83ef8cbcd2a2d070d8c79cb037977
Parents: 428419d
Author: Stefan Richter <s....@data-artisans.com>
Authored: Tue Oct 4 10:59:38 2016 +0200
Committer: Aljoscha Krettek <al...@gmail.com>
Committed: Thu Oct 20 16:14:21 2016 +0200

----------------------------------------------------------------------
 .../state/RocksDBKeyedStateBackend.java         |   5 +-
 .../streaming/state/RocksDBStateBackend.java    |   3 +-
 .../state/RocksDBAsyncSnapshotTest.java         |  10 +-
 .../flink/api/common/state/KeyedStateStore.java | 159 +++++++
 .../api/common/state/OperatorStateStore.java    |   9 +-
 .../core/fs/local/LocalDataInputStream.java     |  17 +-
 .../org/apache/flink/util/CollectionUtil.java   |  37 ++
 .../java/org/apache/flink/util/FutureUtil.java  |  42 ++
 .../flink/cep/operator/CEPOperatorTest.java     |  12 +-
 .../checkpoint/CheckpointCoordinator.java       | 212 +--------
 .../runtime/checkpoint/CheckpointMetaData.java  |  42 ++
 .../runtime/checkpoint/PendingCheckpoint.java   |  96 ++--
 .../RoundRobinOperatorStateRepartitioner.java   |   2 +-
 .../checkpoint/StateAssignmentOperation.java    | 329 +++++++++++++
 .../flink/runtime/checkpoint/SubtaskState.java  | 181 ++++++-
 .../flink/runtime/checkpoint/TaskState.java     |  90 +---
 .../savepoint/SavepointV1Serializer.java        | 161 ++++---
 .../deployment/TaskDeploymentDescriptor.java    |  36 +-
 .../flink/runtime/execution/Environment.java    |   6 +-
 .../flink/runtime/executiongraph/Execution.java |  43 +-
 .../runtime/executiongraph/ExecutionVertex.java |  24 +-
 .../runtime/fs/hdfs/HadoopDataInputStream.java  |   9 +-
 .../runtime/jobgraph/tasks/StatefulTask.java    |  18 +-
 .../checkpoint/AcknowledgeCheckpoint.java       |  20 +-
 .../flink/runtime/query/KvStateMessage.java     |   4 +-
 .../state/AbstractKeyedStateBackend.java        |   2 +-
 .../runtime/state/AbstractStateBackend.java     |   3 +-
 .../flink/runtime/state/BoundedInputStream.java | 112 +++++
 .../flink/runtime/state/ChainedStateHandle.java |   4 +
 .../runtime/state/CheckpointStateHandles.java   | 103 ----
 .../flink/runtime/state/ClosableRegistry.java   |  48 +-
 .../runtime/state/DefaultKeyedStateStore.java   |  89 ++++
 .../state/DefaultOperatorStateBackend.java      |  57 ++-
 .../state/FunctionInitializationContext.java    |  37 ++
 .../runtime/state/FunctionSnapshotContext.java  |  30 ++
 .../flink/runtime/state/KeyGroupRange.java      |  21 +-
 .../runtime/state/KeyGroupRangeOffsets.java     |   6 +-
 .../KeyGroupStatePartitionStreamProvider.java   |  51 ++
 .../flink/runtime/state/KeyGroupsList.java      |  43 ++
 .../flink/runtime/state/KeyedStateBackend.java  |   4 +-
 .../state/KeyedStateCheckpointOutputStream.java | 108 +++++
 .../state/ManagedInitializationContext.java     |  53 +++
 .../runtime/state/ManagedSnapshotContext.java   |  41 ++
 .../state/NonClosingCheckpointOutputStream.java |  80 ++++
 .../runtime/state/OperatorStateBackend.java     |   4 +-
 .../OperatorStateCheckpointOutputStream.java    |  78 +++
 .../runtime/state/OperatorStateHandle.java      |   4 +-
 ...artitionableCheckpointStateOutputStream.java |  96 ----
 .../flink/runtime/state/SnapshotProvider.java   |  45 --
 .../flink/runtime/state/Snapshotable.java       |  45 ++
 .../state/StateInitializationContext.java       |  52 ++
 .../state/StateInitializationContextImpl.java   | 270 +++++++++++
 .../state/StatePartitionStreamProvider.java     |  62 +++
 .../runtime/state/StateSnapshotContext.java     |  40 ++
 .../StateSnapshotContextSynchronousImpl.java    | 129 +++++
 .../flink/runtime/state/TaskStateHandles.java   | 172 +++++++
 .../runtime/state/UserFacingListState.java      |  57 +++
 .../state/filesystem/FsStateBackend.java        |   4 +-
 .../state/heap/HeapKeyedStateBackend.java       |  13 +-
 .../memory/MemCheckpointStreamFactory.java      |   4 +
 .../state/memory/MemoryStateBackend.java        |   4 +-
 .../ActorGatewayCheckpointResponder.java        |   4 +-
 .../taskmanager/CheckpointResponder.java        |   6 +-
 .../runtime/taskmanager/RuntimeEnvironment.java |   4 +-
 .../apache/flink/runtime/taskmanager/Task.java  |  50 +-
 .../apache/flink/runtime/util/IntArrayList.java |   5 +
 .../flink/runtime/util/LongArrayList.java       |   6 +
 .../runtime/util/NonClosingStreamDecorator.java |  79 ++++
 .../checkpoint/CheckpointCoordinatorTest.java   | 262 ++++++-----
 .../checkpoint/CheckpointStateRestoreTest.java  |  37 +-
 .../CompletedCheckpointStoreTest.java           |   2 +-
 .../savepoint/SavepointV1SerializerTest.java    |  26 +-
 .../checkpoint/savepoint/SavepointV1Test.java   | 100 +++-
 .../stats/SimpleCheckpointStatsTrackerTest.java |   3 +-
 .../jobmanager/JobManagerHARecoveryTest.java    |  17 +-
 .../messages/CheckpointMessagesTest.java        |  13 +-
 .../operators/testutils/DummyEnvironment.java   |   5 +-
 .../operators/testutils/MockEnvironment.java    |   5 +-
 .../runtime/state/KeyGroupRangeOffsetTest.java  |   4 +-
 .../flink/runtime/state/KeyGroupRangeTest.java  |   4 +-
 .../KeyedStateCheckpointOutputStreamTest.java   | 165 +++++++
 ...OperatorStateOutputCheckpointStreamTest.java | 102 ++++
 .../runtime/state/StateBackendTestBase.java     |   6 +-
 .../state/TestMemoryCheckpointOutputStream.java |  49 ++
 .../runtime/taskmanager/TaskAsyncCallTest.java  |   5 +-
 .../fs/bucketing/BucketingSinkTest.java         |   2 +-
 .../kafka/FlinkKafkaConsumerBase.java           |  58 +--
 .../kafka/FlinkKafkaProducerBase.java           |  10 +-
 .../kafka/AtLeastOnceProducerTest.java          |   8 +-
 .../kafka/FlinkKafkaConsumerBaseTest.java       | 127 +++--
 .../streaming/api/checkpoint/Checkpointed.java  |  12 +-
 .../api/checkpoint/CheckpointedFunction.java    |  44 +-
 .../api/checkpoint/CheckpointedRestoring.java   |  41 ++
 .../api/operators/AbstractStreamOperator.java   | 174 +++++--
 .../operators/AbstractUdfStreamOperator.java    | 105 +++--
 .../api/operators/OperatorSnapshotResult.java   |  81 ++++
 .../streaming/api/operators/StreamOperator.java |   8 +-
 .../api/operators/StreamingRuntimeContext.java  |  27 +-
 .../api/operators/UserFacingListState.java      |  57 ---
 .../runtime/tasks/OperatorStateHandles.java     | 109 +++++
 .../streaming/runtime/tasks/StreamTask.java     | 470 ++++++++++---------
 .../AbstractUdfStreamOperatorTest.java          | 219 +++++++++
 .../StateInitializationContextImplTest.java     | 260 ++++++++++
 ...StateSnapshotContextSynchronousImplTest.java |  61 +++
 .../StreamOperatorSnapshotRestoreTest.java      | 214 +++++++++
 .../operators/StreamingRuntimeContextTest.java  |  82 ++--
 .../streaming/runtime/io/BarrierBufferTest.java |   6 +-
 .../runtime/io/BarrierTrackerTest.java          |   6 +-
 .../operators/GenericWriteAheadSinkTest.java    |   6 +-
 .../operators/WriteAheadSinkTestBase.java       |  16 +-
 ...AlignedProcessingTimeWindowOperatorTest.java |   4 +-
 ...AlignedProcessingTimeWindowOperatorTest.java |   4 +-
 .../operators/windowing/WindowOperatorTest.java |  16 +-
 .../tasks/InterruptSensitiveRestoreTest.java    |  18 +-
 .../runtime/tasks/OneInputStreamTaskTest.java   |  66 +--
 .../runtime/tasks/StreamMockEnvironment.java    |   6 +-
 .../runtime/tasks/TwoInputStreamTaskTest.java   |  21 +-
 .../KeyedOneInputStreamOperatorTestHarness.java |  24 +-
 .../util/OneInputStreamOperatorTestHarness.java |  74 ++-
 .../util/TwoInputStreamOperatorTestHarness.java |  19 +
 .../streaming/util/WindowingTestHarness.java    |   2 +-
 .../test/checkpointing/RescalingITCase.java     | 149 ++++--
 .../test/checkpointing/SavepointITCase.java     |   5 +-
 .../streaming/runtime/StateBackendITCase.java   |   4 +-
 .../flink/yarn/cli/FlinkYarnSessionCli.java     |   6 +-
 125 files changed, 5337 insertions(+), 1781 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index 7ab35c4..f332d1e 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -65,6 +65,7 @@ import javax.annotation.concurrent.GuardedBy;
 import java.io.File;
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.Comparator;
 import java.util.HashMap;
 import java.util.List;
@@ -185,7 +186,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			TypeSerializer<K> keySerializer,
 			int numberOfKeyGroups,
 			KeyGroupRange keyGroupRange,
-			List<KeyGroupsStateHandle> restoreState
+			Collection<KeyGroupsStateHandle> restoreState
 	) throws Exception {
 
 		this(jobId,
@@ -603,7 +604,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		 * @throws ClassNotFoundException
 		 * @throws RocksDBException
 		 */
-		public void doRestore(List<KeyGroupsStateHandle> keyGroupsStateHandles)
+		public void doRestore(Collection<KeyGroupsStateHandle> keyGroupsStateHandles)
 				throws IOException, ClassNotFoundException, RocksDBException {
 
 			for (KeyGroupsStateHandle keyGroupsStateHandle : keyGroupsStateHandles) {

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
index a0c980b..82e7899 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
@@ -40,6 +40,7 @@ import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
 import java.net.URI;
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.List;
 import java.util.Random;
 import java.util.UUID;
@@ -258,7 +259,7 @@ public class RocksDBStateBackend extends AbstractStateBackend {
 			TypeSerializer<K> keySerializer,
 			int numberOfKeyGroups,
 			KeyGroupRange keyGroupRange,
-			List<KeyGroupsStateHandle> restoredState,
+			Collection<KeyGroupsStateHandle> restoredState,
 			TaskKvStateRegistry kvStateRegistry) throws Exception {
 
 		lazyInitializeForJob(env, operatorIdentifier);

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
index 8f58075..4d1ab50 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
@@ -28,9 +28,9 @@ import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.core.testutils.OneShotLatch;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.checkpoint.SubtaskState;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
@@ -70,7 +70,7 @@ import java.util.concurrent.CancellationException;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.TimeUnit;
 
-import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
 
 /**
  * Tests for asynchronous RocksDB Key/Value state checkpoints.
@@ -136,7 +136,7 @@ public class RocksDBAsyncSnapshotTest {
 			@Override
 			public void acknowledgeCheckpoint(
 					CheckpointMetaData checkpointMetaData,
-					CheckpointStateHandles checkpointStateHandles) {
+					SubtaskState checkpointStateHandles) {
 
 				super.acknowledgeCheckpoint(checkpointMetaData);
 
@@ -148,8 +148,8 @@ public class RocksDBAsyncSnapshotTest {
 					e.printStackTrace();
 				}
 
-				// should be only one k/v state
-				assertEquals(1, checkpointStateHandles.getKeyGroupsStateHandle().size());
+				// should be one k/v state
+				assertNotNull(checkpointStateHandles.getManagedKeyedState());
 
 				// we now know that the checkpoint went through
 				ensureCheckpointLatch.trigger();

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-core/src/main/java/org/apache/flink/api/common/state/KeyedStateStore.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/state/KeyedStateStore.java b/flink-core/src/main/java/org/apache/flink/api/common/state/KeyedStateStore.java
new file mode 100644
index 0000000..89c1240
--- /dev/null
+++ b/flink-core/src/main/java/org/apache/flink/api/common/state/KeyedStateStore.java
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.common.state;
+
+import org.apache.flink.annotation.PublicEvolving;
+
+/**
+ * This interface contains methods for registering keyed state with a managed store.
+ */
+@PublicEvolving
+public interface KeyedStateStore {
+
+	/**
+	 * Gets a handle to the system's key/value state. The key/value state is only accessible
+	 * if the function is executed on a KeyedStream. On each access, the state exposes the value
+	 * for the the key of the element currently processed by the function.
+	 * Each function may have multiple partitioned states, addressed with different names.
+	 *
+	 * <p>Because the scope of each value is the key of the currently processed element,
+	 * and the elements are distributed by the Flink runtime, the system can transparently
+	 * scale out and redistribute the state and KeyedStream.
+	 *
+	 * <p>The following code example shows how to implement a continuous counter that counts
+	 * how many times elements of a certain key occur, and emits an updated count for that
+	 * element on each occurrence.
+	 *
+	 * <pre>{@code
+	 * DataStream<MyType> stream = ...;
+	 * KeyedStream<MyType> keyedStream = stream.keyBy("id");
+	 *
+	 * keyedStream.map(new RichMapFunction<MyType, Tuple2<MyType, Long>>() {
+	 *
+	 *     private ValueState<Long> count;
+	 *
+	 *     public void open(Configuration cfg) {
+	 *         state = getRuntimeContext().getState(
+	 *                 new ValueStateDescriptor<Long>("count", LongSerializer.INSTANCE, 0L));
+	 *     }
+	 *
+	 *     public Tuple2<MyType, Long> map(MyType value) {
+	 *         long count = state.value() + 1;
+	 *         state.update(value);
+	 *         return new Tuple2<>(value, count);
+	 *     }
+	 * });
+	 * }</pre>
+	 *
+	 * @param stateProperties The descriptor defining the properties of the stats.
+	 *
+	 * @param <T> The type of value stored in the state.
+	 *
+	 * @return The partitioned state object.
+	 *
+	 * @throws UnsupportedOperationException Thrown, if no partitioned state is available for the
+	 *                                       function (function is not part of a KeyedStream).
+	 */
+	@PublicEvolving
+	<T> ValueState<T> getState(ValueStateDescriptor<T> stateProperties);
+
+	/**
+	 * Gets a handle to the system's key/value list state. This state is similar to the state
+	 * accessed via {@link #getState(ValueStateDescriptor)}, but is optimized for state that
+	 * holds lists. One can adds elements to the list, or retrieve the list as a whole.
+	 *
+	 * <p>This state is only accessible if the function is executed on a KeyedStream.
+	 *
+	 * <pre>{@code
+	 * DataStream<MyType> stream = ...;
+	 * KeyedStream<MyType> keyedStream = stream.keyBy("id");
+	 *
+	 * keyedStream.map(new RichFlatMapFunction<MyType, List<MyType>>() {
+	 *
+	 *     private ListState<MyType> state;
+	 *
+	 *     public void open(Configuration cfg) {
+	 *         state = getRuntimeContext().getListState(
+	 *                 new ListStateDescriptor<>("myState", MyType.class));
+	 *     }
+	 *
+	 *     public void flatMap(MyType value, Collector<MyType> out) {
+	 *         if (value.isDivider()) {
+	 *             for (MyType t : state.get()) {
+	 *                 out.collect(t);
+	 *             }
+	 *         } else {
+	 *             state.add(value);
+	 *         }
+	 *     }
+	 * });
+	 * }</pre>
+	 *
+	 * @param stateProperties The descriptor defining the properties of the stats.
+	 *
+	 * @param <T> The type of value stored in the state.
+	 *
+	 * @return The partitioned state object.
+	 *
+	 * @throws UnsupportedOperationException Thrown, if no partitioned state is available for the
+	 *                                       function (function is not part os a KeyedStream).
+	 */
+	@PublicEvolving
+	<T> ListState<T> getListState(ListStateDescriptor<T> stateProperties);
+
+	/**
+	 * Gets a handle to the system's key/value list state. This state is similar to the state
+	 * accessed via {@link #getState(ValueStateDescriptor)}, but is optimized for state that
+	 * aggregates values.
+	 *
+	 * <p>This state is only accessible if the function is executed on a KeyedStream.
+	 *
+	 * <pre>{@code
+	 * DataStream<MyType> stream = ...;
+	 * KeyedStream<MyType> keyedStream = stream.keyBy("id");
+	 *
+	 * keyedStream.map(new RichMapFunction<MyType, List<MyType>>() {
+	 *
+	 *     private ReducingState<Long> sum;
+	 *
+	 *     public void open(Configuration cfg) {
+	 *         state = getRuntimeContext().getReducingState(
+	 *                 new ReducingStateDescriptor<>("sum", MyType.class, 0L, (a, b) -> a + b));
+	 *     }
+	 *
+	 *     public Tuple2<MyType, Long> map(MyType value) {
+	 *         sum.add(value.count());
+	 *         return new Tuple2<>(value, sum.get());
+	 *     }
+	 * });
+	 *
+	 * }</pre>
+	 *
+	 * @param stateProperties The descriptor defining the properties of the stats.
+	 *
+	 * @param <T> The type of value stored in the state.
+	 *
+	 * @return The partitioned state object.
+	 *
+	 * @throws UnsupportedOperationException Thrown, if no partitioned state is available for the
+	 *                                       function (function is not part of a KeyedStream).
+	 */
+	@PublicEvolving
+	<T> ReducingState<T> getReducingState(ReducingStateDescriptor<T> stateProperties);
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java b/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java
index 03c11f6..43dbe51 100644
--- a/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java
+++ b/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java
@@ -18,16 +18,17 @@
 
 package org.apache.flink.api.common.state;
 
+import org.apache.flink.annotation.PublicEvolving;
+
 import java.io.Serializable;
 import java.util.Set;
 
 /**
- * Interface for a backend that manages operator state.
+ * This interface contains methods for registering operator state with a managed store.
  */
+@PublicEvolving
 public interface OperatorStateStore {
 
-	String DEFAULT_OPERATOR_STATE_NAME = "_default_";
-
 	/**
 	 * Creates a state descriptor of the given name that uses Java serialization to persist the
 	 * state.
@@ -39,7 +40,7 @@ public interface OperatorStateStore {
 	 * @return A list state using Java serialization to serialize state objects.
 	 * @throws Exception
 	 */
-	ListState<Serializable> getSerializableListState(String stateName) throws Exception;
+	<T extends Serializable> ListState<T> getSerializableListState(String stateName) throws Exception;
 
 	/**
 	 * Creates (or restores) a list state. Each state is registered under a unique name.

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-core/src/main/java/org/apache/flink/core/fs/local/LocalDataInputStream.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/core/fs/local/LocalDataInputStream.java b/flink-core/src/main/java/org/apache/flink/core/fs/local/LocalDataInputStream.java
index e7b2828..172da79 100644
--- a/flink-core/src/main/java/org/apache/flink/core/fs/local/LocalDataInputStream.java
+++ b/flink-core/src/main/java/org/apache/flink/core/fs/local/LocalDataInputStream.java
@@ -18,14 +18,14 @@
 
 package org.apache.flink.core.fs.local;
 
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.IOException;
-
 import org.apache.flink.annotation.Internal;
 import org.apache.flink.core.fs.FSDataInputStream;
 
 import javax.annotation.Nonnull;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.nio.channels.FileChannel;
 
 /**
  * The <code>LocalDataInputStream</code> class is a wrapper class for a data
@@ -36,6 +36,7 @@ public class LocalDataInputStream extends FSDataInputStream {
 
 	/** The file input stream used to read data from.*/
 	private final FileInputStream fis;
+	private final FileChannel fileChannel;
 
 	/**
 	 * Constructs a new <code>LocalDataInputStream</code> object from a given {@link File} object.
@@ -46,16 +47,19 @@ public class LocalDataInputStream extends FSDataInputStream {
 	 */
 	public LocalDataInputStream(File file) throws IOException {
 		this.fis = new FileInputStream(file);
+		this.fileChannel = fis.getChannel();
 	}
 
 	@Override
 	public void seek(long desired) throws IOException {
-		this.fis.getChannel().position(desired);
+		if (desired != getPos()) {
+			this.fileChannel.position(desired);
+		}
 	}
 
 	@Override
 	public long getPos() throws IOException {
-		return this.fis.getChannel().position();
+		return this.fileChannel.position();
 	}
 
 	@Override
@@ -70,6 +74,7 @@ public class LocalDataInputStream extends FSDataInputStream {
 	
 	@Override
 	public void close() throws IOException {
+		// Accoring to javadoc, this also closes the channel
 		this.fis.close();
 	}
 	

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-core/src/main/java/org/apache/flink/util/CollectionUtil.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/util/CollectionUtil.java b/flink-core/src/main/java/org/apache/flink/util/CollectionUtil.java
new file mode 100644
index 0000000..15d00ae
--- /dev/null
+++ b/flink-core/src/main/java/org/apache/flink/util/CollectionUtil.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.util;
+
+import java.util.Collection;
+import java.util.Map;
+
+public final class CollectionUtil {
+
+	private CollectionUtil() {
+		throw new AssertionError();
+	}
+
+	public static boolean isNullOrEmpty(Collection<?> collection) {
+		return collection == null || collection.isEmpty();
+	}
+
+	public static boolean isNullOrEmpty(Map<?, ?> map) {
+		return map == null || map.isEmpty();
+	}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-core/src/main/java/org/apache/flink/util/FutureUtil.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/util/FutureUtil.java b/flink-core/src/main/java/org/apache/flink/util/FutureUtil.java
new file mode 100644
index 0000000..62d836b
--- /dev/null
+++ b/flink-core/src/main/java/org/apache/flink/util/FutureUtil.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.util;
+
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.RunnableFuture;
+
+public class FutureUtil {
+
+	private FutureUtil() {
+		throw new AssertionError();
+	}
+
+	public static <T> T runIfNotDoneAndGet(RunnableFuture<T> future) throws ExecutionException, InterruptedException {
+
+		if (null == future) {
+			return null;
+		}
+
+		if (!future.isDone()) {
+			future.run();
+		}
+
+		return future.get();
+	}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPOperatorTest.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPOperatorTest.java b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPOperatorTest.java
index 1fd8de8..0f49b13 100644
--- a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPOperatorTest.java
+++ b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPOperatorTest.java
@@ -135,7 +135,7 @@ public class CEPOperatorTest extends TestLogger {
 		harness.processElement(new StreamRecord<Event>(new Event(42, "foobar", 1.0), 2));
 
 		// simulate snapshot/restore with some elements in internal sorting queue
-		StreamStateHandle snapshot = harness.snapshot(0, 0);
+		StreamStateHandle snapshot = harness.snapshotLegacy(0, 0);
 		harness.close();
 
 		harness = new OneInputStreamOperatorTestHarness<>(
@@ -157,7 +157,7 @@ public class CEPOperatorTest extends TestLogger {
 		harness.processWatermark(new Watermark(2));
 
 		// simulate snapshot/restore with empty element queue but NFA state
-		StreamStateHandle snapshot2 = harness.snapshot(1, 1);
+		StreamStateHandle snapshot2 = harness.snapshotLegacy(1, 1);
 		harness.close();
 
 		harness = new OneInputStreamOperatorTestHarness<>(
@@ -228,7 +228,7 @@ public class CEPOperatorTest extends TestLogger {
 		harness.processElement(new StreamRecord<Event>(new Event(42, "foobar", 1.0), 2));
 
 		// simulate snapshot/restore with some elements in internal sorting queue
-		StreamStateHandle snapshot = harness.snapshot(0, 0);
+		StreamStateHandle snapshot = harness.snapshotLegacy(0, 0);
 		harness.close();
 
 		harness = new KeyedOneInputStreamOperatorTestHarness<>(
@@ -254,7 +254,7 @@ public class CEPOperatorTest extends TestLogger {
 		harness.processWatermark(new Watermark(2));
 
 		// simulate snapshot/restore with empty element queue but NFA state
-		StreamStateHandle snapshot2 = harness.snapshot(1, 1);
+		StreamStateHandle snapshot2 = harness.snapshotLegacy(1, 1);
 		harness.close();
 
 		harness = new KeyedOneInputStreamOperatorTestHarness<>(
@@ -337,7 +337,7 @@ public class CEPOperatorTest extends TestLogger {
 		harness.processElement(new StreamRecord<Event>(new Event(42, "foobar", 1.0), 2));
 
 		// simulate snapshot/restore with some elements in internal sorting queue
-		StreamStateHandle snapshot = harness.snapshot(0, 0);
+		StreamStateHandle snapshot = harness.snapshotLegacy(0, 0);
 		harness.close();
 
 		harness = new KeyedOneInputStreamOperatorTestHarness<>(
@@ -368,7 +368,7 @@ public class CEPOperatorTest extends TestLogger {
 		harness.processWatermark(new Watermark(2));
 
 		// simulate snapshot/restore with empty element queue but NFA state
-		StreamStateHandle snapshot2 = harness.snapshot(1, 1);
+		StreamStateHandle snapshot2 = harness.snapshotLegacy(1, 1);
 		harness.close();
 
 		harness = new KeyedOneInputStreamOperatorTestHarness<>(

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
index 00028c4..588ba84 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
@@ -36,22 +36,10 @@ import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.NotifyCheckpointComplete;
 import org.apache.flink.runtime.messages.checkpoint.TriggerCheckpoint;
-import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
-import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.OperatorStateHandle;
-import org.apache.flink.runtime.state.StreamStateHandle;
-import org.apache.flink.util.Preconditions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.util.ArrayDeque;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collection;
-import java.util.Collections;
 import java.util.HashMap;
 import java.util.Iterator;
 import java.util.LinkedHashMap;
@@ -444,11 +432,11 @@ public class CheckpointCoordinator {
 							// note that checkpoint completion discards the pending checkpoint object
 							if (!checkpoint.isDiscarded()) {
 								LOG.info("Checkpoint " + checkpointID + " expired before completing.");
-	
+
 								checkpoint.abortExpired();
 								pendingCheckpoints.remove(checkpointID);
 								rememberRecentCheckpointId(checkpointID);
-	
+
 								triggerQueuedRequests();
 							}
 						}
@@ -578,7 +566,7 @@ public class CheckpointCoordinator {
 				isPendingCheckpoint = true;
 
 				LOG.info("Discarding checkpoint " + checkpointId
-					+ " because of checkpoint decline from task " + message.getTaskExecutionId());
+						+ " because of checkpoint decline from task " + message.getTaskExecutionId());
 
 				pendingCheckpoints.remove(checkpointId);
 				checkpoint.abortDeclined();
@@ -602,7 +590,7 @@ public class CheckpointCoordinator {
 			} else if (checkpoint != null) {
 				// this should not happen
 				throw new IllegalStateException(
-					"Received message for discarded but non-removed checkpoint " + checkpointId);
+						"Received message for discarded but non-removed checkpoint " + checkpointId);
 			} else {
 				// message is for an unknown checkpoint, or comes too late (checkpoint disposed)
 				if (recentPendingCheckpoints.contains(checkpointId)) {
@@ -660,7 +648,7 @@ public class CheckpointCoordinator {
 
 				if (checkpoint.acknowledgeTask(
 						message.getTaskExecutionId(),
-						message.getCheckpointStateHandles())) {
+						message.getSubtaskState())) {
 					if (checkpoint.isFullyAcknowledged()) {
 						completed = checkpoint.finalizeCheckpoint();
 
@@ -804,199 +792,15 @@ public class CheckpointCoordinator {
 
 			LOG.info("Restoring from latest valid checkpoint: {}.", latest);
 
-			for (Map.Entry<JobVertexID, TaskState> taskGroupStateEntry: latest.getTaskStates().entrySet()) {
-				TaskState taskState = taskGroupStateEntry.getValue();
-				ExecutionJobVertex executionJobVertex = tasks.get(taskGroupStateEntry.getKey());
-
-				if (executionJobVertex != null) {
-					// check that the number of key groups have not changed
-					if (taskState.getMaxParallelism() != executionJobVertex.getMaxParallelism()) {
-						throw new IllegalStateException("The maximum parallelism (" +
-							taskState.getMaxParallelism() + ") with which the latest " +
-							"checkpoint of the execution job vertex " + executionJobVertex +
-							" has been taken and the current maximum parallelism (" +
-							executionJobVertex.getMaxParallelism() + ") changed. This " +
-							"is currently not supported.");
-					}
-
-
-					int oldParallelism = taskState.getParallelism();
-					int newParallelism = executionJobVertex.getParallelism();
-					boolean parallelismChanged = oldParallelism != newParallelism;
-					boolean hasNonPartitionedState = taskState.hasNonPartitionedState();
-
-					if (hasNonPartitionedState && parallelismChanged) {
-						throw new IllegalStateException("Cannot restore the latest checkpoint because " +
-							"the operator " + executionJobVertex.getJobVertexId() + " has non-partitioned " +
-							"state and its parallelism changed. The operator" + executionJobVertex.getJobVertexId() +
-							" has parallelism " + newParallelism + " whereas the corresponding" +
-							"state object has a parallelism of " + oldParallelism);
-					}
-
-					List<KeyGroupRange> keyGroupPartitions = createKeyGroupPartitions(
-							executionJobVertex.getMaxParallelism(),
-							newParallelism);
-					
-					// operator chain index -> list of the stored partitionables states from all parallel instances
-					@SuppressWarnings("unchecked")
-					List<OperatorStateHandle>[] chainParallelStates =
-							new List[taskState.getChainLength()];
-
-					for (int i = 0; i < oldParallelism; ++i) {
-
-						ChainedStateHandle<OperatorStateHandle> partitionableState =
-								taskState.getPartitionableState(i);
-
-						if (partitionableState != null) {
-							for (int j = 0; j < partitionableState.getLength(); ++j) {
-								OperatorStateHandle opParalleState = partitionableState.get(j);
-								if (opParalleState != null) {
-									List<OperatorStateHandle> opParallelStates =
-											chainParallelStates[j];
-									if (opParallelStates == null) {
-										opParallelStates = new ArrayList<>();
-										chainParallelStates[j] = opParallelStates;
-									}
-									opParallelStates.add(opParalleState);
-								}
-							}
-						}
-					}
-
-					// operator chain index -> lists with collected states (one collection for each parallel subtasks)
-					@SuppressWarnings("unchecked")
-					List<Collection<OperatorStateHandle>>[] redistributedParallelStates =
-							new List[taskState.getChainLength()];
-
-					//TODO here we can employ different redistribution strategies for state, e.g. union state. For now we only offer round robin as the default.
-					OperatorStateRepartitioner repartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE;
-
-					for (int i = 0; i < chainParallelStates.length; ++i) {
-						List<OperatorStateHandle> chainOpParallelStates = chainParallelStates[i];
-						if (chainOpParallelStates != null) {
-							//We only redistribute if the parallelism of the operator changed from previous executions
-							if (parallelismChanged) {
-								redistributedParallelStates[i] = repartitioner.repartitionState(
-										chainOpParallelStates,
-										newParallelism);
-							} else {
-								List<Collection<OperatorStateHandle>> repacking = new ArrayList<>(newParallelism);
-								for (OperatorStateHandle operatorStateHandle : chainOpParallelStates) {
-									repacking.add(Collections.singletonList(operatorStateHandle));
-								}
-								redistributedParallelStates[i] = repacking;
-							}
-						}
-					}
-
-					int counter = 0;
-
-					for (int i = 0; i < newParallelism; ++i) {
-
-						// non-partitioned state
-						ChainedStateHandle<StreamStateHandle> state = null;
-
-						if (hasNonPartitionedState) {
-							SubtaskState subtaskState = taskState.getState(i);
-
-							if (subtaskState != null) {
-								// count the number of executions for which we set a state
-								++counter;
-								state = subtaskState.getChainedStateHandle();
-							}
-						}
-
-						// partitionable state
-						@SuppressWarnings("unchecked")
-						Collection<OperatorStateHandle>[] ia = new Collection[taskState.getChainLength()];
-						List<Collection<OperatorStateHandle>> subTaskPartitionableState = Arrays.asList(ia);
-
-						for (int j = 0; j < redistributedParallelStates.length; ++j) {
-							List<Collection<OperatorStateHandle>> redistributedParallelState =
-									redistributedParallelStates[j];
-
-							if (redistributedParallelState != null) {
-								subTaskPartitionableState.set(j, redistributedParallelState.get(i));
-							}
-						}
+			StateAssignmentOperation stateAssignmentOperation =
+					new StateAssignmentOperation(tasks, latest, allOrNothingState);
 
-						// key-partitioned state
-						KeyGroupRange subtaskKeyGroupIds = keyGroupPartitions.get(i);
-
-						// Again, we only repartition if the parallelism changed
-						List<KeyGroupsStateHandle> subtaskKeyGroupStates = parallelismChanged ?
-								getKeyGroupsStateHandles(taskState.getKeyGroupStates(), subtaskKeyGroupIds)
-								: Collections.singletonList(taskState.getKeyGroupState(i));
-
-						Execution currentExecutionAttempt = executionJobVertex
-							.getTaskVertices()[i]
-							.getCurrentExecutionAttempt();
-
-						CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(
-								state,
-								null/*subTaskPartionableState*/, //TODO chose right structure and put redistributed states here
-								subtaskKeyGroupStates);
-
-						currentExecutionAttempt.setInitialState(checkpointStateHandles, subTaskPartitionableState);
-					}
-
-					if (allOrNothingState && counter > 0 && counter < newParallelism) {
-						throw new IllegalStateException("The checkpoint contained state only for " +
-							"a subset of tasks for vertex " + executionJobVertex);
-					}
-				} else {
-					throw new IllegalStateException("There is no execution job vertex for the job" +
-						" vertex ID " + taskGroupStateEntry.getKey());
-				}
-			}
+			stateAssignmentOperation.assignStates();
 
 			return true;
 		}
 	}
 
-	/**
-	 * Determine the subset of {@link KeyGroupsStateHandle KeyGroupsStateHandles} with correct
-	 * key group index for the given subtask {@link KeyGroupRange}.
-	 *
-	 * <p>This is publicly visible to be used in tests.
-	 */
-	public static List<KeyGroupsStateHandle> getKeyGroupsStateHandles(
-			Collection<KeyGroupsStateHandle> allKeyGroupsHandles,
-			KeyGroupRange subtaskKeyGroupIds) {
-
-		List<KeyGroupsStateHandle> subtaskKeyGroupStates = new ArrayList<>();
-
-		for (KeyGroupsStateHandle storedKeyGroup : allKeyGroupsHandles) {
-			KeyGroupsStateHandle intersection = storedKeyGroup.getKeyGroupIntersection(subtaskKeyGroupIds);
-			if (intersection.getNumberOfKeyGroups() > 0) {
-				subtaskKeyGroupStates.add(intersection);
-			}
-		}
-		return subtaskKeyGroupStates;
-	}
-
-	/**
-	 * Groups the available set of key groups into key group partitions. A key group partition is
-	 * the set of key groups which is assigned to the same task. Each set of the returned list
-	 * constitutes a key group partition.
-	 *
-	 * <b>IMPORTANT</b>: The assignment of key groups to partitions has to be in sync with the
-	 * KeyGroupStreamPartitioner.
-	 *
-	 * @param numberKeyGroups Number of available key groups (indexed from 0 to numberKeyGroups - 1)
-	 * @param parallelism Parallelism to generate the key group partitioning for
-	 * @return List of key group partitions
-	 */
-	public static List<KeyGroupRange> createKeyGroupPartitions(int numberKeyGroups, int parallelism) {
-		Preconditions.checkArgument(numberKeyGroups >= parallelism);
-		List<KeyGroupRange> result = new ArrayList<>(parallelism);
-
-		for (int i = 0; i < parallelism; ++i) {
-			result.add(KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(numberKeyGroups, parallelism, i));
-		}
-		return result;
-	}
-
 	// --------------------------------------------------------------------------------------------
 	//  Accessors
 	// --------------------------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointMetaData.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointMetaData.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointMetaData.java
index 6f117f2..2627b22 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointMetaData.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointMetaData.java
@@ -59,6 +59,15 @@ public class CheckpointMetaData implements Serializable {
 				asynchronousDurationMillis);
 	}
 
+	public CheckpointMetaData(
+			long checkpointId,
+			long timestamp,
+			CheckpointMetrics metrics) {
+		this.checkpointId = checkpointId;
+		this.timestamp = timestamp;
+		this.metrics = Preconditions.checkNotNull(metrics);
+	}
+
 	public CheckpointMetrics getMetrics() {
 		return metrics;
 	}
@@ -110,4 +119,37 @@ public class CheckpointMetaData implements Serializable {
 	public long getAsyncDurationMillis() {
 		return metrics.getAsyncDurationMillis();
 	}
+
+	@Override
+	public boolean equals(Object o) {
+		if (this == o) {
+			return true;
+		}
+		if (o == null || getClass() != o.getClass()) {
+			return false;
+		}
+
+		CheckpointMetaData that = (CheckpointMetaData) o;
+
+		return (checkpointId == that.checkpointId)
+				&& (timestamp == that.timestamp)
+				&& (metrics.equals(that.metrics));
+	}
+
+	@Override
+	public int hashCode() {
+		int result = (int) (checkpointId ^ (checkpointId >>> 32));
+		result = 31 * result + (int) (timestamp ^ (timestamp >>> 32));
+		result = 31 * result + metrics.hashCode();
+		return result;
+	}
+
+	@Override
+	public String toString() {
+		return "CheckpointMetaData{" +
+				"checkpointId=" + checkpointId +
+				", timestamp=" + timestamp +
+				", metrics=" + metrics +
+				'}';
+	}
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
index 6f50392..92dca21 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
@@ -28,8 +28,6 @@ import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateUtil;
 import org.apache.flink.runtime.state.StreamStateHandle;
@@ -37,7 +35,6 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.util.HashMap;
-import java.util.List;
 import java.util.Map;
 
 import static org.apache.flink.util.Preconditions.checkArgument;
@@ -234,80 +231,61 @@ public class PendingCheckpoint {
 	
 	public boolean acknowledgeTask(
 			ExecutionAttemptID attemptID,
-			CheckpointStateHandles checkpointStateHandles) {
+			SubtaskState checkpointedSubtaskState) {
 
 		synchronized (lock) {
+
 			if (discarded) {
 				return false;
 			}
 
-			ExecutionVertex vertex = notYetAcknowledgedTasks.remove(attemptID);
-
-			if (vertex != null) {
-				if (checkpointStateHandles != null) {
-					List<KeyGroupsStateHandle> keyGroupsState = checkpointStateHandles.getKeyGroupsStateHandle();
-					ChainedStateHandle<StreamStateHandle> nonPartitionedState =
-							checkpointStateHandles.getNonPartitionedStateHandles();
-					ChainedStateHandle<OperatorStateHandle> partitioneableState =
-							checkpointStateHandles.getPartitioneableStateHandles();
-
-					if (nonPartitionedState != null || partitioneableState != null || keyGroupsState != null) {
-
-						JobVertexID jobVertexID = vertex.getJobvertexId();
+			final ExecutionVertex vertex = notYetAcknowledgedTasks.remove(attemptID);
 
-						int subtaskIndex = vertex.getParallelSubtaskIndex();
+			if (vertex == null) {
+				return false;
+			}
 
-						TaskState taskState;
+			if (null != checkpointedSubtaskState && checkpointedSubtaskState.hasState()) {
 
-						if (taskStates.containsKey(jobVertexID)) {
-							taskState = taskStates.get(jobVertexID);
-						} else {
-							//TODO this should go away when we remove chained state, assigning state to operators directly instead
-							int chainLength;
-							if (nonPartitionedState != null) {
-								chainLength = nonPartitionedState.getLength();
-							} else if (partitioneableState != null) {
-								chainLength = partitioneableState.getLength();
-							} else {
-								chainLength = 1;
-							}
+				JobVertexID jobVertexID = vertex.getJobvertexId();
 
-							taskState = new TaskState(
-								jobVertexID,
-								vertex.getTotalNumberOfParallelSubtasks(),
-								vertex.getMaxParallelism(),
-								chainLength);
+				int subtaskIndex = vertex.getParallelSubtaskIndex();
 
-							taskStates.put(jobVertexID, taskState);
-						}
+				TaskState taskState = taskStates.get(jobVertexID);
 
-						long duration = System.currentTimeMillis() - checkpointTimestamp;
-
-						if (nonPartitionedState != null) {
-							taskState.putState(
-									subtaskIndex,
-									new SubtaskState(nonPartitionedState, duration));
-						}
+				if (null == taskState) {
+					ChainedStateHandle<StreamStateHandle> nonPartitionedState =
+							checkpointedSubtaskState.getLegacyOperatorState();
+					ChainedStateHandle<OperatorStateHandle> partitioneableState =
+							checkpointedSubtaskState.getManagedOperatorState();
+					//TODO this should go away when we remove chained state, assigning state to operators directly instead
+					int chainLength;
+					if (nonPartitionedState != null) {
+						chainLength = nonPartitionedState.getLength();
+					} else if (partitioneableState != null) {
+						chainLength = partitioneableState.getLength();
+					} else {
+						chainLength = 1;
+					}
 
-						if(partitioneableState != null && !partitioneableState.isEmpty()) {
-							taskState.putPartitionableState(subtaskIndex, partitioneableState);
-						}
+					taskState = new TaskState(
+							jobVertexID,
+							vertex.getTotalNumberOfParallelSubtasks(),
+							vertex.getMaxParallelism(),
+							chainLength);
 
-						// currently a checkpoint can only contain keyed state
-						// for the head operator
-						if (keyGroupsState != null && !keyGroupsState.isEmpty()) {
-							KeyGroupsStateHandle keyGroupsStateHandle = keyGroupsState.get(0);
-							taskState.putKeyedState(subtaskIndex, keyGroupsStateHandle);
-						}
-					}
+					taskStates.put(jobVertexID, taskState);
 				}
 
-				++numAcknowledgedTasks;
+				long duration = System.currentTimeMillis() - checkpointTimestamp;
+				checkpointedSubtaskState.setDuration(duration);
 
-				return true;
-			} else {
-				return false;
+				taskState.putState(subtaskIndex, checkpointedSubtaskState);
 			}
+
+			++numAcknowledgedTasks;
+
+			return true;
 		}
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java
index 09a35f6..16a7e27 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java
@@ -176,7 +176,7 @@ public class RoundRobinOperatorStateRepartitioner implements OperatorStateRepart
 					Map<StreamStateHandle, OperatorStateHandle> mergeMap = mergeMapList.get(parallelOpIdx);
 					OperatorStateHandle psh = mergeMap.get(handleWithOffsets.f0);
 					if (psh == null) {
-						psh = new OperatorStateHandle(handleWithOffsets.f0, new HashMap<String, long[]>());
+						psh = new OperatorStateHandle(new HashMap<String, long[]>(), handleWithOffsets.f0);
 						mergeMap.put(handleWithOffsets.f0, psh);
 					}
 					psh.getStateNameToPartitionOffsets().put(e.getKey(), offs);

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
new file mode 100644
index 0000000..8e2b0bf
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
@@ -0,0 +1,329 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.runtime.executiongraph.Execution;
+import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.TaskStateHandles;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * This class encapsulates the operation of assigning restored state when restoring from a checkpoint.
+ */
+public class StateAssignmentOperation {
+
+	public StateAssignmentOperation(
+			Map<JobVertexID, ExecutionJobVertex> tasks,
+			CompletedCheckpoint latest,
+			boolean allOrNothingState) {
+
+		this.tasks = tasks;
+		this.latest = latest;
+		this.allOrNothingState = allOrNothingState;
+	}
+
+	private final Map<JobVertexID, ExecutionJobVertex> tasks;
+	private final CompletedCheckpoint latest;
+	private final boolean allOrNothingState;
+
+	public boolean assignStates() throws Exception {
+
+		for (Map.Entry<JobVertexID, TaskState> taskGroupStateEntry : latest.getTaskStates().entrySet()) {
+			TaskState taskState = taskGroupStateEntry.getValue();
+			ExecutionJobVertex executionJobVertex = tasks.get(taskGroupStateEntry.getKey());
+
+			if (executionJobVertex != null) {
+				// check that the number of key groups have not changed
+				if (taskState.getMaxParallelism() != executionJobVertex.getMaxParallelism()) {
+					throw new IllegalStateException("The maximum parallelism (" +
+							taskState.getMaxParallelism() + ") with which the latest " +
+							"checkpoint of the execution job vertex " + executionJobVertex +
+							" has been taken and the current maximum parallelism (" +
+							executionJobVertex.getMaxParallelism() + ") changed. This " +
+							"is currently not supported.");
+				}
+
+				final int oldParallelism = taskState.getParallelism();
+				final int newParallelism = executionJobVertex.getParallelism();
+				final boolean parallelismChanged = oldParallelism != newParallelism;
+				final boolean hasNonPartitionedState = taskState.hasNonPartitionedState();
+
+				if (hasNonPartitionedState && parallelismChanged) {
+					throw new IllegalStateException("Cannot restore the latest checkpoint because " +
+							"the operator " + executionJobVertex.getJobVertexId() + " has non-partitioned " +
+							"state and its parallelism changed. The operator" + executionJobVertex.getJobVertexId() +
+							" has parallelism " + newParallelism + " whereas the corresponding" +
+							"state object has a parallelism of " + oldParallelism);
+				}
+
+				List<KeyGroupRange> keyGroupPartitions = createKeyGroupPartitions(
+						executionJobVertex.getMaxParallelism(),
+						newParallelism);
+
+				final int chainLength = taskState.getChainLength();
+
+				// operator chain idx -> list of the stored op states from all parallel instances for this chain idx
+				@SuppressWarnings("unchecked")
+				List<OperatorStateHandle>[] parallelOpStatesBackend = new List[chainLength];
+				@SuppressWarnings("unchecked")
+				List<OperatorStateHandle>[] parallelOpStatesStream = new List[chainLength];
+
+				List<KeyGroupsStateHandle> parallelKeyedStatesBackend = new ArrayList<>(oldParallelism);
+				List<KeyGroupsStateHandle> parallelKeyedStateStream = new ArrayList<>(oldParallelism);
+
+				int counter = 0;
+				for (int p = 0; p < oldParallelism; ++p) {
+
+					SubtaskState subtaskState = taskState.getState(p);
+
+					if (null != subtaskState) {
+
+						++counter;
+
+						collectParallelStatesByChainOperator(
+								parallelOpStatesBackend, subtaskState.getManagedOperatorState());
+
+						collectParallelStatesByChainOperator(
+								parallelOpStatesStream, subtaskState.getRawOperatorState());
+
+						KeyGroupsStateHandle keyedStateBackend = subtaskState.getManagedKeyedState();
+						if (null != keyedStateBackend) {
+							parallelKeyedStatesBackend.add(keyedStateBackend);
+						}
+
+						KeyGroupsStateHandle keyedStateStream = subtaskState.getRawKeyedState();
+						if (null != keyedStateStream) {
+							parallelKeyedStateStream.add(keyedStateStream);
+						}
+					}
+				}
+
+				if (allOrNothingState && counter > 0 && counter < oldParallelism) {
+					throw new IllegalStateException("The checkpoint contained state only for " +
+							"a subset of tasks for vertex " + executionJobVertex);
+				}
+
+				// operator chain index -> lists with collected states (one collection for each parallel subtasks)
+				@SuppressWarnings("unchecked")
+				List<Collection<OperatorStateHandle>>[] partitionedParallelStatesBackend = new List[chainLength];
+
+				@SuppressWarnings("unchecked")
+				List<Collection<OperatorStateHandle>>[] partitionedParallelStatesStream = new List[chainLength];
+
+				//TODO here we can employ different redistribution strategies for state, e.g. union state.
+				// For now we only offer round robin as the default.
+				OperatorStateRepartitioner opStateRepartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE;
+
+				for (int chainIdx = 0; chainIdx < chainLength; ++chainIdx) {
+
+					List<OperatorStateHandle> chainOpParallelStatesBackend = parallelOpStatesBackend[chainIdx];
+					List<OperatorStateHandle> chainOpParallelStatesStream = parallelOpStatesStream[chainIdx];
+
+					partitionedParallelStatesBackend[chainIdx] = applyRepartitioner(
+							opStateRepartitioner,
+							chainOpParallelStatesBackend,
+							oldParallelism,
+							newParallelism);
+
+					partitionedParallelStatesStream[chainIdx] = applyRepartitioner(
+							opStateRepartitioner,
+							chainOpParallelStatesStream,
+							oldParallelism,
+							newParallelism);
+				}
+
+				for (int subTaskIdx = 0; subTaskIdx < newParallelism; ++subTaskIdx) {
+					// non-partitioned state
+					ChainedStateHandle<StreamStateHandle> nonPartitionableState = null;
+
+					if (hasNonPartitionedState) {
+						// count the number of executions for which we set a state
+						nonPartitionableState = taskState.getState(subTaskIdx).getLegacyOperatorState();
+					}
+
+					// partitionable state
+					@SuppressWarnings("unchecked")
+					Collection<OperatorStateHandle>[] iab = new Collection[chainLength];
+					@SuppressWarnings("unchecked")
+					Collection<OperatorStateHandle>[] ias = new Collection[chainLength];
+					List<Collection<OperatorStateHandle>> operatorStateFromBackend = Arrays.asList(iab);
+					List<Collection<OperatorStateHandle>> operatorStateFromStream = Arrays.asList(ias);
+
+					for (int chainIdx = 0; chainIdx < partitionedParallelStatesBackend.length; ++chainIdx) {
+						List<Collection<OperatorStateHandle>> redistributedOpStateBackend =
+								partitionedParallelStatesBackend[chainIdx];
+
+						List<Collection<OperatorStateHandle>> redistributedOpStateStream =
+								partitionedParallelStatesStream[chainIdx];
+
+						if (redistributedOpStateBackend != null) {
+							operatorStateFromBackend.set(chainIdx, redistributedOpStateBackend.get(subTaskIdx));
+						}
+
+						if (redistributedOpStateStream != null) {
+							operatorStateFromStream.set(chainIdx, redistributedOpStateStream.get(subTaskIdx));
+						}
+					}
+
+					Execution currentExecutionAttempt = executionJobVertex
+							.getTaskVertices()[subTaskIdx]
+							.getCurrentExecutionAttempt();
+
+					List<KeyGroupsStateHandle> newKeyedStatesBackend;
+					List<KeyGroupsStateHandle> newKeyedStateStream;
+					if (parallelismChanged) {
+						KeyGroupRange subtaskKeyGroupIds = keyGroupPartitions.get(subTaskIdx);
+						newKeyedStatesBackend = getKeyGroupsStateHandles(parallelKeyedStatesBackend, subtaskKeyGroupIds);
+						newKeyedStateStream = getKeyGroupsStateHandles(parallelKeyedStateStream, subtaskKeyGroupIds);
+					} else {
+						SubtaskState subtaskState = taskState.getState(subTaskIdx);
+						KeyGroupsStateHandle oldKeyedStatesBackend = subtaskState.getManagedKeyedState();
+						KeyGroupsStateHandle oldKeyedStatesStream = subtaskState.getRawKeyedState();
+						newKeyedStatesBackend = oldKeyedStatesBackend != null ? Collections.singletonList(oldKeyedStatesBackend) : null;
+						newKeyedStateStream = oldKeyedStatesStream != null ? Collections.singletonList(oldKeyedStatesStream) : null;
+					}
+
+					TaskStateHandles taskStateHandles = new TaskStateHandles(
+							nonPartitionableState,
+							operatorStateFromBackend,
+							operatorStateFromStream,
+							newKeyedStatesBackend,
+							newKeyedStateStream);
+
+					currentExecutionAttempt.setInitialState(taskStateHandles);
+				}
+
+			} else {
+				throw new IllegalStateException("There is no execution job vertex for the job" +
+						" vertex ID " + taskGroupStateEntry.getKey());
+			}
+		}
+
+		return true;
+
+	}
+
+	/**
+	 * Determine the subset of {@link KeyGroupsStateHandle KeyGroupsStateHandles} with correct
+	 * key group index for the given subtask {@link KeyGroupRange}.
+	 * <p>
+	 * <p>This is publicly visible to be used in tests.
+	 */
+	public static List<KeyGroupsStateHandle> getKeyGroupsStateHandles(
+			Collection<KeyGroupsStateHandle> allKeyGroupsHandles,
+			KeyGroupRange subtaskKeyGroupIds) {
+
+		List<KeyGroupsStateHandle> subtaskKeyGroupStates = new ArrayList<>();
+
+		for (KeyGroupsStateHandle storedKeyGroup : allKeyGroupsHandles) {
+			KeyGroupsStateHandle intersection = storedKeyGroup.getKeyGroupIntersection(subtaskKeyGroupIds);
+			if (intersection.getNumberOfKeyGroups() > 0) {
+				subtaskKeyGroupStates.add(intersection);
+			}
+		}
+		return subtaskKeyGroupStates;
+	}
+
+	/**
+	 * Groups the available set of key groups into key group partitions. A key group partition is
+	 * the set of key groups which is assigned to the same task. Each set of the returned list
+	 * constitutes a key group partition.
+	 * <p>
+	 * <b>IMPORTANT</b>: The assignment of key groups to partitions has to be in sync with the
+	 * KeyGroupStreamPartitioner.
+	 *
+	 * @param numberKeyGroups Number of available key groups (indexed from 0 to numberKeyGroups - 1)
+	 * @param parallelism     Parallelism to generate the key group partitioning for
+	 * @return List of key group partitions
+	 */
+	public static List<KeyGroupRange> createKeyGroupPartitions(int numberKeyGroups, int parallelism) {
+		Preconditions.checkArgument(numberKeyGroups >= parallelism);
+		List<KeyGroupRange> result = new ArrayList<>(parallelism);
+
+		for (int i = 0; i < parallelism; ++i) {
+			result.add(KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(numberKeyGroups, parallelism, i));
+		}
+		return result;
+	}
+
+	/**
+	 * @param chainParallelOpStates array = chain ops, array[idx] = parallel states for this chain op.
+	 * @param chainOpState
+	 */
+	private static void collectParallelStatesByChainOperator(
+			List<OperatorStateHandle>[] chainParallelOpStates, ChainedStateHandle<OperatorStateHandle> chainOpState) {
+
+		if (null != chainOpState) {
+			for (int chainIdx = 0; chainIdx < chainParallelOpStates.length; ++chainIdx) {
+				OperatorStateHandle operatorState = chainOpState.get(chainIdx);
+
+				if (null != operatorState) {
+
+					List<OperatorStateHandle> opParallelStatesForOneChainOp = chainParallelOpStates[chainIdx];
+
+					if (null == opParallelStatesForOneChainOp) {
+						opParallelStatesForOneChainOp = new ArrayList<>();
+						chainParallelOpStates[chainIdx] = opParallelStatesForOneChainOp;
+					}
+					opParallelStatesForOneChainOp.add(operatorState);
+				}
+			}
+		}
+	}
+
+	private static List<Collection<OperatorStateHandle>> applyRepartitioner(
+			OperatorStateRepartitioner opStateRepartitioner,
+			List<OperatorStateHandle> chainOpParallelStates,
+			int oldParallelism,
+			int newParallelism) {
+
+		if (chainOpParallelStates == null) {
+			return null;
+		}
+
+		//We only redistribute if the parallelism of the operator changed from previous executions
+		if (newParallelism != oldParallelism) {
+
+			return opStateRepartitioner.repartitionState(
+					chainOpParallelStates,
+					newParallelism);
+		} else {
+
+			List<Collection<OperatorStateHandle>> repackStream = new ArrayList<>(newParallelism);
+			for (OperatorStateHandle operatorStateHandle : chainOpParallelStates) {
+				repackStream.add(Collections.singletonList(operatorStateHandle));
+			}
+			return repackStream;
+		}
+	}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
index 2aa0491..9b9a810 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
@@ -19,10 +19,13 @@
 package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateObject;
+import org.apache.flink.runtime.state.StateUtil;
 import org.apache.flink.runtime.state.StreamStateHandle;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
+
+import java.util.Arrays;
 
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
@@ -34,10 +37,31 @@ public class SubtaskState implements StateObject {
 
 	private static final long serialVersionUID = -2394696997971923995L;
 
-	private static final Logger LOG = LoggerFactory.getLogger(SubtaskState.class);
+	/**
+	 * Legacy (non-repartitionable) operator state.
+	 */
+	@Deprecated
+	private final ChainedStateHandle<StreamStateHandle> legacyOperatorState;
 
-	/** The state of the parallel operator */
-	private final ChainedStateHandle<StreamStateHandle> chainedStateHandle;
+	/**
+	 * Snapshot from the {@link org.apache.flink.runtime.state.OperatorStateBackend}.
+	 */
+	private final ChainedStateHandle<OperatorStateHandle> managedOperatorState;
+
+	/**
+	 * Snapshot written using {@link org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream}.
+	 */
+	private final ChainedStateHandle<OperatorStateHandle> rawOperatorState;
+
+	/**
+	 * Snapshot from {@link org.apache.flink.runtime.state.KeyedStateBackend}.
+	 */
+	private final KeyGroupsStateHandle managedKeyedState;
+
+	/**
+	 * Snapshot written using {@link org.apache.flink.runtime.state.KeyedStateCheckpointOutputStream}.
+	 */
+	private final KeyGroupsStateHandle rawKeyedState;
 
 	/**
 	 * The state size. This is also part of the deserialized state handle.
@@ -46,26 +70,76 @@ public class SubtaskState implements StateObject {
 	 */
 	private final long stateSize;
 
-	/** The duration of the checkpoint (ack timestamp - trigger timestamp). */
-	private final long duration;
-	
+	/**
+	 * The duration of the checkpoint (ack timestamp - trigger timestamp).
+	 */
+	private long duration;
+
+	public SubtaskState(
+			ChainedStateHandle<StreamStateHandle> legacyOperatorState,
+			ChainedStateHandle<OperatorStateHandle> managedOperatorState,
+			ChainedStateHandle<OperatorStateHandle> rawOperatorState,
+			KeyGroupsStateHandle managedKeyedState,
+			KeyGroupsStateHandle rawKeyedState) {
+		this(legacyOperatorState,
+				managedOperatorState,
+				rawOperatorState,
+				managedKeyedState,
+				rawKeyedState,
+				0L);
+	}
+
 	public SubtaskState(
-			ChainedStateHandle<StreamStateHandle> chainedStateHandle,
+			ChainedStateHandle<StreamStateHandle> legacyOperatorState,
+			ChainedStateHandle<OperatorStateHandle> managedOperatorState,
+			ChainedStateHandle<OperatorStateHandle> rawOperatorState,
+			KeyGroupsStateHandle managedKeyedState,
+			KeyGroupsStateHandle rawKeyedState,
 			long duration) {
 
-		this.chainedStateHandle = checkNotNull(chainedStateHandle, "State");
+		this.legacyOperatorState = checkNotNull(legacyOperatorState, "State");
+		this.managedOperatorState = managedOperatorState;
+		this.rawOperatorState = rawOperatorState;
+		this.managedKeyedState = managedKeyedState;
+		this.rawKeyedState = rawKeyedState;
 		this.duration = duration;
 		try {
-			stateSize = chainedStateHandle.getStateSize();
+			long calculateStateSize = getSizeNullSafe(legacyOperatorState);
+			calculateStateSize += getSizeNullSafe(managedOperatorState);
+			calculateStateSize += getSizeNullSafe(rawOperatorState);
+			calculateStateSize += getSizeNullSafe(managedKeyedState);
+			calculateStateSize += getSizeNullSafe(rawKeyedState);
+			stateSize = calculateStateSize;
 		} catch (Exception e) {
 			throw new RuntimeException("Failed to get state size.", e);
 		}
 	}
 
+	private static final long getSizeNullSafe(StateObject stateObject) throws Exception {
+		return stateObject != null ? stateObject.getStateSize() : 0L;
+	}
+
 	// --------------------------------------------------------------------------------------------
-	
-	public ChainedStateHandle<StreamStateHandle> getChainedStateHandle() {
-		return chainedStateHandle;
+
+	@Deprecated
+	public ChainedStateHandle<StreamStateHandle> getLegacyOperatorState() {
+		return legacyOperatorState;
+	}
+
+	public ChainedStateHandle<OperatorStateHandle> getManagedOperatorState() {
+		return managedOperatorState;
+	}
+
+	public ChainedStateHandle<OperatorStateHandle> getRawOperatorState() {
+		return rawOperatorState;
+	}
+
+	public KeyGroupsStateHandle getManagedKeyedState() {
+		return managedKeyedState;
+	}
+
+	public KeyGroupsStateHandle getRawKeyedState() {
+		return rawKeyedState;
 	}
 
 	@Override
@@ -79,35 +153,94 @@ public class SubtaskState implements StateObject {
 
 	@Override
 	public void discardState() throws Exception {
-		chainedStateHandle.discardState();
+		StateUtil.bestEffortDiscardAllStateObjects(
+				Arrays.asList(
+						legacyOperatorState,
+						managedOperatorState,
+						rawOperatorState,
+						managedKeyedState,
+						rawKeyedState));
+	}
+
+	public void setDuration(long duration) {
+		this.duration = duration;
 	}
 
 	// --------------------------------------------------------------------------------------------
 
+
 	@Override
 	public boolean equals(Object o) {
 		if (this == o) {
 			return true;
 		}
-		else if (o instanceof SubtaskState) {
-			SubtaskState that = (SubtaskState) o;
-			return this.chainedStateHandle.equals(that.chainedStateHandle) && stateSize == that.stateSize &&
-				duration == that.duration;
+		if (o == null || getClass() != o.getClass()) {
+			return false;
 		}
-		else {
+
+		SubtaskState that = (SubtaskState) o;
+
+		if (stateSize != that.stateSize) {
 			return false;
 		}
+		if (duration != that.duration) {
+			return false;
+		}
+		if (legacyOperatorState != null ?
+				!legacyOperatorState.equals(that.legacyOperatorState)
+				: that.legacyOperatorState != null) {
+			return false;
+		}
+		if (managedOperatorState != null ?
+				!managedOperatorState.equals(that.managedOperatorState)
+				: that.managedOperatorState != null) {
+			return false;
+		}
+		if (rawOperatorState != null ?
+				!rawOperatorState.equals(that.rawOperatorState)
+				: that.rawOperatorState != null) {
+			return false;
+		}
+		if (managedKeyedState != null ?
+				!managedKeyedState.equals(that.managedKeyedState)
+				: that.managedKeyedState != null) {
+			return false;
+		}
+		return rawKeyedState != null ?
+				rawKeyedState.equals(that.rawKeyedState)
+				: that.rawKeyedState == null;
+
+	}
+
+	public boolean hasState() {
+		return (null != legacyOperatorState && !legacyOperatorState.isEmpty())
+				|| (null != managedOperatorState && !managedOperatorState.isEmpty())
+				|| null != managedKeyedState
+				|| null != rawKeyedState;
 	}
 
 	@Override
 	public int hashCode() {
-		return (int) (this.stateSize ^ this.stateSize >>> 32) +
-			31 * ((int) (this.duration ^ this.duration >>> 32) +
-				31 * chainedStateHandle.hashCode());
+		int result = legacyOperatorState != null ? legacyOperatorState.hashCode() : 0;
+		result = 31 * result + (managedOperatorState != null ? managedOperatorState.hashCode() : 0);
+		result = 31 * result + (rawOperatorState != null ? rawOperatorState.hashCode() : 0);
+		result = 31 * result + (managedKeyedState != null ? managedKeyedState.hashCode() : 0);
+		result = 31 * result + (rawKeyedState != null ? rawKeyedState.hashCode() : 0);
+		result = 31 * result + (int) (stateSize ^ (stateSize >>> 32));
+		result = 31 * result + (int) (duration ^ (duration >>> 32));
+		return result;
 	}
 
 	@Override
 	public String toString() {
-		return String.format("SubtaskState(Size: %d, Duration: %d, State: %s)", stateSize, duration, chainedStateHandle);
+		return "SubtaskState{" +
+				"chainedStateHandle=" + legacyOperatorState +
+				", operatorStateFromBackend=" + managedOperatorState +
+				", operatorStateFromStream=" + rawOperatorState +
+				", keyedStateFromBackend=" + managedKeyedState +
+				", keyedStateHandleFromStream=" + rawKeyedState +
+				", stateSize=" + stateSize +
+				", duration=" + duration +
+				'}';
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
index 7e4eded..3cdc5e9 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
@@ -18,11 +18,7 @@
 
 package org.apache.flink.runtime.checkpoint;
 
-import com.google.common.collect.Iterables;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
-import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateObject;
 import org.apache.flink.runtime.state.StateUtil;
 import org.apache.flink.util.Preconditions;
@@ -49,12 +45,6 @@ public class TaskState implements StateObject {
 	/** handles to non-partitioned states, subtaskindex -> subtaskstate */
 	private final Map<Integer, SubtaskState> subtaskStates;
 
-	/** handles to partitionable states, subtaskindex -> partitionable state */
-	private final Map<Integer, ChainedStateHandle<OperatorStateHandle>> partitionableStates;
-
-	/** handles to key-partitioned states, subtaskindex -> keyed state */
-	private final Map<Integer, KeyGroupsStateHandle> keyGroupsStateHandles;
-
 
 	/** parallelism of the operator when it was checkpointed */
 	private final int parallelism;
@@ -62,6 +52,7 @@ public class TaskState implements StateObject {
 	/** maximum parallelism of the operator when the job was first created */
 	private final int maxParallelism;
 
+	/** length of the operator chain */
 	private final int chainLength;
 
 	public TaskState(JobVertexID jobVertexID, int parallelism, int maxParallelism, int chainLength) {
@@ -73,8 +64,6 @@ public class TaskState implements StateObject {
 		this.jobVertexID = jobVertexID;
 
 		this.subtaskStates = new HashMap<>(parallelism);
-		this.partitionableStates = new HashMap<>(parallelism);
-		this.keyGroupsStateHandles = new HashMap<>(parallelism);
 
 		this.parallelism = parallelism;
 		this.maxParallelism = maxParallelism;
@@ -96,32 +85,6 @@ public class TaskState implements StateObject {
 		}
 	}
 
-	public void putPartitionableState(
-			int subtaskIndex,
-			ChainedStateHandle<OperatorStateHandle> partitionableState) {
-
-		Preconditions.checkNotNull(partitionableState);
-
-		if (subtaskIndex < 0 || subtaskIndex >= parallelism) {
-			throw new IndexOutOfBoundsException("The given sub task index " + subtaskIndex +
-					" exceeds the maximum number of sub tasks " + subtaskStates.size());
-		} else {
-			partitionableStates.put(subtaskIndex, partitionableState);
-		}
-	}
-
-	public void putKeyedState(int subtaskIndex, KeyGroupsStateHandle keyGroupsStateHandle) {
-		Preconditions.checkNotNull(keyGroupsStateHandle);
-
-		if (subtaskIndex < 0 || subtaskIndex >= parallelism) {
-			throw new IndexOutOfBoundsException("The given sub task index " + subtaskIndex +
-					" exceeds the maximum number of sub tasks " + subtaskStates.size());
-		} else {
-			keyGroupsStateHandles.put(subtaskIndex, keyGroupsStateHandle);
-		}
-	}
-
-
 	public SubtaskState getState(int subtaskIndex) {
 		if (subtaskIndex < 0 || subtaskIndex >= parallelism) {
 			throw new IndexOutOfBoundsException("The given sub task index " + subtaskIndex +
@@ -131,24 +94,6 @@ public class TaskState implements StateObject {
 		}
 	}
 
-	public ChainedStateHandle<OperatorStateHandle> getPartitionableState(int subtaskIndex) {
-		if (subtaskIndex < 0 || subtaskIndex >= parallelism) {
-			throw new IndexOutOfBoundsException("The given sub task index " + subtaskIndex +
-					" exceeds the maximum number of sub tasks " + subtaskStates.size());
-		} else {
-			return partitionableStates.get(subtaskIndex);
-		}
-	}
-
-	public KeyGroupsStateHandle getKeyGroupState(int subtaskIndex) {
-		if (subtaskIndex < 0 || subtaskIndex >= parallelism) {
-			throw new IndexOutOfBoundsException("The given sub task index " + subtaskIndex +
-					" exceeds the maximum number of sub tasks " + keyGroupsStateHandles.size());
-		} else {
-			return keyGroupsStateHandles.get(subtaskIndex);
-		}
-	}
-
 	public Collection<SubtaskState> getStates() {
 		return subtaskStates.values();
 	}
@@ -169,13 +114,9 @@ public class TaskState implements StateObject {
 		return chainLength;
 	}
 
-	public Collection<KeyGroupsStateHandle> getKeyGroupStates() {
-		return keyGroupsStateHandles.values();
-	}
-
 	public boolean hasNonPartitionedState() {
 		for(SubtaskState sts : subtaskStates.values()) {
-			if (sts != null && !sts.getChainedStateHandle().isEmpty()) {
+			if (sts != null && !sts.getLegacyOperatorState().isEmpty()) {
 				return true;
 			}
 		}
@@ -184,8 +125,7 @@ public class TaskState implements StateObject {
 
 	@Override
 	public void discardState() throws Exception {
-		StateUtil.bestEffortDiscardAllStateObjects(
-				Iterables.concat(subtaskStates.values(), partitionableStates.values(), keyGroupsStateHandles.values()));
+		StateUtil.bestEffortDiscardAllStateObjects(subtaskStates.values());
 	}
 
 
@@ -198,16 +138,6 @@ public class TaskState implements StateObject {
 			if (subtaskState != null) {
 				result += subtaskState.getStateSize();
 			}
-
-			ChainedStateHandle<OperatorStateHandle> partitionableState = partitionableStates.get(i);
-			if (partitionableState != null) {
-				result += partitionableState.getStateSize();
-			}
-
-			KeyGroupsStateHandle keyGroupsState = keyGroupsStateHandles.get(i);
-			if (keyGroupsState != null) {
-				result += keyGroupsState.getStateSize();
-			}
 		}
 
 		return result;
@@ -220,9 +150,7 @@ public class TaskState implements StateObject {
 
 			return jobVertexID.equals(other.jobVertexID)
 					&& parallelism == other.parallelism
-					&& subtaskStates.equals(other.subtaskStates)
-					&& partitionableStates.equals(other.partitionableStates)
-					&& keyGroupsStateHandles.equals(other.keyGroupsStateHandles);
+					&& subtaskStates.equals(other.subtaskStates);
 		} else {
 			return false;
 		}
@@ -230,18 +158,10 @@ public class TaskState implements StateObject {
 
 	@Override
 	public int hashCode() {
-		return parallelism + 31 * Objects.hash(jobVertexID, subtaskStates, partitionableStates, keyGroupsStateHandles);
+		return parallelism + 31 * Objects.hash(jobVertexID, subtaskStates);
 	}
 
 	public Map<Integer, SubtaskState> getSubtaskStates() {
 		return Collections.unmodifiableMap(subtaskStates);
 	}
-
-	public Map<Integer, KeyGroupsStateHandle> getKeyGroupsStateHandles() {
-		return Collections.unmodifiableMap(keyGroupsStateHandles);
-	}
-
-	public Map<Integer, ChainedStateHandle<OperatorStateHandle>> getPartitionableStates() {
-		return partitionableStates;
-	}
 }