You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by al...@apache.org on 2016/10/20 14:15:21 UTC

[3/8] flink git commit: [FLINK-4844] Partitionable Raw Keyed/Operator State

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorTest.java
new file mode 100644
index 0000000..f5d633c
--- /dev/null
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperatorTest.java
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators;
+
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.blob.BlobKey;
+import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
+import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
+import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
+import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor;
+import org.apache.flink.runtime.execution.ExecutionState;
+import org.apache.flink.runtime.execution.librarycache.LibraryCacheManager;
+import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.filecache.FileCache;
+import org.apache.flink.runtime.io.disk.iomanager.IOManager;
+import org.apache.flink.runtime.io.network.NetworkEnvironment;
+import org.apache.flink.runtime.io.network.netty.PartitionStateChecker;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
+import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
+import org.apache.flink.runtime.memory.MemoryManager;
+import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup;
+import org.apache.flink.runtime.query.TaskKvStateRegistry;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.runtime.taskmanager.CheckpointResponder;
+import org.apache.flink.runtime.taskmanager.Task;
+import org.apache.flink.runtime.taskmanager.TaskManagerConnection;
+import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
+import org.apache.flink.streaming.api.TimeCharacteristic;
+import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.graph.StreamConfig;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.SourceStreamTask;
+import org.apache.flink.streaming.runtime.tasks.StreamTask;
+import org.apache.flink.util.SerializedValue;
+import org.junit.Test;
+
+import java.io.Serializable;
+import java.net.URL;
+import java.util.Collections;
+import java.util.concurrent.Executor;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class AbstractUdfStreamOperatorTest {
+
+	@Test
+	public void testLifeCycle() throws Exception {
+
+		Configuration taskManagerConfig = new Configuration();
+
+		StreamConfig cfg = new StreamConfig(new Configuration());
+		cfg.setStreamOperator(new LifecycleTrackingStreamSource(new MockSourceFunction()));
+		cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
+
+		Task task = createTask(SourceStreamTask.class, cfg, taskManagerConfig);
+
+		task.startTaskThread();
+
+		// wait for clean termination
+		task.getExecutingThread().join();
+		assertEquals(ExecutionState.FINISHED, task.getExecutionState());
+	}
+
+	private static class MockSourceFunction extends RichSourceFunction<Long> {
+
+		private static final long serialVersionUID = 1L;
+
+		@Override
+		public void run(SourceContext<Long> ctx) {
+		}
+
+		@Override
+		public void cancel() {
+		}
+
+		@Override
+		public void setRuntimeContext(RuntimeContext t) {
+			System.out.println("!setRuntimeContext");
+			super.setRuntimeContext(t);
+		}
+
+		@Override
+		public void open(Configuration parameters) throws Exception {
+			System.out.println("!open");
+			super.open(parameters);
+		}
+
+		@Override
+		public void close() throws Exception {
+			System.out.println("!close");
+			super.close();
+		}
+	}
+
+	private Task createTask(
+			Class<? extends AbstractInvokable> invokable,
+			StreamConfig taskConfig,
+			Configuration taskManagerConfig) throws Exception {
+
+		LibraryCacheManager libCache = mock(LibraryCacheManager.class);
+		when(libCache.getClassLoader(any(JobID.class))).thenReturn(getClass().getClassLoader());
+
+		ResultPartitionManager partitionManager = mock(ResultPartitionManager.class);
+		ResultPartitionConsumableNotifier consumableNotifier = mock(ResultPartitionConsumableNotifier.class);
+		PartitionStateChecker partitionStateChecker = mock(PartitionStateChecker.class);
+		Executor executor = mock(Executor.class);
+
+		NetworkEnvironment network = mock(NetworkEnvironment.class);
+		when(network.getResultPartitionManager()).thenReturn(partitionManager);
+		when(network.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC);
+		when(network.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class)))
+				.thenReturn(mock(TaskKvStateRegistry.class));
+
+		TaskDeploymentDescriptor tdd = new TaskDeploymentDescriptor(
+				new JobID(), "Job Name", new JobVertexID(), new ExecutionAttemptID(),
+				new SerializedValue<>(new ExecutionConfig()),
+				"Test Task", 1, 0, 1, 0,
+				new Configuration(),
+				taskConfig.getConfiguration(),
+				invokable.getName(),
+				Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
+				Collections.<InputGateDeploymentDescriptor>emptyList(),
+				Collections.<BlobKey>emptyList(),
+				Collections.<URL>emptyList(),
+				0);
+
+		return new Task(
+				tdd,
+				mock(MemoryManager.class),
+				mock(IOManager.class),
+				network,
+				mock(BroadcastVariableManager.class),
+				mock(TaskManagerConnection.class),
+				mock(InputSplitProvider.class),
+				mock(CheckpointResponder.class),
+				libCache,
+				mock(FileCache.class),
+				new TaskManagerRuntimeInfo("localhost", taskManagerConfig, System.getProperty("java.io.tmpdir")),
+				new UnregisteredTaskMetricsGroup(),
+				consumableNotifier,
+				partitionStateChecker,
+				executor);
+	}
+
+	static class LifecycleTrackingStreamSource<OUT, SRC extends SourceFunction<OUT>>
+			extends StreamSource<OUT, SRC> implements Serializable {
+
+		//private transient final AtomicInteger currentState;
+
+		private static final long serialVersionUID = 2431488948886850562L;
+
+		public LifecycleTrackingStreamSource(SRC sourceFunction) {
+			super(sourceFunction);
+		}
+
+		@Override
+		public void setup(StreamTask<?, ?> containingTask, StreamConfig config, Output<StreamRecord<OUT>> output) {
+			System.out.println("setup");
+			super.setup(containingTask, config, output);
+		}
+
+		@Override
+		public void snapshotState(StateSnapshotContext context) throws Exception {
+			System.out.println("snapshotState");
+			super.snapshotState(context);
+		}
+
+		@Override
+		public void initializeState(StateInitializationContext context) throws Exception {
+			System.out.println("initializeState");
+			super.initializeState(context);
+		}
+
+		@Override
+		public void open() throws Exception {
+			System.out.println("open");
+			super.open();
+		}
+
+		@Override
+		public void close() throws Exception {
+			System.out.println("close");
+			super.close();
+		}
+
+		@Override
+		public void dispose() throws Exception {
+			super.dispose();
+			System.out.println("dispose");
+		}
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java
new file mode 100644
index 0000000..75c2261
--- /dev/null
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java
@@ -0,0 +1,260 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators;
+
+import org.apache.flink.api.common.state.KeyedStateStore;
+import org.apache.flink.api.common.state.OperatorStateStore;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.state.ClosableRegistry;
+import org.apache.flink.runtime.state.DefaultOperatorStateBackend;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
+import org.apache.flink.runtime.state.KeyGroupStatePartitionStreamProvider;
+import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.StateInitializationContextImpl;
+import org.apache.flink.runtime.state.StatePartitionStreamProvider;
+import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
+import org.apache.flink.runtime.util.LongArrayList;
+import org.apache.flink.util.Preconditions;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import static org.mockito.Mockito.mock;
+
+public class StateInitializationContextImplTest {
+
+	static final int NUM_HANDLES = 10;
+
+	private StateInitializationContextImpl initializationContext;
+	private ClosableRegistry closableRegistry;
+
+	private int writtenKeyGroups;
+	private Set<Integer> writtenOperatorStates;
+
+	@Before
+	public void setUp() throws Exception {
+
+
+		this.writtenKeyGroups = 0;
+		this.writtenOperatorStates = new HashSet<>();
+
+		this.closableRegistry = new ClosableRegistry();
+		OperatorStateStore stateStore = mock(OperatorStateStore.class);
+
+		ByteArrayOutputStreamWithPos out = new ByteArrayOutputStreamWithPos(64);
+
+		List<KeyGroupsStateHandle> keyGroupsStateHandles = new ArrayList<>(NUM_HANDLES);
+		int prev = 0;
+		for (int i = 0; i < NUM_HANDLES; ++i) {
+			out.reset();
+			int size = i % 4;
+			int end = prev + size;
+			DataOutputView dov = new DataOutputViewStreamWrapper(out);
+			KeyGroupRangeOffsets offsets =
+					new KeyGroupRangeOffsets(i == 9 ? KeyGroupRange.EMPTY_KEY_GROUP_RANGE : new KeyGroupRange(prev, end));
+			prev = end + 1;
+			for (int kg : offsets.getKeyGroupRange()) {
+				offsets.setKeyGroupOffset(kg, out.getPosition());
+				dov.writeInt(kg);
+				++writtenKeyGroups;
+			}
+
+			KeyGroupsStateHandle handle =
+					new KeyGroupsStateHandle(offsets, new ByteStateHandleCloseChecking("kg-" + i, out.toByteArray()));
+
+			keyGroupsStateHandles.add(handle);
+		}
+
+		List<OperatorStateHandle> operatorStateHandles = new ArrayList<>(NUM_HANDLES);
+
+		for (int i = 0; i < NUM_HANDLES; ++i) {
+			int size = i % 4;
+			out.reset();
+			DataOutputView dov = new DataOutputViewStreamWrapper(out);
+			LongArrayList offsets = new LongArrayList(size);
+			for (int s = 0; s < size; ++s) {
+				offsets.add(out.getPosition());
+				int val = i * NUM_HANDLES + s;
+				dov.writeInt(val);
+				writtenOperatorStates.add(val);
+			}
+
+			Map<String, long[]> offsetsMap = new HashMap<>();
+			offsetsMap.put(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME, offsets.toArray());
+			OperatorStateHandle operatorStateHandle =
+					new OperatorStateHandle(offsetsMap, new ByteStateHandleCloseChecking("os-" + i, out.toByteArray()));
+			operatorStateHandles.add(operatorStateHandle);
+		}
+
+		this.initializationContext =
+				new StateInitializationContextImpl(
+						true,
+						stateStore,
+						mock(KeyedStateStore.class),
+						keyGroupsStateHandles,
+						operatorStateHandles,
+						closableRegistry);
+	}
+
+	@Test
+	public void getOperatorStateStreams() throws Exception {
+
+	}
+
+	@Test
+	public void getKeyedStateStreams() throws Exception {
+
+		int readKeyGroupCount = 0;
+
+		for (KeyGroupStatePartitionStreamProvider stateStreamProvider
+				: initializationContext.getRawKeyedStateInputs()) {
+
+			Assert.assertNotNull(stateStreamProvider);
+
+			try (InputStream is = stateStreamProvider.getStream()) {
+				DataInputView div = new DataInputViewStreamWrapper(is);
+				int val = div.readInt();
+				++readKeyGroupCount;
+				Assert.assertEquals(stateStreamProvider.getKeyGroupId(), val);
+			}
+		}
+
+		Assert.assertEquals(writtenKeyGroups, readKeyGroupCount);
+	}
+
+	@Test
+	public void getOperatorStateStore() throws Exception {
+
+		Set<Integer> readStatesCount = new HashSet<>();
+
+		for (StatePartitionStreamProvider statePartitionStreamProvider
+				: initializationContext.getRawOperatorStateInputs()) {
+
+			Assert.assertNotNull(statePartitionStreamProvider);
+
+			try (InputStream is = statePartitionStreamProvider.getStream()) {
+				DataInputView div = new DataInputViewStreamWrapper(is);
+				Assert.assertTrue(readStatesCount.add(div.readInt()));
+			}
+		}
+
+		Assert.assertEquals(writtenOperatorStates, readStatesCount);
+	}
+
+	@Test
+	public void close() throws Exception {
+
+		int count = 0;
+		int stopCount = NUM_HANDLES / 2;
+		boolean isClosed = false;
+
+
+		try {
+			for (KeyGroupStatePartitionStreamProvider stateStreamProvider
+					: initializationContext.getRawKeyedStateInputs()) {
+				Assert.assertNotNull(stateStreamProvider);
+
+				if (count == stopCount) {
+					initializationContext.close();
+					isClosed = true;
+				}
+
+				try (InputStream is = stateStreamProvider.getStream()) {
+					DataInputView div = new DataInputViewStreamWrapper(is);
+					try {
+						int val = div.readInt();
+						Assert.assertEquals(stateStreamProvider.getKeyGroupId(), val);
+						if (isClosed) {
+							Assert.fail("Close was ignored: stream");
+						}
+						++count;
+					} catch (IOException ioex) {
+						if (!isClosed) {
+							throw ioex;
+						}
+					}
+				}
+			}
+			Assert.fail("Close was ignored: registry");
+		} catch (IOException iex) {
+			Assert.assertTrue(isClosed);
+			Assert.assertEquals(stopCount, count);
+		}
+
+	}
+
+	static final class ByteStateHandleCloseChecking extends ByteStreamStateHandle {
+
+		private static final long serialVersionUID = -6201941296931334140L;
+
+		public ByteStateHandleCloseChecking(String handleName, byte[] data) {
+			super(handleName, data);
+		}
+
+		@Override
+		public FSDataInputStream openInputStream() throws IOException {
+			return new FSDataInputStream() {
+				private int index = 0;
+				private boolean closed = false;
+
+				@Override
+				public void seek(long desired) throws IOException {
+					Preconditions.checkArgument(desired >= 0 && desired < Integer.MAX_VALUE);
+					index = (int) desired;
+				}
+
+				@Override
+				public long getPos() throws IOException {
+					return index;
+				}
+
+				@Override
+				public int read() throws IOException {
+					if (closed) {
+						throw new IOException("Stream closed");
+					}
+					return index < data.length ? data[index++] & 0xFF : -1;
+				}
+
+				@Override
+				public void close() throws IOException {
+					super.close();
+					this.closed = true;
+				}
+			};
+		}
+	}
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateSnapshotContextSynchronousImplTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateSnapshotContextSynchronousImplTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateSnapshotContextSynchronousImplTest.java
new file mode 100644
index 0000000..0ee839e
--- /dev/null
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateSnapshotContextSynchronousImplTest.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators;
+
+import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.runtime.state.ClosableRegistry;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyedStateCheckpointOutputStream;
+import org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream;
+import org.apache.flink.runtime.state.StateSnapshotContextSynchronousImpl;
+import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+public class StateSnapshotContextSynchronousImplTest {
+
+	private StateSnapshotContextSynchronousImpl snapshotContext;
+
+	@Before
+	public void setUp() throws Exception {
+		ClosableRegistry closableRegistry = new ClosableRegistry();
+		CheckpointStreamFactory streamFactory = new MemCheckpointStreamFactory(1024);
+		KeyGroupRange keyGroupRange = new KeyGroupRange(0, 2);
+		this.snapshotContext = new StateSnapshotContextSynchronousImpl(42, 4711, streamFactory, keyGroupRange, closableRegistry);
+	}
+
+	@Test
+	public void testMetaData() {
+		Assert.assertEquals(42, snapshotContext.getCheckpointId());
+		Assert.assertEquals(4711, snapshotContext.getCheckpointTimestamp());
+	}
+
+	@Test
+	public void testCreateRawKeyedStateOutput() throws Exception {
+		KeyedStateCheckpointOutputStream stream = snapshotContext.getRawKeyedOperatorStateOutput();
+		Assert.assertNotNull(stream);
+	}
+
+	@Test
+	public void testCreateRawOperatorStateOutput() throws Exception {
+		OperatorStateCheckpointOutputStream stream = snapshotContext.getRawOperatorStateOutput();
+		Assert.assertNotNull(stream);
+	}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamOperatorSnapshotRestoreTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamOperatorSnapshotRestoreTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamOperatorSnapshotRestoreTest.java
new file mode 100644
index 0000000..ada0b86
--- /dev/null
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamOperatorSnapshotRestoreTest.java
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.state.KeyGroupStatePartitionStreamProvider;
+import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateCheckpointOutputStream;
+import org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream;
+import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StatePartitionStreamProvider;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
+import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness;
+import org.apache.flink.util.FutureUtil;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.InputStream;
+import java.util.BitSet;
+import java.util.Collections;
+
+public class StreamOperatorSnapshotRestoreTest {
+
+	@Test
+	public void testOperatorStatesSnapshotRestore() throws Exception {
+
+		//-------------------------------------------------------------------------- snapshot
+
+		TestOneInputStreamOperator op = new TestOneInputStreamOperator(false);
+
+		KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> testHarness =
+				new KeyedOneInputStreamOperatorTestHarness<>(op, new KeySelector<Integer, Integer>() {
+					@Override
+					public Integer getKey(Integer value) throws Exception {
+						return value;
+					}
+				}, TypeInformation.of(Integer.class));
+
+		testHarness.open();
+
+		for (int i = 0; i < 10; ++i) {
+			testHarness.processElement(new StreamRecord<>(i));
+		}
+
+		OperatorSnapshotResult snapshotInProgress = testHarness.snapshot(1L, 1L);
+
+		KeyGroupsStateHandle keyedManaged =
+				FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getKeyedStateManagedFuture());
+		KeyGroupsStateHandle keyedRaw =
+				FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getKeyedStateRawFuture());
+
+		OperatorStateHandle opManaged =
+				FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getOperatorStateManagedFuture());
+		OperatorStateHandle opRaw =
+				FutureUtil.runIfNotDoneAndGet(snapshotInProgress.getOperatorStateRawFuture());
+
+		testHarness.close();
+
+		//-------------------------------------------------------------------------- restore
+
+		op = new TestOneInputStreamOperator(true);
+		testHarness = new KeyedOneInputStreamOperatorTestHarness<>(op, new KeySelector<Integer, Integer>() {
+			@Override
+			public Integer getKey(Integer value) throws Exception {
+				return value;
+			}
+		}, TypeInformation.of(Integer.class));
+
+		testHarness.initializeState(new OperatorStateHandles(
+				0,
+				null,
+				Collections.singletonList(keyedManaged),
+				Collections.singletonList(keyedRaw),
+				Collections.singletonList(opManaged),
+				Collections.singletonList(opRaw)));
+
+		testHarness.open();
+
+		for (int i = 0; i < 10; ++i) {
+			testHarness.processElement(new StreamRecord<>(i));
+		}
+
+		testHarness.close();
+	}
+
+	static class TestOneInputStreamOperator
+			extends AbstractStreamOperator<Integer>
+			implements OneInputStreamOperator<Integer, Integer> {
+
+		private static final long serialVersionUID = -8942866418598856475L;
+
+		public TestOneInputStreamOperator(boolean verifyRestore) {
+			this.verifyRestore = verifyRestore;
+		}
+
+		private boolean verifyRestore;
+		private ValueState<Integer> keyedState;
+		private ListState<Integer> opState;
+
+		@Override
+		public void processElement(StreamRecord<Integer> element) throws Exception {
+			if (verifyRestore) {
+				// check restored managed keyed state
+				long exp = element.getValue() + 1;
+				long act = keyedState.value();
+				Assert.assertEquals(exp, act);
+			} else {
+				// write managed keyed state that goes into snapshot
+				keyedState.update(element.getValue() + 1);
+				// write managed operator state that goes into snapshot
+				opState.add(element.getValue());
+			}
+		}
+
+		@Override
+		public void processWatermark(Watermark mark) throws Exception {
+
+		}
+
+		@Override
+		public void snapshotState(StateSnapshotContext context) throws Exception {
+
+			KeyedStateCheckpointOutputStream out = context.getRawKeyedOperatorStateOutput();
+			DataOutputView dov = new DataOutputViewStreamWrapper(out);
+
+			// write raw keyed state that goes into snapshot
+			int count = 0;
+			for (int kg : out.getKeyGroupList()) {
+				out.startNewKeyGroup(kg);
+				dov.writeInt(kg + 2);
+				++count;
+			}
+
+			Assert.assertEquals(KeyedOneInputStreamOperatorTestHarness.MAX_PARALLELISM, count);
+
+			// write raw operator state that goes into snapshot
+			OperatorStateCheckpointOutputStream outOp = context.getRawOperatorStateOutput();
+			dov = new DataOutputViewStreamWrapper(outOp);
+			for (int i = 0; i < 13; ++i) {
+				outOp.startNewPartition();
+				dov.writeInt(42 + i);
+			}
+		}
+
+		@Override
+		public void initializeState(StateInitializationContext context) throws Exception {
+
+			Assert.assertEquals(verifyRestore, context.isRestored());
+
+			keyedState = context.getManagedKeyedStateStore().getState(new ValueStateDescriptor<>("managed-keyed", Integer.class, 0));
+			opState = context.getManagedOperatorStateStore().getSerializableListState("managed-op-state");
+
+			if (context.isRestored()) {
+				// check restored raw keyed state
+				int count = 0;
+				for (KeyGroupStatePartitionStreamProvider streamProvider : context.getRawKeyedStateInputs()) {
+					try (InputStream in = streamProvider.getStream()) {
+						DataInputView div = new DataInputViewStreamWrapper(in);
+						Assert.assertEquals(streamProvider.getKeyGroupId() + 2, div.readInt());
+						++count;
+					}
+				}
+				Assert.assertEquals(KeyedOneInputStreamOperatorTestHarness.MAX_PARALLELISM, count);
+
+				// check restored managed operator state
+				BitSet check = new BitSet(10);
+				for (int v : opState.get()) {
+					check.set(v);
+				}
+
+				Assert.assertEquals(10, check.cardinality());
+
+				// check restored raw operator state
+				check = new BitSet(13);
+				for (StatePartitionStreamProvider streamProvider : context.getRawOperatorStateInputs()) {
+					try (InputStream in = streamProvider.getStream()) {
+						DataInputView div = new DataInputViewStreamWrapper(in);
+						check.set(div.readInt() - 42);
+					}
+				}
+				Assert.assertEquals(13, check.cardinality());
+			}
+		}
+	}
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java
index 02409a3..155a16f 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java
@@ -37,11 +37,14 @@ import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
 import org.apache.flink.runtime.query.KvStateRegistry;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
+import org.apache.flink.runtime.state.DefaultKeyedStateStore;
 import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.junit.Test;
+import org.mockito.Matchers;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
 
@@ -53,6 +56,7 @@ import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
@@ -160,18 +164,24 @@ public class StreamingRuntimeContextTest {
 			final AtomicReference<Object> ref, final ExecutionConfig config) throws Exception {
 		
 		AbstractStreamOperator<?> operatorMock = mock(AbstractStreamOperator.class);
+
+		KeyedStateBackend keyedStateBackend= mock(KeyedStateBackend.class);
+
+		DefaultKeyedStateStore keyedStateStore = new DefaultKeyedStateStore(keyedStateBackend, config);
+
 		when(operatorMock.getExecutionConfig()).thenReturn(config);
-		
-		when(operatorMock.getPartitionedState(any(StateDescriptor.class))).thenAnswer(
-				new Answer<Object>() {
-					
-					@Override
-					public Object answer(InvocationOnMock invocationOnMock) throws Throwable {
-						ref.set(invocationOnMock.getArguments()[0]);
-						return null;
-					}
-				});
-		
+
+		doAnswer(new Answer<Object>() {
+
+			@Override
+			public Object answer(InvocationOnMock invocationOnMock) throws Throwable {
+				ref.set(invocationOnMock.getArguments()[2]);
+				return null;
+			}
+		}).when(keyedStateBackend).getPartitionedState(Matchers.any(), any(TypeSerializer.class), any(StateDescriptor.class));
+
+		when(operatorMock.getKeyedStateStore()).thenReturn(keyedStateStore);
+
 		return operatorMock;
 	}
 
@@ -179,29 +189,35 @@ public class StreamingRuntimeContextTest {
 	private static AbstractStreamOperator<?> createPlainMockOp() throws Exception {
 
 		AbstractStreamOperator<?> operatorMock = mock(AbstractStreamOperator.class);
-		when(operatorMock.getExecutionConfig()).thenReturn(new ExecutionConfig());
-
-		when(operatorMock.getPartitionedState(any(ListStateDescriptor.class))).thenAnswer(
-				new Answer<ListState<String>>() {
-
-					@Override
-					public ListState<String> answer(InvocationOnMock invocationOnMock) throws Throwable {
-						ListStateDescriptor<String> descr =
-								(ListStateDescriptor<String>) invocationOnMock.getArguments()[0];
-
-						AbstractKeyedStateBackend<Integer> backend = new MemoryStateBackend().createKeyedStateBackend(
-								new DummyEnvironment("test_task", 1, 0),
-								new JobID(),
-								"test_op",
-								IntSerializer.INSTANCE,
-								1,
-								new KeyGroupRange(0, 0),
-								new KvStateRegistry().createTaskRegistry(new JobID(), new JobVertexID()));
-						backend.setCurrentKey(0);
-						return backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, descr);
-					}
-				});
+		ExecutionConfig config = new ExecutionConfig();
+
+		KeyedStateBackend keyedStateBackend= mock(KeyedStateBackend.class);
+
+		DefaultKeyedStateStore keyedStateStore = new DefaultKeyedStateStore(keyedStateBackend, config);
+
+		when(operatorMock.getExecutionConfig()).thenReturn(config);
 
+		doAnswer(new Answer<ListState<String>>() {
+
+			@Override
+			public ListState<String> answer(InvocationOnMock invocationOnMock) throws Throwable {
+				ListStateDescriptor<String> descr =
+						(ListStateDescriptor<String>) invocationOnMock.getArguments()[2];
+
+				AbstractKeyedStateBackend<Integer> backend = new MemoryStateBackend().createKeyedStateBackend(
+						new DummyEnvironment("test_task", 1, 0),
+						new JobID(),
+						"test_op",
+						IntSerializer.INSTANCE,
+						1,
+						new KeyGroupRange(0, 0),
+						new KvStateRegistry().createTaskRegistry(new JobID(), new JobVertexID()));
+				backend.setCurrentKey(0);
+				return backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, descr);
+			}
+		}).when(keyedStateBackend).getPartitionedState(Matchers.any(), any(TypeSerializer.class), any(ListStateDescriptor.class));
+
+		when(operatorMock.getKeyedStateStore()).thenReturn(keyedStateStore);
 		return operatorMock;
 	}
 	

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java
index 59242e8..f2fc876 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java
@@ -32,6 +32,7 @@ import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.TaskStateHandles;
 import org.junit.AfterClass;
 import org.junit.BeforeClass;
 import org.junit.Test;
