You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by al...@apache.org on 2017/01/12 12:23:34 UTC

[2/2] flink git commit: [FLINK-5421] Add explicit restore() method in Snapshotable

[FLINK-5421] Add explicit restore() method in Snapshotable


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/aaf8e09d
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/aaf8e09d
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/aaf8e09d

Branch: refs/heads/master
Commit: aaf8e09d8f9ee7f04cb79d317e3122282153858c
Parents: 3f7f172
Author: Stefan Richter <s....@data-artisans.com>
Authored: Thu Jan 5 23:45:13 2017 +0100
Committer: Aljoscha Krettek <al...@gmail.com>
Committed: Thu Jan 12 11:26:38 2017 +0100

----------------------------------------------------------------------
 .../state/RocksDBKeyedStateBackend.java         |  67 +++-----
 .../streaming/state/RocksDBStateBackend.java    |  37 -----
 .../runtime/state/AbstractStateBackend.java     |  32 +---
 .../state/DefaultOperatorStateBackend.java      | 151 ++++++++-----------
 .../flink/runtime/state/Snapshotable.java       |   9 ++
 .../state/StateInitializationContextImpl.java   |  66 +++++---
 .../state/filesystem/FsStateBackend.java        |  21 ---
 .../state/heap/HeapKeyedStateBackend.java       |  37 ++---
 .../state/memory/MemoryStateBackend.java        |  22 ---
 .../runtime/state/OperatorStateBackendTest.java |  11 +-
 .../runtime/state/StateBackendTestBase.java     |  21 ++-
 .../streaming/runtime/tasks/StreamTask.java     |  48 +++---
 .../runtime/tasks/BlockingCheckpointsTest.java  |  13 --
 .../tasks/InterruptSensitiveRestoreTest.java    | 123 ++++++++++++++-
 .../util/AbstractStreamOperatorTestHarness.java |  15 +-
 .../KeyedOneInputStreamOperatorTestHarness.java |  37 ++---
 .../KeyedTwoInputStreamOperatorTestHarness.java |  33 ++--
 .../streaming/runtime/StateBackendITCase.java   |  15 --
 18 files changed, 353 insertions(+), 405 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/aaf8e09d/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index 1c0a4b7..71e2c79 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -180,51 +180,6 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		kvStateInformation = new HashMap<>();
 	}
 
-	public RocksDBKeyedStateBackend(
-			JobID jobId,
-			String operatorIdentifier,
-			ClassLoader userCodeClassLoader,
-			File instanceBasePath,
-			DBOptions dbOptions,
-			ColumnFamilyOptions columnFamilyOptions,
-			TaskKvStateRegistry kvStateRegistry,
-			TypeSerializer<K> keySerializer,
-			int numberOfKeyGroups,
-			KeyGroupRange keyGroupRange,
-			Collection<KeyGroupsStateHandle> restoreState
-	) throws Exception {
-
-		this(jobId,
-			operatorIdentifier,
-			userCodeClassLoader,
-			instanceBasePath,
-			dbOptions,
-			columnFamilyOptions,
-			kvStateRegistry,
-			keySerializer,
-			numberOfKeyGroups,
-			keyGroupRange);
-
-		LOG.info("Initializing RocksDB keyed state backend from snapshot.");
-
-		if (LOG.isDebugEnabled()) {
-			LOG.debug("Restoring snapshot from state handles: {}.", restoreState);
-		}
-
-		try {
-			if (MigrationUtil.isOldSavepointKeyedState(restoreState)) {
-				LOG.info("Converting RocksDB state from old savepoint.");
-				restoreOldSavepointKeyedState(restoreState);
-			} else {
-				RocksDBRestoreOperation restoreOperation = new RocksDBRestoreOperation(this);
-				restoreOperation.doRestore(restoreState);
-			}
-		} catch (Exception ex) {
-			dispose();
-			throw ex;
-		}
-	}
-
 	/**
 	 * Should only be called by one thread, and only after all accesses to the DB happened.
 	 */
@@ -631,6 +586,28 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		}
 	}
 
+	@Override
+	public void restore(Collection<KeyGroupsStateHandle> restoreState) throws Exception {
+		LOG.info("Initializing RocksDB keyed state backend from snapshot.");
+
+		if (LOG.isDebugEnabled()) {
+			LOG.debug("Restoring snapshot from state handles: {}.", restoreState);
+		}
+
+		try {
+			if (MigrationUtil.isOldSavepointKeyedState(restoreState)) {
+				LOG.info("Converting RocksDB state from old savepoint.");
+				restoreOldSavepointKeyedState(restoreState);
+			} else {
+				RocksDBRestoreOperation restoreOperation = new RocksDBRestoreOperation(this);
+				restoreOperation.doRestore(restoreState);
+			}
+		} catch (Exception ex) {
+			dispose();
+			throw ex;
+		}
+	}
+
 	/**
 	 * Encapsulates the process of restoring a RocksDBKeyedStateBackend from a snapshot.
 	 */

http://git-wip-us.apache.org/repos/asf/flink/blob/aaf8e09d/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
index c2e33d4..1e5620f 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
@@ -28,13 +28,10 @@ import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.filesystem.FsStateBackend;
 import org.apache.flink.util.AbstractID;
-
 import org.rocksdb.ColumnFamilyOptions;
 import org.rocksdb.DBOptions;
-
 import org.rocksdb.NativeLibraryLoader;
 import org.rocksdb.RocksDB;
 import org.slf4j.Logger;
@@ -46,7 +43,6 @@ import java.lang.reflect.Field;
 import java.net.URI;
 import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Collection;
 import java.util.List;
 import java.util.Random;
 import java.util.UUID;
