You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by al...@apache.org on 2016/10/20 14:15:19 UTC

[1/8] flink git commit: [FLINK-4842] Introduce test to enforce order of operator / udf lifecycles

Repository: flink
Updated Branches:
  refs/heads/master 428419d59 -> 1e475c768


[FLINK-4842] Introduce test to enforce order of operator / udf lifecycles


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/1e475c76
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/1e475c76
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/1e475c76

Branch: refs/heads/master
Commit: 1e475c768ae0d7e13746a3ca6aa258141016d419
Parents: cab9cd4
Author: Stefan Richter <s....@data-artisans.com>
Authored: Thu Oct 13 11:32:19 2016 +0200
Committer: Aljoscha Krettek <al...@gmail.com>
Committed: Thu Oct 20 16:14:21 2016 +0200

----------------------------------------------------------------------
 .../AbstractUdfStreamOperatorLifecycleTest.java | 293 +++++++++++++++++++
 .../AbstractUdfStreamOperatorTest.java          | 219 --------------
 .../streaming/runtime/tasks/StreamTaskTest.java |   4 +-
 3 files changed, 295 insertions(+), 221 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/1e475c76/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorLifecycleTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorLifecycleTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorLifecycleTest.java
new file mode 100644
index 0000000..cbb833b
--- /dev/null
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorLifecycleTest.java
@@ -0,0 +1,293 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators;
+
+import org.apache.flink.api.common.functions.RichFunction;
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.testutils.OneShotLatch;
+import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.execution.ExecutionState;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.runtime.taskmanager.Task;
+import org.apache.flink.streaming.api.TimeCharacteristic;
+import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.graph.StreamConfig;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.SourceStreamTask;
+import org.apache.flink.streaming.runtime.tasks.StreamTask;
+import org.apache.flink.streaming.runtime.tasks.StreamTaskTest;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.Serializable;
+import java.lang.reflect.Method;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * This test secures the lifecycle of AbstractUdfStreamOperator, including it's UDF handling.
+ */
+public class AbstractUdfStreamOperatorLifecycleTest {
+
+	private static final List<String> EXPECTED_CALL_ORDER_FULL = Arrays.asList(
+			"OPERATOR::setup",
+			"UDF::setRuntimeContext",
+			"OPERATOR::initializeState",
+			"OPERATOR::open",
+			"UDF::open",
+			"OPERATOR::run",
+			"UDF::run",
+			"OPERATOR::snapshotState",
+			"OPERATOR::close",
+			"UDF::close",
+			"OPERATOR::dispose");
+
+	private static final List<String> EXPECTED_CALL_ORDER_CANCEL_RUNNING = Arrays.asList(
+			"OPERATOR::setup",
+			"UDF::setRuntimeContext",
+			"OPERATOR::initializeState",
+			"OPERATOR::open",
+			"UDF::open",
+			"OPERATOR::run",
+			"UDF::run",
+			"OPERATOR::cancel",
+			"UDF::cancel",
+			"OPERATOR::dispose",
+			"UDF::close");
+
+	private static final String ALL_METHODS_STREAM_OPERATOR = "[close[], dispose[], getChainingStrategy[], " +
+			"getMetricGroup[], initializeState[class org.apache.flink.streaming.runtime.tasks.OperatorStateHandles], " +
+			"notifyOfCompletedCheckpoint[long], open[], setChainingStrategy[class " +
+			"org.apache.flink.streaming.api.operators.ChainingStrategy], setKeyContextElement1[class " +
+			"org.apache.flink.streaming.runtime.streamrecord.StreamRecord], " +
+			"setKeyContextElement2[class org.apache.flink.streaming.runtime.streamrecord.StreamRecord], " +
+			"setup[class org.apache.flink.streaming.runtime.tasks.StreamTask, class " +
+			"org.apache.flink.streaming.api.graph.StreamConfig, interface " +
+			"org.apache.flink.streaming.api.operators.Output], snapshotState[long, long, " +
+			"interface org.apache.flink.runtime.state.CheckpointStreamFactory]]";
+
+	private static final String ALL_METHODS_RICH_FUNCTION = "[close[], getIterationRuntimeContext[], getRuntimeContext[]" +
+			", open[class org.apache.flink.configuration.Configuration], setRuntimeContext[interface " +
+			"org.apache.flink.api.common.functions.RuntimeContext]]";
+
+	private static final List<String> ACTUAL_ORDER_TRACKING =
+			Collections.synchronizedList(new ArrayList<String>(EXPECTED_CALL_ORDER_FULL.size()));
+
+	@Test
+	public void testAllMethodsRegisteredInTest() {
+		List<String> methodsWithSignatureString = new ArrayList<>();
+		for (Method method : StreamOperator.class.getMethods()) {
+			methodsWithSignatureString.add(method.getName() + Arrays.toString(method.getParameterTypes()));
+		}
+		Collections.sort(methodsWithSignatureString);
+		Assert.assertEquals("It seems like new methods have been introduced to " + StreamOperator.class +
+				". Please register them with this test and ensure to document their position in the lifecycle " +
+				"(if applicable).", ALL_METHODS_STREAM_OPERATOR, methodsWithSignatureString.toString());
+
+		methodsWithSignatureString = new ArrayList<>();
+		for (Method method : RichFunction.class.getMethods()) {
+			methodsWithSignatureString.add(method.getName() + Arrays.toString(method.getParameterTypes()));
+		}
+		Collections.sort(methodsWithSignatureString);
+		Assert.assertEquals("It seems like new methods have been introduced to " + RichFunction.class +
+				". Please register them with this test and ensure to document their position in the lifecycle " +
+				"(if applicable).", ALL_METHODS_RICH_FUNCTION, methodsWithSignatureString.toString());
+	}
+
+	@Test
+	public void testLifeCycleFull() throws Exception {
+		ACTUAL_ORDER_TRACKING.clear();
+
+		Configuration taskManagerConfig = new Configuration();
+		StreamConfig cfg = new StreamConfig(new Configuration());
+		MockSourceFunction srcFun = new MockSourceFunction();
+
+		cfg.setStreamOperator(new LifecycleTrackingStreamSource(srcFun, true));
+		cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
+
+		Task task = StreamTaskTest.createTask(SourceStreamTask.class, cfg, taskManagerConfig);
+
+		task.startTaskThread();
+
+		LifecycleTrackingStreamSource.runStarted.await();
+
+		// wait for clean termination
+		task.getExecutingThread().join();
+		assertEquals(ExecutionState.FINISHED, task.getExecutionState());
+		assertEquals(EXPECTED_CALL_ORDER_FULL, ACTUAL_ORDER_TRACKING);
+	}
+
+	@Test
+	public void testLifeCycleCancel() throws Exception {
+		ACTUAL_ORDER_TRACKING.clear();
+
+		Configuration taskManagerConfig = new Configuration();
+		StreamConfig cfg = new StreamConfig(new Configuration());
+		MockSourceFunction srcFun = new MockSourceFunction();
+		cfg.setStreamOperator(new LifecycleTrackingStreamSource(srcFun, false));
+		cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
+
+		Task task = StreamTaskTest.createTask(SourceStreamTask.class, cfg, taskManagerConfig);
+
+		task.startTaskThread();
+		LifecycleTrackingStreamSource.runStarted.await();
+
+		// this should cancel the task even though it is blocked on runFinished
+		task.cancelExecution();
+
+		// wait for clean termination
+		task.getExecutingThread().join();
+		assertEquals(ExecutionState.CANCELED, task.getExecutionState());
+		assertEquals(EXPECTED_CALL_ORDER_CANCEL_RUNNING, ACTUAL_ORDER_TRACKING);
+	}
+
+	private static class MockSourceFunction extends RichSourceFunction<Long> {
+
+		private static final long serialVersionUID = 1L;
+
+		@Override
+		public void run(SourceContext<Long> ctx) {
+			ACTUAL_ORDER_TRACKING.add("UDF::run");
+		}
+
+		@Override
+		public void cancel() {
+			ACTUAL_ORDER_TRACKING.add("UDF::cancel");
+		}
+
+		@Override
+		public void setRuntimeContext(RuntimeContext t) {
+			ACTUAL_ORDER_TRACKING.add("UDF::setRuntimeContext");
+			super.setRuntimeContext(t);
+		}
+
+		@Override
+		public void open(Configuration parameters) throws Exception {
+			ACTUAL_ORDER_TRACKING.add("UDF::open");
+			super.open(parameters);
+		}
+
+		@Override
+		public void close() throws Exception {
+			ACTUAL_ORDER_TRACKING.add("UDF::close");
+			super.close();
+		}
+	}
+
+	private static class LifecycleTrackingStreamSource<OUT, SRC extends SourceFunction<OUT>>
+			extends StreamSource<OUT, SRC> implements Serializable {
+
+		private static final long serialVersionUID = 2431488948886850562L;
+		private transient Thread testCheckpointer;
+
+		private final boolean simulateCheckpointing;
+
+		static OneShotLatch runStarted;
+		static OneShotLatch runFinish;
+
+		public LifecycleTrackingStreamSource(SRC sourceFunction, boolean simulateCheckpointing) {
+			super(sourceFunction);
+			this.simulateCheckpointing = simulateCheckpointing;
+			runStarted = new OneShotLatch();
+			runFinish = new OneShotLatch();
+		}
+
+		@Override
+		public void run(Object lockingObject, Output<StreamRecord<OUT>> collector) throws Exception {
+			ACTUAL_ORDER_TRACKING.add("OPERATOR::run");
+			super.run(lockingObject, collector);
+			runStarted.trigger();
+			runFinish.await();
+		}
+
+		@Override
+		public void setup(StreamTask<?, ?> containingTask, StreamConfig config, Output<StreamRecord<OUT>> output) {
+			ACTUAL_ORDER_TRACKING.add("OPERATOR::setup");
+			super.setup(containingTask, config, output);
+			if (simulateCheckpointing) {
+				testCheckpointer = new Thread() {
+					@Override
+					public void run() {
+						long id = 0;
+						while (true) {
+							try {
+								Thread.sleep(50);
+								if (getContainingTask().isCanceled() || getContainingTask().triggerCheckpoint(
+										new CheckpointMetaData(id++, System.currentTimeMillis()))) {
+									LifecycleTrackingStreamSource.runFinish.trigger();
+									break;
+								}
+							} catch (Exception e) {
+								e.printStackTrace();
+								Assert.fail();
+							}
+						}
+					}
+				};
+				testCheckpointer.start();
+			}
+		}
+
+		@Override
+		public void snapshotState(StateSnapshotContext context) throws Exception {
+			ACTUAL_ORDER_TRACKING.add("OPERATOR::snapshotState");
+			super.snapshotState(context);
+		}
+
+		@Override
+		public void initializeState(StateInitializationContext context) throws Exception {
+			ACTUAL_ORDER_TRACKING.add("OPERATOR::initializeState");
+			super.initializeState(context);
+		}
+
+		@Override
+		public void open() throws Exception {
+			ACTUAL_ORDER_TRACKING.add("OPERATOR::open");
+			super.open();
+		}
+
+		@Override
+		public void close() throws Exception {
+			ACTUAL_ORDER_TRACKING.add("OPERATOR::close");
+			super.close();
+		}
+
+		@Override
+		public void cancel() {
+			ACTUAL_ORDER_TRACKING.add("OPERATOR::cancel");
+			super.cancel();
+		}
+
+		@Override
+		public void dispose() throws Exception {
+			ACTUAL_ORDER_TRACKING.add("OPERATOR::dispose");
+			super.dispose();
+			if (simulateCheckpointing) {
+				testCheckpointer.join();
+			}
+		}
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/1e475c76/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorTest.java
deleted file mode 100644
index f5d633c..0000000
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorTest.java
+++ /dev/null
@@ -1,219 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.streaming.api.operators;
-
-import org.apache.flink.api.common.ExecutionConfig;
-import org.apache.flink.api.common.JobID;
-import org.apache.flink.api.common.functions.RuntimeContext;
-import org.apache.flink.configuration.Configuration;
-import org.apache.flink.runtime.blob.BlobKey;
-import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
-import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
-import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
-import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor;
-import org.apache.flink.runtime.execution.ExecutionState;
-import org.apache.flink.runtime.execution.librarycache.LibraryCacheManager;
-import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
-import org.apache.flink.runtime.filecache.FileCache;
-import org.apache.flink.runtime.io.disk.iomanager.IOManager;
-import org.apache.flink.runtime.io.network.NetworkEnvironment;
-import org.apache.flink.runtime.io.network.netty.PartitionStateChecker;
-import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier;
-import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
-import org.apache.flink.runtime.jobgraph.JobVertexID;
-import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
-import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
-import org.apache.flink.runtime.memory.MemoryManager;
-import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup;
-import org.apache.flink.runtime.query.TaskKvStateRegistry;
-import org.apache.flink.runtime.state.StateInitializationContext;
-import org.apache.flink.runtime.state.StateSnapshotContext;
-import org.apache.flink.runtime.taskmanager.CheckpointResponder;
-import org.apache.flink.runtime.taskmanager.Task;
-import org.apache.flink.runtime.taskmanager.TaskManagerConnection;
-import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
-import org.apache.flink.streaming.api.TimeCharacteristic;
-import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
-import org.apache.flink.streaming.api.functions.source.SourceFunction;
-import org.apache.flink.streaming.api.graph.StreamConfig;
-import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
-import org.apache.flink.streaming.runtime.tasks.SourceStreamTask;
-import org.apache.flink.streaming.runtime.tasks.StreamTask;
-import org.apache.flink.util.SerializedValue;
-import org.junit.Test;
-
-import java.io.Serializable;
-import java.net.URL;
-import java.util.Collections;
-import java.util.concurrent.Executor;
-
-import static org.junit.Assert.assertEquals;
-import static org.mockito.Matchers.any;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
-
-public class AbstractUdfStreamOperatorTest {
-
-	@Test
-	public void testLifeCycle() throws Exception {
-
-		Configuration taskManagerConfig = new Configuration();
-
-		StreamConfig cfg = new StreamConfig(new Configuration());
-		cfg.setStreamOperator(new LifecycleTrackingStreamSource(new MockSourceFunction()));
-		cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
-
-		Task task = createTask(SourceStreamTask.class, cfg, taskManagerConfig);
-
-		task.startTaskThread();
-
-		// wait for clean termination
-		task.getExecutingThread().join();
-		assertEquals(ExecutionState.FINISHED, task.getExecutionState());
-	}
-
-	private static class MockSourceFunction extends RichSourceFunction<Long> {
-
-		private static final long serialVersionUID = 1L;
-
-		@Override
-		public void run(SourceContext<Long> ctx) {
-		}
-
-		@Override
-		public void cancel() {
-		}
-
-		@Override
-		public void setRuntimeContext(RuntimeContext t) {
-			System.out.println("!setRuntimeContext");
-			super.setRuntimeContext(t);
-		}
-
-		@Override
-		public void open(Configuration parameters) throws Exception {
-			System.out.println("!open");
-			super.open(parameters);
-		}
-
-		@Override
-		public void close() throws Exception {
-			System.out.println("!close");
-			super.close();
-		}
-	}
-
-	private Task createTask(
-			Class<? extends AbstractInvokable> invokable,
-			StreamConfig taskConfig,
-			Configuration taskManagerConfig) throws Exception {
-
-		LibraryCacheManager libCache = mock(LibraryCacheManager.class);
-		when(libCache.getClassLoader(any(JobID.class))).thenReturn(getClass().getClassLoader());
-
-		ResultPartitionManager partitionManager = mock(ResultPartitionManager.class);
-		ResultPartitionConsumableNotifier consumableNotifier = mock(ResultPartitionConsumableNotifier.class);
-		PartitionStateChecker partitionStateChecker = mock(PartitionStateChecker.class);
-		Executor executor = mock(Executor.class);
-
-		NetworkEnvironment network = mock(NetworkEnvironment.class);
-		when(network.getResultPartitionManager()).thenReturn(partitionManager);
-		when(network.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC);
-		when(network.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class)))
-				.thenReturn(mock(TaskKvStateRegistry.class));
-
-		TaskDeploymentDescriptor tdd = new TaskDeploymentDescriptor(
-				new JobID(), "Job Name", new JobVertexID(), new ExecutionAttemptID(),
-				new SerializedValue<>(new ExecutionConfig()),
-				"Test Task", 1, 0, 1, 0,
-				new Configuration(),
-				taskConfig.getConfiguration(),
-				invokable.getName(),
-				Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
-				Collections.<InputGateDeploymentDescriptor>emptyList(),
-				Collections.<BlobKey>emptyList(),
-				Collections.<URL>emptyList(),
-				0);
-
-		return new Task(
-				tdd,
-				mock(MemoryManager.class),
-				mock(IOManager.class),
-				network,
-				mock(BroadcastVariableManager.class),
-				mock(TaskManagerConnection.class),
-				mock(InputSplitProvider.class),
-				mock(CheckpointResponder.class),
-				libCache,
-				mock(FileCache.class),
-				new TaskManagerRuntimeInfo("localhost", taskManagerConfig, System.getProperty("java.io.tmpdir")),
-				new UnregisteredTaskMetricsGroup(),
-				consumableNotifier,
-				partitionStateChecker,
-				executor);
-	}
-
-	static class LifecycleTrackingStreamSource<OUT, SRC extends SourceFunction<OUT>>
-			extends StreamSource<OUT, SRC> implements Serializable {
-
-		//private transient final AtomicInteger currentState;
-
-		private static final long serialVersionUID = 2431488948886850562L;
-
-		public LifecycleTrackingStreamSource(SRC sourceFunction) {
-			super(sourceFunction);
-		}
-
-		@Override
-		public void setup(StreamTask<?, ?> containingTask, StreamConfig config, Output<StreamRecord<OUT>> output) {
-			System.out.println("setup");
-			super.setup(containingTask, config, output);
-		}
-
-		@Override
-		public void snapshotState(StateSnapshotContext context) throws Exception {
-			System.out.println("snapshotState");
-			super.snapshotState(context);
-		}
-
-		@Override
-		public void initializeState(StateInitializationContext context) throws Exception {
-			System.out.println("initializeState");
-			super.initializeState(context);
-		}
-
-		@Override
-		public void open() throws Exception {
-			System.out.println("open");
-			super.open();
-		}
-
-		@Override
-		public void close() throws Exception {
-			System.out.println("close");
-			super.close();
-		}
-
-		@Override
-		public void dispose() throws Exception {
-			super.dispose();
-			System.out.println("dispose");
-		}
-	}
-}

http://git-wip-us.apache.org/repos/asf/flink/blob/1e475c76/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
index 8aae19f..94f6d5a 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
@@ -200,13 +200,13 @@ public class StreamTaskTest {
 		}
 	}
 
-	private Task createTask(
+	public static Task createTask(
 			Class<? extends AbstractInvokable> invokable,
 			StreamConfig taskConfig,
 			Configuration taskManagerConfig) throws Exception {
 
 		LibraryCacheManager libCache = mock(LibraryCacheManager.class);
-		when(libCache.getClassLoader(any(JobID.class))).thenReturn(getClass().getClassLoader());
+		when(libCache.getClassLoader(any(JobID.class))).thenReturn(StreamTaskTest.class.getClassLoader());
 		
 		ResultPartitionManager partitionManager = mock(ResultPartitionManager.class);
 		ResultPartitionConsumableNotifier consumableNotifier = mock(ResultPartitionConsumableNotifier.class);


[5/8] flink git commit: [FLINK-4844] Partitionable Raw Keyed/Operator State

Posted by al...@apache.org.
http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index bbe10d1..fd425f3 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -38,13 +38,13 @@ import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.NotifyCheckpointComplete;
 import org.apache.flink.runtime.messages.checkpoint.TriggerCheckpoint;
 import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.state.filesystem.FileStateHandle;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.apache.flink.runtime.testutils.CommonTestUtils;
@@ -1847,15 +1847,15 @@ public class CheckpointCoordinatorTest {
 		long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet());
 		CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L);
 
-		List<KeyGroupRange> keyGroupPartitions1 = CheckpointCoordinator.createKeyGroupPartitions(maxParallelism1, parallelism1);
-		List<KeyGroupRange> keyGroupPartitions2 = CheckpointCoordinator.createKeyGroupPartitions(maxParallelism2, parallelism2);
+		List<KeyGroupRange> keyGroupPartitions1 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1);
+		List<KeyGroupRange> keyGroupPartitions2 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2);
 
 		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
 			ChainedStateHandle<StreamStateHandle> nonPartitionedState = generateStateForVertex(jobVertexID1, index);
-			ChainedStateHandle<OperatorStateHandle> partitionableState = generateChainedPartitionableStateHandle(jobVertexID1, index, 2, 8);
-			List<KeyGroupsStateHandle> partitionedKeyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index));
+			ChainedStateHandle<OperatorStateHandle> partitionableState = generateChainedPartitionableStateHandle(jobVertexID1, index, 2, 8, false);
+			KeyGroupsStateHandle partitionedKeyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), false);
 
-			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(nonPartitionedState, partitionableState, partitionedKeyGroupState);
+			SubtaskState checkpointStateHandles = new SubtaskState(nonPartitionedState, partitionableState, null, partitionedKeyGroupState, null, 0);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
@@ -1867,9 +1867,9 @@ public class CheckpointCoordinatorTest {
 
 		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
 			ChainedStateHandle<StreamStateHandle> nonPartitionedState = generateStateForVertex(jobVertexID2, index);
-			ChainedStateHandle<OperatorStateHandle> partitionableState = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8);
-			List<KeyGroupsStateHandle> partitionedKeyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index));
-			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(nonPartitionedState, partitionableState, partitionedKeyGroupState);
+			ChainedStateHandle<OperatorStateHandle> partitionableState = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8, false);
+			KeyGroupsStateHandle partitionedKeyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false);
+			SubtaskState checkpointStateHandles = new SubtaskState(nonPartitionedState, partitionableState, null, partitionedKeyGroupState, null, 0);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
@@ -1952,13 +1952,13 @@ public class CheckpointCoordinatorTest {
 		long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet());
 		CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L);
 
-		List<KeyGroupRange> keyGroupPartitions1 = CheckpointCoordinator.createKeyGroupPartitions(maxParallelism1, parallelism1);
-		List<KeyGroupRange> keyGroupPartitions2 = CheckpointCoordinator.createKeyGroupPartitions(maxParallelism2, parallelism2);
+		List<KeyGroupRange> keyGroupPartitions1 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1);
+		List<KeyGroupRange> keyGroupPartitions2 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2);
 
 		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
 			ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID1, index);
-			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index));
-			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(valueSizeTuple, null, keyGroupState);
+			KeyGroupsStateHandle keyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), false);
+			SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, null, null, keyGroupState, null, 0);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
@@ -1971,8 +1971,8 @@ public class CheckpointCoordinatorTest {
 
 		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
 			ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID2, index);
-			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index));
-			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(valueSizeTuple, null, keyGroupState);
+			KeyGroupsStateHandle keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false);
+			SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, null, null, keyGroupState, null, 0);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
@@ -2067,17 +2067,17 @@ public class CheckpointCoordinatorTest {
 		long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet());
 		CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L);
 
-		List<KeyGroupRange> keyGroupPartitions1 = 
-				CheckpointCoordinator.createKeyGroupPartitions(maxParallelism1, parallelism1);
-		List<KeyGroupRange> keyGroupPartitions2 = 
-				CheckpointCoordinator.createKeyGroupPartitions(maxParallelism2, parallelism2);
+		List<KeyGroupRange> keyGroupPartitions1 =
+				StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1);
+		List<KeyGroupRange> keyGroupPartitions2 =
+				StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2);
 
 		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
 			ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID1, index);
-			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(
-					jobVertexID1, keyGroupPartitions1.get(index));
+			KeyGroupsStateHandle keyGroupState = generateKeyGroupState(
+					jobVertexID1, keyGroupPartitions1.get(index), false);
 
-			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(valueSizeTuple, null, keyGroupState);
+			SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, null, null, keyGroupState, null, 0);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
@@ -2091,10 +2091,10 @@ public class CheckpointCoordinatorTest {
 		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
 
 			ChainedStateHandle<StreamStateHandle> state = generateStateForVertex(jobVertexID2, index);
-			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(
-					jobVertexID2, keyGroupPartitions2.get(index));
+			KeyGroupsStateHandle keyGroupState = generateKeyGroupState(
+					jobVertexID2, keyGroupPartitions2.get(index), false);
 
-			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(state, null, keyGroupState);
+			SubtaskState checkpointStateHandles = new SubtaskState(state, null, null, keyGroupState, null, 0);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
@@ -2132,24 +2132,36 @@ public class CheckpointCoordinatorTest {
 			"non-partitioned state changed.");
 	}
 
+	@Test
+	public void testRestoreLatestCheckpointedStateScaleIn() throws Exception {
+		testRestoreLatestCheckpointedStateWithChangingParallelism(false);
+	}
+
+	@Test
+	public void testRestoreLatestCheckpointedStateScaleOut() throws Exception {
+		testRestoreLatestCheckpointedStateWithChangingParallelism(false);
+	}
+
 	/**
 	 * Tests the checkpoint restoration with changing parallelism of job vertex with partitioned
 	 * state.
 	 *
 	 * @throws Exception
 	 */
-	@Test
-	public void testRestoreLatestCheckpointedStateWithChangingParallelism() throws Exception {
+	private void testRestoreLatestCheckpointedStateWithChangingParallelism(boolean scaleOut) throws Exception {
 		final JobID jid = new JobID();
 		final long timestamp = System.currentTimeMillis();
 
 		final JobVertexID jobVertexID1 = new JobVertexID();
 		final JobVertexID jobVertexID2 = new JobVertexID();
 		int parallelism1 = 3;
-		int parallelism2 = 2;
+		int parallelism2 = scaleOut ? 2 : 13;
+
 		int maxParallelism1 = 42;
 		int maxParallelism2 = 13;
 
+		int newParallelism2 = scaleOut ? 13 : 2;
+
 		final ExecutionJobVertex jobVertex1 = mockExecutionJobVertex(
 				jobVertexID1,
 				parallelism1,
@@ -2190,18 +2202,20 @@ public class CheckpointCoordinatorTest {
 		long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet());
 		CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L);
 
-		List<KeyGroupRange> keyGroupPartitions1 = 
-				CheckpointCoordinator.createKeyGroupPartitions(maxParallelism1, parallelism1);
-		List<KeyGroupRange> keyGroupPartitions2 = 
-				CheckpointCoordinator.createKeyGroupPartitions(maxParallelism2, parallelism2);
+		List<KeyGroupRange> keyGroupPartitions1 =
+				StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1);
+		List<KeyGroupRange> keyGroupPartitions2 =
+				StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2);
 
+		//vertex 1
 		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
 			ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID1, index);
-			ChainedStateHandle<OperatorStateHandle> partitionableState = generateChainedPartitionableStateHandle(jobVertexID1, index, 2, 8);
-			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index));
+			ChainedStateHandle<OperatorStateHandle> opStateBackend = generateChainedPartitionableStateHandle(jobVertexID1, index, 2, 8, false);
+			KeyGroupsStateHandle keyedStateBackend = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), false);
+			KeyGroupsStateHandle keyedStateRaw = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), true);
 
 
-			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(valueSizeTuple, partitionableState, keyGroupState);
+			SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, opStateBackend, null, keyedStateBackend, keyedStateRaw , 0);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
@@ -2211,13 +2225,19 @@ public class CheckpointCoordinatorTest {
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
 
-
-		final List<ChainedStateHandle<OperatorStateHandle>> originalPartitionableStates = new ArrayList<>(jobVertex2.getParallelism());
+		//vertex 2
+		final List<ChainedStateHandle<OperatorStateHandle>> expectedOpStatesBackend = new ArrayList<>(jobVertex2.getParallelism());
+		final List<ChainedStateHandle<OperatorStateHandle>> expectedOpStatesRaw = new ArrayList<>(jobVertex2.getParallelism());
 		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
-			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index));
-			ChainedStateHandle<OperatorStateHandle> partitionableState = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8);
-			originalPartitionableStates.add(partitionableState);
-			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(null, partitionableState, keyGroupState);
+			KeyGroupsStateHandle keyedStateBackend = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false);
+			KeyGroupsStateHandle keyedStateRaw = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), true);
+			ChainedStateHandle<OperatorStateHandle> opStateBackend = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8, false);
+			ChainedStateHandle<OperatorStateHandle> opStateRaw = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8, true);
+			expectedOpStatesBackend.add(opStateBackend);
+			expectedOpStatesRaw.add(opStateRaw);
+			SubtaskState checkpointStateHandles =
+					new SubtaskState(new ChainedStateHandle<>(
+							Collections.<StreamStateHandle>singletonList(null)), opStateBackend, opStateRaw, keyedStateBackend, keyedStateRaw, 0);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
@@ -2233,16 +2253,15 @@ public class CheckpointCoordinatorTest {
 
 		Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>();
 
-		int newParallelism2 = 13;
-
-		List<KeyGroupRange> newKeyGroupPartitions2 = 
-				CheckpointCoordinator.createKeyGroupPartitions(maxParallelism2, newParallelism2);
+		List<KeyGroupRange> newKeyGroupPartitions2 =
+				StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, newParallelism2);
 
 		final ExecutionJobVertex newJobVertex1 = mockExecutionJobVertex(
 				jobVertexID1,
 				parallelism1,
 				maxParallelism1);
 
+		// rescale vertex 2
 		final ExecutionJobVertex newJobVertex2 = mockExecutionJobVertex(
 				jobVertexID2,
 				newParallelism2,
@@ -2254,19 +2273,28 @@ public class CheckpointCoordinatorTest {
 
 		// verify the restored state
 		verifiyStateRestore(jobVertexID1, newJobVertex1, keyGroupPartitions1);
-		List<List<Collection<OperatorStateHandle>>> actualPartitionableStates = new ArrayList<>(newJobVertex2.getParallelism());
+		List<List<Collection<OperatorStateHandle>>> actualOpStatesBackend = new ArrayList<>(newJobVertex2.getParallelism());
+		List<List<Collection<OperatorStateHandle>>> actualOpStatesRaw = new ArrayList<>(newJobVertex2.getParallelism());
 		for (int i = 0; i < newJobVertex2.getParallelism(); i++) {
-			List<KeyGroupsStateHandle> originalKeyGroupState = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i));
+			KeyGroupsStateHandle originalKeyedStateBackend = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), false);
+			KeyGroupsStateHandle originalKeyedStateRaw = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), true);
+
+			TaskStateHandles taskStateHandles = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles();
 
-			ChainedStateHandle<StreamStateHandle> operatorState = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getChainedStateHandle();
-			List<Collection<OperatorStateHandle>> partitionableState = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getChainedPartitionableStateHandle();
-			List<KeyGroupsStateHandle> keyGroupState = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getKeyGroupsStateHandles();
+			ChainedStateHandle<StreamStateHandle> operatorState = taskStateHandles.getLegacyOperatorState();
+			List<Collection<OperatorStateHandle>> opStateBackend = taskStateHandles.getManagedOperatorState();
+			List<Collection<OperatorStateHandle>> opStateRaw = taskStateHandles.getRawOperatorState();
+			Collection<KeyGroupsStateHandle> keyGroupStateBackend = taskStateHandles.getManagedKeyedState();
+			Collection<KeyGroupsStateHandle> keyGroupStateRaw = taskStateHandles.getRawKeyedState();
 
-			actualPartitionableStates.add(partitionableState);
+			actualOpStatesBackend.add(opStateBackend);
+			actualOpStatesRaw.add(opStateRaw);
 			assertNull(operatorState);
-			compareKeyPartitionedState(originalKeyGroupState, keyGroupState);
+			compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyGroupStateBackend);
+			compareKeyedState(Collections.singletonList(originalKeyedStateRaw), keyGroupStateRaw);
 		}
-		comparePartitionableState(originalPartitionableStates, actualPartitionableStates);
+		comparePartitionableState(expectedOpStatesBackend, actualOpStatesBackend);
+		comparePartitionableState(expectedOpStatesRaw, actualOpStatesRaw);
 	}
 
 	/**
@@ -2320,15 +2348,41 @@ public class CheckpointCoordinatorTest {
 	//  Utilities
 	// ------------------------------------------------------------------------
 
-	public static List<KeyGroupsStateHandle> generateKeyGroupState(
+	static void sendAckMessageToCoordinator(
+			CheckpointCoordinator coord,
+			long checkpointId, JobID jid,
+			ExecutionJobVertex jobVertex,
+			JobVertexID jobVertexID,
+			List<KeyGroupRange> keyGroupPartitions) throws Exception {
+
+		for (int index = 0; index < jobVertex.getParallelism(); index++) {
+			ChainedStateHandle<StreamStateHandle> state = generateStateForVertex(jobVertexID, index);
+			KeyGroupsStateHandle keyGroupState = generateKeyGroupState(
+					jobVertexID,
+					keyGroupPartitions.get(index), false);
+
+			SubtaskState checkpointStateHandles = new SubtaskState(state, null, null, keyGroupState, null, 0);
+			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
+					jid,
+					jobVertex.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+					new CheckpointMetaData(checkpointId, 0L),
+					checkpointStateHandles);
+
+			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
+		}
+	}
+
+	public static KeyGroupsStateHandle generateKeyGroupState(
 			JobVertexID jobVertexID,
-			KeyGroupRange keyGroupPartition) throws IOException {
+			KeyGroupRange keyGroupPartition, boolean rawState) throws IOException {
 
 		List<Integer> testStatesLists = new ArrayList<>(keyGroupPartition.getNumberOfKeyGroups());
 
 		// generate state for one keygroup
 		for (int keyGroupIndex : keyGroupPartition) {
-			Random random = new Random(jobVertexID.hashCode() + keyGroupIndex);
+			int vertexHash = jobVertexID.hashCode();
+			int seed = rawState ? (vertexHash * (31 + keyGroupIndex)) : (vertexHash + keyGroupIndex);
+			Random random = new Random(seed);
 			int simulatedStateValue = random.nextInt();
 			testStatesLists.add(simulatedStateValue);
 		}
@@ -2336,7 +2390,7 @@ public class CheckpointCoordinatorTest {
 		return generateKeyGroupState(keyGroupPartition, testStatesLists);
 	}
 
-	public static List<KeyGroupsStateHandle> generateKeyGroupState(
+	public static KeyGroupsStateHandle generateKeyGroupState(
 			KeyGroupRange keyGroupRange,
 			List<? extends Serializable> states) throws IOException {
 
@@ -2353,9 +2407,7 @@ public class CheckpointCoordinatorTest {
 		KeyGroupsStateHandle keyGroupsStateHandle = new KeyGroupsStateHandle(
 				keyGroupRangeOffsets,
 				allSerializedStatesHandle);
-		List<KeyGroupsStateHandle> keyGroupsStateHandleList = new ArrayList<>();
-		keyGroupsStateHandleList.add(keyGroupsStateHandle);
-		return keyGroupsStateHandleList;
+		return keyGroupsStateHandle;
 	}
 
 	public static Tuple2<byte[], List<long[]>> serializeTogetherAndTrackOffsets(
@@ -2412,14 +2464,19 @@ public class CheckpointCoordinatorTest {
 			JobVertexID jobVertexID,
 			int index,
 			int namedStates,
-			int partitionsPerState) throws IOException {
+			int partitionsPerState,
+			boolean rawState) throws IOException {
 
 		Map<String, List<? extends Serializable>> statesListsMap = new HashMap<>(namedStates);
 
 		for (int i = 0; i < namedStates; ++i) {
 			List<Integer> testStatesLists = new ArrayList<>(partitionsPerState);
 			// generate state
-			Random random = new Random(jobVertexID.hashCode() * index + i * namedStates);
+			int seed = jobVertexID.hashCode() * index + i * namedStates;
+			if (rawState) {
+				seed = (seed + 1) * 31;
+			}
+			Random random = new Random(seed);
 			for (int j = 0; j < partitionsPerState; ++j) {
 				int simulatedStateValue = random.nextInt();
 				testStatesLists.add(simulatedStateValue);
@@ -2454,7 +2511,7 @@ public class CheckpointCoordinatorTest {
 				serializationWithOffsets.f0);
 
 		OperatorStateHandle operatorStateHandle =
-				new OperatorStateHandle(streamStateHandle, offsetsMap);
+				new OperatorStateHandle(offsetsMap, streamStateHandle);
 		return ChainedStateHandle.wrapSingleHandle(operatorStateHandle);
 	}
 
@@ -2528,37 +2585,35 @@ public class CheckpointCoordinatorTest {
 
 		for (int i = 0; i < executionJobVertex.getParallelism(); i++) {
 
+			TaskStateHandles taskStateHandles = executionJobVertex.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles();
+
 			ChainedStateHandle<StreamStateHandle> expectNonPartitionedState = generateStateForVertex(jobVertexID, i);
-			ChainedStateHandle<StreamStateHandle> actualNonPartitionedState = executionJobVertex.
-					getTaskVertices()[i].getCurrentExecutionAttempt().getChainedStateHandle();
+			ChainedStateHandle<StreamStateHandle> actualNonPartitionedState = taskStateHandles.getLegacyOperatorState();
 			assertTrue(CommonTestUtils.isSteamContentEqual(
 					expectNonPartitionedState.get(0).openInputStream(),
 					actualNonPartitionedState.get(0).openInputStream()));
 
-			ChainedStateHandle<OperatorStateHandle> expectedPartitionableState =
-					generateChainedPartitionableStateHandle(jobVertexID, i, 2, 8);
+			ChainedStateHandle<OperatorStateHandle> expectedOpStateBackend =
+					generateChainedPartitionableStateHandle(jobVertexID, i, 2, 8, false);
 
-			List<Collection<OperatorStateHandle>> actualPartitionableState = executionJobVertex.
-					getTaskVertices()[i].getCurrentExecutionAttempt().getChainedPartitionableStateHandle();
+			List<Collection<OperatorStateHandle>> actualPartitionableState = taskStateHandles.getManagedOperatorState();
 
 			assertTrue(CommonTestUtils.isSteamContentEqual(
-					expectedPartitionableState.get(0).openInputStream(),
+					expectedOpStateBackend.get(0).openInputStream(),
 					actualPartitionableState.get(0).iterator().next().openInputStream()));
 
-			List<KeyGroupsStateHandle> expectPartitionedKeyGroupState = generateKeyGroupState(
-					jobVertexID,
-					keyGroupPartitions.get(i));
-			List<KeyGroupsStateHandle> actualPartitionedKeyGroupState = executionJobVertex.
-					getTaskVertices()[i].getCurrentExecutionAttempt().getKeyGroupsStateHandles();
-			compareKeyPartitionedState(expectPartitionedKeyGroupState, actualPartitionedKeyGroupState);
+			KeyGroupsStateHandle expectPartitionedKeyGroupState = generateKeyGroupState(
+					jobVertexID, keyGroupPartitions.get(i), false);
+			Collection<KeyGroupsStateHandle> actualPartitionedKeyGroupState = taskStateHandles.getManagedKeyedState();
+			compareKeyedState(Collections.singletonList(expectPartitionedKeyGroupState), actualPartitionedKeyGroupState);
 		}
 	}
 
-	public static void compareKeyPartitionedState(
-			List<KeyGroupsStateHandle> expectPartitionedKeyGroupState,
-			List<KeyGroupsStateHandle> actualPartitionedKeyGroupState) throws Exception {
+	public static void compareKeyedState(
+			Collection<KeyGroupsStateHandle> expectPartitionedKeyGroupState,
+			Collection<KeyGroupsStateHandle> actualPartitionedKeyGroupState) throws Exception {
 
-		KeyGroupsStateHandle expectedHeadOpKeyGroupStateHandle = expectPartitionedKeyGroupState.get(0);
+		KeyGroupsStateHandle expectedHeadOpKeyGroupStateHandle = expectPartitionedKeyGroupState.iterator().next();
 		int expectedTotalKeyGroups = expectedHeadOpKeyGroupStateHandle.getNumberOfKeyGroups();
 		int actualTotalKeyGroups = 0;
 		for(KeyGroupsStateHandle keyGroupsStateHandle: actualPartitionedKeyGroupState) {
@@ -2576,13 +2631,10 @@ public class CheckpointCoordinatorTest {
 				for (KeyGroupsStateHandle oneActualKeyGroupStateHandle : actualPartitionedKeyGroupState) {
 					if (oneActualKeyGroupStateHandle.containsKeyGroup(groupId)) {
 						long actualOffset = oneActualKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
-						try (FSDataInputStream actualInputStream =
-								     oneActualKeyGroupStateHandle.openInputStream()) {
+						try (FSDataInputStream actualInputStream = oneActualKeyGroupStateHandle.openInputStream()) {
 							actualInputStream.seek(actualOffset);
-
 							int actualGroupState = InstantiationUtil.
 									deserializeObject(actualInputStream, Thread.currentThread().getContextClassLoader());
-
 							assertEquals(expectedKeyGroupState, actualGroupState);
 						}
 					}
@@ -2599,16 +2651,7 @@ public class CheckpointCoordinatorTest {
 		for (ChainedStateHandle<OperatorStateHandle> chainedStateHandle : expected) {
 			for (int i = 0; i < chainedStateHandle.getLength(); ++i) {
 				OperatorStateHandle operatorStateHandle = chainedStateHandle.get(i);
-				try (FSDataInputStream in = operatorStateHandle.openInputStream()) {
-					for (Map.Entry<String, long[]> entry : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) {
-						for (long offset : entry.getValue()) {
-							in.seek(offset);
-							Integer state = InstantiationUtil.
-									deserializeObject(in, Thread.currentThread().getContextClassLoader());
-							expectedResult.add(i + " : " + entry.getKey() + " : " + state);
-						}
-					}
-				}
+				collectResult(i, operatorStateHandle, expectedResult);
 			}
 		}
 		Collections.sort(expectedResult);
@@ -2618,25 +2661,32 @@ public class CheckpointCoordinatorTest {
 			if (collectionList != null) {
 				for (int i = 0; i < collectionList.size(); ++i) {
 					Collection<OperatorStateHandle> stateHandles = collectionList.get(i);
+					Assert.assertNotNull(stateHandles);
 					for (OperatorStateHandle operatorStateHandle : stateHandles) {
-						try (FSDataInputStream in = operatorStateHandle.openInputStream()) {
-							for (Map.Entry<String, long[]> entry : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) {
-								for (long offset : entry.getValue()) {
-									in.seek(offset);
-									Integer state = InstantiationUtil.
-											deserializeObject(in, Thread.currentThread().getContextClassLoader());
-									actualResult.add(i + " : " + entry.getKey() + " : " + state);
-								}
-							}
-						}
+						collectResult(i, operatorStateHandle, actualResult);
 					}
 				}
 			}
 		}
+
 		Collections.sort(actualResult);
 		Assert.assertEquals(expectedResult, actualResult);
 	}
 
+	private static void collectResult(int opIdx, OperatorStateHandle operatorStateHandle, List<String> resultCollector) throws Exception {
+		try (FSDataInputStream in = operatorStateHandle.openInputStream()) {
+			for (Map.Entry<String, long[]> entry : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) {
+				for (long offset : entry.getValue()) {
+					in.seek(offset);
+					Integer state = InstantiationUtil.
+							deserializeObject(in, Thread.currentThread().getContextClassLoader());
+					resultCollector.add(opIdx + " : " + entry.getKey() + " : " + state);
+				}
+			}
+		}
+	}
+
+
 	@Test
 	public void testCreateKeyGroupPartitions() {
 		testCreateKeyGroupPartitions(1, 1);
@@ -2697,7 +2747,7 @@ public class CheckpointCoordinatorTest {
 	}
 
 	private void testCreateKeyGroupPartitions(int maxParallelism, int parallelism) {
-		List<KeyGroupRange> ranges = CheckpointCoordinator.createKeyGroupPartitions(maxParallelism, parallelism);
+		List<KeyGroupRange> ranges = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism, parallelism);
 		for (int i = 0; i < maxParallelism; ++i) {
 			KeyGroupRange range = ranges.get(KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(maxParallelism, parallelism, i));
 			if (!range.contains(i)) {
@@ -2743,7 +2793,7 @@ public class CheckpointCoordinatorTest {
 			}
 
 			previousParallelOpInstanceStates.add(
-					new OperatorStateHandle(new FileStateHandle(fakePath, -1), namedStatesToOffsets));
+					new OperatorStateHandle(namedStatesToOffsets, new FileStateHandle(fakePath, -1)));
 		}
 
 		Map<StreamStateHandle, Map<String, List<Long>>> expected = new HashMap<>();

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
index 950526c..359262f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
@@ -29,17 +29,18 @@ import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.util.SerializableObject;
 import org.hamcrest.BaseMatcher;
 import org.hamcrest.Description;
 import org.junit.Test;
 import org.mockito.Mockito;
 
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
@@ -65,7 +66,7 @@ public class CheckpointStateRestoreTest {
 			final ChainedStateHandle<StreamStateHandle> serializedState = CheckpointCoordinatorTest.generateChainedStateHandle(new SerializableObject());
 			KeyGroupRange keyGroupRange = KeyGroupRange.of(0,0);
 			List<SerializableObject> testStates = Collections.singletonList(new SerializableObject());
-			final List<KeyGroupsStateHandle> serializedKeyGroupStates = CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, testStates);
+			final KeyGroupsStateHandle serializedKeyGroupStates = CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, testStates);
 
 			final JobID jid = new JobID();
 			final JobVertexID statefulId = new JobVertexID();
@@ -115,7 +116,7 @@ public class CheckpointStateRestoreTest {
 			PendingCheckpoint pending = coord.getPendingCheckpoints().values().iterator().next();
 			final long checkpointId = pending.getCheckpointId();
 
-			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(serializedState, null, serializedKeyGroupStates);
+			SubtaskState checkpointStateHandles = new SubtaskState(serializedState, null, null, serializedKeyGroupStates, null, 0L);
 			CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L);
 			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointMetaData, checkpointStateHandles));
 			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointMetaData, checkpointStateHandles));
@@ -131,26 +132,33 @@ public class CheckpointStateRestoreTest {
 
 			// verify that each stateful vertex got the state
 
-			BaseMatcher<CheckpointStateHandles> matcher = new BaseMatcher<CheckpointStateHandles>() {
+			final TaskStateHandles taskStateHandles = new TaskStateHandles(
+					serializedState,
+					Collections.<Collection<OperatorStateHandle>>singletonList(null),
+					Collections.<Collection<OperatorStateHandle>>singletonList(null),
+					Collections.singletonList(serializedKeyGroupStates),
+					null);
+
+			BaseMatcher<TaskStateHandles> matcher = new BaseMatcher<TaskStateHandles>() {
 				@Override
 				public boolean matches(Object o) {
-					if (o instanceof CheckpointStateHandles) {
-						return ((CheckpointStateHandles) o).getNonPartitionedStateHandles().equals(serializedState);
+					if (o instanceof TaskStateHandles) {
+						return o.equals(taskStateHandles);
 					}
 					return false;
 				}
 
 				@Override
 				public void describeTo(Description description) {
-					description.appendValue(serializedState);
+					description.appendValue(taskStateHandles);
 				}
 			};
 
-			verify(statefulExec1, times(1)).setInitialState(Mockito.argThat(matcher), Mockito.<List<Collection<OperatorStateHandle>>>any());
-			verify(statefulExec2, times(1)).setInitialState(Mockito.argThat(matcher), Mockito.<List<Collection<OperatorStateHandle>>>any());
-			verify(statefulExec3, times(1)).setInitialState(Mockito.argThat(matcher), Mockito.<List<Collection<OperatorStateHandle>>>any());
-			verify(statelessExec1, times(0)).setInitialState(Mockito.<CheckpointStateHandles>any(), Mockito.<List<Collection<OperatorStateHandle>>>any());
-			verify(statelessExec2, times(0)).setInitialState(Mockito.<CheckpointStateHandles>any(), Mockito.<List<Collection<OperatorStateHandle>>>any());
+			verify(statefulExec1, times(1)).setInitialState(Mockito.argThat(matcher));
+			verify(statefulExec2, times(1)).setInitialState(Mockito.argThat(matcher));
+			verify(statefulExec3, times(1)).setInitialState(Mockito.argThat(matcher));
+			verify(statelessExec1, times(0)).setInitialState(Mockito.<TaskStateHandles>any());
+			verify(statelessExec2, times(0)).setInitialState(Mockito.<TaskStateHandles>any());
 		}
 		catch (Exception e) {
 			e.printStackTrace();
@@ -164,7 +172,7 @@ public class CheckpointStateRestoreTest {
 			final ChainedStateHandle<StreamStateHandle> serializedState = CheckpointCoordinatorTest.generateChainedStateHandle(new SerializableObject());
 			KeyGroupRange keyGroupRange = KeyGroupRange.of(0,0);
 			List<SerializableObject> testStates = Collections.singletonList(new SerializableObject());
-			final List<KeyGroupsStateHandle> serializedKeyGroupStates = CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, testStates);
+			final KeyGroupsStateHandle serializedKeyGroupStates = CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, testStates);
 
 			final JobID jid = new JobID();
 			final JobVertexID statefulId = new JobVertexID();
@@ -215,7 +223,8 @@ public class CheckpointStateRestoreTest {
 			final long checkpointId = pending.getCheckpointId();
 
 			// the difference to the test "testSetState" is that one stateful subtask does not report state
-			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(serializedState, null, serializedKeyGroupStates);
+			SubtaskState checkpointStateHandles =
+					new SubtaskState(serializedState, null, null, serializedKeyGroupStates, null, 0L);
 
 			CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
index baa0e08..6b0d3f8 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
@@ -206,7 +206,7 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 			ChainedStateHandle<StreamStateHandle> stateHandle = CheckpointCoordinatorTest.generateChainedStateHandle(
 					new CheckpointMessagesTest.MyHandle());
 
-			taskState.putState(i, new SubtaskState(stateHandle, 0));
+			taskState.putState(i, new SubtaskState(stateHandle, null, null, null, null, 0L));
 		}
 
 		return new TestCompletedCheckpoint(new JobID(), id, 0, taskGroupStates, props);

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1SerializerTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1SerializerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1SerializerTest.java
index bad836b..508a69d 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1SerializerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1SerializerTest.java
@@ -19,11 +19,13 @@
 package org.apache.flink.runtime.checkpoint.savepoint;
 
 import org.apache.commons.io.output.ByteArrayOutputStream;
+import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.junit.Test;
 
 import java.io.ByteArrayInputStream;
+import java.util.Random;
 
 import static org.junit.Assert.assertEquals;
 
@@ -34,19 +36,23 @@ public class SavepointV1SerializerTest {
 	 */
 	@Test
 	public void testSerializeDeserializeV1() throws Exception {
-		SavepointV1 expected = new SavepointV1(123123, SavepointV1Test.createTaskStates(8, 32));
+		Random r = new Random(42);
+		for (int i = 0; i < 100; ++i) {
+			SavepointV1 expected =
+					new SavepointV1(i+ 123123, SavepointV1Test.createTaskStates(1 + r.nextInt(64), 1 + r.nextInt(64)));
 
-		SavepointV1Serializer serializer = SavepointV1Serializer.INSTANCE;
+			SavepointV1Serializer serializer = SavepointV1Serializer.INSTANCE;
 
-		// Serialize
-		ByteArrayOutputStream baos = new ByteArrayOutputStream();
-		serializer.serialize(expected, new DataOutputViewStreamWrapper(baos));
-		byte[] bytes = baos.toByteArray();
+			// Serialize
+			ByteArrayOutputStreamWithPos baos = new ByteArrayOutputStreamWithPos();
+			serializer.serialize(expected, new DataOutputViewStreamWrapper(baos));
+			byte[] bytes = baos.toByteArray();
 
-		// Deserialize
-		ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
-		Savepoint actual = serializer.deserialize(new DataInputViewStreamWrapper(bais));
+			// Deserialize
+			ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
+			Savepoint actual = serializer.deserialize(new DataInputViewStreamWrapper(bais));
 
-		assertEquals(expected, actual);
+			assertEquals(expected, actual);
+		}
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
index e38e5fb..1ae74ff 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
@@ -32,10 +32,10 @@ import org.junit.Test;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collection;
-import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Random;
 import java.util.concurrent.ThreadLocalRandom;
 
 import static org.junit.Assert.assertEquals;
@@ -66,35 +66,83 @@ public class SavepointV1Test {
 		assertTrue(savepoint.getTaskStates().isEmpty());
 	}
 
-	static Collection<TaskState> createTaskStates(int numTaskStates, int numSubtaskStates) throws IOException {
+	static Collection<TaskState> createTaskStates(int numTaskStates, int numSubtasksPerTask) throws IOException {
+
+		Random random = new Random(numTaskStates * 31 + numSubtasksPerTask);
+
 		List<TaskState> taskStates = new ArrayList<>(numTaskStates);
 
-		for (int i = 0; i < numTaskStates; i++) {
-			TaskState taskState = new TaskState(new JobVertexID(), numSubtaskStates, numSubtaskStates, 1);
-			for (int j = 0; j < numSubtaskStates; j++) {
-				StreamStateHandle stateHandle = new TestByteStreamStateHandleDeepCompare("a", "Hello".getBytes());
-				taskState.putState(i, new SubtaskState(
-						new ChainedStateHandle<>(Collections.singletonList(stateHandle)), 0));
-
-				stateHandle = new TestByteStreamStateHandleDeepCompare("b", "Beautiful".getBytes());
-				Map<String, long[]> offsetsMap = new HashMap<>();
-				offsetsMap.put("A", new long[]{0, 10, 20});
-				offsetsMap.put("B", new long[]{30, 40, 50});
-
-				OperatorStateHandle operatorStateHandle =
-						new OperatorStateHandle(stateHandle, offsetsMap);
-
-				taskState.putPartitionableState(
-						i,
-						new ChainedStateHandle<OperatorStateHandle>(
-								Collections.singletonList(operatorStateHandle)));
-			}
+		for (int stateIdx = 0; stateIdx < numTaskStates; ++stateIdx) {
+
+			int chainLength = 1 + random.nextInt(8);
+
+			TaskState taskState = new TaskState(new JobVertexID(), numSubtasksPerTask, 128, chainLength);
+
+			int noNonPartitionableStateAtIndex = random.nextInt(chainLength);
+			int noOperatorStateBackendAtIndex = random.nextInt(chainLength);
+			int noOperatorStateStreamAtIndex = random.nextInt(chainLength);
+
+			boolean hasKeyedBackend = random.nextInt(4) != 0;
+			boolean hasKeyedStream = random.nextInt(4) != 0;
+
+			for (int subtaskIdx = 0; subtaskIdx < numSubtasksPerTask; subtaskIdx++) {
 
-			taskState.putKeyedState(
-					0,
-					new KeyGroupsStateHandle(
+				List<StreamStateHandle> nonPartitionableStates = new ArrayList<>(chainLength);
+				List<OperatorStateHandle> operatorStatesBackend = new ArrayList<>(chainLength);
+				List<OperatorStateHandle> operatorStatesStream = new ArrayList<>(chainLength);
+
+				for (int chainIdx = 0; chainIdx < chainLength; ++chainIdx) {
+
+					StreamStateHandle nonPartitionableState =
+							new TestByteStreamStateHandleDeepCompare("a-" + chainIdx, ("Hi-" + chainIdx).getBytes());
+					StreamStateHandle operatorStateBackend =
+							new TestByteStreamStateHandleDeepCompare("b-" + chainIdx, ("Beautiful-" + chainIdx).getBytes());
+					StreamStateHandle operatorStateStream =
+							new TestByteStreamStateHandleDeepCompare("b-" + chainIdx, ("Beautiful-" + chainIdx).getBytes());
+					Map<String, long[]> offsetsMap = new HashMap<>();
+					offsetsMap.put("A", new long[]{0, 10, 20});
+					offsetsMap.put("B", new long[]{30, 40, 50});
+
+					if (chainIdx != noNonPartitionableStateAtIndex) {
+						nonPartitionableStates.add(nonPartitionableState);
+					}
+
+					if (chainIdx != noOperatorStateBackendAtIndex) {
+						OperatorStateHandle operatorStateHandleBackend =
+								new OperatorStateHandle(offsetsMap, operatorStateBackend);
+						operatorStatesBackend.add(operatorStateHandleBackend);
+					}
+
+					if (chainIdx != noOperatorStateStreamAtIndex) {
+						OperatorStateHandle operatorStateHandleStream =
+								new OperatorStateHandle(offsetsMap, operatorStateStream);
+						operatorStatesStream.add(operatorStateHandleStream);
+					}
+				}
+
+				KeyGroupsStateHandle keyedStateBackend = null;
+				KeyGroupsStateHandle keyedStateStream = null;
+
+				if (hasKeyedBackend) {
+					keyedStateBackend = new KeyGroupsStateHandle(
 							new KeyGroupRangeOffsets(1, 1, new long[]{42}),
-							new TestByteStreamStateHandleDeepCompare("c", "World".getBytes())));
+							new TestByteStreamStateHandleDeepCompare("c", "Hello".getBytes()));
+				}
+
+				if (hasKeyedStream) {
+					keyedStateStream = new KeyGroupsStateHandle(
+							new KeyGroupRangeOffsets(1, 1, new long[]{23}),
+							new TestByteStreamStateHandleDeepCompare("d", "World".getBytes()));
+				}
+
+				taskState.putState(subtaskIdx, new SubtaskState(
+						new ChainedStateHandle<>(nonPartitionableStates),
+						new ChainedStateHandle<>(operatorStatesBackend),
+						new ChainedStateHandle<>(operatorStatesStream),
+						keyedStateStream,
+						keyedStateBackend,
+						subtaskIdx * 10L));
+			}
 
 			taskStates.add(taskState);
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java
index 2dac87f..50a59a5 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java
@@ -335,8 +335,7 @@ public class SimpleCheckpointStatsTrackerTest {
 					StreamStateHandle proxy = new StateHandleProxy(new Path(), proxySize);
 
 					SubtaskState subtaskState = new SubtaskState(
-						new ChainedStateHandle<>(Collections.singletonList(proxy)),
-						duration);
+							new ChainedStateHandle<>(Collections.singletonList(proxy)), null, null, null, null, duration);
 
 					taskState.putState(subtaskIndex, subtaskState);
 				}

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
index 5ec6991..b195858 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
@@ -37,6 +37,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointRecoveryFactory;
 import org.apache.flink.runtime.checkpoint.CompletedCheckpoint;
 import org.apache.flink.runtime.checkpoint.CompletedCheckpointStore;
 import org.apache.flink.runtime.checkpoint.StandaloneCheckpointIDCounter;
+import org.apache.flink.runtime.checkpoint.SubtaskState;
 import org.apache.flink.runtime.clusterframework.types.ResourceID;
 import org.apache.flink.runtime.execution.librarycache.BlobLibraryCacheManager;
 import org.apache.flink.runtime.executiongraph.restart.FixedDelayRestartStrategy;
@@ -57,10 +58,8 @@ import org.apache.flink.runtime.leaderelection.TestingLeaderRetrievalService;
 import org.apache.flink.runtime.leaderretrieval.LeaderRetrievalService;
 import org.apache.flink.runtime.messages.JobManagerMessages;
 import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.apache.flink.runtime.taskmanager.TaskManager;
 import org.apache.flink.runtime.testingUtils.TestingJobManager;
@@ -85,7 +84,6 @@ import scala.concurrent.duration.FiniteDuration;
 
 import java.util.ArrayDeque;
 import java.util.ArrayList;
-import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
@@ -446,12 +444,10 @@ public class JobManagerHARecoveryTest {
 
 		@Override
 		public void setInitialState(
-				ChainedStateHandle<StreamStateHandle> chainedState,
-				List<KeyGroupsStateHandle> keyGroupsState,
-				List<Collection<OperatorStateHandle>> partitionableOperatorState) throws Exception {
+				TaskStateHandles taskStateHandles) throws Exception {
 			int subtaskIndex = getIndexInSubtaskGroup();
 			if (subtaskIndex < recoveredStates.length) {
-				try (FSDataInputStream in = chainedState.get(0).openInputStream()) {
+				try (FSDataInputStream in = taskStateHandles.getLegacyOperatorState().get(0).openInputStream()) {
 					recoveredStates[subtaskIndex] = InstantiationUtil.deserializeObject(in, getUserCodeClassLoader());
 				}
 			}
@@ -466,9 +462,8 @@ public class JobManagerHARecoveryTest {
 
 				ChainedStateHandle<StreamStateHandle> chainedStateHandle =
 						new ChainedStateHandle<StreamStateHandle>(Collections.singletonList(byteStreamStateHandle));
-
-				CheckpointStateHandles checkpointStateHandles =
-						new CheckpointStateHandles(chainedStateHandle, null, Collections.<KeyGroupsStateHandle>emptyList());
+				SubtaskState checkpointStateHandles =
+						new SubtaskState(chainedStateHandle, null, null, null, null, 0L);
 
 				getEnvironment().acknowledgeCheckpoint(
 						new CheckpointMetaData(checkpointMetaData.getCheckpointId(), -1, 0L, 0L, 0L, 0L),

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
index 305625e..3521630 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
@@ -23,12 +23,12 @@ import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.testutils.CommonTestUtils;
 import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTest;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.checkpoint.SubtaskState;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.NotifyCheckpointComplete;
 import org.apache.flink.runtime.messages.checkpoint.TriggerCheckpoint;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.junit.Test;
@@ -67,11 +67,14 @@ public class CheckpointMessagesTest {
 
 			KeyGroupRange keyGroupRange = KeyGroupRange.of(42,42);
 
-			CheckpointStateHandles checkpointStateHandles =
-					new CheckpointStateHandles(
+			SubtaskState checkpointStateHandles =
+					new SubtaskState(
 							CheckpointCoordinatorTest.generateChainedStateHandle(new MyHandle()),
-							CheckpointCoordinatorTest.generateChainedPartitionableStateHandle(new JobVertexID(), 0, 2, 8),
-							CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, Collections.singletonList(new MyHandle())));
+							CheckpointCoordinatorTest.generateChainedPartitionableStateHandle(new JobVertexID(), 0, 2, 8, false),
+							null,
+							CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, Collections.singletonList(new MyHandle())),
+							null,
+							0L);
 
 			AcknowledgeCheckpoint withState = new AcknowledgeCheckpoint(
 					new JobID(),

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
index 04ba4e5..f2616b5 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
@@ -26,6 +26,7 @@ import org.apache.flink.core.fs.Path;
 import org.apache.flink.runtime.accumulators.AccumulatorRegistry;
 import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.checkpoint.SubtaskState;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.io.disk.iomanager.IOManager;
@@ -37,7 +38,6 @@ import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.metrics.groups.TaskMetricGroup;
 import org.apache.flink.runtime.query.KvStateRegistry;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
 
 import java.util.Collections;
@@ -155,8 +155,7 @@ public class DummyEnvironment implements Environment {
 	}
 
 	@Override
-	public void acknowledgeCheckpoint(
-			CheckpointMetaData checkpointMetaData, CheckpointStateHandles checkpointStateHandles) {
+	public void acknowledgeCheckpoint(CheckpointMetaData checkpointMetaData, SubtaskState subtaskState) {
 	}
 
 	@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
index eb55c4d..08b84cb 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
@@ -28,6 +28,7 @@ import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.runtime.accumulators.AccumulatorRegistry;
 import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.checkpoint.SubtaskState;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.io.disk.iomanager.IOManager;
@@ -46,7 +47,6 @@ import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.metrics.groups.TaskMetricGroup;
 import org.apache.flink.runtime.query.KvStateRegistry;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
 import org.apache.flink.types.Record;
 import org.apache.flink.util.MutableObjectIterator;
@@ -316,8 +316,7 @@ public class MockEnvironment implements Environment {
 	}
 
 	@Override
-	public void acknowledgeCheckpoint(
-			CheckpointMetaData checkpointMetaData, CheckpointStateHandles checkpointStateHandles) {
+	public void acknowledgeCheckpoint(CheckpointMetaData checkpointMetaData, SubtaskState subtaskState) {
 	}
 
 	@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupRangeOffsetTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupRangeOffsetTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupRangeOffsetTest.java
index 95564cc..fb24712 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupRangeOffsetTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupRangeOffsetTest.java
@@ -45,7 +45,7 @@ public class KeyGroupRangeOffsetTest {
 				keyGroupRangeOffsets.getKeyGroupRange()));
 
 		intersection = keyGroupRangeOffsets.getIntersection(KeyGroupRange.of(11, 13));
-		Assert.assertEquals(KeyGroupRange.EMPTY_KEY_GROUP, intersection.getKeyGroupRange());
+		Assert.assertEquals(KeyGroupRange.EMPTY_KEY_GROUP_RANGE, intersection.getKeyGroupRange());
 		Assert.assertFalse(intersection.iterator().hasNext());
 
 		intersection = keyGroupRangeOffsets.getIntersection(KeyGroupRange.of(5, 13));
@@ -129,7 +129,7 @@ public class KeyGroupRangeOffsetTest {
 			Assert.assertFalse(keyGroupRange.getKeyGroupRange().contains(startKeyGroup - 1));
 			Assert.assertFalse(keyGroupRange.getKeyGroupRange().contains(endKeyGroup + 1));
 		} else {
-			Assert.assertEquals(KeyGroupRange.EMPTY_KEY_GROUP, keyGroupRange);
+			Assert.assertEquals(KeyGroupRange.EMPTY_KEY_GROUP_RANGE, keyGroupRange);
 		}
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupRangeTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupRangeTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupRangeTest.java
index ab0c327..94350ad 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupRangeTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupRangeTest.java
@@ -37,7 +37,7 @@ public class KeyGroupRangeTest {
 		keyGroupRange1 = KeyGroupRange.of(0,5);
 		keyGroupRange2 = KeyGroupRange.of(6,10);
 		intersection =keyGroupRange1.getIntersection(keyGroupRange2);
-		Assert.assertEquals(KeyGroupRange.EMPTY_KEY_GROUP, intersection);
+		Assert.assertEquals(KeyGroupRange.EMPTY_KEY_GROUP_RANGE, intersection);
 		Assert.assertEquals(intersection, keyGroupRange2.getIntersection(keyGroupRange1));
 
 		keyGroupRange1 = KeyGroupRange.of(0, 10);
@@ -93,7 +93,7 @@ public class KeyGroupRangeTest {
 			Assert.assertFalse(keyGroupRange.contains(startKeyGroup - 1));
 			Assert.assertFalse(keyGroupRange.contains(endKeyGroup + 1));
 		} else {
-			Assert.assertEquals(KeyGroupRange.EMPTY_KEY_GROUP, keyGroupRange);
+			Assert.assertEquals(KeyGroupRange.EMPTY_KEY_GROUP_RANGE, keyGroupRange);
 		}
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStreamTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStreamTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStreamTest.java
new file mode 100644
index 0000000..0c4ed74
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStreamTest.java
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.IOException;
+
+public class KeyedStateCheckpointOutputStreamTest {
+
+	private static final int STREAM_CAPACITY = 128;
+
+	private static KeyedStateCheckpointOutputStream createStream(KeyGroupRange keyGroupRange) {
+		CheckpointStreamFactory.CheckpointStateOutputStream checkStream =
+				new TestMemoryCheckpointOutputStream(STREAM_CAPACITY);
+		return new KeyedStateCheckpointOutputStream(checkStream, keyGroupRange);
+	}
+
+	private KeyGroupsStateHandle writeAllTestKeyGroups(
+			KeyedStateCheckpointOutputStream stream, KeyGroupRange keyRange) throws Exception {
+
+		DataOutputView dov = new DataOutputViewStreamWrapper(stream);
+		for (int kg : keyRange) {
+			stream.startNewKeyGroup(kg);
+			dov.writeInt(kg);
+		}
+
+		return stream.closeAndGetHandle();
+	}
+
+	@Test
+	public void testCloseNotPropagated() throws Exception {
+		KeyedStateCheckpointOutputStream stream = createStream(new KeyGroupRange(0, 0));
+		TestMemoryCheckpointOutputStream innerStream = (TestMemoryCheckpointOutputStream) stream.getDelegate();
+		stream.close();
+		Assert.assertFalse(innerStream.isClosed());
+	}
+
+	@Test
+	public void testEmptyKeyedStream() throws Exception {
+		final KeyGroupRange keyRange = new KeyGroupRange(0, 2);
+		KeyedStateCheckpointOutputStream stream = createStream(keyRange);
+		TestMemoryCheckpointOutputStream innerStream = (TestMemoryCheckpointOutputStream) stream.getDelegate();
+		KeyGroupsStateHandle emptyHandle = stream.closeAndGetHandle();
+		Assert.assertTrue(innerStream.isClosed());
+		Assert.assertEquals(null, emptyHandle);
+	}
+
+	@Test
+	public void testWriteReadRoundtrip() throws Exception {
+		final KeyGroupRange keyRange = new KeyGroupRange(0, 2);
+		KeyedStateCheckpointOutputStream stream = createStream(keyRange);
+		KeyGroupsStateHandle fullHandle = writeAllTestKeyGroups(stream, keyRange);
+		Assert.assertNotNull(fullHandle);
+
+		verifyRead(fullHandle, keyRange);
+	}
+
+	@Test
+	public void testWriteKeyGroupTracking() throws Exception {
+		final KeyGroupRange keyRange = new KeyGroupRange(0, 2);
+		KeyedStateCheckpointOutputStream stream = createStream(keyRange);
+
+		try {
+			stream.startNewKeyGroup(4711);
+			Assert.fail();
+		} catch (IllegalArgumentException expected) {
+			// good
+		}
+
+		Assert.assertEquals(-1, stream.getCurrentKeyGroup());
+
+		DataOutputView dov = new DataOutputViewStreamWrapper(stream);
+		int previous = -1;
+		for (int kg : keyRange) {
+			Assert.assertFalse(stream.isKeyGroupAlreadyStarted(kg));
+			Assert.assertFalse(stream.isKeyGroupAlreadyFinished(kg));
+			stream.startNewKeyGroup(kg);
+			if(-1 != previous) {
+				Assert.assertTrue(stream.isKeyGroupAlreadyStarted(previous));
+				Assert.assertTrue(stream.isKeyGroupAlreadyFinished(previous));
+			}
+			Assert.assertTrue(stream.isKeyGroupAlreadyStarted(kg));
+			Assert.assertFalse(stream.isKeyGroupAlreadyFinished(kg));
+			dov.writeInt(kg);
+			previous = kg;
+		}
+
+		KeyGroupsStateHandle fullHandle = stream.closeAndGetHandle();
+
+		verifyRead(fullHandle, keyRange);
+
+		for (int kg : keyRange) {
+			try {
+				stream.startNewKeyGroup(kg);
+				Assert.fail();
+			} catch (IOException ex) {
+				// required
+			}
+		}
+	}
+
+	@Test
+	public void testReadWriteMissingKeyGroups() throws Exception {
+		final KeyGroupRange keyRange = new KeyGroupRange(0, 2);
+		KeyedStateCheckpointOutputStream stream = createStream(keyRange);
+
+		DataOutputView dov = new DataOutputViewStreamWrapper(stream);
+		stream.startNewKeyGroup(1);
+		dov.writeInt(1);
+
+		KeyGroupsStateHandle fullHandle = stream.closeAndGetHandle();
+
+		int count = 0;
+		try (FSDataInputStream in = fullHandle.openInputStream()) {
+			DataInputView div = new DataInputViewStreamWrapper(in);
+			for (int kg : fullHandle.keyGroups()) {
+				long off = fullHandle.getOffsetForKeyGroup(kg);
+				if (off >= 0) {
+					in.seek(off);
+					Assert.assertEquals(1, div.readInt());
+					++count;
+				}
+			}
+		}
+
+		Assert.assertEquals(1, count);
+	}
+
+	private static void verifyRead(KeyGroupsStateHandle fullHandle, KeyGroupRange keyRange) throws IOException {
+		int count = 0;
+		try (FSDataInputStream in = fullHandle.openInputStream()) {
+			DataInputView div = new DataInputViewStreamWrapper(in);
+			for (int kg : fullHandle.keyGroups()) {
+				long off = fullHandle.getOffsetForKeyGroup(kg);
+				in.seek(off);
+				Assert.assertEquals(kg, div.readInt());
+				++count;
+			}
+		}
+
+		Assert.assertEquals(keyRange.getNumberOfKeyGroups(), count);
+	}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateOutputCheckpointStreamTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateOutputCheckpointStreamTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateOutputCheckpointStreamTest.java
new file mode 100644
index 0000000..c6ef0f0
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateOutputCheckpointStreamTest.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.IOException;
+
+public class OperatorStateOutputCheckpointStreamTest {
+
+	private static final int STREAM_CAPACITY = 128;
+
+	private static OperatorStateCheckpointOutputStream createStream() throws IOException {
+		CheckpointStreamFactory.CheckpointStateOutputStream checkStream =
+				new TestMemoryCheckpointOutputStream(STREAM_CAPACITY);
+		return new OperatorStateCheckpointOutputStream(checkStream);
+	}
+
+	private OperatorStateHandle writeAllTestKeyGroups(
+			OperatorStateCheckpointOutputStream stream, int numPartitions) throws Exception {
+
+		DataOutputView dov = new DataOutputViewStreamWrapper(stream);
+		for (int i = 0; i < numPartitions; ++i) {
+			Assert.assertEquals(i, stream.getNumberOfPartitions());
+			stream.startNewPartition();
+			dov.writeInt(i);
+		}
+
+		return stream.closeAndGetHandle();
+	}
+
+	@Test
+	public void testCloseNotPropagated() throws Exception {
+		OperatorStateCheckpointOutputStream stream = createStream();
+		TestMemoryCheckpointOutputStream innerStream = (TestMemoryCheckpointOutputStream) stream.getDelegate();
+		stream.close();
+		Assert.assertFalse(innerStream.isClosed());
+		innerStream.close();
+	}
+
+	@Test
+	public void testEmptyOperatorStream() throws Exception {
+		OperatorStateCheckpointOutputStream stream = createStream();
+		TestMemoryCheckpointOutputStream innerStream = (TestMemoryCheckpointOutputStream) stream.getDelegate();
+		OperatorStateHandle emptyHandle = stream.closeAndGetHandle();
+		Assert.assertTrue(innerStream.isClosed());
+		Assert.assertEquals(0, stream.getNumberOfPartitions());
+		Assert.assertEquals(null, emptyHandle);
+	}
+
+	@Test
+	public void testWriteReadRoundtrip() throws Exception {
+		int numPartitions = 3;
+		OperatorStateCheckpointOutputStream stream = createStream();
+		OperatorStateHandle fullHandle = writeAllTestKeyGroups(stream, numPartitions);
+		Assert.assertNotNull(fullHandle);
+
+		verifyRead(fullHandle, numPartitions);
+	}
+
+	private static void verifyRead(OperatorStateHandle fullHandle, int numPartitions) throws IOException {
+		int count = 0;
+		try (FSDataInputStream in = fullHandle.openInputStream()) {
+			long[] offsets = fullHandle.getStateNameToPartitionOffsets().
+					get(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME);
+
+			Assert.assertNotNull(offsets);
+
+			DataInputView div = new DataInputViewStreamWrapper(in);
+			for (int i = 0; i < numPartitions; ++i) {
+				in.seek(offsets[i]);
+				Assert.assertEquals(i, div.readInt());
+				++count;
+			}
+		}
+
+		Assert.assertEquals(numPartitions, count);
+	}
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
index 2f21574..9e835ce 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
@@ -38,7 +38,7 @@ import org.apache.flink.api.common.typeutils.base.IntSerializer;
 import org.apache.flink.api.common.typeutils.base.LongSerializer;
 import org.apache.flink.api.common.typeutils.base.StringSerializer;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
-import org.apache.flink.runtime.checkpoint.CheckpointCoordinator;
+import org.apache.flink.runtime.checkpoint.StateAssignmentOperation;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
 import org.apache.flink.runtime.query.KvStateID;
@@ -707,11 +707,11 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 
 		KeyGroupsStateHandle snapshot = runSnapshot(backend.snapshot(0, 0, streamFactory));
 
-		List<KeyGroupsStateHandle> firstHalfKeyGroupStates = CheckpointCoordinator.getKeyGroupsStateHandles(
+		List<KeyGroupsStateHandle> firstHalfKeyGroupStates = StateAssignmentOperation.getKeyGroupsStateHandles(
 				Collections.singletonList(snapshot),
 				KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(MAX_PARALLELISM, 2, 0));
 
-		List<KeyGroupsStateHandle> secondHalfKeyGroupStates = CheckpointCoordinator.getKeyGroupsStateHandles(
+		List<KeyGroupsStateHandle> secondHalfKeyGroupStates = StateAssignmentOperation.getKeyGroupsStateHandles(
 				Collections.singletonList(snapshot),
 				KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(MAX_PARALLELISM, 2, 1));
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/state/TestMemoryCheckpointOutputStream.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/TestMemoryCheckpointOutputStream.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/TestMemoryCheckpointOutputStream.java
new file mode 100644
index 0000000..5accc19
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/TestMemoryCheckpointOutputStream.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory;
+
+import java.io.IOException;
+
+final class TestMemoryCheckpointOutputStream extends MemCheckpointStreamFactory.MemoryCheckpointOutputStream {
+
+	private boolean closed;
+
+	public TestMemoryCheckpointOutputStream(int maxSize) {
+		super(maxSize);
+		this.closed = false;
+	}
+
+	@Override
+	public void close() {
+		this.closed = true;
+		super.close();
+	}
+
+	public boolean isClosed() {
+		return this.closed;
+	}
+
+	@Override
+	public StreamStateHandle closeAndGetHandle() throws IOException {
+		this.closed = true;
+		return super.closeAndGetHandle();
+	}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
index e2abe88..7dd67ed 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
@@ -48,6 +48,7 @@ import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.util.SerializedValue;
 import org.junit.Before;
 import org.junit.Test;
@@ -209,9 +210,7 @@ public class TaskAsyncCallTest {
 		}
 
 		@Override
-		public void setInitialState(ChainedStateHandle<StreamStateHandle> chainedState,
-									List<KeyGroupsStateHandle> keyGroupsState,
-									List<Collection<OperatorStateHandle>> partitionableOperatorState) throws Exception {
+		public void setInitialState(TaskStateHandles taskStateHandles) throws Exception {
 
 		}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-connectors/flink-connector-filesystem/src/test/java/org/apache/flink/streaming/connectors/fs/bucketing/BucketingSinkTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-filesystem/src/test/java/org/apache/flink/streaming/connectors/fs/bucketing/BucketingSinkTest.java b/flink-streaming-connectors/flink-connector-filesystem/src/test/java/org/apache/flink/streaming/connectors/fs/bucketing/BucketingSinkTest.java
index ac1e3f0..0c0111c 100644
--- a/flink-streaming-connectors/flink-connector-filesystem/src/test/java/org/apache/flink/streaming/connectors/fs/bucketing/BucketingSinkTest.java
+++ b/flink-streaming-connectors/flink-connector-filesystem/src/test/java/org/apache/flink/streaming/connectors/fs/bucketing/BucketingSinkTest.java
@@ -137,7 +137,7 @@ public class BucketingSinkTest {
 
 		// snapshot but don't call notify to simulate a notify that never
 		// arrives, the sink should move pending files in restore() in that case
-		StreamStateHandle snapshot1 = testHarness.snapshot(0, 0);
+		StreamStateHandle snapshot1 = testHarness.snapshotLegacy(0, 0);
 
 		testHarness = createTestSink(dataDir, clock);
 		testHarness.setup();

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
index 7d6bd76..db092f0 100644
--- a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
+++ b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
@@ -19,13 +19,16 @@ package org.apache.flink.streaming.connectors.kafka;
 
 import org.apache.commons.collections.map.LinkedMap;
 import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.OperatorStateStore;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.ClosureCleaner;
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.state.CheckpointListener;
-import org.apache.flink.api.common.state.OperatorStateStore;
+import org.apache.flink.runtime.state.DefaultOperatorStateBackend;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
 import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
 import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks;
 import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
@@ -37,11 +40,9 @@ import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition
 import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionState;
 import org.apache.flink.streaming.util.serialization.KeyedDeserializationSchema;
 import org.apache.flink.util.SerializedValue;
-
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.Comparator;
@@ -97,7 +98,7 @@ public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFuncti
 	 * The assigner is kept in serialized form, to deserialize it into multiple copies */
 	private SerializedValue<AssignerWithPunctuatedWatermarks<T>> punctuatedWatermarkAssigner;
 
-	private transient OperatorStateStore stateStore;
+	private transient ListState<Tuple2<KafkaTopicPartition, Long>> offsetsStateForCheckpoint;
 
 	// ------------------------------------------------------------------------
 	//  runtime state (used individually by each parallel subtask) 
@@ -311,33 +312,33 @@ public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFuncti
 	// ------------------------------------------------------------------------
 
 	@Override
-	public void initializeState(OperatorStateStore stateStore) throws Exception {
-
-		this.stateStore = stateStore;
+	public void initializeState(FunctionInitializationContext context) throws Exception {
 
-		ListState<Serializable> offsets =
-				stateStore.getSerializableListState(OperatorStateStore.DEFAULT_OPERATOR_STATE_NAME);
+		OperatorStateStore stateStore = context.getManagedOperatorStateStore();
+		offsetsStateForCheckpoint = stateStore.getSerializableListState(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME);
 
-		restoreToOffset = new HashMap<>();
+		if (context.isRestored()) {
+			restoreToOffset = new HashMap<>();
+			for (Tuple2<KafkaTopicPartition, Long> kafkaOffset : offsetsStateForCheckpoint.get()) {
+				restoreToOffset.put(kafkaOffset.f0, kafkaOffset.f1);
+			}
 
-		for (Serializable serializable : offsets.get()) {
-			@SuppressWarnings("unchecked")
-			Tuple2<KafkaTopicPartition, Long> kafkaOffset = (Tuple2<KafkaTopicPartition, Long>) serializable;
-			restoreToOffset.put(kafkaOffset.f0, kafkaOffset.f1);
+			LOG.info("Setting restore state in the FlinkKafkaConsumer.");
+			if (LOG.isDebugEnabled()) {
+				LOG.debug("Using the following offsets: {}", restoreToOffset);
+			}
+		} else {
+			LOG.info("No restore state for FlinkKafkaConsumer.");
 		}
-
-		LOG.info("Setting restore state in the FlinkKafkaConsumer: {}", restoreToOffset);
 	}
 
 	@Override
-	public void prepareSnapshot(long checkpointId, long timestamp) throws Exception {
+	public void snapshotState(FunctionSnapshotContext context) throws Exception {
 		if (!running) {
-			LOG.debug("storeOperatorState() called on closed source");
+			LOG.debug("snapshotState() called on closed source");
 		} else {
 
-			ListState<Serializable> listState =
-					stateStore.getSerializableListState(OperatorStateStore.DEFAULT_OPERATOR_STATE_NAME);
-			listState.clear();
+			offsetsStateForCheckpoint.clear();
 
 			final AbstractFetcher<?, ?> fetcher = this.kafkaFetcher;
 			if (fetcher == null) {
@@ -347,14 +348,16 @@ public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFuncti
 				if (restoreToOffset != null) {
 					// the map cannot be asynchronously updated, because only one checkpoint call can happen
 					// on this function at a time: either snapshotState() or notifyCheckpointComplete()
-					pendingOffsetsToCommit.put(checkpointId, restoreToOffset);
+					pendingOffsetsToCommit.put(context.getCheckpointId(), restoreToOffset);
 
 					for (Map.Entry<KafkaTopicPartition, Long> kafkaTopicPartitionLongEntry : restoreToOffset.entrySet()) {
-						listState.add(Tuple2.of(kafkaTopicPartitionLongEntry.getKey(), kafkaTopicPartitionLongEntry.getValue()));
+						offsetsStateForCheckpoint.add(
+								Tuple2.of(kafkaTopicPartitionLongEntry.getKey(), kafkaTopicPartitionLongEntry.getValue()));
 					}
 				} else if (subscribedPartitions != null) {
 					for (KafkaTopicPartition subscribedPartition : subscribedPartitions) {
-						listState.add(Tuple2.of(subscribedPartition, KafkaTopicPartitionState.OFFSET_NOT_SET));
+						offsetsStateForCheckpoint.add(
+								Tuple2.of(subscribedPartition, KafkaTopicPartitionState.OFFSET_NOT_SET));
 					}
 				}
 			} else {
@@ -362,10 +365,11 @@ public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFuncti
 
 				// the map cannot be asynchronously updated, because only one checkpoint call can happen
 				// on this function at a time: either snapshotState() or notifyCheckpointComplete()
-				pendingOffsetsToCommit.put(checkpointId, currentOffsets);
+				pendingOffsetsToCommit.put(context.getCheckpointId(), currentOffsets);
 
 				for (Map.Entry<KafkaTopicPartition, Long> kafkaTopicPartitionLongEntry : currentOffsets.entrySet()) {
-					listState.add(Tuple2.of(kafkaTopicPartitionLongEntry.getKey(), kafkaTopicPartitionLongEntry.getValue()));
+					offsetsStateForCheckpoint.add(
+							Tuple2.of(kafkaTopicPartitionLongEntry.getKey(), kafkaTopicPartitionLongEntry.getValue()));
 				}
 			}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java
index 26a695e..bede064 100644
--- a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java
+++ b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java
@@ -18,10 +18,12 @@
 package org.apache.flink.streaming.connectors.kafka;
 
 import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.api.common.state.OperatorStateStore;
 import org.apache.flink.api.java.ClosureCleaner;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.metrics.MetricGroup;
-import org.apache.flink.api.common.state.OperatorStateStore;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
 import org.apache.flink.runtime.util.SerializableObject;
 import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
 import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
@@ -330,12 +332,12 @@ public abstract class FlinkKafkaProducerBase<IN> extends RichSinkFunction<IN> im
 	protected abstract void flush();
 
 	@Override
-	public void initializeState(OperatorStateStore stateStore) throws Exception {
-		this.stateStore = stateStore;
+	public void initializeState(FunctionInitializationContext context) throws Exception {
+		this.stateStore = context.getManagedOperatorStateStore();
 	}
 
 	@Override
-	public void prepareSnapshot(long checkpointId, long timestamp) throws Exception {
+	public void snapshotState(FunctionSnapshotContext ctx) throws Exception {
 		if (flushOnCheckpoint) {
 			// flushing is activated: We need to wait until pendingRecords is 0
 			flush();


[3/8] flink git commit: [FLINK-4844] Partitionable Raw Keyed/Operator State

Posted by al...@apache.org.
http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorTest.java
new file mode 100644
index 0000000..f5d633c
--- /dev/null
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorTest.java
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators;
+
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.blob.BlobKey;
+import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
+import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
+import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
+import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor;
+import org.apache.flink.runtime.execution.ExecutionState;
+import org.apache.flink.runtime.execution.librarycache.LibraryCacheManager;
+import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.filecache.FileCache;
+import org.apache.flink.runtime.io.disk.iomanager.IOManager;
+import org.apache.flink.runtime.io.network.NetworkEnvironment;
+import org.apache.flink.runtime.io.network.netty.PartitionStateChecker;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
+import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
+import org.apache.flink.runtime.memory.MemoryManager;
+import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup;
+import org.apache.flink.runtime.query.TaskKvStateRegistry;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.runtime.taskmanager.CheckpointResponder;
+import org.apache.flink.runtime.taskmanager.Task;
+import org.apache.flink.runtime.taskmanager.TaskManagerConnection;
+import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
+import org.apache.flink.streaming.api.TimeCharacteristic;
+import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.graph.StreamConfig;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.SourceStreamTask;
+import org.apache.flink.streaming.runtime.tasks.StreamTask;
+import org.apache.flink.util.SerializedValue;
+import org.junit.Test;
+
+import java.io.Serializable;
+import java.net.URL;
+import java.util.Collections;
+import java.util.concurrent.Executor;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class AbstractUdfStreamOperatorTest {
+
+	@Test
+	public void testLifeCycle() throws Exception {
+
+		Configuration taskManagerConfig = new Configuration();
+
+		StreamConfig cfg = new StreamConfig(new Configuration());
+		cfg.setStreamOperator(new LifecycleTrackingStreamSource(new MockSourceFunction()));
+		cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
+
+		Task task = createTask(SourceStreamTask.class, cfg, taskManagerConfig);
+
+		task.startTaskThread();
+
+		// wait for clean termination
+		task.getExecutingThread().join();
+		assertEquals(ExecutionState.FINISHED, task.getExecutionState());
+	}
+
+	private static class MockSourceFunction extends RichSourceFunction<Long> {
+
+		private static final long serialVersionUID = 1L;
+
+		@Override
+		public void run(SourceContext<Long> ctx) {
+		}
+
+		@Override
+		public void cancel() {
+		}
+
+		@Override
+		public void setRuntimeContext(RuntimeContext t) {
+			System.out.println("!setRuntimeContext");
+			super.setRuntimeContext(t);
+		}
+
+		@Override
+		public void open(Configuration parameters) throws Exception {
+			System.out.println("!open");
+			super.open(parameters);
+		}
+
+		@Override
+		public void close() throws Exception {
+			System.out.println("!close");
+			super.close();
+		}
+	}
+
+	private Task createTask(
+			Class<? extends AbstractInvokable> invokable,
+			StreamConfig taskConfig,
+			Configuration taskManagerConfig) throws Exception {
+
+		LibraryCacheManager libCache = mock(LibraryCacheManager.class);
+		when(libCache.getClassLoader(any(JobID.class))).thenReturn(getClass().getClassLoader());
+
+		ResultPartitionManager partitionManager = mock(ResultPartitionManager.class);
+		ResultPartitionConsumableNotifier consumableNotifier = mock(ResultPartitionConsumableNotifier.class);
+		PartitionStateChecker partitionStateChecker = mock(PartitionStateChecker.class);
+		Executor executor = mock(Executor.class);
+
+		NetworkEnvironment network = mock(NetworkEnvironment.class);
+		when(network.getResultPartitionManager()).thenReturn(partitionManager);
+		when(network.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC);
+		when(network.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class)))
+				.thenReturn(mock(TaskKvStateRegistry.class));
+
+		TaskDeploymentDescriptor tdd = new TaskDeploymentDescriptor(
+				new JobID(), "Job Name", new JobVertexID(), new ExecutionAttemptID(),
+				new SerializedValue<>(new ExecutionConfig()),
+				"Test Task", 1, 0, 1, 0,
+				new Configuration(),
+				taskConfig.getConfiguration(),
+				invokable.getName(),
+				Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
+				Collections.<InputGateDeploymentDescriptor>emptyList(),
+				Collections.<BlobKey>emptyList(),
+				Collections.<URL>emptyList(),
+				0);
+
+		return new Task(
+				tdd,
+				mock(MemoryManager.class),
+				mock(IOManager.class),
+				network,
+				mock(BroadcastVariableManager.class),
+				mock(TaskManagerConnection.class),
+				mock(InputSplitProvider.class),
+				mock(CheckpointResponder.class),
+				libCache,
+				mock(FileCache.class),
+				new TaskManagerRuntimeInfo("localhost", taskManagerConfig, System.getProperty("java.io.tmpdir")),
+				new UnregisteredTaskMetricsGroup(),
+				consumableNotifier,
+				partitionStateChecker,
+				executor);
+	}
+
+	static class LifecycleTrackingStreamSource<OUT, SRC extends SourceFunction<OUT>>
+			extends StreamSource<OUT, SRC> implements Serializable {
+
+		//private transient final AtomicInteger currentState;
+
+		private static final long serialVersionUID = 2431488948886850562L;
+
+		public LifecycleTrackingStreamSource(SRC sourceFunction) {
+			super(sourceFunction);
+		}
+
+		@Override
+		public void setup(StreamTask<?, ?> containingTask, StreamConfig config, Output<StreamRecord<OUT>> output) {
+			System.out.println("setup");
+			super.setup(containingTask, config, output);
+		}
+
+		@Override
+		public void snapshotState(StateSnapshotContext context) throws Exception {
+			System.out.println("snapshotState");
+			super.snapshotState(context);
+		}
+
+		@Override
+		public void initializeState(StateInitializationContext context) throws Exception {
+			System.out.println("initializeState");
+			super.initializeState(context);
+		}
+
+		@Override
+		public void open() throws Exception {
+			System.out.println("open");
+			super.open();
+		}
+
+		@Override
+		public void close() throws Exception {
+			System.out.println("close");
+			super.close();
+		}
+
+		@Override
+		public void dispose() throws Exception {
+			super.dispose();
+			System.out.println("dispose");
+		}
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java
new file mode 100644
index 0000000..75c2261
--- /dev/null
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java
@@ -0,0 +1,260 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators;
+
+import org.apache.flink.api.common.state.KeyedStateStore;
+import org.apache.flink.api.common.state.OperatorStateStore;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.state.ClosableRegistry;
+import org.apache.flink.runtime.state.DefaultOperatorStateBackend;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
+import org.apache.flink.runtime.state.KeyGroupStatePartitionStreamProvider;
+import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.StateInitializationContextImpl;
+import org.apache.flink.runtime.state.StatePartitionStreamProvider;
+import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
+import org.apache.flink.runtime.util.LongArrayList;
+import org.apache.flink.util.Preconditions;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import static org.mockito.Mockito.mock;
+
+public class StateInitializationContextImplTest {
+
+	static final int NUM_HANDLES = 10;
+
+	private StateInitializationContextImpl initializationContext;
+	private ClosableRegistry closableRegistry;
+
+	private int writtenKeyGroups;
+	private Set<Integer> writtenOperatorStates;
+
+	@Before
+	public void setUp() throws Exception {
+
+
+		this.writtenKeyGroups = 0;
+		this.writtenOperatorStates = new HashSet<>();
+
+		this.closableRegistry = new ClosableRegistry();
+		OperatorStateStore stateStore = mock(OperatorStateStore.class);
+
+		ByteArrayOutputStreamWithPos out = new ByteArrayOutputStreamWithPos(64);
+
+		List<KeyGroupsStateHandle> keyGroupsStateHandles = new ArrayList<>(NUM_HANDLES);
+		int prev = 0;
+		for (int i = 0; i < NUM_HANDLES; ++i) {
+			out.reset();
+			int size = i % 4;
+			int end = prev + size;
+			DataOutputView dov = new DataOutputViewStreamWrapper(out);
+			KeyGroupRangeOffsets offsets =
+					new KeyGroupRangeOffsets(i == 9 ? KeyGroupRange.EMPTY_KEY_GROUP_RANGE : new KeyGroupRange(prev, end));
+			prev = end + 1;
+			for (int kg : offsets.getKeyGroupRange()) {
+				offsets.setKeyGroupOffset(kg, out.getPosition());
+				dov.writeInt(kg);
+				++writtenKeyGroups;
+			}
+
+			KeyGroupsStateHandle handle =
+					new KeyGroupsStateHandle(offsets, new ByteStateHandleCloseChecking("kg-" + i, out.toByteArray()));
+
+			keyGroupsStateHandles.add(handle);
+		}
+
+		List<OperatorStateHandle> operatorStateHandles = new ArrayList<>(NUM_HANDLES);
+
+		for (int i = 0; i < NUM_HANDLES; ++i) {
+			int size = i % 4;
+			out.reset();
+			DataOutputView dov = new DataOutputViewStreamWrapper(out);
+			LongArrayList offsets = new LongArrayList(size);
+			for (int s = 0; s < size; ++s) {
+				offsets.add(out.getPosition());
+				int val = i * NUM_HANDLES + s;
+				dov.writeInt(val);
+				writtenOperatorStates.add(val);
+			}
+
+			Map<String, long[]> offsetsMap = new HashMap<>();
+			offsetsMap.put(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME, offsets.toArray());
+			OperatorStateHandle operatorStateHandle =
+					new OperatorStateHandle(offsetsMap, new ByteStateHandleCloseChecking("os-" + i, out.toByteArray()));
+			operatorStateHandles.add(operatorStateHandle);
+		}
+
+		this.initializationContext =
+				new StateInitializationContextImpl(
+						true,
+						stateStore,
+						mock(KeyedStateStore.class),
+						keyGroupsStateHandles,
+						operatorStateHandles,
+						closableRegistry);
+	}
+
+	@Test
+	public void getOperatorStateStreams() throws Exception {
+
+	}
+
+	@Test
+	public void getKeyedStateStreams() throws Exception {
+
+		int readKeyGroupCount = 0;
+
+		for (KeyGroupStatePartitionStreamProvider stateStreamProvider
+				: initializationContext.getRawKeyedStateInputs()) {
+
+			Assert.assertNotNull(stateStreamProvider);
+
+			try (InputStream is = stateStreamProvider.getStream()) {
+				DataInputView div = new DataInputViewStreamWrapper(is);
+				int val = div.readInt();
+				++readKeyGroupCount;
+				Assert.assertEquals(stateStreamProvider.getKeyGroupId(), val);
+			}
+		}
+
+		Assert.assertEquals(writtenKeyGroups, readKeyGroupCount);
+	}
+
+	@Test
+	public void getOperatorStateStore() throws Exception {
+
+		Set<Integer> readStatesCount = new HashSet<>();
+
+		for (StatePartitionStreamProvider statePartitionStreamProvider
+				: initializationContext.getRawOperatorStateInputs()) {
+
+			Assert.assertNotNull(statePartitionStreamProvider);
+
+			try (InputStream is = statePartitionStreamProvider.getStream()) {
+				DataInputView div = new DataInputViewStreamWrapper(is);
+				Assert.assertTrue(readStatesCount.add(div.readInt()));
+			}
+		}
+
+		Assert.assertEquals(writtenOperatorStates, readStatesCount);
+	}
+
+	@Test
+	public void close() throws Exception {
+
+		int count = 0;
+		int stopCount = NUM_HANDLES / 2;
+		boolean isClosed = false;
+
+
+		try {
+			for (KeyGroupStatePartitionStreamProvider stateStreamProvider
+					: initializationContext.getRawKeyedStateInputs()) {
+				Assert.assertNotNull(stateStreamProvider);
+
+				if (count == stopCount) {
+					initializationContext.close();
+					isClosed = true;
+				}
+
+				try (InputStream is = stateStreamProvider.getStream()) {
+					DataInputView div = new DataInputViewStreamWrapper(is);
+					try {
+						int val = div.readInt();
+						Assert.assertEquals(stateStreamProvider.getKeyGroupId(), val);
+						if (isClosed) {
+							Assert.fail("Close was ignored: stream");
+						}
+						++count;
+					} catch (IOException ioex) {
+						if (!isClosed) {
+							throw ioex;
+						}
+					}
+				}
+			}
+			Assert.fail("Close was ignored: registry");
+		} catch (IOException iex) {
+			Assert.assertTrue(isClosed);
+			Assert.assertEquals(stopCount, count);
+		}
+
+	}
+
+	static final class ByteStateHandleCloseChecking extends ByteStreamStateHandle {
+
+		private static final long serialVersionUID = -6201941296931334140L;
+
+		public ByteStateHandleCloseChecking(String handleName, byte[] data) {
+			super(handleName, data);
+		}
+
+		@Override
+		public FSDataInputStream openInputStream() throws IOException {
+			return new FSDataInputStream() {
+				private int index = 0;
+				private boolean closed = false;
+
+				@Override
+				public void seek(long desired) throws IOException {
+					Preconditions.checkArgument(desired >= 0 && desired < Integer.MAX_VALUE);
+					index = (int) desired;
+				}
+
+				@Override
+				public long getPos() throws IOException {
+					return index;
+				}
+
+				@Override
+				public int read() throws IOException {
+					if (closed) {
+						throw new IOException("Stream closed");
+					}
+					return index < data.length ? data[index++] & 0xFF : -1;
+				}
+
+				@Override
+				public void close() throws IOException {
+					super.close();
+					this.closed = true;
+				}
+			};
+		}
+	}
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateSnapshotContextSynchronousImplTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateSnapshotContextSynchronousImplTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateSnapshotContextSynchronousImplTest.java
new file mode 100644
index 0000000..0ee839e
--- /dev/null
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateSnapshotContextSynchronousImplTest.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators;
+
+import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.runtime.state.ClosableRegistry;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyedStateCheckpointOutputStream;
+import org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream;
+import org.apache.flink.runtime.state.StateSnapshotContextSynchronousImpl;
+import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+public class StateSnapshotContextSynchronousImplTest {
+
+	private StateSnapshotContextSynchronousImpl snapshotContext;
+
+	@Before
+	public void setUp() throws Exception {
+		ClosableRegistry closableRegistry = new ClosableRegistry();
+		CheckpointStreamFactory streamFactory = new MemCheckpointStreamFactory(1024);
+		KeyGroupRange keyGroupRange = new KeyGroupRange(0, 2);
+		this.snapshotContext = new StateSnapshotContextSynchronousImpl(42, 4711, streamFactory, keyGroupRange, closableRegistry);
+	}
+
+	@Test
+	public void testMetaData() {
+		Assert.assertEquals(42, snapshotContext.getCheckpointId());
+		Assert.assertEquals(4711, snapshotContext.getCheckpointTimestamp());
+	}
+
+	@Test
+	public void testCreateRawKeyedStateOutput() throws Exception {
+		KeyedStateCheckpointOutputStream stream = snapshotContext.getRawKeyedOperatorStateOutput();
+		Assert.assertNotNull(stream);
+	}
+
+	@Test
+	public void testCreateRawOperatorStateOutput() throws Exception {
+		OperatorStateCheckpointOutputStream stream = snapshotContext.getRawOperatorStateOutput();
+		Assert.assertNotNull(stream);
+	}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamOperatorSnapshotRestoreTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamOperatorSnapshotRestoreTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamOperatorSnapshotRestoreTest.java
new file mode 100644
index 0000000..ada0b86
--- /dev/null
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamOperatorSnapshotRestoreTest.java
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.state.KeyGroupStatePartitionStreamProvider;
+import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateCheckpointOutputStream;
+import org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream;
+import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StatePartitionStreamProvider;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
+import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness;
+import org.apache.flink.util.FutureUtil;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.InputStream;
+import java.util.BitSet;
+import java.util.Collections;
+
+public class StreamOperatorSnapshotRestoreTest {
+
+	@Test
+	public void testOperatorStatesSnapshotRestore() throws Exception {
+
+		//-------------------------------------------------------------------------- snapshot
+
+		TestOneInputStreamOperator op = new TestOneInputStreamOperator(false);
+
+		KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> testHarness =
+				new KeyedOneInputStreamOperatorTestHarness<>(op, new KeySelector<Integer, Integer>() {
+					@Override
+					public Integer getKey(Integer value) throws Exception {
+						return value;
+					}
+				}, TypeInformation.of(Integer.class));
+
+		testHarness.open();
+
+		for (int i = 0; i < 10; ++i) {
+			testHarness.processElement(new StreamRecord<>(i));
+		}
+
+		OperatorSnapshotResult snapshotInProgress = testHarness.snapshot(1L, 1L);
+
+		KeyGroupsStateHandle keyedManaged =
+				FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getKeyedStateManagedFuture());
+		KeyGroupsStateHandle keyedRaw =
+				FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getKeyedStateRawFuture());
+
+		OperatorStateHandle opManaged =
+				FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getOperatorStateManagedFuture());
+		OperatorStateHandle opRaw =
+				FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getOperatorStateRawFuture());
+
+		testHarness.close();
+
+		//-------------------------------------------------------------------------- restore
+
+		op = new TestOneInputStreamOperator(true);
+		testHarness = new KeyedOneInputStreamOperatorTestHarness<>(op, new KeySelector<Integer, Integer>() {
+			@Override
+			public Integer getKey(Integer value) throws Exception {
+				return value;
+			}
+		}, TypeInformation.of(Integer.class));
+
+		testHarness.initializeState(new OperatorStateHandles(
+				0,
+				null,
+				Collections.singletonList(keyedManaged),
+				Collections.singletonList(keyedRaw),
+				Collections.singletonList(opManaged),
+				Collections.singletonList(opRaw)));
+
+		testHarness.open();
+
+		for (int i = 0; i < 10; ++i) {
+			testHarness.processElement(new StreamRecord<>(i));
+		}
+
+		testHarness.close();
+	}
+
+	static class TestOneInputStreamOperator
+			extends AbstractStreamOperator<Integer>
+			implements OneInputStreamOperator<Integer, Integer> {
+
+		private static final long serialVersionUID = -8942866418598856475L;
+
+		public TestOneInputStreamOperator(boolean verifyRestore) {
+			this.verifyRestore = verifyRestore;
+		}
+
+		private boolean verifyRestore;
+		private ValueState<Integer> keyedState;
+		private ListState<Integer> opState;
+
+		@Override
+		public void processElement(StreamRecord<Integer> element) throws Exception {
+			if (verifyRestore) {
+				// check restored managed keyed state
+				long exp = element.getValue() + 1;
+				long act = keyedState.value();
+				Assert.assertEquals(exp, act);
+			} else {
+				// write managed keyed state that goes into snapshot
+				keyedState.update(element.getValue() + 1);
+				// write managed operator state that goes into snapshot
+				opState.add(element.getValue());
+			}
+		}
+
+		@Override
+		public void processWatermark(Watermark mark) throws Exception {
+
+		}
+
+		@Override
+		public void snapshotState(StateSnapshotContext context) throws Exception {
+
+			KeyedStateCheckpointOutputStream out = context.getRawKeyedOperatorStateOutput();
+			DataOutputView dov = new DataOutputViewStreamWrapper(out);
+
+			// write raw keyed state that goes into snapshot
+			int count = 0;
+			for (int kg : out.getKeyGroupList()) {
+				out.startNewKeyGroup(kg);
+				dov.writeInt(kg + 2);
+				++count;
+			}
+
+			Assert.assertEquals(KeyedOneInputStreamOperatorTestHarness.MAX_PARALLELISM, count);
+
+			// write raw operator state that goes into snapshot
+			OperatorStateCheckpointOutputStream outOp = context.getRawOperatorStateOutput();
+			dov = new DataOutputViewStreamWrapper(outOp);
+			for (int i = 0; i < 13; ++i) {
+				outOp.startNewPartition();
+				dov.writeInt(42 + i);
+			}
+		}
+
+		@Override
+		public void initializeState(StateInitializationContext context) throws Exception {
+
+			Assert.assertEquals(verifyRestore, context.isRestored());
+
+			keyedState = context.getManagedKeyedStateStore().getState(new ValueStateDescriptor<>("managed-keyed", Integer.class, 0));
+			opState = context.getManagedOperatorStateStore().getSerializableListState("managed-op-state");
+
+			if (context.isRestored()) {
+				// check restored raw keyed state
+				int count = 0;
+				for (KeyGroupStatePartitionStreamProvider streamProvider : context.getRawKeyedStateInputs()) {
+					try (InputStream in = streamProvider.getStream()) {
+						DataInputView div = new DataInputViewStreamWrapper(in);
+						Assert.assertEquals(streamProvider.getKeyGroupId() + 2, div.readInt());
+						++count;
+					}
+				}
+				Assert.assertEquals(KeyedOneInputStreamOperatorTestHarness.MAX_PARALLELISM, count);
+
+				// check restored managed operator state
+				BitSet check = new BitSet(10);
+				for (int v : opState.get()) {
+					check.set(v);
+				}
+
+				Assert.assertEquals(10, check.cardinality());
+
+				// check restored raw operator state
+				check = new BitSet(13);
+				for (StatePartitionStreamProvider streamProvider : context.getRawOperatorStateInputs()) {
+					try (InputStream in = streamProvider.getStream()) {
+						DataInputView div = new DataInputViewStreamWrapper(in);
+						check.set(div.readInt() - 42);
+					}
+				}
+				Assert.assertEquals(13, check.cardinality());
+			}
+		}
+	}
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java
index 02409a3..155a16f 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java
@@ -37,11 +37,14 @@ import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
 import org.apache.flink.runtime.query.KvStateRegistry;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
+import org.apache.flink.runtime.state.DefaultKeyedStateStore;
 import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.junit.Test;
+import org.mockito.Matchers;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
 
@@ -53,6 +56,7 @@ import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
@@ -160,18 +164,24 @@ public class StreamingRuntimeContextTest {
 			final AtomicReference<Object> ref, final ExecutionConfig config) throws Exception {
 		
 		AbstractStreamOperator<?> operatorMock = mock(AbstractStreamOperator.class);
+
+		KeyedStateBackend keyedStateBackend= mock(KeyedStateBackend.class);
+
+		DefaultKeyedStateStore keyedStateStore = new DefaultKeyedStateStore(keyedStateBackend, config);
+
 		when(operatorMock.getExecutionConfig()).thenReturn(config);
-		
-		when(operatorMock.getPartitionedState(any(StateDescriptor.class))).thenAnswer(
-				new Answer<Object>() {
-					
-					@Override
-					public Object answer(InvocationOnMock invocationOnMock) throws Throwable {
-						ref.set(invocationOnMock.getArguments()[0]);
-						return null;
-					}
-				});
-		
+
+		doAnswer(new Answer<Object>() {
+
+			@Override
+			public Object answer(InvocationOnMock invocationOnMock) throws Throwable {
+				ref.set(invocationOnMock.getArguments()[2]);
+				return null;
+			}
+		}).when(keyedStateBackend).getPartitionedState(Matchers.any(), any(TypeSerializer.class), any(StateDescriptor.class));
+
+		when(operatorMock.getKeyedStateStore()).thenReturn(keyedStateStore);
+
 		return operatorMock;
 	}
 
@@ -179,29 +189,35 @@ public class StreamingRuntimeContextTest {
 	private static AbstractStreamOperator<?> createPlainMockOp() throws Exception {
 
 		AbstractStreamOperator<?> operatorMock = mock(AbstractStreamOperator.class);
-		when(operatorMock.getExecutionConfig()).thenReturn(new ExecutionConfig());
-
-		when(operatorMock.getPartitionedState(any(ListStateDescriptor.class))).thenAnswer(
-				new Answer<ListState<String>>() {
-
-					@Override
-					public ListState<String> answer(InvocationOnMock invocationOnMock) throws Throwable {
-						ListStateDescriptor<String> descr =
-								(ListStateDescriptor<String>) invocationOnMock.getArguments()[0];
-
-						AbstractKeyedStateBackend<Integer> backend = new MemoryStateBackend().createKeyedStateBackend(
-								new DummyEnvironment("test_task", 1, 0),
-								new JobID(),
-								"test_op",
-								IntSerializer.INSTANCE,
-								1,
-								new KeyGroupRange(0, 0),
-								new KvStateRegistry().createTaskRegistry(new JobID(), new JobVertexID()));
-						backend.setCurrentKey(0);
-						return backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, descr);
-					}
-				});
+		ExecutionConfig config = new ExecutionConfig();
+
+		KeyedStateBackend keyedStateBackend= mock(KeyedStateBackend.class);
+
+		DefaultKeyedStateStore keyedStateStore = new DefaultKeyedStateStore(keyedStateBackend, config);
+
+		when(operatorMock.getExecutionConfig()).thenReturn(config);
 
+		doAnswer(new Answer<ListState<String>>() {
+
+			@Override
+			public ListState<String> answer(InvocationOnMock invocationOnMock) throws Throwable {
+				ListStateDescriptor<String> descr =
+						(ListStateDescriptor<String>) invocationOnMock.getArguments()[2];
+
+				AbstractKeyedStateBackend<Integer> backend = new MemoryStateBackend().createKeyedStateBackend(
+						new DummyEnvironment("test_task", 1, 0),
+						new JobID(),
+						"test_op",
+						IntSerializer.INSTANCE,
+						1,
+						new KeyGroupRange(0, 0),
+						new KvStateRegistry().createTaskRegistry(new JobID(), new JobVertexID()));
+				backend.setCurrentKey(0);
+				return backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, descr);
+			}
+		}).when(keyedStateBackend).getPartitionedState(Matchers.any(), any(TypeSerializer.class), any(ListStateDescriptor.class));
+
+		when(operatorMock.getKeyedStateStore()).thenReturn(keyedStateStore);
 		return operatorMock;
 	}
 	

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java
index 59242e8..f2fc876 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java
@@ -32,6 +32,7 @@ import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.TaskStateHandles;
 import org.junit.AfterClass;
 import org.junit.BeforeClass;
 import org.junit.Test;
@@ -973,10 +974,7 @@ public class BarrierBufferTest {
 		}
 
 		@Override
-		public void setInitialState(
-				ChainedStateHandle<StreamStateHandle> chainedState,
-				List<KeyGroupsStateHandle> keyGroupsState,
-				List<Collection<OperatorStateHandle>> partitionableOperatorState) throws Exception {
+		public void setInitialState(TaskStateHandles taskStateHandles) throws Exception {
 			throw new UnsupportedOperationException("should never be called");
 		}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java
index b6d0450..7cfbb66 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java
@@ -29,6 +29,7 @@ import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.TaskStateHandles;
 import org.junit.Test;
 
 import java.util.Arrays;
@@ -366,10 +367,7 @@ public class BarrierTrackerTest {
 		}
 
 		@Override
-		public void setInitialState(
-				ChainedStateHandle<StreamStateHandle> chainedState,
-				List<KeyGroupsStateHandle> keyGroupsState,
-				List<Collection<OperatorStateHandle>> partitionableOperatorState) throws Exception {
+		public void setInitialState(TaskStateHandles taskStateHandles) throws Exception {
 
 			throw new UnsupportedOperationException("should never be called");
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSinkTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSinkTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSinkTest.java
index d1ba489..e5e26e9 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSinkTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSinkTest.java
@@ -129,7 +129,7 @@ public class GenericWriteAheadSinkTest extends WriteAheadSinkTestBase<Tuple1<Int
 			elementCounter++;
 		}
 
-		testHarness.snapshot(0, 0);
+		testHarness.snapshotLegacy(0, 0);
 		testHarness.notifyOfCompletedCheckpoint(0);
 
 		//isCommitted should have failed, thus sendValues() should never have been called
@@ -140,7 +140,7 @@ public class GenericWriteAheadSinkTest extends WriteAheadSinkTestBase<Tuple1<Int
 			elementCounter++;
 		}
 
-		testHarness.snapshot(1, 0);
+		testHarness.snapshotLegacy(1, 0);
 		testHarness.notifyOfCompletedCheckpoint(1);
 
 		//previous CP should be retried, but will fail the CP commit. Second CP should be skipped.
@@ -151,7 +151,7 @@ public class GenericWriteAheadSinkTest extends WriteAheadSinkTestBase<Tuple1<Int
 			elementCounter++;
 		}
 
-		testHarness.snapshot(2, 0);
+		testHarness.snapshotLegacy(2, 0);
 		testHarness.notifyOfCompletedCheckpoint(2);
 
 		//all CP's should be retried and succeed; since one CP was written twice we have 2 * 10 + 10 + 10 = 40 values

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/WriteAheadSinkTestBase.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/WriteAheadSinkTestBase.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/WriteAheadSinkTestBase.java
index b7203b5..3d1e6e8 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/WriteAheadSinkTestBase.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/WriteAheadSinkTestBase.java
@@ -66,7 +66,7 @@ public abstract class WriteAheadSinkTestBase<IN, S extends GenericWriteAheadSink
 			elementCounter++;
 		}
 
-		testHarness.snapshot(snapshotCount++, 0);
+		testHarness.snapshotLegacy(snapshotCount++, 0);
 		testHarness.notifyOfCompletedCheckpoint(snapshotCount - 1);
 
 		for (int x = 0; x < 20; x++) {
@@ -74,7 +74,7 @@ public abstract class WriteAheadSinkTestBase<IN, S extends GenericWriteAheadSink
 			elementCounter++;
 		}
 
-		testHarness.snapshot(snapshotCount++, 0);
+		testHarness.snapshotLegacy(snapshotCount++, 0);
 		testHarness.notifyOfCompletedCheckpoint(snapshotCount - 1);
 
 		for (int x = 0; x < 20; x++) {
@@ -82,7 +82,7 @@ public abstract class WriteAheadSinkTestBase<IN, S extends GenericWriteAheadSink
 			elementCounter++;
 		}
 
-		testHarness.snapshot(snapshotCount++, 0);
+		testHarness.snapshotLegacy(snapshotCount++, 0);
 		testHarness.notifyOfCompletedCheckpoint(snapshotCount - 1);
 
 		verifyResultsIdealCircumstances(testHarness, sink);
@@ -105,7 +105,7 @@ public abstract class WriteAheadSinkTestBase<IN, S extends GenericWriteAheadSink
 			elementCounter++;
 		}
 
-		testHarness.snapshot(snapshotCount++, 0);
+		testHarness.snapshotLegacy(snapshotCount++, 0);
 		testHarness.notifyOfCompletedCheckpoint(snapshotCount - 1);
 
 		for (int x = 0; x < 20; x++) {
@@ -113,14 +113,14 @@ public abstract class WriteAheadSinkTestBase<IN, S extends GenericWriteAheadSink
 			elementCounter++;
 		}
 
-		testHarness.snapshot(snapshotCount++, 0);
+		testHarness.snapshotLegacy(snapshotCount++, 0);
 
 		for (int x = 0; x < 20; x++) {
 			testHarness.processElement(new StreamRecord<>(generateValue(elementCounter, 2)));
 			elementCounter++;
 		}
 
-		testHarness.snapshot(snapshotCount++, 0);
+		testHarness.snapshotLegacy(snapshotCount++, 0);
 		testHarness.notifyOfCompletedCheckpoint(snapshotCount - 1);
 
 		verifyResultsDataPersistenceUponMissedNotify(testHarness, sink);
@@ -143,7 +143,7 @@ public abstract class WriteAheadSinkTestBase<IN, S extends GenericWriteAheadSink
 			elementCounter++;
 		}
 
-		StreamStateHandle latestSnapshot = testHarness.snapshot(snapshotCount++, 0);
+		StreamStateHandle latestSnapshot = testHarness.snapshotLegacy(snapshotCount++, 0);
 		testHarness.notifyOfCompletedCheckpoint(snapshotCount - 1);
 
 		for (int x = 0; x < 20; x++) {
@@ -166,7 +166,7 @@ public abstract class WriteAheadSinkTestBase<IN, S extends GenericWriteAheadSink
 			elementCounter++;
 		}
 
-		testHarness.snapshot(snapshotCount++, 0);
+		testHarness.snapshotLegacy(snapshotCount++, 0);
 		testHarness.notifyOfCompletedCheckpoint(snapshotCount - 1);
 
 		verifyResultsDataDiscardingUponRestore(testHarness, sink);

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java
index 51e61a1..e96109e 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java
@@ -522,7 +522,7 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest {
 
 			// draw a snapshot and dispose the window
 			int beforeSnapShot = testHarness.getOutput().size();
-			StreamStateHandle state = testHarness.snapshot(1L, System.currentTimeMillis());
+			StreamStateHandle state = testHarness.snapshotLegacy(1L, System.currentTimeMillis());
 			List<Integer> resultAtSnapshot = extractFromStreamRecords(testHarness.getOutput());
 			int afterSnapShot = testHarness.getOutput().size();
 			assertEquals("operator performed computation during snapshot", beforeSnapShot, afterSnapShot);
@@ -611,7 +611,7 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest {
 			// draw a snapshot
 			List<Integer> resultAtSnapshot = extractFromStreamRecords(testHarness.getOutput());
 			int beforeSnapShot = testHarness.getOutput().size();
-			StreamStateHandle state = testHarness.snapshot(1L, System.currentTimeMillis());
+			StreamStateHandle state = testHarness.snapshotLegacy(1L, System.currentTimeMillis());
 			int afterSnapShot = testHarness.getOutput().size();
 			assertEquals("operator performed computation during snapshot", beforeSnapShot, afterSnapShot);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java
index 12a842f..802329b 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java
@@ -634,7 +634,7 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest {
 			// draw a snapshot
 			List<Tuple2<Integer, Integer>> resultAtSnapshot = extractFromStreamRecords(testHarness.getOutput());
 			int beforeSnapShot = resultAtSnapshot.size();
-			StreamStateHandle state = testHarness.snapshot(1L, System.currentTimeMillis());
+			StreamStateHandle state = testHarness.snapshotLegacy(1L, System.currentTimeMillis());
 			int afterSnapShot = testHarness.getOutput().size();
 			assertEquals("operator performed computation during snapshot", beforeSnapShot, afterSnapShot);
 
@@ -727,7 +727,7 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest {
 			// draw a snapshot
 			List<Tuple2<Integer, Integer>> resultAtSnapshot = extractFromStreamRecords(testHarness.getOutput());
 			int beforeSnapShot = resultAtSnapshot.size();
-			StreamStateHandle state = testHarness.snapshot(1L, System.currentTimeMillis());
+			StreamStateHandle state = testHarness.snapshotLegacy(1L, System.currentTimeMillis());
 			int afterSnapShot = testHarness.getOutput().size();
 			assertEquals("operator performed computation during snapshot", beforeSnapShot, afterSnapShot);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperatorTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperatorTest.java
index ba803e3..2b0b915 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperatorTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperatorTest.java
@@ -125,7 +125,7 @@ public class WindowOperatorTest extends TestLogger {
 		TestHarnessUtil.assertOutputEqualsSorted("Output was not correct.", expectedOutput, testHarness.getOutput(), new Tuple2ResultSortComparator());
 
 		// do a snapshot, close and restore again
-		StreamStateHandle snapshot = testHarness.snapshot(0L, 0L);
+		StreamStateHandle snapshot = testHarness.snapshotLegacy(0L, 0L);
 		testHarness.close();
 		testHarness.setup();
 		testHarness.restore(snapshot);
@@ -255,7 +255,7 @@ public class WindowOperatorTest extends TestLogger {
 		TestHarnessUtil.assertOutputEqualsSorted("Output was not correct.", expectedOutput, testHarness.getOutput(), new Tuple2ResultSortComparator());
 
 		// do a snapshot, close and restore again
-		StreamStateHandle snapshot = testHarness.snapshot(0L, 0L);
+		StreamStateHandle snapshot = testHarness.snapshotLegacy(0L, 0L);
 		testHarness.close();
 		testHarness.setup();
 		testHarness.restore(snapshot);
@@ -396,7 +396,7 @@ public class WindowOperatorTest extends TestLogger {
 		testHarness.processElement(new StreamRecord<>(new Tuple2<>("key1", 2), 1000));
 
 		// do a snapshot, close and restore again
-		StreamStateHandle snapshot = testHarness.snapshot(0L, 0L);
+		StreamStateHandle snapshot = testHarness.snapshotLegacy(0L, 0L);
 		testHarness.close();
 		testHarness.setup();
 		testHarness.restore(snapshot);
@@ -465,7 +465,7 @@ public class WindowOperatorTest extends TestLogger {
 		testHarness.processElement(new StreamRecord<>(new Tuple2<>("key2", 3), 2500));
 
 		// do a snapshot, close and restore again
-		StreamStateHandle snapshot = testHarness.snapshot(0L, 0L);
+		StreamStateHandle snapshot = testHarness.snapshotLegacy(0L, 0L);
 		testHarness.close();
 		testHarness.setup();
 		testHarness.restore(snapshot);
@@ -543,7 +543,7 @@ public class WindowOperatorTest extends TestLogger {
 		testHarness.processElement(new StreamRecord<>(new Tuple2<>("key1", 2), 1000));
 
 		// do a snapshot, close and restore again
-		StreamStateHandle snapshot = testHarness.snapshot(0L, 0L);
+		StreamStateHandle snapshot = testHarness.snapshotLegacy(0L, 0L);
 		testHarness.close();
 		testHarness.setup();
 		testHarness.restore(snapshot);
@@ -641,7 +641,7 @@ public class WindowOperatorTest extends TestLogger {
 		testHarness.processElement(new StreamRecord<>(new Tuple2<>("key2", 33), 1000));
 
 		// do a snapshot, close and restore again
-		StreamStateHandle snapshot = testHarness.snapshot(0L, 0L);
+		StreamStateHandle snapshot = testHarness.snapshotLegacy(0L, 0L);
 		testHarness.close();
 		testHarness.setup();
 		testHarness.restore(snapshot);
@@ -796,7 +796,7 @@ public class WindowOperatorTest extends TestLogger {
 		testHarness.processElement(new StreamRecord<>(new Tuple2<>("key2", 1), 1999));
 
 		// do a snapshot, close and restore again
-		StreamStateHandle snapshot = testHarness.snapshot(0L, 0L);
+		StreamStateHandle snapshot = testHarness.snapshotLegacy(0L, 0L);
 
 		testHarness.close();
 
@@ -884,7 +884,7 @@ public class WindowOperatorTest extends TestLogger {
 		operator.processingTimeTimersQueue.add(timer2);
 		operator.processingTimeTimersQueue.add(timer3);
 		
-		StreamStateHandle snapshot = testHarness.snapshot(0, 0);
+		StreamStateHandle snapshot = testHarness.snapshotLegacy(0, 0);
 
 		WindowOperator<String, Tuple2<String, Integer>, Tuple2<String, Integer>, Tuple2<String, Integer>, TimeWindow> otherOperator = new WindowOperator<>(
 				SlidingEventTimeWindows.of(Time.of(WINDOW_SIZE, TimeUnit.SECONDS), Time.of(WINDOW_SLIDE, TimeUnit.SECONDS)),

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
index b5b6582..ee5a203 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
@@ -45,6 +45,7 @@ import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.taskmanager.CheckpointResponder;
 import org.apache.flink.runtime.taskmanager.Task;
 import org.apache.flink.runtime.taskmanager.TaskManagerConnection;
@@ -124,8 +125,17 @@ public class InterruptSensitiveRestoreTest {
 			StreamStateHandle state) throws IOException {
 
 		ChainedStateHandle<StreamStateHandle> operatorState = new ChainedStateHandle<>(Collections.singletonList(state));
-		List<KeyGroupsStateHandle> keyGroupState = Collections.emptyList();
-		List<Collection<OperatorStateHandle>> partitionableOperatorState = Collections.emptyList();
+		List<KeyGroupsStateHandle> keyGroupStateFromBackend = Collections.emptyList();
+		List<KeyGroupsStateHandle> keyGroupStateFromStream = Collections.emptyList();
+		List<Collection<OperatorStateHandle>> operatorStateBackend = Collections.emptyList();
+		List<Collection<OperatorStateHandle>> operatorStateStream = Collections.emptyList();
+
+		TaskStateHandles taskStateHandles = new TaskStateHandles(
+				operatorState,
+				operatorStateBackend,
+				operatorStateStream,
+				keyGroupStateFromBackend,
+				keyGroupStateFromStream);
 
 		return new TaskDeploymentDescriptor(
 				new JobID(),
@@ -143,9 +153,7 @@ public class InterruptSensitiveRestoreTest {
 				Collections.<BlobKey>emptyList(),
 				Collections.<URL>emptyList(),
 				0,
-				operatorState,
-				keyGroupState,
-				partitionableOperatorState);
+				taskStateHandles);
 	}
 
 	private static Task createTask(TaskDeploymentDescriptor tdd) throws IOException {

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
index 1b2b723..3dd2ed7 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
@@ -31,15 +31,13 @@ import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.core.testutils.OneShotLatch;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.checkpoint.SubtaskState;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
-import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
-import org.apache.flink.runtime.state.CheckpointStreamFactory;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.OperatorStateHandle;
-import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.graph.StreamEdge;
 import org.apache.flink.streaming.api.graph.StreamNode;
@@ -64,8 +62,6 @@ import scala.concurrent.duration.FiniteDuration;
 
 import java.io.Serializable;
 import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -74,7 +70,6 @@ import java.util.Map;
 import java.util.Random;
 import java.util.Set;
 import java.util.concurrent.ConcurrentLinkedQueue;
-import java.util.concurrent.RunnableFuture;
 import java.util.concurrent.TimeUnit;
 
 import static org.junit.Assert.assertEquals;
@@ -390,7 +385,7 @@ public class OneInputStreamTaskTest extends TestLogger {
 		testHarness.waitForTaskCompletion(deadline.timeLeft().toMillis());
 
 		final OneInputStreamTask<String, String> restoredTask = new OneInputStreamTask<String, String>();
-		restoredTask.setInitialState(env.getState(), env.getKeyGroupStates(), env.getPartitionableOperatorState());
+		restoredTask.setInitialState(new TaskStateHandles(env.getCheckpointStateHandles()));
 
 		final OneInputStreamTaskTestHarness<String, String> restoredTaskHarness = new OneInputStreamTaskTestHarness<String, String>(restoredTask, BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO);
 		restoredTaskHarness.configureForKeyedStream(keySelector, BasicTypeInfo.STRING_TYPE_INFO);
@@ -482,9 +477,7 @@ public class OneInputStreamTaskTest extends TestLogger {
 
 	private static class AcknowledgeStreamMockEnvironment extends StreamMockEnvironment {
 		private volatile long checkpointId;
-		private volatile ChainedStateHandle<StreamStateHandle> state;
-		private volatile List<KeyGroupsStateHandle> keyGroupStates;
-		private volatile List<Collection<OperatorStateHandle>> partitionableOperatorState;
+		private volatile SubtaskState checkpointStateHandles;
 
 		private final OneShotLatch checkpointLatch = new OneShotLatch();
 
@@ -492,24 +485,6 @@ public class OneInputStreamTaskTest extends TestLogger {
 			return checkpointId;
 		}
 
-		public ChainedStateHandle<StreamStateHandle> getState() {
-			return state;
-		}
-
-		List<KeyGroupsStateHandle> getKeyGroupStates() {
-			List<KeyGroupsStateHandle> result = new ArrayList<>();
-			for (KeyGroupsStateHandle keyGroupState : keyGroupStates) {
-				if (keyGroupState != null) {
-					result.add(keyGroupState);
-				}
-			}
-			return result;
-		}
-
-		List<Collection<OperatorStateHandle>> getPartitionableOperatorState() {
-			return partitionableOperatorState;
-		}
-
 		AcknowledgeStreamMockEnvironment(
 				Configuration jobConfig, Configuration taskConfig,
 				ExecutionConfig executionConfig, long memorySize,
@@ -521,26 +496,20 @@ public class OneInputStreamTaskTest extends TestLogger {
 		@Override
 		public void acknowledgeCheckpoint(
 				CheckpointMetaData checkpointMetaData,
-				CheckpointStateHandles checkpointStateHandles) {
+				SubtaskState checkpointStateHandles) {
 
 			this.checkpointId = checkpointMetaData.getCheckpointId();
-			if(checkpointStateHandles != null) {
-				this.state = checkpointStateHandles.getNonPartitionedStateHandles();
-				this.keyGroupStates = checkpointStateHandles.getKeyGroupsStateHandle();
-				ChainedStateHandle<OperatorStateHandle> chainedStateHandle = checkpointStateHandles.getPartitioneableStateHandles();
-				Collection<OperatorStateHandle>[] ia = new Collection[chainedStateHandle.getLength()];
-				this.partitionableOperatorState = Arrays.asList(ia);
-
-				for (int i = 0; i < chainedStateHandle.getLength(); ++i) {
-					partitionableOperatorState.set(i, Collections.singletonList(chainedStateHandle.get(i)));
-				}
-			}
+			this.checkpointStateHandles = checkpointStateHandles;
 			checkpointLatch.trigger();
 		}
 
 		public OneShotLatch getCheckpointLatch() {
 			return checkpointLatch;
 		}
+
+		public SubtaskState getCheckpointStateHandles() {
+			return checkpointStateHandles;
+		}
 	}
 
 	private static class TestingStreamOperator<IN, OUT>
@@ -580,9 +549,7 @@ public class OneInputStreamTaskTest extends TestLogger {
 		}
 
 		@Override
-		public RunnableFuture<OperatorStateHandle> snapshotState(
-				long checkpointId, long timestamp, CheckpointStreamFactory streamFactory) throws Exception {
-
+		public void snapshotState(StateSnapshotContext context) throws Exception {
 			ListState<Integer> partitionableState =
 					getOperatorStateBackend().getOperatorState(TEST_DESCRIPTOR);
 			partitionableState.clear();
@@ -591,7 +558,11 @@ public class OneInputStreamTaskTest extends TestLogger {
 			partitionableState.add(4711);
 
 			++numberSnapshotCalls;
-			return super.snapshotState(checkpointId, timestamp, streamFactory);
+		}
+
+		@Override
+		public void initializeState(StateInitializationContext context) throws Exception {
+
 		}
 
 		TestingStreamOperator(long seed, long recoveryTimestamp) {
@@ -643,7 +614,6 @@ public class OneInputStreamTaskTest extends TestLogger {
 			assertEquals(random.nextInt(), (int) operatorState);
 		}
 
-
 		private Serializable generateFunctionState() {
 			return random.nextInt();
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
index f852682..36ecf59 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
@@ -28,6 +28,7 @@ import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.runtime.accumulators.AccumulatorRegistry;
 import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.checkpoint.SubtaskState;
 import org.apache.flink.runtime.event.AbstractEvent;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
@@ -50,7 +51,6 @@ import org.apache.flink.runtime.plugable.DeserializationDelegate;
 import org.apache.flink.runtime.plugable.NonReusingDeserializationDelegate;
 import org.apache.flink.runtime.query.KvStateRegistry;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
@@ -123,7 +123,7 @@ public class StreamMockEnvironment implements Environment {
 
 	public StreamMockEnvironment(Configuration jobConfig, Configuration taskConfig, long memorySize,
 								 MockInputSplitProvider inputSplitProvider, int bufferSize) {
-		this(jobConfig, taskConfig, null, memorySize, inputSplitProvider, bufferSize);
+		this(jobConfig, taskConfig, new ExecutionConfig(), memorySize, inputSplitProvider, bufferSize);
 	}
 
 	public void addInputGate(InputGate gate) {
@@ -313,7 +313,7 @@ public class StreamMockEnvironment implements Environment {
 
 	@Override
 	public void acknowledgeCheckpoint(
-			CheckpointMetaData checkpointMetaData, CheckpointStateHandles checkpointStateHandles) {
+			CheckpointMetaData checkpointMetaData, SubtaskState subtaskState) {
 	}
 
 	@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java
index b9211b1..1bb3fb0 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java
@@ -214,6 +214,17 @@ public class TwoInputStreamTaskTest {
 		expectedOutput.add(new StreamRecord<String>("111", initialTime));
 
 		testHarness.waitForInputProcessing();
+
+		// Wait to allow input to end up in the output.
+		// TODO Use count down latches instead as a cleaner solution
+		for (int i = 0; i < 20; ++i) {
+			if (testHarness.getOutput().size() >= expectedOutput.size()) {
+				break;
+			} else {
+				Thread.sleep(100);
+			}
+		}
+
 		// we should not yet see the barrier, only the two elements from non-blocked input
 		TestHarnessUtil.assertOutputEquals("Output was not correct.",
 				testHarness.getOutput(),
@@ -224,17 +235,17 @@ public class TwoInputStreamTaskTest {
 		testHarness.processEvent(new CheckpointBarrier(0, 0), 1, 1);
 
 		testHarness.waitForInputProcessing();
+		testHarness.endInput();
+		testHarness.waitForTaskCompletion();
 
 		// now we should see the barrier and after that the buffered elements
 		expectedOutput.add(new CheckpointBarrier(0, 0));
 		expectedOutput.add(new StreamRecord<String>("Hello-0-0", initialTime));
-		TestHarnessUtil.assertOutputEquals("Output was not correct.",
-				testHarness.getOutput(),
-				expectedOutput);
 
-		testHarness.endInput();
+		TestHarnessUtil.assertOutputEquals("Output was not correct.",
+				expectedOutput,
+				testHarness.getOutput());
 
-		testHarness.waitForTaskCompletion();
 
 		List<String> resultElements = TestHarnessUtil.getRawElementsFromOutput(testHarness.getOutput());
 		Assert.assertEquals(4, resultElements.size());

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
index 5275a39..41968e6 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
@@ -31,14 +31,16 @@ import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
-import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
+import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
 import org.apache.flink.streaming.runtime.tasks.TestTimeServiceProvider;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
 
 import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.concurrent.RunnableFuture;
 
@@ -60,7 +62,7 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 
 	// when we restore we keep the state here so that we can call restore
 	// when the operator requests the keyed state backend
-	private KeyGroupsStateHandle restoredKeyedState = null;
+	private Collection<KeyGroupsStateHandle> restoredKeyedState = null;
 
 	public KeyedOneInputStreamOperatorTestHarness(
 			OneInputStreamOperator<IN, OUT> operator,
@@ -138,7 +140,7 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 								keySerializer,
 								numberOfKeyGroups,
 								keyGroupRange,
-								Collections.singletonList(restoredKeyedState),
+								restoredKeyedState,
 								mockTask.getEnvironment().getTaskKvStateRegistry());
 						restoredKeyedState = null;
 						return keyedStateBackend;
@@ -154,7 +156,7 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 	 *
 	 */
 	@Override
-	public StreamStateHandle snapshot(long checkpointId, long timestamp) throws Exception {
+	public StreamStateHandle snapshotLegacy(long checkpointId, long timestamp) throws Exception {
 		// simply use an in-memory handle
 		MemoryStateBackend backend = new MemoryStateBackend();
 
@@ -185,7 +187,7 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 	}
 
 	/**
-	 * 
+	 *
 	 */
 	@Override
 	public void restore(StreamStateHandle snapshot) throws Exception {
@@ -198,7 +200,7 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 			byte keyedStatePresent = (byte) inStream.read();
 			if (keyedStatePresent == 1) {
 				ObjectInputStream ois = new ObjectInputStream(inStream);
-				this.restoredKeyedState = (KeyGroupsStateHandle) ois.readObject();
+				this.restoredKeyedState = Collections.singletonList((KeyGroupsStateHandle) ois.readObject());
 			}
 		}
 	}
@@ -208,8 +210,16 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 	 */
 	public void close() throws Exception {
 		super.close();
-		if(keyedStateBackend != null) {
+		if (keyedStateBackend != null) {
 			keyedStateBackend.dispose();
 		}
 	}
+
+	@Override
+	public void initializeState(OperatorStateHandles operatorStateHandles) throws Exception {
+		if (null != operatorStateHandles) {
+			this.restoredKeyedState = operatorStateHandles.getManagedKeyedState();
+		}
+		super.initializeState(operatorStateHandles);
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java
index d1622ff..9f8d223 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java
@@ -28,22 +28,26 @@ import org.apache.flink.runtime.operators.testutils.MockEnvironment;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.runtime.state.ClosableRegistry;
+import org.apache.flink.runtime.state.OperatorStateBackend;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
-import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
 import org.apache.flink.streaming.api.TimeCharacteristic;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OperatorSnapshotResult;
 import org.apache.flink.streaming.api.operators.Output;
+import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
 import org.apache.flink.streaming.api.operators.StreamOperator;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.DefaultTimeServiceProvider;
+import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
 import org.apache.flink.streaming.runtime.tasks.TestTimeServiceProvider;
 import org.apache.flink.streaming.runtime.tasks.TimeServiceProvider;
-
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
 
@@ -65,7 +69,7 @@ import static org.mockito.Mockito.when;
  */
 public class OneInputStreamOperatorTestHarness<IN, OUT> {
 
-	protected static final int MAX_PARALLELISM = 10;
+	public static final int MAX_PARALLELISM = 10;
 
 	final OneInputStreamOperator<IN, OUT> operator;
 
@@ -81,6 +85,8 @@ public class OneInputStreamOperatorTestHarness<IN, OUT> {
 
 	StreamTask<?, ?> mockTask;
 
+	ClosableRegistry closableRegistry;
+
 	// use this as default for tests
 	AbstractStateBackend stateBackend = new MemoryStateBackend();
 
@@ -88,6 +94,7 @@ public class OneInputStreamOperatorTestHarness<IN, OUT> {
 	 * Whether setup() was called on the operator. This is reset when calling close().
 	 */
 	private boolean setupCalled = false;
+	private boolean initializeCalled = false;
 
 	private volatile boolean wasFailedExternally = false;
 
@@ -121,6 +128,7 @@ public class OneInputStreamOperatorTestHarness<IN, OUT> {
 		this.config.setCheckpointingEnabled(true);
 		this.executionConfig = executionConfig;
 		this.checkpointLock = checkpointLock;
+		this.closableRegistry = new ClosableRegistry();
 
 		final Environment env = new MockEnvironment("MockTwoInputTask", 3 * 1024 * 1024, new MockInputSplitProvider(), 1024, underlyingConfig, executionConfig, MAX_PARALLELISM, 1, 0);
 		mockTask = mock(StreamTask.class);
@@ -132,6 +140,7 @@ public class OneInputStreamOperatorTestHarness<IN, OUT> {
 		when(mockTask.getEnvironment()).thenReturn(env);
 		when(mockTask.getExecutionConfig()).thenReturn(executionConfig);
 		when(mockTask.getUserCodeClassLoader()).thenReturn(this.getClass().getClassLoader());
+		when(mockTask.getCancelables()).thenReturn(this.closableRegistry);
 
 		doAnswer(new Answer<Void>() {
 			@Override
@@ -154,6 +163,26 @@ public class OneInputStreamOperatorTestHarness<IN, OUT> {
 			throw new RuntimeException(e.getMessage(), e);
 		}
 
+		try {
+			doAnswer(new Answer<OperatorStateBackend>() {
+				@Override
+				public OperatorStateBackend answer(InvocationOnMock invocationOnMock) throws Throwable {
+					final StreamOperator<?> operator = (StreamOperator<?>) invocationOnMock.getArguments()[0];
+					final Collection<OperatorStateHandle> stateHandles = (Collection<OperatorStateHandle>) invocationOnMock.getArguments()[1];
+					OperatorStateBackend osb;
+					if (null == stateHandles) {
+						osb = stateBackend.createOperatorStateBackend(env, operator.getClass().getSimpleName());
+					} else {
+						osb = stateBackend.restoreOperatorStateBackend(env, operator.getClass().getSimpleName(), stateHandles);
+					}
+					mockTask.getCancelables().registerClosable(osb);
+					return osb;
+				}
+			}).when(mockTask).createOperatorStateBackend(any(StreamOperator.class), any(Collection.class));
+		} catch (Exception e) {
+			throw new RuntimeException(e.getMessage(), e);
+		}
+
 		timeServiceProvider = testTimeProvider != null ? testTimeProvider :
 			new DefaultTimeServiceProvider(mockTask, this.checkpointLock);
 
@@ -199,8 +228,7 @@ public class OneInputStreamOperatorTestHarness<IN, OUT> {
 	}
 
 	/**
-	 * Calls
-	 * {@link org.apache.flink.streaming.api.operators.StreamOperator#setup(StreamTask, StreamConfig, Output)} ()}
+	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#setup(StreamTask, StreamConfig, Output)} ()}
 	 */
 	public void setup() throws Exception {
 		operator.setup(mockTask, config, new MockOutput());
@@ -208,21 +236,48 @@ public class OneInputStreamOperatorTestHarness<IN, OUT> {
 	}
 
 	/**
-	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#open()}. This also
-	 * calls {@link org.apache.flink.streaming.api.operators.StreamOperator#setup(StreamTask, StreamConfig, Output)}
+	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#initializeState(OperatorStateHandles)}.
+	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#setup(StreamTask, StreamConfig, Output)}
 	 * if it was not called before.
 	 */
-	public void open() throws Exception {
+	public void initializeState(OperatorStateHandles operatorStateHandles) throws Exception {
 		if (!setupCalled) {
 			setup();
 		}
+		operator.initializeState(operatorStateHandles);
+		initializeCalled = true;
+	}
+
+	/**
+	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#open()}.
+	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#initializeState(OperatorStateHandles)} if it
+	 * was not called before.
+	 */
+	public void open() throws Exception {
+		if (!initializeCalled) {
+			initializeState(null);
+		}
 		operator.open();
 	}
 
 	/**
 	 *
 	 */
-	public StreamStateHandle snapshot(long checkpointId, long timestamp) throws Exception {
+	public OperatorSnapshotResult snapshot(long checkpointId, long timestamp) throws Exception {
+
+		CheckpointStreamFactory streamFactory = stateBackend.createStreamFactory(
+				new JobID(),
+				"test_op");
+
+		return operator.snapshotState(checkpointId, timestamp, streamFactory);
+	}
+
+	/**
+	 *
+	 */
+	@Deprecated
+	public StreamStateHandle snapshotLegacy(long checkpointId, long timestamp) throws Exception {
+
 		CheckpointStreamFactory.CheckpointStateOutputStream outStream = stateBackend.createStreamFactory(
 				new JobID(),
 				"test_op").createCheckpointStateOutputStream(checkpointId, timestamp);
@@ -244,6 +299,7 @@ public class OneInputStreamOperatorTestHarness<IN, OUT> {
 	/**
 	 *
 	 */
+	@Deprecated
 	public void restore(StreamStateHandle snapshot) throws Exception {
 		if(operator instanceof StreamCheckpointedOperator) {
 			try (FSDataInputStream in = snapshot.openInputStream()) {

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/TwoInputStreamOperatorTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/TwoInputStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/TwoInputStreamOperatorTestHarness.java
index 32b4c77..7df6848 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/TwoInputStreamOperatorTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/TwoInputStreamOperatorTestHarness.java
@@ -25,12 +25,14 @@ import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.operators.testutils.MockEnvironment;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
+import org.apache.flink.runtime.state.ClosableRegistry;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.Output;
 import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
 
 import java.util.concurrent.ConcurrentLinkedQueue;
@@ -56,6 +58,10 @@ public class TwoInputStreamOperatorTestHarness<IN1, IN2, OUT> {
 
 	final Object checkpointLock;
 
+	final ClosableRegistry closableRegistry;
+
+	boolean initializeCalled = false;
+
 	public TwoInputStreamOperatorTestHarness(TwoInputStreamOperator<IN1, IN2, OUT> operator) {
 		this(operator, new StreamConfig(new Configuration()));
 	}
@@ -65,6 +71,7 @@ public class TwoInputStreamOperatorTestHarness<IN1, IN2, OUT> {
 		this.outputList = new ConcurrentLinkedQueue<Object>();
 		this.executionConfig = new ExecutionConfig();
 		this.checkpointLock = new Object();
+		this.closableRegistry = new ClosableRegistry();
 
 		Environment env = new MockEnvironment("MockTwoInputTask", 3 * 1024 * 1024, new MockInputSplitProvider(), 1024);
 		StreamTask<?, ?> mockTask = mock(StreamTask.class);
@@ -73,6 +80,7 @@ public class TwoInputStreamOperatorTestHarness<IN1, IN2, OUT> {
 		when(mockTask.getConfiguration()).thenReturn(config);
 		when(mockTask.getEnvironment()).thenReturn(env);
 		when(mockTask.getExecutionConfig()).thenReturn(executionConfig);
+		when(mockTask.getCancelables()).thenReturn(this.closableRegistry);
 
 		operator.setup(mockTask, new StreamConfig(new Configuration()), new MockOutput());
 	}
@@ -86,11 +94,22 @@ public class TwoInputStreamOperatorTestHarness<IN1, IN2, OUT> {
 		return outputList;
 	}
 
+	/**
+	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#initializeState(OperatorStateHandles)}.
+	 */
+	public void initializeState(OperatorStateHandles operatorStateHandles) throws Exception {
+		operator.initializeState(operatorStateHandles);
+		initializeCalled = true;
+	}
 
 	/**
 	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#open()}.
 	 */
 	public void open() throws Exception {
+		if(!initializeCalled) {
+			initializeState(mock(OperatorStateHandles.class));
+		}
+
 		operator.open();
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/WindowingTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/WindowingTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/WindowingTestHarness.java
index ab8b70f..a4e26f0 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/WindowingTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/WindowingTestHarness.java
@@ -165,7 +165,7 @@ public class WindowingTestHarness<K, IN, W extends Window> {
 	 * Takes a snapshot of the current state of the operator. This can be used to test fault-tolerance.
 	 */
 	public StreamStateHandle snapshot(long checkpointId, long timestamp) throws Exception {
-		return testHarness.snapshot(checkpointId, timestamp);
+		return testHarness.snapshotLegacy(checkpointId, timestamp);
 	}
 
 	/**


[7/8] flink git commit: [FLINK-4844] Partitionable Raw Keyed/Operator State

Posted by al...@apache.org.
http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
index 666176b..89f1f42 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
@@ -80,46 +80,13 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV1> {
 				dos.writeInt(taskState.getMaxParallelism());
 				dos.writeInt(taskState.getChainLength());
 
-				// Sub task non-partitionable states
+				// Sub task states
 				Map<Integer, SubtaskState> subtaskStateMap = taskState.getSubtaskStates();
 				dos.writeInt(subtaskStateMap.size());
 				for (Map.Entry<Integer, SubtaskState> entry : subtaskStateMap.entrySet()) {
 					dos.writeInt(entry.getKey());
-
-					SubtaskState subtaskState = entry.getValue();
-					ChainedStateHandle<StreamStateHandle> chainedStateHandle = subtaskState.getChainedStateHandle();
-					dos.writeInt(chainedStateHandle.getLength());
-					for (int j = 0; j < chainedStateHandle.getLength(); ++j) {
-						StreamStateHandle stateHandle = chainedStateHandle.get(j);
-						serializeStreamStateHandle(stateHandle, dos);
-					}
-
-					dos.writeLong(subtaskState.getDuration());
+					serializeSubtaskState(entry.getValue(), dos);
 				}
-
-				// Sub task partitionable states
-				Map<Integer, ChainedStateHandle<OperatorStateHandle>> partitionableStatesMap = taskState.getPartitionableStates();
-				dos.writeInt(partitionableStatesMap.size());
-
-				for (Map.Entry<Integer, ChainedStateHandle<OperatorStateHandle>> entry : partitionableStatesMap.entrySet()) {
-					dos.writeInt(entry.getKey());
-
-					ChainedStateHandle<OperatorStateHandle> chainedStateHandle = entry.getValue();
-					dos.writeInt(chainedStateHandle.getLength());
-					for (int j = 0; j < chainedStateHandle.getLength(); ++j) {
-						OperatorStateHandle stateHandle = chainedStateHandle.get(j);
-						serializePartitionableStateHandle(stateHandle, dos);
-					}
-				}
-
-				// Keyed state
-				Map<Integer, KeyGroupsStateHandle> keyGroupsStateHandles = taskState.getKeyGroupsStateHandles();
-				dos.writeInt(keyGroupsStateHandles.size());
-				for (Map.Entry<Integer, KeyGroupsStateHandle> entry : keyGroupsStateHandles.entrySet()) {
-					dos.writeInt(entry.getKey());
-					serializeKeyGroupStateHandle(entry.getValue(), dos);
-				}
-
 			}
 		} catch (Exception e) {
 			throw new IOException(e);
@@ -149,50 +116,99 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV1> {
 
 			for (int j = 0; j < numSubTaskStates; j++) {
 				int subtaskIndex = dis.readInt();
-				int chainedStateHandleSize = dis.readInt();
-				List<StreamStateHandle> streamStateHandleList = new ArrayList<>(chainedStateHandleSize);
-				for (int k = 0; k < chainedStateHandleSize; ++k) {
-					StreamStateHandle streamStateHandle = deserializeStreamStateHandle(dis);
-					streamStateHandleList.add(streamStateHandle);
-				}
-
-				long duration = dis.readLong();
-				ChainedStateHandle<StreamStateHandle> chainedStateHandle = new ChainedStateHandle<>(streamStateHandleList);
-				SubtaskState subtaskState = new SubtaskState(chainedStateHandle, duration);
+				SubtaskState subtaskState = deserializeSubtaskState(dis);
 				taskState.putState(subtaskIndex, subtaskState);
 			}
+		}
 
-			int numPartitionableOpStates = dis.readInt();
+		return new SavepointV1(checkpointId, taskStates);
+	}
 
-			for (int j = 0; j < numPartitionableOpStates; j++) {
-				int subtaskIndex = dis.readInt();
-				int chainedStateHandleSize = dis.readInt();
-				List<OperatorStateHandle> streamStateHandleList = new ArrayList<>(chainedStateHandleSize);
+	public static void serializeSubtaskState(SubtaskState subtaskState, DataOutputStream dos) throws IOException {
 
-				for (int k = 0; k < chainedStateHandleSize; ++k) {
-					OperatorStateHandle streamStateHandle = deserializePartitionableStateHandle(dis);
-					streamStateHandleList.add(streamStateHandle);
-				}
+		dos.writeLong(subtaskState.getDuration());
 
-				ChainedStateHandle<OperatorStateHandle> chainedStateHandle =
-						new ChainedStateHandle<>(streamStateHandleList);
+		ChainedStateHandle<StreamStateHandle> nonPartitionableState = subtaskState.getLegacyOperatorState();
 
-				taskState.putPartitionableState(subtaskIndex, chainedStateHandle);
-			}
+		int len = nonPartitionableState != null ? nonPartitionableState.getLength() : 0;
+		dos.writeInt(len);
+		for (int i = 0; i < len; ++i) {
+			StreamStateHandle stateHandle = nonPartitionableState.get(i);
+			serializeStreamStateHandle(stateHandle, dos);
+		}
 
-			// Key group states
-			int numKeyGroupStates = dis.readInt();
-			for (int j = 0; j < numKeyGroupStates; j++) {
-				int keyGroupIndex = dis.readInt();
+		ChainedStateHandle<OperatorStateHandle> operatorStateBackend = subtaskState.getManagedOperatorState();
 
-				KeyGroupsStateHandle keyGroupsStateHandle = deserializeKeyGroupStateHandle(dis);
-				if (keyGroupsStateHandle != null) {
-					taskState.putKeyedState(keyGroupIndex, keyGroupsStateHandle);
-				}
-			}
+		len = operatorStateBackend != null ? operatorStateBackend.getLength() : 0;
+		dos.writeInt(len);
+		for (int i = 0; i < len; ++i) {
+			OperatorStateHandle stateHandle = operatorStateBackend.get(i);
+			serializeOperatorStateHandle(stateHandle, dos);
 		}
 
-		return new SavepointV1(checkpointId, taskStates);
+		ChainedStateHandle<OperatorStateHandle> operatorStateFromStream = subtaskState.getRawOperatorState();
+
+		len = operatorStateFromStream != null ? operatorStateFromStream.getLength() : 0;
+		dos.writeInt(len);
+		for (int i = 0; i < len; ++i) {
+			OperatorStateHandle stateHandle = operatorStateFromStream.get(i);
+			serializeOperatorStateHandle(stateHandle, dos);
+		}
+
+		KeyGroupsStateHandle keyedStateBackend = subtaskState.getManagedKeyedState();
+		serializeKeyGroupStateHandle(keyedStateBackend, dos);
+
+		KeyGroupsStateHandle keyedStateStream = subtaskState.getRawKeyedState();
+		serializeKeyGroupStateHandle(keyedStateStream, dos);
+
+	}
+
+	public static SubtaskState deserializeSubtaskState(DataInputStream dis) throws IOException {
+
+		long duration = dis.readLong();
+
+		int len = dis.readInt();
+		List<StreamStateHandle> nonPartitionableState = new ArrayList<>(len);
+		for (int i = 0; i < len; ++i) {
+			StreamStateHandle streamStateHandle = deserializeStreamStateHandle(dis);
+			nonPartitionableState.add(streamStateHandle);
+		}
+
+
+		len = dis.readInt();
+		List<OperatorStateHandle> operatorStateBackend = new ArrayList<>(len);
+		for (int i = 0; i < len; ++i) {
+			OperatorStateHandle streamStateHandle = deserializeOperatorStateHandle(dis);
+			operatorStateBackend.add(streamStateHandle);
+		}
+
+		len = dis.readInt();
+		List<OperatorStateHandle> operatorStateStream = new ArrayList<>(len);
+		for (int i = 0; i < len; ++i) {
+			OperatorStateHandle streamStateHandle = deserializeOperatorStateHandle(dis);
+			operatorStateStream.add(streamStateHandle);
+		}
+
+		KeyGroupsStateHandle keyedStateBackend = deserializeKeyGroupStateHandle(dis);
+
+		KeyGroupsStateHandle keyedStateStream = deserializeKeyGroupStateHandle(dis);
+
+		ChainedStateHandle<StreamStateHandle> nonPartitionableStateChain =
+				new ChainedStateHandle<>(nonPartitionableState);
+
+		ChainedStateHandle<OperatorStateHandle> operatorStateBackendChain =
+				new ChainedStateHandle<>(operatorStateBackend);
+
+		ChainedStateHandle<OperatorStateHandle> operatorStateStreamChain =
+				new ChainedStateHandle<>(operatorStateStream);
+
+		return new SubtaskState(
+				nonPartitionableStateChain,
+				operatorStateBackendChain,
+				operatorStateStreamChain,
+				keyedStateBackend,
+				keyedStateStream,
+				duration);
 	}
 
 	public static void serializeKeyGroupStateHandle(
@@ -231,7 +247,7 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV1> {
 		}
 	}
 
-	public static void serializePartitionableStateHandle(
+	public static void serializeOperatorStateHandle(
 			OperatorStateHandle stateHandle, DataOutputStream dos) throws IOException {
 
 		if (stateHandle != null) {
@@ -252,7 +268,7 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV1> {
 		}
 	}
 
-	public static OperatorStateHandle deserializePartitionableStateHandle(
+	public static OperatorStateHandle deserializeOperatorStateHandle(
 			DataInputStream dis) throws IOException {
 
 		final int type = dis.readByte();
@@ -270,13 +286,14 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV1> {
 				offsetsMap.put(key, offsets);
 			}
 			StreamStateHandle stateHandle = deserializeStreamStateHandle(dis);
-			return new OperatorStateHandle(stateHandle, offsetsMap);
+			return new OperatorStateHandle(offsetsMap, stateHandle);
 		} else {
 			throw new IllegalStateException("Reading invalid OperatorStateHandle, type: " + type);
 		}
 	}
 
-	public static void serializeStreamStateHandle(StreamStateHandle stateHandle, DataOutputStream dos) throws IOException {
+	public static void serializeStreamStateHandle(
+			StreamStateHandle stateHandle, DataOutputStream dos) throws IOException {
 
 		if (stateHandle == null) {
 			dos.writeByte(NULL_HANDLE);

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java
index 7bbdb2a..bf31e51 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java
@@ -25,10 +25,7 @@ import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.blob.BlobKey;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
-import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.OperatorStateHandle;
-import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.util.SerializedValue;
 
 import java.io.Serializable;
@@ -36,7 +33,6 @@ import java.net.URL;
 import java.util.Collection;
 import java.util.List;
 
-
 import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
@@ -95,13 +91,7 @@ public final class TaskDeploymentDescriptor implements Serializable {
 	/** The list of classpaths required to run this task. */
 	private final List<URL> requiredClasspaths;
 
-	/** Handle to the non-partitioned state of the operator chain */
-	private final ChainedStateHandle<StreamStateHandle> operatorState;
-
-	/** Handle to the key-grouped state of the head operator in the chain */
-	private final List<KeyGroupsStateHandle> keyGroupState;
-
-	private final List<Collection<OperatorStateHandle>> partitionableOperatorState;
+	private final TaskStateHandles taskStateHandles;
 
 	/** The execution configuration (see {@link ExecutionConfig}) related to the specific job. */
 	private final SerializedValue<ExecutionConfig> serializedExecutionConfig;
@@ -128,9 +118,7 @@ public final class TaskDeploymentDescriptor implements Serializable {
 		List<BlobKey> requiredJarFiles,
 		List<URL> requiredClasspaths,
 		int targetSlotNumber,
-		ChainedStateHandle<StreamStateHandle> operatorState,
-		List<KeyGroupsStateHandle> keyGroupState,
-		List<Collection<OperatorStateHandle>> partitionableOperatorStateHandles) {
+		TaskStateHandles taskStateHandles) {
 
 		checkArgument(indexInSubtaskGroup >= 0);
 		checkArgument(numberOfSubtasks > indexInSubtaskGroup);
@@ -155,9 +143,7 @@ public final class TaskDeploymentDescriptor implements Serializable {
 		this.requiredJarFiles = checkNotNull(requiredJarFiles);
 		this.requiredClasspaths = checkNotNull(requiredClasspaths);
 		this.targetSlotNumber = targetSlotNumber;
-		this.operatorState = operatorState;
-		this.keyGroupState = keyGroupState;
-		this.partitionableOperatorState = partitionableOperatorStateHandles;
+		this.taskStateHandles = taskStateHandles;
 	}
 
 	public TaskDeploymentDescriptor(
@@ -199,8 +185,6 @@ public final class TaskDeploymentDescriptor implements Serializable {
 			requiredJarFiles,
 			requiredClasspaths,
 			targetSlotNumber,
-			null,
-			null,
 			null);
 	}
 
@@ -346,15 +330,7 @@ public final class TaskDeploymentDescriptor implements Serializable {
 		return strBuilder.toString();
 	}
 
-	public ChainedStateHandle<StreamStateHandle> getOperatorState() {
-		return operatorState;
-	}
-
-	public List<KeyGroupsStateHandle> getKeyGroupState() {
-		return keyGroupState;
-	}
-
-	public List<Collection<OperatorStateHandle>> getPartitionableOperatorState() {
-		return partitionableOperatorState;
+	public TaskStateHandles getTaskStateHandles() {
+		return taskStateHandles;
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java
index f0ff918..af1a640 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java
@@ -26,6 +26,7 @@ import org.apache.flink.core.fs.Path;
 import org.apache.flink.runtime.accumulators.AccumulatorRegistry;
 import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.checkpoint.SubtaskState;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.io.disk.iomanager.IOManager;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
@@ -35,7 +36,6 @@ import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
 import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.metrics.groups.TaskMetricGroup;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.KvState;
 import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
 
@@ -171,12 +171,12 @@ public interface Environment {
 	 * the checkpoint with the give checkpoint-ID. This method does include
 	 * the given state in the checkpoint.
 	 *
-	 * @param checkpointStateHandles All state handles for the checkpointed state
 	 * @param checkpointMetaData the meta data for this checkpoint
+	 * @param subtaskState All state handles for the checkpointed state
 	 */
 	void acknowledgeCheckpoint(
 			CheckpointMetaData checkpointMetaData,
-			CheckpointStateHandles checkpointStateHandles);
+			SubtaskState subtaskState);
 
 	/**
 	 * Marks task execution failed for an external reason (a reason other than the task code itself

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
index 0b56931..17e0df1 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
@@ -34,7 +34,6 @@ import org.apache.flink.runtime.deployment.PartialInputChannelDeploymentDescript
 import org.apache.flink.runtime.deployment.ResultPartitionLocation;
 import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor;
 import org.apache.flink.runtime.execution.ExecutionState;
-import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.instance.ActorGateway;
 import org.apache.flink.runtime.instance.SimpleSlot;
 import org.apache.flink.runtime.instance.SlotProvider;
@@ -47,19 +46,14 @@ import org.apache.flink.runtime.jobmanager.scheduler.ScheduledUnit;
 import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup;
 import org.apache.flink.runtime.messages.Messages;
 import org.apache.flink.runtime.messages.TaskMessages.TaskOperationResult;
-import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.taskmanager.TaskManagerLocation;
 import org.apache.flink.util.ExceptionUtils;
 import org.slf4j.Logger;
-
 import scala.concurrent.ExecutionContext;
 import scala.concurrent.duration.FiniteDuration;
 
 import java.util.ArrayList;
-import java.util.Collection;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.Callable;
@@ -136,12 +130,7 @@ public class Execution implements AccessExecution, Archiveable<ArchivedExecution
 
 	private volatile TaskManagerLocation assignedResourceLocation; // for the archived execution
 
-	private ChainedStateHandle<StreamStateHandle> chainedStateHandle;
-
-	private List<Collection<OperatorStateHandle>> chainedPartitionableStateHandle;
-
-	private List<KeyGroupsStateHandle> keyGroupsStateHandles;
-	
+	private TaskStateHandles taskStateHandles;
 
 	/** The execution context which is used to execute futures. */
 	private ExecutionContext executionContext;
@@ -232,39 +221,27 @@ public class Execution implements AccessExecution, Archiveable<ArchivedExecution
 		return this.stateTimestamps[state.ordinal()];
 	}
 
-	public ChainedStateHandle<StreamStateHandle> getChainedStateHandle() {
-		return chainedStateHandle;
-	}
-
-	public List<KeyGroupsStateHandle> getKeyGroupsStateHandles() {
-		return keyGroupsStateHandles;
-	}
-
-	public List<Collection<OperatorStateHandle>> getChainedPartitionableStateHandle() {
-		return chainedPartitionableStateHandle;
-	}
-
 	public boolean isFinished() {
 		return state.isTerminal();
 	}
 
+	public TaskStateHandles getTaskStateHandles() {
+		return taskStateHandles;
+	}
+
 	/**
 	 * Sets the initial state for the execution. The serialized state is then shipped via the
 	 * {@link TaskDeploymentDescriptor} to the TaskManagers.
 	 *
 	 * @param checkpointStateHandles all checkpointed operator state
 	 */
-	public void setInitialState(CheckpointStateHandles checkpointStateHandles, List<Collection<OperatorStateHandle>> chainedPartitionableStateHandle) {
+	public void setInitialState(TaskStateHandles checkpointStateHandles) {
 
 		if (state != ExecutionState.CREATED) {
 			throw new IllegalArgumentException("Can only assign operator state when execution attempt is in CREATED");
 		}
 
-		if(checkpointStateHandles != null) {
-			this.chainedStateHandle = checkpointStateHandles.getNonPartitionedStateHandles();
-			this.chainedPartitionableStateHandle = chainedPartitionableStateHandle;
-			this.keyGroupsStateHandles = checkpointStateHandles.getKeyGroupsStateHandle();
-		}
+		this.taskStateHandles = checkpointStateHandles;
 	}
 
 	// --------------------------------------------------------------------------------------------
@@ -390,9 +367,7 @@ public class Execution implements AccessExecution, Archiveable<ArchivedExecution
 			final TaskDeploymentDescriptor deployment = vertex.createDeploymentDescriptor(
 				attemptId,
 				slot,
-				chainedStateHandle,
-				keyGroupsStateHandles,
-				chainedPartitionableStateHandle,
+				taskStateHandles,
 				attemptNumber);
 
 			// register this execution at the execution graph, to receive call backs

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
index 96af91e..b647385 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
@@ -18,7 +18,9 @@
 
 package org.apache.flink.runtime.executiongraph;
 
+import org.apache.flink.api.common.Archiveable;
 import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.JobID;
 import org.apache.flink.runtime.JobException;
 import org.apache.flink.runtime.blob.BlobKey;
 import org.apache.flink.runtime.deployment.InputChannelDeploymentDescriptor;
@@ -27,36 +29,28 @@ import org.apache.flink.runtime.deployment.PartialInputChannelDeploymentDescript
 import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor;
 import org.apache.flink.runtime.execution.ExecutionState;
-import org.apache.flink.api.common.Archiveable;
-import org.apache.flink.runtime.instance.SlotProvider;
-import org.apache.flink.runtime.state.OperatorStateHandle;
-import org.apache.flink.runtime.taskmanager.TaskManagerLocation;
 import org.apache.flink.runtime.instance.ActorGateway;
 import org.apache.flink.runtime.instance.SimpleSlot;
+import org.apache.flink.runtime.instance.SlotProvider;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
 import org.apache.flink.runtime.jobgraph.DistributionPattern;
 import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
 import org.apache.flink.runtime.jobgraph.JobEdge;
-import org.apache.flink.api.common.JobID;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobmanager.scheduler.CoLocationConstraint;
 import org.apache.flink.runtime.jobmanager.scheduler.CoLocationGroup;
 import org.apache.flink.runtime.jobmanager.scheduler.NoResourceAvailableException;
+import org.apache.flink.runtime.state.TaskStateHandles;
+import org.apache.flink.runtime.taskmanager.TaskManagerLocation;
 import org.apache.flink.util.ExceptionUtils;
 import org.apache.flink.util.SerializedValue;
-import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.StreamStateHandle;
-
 import org.slf4j.Logger;
-
 import scala.concurrent.duration.FiniteDuration;
 
 import java.io.Serializable;
 import java.net.URL;
 import java.util.ArrayList;
-import java.util.Collection;
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.LinkedHashMap;
@@ -622,9 +616,7 @@ public class ExecutionVertex implements AccessExecutionVertex, Archiveable<Archi
 	TaskDeploymentDescriptor createDeploymentDescriptor(
 			ExecutionAttemptID executionId,
 			SimpleSlot targetSlot,
-			ChainedStateHandle<StreamStateHandle> operatorState,
-			List<KeyGroupsStateHandle> keyGroupStates,
-			List<Collection<OperatorStateHandle>> partitionableOperatorStateHandle,
+			TaskStateHandles taskStateHandles,
 			int attemptNumber) {
 
 		// Produced intermediate results
@@ -676,9 +668,7 @@ public class ExecutionVertex implements AccessExecutionVertex, Archiveable<Archi
 			jarFiles,
 			classpaths,
 			targetSlot.getRoot().getSlotNumber(),
-			operatorState,
-			keyGroupStates,
-			partitionableOperatorStateHandle);
+			taskStateHandles);
 	}
 
 	// --------------------------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/fs/hdfs/HadoopDataInputStream.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/fs/hdfs/HadoopDataInputStream.java b/flink-runtime/src/main/java/org/apache/flink/runtime/fs/hdfs/HadoopDataInputStream.java
index 8893ba4..47b63f0 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/fs/hdfs/HadoopDataInputStream.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/fs/hdfs/HadoopDataInputStream.java
@@ -18,11 +18,10 @@
 
 package org.apache.flink.runtime.fs.hdfs;
 
-import java.io.IOException;
-
 import org.apache.flink.core.fs.FSDataInputStream;
 
 import javax.annotation.Nonnull;
+import java.io.IOException;
 
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
@@ -46,7 +45,11 @@ public final class HadoopDataInputStream extends FSDataInputStream {
 
 	@Override
 	public void seek(long desired) throws IOException {
-		fsDataInputStream.seek(desired);
+		// This optimization prevents some implementations of distributed FS to perform expensive seeks when they are
+		// actually not needed
+		if (desired != getPos()) {
+			fsDataInputStream.seek(desired);
+		}
 	}
 
 	@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java
index e1d15e2..b0c3730 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java
@@ -19,13 +19,7 @@
 package org.apache.flink.runtime.jobgraph.tasks;
 
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
-import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.OperatorStateHandle;
-import org.apache.flink.runtime.state.StreamStateHandle;
-
-import java.util.Collection;
-import java.util.List;
+import org.apache.flink.runtime.state.TaskStateHandles;
 
 /**
  * This interface must be implemented by any invokable that has recoverable state and participates
@@ -37,15 +31,9 @@ public interface StatefulTask {
 	 * Sets the initial state of the operator, upon recovery. The initial state is typically
 	 * a snapshot of the state from a previous execution.
 	 *
-	 * TODO this should use @{@link org.apache.flink.runtime.state.CheckpointStateHandles} after redoing chained state.
-	 *
-	 * @param chainedState Handle for the chained operator states.
-	 * @param keyGroupsState Handle for key group states.
+	 * @param taskStateHandles All state handle for the task.
 	 */
-	void setInitialState(
-		ChainedStateHandle<StreamStateHandle> chainedState,
-		List<KeyGroupsStateHandle> keyGroupsState,
-		List<Collection<OperatorStateHandle>> partitionableOperatorState) throws Exception;
+	void setInitialState(TaskStateHandles taskStateHandles) throws Exception;
 
 	/**
 	 * This method is called to trigger a checkpoint, asynchronously by the checkpoint

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/messages/checkpoint/AcknowledgeCheckpoint.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/messages/checkpoint/AcknowledgeCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/messages/checkpoint/AcknowledgeCheckpoint.java
index ac14d3a..c63bac5 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/messages/checkpoint/AcknowledgeCheckpoint.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/messages/checkpoint/AcknowledgeCheckpoint.java
@@ -20,8 +20,8 @@ package org.apache.flink.runtime.messages.checkpoint;
 
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.checkpoint.SubtaskState;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
 
 import static org.apache.flink.util.Preconditions.checkArgument;
 
@@ -38,7 +38,7 @@ public class AcknowledgeCheckpoint extends AbstractCheckpointMessage implements
 	private static final long serialVersionUID = -7606214777192401493L;
 
 
-	private final CheckpointStateHandles checkpointStateHandles;
+	private final SubtaskState subtaskState;
 
 	private final CheckpointMetaData checkpointMetaData;
 
@@ -55,11 +55,11 @@ public class AcknowledgeCheckpoint extends AbstractCheckpointMessage implements
 			JobID job,
 			ExecutionAttemptID taskExecutionId,
 			CheckpointMetaData checkpointMetaData,
-			CheckpointStateHandles checkpointStateHandles) {
+			SubtaskState subtaskState) {
 
 		super(job, taskExecutionId, checkpointMetaData.getCheckpointId());
 
-		this.checkpointStateHandles = checkpointStateHandles;
+		this.subtaskState = subtaskState;
 		this.checkpointMetaData = checkpointMetaData;
 		// these may be "-1", in case the values are unknown or not set
 		checkArgument(checkpointMetaData.getSyncDurationMillis() >= -1);
@@ -72,8 +72,8 @@ public class AcknowledgeCheckpoint extends AbstractCheckpointMessage implements
 	//  properties
 	// ------------------------------------------------------------------------
 
-	public CheckpointStateHandles getCheckpointStateHandles() {
-		return checkpointStateHandles;
+	public SubtaskState getSubtaskState() {
+		return subtaskState;
 	}
 
 	public long getSynchronousDurationMillis() {
@@ -107,21 +107,21 @@ public class AcknowledgeCheckpoint extends AbstractCheckpointMessage implements
 		}
 
 		AcknowledgeCheckpoint that = (AcknowledgeCheckpoint) o;
-		return checkpointStateHandles != null ?
-				checkpointStateHandles.equals(that.checkpointStateHandles) : that.checkpointStateHandles == null;
+		return subtaskState != null ?
+				subtaskState.equals(that.subtaskState) : that.subtaskState == null;
 
 	}
 
 	@Override
 	public int hashCode() {
 		int result = super.hashCode();
-		result = 31 * result + (checkpointStateHandles != null ? checkpointStateHandles.hashCode() : 0);
+		result = 31 * result + (subtaskState != null ? subtaskState.hashCode() : 0);
 		return result;
 	}
 
 	@Override
 	public String toString() {
 		return String.format("Confirm Task Checkpoint %d for (%s/%s) - state=%s",
-				getCheckpointId(), getJob(), getTaskExecutionId(), checkpointStateHandles);
+				getCheckpointId(), getJob(), getTaskExecutionId(), subtaskState);
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateMessage.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateMessage.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateMessage.java
index 857b8b3..9ac2d44 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateMessage.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/KvStateMessage.java
@@ -130,7 +130,7 @@ public interface KvStateMessage extends Serializable {
 
 			this.jobId = Preconditions.checkNotNull(jobId, "JobID");
 			this.jobVertexId = Preconditions.checkNotNull(jobVertexId, "JobVertexID");
-			Preconditions.checkArgument(keyGroupRange != KeyGroupRange.EMPTY_KEY_GROUP);
+			Preconditions.checkArgument(keyGroupRange != KeyGroupRange.EMPTY_KEY_GROUP_RANGE);
 			this.keyGroupRange = Preconditions.checkNotNull(keyGroupRange);
 			this.registrationName = Preconditions.checkNotNull(registrationName, "Registration name");
 			this.kvStateId = Preconditions.checkNotNull(kvStateId, "KvStateID");
@@ -236,7 +236,7 @@ public interface KvStateMessage extends Serializable {
 
 			this.jobId = Preconditions.checkNotNull(jobId, "JobID");
 			this.jobVertexId = Preconditions.checkNotNull(jobVertexId, "JobVertexID");
-			Preconditions.checkArgument(keyGroupRange != KeyGroupRange.EMPTY_KEY_GROUP);
+			Preconditions.checkArgument(keyGroupRange != KeyGroupRange.EMPTY_KEY_GROUP_RANGE);
 			this.keyGroupRange = Preconditions.checkNotNull(keyGroupRange);
 			this.registrationName = Preconditions.checkNotNull(registrationName, "Registration name");
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
index 7ca3b38..e5d9b2b 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
@@ -50,7 +50,7 @@ import java.util.List;
  * @param <K> Type of the key by which state is keyed.
  */
 public abstract class AbstractKeyedStateBackend<K>
-		implements KeyedStateBackend<K>, SnapshotProvider<KeyGroupsStateHandle>, Closeable {
+		implements KeyedStateBackend<K>, Snapshotable<KeyGroupsStateHandle>, Closeable {
 
 	/** {@link TypeSerializer} for our key. */
 	protected final TypeSerializer<K> keySerializer;

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java
index c683a02..1b53f1a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java
@@ -25,7 +25,6 @@ import org.apache.flink.runtime.query.TaskKvStateRegistry;
 
 import java.io.IOException;
 import java.util.Collection;
-import java.util.List;
 
 /**
  * A state backend defines how state is stored and snapshotted during checkpoints.
@@ -70,7 +69,7 @@ public abstract class AbstractStateBackend implements java.io.Serializable {
 			TypeSerializer<K> keySerializer,
 			int numberOfKeyGroups,
 			KeyGroupRange keyGroupRange,
-			List<KeyGroupsStateHandle> restoredState,
+			Collection<KeyGroupsStateHandle> restoredState,
 			TaskKvStateRegistry kvStateRegistry
 	) throws Exception;
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/BoundedInputStream.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/BoundedInputStream.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/BoundedInputStream.java
new file mode 100644
index 0000000..aeb0ce8
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/BoundedInputStream.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.core.fs.FSDataInputStream;
+
+import java.io.IOException;
+import java.io.InputStream;
+
+/**
+ * Wrapper around a FSDataInputStream to limit the maximum read offset.
+ *
+ * Based on the implementation from org.apache.commons.io.input.BoundedInputStream
+ */
+public class BoundedInputStream extends InputStream {
+	private final FSDataInputStream delegate;
+	private long endOffsetExclusive;
+	private long position;
+	private long mark;
+
+	public BoundedInputStream(FSDataInputStream delegate, long endOffsetExclusive) throws IOException {
+		this.position = delegate.getPos();
+		this.mark = -1L;
+		this.endOffsetExclusive = endOffsetExclusive;
+		this.delegate = delegate;
+	}
+
+	public int read() throws IOException {
+		if (endOffsetExclusive >= 0L && position >= endOffsetExclusive) {
+			return -1;
+		} else {
+			int result = delegate.read();
+			++position;
+			return result;
+		}
+	}
+
+	public int read(byte[] b) throws IOException {
+		return read(b, 0, b.length);
+	}
+
+	public int read(byte[] b, int off, int len) throws IOException {
+		if (endOffsetExclusive >= 0L && position >= endOffsetExclusive) {
+			return -1;
+		} else {
+			long maxRead = endOffsetExclusive >= 0L ? Math.min((long) len, endOffsetExclusive - position) : (long) len;
+			int bytesRead = delegate.read(b, off, (int) maxRead);
+			if (bytesRead == -1) {
+				return -1;
+			} else {
+				position += (long) bytesRead;
+				return bytesRead;
+			}
+		}
+	}
+
+	public long skip(long n) throws IOException {
+		long toSkip = endOffsetExclusive >= 0L ? Math.min(n, endOffsetExclusive - position) : n;
+		long skippedBytes = delegate.skip(toSkip);
+		position += skippedBytes;
+		return skippedBytes;
+	}
+
+	public int available() throws IOException {
+		return endOffsetExclusive >= 0L && position >= endOffsetExclusive ? 0 : delegate.available();
+	}
+
+	public String toString() {
+		return delegate.toString();
+	}
+
+	public void close() throws IOException {
+		delegate.close();
+	}
+
+	public synchronized void reset() throws IOException {
+		delegate.reset();
+		position = mark;
+	}
+
+	public synchronized void mark(int readlimit) {
+		delegate.mark(readlimit);
+		mark = position;
+	}
+
+	public long getEndOffsetExclusive() {
+		return endOffsetExclusive;
+	}
+
+	public void setEndOffsetExclusive(long endOffsetExclusive) {
+		this.endOffsetExclusive = endOffsetExclusive;
+	}
+
+	public boolean markSupported() {
+		return delegate.markSupported();
+	}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/ChainedStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ChainedStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ChainedStateHandle.java
index c6904c0..a807428 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ChainedStateHandle.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ChainedStateHandle.java
@@ -123,4 +123,8 @@ public class ChainedStateHandle<T extends StateObject> implements StateObject {
 	public static <T extends StateObject> ChainedStateHandle<T> wrapSingleHandle(T stateHandleToWrap) {
 		return new ChainedStateHandle<T>(Collections.singletonList(stateHandleToWrap));
 	}
+
+	public static boolean isNullOrEmpty(ChainedStateHandle<?> chainedStateHandle) {
+		return chainedStateHandle == null || chainedStateHandle.isEmpty();
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/CheckpointStateHandles.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/CheckpointStateHandles.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CheckpointStateHandles.java
deleted file mode 100644
index 9daf963..0000000
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/CheckpointStateHandles.java
+++ /dev/null
@@ -1,103 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.runtime.state;
-
-import java.io.Serializable;
-import java.util.List;
-
-/**
- * Container state handles that contains all state handles from the different state types of a checkpointed state.
- * TODO This will be changed in the future if we get rid of chained state and instead connect state directly to individual operators in a chain.
- */
-public class CheckpointStateHandles implements Serializable {
-
-	private static final long serialVersionUID = 3252351989995L;
-
-	private final ChainedStateHandle<StreamStateHandle> nonPartitionedStateHandles;
-
-	private final ChainedStateHandle<OperatorStateHandle> partitioneableStateHandles;
-
-	private final List<KeyGroupsStateHandle> keyGroupsStateHandle;
-
-	public CheckpointStateHandles(
-			ChainedStateHandle<StreamStateHandle> nonPartitionedStateHandles,
-			ChainedStateHandle<OperatorStateHandle> partitioneableStateHandles,
-			List<KeyGroupsStateHandle> keyGroupsStateHandle) {
-
-		this.nonPartitionedStateHandles = nonPartitionedStateHandles;
-		this.partitioneableStateHandles = partitioneableStateHandles;
-		this.keyGroupsStateHandle = keyGroupsStateHandle;
-	}
-
-	public ChainedStateHandle<StreamStateHandle> getNonPartitionedStateHandles() {
-		return nonPartitionedStateHandles;
-	}
-
-	public ChainedStateHandle<OperatorStateHandle> getPartitioneableStateHandles() {
-		return partitioneableStateHandles;
-	}
-
-	public List<KeyGroupsStateHandle> getKeyGroupsStateHandle() {
-		return keyGroupsStateHandle;
-	}
-
-	@Override
-	public boolean equals(Object o) {
-		if (this == o) {
-			return true;
-		}
-		if (!(o instanceof CheckpointStateHandles)) {
-			return false;
-		}
-
-		CheckpointStateHandles that = (CheckpointStateHandles) o;
-
-		if (nonPartitionedStateHandles != null ?
-				!nonPartitionedStateHandles.equals(that.nonPartitionedStateHandles)
-				: that.nonPartitionedStateHandles != null) {
-			return false;
-		}
-
-		if (partitioneableStateHandles != null ?
-				!partitioneableStateHandles.equals(that.partitioneableStateHandles)
-				: that.partitioneableStateHandles != null) {
-			return false;
-		}
-		return keyGroupsStateHandle != null ?
-				keyGroupsStateHandle.equals(that.keyGroupsStateHandle) : that.keyGroupsStateHandle == null;
-
-	}
-
-	@Override
-	public int hashCode() {
-		int result = nonPartitionedStateHandles != null ? nonPartitionedStateHandles.hashCode() : 0;
-		result = 31 * result + (partitioneableStateHandles != null ? partitioneableStateHandles.hashCode() : 0);
-		result = 31 * result + (keyGroupsStateHandle != null ? keyGroupsStateHandle.hashCode() : 0);
-		return result;
-	}
-
-	@Override
-	public String toString() {
-		return "CheckpointStateHandles{" +
-				"nonPartitionedStateHandles=" + nonPartitionedStateHandles +
-				", partitioneableStateHandles=" + partitioneableStateHandles +
-				", keyGroupsStateHandle=" + keyGroupsStateHandle +
-				'}';
-	}
-}

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/ClosableRegistry.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ClosableRegistry.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ClosableRegistry.java
index 26d6192..b5f7dad 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ClosableRegistry.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ClosableRegistry.java
@@ -25,6 +25,13 @@ import java.io.IOException;
 import java.util.HashSet;
 import java.util.Set;
 
+/**
+ * This class allows to register instances of {@link Closeable}, which are all closed if this registry is closed.
+ * <p>
+ * Registering to an already closed registry will throw an exception and close the provided {@link Closeable}
+ * <p>
+ * All methods in this class are thread-safe.
+ */
 public class ClosableRegistry implements Closeable {
 
 	private final Set<Closeable> registeredCloseables;
@@ -35,7 +42,15 @@ public class ClosableRegistry implements Closeable {
 		this.closed = false;
 	}
 
-	public boolean registerClosable(Closeable closeable) {
+	/**
+	 * Registers a {@link Closeable} with the registry. In case the registry is already closed, this method throws an
+	 * {@link IllegalStateException} and closes the passed {@link Closeable}.
+	 *
+	 * @param closeable Closable tor register
+	 * @return true if the the Closable was newly added to the registry
+	 * @throws IOException exception when the registry was closed before
+	 */
+	public boolean registerClosable(Closeable closeable) throws IOException {
 
 		if (null == closeable) {
 			return false;
@@ -43,13 +58,20 @@ public class ClosableRegistry implements Closeable {
 
 		synchronized (getSynchronizationLock()) {
 			if (closed) {
-				throw new IllegalStateException("Cannot register Closable, registry is already closed.");
+				IOUtils.closeQuietly(closeable);
+				throw new IOException("Cannot register Closable, registry is already closed. Closed passed closable.");
 			}
 
 			return registeredCloseables.add(closeable);
 		}
 	}
 
+	/**
+	 * Removes a {@link Closeable} from the registry.
+	 *
+	 * @param closeable instance to remove from the registry.
+	 * @return true, if the instance was actually registered and now removed
+	 */
 	public boolean unregisterClosable(Closeable closeable) {
 
 		if (null == closeable) {
@@ -63,22 +85,24 @@ public class ClosableRegistry implements Closeable {
 
 	@Override
 	public void close() throws IOException {
+		synchronized (getSynchronizationLock()) {
 
-		if (!registeredCloseables.isEmpty()) {
-
-			synchronized (getSynchronizationLock()) {
+			for (Closeable closeable : registeredCloseables) {
+				IOUtils.closeQuietly(closeable);
+			}
 
-				for (Closeable closeable : registeredCloseables) {
-					IOUtils.closeQuietly(closeable);
-				}
+			registeredCloseables.clear();
+			closed = true;
+		}
+	}
 
-				registeredCloseables.clear();
-				closed = true;
-			}
+	public boolean isClosed() {
+		synchronized (getSynchronizationLock()) {
+			return closed;
 		}
 	}
 
 	private Object getSynchronizationLock() {
 		return registeredCloseables;
 	}
-}
\ No newline at end of file
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultKeyedStateStore.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultKeyedStateStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultKeyedStateStore.java
new file mode 100644
index 0000000..776f4b8
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultKeyedStateStore.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.api.common.state.KeyedStateStore;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.ReducingState;
+import org.apache.flink.api.common.state.ReducingStateDescriptor;
+import org.apache.flink.api.common.state.State;
+import org.apache.flink.api.common.state.StateDescriptor;
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.util.Preconditions;
+
+import static java.util.Objects.requireNonNull;
+
+/**
+ * Default implementation of KeyedStateStore that currently forwards state registration to a {@link RuntimeContext}.
+ */
+public class DefaultKeyedStateStore implements KeyedStateStore {
+
+	private final KeyedStateBackend<?> keyedStateBackend;
+	private final ExecutionConfig executionConfig;
+
+	public DefaultKeyedStateStore(KeyedStateBackend<?> keyedStateBackend, ExecutionConfig executionConfig) {
+		this.keyedStateBackend = Preconditions.checkNotNull(keyedStateBackend);
+		this.executionConfig = Preconditions.checkNotNull(executionConfig);
+	}
+
+	@Override
+	public <T> ValueState<T> getState(ValueStateDescriptor<T> stateProperties) {
+		requireNonNull(stateProperties, "The state properties must not be null");
+		try {
+			stateProperties.initializeSerializerUnlessSet(executionConfig);
+			return getPartitionedState(stateProperties);
+		} catch (Exception e) {
+			throw new RuntimeException("Error while getting state", e);
+		}
+	}
+
+	@Override
+	public <T> ListState<T> getListState(ListStateDescriptor<T> stateProperties) {
+		requireNonNull(stateProperties, "The state properties must not be null");
+		try {
+			stateProperties.initializeSerializerUnlessSet(executionConfig);
+			ListState<T> originalState = getPartitionedState(stateProperties);
+			return new UserFacingListState<>(originalState);
+		} catch (Exception e) {
+			throw new RuntimeException("Error while getting state", e);
+		}
+	}
+
+	@Override
+	public <T> ReducingState<T> getReducingState(ReducingStateDescriptor<T> stateProperties) {
+		requireNonNull(stateProperties, "The state properties must not be null");
+		try {
+			stateProperties.initializeSerializerUnlessSet(executionConfig);
+			return getPartitionedState(stateProperties);
+		} catch (Exception e) {
+			throw new RuntimeException("Error while getting state", e);
+		}
+	}
+
+	private <S extends State> S getPartitionedState(StateDescriptor<S, ?> stateDescriptor) throws Exception {
+		return keyedStateBackend.getPartitionedState(
+				VoidNamespace.INSTANCE,
+				VoidNamespaceSerializer.INSTANCE,
+				stateDescriptor);
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
index b1ab7e3..2f5d3cb 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
@@ -20,7 +20,6 @@ package org.apache.flink.runtime.state;
 
 import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ListStateDescriptor;
-import org.apache.flink.api.common.state.OperatorStateStore;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.fs.FSDataOutputStream;
@@ -45,6 +44,9 @@ import java.util.concurrent.RunnableFuture;
  */
 public class DefaultOperatorStateBackend implements OperatorStateBackend {
 
+	/** The default namespace for state in cases where no state name is provided */
+	public static final String DEFAULT_OPERATOR_STATE_NAME = "_default_";
+	
 	private final Map<String, PartitionableListState<?>> registeredStates;
 	private final Collection<OperatorStateHandle> restoreSnapshots;
 	private final ClosableRegistry closeStreamOnCancelRegistry;
@@ -72,15 +74,12 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 	public DefaultOperatorStateBackend(ClassLoader userClassLoader) {
 		this(userClassLoader, null);
 	}
-
+	@SuppressWarnings("unchecked")
 	@Override
-	public ListState<Serializable> getSerializableListState(String stateName) throws Exception {
-		return getOperatorState(new ListStateDescriptor<>(stateName, javaSerializer));
+	public <T extends Serializable> ListState<T> getSerializableListState(String stateName) throws Exception {
+		return (ListState<T>) getOperatorState(new ListStateDescriptor<>(stateName, javaSerializer));
 	}
-
-	/**
-	 * @see OperatorStateStore
-	 */
+	
 	@Override
 	public <S> ListState<S> getOperatorState(
 			ListStateDescriptor<S> stateDescriptor) throws IOException {
@@ -102,8 +101,9 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 			// Try to restore previous state if state handles to snapshots are provided
 			if (restoreSnapshots != null) {
 				for (OperatorStateHandle stateHandle : restoreSnapshots) {
-
-					long[] offsets = stateHandle.getStateNameToPartitionOffsets().get(name);
+					//TODO we coud be even more gc friendly be removing handles from the collections one the map is empty
+					// search and remove to be gc friendly
+					long[] offsets = stateHandle.getStateNameToPartitionOffsets().remove(name);
 
 					if (offsets != null) {
 
@@ -130,10 +130,7 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 
 		return partitionableListState;
 	}
-
-	/**
-	 * @see SnapshotProvider
-	 */
+	
 	@Override
 	public RunnableFuture<OperatorStateHandle> snapshot(
 			long checkpointId, long timestamp, CheckpointStreamFactory streamFactory) throws Exception {
@@ -159,7 +156,7 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 				writtenStatesMetaData.put(entry.getKey(), partitionOffsets);
 			}
 
-			OperatorStateHandle handle = new OperatorStateHandle(out.closeAndGetHandle(), writtenStatesMetaData);
+			OperatorStateHandle handle = new OperatorStateHandle(writtenStatesMetaData, out.closeAndGetHandle());
 
 			return new DoneFuture<>(handle);
 		} finally {
@@ -170,48 +167,59 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 
 	@Override
 	public void dispose() {
-
+		registeredStates.clear();
 	}
 
 	static final class PartitionableListState<S> implements ListState<S> {
 
-		private final List<S> listState;
+		private final List<S> internalList;
 		private final TypeSerializer<S> partitionStateSerializer;
 
 		public PartitionableListState(TypeSerializer<S> partitionStateSerializer) {
-			this.listState = new ArrayList<>();
+			this.internalList = new ArrayList<>();
 			this.partitionStateSerializer = Preconditions.checkNotNull(partitionStateSerializer);
 		}
 
 		@Override
 		public void clear() {
-			listState.clear();
+			internalList.clear();
 		}
 
 		@Override
 		public Iterable<S> get() {
-			return listState;
+			return internalList;
 		}
 
 		@Override
 		public void add(S value) {
-			listState.add(value);
+			internalList.add(value);
 		}
 
 		public long[] write(FSDataOutputStream out) throws IOException {
 
-			long[] partitionOffsets = new long[listState.size()];
+			long[] partitionOffsets = new long[internalList.size()];
 
 			DataOutputView dov = new DataOutputViewStreamWrapper(out);
 
-			for (int i = 0; i < listState.size(); ++i) {
-				S element = listState.get(i);
+			for (int i = 0; i < internalList.size(); ++i) {
+				S element = internalList.get(i);
 				partitionOffsets[i] = out.getPos();
 				partitionStateSerializer.serialize(element, dov);
 			}
 
 			return partitionOffsets;
 		}
+
+		public List<S> getInternalList() {
+			return internalList;
+		}
+
+		@Override
+		public String toString() {
+			return "PartitionableListState{" +
+					"listState=" + internalList +
+					'}';
+		}
 	}
 
 	@Override
@@ -223,5 +231,6 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 	public void close() throws IOException {
 		closeStreamOnCancelRegistry.close();
 	}
+
 }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/FunctionInitializationContext.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/FunctionInitializationContext.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/FunctionInitializationContext.java
new file mode 100644
index 0000000..ff632ef
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/FunctionInitializationContext.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.annotation.PublicEvolving;
+
+/**
+ * This interface provides a context in which user functions can initialize by registering to managed state (i.e. state
+ * that is managed by state backends).
+ *
+ * <p>
+ * Operator state is available to all functions, while keyed state is only available for functions after keyBy.
+ *
+ * <p>
+ * For the purpose of initialization, the context signals if the state is empty or was restored from a previous
+ * execution.
+ *
+ */
+@PublicEvolving
+public interface FunctionInitializationContext extends ManagedInitializationContext {
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/FunctionSnapshotContext.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/FunctionSnapshotContext.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/FunctionSnapshotContext.java
new file mode 100644
index 0000000..571b881
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/FunctionSnapshotContext.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.annotation.PublicEvolving;
+
+/**
+ * This interface provides a context in which user functions that use managed state (i.e. state that is managed by state
+ * backends) can participate in a snapshot. As snapshots of the backends themselves are taken by the system, this
+ * interface mainly provides meta information about the checkpoint.
+ */
+@PublicEvolving
+public interface FunctionSnapshotContext extends ManagedSnapshotContext {
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRange.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRange.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRange.java
index 3a9d3d0..32151db 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRange.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRange.java
@@ -27,10 +27,12 @@ import java.util.Iterator;
  * This class defines a range of key-group indexes. Key-groups are the granularity into which the keyspace of a job
  * is partitioned for keyed state-handling in state backends. The boundaries of the range are inclusive.
  */
-public class KeyGroupRange implements Iterable<Integer>, Serializable {
+public class KeyGroupRange implements KeyGroupsList, Serializable {
+
+	private static final long serialVersionUID = 4869121477592070607L;
 
 	/** The empty key-group */
-	public static final KeyGroupRange EMPTY_KEY_GROUP = new KeyGroupRange();
+	public static final KeyGroupRange EMPTY_KEY_GROUP_RANGE = new KeyGroupRange();
 
 	private final int startKeyGroup;
 	private final int endKeyGroup;
@@ -64,6 +66,7 @@ public class KeyGroupRange implements Iterable<Integer>, Serializable {
 	 * @param keyGroup Key-group to check for inclusion.
 	 * @return True, only if the key-group is in the range.
 	 */
+	@Override
 	public boolean contains(int keyGroup) {
 		return keyGroup >= startKeyGroup && keyGroup <= endKeyGroup;
 	}
@@ -77,13 +80,14 @@ public class KeyGroupRange implements Iterable<Integer>, Serializable {
 	public KeyGroupRange getIntersection(KeyGroupRange other) {
 		int start = Math.max(startKeyGroup, other.startKeyGroup);
 		int end = Math.min(endKeyGroup, other.endKeyGroup);
-		return start <= end ? new KeyGroupRange(start, end) : EMPTY_KEY_GROUP;
+		return start <= end ? new KeyGroupRange(start, end) : EMPTY_KEY_GROUP_RANGE;
 	}
 
 	/**
 	 *
 	 * @return The number of key-groups in the range
 	 */
+	@Override
 	public int getNumberOfKeyGroups() {
 		return 1 + endKeyGroup - startKeyGroup;
 	}
@@ -105,6 +109,14 @@ public class KeyGroupRange implements Iterable<Integer>, Serializable {
 	}
 
 	@Override
+	public int getKeyGroupId(int idx) {
+		if (idx < 0 || idx > getNumberOfKeyGroups()) {
+			throw new IndexOutOfBoundsException("Key group index out of bounds: " + idx);
+		}
+		return startKeyGroup + idx;
+	}
+
+	@Override
 	public boolean equals(Object o) {
 		if (this == o) {
 			return true;
@@ -172,7 +184,6 @@ public class KeyGroupRange implements Iterable<Integer>, Serializable {
 	 * @return the key-group from start to end or an empty key-group range.
 	 */
 	public static KeyGroupRange of(int startKeyGroup, int endKeyGroup) {
-		return startKeyGroup <= endKeyGroup ? new KeyGroupRange(startKeyGroup, endKeyGroup) : EMPTY_KEY_GROUP;
+		return startKeyGroup <= endKeyGroup ? new KeyGroupRange(startKeyGroup, endKeyGroup) : EMPTY_KEY_GROUP_RANGE;
 	}
-
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRangeOffsets.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRangeOffsets.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRangeOffsets.java
index 8e7207e..d425278 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRangeOffsets.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRangeOffsets.java
@@ -137,7 +137,11 @@ public class KeyGroupRangeOffsets implements Iterable<Tuple2<Integer, Long>> , S
 	}
 
 	private int computeKeyGroupIndex(int keyGroup) {
-		return keyGroup - keyGroupRange.getStartKeyGroup();
+		int idx = keyGroup - keyGroupRange.getStartKeyGroup();
+		if (idx < 0 || idx >= offsets.length) {
+			throw new IllegalArgumentException("Key group " + keyGroup + " is not in " + keyGroupRange + ".");
+		}
+		return idx;
 	}
 
 	/**

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupStatePartitionStreamProvider.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupStatePartitionStreamProvider.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupStatePartitionStreamProvider.java
new file mode 100644
index 0000000..2a91f0f
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupStatePartitionStreamProvider.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.annotation.PublicEvolving;
+
+import java.io.IOException;
+import java.io.InputStream;
+
+/**
+ * This class provides access to an input stream that contains state data for one key group and the key group id.
+ */
+@PublicEvolving
+public class KeyGroupStatePartitionStreamProvider extends StatePartitionStreamProvider {
+
+	/** Key group that corresponds to the data in the provided stream */
+	private final int keyGroupId;
+
+	public KeyGroupStatePartitionStreamProvider(IOException creationException, int keyGroupId) {
+		super(creationException);
+		this.keyGroupId = keyGroupId;
+	}
+
+	public KeyGroupStatePartitionStreamProvider(InputStream stream, int keyGroupId) {
+		super(stream);
+		this.keyGroupId = keyGroupId;
+	}
+
+	/**
+	 * Returns the key group that corresponds to the data in the provided stream.
+	 */
+	public int getKeyGroupId() {
+		return keyGroupId;
+	}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsList.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsList.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsList.java
new file mode 100644
index 0000000..928ebf3
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsList.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+/**
+ * This interface offers ordered random read access to multiple key group ids.
+ */
+public interface KeyGroupsList extends Iterable<Integer> {
+
+	/**
+	 * Returns the number of key group ids in the list.
+	 */
+	int getNumberOfKeyGroups();
+
+	/**
+	 * Returns the id of the keygroup at the given index, where index in interval [0,  {@link #getNumberOfKeyGroups()}[.
+	 *
+	 * @param idx the index into the list
+	 * @return key group id at the given index
+	 */
+	int getKeyGroupId(int idx);
+
+	/**
+	 * Returns true, if the given key group id is contained in the list, otherwise false.
+	 */
+	boolean contains(int keyGroupId);
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
index 7293a84..03f584e 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
@@ -54,9 +54,9 @@ public interface KeyedStateBackend<K> {
 	int getNumberOfKeyGroups();
 
 	/**
-	 * Returns the key group range for this backend.
+	 * Returns the key groups for this backend.
 	 */
-	KeyGroupRange getKeyGroupRange();
+	KeyGroupsList getKeyGroupRange();
 
 	/**
 	 * {@link TypeSerializer} for the state backend key type.

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStream.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStream.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStream.java
new file mode 100644
index 0000000..2121574
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateCheckpointOutputStream.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.Arrays;
+
+/**
+ * Checkpoint output stream that allows to write raw keyed state in a partitioned way, split into key-groups.
+ */
+@PublicEvolving
+public final class KeyedStateCheckpointOutputStream extends NonClosingCheckpointOutputStream<KeyGroupsStateHandle> {
+
+	public static final long NO_OFFSET_SET = -1L;
+	public static final int NO_CURRENT_KEY_GROUP = -1;
+
+	private int currentKeyGroup;
+	private final KeyGroupRangeOffsets keyGroupRangeOffsets;
+
+	public KeyedStateCheckpointOutputStream(
+			CheckpointStreamFactory.CheckpointStateOutputStream delegate, KeyGroupRange keyGroupRange) {
+
+		super(delegate);
+		Preconditions.checkNotNull(keyGroupRange);
+		Preconditions.checkArgument(keyGroupRange != KeyGroupRange.EMPTY_KEY_GROUP_RANGE);
+
+		this.currentKeyGroup = NO_CURRENT_KEY_GROUP;
+		long[] emptyOffsets = new long[keyGroupRange.getNumberOfKeyGroups()];
+		// mark offsets as currently not set
+		Arrays.fill(emptyOffsets, NO_OFFSET_SET);
+		this.keyGroupRangeOffsets = new KeyGroupRangeOffsets(keyGroupRange, emptyOffsets);
+	}
+
+	@Override
+	public void close() throws IOException {
+		// users should not be able to actually close the stream, it is closed by the system.
+		// TODO if we want to support async writes, this call could trigger a callback to the snapshot context that a handle is available.
+	}
+
+	/**
+	 * Returns a list of all key-groups which can be written to this stream.
+	 */
+	public KeyGroupsList getKeyGroupList() {
+		return keyGroupRangeOffsets.getKeyGroupRange();
+	}
+
+	/**
+	 * User code can call this method to signal that it begins to write a new key group with the given key group id.
+	 * This id must be within the {@link KeyGroupsList} provided by the stream. Each key-group can only be started once
+	 * and is considered final/immutable as soon as this method is called again.
+	 */
+	public void startNewKeyGroup(int keyGroupId) throws IOException {
+		if (isKeyGroupAlreadyStarted(keyGroupId)) {
+			throw new IOException("Key group " + keyGroupId + " already registered!");
+		}
+		keyGroupRangeOffsets.setKeyGroupOffset(keyGroupId, delegate.getPos());
+		currentKeyGroup = keyGroupId;
+	}
+
+	/**
+	 * Returns true, if the key group with the given id was already started. The key group might not yet be finished,
+	 * if it's id is equal to the return value of {@link #getCurrentKeyGroup()}.
+	 */
+	public boolean isKeyGroupAlreadyStarted(int keyGroupId) {
+		return NO_OFFSET_SET != keyGroupRangeOffsets.getKeyGroupOffset(keyGroupId);
+	}
+
+	/**
+	 * Returns true if the key group is already completely written and immutable. It was started and since then another
+	 * key group has been started.
+	 */
+	public boolean isKeyGroupAlreadyFinished(int keyGroupId) {
+		return isKeyGroupAlreadyStarted(keyGroupId) && keyGroupId != getCurrentKeyGroup();
+	}
+
+	/**
+	 * Returns the key group that is currently being written. The key group was started but not yet finished, i.e. data
+	 * can still be added. If no key group was started, this returns {@link #NO_CURRENT_KEY_GROUP}.
+	 */
+	public int getCurrentKeyGroup() {
+		return currentKeyGroup;
+	}
+
+	@Override
+	KeyGroupsStateHandle closeAndGetHandle() throws IOException {
+		StreamStateHandle streamStateHandle = delegate.closeAndGetHandle();
+		return streamStateHandle != null ? new KeyGroupsStateHandle(keyGroupRangeOffsets, streamStateHandle) : null;
+	}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/ManagedInitializationContext.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ManagedInitializationContext.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ManagedInitializationContext.java
new file mode 100644
index 0000000..abc528b
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ManagedInitializationContext.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.api.common.state.KeyedStateStore;
+import org.apache.flink.api.common.state.OperatorStateStore;
+
+/**
+ * This interface provides a context in which operators can initialize by registering to managed state (i.e. state that
+ * is managed by state backends).
+ *
+ * <p>
+ * Operator state is available to all operators, while keyed state is only available for operators after keyBy.
+ *
+ * <p>
+ * For the purpose of initialization, the context signals if the state is empty (new operator) or was restored from
+ * a previous execution of this operator.
+ *
+ */
+public interface ManagedInitializationContext {
+
+	/**
+	 * Returns true, if some managed state was restored from the snapshot of a previous execution.
+	 */
+	boolean isRestored();
+
+	/**
+	 * Returns an interface that allows for registering operator state with the backend.
+	 */
+	OperatorStateStore getManagedOperatorStateStore();
+
+	/**
+	 * Returns an interface that allows for registering keyed state with the backend.
+	 */
+	KeyedStateStore getManagedKeyedStateStore();
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/ManagedSnapshotContext.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ManagedSnapshotContext.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ManagedSnapshotContext.java
new file mode 100644
index 0000000..14156a6
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ManagedSnapshotContext.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.annotation.PublicEvolving;
+
+/**
+ * This interface provides a context in which operators that use managed state (i.e. state that is managed by state
+ * backends) can perform a snapshot. As snapshots of the backends themselves are taken by the system, this interface
+ * mainly provides meta information about the checkpoint.
+ */
+@PublicEvolving
+public interface ManagedSnapshotContext {
+
+	/**
+	 * Returns the Id of the checkpoint for which the snapshot is taken.
+	 */
+	long getCheckpointId();
+
+	/**
+	 * Returns the timestamp of the checkpoint for which the snapshot is taken.
+	 */
+	long getCheckpointTimestamp();
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/NonClosingCheckpointOutputStream.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/NonClosingCheckpointOutputStream.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/NonClosingCheckpointOutputStream.java
new file mode 100644
index 0000000..f7f4bdb
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/NonClosingCheckpointOutputStream.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.OutputStream;
+
+/**
+ * Abstract class to implement custom checkpoint output streams which should not be closable for user code.
+ * 
+ * @param <T> type of the returned state handle.
+ */
+public abstract class NonClosingCheckpointOutputStream<T extends StreamStateHandle> extends OutputStream {
+
+	protected final CheckpointStreamFactory.CheckpointStateOutputStream delegate;
+
+	public NonClosingCheckpointOutputStream(
+			CheckpointStreamFactory.CheckpointStateOutputStream delegate) {
+		this.delegate = Preconditions.checkNotNull(delegate);
+	}
+
+	@Override
+	public void flush() throws IOException {
+		delegate.flush();
+	}
+
+	@Override
+	public void write(int b) throws IOException {
+		delegate.write(b);
+	}
+
+	@Override
+	public void write(byte[] b) throws IOException {
+		delegate.write(b);
+	}
+
+	@Override
+	public void write(byte[] b, int off, int len) throws IOException {
+		delegate.write(b, off, len);
+	}
+
+	@Override
+	public void close() throws IOException {
+		// users should not be able to actually close the stream, it is closed by the system.
+		// TODO if we want to support async writes, this call could trigger a callback to the snapshot context that a handle is available.
+	}
+
+
+	/**
+	 * This method should not be public so as to not expose internals to user code.
+	 */
+	CheckpointStreamFactory.CheckpointStateOutputStream getDelegate() {
+		return delegate;
+	}
+
+	/**
+	 * This method should not be public so as to not expose internals to user code. Closes the underlying stream and
+	 * returns a state handle.
+	 */
+	abstract T closeAndGetHandle() throws IOException;
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateBackend.java
index 83e6369..aee5226 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateBackend.java
@@ -24,10 +24,10 @@ import java.io.Closeable;
 
 /**
  * Interface that combines both, the user facing {@link OperatorStateStore} interface and the system interface
- * {@link SnapshotProvider}
+ * {@link Snapshotable}
  *
  */
-public interface OperatorStateBackend extends OperatorStateStore, SnapshotProvider<OperatorStateHandle>, Closeable {
+public interface OperatorStateBackend extends OperatorStateStore, Snapshotable<OperatorStateHandle>, Closeable {
 
 	/**
 	 * Disposes the backend and releases all resources.

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateCheckpointOutputStream.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateCheckpointOutputStream.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateCheckpointOutputStream.java
new file mode 100644
index 0000000..eaa9fd9
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateCheckpointOutputStream.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.runtime.util.LongArrayList;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Checkpoint output stream that allows to write raw operator state in a partitioned way.
+ */
+@PublicEvolving
+public final class OperatorStateCheckpointOutputStream
+		extends NonClosingCheckpointOutputStream<OperatorStateHandle> {
+
+	private LongArrayList partitionOffsets;
+	private final long initialPosition;
+
+	public OperatorStateCheckpointOutputStream(
+			CheckpointStreamFactory.CheckpointStateOutputStream delegate) throws IOException {
+
+		super(delegate);
+		this.partitionOffsets = new LongArrayList(16);
+		this.initialPosition = delegate.getPos();
+	}
+
+	/**
+	 * User code can call this method to signal that it begins to write a new partition of operator state.
+	 * Each previously written partition is considered final/immutable as soon as this method is called again.
+	 */
+	public void startNewPartition() throws IOException {
+		partitionOffsets.add(delegate.getPos());
+	}
+
+	/**
+	 * This method should not be public so as to not expose internals to user code.
+	 */
+	@Override
+	OperatorStateHandle closeAndGetHandle() throws IOException {
+		StreamStateHandle streamStateHandle = delegate.closeAndGetHandle();
+
+		if (null == streamStateHandle) {
+			return null;
+		}
+
+		if (partitionOffsets.isEmpty() && delegate.getPos() > initialPosition) {
+			startNewPartition();
+		}
+
+		Map<String, long[]> offsetsMap = new HashMap<>(1);
+		offsetsMap.put(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME, partitionOffsets.toArray());
+
+		return new OperatorStateHandle(offsetsMap, streamStateHandle);
+	}
+
+	public int getNumberOfPartitions() {
+		return partitionOffsets.size();
+	}
+}
\ No newline at end of file


[2/8] flink git commit: [FLINK-4844] Partitionable Raw Keyed/Operator State

Posted by al...@apache.org.
http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
index 0e513fa..95115d6 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
@@ -22,6 +22,7 @@ import io.netty.util.internal.ConcurrentSet;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.functions.RichFlatMapFunction;
 import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ValueState;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.java.functions.KeySelector;
@@ -32,17 +33,21 @@ import org.apache.flink.runtime.client.JobExecutionException;
 import org.apache.flink.runtime.instance.ActorGateway;
 import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.messages.JobManagerMessages;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.filesystem.FsStateBackendFactory;
 import org.apache.flink.runtime.testingUtils.TestingCluster;
 import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages;
 import org.apache.flink.streaming.api.checkpoint.Checkpointed;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
 import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
 import org.apache.flink.streaming.api.functions.sink.SinkFunction;
 import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
 import org.apache.flink.util.Collector;
 import org.apache.flink.util.TestLogger;
 import org.junit.AfterClass;
@@ -71,7 +76,6 @@ import static org.junit.Assert.fail;
 
 /**
  * TODO : parameterize to test all different state backends!
- * TODO: reactivate ignored test as soon as savepoints work with deactivated checkpoints.
  */
 public class RescalingITCase extends TestLogger {
 
@@ -79,6 +83,10 @@ public class RescalingITCase extends TestLogger {
 	private static final int slotsPerTaskManager = 2;
 	private static final int numSlots = numTaskManagers * slotsPerTaskManager;
 
+	enum OperatorCheckpointMethod {
+		NON_PARTITIONED, CHECKPOINTED_FUNCTION, LIST_CHECKPOINTED
+	}
+
 	private static TestingCluster cluster;
 
 	@ClassRule
@@ -242,7 +250,7 @@ public class RescalingITCase extends TestLogger {
 		try {
 			jobManager = cluster.getLeaderGateway(deadline.timeLeft());
 
-			JobGraph jobGraph = createJobGraphWithOperatorState(parallelism, maxParallelism, false);
+			JobGraph jobGraph = createJobGraphWithOperatorState(parallelism, maxParallelism, OperatorCheckpointMethod.NON_PARTITIONED);
 
 			jobID = jobGraph.getJobID();
 
@@ -280,7 +288,7 @@ public class RescalingITCase extends TestLogger {
 			// job successfully removed
 			jobID = null;
 
-			JobGraph scaledJobGraph = createJobGraphWithOperatorState(parallelism2, maxParallelism, false);
+			JobGraph scaledJobGraph = createJobGraphWithOperatorState(parallelism2, maxParallelism, OperatorCheckpointMethod.NON_PARTITIONED);
 
 			scaledJobGraph.setSavepointPath(savepointPath);
 
@@ -433,12 +441,22 @@ public class RescalingITCase extends TestLogger {
 
 	@Test
 	public void testSavepointRescalingInPartitionedOperatorState() throws Exception {
-		testSavepointRescalingPartitionedOperatorState(false);
+		testSavepointRescalingPartitionedOperatorState(false, OperatorCheckpointMethod.CHECKPOINTED_FUNCTION);
 	}
 
 	@Test
 	public void testSavepointRescalingOutPartitionedOperatorState() throws Exception {
-		testSavepointRescalingPartitionedOperatorState(true);
+		testSavepointRescalingPartitionedOperatorState(true, OperatorCheckpointMethod.CHECKPOINTED_FUNCTION);
+	}
+
+	@Test
+	public void testSavepointRescalingInPartitionedOperatorStateList() throws Exception {
+		testSavepointRescalingPartitionedOperatorState(false, OperatorCheckpointMethod.LIST_CHECKPOINTED);
+	}
+
+	@Test
+	public void testSavepointRescalingOutPartitionedOperatorStateList() throws Exception {
+		testSavepointRescalingPartitionedOperatorState(true, OperatorCheckpointMethod.LIST_CHECKPOINTED);
 	}
 
 
@@ -446,7 +464,7 @@ public class RescalingITCase extends TestLogger {
 	 * Tests rescaling of partitioned operator state. More specific, we test the mechanism with {@link ListCheckpointed}
 	 * as it subsumes {@link org.apache.flink.streaming.api.checkpoint.CheckpointedFunction}.
 	 */
-	public void testSavepointRescalingPartitionedOperatorState(boolean scaleOut) throws Exception {
+	public void testSavepointRescalingPartitionedOperatorState(boolean scaleOut, OperatorCheckpointMethod checkpointMethod) throws Exception {
 		final int parallelism = scaleOut ? numSlots : numSlots / 2;
 		final int parallelism2 = scaleOut ? numSlots / 2 : numSlots;
 		final int maxParallelism = 13;
@@ -459,13 +477,18 @@ public class RescalingITCase extends TestLogger {
 
 		int counterSize = Math.max(parallelism, parallelism2);
 
-		PartitionedStateSource.CHECK_CORRECT_SNAPSHOT = new int[counterSize];
-		PartitionedStateSource.CHECK_CORRECT_RESTORE = new int[counterSize];
+		if(checkpointMethod == OperatorCheckpointMethod.CHECKPOINTED_FUNCTION) {
+			PartitionedStateSource.CHECK_CORRECT_SNAPSHOT = new int[counterSize];
+			PartitionedStateSource.CHECK_CORRECT_RESTORE = new int[counterSize];
+		} else {
+			PartitionedStateSourceListCheckpointed.CHECK_CORRECT_SNAPSHOT = new int[counterSize];
+			PartitionedStateSourceListCheckpointed.CHECK_CORRECT_RESTORE = new int[counterSize];
+		}
 
 		try {
 			jobManager = cluster.getLeaderGateway(deadline.timeLeft());
 
-			JobGraph jobGraph = createJobGraphWithOperatorState(parallelism, maxParallelism, true);
+			JobGraph jobGraph = createJobGraphWithOperatorState(parallelism, maxParallelism, checkpointMethod);
 
 			jobID = jobGraph.getJobID();
 
@@ -504,7 +527,7 @@ public class RescalingITCase extends TestLogger {
 			// job successfully removed
 			jobID = null;
 
-			JobGraph scaledJobGraph = createJobGraphWithOperatorState(parallelism2, maxParallelism, true);
+			JobGraph scaledJobGraph = createJobGraphWithOperatorState(parallelism2, maxParallelism, checkpointMethod);
 
 			scaledJobGraph.setSavepointPath(savepointPath);
 
@@ -515,12 +538,22 @@ public class RescalingITCase extends TestLogger {
 			int sumExp = 0;
 			int sumAct = 0;
 
-			for (int c : PartitionedStateSource.CHECK_CORRECT_SNAPSHOT) {
-				sumExp += c;
-			}
+			if (checkpointMethod == OperatorCheckpointMethod.CHECKPOINTED_FUNCTION) {
+				for (int c : PartitionedStateSource.CHECK_CORRECT_SNAPSHOT) {
+					sumExp += c;
+				}
+
+				for (int c : PartitionedStateSource.CHECK_CORRECT_RESTORE) {
+					sumAct += c;
+				}
+			} else {
+				for (int c : PartitionedStateSourceListCheckpointed.CHECK_CORRECT_SNAPSHOT) {
+					sumExp += c;
+				}
 
-			for (int c : PartitionedStateSource.CHECK_CORRECT_RESTORE) {
-				sumAct += c;
+				for (int c : PartitionedStateSourceListCheckpointed.CHECK_CORRECT_RESTORE) {
+					sumAct += c;
+				}
 			}
 
 			assertEquals(sumExp, sumAct);
@@ -543,7 +576,7 @@ public class RescalingITCase extends TestLogger {
 	//------------------------------------------------------------------------------------------------------------------
 
 	private static JobGraph createJobGraphWithOperatorState(
-			int parallelism, int maxParallelism, boolean partitionedOperatorState) {
+			int parallelism, int maxParallelism, OperatorCheckpointMethod checkpointMethod) {
 
 		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
 		env.setParallelism(parallelism);
@@ -553,8 +586,23 @@ public class RescalingITCase extends TestLogger {
 
 		StateSourceBase.workStartedLatch = new CountDownLatch(1);
 
-		DataStream<Integer> input = env.addSource(
-				partitionedOperatorState ? new PartitionedStateSource() : new NonPartitionedStateSource());
+		SourceFunction<Integer> src;
+
+		switch (checkpointMethod) {
+			case CHECKPOINTED_FUNCTION:
+				src = new PartitionedStateSource();
+				break;
+			case LIST_CHECKPOINTED:
+				src = new PartitionedStateSourceListCheckpointed();
+				break;
+			case NON_PARTITIONED:
+				src = new NonPartitionedStateSource();
+				break;
+			default:
+				throw new IllegalArgumentException();
+		}
+
+		DataStream<Integer> input = env.addSource(src);
 
 		input.addSink(new DiscardingSink<Integer>());
 
@@ -711,7 +759,7 @@ public class RescalingITCase extends TestLogger {
 		}
 	}
 
-	private static class SubtaskIndexFlatMapper extends RichFlatMapFunction<Integer, Tuple2<Integer, Integer>> {
+	private static class SubtaskIndexFlatMapper extends RichFlatMapFunction<Integer, Tuple2<Integer, Integer>> implements CheckpointedFunction {
 
 		private static final long serialVersionUID = 5273172591283191348L;
 
@@ -727,12 +775,6 @@ public class RescalingITCase extends TestLogger {
 		}
 
 		@Override
-		public void open(Configuration configuration) {
-			counter = getRuntimeContext().getState(new ValueStateDescriptor<>("counter", Integer.class, 0));
-			sum = getRuntimeContext().getState(new ValueStateDescriptor<>("sum", Integer.class, 0));
-		}
-
-		@Override
 		public void flatMap(Integer value, Collector<Tuple2<Integer, Integer>> out) throws Exception {
 
 			int count = counter.value() + 1;
@@ -746,6 +788,17 @@ public class RescalingITCase extends TestLogger {
 				workCompletedLatch.countDown();
 			}
 		}
+
+		@Override
+		public void snapshotState(FunctionSnapshotContext context) throws Exception {
+			//all managed, nothing to do.
+		}
+
+		@Override
+		public void initializeState(FunctionInitializationContext context) throws Exception {
+			counter = context.getManagedKeyedStateStore().getState(new ValueStateDescriptor<>("counter", Integer.class, 0));
+			sum = context.getManagedKeyedStateStore().getState(new ValueStateDescriptor<>("sum", Integer.class, 0));
+		}
 	}
 
 	private static class CollectionSink<IN> implements SinkFunction<IN> {
@@ -817,9 +870,9 @@ public class RescalingITCase extends TestLogger {
 		}
 	}
 
-	private static class PartitionedStateSource extends StateSourceBase implements ListCheckpointed<Integer> {
+	private static class PartitionedStateSourceListCheckpointed extends StateSourceBase implements ListCheckpointed<Integer> {
 
-		private static final long serialVersionUID = -359715965103593462L;
+		private static final long serialVersionUID = -4357864582992546L;
 		private static final int NUM_PARTITIONS = 7;
 
 		private static int[] CHECK_CORRECT_SNAPSHOT;
@@ -853,4 +906,46 @@ public class RescalingITCase extends TestLogger {
 			CHECK_CORRECT_RESTORE[getRuntimeContext().getIndexOfThisSubtask()] = counter;
 		}
 	}
+
+	private static class PartitionedStateSource extends StateSourceBase implements CheckpointedFunction {
+
+		private static final long serialVersionUID = -359715965103593462L;
+		private static final int NUM_PARTITIONS = 7;
+
+		private ListState<Integer> counterPartitions;
+
+		private static int[] CHECK_CORRECT_SNAPSHOT;
+		private static int[] CHECK_CORRECT_RESTORE;
+
+
+		@Override
+		public void snapshotState(FunctionSnapshotContext context) throws Exception {
+
+			CHECK_CORRECT_SNAPSHOT[getRuntimeContext().getIndexOfThisSubtask()] = counter;
+
+			int div = counter / NUM_PARTITIONS;
+			int mod = counter % NUM_PARTITIONS;
+
+			for (int i = 0; i < NUM_PARTITIONS; ++i) {
+				int partitionValue = div;
+				if (mod > 0) {
+					--mod;
+					++partitionValue;
+				}
+				counterPartitions.add(partitionValue);
+			}
+		}
+
+		@Override
+		public void initializeState(FunctionInitializationContext context) throws Exception {
+			this.counterPartitions =
+					context.getManagedOperatorStateStore().getSerializableListState("counter_partitions");
+			if (context.isRestored()) {
+				for (int v : counterPartitions.get()) {
+					counter += v;
+				}
+				CHECK_CORRECT_RESTORE[getRuntimeContext().getIndexOfThisSubtask()] = counter;
+			}
+		}
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java
index 92e1f41..fc48719 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java
@@ -338,7 +338,8 @@ public class SavepointITCase extends TestLogger {
 
 					assertNotNull(subtaskState);
 					errMsg = "Initial operator state mismatch.";
-					assertEquals(errMsg, subtaskState.getChainedStateHandle(), tdd.getOperatorState());
+					assertEquals(errMsg, subtaskState.getLegacyOperatorState(),
+							tdd.getTaskStateHandles().getLegacyOperatorState());
 				}
 			}
 
@@ -364,7 +365,7 @@ public class SavepointITCase extends TestLogger {
 
 			for (TaskState stateForTaskGroup : savepoint.getTaskStates()) {
 				for (SubtaskState subtaskState : stateForTaskGroup.getStates()) {
-					ChainedStateHandle<StreamStateHandle> streamTaskState = subtaskState.getChainedStateHandle();
+					ChainedStateHandle<StreamStateHandle> streamTaskState = subtaskState.getLegacyOperatorState();
 
 					for (int i = 0; i < streamTaskState.getLength(); i++) {
 						if (streamTaskState.get(i) != null) {

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
index 2a635ab..963d18a 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
@@ -38,7 +38,7 @@ import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase;
 import org.junit.Test;
 
 import java.io.IOException;
-import java.util.List;
+import java.util.Collection;
 
 import static org.junit.Assert.fail;
 
@@ -119,7 +119,7 @@ public class StateBackendITCase extends StreamingMultipleProgramsTestBase {
 				TypeSerializer<K> keySerializer,
 				int numberOfKeyGroups,
 				KeyGroupRange keyGroupRange,
-				List<KeyGroupsStateHandle> restoredState,
+				Collection<KeyGroupsStateHandle> restoredState,
 				TaskKvStateRegistry kvStateRegistry) throws Exception {
 			throw new SuccessException();
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-yarn/src/main/java/org/apache/flink/yarn/cli/FlinkYarnSessionCli.java
----------------------------------------------------------------------
diff --git a/flink-yarn/src/main/java/org/apache/flink/yarn/cli/FlinkYarnSessionCli.java b/flink-yarn/src/main/java/org/apache/flink/yarn/cli/FlinkYarnSessionCli.java
index 88708b6..7ce040b 100644
--- a/flink-yarn/src/main/java/org/apache/flink/yarn/cli/FlinkYarnSessionCli.java
+++ b/flink-yarn/src/main/java/org/apache/flink/yarn/cli/FlinkYarnSessionCli.java
@@ -31,12 +31,12 @@ import org.apache.flink.configuration.Configuration;
 import org.apache.flink.configuration.GlobalConfiguration;
 import org.apache.flink.configuration.HighAvailabilityOptions;
 import org.apache.flink.configuration.IllegalConfigurationException;
+import org.apache.flink.runtime.clusterframework.ApplicationStatus;
+import org.apache.flink.runtime.clusterframework.messages.GetClusterStatusResponse;
 import org.apache.flink.runtime.security.SecurityContext;
 import org.apache.flink.yarn.AbstractYarnClusterDescriptor;
-import org.apache.flink.yarn.YarnClusterDescriptor;
 import org.apache.flink.yarn.YarnClusterClient;
-import org.apache.flink.runtime.clusterframework.ApplicationStatus;
-import org.apache.flink.runtime.clusterframework.messages.GetClusterStatusResponse;
+import org.apache.flink.yarn.YarnClusterDescriptor;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.yarn.util.ConverterUtils;
 import org.slf4j.Logger;


[4/8] flink git commit: [FLINK-4844] Partitionable Raw Keyed/Operator State

Posted by al...@apache.org.
http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/AtLeastOnceProducerTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/AtLeastOnceProducerTest.java b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/AtLeastOnceProducerTest.java
index d2d7fca..5e9bacc 100644
--- a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/AtLeastOnceProducerTest.java
+++ b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/AtLeastOnceProducerTest.java
@@ -20,6 +20,8 @@ package org.apache.flink.streaming.connectors.kafka;
 
 import org.apache.flink.api.java.tuple.Tuple1;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.runtime.state.StateSnapshotContextSynchronousImpl;
 import org.apache.flink.streaming.connectors.kafka.testutils.MockRuntimeContext;
 import org.apache.flink.streaming.util.serialization.KeyedSerializationSchema;
 import org.apache.flink.streaming.util.serialization.KeyedSerializationSchemaWrapper;
@@ -112,7 +114,7 @@ public class AtLeastOnceProducerTest {
 		Thread threadB = new Thread(confirmer);
 		threadB.start();
 		// this should block:
-		producer.prepareSnapshot(0, 0);
+		producer.snapshotState(new StateSnapshotContextSynchronousImpl(0, 0));
 		synchronized (threadA) {
 			threadA.notifyAll(); // just in case, to let the test fail faster
 		}
@@ -148,9 +150,9 @@ public class AtLeastOnceProducerTest {
 		}
 
 		@Override
-		public void prepareSnapshot(long checkpointId, long timestamp) throws Exception {
+		public void snapshotState(FunctionSnapshotContext ctx) throws Exception {
 			// call the actual snapshot state
-			super.prepareSnapshot(checkpointId, timestamp);
+			super.snapshotState(ctx);
 			// notify test that snapshotting has been done
 			snapshottingFinished.set(true);
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
index 97220c2..9b7eabf 100644
--- a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
+++ b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
@@ -19,19 +19,26 @@
 package org.apache.flink.streaming.connectors.kafka;
 
 import org.apache.commons.collections.map.LinkedMap;
+import org.apache.flink.api.common.functions.RuntimeContext;
 import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ListStateDescriptor;
-import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.common.state.OperatorStateStore;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContextSynchronousImpl;
 import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks;
 import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
 import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
 import org.apache.flink.streaming.connectors.kafka.internals.AbstractFetcher;
 import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
 import org.apache.flink.streaming.util.serialization.KeyedDeserializationSchema;
 import org.apache.flink.util.SerializedValue;
+import org.junit.Assert;
 import org.junit.Test;
 import org.mockito.Matchers;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
 
 import java.io.Serializable;
 import java.lang.reflect.Field;
@@ -47,6 +54,8 @@ import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
@@ -100,7 +109,7 @@ public class FlinkKafkaConsumerBaseTest {
 		TestingListState<Tuple2<KafkaTopicPartition, Long>> listState = new TestingListState<>();
 		when(operatorStateStore.getOperatorState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
 
-		consumer.prepareSnapshot(17L, 17L);
+		consumer.snapshotState(new StateSnapshotContextSynchronousImpl(1, 1));
 
 		assertFalse(listState.get().iterator().hasNext());
 		consumer.notifyCheckpointComplete(66L);
@@ -113,24 +122,30 @@ public class FlinkKafkaConsumerBaseTest {
 	public void checkRestoredCheckpointWhenFetcherNotReady() throws Exception {
 		OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
 
-		TestingListState<Serializable> expectedState = new TestingListState<>();
-		expectedState.add(Tuple2.of(new KafkaTopicPartition("abc", 13), 16768L));
-		expectedState.add(Tuple2.of(new KafkaTopicPartition("def", 7), 987654321L));
-
 		TestingListState<Serializable> listState = new TestingListState<>();
+		listState.add(Tuple2.of(new KafkaTopicPartition("abc", 13), 16768L));
+		listState.add(Tuple2.of(new KafkaTopicPartition("def", 7), 987654321L));
 
 		FlinkKafkaConsumerBase<String> consumer = getConsumer(null, new LinkedMap(), true);
 
-		when(operatorStateStore.getSerializableListState(Matchers.any(String.class))).thenReturn(expectedState);
-		consumer.initializeState(operatorStateStore);
-
 		when(operatorStateStore.getSerializableListState(Matchers.any(String.class))).thenReturn(listState);
 
-		consumer.prepareSnapshot(17L, 17L);
+		StateInitializationContext initializationContext = mock(StateInitializationContext.class);
+
+		when(initializationContext.getManagedOperatorStateStore()).thenReturn(operatorStateStore);
+		when(initializationContext.isRestored()).thenReturn(true);
+
+		consumer.initializeState(initializationContext);
+
+		consumer.snapshotState(new StateSnapshotContextSynchronousImpl(17, 17));
+
+		// ensure that the list was cleared and refilled. while this is an implementation detail, we use it here
+		// to figure out that snapshotState() actually did something.
+		Assert.assertTrue(listState.isClearCalled());
 
 		Set<Serializable> expected = new HashSet<>();
 
-		for (Serializable serializable : expectedState.get()) {
+		for (Serializable serializable : listState.get()) {
 			expected.add(serializable);
 		}
 
@@ -155,8 +170,14 @@ public class FlinkKafkaConsumerBaseTest {
 		TestingListState<Serializable> listState = new TestingListState<>();
 		when(operatorStateStore.getSerializableListState(Matchers.any(String.class))).thenReturn(listState);
 
-		consumer.initializeState(operatorStateStore);
-		consumer.prepareSnapshot(17L, 17L);
+		StateInitializationContext initializationContext = mock(StateInitializationContext.class);
+
+		when(initializationContext.getManagedOperatorStateStore()).thenReturn(operatorStateStore);
+		when(initializationContext.isRestored()).thenReturn(false);
+
+		consumer.initializeState(initializationContext);
+
+		consumer.snapshotState(new StateSnapshotContextSynchronousImpl(17, 17));
 
 		assertFalse(listState.get().iterator().hasNext());
 	}
@@ -165,6 +186,28 @@ public class FlinkKafkaConsumerBaseTest {
 	 * Tests that on snapshots, states and offsets to commit to Kafka are correct
 	 */
 	@Test
+	public void checkUseFetcherWhenNoCheckpoint() throws Exception {
+
+		FlinkKafkaConsumerBase<String> consumer = getConsumer(null, new LinkedMap(), true);
+		List<KafkaTopicPartition> partitionList = new ArrayList<>(1);
+		partitionList.add(new KafkaTopicPartition("test", 0));
+		consumer.setSubscribedPartitions(partitionList);
+
+		OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
+		TestingListState<Serializable> listState = new TestingListState<>();
+		when(operatorStateStore.getSerializableListState(Matchers.any(String.class))).thenReturn(listState);
+
+		StateInitializationContext initializationContext = mock(StateInitializationContext.class);
+
+		when(initializationContext.getManagedOperatorStateStore()).thenReturn(operatorStateStore);
+
+		// make the context signal that there is no restored state, then validate that
+		when(initializationContext.isRestored()).thenReturn(false);
+		consumer.initializeState(initializationContext);
+		consumer.run(mock(SourceFunction.SourceContext.class));
+	}
+
+	@Test
 	@SuppressWarnings("unchecked")
 	public void testSnapshotState() throws Exception {
 
@@ -196,22 +239,23 @@ public class FlinkKafkaConsumerBaseTest {
 
 		OperatorStateStore backend = mock(OperatorStateStore.class);
 
-		TestingListState<Serializable> init = new TestingListState<>();
-		TestingListState<Serializable> listState1 = new TestingListState<>();
-		TestingListState<Serializable> listState2 = new TestingListState<>();
-		TestingListState<Serializable> listState3 = new TestingListState<>();
+		TestingListState<Serializable> listState = new TestingListState<>();
+
+		when(backend.getSerializableListState(Matchers.any(String.class))).thenReturn(listState);
+
+		StateInitializationContext initializationContext = mock(StateInitializationContext.class);
 
-		when(backend.getSerializableListState(Matchers.any(String.class))).
-				thenReturn(init, listState1, listState2, listState3);
+		when(initializationContext.getManagedOperatorStateStore()).thenReturn(backend);
+		when(initializationContext.isRestored()).thenReturn(false, true, true, true);
 
-		consumer.initializeState(backend);
+		consumer.initializeState(initializationContext);
 
 		// checkpoint 1
-		consumer.prepareSnapshot(138L, 138L);
+		consumer.snapshotState(new StateSnapshotContextSynchronousImpl(138, 138));
 
 		HashMap<KafkaTopicPartition, Long> snapshot1 = new HashMap<>();
 
-		for (Serializable serializable : listState1.get()) {
+		for (Serializable serializable : listState.get()) {
 			Tuple2<KafkaTopicPartition, Long> kafkaTopicPartitionLongTuple2 = (Tuple2<KafkaTopicPartition, Long>) serializable;
 			snapshot1.put(kafkaTopicPartitionLongTuple2.f0, kafkaTopicPartitionLongTuple2.f1);
 		}
@@ -221,11 +265,11 @@ public class FlinkKafkaConsumerBaseTest {
 		assertEquals(state1, pendingOffsetsToCommit.get(138L));
 
 		// checkpoint 2
-		consumer.prepareSnapshot(140L, 140L);
+		consumer.snapshotState(new StateSnapshotContextSynchronousImpl(140, 140));
 
 		HashMap<KafkaTopicPartition, Long> snapshot2 = new HashMap<>();
 
-		for (Serializable serializable : listState2.get()) {
+		for (Serializable serializable : listState.get()) {
 			Tuple2<KafkaTopicPartition, Long> kafkaTopicPartitionLongTuple2 = (Tuple2<KafkaTopicPartition, Long>) serializable;
 			snapshot2.put(kafkaTopicPartitionLongTuple2.f0, kafkaTopicPartitionLongTuple2.f1);
 		}
@@ -240,11 +284,11 @@ public class FlinkKafkaConsumerBaseTest {
 		assertTrue(pendingOffsetsToCommit.containsKey(140L));
 
 		// checkpoint 3
-		consumer.prepareSnapshot(141L, 141L);
+		consumer.snapshotState(new StateSnapshotContextSynchronousImpl(141, 141));
 
 		HashMap<KafkaTopicPartition, Long> snapshot3 = new HashMap<>();
 
-		for (Serializable serializable : listState3.get()) {
+		for (Serializable serializable : listState.get()) {
 			Tuple2<KafkaTopicPartition, Long> kafkaTopicPartitionLongTuple2 = (Tuple2<KafkaTopicPartition, Long>) serializable;
 			snapshot3.put(kafkaTopicPartitionLongTuple2.f0, kafkaTopicPartitionLongTuple2.f1);
 		}
@@ -262,12 +306,12 @@ public class FlinkKafkaConsumerBaseTest {
 		assertEquals(0, pendingOffsetsToCommit.size());
 
 		OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
-		TestingListState<Tuple2<KafkaTopicPartition, Long>> listState = new TestingListState<>();
+		listState = new TestingListState<>();
 		when(operatorStateStore.getOperatorState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
 
 		// create 500 snapshots
 		for (int i = 100; i < 600; i++) {
-			consumer.prepareSnapshot(i, i);
+			consumer.snapshotState(new StateSnapshotContextSynchronousImpl(i, i));
 			listState.clear();
 		}
 		assertEquals(FlinkKafkaConsumerBase.MAX_NUM_PENDING_CHECKPOINTS, pendingOffsetsToCommit.size());
@@ -308,7 +352,7 @@ public class FlinkKafkaConsumerBaseTest {
 
 	// ------------------------------------------------------------------------
 
-	private static final class DummyFlinkKafkaConsumer<T> extends FlinkKafkaConsumerBase<T> {
+	private static class DummyFlinkKafkaConsumer<T> extends FlinkKafkaConsumerBase<T> {
 		private static final long serialVersionUID = 1L;
 
 		@SuppressWarnings("unchecked")
@@ -318,22 +362,37 @@ public class FlinkKafkaConsumerBaseTest {
 
 		@Override
 		protected AbstractFetcher<T, ?> createFetcher(SourceContext<T> sourceContext, List<KafkaTopicPartition> thisSubtaskPartitions, SerializedValue<AssignerWithPeriodicWatermarks<T>> watermarksPeriodic, SerializedValue<AssignerWithPunctuatedWatermarks<T>> watermarksPunctuated, StreamingRuntimeContext runtimeContext) throws Exception {
-			return null;
+			AbstractFetcher<T, ?> fetcher = mock(AbstractFetcher.class);
+			doAnswer(new Answer() {
+				@Override
+				public Object answer(InvocationOnMock invocationOnMock) throws Throwable {
+					Assert.fail("Trying to restore offsets even though there was no restore state.");
+					return null;
+				}
+			}).when(fetcher).restoreOffsets(any(HashMap.class));
+			return fetcher;
 		}
 
 		@Override
 		protected List<KafkaTopicPartition> getKafkaPartitions(List<String> topics) {
 			return Collections.emptyList();
 		}
+
+		@Override
+		public RuntimeContext getRuntimeContext() {
+			return mock(StreamingRuntimeContext.class);
+		}
 	}
 
 	private static final class TestingListState<T> implements ListState<T> {
 
 		private final List<T> list = new ArrayList<>();
+		private boolean clearCalled = false;
 
 		@Override
 		public void clear() {
 			list.clear();
+			clearCalled = true;
 		}
 
 		@Override
@@ -345,5 +404,13 @@ public class FlinkKafkaConsumerBaseTest {
 		public void add(T value) throws Exception {
 			list.add(value);
 		}
+
+		public List<T> getList() {
+			return list;
+		}
+
+		public boolean isClearCalled() {
+			return clearCalled;
+		}
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java
index 4a0fd60..7af5cea 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java
@@ -38,7 +38,7 @@ import java.io.Serializable;
  */
 @Deprecated
 @PublicEvolving
-public interface Checkpointed<T extends Serializable> {
+public interface Checkpointed<T extends Serializable> extends CheckpointedRestoring<T> {
 
 	/**
 	 * Gets the current state of the function of operator. The state must reflect the result of all
@@ -56,14 +56,4 @@ public interface Checkpointed<T extends Serializable> {
 	 *                   and to try again with the next checkpoint attempt.
 	 */
 	T snapshotState(long checkpointId, long checkpointTimestamp) throws Exception;
-
-	/**
-	 * Restores the state of the function or operator to that of a previous checkpoint.
-	 * This method is invoked when a function is executed as part of a recovery run.
-	 *
-	 * Note that restoreState() is called before open().
-	 *
-	 * @param state The state to be restored. 
-	 */
-	void restoreState(T state) throws Exception;
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedFunction.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedFunction.java
index 777cb91..37d8244 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedFunction.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedFunction.java
@@ -20,46 +20,48 @@ package org.apache.flink.streaming.api.checkpoint;
 
 import org.apache.flink.annotation.PublicEvolving;
 import org.apache.flink.api.common.state.OperatorStateStore;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
 
 /**
  *
  * Similar to @{@link Checkpointed}, this interface must be implemented by functions that have potentially
  * repartitionable state that needs to be checkpointed. Methods from this interface are called upon checkpointing and
- * restoring of state.
+ * initialization of state.
  *
- * On #initializeState the implementing class receives the {@link OperatorStateStore}
- * to store it's state. At least before each snapshot, all state persistent state must be stored in the state store.
+ * On {@link #initializeState(FunctionInitializationContext)} the implementing class receives a
+ * {@link FunctionInitializationContext} which provides access to the {@link OperatorStateStore} (all) and
+ * {@link org.apache.flink.api.common.state.KeyedStateStore} (only for keyed operators). Those allow to register
+ * managed operator / keyed  user states. Furthermore, the context provides information whether or the operator was
+ * restored.
  *
- * When the backend is received for initialization, the user registers states with the backend via
- * {@link org.apache.flink.api.common.state.StateDescriptor}. Then, all previously stored state is found in the
- * received {@link org.apache.flink.api.common.state.State} (currently only
- * {@link org.apache.flink.api.common.state.ListState} is supported.
  *
- * In #prepareSnapshot, the implementing class must ensure that all operator state is passed to the operator backend,
- * i.e. that the state was stored in the relevant {@link org.apache.flink.api.common.state.State} instances that
- * are requested on restore. Notice that users might want to clear and reinsert the complete state first if incremental
- * updates of the states are not possible.
+ * In {@link #snapshotState(FunctionSnapshotContext)} the implementing class must ensure that all operator / keyed state
+ * is passed to user states that have been registered during initialization, so that it is visible to the system
+ * backends for checkpointing.
+ *
  */
 @PublicEvolving
 public interface CheckpointedFunction {
 
 	/**
+	 * This method is called when a snapshot for a checkpoint is requested. This acts as a hook to the function to
+	 * ensure that all state is exposed by means previously offered through {@link FunctionInitializationContext} when
+	 * the Function was initialized, or offered now by {@link FunctionSnapshotContext} itself.
 	 *
-	 * This method is called when state should be stored for a checkpoint. The state can be registered and written to
-	 * the provided backend.
-	 *
-	 * @param checkpointId Id of the checkpoint to perform
-	 * @param timestamp Timestamp of the checkpoint
+	 * @param context the context for drawing a snapshot of the operator
 	 * @throws Exception
 	 */
-	void prepareSnapshot(long checkpointId, long timestamp) throws Exception;
+	void snapshotState(FunctionSnapshotContext context) throws Exception;
 
 	/**
-	 * This method is called when an operator is opened, so that the function can set the state backend to which it
-	 * hands it's state on snapshot.
+	 * This method is called when an operator is initialized, so that the function can set up it's state through
+	 * the provided context. Initialization typically includes registering user states through the state stores
+	 * that the context offers.
 	 *
-	 * @param stateStore the state store to which this function stores it's state
+	 * @param context the context for initializing the operator
 	 * @throws Exception
 	 */
-	void initializeState(OperatorStateStore stateStore) throws Exception;
+	void initializeState(FunctionInitializationContext context) throws Exception;
+
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedRestoring.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedRestoring.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedRestoring.java
new file mode 100644
index 0000000..c0dd361
--- /dev/null
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedRestoring.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.checkpoint;
+
+import org.apache.flink.annotation.PublicEvolving;
+
+import java.io.Serializable;
+
+/**
+ * This deprecated interface contains the methods for restoring from the legacy checkpointing mechanism of state.
+ * @param <T> type of the restored state.
+ */
+@Deprecated
+@PublicEvolving
+public interface CheckpointedRestoring<T extends Serializable> {
+	/**
+	 * Restores the state of the function or operator to that of a previous checkpoint.
+	 * This method is invoked when a function is executed as part of a recovery run.
+	 *
+	 * Note that restoreState() is called before open().
+	 *
+	 * @param state The state to be restored.
+	 */
+	void restoreState(T state) throws Exception;
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
index 167dfb0..9184e93 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
@@ -22,38 +22,47 @@ import org.apache.commons.io.IOUtils;
 import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
 import org.apache.flink.annotation.PublicEvolving;
 import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.state.KeyedStateStore;
 import org.apache.flink.api.common.state.State;
 import org.apache.flink.api.common.state.StateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.metrics.Counter;
 import org.apache.flink.metrics.Gauge;
 import org.apache.flink.metrics.MetricGroup;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.runtime.state.DefaultKeyedStateStore;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.OperatorStateBackend;
 import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateInitializationContextImpl;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.runtime.state.StateSnapshotContextSynchronousImpl;
+import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
 import org.apache.flink.streaming.runtime.tasks.TimeServiceProvider;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.util.Collection;
+import java.util.ConcurrentModificationException;
 import java.util.HashMap;
 import java.util.Map;
-import java.util.ConcurrentModificationException;
-import java.util.Collection;
-import java.util.concurrent.RunnableFuture;
 
 /**
  * Base class for all stream operators. Operators that contain a user function should extend the class 
@@ -97,6 +106,7 @@ public abstract class AbstractStreamOperator<OUT>
 	private transient StreamingRuntimeContext runtimeContext;
 
 
+
 	// ---------------- key/value state ------------------
 
 	/** key selector used to get the key for the state. Non-null only is the operator uses key/value state */
@@ -106,11 +116,12 @@ public abstract class AbstractStreamOperator<OUT>
 	/** Backend for keyed state. This might be empty if we're not on a keyed stream. */
 	private transient AbstractKeyedStateBackend<?> keyedStateBackend;
 
-	/** Operator state backend */
+	/** Keyed state store view on the keyed backend */
+	private transient DefaultKeyedStateStore keyedStateStore;
+	
+	/** Operator state backend / store */
 	private transient OperatorStateBackend operatorStateBackend;
 
-	private transient Collection<OperatorStateHandle> lazyRestoreStateHandles;
-
 
 	// --------------- Metrics ---------------------------
 
@@ -151,8 +162,61 @@ public abstract class AbstractStreamOperator<OUT>
 	}
 
 	@Override
-	public void restoreState(Collection<OperatorStateHandle> stateHandles) {
-		this.lazyRestoreStateHandles = stateHandles;
+	public final void initializeState(OperatorStateHandles stateHandles) throws Exception {
+
+		Collection<KeyGroupsStateHandle> keyedStateHandlesRaw = null;
+		Collection<OperatorStateHandle> operatorStateHandlesRaw = null;
+		Collection<OperatorStateHandle> operatorStateHandlesBackend = null;
+
+		boolean restoring = null != stateHandles;
+
+		initKeyedState(); //TODO we should move the actual initialization of this from StreamTask to this class
+
+		if (restoring) {
+
+			// TODO check that there is EITHER old OR new state in handles!
+			restoreStreamCheckpointed(stateHandles);
+
+			//pass directly
+			operatorStateHandlesBackend = stateHandles.getManagedOperatorState();
+			operatorStateHandlesRaw = stateHandles.getRawOperatorState();
+
+			if (null != getKeyedStateBackend()) {
+				//only use the keyed state if it is meant for us (aka head operator)
+				keyedStateHandlesRaw = stateHandles.getRawKeyedState();
+			}
+		}
+
+		initOperatorState(operatorStateHandlesBackend);
+
+		StateInitializationContext initializationContext = new StateInitializationContextImpl(
+				restoring, // information whether we restore or start for the first time
+				operatorStateBackend, // access to operator state backend
+				keyedStateStore, // access to keyed state backend
+				keyedStateHandlesRaw, // access to keyed state stream
+				operatorStateHandlesRaw, // access to operator state stream
+				getContainingTask().getCancelables()); // access to register streams for canceling
+
+		initializeState(initializationContext);
+	}
+
+	@Deprecated
+	private void restoreStreamCheckpointed(OperatorStateHandles stateHandles) throws Exception {
+		StreamStateHandle state = stateHandles.getLegacyOperatorState();
+		if (this instanceof StreamCheckpointedOperator && null != state) {
+
+			LOG.debug("Restore state of task {} in chain ({}).",
+					stateHandles.getOperatorChainIndex(), getContainingTask().getName());
+
+			FSDataInputStream is = state.openInputStream();
+			try {
+				getContainingTask().getCancelables().registerClosable(is);
+				((StreamCheckpointedOperator) this).restoreState(is);
+			} finally {
+				getContainingTask().getCancelables().unregisterClosable(is);
+				is.close();
+			}
+		}
 	}
 
 	/**
@@ -165,8 +229,7 @@ public abstract class AbstractStreamOperator<OUT>
 	 */
 	@Override
 	public void open() throws Exception {
-		initOperatorState();
-		initKeyedState();
+
 	}
 
 	private void initKeyedState() {
@@ -174,7 +237,6 @@ public abstract class AbstractStreamOperator<OUT>
 			TypeSerializer<Object> keySerializer = config.getStateKeySerializer(getUserCodeClassloader());
 			// create a keyed state backend if there is keyed state, as indicated by the presence of a key serializer
 			if (null != keySerializer) {
-
 				KeyGroupRange subTaskKeyGroupRange = KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(
 						container.getEnvironment().getTaskInfo().getNumberOfKeyGroups(),
 						container.getEnvironment().getTaskInfo().getNumberOfParallelSubtasks(),
@@ -184,7 +246,8 @@ public abstract class AbstractStreamOperator<OUT>
 						keySerializer,
 						container.getConfiguration().getNumberOfKeyGroups(getUserCodeClassloader()),
 						subTaskKeyGroupRange);
-
+				
+				this.keyedStateStore = new DefaultKeyedStateStore(keyedStateBackend, getExecutionConfig());
 			}
 
 		} catch (Exception e) {
@@ -192,10 +255,10 @@ public abstract class AbstractStreamOperator<OUT>
 		}
 	}
 
-	private void initOperatorState() {
+	private void initOperatorState(Collection<OperatorStateHandle> operatorStateHandles) {
 		try {
 			// create an operator state backend
-			this.operatorStateBackend = container.createOperatorStateBackend(this, lazyRestoreStateHandles);
+			this.operatorStateBackend = container.createOperatorStateBackend(this, operatorStateHandles);
 		} catch (Exception e) {
 			throw new IllegalStateException("Could not initialize operator state backend.", e);
 		}
@@ -238,11 +301,51 @@ public abstract class AbstractStreamOperator<OUT>
 	}
 
 	@Override
-	public RunnableFuture<OperatorStateHandle> snapshotState(
+	public final OperatorSnapshotResult snapshotState(
 			long checkpointId, long timestamp, CheckpointStreamFactory streamFactory) throws Exception {
 
-		return operatorStateBackend != null ?
-				operatorStateBackend.snapshot(checkpointId, timestamp, streamFactory) : null;
+		KeyGroupRange keyGroupRange = null != keyedStateBackend ?
+				keyedStateBackend.getKeyGroupRange() : KeyGroupRange.EMPTY_KEY_GROUP_RANGE;
+
+		StateSnapshotContextSynchronousImpl snapshotContext = new StateSnapshotContextSynchronousImpl(
+				checkpointId, timestamp, streamFactory, keyGroupRange, getContainingTask().getCancelables());
+
+		snapshotState(snapshotContext);
+
+		OperatorSnapshotResult snapshotInProgress = new OperatorSnapshotResult();
+
+		snapshotInProgress.setKeyedStateRawFuture(snapshotContext.getKeyedStateStreamFuture());
+		snapshotInProgress.setOperatorStateRawFuture(snapshotContext.getOperatorStateStreamFuture());
+
+		if (null != operatorStateBackend) {
+			snapshotInProgress.setOperatorStateManagedFuture(
+					operatorStateBackend.snapshot(checkpointId, timestamp, streamFactory));
+		}
+
+		if (null != keyedStateBackend) {
+			snapshotInProgress.setKeyedStateManagedFuture(
+					keyedStateBackend.snapshot(checkpointId, timestamp, streamFactory));
+		}
+
+		return snapshotInProgress;
+	}
+
+	/**
+	 * Stream operators with state, which want to participate in a snapshot need to override this hook method.
+	 *
+	 * @param context context that provides information and means required for taking a snapshot
+	 */
+	public void snapshotState(StateSnapshotContext context) throws Exception {
+
+	}
+
+	/**
+	 * Stream operators with state which can be restored need to override this hook method.
+	 *
+	 * @param context context that allows to register different states.
+	 */
+	public void initializeState(StateInitializationContext context) throws Exception {
+
 	}
 
 	@Override
@@ -283,22 +386,12 @@ public abstract class AbstractStreamOperator<OUT>
 		return runtimeContext;
 	}
 
-	@SuppressWarnings("rawtypes, unchecked")
+	@SuppressWarnings("unchecked")
 	public <K> KeyedStateBackend<K> getKeyedStateBackend() {
-
-		if (null == keyedStateBackend) {
-			initKeyedState();
-		}
-
 		return (KeyedStateBackend<K>) keyedStateBackend;
 	}
 
 	public OperatorStateBackend getOperatorStateBackend() {
-
-		if (null == operatorStateBackend) {
-			initOperatorState();
-		}
-
 		return operatorStateBackend;
 	}
 
@@ -327,12 +420,12 @@ public abstract class AbstractStreamOperator<OUT>
 	 * @throws Exception Thrown, if the state backend cannot create the key/value state.
 	 */
 	@SuppressWarnings("unchecked")
-	protected <S extends State, N> S getPartitionedState(N namespace, TypeSerializer<N> namespaceSerializer, StateDescriptor<S, ?> stateDescriptor) throws Exception {
-		if (keyedStateBackend != null) {
-			return keyedStateBackend.getPartitionedState(
-					namespace,
-					namespaceSerializer,
-					stateDescriptor);
+	protected <S extends State, N> S getPartitionedState(
+			N namespace, TypeSerializer<N> namespaceSerializer, 
+			StateDescriptor<S, ?> stateDescriptor) throws Exception {
+		
+		if (keyedStateStore != null) {
+			return keyedStateBackend.getPartitionedState(namespace, namespaceSerializer, stateDescriptor);
 		} else {
 			throw new RuntimeException("Cannot create partitioned state. The keyed state " +
 				"backend has not been set. This indicates that the operator is not " +
@@ -343,18 +436,18 @@ public abstract class AbstractStreamOperator<OUT>
 	@Override
 	@SuppressWarnings({"unchecked", "rawtypes"})
 	public void setKeyContextElement1(StreamRecord record) throws Exception {
-		setRawKeyContextElement(record, stateKeySelector1);
+		setKeyContextElement(record, stateKeySelector1);
 	}
 
 	@Override
 	@SuppressWarnings({"unchecked", "rawtypes"})
 	public void setKeyContextElement2(StreamRecord record) throws Exception {
-		setRawKeyContextElement(record, stateKeySelector2);
+		setKeyContextElement(record, stateKeySelector2);
 	}
 
-	private void setRawKeyContextElement(StreamRecord record, KeySelector<?, ?> selector) throws Exception {
+	private <T> void setKeyContextElement(StreamRecord<T> record, KeySelector<T, ?> selector) throws Exception {
 		if (selector != null) {
-			Object key = ((KeySelector) selector).getKey(record.getValue());
+			Object key = selector.getKey(record.getValue());
 			setKeyContext(key);
 		}
 	}
@@ -374,6 +467,10 @@ public abstract class AbstractStreamOperator<OUT>
 		}
 	}
 
+	public KeyedStateStore getKeyedStateStore() {
+		return keyedStateStore;
+	}
+
 	// ------------------------------------------------------------------------
 	//  Context and chaining properties
 	// ------------------------------------------------------------------------
@@ -567,4 +664,5 @@ public abstract class AbstractStreamOperator<OUT>
 			output.close();
 		}
 	}
+
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
index 72f30b8..5e1a252 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
@@ -28,11 +28,12 @@ import org.apache.flink.configuration.Configuration;
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.runtime.state.CheckpointListener;
-import org.apache.flink.runtime.state.CheckpointStreamFactory;
-import org.apache.flink.runtime.state.OperatorStateHandle;
-import org.apache.flink.api.common.state.OperatorStateStore;
+import org.apache.flink.runtime.state.DefaultOperatorStateBackend;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
 import org.apache.flink.streaming.api.checkpoint.Checkpointed;
 import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedRestoring;
 import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
@@ -42,7 +43,6 @@ import org.apache.flink.util.InstantiationUtil;
 import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.List;
-import java.util.concurrent.RunnableFuture;
 
 import static java.util.Objects.requireNonNull;
 
@@ -73,6 +73,7 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function>
 	
 	public AbstractUdfStreamOperator(F userFunction) {
 		this.userFunction = requireNonNull(userFunction);
+		checkUdfCheckpointingPreconditions();
 	}
 
 	/**
@@ -93,22 +94,44 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function>
 		super.setup(containingTask, config, output);
 		
 		FunctionUtils.setFunctionRuntimeContext(userFunction, getRuntimeContext());
+
 	}
 
 	@Override
-	public void open() throws Exception {
-		super.open();
-		
-		FunctionUtils.openFunction(userFunction, new Configuration());
+	public void snapshotState(StateSnapshotContext context) throws Exception {
+		super.snapshotState(context);
 
 		if (userFunction instanceof CheckpointedFunction) {
-			((CheckpointedFunction) userFunction).initializeState(getOperatorStateBackend());
+			((CheckpointedFunction) userFunction).snapshotState(context);
 		} else if (userFunction instanceof ListCheckpointed) {
 			@SuppressWarnings("unchecked")
-			ListCheckpointed<Serializable> listCheckpointedFun = (ListCheckpointed<Serializable>) userFunction;
+			List<Serializable> partitionableState = ((ListCheckpointed<Serializable>) userFunction).
+							snapshotState(context.getCheckpointId(), context.getCheckpointTimestamp());
 
 			ListState<Serializable> listState = getOperatorStateBackend().
-					getSerializableListState(OperatorStateStore.DEFAULT_OPERATOR_STATE_NAME);
+					getSerializableListState(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME);
+
+			listState.clear();
+
+			for (Serializable statePartition : partitionableState) {
+				listState.add(statePartition);
+			}
+		}
+
+	}
+
+	@Override
+	public void initializeState(StateInitializationContext context) throws Exception {
+		super.initializeState(context);
+
+		if (userFunction instanceof CheckpointedFunction) {
+			((CheckpointedFunction) userFunction).initializeState(context);
+		} else if (context.isRestored() && userFunction instanceof ListCheckpointed) {
+			@SuppressWarnings("unchecked")
+			ListCheckpointed<Serializable> listCheckpointedFun = (ListCheckpointed<Serializable>) userFunction;
+
+			ListState<Serializable> listState = context.getManagedOperatorStateStore().
+					getSerializableListState(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME);
 
 			List<Serializable> list = new ArrayList<>();
 
@@ -122,6 +145,13 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function>
 				throw new Exception("Failed to restore state to function: " + e.getMessage(), e);
 			}
 		}
+
+	}
+
+	@Override
+	public void open() throws Exception {
+		super.open();
+		FunctionUtils.openFunction(userFunction, new Configuration());
 	}
 
 	@Override
@@ -147,6 +177,7 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function>
 	@Override
 	public void snapshotState(FSDataOutputStream out, long checkpointId, long timestamp) throws Exception {
 
+
 		if (userFunction instanceof Checkpointed) {
 			@SuppressWarnings("unchecked")
 			Checkpointed<Serializable> chkFunction = (Checkpointed<Serializable>) userFunction;
@@ -169,9 +200,9 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function>
 	@Override
 	public void restoreState(FSDataInputStream in) throws Exception {
 
-		if (userFunction instanceof Checkpointed) {
+		if (userFunction instanceof CheckpointedRestoring) {
 			@SuppressWarnings("unchecked")
-			Checkpointed<Serializable> chkFunction = (Checkpointed<Serializable>) userFunction;
+			CheckpointedRestoring<Serializable> chkFunction = (CheckpointedRestoring<Serializable>) userFunction;
 
 			int hasUdfState = in.read();
 
@@ -189,32 +220,6 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function>
 	}
 
 	@Override
-	public RunnableFuture<OperatorStateHandle> snapshotState(
-			long checkpointId, long timestamp, CheckpointStreamFactory streamFactory) throws Exception {
-
-		if (userFunction instanceof CheckpointedFunction) {
-			((CheckpointedFunction) userFunction).prepareSnapshot(checkpointId, timestamp);
-		}
-
-		if (userFunction instanceof ListCheckpointed) {
-			@SuppressWarnings("unchecked")
-			List<Serializable> partitionableState =
-					((ListCheckpointed<Serializable>) userFunction).snapshotState(checkpointId, timestamp);
-
-			ListState<Serializable> listState = getOperatorStateBackend().
-					getSerializableListState(OperatorStateStore.DEFAULT_OPERATOR_STATE_NAME);
-
-			listState.clear();
-
-			for (Serializable statePartition : partitionableState) {
-				listState.add(statePartition);
-			}
-		}
-
-		return super.snapshotState(checkpointId, timestamp, streamFactory);
-	}
-
-	@Override
 	public void notifyOfCompletedCheckpoint(long checkpointId) throws Exception {
 		super.notifyOfCompletedCheckpoint(checkpointId);
 
@@ -251,4 +256,26 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function>
 	public Configuration getUserFunctionParameters() {
 		return new Configuration();
 	}
+
+	private void checkUdfCheckpointingPreconditions() {
+
+		boolean newCheckpointInferface = false;
+
+		if (userFunction instanceof CheckpointedFunction) {
+			newCheckpointInferface = true;
+		}
+
+		if (userFunction instanceof ListCheckpointed) {
+			if (newCheckpointInferface) {
+				throw new IllegalStateException("User functions are not allowed to implement " +
+						"CheckpointedFunction AND ListCheckpointed.");
+			}
+			newCheckpointInferface = true;
+		}
+
+		if (newCheckpointInferface && userFunction instanceof Checkpointed) {
+			throw new IllegalStateException("User functions are not allowed to implement Checkpointed AND " +
+					"CheckpointedFunction/ListCheckpointed.");
+		}
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/OperatorSnapshotResult.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/OperatorSnapshotResult.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/OperatorSnapshotResult.java
new file mode 100644
index 0000000..52c89f8
--- /dev/null
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/OperatorSnapshotResult.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators;
+
+import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
+
+import java.util.concurrent.RunnableFuture;
+
+/**
+ * Result of {@link AbstractStreamOperator#snapshotState}.
+ */
+public class OperatorSnapshotResult {
+
+	private RunnableFuture<KeyGroupsStateHandle> keyedStateManagedFuture;
+	private RunnableFuture<KeyGroupsStateHandle> keyedStateRawFuture;
+	private RunnableFuture<OperatorStateHandle> operatorStateManagedFuture;
+	private RunnableFuture<OperatorStateHandle> operatorStateRawFuture;
+
+	public OperatorSnapshotResult() {
+	}
+
+	public OperatorSnapshotResult(
+			RunnableFuture<KeyGroupsStateHandle> keyedStateManagedFuture,
+			RunnableFuture<KeyGroupsStateHandle> keyedStateRawFuture,
+			RunnableFuture<OperatorStateHandle> operatorStateManagedFuture,
+			RunnableFuture<OperatorStateHandle> operatorStateRawFuture) {
+		this.keyedStateManagedFuture = keyedStateManagedFuture;
+		this.keyedStateRawFuture = keyedStateRawFuture;
+		this.operatorStateManagedFuture = operatorStateManagedFuture;
+		this.operatorStateRawFuture = operatorStateRawFuture;
+	}
+
+	public RunnableFuture<KeyGroupsStateHandle> getKeyedStateManagedFuture() {
+		return keyedStateManagedFuture;
+	}
+
+	public void setKeyedStateManagedFuture(RunnableFuture<KeyGroupsStateHandle> keyedStateManagedFuture) {
+		this.keyedStateManagedFuture = keyedStateManagedFuture;
+	}
+
+	public RunnableFuture<KeyGroupsStateHandle> getKeyedStateRawFuture() {
+		return keyedStateRawFuture;
+	}
+
+	public void setKeyedStateRawFuture(RunnableFuture<KeyGroupsStateHandle> keyedStateRawFuture) {
+		this.keyedStateRawFuture = keyedStateRawFuture;
+	}
+
+	public RunnableFuture<OperatorStateHandle> getOperatorStateManagedFuture() {
+		return operatorStateManagedFuture;
+	}
+
+	public void setOperatorStateManagedFuture(RunnableFuture<OperatorStateHandle> operatorStateManagedFuture) {
+		this.operatorStateManagedFuture = operatorStateManagedFuture;
+	}
+
+	public RunnableFuture<OperatorStateHandle> getOperatorStateRawFuture() {
+		return operatorStateRawFuture;
+	}
+
+	public void setOperatorStateRawFuture(RunnableFuture<OperatorStateHandle> operatorStateRawFuture) {
+		this.operatorStateRawFuture = operatorStateRawFuture;
+	}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java
index fae5fd0..f6e5472 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java
@@ -20,14 +20,12 @@ package org.apache.flink.streaming.api.operators;
 import org.apache.flink.annotation.PublicEvolving;
 import org.apache.flink.metrics.MetricGroup;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
-import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
 
 import java.io.Serializable;
-import java.util.Collection;
-import java.util.concurrent.RunnableFuture;
 
 /**
  * Basic interface for stream operators. Implementers would implement one of
@@ -105,7 +103,7 @@ public interface StreamOperator<OUT> extends Serializable {
 	 * the runnable might already be finished.
 	 * @throws Exception exception that happened during snapshotting.
 	 */
-	RunnableFuture<OperatorStateHandle> snapshotState(
+	OperatorSnapshotResult snapshotState(
 			long checkpointId, long timestamp, CheckpointStreamFactory streamFactory) throws Exception;
 
 	/**
@@ -113,7 +111,7 @@ public interface StreamOperator<OUT> extends Serializable {
 	 *
 	 * @param stateHandles state handles to the operator state.
 	 */
-	void restoreState(Collection<OperatorStateHandle> stateHandles);
+	void initializeState(OperatorStateHandles stateHandles) throws Exception;
 
 	/**
 	 * Called when the checkpoint with the given ID is completed and acknowledged on the JobManager.

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java
index cc2e54b..cd0489f 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java
@@ -37,8 +37,6 @@ import org.apache.flink.streaming.runtime.tasks.TimeServiceProvider;
 import java.util.List;
 import java.util.Map;
 
-import static java.util.Objects.requireNonNull;
-
 /**
  * Implementation of the {@link org.apache.flink.api.common.functions.RuntimeContext},
  * for streaming operators.
@@ -108,36 +106,17 @@ public class StreamingRuntimeContext extends AbstractRuntimeUDFContext {
 
 	@Override
 	public <T> ValueState<T> getState(ValueStateDescriptor<T> stateProperties) {
-		requireNonNull(stateProperties, "The state properties must not be null");
-		try {
-			stateProperties.initializeSerializerUnlessSet(getExecutionConfig());
-			return operator.getPartitionedState(stateProperties);
-		} catch (Exception e) {
-			throw new RuntimeException("Error while getting state", e);
-		}
+		return operator.getKeyedStateStore().getState(stateProperties);
 	}
 
 	@Override
 	public <T> ListState<T> getListState(ListStateDescriptor<T> stateProperties) {
-		requireNonNull(stateProperties, "The state properties must not be null");
-		try {
-			stateProperties.initializeSerializerUnlessSet(getExecutionConfig());
-			ListState<T> originalState = operator.getPartitionedState(stateProperties);
-			return new UserFacingListState<T>(originalState);
-		} catch (Exception e) {
-			throw new RuntimeException("Error while getting state", e);
-		}
+		return operator.getKeyedStateStore().getListState(stateProperties);
 	}
 
 	@Override
 	public <T> ReducingState<T> getReducingState(ReducingStateDescriptor<T> stateProperties) {
-		requireNonNull(stateProperties, "The state properties must not be null");
-		try {
-			stateProperties.initializeSerializerUnlessSet(getExecutionConfig());
-			return operator.getPartitionedState(stateProperties);
-		} catch (Exception e) {
-			throw new RuntimeException("Error while getting state", e);
-		}
+		return operator.getKeyedStateStore().getReducingState(stateProperties);
 	}
 
 	// ------------------ expose (read only) relevant information from the stream config -------- //

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/UserFacingListState.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/UserFacingListState.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/UserFacingListState.java
deleted file mode 100644
index a02a204..0000000
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/UserFacingListState.java
+++ /dev/null
@@ -1,57 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.streaming.api.operators;
-
-import org.apache.flink.api.common.state.ListState;
-
-import java.util.Collections;
-
-/**
- * Simple wrapper list state that exposes empty state properly as an empty list.
- * 
- * @param <T> The type of elements in the list state.
- */
-class UserFacingListState<T> implements ListState<T> {
-
-	private final ListState<T> originalState;
-
-	private final Iterable<T> emptyState = Collections.emptyList();
-
-	UserFacingListState(ListState<T> originalState) {
-		this.originalState = originalState;
-	}
-
-	// ------------------------------------------------------------------------
-
-	@Override
-	public Iterable<T> get() throws Exception {
-		Iterable<T> original = originalState.get();
-		return original != null ? original : emptyState;
-	}
-
-	@Override
-	public void add(T value) throws Exception {
-		originalState.add(value);
-	}
-
-	@Override
-	public void clear() {
-		originalState.clear();
-	}
-}

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorStateHandles.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorStateHandles.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorStateHandles.java
new file mode 100644
index 0000000..7abf8d9
--- /dev/null
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorStateHandles.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.runtime.tasks;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.TaskStateHandles;
+import org.apache.flink.util.CollectionUtil;
+import org.apache.flink.util.Preconditions;
+
+import java.util.Collection;
+import java.util.List;
+
+/**
+ * This class holds all state handles for one operator.
+ */
+@Internal
+@VisibleForTesting
+public class OperatorStateHandles {
+
+	private final int operatorChainIndex;
+
+	private final StreamStateHandle legacyOperatorState;
+
+	private final Collection<KeyGroupsStateHandle> managedKeyedState;
+	private final Collection<KeyGroupsStateHandle> rawKeyedState;
+	private final Collection<OperatorStateHandle> managedOperatorState;
+	private final Collection<OperatorStateHandle> rawOperatorState;
+
+	public OperatorStateHandles(
+			int operatorChainIndex,
+			StreamStateHandle legacyOperatorState,
+			Collection<KeyGroupsStateHandle> managedKeyedState,
+			Collection<KeyGroupsStateHandle> rawKeyedState,
+			Collection<OperatorStateHandle> managedOperatorState,
+			Collection<OperatorStateHandle> rawOperatorState) {
+
+		this.operatorChainIndex = operatorChainIndex;
+		this.legacyOperatorState = legacyOperatorState;
+		this.managedKeyedState = managedKeyedState;
+		this.rawKeyedState = rawKeyedState;
+		this.managedOperatorState = managedOperatorState;
+		this.rawOperatorState = rawOperatorState;
+	}
+
+	public OperatorStateHandles(TaskStateHandles taskStateHandles, int operatorChainIndex) {
+		Preconditions.checkNotNull(taskStateHandles);
+
+		this.operatorChainIndex = operatorChainIndex;
+
+		ChainedStateHandle<StreamStateHandle> legacyState = taskStateHandles.getLegacyOperatorState();
+		this.legacyOperatorState = ChainedStateHandle.isNullOrEmpty(legacyState) ?
+				null : legacyState.get(operatorChainIndex);
+
+		this.rawKeyedState = taskStateHandles.getRawKeyedState();
+		this.managedKeyedState = taskStateHandles.getManagedKeyedState();
+
+		this.managedOperatorState = getSafeItemAtIndexOrNull(taskStateHandles.getManagedOperatorState(), operatorChainIndex);
+		this.rawOperatorState = getSafeItemAtIndexOrNull(taskStateHandles.getRawOperatorState(), operatorChainIndex);
+	}
+
+	public StreamStateHandle getLegacyOperatorState() {
+		return legacyOperatorState;
+	}
+
+	public Collection<KeyGroupsStateHandle> getManagedKeyedState() {
+		return managedKeyedState;
+	}
+
+	public Collection<KeyGroupsStateHandle> getRawKeyedState() {
+		return rawKeyedState;
+	}
+
+	public Collection<OperatorStateHandle> getManagedOperatorState() {
+		return managedOperatorState;
+	}
+
+	public Collection<OperatorStateHandle> getRawOperatorState() {
+		return rawOperatorState;
+	}
+
+	public int getOperatorChainIndex() {
+		return operatorChainIndex;
+	}
+
+	private static <T> T getSafeItemAtIndexOrNull(List<T> list, int idx) {
+		return CollectionUtil.isNullOrEmpty(list) ? null : list.get(idx);
+	}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index 2e6ebf3..eb5fde7 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -23,9 +23,9 @@ import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.configuration.IllegalConfigurationException;
-import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.metrics.Gauge;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.checkpoint.SubtaskState;
 import org.apache.flink.runtime.execution.CancelTaskException;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
@@ -33,7 +33,6 @@ import org.apache.flink.runtime.jobgraph.tasks.StatefulTask;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.ClosableRegistry;
 import org.apache.flink.runtime.state.KeyGroupRange;
@@ -42,27 +41,29 @@ import org.apache.flink.runtime.state.OperatorStateBackend;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateBackendFactory;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.state.filesystem.FsStateBackend;
 import org.apache.flink.runtime.state.filesystem.FsStateBackendFactory;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.runtime.taskmanager.DispatcherThreadFactory;
 import org.apache.flink.streaming.api.TimeCharacteristic;
 import org.apache.flink.streaming.api.graph.StreamConfig;
+import org.apache.flink.streaming.api.operators.OperatorSnapshotResult;
 import org.apache.flink.streaming.api.operators.Output;
 import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
 import org.apache.flink.streaming.api.operators.StreamOperator;
 import org.apache.flink.streaming.runtime.io.RecordWriterOutput;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.CollectionUtil;
+import org.apache.flink.util.FutureUtil;
 import org.apache.flink.util.Preconditions;
-
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.io.Closeable;
 import java.io.IOException;
-import java.util.Arrays;
+import java.util.ArrayList;
 import java.util.Collection;
-import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.ExecutorService;
@@ -87,13 +88,14 @@ import java.util.concurrent.ThreadFactory;
  *
  * The life cycle of the task is set up as follows:
  * <pre>{@code
- *  -- getOperatorState() -> restores state of all operators in the chain
+ *  -- setInitialState -> provides state of all operators in the chain
  *
  *  -- invoke()
  *        |
  *        +----> Create basic utils (config, etc) and load the chain of operators
  *        +----> operators.setup()
  *        +----> task specific init()
+ *        +----> initialize-operator-states()
  *        +----> open-operators()
  *        +----> run()
  *        +----> close-operators()
@@ -153,12 +155,7 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 	/** The map of user-defined accumulators of this task */
 	private Map<String, Accumulator<?, ?>> accumulatorMap;
 
-	/** The chained operator state to be restored once the initialization is done */
-	private ChainedStateHandle<StreamStateHandle> lazyRestoreChainedOperatorState;
-
-	private List<KeyGroupsStateHandle> lazyRestoreKeyGroupStates;
-
-	private List<Collection<OperatorStateHandle>> lazyRestoreOperatorState;
+	private TaskStateHandles restoreStateHandles;
 
 
 	/** The currently active background materialization threads */
@@ -251,9 +248,8 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 			// -------- Invoke --------
 			LOG.debug("Invoking {}", getName());
 
-			// first order of business is to give operators back their state
-			restoreState();
-			lazyRestoreChainedOperatorState = null; // GC friendliness
+			// first order of business is to give operators their state
+			initializeState();
 
 			// we need to make sure that any triggers scheduled in open() cannot be
 			// executed before all operators are opened
@@ -510,60 +506,8 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 	// ------------------------------------------------------------------------
 
 	@Override
-	public void setInitialState(
-		ChainedStateHandle<StreamStateHandle> chainedState,
-		List<KeyGroupsStateHandle> keyGroupsState,
-		List<Collection<OperatorStateHandle>> partitionableOperatorState) {
-
-		lazyRestoreChainedOperatorState = chainedState;
-		lazyRestoreKeyGroupStates = keyGroupsState;
-		lazyRestoreOperatorState = partitionableOperatorState;
-	}
-
-	private void restoreState() throws Exception {
-		final StreamOperator<?>[] allOperators = operatorChain.getAllOperators();
-
-		if (lazyRestoreChainedOperatorState != null) {
-			Preconditions.checkState(lazyRestoreChainedOperatorState.getLength() == allOperators.length,
-					"Invalid Invalid number of operator states. Found :" + lazyRestoreChainedOperatorState.getLength() +
-							". Expected: " + allOperators.length);
-		}
-
-		if (lazyRestoreOperatorState != null) {
-			Preconditions.checkArgument(lazyRestoreOperatorState.isEmpty()
-							|| lazyRestoreOperatorState.size() == allOperators.length,
-					"Invalid number of operator states. Found :" + lazyRestoreOperatorState.size() +
-							". Expected: " + allOperators.length);
-		}
-
-		for (int i = 0; i < allOperators.length; i++) {
-			StreamOperator<?> operator = allOperators[i];
-
-			if (null != lazyRestoreOperatorState && !lazyRestoreOperatorState.isEmpty()) {
-				operator.restoreState(lazyRestoreOperatorState.get(i));
-			}
-
-			// TODO deprecated code path
-			if (operator instanceof StreamCheckpointedOperator) {
-
-				if (lazyRestoreChainedOperatorState != null) {
-					StreamStateHandle state = lazyRestoreChainedOperatorState.get(i);
-
-					if (state != null) {
-						LOG.debug("Restore state of task {} in chain ({}).", i, getName());
-
-						FSDataInputStream is = state.openInputStream();
-						try {
-							cancelables.registerClosable(is);
-							((StreamCheckpointedOperator) operator).restoreState(is);
-						} finally {
-							cancelables.unregisterClosable(is);
-							is.close();
-						}
-					}
-				}
-			}
-		}
+	public void setInitialState(TaskStateHandles taskStateHandles) {
+		this.restoreStateHandles = taskStateHandles;
 	}
 
 	@Override
@@ -600,117 +544,19 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 
 	private boolean performCheckpoint(CheckpointMetaData checkpointMetaData) throws Exception {
 
-		long checkpointId = checkpointMetaData.getCheckpointId();
-		long timestamp = checkpointMetaData.getTimestamp();
-
-		LOG.debug("Starting checkpoint {} on task {}", checkpointId, getName());
+		LOG.debug("Starting checkpoint {} on task {}", checkpointMetaData.getCheckpointId(), getName());
 
 		synchronized (lock) {
 			if (isRunning) {
 
-				final long startOfSyncPart = System.nanoTime();
-
 				// Since both state checkpointing and downstream barrier emission occurs in this
 				// lock scope, they are an atomic operation regardless of the order in which they occur.
 				// Given this, we immediately emit the checkpoint barriers, so the downstream operators
 				// can start their checkpoint work as soon as possible
-				operatorChain.broadcastCheckpointBarrier(checkpointId, timestamp);
-
-				// now draw the state snapshot
-				final StreamOperator<?>[] allOperators = operatorChain.getAllOperators();
-
-				final List<StreamStateHandle> nonPartitionedStates =
-						Arrays.asList(new StreamStateHandle[allOperators.length]);
-
-				final List<OperatorStateHandle> operatorStates =
-						Arrays.asList(new OperatorStateHandle[allOperators.length]);
-
-				for (int i = 0; i < allOperators.length; i++) {
-					StreamOperator<?> operator = allOperators[i];
-
-					if (operator != null) {
-
-						final String operatorId = createOperatorIdentifier(operator, configuration.getVertexID());
-
-						CheckpointStreamFactory streamFactory =
-								stateBackend.createStreamFactory(getEnvironment().getJobID(), operatorId);
-
-						//TODO deprecated code path
-						if (operator instanceof StreamCheckpointedOperator) {
-
-							CheckpointStreamFactory.CheckpointStateOutputStream outStream =
-									streamFactory.createCheckpointStateOutputStream(checkpointId, timestamp);
-
-
-							cancelables.registerClosable(outStream);
-
-							try {
-								((StreamCheckpointedOperator) operator).
-										snapshotState(outStream, checkpointId, timestamp);
-
-								nonPartitionedStates.set(i, outStream.closeAndGetHandle());
-							} finally {
-								cancelables.unregisterClosable(outStream);
-							}
-						}
-
-						RunnableFuture<OperatorStateHandle> handleFuture =
-								operator.snapshotState(checkpointId, timestamp, streamFactory);
-
-						if (null != handleFuture) {
-							//TODO for now we assume there are only synchrous snapshots, no need to start the runnable.
-							if (!handleFuture.isDone()) {
-								throw new IllegalStateException("Currently only supports synchronous snapshots!");
-							}
-
-							operatorStates.set(i, handleFuture.get());
-						}
-					}
-
-				}
-
-				RunnableFuture<KeyGroupsStateHandle> keyGroupsStateHandleFuture = null;
-
-				if (keyedStateBackend != null) {
-					CheckpointStreamFactory streamFactory = stateBackend.createStreamFactory(
-							getEnvironment().getJobID(),
-							createOperatorIdentifier(headOperator, configuration.getVertexID()));
-
-					keyGroupsStateHandleFuture = keyedStateBackend.snapshot(checkpointId, timestamp, streamFactory);
-				}
-
-				ChainedStateHandle<StreamStateHandle> chainedNonPartitionedStateHandles =
-						new ChainedStateHandle<>(nonPartitionedStates);
-
-				ChainedStateHandle<OperatorStateHandle> chainedPartitionedStateHandles =
-						new ChainedStateHandle<>(operatorStates);
-
-				LOG.debug("Finished synchronous checkpoints for checkpoint {} on task {}", checkpointId, getName());
-
-				final long syncEndNanos = System.nanoTime();
-				final long syncDurationMillis = (syncEndNanos - startOfSyncPart) / 1_000_000;
-
-				checkpointMetaData.setSyncDurationMillis(syncDurationMillis);
-
-				AsyncCheckpointRunnable asyncCheckpointRunnable = new AsyncCheckpointRunnable(
-						"checkpoint-" + checkpointId + "-" + timestamp,
-						this,
-						cancelables,
-						chainedNonPartitionedStateHandles,
-						chainedPartitionedStateHandles,
-						keyGroupsStateHandleFuture,
-						checkpointMetaData,
-						syncEndNanos);
-
-				cancelables.registerClosable(asyncCheckpointRunnable);
-				asyncOperationsThreadPool.submit(asyncCheckpointRunnable);
-
-				if (LOG.isDebugEnabled()) {
-					LOG.debug("{} - finished synchronous part of checkpoint {}." +
-							"Alignment duration: {} ms, snapshot duration {} ms",
-							getName(), checkpointId, checkpointMetaData.getAlignmentDurationNanos() / 1_000_000, syncDurationMillis);
-				}
+				operatorChain.broadcastCheckpointBarrier(
+						checkpointMetaData.getCheckpointId(), checkpointMetaData.getTimestamp());
 
+				checkpointState(checkpointMetaData);
 				return true;
 			} else {
 				return false;
@@ -740,6 +586,59 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 		}
 	}
 
+	private void checkpointState(CheckpointMetaData checkpointMetaData) throws Exception {
+		CheckpointingOperation checkpointingOperation = new CheckpointingOperation(this, checkpointMetaData);
+		checkpointingOperation.executeCheckpointing();
+	}
+
+	private void initializeState() throws Exception {
+
+		boolean restored = null != restoreStateHandles;
+
+		if (restored) {
+
+			checkRestorePreconditions(operatorChain.getChainLength());
+			initializeOperators(true);
+			restoreStateHandles = null; // free for GC
+		} else {
+			initializeOperators(false);
+		}
+	}
+
+	private void initializeOperators(boolean restored) throws Exception {
+		StreamOperator<?>[] allOperators = operatorChain.getAllOperators();
+		for (int chainIdx = 0; chainIdx < allOperators.length; ++chainIdx) {
+			StreamOperator<?> operator = allOperators[chainIdx];
+			if (null != operator) {
+				if (restored) {
+					operator.initializeState(new OperatorStateHandles(restoreStateHandles, chainIdx));
+				} else {
+					operator.initializeState(null);
+				}
+			}
+		}
+	}
+
+	private void checkRestorePreconditions(int operatorChainLength) {
+
+		ChainedStateHandle<StreamStateHandle> nonPartitionableOperatorStates =
+				restoreStateHandles.getLegacyOperatorState();
+		List<Collection<OperatorStateHandle>> operatorStates =
+				restoreStateHandles.getManagedOperatorState();
+
+		if (nonPartitionableOperatorStates != null) {
+			Preconditions.checkState(nonPartitionableOperatorStates.getLength() == operatorChainLength,
+					"Invalid Invalid number of operator states. Found :" + nonPartitionableOperatorStates.getLength()
+							+ ". Expected: " + operatorChainLength);
+		}
+
+		if (!CollectionUtil.isNullOrEmpty(operatorStates)) {
+			Preconditions.checkArgument(operatorStates.size() == operatorChainLength,
+					"Invalid number of operator states. Found :" + operatorStates.size() +
+							". Expected: " + operatorChainLength);
+		}
+	}
+
 	// ------------------------------------------------------------------------
 	//  State backend
 	// ------------------------------------------------------------------------
@@ -777,7 +676,8 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 					try {
 						@SuppressWarnings("rawtypes")
 						Class<? extends StateBackendFactory> clazz =
-								Class.forName(backendName, false, getUserCodeClassLoader()).asSubclass(StateBackendFactory.class);
+								Class.forName(backendName, false, getUserCodeClassLoader()).
+										asSubclass(StateBackendFactory.class);
 
 						stateBackend = clazz.newInstance().createFromConfig(flinkConfig);
 					} catch (ClassNotFoundException e) {
@@ -799,7 +699,7 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 			StreamOperator<?> op, Collection<OperatorStateHandle> restoreStateHandles) throws Exception {
 
 		Environment env = getEnvironment();
-		String opId = createOperatorIdentifier(op, configuration.getVertexID());
+		String opId = createOperatorIdentifier(op, getConfiguration().getVertexID());
 
 		OperatorStateBackend newBackend = restoreStateHandles == null ?
 				stateBackend.createOperatorStateBackend(env, opId)
@@ -823,7 +723,7 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 				headOperator,
 				configuration.getVertexID());
 
-		if (lazyRestoreKeyGroupStates != null) {
+		if (null != restoreStateHandles && null != restoreStateHandles.getManagedKeyedState()) {
 			keyedStateBackend = stateBackend.restoreKeyedStateBackend(
 					getEnvironment(),
 					getEnvironment().getJobID(),
@@ -831,10 +731,10 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 					keySerializer,
 					numberOfKeyGroups,
 					keyGroupRange,
-					lazyRestoreKeyGroupStates,
+					restoreStateHandles.getManagedKeyedState(),
 					getEnvironment().getTaskKvStateRegistry());
 
-			lazyRestoreKeyGroupStates = null; // GC friendliness
+			restoreStateHandles = null; // GC friendliness
 		} else {
 			keyedStateBackend = stateBackend.createKeyedStateBackend(
 					getEnvironment(),
@@ -913,62 +813,60 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 
 	// ------------------------------------------------------------------------
 
-	private static class AsyncCheckpointRunnable implements Runnable, Closeable {
+	private static final class AsyncCheckpointRunnable implements Runnable, Closeable {
 
 		private final StreamTask<?, ?> owner;
 
-		private final ClosableRegistry cancelables;
-
-		private final ChainedStateHandle<StreamStateHandle> nonPartitionedStateHandles;
-
-		private final ChainedStateHandle<OperatorStateHandle> partitioneableStateHandles;
+		private final List<OperatorSnapshotResult> snapshotInProgressList;
 
-		private final RunnableFuture<KeyGroupsStateHandle> keyGroupsStateHandleFuture;
+		RunnableFuture<KeyGroupsStateHandle> futureKeyedBackendStateHandles;
+		RunnableFuture<KeyGroupsStateHandle> futureKeyedStreamStateHandles;
 
-		private final String name;
+		List<StreamStateHandle> nonPartitionedStateHandles;
 
 		private final CheckpointMetaData checkpointMetaData;
 
 		private final long asyncStartNanos;
 
 		AsyncCheckpointRunnable(
-				String name,
 				StreamTask<?, ?> owner,
-				ClosableRegistry cancelables,
-				ChainedStateHandle<StreamStateHandle> nonPartitionedStateHandles,
-				ChainedStateHandle<OperatorStateHandle> partitioneableStateHandles,
-				RunnableFuture<KeyGroupsStateHandle> keyGroupsStateHandleFuture,
+				List<StreamStateHandle> nonPartitionedStateHandles,
+				List<OperatorSnapshotResult> snapshotInProgressList,
 				CheckpointMetaData checkpointMetaData,
-				long asyncStartNanos
-		) {
+				long asyncStartNanos) {
 
-			this.name = name;
-			this.owner = owner;
-			this.cancelables = cancelables;
+			this.owner = Preconditions.checkNotNull(owner);
+			this.snapshotInProgressList = Preconditions.checkNotNull(snapshotInProgressList);
+			this.checkpointMetaData = Preconditions.checkNotNull(checkpointMetaData);
 			this.nonPartitionedStateHandles = nonPartitionedStateHandles;
-			this.partitioneableStateHandles = partitioneableStateHandles;
-			this.keyGroupsStateHandleFuture = keyGroupsStateHandleFuture;
-			this.checkpointMetaData = checkpointMetaData;
 			this.asyncStartNanos = asyncStartNanos;
+
+			if (!snapshotInProgressList.isEmpty()) {
+				// TODO Currently only the head operator of a chain can have keyed state, so simply access it directly.
+				int headIndex = snapshotInProgressList.size() - 1;
+				OperatorSnapshotResult snapshotInProgress = snapshotInProgressList.get(headIndex);
+				if (null != snapshotInProgress) {
+					this.futureKeyedBackendStateHandles = snapshotInProgress.getKeyedStateManagedFuture();
+					this.futureKeyedStreamStateHandles = snapshotInProgress.getKeyedStateRawFuture();
+				}
+			}
 		}
 
 		@Override
 		public void run() {
+
 			try {
 
-				List<KeyGroupsStateHandle> keyedStates = Collections.emptyList();
+				// Keyed state handle future, currently only one (the head) operator can have this
+				KeyGroupsStateHandle keyedStateHandleBackend = FutureUtil.runIfNotDoneAndGet(futureKeyedBackendStateHandles);
+				KeyGroupsStateHandle keyedStateHandleStream = FutureUtil.runIfNotDoneAndGet(futureKeyedStreamStateHandles);
 
-				if (keyGroupsStateHandleFuture != null) {
+				List<OperatorStateHandle> operatorStatesBackend = new ArrayList<>(snapshotInProgressList.size());
+				List<OperatorStateHandle> operatorStatesStream = new ArrayList<>(snapshotInProgressList.size());
 
-					if (!keyGroupsStateHandleFuture.isDone()) {
-						//TODO this currently works because we only have one RunnableFuture
-						keyGroupsStateHandleFuture.run();
-					}
-
-					KeyGroupsStateHandle keyGroupsStateHandle = this.keyGroupsStateHandleFuture.get();
-					if (keyGroupsStateHandle != null) {
-						keyedStates = Collections.singletonList(keyGroupsStateHandle);
-					}
+				for (OperatorSnapshotResult snapshotInProgress : snapshotInProgressList) {
+					operatorStatesBackend.add(FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getOperatorStateManagedFuture()));
+					operatorStatesStream.add(FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getOperatorStateRawFuture()));
 				}
 
 				final long asyncEndNanos = System.nanoTime();
@@ -976,37 +874,161 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 
 				checkpointMetaData.setAsyncDurationMillis(asyncDurationMillis);
 
-				if (nonPartitionedStateHandles.isEmpty() && partitioneableStateHandles.isEmpty() && keyedStates.isEmpty()) {
-					owner.getEnvironment().acknowledgeCheckpoint(checkpointMetaData);
-				} else {
-					CheckpointStateHandles allStateHandles = new CheckpointStateHandles(
-							nonPartitionedStateHandles,
-							partitioneableStateHandles,
-							keyedStates);
+				ChainedStateHandle<StreamStateHandle> chainedNonPartitionedOperatorsState =
+						new ChainedStateHandle<>(nonPartitionedStateHandles);
 
-					owner.getEnvironment().acknowledgeCheckpoint(checkpointMetaData, allStateHandles);
+				ChainedStateHandle<OperatorStateHandle> chainedOperatorStateBackend =
+						new ChainedStateHandle<>(operatorStatesBackend);
+
+				ChainedStateHandle<OperatorStateHandle> chainedOperatorStateStream =
+						new ChainedStateHandle<>(operatorStatesStream);
+
+				SubtaskState subtaskState = new SubtaskState(
+						chainedNonPartitionedOperatorsState,
+						chainedOperatorStateBackend,
+						chainedOperatorStateStream,
+						keyedStateHandleBackend,
+						keyedStateHandleStream);
+
+				if (subtaskState.hasState()) {
+					owner.getEnvironment().acknowledgeCheckpoint(checkpointMetaData, subtaskState);
+				} else {
+					owner.getEnvironment().acknowledgeCheckpoint(checkpointMetaData);
 				}
 
 				if (LOG.isDebugEnabled()) {
-					LOG.debug("{} - finished asynchronous part of checkpoint {}. Asynchronous duration: {} ms", 
+					LOG.debug("{} - finished asynchronous part of checkpoint {}. Asynchronous duration: {} ms",
 							owner.getName(), checkpointMetaData.getCheckpointId(), asyncDurationMillis);
 				}
-			}
-			catch (Exception e) {
+			} catch (Exception e) {
 				// registers the exception and tries to fail the whole task
 				AsynchronousException asyncException = new AsynchronousException(e);
 				owner.handleAsyncException("Failure in asynchronous checkpoint materialization", asyncException);
-			}
-			finally {
-				cancelables.unregisterClosable(this);
+			} finally {
+				owner.cancelables.unregisterClosable(this);
 			}
 		}
 
 		@Override
 		public void close() {
-			if (keyGroupsStateHandleFuture != null) {
-				keyGroupsStateHandleFuture.cancel(true);
+			//TODO Handle other state futures in case we actually run them. Currently they are just DoneFutures.
+			if (futureKeyedBackendStateHandles != null) {
+				futureKeyedBackendStateHandles.cancel(true);
+			}
+		}
+	}
+
+	public ClosableRegistry getCancelables() {
+		return cancelables;
+	}
+
+	// ------------------------------------------------------------------------
+
+	private static final class CheckpointingOperation {
+
+		private final StreamTask<?, ?> owner;
+
+		private final CheckpointMetaData checkpointMetaData;
+
+		private final StreamOperator<?>[] allOperators;
+
+		private long startSyncPartNano;
+		private long startAsyncPartNano;
+
+		// ------------------------
+
+		private CheckpointStreamFactory streamFactory;
+
+		private final List<StreamStateHandle> nonPartitionedStates;
+		private final List<OperatorSnapshotResult> snapshotInProgressList;
+
+		public CheckpointingOperation(StreamTask<?, ?> owner, CheckpointMetaData checkpointMetaData) {
+			this.owner = Preconditions.checkNotNull(owner);
+			this.checkpointMetaData = Preconditions.checkNotNull(checkpointMetaData);
+			this.allOperators = owner.operatorChain.getAllOperators();
+			this.nonPartitionedStates = new ArrayList<>(allOperators.length);
+			this.snapshotInProgressList = new ArrayList<>(allOperators.length);
+		}
+
+		public void executeCheckpointing() throws Exception {
+
+			startSyncPartNano = System.nanoTime();
+
+			for (StreamOperator<?> op : allOperators) {
+
+				createStreamFactory(op);
+				snapshotNonPartitionableState(op);
+
+				OperatorSnapshotResult snapshotInProgress =
+						op.snapshotState(checkpointMetaData.getCheckpointId(), checkpointMetaData.getTimestamp(), streamFactory);
+
+				snapshotInProgressList.add(snapshotInProgress);
 			}
+
+			if (LOG.isDebugEnabled()) {
+				LOG.debug("Finished synchronous checkpoints for checkpoint {} on task {}",
+						checkpointMetaData.getCheckpointId(), owner.getName());
+			}
+
+			startAsyncPartNano= System.nanoTime();
+
+			checkpointMetaData.setSyncDurationMillis((startAsyncPartNano - startSyncPartNano) / 1_000_000);
+
+			runAsyncCheckpointingAndAcknowledge();
+
+			if (LOG.isDebugEnabled()) {
+				LOG.debug("{} - finished synchronous part of checkpoint {}." +
+								"Alignment duration: {} ms, snapshot duration {} ms",
+						owner.getName(), checkpointMetaData.getCheckpointId(),
+						checkpointMetaData.getAlignmentDurationNanos() / 1_000_000,
+						checkpointMetaData.getSyncDurationMillis());
+			}
+		}
+
+		private void createStreamFactory(StreamOperator<?> operator) throws IOException {
+			String operatorId = owner.createOperatorIdentifier(operator, owner.configuration.getVertexID());
+			this.streamFactory = owner.stateBackend.createStreamFactory(owner.getEnvironment().getJobID(), operatorId);
+		}
+
+		//TODO deprecated code path
+		private void snapshotNonPartitionableState(StreamOperator<?> operator) throws Exception {
+
+			StreamStateHandle stateHandle = null;
+
+			if (operator instanceof StreamCheckpointedOperator) {
+
+				CheckpointStreamFactory.CheckpointStateOutputStream outStream =
+						streamFactory.createCheckpointStateOutputStream(
+								checkpointMetaData.getCheckpointId(), checkpointMetaData.getTimestamp());
+
+				owner.cancelables.registerClosable(outStream);
+
+				try {
+					((StreamCheckpointedOperator) operator).
+							snapshotState(
+									outStream,
+									checkpointMetaData.getCheckpointId(),
+									checkpointMetaData.getTimestamp());
+
+					stateHandle = outStream.closeAndGetHandle();
+				} finally {
+					owner.cancelables.unregisterClosable(outStream);
+					outStream.close();
+				}
+			}
+			nonPartitionedStates.add(stateHandle);
+		}
+
+		public void runAsyncCheckpointingAndAcknowledge() throws IOException {
+			AsyncCheckpointRunnable asyncCheckpointRunnable = new AsyncCheckpointRunnable(
+					owner,
+					nonPartitionedStates,
+					snapshotInProgressList,
+					checkpointMetaData,
+					startAsyncPartNano);
+
+			owner.cancelables.registerClosable(asyncCheckpointRunnable);
+			owner.asyncOperationsThreadPool.submit(asyncCheckpointRunnable);
 		}
 	}
 }


[6/8] flink git commit: [FLINK-4844] Partitionable Raw Keyed/Operator State

Posted by al...@apache.org.
http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java
index 3e2d713..1ad41ea 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java
@@ -38,8 +38,8 @@ public class OperatorStateHandle implements StreamStateHandle {
 	private final StreamStateHandle delegateStateHandle;
 
 	public OperatorStateHandle(
-			StreamStateHandle delegateStateHandle,
-			Map<String, long[]> stateNameToPartitionOffsets) {
+			Map<String, long[]> stateNameToPartitionOffsets,
+			StreamStateHandle delegateStateHandle) {
 
 		this.delegateStateHandle = Preconditions.checkNotNull(delegateStateHandle);
 		this.stateNameToPartitionOffsets = Preconditions.checkNotNull(stateNameToPartitionOffsets);

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/PartitionableCheckpointStateOutputStream.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/PartitionableCheckpointStateOutputStream.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/PartitionableCheckpointStateOutputStream.java
deleted file mode 100644
index 065f9c2..0000000
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/PartitionableCheckpointStateOutputStream.java
+++ /dev/null
@@ -1,96 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.runtime.state;
-
-import org.apache.flink.core.fs.FSDataOutputStream;
-import org.apache.flink.util.Preconditions;
-
-import java.io.IOException;
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.Map;
-
-public class PartitionableCheckpointStateOutputStream extends FSDataOutputStream {
-
-	private final Map<String, long[]> stateNameToPartitionOffsets;
-	private final CheckpointStreamFactory.CheckpointStateOutputStream delegate;
-
-	public PartitionableCheckpointStateOutputStream(CheckpointStreamFactory.CheckpointStateOutputStream delegate) {
-		this.delegate = Preconditions.checkNotNull(delegate);
-		this.stateNameToPartitionOffsets = new HashMap<>();
-	}
-
-	@Override
-	public long getPos() throws IOException {
-		return delegate.getPos();
-	}
-
-	@Override
-	public void flush() throws IOException {
-		delegate.flush();
-	}
-
-	@Override
-	public void sync() throws IOException {
-		delegate.sync();
-	}
-
-	@Override
-	public void write(int b) throws IOException {
-		delegate.write(b);
-	}
-
-	@Override
-	public void write(byte[] b) throws IOException {
-		delegate.write(b);
-	}
-
-	@Override
-	public void write(byte[] b, int off, int len) throws IOException {
-		delegate.write(b, off, len);
-	}
-
-	@Override
-	public void close() throws IOException {
-		delegate.close();
-	}
-
-	public OperatorStateHandle closeAndGetHandle() throws IOException {
-		StreamStateHandle streamStateHandle = delegate.closeAndGetHandle();
-		return new OperatorStateHandle(streamStateHandle, stateNameToPartitionOffsets);
-	}
-
-	public void startNewPartition(String stateName) throws IOException {
-		long[] offs = stateNameToPartitionOffsets.get(stateName);
-		if (offs == null) {
-			offs = new long[1];
-		} else {
-			//TODO maybe we can use some primitive array list here instead of an array to avoid resize on each call.
-			offs = Arrays.copyOf(offs, offs.length + 1);
-		}
-
-		offs[offs.length - 1] = getPos();
-		stateNameToPartitionOffsets.put(stateName, offs);
-	}
-
-	public static PartitionableCheckpointStateOutputStream wrap(
-			CheckpointStreamFactory.CheckpointStateOutputStream stream) {
-		return new PartitionableCheckpointStateOutputStream(stream);
-	}
-}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/SnapshotProvider.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SnapshotProvider.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SnapshotProvider.java
deleted file mode 100644
index c47fedd..0000000
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SnapshotProvider.java
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.runtime.state;
-
-import java.util.concurrent.RunnableFuture;
-
-/**
- * Interface for operations that can perform snapshots of their state.
- *
- * @param <S> Generic type of the state object that is created as handle to snapshots.
- */
-public interface SnapshotProvider<S extends StateObject> {
-
-	/**
-	 * Operation that writes a snapshot into a stream that is provided by the given {@link CheckpointStreamFactory} and
-	 * returns a @{@link RunnableFuture} that gives a state handle to the snapshot. It is up to the implementation if
-	 * the operation is performed synchronous or asynchronous. In the later case, the returned Runnable must be executed
-	 * first before obtaining the handle.
-	 *
-	 * @param checkpointId  The ID of the checkpoint.
-	 * @param timestamp     The timestamp of the checkpoint.
-	 * @param streamFactory The factory that we can use for writing our state to streams.
-	 * @return A runnable future that will yield a {@link StateObject}.
-	 */
-	RunnableFuture<S> snapshot(
-			long checkpointId,
-			long timestamp,
-			CheckpointStreamFactory streamFactory) throws Exception;
-}

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/Snapshotable.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/Snapshotable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/Snapshotable.java
new file mode 100644
index 0000000..2aa282d
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/Snapshotable.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import java.util.concurrent.RunnableFuture;
+
+/**
+ * Interface for operations that can perform snapshots of their state.
+ *
+ * @param <S> Generic type of the state object that is created as handle to snapshots.
+ */
+public interface Snapshotable<S extends StateObject> {
+
+	/**
+	 * Operation that writes a snapshot into a stream that is provided by the given {@link CheckpointStreamFactory} and
+	 * returns a @{@link RunnableFuture} that gives a state handle to the snapshot. It is up to the implementation if
+	 * the operation is performed synchronous or asynchronous. In the later case, the returned Runnable must be executed
+	 * first before obtaining the handle.
+	 *
+	 * @param checkpointId  The ID of the checkpoint.
+	 * @param timestamp     The timestamp of the checkpoint.
+	 * @param streamFactory The factory that we can use for writing our state to streams.
+	 * @return A runnable future that will yield a {@link StateObject}.
+	 */
+	RunnableFuture<S> snapshot(
+			long checkpointId,
+			long timestamp,
+			CheckpointStreamFactory streamFactory) throws Exception;
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContext.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContext.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContext.java
new file mode 100644
index 0000000..a066739
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContext.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.annotation.PublicEvolving;
+
+/**
+ * This interface provides a context in which operators can initialize by registering to managed state (i.e. state that
+ * is managed by state backends) or iterating over streams of state partitions written as raw state in a previous
+ * snapshot.
+ *
+ * <p>
+ * Similar to the managed state from {@link ManagedInitializationContext} and in general,  raw operator state is
+ * available to all operators, while raw keyed state is only available for operators after keyBy.
+ *
+ * <p>
+ * For the purpose of initialization, the context signals if all state is empty (new operator) or if any state was
+ * restored from a previous execution of this operator.
+ *
+ */
+@PublicEvolving
+public interface StateInitializationContext extends FunctionInitializationContext {
+
+	/**
+	 * Returns an iterable to obtain input streams for previously stored operator state partitions that are assigned to
+	 * this operator.
+	 */
+	Iterable<StatePartitionStreamProvider> getRawOperatorStateInputs();
+
+	/**
+	 * Returns an iterable to obtain input streams for previously stored keyed state partitions that are assigned to
+	 * this operator.
+	 */
+	Iterable<KeyGroupStatePartitionStreamProvider> getRawKeyedStateInputs();
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
new file mode 100644
index 0000000..8fbde05
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
@@ -0,0 +1,270 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.commons.io.IOUtils;
+import org.apache.flink.api.common.state.KeyedStateStore;
+import org.apache.flink.api.common.state.OperatorStateStore;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Iterator;
+
+/**
+ * Default implementation of {@link StateInitializationContext}.
+ */
+public class StateInitializationContextImpl implements StateInitializationContext {
+
+	/** Closable registry to participate in the operator's cancel/close methods */
+	private final ClosableRegistry closableRegistry;
+
+	/** Signal whether any state to restore was found */
+	private final boolean restored;
+
+	private final OperatorStateStore operatorStateStore;
+	private final Collection<OperatorStateHandle> operatorStateHandles;
+
+	private final KeyedStateStore keyedStateStore;
+	private final Collection<KeyGroupsStateHandle> keyGroupsStateHandles;
+
+	private final Iterable<KeyGroupStatePartitionStreamProvider> keyedStateIterable;
+
+	public StateInitializationContextImpl(
+			boolean restored,
+			OperatorStateStore operatorStateStore,
+			KeyedStateStore keyedStateStore,
+			Collection<KeyGroupsStateHandle> keyGroupsStateHandles,
+			Collection<OperatorStateHandle> operatorStateHandles,
+			ClosableRegistry closableRegistry) {
+
+		this.restored = restored;
+		this.closableRegistry = Preconditions.checkNotNull(closableRegistry);
+		this.operatorStateStore = operatorStateStore;
+		this.keyedStateStore = keyedStateStore;
+		this.operatorStateHandles = operatorStateHandles;
+		this.keyGroupsStateHandles = keyGroupsStateHandles;
+
+		this.keyedStateIterable = keyGroupsStateHandles == null ?
+				null
+				: new Iterable<KeyGroupStatePartitionStreamProvider>() {
+			@Override
+			public Iterator<KeyGroupStatePartitionStreamProvider> iterator() {
+				return new KeyGroupStreamIterator(getKeyGroupsStateHandles().iterator(), getClosableRegistry());
+			}
+		};
+	}
+
+	@Override
+	public boolean isRestored() {
+		return restored;
+	}
+
+	public Collection<OperatorStateHandle> getOperatorStateHandles() {
+		return operatorStateHandles;
+	}
+
+	public Collection<KeyGroupsStateHandle> getKeyGroupsStateHandles() {
+		return keyGroupsStateHandles;
+	}
+
+	public ClosableRegistry getClosableRegistry() {
+		return closableRegistry;
+	}
+
+	@Override
+	public Iterable<StatePartitionStreamProvider> getRawOperatorStateInputs() {
+		if (null != operatorStateHandles) {
+			return new Iterable<StatePartitionStreamProvider>() {
+				@Override
+				public Iterator<StatePartitionStreamProvider> iterator() {
+					return new OperatorStateStreamIterator(
+							DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME,
+							getOperatorStateHandles().iterator(), getClosableRegistry());
+				}
+			};
+		} else {
+			return Collections.emptyList();
+		}
+	}
+
+	@Override
+	public Iterable<KeyGroupStatePartitionStreamProvider> getRawKeyedStateInputs() {
+		if(null == keyedStateStore) {
+			throw new IllegalStateException("Attempt to access keyed state from non-keyed operator.");
+		}
+
+		if (null != keyGroupsStateHandles) {
+			return keyedStateIterable;
+		} else {
+			return Collections.emptyList();
+		}
+	}
+
+	@Override
+	public OperatorStateStore getManagedOperatorStateStore() {
+		return operatorStateStore;
+	}
+
+	@Override
+	public KeyedStateStore getManagedKeyedStateStore() {
+		return keyedStateStore;
+	}
+
+	public void close() {
+		IOUtils.closeQuietly(closableRegistry);
+	}
+
+	private static class KeyGroupStreamIterator implements Iterator<KeyGroupStatePartitionStreamProvider> {
+
+		private final Iterator<KeyGroupsStateHandle> stateHandleIterator;
+		private final ClosableRegistry closableRegistry;
+
+		private KeyGroupsStateHandle currentStateHandle;
+		private FSDataInputStream currentStream;
+		private Iterator<Tuple2<Integer, Long>> currentOffsetsIterator;
+
+		public KeyGroupStreamIterator(
+				Iterator<KeyGroupsStateHandle> stateHandleIterator, ClosableRegistry closableRegistry) {
+
+			this.stateHandleIterator = Preconditions.checkNotNull(stateHandleIterator);
+			this.closableRegistry = Preconditions.checkNotNull(closableRegistry);
+		}
+
+		@Override
+		public boolean hasNext() {
+			if (null != currentStateHandle && currentOffsetsIterator.hasNext()) {
+				return true;
+			} else {
+				while (stateHandleIterator.hasNext()) {
+					currentStateHandle = stateHandleIterator.next();
+					if (currentStateHandle.getNumberOfKeyGroups() > 0) {
+						currentOffsetsIterator = currentStateHandle.getGroupRangeOffsets().iterator();
+						closableRegistry.unregisterClosable(currentStream);
+						IOUtils.closeQuietly(currentStream);
+						currentStream = null;
+						return true;
+					}
+				}
+				return false;
+			}
+		}
+
+		private void openStream() throws IOException {
+			FSDataInputStream stream = currentStateHandle.openInputStream();
+			closableRegistry.registerClosable(stream);
+			currentStream = stream;
+		}
+
+		@Override
+		public KeyGroupStatePartitionStreamProvider next() {
+			Tuple2<Integer, Long> keyGroupOffset = currentOffsetsIterator.next();
+			try {
+				if (null == currentStream) {
+					openStream();
+				}
+				currentStream.seek(keyGroupOffset.f1);
+				return new KeyGroupStatePartitionStreamProvider(currentStream, keyGroupOffset.f0);
+			} catch (IOException ioex) {
+				return new KeyGroupStatePartitionStreamProvider(ioex, keyGroupOffset.f0);
+			}
+		}
+
+		@Override
+		public void remove() {
+			throw new UnsupportedOperationException("Read only Iterator");
+		}
+	}
+
+	private static class OperatorStateStreamIterator implements Iterator<StatePartitionStreamProvider> {
+
+		private final String stateName; //TODO since we only support a single named state in raw, this could be dropped
+
+		private final Iterator<OperatorStateHandle> stateHandleIterator;
+		private final ClosableRegistry closableRegistry;
+
+		private OperatorStateHandle currentStateHandle;
+		private FSDataInputStream currentStream;
+		private long[] offsets;
+		private int offPos;
+
+		public OperatorStateStreamIterator(
+				String stateName,
+				Iterator<OperatorStateHandle> stateHandleIterator,
+				ClosableRegistry closableRegistry) {
+
+			this.stateName = Preconditions.checkNotNull(stateName);
+			this.stateHandleIterator = Preconditions.checkNotNull(stateHandleIterator);
+			this.closableRegistry = Preconditions.checkNotNull(closableRegistry);
+		}
+
+		@Override
+		public boolean hasNext() {
+			if (null != currentStateHandle && offPos < offsets.length) {
+				return true;
+			} else {
+				while (stateHandleIterator.hasNext()) {
+					currentStateHandle = stateHandleIterator.next();
+					long[] offsets = currentStateHandle.getStateNameToPartitionOffsets().get(stateName);
+					if (null != offsets && offsets.length > 0) {
+
+						this.offsets = offsets;
+						this.offPos = 0;
+
+						closableRegistry.unregisterClosable(currentStream);
+						IOUtils.closeQuietly(currentStream);
+						currentStream = null;
+
+						return true;
+					}
+				}
+				return false;
+			}
+		}
+
+		private void openStream() throws IOException {
+			FSDataInputStream stream = currentStateHandle.openInputStream();
+			closableRegistry.registerClosable(stream);
+			currentStream = stream;
+		}
+
+		@Override
+		public StatePartitionStreamProvider next() {
+			long offset = offsets[offPos++];
+			try {
+				if (null == currentStream) {
+					openStream();
+				}
+				currentStream.seek(offset);
+
+				return new StatePartitionStreamProvider(currentStream);
+			} catch (IOException ioex) {
+				return new StatePartitionStreamProvider(ioex);
+			}
+		}
+
+		@Override
+		public void remove() {
+			throw new UnsupportedOperationException("Read only Iterator");
+		}
+	}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/StatePartitionStreamProvider.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StatePartitionStreamProvider.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StatePartitionStreamProvider.java
new file mode 100644
index 0000000..8b07da8
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StatePartitionStreamProvider.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.runtime.util.NonClosingStreamDecorator;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.InputStream;
+
+/**
+ * This class provides access to input streams that contain data of one state partition of a partitionable state.
+ *
+ * TODO use bounded stream that fail fast if the limit is exceeded on corrupted reads.
+ */
+@PublicEvolving
+public class StatePartitionStreamProvider {
+
+	/** A ready-made stream that contains data for one state partition */
+	private final InputStream stream;
+
+	/** Holds potential exception that happened when actually trying to create the stream */
+	private final IOException creationException;
+
+	public StatePartitionStreamProvider(IOException creationException) {
+		this.creationException = Preconditions.checkNotNull(creationException);
+		this.stream = null;
+	}
+
+	public StatePartitionStreamProvider(InputStream stream) {
+		this.stream = new NonClosingStreamDecorator(Preconditions.checkNotNull(stream));
+		this.creationException = null;
+	}
+
+
+	/**
+	 * Returns a stream with the data of one state partition.
+	 */
+	public InputStream getStream() throws IOException {
+		if (creationException != null) {
+			throw new IOException(creationException);
+		}
+		return stream;
+	}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContext.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContext.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContext.java
new file mode 100644
index 0000000..4dbbeaf
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContext.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.annotation.PublicEvolving;
+
+/**
+ * This interface provides a context in which operators that use managed (i.e. state that is managed by state
+ * backends) or raw (i.e. the operator can write it's state streams) state can perform a snapshot.
+ */
+@PublicEvolving
+public interface StateSnapshotContext extends FunctionSnapshotContext {
+
+	/**
+	 * Returns an output stream for keyed state
+	 */
+	KeyedStateCheckpointOutputStream getRawKeyedOperatorStateOutput() throws Exception;
+
+	/**
+	 * Returns an output stream for operator state
+	 */
+	OperatorStateCheckpointOutputStream getRawOperatorStateOutput() throws Exception;
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContextSynchronousImpl.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContextSynchronousImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContextSynchronousImpl.java
new file mode 100644
index 0000000..d632529
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotContextSynchronousImpl.java
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.concurrent.RunnableFuture;
+
+/**
+ * This class is a default implementation for StateSnapshotContext.
+ */
+public class StateSnapshotContextSynchronousImpl implements StateSnapshotContext {
+	
+	private final long checkpointId;
+	private final long checkpointTimestamp;
+	
+	/** Factory for he checkpointing stream */
+	private final CheckpointStreamFactory streamFactory;
+	
+	/** Key group range for the operator that created this context. Only for keyed operators */
+	private final KeyGroupRange keyGroupRange;
+
+	/**
+	 * Registry for opened streams to participate in the lifecycle of the stream task. Hence, this registry should be 
+	 * obtained from and managed by the stream task.
+	 */
+	private final ClosableRegistry closableRegistry;
+
+	private KeyedStateCheckpointOutputStream keyedStateCheckpointOutputStream;
+	private OperatorStateCheckpointOutputStream operatorStateCheckpointOutputStream;
+
+	@VisibleForTesting
+	public StateSnapshotContextSynchronousImpl(long checkpointId, long checkpointTimestamp) {
+		this.checkpointId = checkpointId;
+		this.checkpointTimestamp = checkpointTimestamp;
+		this.streamFactory = null;
+		this.keyGroupRange = KeyGroupRange.EMPTY_KEY_GROUP_RANGE;
+		this.closableRegistry = null;
+	}
+
+
+	public StateSnapshotContextSynchronousImpl(
+			long checkpointId,
+			long checkpointTimestamp,
+			CheckpointStreamFactory streamFactory,
+			KeyGroupRange keyGroupRange,
+			ClosableRegistry closableRegistry) {
+
+		this.checkpointId = checkpointId;
+		this.checkpointTimestamp = checkpointTimestamp;
+		this.streamFactory = Preconditions.checkNotNull(streamFactory);
+		this.keyGroupRange = Preconditions.checkNotNull(keyGroupRange);
+		this.closableRegistry = Preconditions.checkNotNull(closableRegistry);
+	}
+
+	@Override
+	public long getCheckpointId() {
+		return checkpointId;
+	}
+
+	@Override
+	public long getCheckpointTimestamp() {
+		return checkpointTimestamp;
+	}
+
+	private CheckpointStreamFactory.CheckpointStateOutputStream openAndRegisterNewStream() throws Exception {
+		CheckpointStreamFactory.CheckpointStateOutputStream cout =
+				streamFactory.createCheckpointStateOutputStream(checkpointId, checkpointTimestamp);
+
+		closableRegistry.registerClosable(cout);
+		return cout;
+	}
+
+	@Override
+	public KeyedStateCheckpointOutputStream getRawKeyedOperatorStateOutput() throws Exception {
+		if (null == keyedStateCheckpointOutputStream) {
+			Preconditions.checkState(keyGroupRange != KeyGroupRange.EMPTY_KEY_GROUP_RANGE, "Not a keyed operator");
+			keyedStateCheckpointOutputStream = new KeyedStateCheckpointOutputStream(openAndRegisterNewStream(), keyGroupRange);
+		}
+		return keyedStateCheckpointOutputStream;
+	}
+
+	@Override
+	public OperatorStateCheckpointOutputStream getRawOperatorStateOutput() throws Exception {
+		if (null == operatorStateCheckpointOutputStream) {
+			operatorStateCheckpointOutputStream = new OperatorStateCheckpointOutputStream(openAndRegisterNewStream());
+		}
+		return operatorStateCheckpointOutputStream;
+	}
+
+	public RunnableFuture<KeyGroupsStateHandle> getKeyedStateStreamFuture() throws IOException {
+		return closeAndUnregisterStreamToObtainStateHandle(keyedStateCheckpointOutputStream);
+	}
+
+	public RunnableFuture<OperatorStateHandle> getOperatorStateStreamFuture() throws IOException {
+		return closeAndUnregisterStreamToObtainStateHandle(operatorStateCheckpointOutputStream);
+	}
+
+	private <T extends StreamStateHandle> RunnableFuture<T> closeAndUnregisterStreamToObtainStateHandle(
+			NonClosingCheckpointOutputStream<T> stream) throws IOException {
+		if (null == stream) {
+			return null;
+		}
+
+		closableRegistry.unregisterClosable(stream.getDelegate());
+
+		// for now we only support synchronous writing
+		return new DoneFuture<>(stream.closeAndGetHandle());
+	}
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java
new file mode 100644
index 0000000..ecd6399
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskStateHandles.java
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.runtime.checkpoint.SubtaskState;
+import org.apache.flink.util.CollectionUtil;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * This class encapsulates all state handles for a task.
+ */
+public class TaskStateHandles implements Serializable {
+
+	public static final TaskStateHandles EMPTY = new TaskStateHandles();
+
+	private static final long serialVersionUID = 267686583583579359L;
+
+	/** State handle with the (non-partitionable) legacy operator state*/
+	@Deprecated
+	private final ChainedStateHandle<StreamStateHandle> legacyOperatorState;
+
+	/** Collection of handles which represent the managed keyed state of the head operator */
+	private final Collection<KeyGroupsStateHandle> managedKeyedState;
+
+	/** Collection of handles which represent the raw/streamed keyed state of the head operator */
+	private final Collection<KeyGroupsStateHandle> rawKeyedState;
+
+	/** Outer list represents the operator chain, each collection holds handles for managed state of a single operator */
+	private final List<Collection<OperatorStateHandle>> managedOperatorState;
+
+	/** Outer list represents the operator chain, each collection holds handles for raw/streamed state of a single operator */
+	private final List<Collection<OperatorStateHandle>> rawOperatorState;
+
+	public TaskStateHandles() {
+		this(null, null, null, null, null);
+	}
+
+	public TaskStateHandles(SubtaskState checkpointStateHandles) {
+		this(checkpointStateHandles.getLegacyOperatorState(),
+				transform(checkpointStateHandles.getManagedOperatorState()),
+				transform(checkpointStateHandles.getRawOperatorState()),
+				transform(checkpointStateHandles.getManagedKeyedState()),
+				transform(checkpointStateHandles.getRawKeyedState()));
+	}
+
+	public TaskStateHandles(
+			ChainedStateHandle<StreamStateHandle> legacyOperatorState,
+			List<Collection<OperatorStateHandle>> managedOperatorState,
+			List<Collection<OperatorStateHandle>> rawOperatorState,
+			Collection<KeyGroupsStateHandle> managedKeyedState,
+			Collection<KeyGroupsStateHandle> rawKeyedState) {
+
+		this.legacyOperatorState = legacyOperatorState;
+		this.managedKeyedState = managedKeyedState;
+		this.rawKeyedState = rawKeyedState;
+		this.managedOperatorState = managedOperatorState;
+		this.rawOperatorState = rawOperatorState;
+	}
+
+	@Deprecated
+	public ChainedStateHandle<StreamStateHandle> getLegacyOperatorState() {
+		return legacyOperatorState;
+	}
+
+	public Collection<KeyGroupsStateHandle> getManagedKeyedState() {
+		return managedKeyedState;
+	}
+
+	public Collection<KeyGroupsStateHandle> getRawKeyedState() {
+		return rawKeyedState;
+	}
+
+	public List<Collection<OperatorStateHandle>> getRawOperatorState() {
+		return rawOperatorState;
+	}
+
+	public List<Collection<OperatorStateHandle>> getManagedOperatorState() {
+		return managedOperatorState;
+	}
+
+	public boolean hasState() {
+		return !ChainedStateHandle.isNullOrEmpty(legacyOperatorState)
+				|| !CollectionUtil.isNullOrEmpty(managedKeyedState)
+				|| !CollectionUtil.isNullOrEmpty(rawKeyedState)
+				|| !CollectionUtil.isNullOrEmpty(rawOperatorState)
+				|| !CollectionUtil.isNullOrEmpty(managedOperatorState);
+	}
+
+	private static List<Collection<OperatorStateHandle>> transform(ChainedStateHandle<OperatorStateHandle> in) {
+		if (null == in) {
+			return Collections.emptyList();
+		}
+		List<Collection<OperatorStateHandle>> out = new ArrayList<>(in.getLength());
+		for (int i = 0; i < in.getLength(); ++i) {
+			OperatorStateHandle osh = in.get(i);
+			out.add(osh != null ? Collections.singletonList(osh) : null);
+		}
+		return out;
+	}
+
+	private static List<KeyGroupsStateHandle> transform(KeyGroupsStateHandle in) {
+		return in == null ? Collections.<KeyGroupsStateHandle>emptyList() : Collections.singletonList(in);
+	}
+
+	@Override
+	public boolean equals(Object o) {
+		if (this == o) {
+			return true;
+		}
+		if (o == null || getClass() != o.getClass()) {
+			return false;
+		}
+
+		TaskStateHandles that = (TaskStateHandles) o;
+
+		if (legacyOperatorState != null ?
+				!legacyOperatorState.equals(that.legacyOperatorState)
+				: that.legacyOperatorState != null) {
+			return false;
+		}
+		if (managedKeyedState != null ?
+				!managedKeyedState.equals(that.managedKeyedState)
+				: that.managedKeyedState != null) {
+			return false;
+		}
+		if (rawKeyedState != null ?
+				!rawKeyedState.equals(that.rawKeyedState)
+				: that.rawKeyedState != null) {
+			return false;
+		}
+
+		if (rawOperatorState != null ?
+				!rawOperatorState.equals(that.rawOperatorState)
+				: that.rawOperatorState != null) {
+			return false;
+		}
+		return managedOperatorState != null ?
+				managedOperatorState.equals(that.managedOperatorState)
+				: that.managedOperatorState == null;
+	}
+
+	@Override
+	public int hashCode() {
+		int result = legacyOperatorState != null ? legacyOperatorState.hashCode() : 0;
+		result = 31 * result + (managedKeyedState != null ? managedKeyedState.hashCode() : 0);
+		result = 31 * result + (rawKeyedState != null ? rawKeyedState.hashCode() : 0);
+		result = 31 * result + (managedOperatorState != null ? managedOperatorState.hashCode() : 0);
+		result = 31 * result + (rawOperatorState != null ? rawOperatorState.hashCode() : 0);
+		return result;
+	}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/UserFacingListState.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/UserFacingListState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/UserFacingListState.java
new file mode 100644
index 0000000..71026c6
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/UserFacingListState.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.api.common.state.ListState;
+
+import java.util.Collections;
+
+/**
+ * Simple wrapper list state that exposes empty state properly as an empty list.
+ * 
+ * @param <T> The type of elements in the list state.
+ */
+class UserFacingListState<T> implements ListState<T> {
+
+	private final ListState<T> originalState;
+
+	private final Iterable<T> emptyState = Collections.emptyList();
+
+	UserFacingListState(ListState<T> originalState) {
+		this.originalState = originalState;
+	}
+
+	// ------------------------------------------------------------------------
+
+	@Override
+	public Iterable<T> get() throws Exception {
+		Iterable<T> original = originalState.get();
+		return original != null ? original : emptyState;
+	}
+
+	@Override
+	public void add(T value) throws Exception {
+		originalState.add(value);
+	}
+
+	@Override
+	public void clear() {
+		originalState.clear();
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
index e027632..4e15cd5 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
@@ -36,7 +36,7 @@ import org.slf4j.LoggerFactory;
 import java.io.IOException;
 import java.net.URI;
 import java.net.URISyntaxException;
-import java.util.List;
+import java.util.Collection;
 
 /**
  * The file state backend is a state backend that stores the state of streaming jobs in a file system.
@@ -199,7 +199,7 @@ public class FsStateBackend extends AbstractStateBackend {
 			TypeSerializer<K> keySerializer,
 			int numberOfKeyGroups,
 			KeyGroupRange keyGroupRange,
-			List<KeyGroupsStateHandle> restoredState,
+			Collection<KeyGroupsStateHandle> restoredState,
 			TaskKvStateRegistry kvStateRegistry) throws Exception {
 		return new HeapKeyedStateBackend<>(
 				kvStateRegistry,

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
index b283494..56be46f 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
@@ -50,8 +50,8 @@ import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.HashMap;
-import java.util.List;
 import java.util.Map;
 import java.util.concurrent.RunnableFuture;
 
@@ -94,7 +94,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			ClassLoader userCodeClassLoader,
 			int numberOfKeyGroups,
 			KeyGroupRange keyGroupRange,
-			List<KeyGroupsStateHandle> restoredState) throws Exception {
+			Collection<KeyGroupsStateHandle> restoredState) throws Exception {
 		super(kvStateRegistry, keySerializer, userCodeClassLoader, numberOfKeyGroups, keyGroupRange);
 
 		LOG.info("Initializing heap keyed state backend from snapshot.");
@@ -248,7 +248,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	}
 
 	@SuppressWarnings({"unchecked"})
-	private void restorePartitionedState(List<KeyGroupsStateHandle> state) throws Exception {
+	private void restorePartitionedState(Collection<KeyGroupsStateHandle> state) throws Exception {
 
 		int numRegisteredKvStates = 0;
 		Map<Integer, String> kvStatesById = new HashMap<>();
@@ -259,13 +259,10 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 				continue;
 			}
 
-			FSDataInputStream fsDataInputStream = null;
+			FSDataInputStream fsDataInputStream = keyGroupsHandle.openInputStream();
+			cancelStreamRegistry.registerClosable(fsDataInputStream);
 
 			try {
-
-				fsDataInputStream = keyGroupsHandle.openInputStream();
-				cancelStreamRegistry.registerClosable(fsDataInputStream);
-
 				DataInputViewStreamWrapper inView = new DataInputViewStreamWrapper(fsDataInputStream);
 
 				int numKvStates = inView.readShort();

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemCheckpointStreamFactory.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemCheckpointStreamFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemCheckpointStreamFactory.java
index 028f8c8..30de638 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemCheckpointStreamFactory.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemCheckpointStreamFactory.java
@@ -127,6 +127,10 @@ public class MemCheckpointStreamFactory implements CheckpointStreamFactory {
 			return os.getPosition();
 		}
 
+		public boolean isClosed() {
+			return closed;
+		}
+
 		/**
 		 * Closes the stream and returns the byte array containing the stream's data.
 		 * @return The byte array containing the stream's data.

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
index 1772dbe..33f03ad 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
@@ -30,7 +30,7 @@ import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend;
 
 import java.io.IOException;
-import java.util.List;
+import java.util.Collection;
 
 /**
  * A {@link AbstractStateBackend} that stores all its data and checkpoints in memory and has no
@@ -100,7 +100,7 @@ public class MemoryStateBackend extends AbstractStateBackend {
 			TypeSerializer<K> keySerializer,
 			int numberOfKeyGroups,
 			KeyGroupRange keyGroupRange,
-			List<KeyGroupsStateHandle> restoredState,
+			Collection<KeyGroupsStateHandle> restoredState,
 			TaskKvStateRegistry kvStateRegistry) throws Exception {
 
 		return new HeapKeyedStateBackend<>(

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java
index 6f1bf7b..38defcc 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java
@@ -20,11 +20,11 @@ package org.apache.flink.runtime.taskmanager;
 
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.checkpoint.SubtaskState;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.instance.ActorGateway;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.util.Preconditions;
 
 /**
@@ -43,7 +43,7 @@ public class ActorGatewayCheckpointResponder implements CheckpointResponder {
 			JobID jobID,
 			ExecutionAttemptID executionAttemptID,
 			CheckpointMetaData checkpointMetaData,
-			CheckpointStateHandles checkpointStateHandles) {
+			SubtaskState checkpointStateHandles) {
 
 		AcknowledgeCheckpoint message = new AcknowledgeCheckpoint(
 				jobID, executionAttemptID, checkpointMetaData,

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java
index 4fa20e6..7dbb76c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java
@@ -20,8 +20,8 @@ package org.apache.flink.runtime.taskmanager;
 
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.checkpoint.SubtaskState;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
 
 /**
  * Responder for checkpoint acknowledge and decline messages in the {@link Task}.
@@ -35,7 +35,7 @@ public interface CheckpointResponder {
 	 *             Job ID of the running job
 	 * @param executionAttemptID
 	 *             Execution attempt ID of the running task
-	 * @param checkpointStateHandles
+	 * @param subtaskState
 	 *             State handles for the checkpoint
 	 * @param checkpointMetaData
 	 *             Meta data for this checkpoint
@@ -45,7 +45,7 @@ public interface CheckpointResponder {
 		JobID jobID,
 		ExecutionAttemptID executionAttemptID,
 		CheckpointMetaData checkpointMetaData,
-		CheckpointStateHandles checkpointStateHandles);
+		SubtaskState subtaskState);
 
 	/**
 	 * Declines the given checkpoint.

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java
index f6720e7..fa69a60 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java
@@ -26,6 +26,7 @@ import org.apache.flink.core.fs.Path;
 import org.apache.flink.runtime.accumulators.AccumulatorRegistry;
 import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.checkpoint.SubtaskState;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.io.disk.iomanager.IOManager;
@@ -36,7 +37,6 @@ import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
 import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.metrics.groups.TaskMetricGroup;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
 
 import java.util.Map;
 import java.util.concurrent.Future;
@@ -245,7 +245,7 @@ public class RuntimeEnvironment implements Environment {
 	@Override
 	public void acknowledgeCheckpoint(
 			CheckpointMetaData checkpointMetaData,
-			CheckpointStateHandles checkpointStateHandles) {
+			SubtaskState checkpointStateHandles) {
 
 
 		checkpointResponder.acknowledgeCheckpoint(

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
index 02a41b5..bd522bd 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
@@ -25,16 +25,11 @@ import org.apache.flink.api.common.cache.DistributedCache;
 import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.core.fs.Path;
-import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
-import org.apache.flink.runtime.concurrent.BiFunction;
-import org.apache.flink.runtime.io.network.PartitionState;
-import org.apache.flink.runtime.io.network.netty.PartitionStateChecker;
-import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
-import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier;
-import org.apache.flink.runtime.metrics.groups.TaskMetricGroup;
 import org.apache.flink.runtime.accumulators.AccumulatorRegistry;
 import org.apache.flink.runtime.blob.BlobKey;
 import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
+import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.concurrent.BiFunction;
 import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor;
@@ -46,23 +41,24 @@ import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.filecache.FileCache;
 import org.apache.flink.runtime.io.disk.iomanager.IOManager;
 import org.apache.flink.runtime.io.network.NetworkEnvironment;
+import org.apache.flink.runtime.io.network.PartitionState;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
+import org.apache.flink.runtime.io.network.netty.PartitionStateChecker;
 import org.apache.flink.runtime.io.network.partition.ResultPartition;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
 import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
 import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
+import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
 import org.apache.flink.runtime.jobgraph.tasks.StatefulTask;
 import org.apache.flink.runtime.jobgraph.tasks.StoppableTask;
 import org.apache.flink.runtime.memory.MemoryManager;
+import org.apache.flink.runtime.metrics.groups.TaskMetricGroup;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
-import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.OperatorStateHandle;
-import org.apache.flink.runtime.state.StreamStateHandle;
-
+import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.SerializedValue;
 import org.slf4j.Logger;
@@ -70,7 +66,6 @@ import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
 import java.net.URL;
-import java.util.Collection;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -233,18 +228,10 @@ public class Task implements Runnable, TaskActions {
 	private volatile ExecutorService asyncCallDispatcher;
 
 	/**
-	 * The handle to the chained operator state that the task was initialized with. Will be set
+	 * The handles to the states that the task was initialized with. Will be set
 	 * to null after the initialization, to be memory friendly.
 	 */
-	private volatile ChainedStateHandle<StreamStateHandle> chainedOperatorState;
-
-	/**
-	 * The handle to the key group state that the task was initialized with. Will be set
-	 * to null after the initialization, to be memory friendly.
-	 */
-	private volatile List<KeyGroupsStateHandle> keyGroupStates;
-
-	private volatile List<Collection<OperatorStateHandle>> partitionableOperatorState;
+	private volatile TaskStateHandles taskStateHandles;
 
 	/** Initialized from the Flink configuration. May also be set at the ExecutionConfig */
 	private long taskCancellationInterval;
@@ -280,10 +267,8 @@ public class Task implements Runnable, TaskActions {
 		this.requiredJarFiles = checkNotNull(tdd.getRequiredJarFiles());
 		this.requiredClasspaths = checkNotNull(tdd.getRequiredClasspaths());
 		this.nameOfInvokableClass = checkNotNull(tdd.getInvokableClassName());
-		this.chainedOperatorState = tdd.getOperatorState();
 		this.serializedExecutionConfig = checkNotNull(tdd.getSerializedExecutionConfig());
-		this.keyGroupStates = tdd.getKeyGroupState();
-		this.partitionableOperatorState = tdd.getPartitionableOperatorState();
+		this.taskStateHandles = tdd.getTaskStateHandles();
 
 		this.taskCancellationInterval = jobConfiguration.getLong(
 			ConfigConstants.TASK_CANCELLATION_INTERVAL_MILLIS,
@@ -570,20 +555,19 @@ public class Task implements Runnable, TaskActions {
 			// the state into the task. the state is non-empty if this is an execution
 			// of a task that failed but had backuped state from a checkpoint
 
-			if (chainedOperatorState != null || keyGroupStates != null || partitionableOperatorState != null) {
+			if (null != taskStateHandles) {
 				if (invokable instanceof StatefulTask) {
 					StatefulTask op = (StatefulTask) invokable;
-					op.setInitialState(chainedOperatorState, keyGroupStates, partitionableOperatorState);
+					op.setInitialState(taskStateHandles);
 				} else {
 					throw new IllegalStateException("Found operator state for a non-stateful task invokable");
 				}
+				// be memory and GC friendly - since the code stays in invoke() for a potentially long time,
+				// we clear the reference to the state handle
+				//noinspection UnusedAssignment
+				taskStateHandles = null;
 			}
 
-			// be memory and GC friendly - since the code stays in invoke() for a potentially long time,
-			// we clear the reference to the state handle
-			//noinspection UnusedAssignment
-			this.chainedOperatorState = null;
-			this.keyGroupStates = null;
 
 			// ----------------------------------------------------------------
 			//  actual task core work

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/util/IntArrayList.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/util/IntArrayList.java b/flink-runtime/src/main/java/org/apache/flink/runtime/util/IntArrayList.java
index 27d958a..ce52eac 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/util/IntArrayList.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/util/IntArrayList.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.util;
 
+import java.util.Arrays;
 import java.util.NoSuchElementException;
 
 /**
@@ -69,6 +70,10 @@ public class IntArrayList {
 		}
 	}
 
+	public int[] toArray() {
+		return Arrays.copyOf(array, size);
+	}
+
 	public static final IntArrayList EMPTY = new IntArrayList(0) {
 		
 		@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/util/LongArrayList.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/util/LongArrayList.java b/flink-runtime/src/main/java/org/apache/flink/runtime/util/LongArrayList.java
index e653209..f2d9556 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/util/LongArrayList.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/util/LongArrayList.java
@@ -18,6 +18,8 @@
 
 package org.apache.flink.runtime.util;
 
+import java.util.Arrays;
+
 /**
  * Minimal implementation of an array-backed list of longs
  */
@@ -61,6 +63,10 @@ public class LongArrayList {
 	public boolean isEmpty() {
 		return (size==0);
 	}
+
+	public long[] toArray() {
+		return Arrays.copyOf(array, size);
+	}
 	
 	private void grow(int length) {
 		if(length > array.length) {

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/util/NonClosingStreamDecorator.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/util/NonClosingStreamDecorator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/util/NonClosingStreamDecorator.java
new file mode 100644
index 0000000..ba7bc79
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/util/NonClosingStreamDecorator.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.util;
+
+import java.io.IOException;
+import java.io.InputStream;
+
+/**
+ * Decorator for input streams that ignores calls to {@link InputStream#close()}.
+ */
+public class NonClosingStreamDecorator extends InputStream {
+
+	private final InputStream delegate;
+
+	public NonClosingStreamDecorator(InputStream delegate) {
+		this.delegate = delegate;
+	}
+
+	@Override
+	public int read() throws IOException {
+		return delegate.read();
+	}
+
+	@Override
+	public int read(byte[] b) throws IOException {
+		return delegate.read(b);
+	}
+
+	@Override
+	public int read(byte[] b, int off, int len) throws IOException {
+		return delegate.read(b, off, len);
+	}
+
+	@Override
+	public long skip(long n) throws IOException {
+		return delegate.skip(n);
+	}
+
+	@Override
+	public int available() throws IOException {
+		return super.available();
+	}
+
+	@Override
+	public void close() throws IOException {
+		// ignore
+	}
+
+	@Override
+	public void mark(int readlimit) {
+		super.mark(readlimit);
+	}
+
+	@Override
+	public void reset() throws IOException {
+		super.reset();
+	}
+
+	@Override
+	public boolean markSupported() {
+		return super.markSupported();
+	}
+}
\ No newline at end of file


[8/8] flink git commit: [FLINK-4844] Partitionable Raw Keyed/Operator State

Posted by al...@apache.org.
[FLINK-4844] Partitionable Raw Keyed/Operator State


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/cab9cd44
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/cab9cd44
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/cab9cd44

Branch: refs/heads/master
Commit: cab9cd44eca83ef8cbcd2a2d070d8c79cb037977
Parents: 428419d
Author: Stefan Richter <s....@data-artisans.com>
Authored: Tue Oct 4 10:59:38 2016 +0200
Committer: Aljoscha Krettek <al...@gmail.com>
Committed: Thu Oct 20 16:14:21 2016 +0200

----------------------------------------------------------------------
 .../state/RocksDBKeyedStateBackend.java         |   5 +-
 .../streaming/state/RocksDBStateBackend.java    |   3 +-
 .../state/RocksDBAsyncSnapshotTest.java         |  10 +-
 .../flink/api/common/state/KeyedStateStore.java | 159 +++++++
 .../api/common/state/OperatorStateStore.java    |   9 +-
 .../core/fs/local/LocalDataInputStream.java     |  17 +-
 .../org/apache/flink/util/CollectionUtil.java   |  37 ++
 .../java/org/apache/flink/util/FutureUtil.java  |  42 ++
 .../flink/cep/operator/CEPOperatorTest.java     |  12 +-
 .../checkpoint/CheckpointCoordinator.java       | 212 +--------
 .../runtime/checkpoint/CheckpointMetaData.java  |  42 ++
 .../runtime/checkpoint/PendingCheckpoint.java   |  96 ++--
 .../RoundRobinOperatorStateRepartitioner.java   |   2 +-
 .../checkpoint/StateAssignmentOperation.java    | 329 +++++++++++++
 .../flink/runtime/checkpoint/SubtaskState.java  | 181 ++++++-
 .../flink/runtime/checkpoint/TaskState.java     |  90 +---
 .../savepoint/SavepointV1Serializer.java        | 161 ++++---
 .../deployment/TaskDeploymentDescriptor.java    |  36 +-
 .../flink/runtime/execution/Environment.java    |   6 +-
 .../flink/runtime/executiongraph/Execution.java |  43 +-
 .../runtime/executiongraph/ExecutionVertex.java |  24 +-
 .../runtime/fs/hdfs/HadoopDataInputStream.java  |   9 +-
 .../runtime/jobgraph/tasks/StatefulTask.java    |  18 +-
 .../checkpoint/AcknowledgeCheckpoint.java       |  20 +-
 .../flink/runtime/query/KvStateMessage.java     |   4 +-
 .../state/AbstractKeyedStateBackend.java        |   2 +-
 .../runtime/state/AbstractStateBackend.java     |   3 +-
 .../flink/runtime/state/BoundedInputStream.java | 112 +++++
 .../flink/runtime/state/ChainedStateHandle.java |   4 +
 .../runtime/state/CheckpointStateHandles.java   | 103 ----
 .../flink/runtime/state/ClosableRegistry.java   |  48 +-
 .../runtime/state/DefaultKeyedStateStore.java   |  89 ++++
 .../state/DefaultOperatorStateBackend.java      |  57 ++-
 .../state/FunctionInitializationContext.java    |  37 ++
 .../runtime/state/FunctionSnapshotContext.java  |  30 ++
 .../flink/runtime/state/KeyGroupRange.java      |  21 +-
 .../runtime/state/KeyGroupRangeOffsets.java     |   6 +-
 .../KeyGroupStatePartitionStreamProvider.java   |  51 ++
 .../flink/runtime/state/KeyGroupsList.java      |  43 ++
 .../flink/runtime/state/KeyedStateBackend.java  |   4 +-
 .../state/KeyedStateCheckpointOutputStream.java | 108 +++++
 .../state/ManagedInitializationContext.java     |  53 +++
 .../runtime/state/ManagedSnapshotContext.java   |  41 ++
 .../state/NonClosingCheckpointOutputStream.java |  80 ++++
 .../runtime/state/OperatorStateBackend.java     |   4 +-
 .../OperatorStateCheckpointOutputStream.java    |  78 +++
 .../runtime/state/OperatorStateHandle.java      |   4 +-
 ...artitionableCheckpointStateOutputStream.java |  96 ----
 .../flink/runtime/state/SnapshotProvider.java   |  45 --
 .../flink/runtime/state/Snapshotable.java       |  45 ++
 .../state/StateInitializationContext.java       |  52 ++
 .../state/StateInitializationContextImpl.java   | 270 +++++++++++
 .../state/StatePartitionStreamProvider.java     |  62 +++
 .../runtime/state/StateSnapshotContext.java     |  40 ++
 .../StateSnapshotContextSynchronousImpl.java    | 129 +++++
 .../flink/runtime/state/TaskStateHandles.java   | 172 +++++++
 .../runtime/state/UserFacingListState.java      |  57 +++
 .../state/filesystem/FsStateBackend.java        |   4 +-
 .../state/heap/HeapKeyedStateBackend.java       |  13 +-
 .../memory/MemCheckpointStreamFactory.java      |   4 +
 .../state/memory/MemoryStateBackend.java        |   4 +-
 .../ActorGatewayCheckpointResponder.java        |   4 +-
 .../taskmanager/CheckpointResponder.java        |   6 +-
 .../runtime/taskmanager/RuntimeEnvironment.java |   4 +-
 .../apache/flink/runtime/taskmanager/Task.java  |  50 +-
 .../apache/flink/runtime/util/IntArrayList.java |   5 +
 .../flink/runtime/util/LongArrayList.java       |   6 +
 .../runtime/util/NonClosingStreamDecorator.java |  79 ++++
 .../checkpoint/CheckpointCoordinatorTest.java   | 262 ++++++-----
 .../checkpoint/CheckpointStateRestoreTest.java  |  37 +-
 .../CompletedCheckpointStoreTest.java           |   2 +-
 .../savepoint/SavepointV1SerializerTest.java    |  26 +-
 .../checkpoint/savepoint/SavepointV1Test.java   | 100 +++-
 .../stats/SimpleCheckpointStatsTrackerTest.java |   3 +-
 .../jobmanager/JobManagerHARecoveryTest.java    |  17 +-
 .../messages/CheckpointMessagesTest.java        |  13 +-
 .../operators/testutils/DummyEnvironment.java   |   5 +-
 .../operators/testutils/MockEnvironment.java    |   5 +-
 .../runtime/state/KeyGroupRangeOffsetTest.java  |   4 +-
 .../flink/runtime/state/KeyGroupRangeTest.java  |   4 +-
 .../KeyedStateCheckpointOutputStreamTest.java   | 165 +++++++
 ...OperatorStateOutputCheckpointStreamTest.java | 102 ++++
 .../runtime/state/StateBackendTestBase.java     |   6 +-
 .../state/TestMemoryCheckpointOutputStream.java |  49 ++
 .../runtime/taskmanager/TaskAsyncCallTest.java  |   5 +-
 .../fs/bucketing/BucketingSinkTest.java         |   2 +-
 .../kafka/FlinkKafkaConsumerBase.java           |  58 +--
 .../kafka/FlinkKafkaProducerBase.java           |  10 +-
 .../kafka/AtLeastOnceProducerTest.java          |   8 +-
 .../kafka/FlinkKafkaConsumerBaseTest.java       | 127 +++--
 .../streaming/api/checkpoint/Checkpointed.java  |  12 +-
 .../api/checkpoint/CheckpointedFunction.java    |  44 +-
 .../api/checkpoint/CheckpointedRestoring.java   |  41 ++
 .../api/operators/AbstractStreamOperator.java   | 174 +++++--
 .../operators/AbstractUdfStreamOperator.java    | 105 +++--
 .../api/operators/OperatorSnapshotResult.java   |  81 ++++
 .../streaming/api/operators/StreamOperator.java |   8 +-
 .../api/operators/StreamingRuntimeContext.java  |  27 +-
 .../api/operators/UserFacingListState.java      |  57 ---
 .../runtime/tasks/OperatorStateHandles.java     | 109 +++++
 .../streaming/runtime/tasks/StreamTask.java     | 470 ++++++++++---------
 .../AbstractUdfStreamOperatorTest.java          | 219 +++++++++
 .../StateInitializationContextImplTest.java     | 260 ++++++++++
 ...StateSnapshotContextSynchronousImplTest.java |  61 +++
 .../StreamOperatorSnapshotRestoreTest.java      | 214 +++++++++
 .../operators/StreamingRuntimeContextTest.java  |  82 ++--
 .../streaming/runtime/io/BarrierBufferTest.java |   6 +-
 .../runtime/io/BarrierTrackerTest.java          |   6 +-
 .../operators/GenericWriteAheadSinkTest.java    |   6 +-
 .../operators/WriteAheadSinkTestBase.java       |  16 +-
 ...AlignedProcessingTimeWindowOperatorTest.java |   4 +-
 ...AlignedProcessingTimeWindowOperatorTest.java |   4 +-
 .../operators/windowing/WindowOperatorTest.java |  16 +-
 .../tasks/InterruptSensitiveRestoreTest.java    |  18 +-
 .../runtime/tasks/OneInputStreamTaskTest.java   |  66 +--
 .../runtime/tasks/StreamMockEnvironment.java    |   6 +-
 .../runtime/tasks/TwoInputStreamTaskTest.java   |  21 +-
 .../KeyedOneInputStreamOperatorTestHarness.java |  24 +-
 .../util/OneInputStreamOperatorTestHarness.java |  74 ++-
 .../util/TwoInputStreamOperatorTestHarness.java |  19 +
 .../streaming/util/WindowingTestHarness.java    |   2 +-
 .../test/checkpointing/RescalingITCase.java     | 149 ++++--
 .../test/checkpointing/SavepointITCase.java     |   5 +-
 .../streaming/runtime/StateBackendITCase.java   |   4 +-
 .../flink/yarn/cli/FlinkYarnSessionCli.java     |   6 +-
 125 files changed, 5337 insertions(+), 1781 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index 7ab35c4..f332d1e 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -65,6 +65,7 @@ import javax.annotation.concurrent.GuardedBy;
 import java.io.File;
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.Comparator;
 import java.util.HashMap;
 import java.util.List;
@@ -185,7 +186,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			TypeSerializer<K> keySerializer,
 			int numberOfKeyGroups,
 			KeyGroupRange keyGroupRange,
-			List<KeyGroupsStateHandle> restoreState
+			Collection<KeyGroupsStateHandle> restoreState
 	) throws Exception {
 
 		this(jobId,
@@ -603,7 +604,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		 * @throws ClassNotFoundException
 		 * @throws RocksDBException
 		 */
-		public void doRestore(List<KeyGroupsStateHandle> keyGroupsStateHandles)
+		public void doRestore(Collection<KeyGroupsStateHandle> keyGroupsStateHandles)
 				throws IOException, ClassNotFoundException, RocksDBException {
 
 			for (KeyGroupsStateHandle keyGroupsStateHandle : keyGroupsStateHandles) {

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
index a0c980b..82e7899 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
@@ -40,6 +40,7 @@ import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
 import java.net.URI;
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.List;
 import java.util.Random;
 import java.util.UUID;
@@ -258,7 +259,7 @@ public class RocksDBStateBackend extends AbstractStateBackend {
 			TypeSerializer<K> keySerializer,
 			int numberOfKeyGroups,
 			KeyGroupRange keyGroupRange,
-			List<KeyGroupsStateHandle> restoredState,
+			Collection<KeyGroupsStateHandle> restoredState,
 			TaskKvStateRegistry kvStateRegistry) throws Exception {
 
 		lazyInitializeForJob(env, operatorIdentifier);

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
index 8f58075..4d1ab50 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
@@ -28,9 +28,9 @@ import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.core.testutils.OneShotLatch;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.checkpoint.SubtaskState;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
@@ -70,7 +70,7 @@ import java.util.concurrent.CancellationException;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.TimeUnit;
 
-import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
 
 /**
  * Tests for asynchronous RocksDB Key/Value state checkpoints.
@@ -136,7 +136,7 @@ public class RocksDBAsyncSnapshotTest {
 			@Override
 			public void acknowledgeCheckpoint(
 					CheckpointMetaData checkpointMetaData,
-					CheckpointStateHandles checkpointStateHandles) {
+					SubtaskState checkpointStateHandles) {
 
 				super.acknowledgeCheckpoint(checkpointMetaData);
 
@@ -148,8 +148,8 @@ public class RocksDBAsyncSnapshotTest {
 					e.printStackTrace();
 				}
 
-				// should be only one k/v state
-				assertEquals(1, checkpointStateHandles.getKeyGroupsStateHandle().size());
+				// should be one k/v state
+				assertNotNull(checkpointStateHandles.getManagedKeyedState());
 
 				// we now know that the checkpoint went through
 				ensureCheckpointLatch.trigger();

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-core/src/main/java/org/apache/flink/api/common/state/KeyedStateStore.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/state/KeyedStateStore.java b/flink-core/src/main/java/org/apache/flink/api/common/state/KeyedStateStore.java
new file mode 100644
index 0000000..89c1240
--- /dev/null
+++ b/flink-core/src/main/java/org/apache/flink/api/common/state/KeyedStateStore.java
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.common.state;
+
+import org.apache.flink.annotation.PublicEvolving;
+
+/**
+ * This interface contains methods for registering keyed state with a managed store.
+ */
+@PublicEvolving
+public interface KeyedStateStore {
+
+	/**
+	 * Gets a handle to the system's key/value state. The key/value state is only accessible
+	 * if the function is executed on a KeyedStream. On each access, the state exposes the value
+	 * for the the key of the element currently processed by the function.
+	 * Each function may have multiple partitioned states, addressed with different names.
+	 *
+	 * <p>Because the scope of each value is the key of the currently processed element,
+	 * and the elements are distributed by the Flink runtime, the system can transparently
+	 * scale out and redistribute the state and KeyedStream.
+	 *
+	 * <p>The following code example shows how to implement a continuous counter that counts
+	 * how many times elements of a certain key occur, and emits an updated count for that
+	 * element on each occurrence.
+	 *
+	 * <pre>{@code
+	 * DataStream<MyType> stream = ...;
+	 * KeyedStream<MyType> keyedStream = stream.keyBy("id");
+	 *
+	 * keyedStream.map(new RichMapFunction<MyType, Tuple2<MyType, Long>>() {
+	 *
+	 *     private ValueState<Long> count;
+	 *
+	 *     public void open(Configuration cfg) {
+	 *         state = getRuntimeContext().getState(
+	 *                 new ValueStateDescriptor<Long>("count", LongSerializer.INSTANCE, 0L));
+	 *     }
+	 *
+	 *     public Tuple2<MyType, Long> map(MyType value) {
+	 *         long count = state.value() + 1;
+	 *         state.update(value);
+	 *         return new Tuple2<>(value, count);
+	 *     }
+	 * });
+	 * }</pre>
+	 *
+	 * @param stateProperties The descriptor defining the properties of the stats.
+	 *
+	 * @param <T> The type of value stored in the state.
+	 *
+	 * @return The partitioned state object.
+	 *
+	 * @throws UnsupportedOperationException Thrown, if no partitioned state is available for the
+	 *                                       function (function is not part of a KeyedStream).
+	 */
+	@PublicEvolving
+	<T> ValueState<T> getState(ValueStateDescriptor<T> stateProperties);
+
+	/**
+	 * Gets a handle to the system's key/value list state. This state is similar to the state
+	 * accessed via {@link #getState(ValueStateDescriptor)}, but is optimized for state that
+	 * holds lists. One can adds elements to the list, or retrieve the list as a whole.
+	 *
+	 * <p>This state is only accessible if the function is executed on a KeyedStream.
+	 *
+	 * <pre>{@code
+	 * DataStream<MyType> stream = ...;
+	 * KeyedStream<MyType> keyedStream = stream.keyBy("id");
+	 *
+	 * keyedStream.map(new RichFlatMapFunction<MyType, List<MyType>>() {
+	 *
+	 *     private ListState<MyType> state;
+	 *
+	 *     public void open(Configuration cfg) {
+	 *         state = getRuntimeContext().getListState(
+	 *                 new ListStateDescriptor<>("myState", MyType.class));
+	 *     }
+	 *
+	 *     public void flatMap(MyType value, Collector<MyType> out) {
+	 *         if (value.isDivider()) {
+	 *             for (MyType t : state.get()) {
+	 *                 out.collect(t);
+	 *             }
+	 *         } else {
+	 *             state.add(value);
+	 *         }
+	 *     }
+	 * });
+	 * }</pre>
+	 *
+	 * @param stateProperties The descriptor defining the properties of the stats.
+	 *
+	 * @param <T> The type of value stored in the state.
+	 *
+	 * @return The partitioned state object.
+	 *
+	 * @throws UnsupportedOperationException Thrown, if no partitioned state is available for the
+	 *                                       function (function is not part os a KeyedStream).
+	 */
+	@PublicEvolving
+	<T> ListState<T> getListState(ListStateDescriptor<T> stateProperties);
+
+	/**
+	 * Gets a handle to the system's key/value list state. This state is similar to the state
+	 * accessed via {@link #getState(ValueStateDescriptor)}, but is optimized for state that
+	 * aggregates values.
+	 *
+	 * <p>This state is only accessible if the function is executed on a KeyedStream.
+	 *
+	 * <pre>{@code
+	 * DataStream<MyType> stream = ...;
+	 * KeyedStream<MyType> keyedStream = stream.keyBy("id");
+	 *
+	 * keyedStream.map(new RichMapFunction<MyType, List<MyType>>() {
+	 *
+	 *     private ReducingState<Long> sum;
+	 *
+	 *     public void open(Configuration cfg) {
+	 *         state = getRuntimeContext().getReducingState(
+	 *                 new ReducingStateDescriptor<>("sum", MyType.class, 0L, (a, b) -> a + b));
+	 *     }
+	 *
+	 *     public Tuple2<MyType, Long> map(MyType value) {
+	 *         sum.add(value.count());
+	 *         return new Tuple2<>(value, sum.get());
+	 *     }
+	 * });
+	 *
+	 * }</pre>
+	 *
+	 * @param stateProperties The descriptor defining the properties of the stats.
+	 *
+	 * @param <T> The type of value stored in the state.
+	 *
+	 * @return The partitioned state object.
+	 *
+	 * @throws UnsupportedOperationException Thrown, if no partitioned state is available for the
+	 *                                       function (function is not part of a KeyedStream).
+	 */
+	@PublicEvolving
+	<T> ReducingState<T> getReducingState(ReducingStateDescriptor<T> stateProperties);
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java b/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java
index 03c11f6..43dbe51 100644
--- a/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java
+++ b/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java
@@ -18,16 +18,17 @@
 
 package org.apache.flink.api.common.state;
 
+import org.apache.flink.annotation.PublicEvolving;
+
 import java.io.Serializable;
 import java.util.Set;
 
 /**
- * Interface for a backend that manages operator state.
+ * This interface contains methods for registering operator state with a managed store.
  */
+@PublicEvolving
 public interface OperatorStateStore {
 
-	String DEFAULT_OPERATOR_STATE_NAME = "_default_";
-
 	/**
 	 * Creates a state descriptor of the given name that uses Java serialization to persist the
 	 * state.
@@ -39,7 +40,7 @@ public interface OperatorStateStore {
 	 * @return A list state using Java serialization to serialize state objects.
 	 * @throws Exception
 	 */
-	ListState<Serializable> getSerializableListState(String stateName) throws Exception;
+	<T extends Serializable> ListState<T> getSerializableListState(String stateName) throws Exception;
 
 	/**
 	 * Creates (or restores) a list state. Each state is registered under a unique name.

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-core/src/main/java/org/apache/flink/core/fs/local/LocalDataInputStream.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/core/fs/local/LocalDataInputStream.java b/flink-core/src/main/java/org/apache/flink/core/fs/local/LocalDataInputStream.java
index e7b2828..172da79 100644
--- a/flink-core/src/main/java/org/apache/flink/core/fs/local/LocalDataInputStream.java
+++ b/flink-core/src/main/java/org/apache/flink/core/fs/local/LocalDataInputStream.java
@@ -18,14 +18,14 @@
 
 package org.apache.flink.core.fs.local;
 
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.IOException;
-
 import org.apache.flink.annotation.Internal;
 import org.apache.flink.core.fs.FSDataInputStream;
 
 import javax.annotation.Nonnull;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.nio.channels.FileChannel;
 
 /**
  * The <code>LocalDataInputStream</code> class is a wrapper class for a data
@@ -36,6 +36,7 @@ public class LocalDataInputStream extends FSDataInputStream {
 
 	/** The file input stream used to read data from.*/
 	private final FileInputStream fis;
+	private final FileChannel fileChannel;
 
 	/**
 	 * Constructs a new <code>LocalDataInputStream</code> object from a given {@link File} object.
@@ -46,16 +47,19 @@ public class LocalDataInputStream extends FSDataInputStream {
 	 */
 	public LocalDataInputStream(File file) throws IOException {
 		this.fis = new FileInputStream(file);
+		this.fileChannel = fis.getChannel();
 	}
 
 	@Override
 	public void seek(long desired) throws IOException {
-		this.fis.getChannel().position(desired);
+		if (desired != getPos()) {
+			this.fileChannel.position(desired);
+		}
 	}
 
 	@Override
 	public long getPos() throws IOException {
-		return this.fis.getChannel().position();
+		return this.fileChannel.position();
 	}
 
 	@Override
@@ -70,6 +74,7 @@ public class LocalDataInputStream extends FSDataInputStream {
 	
 	@Override
 	public void close() throws IOException {
+		// Accoring to javadoc, this also closes the channel
 		this.fis.close();
 	}
 	

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-core/src/main/java/org/apache/flink/util/CollectionUtil.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/util/CollectionUtil.java b/flink-core/src/main/java/org/apache/flink/util/CollectionUtil.java
new file mode 100644
index 0000000..15d00ae
--- /dev/null
+++ b/flink-core/src/main/java/org/apache/flink/util/CollectionUtil.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.util;
+
+import java.util.Collection;
+import java.util.Map;
+
+public final class CollectionUtil {
+
+	private CollectionUtil() {
+		throw new AssertionError();
+	}
+
+	public static boolean isNullOrEmpty(Collection<?> collection) {
+		return collection == null || collection.isEmpty();
+	}
+
+	public static boolean isNullOrEmpty(Map<?, ?> map) {
+		return map == null || map.isEmpty();
+	}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-core/src/main/java/org/apache/flink/util/FutureUtil.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/util/FutureUtil.java b/flink-core/src/main/java/org/apache/flink/util/FutureUtil.java
new file mode 100644
index 0000000..62d836b
--- /dev/null
+++ b/flink-core/src/main/java/org/apache/flink/util/FutureUtil.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.util;
+
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.RunnableFuture;
+
+public class FutureUtil {
+
+	private FutureUtil() {
+		throw new AssertionError();
+	}
+
+	public static <T> T runIfNotDoneAndGet(RunnableFuture<T> future) throws ExecutionException, InterruptedException {
+
+		if (null == future) {
+			return null;
+		}
+
+		if (!future.isDone()) {
+			future.run();
+		}
+
+		return future.get();
+	}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPOperatorTest.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPOperatorTest.java b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPOperatorTest.java
index 1fd8de8..0f49b13 100644
--- a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPOperatorTest.java
+++ b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPOperatorTest.java
@@ -135,7 +135,7 @@ public class CEPOperatorTest extends TestLogger {
 		harness.processElement(new StreamRecord<Event>(new Event(42, "foobar", 1.0), 2));
 
 		// simulate snapshot/restore with some elements in internal sorting queue
-		StreamStateHandle snapshot = harness.snapshot(0, 0);
+		StreamStateHandle snapshot = harness.snapshotLegacy(0, 0);
 		harness.close();
 
 		harness = new OneInputStreamOperatorTestHarness<>(
@@ -157,7 +157,7 @@ public class CEPOperatorTest extends TestLogger {
 		harness.processWatermark(new Watermark(2));
 
 		// simulate snapshot/restore with empty element queue but NFA state
-		StreamStateHandle snapshot2 = harness.snapshot(1, 1);
+		StreamStateHandle snapshot2 = harness.snapshotLegacy(1, 1);
 		harness.close();
 
 		harness = new OneInputStreamOperatorTestHarness<>(
@@ -228,7 +228,7 @@ public class CEPOperatorTest extends TestLogger {
 		harness.processElement(new StreamRecord<Event>(new Event(42, "foobar", 1.0), 2));
 
 		// simulate snapshot/restore with some elements in internal sorting queue
-		StreamStateHandle snapshot = harness.snapshot(0, 0);
+		StreamStateHandle snapshot = harness.snapshotLegacy(0, 0);
 		harness.close();
 
 		harness = new KeyedOneInputStreamOperatorTestHarness<>(
@@ -254,7 +254,7 @@ public class CEPOperatorTest extends TestLogger {
 		harness.processWatermark(new Watermark(2));
 
 		// simulate snapshot/restore with empty element queue but NFA state
-		StreamStateHandle snapshot2 = harness.snapshot(1, 1);
+		StreamStateHandle snapshot2 = harness.snapshotLegacy(1, 1);
 		harness.close();
 
 		harness = new KeyedOneInputStreamOperatorTestHarness<>(
@@ -337,7 +337,7 @@ public class CEPOperatorTest extends TestLogger {
 		harness.processElement(new StreamRecord<Event>(new Event(42, "foobar", 1.0), 2));
 
 		// simulate snapshot/restore with some elements in internal sorting queue
-		StreamStateHandle snapshot = harness.snapshot(0, 0);
+		StreamStateHandle snapshot = harness.snapshotLegacy(0, 0);
 		harness.close();
 
 		harness = new KeyedOneInputStreamOperatorTestHarness<>(
@@ -368,7 +368,7 @@ public class CEPOperatorTest extends TestLogger {
 		harness.processWatermark(new Watermark(2));
 
 		// simulate snapshot/restore with empty element queue but NFA state
-		StreamStateHandle snapshot2 = harness.snapshot(1, 1);
+		StreamStateHandle snapshot2 = harness.snapshotLegacy(1, 1);
 		harness.close();
 
 		harness = new KeyedOneInputStreamOperatorTestHarness<>(

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
index 00028c4..588ba84 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
@@ -36,22 +36,10 @@ import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.NotifyCheckpointComplete;
 import org.apache.flink.runtime.messages.checkpoint.TriggerCheckpoint;
-import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
-import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.OperatorStateHandle;
-import org.apache.flink.runtime.state.StreamStateHandle;
-import org.apache.flink.util.Preconditions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.util.ArrayDeque;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collection;
-import java.util.Collections;
 import java.util.HashMap;
 import java.util.Iterator;
 import java.util.LinkedHashMap;
@@ -444,11 +432,11 @@ public class CheckpointCoordinator {
 							// note that checkpoint completion discards the pending checkpoint object
 							if (!checkpoint.isDiscarded()) {
 								LOG.info("Checkpoint " + checkpointID + " expired before completing.");
-	
+
 								checkpoint.abortExpired();
 								pendingCheckpoints.remove(checkpointID);
 								rememberRecentCheckpointId(checkpointID);
-	
+
 								triggerQueuedRequests();
 							}
 						}
@@ -578,7 +566,7 @@ public class CheckpointCoordinator {
 				isPendingCheckpoint = true;
 
 				LOG.info("Discarding checkpoint " + checkpointId
-					+ " because of checkpoint decline from task " + message.getTaskExecutionId());
+						+ " because of checkpoint decline from task " + message.getTaskExecutionId());
 
 				pendingCheckpoints.remove(checkpointId);
 				checkpoint.abortDeclined();
@@ -602,7 +590,7 @@ public class CheckpointCoordinator {
 			} else if (checkpoint != null) {
 				// this should not happen
 				throw new IllegalStateException(
-					"Received message for discarded but non-removed checkpoint " + checkpointId);
+						"Received message for discarded but non-removed checkpoint " + checkpointId);
 			} else {
 				// message is for an unknown checkpoint, or comes too late (checkpoint disposed)
 				if (recentPendingCheckpoints.contains(checkpointId)) {
@@ -660,7 +648,7 @@ public class CheckpointCoordinator {
 
 				if (checkpoint.acknowledgeTask(
 						message.getTaskExecutionId(),
-						message.getCheckpointStateHandles())) {
+						message.getSubtaskState())) {
 					if (checkpoint.isFullyAcknowledged()) {
 						completed = checkpoint.finalizeCheckpoint();
 
@@ -804,199 +792,15 @@ public class CheckpointCoordinator {
 
 			LOG.info("Restoring from latest valid checkpoint: {}.", latest);
 
-			for (Map.Entry<JobVertexID, TaskState> taskGroupStateEntry: latest.getTaskStates().entrySet()) {
-				TaskState taskState = taskGroupStateEntry.getValue();
-				ExecutionJobVertex executionJobVertex = tasks.get(taskGroupStateEntry.getKey());
-
-				if (executionJobVertex != null) {
-					// check that the number of key groups have not changed
-					if (taskState.getMaxParallelism() != executionJobVertex.getMaxParallelism()) {
-						throw new IllegalStateException("The maximum parallelism (" +
-							taskState.getMaxParallelism() + ") with which the latest " +
-							"checkpoint of the execution job vertex " + executionJobVertex +
-							" has been taken and the current maximum parallelism (" +
-							executionJobVertex.getMaxParallelism() + ") changed. This " +
-							"is currently not supported.");
-					}
-
-
-					int oldParallelism = taskState.getParallelism();
-					int newParallelism = executionJobVertex.getParallelism();
-					boolean parallelismChanged = oldParallelism != newParallelism;
-					boolean hasNonPartitionedState = taskState.hasNonPartitionedState();
-
-					if (hasNonPartitionedState && parallelismChanged) {
-						throw new IllegalStateException("Cannot restore the latest checkpoint because " +
-							"the operator " + executionJobVertex.getJobVertexId() + " has non-partitioned " +
-							"state and its parallelism changed. The operator" + executionJobVertex.getJobVertexId() +
-							" has parallelism " + newParallelism + " whereas the corresponding" +
-							"state object has a parallelism of " + oldParallelism);
-					}
-
-					List<KeyGroupRange> keyGroupPartitions = createKeyGroupPartitions(
-							executionJobVertex.getMaxParallelism(),
-							newParallelism);
-					
-					// operator chain index -> list of the stored partitionables states from all parallel instances
-					@SuppressWarnings("unchecked")
-					List<OperatorStateHandle>[] chainParallelStates =
-							new List[taskState.getChainLength()];
-
-					for (int i = 0; i < oldParallelism; ++i) {
-
-						ChainedStateHandle<OperatorStateHandle> partitionableState =
-								taskState.getPartitionableState(i);
-
-						if (partitionableState != null) {
-							for (int j = 0; j < partitionableState.getLength(); ++j) {
-								OperatorStateHandle opParalleState = partitionableState.get(j);
-								if (opParalleState != null) {
-									List<OperatorStateHandle> opParallelStates =
-											chainParallelStates[j];
-									if (opParallelStates == null) {
-										opParallelStates = new ArrayList<>();
-										chainParallelStates[j] = opParallelStates;
-									}
-									opParallelStates.add(opParalleState);
-								}
-							}
-						}
-					}
-
-					// operator chain index -> lists with collected states (one collection for each parallel subtasks)
-					@SuppressWarnings("unchecked")
-					List<Collection<OperatorStateHandle>>[] redistributedParallelStates =
-							new List[taskState.getChainLength()];
-
-					//TODO here we can employ different redistribution strategies for state, e.g. union state. For now we only offer round robin as the default.
-					OperatorStateRepartitioner repartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE;
-
-					for (int i = 0; i < chainParallelStates.length; ++i) {
-						List<OperatorStateHandle> chainOpParallelStates = chainParallelStates[i];
-						if (chainOpParallelStates != null) {
-							//We only redistribute if the parallelism of the operator changed from previous executions
-							if (parallelismChanged) {
-								redistributedParallelStates[i] = repartitioner.repartitionState(
-										chainOpParallelStates,
-										newParallelism);
-							} else {
-								List<Collection<OperatorStateHandle>> repacking = new ArrayList<>(newParallelism);
-								for (OperatorStateHandle operatorStateHandle : chainOpParallelStates) {
-									repacking.add(Collections.singletonList(operatorStateHandle));
-								}
-								redistributedParallelStates[i] = repacking;
-							}
-						}
-					}
-
-					int counter = 0;
-
-					for (int i = 0; i < newParallelism; ++i) {
-
-						// non-partitioned state
-						ChainedStateHandle<StreamStateHandle> state = null;
-
-						if (hasNonPartitionedState) {
-							SubtaskState subtaskState = taskState.getState(i);
-
-							if (subtaskState != null) {
-								// count the number of executions for which we set a state
-								++counter;
-								state = subtaskState.getChainedStateHandle();
-							}
-						}
-
-						// partitionable state
-						@SuppressWarnings("unchecked")
-						Collection<OperatorStateHandle>[] ia = new Collection[taskState.getChainLength()];
-						List<Collection<OperatorStateHandle>> subTaskPartitionableState = Arrays.asList(ia);
-
-						for (int j = 0; j < redistributedParallelStates.length; ++j) {
-							List<Collection<OperatorStateHandle>> redistributedParallelState =
-									redistributedParallelStates[j];
-
-							if (redistributedParallelState != null) {
-								subTaskPartitionableState.set(j, redistributedParallelState.get(i));
-							}
-						}
+			StateAssignmentOperation stateAssignmentOperation =
+					new StateAssignmentOperation(tasks, latest, allOrNothingState);
 
-						// key-partitioned state
-						KeyGroupRange subtaskKeyGroupIds = keyGroupPartitions.get(i);
-
-						// Again, we only repartition if the parallelism changed
-						List<KeyGroupsStateHandle> subtaskKeyGroupStates = parallelismChanged ?
-								getKeyGroupsStateHandles(taskState.getKeyGroupStates(), subtaskKeyGroupIds)
-								: Collections.singletonList(taskState.getKeyGroupState(i));
-
-						Execution currentExecutionAttempt = executionJobVertex
-							.getTaskVertices()[i]
-							.getCurrentExecutionAttempt();
-
-						CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(
-								state,
-								null/*subTaskPartionableState*/, //TODO chose right structure and put redistributed states here
-								subtaskKeyGroupStates);
-
-						currentExecutionAttempt.setInitialState(checkpointStateHandles, subTaskPartitionableState);
-					}
-
-					if (allOrNothingState && counter > 0 && counter < newParallelism) {
-						throw new IllegalStateException("The checkpoint contained state only for " +
-							"a subset of tasks for vertex " + executionJobVertex);
-					}
-				} else {
-					throw new IllegalStateException("There is no execution job vertex for the job" +
-						" vertex ID " + taskGroupStateEntry.getKey());
-				}
-			}
+			stateAssignmentOperation.assignStates();
 
 			return true;
 		}
 	}
 
-	/**
-	 * Determine the subset of {@link KeyGroupsStateHandle KeyGroupsStateHandles} with correct
-	 * key group index for the given subtask {@link KeyGroupRange}.
-	 *
-	 * <p>This is publicly visible to be used in tests.
-	 */
-	public static List<KeyGroupsStateHandle> getKeyGroupsStateHandles(
-			Collection<KeyGroupsStateHandle> allKeyGroupsHandles,
-			KeyGroupRange subtaskKeyGroupIds) {
-
-		List<KeyGroupsStateHandle> subtaskKeyGroupStates = new ArrayList<>();
-
-		for (KeyGroupsStateHandle storedKeyGroup : allKeyGroupsHandles) {
-			KeyGroupsStateHandle intersection = storedKeyGroup.getKeyGroupIntersection(subtaskKeyGroupIds);
-			if (intersection.getNumberOfKeyGroups() > 0) {
-				subtaskKeyGroupStates.add(intersection);
-			}
-		}
-		return subtaskKeyGroupStates;
-	}
-
-	/**
-	 * Groups the available set of key groups into key group partitions. A key group partition is
-	 * the set of key groups which is assigned to the same task. Each set of the returned list
-	 * constitutes a key group partition.
-	 *
-	 * <b>IMPORTANT</b>: The assignment of key groups to partitions has to be in sync with the
-	 * KeyGroupStreamPartitioner.
-	 *
-	 * @param numberKeyGroups Number of available key groups (indexed from 0 to numberKeyGroups - 1)
-	 * @param parallelism Parallelism to generate the key group partitioning for
-	 * @return List of key group partitions
-	 */
-	public static List<KeyGroupRange> createKeyGroupPartitions(int numberKeyGroups, int parallelism) {
-		Preconditions.checkArgument(numberKeyGroups >= parallelism);
-		List<KeyGroupRange> result = new ArrayList<>(parallelism);
-
-		for (int i = 0; i < parallelism; ++i) {
-			result.add(KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(numberKeyGroups, parallelism, i));
-		}
-		return result;
-	}
-
 	// --------------------------------------------------------------------------------------------
 	//  Accessors
 	// --------------------------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointMetaData.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointMetaData.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointMetaData.java
index 6f117f2..2627b22 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointMetaData.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointMetaData.java
@@ -59,6 +59,15 @@ public class CheckpointMetaData implements Serializable {
 				asynchronousDurationMillis);
 	}
 
+	public CheckpointMetaData(
+			long checkpointId,
+			long timestamp,
+			CheckpointMetrics metrics) {
+		this.checkpointId = checkpointId;
+		this.timestamp = timestamp;
+		this.metrics = Preconditions.checkNotNull(metrics);
+	}
+
 	public CheckpointMetrics getMetrics() {
 		return metrics;
 	}
@@ -110,4 +119,37 @@ public class CheckpointMetaData implements Serializable {
 	public long getAsyncDurationMillis() {
 		return metrics.getAsyncDurationMillis();
 	}
+
+	@Override
+	public boolean equals(Object o) {
+		if (this == o) {
+			return true;
+		}
+		if (o == null || getClass() != o.getClass()) {
+			return false;
+		}
+
+		CheckpointMetaData that = (CheckpointMetaData) o;
+
+		return (checkpointId == that.checkpointId)
+				&& (timestamp == that.timestamp)
+				&& (metrics.equals(that.metrics));
+	}
+
+	@Override
+	public int hashCode() {
+		int result = (int) (checkpointId ^ (checkpointId >>> 32));
+		result = 31 * result + (int) (timestamp ^ (timestamp >>> 32));
+		result = 31 * result + metrics.hashCode();
+		return result;
+	}
+
+	@Override
+	public String toString() {
+		return "CheckpointMetaData{" +
+				"checkpointId=" + checkpointId +
+				", timestamp=" + timestamp +
+				", metrics=" + metrics +
+				'}';
+	}
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
index 6f50392..92dca21 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
@@ -28,8 +28,6 @@ import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateUtil;
 import org.apache.flink.runtime.state.StreamStateHandle;
@@ -37,7 +35,6 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.util.HashMap;
-import java.util.List;
 import java.util.Map;
 
 import static org.apache.flink.util.Preconditions.checkArgument;
@@ -234,80 +231,61 @@ public class PendingCheckpoint {
 	
 	public boolean acknowledgeTask(
 			ExecutionAttemptID attemptID,
-			CheckpointStateHandles checkpointStateHandles) {
+			SubtaskState checkpointedSubtaskState) {
 
 		synchronized (lock) {
+
 			if (discarded) {
 				return false;
 			}
 
-			ExecutionVertex vertex = notYetAcknowledgedTasks.remove(attemptID);
-
-			if (vertex != null) {
-				if (checkpointStateHandles != null) {
-					List<KeyGroupsStateHandle> keyGroupsState = checkpointStateHandles.getKeyGroupsStateHandle();
-					ChainedStateHandle<StreamStateHandle> nonPartitionedState =
-							checkpointStateHandles.getNonPartitionedStateHandles();
-					ChainedStateHandle<OperatorStateHandle> partitioneableState =
-							checkpointStateHandles.getPartitioneableStateHandles();
-
-					if (nonPartitionedState != null || partitioneableState != null || keyGroupsState != null) {
-
-						JobVertexID jobVertexID = vertex.getJobvertexId();
+			final ExecutionVertex vertex = notYetAcknowledgedTasks.remove(attemptID);
 
-						int subtaskIndex = vertex.getParallelSubtaskIndex();
+			if (vertex == null) {
+				return false;
+			}
 
-						TaskState taskState;
+			if (null != checkpointedSubtaskState && checkpointedSubtaskState.hasState()) {
 
-						if (taskStates.containsKey(jobVertexID)) {
-							taskState = taskStates.get(jobVertexID);
-						} else {
-							//TODO this should go away when we remove chained state, assigning state to operators directly instead
-							int chainLength;
-							if (nonPartitionedState != null) {
-								chainLength = nonPartitionedState.getLength();
-							} else if (partitioneableState != null) {
-								chainLength = partitioneableState.getLength();
-							} else {
-								chainLength = 1;
-							}
+				JobVertexID jobVertexID = vertex.getJobvertexId();
 
-							taskState = new TaskState(
-								jobVertexID,
-								vertex.getTotalNumberOfParallelSubtasks(),
-								vertex.getMaxParallelism(),
-								chainLength);
+				int subtaskIndex = vertex.getParallelSubtaskIndex();
 
-							taskStates.put(jobVertexID, taskState);
-						}
+				TaskState taskState = taskStates.get(jobVertexID);
 
-						long duration = System.currentTimeMillis() - checkpointTimestamp;
-
-						if (nonPartitionedState != null) {
-							taskState.putState(
-									subtaskIndex,
-									new SubtaskState(nonPartitionedState, duration));
-						}
+				if (null == taskState) {
+					ChainedStateHandle<StreamStateHandle> nonPartitionedState =
+							checkpointedSubtaskState.getLegacyOperatorState();
+					ChainedStateHandle<OperatorStateHandle> partitioneableState =
+							checkpointedSubtaskState.getManagedOperatorState();
+					//TODO this should go away when we remove chained state, assigning state to operators directly instead
+					int chainLength;
+					if (nonPartitionedState != null) {
+						chainLength = nonPartitionedState.getLength();
+					} else if (partitioneableState != null) {
+						chainLength = partitioneableState.getLength();
+					} else {
+						chainLength = 1;
+					}
 
-						if(partitioneableState != null && !partitioneableState.isEmpty()) {
-							taskState.putPartitionableState(subtaskIndex, partitioneableState);
-						}
+					taskState = new TaskState(
+							jobVertexID,
+							vertex.getTotalNumberOfParallelSubtasks(),
+							vertex.getMaxParallelism(),
+							chainLength);
 
-						// currently a checkpoint can only contain keyed state
-						// for the head operator
-						if (keyGroupsState != null && !keyGroupsState.isEmpty()) {
-							KeyGroupsStateHandle keyGroupsStateHandle = keyGroupsState.get(0);
-							taskState.putKeyedState(subtaskIndex, keyGroupsStateHandle);
-						}
-					}
+					taskStates.put(jobVertexID, taskState);
 				}
 
-				++numAcknowledgedTasks;
+				long duration = System.currentTimeMillis() - checkpointTimestamp;
+				checkpointedSubtaskState.setDuration(duration);
 
-				return true;
-			} else {
-				return false;
+				taskState.putState(subtaskIndex, checkpointedSubtaskState);
 			}
+
+			++numAcknowledgedTasks;
+
+			return true;
 		}
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java
index 09a35f6..16a7e27 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java
@@ -176,7 +176,7 @@ public class RoundRobinOperatorStateRepartitioner implements OperatorStateRepart
 					Map<StreamStateHandle, OperatorStateHandle> mergeMap = mergeMapList.get(parallelOpIdx);
 					OperatorStateHandle psh = mergeMap.get(handleWithOffsets.f0);
 					if (psh == null) {
-						psh = new OperatorStateHandle(handleWithOffsets.f0, new HashMap<String, long[]>());
+						psh = new OperatorStateHandle(new HashMap<String, long[]>(), handleWithOffsets.f0);
 						mergeMap.put(handleWithOffsets.f0, psh);
 					}
 					psh.getStateNameToPartitionOffsets().put(e.getKey(), offs);

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
new file mode 100644
index 0000000..8e2b0bf
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
@@ -0,0 +1,329 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.runtime.executiongraph.Execution;
+import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.TaskStateHandles;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * This class encapsulates the operation of assigning restored state when restoring from a checkpoint.
+ */
+public class StateAssignmentOperation {
+
+	public StateAssignmentOperation(
+			Map<JobVertexID, ExecutionJobVertex> tasks,
+			CompletedCheckpoint latest,
+			boolean allOrNothingState) {
+
+		this.tasks = tasks;
+		this.latest = latest;
+		this.allOrNothingState = allOrNothingState;
+	}
+
+	private final Map<JobVertexID, ExecutionJobVertex> tasks;
+	private final CompletedCheckpoint latest;
+	private final boolean allOrNothingState;
+
+	public boolean assignStates() throws Exception {
+
+		for (Map.Entry<JobVertexID, TaskState> taskGroupStateEntry : latest.getTaskStates().entrySet()) {
+			TaskState taskState = taskGroupStateEntry.getValue();
+			ExecutionJobVertex executionJobVertex = tasks.get(taskGroupStateEntry.getKey());
+
+			if (executionJobVertex != null) {
+				// check that the number of key groups have not changed
+				if (taskState.getMaxParallelism() != executionJobVertex.getMaxParallelism()) {
+					throw new IllegalStateException("The maximum parallelism (" +
+							taskState.getMaxParallelism() + ") with which the latest " +
+							"checkpoint of the execution job vertex " + executionJobVertex +
+							" has been taken and the current maximum parallelism (" +
+							executionJobVertex.getMaxParallelism() + ") changed. This " +
+							"is currently not supported.");
+				}
+
+				final int oldParallelism = taskState.getParallelism();
+				final int newParallelism = executionJobVertex.getParallelism();
+				final boolean parallelismChanged = oldParallelism != newParallelism;
+				final boolean hasNonPartitionedState = taskState.hasNonPartitionedState();
+
+				if (hasNonPartitionedState && parallelismChanged) {
+					throw new IllegalStateException("Cannot restore the latest checkpoint because " +
+							"the operator " + executionJobVertex.getJobVertexId() + " has non-partitioned " +
+							"state and its parallelism changed. The operator" + executionJobVertex.getJobVertexId() +
+							" has parallelism " + newParallelism + " whereas the corresponding" +
+							"state object has a parallelism of " + oldParallelism);
+				}
+
+				List<KeyGroupRange> keyGroupPartitions = createKeyGroupPartitions(
+						executionJobVertex.getMaxParallelism(),
+						newParallelism);
+
+				final int chainLength = taskState.getChainLength();
+
+				// operator chain idx -> list of the stored op states from all parallel instances for this chain idx
+				@SuppressWarnings("unchecked")
+				List<OperatorStateHandle>[] parallelOpStatesBackend = new List[chainLength];
+				@SuppressWarnings("unchecked")
+				List<OperatorStateHandle>[] parallelOpStatesStream = new List[chainLength];
+
+				List<KeyGroupsStateHandle> parallelKeyedStatesBackend = new ArrayList<>(oldParallelism);
+				List<KeyGroupsStateHandle> parallelKeyedStateStream = new ArrayList<>(oldParallelism);
+
+				int counter = 0;
+				for (int p = 0; p < oldParallelism; ++p) {
+
+					SubtaskState subtaskState = taskState.getState(p);
+
+					if (null != subtaskState) {
+
+						++counter;
+
+						collectParallelStatesByChainOperator(
+								parallelOpStatesBackend, subtaskState.getManagedOperatorState());
+
+						collectParallelStatesByChainOperator(
+								parallelOpStatesStream, subtaskState.getRawOperatorState());
+
+						KeyGroupsStateHandle keyedStateBackend = subtaskState.getManagedKeyedState();
+						if (null != keyedStateBackend) {
+							parallelKeyedStatesBackend.add(keyedStateBackend);
+						}
+
+						KeyGroupsStateHandle keyedStateStream = subtaskState.getRawKeyedState();
+						if (null != keyedStateStream) {
+							parallelKeyedStateStream.add(keyedStateStream);
+						}
+					}
+				}
+
+				if (allOrNothingState && counter > 0 && counter < oldParallelism) {
+					throw new IllegalStateException("The checkpoint contained state only for " +
+							"a subset of tasks for vertex " + executionJobVertex);
+				}
+
+				// operator chain index -> lists with collected states (one collection for each parallel subtasks)
+				@SuppressWarnings("unchecked")
+				List<Collection<OperatorStateHandle>>[] partitionedParallelStatesBackend = new List[chainLength];
+
+				@SuppressWarnings("unchecked")
+				List<Collection<OperatorStateHandle>>[] partitionedParallelStatesStream = new List[chainLength];
+
+				//TODO here we can employ different redistribution strategies for state, e.g. union state.
+				// For now we only offer round robin as the default.
+				OperatorStateRepartitioner opStateRepartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE;
+
+				for (int chainIdx = 0; chainIdx < chainLength; ++chainIdx) {
+
+					List<OperatorStateHandle> chainOpParallelStatesBackend = parallelOpStatesBackend[chainIdx];
+					List<OperatorStateHandle> chainOpParallelStatesStream = parallelOpStatesStream[chainIdx];
+
+					partitionedParallelStatesBackend[chainIdx] = applyRepartitioner(
+							opStateRepartitioner,
+							chainOpParallelStatesBackend,
+							oldParallelism,
+							newParallelism);
+
+					partitionedParallelStatesStream[chainIdx] = applyRepartitioner(
+							opStateRepartitioner,
+							chainOpParallelStatesStream,
+							oldParallelism,
+							newParallelism);
+				}
+
+				for (int subTaskIdx = 0; subTaskIdx < newParallelism; ++subTaskIdx) {
+					// non-partitioned state
+					ChainedStateHandle<StreamStateHandle> nonPartitionableState = null;
+
+					if (hasNonPartitionedState) {
+						// count the number of executions for which we set a state
+						nonPartitionableState = taskState.getState(subTaskIdx).getLegacyOperatorState();
+					}
+
+					// partitionable state
+					@SuppressWarnings("unchecked")
+					Collection<OperatorStateHandle>[] iab = new Collection[chainLength];
+					@SuppressWarnings("unchecked")
+					Collection<OperatorStateHandle>[] ias = new Collection[chainLength];
+					List<Collection<OperatorStateHandle>> operatorStateFromBackend = Arrays.asList(iab);
+					List<Collection<OperatorStateHandle>> operatorStateFromStream = Arrays.asList(ias);
+
+					for (int chainIdx = 0; chainIdx < partitionedParallelStatesBackend.length; ++chainIdx) {
+						List<Collection<OperatorStateHandle>> redistributedOpStateBackend =
+								partitionedParallelStatesBackend[chainIdx];
+
+						List<Collection<OperatorStateHandle>> redistributedOpStateStream =
+								partitionedParallelStatesStream[chainIdx];
+
+						if (redistributedOpStateBackend != null) {
+							operatorStateFromBackend.set(chainIdx, redistributedOpStateBackend.get(subTaskIdx));
+						}
+
+						if (redistributedOpStateStream != null) {
+							operatorStateFromStream.set(chainIdx, redistributedOpStateStream.get(subTaskIdx));
+						}
+					}
+
+					Execution currentExecutionAttempt = executionJobVertex
+							.getTaskVertices()[subTaskIdx]
+							.getCurrentExecutionAttempt();
+
+					List<KeyGroupsStateHandle> newKeyedStatesBackend;
+					List<KeyGroupsStateHandle> newKeyedStateStream;
+					if (parallelismChanged) {
+						KeyGroupRange subtaskKeyGroupIds = keyGroupPartitions.get(subTaskIdx);
+						newKeyedStatesBackend = getKeyGroupsStateHandles(parallelKeyedStatesBackend, subtaskKeyGroupIds);
+						newKeyedStateStream = getKeyGroupsStateHandles(parallelKeyedStateStream, subtaskKeyGroupIds);
+					} else {
+						SubtaskState subtaskState = taskState.getState(subTaskIdx);
+						KeyGroupsStateHandle oldKeyedStatesBackend = subtaskState.getManagedKeyedState();
+						KeyGroupsStateHandle oldKeyedStatesStream = subtaskState.getRawKeyedState();
+						newKeyedStatesBackend = oldKeyedStatesBackend != null ? Collections.singletonList(oldKeyedStatesBackend) : null;
+						newKeyedStateStream = oldKeyedStatesStream != null ? Collections.singletonList(oldKeyedStatesStream) : null;
+					}
+
+					TaskStateHandles taskStateHandles = new TaskStateHandles(
+							nonPartitionableState,
+							operatorStateFromBackend,
+							operatorStateFromStream,
+							newKeyedStatesBackend,
+							newKeyedStateStream);
+
+					currentExecutionAttempt.setInitialState(taskStateHandles);
+				}
+
+			} else {
+				throw new IllegalStateException("There is no execution job vertex for the job" +
+						" vertex ID " + taskGroupStateEntry.getKey());
+			}
+		}
+
+		return true;
+
+	}
+
+	/**
+	 * Determine the subset of {@link KeyGroupsStateHandle KeyGroupsStateHandles} with correct
+	 * key group index for the given subtask {@link KeyGroupRange}.
+	 * <p>
+	 * <p>This is publicly visible to be used in tests.
+	 */
+	public static List<KeyGroupsStateHandle> getKeyGroupsStateHandles(
+			Collection<KeyGroupsStateHandle> allKeyGroupsHandles,
+			KeyGroupRange subtaskKeyGroupIds) {
+
+		List<KeyGroupsStateHandle> subtaskKeyGroupStates = new ArrayList<>();
+
+		for (KeyGroupsStateHandle storedKeyGroup : allKeyGroupsHandles) {
+			KeyGroupsStateHandle intersection = storedKeyGroup.getKeyGroupIntersection(subtaskKeyGroupIds);
+			if (intersection.getNumberOfKeyGroups() > 0) {
+				subtaskKeyGroupStates.add(intersection);
+			}
+		}
+		return subtaskKeyGroupStates;
+	}
+
+	/**
+	 * Groups the available set of key groups into key group partitions. A key group partition is
+	 * the set of key groups which is assigned to the same task. Each set of the returned list
+	 * constitutes a key group partition.
+	 * <p>
+	 * <b>IMPORTANT</b>: The assignment of key groups to partitions has to be in sync with the
+	 * KeyGroupStreamPartitioner.
+	 *
+	 * @param numberKeyGroups Number of available key groups (indexed from 0 to numberKeyGroups - 1)
+	 * @param parallelism     Parallelism to generate the key group partitioning for
+	 * @return List of key group partitions
+	 */
+	public static List<KeyGroupRange> createKeyGroupPartitions(int numberKeyGroups, int parallelism) {
+		Preconditions.checkArgument(numberKeyGroups >= parallelism);
+		List<KeyGroupRange> result = new ArrayList<>(parallelism);
+
+		for (int i = 0; i < parallelism; ++i) {
+			result.add(KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(numberKeyGroups, parallelism, i));
+		}
+		return result;
+	}
+
+	/**
+	 * @param chainParallelOpStates array = chain ops, array[idx] = parallel states for this chain op.
+	 * @param chainOpState
+	 */
+	private static void collectParallelStatesByChainOperator(
+			List<OperatorStateHandle>[] chainParallelOpStates, ChainedStateHandle<OperatorStateHandle> chainOpState) {
+
+		if (null != chainOpState) {
+			for (int chainIdx = 0; chainIdx < chainParallelOpStates.length; ++chainIdx) {
+				OperatorStateHandle operatorState = chainOpState.get(chainIdx);
+
+				if (null != operatorState) {
+
+					List<OperatorStateHandle> opParallelStatesForOneChainOp = chainParallelOpStates[chainIdx];
+
+					if (null == opParallelStatesForOneChainOp) {
+						opParallelStatesForOneChainOp = new ArrayList<>();
+						chainParallelOpStates[chainIdx] = opParallelStatesForOneChainOp;
+					}
+					opParallelStatesForOneChainOp.add(operatorState);
+				}
+			}
+		}
+	}
+
+	private static List<Collection<OperatorStateHandle>> applyRepartitioner(
+			OperatorStateRepartitioner opStateRepartitioner,
+			List<OperatorStateHandle> chainOpParallelStates,
+			int oldParallelism,
+			int newParallelism) {
+
+		if (chainOpParallelStates == null) {
+			return null;
+		}
+
+		//We only redistribute if the parallelism of the operator changed from previous executions
+		if (newParallelism != oldParallelism) {
+
+			return opStateRepartitioner.repartitionState(
+					chainOpParallelStates,
+					newParallelism);
+		} else {
+
+			List<Collection<OperatorStateHandle>> repackStream = new ArrayList<>(newParallelism);
+			for (OperatorStateHandle operatorStateHandle : chainOpParallelStates) {
+				repackStream.add(Collections.singletonList(operatorStateHandle));
+			}
+			return repackStream;
+		}
+	}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
index 2aa0491..9b9a810 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
@@ -19,10 +19,13 @@
 package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateObject;
+import org.apache.flink.runtime.state.StateUtil;
 import org.apache.flink.runtime.state.StreamStateHandle;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
+
+import java.util.Arrays;
 
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
@@ -34,10 +37,31 @@ public class SubtaskState implements StateObject {
 
 	private static final long serialVersionUID = -2394696997971923995L;
 
-	private static final Logger LOG = LoggerFactory.getLogger(SubtaskState.class);
+	/**
+	 * Legacy (non-repartitionable) operator state.
+	 */
+	@Deprecated
+	private final ChainedStateHandle<StreamStateHandle> legacyOperatorState;
 
-	/** The state of the parallel operator */
-	private final ChainedStateHandle<StreamStateHandle> chainedStateHandle;
+	/**
+	 * Snapshot from the {@link org.apache.flink.runtime.state.OperatorStateBackend}.
+	 */
+	private final ChainedStateHandle<OperatorStateHandle> managedOperatorState;
+
+	/**
+	 * Snapshot written using {@link org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream}.
+	 */
+	private final ChainedStateHandle<OperatorStateHandle> rawOperatorState;
+
+	/**
+	 * Snapshot from {@link org.apache.flink.runtime.state.KeyedStateBackend}.
+	 */
+	private final KeyGroupsStateHandle managedKeyedState;
+
+	/**
+	 * Snapshot written using {@link org.apache.flink.runtime.state.KeyedStateCheckpointOutputStream}.
+	 */
+	private final KeyGroupsStateHandle rawKeyedState;
 
 	/**
 	 * The state size. This is also part of the deserialized state handle.
@@ -46,26 +70,76 @@ public class SubtaskState implements StateObject {
 	 */
 	private final long stateSize;
 
-	/** The duration of the checkpoint (ack timestamp - trigger timestamp). */
-	private final long duration;
-	
+	/**
+	 * The duration of the checkpoint (ack timestamp - trigger timestamp).
+	 */
+	private long duration;
+
+	public SubtaskState(
+			ChainedStateHandle<StreamStateHandle> legacyOperatorState,
+			ChainedStateHandle<OperatorStateHandle> managedOperatorState,
+			ChainedStateHandle<OperatorStateHandle> rawOperatorState,
+			KeyGroupsStateHandle managedKeyedState,
+			KeyGroupsStateHandle rawKeyedState) {
+		this(legacyOperatorState,
+				managedOperatorState,
+				rawOperatorState,
+				managedKeyedState,
+				rawKeyedState,
+				0L);
+	}
+
 	public SubtaskState(
-			ChainedStateHandle<StreamStateHandle> chainedStateHandle,
+			ChainedStateHandle<StreamStateHandle> legacyOperatorState,
+			ChainedStateHandle<OperatorStateHandle> managedOperatorState,
+			ChainedStateHandle<OperatorStateHandle> rawOperatorState,
+			KeyGroupsStateHandle managedKeyedState,
+			KeyGroupsStateHandle rawKeyedState,
 			long duration) {
 
-		this.chainedStateHandle = checkNotNull(chainedStateHandle, "State");
+		this.legacyOperatorState = checkNotNull(legacyOperatorState, "State");
+		this.managedOperatorState = managedOperatorState;
+		this.rawOperatorState = rawOperatorState;
+		this.managedKeyedState = managedKeyedState;
+		this.rawKeyedState = rawKeyedState;
 		this.duration = duration;
 		try {
-			stateSize = chainedStateHandle.getStateSize();
+			long calculateStateSize = getSizeNullSafe(legacyOperatorState);
+			calculateStateSize += getSizeNullSafe(managedOperatorState);
+			calculateStateSize += getSizeNullSafe(rawOperatorState);
+			calculateStateSize += getSizeNullSafe(managedKeyedState);
+			calculateStateSize += getSizeNullSafe(rawKeyedState);
+			stateSize = calculateStateSize;
 		} catch (Exception e) {
 			throw new RuntimeException("Failed to get state size.", e);
 		}
 	}
 
+	private static final long getSizeNullSafe(StateObject stateObject) throws Exception {
+		return stateObject != null ? stateObject.getStateSize() : 0L;
+	}
+
 	// --------------------------------------------------------------------------------------------
-	
-	public ChainedStateHandle<StreamStateHandle> getChainedStateHandle() {
-		return chainedStateHandle;
+
+	@Deprecated
+	public ChainedStateHandle<StreamStateHandle> getLegacyOperatorState() {
+		return legacyOperatorState;
+	}
+
+	public ChainedStateHandle<OperatorStateHandle> getManagedOperatorState() {
+		return managedOperatorState;
+	}
+
+	public ChainedStateHandle<OperatorStateHandle> getRawOperatorState() {
+		return rawOperatorState;
+	}
+
+	public KeyGroupsStateHandle getManagedKeyedState() {
+		return managedKeyedState;
+	}
+
+	public KeyGroupsStateHandle getRawKeyedState() {
+		return rawKeyedState;
 	}
 
 	@Override
@@ -79,35 +153,94 @@ public class SubtaskState implements StateObject {
 
 	@Override
 	public void discardState() throws Exception {
-		chainedStateHandle.discardState();
+		StateUtil.bestEffortDiscardAllStateObjects(
+				Arrays.asList(
+						legacyOperatorState,
+						managedOperatorState,
+						rawOperatorState,
+						managedKeyedState,
+						rawKeyedState));
+	}
+
+	public void setDuration(long duration) {
+		this.duration = duration;
 	}
 
 	// --------------------------------------------------------------------------------------------
 
+
 	@Override
 	public boolean equals(Object o) {
 		if (this == o) {
 			return true;
 		}
-		else if (o instanceof SubtaskState) {
-			SubtaskState that = (SubtaskState) o;
-			return this.chainedStateHandle.equals(that.chainedStateHandle) && stateSize == that.stateSize &&
-				duration == that.duration;
+		if (o == null || getClass() != o.getClass()) {
+			return false;
 		}
-		else {
+
+		SubtaskState that = (SubtaskState) o;
+
+		if (stateSize != that.stateSize) {
 			return false;
 		}
+		if (duration != that.duration) {
+			return false;
+		}
+		if (legacyOperatorState != null ?
+				!legacyOperatorState.equals(that.legacyOperatorState)
+				: that.legacyOperatorState != null) {
+			return false;
+		}
+		if (managedOperatorState != null ?
+				!managedOperatorState.equals(that.managedOperatorState)
+				: that.managedOperatorState != null) {
+			return false;
+		}
+		if (rawOperatorState != null ?
+				!rawOperatorState.equals(that.rawOperatorState)
+				: that.rawOperatorState != null) {
+			return false;
+		}
+		if (managedKeyedState != null ?
+				!managedKeyedState.equals(that.managedKeyedState)
+				: that.managedKeyedState != null) {
+			return false;
+		}
+		return rawKeyedState != null ?
+				rawKeyedState.equals(that.rawKeyedState)
+				: that.rawKeyedState == null;
+
+	}
+
+	public boolean hasState() {
+		return (null != legacyOperatorState && !legacyOperatorState.isEmpty())
+				|| (null != managedOperatorState && !managedOperatorState.isEmpty())
+				|| null != managedKeyedState
+				|| null != rawKeyedState;
 	}
 
 	@Override
 	public int hashCode() {
-		return (int) (this.stateSize ^ this.stateSize >>> 32) +
-			31 * ((int) (this.duration ^ this.duration >>> 32) +
-				31 * chainedStateHandle.hashCode());
+		int result = legacyOperatorState != null ? legacyOperatorState.hashCode() : 0;
+		result = 31 * result + (managedOperatorState != null ? managedOperatorState.hashCode() : 0);
+		result = 31 * result + (rawOperatorState != null ? rawOperatorState.hashCode() : 0);
+		result = 31 * result + (managedKeyedState != null ? managedKeyedState.hashCode() : 0);
+		result = 31 * result + (rawKeyedState != null ? rawKeyedState.hashCode() : 0);
+		result = 31 * result + (int) (stateSize ^ (stateSize >>> 32));
+		result = 31 * result + (int) (duration ^ (duration >>> 32));
+		return result;
 	}
 
 	@Override
 	public String toString() {
-		return String.format("SubtaskState(Size: %d, Duration: %d, State: %s)", stateSize, duration, chainedStateHandle);
+		return "SubtaskState{" +
+				"chainedStateHandle=" + legacyOperatorState +
+				", operatorStateFromBackend=" + managedOperatorState +
+				", operatorStateFromStream=" + rawOperatorState +
+				", keyedStateFromBackend=" + managedKeyedState +
+				", keyedStateHandleFromStream=" + rawKeyedState +
+				", stateSize=" + stateSize +
+				", duration=" + duration +
+				'}';
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
index 7e4eded..3cdc5e9 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
@@ -18,11 +18,7 @@
 
 package org.apache.flink.runtime.checkpoint;
 
-import com.google.common.collect.Iterables;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
-import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateObject;
 import org.apache.flink.runtime.state.StateUtil;
 import org.apache.flink.util.Preconditions;
@@ -49,12 +45,6 @@ public class TaskState implements StateObject {
 	/** handles to non-partitioned states, subtaskindex -> subtaskstate */
 	private final Map<Integer, SubtaskState> subtaskStates;
 
-	/** handles to partitionable states, subtaskindex -> partitionable state */
-	private final Map<Integer, ChainedStateHandle<OperatorStateHandle>> partitionableStates;
-
-	/** handles to key-partitioned states, subtaskindex -> keyed state */
-	private final Map<Integer, KeyGroupsStateHandle> keyGroupsStateHandles;
-
 
 	/** parallelism of the operator when it was checkpointed */
 	private final int parallelism;
@@ -62,6 +52,7 @@ public class TaskState implements StateObject {
 	/** maximum parallelism of the operator when the job was first created */
 	private final int maxParallelism;
 
+	/** length of the operator chain */
 	private final int chainLength;
 
 	public TaskState(JobVertexID jobVertexID, int parallelism, int maxParallelism, int chainLength) {
@@ -73,8 +64,6 @@ public class TaskState implements StateObject {
 		this.jobVertexID = jobVertexID;
 
 		this.subtaskStates = new HashMap<>(parallelism);
-		this.partitionableStates = new HashMap<>(parallelism);
-		this.keyGroupsStateHandles = new HashMap<>(parallelism);
 
 		this.parallelism = parallelism;
 		this.maxParallelism = maxParallelism;
@@ -96,32 +85,6 @@ public class TaskState implements StateObject {
 		}
 	}
 
-	public void putPartitionableState(
-			int subtaskIndex,
-			ChainedStateHandle<OperatorStateHandle> partitionableState) {
-
-		Preconditions.checkNotNull(partitionableState);
-
-		if (subtaskIndex < 0 || subtaskIndex >= parallelism) {
-			throw new IndexOutOfBoundsException("The given sub task index " + subtaskIndex +
-					" exceeds the maximum number of sub tasks " + subtaskStates.size());
-		} else {
-			partitionableStates.put(subtaskIndex, partitionableState);
-		}
-	}
-
-	public void putKeyedState(int subtaskIndex, KeyGroupsStateHandle keyGroupsStateHandle) {
-		Preconditions.checkNotNull(keyGroupsStateHandle);
-
-		if (subtaskIndex < 0 || subtaskIndex >= parallelism) {
-			throw new IndexOutOfBoundsException("The given sub task index " + subtaskIndex +
-					" exceeds the maximum number of sub tasks " + subtaskStates.size());
-		} else {
-			keyGroupsStateHandles.put(subtaskIndex, keyGroupsStateHandle);
-		}
-	}
-
-
 	public SubtaskState getState(int subtaskIndex) {
 		if (subtaskIndex < 0 || subtaskIndex >= parallelism) {
 			throw new IndexOutOfBoundsException("The given sub task index " + subtaskIndex +
@@ -131,24 +94,6 @@ public class TaskState implements StateObject {
 		}
 	}
 
-	public ChainedStateHandle<OperatorStateHandle> getPartitionableState(int subtaskIndex) {
-		if (subtaskIndex < 0 || subtaskIndex >= parallelism) {
-			throw new IndexOutOfBoundsException("The given sub task index " + subtaskIndex +
-					" exceeds the maximum number of sub tasks " + subtaskStates.size());
-		} else {
-			return partitionableStates.get(subtaskIndex);
-		}
-	}
-
-	public KeyGroupsStateHandle getKeyGroupState(int subtaskIndex) {
-		if (subtaskIndex < 0 || subtaskIndex >= parallelism) {
-			throw new IndexOutOfBoundsException("The given sub task index " + subtaskIndex +
-					" exceeds the maximum number of sub tasks " + keyGroupsStateHandles.size());
-		} else {
-			return keyGroupsStateHandles.get(subtaskIndex);
-		}
-	}
-
 	public Collection<SubtaskState> getStates() {
 		return subtaskStates.values();
 	}
@@ -169,13 +114,9 @@ public class TaskState implements StateObject {
 		return chainLength;
 	}
 
-	public Collection<KeyGroupsStateHandle> getKeyGroupStates() {
-		return keyGroupsStateHandles.values();
-	}
-
 	public boolean hasNonPartitionedState() {
 		for(SubtaskState sts : subtaskStates.values()) {
-			if (sts != null && !sts.getChainedStateHandle().isEmpty()) {
+			if (sts != null && !sts.getLegacyOperatorState().isEmpty()) {
 				return true;
 			}
 		}
@@ -184,8 +125,7 @@ public class TaskState implements StateObject {
 
 	@Override
 	public void discardState() throws Exception {
-		StateUtil.bestEffortDiscardAllStateObjects(
-				Iterables.concat(subtaskStates.values(), partitionableStates.values(), keyGroupsStateHandles.values()));
+		StateUtil.bestEffortDiscardAllStateObjects(subtaskStates.values());
 	}
 
 
@@ -198,16 +138,6 @@ public class TaskState implements StateObject {
 			if (subtaskState != null) {
 				result += subtaskState.getStateSize();
 			}
-
-			ChainedStateHandle<OperatorStateHandle> partitionableState = partitionableStates.get(i);
-			if (partitionableState != null) {
-				result += partitionableState.getStateSize();
-			}
-
-			KeyGroupsStateHandle keyGroupsState = keyGroupsStateHandles.get(i);
-			if (keyGroupsState != null) {
-				result += keyGroupsState.getStateSize();
-			}
 		}
 
 		return result;
@@ -220,9 +150,7 @@ public class TaskState implements StateObject {
 
 			return jobVertexID.equals(other.jobVertexID)
 					&& parallelism == other.parallelism
-					&& subtaskStates.equals(other.subtaskStates)
-					&& partitionableStates.equals(other.partitionableStates)
-					&& keyGroupsStateHandles.equals(other.keyGroupsStateHandles);
+					&& subtaskStates.equals(other.subtaskStates);
 		} else {
 			return false;
 		}
@@ -230,18 +158,10 @@ public class TaskState implements StateObject {
 
 	@Override
 	public int hashCode() {
-		return parallelism + 31 * Objects.hash(jobVertexID, subtaskStates, partitionableStates, keyGroupsStateHandles);
+		return parallelism + 31 * Objects.hash(jobVertexID, subtaskStates);
 	}
 
 	public Map<Integer, SubtaskState> getSubtaskStates() {
 		return Collections.unmodifiableMap(subtaskStates);
 	}
-
-	public Map<Integer, KeyGroupsStateHandle> getKeyGroupsStateHandles() {
-		return Collections.unmodifiableMap(keyGroupsStateHandles);
-	}
-
-	public Map<Integer, ChainedStateHandle<OperatorStateHandle>> getPartitionableStates() {
-		return partitionableStates;
-	}
 }