@@ -973,10 +974,7 @@ public class BarrierBufferTest {
 		}
 
 		@Override
-		public void setInitialState(
-				ChainedStateHandle<StreamStateHandle> chainedState,
-				List<KeyGroupsStateHandle> keyGroupsState,
-				List<Collection<OperatorStateHandle>> partitionableOperatorState) throws Exception {
+		public void setInitialState(TaskStateHandles taskStateHandles) throws Exception {
 			throw new UnsupportedOperationException("should never be called");
 		}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java
index b6d0450..7cfbb66 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java
@@ -29,6 +29,7 @@ import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.TaskStateHandles;
 import org.junit.Test;
 
 import java.util.Arrays;
@@ -366,10 +367,7 @@ public class BarrierTrackerTest {
 		}
 
 		@Override
-		public void setInitialState(
-				ChainedStateHandle<StreamStateHandle> chainedState,
-				List<KeyGroupsStateHandle> keyGroupsState,
-				List<Collection<OperatorStateHandle>> partitionableOperatorState) throws Exception {
+		public void setInitialState(TaskStateHandles taskStateHandles) throws Exception {
 
 			throw new UnsupportedOperationException("should never be called");
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSinkTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSinkTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSinkTest.java
index d1ba489..e5e26e9 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSinkTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSinkTest.java
@@ -129,7 +129,7 @@ public class GenericWriteAheadSinkTest extends WriteAheadSinkTestBase<Tuple1<Int
 			elementCounter++;
 		}
 
-		testHarness.snapshot(0, 0);
+		testHarness.snapshotLegacy(0, 0);
 		testHarness.notifyOfCompletedCheckpoint(0);
 
 		//isCommitted should have failed, thus sendValues() should never have been called
@@ -140,7 +140,7 @@ public class GenericWriteAheadSinkTest extends WriteAheadSinkTestBase<Tuple1<Int
 			elementCounter++;
 		}
 
-		testHarness.snapshot(1, 0);
+		testHarness.snapshotLegacy(1, 0);
 		testHarness.notifyOfCompletedCheckpoint(1);
 
 		//previous CP should be retried, but will fail the CP commit. Second CP should be skipped.
@@ -151,7 +151,7 @@ public class GenericWriteAheadSinkTest extends WriteAheadSinkTestBase<Tuple1<Int
 			elementCounter++;
 		}
 