@@ -262,39 +258,6 @@ public class RocksDBStateBackend extends AbstractStateBackend {
 				keyGroupRange);
 	}
 
-	@Override
-	public <K> AbstractKeyedStateBackend<K> restoreKeyedStateBackend(
-			Environment env,
-			JobID jobID,
-			String operatorIdentifier,
-			TypeSerializer<K> keySerializer,
-			int numberOfKeyGroups,
-			KeyGroupRange keyGroupRange,
-			Collection<KeyGroupsStateHandle> restoredState,
-			TaskKvStateRegistry kvStateRegistry) throws Exception {
-
-		// first, make sure that the RocksDB JNI library is loaded
-		// we do this explicitly here to have better error handling
-		String tempDir = env.getTaskManagerInfo().getTmpDirectories()[0];
-		ensureRocksDBIsLoaded(tempDir);
-
-		lazyInitializeForJob(env, operatorIdentifier);
-
-		File instanceBasePath = new File(getDbPath(), UUID.randomUUID().toString());
-		return new RocksDBKeyedStateBackend<>(
-				jobID,
-				operatorIdentifier,
-				env.getUserClassLoader(),
-				instanceBasePath,
-				getDbOptions(),
-				getColumnOptions(),
-				kvStateRegistry,
-				keySerializer,
-				numberOfKeyGroups,
-				keyGroupRange,
-				restoredState);
-	}
-
 	// ------------------------------------------------------------------------
 	//  Parameters
 	// ------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/aaf8e09d/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java
index 1b53f1a..60d035a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java
@@ -24,7 +24,6 @@ import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 
 import java.io.IOException;
-import java.util.Collection;
 
 /**
  * A state backend defines how state is stored and snapshotted during checkpoints.
@@ -59,41 +58,12 @@ public abstract class AbstractStateBackend implements java.io.Serializable {
 	) throws Exception;
 
 	/**
-	 * Creates a new {@link AbstractKeyedStateBackend} that restores its state from the given list
-	 * {@link KeyGroupsStateHandle KeyGroupStateHandles}.
-	 */
-	public abstract <K> AbstractKeyedStateBackend<K> restoreKeyedStateBackend(
-			Environment env,
-			JobID jobID,
-			String operatorIdentifier,
-			TypeSerializer<K> keySerializer,
-			int numberOfKeyGroups,
-			KeyGroupRange keyGroupRange,
-			Collection<KeyGroupsStateHandle> restoredState,
-			TaskKvStateRegistry kvStateRegistry
-	) throws Exception;
-
-
-	/**
 	 * Creates a new {@link OperatorStateBackend} that can be used for storing partitionable operator
 	 * state in checkpoint streams.
 	 */
 	public OperatorStateBackend createOperatorStateBackend(
 			Environment env,
-			String operatorIdentifier
-	) throws Exception {
+			String operatorIdentifier) throws Exception {
 		return new DefaultOperatorStateBackend(env.getUserClassLoader());
 	}
