You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by sr...@apache.org on 2018/01/22 13:08:47 UTC

[3/6] flink git commit: [FLINK-7720] [checkpoints] Centralize creation of backends and state related resources

http://git-wip-us.apache.org/repos/asf/flink/blob/517b3f87/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java
index 42e1197..ad23303 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java
@@ -18,8 +18,6 @@
 
 package org.apache.flink.streaming.api.operators.async;
 
-import org.apache.flink.api.common.ExecutionConfig;
-import org.apache.flink.api.common.TaskInfo;
 import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.common.functions.RichMapFunction;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
@@ -35,11 +33,9 @@ import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.jobgraph.OperatorID;
-import org.apache.flink.runtime.metrics.groups.TaskMetricGroup;
-import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups;
+import org.apache.flink.runtime.operators.testutils.MockEnvironment;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
-import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
-import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo;
+import org.apache.flink.runtime.state.TestTaskStateManager;
 import org.apache.flink.streaming.api.datastream.AsyncDataStream;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
@@ -54,7 +50,6 @@ import org.apache.flink.streaming.api.operators.async.queue.StreamElementQueueEn
 import org.apache.flink.streaming.api.operators.async.queue.StreamRecordQueueEntry;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
-import org.apache.flink.streaming.runtime.tasks.AcknowledgeStreamMockEnvironment;
 import org.apache.flink.streaming.runtime.tasks.OneInputStreamTask;
 import org.apache.flink.streaming.runtime.tasks.OneInputStreamTaskTestHarness;
 import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
@@ -99,11 +94,8 @@ import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyLong;
 import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.doAnswer;