-		testHarness.snapshot(2, 0);
+		testHarness.snapshotLegacy(2, 0);
 		testHarness.notifyOfCompletedCheckpoint(2);
 
 		//all CP's should be retried and succeed; since one CP was written twice we have 2 * 10 + 10 + 10 = 40 values

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/WriteAheadSinkTestBase.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/WriteAheadSinkTestBase.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/WriteAheadSinkTestBase.java
index b7203b5..3d1e6e8 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/WriteAheadSinkTestBase.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/WriteAheadSinkTestBase.java
@@ -66,7 +66,7 @@ public abstract class WriteAheadSinkTestBase<IN, S extends GenericWriteAheadSink
 			elementCounter++;
 		}
 
-		testHarness.snapshot(snapshotCount++, 0);
+		testHarness.snapshotLegacy(snapshotCount++, 0);
 		testHarness.notifyOfCompletedCheckpoint(snapshotCount - 1);
 
 		for (int x = 0; x < 20; x++) {
@@ -74,7 +74,7 @@ public abstract class WriteAheadSinkTestBase<IN, S extends GenericWriteAheadSink
 			elementCounter++;
 		}
 
-		testHarness.snapshot(snapshotCount++, 0);
+		testHarness.snapshotLegacy(snapshotCount++, 0);
 		testHarness.notifyOfCompletedCheckpoint(snapshotCount - 1);
 
 		for (int x = 0; x < 20; x++) {
@@ -82,7 +82,7 @@ public abstract class WriteAheadSinkTestBase<IN, S extends GenericWriteAheadSink
 			elementCounter++;
 		}
 