-
-	/**
-	 * Creates a new {@link OperatorStateBackend} that restores its state from the given collection of
-	 * {@link OperatorStateHandle}.
-	 */
-	public OperatorStateBackend restoreOperatorStateBackend(
-			Environment env,
-			String operatorIdentifier,
-			Collection<OperatorStateHandle> restoreSnapshots
-	) throws Exception {
-		return new DefaultOperatorStateBackend(env.getUserClassLoader(), restoreSnapshots);
-	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/aaf8e09d/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
index d7a10d5..10bb409 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
@@ -50,33 +50,16 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 	public static final String DEFAULT_OPERATOR_STATE_NAME = "_default_";
 	
 	private final Map<String, PartitionableListState<?>> registeredStates;
-	private final Collection<OperatorStateHandle> restoreSnapshots;
 	private final CloseableRegistry closeStreamOnCancelRegistry;
 	private final JavaSerializer<Serializable> javaSerializer;
 	private final ClassLoader userClassloader;
 
-	/**
-	 * Restores a OperatorStateStore (lazily) using the provided snapshots.
-	 *
-	 * @param restoreSnapshots snapshots that are available to restore partitionable states on request.
-	 */
-	public DefaultOperatorStateBackend(
-			ClassLoader userClassLoader,
-			Collection<OperatorStateHandle> restoreSnapshots) throws IOException {
+	public DefaultOperatorStateBackend(ClassLoader userClassLoader) throws IOException {
 
+		this.closeStreamOnCancelRegistry = new CloseableRegistry();
 		this.userClassloader = Preconditions.checkNotNull(userClassLoader);
 		this.javaSerializer = new JavaSerializer<>();
 		this.registeredStates = new HashMap<>();
-		this.closeStreamOnCancelRegistry = new CloseableRegistry();
-		this.restoreSnapshots = restoreSnapshots;
-		restoreState();
-	}
-
-	/**
-	 * Creates an empty OperatorStateStore.
-	 */
-	public DefaultOperatorStateBackend(ClassLoader userClassLoader) throws IOException {
-		this(userClassLoader, null);
 	}
 
 	@SuppressWarnings("unchecked")
@@ -111,69 +94,6 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 		return partitionableListState;
 	}
 
-	private void restoreState() throws IOException {
-
-		if (null == restoreSnapshots) {
-			return;
-		}
-
-		for (OperatorStateHandle stateHandle : restoreSnapshots) {
-
-			if (stateHandle == null) {
-				continue;
-			}
-
-			FSDataInputStream in = stateHandle.openInputStream();
-			closeStreamOnCancelRegistry.registerClosable(in);
-
-			ClassLoader restoreClassLoader = Thread.currentThread().getContextClassLoader();
-
-			try {
-				Thread.currentThread().setContextClassLoader(userClassloader);
-				OperatorBackendSerializationProxy backendSerializationProxy =
-						new OperatorBackendSerializationProxy(userClassloader);
-
-				backendSerializationProxy.read(new DataInputViewStreamWrapper(in));
-
-				List<OperatorBackendSerializationProxy.StateMetaInfo<?>> metaInfoList =
-						backendSerializationProxy.getNamedStateSerializationProxies();
-
-				// Recreate all PartitionableListStates from the meta info
-				for (OperatorBackendSerializationProxy.StateMetaInfo<?> stateMetaInfo : metaInfoList) {
-					PartitionableListState<?> listState = registeredStates.get(stateMetaInfo.getName());
-
-					if (null == listState) {
-						listState = new PartitionableListState<>(
-								stateMetaInfo.getName(),
-								stateMetaInfo.getStateSerializer());
-
-						registeredStates.put(listState.getName(), listState);
-					} else {
-						Preconditions.checkState(listState.getPartitionStateSerializer().isCompatibleWith(
-								stateMetaInfo.getStateSerializer()), "Incompatible state serializers found: " +
-								listState.getPartitionStateSerializer() + " is not compatible with " +
-								stateMetaInfo.getStateSerializer());
-					}
-				}
-
-				// Restore all the state in PartitionableListStates
-				for (Map.Entry<String, long[]> nameToOffsets : stateHandle.getStateNameToPartitionOffsets().entrySet()) {
-					PartitionableListState<?> stateListForName = registeredStates.get(nameToOffsets.getKey());
-
-					Preconditions.checkState(null != stateListForName, "Found state without " +
-							"corresponding meta info: " + nameToOffsets.getKey());
-
-					deserializeStateValues(stateListForName, in, nameToOffsets.getValue());
-				}
-
-			} finally {
-				Thread.currentThread().setContextClassLoader(restoreClassLoader);
-				closeStreamOnCancelRegistry.unregisterClosable(in);
-				IOUtils.closeQuietly(in);
-			}
-		}
-	}
-
 	private static <S> void deserializeStateValues(
 			PartitionableListState<S> stateListForName,
 			FSDataInputStream in,
@@ -239,6 +159,70 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 	}
 
 	@Override
+	public void restore(Collection<OperatorStateHandle> restoreSnapshots) throws Exception {
+
+		if (null == restoreSnapshots) {
+			return;
+		}
+
+		for (OperatorStateHandle stateHandle : restoreSnapshots) {
+
+			if (stateHandle == null) {
+				continue;
+			}
+
+			FSDataInputStream in = stateHandle.openInputStream();
+			closeStreamOnCancelRegistry.registerClosable(in);
+
+			ClassLoader restoreClassLoader = Thread.currentThread().getContextClassLoader();
+
+			try {
+				Thread.currentThread().setContextClassLoader(userClassloader);
+				OperatorBackendSerializationProxy backendSerializationProxy =
+						new OperatorBackendSerializationProxy(userClassloader);
+
+				backendSerializationProxy.read(new DataInputViewStreamWrapper(in));
+
+				List<OperatorBackendSerializationProxy.StateMetaInfo<?>> metaInfoList =
+						backendSerializationProxy.getNamedStateSerializationProxies();
+
+				// Recreate all PartitionableListStates from the meta info
+				for (OperatorBackendSerializationProxy.StateMetaInfo<?> stateMetaInfo : metaInfoList) {
+					PartitionableListState<?> listState = registeredStates.get(stateMetaInfo.getName());
+
+					if (null == listState) {
+						listState = new PartitionableListState<>(
+								stateMetaInfo.getName(),
+								stateMetaInfo.getStateSerializer());
+
+						registeredStates.put(listState.getName(), listState);
+					} else {
+						Preconditions.checkState(listState.getPartitionStateSerializer().isCompatibleWith(
+								stateMetaInfo.getStateSerializer()), "Incompatible state serializers found: " +
+								listState.getPartitionStateSerializer() + " is not compatible with " +
+								stateMetaInfo.getStateSerializer());
+					}
+				}
+
+				// Restore all the state in PartitionableListStates
+				for (Map.Entry<String, long[]> nameToOffsets : stateHandle.getStateNameToPartitionOffsets().entrySet()) {
+					PartitionableListState<?> stateListForName = registeredStates.get(nameToOffsets.getKey());
+
+					Preconditions.checkState(null != stateListForName, "Found state without " +
+							"corresponding meta info: " + nameToOffsets.getKey());
+
+					deserializeStateValues(stateListForName, in, nameToOffsets.getValue());
+				}
+
+			} finally {
+				Thread.currentThread().setContextClassLoader(restoreClassLoader);
+				closeStreamOnCancelRegistry.unregisterClosable(in);
+				IOUtils.closeQuietly(in);
+			}
+		}
+	}
+
+	@Override
 	public void dispose() {
 		registeredStates.clear();
 	}
@@ -314,5 +298,4 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 					'}';
 		}
 	}
-}
-
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/aaf8e09d/flink-runtime/src/main/java/org/apache/flink/runtime/state/Snapshotable.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/Snapshotable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/Snapshotable.java
index 2aa282d..a4a6bc4 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/Snapshotable.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/Snapshotable.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.state;
 
