You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by al...@apache.org on 2016/10/20 14:15:22 UTC

[4/8] flink git commit: [FLINK-4844] Partitionable Raw Keyed/Operator State

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/AtLeastOnceProducerTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/AtLeastOnceProducerTest.java b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/AtLeastOnceProducerTest.java
index d2d7fca..5e9bacc 100644
--- a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/AtLeastOnceProducerTest.java
+++ b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/AtLeastOnceProducerTest.java
@@ -20,6 +20,8 @@ package org.apache.flink.streaming.connectors.kafka;
 
 import org.apache.flink.api.java.tuple.Tuple1;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.runtime.state.StateSnapshotContextSynchronousImpl;
 import org.apache.flink.streaming.connectors.kafka.testutils.MockRuntimeContext;
 import org.apache.flink.streaming.util.serialization.KeyedSerializationSchema;
 import org.apache.flink.streaming.util.serialization.KeyedSerializationSchemaWrapper;
@@ -112,7 +114,7 @@ public class AtLeastOnceProducerTest {
 		Thread threadB = new Thread(confirmer);
 		threadB.start();
 		// this should block:
-		producer.prepareSnapshot(0, 0);
+		producer.snapshotState(new StateSnapshotContextSynchronousImpl(0, 0));
 		synchronized (threadA) {
 			threadA.notifyAll(); // just in case, to let the test fail faster
 		}
@@ -148,9 +150,9 @@ public class AtLeastOnceProducerTest {
 		}
 
 		@Override
-		public void prepareSnapshot(long checkpointId, long timestamp) throws Exception {
+		public void snapshotState(FunctionSnapshotContext ctx) throws Exception {
 			// call the actual snapshot state
-			super.prepareSnapshot(checkpointId, timestamp);
+			super.snapshotState(ctx);
 			// notify test that snapshotting has been done
 			snapshottingFinished.set(true);
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
index 97220c2..9b7eabf 100644
--- a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
+++ b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
@@ -19,19 +19,26 @@
 package org.apache.flink.streaming.connectors.kafka;
 
 import org.apache.commons.collections.map.LinkedMap;
+import org.apache.flink.api.common.functions.RuntimeContext;
 import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ListStateDescriptor;
-import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.common.state.OperatorStateStore;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContextSynchronousImpl;
 import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks;
 import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
 import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
 import org.apache.flink.streaming.connectors.kafka.internals.AbstractFetcher;
 import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
 import org.apache.flink.streaming.util.serialization.KeyedDeserializationSchema;
 import org.apache.flink.util.SerializedValue;
+import org.junit.Assert;
 import org.junit.Test;
 import org.mockito.Matchers;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
 
 import java.io.Serializable;
 import java.lang.reflect.Field;
@@ -47,6 +54,8 @@ import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
@@ -100,7 +109,7 @@ public class FlinkKafkaConsumerBaseTest {
 		TestingListState<Tuple2<KafkaTopicPartition, Long>> listState = new TestingListState<>();
 		when(operatorStateStore.getOperatorState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
 
-		consumer.prepareSnapshot(17L, 17L);
+		consumer.snapshotState(new StateSnapshotContextSynchronousImpl(1, 1));
 
 		assertFalse(listState.get().iterator().hasNext());
 		consumer.notifyCheckpointComplete(66L);
@@ -113,24 +122,30 @@ public class FlinkKafkaConsumerBaseTest {
 	public void checkRestoredCheckpointWhenFetcherNotReady() throws Exception {
 		OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
 
-		TestingListState<Serializable> expectedState = new TestingListState<>();
-		expectedState.add(Tuple2.of(new KafkaTopicPartition("abc", 13), 16768L));
-		expectedState.add(Tuple2.of(new KafkaTopicPartition("def", 7), 987654321L));
-
 		TestingListState<Serializable> listState = new TestingListState<>();
+		listState.add(Tuple2.of(new KafkaTopicPartition("abc", 13), 16768L));
+		listState.add(Tuple2.of(new KafkaTopicPartition("def", 7), 987654321L));
 
 		FlinkKafkaConsumerBase<String> consumer = getConsumer(null, new LinkedMap(), true);
 
-		when(operatorStateStore.getSerializableListState(Matchers.any(String.class))).thenReturn(expectedState);
-		consumer.initializeState(operatorStateStore);
-
 		when(operatorStateStore.getSerializableListState(Matchers.any(String.class))).thenReturn(listState);
 
-		consumer.prepareSnapshot(17L, 17L);
+		StateInitializationContext initializationContext = mock(StateInitializationContext.class);
+
+		when(initializationContext.getManagedOperatorStateStore()).thenReturn(operatorStateStore);
+		when(initializationContext.isRestored()).thenReturn(true);
+
+		consumer.initializeState(initializationContext);
+
+		consumer.snapshotState(new StateSnapshotContextSynchronousImpl(17, 17));
+
+		// ensure that the list was cleared and refilled. while this is an implementation detail, we use it here
+		// to figure out that snapshotState() actually did something.
+		Assert.assertTrue(listState.isClearCalled());
 
 		Set<Serializable> expected = new HashSet<>();
 
-		for (Serializable serializable : expectedState.get()) {
+		for (Serializable serializable : listState.get()) {
 			expected.add(serializable);
 		}
 
@@ -155,8 +170,14 @@ public class FlinkKafkaConsumerBaseTest {
 		TestingListState<Serializable> listState = new TestingListState<>();
 		when(operatorStateStore.getSerializableListState(Matchers.any(String.class))).thenReturn(listState);
 
-		consumer.initializeState(operatorStateStore);
-		consumer.prepareSnapshot(17L, 17L);
+		StateInitializationContext initializationContext = mock(StateInitializationContext.class);
+
+		when(initializationContext.getManagedOperatorStateStore()).thenReturn(operatorStateStore);
+		when(initializationContext.isRestored()).thenReturn(false);
+
+		consumer.initializeState(initializationContext);
+
+		consumer.snapshotState(new StateSnapshotContextSynchronousImpl(17, 17));
 
 		assertFalse(listState.get().iterator().hasNext());
 	}
@@ -165,6 +186,28 @@ public class FlinkKafkaConsumerBaseTest {
 	 * Tests that on snapshots, states and offsets to commit to Kafka are correct
 	 */
 	@Test
+	public void checkUseFetcherWhenNoCheckpoint() throws Exception {
+
+		FlinkKafkaConsumerBase<String> consumer = getConsumer(null, new LinkedMap(), true);
+		List<KafkaTopicPartition> partitionList = new ArrayList<>(1);
+		partitionList.add(new KafkaTopicPartition("test", 0));
+		consumer.setSubscribedPartitions(partitionList);
+
+		OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
+		TestingListState<Serializable> listState = new TestingListState<>();
+		when(operatorStateStore.getSerializableListState(Matchers.any(String.class))).thenReturn(listState);
+
+		StateInitializationContext initializationContext = mock(StateInitializationContext.class);
+
+		when(initializationContext.getManagedOperatorStateStore()).thenReturn(operatorStateStore);
+
+		// make the context signal that there is no restored state, then validate that
+		when(initializationContext.isRestored()).thenReturn(false);
+		consumer.initializeState(initializationContext);
+		consumer.run(mock(SourceFunction.SourceContext.class));
+	}
+
+	@Test
 	@SuppressWarnings("unchecked")
 	public void testSnapshotState() throws Exception {
 
@@ -196,22 +239,23 @@ public class FlinkKafkaConsumerBaseTest {
 
 		OperatorStateStore backend = mock(OperatorStateStore.class);
 
-		TestingListState<Serializable> init = new TestingListState<>();
-		TestingListState<Serializable> listState1 = new TestingListState<>();
-		TestingListState<Serializable> listState2 = new TestingListState<>();
-		TestingListState<Serializable> listState3 = new TestingListState<>();
+		TestingListState<Serializable> listState = new TestingListState<>();
+
+		when(backend.getSerializableListState(Matchers.any(String.class))).thenReturn(listState);
+
+		StateInitializationContext initializationContext = mock(StateInitializationContext.class);
 
-		when(backend.getSerializableListState(Matchers.any(String.class))).
-				thenReturn(init, listState1, listState2, listState3);
+		when(initializationContext.getManagedOperatorStateStore()).thenReturn(backend);
+		when(initializationContext.isRestored()).thenReturn(false, true, true, true);
 
-		consumer.initializeState(backend);
+		consumer.initializeState(initializationContext);
 
 		// checkpoint 1
-		consumer.prepareSnapshot(138L, 138L);
+		consumer.snapshotState(new StateSnapshotContextSynchronousImpl(138, 138));
 
 		HashMap<KafkaTopicPartition, Long> snapshot1 = new HashMap<>();
 
-		for (Serializable serializable : listState1.get()) {
+		for (Serializable serializable : listState.get()) {
 			Tuple2<KafkaTopicPartition, Long> kafkaTopicPartitionLongTuple2 = (Tuple2<KafkaTopicPartition, Long>) serializable;
 			snapshot1.put(kafkaTopicPartitionLongTuple2.f0, kafkaTopicPartitionLongTuple2.f1);
 		}
@@ -221,11 +265,11 @@ public class FlinkKafkaConsumerBaseTest {
 		assertEquals(state1, pendingOffsetsToCommit.get(138L));
 
 		// checkpoint 2
-		consumer.prepareSnapshot(140L, 140L);
+		consumer.snapshotState(new StateSnapshotContextSynchronousImpl(140, 140));
 
 		HashMap<KafkaTopicPartition, Long> snapshot2 = new HashMap<>();
 
-		for (Serializable serializable : listState2.get()) {
+		for (Serializable serializable : listState.get()) {
 			Tuple2<KafkaTopicPartition, Long> kafkaTopicPartitionLongTuple2 = (Tuple2<KafkaTopicPartition, Long>) serializable;
 			snapshot2.put(kafkaTopicPartitionLongTuple2.f0, kafkaTopicPartitionLongTuple2.f1);
 		}
@@ -240,11 +284,11 @@ public class FlinkKafkaConsumerBaseTest {
 		assertTrue(pendingOffsetsToCommit.containsKey(140L));
 
 		// checkpoint 3
-		consumer.prepareSnapshot(141L, 141L);
+		consumer.snapshotState(new StateSnapshotContextSynchronousImpl(141, 141));
 
 		HashMap<KafkaTopicPartition, Long> snapshot3 = new HashMap<>();
 
-		for (Serializable serializable : listState3.get()) {
+		for (Serializable serializable : listState.get()) {
 			Tuple2<KafkaTopicPartition, Long> kafkaTopicPartitionLongTuple2 = (Tuple2<KafkaTopicPartition, Long>) serializable;
 			snapshot3.put(kafkaTopicPartitionLongTuple2.f0, kafkaTopicPartitionLongTuple2.f1);
 		}
@@ -262,12 +306,12 @@ public class FlinkKafkaConsumerBaseTest {
 		assertEquals(0, pendingOffsetsToCommit.size());
 
 		OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
-		TestingListState<Tuple2<KafkaTopicPartition, Long>> listState = new TestingListState<>();
+		listState = new TestingListState<>();
 		when(operatorStateStore.getOperatorState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
 
 		// create 500 snapshots
 		for (int i = 100; i < 600; i++) {
-			consumer.prepareSnapshot(i, i);
+			consumer.snapshotState(new StateSnapshotContextSynchronousImpl(i, i));
 			listState.clear();
 		}
 		assertEquals(FlinkKafkaConsumerBase.MAX_NUM_PENDING_CHECKPOINTS, pendingOffsetsToCommit.size());
@@ -308,7 +352,7 @@ public class FlinkKafkaConsumerBaseTest {
 
 	// ------------------------------------------------------------------------
 
-	private static final class DummyFlinkKafkaConsumer<T> extends FlinkKafkaConsumerBase<T> {
+	private static class DummyFlinkKafkaConsumer<T> extends FlinkKafkaConsumerBase<T> {
 		private static final long serialVersionUID = 1L;
 
 		@SuppressWarnings("unchecked")
@@ -318,22 +362,37 @@ public class FlinkKafkaConsumerBaseTest {
 
 		@Override
 		protected AbstractFetcher<T, ?> createFetcher(SourceContext<T> sourceContext, List<KafkaTopicPartition> thisSubtaskPartitions, SerializedValue<AssignerWithPeriodicWatermarks<T>> watermarksPeriodic, SerializedValue<AssignerWithPunctuatedWatermarks<T>> watermarksPunctuated, StreamingRuntimeContext runtimeContext) throws Exception {
-			return null;
+			AbstractFetcher<T, ?> fetcher = mock(AbstractFetcher.class);
+			doAnswer(new Answer() {
+				@Override
+				public Object answer(InvocationOnMock invocationOnMock) throws Throwable {
+					Assert.fail("Trying to restore offsets even though there was no restore state.");
+					return null;
+				}
+			}).when(fetcher).restoreOffsets(any(HashMap.class));
+			return fetcher;
 		}
 
 		@Override
 		protected List<KafkaTopicPartition> getKafkaPartitions(List<String> topics) {
 			return Collections.emptyList();
 		}
+
+		@Override
+		public RuntimeContext getRuntimeContext() {
+			return mock(StreamingRuntimeContext.class);
+		}
 	}
 
 	private static final class TestingListState<T> implements ListState<T> {
 
 		private final List<T> list = new ArrayList<>();
+		private boolean clearCalled = false;
 
 		@Override
 		public void clear() {
 			list.clear();
+			clearCalled = true;
 		}
 
 		@Override
@@ -345,5 +404,13 @@ public class FlinkKafkaConsumerBaseTest {
 		public void add(T value) throws Exception {
 			list.add(value);
 		}
+
+		public List<T> getList() {
+			return list;
+		}
+
+		public boolean isClearCalled() {
+			return clearCalled;
+		}
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java
index 4a0fd60..7af5cea 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java
@@ -38,7 +38,7 @@ import java.io.Serializable;
  */
 @Deprecated
 @PublicEvolving
-public interface Checkpointed<T extends Serializable> {
+public interface Checkpointed<T extends Serializable> extends CheckpointedRestoring<T> {
 
 	/**
 	 * Gets the current state of the function of operator. The state must reflect the result of all
@@ -56,14 +56,4 @@ public interface Checkpointed<T extends Serializable> {
 	 *                   and to try again with the next checkpoint attempt.
 	 */
 	T snapshotState(long checkpointId, long checkpointTimestamp) throws Exception;
-
-	/**
-	 * Restores the state of the function or operator to that of a previous checkpoint.
-	 * This method is invoked when a function is executed as part of a recovery run.
-	 *
-	 * Note that restoreState() is called before open().
-	 *
-	 * @param state The state to be restored. 
-	 */
-	void restoreState(T state) throws Exception;
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedFunction.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedFunction.java
index 777cb91..37d8244 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedFunction.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedFunction.java
@@ -20,46 +20,48 @@ package org.apache.flink.streaming.api.checkpoint;
 
 import org.apache.flink.annotation.PublicEvolving;
 import org.apache.flink.api.common.state.OperatorStateStore;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
 
 /**
  *
  * Similar to @{@link Checkpointed}, this interface must be implemented by functions that have potentially
  * repartitionable state that needs to be checkpointed. Methods from this interface are called upon checkpointing and
- * restoring of state.
+ * initialization of state.
  *
- * On #initializeState the implementing class receives the {@link OperatorStateStore}
- * to store it's state. At least before each snapshot, all state persistent state must be stored in the state store.
+ * On {@link #initializeState(FunctionInitializationContext)} the implementing class receives a
+ * {@link FunctionInitializationContext} which provides access to the {@link OperatorStateStore} (all) and
+ * {@link org.apache.flink.api.common.state.KeyedStateStore} (only for keyed operators). Those allow to register
+ * managed operator / keyed  user states. Furthermore, the context provides information whether or the operator was
+ * restored.
  *
- * When the backend is received for initialization, the user registers states with the backend via
- * {@link org.apache.flink.api.common.state.StateDescriptor}. Then, all previously stored state is found in the
- * received {@link org.apache.flink.api.common.state.State} (currently only
- * {@link org.apache.flink.api.common.state.ListState} is supported.
  *
- * In #prepareSnapshot, the implementing class must ensure that all operator state is passed to the operator backend,
- * i.e. that the state was stored in the relevant {@link org.apache.flink.api.common.state.State} instances that
- * are requested on restore. Notice that users might want to clear and reinsert the complete state first if incremental
- * updates of the states are not possible.
+ * In {@link #snapshotState(FunctionSnapshotContext)} the implementing class must ensure that all operator / keyed state
+ * is passed to user states that have been registered during initialization, so that it is visible to the system
+ * backends for checkpointing.
+ *
  */
 @PublicEvolving
 public interface CheckpointedFunction {
 
 	/**
+	 * This method is called when a snapshot for a checkpoint is requested. This acts as a hook to the function to
+	 * ensure that all state is exposed by means previously offered through {@link FunctionInitializationContext} when
+	 * the Function was initialized, or offered now by {@link FunctionSnapshotContext} itself.
 	 *
-	 * This method is called when state should be stored for a checkpoint. The state can be registered and written to
-	 * the provided backend.
-	 *
-	 * @param checkpointId Id of the checkpoint to perform
-	 * @param timestamp Timestamp of the checkpoint
+	 * @param context the context for drawing a snapshot of the operator
 	 * @throws Exception
 	 */
-	void prepareSnapshot(long checkpointId, long timestamp) throws Exception;
+	void snapshotState(FunctionSnapshotContext context) throws Exception;
 
 	/**
-	 * This method is called when an operator is opened, so that the function can set the state backend to which it
-	 * hands it's state on snapshot.
+	 * This method is called when an operator is initialized, so that the function can set up it's state through
+	 * the provided context. Initialization typically includes registering user states through the state stores
+	 * that the context offers.
 	 *
-	 * @param stateStore the state store to which this function stores it's state
+	 * @param context the context for initializing the operator
 	 * @throws Exception
 	 */
-	void initializeState(OperatorStateStore stateStore) throws Exception;
+	void initializeState(FunctionInitializationContext context) throws Exception;
+
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedRestoring.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedRestoring.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedRestoring.java
new file mode 100644
index 0000000..c0dd361
--- /dev/null
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedRestoring.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.checkpoint;
+
+import org.apache.flink.annotation.PublicEvolving;
+
+import java.io.Serializable;
+
+/**
+ * This deprecated interface contains the methods for restoring from the legacy checkpointing mechanism of state.
+ * @param <T> type of the restored state.
+ */
+@Deprecated
+@PublicEvolving
+public interface CheckpointedRestoring<T extends Serializable> {
+	/**
+	 * Restores the state of the function or operator to that of a previous checkpoint.
+	 * This method is invoked when a function is executed as part of a recovery run.
+	 *
+	 * Note that restoreState() is called before open().
+	 *
+	 * @param state The state to be restored.
+	 */
+	void restoreState(T state) throws Exception;
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
index 167dfb0..9184e93 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
@@ -22,38 +22,47 @@ import org.apache.commons.io.IOUtils;
 import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
 import org.apache.flink.annotation.PublicEvolving;
 import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.state.KeyedStateStore;
 import org.apache.flink.api.common.state.State;
 import org.apache.flink.api.common.state.StateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.metrics.Counter;
 import org.apache.flink.metrics.Gauge;
 import org.apache.flink.metrics.MetricGroup;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.runtime.state.DefaultKeyedStateStore;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.OperatorStateBackend;
 import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateInitializationContextImpl;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.runtime.state.StateSnapshotContextSynchronousImpl;
+import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
 import org.apache.flink.streaming.runtime.tasks.TimeServiceProvider;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.util.Collection;
+import java.util.ConcurrentModificationException;
 import java.util.HashMap;
 import java.util.Map;
-import java.util.ConcurrentModificationException;
-import java.util.Collection;
-import java.util.concurrent.RunnableFuture;
 
 /**
  * Base class for all stream operators. Operators that contain a user function should extend the class 
@@ -97,6 +106,7 @@ public abstract class AbstractStreamOperator<OUT>
 	private transient StreamingRuntimeContext runtimeContext;
 
 
+
 	// ---------------- key/value state ------------------
 
 	/** key selector used to get the key for the state. Non-null only is the operator uses key/value state */
@@ -106,11 +116,12 @@ public abstract class AbstractStreamOperator<OUT>
 	/** Backend for keyed state. This might be empty if we're not on a keyed stream. */
 	private transient AbstractKeyedStateBackend<?> keyedStateBackend;
 
-	/** Operator state backend */
+	/** Keyed state store view on the keyed backend */
+	private transient DefaultKeyedStateStore keyedStateStore;
+	
+	/** Operator state backend / store */
 	private transient OperatorStateBackend operatorStateBackend;
 
-	private transient Collection<OperatorStateHandle> lazyRestoreStateHandles;
-
 
 	// --------------- Metrics ---------------------------
 
@@ -151,8 +162,61 @@ public abstract class AbstractStreamOperator<OUT>
 	}
 
 	@Override
-	public void restoreState(Collection<OperatorStateHandle> stateHandles) {
-		this.lazyRestoreStateHandles = stateHandles;
+	public final void initializeState(OperatorStateHandles stateHandles) throws Exception {
+
+		Collection<KeyGroupsStateHandle> keyedStateHandlesRaw = null;
+		Collection<OperatorStateHandle> operatorStateHandlesRaw = null;
+		Collection<OperatorStateHandle> operatorStateHandlesBackend = null;
+
+		boolean restoring = null != stateHandles;
+
+		initKeyedState(); //TODO we should move the actual initialization of this from StreamTask to this class
+
+		if (restoring) {
+
+			// TODO check that there is EITHER old OR new state in handles!
+			restoreStreamCheckpointed(stateHandles);
+
+			//pass directly
+			operatorStateHandlesBackend = stateHandles.getManagedOperatorState();
+			operatorStateHandlesRaw = stateHandles.getRawOperatorState();
+
+			if (null != getKeyedStateBackend()) {
+				//only use the keyed state if it is meant for us (aka head operator)
+				keyedStateHandlesRaw = stateHandles.getRawKeyedState();
+			}
+		}
+
+		initOperatorState(operatorStateHandlesBackend);
+
+		StateInitializationContext initializationContext = new StateInitializationContextImpl(
+				restoring, // information whether we restore or start for the first time
+				operatorStateBackend, // access to operator state backend
+				keyedStateStore, // access to keyed state backend
+				keyedStateHandlesRaw, // access to keyed state stream
+				operatorStateHandlesRaw, // access to operator state stream
+				getContainingTask().getCancelables()); // access to register streams for canceling
+
+		initializeState(initializationContext);
+	}
+
+	@Deprecated
+	private void restoreStreamCheckpointed(OperatorStateHandles stateHandles) throws Exception {
+		StreamStateHandle state = stateHandles.getLegacyOperatorState();
+		if (this instanceof StreamCheckpointedOperator && null != state) {
+
+			LOG.debug("Restore state of task {} in chain ({}).",
+					stateHandles.getOperatorChainIndex(), getContainingTask().getName());
+
+			FSDataInputStream is = state.openInputStream();
+			try {
+				getContainingTask().getCancelables().registerClosable(is);
+				((StreamCheckpointedOperator) this).restoreState(is);
+			} finally {
+				getContainingTask().getCancelables().unregisterClosable(is);
+				is.close();
+			}
+		}
 	}
 
 	/**
@@ -165,8 +229,7 @@ public abstract class AbstractStreamOperator<OUT>
 	 */
 	@Override
 	public void open() throws Exception {
-		initOperatorState();
-		initKeyedState();
+
 	}
 
 	private void initKeyedState() {
@@ -174,7 +237,6 @@ public abstract class AbstractStreamOperator<OUT>
 			TypeSerializer<Object> keySerializer = config.getStateKeySerializer(getUserCodeClassloader());
 			// create a keyed state backend if there is keyed state, as indicated by the presence of a key serializer
 			if (null != keySerializer) {
-
 				KeyGroupRange subTaskKeyGroupRange = KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(
 						container.getEnvironment().getTaskInfo().getNumberOfKeyGroups(),
 						container.getEnvironment().getTaskInfo().getNumberOfParallelSubtasks(),
@@ -184,7 +246,8 @@ public abstract class AbstractStreamOperator<OUT>
 						keySerializer,
 						container.getConfiguration().getNumberOfKeyGroups(getUserCodeClassloader()),
 						subTaskKeyGroupRange);
-
+				
+				this.keyedStateStore = new DefaultKeyedStateStore(keyedStateBackend, getExecutionConfig());
 			}
 
 		} catch (Exception e) {
@@ -192,10 +255,10 @@ public abstract class AbstractStreamOperator<OUT>
 		}
 	}
 
-	private void initOperatorState() {
+	private void initOperatorState(Collection<OperatorStateHandle> operatorStateHandles) {
 		try {
 			// create an operator state backend
-			this.operatorStateBackend = container.createOperatorStateBackend(this, lazyRestoreStateHandles);
+			this.operatorStateBackend = container.createOperatorStateBackend(this, operatorStateHandles);
 		} catch (Exception e) {
 			throw new IllegalStateException("Could not initialize operator state backend.", e);
 		}
@@ -238,11 +301,51 @@ public abstract class AbstractStreamOperator<OUT>
 	}
 
 	@Override
-	public RunnableFuture<OperatorStateHandle> snapshotState(
+	public final OperatorSnapshotResult snapshotState(
 			long checkpointId, long timestamp, CheckpointStreamFactory streamFactory) throws Exception {
 
-		return operatorStateBackend != null ?
-				operatorStateBackend.snapshot(checkpointId, timestamp, streamFactory) : null;
+		KeyGroupRange keyGroupRange = null != keyedStateBackend ?
+				keyedStateBackend.getKeyGroupRange() : KeyGroupRange.EMPTY_KEY_GROUP_RANGE;
+
+		StateSnapshotContextSynchronousImpl snapshotContext = new StateSnapshotContextSynchronousImpl(
+				checkpointId, timestamp, streamFactory, keyGroupRange, getContainingTask().getCancelables());
+
+		snapshotState(snapshotContext);
+
+		OperatorSnapshotResult snapshotInProgress = new OperatorSnapshotResult();
+
+		snapshotInProgress.setKeyedStateRawFuture(snapshotContext.getKeyedStateStreamFuture());
+		snapshotInProgress.setOperatorStateRawFuture(snapshotContext.getOperatorStateStreamFuture());
+
+		if (null != operatorStateBackend) {
+			snapshotInProgress.setOperatorStateManagedFuture(
+					operatorStateBackend.snapshot(checkpointId, timestamp, streamFactory));
+		}
+
+		if (null != keyedStateBackend) {
+			snapshotInProgress.setKeyedStateManagedFuture(
+					keyedStateBackend.snapshot(checkpointId, timestamp, streamFactory));
+		}
+
+		return snapshotInProgress;
+	}
+
+	/**
+	 * Stream operators with state, which want to participate in a snapshot need to override this hook method.
+	 *
+	 * @param context context that provides information and means required for taking a snapshot
+	 */
+	public void snapshotState(StateSnapshotContext context) throws Exception {
+
+	}
+
+	/**
+	 * Stream operators with state which can be restored need to override this hook method.
+	 *
+	 * @param context context that allows to register different states.
+	 */
+	public void initializeState(StateInitializationContext context) throws Exception {
+
 	}
 
 	@Override
@@ -283,22 +386,12 @@ public abstract class AbstractStreamOperator<OUT>
 		return runtimeContext;
 	}
 
-	@SuppressWarnings("rawtypes, unchecked")
+	@SuppressWarnings("unchecked")
 	public <K> KeyedStateBackend<K> getKeyedStateBackend() {
-
-		if (null == keyedStateBackend) {
-			initKeyedState();
-		}
-
 		return (KeyedStateBackend<K>) keyedStateBackend;
 	}
 
 	public OperatorStateBackend getOperatorStateBackend() {
-
-		if (null == operatorStateBackend) {
-			initOperatorState();
-		}
-
 		return operatorStateBackend;
 	}
 
@@ -327,12 +420,12 @@ public abstract class AbstractStreamOperator<OUT>
 	 * @throws Exception Thrown, if the state backend cannot create the key/value state.
 	 */
 	@SuppressWarnings("unchecked")
-	protected <S extends State, N> S getPartitionedState(N namespace, TypeSerializer<N> namespaceSerializer, StateDescriptor<S, ?> stateDescriptor) throws Exception {
-		if (keyedStateBackend != null) {
-			return keyedStateBackend.getPartitionedState(
-					namespace,
-					namespaceSerializer,
-					stateDescriptor);
+	protected <S extends State, N> S getPartitionedState(
+			N namespace, TypeSerializer<N> namespaceSerializer, 
+			StateDescriptor<S, ?> stateDescriptor) throws Exception {
+		
+		if (keyedStateStore != null) {
+			return keyedStateBackend.getPartitionedState(namespace, namespaceSerializer, stateDescriptor);
 		} else {
 			throw new RuntimeException("Cannot create partitioned state. The keyed state " +
 				"backend has not been set. This indicates that the operator is not " +
@@ -343,18 +436,18 @@ public abstract class AbstractStreamOperator<OUT>
 	@Override
 	@SuppressWarnings({"unchecked", "rawtypes"})
 	public void setKeyContextElement1(StreamRecord record) throws Exception {
-		setRawKeyContextElement(record, stateKeySelector1);
+		setKeyContextElement(record, stateKeySelector1);
 	}
 
 	@Override
 	@SuppressWarnings({"unchecked", "rawtypes"})
 	public void setKeyContextElement2(StreamRecord record) throws Exception {
-		setRawKeyContextElement(record, stateKeySelector2);
+		setKeyContextElement(record, stateKeySelector2);
 	}
 
-	private void setRawKeyContextElement(StreamRecord record, KeySelector<?, ?> selector) throws Exception {
+	private <T> void setKeyContextElement(StreamRecord<T> record, KeySelector<T, ?> selector) throws Exception {
 		if (selector != null) {
-			Object key = ((KeySelector) selector).getKey(record.getValue());
+			Object key = selector.getKey(record.getValue());
 			setKeyContext(key);
 		}
 	}
@@ -374,6 +467,10 @@ public abstract class AbstractStreamOperator<OUT>
 		}
 	}
 
+	public KeyedStateStore getKeyedStateStore() {
+		return keyedStateStore;
+	}
+
 	// ------------------------------------------------------------------------
 	//  Context and chaining properties
 	// ------------------------------------------------------------------------
@@ -567,4 +664,5 @@ public abstract class AbstractStreamOperator<OUT>
 			output.close();
 		}
 	}
+
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
index 72f30b8..5e1a252 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
@@ -28,11 +28,12 @@ import org.apache.flink.configuration.Configuration;
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.runtime.state.CheckpointListener;
-import org.apache.flink.runtime.state.CheckpointStreamFactory;
-import org.apache.flink.runtime.state.OperatorStateHandle;
-import org.apache.flink.api.common.state.OperatorStateStore;
+import org.apache.flink.runtime.state.DefaultOperatorStateBackend;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
 import org.apache.flink.streaming.api.checkpoint.Checkpointed;
 import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedRestoring;
 import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
@@ -42,7 +43,6 @@ import org.apache.flink.util.InstantiationUtil;
 import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.List;
-import java.util.concurrent.RunnableFuture;
 
 import static java.util.Objects.requireNonNull;
 
@@ -73,6 +73,7 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function>
 	
 	public AbstractUdfStreamOperator(F userFunction) {
 		this.userFunction = requireNonNull(userFunction);
+		checkUdfCheckpointingPreconditions();
 	}
 
 	/**
@@ -93,22 +94,44 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function>
 		super.setup(containingTask, config, output);
 		
 		FunctionUtils.setFunctionRuntimeContext(userFunction, getRuntimeContext());
+
 	}
 
 	@Override
-	public void open() throws Exception {
-		super.open();
-		
-		FunctionUtils.openFunction(userFunction, new Configuration());
+	public void snapshotState(StateSnapshotContext context) throws Exception {
+		super.snapshotState(context);
 
 		if (userFunction instanceof CheckpointedFunction) {
-			((CheckpointedFunction) userFunction).initializeState(getOperatorStateBackend());
+			((CheckpointedFunction) userFunction).snapshotState(context);
 		} else if (userFunction instanceof ListCheckpointed) {
 			@SuppressWarnings("unchecked")
-			ListCheckpointed<Serializable> listCheckpointedFun = (ListCheckpointed<Serializable>) userFunction;
+			List<Serializable> partitionableState = ((ListCheckpointed<Serializable>) userFunction).
+							snapshotState(context.getCheckpointId(), context.getCheckpointTimestamp());
 
 			ListState<Serializable> listState = getOperatorStateBackend().
-					getSerializableListState(OperatorStateStore.DEFAULT_OPERATOR_STATE_NAME);
+					getSerializableListState(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME);
+
+			listState.clear();
+
+			for (Serializable statePartition : partitionableState) {
+				listState.add(statePartition);
+			}
+		}
+
+	}
+
+	@Override
+	public void initializeState(StateInitializationContext context) throws Exception {
+		super.initializeState(context);
+
+		if (userFunction instanceof CheckpointedFunction) {
+			((CheckpointedFunction) userFunction).initializeState(context);
+		} else if (context.isRestored() && userFunction instanceof ListCheckpointed) {
+			@SuppressWarnings("unchecked")
+			ListCheckpointed<Serializable> listCheckpointedFun = (ListCheckpointed<Serializable>) userFunction;
+
+			ListState<Serializable> listState = context.getManagedOperatorStateStore().
+					getSerializableListState(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME);
 
 			List<Serializable> list = new ArrayList<>();
 
@@ -122,6 +145,13 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function>
 				throw new Exception("Failed to restore state to function: " + e.getMessage(), e);
 			}
 		}
+
+	}
+
+	@Override
+	public void open() throws Exception {
+		super.open();
+		FunctionUtils.openFunction(userFunction, new Configuration());
 	}
 
 	@Override
@@ -147,6 +177,7 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function>
 	@Override
 	public void snapshotState(FSDataOutputStream out, long checkpointId, long timestamp) throws Exception {
 
+
 		if (userFunction instanceof Checkpointed) {
 			@SuppressWarnings("unchecked")
 			Checkpointed<Serializable> chkFunction = (Checkpointed<Serializable>) userFunction;
@@ -169,9 +200,9 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function>
 	@Override
 	public void restoreState(FSDataInputStream in) throws Exception {
 
-		if (userFunction instanceof Checkpointed) {
+		if (userFunction instanceof CheckpointedRestoring) {
 			@SuppressWarnings("unchecked")
-			Checkpointed<Serializable> chkFunction = (Checkpointed<Serializable>) userFunction;
+			CheckpointedRestoring<Serializable> chkFunction = (CheckpointedRestoring<Serializable>) userFunction;
 
 			int hasUdfState = in.read();
 
@@ -189,32 +220,6 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function>
 	}
 
 	@Override
-	public RunnableFuture<OperatorStateHandle> snapshotState(
-			long checkpointId, long timestamp, CheckpointStreamFactory streamFactory) throws Exception {
-
-		if (userFunction instanceof CheckpointedFunction) {
-			((CheckpointedFunction) userFunction).prepareSnapshot(checkpointId, timestamp);
-		}
-
-		if (userFunction instanceof ListCheckpointed) {
-			@SuppressWarnings("unchecked")
-			List<Serializable> partitionableState =
-					((ListCheckpointed<Serializable>) userFunction).snapshotState(checkpointId, timestamp);
-
-			ListState<Serializable> listState = getOperatorStateBackend().
-					getSerializableListState(OperatorStateStore.DEFAULT_OPERATOR_STATE_NAME);
-
-			listState.clear();
-
-			for (Serializable statePartition : partitionableState) {
-				listState.add(statePartition);
-			}
-		}
-
-		return super.snapshotState(checkpointId, timestamp, streamFactory);
-	}
-
-	@Override
 	public void notifyOfCompletedCheckpoint(long checkpointId) throws Exception {
 		super.notifyOfCompletedCheckpoint(checkpointId);
 
@@ -251,4 +256,26 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function>
 	public Configuration getUserFunctionParameters() {
 		return new Configuration();
 	}
+
+	private void checkUdfCheckpointingPreconditions() {
+
+		boolean newCheckpointInferface = false;
+
+		if (userFunction instanceof CheckpointedFunction) {
+			newCheckpointInferface = true;
+		}
+
+		if (userFunction instanceof ListCheckpointed) {
+			if (newCheckpointInferface) {
+				throw new IllegalStateException("User functions are not allowed to implement " +
+						"CheckpointedFunction AND ListCheckpointed.");
+			}
+			newCheckpointInferface = true;
+		}
+
+		if (newCheckpointInferface && userFunction instanceof Checkpointed) {
+			throw new IllegalStateException("User functions are not allowed to implement Checkpointed AND " +
+					"CheckpointedFunction/ListCheckpointed.");
+		}
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/OperatorSnapshotResult.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/OperatorSnapshotResult.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/OperatorSnapshotResult.java
new file mode 100644
index 0000000..52c89f8
--- /dev/null
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/OperatorSnapshotResult.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators;
+
+import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
+
+import java.util.concurrent.RunnableFuture;
+
+/**
+ * Result of {@link AbstractStreamOperator#snapshotState}.
+ */
+public class OperatorSnapshotResult {
+
+	private RunnableFuture<KeyGroupsStateHandle> keyedStateManagedFuture;
+	private RunnableFuture<KeyGroupsStateHandle> keyedStateRawFuture;
+	private RunnableFuture<OperatorStateHandle> operatorStateManagedFuture;
+	private RunnableFuture<OperatorStateHandle> operatorStateRawFuture;
+
+	public OperatorSnapshotResult() {
+	}
+
+	public OperatorSnapshotResult(
+			RunnableFuture<KeyGroupsStateHandle> keyedStateManagedFuture,
+			RunnableFuture<KeyGroupsStateHandle> keyedStateRawFuture,
+			RunnableFuture<OperatorStateHandle> operatorStateManagedFuture,
+			RunnableFuture<OperatorStateHandle> operatorStateRawFuture) {
+		this.keyedStateManagedFuture = keyedStateManagedFuture;
+		this.keyedStateRawFuture = keyedStateRawFuture;
+		this.operatorStateManagedFuture = operatorStateManagedFuture;
+		this.operatorStateRawFuture = operatorStateRawFuture;
+	}
+
+	public RunnableFuture<KeyGroupsStateHandle> getKeyedStateManagedFuture() {
+		return keyedStateManagedFuture;
+	}
+
+	public void setKeyedStateManagedFuture(RunnableFuture<KeyGroupsStateHandle> keyedStateManagedFuture) {
+		this.keyedStateManagedFuture = keyedStateManagedFuture;
+	}
+
+	public RunnableFuture<KeyGroupsStateHandle> getKeyedStateRawFuture() {
+		return keyedStateRawFuture;
+	}
+
+	public void setKeyedStateRawFuture(RunnableFuture<KeyGroupsStateHandle> keyedStateRawFuture) {
+		this.keyedStateRawFuture = keyedStateRawFuture;
+	}
+
+	public RunnableFuture<OperatorStateHandle> getOperatorStateManagedFuture() {
+		return operatorStateManagedFuture;
+	}
+
+	public void setOperatorStateManagedFuture(RunnableFuture<OperatorStateHandle> operatorStateManagedFuture) {
+		this.operatorStateManagedFuture = operatorStateManagedFuture;
+	}
+
+	public RunnableFuture<OperatorStateHandle> getOperatorStateRawFuture() {
+		return operatorStateRawFuture;
+	}
+
+	public void setOperatorStateRawFuture(RunnableFuture<OperatorStateHandle> operatorStateRawFuture) {
+		this.operatorStateRawFuture = operatorStateRawFuture;
+	}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java
index fae5fd0..f6e5472 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java
@@ -20,14 +20,12 @@ package org.apache.flink.streaming.api.operators;
 import org.apache.flink.annotation.PublicEvolving;
 import org.apache.flink.metrics.MetricGroup;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
-import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
 
 import java.io.Serializable;
-import java.util.Collection;
-import java.util.concurrent.RunnableFuture;
 
 /**
  * Basic interface for stream operators. Implementers would implement one of
@@ -105,7 +103,7 @@ public interface StreamOperator<OUT> extends Serializable {
 	 * the runnable might already be finished.
 	 * @throws Exception exception that happened during snapshotting.
 	 */
-	RunnableFuture<OperatorStateHandle> snapshotState(
+	OperatorSnapshotResult snapshotState(
 			long checkpointId, long timestamp, CheckpointStreamFactory streamFactory) throws Exception;
 
 	/**
@@ -113,7 +111,7 @@ public interface StreamOperator<OUT> extends Serializable {
 	 *
 	 * @param stateHandles state handles to the operator state.
 	 */
-	void restoreState(Collection<OperatorStateHandle> stateHandles);
+	void initializeState(OperatorStateHandles stateHandles) throws Exception;
 
 	/**
 	 * Called when the checkpoint with the given ID is completed and acknowledged on the JobManager.

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java
index cc2e54b..cd0489f 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java
@@ -37,8 +37,6 @@ import org.apache.flink.streaming.runtime.tasks.TimeServiceProvider;
 import java.util.List;
 import java.util.Map;
 
-import static java.util.Objects.requireNonNull;
-
 /**
  * Implementation of the {@link org.apache.flink.api.common.functions.RuntimeContext},
  * for streaming operators.
@@ -108,36 +106,17 @@ public class StreamingRuntimeContext extends AbstractRuntimeUDFContext {
 
 	@Override
 	public <T> ValueState<T> getState(ValueStateDescriptor<T> stateProperties) {
-		requireNonNull(stateProperties, "The state properties must not be null");
-		try {
-			stateProperties.initializeSerializerUnlessSet(getExecutionConfig());
-			return operator.getPartitionedState(stateProperties);
-		} catch (Exception e) {
-			throw new RuntimeException("Error while getting state", e);
-		}
+		return operator.getKeyedStateStore().getState(stateProperties);
 	}
 
 	@Override
 	public <T> ListState<T> getListState(ListStateDescriptor<T> stateProperties) {
-		requireNonNull(stateProperties, "The state properties must not be null");
-		try {
-			stateProperties.initializeSerializerUnlessSet(getExecutionConfig());
-			ListState<T> originalState = operator.getPartitionedState(stateProperties);
-			return new UserFacingListState<T>(originalState);
-		} catch (Exception e) {
-			throw new RuntimeException("Error while getting state", e);
-		}
+		return operator.getKeyedStateStore().getListState(stateProperties);
 	}
 
 	@Override
 	public <T> ReducingState<T> getReducingState(ReducingStateDescriptor<T> stateProperties) {
-		requireNonNull(stateProperties, "The state properties must not be null");
-		try {
-			stateProperties.initializeSerializerUnlessSet(getExecutionConfig());
-			return operator.getPartitionedState(stateProperties);
-		} catch (Exception e) {
-			throw new RuntimeException("Error while getting state", e);
-		}
+		return operator.getKeyedStateStore().getReducingState(stateProperties);
 	}
 
 	// ------------------ expose (read only) relevant information from the stream config -------- //

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/UserFacingListState.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/UserFacingListState.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/UserFacingListState.java
deleted file mode 100644
index a02a204..0000000
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/UserFacingListState.java
+++ /dev/null
@@ -1,57 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.streaming.api.operators;
-
-import org.apache.flink.api.common.state.ListState;
-
-import java.util.Collections;
-
-/**
- * Simple wrapper list state that exposes empty state properly as an empty list.
- * 
- * @param <T> The type of elements in the list state.
- */
-class UserFacingListState<T> implements ListState<T> {
-
-	private final ListState<T> originalState;
-
-	private final Iterable<T> emptyState = Collections.emptyList();
-
-	UserFacingListState(ListState<T> originalState) {
-		this.originalState = originalState;
-	}
-
-	// ------------------------------------------------------------------------
-
-	@Override
-	public Iterable<T> get() throws Exception {
-		Iterable<T> original = originalState.get();
-		return original != null ? original : emptyState;
-	}
-
-	@Override
-	public void add(T value) throws Exception {
-		originalState.add(value);
-	}
-
-	@Override
-	public void clear() {
-		originalState.clear();
-	}
-}

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorStateHandles.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorStateHandles.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorStateHandles.java
new file mode 100644
index 0000000..7abf8d9
--- /dev/null
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorStateHandles.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.runtime.tasks;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.TaskStateHandles;
+import org.apache.flink.util.CollectionUtil;
+import org.apache.flink.util.Preconditions;
+
+import java.util.Collection;
+import java.util.List;
+
+/**
+ * This class holds all state handles for one operator.
+ */
+@Internal
+@VisibleForTesting
+public class OperatorStateHandles {
+
+	private final int operatorChainIndex;
+
+	private final StreamStateHandle legacyOperatorState;
+
+	private final Collection<KeyGroupsStateHandle> managedKeyedState;
+	private final Collection<KeyGroupsStateHandle> rawKeyedState;
+	private final Collection<OperatorStateHandle> managedOperatorState;
+	private final Collection<OperatorStateHandle> rawOperatorState;
+
+	public OperatorStateHandles(
+			int operatorChainIndex,
+			StreamStateHandle legacyOperatorState,
+			Collection<KeyGroupsStateHandle> managedKeyedState,
+			Collection<KeyGroupsStateHandle> rawKeyedState,
+			Collection<OperatorStateHandle> managedOperatorState,
+			Collection<OperatorStateHandle> rawOperatorState) {
+
+		this.operatorChainIndex = operatorChainIndex;
+		this.legacyOperatorState = legacyOperatorState;
+		this.managedKeyedState = managedKeyedState;
+		this.rawKeyedState = rawKeyedState;
+		this.managedOperatorState = managedOperatorState;
+		this.rawOperatorState = rawOperatorState;
+	}
+
+	public OperatorStateHandles(TaskStateHandles taskStateHandles, int operatorChainIndex) {
+		Preconditions.checkNotNull(taskStateHandles);
+
+		this.operatorChainIndex = operatorChainIndex;
+
+		ChainedStateHandle<StreamStateHandle> legacyState = taskStateHandles.getLegacyOperatorState();
+		this.legacyOperatorState = ChainedStateHandle.isNullOrEmpty(legacyState) ?
+				null : legacyState.get(operatorChainIndex);
+
+		this.rawKeyedState = taskStateHandles.getRawKeyedState();
+		this.managedKeyedState = taskStateHandles.getManagedKeyedState();
+
+		this.managedOperatorState = getSafeItemAtIndexOrNull(taskStateHandles.getManagedOperatorState(), operatorChainIndex);
+		this.rawOperatorState = getSafeItemAtIndexOrNull(taskStateHandles.getRawOperatorState(), operatorChainIndex);
+	}
+
+	public StreamStateHandle getLegacyOperatorState() {
+		return legacyOperatorState;
+	}
+
+	public Collection<KeyGroupsStateHandle> getManagedKeyedState() {
+		return managedKeyedState;
+	}
+
+	public Collection<KeyGroupsStateHandle> getRawKeyedState() {
+		return rawKeyedState;
+	}
+
+	public Collection<OperatorStateHandle> getManagedOperatorState() {
+		return managedOperatorState;
+	}
+
+	public Collection<OperatorStateHandle> getRawOperatorState() {
+		return rawOperatorState;
+	}
+
+	public int getOperatorChainIndex() {
+		return operatorChainIndex;
+	}
+
+	private static <T> T getSafeItemAtIndexOrNull(List<T> list, int idx) {
+		return CollectionUtil.isNullOrEmpty(list) ? null : list.get(idx);
+	}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index 2e6ebf3..eb5fde7 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -23,9 +23,9 @@ import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.configuration.IllegalConfigurationException;
-import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.metrics.Gauge;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.checkpoint.SubtaskState;
 import org.apache.flink.runtime.execution.CancelTaskException;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
@@ -33,7 +33,6 @@ import org.apache.flink.runtime.jobgraph.tasks.StatefulTask;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.ClosableRegistry;
 import org.apache.flink.runtime.state.KeyGroupRange;
@@ -42,27 +41,29 @@ import org.apache.flink.runtime.state.OperatorStateBackend;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateBackendFactory;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.state.filesystem.FsStateBackend;
 import org.apache.flink.runtime.state.filesystem.FsStateBackendFactory;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.runtime.taskmanager.DispatcherThreadFactory;
 import org.apache.flink.streaming.api.TimeCharacteristic;
 import org.apache.flink.streaming.api.graph.StreamConfig;
+import org.apache.flink.streaming.api.operators.OperatorSnapshotResult;
 import org.apache.flink.streaming.api.operators.Output;
 import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
 import org.apache.flink.streaming.api.operators.StreamOperator;
 import org.apache.flink.streaming.runtime.io.RecordWriterOutput;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.CollectionUtil;
+import org.apache.flink.util.FutureUtil;
 import org.apache.flink.util.Preconditions;
-
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.io.Closeable;
 import java.io.IOException;
-import java.util.Arrays;
+import java.util.ArrayList;
 import java.util.Collection;
-import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.ExecutorService;
@@ -87,13 +88,14 @@ import java.util.concurrent.ThreadFactory;
  *
  * The life cycle of the task is set up as follows:
  * <pre>{@code
- *  -- getOperatorState() -> restores state of all operators in the chain
+ *  -- setInitialState -> provides state of all operators in the chain
  *
  *  -- invoke()
  *        |
  *        +----> Create basic utils (config, etc) and load the chain of operators
  *        +----> operators.setup()
  *        +----> task specific init()
+ *        +----> initialize-operator-states()
  *        +----> open-operators()
  *        +----> run()
  *        +----> close-operators()
@@ -153,12 +155,7 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 	/** The map of user-defined accumulators of this task */
 	private Map<String, Accumulator<?, ?>> accumulatorMap;
 
-	/** The chained operator state to be restored once the initialization is done */
-	private ChainedStateHandle<StreamStateHandle> lazyRestoreChainedOperatorState;
-
-	private List<KeyGroupsStateHandle> lazyRestoreKeyGroupStates;
-
-	private List<Collection<OperatorStateHandle>> lazyRestoreOperatorState;
+	private TaskStateHandles restoreStateHandles;
 
 
 	/** The currently active background materialization threads */
@@ -251,9 +248,8 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 			// -------- Invoke --------
 			LOG.debug("Invoking {}", getName());
 
-			// first order of business is to give operators back their state
-			restoreState();
-			lazyRestoreChainedOperatorState = null; // GC friendliness
+			// first order of business is to give operators their state
+			initializeState();
 
 			// we need to make sure that any triggers scheduled in open() cannot be
 			// executed before all operators are opened
@@ -510,60 +506,8 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 	// ------------------------------------------------------------------------
 
 	@Override
-	public void setInitialState(
-		ChainedStateHandle<StreamStateHandle> chainedState,
-		List<KeyGroupsStateHandle> keyGroupsState,
-		List<Collection<OperatorStateHandle>> partitionableOperatorState) {
-
-		lazyRestoreChainedOperatorState = chainedState;
-		lazyRestoreKeyGroupStates = keyGroupsState;
-		lazyRestoreOperatorState = partitionableOperatorState;
-	}
-
-	private void restoreState() throws Exception {
-		final StreamOperator<?>[] allOperators = operatorChain.getAllOperators();
-
-		if (lazyRestoreChainedOperatorState != null) {
-			Preconditions.checkState(lazyRestoreChainedOperatorState.getLength() == allOperators.length,
-					"Invalid Invalid number of operator states. Found :" + lazyRestoreChainedOperatorState.getLength() +
-							". Expected: " + allOperators.length);
-		}
-
-		if (lazyRestoreOperatorState != null) {
-			Preconditions.checkArgument(lazyRestoreOperatorState.isEmpty()
-							|| lazyRestoreOperatorState.size() == allOperators.length,
-					"Invalid number of operator states. Found :" + lazyRestoreOperatorState.size() +
-							". Expected: " + allOperators.length);
-		}
-
-		for (int i = 0; i < allOperators.length; i++) {
-			StreamOperator<?> operator = allOperators[i];
-
-			if (null != lazyRestoreOperatorState && !lazyRestoreOperatorState.isEmpty()) {
-				operator.restoreState(lazyRestoreOperatorState.get(i));
-			}
-
-			// TODO deprecated code path
-			if (operator instanceof StreamCheckpointedOperator) {
-
-				if (lazyRestoreChainedOperatorState != null) {
-					StreamStateHandle state = lazyRestoreChainedOperatorState.get(i);
-
-					if (state != null) {
-						LOG.debug("Restore state of task {} in chain ({}).", i, getName());
-
-						FSDataInputStream is = state.openInputStream();
-						try {
-							cancelables.registerClosable(is);
-							((StreamCheckpointedOperator) operator).restoreState(is);
-						} finally {
-							cancelables.unregisterClosable(is);
-							is.close();
-						}
-					}
-				}
-			}
-		}
+	public void setInitialState(TaskStateHandles taskStateHandles) {
+		this.restoreStateHandles = taskStateHandles;
 	}
 
 	@Override
@@ -600,117 +544,19 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 
 	private boolean performCheckpoint(CheckpointMetaData checkpointMetaData) throws Exception {
 
-		long checkpointId = checkpointMetaData.getCheckpointId();
-		long timestamp = checkpointMetaData.getTimestamp();
-
-		LOG.debug("Starting checkpoint {} on task {}", checkpointId, getName());
+		LOG.debug("Starting checkpoint {} on task {}", checkpointMetaData.getCheckpointId(), getName());
 
 		synchronized (lock) {
 			if (isRunning) {
 
-				final long startOfSyncPart = System.nanoTime();
-
 				// Since both state checkpointing and downstream barrier emission occurs in this
 				// lock scope, they are an atomic operation regardless of the order in which they occur.
 				// Given this, we immediately emit the checkpoint barriers, so the downstream operators
 				// can start their checkpoint work as soon as possible
-				operatorChain.broadcastCheckpointBarrier(checkpointId, timestamp);
-
-				// now draw the state snapshot
-				final StreamOperator<?>[] allOperators = operatorChain.getAllOperators();
-
-				final List<StreamStateHandle> nonPartitionedStates =
-						Arrays.asList(new StreamStateHandle[allOperators.length]);
-
-				final List<OperatorStateHandle> operatorStates =
-						Arrays.asList(new OperatorStateHandle[allOperators.length]);
-
-				for (int i = 0; i < allOperators.length; i++) {
-					StreamOperator<?> operator = allOperators[i];
-
-					if (operator != null) {
-
-						final String operatorId = createOperatorIdentifier(operator, configuration.getVertexID());
-
-						CheckpointStreamFactory streamFactory =
-								stateBackend.createStreamFactory(getEnvironment().getJobID(), operatorId);
-
-						//TODO deprecated code path
-						if (operator instanceof StreamCheckpointedOperator) {
-
-							CheckpointStreamFactory.CheckpointStateOutputStream outStream =
-									streamFactory.createCheckpointStateOutputStream(checkpointId, timestamp);
-
-
-							cancelables.registerClosable(outStream);
-
-							try {
-								((StreamCheckpointedOperator) operator).
-										snapshotState(outStream, checkpointId, timestamp);
-
-								nonPartitionedStates.set(i, outStream.closeAndGetHandle());
-							} finally {
-								cancelables.unregisterClosable(outStream);
-							}
-						}
-
-						RunnableFuture<OperatorStateHandle> handleFuture =
-								operator.snapshotState(checkpointId, timestamp, streamFactory);
-
-						if (null != handleFuture) {
-							//TODO for now we assume there are only synchrous snapshots, no need to start the runnable.
-							if (!handleFuture.isDone()) {
-								throw new IllegalStateException("Currently only supports synchronous snapshots!");
-							}
-
-							operatorStates.set(i, handleFuture.get());
-						}
-					}
-
-				}
-
-				RunnableFuture<KeyGroupsStateHandle> keyGroupsStateHandleFuture = null;
-
-				if (keyedStateBackend != null) {
-					CheckpointStreamFactory streamFactory = stateBackend.createStreamFactory(
-							getEnvironment().getJobID(),
-							createOperatorIdentifier(headOperator, configuration.getVertexID()));
-
-					keyGroupsStateHandleFuture = keyedStateBackend.snapshot(checkpointId, timestamp, streamFactory);
-				}
-
-				ChainedStateHandle<StreamStateHandle> chainedNonPartitionedStateHandles =
-						new ChainedStateHandle<>(nonPartitionedStates);
-
-				ChainedStateHandle<OperatorStateHandle> chainedPartitionedStateHandles =
-						new ChainedStateHandle<>(operatorStates);
-
-				LOG.debug("Finished synchronous checkpoints for checkpoint {} on task {}", checkpointId, getName());
-
-				final long syncEndNanos = System.nanoTime();
-				final long syncDurationMillis = (syncEndNanos - startOfSyncPart) / 1_000_000;
-
-				checkpointMetaData.setSyncDurationMillis(syncDurationMillis);
-
-				AsyncCheckpointRunnable asyncCheckpointRunnable = new AsyncCheckpointRunnable(
-						"checkpoint-" + checkpointId + "-" + timestamp,
-						this,
-						cancelables,
-						chainedNonPartitionedStateHandles,
-						chainedPartitionedStateHandles,
-						keyGroupsStateHandleFuture,
-						checkpointMetaData,
-						syncEndNanos);
-
-				cancelables.registerClosable(asyncCheckpointRunnable);
-				asyncOperationsThreadPool.submit(asyncCheckpointRunnable);
-
-				if (LOG.isDebugEnabled()) {
-					LOG.debug("{} - finished synchronous part of checkpoint {}." +
-							"Alignment duration: {} ms, snapshot duration {} ms",
-							getName(), checkpointId, checkpointMetaData.getAlignmentDurationNanos() / 1_000_000, syncDurationMillis);
-				}
+				operatorChain.broadcastCheckpointBarrier(
+						checkpointMetaData.getCheckpointId(), checkpointMetaData.getTimestamp());
 
+				checkpointState(checkpointMetaData);
 				return true;
 			} else {
 				return false;
@@ -740,6 +586,59 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 		}
 	}
 
+	private void checkpointState(CheckpointMetaData checkpointMetaData) throws Exception {
+		CheckpointingOperation checkpointingOperation = new CheckpointingOperation(this, checkpointMetaData);
+		checkpointingOperation.executeCheckpointing();
+	}
+
+	private void initializeState() throws Exception {
+
+		boolean restored = null != restoreStateHandles;
+
+		if (restored) {
+
+			checkRestorePreconditions(operatorChain.getChainLength());
+			initializeOperators(true);
+			restoreStateHandles = null; // free for GC
+		} else {
+			initializeOperators(false);
+		}
+	}
+
+	private void initializeOperators(boolean restored) throws Exception {
+		StreamOperator<?>[] allOperators = operatorChain.getAllOperators();
+		for (int chainIdx = 0; chainIdx < allOperators.length; ++chainIdx) {
+			StreamOperator<?> operator = allOperators[chainIdx];
+			if (null != operator) {
+				if (restored) {
+					operator.initializeState(new OperatorStateHandles(restoreStateHandles, chainIdx));
+				} else {
+					operator.initializeState(null);
+				}
+			}
+		}
+	}
+
+	private void checkRestorePreconditions(int operatorChainLength) {
+
+		ChainedStateHandle<StreamStateHandle> nonPartitionableOperatorStates =
+				restoreStateHandles.getLegacyOperatorState();
+		List<Collection<OperatorStateHandle>> operatorStates =
+				restoreStateHandles.getManagedOperatorState();
+
+		if (nonPartitionableOperatorStates != null) {
+			Preconditions.checkState(nonPartitionableOperatorStates.getLength() == operatorChainLength,
+					"Invalid Invalid number of operator states. Found :" + nonPartitionableOperatorStates.getLength()
+							+ ". Expected: " + operatorChainLength);
+		}
+
+		if (!CollectionUtil.isNullOrEmpty(operatorStates)) {
+			Preconditions.checkArgument(operatorStates.size() == operatorChainLength,
+					"Invalid number of operator states. Found :" + operatorStates.size() +
+							". Expected: " + operatorChainLength);
+		}
+	}
+
 	// ------------------------------------------------------------------------
 	//  State backend
 	// ------------------------------------------------------------------------
@@ -777,7 +676,8 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 					try {
 						@SuppressWarnings("rawtypes")
 						Class<? extends StateBackendFactory> clazz =
-								Class.forName(backendName, false, getUserCodeClassLoader()).asSubclass(StateBackendFactory.class);
+								Class.forName(backendName, false, getUserCodeClassLoader()).
+										asSubclass(StateBackendFactory.class);
 
 						stateBackend = clazz.newInstance().createFromConfig(flinkConfig);
 					} catch (ClassNotFoundException e) {
@@ -799,7 +699,7 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 			StreamOperator<?> op, Collection<OperatorStateHandle> restoreStateHandles) throws Exception {
 
 		Environment env = getEnvironment();
-		String opId = createOperatorIdentifier(op, configuration.getVertexID());
+		String opId = createOperatorIdentifier(op, getConfiguration().getVertexID());
 
 		OperatorStateBackend newBackend = restoreStateHandles == null ?
 				stateBackend.createOperatorStateBackend(env, opId)
@@ -823,7 +723,7 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 				headOperator,
 				configuration.getVertexID());
 
-		if (lazyRestoreKeyGroupStates != null) {
+		if (null != restoreStateHandles && null != restoreStateHandles.getManagedKeyedState()) {
 			keyedStateBackend = stateBackend.restoreKeyedStateBackend(
 					getEnvironment(),
 					getEnvironment().getJobID(),
@@ -831,10 +731,10 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 					keySerializer,
 					numberOfKeyGroups,
 					keyGroupRange,
-					lazyRestoreKeyGroupStates,
+					restoreStateHandles.getManagedKeyedState(),
 					getEnvironment().getTaskKvStateRegistry());
 
-			lazyRestoreKeyGroupStates = null; // GC friendliness
+			restoreStateHandles = null; // GC friendliness
 		} else {
 			keyedStateBackend = stateBackend.createKeyedStateBackend(
 					getEnvironment(),
@@ -913,62 +813,60 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 
 	// ------------------------------------------------------------------------
 
-	private static class AsyncCheckpointRunnable implements Runnable, Closeable {
+	private static final class AsyncCheckpointRunnable implements Runnable, Closeable {
 
 		private final StreamTask<?, ?> owner;
 
-		private final ClosableRegistry cancelables;
-
-		private final ChainedStateHandle<StreamStateHandle> nonPartitionedStateHandles;
-
-		private final ChainedStateHandle<OperatorStateHandle> partitioneableStateHandles;
+		private final List<OperatorSnapshotResult> snapshotInProgressList;
 
-		private final RunnableFuture<KeyGroupsStateHandle> keyGroupsStateHandleFuture;
+		RunnableFuture<KeyGroupsStateHandle> futureKeyedBackendStateHandles;
+		RunnableFuture<KeyGroupsStateHandle> futureKeyedStreamStateHandles;
 
-		private final String name;
+		List<StreamStateHandle> nonPartitionedStateHandles;
 
 		private final CheckpointMetaData checkpointMetaData;
 
 		private final long asyncStartNanos;
 
 		AsyncCheckpointRunnable(
-				String name,
 				StreamTask<?, ?> owner,
-				ClosableRegistry cancelables,
-				ChainedStateHandle<StreamStateHandle> nonPartitionedStateHandles,
-				ChainedStateHandle<OperatorStateHandle> partitioneableStateHandles,
-				RunnableFuture<KeyGroupsStateHandle> keyGroupsStateHandleFuture,
+				List<StreamStateHandle> nonPartitionedStateHandles,
+				List<OperatorSnapshotResult> snapshotInProgressList,
 				CheckpointMetaData checkpointMetaData,
-				long asyncStartNanos
-		) {
+				long asyncStartNanos) {
 
-			this.name = name;
-			this.owner = owner;
-			this.cancelables = cancelables;
+			this.owner = Preconditions.checkNotNull(owner);
+			this.snapshotInProgressList = Preconditions.checkNotNull(snapshotInProgressList);
+			this.checkpointMetaData = Preconditions.checkNotNull(checkpointMetaData);
 			this.nonPartitionedStateHandles = nonPartitionedStateHandles;
-			this.partitioneableStateHandles = partitioneableStateHandles;
-			this.keyGroupsStateHandleFuture = keyGroupsStateHandleFuture;
-			this.checkpointMetaData = checkpointMetaData;
 			this.asyncStartNanos = asyncStartNanos;
+
+			if (!snapshotInProgressList.isEmpty()) {
+				// TODO Currently only the head operator of a chain can have keyed state, so simply access it directly.
+				int headIndex = snapshotInProgressList.size() - 1;
+				OperatorSnapshotResult snapshotInProgress = snapshotInProgressList.get(headIndex);
+				if (null != snapshotInProgress) {
+					this.futureKeyedBackendStateHandles = snapshotInProgress.getKeyedStateManagedFuture();
+					this.futureKeyedStreamStateHandles = snapshotInProgress.getKeyedStateRawFuture();
+				}
+			}
 		}
 
 		@Override
 		public void run() {
+
 			try {
 
-				List<KeyGroupsStateHandle> keyedStates = Collections.emptyList();
+				// Keyed state handle future, currently only one (the head) operator can have this
+				KeyGroupsStateHandle keyedStateHandleBackend = FutureUtil.runIfNotDoneAndGet(futureKeyedBackendStateHandles);
+				KeyGroupsStateHandle keyedStateHandleStream = FutureUtil.runIfNotDoneAndGet(futureKeyedStreamStateHandles);
 
-				if (keyGroupsStateHandleFuture != null) {
+				List<OperatorStateHandle> operatorStatesBackend = new ArrayList<>(snapshotInProgressList.size());
+				List<OperatorStateHandle> operatorStatesStream = new ArrayList<>(snapshotInProgressList.size());
 
-					if (!keyGroupsStateHandleFuture.isDone()) {
-						//TODO this currently works because we only have one RunnableFuture
-						keyGroupsStateHandleFuture.run();
-					}
-
-					KeyGroupsStateHandle keyGroupsStateHandle = this.keyGroupsStateHandleFuture.get();
-					if (keyGroupsStateHandle != null) {
-						keyedStates = Collections.singletonList(keyGroupsStateHandle);
-					}
+				for (OperatorSnapshotResult snapshotInProgress : snapshotInProgressList) {
+					operatorStatesBackend.add(FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getOperatorStateManagedFuture()));
+					operatorStatesStream.add(FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getOperatorStateRawFuture()));
 				}
 
 				final long asyncEndNanos = System.nanoTime();
@@ -976,37 +874,161 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 
 				checkpointMetaData.setAsyncDurationMillis(asyncDurationMillis);
 
-				if (nonPartitionedStateHandles.isEmpty() && partitioneableStateHandles.isEmpty() && keyedStates.isEmpty()) {
-					owner.getEnvironment().acknowledgeCheckpoint(checkpointMetaData);
-				} else {
-					CheckpointStateHandles allStateHandles = new CheckpointStateHandles(
-							nonPartitionedStateHandles,
-							partitioneableStateHandles,
-							keyedStates);
+				ChainedStateHandle<StreamStateHandle> chainedNonPartitionedOperatorsState =
+						new ChainedStateHandle<>(nonPartitionedStateHandles);
 
-					owner.getEnvironment().acknowledgeCheckpoint(checkpointMetaData, allStateHandles);
+				ChainedStateHandle<OperatorStateHandle> chainedOperatorStateBackend =
+						new ChainedStateHandle<>(operatorStatesBackend);
+
+				ChainedStateHandle<OperatorStateHandle> chainedOperatorStateStream =
+						new ChainedStateHandle<>(operatorStatesStream);
+
+				SubtaskState subtaskState = new SubtaskState(
+						chainedNonPartitionedOperatorsState,
+						chainedOperatorStateBackend,
+						chainedOperatorStateStream,
+						keyedStateHandleBackend,
+						keyedStateHandleStream);
+
+				if (subtaskState.hasState()) {
+					owner.getEnvironment().acknowledgeCheckpoint(checkpointMetaData, subtaskState);
+				} else {
+					owner.getEnvironment().acknowledgeCheckpoint(checkpointMetaData);
 				}
 
 				if (LOG.isDebugEnabled()) {
-					LOG.debug("{} - finished asynchronous part of checkpoint {}. Asynchronous duration: {} ms", 
+					LOG.debug("{} - finished asynchronous part of checkpoint {}. Asynchronous duration: {} ms",
 							owner.getName(), checkpointMetaData.getCheckpointId(), asyncDurationMillis);
 				}
-			}
-			catch (Exception e) {
+			} catch (Exception e) {
 				// registers the exception and tries to fail the whole task
 				AsynchronousException asyncException = new AsynchronousException(e);
 				owner.handleAsyncException("Failure in asynchronous checkpoint materialization", asyncException);
-			}
-			finally {
-				cancelables.unregisterClosable(this);
+			} finally {
+				owner.cancelables.unregisterClosable(this);
 			}
 		}
 
 		@Override
 		public void close() {
-			if (keyGroupsStateHandleFuture != null) {
-				keyGroupsStateHandleFuture.cancel(true);
+			//TODO Handle other state futures in case we actually run them. Currently they are just DoneFutures.
+			if (futureKeyedBackendStateHandles != null) {
+				futureKeyedBackendStateHandles.cancel(true);
+			}
+		}
+	}
+
+	public ClosableRegistry getCancelables() {
+		return cancelables;
+	}
+
+	// ------------------------------------------------------------------------
+
+	private static final class CheckpointingOperation {
+
+		private final StreamTask<?, ?> owner;
+
+		private final CheckpointMetaData checkpointMetaData;
+
+		private final StreamOperator<?>[] allOperators;
+
+		private long startSyncPartNano;
+		private long startAsyncPartNano;
+
+		// ------------------------
+
+		private CheckpointStreamFactory streamFactory;
+
+		private final List<StreamStateHandle> nonPartitionedStates;
+		private final List<OperatorSnapshotResult> snapshotInProgressList;
+
+		public CheckpointingOperation(StreamTask<?, ?> owner, CheckpointMetaData checkpointMetaData) {
+			this.owner = Preconditions.checkNotNull(owner);
+			this.checkpointMetaData = Preconditions.checkNotNull(checkpointMetaData);
+			this.allOperators = owner.operatorChain.getAllOperators();
+			this.nonPartitionedStates = new ArrayList<>(allOperators.length);
+			this.snapshotInProgressList = new ArrayList<>(allOperators.length);
+		}
+
+		public void executeCheckpointing() throws Exception {
+
+			startSyncPartNano = System.nanoTime();
+
+			for (StreamOperator<?> op : allOperators) {
+
+				createStreamFactory(op);
+				snapshotNonPartitionableState(op);
+
+				OperatorSnapshotResult snapshotInProgress =
+						op.snapshotState(checkpointMetaData.getCheckpointId(), checkpointMetaData.getTimestamp(), streamFactory);
+
+				snapshotInProgressList.add(snapshotInProgress);
 			}
+
+			if (LOG.isDebugEnabled()) {
+				LOG.debug("Finished synchronous checkpoints for checkpoint {} on task {}",
+						checkpointMetaData.getCheckpointId(), owner.getName());
+			}
+
+			startAsyncPartNano= System.nanoTime();
+
+			checkpointMetaData.setSyncDurationMillis((startAsyncPartNano - startSyncPartNano) / 1_000_000);
+
+			runAsyncCheckpointingAndAcknowledge();
+
+			if (LOG.isDebugEnabled()) {
+				LOG.debug("{} - finished synchronous part of checkpoint {}." +
+								"Alignment duration: {} ms, snapshot duration {} ms",
+						owner.getName(), checkpointMetaData.getCheckpointId(),
+						checkpointMetaData.getAlignmentDurationNanos() / 1_000_000,
+						checkpointMetaData.getSyncDurationMillis());
+			}
+		}
+
+		private void createStreamFactory(StreamOperator<?> operator) throws IOException {
+			String operatorId = owner.createOperatorIdentifier(operator, owner.configuration.getVertexID());
+			this.streamFactory = owner.stateBackend.createStreamFactory(owner.getEnvironment().getJobID(), operatorId);
+		}
+
+		//TODO deprecated code path
+		private void snapshotNonPartitionableState(StreamOperator<?> operator) throws Exception {
+
+			StreamStateHandle stateHandle = null;
+
+			if (operator instanceof StreamCheckpointedOperator) {
+
+				CheckpointStreamFactory.CheckpointStateOutputStream outStream =
+						streamFactory.createCheckpointStateOutputStream(
+								checkpointMetaData.getCheckpointId(), checkpointMetaData.getTimestamp());
+
+				owner.cancelables.registerClosable(outStream);
+
+				try {
+					((StreamCheckpointedOperator) operator).
+							snapshotState(
+									outStream,
+									checkpointMetaData.getCheckpointId(),
+									checkpointMetaData.getTimestamp());
+
+					stateHandle = outStream.closeAndGetHandle();
+				} finally {
+					owner.cancelables.unregisterClosable(outStream);
+					outStream.close();
+				}
+			}
+			nonPartitionedStates.add(stateHandle);
+		}
+
+		public void runAsyncCheckpointingAndAcknowledge() throws IOException {
+			AsyncCheckpointRunnable asyncCheckpointRunnable = new AsyncCheckpointRunnable(
+					owner,
+					nonPartitionedStates,
+					snapshotInProgressList,
+					checkpointMetaData,
+					startAsyncPartNano);
+
+			owner.cancelables.registerClosable(asyncCheckpointRunnable);
+			owner.asyncOperationsThreadPool.submit(asyncCheckpointRunnable);
 		}
 	}
 }