-		testHarness.snapshot(snapshotCount++, 0);
+		testHarness.snapshotLegacy(snapshotCount++, 0);
 		testHarness.notifyOfCompletedCheckpoint(snapshotCount - 1);
 
 		verifyResultsIdealCircumstances(testHarness, sink);
@@ -105,7 +105,7 @@ public abstract class WriteAheadSinkTestBase<IN, S extends GenericWriteAheadSink
 			elementCounter++;
 		}
 
-		testHarness.snapshot(snapshotCount++, 0);
+		testHarness.snapshotLegacy(snapshotCount++, 0);
 		testHarness.notifyOfCompletedCheckpoint(snapshotCount - 1);
 
 		for (int x = 0; x < 20; x++) {
@@ -113,14 +113,14 @@ public abstract class WriteAheadSinkTestBase<IN, S extends GenericWriteAheadSink
 			elementCounter++;
 		}
 
-		testHarness.snapshot(snapshotCount++, 0);
+		testHarness.snapshotLegacy(snapshotCount++, 0);
 
 		for (int x = 0; x < 20; x++) {
 			testHarness.processElement(new StreamRecord<>(generateValue(elementCounter, 2)));
 			elementCounter++;
 		}
 
-		testHarness.snapshot(snapshotCount++, 0);
+		testHarness.snapshotLegacy(snapshotCount++, 0);
 		testHarness.notifyOfCompletedCheckpoint(snapshotCount - 1);
 
 		verifyResultsDataPersistenceUponMissedNotify(testHarness, sink);