+import java.util.Collection;
 import java.util.concurrent.RunnableFuture;
 
 /**
@@ -42,4 +43,12 @@ public interface Snapshotable<S extends StateObject> {
 			long checkpointId,
 			long timestamp,
 			CheckpointStreamFactory streamFactory) throws Exception;
+
+	/**
+	 * Restores state that was previously snapshotted from the provided parameters. Typically the parameters are state
+	 * handles from which the old state is read.
+	 *
+	 * @param state the old state to restore.
+	 */
+	void restore(Collection<S> state) throws Exception;
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/aaf8e09d/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
index c86ff6c..be59a2a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
@@ -30,6 +30,7 @@ import java.io.IOException;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.Iterator;
+import java.util.NoSuchElementException;
 
 /**
  * Default implementation of {@link StateInitializationContext}.
@@ -155,19 +156,21 @@ public class StateInitializationContextImpl implements StateInitializationContex
 		public boolean hasNext() {
 			if (null != currentStateHandle && currentOffsetsIterator.hasNext()) {
 				return true;
-			} else {
-				while (stateHandleIterator.hasNext()) {
-					currentStateHandle = stateHandleIterator.next();
-					if (currentStateHandle.getNumberOfKeyGroups() > 0) {
-						currentOffsetsIterator = currentStateHandle.getGroupRangeOffsets().iterator();
-						closableRegistry.unregisterClosable(currentStream);
-						IOUtils.closeQuietly(currentStream);
-						currentStream = null;
-						return true;
-					}
+			}
+
+			while (stateHandleIterator.hasNext()) {
+				currentStateHandle = stateHandleIterator.next();
+				if (currentStateHandle.getNumberOfKeyGroups() > 0) {
+					currentOffsetsIterator = currentStateHandle.getGroupRangeOffsets().iterator();
+					closableRegistry.unregisterClosable(currentStream);
+					IOUtils.closeQuietly(currentStream);
+					currentStream = null;
+
+					return true;
 				}
-				return false;
 			}
+
+			return false;
 		}
 
 		private void openStream() throws IOException {
@@ -178,6 +181,11 @@ public class StateInitializationContextImpl implements StateInitializationContex
 
 		@Override
 		public KeyGroupStatePartitionStreamProvider next() {
+
+			if (!hasNext()) {
+				throw new NoSuchElementException("Iterator exhausted");
+			}
+
 			Tuple2<Integer, Long> keyGroupOffset = currentOffsetsIterator.next();
 			try {
 				if (null == currentStream) {
@@ -220,26 +228,28 @@ public class StateInitializationContextImpl implements StateInitializationContex
 
 		@Override
 		public boolean hasNext() {
-			if (null != currentStateHandle && offPos < offsets.length) {
+
+			if (null != offsets && offPos < offsets.length) {
 				return true;
-			} else {
-				while (stateHandleIterator.hasNext()) {
-					currentStateHandle = stateHandleIterator.next();
-					long[] offsets = currentStateHandle.getStateNameToPartitionOffsets().get(stateName);
-					if (null != offsets && offsets.length > 0) {
+			}
+
+			while (stateHandleIterator.hasNext()) {
+				currentStateHandle = stateHandleIterator.next();
+				long[] offsets = currentStateHandle.getStateNameToPartitionOffsets().get(stateName);
+				if (null != offsets && offsets.length > 0) {
 
-						this.offsets = offsets;
-						this.offPos = 0;
+					this.offsets = offsets;
+					this.offPos = 0;
 
-						closableRegistry.unregisterClosable(currentStream);
-						IOUtils.closeQuietly(currentStream);
-						currentStream = null;
+					closableRegistry.unregisterClosable(currentStream);
+					IOUtils.closeQuietly(currentStream);
+					currentStream = null;
 
-						return true;
-					}
+					return true;
 				}
-				return false;
 			}
+
+			return false;
 		}
 
 		private void openStream() throws IOException {
@@ -250,7 +260,13 @@ public class StateInitializationContextImpl implements StateInitializationContex
 
 		@Override
 		public StatePartitionStreamProvider next() {
+
+			if (!hasNext()) {
+				throw new NoSuchElementException("Iterator exhausted");
+			}
+
 			long offset = offsets[offPos++];
+
 			try {
 				if (null == currentStream) {
 					openStream();

http://git-wip-us.apache.org/repos/asf/flink/blob/aaf8e09d/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
index 4e15cd5..281dbb0 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
@@ -28,7 +28,6 @@ import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -36,7 +35,6 @@ import org.slf4j.LoggerFactory;
 import java.io.IOException;
 import java.net.URI;
 import java.net.URISyntaxException;
-import java.util.Collection;
 
 /**
  * The file state backend is a state backend that stores the state of streaming jobs in a file system.
@@ -192,25 +190,6 @@ public class FsStateBackend extends AbstractStateBackend {
 	}
 
 	@Override
-	public <K> AbstractKeyedStateBackend<K> restoreKeyedStateBackend(
-			Environment env,
-			JobID jobID,
-			String operatorIdentifier,
-			TypeSerializer<K> keySerializer,
-			int numberOfKeyGroups,
-			KeyGroupRange keyGroupRange,
-			Collection<KeyGroupsStateHandle> restoredState,
-			TaskKvStateRegistry kvStateRegistry) throws Exception {
-		return new HeapKeyedStateBackend<>(
-				kvStateRegistry,
-				keySerializer,
-				env.getUserClassLoader(),
-				numberOfKeyGroups,
-				keyGroupRange,
-				restoredState);
-	}
-
-	@Override
 	public String toString() {
 		return "File State Backend @ " + basePath;
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/aaf8e09d/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
index d07901b..d461dfd 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
@@ -101,28 +101,6 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		LOG.info("Initializing heap keyed state backend with stream factory.");
 	}
 
-	public HeapKeyedStateBackend(
-			TaskKvStateRegistry kvStateRegistry,
-			TypeSerializer<K> keySerializer,
-			ClassLoader userCodeClassLoader,
-			int numberOfKeyGroups,
-			KeyGroupRange keyGroupRange,
-			Collection<KeyGroupsStateHandle> restoredState) throws Exception {
-		super(kvStateRegistry, keySerializer, userCodeClassLoader, numberOfKeyGroups, keyGroupRange);
-
-		LOG.info("Initializing heap keyed state backend from snapshot.");
-
-		if (LOG.isDebugEnabled()) {
-			LOG.debug("Restoring snapshot from state handles: {}.", restoredState);
-		}
-
-		if (MigrationUtil.isOldSavepointKeyedState(restoredState)) {
-			restoreOldSavepointKeyedState(restoredState);
-		} else {
-			restorePartitionedState(restoredState);
-		}
-	}
-
 	// ------------------------------------------------------------------------
 	//  state backend operations
 	// ------------------------------------------------------------------------
@@ -251,6 +229,21 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		}
 	}
 
+	@Override
+	public void restore(Collection<KeyGroupsStateHandle> restoredState) throws Exception {
+		LOG.info("Initializing heap keyed state backend from snapshot.");
+
+		if (LOG.isDebugEnabled()) {
+			LOG.debug("Restoring snapshot from state handles: {}.", restoredState);
+		}
+
+		if (MigrationUtil.isOldSavepointKeyedState(restoredState)) {
+			restoreOldSavepointKeyedState(restoredState);
+		} else {
+			restorePartitionedState(restoredState);
+		}
+	}
+
 	private <N, S> void writeStateTableForKeyGroup(
 			DataOutputView outView,
 			StateTable<K, N, S> stateTable,

http://git-wip-us.apache.org/repos/asf/flink/blob/aaf8e09d/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
index 33f03ad..58a86df 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
@@ -26,11 +26,9 @@ import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend;
 
 import java.io.IOException;
-import java.util.Collection;
 
 /**
  * A {@link AbstractStateBackend} that stores all its data and checkpoints in memory and has no
@@ -92,24 +90,4 @@ public class MemoryStateBackend extends AbstractStateBackend {
 				numberOfKeyGroups,
 				keyGroupRange);
 	}
-
-	@Override
-	public <K> AbstractKeyedStateBackend<K> restoreKeyedStateBackend(
-			Environment env, JobID jobID,
-			String operatorIdentifier,
-			TypeSerializer<K> keySerializer,
-			int numberOfKeyGroups,
-			KeyGroupRange keyGroupRange,
-			Collection<KeyGroupsStateHandle> restoredState,
-			TaskKvStateRegistry kvStateRegistry) throws Exception {
-
-		return new HeapKeyedStateBackend<>(
-				kvStateRegistry,
-				keySerializer,
-				env.getUserClassLoader(),
-				numberOfKeyGroups,
-				keyGroupRange,
-				restoredState);
-	}
-
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/aaf8e09d/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
index 648d762..515011f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
@@ -45,7 +45,9 @@ public class OperatorStateBackendTest {
 	}
 
 	private OperatorStateBackend createNewOperatorStateBackend() throws Exception {
-		return abstractStateBackend.createOperatorStateBackend(createMockEnvironment(), "test-operator");
+		return abstractStateBackend.createOperatorStateBackend(
+				createMockEnvironment(),
+				"test-operator");
 	}
 
 	@Test
@@ -131,8 +133,11 @@ public class OperatorStateBackendTest {
 
 			operatorStateBackend.dispose();
 
-			operatorStateBackend = abstractStateBackend.restoreOperatorStateBackend(
-					createMockEnvironment(), "testOperator", Collections.singletonList(stateHandle));
+			operatorStateBackend = abstractStateBackend.createOperatorStateBackend(
+					createMockEnvironment(),
+					"testOperator");
+
+			operatorStateBackend.restore(Collections.singletonList(stateHandle));
 
 			assertEquals(2, operatorStateBackend.getRegisteredStateNames().size());
 

http://git-wip-us.apache.org/repos/asf/flink/blob/aaf8e09d/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
index 5655f1c..9bc4c53 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
@@ -58,7 +58,13 @@ import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.RunnableFuture;
 
 import static org.hamcrest.Matchers.containsInAnyOrder;
-import static org.junit.Assert.*;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.mock;
@@ -101,8 +107,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 				keySerializer,
 				numberOfKeyGroups,
 				keyGroupRange,
-				env.getTaskKvStateRegistry())
-;
+				env.getTaskKvStateRegistry());
 	}
 
 	protected <K> AbstractKeyedStateBackend<K> restoreKeyedBackend(TypeSerializer<K> keySerializer, KeyGroupsStateHandle state) throws Exception {
@@ -127,15 +132,21 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			KeyGroupRange keyGroupRange,
 			List<KeyGroupsStateHandle> state,
 			Environment env) throws Exception {
-		return getStateBackend().restoreKeyedStateBackend(
+
+		AbstractKeyedStateBackend<K> backend = getStateBackend().createKeyedStateBackend(
 				env,
 				new JobID(),
 				"test_op",
 				keySerializer,
 				numberOfKeyGroups,
 				keyGroupRange,
-				state,
 				env.getTaskKvStateRegistry());
+
+		if (null != state) {
+			backend.restore(state);
+		}
+
+		return backend;
 	}
 
 	@Test

http://git-wip-us.apache.org/repos/asf/flink/blob/aaf8e09d/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index 675c606..1c20393 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -742,13 +742,17 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 		Environment env = getEnvironment();
 		String opId = createOperatorIdentifier(op, getConfiguration().getVertexID());
 
-		OperatorStateBackend newBackend = restoreStateHandles == null ?
-				stateBackend.createOperatorStateBackend(env, opId)
-				: stateBackend.restoreOperatorStateBackend(env, opId, restoreStateHandles);
+		OperatorStateBackend operatorStateBackend = stateBackend.createOperatorStateBackend(env, opId);
 
-		cancelables.registerClosable(newBackend);
+		// let operator state backend participate in the operator lifecycle, i.e. make it responsive to cancelation
+		cancelables.registerClosable(operatorStateBackend);
 
-		return newBackend;
+		// restore if we have some old state
+		if (null != restoreStateHandles) {
+			operatorStateBackend.restore(restoreStateHandles);
+		}
+
+		return operatorStateBackend;
 	}
 
 	public <K> AbstractKeyedStateBackend<K> createKeyedStateBackend(
@@ -764,29 +768,23 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 				headOperator,
 				configuration.getVertexID());
 
-		if (null != restoreStateHandles && null != restoreStateHandles.getManagedKeyedState()) {
-			keyedStateBackend = stateBackend.restoreKeyedStateBackend(
-					getEnvironment(),
-					getEnvironment().getJobID(),
-					operatorIdentifier,
-					keySerializer,
-					numberOfKeyGroups,
-					keyGroupRange,
-					restoreStateHandles.getManagedKeyedState(),
-					getEnvironment().getTaskKvStateRegistry());
-		} else {
-			keyedStateBackend = stateBackend.createKeyedStateBackend(
-					getEnvironment(),
-					getEnvironment().getJobID(),
-					operatorIdentifier,
-					keySerializer,
-					numberOfKeyGroups,
-					keyGroupRange,
-					getEnvironment().getTaskKvStateRegistry());
-		}
+		keyedStateBackend = stateBackend.createKeyedStateBackend(
+				getEnvironment(),
+				getEnvironment().getJobID(),
+				operatorIdentifier,
+				keySerializer,
+				numberOfKeyGroups,
+				keyGroupRange,
+				getEnvironment().getTaskKvStateRegistry());
 
+		// let keyed state backend participate in the operator lifecycle, i.e. make it responsive to cancelation
 		cancelables.registerClosable(keyedStateBackend);
 
+		// restore if we have some old state
+		if (null != restoreStateHandles && null != restoreStateHandles.getManagedKeyedState()) {
+			keyedStateBackend.restore(restoreStateHandles.getManagedKeyedState());
+		}
+
 		@SuppressWarnings("unchecked")
 		AbstractKeyedStateBackend<K> typedBackend = (AbstractKeyedStateBackend<K>) keyedStateBackend;
 		return typedBackend;

http://git-wip-us.apache.org/repos/asf/flink/blob/aaf8e09d/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/BlockingCheckpointsTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/BlockingCheckpointsTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/BlockingCheckpointsTest.java
index 7becbf4..492b470 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/BlockingCheckpointsTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/BlockingCheckpointsTest.java
@@ -52,7 +52,6 @@ import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.CheckpointStreamFactory.CheckpointStateOutputStream;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream;
 import org.apache.flink.runtime.state.StateSnapshotContext;
 import org.apache.flink.runtime.state.StreamStateHandle;
@@ -64,12 +63,10 @@ import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.StreamFilter;
 import org.apache.flink.util.SerializedValue;
-
 import org.junit.Test;
 
 import java.io.IOException;
 import java.net.URL;
-import java.util.Collection;
 import java.util.Collections;
 
 import static org.junit.Assert.assertEquals;
@@ -183,16 +180,6 @@ public class BlockingCheckpointsTest {
 
 			throw new UnsupportedOperationException();
 		}
-
-		@Override
-		public <K> AbstractKeyedStateBackend<K> restoreKeyedStateBackend(
-				Environment env, JobID jobID, String operatorIdentifier,
-				TypeSerializer<K> keySerializer, int numberOfKeyGroups,
-				KeyGroupRange keyGroupRange, Collection<KeyGroupsStateHandle> restoredState,
-				TaskKvStateRegistry kvStateRegistry) throws Exception {
-
-			throw new UnsupportedOperationException();
-		}
 	}
 
 	private static final class LockingOutputStreamFactory implements CheckpointStreamFactory {

http://git-wip-us.apache.org/repos/asf/flink/blob/aaf8e09d/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
index aa2492c..0206cf5 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
@@ -20,6 +20,7 @@ package org.apache.flink.streaming.runtime.tasks;
 
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.testutils.OneShotLatch;
@@ -44,8 +45,14 @@ import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.DefaultOperatorStateBackend;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.StateInitializationContext;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.taskmanager.CheckpointResponder;
@@ -55,11 +62,11 @@ import org.apache.flink.runtime.util.EnvironmentInformation;
 import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo;
 import org.apache.flink.streaming.api.TimeCharacteristic;
 import org.apache.flink.streaming.api.checkpoint.Checkpointed;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.StreamSource;
 import org.apache.flink.util.SerializedValue;
-
 import org.junit.Test;
 
 import java.io.EOFException;
@@ -68,7 +75,9 @@ import java.io.Serializable;
 import java.net.URL;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.concurrent.Executor;
 
 import static org.junit.Assert.assertEquals;
@@ -89,17 +98,61 @@ public class InterruptSensitiveRestoreTest {
 
 	private static final OneShotLatch IN_RESTORE_LATCH = new OneShotLatch();
 
+	private static final int OPERATOR_MANAGED = 0;
+	private static final int OPERATOR_RAW = 1;
+	private static final int KEYED_MANAGED = 2;
+	private static final int KEYED_RAW = 3;
+	private static final int LEGACY = 4;
+
+	@Test
+	public void testRestoreWithInterruptLegacy() throws Exception {
+		testRestoreWithInterrupt(LEGACY);
+	}
+
+	@Test
+	public void testRestoreWithInterruptOperatorManaged() throws Exception {
+		testRestoreWithInterrupt(OPERATOR_MANAGED);
+	}
+
+	@Test
+	public void testRestoreWithInterruptOperatorRaw() throws Exception {
+		testRestoreWithInterrupt(OPERATOR_RAW);
+	}
+
 	@Test
-	public void testRestoreWithInterrupt() throws Exception {
+	public void testRestoreWithInterruptKeyedManaged() throws Exception {
+		testRestoreWithInterrupt(KEYED_MANAGED);
+	}
+
+	@Test
+	public void testRestoreWithInterruptKeyedRaw() throws Exception {
+		testRestoreWithInterrupt(KEYED_RAW);
+	}
 
+	private void testRestoreWithInterrupt(int mode) throws Exception {
+
+		IN_RESTORE_LATCH.reset();
 		Configuration taskConfig = new Configuration();
 		StreamConfig cfg = new StreamConfig(taskConfig);
 		cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
-		cfg.setStreamOperator(new StreamSource<>(new TestSource()));
+		switch (mode) {
+			case OPERATOR_MANAGED:
+			case OPERATOR_RAW:
+			case KEYED_MANAGED:
+			case KEYED_RAW:
+				cfg.setStateKeySerializer(IntSerializer.INSTANCE);
+				cfg.setStreamOperator(new StreamSource<>(new TestSource()));
+				break;
+			case LEGACY:
+				cfg.setStreamOperator(new StreamSource<>(new TestSourceLegacy()));
+				break;
+			default:
+				throw new IllegalArgumentException();
+		}
 
 		StreamStateHandle lockingHandle = new InterruptLockingStateHandle();
 
-		Task task = createTask(taskConfig, lockingHandle);
+		Task task = createTask(taskConfig, lockingHandle, mode);
 
 		// start the task and wait until it is in "restore"
 		task.startTaskThread();
@@ -124,18 +177,51 @@ public class InterruptSensitiveRestoreTest {
 
 	private static Task createTask(
 			Configuration taskConfig,
-			StreamStateHandle state) throws IOException {
+			StreamStateHandle state,
+			int mode) throws IOException {
 
 		NetworkEnvironment networkEnvironment = mock(NetworkEnvironment.class);
 		when(networkEnvironment.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class)))
 				.thenReturn(mock(TaskKvStateRegistry.class));
 
-		ChainedStateHandle<StreamStateHandle> operatorState = new ChainedStateHandle<>(Collections.singletonList(state));
+
+		ChainedStateHandle<StreamStateHandle> operatorState = null;
 		List<KeyGroupsStateHandle> keyGroupStateFromBackend = Collections.emptyList();
 		List<KeyGroupsStateHandle> keyGroupStateFromStream = Collections.emptyList();
 		List<Collection<OperatorStateHandle>> operatorStateBackend = Collections.emptyList();
 		List<Collection<OperatorStateHandle>> operatorStateStream = Collections.emptyList();
 
+		Map<String, long[]> operatorStateMetadata = new HashMap<>(1);
+		operatorStateMetadata.put(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME, new long[]{0});
+
+		KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(new KeyGroupRange(0,0));
+
+		Collection<OperatorStateHandle> operatorStateHandles =
+				Collections.singletonList(new OperatorStateHandle(operatorStateMetadata, state));
+
+		List<KeyGroupsStateHandle> keyGroupsStateHandles =
+				Collections.singletonList(new KeyGroupsStateHandle(keyGroupRangeOffsets, state));
+
+		switch (mode) {
+			case OPERATOR_MANAGED:
+				operatorStateBackend = Collections.singletonList(operatorStateHandles);
+				break;
+			case OPERATOR_RAW:
+				operatorStateStream = Collections.singletonList(operatorStateHandles);
+				break;
+			case KEYED_MANAGED:
+				keyGroupStateFromBackend = keyGroupsStateHandles;
+				break;
+			case KEYED_RAW:
+				keyGroupStateFromStream = keyGroupsStateHandles;
+				break;
+			case LEGACY:
+				operatorState = new ChainedStateHandle<>(Collections.singletonList(state));
+				break;
+			default:
+				throw new IllegalArgumentException();
+		}
+
 		TaskStateHandles taskStateHandles = new TaskStateHandles(
 			operatorState,
 			operatorStateBackend,
@@ -258,7 +344,7 @@ public class InterruptSensitiveRestoreTest {
 
 	// ------------------------------------------------------------------------
 
-	private static class TestSource implements SourceFunction<Object>, Checkpointed<Serializable> {
+	private static class TestSourceLegacy implements SourceFunction<Object>, Checkpointed<Serializable> {
 		private static final long serialVersionUID = 1L;
 
 		@Override
@@ -280,4 +366,27 @@ public class InterruptSensitiveRestoreTest {
 			fail("should never be called");
 		}
 	}
+
+	private static class TestSource implements SourceFunction<Object>, CheckpointedFunction {
+		private static final long serialVersionUID = 1L;
+
+		@Override
+		public void run(SourceContext<Object> ctx) throws Exception {
+			fail("should never be called");
+		}
+
+		@Override
+		public void cancel() {}
+
+
+		@Override
+		public void snapshotState(FunctionSnapshotContext context) throws Exception {
+			fail("should never be called");
+		}
+
+		@Override
+		public void initializeState(FunctionInitializationContext context) throws Exception {
+			((StateInitializationContext)context).getRawOperatorStateInputs().iterator().next().getStream().read();
+		}
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/aaf8e09d/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
index 346d5c3..7fe4ebc 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
@@ -192,12 +192,17 @@ public class AbstractStreamOperatorTestHarness<OUT> {
 					final StreamOperator<?> operator = (StreamOperator<?>) invocationOnMock.getArguments()[0];
 					final Collection<OperatorStateHandle> stateHandles = (Collection<OperatorStateHandle>) invocationOnMock.getArguments()[1];
 					OperatorStateBackend osb;
-					if (null == stateHandles) {
-						osb = stateBackend.createOperatorStateBackend(environment, operator.getClass().getSimpleName());
-					} else {
-						osb = stateBackend.restoreOperatorStateBackend(environment, operator.getClass().getSimpleName(), stateHandles);
-					}
+
+					osb = stateBackend.createOperatorStateBackend(
+							environment,
+							operator.getClass().getSimpleName());
+
 					mockTask.getCancelables().registerClosable(osb);
+
+					if (null != stateHandles) {
+						osb.restore(stateHandles);
+					}
+
 					return osb;
 				}
 			}).when(mockTask).createOperatorStateBackend(any(StreamOperator.class), any(Collection.class));

http://git-wip-us.apache.org/repos/asf/flink/blob/aaf8e09d/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
index 3a47a1d..4abb6e2 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
@@ -100,33 +100,24 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 					final int numberOfKeyGroups = (Integer) invocationOnMock.getArguments()[1];
 					final KeyGroupRange keyGroupRange = (KeyGroupRange) invocationOnMock.getArguments()[2];
 
-					if(keyedStateBackend != null) {
+					if (keyedStateBackend != null) {
 						keyedStateBackend.dispose();
 					}
 
-					if (restoredKeyedState == null) {
-						keyedStateBackend = stateBackend.createKeyedStateBackend(
-								mockTask.getEnvironment(),
-								new JobID(),
-								"test_op",
-								keySerializer,
-								numberOfKeyGroups,
-								keyGroupRange,
-								mockTask.getEnvironment().getTaskKvStateRegistry());
-						return keyedStateBackend;
-					} else {
-						keyedStateBackend = stateBackend.restoreKeyedStateBackend(
-								mockTask.getEnvironment(),
-								new JobID(),
-								"test_op",
-								keySerializer,
-								numberOfKeyGroups,
-								keyGroupRange,
-								restoredKeyedState,
-								mockTask.getEnvironment().getTaskKvStateRegistry());
-						restoredKeyedState = null;
-						return keyedStateBackend;
+					keyedStateBackend = stateBackend.createKeyedStateBackend(
+							mockTask.getEnvironment(),
+							new JobID(),
+							"test_op",
+							keySerializer,
+							numberOfKeyGroups,
+							keyGroupRange,
+							mockTask.getEnvironment().getTaskKvStateRegistry());
+
+					if (restoredKeyedState != null) {
+						keyedStateBackend.restore(restoredKeyedState);
 					}
+
+					return keyedStateBackend;
 				}
 			}).when(mockTask).createKeyedStateBackend(any(TypeSerializer.class), anyInt(), any(KeyGroupRange.class));
 		} catch (Exception e) {

http://git-wip-us.apache.org/repos/asf/flink/blob/aaf8e09d/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java
index 0aa91d9..8e76f70 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java
@@ -94,29 +94,18 @@ public class KeyedTwoInputStreamOperatorTestHarness<K, IN1, IN2, OUT>
 						keyedStateBackend.close();
 					}
 
-					if (restoredKeyedState == null) {
-						keyedStateBackend = stateBackend.createKeyedStateBackend(
-								mockTask.getEnvironment(),
-								new JobID(),
-								"test_op",
-								keySerializer,
-								numberOfKeyGroups,
-								keyGroupRange,
-								mockTask.getEnvironment().getTaskKvStateRegistry());
-						return keyedStateBackend;
-					} else {
-						keyedStateBackend = stateBackend.restoreKeyedStateBackend(
-								mockTask.getEnvironment(),
-								new JobID(),
-								"test_op",
-								keySerializer,
-								numberOfKeyGroups,
-								keyGroupRange,
-								restoredKeyedState,
-								mockTask.getEnvironment().getTaskKvStateRegistry());
-						restoredKeyedState = null;
-						return keyedStateBackend;
+					keyedStateBackend = stateBackend.createKeyedStateBackend(
+							mockTask.getEnvironment(),
+							new JobID(),
+							"test_op",
+							keySerializer,
+							numberOfKeyGroups,
+							keyGroupRange,
+							mockTask.getEnvironment().getTaskKvStateRegistry());
+					if (restoredKeyedState != null) {
+						keyedStateBackend.restore(restoredKeyedState);
 					}
+					return keyedStateBackend;
 				}
 			}).when(mockTask).createKeyedStateBackend(any(TypeSerializer.class), anyInt(), any(KeyGroupRange.class));
 		} catch (Exception e) {

http://git-wip-us.apache.org/repos/asf/flink/blob/aaf8e09d/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
index 963d18a..0e62fbb 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
@@ -32,13 +32,11 @@ import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase;
 import org.junit.Test;
 
 import java.io.IOException;
-import java.util.Collection;
 
 import static org.junit.Assert.fail;
 
@@ -110,19 +108,6 @@ public class StateBackendITCase extends StreamingMultipleProgramsTestBase {
 				TaskKvStateRegistry kvStateRegistry) throws Exception {
 			throw new SuccessException();
 		}
-
-		@Override
-		public <K> AbstractKeyedStateBackend<K> restoreKeyedStateBackend(
-				Environment env,
-				JobID jobID,
-				String operatorIdentifier,
-				TypeSerializer<K> keySerializer,
-				int numberOfKeyGroups,
-				KeyGroupRange keyGroupRange,
-				Collection<KeyGroupsStateHandle> restoredState,
-				TaskKvStateRegistry kvStateRegistry) throws Exception {
-			throw new SuccessException();
-		}
 	}
 
 	static final class SuccessException extends IOException {