-import static org.mockito.Mockito.doNothing;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.never;
-import static org.mockito.Mockito.timeout;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
@@ -512,15 +504,10 @@ public class AsyncWaitOperatorTest extends TestLogger {
 		streamConfig.setStreamOperator(operator);
 		streamConfig.setOperatorID(operatorID);
 
-		final AcknowledgeStreamMockEnvironment env = new AcknowledgeStreamMockEnvironment(
-				testHarness.jobConfig,
-				testHarness.taskConfig,
-				testHarness.getExecutionConfig(),
-				testHarness.memorySize,
-				new MockInputSplitProvider(),
-				testHarness.bufferSize);
+		final TestTaskStateManager taskStateManagerMock = testHarness.getTaskStateManager();
+		taskStateManagerMock.setWaitForReportLatch(new OneShotLatch());
 
-		testHarness.invoke(env);
+		testHarness.invoke();
 		testHarness.waitForTaskRunning();
 
 		final OneInputStreamTask<Integer, Integer> task = testHarness.getTask();
@@ -541,9 +528,9 @@ public class AsyncWaitOperatorTest extends TestLogger {
 
 		task.triggerCheckpoint(checkpointMetaData, CheckpointOptions.forCheckpoint());
 
-		env.getCheckpointLatch().await();
+		taskStateManagerMock.getWaitForReportLatch().await();
 
-		assertEquals(checkpointId, env.getCheckpointId());
+		assertEquals(checkpointId, taskStateManagerMock.getReportedCheckpointId());
 
 		LazyAsyncFunction.countDown();
 
@@ -551,12 +538,14 @@ public class AsyncWaitOperatorTest extends TestLogger {
 		testHarness.waitForTaskCompletion();
 
 		// set the operator state from previous attempt into the restored one
-		TaskStateSnapshot subtaskStates = env.getCheckpointStateHandles();
+		TaskStateSnapshot subtaskStates = taskStateManagerMock.getLastTaskStateSnapshot();
 
 		final OneInputStreamTaskTestHarness<Integer, Integer> restoredTaskHarness =
 				new OneInputStreamTaskTestHarness<>(
 						OneInputStreamTask::new,
 						BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO);
+
+		restoredTaskHarness.setTaskStateSnapshot(checkpointId, subtaskStates);
 		restoredTaskHarness.setupOutputForSingletonOperatorChain();
 
 		AsyncWaitOperator<Integer, Integer> restoredOperator = new AsyncWaitOperator<>(
@@ -568,7 +557,7 @@ public class AsyncWaitOperatorTest extends TestLogger {
 		restoredTaskHarness.getStreamConfig().setStreamOperator(restoredOperator);
 		restoredTaskHarness.getStreamConfig().setOperatorID(operatorID);
 
-		restoredTaskHarness.invoke(subtaskStates);
+		restoredTaskHarness.invoke();
 		restoredTaskHarness.waitForTaskRunning();
 
 		final OneInputStreamTask<Integer, Integer> restoredTask = restoredTaskHarness.getTask();
@@ -619,7 +608,7 @@ public class AsyncWaitOperatorTest extends TestLogger {
 			2,
 			AsyncDataStream.OutputMode.ORDERED);
 
-		final Environment mockEnvironment = createMockEnvironment();
+		final MockEnvironment mockEnvironment = createMockEnvironment();
 
 		final OneInputStreamOperatorTestHarness<Integer, Integer> testHarness =
 			new OneInputStreamOperatorTestHarness<>(operator, IntSerializer.INSTANCE, mockEnvironment);
@@ -654,9 +643,8 @@ public class AsyncWaitOperatorTest extends TestLogger {
 
 		ArgumentCaptor<Throwable> argumentCaptor = ArgumentCaptor.forClass(Throwable.class);
 
-		verify(mockEnvironment).failExternally(argumentCaptor.capture());
-
-		Throwable failureCause = argumentCaptor.getValue();
+		Throwable failureCause = mockEnvironment.getFailExternallyCause();
+		Assert.assertNotNull(failureCause);
 
 		Assert.assertNotNull(failureCause.getCause());
 		Assert.assertTrue(failureCause.getCause() instanceof ExecutionException);
@@ -666,22 +654,13 @@ public class AsyncWaitOperatorTest extends TestLogger {
 	}
 
 	@Nonnull
-	private Environment createMockEnvironment() {
-		final Environment mockEnvironment = mock(Environment.class);
-
-		final Configuration taskConfiguration = new Configuration();
-		final ExecutionConfig executionConfig = new ExecutionConfig();
-		final TaskMetricGroup metricGroup = UnregisteredMetricGroups.createUnregisteredTaskMetricGroup();
-		final TaskManagerRuntimeInfo taskManagerRuntimeInfo = new TestingTaskManagerRuntimeInfo();
-		final TaskInfo taskInfo = new TaskInfo("foobarTask", 1, 0, 1, 1);
-
-		when(mockEnvironment.getTaskConfiguration()).thenReturn(taskConfiguration);
-		when(mockEnvironment.getExecutionConfig()).thenReturn(executionConfig);
-		when(mockEnvironment.getMetricGroup()).thenReturn(metricGroup);
-		when(mockEnvironment.getTaskManagerInfo()).thenReturn(taskManagerRuntimeInfo);
-		when(mockEnvironment.getTaskInfo()).thenReturn(taskInfo);
-		when(mockEnvironment.getUserClassLoader()).thenReturn(Thread.currentThread().getContextClassLoader());
-		return mockEnvironment;
+	private MockEnvironment createMockEnvironment() {
+		return new MockEnvironment(
+			"foobarTask",
+			1024 * 1024L,
+			new MockInputSplitProvider(),
+			4 * 1024,
+			new TestTaskStateManager());
 	}
 
 	/**
@@ -698,8 +677,7 @@ public class AsyncWaitOperatorTest extends TestLogger {
 
 		ArgumentCaptor<Throwable> failureReason = ArgumentCaptor.forClass(Throwable.class);
 
-		Environment environment = createMockEnvironment();
-		doNothing().when(environment).failExternally(failureReason.capture());
+		MockEnvironment environment = createMockEnvironment();
 
 		StreamTask<?, ?> containingTask = mock(StreamTask.class);
 		when(containingTask.getEnvironment()).thenReturn(environment);
@@ -753,15 +731,7 @@ public class AsyncWaitOperatorTest extends TestLogger {
 			operator.close();
 		}
 
-		// check that no concurrent exception has occurred
-		try {
-			verify(environment, never()).failExternally(any(Throwable.class));
-		} catch (Error e) {
-			// add the exception occurring in the emitter thread (root cause) as a suppressed
-			// exception
-			e.addSuppressed(failureReason.getValue());
-			throw e;
-		}
+		Assert.assertNull(environment.getFailExternallyCause());
 	}
 
 	/**
@@ -896,7 +866,7 @@ public class AsyncWaitOperatorTest extends TestLogger {
 			2,
 			outputMode);
 
-		final Environment mockEnvironment = createMockEnvironment();
+		final MockEnvironment mockEnvironment = createMockEnvironment();
 
 		OneInputStreamOperatorTestHarness<Integer, Integer> harness = new OneInputStreamOperatorTestHarness<>(
 			asyncWaitOperator,
@@ -909,11 +879,11 @@ public class AsyncWaitOperatorTest extends TestLogger {
 			harness.processElement(1, 1L);
 		}
 
-		verify(harness.getEnvironment(), timeout(timeout)).failExternally(any(Exception.class));
-
 		synchronized (harness.getCheckpointLock()) {
 			harness.close();
 		}
+
+		Assert.assertNotNull(harness.getEnvironment().getFailExternallyCause());
 	}
 
 	/**
@@ -961,12 +931,12 @@ public class AsyncWaitOperatorTest extends TestLogger {
 			2,
 			outputMode);
 
-		final Environment mockenvironment = createMockEnvironment();
+		final MockEnvironment mockEnvironment = createMockEnvironment();
 
 		OneInputStreamOperatorTestHarness<Integer, Integer> harness = new OneInputStreamOperatorTestHarness<>(
 			asyncWaitOperator,
 			IntSerializer.INSTANCE,
-			mockenvironment);
+			mockEnvironment);
 
 		harness.open();
 
@@ -976,11 +946,11 @@ public class AsyncWaitOperatorTest extends TestLogger {
 
 		harness.setProcessingTime(10L);
 
-		verify(harness.getEnvironment(), timeout(100L * timeout)).failExternally(any(Exception.class));
-
 		synchronized (harness.getCheckpointLock()) {
 			harness.close();
 		}
+
+		Assert.assertNotNull(mockEnvironment.getFailExternallyCause());
 	}
 
 	/**

http://git-wip-us.apache.org/repos/asf/flink/blob/517b3f87/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java
index b237373..3cf5248 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java
@@ -26,6 +26,7 @@ import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.operators.testutils.MockEnvironment;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
+import org.apache.flink.runtime.state.TestTaskStateManager;
 import org.apache.flink.streaming.api.collector.selector.OutputSelector;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.datastream.SplitStream;
@@ -170,7 +171,12 @@ public class StreamOperatorChainingTest {
 	}
 
 	private MockEnvironment createMockEnvironment(String taskName) {
-		return new MockEnvironment(taskName, 3 * 1024 * 1024, new MockInputSplitProvider(), 1024);
+		return new MockEnvironment(
+			taskName,
+			3 * 1024 * 1024,
+			new MockInputSplitProvider(),
+			1024,
+			new TestTaskStateManager());
 	}
 
 	@Test

http://git-wip-us.apache.org/repos/asf/flink/blob/517b3f87/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/TestProcessingTimeServiceTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/TestProcessingTimeServiceTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/TestProcessingTimeServiceTest.java
index 008c848..1df972c 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/TestProcessingTimeServiceTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/TestProcessingTimeServiceTest.java
@@ -44,7 +44,7 @@ public class TestProcessingTimeServiceTest {
 		final TestProcessingTimeService tp = new TestProcessingTimeService();
 
 		final OneInputStreamTaskTestHarness<String, String> testHarness = new OneInputStreamTaskTestHarness<>(
-				(env, state) -> new OneInputStreamTask<>(env, state, tp),
+				(env) -> new OneInputStreamTask<>(env, tp),
 				BasicTypeInfo.STRING_TYPE_INFO,
 				BasicTypeInfo.STRING_TYPE_INFO);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/517b3f87/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/AcknowledgeStreamMockEnvironment.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/AcknowledgeStreamMockEnvironment.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/AcknowledgeStreamMockEnvironment.java
index c5983ca..8941cc1 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/AcknowledgeStreamMockEnvironment.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/AcknowledgeStreamMockEnvironment.java
@@ -24,6 +24,7 @@ import org.apache.flink.core.testutils.OneShotLatch;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
+import org.apache.flink.runtime.state.TaskStateManager;
 
 /**
  * Stream environment that allows to wait for checkpoint acknowledgement.
@@ -39,8 +40,16 @@ public class AcknowledgeStreamMockEnvironment extends StreamMockEnvironment {
 		ExecutionConfig executionConfig,
 		long memorySize,
 		MockInputSplitProvider inputSplitProvider,
-		int bufferSize) {
-		super(jobConfig, taskConfig, executionConfig, memorySize, inputSplitProvider, bufferSize);
+		int bufferSize,
+		TaskStateManager taskStateManager) {
+		super(
+			jobConfig,
+			taskConfig,
+			executionConfig,
+			memorySize,
+			inputSplitProvider,
+			bufferSize,
+			taskStateManager);
 	}
 
 	public long getCheckpointId() {

http://git-wip-us.apache.org/repos/asf/flink/blob/517b3f87/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/CheckpointExceptionHandlerConfigurationTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/CheckpointExceptionHandlerConfigurationTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/CheckpointExceptionHandlerConfigurationTest.java
index e47dd0b..08cee55 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/CheckpointExceptionHandlerConfigurationTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/CheckpointExceptionHandlerConfigurationTest.java
@@ -22,6 +22,7 @@ import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
+import org.apache.flink.runtime.state.TestTaskStateManager;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
@@ -59,6 +60,7 @@ public class CheckpointExceptionHandlerConfigurationTest extends TestLogger {
 		final boolean expectedHandlerFlag = failOnException;
 
 		final DummyEnvironment environment = new DummyEnvironment("test", 1, 0);
+		environment.setTaskStateManager(new TestTaskStateManager());
 		environment.getExecutionConfig().setFailTaskOnCheckpointError(expectedHandlerFlag);
 
 		final CheckpointExceptionHandlerFactory inspectingFactory = new CheckpointExceptionHandlerFactory() {

http://git-wip-us.apache.org/repos/asf/flink/blob/517b3f87/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
index 5ffaa29..499dfb1 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
@@ -28,8 +28,8 @@ import org.apache.flink.runtime.blob.BlobCacheService;
 import org.apache.flink.runtime.blob.PermanentBlobCache;
 import org.apache.flink.runtime.blob.TransientBlobCache;
 import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
-import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
 import org.apache.flink.runtime.checkpoint.JobManagerTaskRestore;
+import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
 import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
@@ -62,6 +62,7 @@ import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateInitializationContext;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.TestTaskStateManager;
 import org.apache.flink.runtime.taskmanager.CheckpointResponder;
 import org.apache.flink.runtime.taskmanager.Task;
 import org.apache.flink.runtime.taskmanager.TaskManagerActions;
@@ -251,6 +252,13 @@ public class InterruptSensitiveRestoreTest {
 		BlobCacheService blobService =
 			new BlobCacheService(mock(PermanentBlobCache.class), mock(TransientBlobCache.class));
 
+		TestTaskStateManager taskStateManager = new TestTaskStateManager();
+		taskStateManager.setReportedCheckpointId(taskRestore.getRestoreCheckpointId());
+		taskStateManager.setTaskStateSnapshotsByCheckpointId(
+			Collections.singletonMap(
+				taskRestore.getRestoreCheckpointId(),
+				taskRestore.getTaskStateSnapshot()));
+
 		return new Task(
 			jobInformation,
 			taskInformation,
@@ -261,11 +269,11 @@ public class InterruptSensitiveRestoreTest {
 			Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
 			Collections.<InputGateDeploymentDescriptor>emptyList(),
 			0,
-			taskRestore,
 			mock(MemoryManager.class),
 			mock(IOManager.class),
 			networkEnvironment,
 			mock(BroadcastVariableManager.class),
+			taskStateManager,
 			mock(TaskManagerActions.class),
 			mock(InputSplitProvider.class),
 			mock(CheckpointResponder.class),

http://git-wip-us.apache.org/repos/asf/flink/blob/517b3f87/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
index c2c2553..b8acf8b 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
@@ -27,15 +27,16 @@ import org.apache.flink.api.common.typeutils.base.IntSerializer;
 import org.apache.flink.api.common.typeutils.base.StringSerializer;
 import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.testutils.OneShotLatch;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.jobgraph.OperatorID;
-import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
 import org.apache.flink.runtime.state.StateInitializationContext;
 import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.runtime.state.TestTaskStateManager;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.graph.StreamEdge;
 import org.apache.flink.streaming.api.graph.StreamNode;
@@ -496,19 +497,15 @@ public class OneInputStreamTaskTest extends TestLogger {
 		StreamConfig streamConfig = testHarness.getStreamConfig();
 
 		configureChainedTestingStreamOperator(streamConfig, numberChainedTasks);
+		TestTaskStateManager taskStateManager = testHarness.taskStateManager;
+		OneShotLatch waitForAcknowledgeLatch = new OneShotLatch();
 
-		AcknowledgeStreamMockEnvironment env = new AcknowledgeStreamMockEnvironment(
-			testHarness.jobConfig,
-			testHarness.taskConfig,
-			testHarness.executionConfig,
-			testHarness.memorySize,
-			new MockInputSplitProvider(),
-			testHarness.bufferSize);
+		taskStateManager.setWaitForReportLatch(waitForAcknowledgeLatch);
 
 		// reset number of restore calls
 		TestingStreamOperator.numberRestoreCalls = 0;
 
-		testHarness.invoke(env);
+		testHarness.invoke();
 		testHarness.waitForTaskRunning(deadline.timeLeft().toMillis());
 
 		final OneInputStreamTask<String, String> streamTask = testHarness.getTask();
@@ -520,9 +517,9 @@ public class OneInputStreamTaskTest extends TestLogger {
 		// since no state was set, there shouldn't be restore calls
 		assertEquals(0, TestingStreamOperator.numberRestoreCalls);
 
-		env.getCheckpointLatch().await();
+		waitForAcknowledgeLatch.await();
 
-		assertEquals(checkpointId, env.getCheckpointId());
+		assertEquals(checkpointId, taskStateManager.getReportedCheckpointId());
 
 		testHarness.endInput();
 		testHarness.waitForTaskCompletion(deadline.timeLeft().toMillis());
@@ -533,16 +530,21 @@ public class OneInputStreamTaskTest extends TestLogger {
 
 		restoredTaskHarness.configureForKeyedStream(keySelector, BasicTypeInfo.STRING_TYPE_INFO);
 
+		restoredTaskHarness.setTaskStateSnapshot(checkpointId, taskStateManager.getLastTaskStateSnapshot());
+
 		StreamConfig restoredTaskStreamConfig = restoredTaskHarness.getStreamConfig();
 
 		configureChainedTestingStreamOperator(restoredTaskStreamConfig, numberChainedTasks);
 
-		TaskStateSnapshot stateHandles = env.getCheckpointStateHandles();
+		TaskStateSnapshot stateHandles = taskStateManager.getLastTaskStateSnapshot();
 		Assert.assertEquals(numberChainedTasks, stateHandles.getSubtaskStateMappings().size());
 
 		TestingStreamOperator.numberRestoreCalls = 0;
 
-		restoredTaskHarness.invoke(stateHandles);
+		// transfer state to new harness
+		restoredTaskHarness.taskStateManager.restoreLatestCheckpointState(
+			taskStateManager.getTaskStateSnapshotsByCheckpointId());
+		restoredTaskHarness.invoke();
 		restoredTaskHarness.endInput();
 		restoredTaskHarness.waitForTaskCompletion(deadline.timeLeft().toMillis());
 
@@ -679,28 +681,6 @@ public class OneInputStreamTaskTest extends TestLogger {
 		public static int numberSnapshotCalls = 0;
 
 		@Override
-		public void open() throws Exception {
-			super.open();
-
-			ListState<Integer> partitionableState = getOperatorStateBackend().getListState(TEST_DESCRIPTOR);
-
-			if (numberSnapshotCalls == 0) {
-				for (Integer v : partitionableState.get()) {
-					fail();
-				}
-			} else {
-				Set<Integer> result = new HashSet<>();
-				for (Integer v : partitionableState.get()) {
-					result.add(v);
-				}
-
-				assertEquals(2, result.size());
-				assertTrue(result.contains(42));
-				assertTrue(result.contains(4711));
-			}
-		}
-
-		@Override
 		public void snapshotState(StateSnapshotContext context) throws Exception {
 			ListState<Integer> partitionableState =
 				getOperatorStateBackend().getListState(TEST_DESCRIPTOR);
@@ -717,6 +697,23 @@ public class OneInputStreamTaskTest extends TestLogger {
 			if (context.isRestored()) {
 				++numberRestoreCalls;
 			}
+
+			ListState<Integer> partitionableState = context.getOperatorStateStore().getListState(TEST_DESCRIPTOR);
+
+			if (numberSnapshotCalls == 0) {
+				for (Integer v : partitionableState.get()) {
+					fail();
+				}
+			} else {
+				Set<Integer> result = new HashSet<>();
+				for (Integer v : partitionableState.get()) {
+					result.add(v);
+				}
+
+				assertEquals(2, result.size());
+				assertTrue(result.contains(42));
+				assertTrue(result.contains(4711));
+			}
 		}
 
 		@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/517b3f87/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTestHarness.java
index 357d629..89a4f81 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTestHarness.java
@@ -22,13 +22,12 @@ import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.ClosureCleaner;
 import org.apache.flink.api.java.functions.KeySelector;
-import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.event.AbstractEvent;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.io.network.partition.consumer.StreamTestSingleInputGate;
 
 import java.io.IOException;
-import java.util.function.BiFunction;
+import java.util.function.Function;
 
 
 /**
@@ -60,7 +59,7 @@ public class OneInputStreamTaskTestHarness<IN, OUT> extends StreamTaskTestHarnes
 	 * of channels per input gate.
 	 */
 	public OneInputStreamTaskTestHarness(
-			BiFunction<Environment, TaskStateSnapshot, ? extends OneInputStreamTask<IN, OUT>> taskFactory,
+			Function<Environment, ? extends StreamTask<OUT, ?>> taskFactory,
 			int numInputGates,
 			int numInputChannelsPerGate,
 			TypeInformation<IN> inputType,
@@ -79,7 +78,7 @@ public class OneInputStreamTaskTestHarness<IN, OUT> extends StreamTaskTestHarnes
 	 * Creates a test harness with one input gate that has one input channel.
 	 */
 	public OneInputStreamTaskTestHarness(
-			BiFunction<Environment, TaskStateSnapshot, ? extends OneInputStreamTask<IN, OUT>> taskFactory,
+			Function<Environment, ? extends StreamTask<OUT, ?>> taskFactory,
 			TypeInformation<IN> inputType,
 			TypeInformation<OUT> outputType) {
 

http://git-wip-us.apache.org/repos/asf/flink/blob/517b3f87/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/RestoreStreamTaskTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/RestoreStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/RestoreStreamTaskTest.java
index b241d05..23745e3 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/RestoreStreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/RestoreStreamTaskTest.java
@@ -23,8 +23,10 @@ import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.api.common.typeutils.base.LongSerializer;
 import org.apache.flink.api.common.typeutils.base.StringSerializer;
+import org.apache.flink.core.testutils.OneShotLatch;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.JobManagerTaskRestore;
 import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
 import org.apache.flink.runtime.checkpoint.StateAssignmentOperation;
 import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
@@ -33,6 +35,7 @@ import org.apache.flink.runtime.jobgraph.OperatorInstanceID;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
 import org.apache.flink.runtime.state.StateInitializationContext;
 import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.runtime.state.TestTaskStateManager;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
@@ -67,49 +70,53 @@ public class RestoreStreamTaskTest extends TestLogger {
 
 	@Test
 	public void testRestore() throws Exception {
+
 		OperatorID headOperatorID = new OperatorID(42L, 42L);
 		OperatorID tailOperatorID = new OperatorID(44L, 44L);
-		AcknowledgeStreamMockEnvironment environment1 = createRunAndCheckpointOperatorChain(
+
+		JobManagerTaskRestore restore = createRunAndCheckpointOperatorChain(
 			headOperatorID,
 			new CounterOperator(),
 			tailOperatorID,
 			new CounterOperator(),
 			Optional.empty());
 
-		assertEquals(2, environment1.getCheckpointStateHandles().getSubtaskStateMappings().size());
+		TaskStateSnapshot stateHandles = restore.getTaskStateSnapshot();
 
-		TaskStateSnapshot stateHandles = environment1.getCheckpointStateHandles();
+		assertEquals(2, stateHandles.getSubtaskStateMappings().size());
 
-		AcknowledgeStreamMockEnvironment environment2 = createRunAndCheckpointOperatorChain(
+		createRunAndCheckpointOperatorChain(
 			headOperatorID,
 			new CounterOperator(),
 			tailOperatorID,
 			new CounterOperator(),
-			Optional.of(stateHandles));
+			Optional.of(restore));
 
 		assertEquals(new HashSet<>(Arrays.asList(headOperatorID, tailOperatorID)), RESTORED_OPERATORS);
 	}
 
 	@Test
 	public void testRestoreHeadWithNewId() throws Exception {
+
 		OperatorID tailOperatorID = new OperatorID(44L, 44L);
-		AcknowledgeStreamMockEnvironment environment1 = createRunAndCheckpointOperatorChain(
+
+		JobManagerTaskRestore restore = createRunAndCheckpointOperatorChain(
 			new OperatorID(42L, 42L),
 			new CounterOperator(),
 			tailOperatorID,
 			new CounterOperator(),
 			Optional.empty());
 
-		assertEquals(2, environment1.getCheckpointStateHandles().getSubtaskStateMappings().size());
+		TaskStateSnapshot stateHandles = restore.getTaskStateSnapshot();
 
-		TaskStateSnapshot stateHandles = environment1.getCheckpointStateHandles();
+		assertEquals(2, stateHandles.getSubtaskStateMappings().size());
 
-		AcknowledgeStreamMockEnvironment environment2 = createRunAndCheckpointOperatorChain(
+		createRunAndCheckpointOperatorChain(
 			new OperatorID(4242L, 4242L),
 			new CounterOperator(),
 			tailOperatorID,
 			new CounterOperator(),
-			Optional.of(stateHandles));
+			Optional.of(restore));
 
 		assertEquals(Collections.singleton(tailOperatorID), RESTORED_OPERATORS);
 	}
@@ -118,23 +125,22 @@ public class RestoreStreamTaskTest extends TestLogger {
 	public void testRestoreTailWithNewId() throws Exception {
 		OperatorID headOperatorID = new OperatorID(42L, 42L);
 
-		AcknowledgeStreamMockEnvironment environment1 = createRunAndCheckpointOperatorChain(
+		JobManagerTaskRestore restore = createRunAndCheckpointOperatorChain(
 			headOperatorID,
 			new CounterOperator(),
 			new OperatorID(44L, 44L),
 			new CounterOperator(),
 			Optional.empty());
 
-		assertEquals(2, environment1.getCheckpointStateHandles().getSubtaskStateMappings().size());
+		TaskStateSnapshot stateHandles = restore.getTaskStateSnapshot();
+		assertEquals(2, stateHandles.getSubtaskStateMappings().size());
 
-		TaskStateSnapshot stateHandles = environment1.getCheckpointStateHandles();
-
-		AcknowledgeStreamMockEnvironment environment2 = createRunAndCheckpointOperatorChain(
+		createRunAndCheckpointOperatorChain(
 			headOperatorID,
 			new CounterOperator(),
 			new OperatorID(4444L, 4444L),
 			new CounterOperator(),
-			Optional.of(stateHandles));
+			Optional.of(restore));
 
 		assertEquals(Collections.singleton(headOperatorID), RESTORED_OPERATORS);
 	}
@@ -144,14 +150,16 @@ public class RestoreStreamTaskTest extends TestLogger {
 		OperatorID headOperatorID = new OperatorID(42L, 42L);
 		OperatorID tailOperatorID = new OperatorID(44L, 44L);
 
-		AcknowledgeStreamMockEnvironment environment1 = createRunAndCheckpointOperatorChain(
+		JobManagerTaskRestore restore = createRunAndCheckpointOperatorChain(
 			headOperatorID,
 			new CounterOperator(),
 			tailOperatorID,
 			new CounterOperator(),
 			Optional.empty());
 
-		assertEquals(2, environment1.getCheckpointStateHandles().getSubtaskStateMappings().size());
+		TaskStateSnapshot stateHandles = restore.getTaskStateSnapshot();
+
+		assertEquals(2, stateHandles.getSubtaskStateMappings().size());
 
 		// test empty state in case of scale up
 		OperatorSubtaskState emptyHeadOperatorState = StateAssignmentOperation.operatorSubtaskStateFrom(
@@ -161,15 +169,14 @@ public class RestoreStreamTaskTest extends TestLogger {
 			Collections.emptyMap(),
 			Collections.emptyMap());
 
-		TaskStateSnapshot stateHandles = environment1.getCheckpointStateHandles();
 		stateHandles.putSubtaskStateByOperatorID(headOperatorID, emptyHeadOperatorState);
 
-		AcknowledgeStreamMockEnvironment environment2 = createRunAndCheckpointOperatorChain(
+		createRunAndCheckpointOperatorChain(
 			headOperatorID,
 			new CounterOperator(),
 			tailOperatorID,
 			new CounterOperator(),
-			Optional.of(stateHandles));
+			Optional.of(restore));
 
 		assertEquals(new HashSet<>(Arrays.asList(headOperatorID, tailOperatorID)), RESTORED_OPERATORS);
 	}
@@ -179,33 +186,32 @@ public class RestoreStreamTaskTest extends TestLogger {
 		OperatorID headOperatorID = new OperatorID(42L, 42L);
 		OperatorID tailOperatorID = new OperatorID(44L, 44L);
 
-		AcknowledgeStreamMockEnvironment environment1 = createRunAndCheckpointOperatorChain(
+		JobManagerTaskRestore restore = createRunAndCheckpointOperatorChain(
 			headOperatorID,
 			new StatelessOperator(),
 			tailOperatorID,
 			new CounterOperator(),
 			Optional.empty());
 
-		assertEquals(2, environment1.getCheckpointStateHandles().getSubtaskStateMappings().size());
+		TaskStateSnapshot stateHandles = restore.getTaskStateSnapshot();
+		assertEquals(2, stateHandles.getSubtaskStateMappings().size());
 
-		TaskStateSnapshot stateHandles = environment1.getCheckpointStateHandles();
-
-		AcknowledgeStreamMockEnvironment environment2 = createRunAndCheckpointOperatorChain(
+		createRunAndCheckpointOperatorChain(
 			headOperatorID,
 			new StatelessOperator(),
 			tailOperatorID,
 			new CounterOperator(),
-			Optional.of(stateHandles));
+			Optional.of(restore));
 
 		assertEquals(new HashSet<>(Arrays.asList(headOperatorID, tailOperatorID)), RESTORED_OPERATORS);
 	}
 
-	private AcknowledgeStreamMockEnvironment createRunAndCheckpointOperatorChain(
-			OperatorID headId,
-			OneInputStreamOperator<String, String> headOperator,
-			OperatorID tailId,
-			OneInputStreamOperator<String, String> tailOperator,
-			Optional<TaskStateSnapshot> stateHandles) throws Exception {
+	private JobManagerTaskRestore createRunAndCheckpointOperatorChain(
+		OperatorID headId,
+		OneInputStreamOperator<String, String> headOperator,
+		OperatorID tailId,
+		OneInputStreamOperator<String, String> tailOperator,
+		Optional<JobManagerTaskRestore> restore) throws Exception {
 
 		final OneInputStreamTaskTestHarness<String, String> testHarness =
 			new OneInputStreamTaskTestHarness<>(
@@ -218,39 +224,56 @@ public class RestoreStreamTaskTest extends TestLogger {
 			.chain(tailId, tailOperator, StringSerializer.INSTANCE)
 			.finish();
 
-		AcknowledgeStreamMockEnvironment environment = new AcknowledgeStreamMockEnvironment(
+		if (restore.isPresent()) {
+			JobManagerTaskRestore taskRestore = restore.get();
+			testHarness.setTaskStateSnapshot(
+				taskRestore.getRestoreCheckpointId(),
+				taskRestore.getTaskStateSnapshot());
+		}
+
+		StreamMockEnvironment environment = new StreamMockEnvironment(
 			testHarness.jobConfig,
 			testHarness.taskConfig,
 			testHarness.executionConfig,
 			testHarness.memorySize,
 			new MockInputSplitProvider(),
-			testHarness.bufferSize);
+			testHarness.bufferSize,
+			testHarness.taskStateManager);
 
-		testHarness.invoke(environment, stateHandles.orElse(null));
+		testHarness.invoke(environment);
 		testHarness.waitForTaskRunning();
 
 		OneInputStreamTask<String, String> streamTask = testHarness.getTask();
 
 		processRecords(testHarness);
-		triggerCheckpoint(testHarness, environment, streamTask);
+		triggerCheckpoint(testHarness, streamTask);
+
+		TestTaskStateManager taskStateManager = testHarness.taskStateManager;
+
+		JobManagerTaskRestore jobManagerTaskRestore = new JobManagerTaskRestore(
+			taskStateManager.getReportedCheckpointId(),
+			taskStateManager.getLastTaskStateSnapshot());
 
 		testHarness.endInput();
 		testHarness.waitForTaskCompletion();
-
-		return environment;
+		return jobManagerTaskRestore;
 	}
 
 	private void triggerCheckpoint(
 			OneInputStreamTaskTestHarness<String, String> testHarness,
-			AcknowledgeStreamMockEnvironment environment,
 			OneInputStreamTask<String, String> streamTask) throws Exception {
+
 		long checkpointId = 1L;
 		CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 1L);
 
+		testHarness.taskStateManager.setWaitForReportLatch(new OneShotLatch());
+
 		while (!streamTask.triggerCheckpoint(checkpointMetaData, CheckpointOptions.forCheckpoint())) {}
 
-		environment.getCheckpointLatch().await();
-		assertEquals(checkpointId, environment.getCheckpointId());
+		testHarness.taskStateManager.getWaitForReportLatch().await();
+		long reportedCheckpointId = testHarness.taskStateManager.getReportedCheckpointId();
+
+		assertEquals(checkpointId, reportedCheckpointId);
 	}
 
 	private void processRecords(OneInputStreamTaskTestHarness<String, String> testHarness) throws Exception {

http://git-wip-us.apache.org/repos/asf/flink/blob/517b3f87/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskStoppingTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskStoppingTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskStoppingTest.java
index a58f436..52d8429 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskStoppingTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskStoppingTest.java
@@ -40,7 +40,7 @@ public class SourceStreamTaskStoppingTest {
 	@Test
 	public void testStop() {
 		final StoppableSourceStreamTask<Object, StoppableSource> sourceTask =
-				new StoppableSourceStreamTask<>(new DummyEnvironment("test", 1, 0), null);
+				new StoppableSourceStreamTask<>(new DummyEnvironment("test", 1, 0));
 
 		sourceTask.headOperator = new StoppableStreamSource<>(new StoppableSource());
 
@@ -53,7 +53,7 @@ public class SourceStreamTaskStoppingTest {
 	public void testStopBeforeInitialization() throws Exception {
 
 		final StoppableSourceStreamTask<Object, StoppableFailingSource> sourceTask =
-				new StoppableSourceStreamTask<>(new DummyEnvironment("test", 1, 0), null);
+				new StoppableSourceStreamTask<>(new DummyEnvironment("test", 1, 0));
 		sourceTask.stop();
 
 		sourceTask.headOperator = new StoppableStreamSource<>(new StoppableFailingSource());

http://git-wip-us.apache.org/repos/asf/flink/blob/517b3f87/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
index 71371f0..0ba081e 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
@@ -26,6 +26,7 @@ import org.apache.flink.configuration.Configuration;
 import org.apache.flink.core.fs.Path;
 import org.apache.flink.runtime.accumulators.AccumulatorRegistry;
 import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
+import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.execution.Environment;
@@ -45,8 +46,11 @@ import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
 import org.apache.flink.runtime.query.KvStateRegistry;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
+import org.apache.flink.runtime.state.TaskLocalStateStore;
+import org.apache.flink.runtime.state.TaskStateManager;
 import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
 import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo;
+import org.apache.flink.util.Preconditions;
 
 import java.util.Collection;
 import java.util.Collections;
@@ -79,7 +83,9 @@ public class StreamMockEnvironment implements Environment {
 
 	private final List<ResultPartitionWriter> outputs;
 
-	private final JobID jobID = new JobID();
+	private final JobID jobID;
+
+	private final ExecutionAttemptID executionAttemptID;
 
 	private final BroadcastVariableManager bcVarManager = new BroadcastVariableManager();
 
@@ -91,25 +97,60 @@ public class StreamMockEnvironment implements Environment {
 
 	private final ExecutionConfig executionConfig;
 
+	private final TaskStateManager taskStateManager;
+
 	private volatile boolean wasFailedExternally = false;
 
 	private TaskEventDispatcher taskEventDispatcher = mock(TaskEventDispatcher.class);
 
-	public StreamMockEnvironment(Configuration jobConfig, Configuration taskConfig, ExecutionConfig executionConfig,
-								long memorySize, MockInputSplitProvider inputSplitProvider, int bufferSize) {
+	public StreamMockEnvironment(
+		Configuration jobConfig,
+		Configuration taskConfig,
+		ExecutionConfig executionConfig,
+		long memorySize,
+		MockInputSplitProvider inputSplitProvider,
+		int bufferSize,
+		TaskStateManager taskStateManager) {
+		this(
+			new JobID(),
+			new ExecutionAttemptID(0L, 0L),
+			jobConfig,
+			taskConfig,
+			executionConfig,
+			memorySize,
+			inputSplitProvider,
+			bufferSize,
+			taskStateManager);
+	}
+
+	public StreamMockEnvironment(
+		JobID jobID,
+		ExecutionAttemptID executionAttemptID,
+		Configuration jobConfig,
+		Configuration taskConfig,
+		ExecutionConfig executionConfig,
+		long memorySize,
+		MockInputSplitProvider inputSplitProvider,
+		int bufferSize,
+		TaskStateManager taskStateManager) {
+
+		this.jobID = jobID;
+		this.executionAttemptID = executionAttemptID;
+
+		int subtaskIndex = 0;
 		this.taskInfo = new TaskInfo(
 			"", /* task name */
 			1, /* num key groups / max parallelism */
-			0, /* index of this subtask */
+			subtaskIndex, /* index of this subtask */
 			1, /* num subtasks */
 			0 /* attempt number */);
 		this.jobConfiguration = jobConfig;
 		this.taskConfiguration = taskConfig;
 		this.inputs = new LinkedList<InputGate>();
 		this.outputs = new LinkedList<ResultPartitionWriter>();
-
 		this.memManager = new MemoryManager(memorySize, 1);
 		this.ioManager = new IOManagerAsync();
+		this.taskStateManager = Preconditions.checkNotNull(taskStateManager);
 		this.inputSplitProvider = inputSplitProvider;
 		this.bufferSize = bufferSize;
 
@@ -118,11 +159,19 @@ public class StreamMockEnvironment implements Environment {
 
 		KvStateRegistry registry = new KvStateRegistry();
 		this.kvStateRegistry = registry.createTaskRegistry(jobID, getJobVertexId());
+
+		final TaskLocalStateStore localStateStore = new TaskLocalStateStore(jobID, getJobVertexId(), subtaskIndex);
 	}
 
-	public StreamMockEnvironment(Configuration jobConfig, Configuration taskConfig, long memorySize,
-								MockInputSplitProvider inputSplitProvider, int bufferSize) {
-		this(jobConfig, taskConfig, new ExecutionConfig(), memorySize, inputSplitProvider, bufferSize);
+	public StreamMockEnvironment(
+		Configuration jobConfig,
+		Configuration taskConfig,
+		long memorySize,
+		MockInputSplitProvider inputSplitProvider,
+		int bufferSize,
+		TaskStateManager taskStateManager) {
+
+		this(jobConfig, taskConfig, new ExecutionConfig(), memorySize, inputSplitProvider, bufferSize, taskStateManager);
 	}
 
 	public void addInputGate(InputGate gate) {
@@ -226,7 +275,7 @@ public class StreamMockEnvironment implements Environment {
 
 	@Override
 	public ExecutionAttemptID getExecutionId() {
-		return new ExecutionAttemptID(0L, 0L);
+		return executionAttemptID;
 	}
 
 	@Override
@@ -235,6 +284,11 @@ public class StreamMockEnvironment implements Environment {
 	}
 
 	@Override
+	public TaskStateManager getTaskStateManager() {
+		return taskStateManager;
+	}
+
+	@Override
 	public AccumulatorRegistry getAccumulatorRegistry() {
 		return accumulatorRegistry;
 	}
@@ -250,6 +304,10 @@ public class StreamMockEnvironment implements Environment {
 
 	@Override
 	public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, TaskStateSnapshot subtaskState) {
+		taskStateManager.reportStateHandles(
+			new CheckpointMetaData(checkpointId, 0L),
+			checkpointMetrics,
+			subtaskState);
 	}
 
 	@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/517b3f87/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskCancellationBarrierTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskCancellationBarrierTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskCancellationBarrierTest.java
index 6bd4acc..5b5bf70 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskCancellationBarrierTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskCancellationBarrierTest.java
@@ -22,7 +22,6 @@ import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
-import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.checkpoint.decline.CheckpointDeclineOnCancellationBarrierException;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker;
@@ -35,8 +34,6 @@ import org.apache.flink.streaming.api.operators.co.CoStreamMap;
 
 import org.junit.Test;
 
-import javax.annotation.Nullable;
-
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotNull;
@@ -177,8 +174,8 @@ public class StreamTaskCancellationBarrierTest {
 		private final Object lock = new Object();
 		private volatile boolean running = true;
 
-		protected InitBlockingTask(Environment env, @Nullable TaskStateSnapshot initialState) {
-			super(env, initialState);
+		protected InitBlockingTask(Environment env) {
+			super(env);
 		}
 
 		@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/517b3f87/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java
index 812cb56..3e4ccae 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java
@@ -28,7 +28,6 @@ import org.apache.flink.runtime.blob.PermanentBlobCache;
 import org.apache.flink.runtime.blob.TransientBlobCache;
 import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
-import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.concurrent.Executors;
 import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
@@ -60,6 +59,7 @@ import org.apache.flink.runtime.state.OperatorStateBackend;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateBackend;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.TestTaskStateManager;
 import org.apache.flink.runtime.taskmanager.CheckpointResponder;
 import org.apache.flink.runtime.taskmanager.Task;
 import org.apache.flink.runtime.taskmanager.TaskManagerActions;
@@ -156,11 +156,11 @@ public class StreamTaskTerminationTest extends TestLogger {
 			Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
 			Collections.<InputGateDeploymentDescriptor>emptyList(),
 			0,
-			null,
 			new MemoryManager(32L * 1024L, 1),
 			new IOManagerAsync(),
 			networkEnv,
 			mock(BroadcastVariableManager.class),
+			new TestTaskStateManager(),
 			mock(TaskManagerActions.class),
 			mock(InputSplitProvider.class),
 			mock(CheckpointResponder.class),
@@ -204,8 +204,8 @@ public class StreamTaskTerminationTest extends TestLogger {
 	 */
 	public static class BlockingStreamTask<T, OP extends StreamOperator<T>> extends StreamTask<T, OP> {
 
-		public BlockingStreamTask(Environment env, @Nullable TaskStateSnapshot initialState) {
-			super(env, initialState);
+		public BlockingStreamTask(Environment env) {
+			super(env);
 		}
 
 		@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/517b3f87/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
index d9a21a9..bcf44f2 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
@@ -36,7 +36,6 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
 import org.apache.flink.runtime.checkpoint.SubtaskState;
-import org.apache.flink.runtime.checkpoint.JobManagerTaskRestore;
 import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
@@ -68,11 +67,19 @@ import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.DoneFuture;
 import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyGroupStatePartitionStreamProvider;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateBackend;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateBackendFactory;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StatePartitionStreamProvider;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.TaskLocalStateStore;
+import org.apache.flink.runtime.state.TaskStateManager;
+import org.apache.flink.runtime.state.TaskStateManagerImpl;
+import org.apache.flink.runtime.state.TestTaskStateManager;
+import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.runtime.taskmanager.CheckpointResponder;
 import org.apache.flink.runtime.taskmanager.Task;
 import org.apache.flink.runtime.taskmanager.TaskExecutionState;
@@ -85,12 +92,16 @@ import org.apache.flink.streaming.api.TimeCharacteristic;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.InternalTimeServiceManager;
 import org.apache.flink.streaming.api.operators.OperatorSnapshotResult;
 import org.apache.flink.streaming.api.operators.Output;
 import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorStateContext;
 import org.apache.flink.streaming.api.operators.StreamSource;
+import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.streamstatus.StreamStatusMaintainer;
+import org.apache.flink.util.CloseableIterable;
 import org.apache.flink.util.ExceptionUtils;
 import org.apache.flink.util.SerializedValue;
 import org.apache.flink.util.TestLogger;
@@ -100,7 +111,6 @@ import org.junit.Assert;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.ArgumentCaptor;
-import org.mockito.Mockito;
 import org.mockito.internal.util.reflection.Whitebox;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
@@ -108,8 +118,6 @@ import org.powermock.core.classloader.annotations.PowerMockIgnore;
 import org.powermock.core.classloader.annotations.PrepareForTest;
 import org.powermock.modules.junit4.PowerMockRunner;
 
-import javax.annotation.Nullable;
-
 import java.io.Closeable;
 import java.io.IOException;
 import java.io.ObjectInputStream;
@@ -215,8 +223,10 @@ public class StreamTaskTest extends TestLogger {
 		taskManagerConfig.setString(CheckpointingOptions.STATE_BACKEND, MockStateBackend.class.getName());
 
 		StreamConfig cfg = new StreamConfig(new Configuration());
+		cfg.setStateKeySerializer(mock(TypeSerializer.class));
 		cfg.setOperatorID(new OperatorID(4711L, 42L));
-		cfg.setStreamOperator(new StreamSource<>(new MockSourceFunction()));
+		TestStreamSource<Long, MockSourceFunction> streamSource = new TestStreamSource<>(new MockSourceFunction());
+		cfg.setStreamOperator(streamSource);
 		cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
 
 		Task task = createTask(StateBackendTestSource.class, cfg, taskManagerConfig);
@@ -227,9 +237,14 @@ public class StreamTaskTest extends TestLogger {
 		// wait for clean termination
 		task.getExecutingThread().join();
 
-		// ensure that the state backends are closed
-		verify(StateBackendTestSource.operatorStateBackend).close();
-		verify(StateBackendTestSource.keyedStateBackend).close();
+		// ensure that the state backends and stream iterables are closed ...
+		verify(TestStreamSource.operatorStateBackend).close();
+		verify(TestStreamSource.keyedStateBackend).close();
+		verify(TestStreamSource.rawOperatorStateInputs).close();
+		verify(TestStreamSource.rawKeyedStateInputs).close();
+		// ... and disposed
+		verify(TestStreamSource.operatorStateBackend).dispose();
+		verify(TestStreamSource.keyedStateBackend).dispose();
 
 		assertEquals(ExecutionState.FINISHED, task.getExecutionState());
 	}
@@ -240,8 +255,10 @@ public class StreamTaskTest extends TestLogger {
 		taskManagerConfig.setString(CheckpointingOptions.STATE_BACKEND, MockStateBackend.class.getName());
 
 		StreamConfig cfg = new StreamConfig(new Configuration());
+		cfg.setStateKeySerializer(mock(TypeSerializer.class));
 		cfg.setOperatorID(new OperatorID(4711L, 42L));
-		cfg.setStreamOperator(new StreamSource<>(new MockSourceFunction()));
+		TestStreamSource<Long, MockSourceFunction> streamSource = new TestStreamSource<>(new MockSourceFunction());
+		cfg.setStreamOperator(streamSource);
 		cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
 
 		Task task = createTask(StateBackendTestSource.class, cfg, taskManagerConfig);
@@ -252,9 +269,14 @@ public class StreamTaskTest extends TestLogger {
 		// wait for clean termination
 		task.getExecutingThread().join();
 
-		// ensure that the state backends are closed
-		verify(StateBackendTestSource.operatorStateBackend).close();
-		verify(StateBackendTestSource.keyedStateBackend).close();
+		// ensure that the state backends and stream iterables are closed ...
+		verify(TestStreamSource.operatorStateBackend).close();
+		verify(TestStreamSource.keyedStateBackend).close();
+		verify(TestStreamSource.rawOperatorStateInputs).close();
+		verify(TestStreamSource.rawKeyedStateInputs).close();
+		// ... and disposed
+		verify(TestStreamSource.operatorStateBackend).dispose();
+		verify(TestStreamSource.keyedStateBackend).dispose();
 
 		assertEquals(ExecutionState.FAILED, task.getExecutionState());
 	}
@@ -463,6 +485,8 @@ public class StreamTaskTest extends TestLogger {
 		when(mockTaskInfo.getIndexOfThisSubtask()).thenReturn(0);
 		Environment mockEnvironment = mock(Environment.class);
 		when(mockEnvironment.getTaskInfo()).thenReturn(mockTaskInfo);
+
+		CheckpointResponder checkpointResponder = mock(CheckpointResponder.class);
 		doAnswer(new Answer() {
 			@Override
 			public Object answer(InvocationOnMock invocation) throws Throwable {
@@ -473,7 +497,21 @@ public class StreamTaskTest extends TestLogger {
 
 				return null;
 			}
-		}).when(mockEnvironment).acknowledgeCheckpoint(anyLong(), any(CheckpointMetrics.class), any(TaskStateSnapshot.class));
+		}).when(checkpointResponder).acknowledgeCheckpoint(
+			any(JobID.class),
+			any(ExecutionAttemptID.class),
+			anyLong(),
+			any(CheckpointMetrics.class),
+			any(TaskStateSnapshot.class));
+
+		TaskStateManager taskStateManager = new TaskStateManagerImpl(
+			new JobID(1L, 2L),
+			new ExecutionAttemptID(1L, 2L),
+			mock(TaskLocalStateStore.class),
+			null,
+			checkpointResponder);
+
+		when(mockEnvironment.getTaskStateManager()).thenReturn(taskStateManager);
 
 		StreamTask<?, ?> streamTask = new EmptyStreamTask(mockEnvironment);
 		CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, timestamp);
@@ -525,7 +563,12 @@ public class StreamTaskTest extends TestLogger {
 		ArgumentCaptor<TaskStateSnapshot> subtaskStateCaptor = ArgumentCaptor.forClass(TaskStateSnapshot.class);
 
 		// check that the checkpoint has been completed
-		verify(mockEnvironment).acknowledgeCheckpoint(eq(checkpointId), any(CheckpointMetrics.class), subtaskStateCaptor.capture());
+		verify(checkpointResponder).acknowledgeCheckpoint(
+			any(JobID.class),
+			any(ExecutionAttemptID.class),
+			eq(checkpointId),
+			any(CheckpointMetrics.class),
+			subtaskStateCaptor.capture());
 
 		TaskStateSnapshot subtaskStates = subtaskStateCaptor.getValue();
 		OperatorSubtaskState subtaskState = subtaskStates.getSubtaskStateMappings().iterator().next().getValue();
@@ -692,17 +735,30 @@ public class StreamTaskTest extends TestLogger {
 		final OneShotLatch checkpointCompletedLatch = new OneShotLatch();
 		final List<SubtaskState> checkpointResult = new ArrayList<>(1);
 
-		// we remember what is acknowledged (expected to be null as our task will snapshot empty states).
+		CheckpointResponder checkpointResponder = mock(CheckpointResponder.class);
 		doAnswer(new Answer() {
 			@Override
-			public Object answer(InvocationOnMock invocationOnMock) throws Throwable {
-				SubtaskState subtaskState = invocationOnMock.getArgumentAt(2, SubtaskState.class);
+			public Object answer(InvocationOnMock invocation) throws Throwable {
+				SubtaskState subtaskState = invocation.getArgumentAt(4, SubtaskState.class);
 				checkpointResult.add(subtaskState);
 				checkpointCompletedLatch.trigger();
 				return null;
 			}
-		}).when(mockEnvironment).acknowledgeCheckpoint(anyLong(), any(CheckpointMetrics.class), any(TaskStateSnapshot.class));
-
+		}).when(checkpointResponder).acknowledgeCheckpoint(
+			any(JobID.class),
+			any(ExecutionAttemptID.class),
+			anyLong(),
+			any(CheckpointMetrics.class),
+			any(TaskStateSnapshot.class));
+
+		TaskStateManager taskStateManager = new TaskStateManagerImpl(
+			new JobID(1L, 2L),
+			new ExecutionAttemptID(1L, 2L),
+			mock(TaskLocalStateStore.class),
+			null,
+			checkpointResponder);
+
+		when(mockEnvironment.getTaskStateManager()).thenReturn(taskStateManager);
 		when(mockEnvironment.getTaskInfo()).thenReturn(mockTaskInfo);
 
 		StreamTask<?, ?> streamTask = new EmptyStreamTask(mockEnvironment);
@@ -759,7 +815,8 @@ public class StreamTaskTest extends TestLogger {
 				new MockInputSplitProvider(),
 				1,
 				taskConfiguration,
-				new ExecutionConfig())) {
+				new ExecutionConfig(),
+				new TestTaskStateManager())) {
 			StreamTask<Void, BlockingCloseStreamOperator> streamTask = new NoOpStreamTask<>(mockEnvironment);
 			final AtomicReference<Throwable> atomicThrowable = new AtomicReference<>(null);
 
@@ -922,11 +979,11 @@ public class StreamTaskTest extends TestLogger {
 			Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
 			Collections.<InputGateDeploymentDescriptor>emptyList(),
 			0,
-			new JobManagerTaskRestore(1L, null),
 			mock(MemoryManager.class),
 			mock(IOManager.class),
 			network,
 			mock(BroadcastVariableManager.class),
+			new TestTaskStateManager(),
 			mock(TaskManagerActions.class),
 			mock(InputSplitProvider.class),
 			mock(CheckpointResponder.class),
@@ -1001,40 +1058,17 @@ public class StreamTaskTest extends TestLogger {
 
 		@Override
 		public AbstractStateBackend createFromConfig(Configuration config) {
-			AbstractStateBackend stateBackendMock = mock(AbstractStateBackend.class);
-
-			try {
-				Mockito.when(stateBackendMock.createOperatorStateBackend(
-						Mockito.any(Environment.class),
-						Mockito.any(String.class)))
-					.thenAnswer(new Answer<OperatorStateBackend>() {
-						@Override
-						public OperatorStateBackend answer(InvocationOnMock invocationOnMock) throws Throwable {
-							return Mockito.mock(OperatorStateBackend.class);
-						}
-					});
-
-				Mockito.when(stateBackendMock.createKeyedStateBackend(
-						Mockito.any(Environment.class),
-						Mockito.any(JobID.class),
-						Mockito.any(String.class),
-						Mockito.any(TypeSerializer.class),
-						Mockito.any(int.class),
-						Mockito.any(KeyGroupRange.class),
-						Mockito.any(TaskKvStateRegistry.class)))
-					.thenAnswer(new Answer<AbstractKeyedStateBackend>() {
-						@Override
-						public AbstractKeyedStateBackend answer(InvocationOnMock invocationOnMock) throws Throwable {
-							return Mockito.mock(AbstractKeyedStateBackend.class);
-						}
-					});
-			}
-			catch (Exception e) {
-				// this is needed, because the signatures of the mocked methods throw 'Exception'
-				throw new RuntimeException(e);
-			}
+			return new MemoryStateBackend() {
+				@Override
+				public OperatorStateBackend createOperatorStateBackend(Environment env, String operatorIdentifier) throws Exception {
+					return spy(super.createOperatorStateBackend(env, operatorIdentifier));
+				}
 
-			return stateBackendMock;
+				@Override
+				public <K> AbstractKeyedStateBackend<K> createKeyedStateBackend(Environment env, JobID jobID, String operatorIdentifier, TypeSerializer<K> keySerializer, int numberOfKeyGroups, KeyGroupRange keyGroupRange, TaskKvStateRegistry kvStateRegistry) {
+					return spy(super.createKeyedStateBackend(env, jobID, operatorIdentifier, keySerializer, numberOfKeyGroups, keyGroupRange, kvStateRegistry));
+				}
+			};
 		}
 	}
 
@@ -1069,22 +1103,13 @@ public class StreamTaskTest extends TestLogger {
 
 		private static volatile boolean fail;
 
-		private static volatile OperatorStateBackend operatorStateBackend;
-		private static volatile AbstractKeyedStateBackend keyedStateBackend;
-
-		public StateBackendTestSource(Environment env, @Nullable TaskStateSnapshot initialState) {
-			super(env, initialState);
+		public StateBackendTestSource(Environment env) {
+			super(env);
 		}
 
 		@Override
 		protected void init() throws Exception {
-			operatorStateBackend = createOperatorStateBackend(
-				Mockito.mock(StreamOperator.class),
-				null);
-			keyedStateBackend = createKeyedStateBackend(
-				Mockito.mock(TypeSerializer.class),
-				4,
-				Mockito.mock(KeyGroupRange.class));
+
 		}
 
 		@Override
@@ -1100,6 +1125,69 @@ public class StreamTaskTest extends TestLogger {
 		@Override
 		protected void cancelTask() throws Exception {}
 
+		@Override
+		public StreamTaskStateInitializer createStreamTaskStateInitializer() {
+			final StreamTaskStateInitializer streamTaskStateManager = super.createStreamTaskStateInitializer();
+			return (operatorID, operatorClassName, keyContext, keySerializer, closeableRegistry) -> {
+
+				final StreamOperatorStateContext context = streamTaskStateManager.streamOperatorStateContext(
+					operatorID,
+					operatorClassName,
+					keyContext,
+					keySerializer,
+					closeableRegistry);
+
+				return new StreamOperatorStateContext() {
+					@Override
+					public boolean isRestored() {
+						return context.isRestored();
+					}
+
+					@Override
+					public OperatorStateBackend operatorStateBackend() {
+						return context.operatorStateBackend();
+					}
+
+					@Override
+					public AbstractKeyedStateBackend<?> keyedStateBackend() {
+						return context.keyedStateBackend();
+					}
+
+					@Override
+					public InternalTimeServiceManager<?, ?> internalTimerServiceManager() {
+						InternalTimeServiceManager<?, ?> timeServiceManager = context.internalTimerServiceManager();
+						return timeServiceManager != null ? spy(timeServiceManager) : null;
+					}
+
+					@Override
+					public CheckpointStreamFactory checkpointStreamFactory() {
+						return context.checkpointStreamFactory();
+					}
+
+					@Override
+					public CloseableIterable<StatePartitionStreamProvider> rawOperatorStateInputs() {
+						return replaceWithSpy(context.rawOperatorStateInputs());
+					}
+
+					@Override
+					public CloseableIterable<KeyGroupStatePartitionStreamProvider> rawKeyedStateInputs() {
+						return replaceWithSpy(context.rawKeyedStateInputs());
+					}
+
+					public <T extends Closeable> T replaceWithSpy(T closeable) {
+						T spyCloseable = spy(closeable);
+						if (closeableRegistry.unregisterCloseable(closeable)) {
+							try {
+								closeableRegistry.registerCloseable(spyCloseable);
+							} catch (IOException e) {
+								throw new RuntimeException(e);
+							}
+						}
+						return spyCloseable;
+					}
+				};
+			};
+		}
 	}
 
 	/**
@@ -1111,8 +1199,8 @@ public class StreamTaskTest extends TestLogger {
 
 		private LockHolder holder;
 
-		public CancelLockingTask(Environment env, @Nullable TaskStateSnapshot initialState) {
-			super(env, initialState);
+		public CancelLockingTask(Environment env) {
+			super(env);
 		}
 
 		@Override
@@ -1155,8 +1243,8 @@ public class StreamTaskTest extends TestLogger {
 	 */
 	public static class CancelFailingTask extends StreamTask<String, AbstractStreamOperator<String>> {
 
-		public CancelFailingTask(Environment env, @Nullable TaskStateSnapshot initialState) {
-			super(env, initialState);
+		public CancelFailingTask(Environment env) {
+			super(env);
 		}
 
 		@Override
@@ -1245,4 +1333,27 @@ public class StreamTaskTest extends TestLogger {
 			interrupt();
 		}
 	}
+
+	static class TestStreamSource<OUT, SRC extends SourceFunction<OUT>> extends StreamSource<OUT, SRC> {
+
+		static AbstractKeyedStateBackend<?> keyedStateBackend;
+		static OperatorStateBackend operatorStateBackend;
+		static CloseableIterable<StatePartitionStreamProvider> rawOperatorStateInputs;
+		static CloseableIterable<KeyGroupStatePartitionStreamProvider> rawKeyedStateInputs;
+
+		public TestStreamSource(SRC sourceFunction) {
+			super(sourceFunction);
+		}
+
+		@Override
+		public void initializeState(StateInitializationContext context) throws Exception {
+			keyedStateBackend = (AbstractKeyedStateBackend<?>) getKeyedStateBackend();
+			operatorStateBackend = getOperatorStateBackend();
+			rawOperatorStateInputs =
+				(CloseableIterable<StatePartitionStreamProvider>) context.getRawOperatorStateInputs();
+			rawKeyedStateInputs =
+				(CloseableIterable<KeyGroupStatePartitionStreamProvider>) context.getRawKeyedStateInputs();
+			super.initializeState(context);
+		}
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/517b3f87/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
index 6573ecd..e535bed 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
@@ -30,6 +30,7 @@ import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
+import org.apache.flink.runtime.state.TestTaskStateManager;
 import org.apache.flink.streaming.api.TimeCharacteristic;
 import org.apache.flink.streaming.api.collector.selector.OutputSelector;
 import org.apache.flink.streaming.api.graph.StreamConfig;
@@ -46,14 +47,12 @@ import org.apache.flink.util.Preconditions;
 
 import org.junit.Assert;
 
-import javax.annotation.Nullable;
-
 import java.io.IOException;
 import java.util.Collections;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.concurrent.LinkedBlockingQueue;
-import java.util.function.BiFunction;
+import java.util.function.Function;
 
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
@@ -74,7 +73,7 @@ public class StreamTaskTestHarness<OUT> {
 
 	public static final int DEFAULT_NETWORK_BUFFER_SIZE = 1024;
 
-	private final BiFunction<Environment, TaskStateSnapshot, ? extends StreamTask<OUT, ?>> taskFactory;
+	private final Function<Environment, ? extends StreamTask<OUT, ?>> taskFactory;
 
 	public long memorySize = 0;
 	public int bufferSize = 0;
@@ -85,6 +84,8 @@ public class StreamTaskTestHarness<OUT> {
 	public Configuration taskConfig;
 	protected StreamConfig streamConfig;
 
+	protected TestTaskStateManager taskStateManager;
+
 	private StreamTask<OUT, ?> task;
 
 	private TypeSerializer<OUT> outputSerializer;
@@ -106,7 +107,7 @@ public class StreamTaskTestHarness<OUT> {
 	protected StreamTestSingleInputGate[] inputGates;
 
 	public StreamTaskTestHarness(
-			BiFunction<Environment, TaskStateSnapshot, ? extends StreamTask<OUT, ?>> taskFactory,
+			Function<Environment, ? extends StreamTask<OUT, ?>> taskFactory,
 			TypeInformation<OUT> outputType) {
 
 		this.taskFactory = checkNotNull(taskFactory);
@@ -121,6 +122,8 @@ public class StreamTaskTestHarness<OUT> {
 
 		outputSerializer = outputType.createSerializer(executionConfig);
 		outputStreamRecordSerializer = new StreamElementSerializer<OUT>(outputSerializer);
+
+		this.taskStateManager = new TestTaskStateManager();
 	}
 
 	public ProcessingTimeService getProcessingTimeService() {
@@ -132,6 +135,16 @@ public class StreamTaskTestHarness<OUT> {
 	 */
 	protected void initializeInputs() throws IOException, InterruptedException {}
 
+	public TestTaskStateManager getTaskStateManager() {
+		return taskStateManager;
+	}
+
+	public void setTaskStateSnapshot(long checkpointId, TaskStateSnapshot taskStateSnapshot) {
+		taskStateManager.setReportedCheckpointId(checkpointId);
+		taskStateManager.setTaskStateSnapshotsByCheckpointId(
+			Collections.singletonMap(checkpointId, taskStateSnapshot));
+	}
+
 	@SuppressWarnings("unchecked")
 	private void initializeOutput() {
 		outputList = new LinkedBlockingQueue<Object>();
@@ -174,16 +187,13 @@ public class StreamTaskTestHarness<OUT> {
 
 	public StreamMockEnvironment createEnvironment() {
 		return new StreamMockEnvironment(
-				jobConfig, taskConfig, executionConfig, memorySize, new MockInputSplitProvider(), bufferSize);
-	}
-
-	/**
-	 * Invoke the Task. This resets the output of any previous invocation. This will start a new
-	 * Thread to execute the Task in. Use {@link #waitForTaskCompletion()} to wait for the
-	 * Task thread to finish running.
-	 */
-	public void invoke() throws Exception {
-		invoke(createEnvironment(), null);
+			jobConfig,
+			taskConfig,
+			executionConfig,
+			memorySize,
+			new MockInputSplitProvider(),
+			bufferSize,
+			taskStateManager);
 	}
 
 	/**
@@ -191,10 +201,9 @@ public class StreamTaskTestHarness<OUT> {
 	 * Thread to execute the Task in. Use {@link #waitForTaskCompletion()} to wait for the
 	 * Task thread to finish running.
 	 *
-	 * <p>Variant for providing initial task state.
 	 */
-	public void invoke(TaskStateSnapshot initialState) throws Exception {
-		invoke(createEnvironment(), initialState);
+	public void invoke() throws Exception {
+		invoke(createEnvironment());
 	}
 
 	/**
@@ -202,22 +211,10 @@ public class StreamTaskTestHarness<OUT> {
 	 * Thread to execute the Task in. Use {@link #waitForTaskCompletion()} to wait for the
 	 * Task thread to finish running.
 	 *
-	 * <p>Variant for providing a custom environment but no initial state.
 	 */
 	public void invoke(StreamMockEnvironment mockEnv) throws Exception {
-		invoke(mockEnv, null);
-	}
-
-	/**
-	 * Invoke the Task. This resets the output of any previous invocation. This will start a new
-	 * Thread to execute the Task in. Use {@link #waitForTaskCompletion()} to wait for the
-	 * Task thread to finish running.
-	 *
-	 * <p>Variant for providing a custom environment and initial task state.
-	 */
-	public void invoke(StreamMockEnvironment mockEnv, @Nullable TaskStateSnapshot initialState) throws Exception {
 		this.mockEnv = checkNotNull(mockEnv);
-		this.task = taskFactory.apply(mockEnv, initialState);
+		this.task = taskFactory.apply(mockEnv);
 
 		initializeInputs();
 		initializeOutput();

http://git-wip-us.apache.org/repos/asf/flink/blob/517b3f87/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java
index b04a72a..243da62 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java
@@ -63,6 +63,7 @@ import org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateSnapshotContext;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.TestTaskStateManager;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.runtime.taskmanager.CheckpointResponder;
 import org.apache.flink.runtime.taskmanager.Task;
@@ -78,8 +79,6 @@ import org.apache.flink.util.TestLogger;
 import org.junit.Assert;
 import org.junit.Test;
 
-import javax.annotation.Nullable;
-
 import java.io.IOException;
 import java.util.Collections;
 import java.util.concurrent.Callable;
@@ -231,11 +230,11 @@ public class TaskCheckpointingBehaviourTest extends TestLogger {
 				Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
 				Collections.<InputGateDeploymentDescriptor>emptyList(),
 				0,
-				null,
 				mock(MemoryManager.class),
 				mock(IOManager.class),
 				network,
 				mock(BroadcastVariableManager.class),
+				new TestTaskStateManager(),
 				mock(TaskManagerActions.class),
 				mock(InputSplitProvider.class),
 				checkpointResponder,
@@ -478,8 +477,8 @@ public class TaskCheckpointingBehaviourTest extends TestLogger {
 	 */
 	public static final class TestStreamTask extends OneInputStreamTask<Object, Object> {
 
-		public TestStreamTask(Environment env, @Nullable TaskStateSnapshot initialState) {
-			super(env, initialState);
+		public TestStreamTask(Environment env) {
+			super(env);
 		}
 
 		@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/517b3f87/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTestHarness.java
index 155e45b..78e6de2 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTestHarness.java
@@ -20,7 +20,6 @@ package org.apache.flink.streaming.runtime.tasks;
 
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.io.network.partition.consumer.StreamTestSingleInputGate;
 import org.apache.flink.streaming.api.collector.selector.OutputSelector;
@@ -33,7 +32,7 @@ import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner;
 import java.io.IOException;
 import java.util.LinkedList;
 import java.util.List;
-import java.util.function.BiFunction;
+import java.util.function.Function;
 
 
 /**
@@ -72,7 +71,7 @@ public class TwoInputStreamTaskTestHarness<IN1, IN2, OUT> extends StreamTaskTest
 	 * it should be assigned to the first (1), or second (2) input of the task.
 	 */
 	public TwoInputStreamTaskTestHarness(
-			BiFunction<Environment, TaskStateSnapshot, ? extends TwoInputStreamTask<IN1, IN2, OUT>> taskFactory,
+			Function<Environment, ? extends TwoInputStreamTask<IN1, IN2, OUT>> taskFactory,
 			int numInputGates,
 			int numInputChannelsPerGate,
 			int[] inputGateAssignment,
@@ -99,7 +98,7 @@ public class TwoInputStreamTaskTestHarness<IN1, IN2, OUT> extends StreamTaskTest
 	 * second task input.
 	 */
 	public TwoInputStreamTaskTestHarness(
-			BiFunction<Environment, TaskStateSnapshot, ? extends TwoInputStreamTask<IN1, IN2, OUT>> taskFactory,
+			Function<Environment, ? extends TwoInputStreamTask<IN1, IN2, OUT>> taskFactory,
 			TypeInformation<IN1> inputType1,
 			TypeInformation<IN2> inputType2,
 			TypeInformation<OUT> outputType) {