@@ -143,7 +143,7 @@ public abstract class WriteAheadSinkTestBase<IN, S extends GenericWriteAheadSink
 			elementCounter++;
 		}
 
-		StreamStateHandle latestSnapshot = testHarness.snapshot(snapshotCount++, 0);
+		StreamStateHandle latestSnapshot = testHarness.snapshotLegacy(snapshotCount++, 0);
 		testHarness.notifyOfCompletedCheckpoint(snapshotCount - 1);
 
 		for (int x = 0; x < 20; x++) {
@@ -166,7 +166,7 @@ public abstract class WriteAheadSinkTestBase<IN, S extends GenericWriteAheadSink
 			elementCounter++;
 		}
 
-		testHarness.snapshot(snapshotCount++, 0);
+		testHarness.snapshotLegacy(snapshotCount++, 0);
 		testHarness.notifyOfCompletedCheckpoint(snapshotCount - 1);
 
 		verifyResultsDataDiscardingUponRestore(testHarness, sink);

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java
index 51e61a1..e96109e 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java
@@ -522,7 +522,7 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest {
 
 			// draw a snapshot and dispose the window
 			int beforeSnapShot = testHarness.getOutput().size();
-			StreamStateHandle state = testHarness.snapshot(1L, System.currentTimeMillis());
+			StreamStateHandle state = testHarness.snapshotLegacy(1L, System.currentTimeMillis());
 			List<Integer> resultAtSnapshot = extractFromStreamRecords(testHarness.getOutput());
 			int afterSnapShot = testHarness.getOutput().size();
 			assertEquals("operator performed computation during snapshot", beforeSnapShot, afterSnapShot);
@@ -611,7 +611,7 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest {
 			// draw a snapshot
 			List<Integer> resultAtSnapshot = extractFromStreamRecords(testHarness.getOutput());
 			int beforeSnapShot = testHarness.getOutput().size();
-			StreamStateHandle state = testHarness.snapshot(1L, System.currentTimeMillis());
+			StreamStateHandle state = testHarness.snapshotLegacy(1L, System.currentTimeMillis());
 			int afterSnapShot = testHarness.getOutput().size();
 			assertEquals("operator performed computation during snapshot", beforeSnapShot, afterSnapShot);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java
index 12a842f..802329b 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java
@@ -634,7 +634,7 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest {
 			// draw a snapshot
 			List<Tuple2<Integer, Integer>> resultAtSnapshot = extractFromStreamRecords(testHarness.getOutput());
 			int beforeSnapShot = resultAtSnapshot.size();
-			StreamStateHandle state = testHarness.snapshot(1L, System.currentTimeMillis());
+			StreamStateHandle state = testHarness.snapshotLegacy(1L, System.currentTimeMillis());
 			int afterSnapShot = testHarness.getOutput().size();
 			assertEquals("operator performed computation during snapshot", beforeSnapShot, afterSnapShot);
 
@@ -727,7 +727,7 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest {
 			// draw a snapshot
 			List<Tuple2<Integer, Integer>> resultAtSnapshot = extractFromStreamRecords(testHarness.getOutput());
 			int beforeSnapShot = resultAtSnapshot.size();
-			StreamStateHandle state = testHarness.snapshot(1L, System.currentTimeMillis());
+			StreamStateHandle state = testHarness.snapshotLegacy(1L, System.currentTimeMillis());
 			int afterSnapShot = testHarness.getOutput().size();
 			assertEquals("operator performed computation during snapshot", beforeSnapShot, afterSnapShot);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperatorTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperatorTest.java
index ba803e3..2b0b915 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperatorTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperatorTest.java
@@ -125,7 +125,7 @@ public class WindowOperatorTest extends TestLogger {
 		TestHarnessUtil.assertOutputEqualsSorted("Output was not correct.", expectedOutput, testHarness.getOutput(), new Tuple2ResultSortComparator());
 
 		// do a snapshot, close and restore again
-		StreamStateHandle snapshot = testHarness.snapshot(0L, 0L);
+		StreamStateHandle snapshot = testHarness.snapshotLegacy(0L, 0L);
 		testHarness.close();
 		testHarness.setup();
 		testHarness.restore(snapshot);
@@ -255,7 +255,7 @@ public class WindowOperatorTest extends TestLogger {
 		TestHarnessUtil.assertOutputEqualsSorted("Output was not correct.", expectedOutput, testHarness.getOutput(), new Tuple2ResultSortComparator());
 
 		// do a snapshot, close and restore again
-		StreamStateHandle snapshot = testHarness.snapshot(0L, 0L);
+		StreamStateHandle snapshot = testHarness.snapshotLegacy(0L, 0L);
 		testHarness.close();
 		testHarness.setup();
 		testHarness.restore(snapshot);
@@ -396,7 +396,7 @@ public class WindowOperatorTest extends TestLogger {
 		testHarness.processElement(new StreamRecord<>(new Tuple2<>("key1", 2), 1000));
 
 		// do a snapshot, close and restore again
-		StreamStateHandle snapshot = testHarness.snapshot(0L, 0L);
+		StreamStateHandle snapshot = testHarness.snapshotLegacy(0L, 0L);
 		testHarness.close();
 		testHarness.setup();
 		testHarness.restore(snapshot);
@@ -465,7 +465,7 @@ public class WindowOperatorTest extends TestLogger {
 		testHarness.processElement(new StreamRecord<>(new Tuple2<>("key2", 3), 2500));
 
 		// do a snapshot, close and restore again
-		StreamStateHandle snapshot = testHarness.snapshot(0L, 0L);
+		StreamStateHandle snapshot = testHarness.snapshotLegacy(0L, 0L);
 		testHarness.close();
 		testHarness.setup();
 		testHarness.restore(snapshot);
@@ -543,7 +543,7 @@ public class WindowOperatorTest extends TestLogger {
 		testHarness.processElement(new StreamRecord<>(new Tuple2<>("key1", 2), 1000));
 
 		// do a snapshot, close and restore again
-		StreamStateHandle snapshot = testHarness.snapshot(0L, 0L);
+		StreamStateHandle snapshot = testHarness.snapshotLegacy(0L, 0L);
 		testHarness.close();
 		testHarness.setup();
 		testHarness.restore(snapshot);
@@ -641,7 +641,7 @@ public class WindowOperatorTest extends TestLogger {
 		testHarness.processElement(new StreamRecord<>(new Tuple2<>("key2", 33), 1000));
 
 		// do a snapshot, close and restore again
-		StreamStateHandle snapshot = testHarness.snapshot(0L, 0L);
+		StreamStateHandle snapshot = testHarness.snapshotLegacy(0L, 0L);
 		testHarness.close();
 		testHarness.setup();
 		testHarness.restore(snapshot);
@@ -796,7 +796,7 @@ public class WindowOperatorTest extends TestLogger {
 		testHarness.processElement(new StreamRecord<>(new Tuple2<>("key2", 1), 1999));
 
 		// do a snapshot, close and restore again
-		StreamStateHandle snapshot = testHarness.snapshot(0L, 0L);
+		StreamStateHandle snapshot = testHarness.snapshotLegacy(0L, 0L);
 
 		testHarness.close();
 
@@ -884,7 +884,7 @@ public class WindowOperatorTest extends TestLogger {
 		operator.processingTimeTimersQueue.add(timer2);
 		operator.processingTimeTimersQueue.add(timer3);
 		
-		StreamStateHandle snapshot = testHarness.snapshot(0, 0);
+		StreamStateHandle snapshot = testHarness.snapshotLegacy(0, 0);
 
 		WindowOperator<String, Tuple2<String, Integer>, Tuple2<String, Integer>, Tuple2<String, Integer>, TimeWindow> otherOperator = new WindowOperator<>(
 				SlidingEventTimeWindows.of(Time.of(WINDOW_SIZE, TimeUnit.SECONDS), Time.of(WINDOW_SLIDE, TimeUnit.SECONDS)),

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
index b5b6582..ee5a203 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
@@ -45,6 +45,7 @@ import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.taskmanager.CheckpointResponder;
 import org.apache.flink.runtime.taskmanager.Task;
 import org.apache.flink.runtime.taskmanager.TaskManagerConnection;
@@ -124,8 +125,17 @@ public class InterruptSensitiveRestoreTest {
 			StreamStateHandle state) throws IOException {
 
 		ChainedStateHandle<StreamStateHandle> operatorState = new ChainedStateHandle<>(Collections.singletonList(state));
-		List<KeyGroupsStateHandle> keyGroupState = Collections.emptyList();
-		List<Collection<OperatorStateHandle>> partitionableOperatorState = Collections.emptyList();
+		List<KeyGroupsStateHandle> keyGroupStateFromBackend = Collections.emptyList();
+		List<KeyGroupsStateHandle> keyGroupStateFromStream = Collections.emptyList();
+		List<Collection<OperatorStateHandle>> operatorStateBackend = Collections.emptyList();
+		List<Collection<OperatorStateHandle>> operatorStateStream = Collections.emptyList();
+
+		TaskStateHandles taskStateHandles = new TaskStateHandles(
+				operatorState,
+				operatorStateBackend,
+				operatorStateStream,
+				keyGroupStateFromBackend,
+				keyGroupStateFromStream);
 
 		return new TaskDeploymentDescriptor(
 				new JobID(),
@@ -143,9 +153,7 @@ public class InterruptSensitiveRestoreTest {
 				Collections.<BlobKey>emptyList(),
 				Collections.<URL>emptyList(),
 				0,
-				operatorState,
-				keyGroupState,
-				partitionableOperatorState);
+				taskStateHandles);
 	}
 
 	private static Task createTask(TaskDeploymentDescriptor tdd) throws IOException {

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
index 1b2b723..3dd2ed7 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
@@ -31,15 +31,13 @@ import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.core.testutils.OneShotLatch;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.checkpoint.SubtaskState;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
-import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
-import org.apache.flink.runtime.state.CheckpointStreamFactory;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.OperatorStateHandle;
-import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.graph.StreamEdge;
 import org.apache.flink.streaming.api.graph.StreamNode;
@@ -64,8 +62,6 @@ import scala.concurrent.duration.FiniteDuration;
 
 import java.io.Serializable;
 import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -74,7 +70,6 @@ import java.util.Map;
 import java.util.Random;
 import java.util.Set;
 import java.util.concurrent.ConcurrentLinkedQueue;
-import java.util.concurrent.RunnableFuture;
 import java.util.concurrent.TimeUnit;
 
 import static org.junit.Assert.assertEquals;
@@ -390,7 +385,7 @@ public class OneInputStreamTaskTest extends TestLogger {
 		testHarness.waitForTaskCompletion(deadline.timeLeft().toMillis());
 
 		final OneInputStreamTask<String, String> restoredTask = new OneInputStreamTask<String, String>();
-		restoredTask.setInitialState(env.getState(), env.getKeyGroupStates(), env.getPartitionableOperatorState());
+		restoredTask.setInitialState(new TaskStateHandles(env.getCheckpointStateHandles()));
 
 		final OneInputStreamTaskTestHarness<String, String> restoredTaskHarness = new OneInputStreamTaskTestHarness<String, String>(restoredTask, BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO);
 		restoredTaskHarness.configureForKeyedStream(keySelector, BasicTypeInfo.STRING_TYPE_INFO);
@@ -482,9 +477,7 @@ public class OneInputStreamTaskTest extends TestLogger {
 
 	private static class AcknowledgeStreamMockEnvironment extends StreamMockEnvironment {
 		private volatile long checkpointId;
-		private volatile ChainedStateHandle<StreamStateHandle> state;
-		private volatile List<KeyGroupsStateHandle> keyGroupStates;
-		private volatile List<Collection<OperatorStateHandle>> partitionableOperatorState;
+		private volatile SubtaskState checkpointStateHandles;
 
 		private final OneShotLatch checkpointLatch = new OneShotLatch();
 
@@ -492,24 +485,6 @@ public class OneInputStreamTaskTest extends TestLogger {
 			return checkpointId;
 		}
 
-		public ChainedStateHandle<StreamStateHandle> getState() {
-			return state;
-		}
-
-		List<KeyGroupsStateHandle> getKeyGroupStates() {
-			List<KeyGroupsStateHandle> result = new ArrayList<>();
-			for (KeyGroupsStateHandle keyGroupState : keyGroupStates) {
-				if (keyGroupState != null) {
-					result.add(keyGroupState);
-				}
-			}
-			return result;
-		}
-
-		List<Collection<OperatorStateHandle>> getPartitionableOperatorState() {
-			return partitionableOperatorState;
-		}
-
 		AcknowledgeStreamMockEnvironment(
 				Configuration jobConfig, Configuration taskConfig,
 				ExecutionConfig executionConfig, long memorySize,
@@ -521,26 +496,20 @@ public class OneInputStreamTaskTest extends TestLogger {
 		@Override
 		public void acknowledgeCheckpoint(
 				CheckpointMetaData checkpointMetaData,
-				CheckpointStateHandles checkpointStateHandles) {
+				SubtaskState checkpointStateHandles) {
 
 			this.checkpointId = checkpointMetaData.getCheckpointId();
-			if(checkpointStateHandles != null) {
-				this.state = checkpointStateHandles.getNonPartitionedStateHandles();
-				this.keyGroupStates = checkpointStateHandles.getKeyGroupsStateHandle();
-				ChainedStateHandle<OperatorStateHandle> chainedStateHandle = checkpointStateHandles.getPartitioneableStateHandles();
-				Collection<OperatorStateHandle>[] ia = new Collection[chainedStateHandle.getLength()];
-				this.partitionableOperatorState = Arrays.asList(ia);
-
-				for (int i = 0; i < chainedStateHandle.getLength(); ++i) {
-					partitionableOperatorState.set(i, Collections.singletonList(chainedStateHandle.get(i)));
-				}
-			}
+			this.checkpointStateHandles = checkpointStateHandles;
 			checkpointLatch.trigger();
 		}
 
 		public OneShotLatch getCheckpointLatch() {
 			return checkpointLatch;
 		}
+
+		public SubtaskState getCheckpointStateHandles() {
+			return checkpointStateHandles;
+		}
 	}
 
 	private static class TestingStreamOperator<IN, OUT>
@@ -580,9 +549,7 @@ public class OneInputStreamTaskTest extends TestLogger {
 		}
 
 		@Override
-		public RunnableFuture<OperatorStateHandle> snapshotState(
-				long checkpointId, long timestamp, CheckpointStreamFactory streamFactory) throws Exception {
-
+		public void snapshotState(StateSnapshotContext context) throws Exception {
 			ListState<Integer> partitionableState =
 					getOperatorStateBackend().getOperatorState(TEST_DESCRIPTOR);
 			partitionableState.clear();
@@ -591,7 +558,11 @@ public class OneInputStreamTaskTest extends TestLogger {
 			partitionableState.add(4711);
 
 			++numberSnapshotCalls;
-			return super.snapshotState(checkpointId, timestamp, streamFactory);
+		}
+
+		@Override
+		public void initializeState(StateInitializationContext context) throws Exception {
+
 		}
 
 		TestingStreamOperator(long seed, long recoveryTimestamp) {
@@ -643,7 +614,6 @@ public class OneInputStreamTaskTest extends TestLogger {
 			assertEquals(random.nextInt(), (int) operatorState);
 		}
 
-
 		private Serializable generateFunctionState() {
 			return random.nextInt();
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
index f852682..36ecf59 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
@@ -28,6 +28,7 @@ import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.runtime.accumulators.AccumulatorRegistry;
 import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.checkpoint.SubtaskState;
 import org.apache.flink.runtime.event.AbstractEvent;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
@@ -50,7 +51,6 @@ import org.apache.flink.runtime.plugable.DeserializationDelegate;
 import org.apache.flink.runtime.plugable.NonReusingDeserializationDelegate;
 import org.apache.flink.runtime.query.KvStateRegistry;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
-import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
@@ -123,7 +123,7 @@ public class StreamMockEnvironment implements Environment {
 
 	public StreamMockEnvironment(Configuration jobConfig, Configuration taskConfig, long memorySize,
 								 MockInputSplitProvider inputSplitProvider, int bufferSize) {
-		this(jobConfig, taskConfig, null, memorySize, inputSplitProvider, bufferSize);
+		this(jobConfig, taskConfig, new ExecutionConfig(), memorySize, inputSplitProvider, bufferSize);
 	}
 
 	public void addInputGate(InputGate gate) {
@@ -313,7 +313,7 @@ public class StreamMockEnvironment implements Environment {
 
 	@Override
 	public void acknowledgeCheckpoint(
-			CheckpointMetaData checkpointMetaData, CheckpointStateHandles checkpointStateHandles) {
+			CheckpointMetaData checkpointMetaData, SubtaskState subtaskState) {
 	}
 
 	@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java
index b9211b1..1bb3fb0 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java
@@ -214,6 +214,17 @@ public class TwoInputStreamTaskTest {
 		expectedOutput.add(new StreamRecord<String>("111", initialTime));
 
 		testHarness.waitForInputProcessing();
+
+		// Wait to allow input to end up in the output.
+		// TODO Use count down latches instead as a cleaner solution
+		for (int i = 0; i < 20; ++i) {
+			if (testHarness.getOutput().size() >= expectedOutput.size()) {
+				break;
+			} else {
+				Thread.sleep(100);
+			}
+		}
+
 		// we should not yet see the barrier, only the two elements from non-blocked input
 		TestHarnessUtil.assertOutputEquals("Output was not correct.",
 				testHarness.getOutput(),
@@ -224,17 +235,17 @@ public class TwoInputStreamTaskTest {
 		testHarness.processEvent(new CheckpointBarrier(0, 0), 1, 1);
 
 		testHarness.waitForInputProcessing();
+		testHarness.endInput();
+		testHarness.waitForTaskCompletion();
 
 		// now we should see the barrier and after that the buffered elements
 		expectedOutput.add(new CheckpointBarrier(0, 0));
 		expectedOutput.add(new StreamRecord<String>("Hello-0-0", initialTime));
-		TestHarnessUtil.assertOutputEquals("Output was not correct.",
-				testHarness.getOutput(),
-				expectedOutput);
 
-		testHarness.endInput();
+		TestHarnessUtil.assertOutputEquals("Output was not correct.",
+				expectedOutput,
+				testHarness.getOutput());
 
-		testHarness.waitForTaskCompletion();
 
 		List<String> resultElements = TestHarnessUtil.getRawElementsFromOutput(testHarness.getOutput());
 		Assert.assertEquals(4, resultElements.size());

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
index 5275a39..41968e6 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
@@ -31,14 +31,16 @@ import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
-import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
+import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
 import org.apache.flink.streaming.runtime.tasks.TestTimeServiceProvider;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
 
 import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.concurrent.RunnableFuture;
 
@@ -60,7 +62,7 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 
 	// when we restore we keep the state here so that we can call restore
 	// when the operator requests the keyed state backend
-	private KeyGroupsStateHandle restoredKeyedState = null;
+	private Collection<KeyGroupsStateHandle> restoredKeyedState = null;
 
 	public KeyedOneInputStreamOperatorTestHarness(
 			OneInputStreamOperator<IN, OUT> operator,
@@ -138,7 +140,7 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 								keySerializer,
 								numberOfKeyGroups,
 								keyGroupRange,
-								Collections.singletonList(restoredKeyedState),
+								restoredKeyedState,
 								mockTask.getEnvironment().getTaskKvStateRegistry());
 						restoredKeyedState = null;
 						return keyedStateBackend;
@@ -154,7 +156,7 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 	 *
 	 */
 	@Override
-	public StreamStateHandle snapshot(long checkpointId, long timestamp) throws Exception {
+	public StreamStateHandle snapshotLegacy(long checkpointId, long timestamp) throws Exception {
 		// simply use an in-memory handle
 		MemoryStateBackend backend = new MemoryStateBackend();
 
@@ -185,7 +187,7 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 	}
 
 	/**
-	 * 
+	 *
 	 */
 	@Override
 	public void restore(StreamStateHandle snapshot) throws Exception {
@@ -198,7 +200,7 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 			byte keyedStatePresent = (byte) inStream.read();
 			if (keyedStatePresent == 1) {
 				ObjectInputStream ois = new ObjectInputStream(inStream);
-				this.restoredKeyedState = (KeyGroupsStateHandle) ois.readObject();
+				this.restoredKeyedState = Collections.singletonList((KeyGroupsStateHandle) ois.readObject());
 			}
 		}
 	}
@@ -208,8 +210,16 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 	 */
 	public void close() throws Exception {
 		super.close();
-		if(keyedStateBackend != null) {
+		if (keyedStateBackend != null) {
 			keyedStateBackend.dispose();
 		}
 	}
+
+	@Override
+	public void initializeState(OperatorStateHandles operatorStateHandles) throws Exception {
+		if (null != operatorStateHandles) {
+			this.restoredKeyedState = operatorStateHandles.getManagedKeyedState();
+		}
+		super.initializeState(operatorStateHandles);
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java
index d1622ff..9f8d223 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java
@@ -28,22 +28,26 @@ import org.apache.flink.runtime.operators.testutils.MockEnvironment;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.runtime.state.ClosableRegistry;
+import org.apache.flink.runtime.state.OperatorStateBackend;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
-import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
 import org.apache.flink.streaming.api.TimeCharacteristic;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OperatorSnapshotResult;
 import org.apache.flink.streaming.api.operators.Output;
+import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
 import org.apache.flink.streaming.api.operators.StreamOperator;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.DefaultTimeServiceProvider;
+import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
 import org.apache.flink.streaming.runtime.tasks.TestTimeServiceProvider;
 import org.apache.flink.streaming.runtime.tasks.TimeServiceProvider;
-
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
 
@@ -65,7 +69,7 @@ import static org.mockito.Mockito.when;
  */
 public class OneInputStreamOperatorTestHarness<IN, OUT> {
 
-	protected static final int MAX_PARALLELISM = 10;
+	public static final int MAX_PARALLELISM = 10;
 
 	final OneInputStreamOperator<IN, OUT> operator;
 
@@ -81,6 +85,8 @@ public class OneInputStreamOperatorTestHarness<IN, OUT> {
 
 	StreamTask<?, ?> mockTask;
 
+	ClosableRegistry closableRegistry;
+
 	// use this as default for tests
 	AbstractStateBackend stateBackend = new MemoryStateBackend();
 
@@ -88,6 +94,7 @@ public class OneInputStreamOperatorTestHarness<IN, OUT> {
 	 * Whether setup() was called on the operator. This is reset when calling close().
 	 */
 	private boolean setupCalled = false;
+	private boolean initializeCalled = false;
 
 	private volatile boolean wasFailedExternally = false;
 
@@ -121,6 +128,7 @@ public class OneInputStreamOperatorTestHarness<IN, OUT> {
 		this.config.setCheckpointingEnabled(true);
 		this.executionConfig = executionConfig;
 		this.checkpointLock = checkpointLock;
+		this.closableRegistry = new ClosableRegistry();
 
 		final Environment env = new MockEnvironment("MockTwoInputTask", 3 * 1024 * 1024, new MockInputSplitProvider(), 1024, underlyingConfig, executionConfig, MAX_PARALLELISM, 1, 0);
 		mockTask = mock(StreamTask.class);
@@ -132,6 +140,7 @@ public class OneInputStreamOperatorTestHarness<IN, OUT> {
 		when(mockTask.getEnvironment()).thenReturn(env);
 		when(mockTask.getExecutionConfig()).thenReturn(executionConfig);
 		when(mockTask.getUserCodeClassLoader()).thenReturn(this.getClass().getClassLoader());
+		when(mockTask.getCancelables()).thenReturn(this.closableRegistry);
 
 		doAnswer(new Answer<Void>() {
 			@Override
@@ -154,6 +163,26 @@ public class OneInputStreamOperatorTestHarness<IN, OUT> {
 			throw new RuntimeException(e.getMessage(), e);
 		}
 
+		try {
+			doAnswer(new Answer<OperatorStateBackend>() {
+				@Override
+				public OperatorStateBackend answer(InvocationOnMock invocationOnMock) throws Throwable {
+					final StreamOperator<?> operator = (StreamOperator<?>) invocationOnMock.getArguments()[0];
+					final Collection<OperatorStateHandle> stateHandles = (Collection<OperatorStateHandle>) invocationOnMock.getArguments()[1];
+					OperatorStateBackend osb;
+					if (null == stateHandles) {
+						osb = stateBackend.createOperatorStateBackend(env, operator.getClass().getSimpleName());
+					} else {
+						osb = stateBackend.restoreOperatorStateBackend(env, operator.getClass().getSimpleName(), stateHandles);
+					}
+					mockTask.getCancelables().registerClosable(osb);
+					return osb;
+				}
+			}).when(mockTask).createOperatorStateBackend(any(StreamOperator.class), any(Collection.class));
+		} catch (Exception e) {
+			throw new RuntimeException(e.getMessage(), e);
+		}
+
 		timeServiceProvider = testTimeProvider != null ? testTimeProvider :
 			new DefaultTimeServiceProvider(mockTask, this.checkpointLock);
 
@@ -199,8 +228,7 @@ public class OneInputStreamOperatorTestHarness<IN, OUT> {
 	}
 
 	/**
-	 * Calls
-	 * {@link org.apache.flink.streaming.api.operators.StreamOperator#setup(StreamTask, StreamConfig, Output)} ()}
+	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#setup(StreamTask, StreamConfig, Output)} ()}
 	 */
 	public void setup() throws Exception {
 		operator.setup(mockTask, config, new MockOutput());
@@ -208,21 +236,48 @@ public class OneInputStreamOperatorTestHarness<IN, OUT> {
 	}
 
 	/**
-	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#open()}. This also
-	 * calls {@link org.apache.flink.streaming.api.operators.StreamOperator#setup(StreamTask, StreamConfig, Output)}
+	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#initializeState(OperatorStateHandles)}.
+	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#setup(StreamTask, StreamConfig, Output)}
 	 * if it was not called before.
 	 */
-	public void open() throws Exception {
+	public void initializeState(OperatorStateHandles operatorStateHandles) throws Exception {
 		if (!setupCalled) {
 			setup();
 		}
+		operator.initializeState(operatorStateHandles);
+		initializeCalled = true;
+	}
+
+	/**
+	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#open()}.
+	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#initializeState(OperatorStateHandles)} if it
+	 * was not called before.
+	 */
+	public void open() throws Exception {
+		if (!initializeCalled) {
+			initializeState(null);
+		}
 		operator.open();
 	}
 
 	/**
 	 *
 	 */
-	public StreamStateHandle snapshot(long checkpointId, long timestamp) throws Exception {
+	public OperatorSnapshotResult snapshot(long checkpointId, long timestamp) throws Exception {
+
+		CheckpointStreamFactory streamFactory = stateBackend.createStreamFactory(
+				new JobID(),
+				"test_op");
+
+		return operator.snapshotState(checkpointId, timestamp, streamFactory);
+	}
+
+	/**
+	 *
+	 */
+	@Deprecated
+	public StreamStateHandle snapshotLegacy(long checkpointId, long timestamp) throws Exception {
+
 		CheckpointStreamFactory.CheckpointStateOutputStream outStream = stateBackend.createStreamFactory(
 				new JobID(),
 				"test_op").createCheckpointStateOutputStream(checkpointId, timestamp);
@@ -244,6 +299,7 @@ public class OneInputStreamOperatorTestHarness<IN, OUT> {
 	/**
 	 *
 	 */
+	@Deprecated
 	public void restore(StreamStateHandle snapshot) throws Exception {
 		if(operator instanceof StreamCheckpointedOperator) {
 			try (FSDataInputStream in = snapshot.openInputStream()) {

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/TwoInputStreamOperatorTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/TwoInputStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/TwoInputStreamOperatorTestHarness.java
index 32b4c77..7df6848 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/TwoInputStreamOperatorTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/TwoInputStreamOperatorTestHarness.java
@@ -25,12 +25,14 @@ import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.operators.testutils.MockEnvironment;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
+import org.apache.flink.runtime.state.ClosableRegistry;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.Output;
 import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
 
 import java.util.concurrent.ConcurrentLinkedQueue;
@@ -56,6 +58,10 @@ public class TwoInputStreamOperatorTestHarness<IN1, IN2, OUT> {
 
 	final Object checkpointLock;
 
+	final ClosableRegistry closableRegistry;
+
+	boolean initializeCalled = false;
+
 	public TwoInputStreamOperatorTestHarness(TwoInputStreamOperator<IN1, IN2, OUT> operator) {
 		this(operator, new StreamConfig(new Configuration()));
 	}
@@ -65,6 +71,7 @@ public class TwoInputStreamOperatorTestHarness<IN1, IN2, OUT> {
 		this.outputList = new ConcurrentLinkedQueue<Object>();
 		this.executionConfig = new ExecutionConfig();
 		this.checkpointLock = new Object();
+		this.closableRegistry = new ClosableRegistry();
 
 		Environment env = new MockEnvironment("MockTwoInputTask", 3 * 1024 * 1024, new MockInputSplitProvider(), 1024);
 		StreamTask<?, ?> mockTask = mock(StreamTask.class);
@@ -73,6 +80,7 @@ public class TwoInputStreamOperatorTestHarness<IN1, IN2, OUT> {
 		when(mockTask.getConfiguration()).thenReturn(config);
 		when(mockTask.getEnvironment()).thenReturn(env);
 		when(mockTask.getExecutionConfig()).thenReturn(executionConfig);
+		when(mockTask.getCancelables()).thenReturn(this.closableRegistry);
 
 		operator.setup(mockTask, new StreamConfig(new Configuration()), new MockOutput());
 	}
@@ -86,11 +94,22 @@ public class TwoInputStreamOperatorTestHarness<IN1, IN2, OUT> {
 		return outputList;
 	}
 
+	/**
+	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#initializeState(OperatorStateHandles)}.
+	 */
+	public void initializeState(OperatorStateHandles operatorStateHandles) throws Exception {
+		operator.initializeState(operatorStateHandles);
+		initializeCalled = true;
+	}
 
 	/**
 	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#open()}.
 	 */
 	public void open() throws Exception {
+		if(!initializeCalled) {
+			initializeState(mock(OperatorStateHandles.class));
+		}
+
 		operator.open();
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/WindowingTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/WindowingTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/WindowingTestHarness.java
index ab8b70f..a4e26f0 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/WindowingTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/WindowingTestHarness.java
@@ -165,7 +165,7 @@ public class WindowingTestHarness<K, IN, W extends Window> {
 	 * Takes a snapshot of the current state of the operator. This can be used to test fault-tolerance.
 	 */
 	public StreamStateHandle snapshot(long checkpointId, long timestamp) throws Exception {
-		return testHarness.snapshot(checkpointId, timestamp);
+		return testHarness.snapshotLegacy(checkpointId, timestamp);
 	}
 
 	/**