You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by al...@apache.org on 2016/08/31 17:28:26 UTC

[08/27] flink git commit: [FLINK-3761] Refactor State Backends/Make Keyed State Key-Group Aware

http://git-wip-us.apache.org/repos/asf/flink/blob/4809f536/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java
index d653f73..30d91b6 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java
@@ -31,18 +31,23 @@ import io.netty.channel.socket.SocketChannel;
 import io.netty.channel.socket.nio.NioSocketChannel;
 import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
 import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.common.state.ValueState;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.typeutils.base.IntSerializer;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
-import org.apache.flink.runtime.query.KvStateID;
+import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
 import org.apache.flink.runtime.query.KvStateRegistry;
 import org.apache.flink.runtime.query.KvStateServerAddress;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestResult;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestType;
+import org.apache.flink.runtime.state.AbstractStateBackend;
+import org.apache.flink.runtime.state.HashKeyGroupAssigner;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
-import org.apache.flink.runtime.state.memory.MemValueState;
+import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.junit.AfterClass;
 import org.junit.Test;
 
@@ -52,6 +57,7 @@ import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.TimeUnit;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
 
 public class KvStateServerTest {
 
@@ -84,26 +90,37 @@ public class KvStateServerTest {
 
 			KvStateServerAddress serverAddress = server.getAddress();
 
-			// Register state
-			MemValueState<Integer, VoidNamespace, Integer> kvState = new MemValueState<>(
+			AbstractStateBackend abstractBackend = new MemoryStateBackend();
+			DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0);
+			dummyEnv.setKvStateRegistry(registry);
+			KeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend(
+					dummyEnv,
+					new JobID(),
+					"test_op",
 					IntSerializer.INSTANCE,
-					VoidNamespaceSerializer.INSTANCE,
-					new ValueStateDescriptor<>("any", IntSerializer.INSTANCE, null));
+					new HashKeyGroupAssigner<Integer>(1),
+					new KeyGroupRange(0, 0),
+					registry.createTaskRegistry(new JobID(), new JobVertexID()));
 
-			KvStateID kvStateId = registry.registerKvState(
-					new JobID(),
-					new JobVertexID(),
-					0,
-					"vanilla",
-					kvState);
+			final KvStateServerHandlerTest.TestRegistryListener registryListener =
+					new KvStateServerHandlerTest.TestRegistryListener();
+
+			registry.registerListener(registryListener);
+
+			ValueStateDescriptor<Integer> desc = new ValueStateDescriptor<>("any", IntSerializer.INSTANCE, null);
+			desc.setQueryable("vanilla");
+
+			ValueState<Integer> state = backend.getPartitionedState(
+					VoidNamespace.INSTANCE,
+					VoidNamespaceSerializer.INSTANCE,
+					desc);
 
 			// Update KvState
 			int expectedValue = 712828289;
 
 			int key = 99812822;
-			kvState.setCurrentKey(key);
-			kvState.setCurrentNamespace(VoidNamespace.INSTANCE);
-			kvState.update(expectedValue);
+			backend.setCurrentKey(key);
+			state.update(expectedValue);
 
 			// Request
 			byte[] serializedKeyAndNamespace = KvStateRequestSerializer.serializeKeyAndNamespace(
@@ -128,10 +145,12 @@ public class KvStateServerTest {
 					.sync().channel();
 
 			long requestId = Integer.MAX_VALUE + 182828L;
+
+			assertTrue(registryListener.registrationName.equals("vanilla"));
 			ByteBuf request = KvStateRequestSerializer.serializeKvStateRequest(
 					channel.alloc(),
 					requestId,
-					kvStateId,
+					registryListener.kvStateId,
 					serializedKeyAndNamespace);
 
 			channel.writeAndFlush(request);

http://git-wip-us.apache.org/repos/asf/flink/blob/4809f536/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java
index 04fa089..bc0b9c3 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java
@@ -19,6 +19,7 @@
 package org.apache.flink.runtime.state;
 
 import org.apache.commons.io.FileUtils;
+import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.typeutils.base.IntSerializer;
 import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.core.fs.Path;
@@ -29,7 +30,9 @@ import org.apache.flink.runtime.state.filesystem.FileStateHandle;
 import org.apache.flink.runtime.state.filesystem.FsStateBackend;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 
+import org.junit.Rule;
 import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
 
 import java.io.File;
 import java.io.IOException;
@@ -42,17 +45,13 @@ import static org.junit.Assert.*;
 
 public class FileStateBackendTest extends StateBackendTestBase<FsStateBackend> {
 
-	private File stateDir;
+	@Rule
+	public TemporaryFolder tempFolder = new TemporaryFolder();
 
 	@Override
 	protected FsStateBackend getStateBackend() throws Exception {
-		stateDir = new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString());
-		return new FsStateBackend(localFileUri(stateDir));
-	}
-
-	@Override
-	protected void cleanup() throws Exception {
-		deleteDirectorySilently(stateDir);
+		File checkpointPath = tempFolder.newFolder();
+		return new FsStateBackend(localFileUri(checkpointPath));
 	}
 
 	// disable these because the verification does not work for this state backend
@@ -69,66 +68,19 @@ public class FileStateBackendTest extends StateBackendTestBase<FsStateBackend> {
 	public void testReducingStateRestoreWithWrongSerializers() {}
 
 	@Test
-	public void testSetupAndSerialization() {
-		File tempDir = new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString());
-		try {
-			final String backendDir = localFileUri(tempDir);
-			FsStateBackend originalBackend = new FsStateBackend(backendDir);
-
-			assertFalse(originalBackend.isInitialized());
-			assertEquals(new URI(backendDir), originalBackend.getBasePath().toUri());
-			assertNull(originalBackend.getCheckpointDirectory());
-
-			// serialize / copy the backend
-			FsStateBackend backend = CommonTestUtils.createCopySerializable(originalBackend);
-			assertFalse(backend.isInitialized());
-			assertEquals(new URI(backendDir), backend.getBasePath().toUri());
-			assertNull(backend.getCheckpointDirectory());
-
-			// no file operations should be possible right now
-			try {
-				FsStateBackend.FsCheckpointStateOutputStream out = backend.createCheckpointStateOutputStream(
-						2L,
-						System.currentTimeMillis());
-
-				out.write(1);
-				out.closeAndGetHandle();
-				fail("should fail with an exception");
-			} catch (IllegalStateException e) {
-				// supreme!
-			}
-
-			backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test-op", IntSerializer.INSTANCE);
-			assertNotNull(backend.getCheckpointDirectory());
-
-			File checkpointDir = new File(backend.getCheckpointDirectory().toUri().getPath());
-			assertTrue(checkpointDir.exists());
-			assertTrue(isDirectoryEmpty(checkpointDir));
-
-			backend.disposeAllStateForCurrentJob();
-			assertNull(backend.getCheckpointDirectory());
+	public void testStateOutputStream() throws IOException {
+		File basePath = tempFolder.newFolder().getAbsoluteFile();
 
-			assertTrue(isDirectoryEmpty(tempDir));
-		}
-		catch (Exception e) {
-			e.printStackTrace();
-			fail(e.getMessage());
-		}
-		finally {
-			deleteDirectorySilently(tempDir);
-		}
-	}
-
-	@Test
-	public void testStateOutputStream() {
-		File tempDir = new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString());
 		try {
 			// the state backend has a very low in-mem state threshold (15 bytes)
-			FsStateBackend backend = CommonTestUtils.createCopySerializable(new FsStateBackend(tempDir.toURI(), 15));
+			FsStateBackend backend = CommonTestUtils.createCopySerializable(new FsStateBackend(basePath.toURI(), 15));
+			JobID jobId = new JobID();
 
-			backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test-op", IntSerializer.INSTANCE);
+			// we know how FsCheckpointStreamFactory is implemented so we know where it
+			// will store checkpoints
+			File checkpointPath = new File(basePath.getAbsolutePath(), jobId.toString());
 
-			File checkpointDir = new File(backend.getCheckpointDirectory().toUri().getPath());
+			CheckpointStreamFactory streamFactory = backend.createStreamFactory(jobId, "test_op");
 
 			byte[] state1 = new byte[1274673];
 			byte[] state2 = new byte[1];
@@ -143,12 +95,14 @@ public class FileStateBackendTest extends StateBackendTestBase<FsStateBackend> {
 
 			long checkpointId = 97231523452L;
 
-			FsStateBackend.FsCheckpointStateOutputStream stream1 =
-					backend.createCheckpointStateOutputStream(checkpointId, System.currentTimeMillis());
-			FsStateBackend.FsCheckpointStateOutputStream stream2 =
-					backend.createCheckpointStateOutputStream(checkpointId, System.currentTimeMillis());
-			FsStateBackend.FsCheckpointStateOutputStream stream3 =
-					backend.createCheckpointStateOutputStream(checkpointId, System.currentTimeMillis());
+			CheckpointStreamFactory.CheckpointStateOutputStream stream1 =
+					streamFactory.createCheckpointStateOutputStream(checkpointId, System.currentTimeMillis());
+
+			CheckpointStreamFactory.CheckpointStateOutputStream stream2 =
+					streamFactory.createCheckpointStateOutputStream(checkpointId, System.currentTimeMillis());
+
+			CheckpointStreamFactory.CheckpointStateOutputStream stream3 =
+					streamFactory.createCheckpointStateOutputStream(checkpointId, System.currentTimeMillis());
 
 			stream1.write(state1);
 			stream2.write(state2);
@@ -160,15 +114,15 @@ public class FileStateBackendTest extends StateBackendTestBase<FsStateBackend> {
 
 			// use with try-with-resources
 			StreamStateHandle handle4;
-			try (AbstractStateBackend.CheckpointStateOutputStream stream4 =
-					backend.createCheckpointStateOutputStream(checkpointId, System.currentTimeMillis())) {
+			try (CheckpointStreamFactory.CheckpointStateOutputStream stream4 =
+					streamFactory.createCheckpointStateOutputStream(checkpointId, System.currentTimeMillis())) {
 				stream4.write(state4);
 				handle4 = stream4.closeAndGetHandle();
 			}
 
 			// close before accessing handle
-			AbstractStateBackend.CheckpointStateOutputStream stream5 =
-					backend.createCheckpointStateOutputStream(checkpointId, System.currentTimeMillis());
+			CheckpointStreamFactory.CheckpointStateOutputStream stream5 =
+					streamFactory.createCheckpointStateOutputStream(checkpointId, System.currentTimeMillis());
 			stream5.write(state4);
 			stream5.close();
 			try {
@@ -180,7 +134,7 @@ public class FileStateBackendTest extends StateBackendTestBase<FsStateBackend> {
 
 			validateBytesInStream(handle1.openInputStream(), state1);
 			handle1.discardState();
-			assertFalse(isDirectoryEmpty(checkpointDir));
+			assertFalse(isDirectoryEmpty(basePath));
 			ensureLocalFileDeleted(handle1.getFilePath());
 
 			validateBytesInStream(handle2.openInputStream(), state2);
@@ -191,15 +145,12 @@ public class FileStateBackendTest extends StateBackendTestBase<FsStateBackend> {
 
 			validateBytesInStream(handle4.openInputStream(), state4);
 			handle4.discardState();
-			assertTrue(isDirectoryEmpty(checkpointDir));
+			assertTrue(isDirectoryEmpty(checkpointPath));
 		}
 		catch (Exception e) {
 			e.printStackTrace();
 			fail(e.getMessage());
 		}
-		finally {
-			deleteDirectorySilently(tempDir);
-		}
 	}
 
 	// ------------------------------------------------------------------------
@@ -253,8 +204,7 @@ public class FileStateBackendTest extends StateBackendTestBase<FsStateBackend> {
 
 	@Test
 	public void testConcurrentMapIfQueryable() throws Exception {
-		backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE);
-		StateBackendTestBase.testConcurrentMapIfQueryable(backend);
+		super.testConcurrentMapIfQueryable();
 	}
 
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/4809f536/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
index 940b337..944938b 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.state;
 
+import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.typeutils.base.IntSerializer;
 import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
@@ -40,9 +41,6 @@ public class MemoryStateBackendTest extends StateBackendTestBase<MemoryStateBack
 		return new MemoryStateBackend();
 	}
 
-	@Override
-	protected void cleanup() throws Exception { }
-
 	// disable these because the verification does not work for this state backend
 	@Override
 	@Test
@@ -60,15 +58,15 @@ public class MemoryStateBackendTest extends StateBackendTestBase<MemoryStateBack
 	public void testOversizedState() {
 		try {
 			MemoryStateBackend backend = new MemoryStateBackend(10);
+			CheckpointStreamFactory streamFactory = backend.createStreamFactory(new JobID(), "test_op");
 
 			HashMap<String, Integer> state = new HashMap<>();
 			state.put("hey there", 2);
 			state.put("the crazy brown fox stumbles over a sentence that does not contain every letter", 77);
 
 			try {
-				AbstractStateBackend.CheckpointStateOutputStream outStream = backend.createCheckpointStateOutputStream(
-						12,
-						459);
+				CheckpointStreamFactory.CheckpointStateOutputStream outStream =
+						streamFactory.createCheckpointStateOutputStream(12, 459);
 
 				ObjectOutputStream oos = new ObjectOutputStream(outStream);
 				oos.writeObject(state);
@@ -93,12 +91,13 @@ public class MemoryStateBackendTest extends StateBackendTestBase<MemoryStateBack
 	public void testStateStream() {
 		try {
 			MemoryStateBackend backend = new MemoryStateBackend();
+			CheckpointStreamFactory streamFactory = backend.createStreamFactory(new JobID(), "test_op");
 
 			HashMap<String, Integer> state = new HashMap<>();
 			state.put("hey there", 2);
 			state.put("the crazy brown fox stumbles over a sentence that does not contain every letter", 77);
 
-			AbstractStateBackend.CheckpointStateOutputStream os = backend.createCheckpointStateOutputStream(1, 2);
+			CheckpointStreamFactory.CheckpointStateOutputStream os = streamFactory.createCheckpointStateOutputStream(1, 2);
 			ObjectOutputStream oos = new ObjectOutputStream(os);
 			oos.writeObject(state);
 			oos.flush();
@@ -121,12 +120,13 @@ public class MemoryStateBackendTest extends StateBackendTestBase<MemoryStateBack
 	public void testOversizedStateStream() {
 		try {
 			MemoryStateBackend backend = new MemoryStateBackend(10);
+			CheckpointStreamFactory streamFactory = backend.createStreamFactory(new JobID(), "test_op");
 
 			HashMap<String, Integer> state = new HashMap<>();
 			state.put("hey there", 2);
 			state.put("the crazy brown fox stumbles over a sentence that does not contain every letter", 77);
 
-			AbstractStateBackend.CheckpointStateOutputStream os = backend.createCheckpointStateOutputStream(1, 2);
+			CheckpointStreamFactory.CheckpointStateOutputStream os = streamFactory.createCheckpointStateOutputStream(1, 2);
 			ObjectOutputStream oos = new ObjectOutputStream(os);
 
 			try {
@@ -147,7 +147,6 @@ public class MemoryStateBackendTest extends StateBackendTestBase<MemoryStateBack
 
 	@Test
 	public void testConcurrentMapIfQueryable() throws Exception {
-		backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE);
-		StateBackendTestBase.testConcurrentMapIfQueryable(backend);
+		super.testConcurrentMapIfQueryable();
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/4809f536/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
index 834c35c..f094bd5 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
@@ -21,10 +21,12 @@ package org.apache.flink.runtime.state;
 import com.google.common.base.Joiner;
 import org.apache.commons.io.output.ByteArrayOutputStream;
 import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.functions.FoldFunction;
 import org.apache.flink.api.common.functions.ReduceFunction;
 import org.apache.flink.api.common.state.FoldingState;
 import org.apache.flink.api.common.state.FoldingStateDescriptor;
+import org.apache.flink.api.common.state.KeyGroupAssigner;
 import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.state.ReducingState;
@@ -37,21 +39,25 @@ import org.apache.flink.api.common.typeutils.base.IntSerializer;
 import org.apache.flink.api.common.typeutils.base.LongSerializer;
 import org.apache.flink.api.common.typeutils.base.StringSerializer;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.checkpoint.CheckpointCoordinator;
+import org.apache.flink.runtime.execution.Environment;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
 import org.apache.flink.runtime.query.KvStateID;
 import org.apache.flink.runtime.query.KvStateRegistry;
 import org.apache.flink.runtime.query.KvStateRegistryListener;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer;
+import org.apache.flink.runtime.state.heap.AbstractHeapState;
+import org.apache.flink.runtime.state.heap.StateTable;
 import org.apache.flink.types.IntValue;
-import org.junit.After;
-import org.junit.Before;
 import org.junit.Test;
 
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.List;
-import java.util.Map;
+import java.util.Random;
 import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.Future;
+import java.util.concurrent.RunnableFuture;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
@@ -71,27 +77,77 @@ import static org.mockito.Mockito.verify;
 @SuppressWarnings("serial")
 public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 
-	protected B backend;
-
 	protected abstract B getStateBackend() throws Exception;
 
-	protected abstract void cleanup() throws Exception;
+	protected CheckpointStreamFactory createStreamFactory() throws Exception {
+		return getStateBackend().createStreamFactory(new JobID(), "test_op");
+	}
+
+	protected <K> KeyedStateBackend<K> createKeyedBackend(TypeSerializer<K> keySerializer) throws Exception {
+		return createKeyedBackend(keySerializer, new DummyEnvironment("test", 1, 0));
+	}
+
+	protected <K> KeyedStateBackend<K> createKeyedBackend(TypeSerializer<K> keySerializer, Environment env) throws Exception {
+		return createKeyedBackend(
+				keySerializer,
+				new HashKeyGroupAssigner<K>(10),
+				new KeyGroupRange(0, 9),
+				env);
+	}
+
+	protected <K> KeyedStateBackend<K> createKeyedBackend(
+			TypeSerializer<K> keySerializer,
+			KeyGroupAssigner<K> keyGroupAssigner,
+			KeyGroupRange keyGroupRange,
+			Environment env) throws Exception {
+		return getStateBackend().createKeyedStateBackend(
+				env,
+				new JobID(),
+				"test_op",
+				keySerializer,
+				keyGroupAssigner,
+				keyGroupRange,
+				env.getTaskKvStateRegistry());
+	}
 
-	@Before
-	public void setup() throws Exception {
-		this.backend = getStateBackend();
+	protected <K> KeyedStateBackend<K> restoreKeyedBackend(TypeSerializer<K> keySerializer, KeyGroupsStateHandle state) throws Exception {
+		return restoreKeyedBackend(keySerializer, state, new DummyEnvironment("test", 1, 0));
 	}
 
-	@After
-	public void teardown() throws Exception {
-		this.backend.discardState();
-		cleanup();
+	protected <K> KeyedStateBackend<K> restoreKeyedBackend(
+			TypeSerializer<K> keySerializer,
+			KeyGroupsStateHandle state,
+			Environment env) throws Exception {
+		return restoreKeyedBackend(
+				keySerializer,
+				new HashKeyGroupAssigner<K>(10),
+				new KeyGroupRange(0, 9),
+				Collections.singletonList(state),
+				env);
+	}
+
+	protected <K> KeyedStateBackend<K> restoreKeyedBackend(
+			TypeSerializer<K> keySerializer,
+			KeyGroupAssigner<K> keyGroupAssigner,
+			KeyGroupRange keyGroupRange,
+			List<KeyGroupsStateHandle> state,
+			Environment env) throws Exception {
+		return getStateBackend().restoreKeyedStateBackend(
+				env,
+				new JobID(),
+				"test_op",
+				keySerializer,
+				keyGroupAssigner,
+				keyGroupRange,
+				state,
+				env.getTaskKvStateRegistry());
 	}
 
 	@Test
 	@SuppressWarnings("unchecked")
 	public void testValueState() throws Exception {
-		backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE);
+		CheckpointStreamFactory streamFactory = createStreamFactory();
+		KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 		ValueStateDescriptor<String> kvId = new ValueStateDescriptor<>("id", String.class, null);
 		kvId.initializeSerializerUnlessSet(new ExecutionConfig());
@@ -102,7 +158,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 
 		ValueState<String> state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
 		@SuppressWarnings("unchecked")
-		KvState<Integer, VoidNamespace, ?, ?, B> kvState = (KvState<Integer, VoidNamespace, ?, ?, B>) state;
+		KvState<VoidNamespace> kvState = (KvState<VoidNamespace>) state;
 
 		// some modifications to the state
 		backend.setCurrentKey(1);
@@ -118,13 +174,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		assertEquals("1", getSerializedValue(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
 		// draw a snapshot
-		HashMap<String, KvStateSnapshot<?, ?, ?, ?, ?>> snapshot1 = backend.snapshotPartitionedState(682375462378L, 2);
-
-		for (String key: snapshot1.keySet()) {
-			if (snapshot1.get(key) instanceof AsynchronousKvStateSnapshot) {
-				snapshot1.put(key, ((AsynchronousKvStateSnapshot<?, ?, ?, ?, ?>) snapshot1.get(key)).materialize());
-			}
-		}
+		KeyGroupsStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory));
 
 		// make some more modifications
 		backend.setCurrentKey(1);
@@ -135,13 +185,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		state.update("u3");
 
 		// draw another snapshot
-		HashMap<String, KvStateSnapshot<?, ?, ?, ?, ?>> snapshot2 = backend.snapshotPartitionedState(682375462379L, 4);
-
-		for (String key: snapshot2.keySet()) {
-			if (snapshot2.get(key) instanceof AsynchronousKvStateSnapshot) {
-				snapshot2.put(key, ((AsynchronousKvStateSnapshot<?, ?, ?, ?, ?>) snapshot2.get(key)).materialize());
-			}
-		}
+		KeyGroupsStateHandle snapshot2 = runSnapshot(backend.snapshot(682375462379L, 4, streamFactory));
 
 		// validate the original state
 		backend.setCurrentKey(1);
@@ -154,18 +198,14 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		assertEquals("u3", state.value());
 		assertEquals("u3", getSerializedValue(kvState, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
-		backend.discardState();
-		backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE);
+		backend.close();
+		backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot1);
 
-		backend.injectKeyValueStateSnapshots((HashMap) snapshot1);
-
-		for (String key: snapshot1.keySet()) {
-			snapshot1.get(key).discardState();
-		}
+		snapshot1.discardState();
 
 		ValueState<String> restored1 = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
 		@SuppressWarnings("unchecked")
-		KvState<Integer, VoidNamespace, ?, ?, B> restoredKvState1 = (KvState<Integer, VoidNamespace, ?, ?, B>) restored1;
+		KvState<VoidNamespace> restoredKvState1 = (KvState<VoidNamespace>) restored1;
 
 		backend.setCurrentKey(1);
 		assertEquals("1", restored1.value());
@@ -174,18 +214,14 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		assertEquals("2", restored1.value());
 		assertEquals("2", getSerializedValue(restoredKvState1, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
-		backend.discardState();
-		backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE);
+		backend.close();
+		backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot2);
 
-		backend.injectKeyValueStateSnapshots((HashMap) snapshot2);
-
-		for (String key: snapshot2.keySet()) {
-			snapshot2.get(key).discardState();
-		}
+		snapshot2.discardState();
 
 		ValueState<String> restored2 = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
 		@SuppressWarnings("unchecked")
-		KvState<Integer, VoidNamespace, ?, ?, B> restoredKvState2 = (KvState<Integer, VoidNamespace, ?, ?, B>) restored2;
+		KvState<VoidNamespace> restoredKvState2 = (KvState<VoidNamespace>) restored2;
 
 		backend.setCurrentKey(1);
 		assertEquals("u1", restored2.value());
@@ -196,6 +232,68 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		backend.setCurrentKey(3);
 		assertEquals("u3", restored2.value());
 		assertEquals("u3", getSerializedValue(restoredKvState2, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
+
+		backend.close();
+	}
+
+	@Test
+	@SuppressWarnings("unchecked")
+	public void testMultipleValueStates() throws Exception {
+		CheckpointStreamFactory streamFactory = createStreamFactory();
+
+		KeyedStateBackend<Integer> backend = createKeyedBackend(
+				IntSerializer.INSTANCE,
+				new HashKeyGroupAssigner<Integer>(1),
+				new KeyGroupRange(0, 0),
+				new DummyEnvironment("test_op", 1, 0));
+
+		ValueStateDescriptor<String> desc1 = new ValueStateDescriptor<>("a-string", StringSerializer.INSTANCE, null);
+		ValueStateDescriptor<Integer> desc2 = new ValueStateDescriptor<>("an-integer", IntSerializer.INSTANCE, null);
+
+		desc1.initializeSerializerUnlessSet(new ExecutionConfig());
+		desc2.initializeSerializerUnlessSet(new ExecutionConfig());
+
+		ValueState<String> state1 = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, desc1);
+		ValueState<Integer> state2 = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, desc2);
+
+		// some modifications to the state
+		backend.setCurrentKey(1);
+		assertNull(state1.value());
+		assertNull(state2.value());
+		state1.update("1");
+
+		// state2 should still have nothing
+		assertEquals("1", state1.value());
+		assertNull(state2.value());
+		state2.update(13);
+
+		// both have some state now
+		assertEquals("1", state1.value());
+		assertEquals(13, (int) state2.value());
+
+		// draw a snapshot
+		KeyGroupsStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory));
+
+		backend.close();
+		backend = restoreKeyedBackend(
+				IntSerializer.INSTANCE,
+				new HashKeyGroupAssigner<Integer>(1),
+				new KeyGroupRange(0, 0),
+				Collections.singletonList(snapshot1),
+				new DummyEnvironment("test_op", 1, 0));
+
+		snapshot1.discardState();
+
+		backend.setCurrentKey(1);
+
+		state1 = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, desc1);
+		state2 = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, desc2);
+
+		// verify that they are still the same
+		assertEquals("1", state1.value());
+		assertEquals(13, (int) state2.value());
+
+		backend.close();
 	}
 
 	/**
@@ -217,7 +315,8 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			// alrighty
 		}
 
-		backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE);
+		CheckpointStreamFactory streamFactory = createStreamFactory();
+		KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 		ValueStateDescriptor<Long> kvId = new ValueStateDescriptor<>("id", LongSerializer.INSTANCE, 42L);
 		kvId.initializeSerializerUnlessSet(new ExecutionConfig());
@@ -246,31 +345,24 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		assertEquals(42L, (long) state.value());
 
 		// draw a snapshot
-		HashMap<String, KvStateSnapshot<?, ?, ?, ?, ?>> snapshot1 = backend.snapshotPartitionedState(682375462378L, 2);
-
-		for (String key: snapshot1.keySet()) {
-			if (snapshot1.get(key) instanceof AsynchronousKvStateSnapshot) {
-				snapshot1.put(key, ((AsynchronousKvStateSnapshot<?, ?, ?, ?, ?>) snapshot1.get(key)).materialize());
-			}
-		}
-
-		backend.discardState();
-		backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE);
+		KeyGroupsStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory));
 
-		backend.injectKeyValueStateSnapshots((HashMap) snapshot1);
+		backend.close();
+		backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot1);
 
-		for (String key: snapshot1.keySet()) {
-			snapshot1.get(key).discardState();
-		}
+		snapshot1.discardState();
 
 		backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
+
+		backend.close();
 	}
 
 	@Test
 	@SuppressWarnings("unchecked,rawtypes")
 	public void testListState() {
 		try {
-			backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE);
+			CheckpointStreamFactory streamFactory = createStreamFactory();
+			KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 			ListStateDescriptor<String> kvId = new ListStateDescriptor<>("id", String.class);
 			kvId.initializeSerializerUnlessSet(new ExecutionConfig());
@@ -281,7 +373,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 
 			ListState<String> state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
 			@SuppressWarnings("unchecked")
-			KvState<Integer, VoidNamespace, ?, ?, B> kvState = (KvState<Integer, VoidNamespace, ?, ?, B>) state;
+			KvState<VoidNamespace> kvState = (KvState<VoidNamespace>) state;
 
 			Joiner joiner = Joiner.on(",");
 			// some modifications to the state
@@ -298,13 +390,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			assertEquals("1", joiner.join(getSerializedList(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)));
 
 			// draw a snapshot
-			HashMap<String, KvStateSnapshot<?, ?, ?, ?, ?>> snapshot1 = backend.snapshotPartitionedState(682375462378L, 2);
-
-			for (String key: snapshot1.keySet()) {
-				if (snapshot1.get(key) instanceof AsynchronousKvStateSnapshot) {
-					snapshot1.put(key, ((AsynchronousKvStateSnapshot<?, ?, ?, ?, ?>) snapshot1.get(key)).materialize());
-				}
-			}
+			KeyGroupsStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory));
 
 			// make some more modifications
 			backend.setCurrentKey(1);
@@ -315,13 +401,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			state.add("u3");
 
 			// draw another snapshot
-			HashMap<String, KvStateSnapshot<?, ?, ?, ?, ?>> snapshot2 = backend.snapshotPartitionedState(682375462379L, 4);
-
-			for (String key: snapshot2.keySet()) {
-				if (snapshot2.get(key) instanceof AsynchronousKvStateSnapshot) {
-					snapshot2.put(key, ((AsynchronousKvStateSnapshot<?, ?, ?, ?, ?>) snapshot2.get(key)).materialize());
-				}
-			}
+			KeyGroupsStateHandle snapshot2 = runSnapshot(backend.snapshot(682375462379L, 4, streamFactory));
 
 			// validate the original state
 			backend.setCurrentKey(1);
@@ -334,19 +414,14 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			assertEquals("u3", joiner.join(state.get()));
 			assertEquals("u3", joiner.join(getSerializedList(kvState, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)));
 
-			backend.discardState();
-
+			backend.close();
 			// restore the first snapshot and validate it
-			backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE);
-			backend.injectKeyValueStateSnapshots((HashMap) snapshot1);
-
-			for (String key: snapshot1.keySet()) {
-				snapshot1.get(key).discardState();
-			}
+			backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot1);
+			snapshot1.discardState();
 
 			ListState<String> restored1 = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
 			@SuppressWarnings("unchecked")
-			KvState<Integer, VoidNamespace, ?, ?, B> restoredKvState1 = (KvState<Integer, VoidNamespace, ?, ?, B>) restored1;
+			KvState<VoidNamespace> restoredKvState1 = (KvState<VoidNamespace>) restored1;
 
 			backend.setCurrentKey(1);
 			assertEquals("1", joiner.join(restored1.get()));
@@ -355,19 +430,14 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			assertEquals("2", joiner.join(restored1.get()));
 			assertEquals("2", joiner.join(getSerializedList(restoredKvState1, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)));
 
-			backend.discardState();
-
+			backend.close();
 			// restore the second snapshot and validate it
-			backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE);
-			backend.injectKeyValueStateSnapshots((HashMap) snapshot2);
-
-			for (String key: snapshot2.keySet()) {
-				snapshot2.get(key).discardState();
-			}
+			backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot2);
+			snapshot2.discardState();
 
 			ListState<String> restored2 = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
 			@SuppressWarnings("unchecked")
-			KvState<Integer, VoidNamespace, ?, ?, B> restoredKvState2 = (KvState<Integer, VoidNamespace, ?, ?, B>) restored2;
+			KvState<VoidNamespace> restoredKvState2 = (KvState<VoidNamespace>) restored2;
 
 			backend.setCurrentKey(1);
 			assertEquals("1,u1", joiner.join(restored2.get()));
@@ -378,6 +448,8 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			backend.setCurrentKey(3);
 			assertEquals("u3", joiner.join(restored2.get()));
 			assertEquals("u3", joiner.join(getSerializedList(restoredKvState2, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)));
+
+			backend.close();
 		}
 		catch (Exception e) {
 			e.printStackTrace();
@@ -389,7 +461,8 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 	@SuppressWarnings("unchecked")
 	public void testReducingState() {
 		try {
-			backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE);
+			CheckpointStreamFactory streamFactory = createStreamFactory();
+			KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 			ReducingStateDescriptor<String> kvId = new ReducingStateDescriptor<>("id", new AppendingReduce(), String.class);
 			kvId.initializeSerializerUnlessSet(new ExecutionConfig());
@@ -400,7 +473,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 
 			ReducingState<String> state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
 			@SuppressWarnings("unchecked")
-			KvState<Integer, VoidNamespace, ?, ?, B> kvState = (KvState<Integer, VoidNamespace, ?, ?, B>) state;
+			KvState<VoidNamespace> kvState = (KvState<VoidNamespace>) state;
 
 			// some modifications to the state
 			backend.setCurrentKey(1);
@@ -416,13 +489,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			assertEquals("1", getSerializedValue(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
 			// draw a snapshot
-			HashMap<String, KvStateSnapshot<?, ?, ?, ?, ?>> snapshot1 = backend.snapshotPartitionedState(682375462378L, 2);
-
-			for (String key: snapshot1.keySet()) {
-				if (snapshot1.get(key) instanceof AsynchronousKvStateSnapshot) {
-					snapshot1.put(key, ((AsynchronousKvStateSnapshot<?, ?, ?, ?, ?>) snapshot1.get(key)).materialize());
-				}
-			}
+			KeyGroupsStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory));
 
 			// make some more modifications
 			backend.setCurrentKey(1);
@@ -433,13 +500,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			state.add("u3");
 
 			// draw another snapshot
-			HashMap<String, KvStateSnapshot<?, ?, ?, ?, ?>> snapshot2 = backend.snapshotPartitionedState(682375462379L, 4);
-
-			for (String key: snapshot2.keySet()) {
-				if (snapshot2.get(key) instanceof AsynchronousKvStateSnapshot) {
-					snapshot2.put(key, ((AsynchronousKvStateSnapshot<?, ?, ?, ?, ?>) snapshot2.get(key)).materialize());
-				}
-			}
+			KeyGroupsStateHandle snapshot2 = runSnapshot(backend.snapshot(682375462379L, 4, streamFactory));
 
 			// validate the original state
 			backend.setCurrentKey(1);
@@ -452,19 +513,14 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			assertEquals("u3", state.get());
 			assertEquals("u3", getSerializedValue(kvState, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
-			backend.discardState();
-
+			backend.close();
 			// restore the first snapshot and validate it
-			backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE);
-			backend.injectKeyValueStateSnapshots((HashMap) snapshot1);
-
-			for (String key: snapshot1.keySet()) {
-				snapshot1.get(key).discardState();
-			}
+			backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot1);
+			snapshot1.discardState();
 
 			ReducingState<String> restored1 = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
 			@SuppressWarnings("unchecked")
-			KvState<Integer, VoidNamespace, ?, ?, B> restoredKvState1 = (KvState<Integer, VoidNamespace, ?, ?, B>) restored1;
+			KvState<VoidNamespace> restoredKvState1 = (KvState<VoidNamespace>) restored1;
 
 			backend.setCurrentKey(1);
 			assertEquals("1", restored1.get());
@@ -473,19 +529,14 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			assertEquals("2", restored1.get());
 			assertEquals("2", getSerializedValue(restoredKvState1, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
-			backend.discardState();
-
+			backend.close();
 			// restore the second snapshot and validate it
-			backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE);
-			backend.injectKeyValueStateSnapshots((HashMap) snapshot2);
-
-			for (String key: snapshot2.keySet()) {
-				snapshot2.get(key).discardState();
-			}
+			backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot2);
+			snapshot2.discardState();
 
 			ReducingState<String> restored2 = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
 			@SuppressWarnings("unchecked")
-			KvState<Integer, VoidNamespace, ?, ?, B> restoredKvState2 = (KvState<Integer, VoidNamespace, ?, ?, B>) restored2;
+			KvState<VoidNamespace> restoredKvState2 = (KvState<VoidNamespace>) restored2;
 
 			backend.setCurrentKey(1);
 			assertEquals("1,u1", restored2.get());
@@ -496,6 +547,8 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			backend.setCurrentKey(3);
 			assertEquals("u3", restored2.get());
 			assertEquals("u3", getSerializedValue(restoredKvState2, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
+
+			backend.close();
 		}
 		catch (Exception e) {
 			e.printStackTrace();
@@ -507,7 +560,8 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 	@SuppressWarnings("unchecked,rawtypes")
 	public void testFoldingState() {
 		try {
-			backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE);
+			CheckpointStreamFactory streamFactory = createStreamFactory();
+			KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 			FoldingStateDescriptor<Integer, String> kvId = new FoldingStateDescriptor<>("id",
 					"Fold-Initial:",
@@ -521,7 +575,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 
 			FoldingState<Integer, String> state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
 			@SuppressWarnings("unchecked")
-			KvState<Integer, VoidNamespace, ?, ?, B> kvState = (KvState<Integer, VoidNamespace, ?, ?, B>) state;
+			KvState<VoidNamespace> kvState = (KvState<VoidNamespace>) state;
 
 			// some modifications to the state
 			backend.setCurrentKey(1);
@@ -537,13 +591,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			assertEquals("Fold-Initial:,1", getSerializedValue(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
 			// draw a snapshot
-			HashMap<String, KvStateSnapshot<?, ?, ?, ?, ?>> snapshot1 = backend.snapshotPartitionedState(682375462378L, 2);
-
-			for (String key: snapshot1.keySet()) {
-				if (snapshot1.get(key) instanceof AsynchronousKvStateSnapshot) {
-					snapshot1.put(key, ((AsynchronousKvStateSnapshot<?, ?, ?, ?, ?>) snapshot1.get(key)).materialize());
-				}
-			}
+			KeyGroupsStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory));
 
 			// make some more modifications
 			backend.setCurrentKey(1);
@@ -555,13 +603,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			state.add(103);
 
 			// draw another snapshot
-			HashMap<String, KvStateSnapshot<?, ?, ?, ?, ?>> snapshot2 = backend.snapshotPartitionedState(682375462379L, 4);
-
-			for (String key: snapshot2.keySet()) {
-				if (snapshot2.get(key) instanceof AsynchronousKvStateSnapshot) {
-					snapshot2.put(key, ((AsynchronousKvStateSnapshot<?, ?, ?, ?, ?>) snapshot2.get(key)).materialize());
-				}
-			}
+			KeyGroupsStateHandle snapshot2 = runSnapshot(backend.snapshot(682375462379L, 4, streamFactory));
 
 			// validate the original state
 			backend.setCurrentKey(1);
@@ -574,19 +616,14 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			assertEquals("Fold-Initial:,103", state.get());
 			assertEquals("Fold-Initial:,103", getSerializedValue(kvState, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
-			backend.discardState();
-
+			backend.close();
 			// restore the first snapshot and validate it
-			backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE);
-			backend.injectKeyValueStateSnapshots((HashMap) snapshot1);
-
-			for (String key: snapshot1.keySet()) {
-				snapshot1.get(key).discardState();
-			}
+			backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot1);
+			snapshot1.discardState();
 
 			FoldingState<Integer, String> restored1 = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
 			@SuppressWarnings("unchecked")
-			KvState<Integer, VoidNamespace, ?, ?, B> restoredKvState1 = (KvState<Integer, VoidNamespace, ?, ?, B>) restored1;
+			KvState<VoidNamespace> restoredKvState1 = (KvState<VoidNamespace>) restored1;
 
 			backend.setCurrentKey(1);
 			assertEquals("Fold-Initial:,1", restored1.get());
@@ -595,20 +632,15 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			assertEquals("Fold-Initial:,2", restored1.get());
 			assertEquals("Fold-Initial:,2", getSerializedValue(restoredKvState1, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
-			backend.discardState();
-
+			backend.close();
 			// restore the second snapshot and validate it
-			backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE);
-			backend.injectKeyValueStateSnapshots((HashMap) snapshot2);
-
-			for (String key: snapshot2.keySet()) {
-				snapshot2.get(key).discardState();
-			}
+			backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot2);
+			snapshot1.discardState();
 
 			@SuppressWarnings("unchecked")
 			FoldingState<Integer, String> restored2 = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
 			@SuppressWarnings("unchecked")
-			KvState<Integer, VoidNamespace, ?, ?, B> restoredKvState2 = (KvState<Integer, VoidNamespace, ?, ?, B>) restored2;
+			KvState<VoidNamespace> restoredKvState2 = (KvState<VoidNamespace>) restored2;
 
 			backend.setCurrentKey(1);
 			assertEquals("Fold-Initial:,101", restored2.get());
@@ -619,6 +651,8 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			backend.setCurrentKey(3);
 			assertEquals("Fold-Initial:,103", restored2.get());
 			assertEquals("Fold-Initial:,103", getSerializedValue(restoredKvState2, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
+
+			backend.close();
 		}
 		catch (Exception e) {
 			e.printStackTrace();
@@ -626,17 +660,115 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		}
 	}
 
+	/**
+	 * This test verifies that state is correctly assigned to key groups and that restore
+	 * restores the relevant key groups in the backend.
+	 *
+	 * <p>We have ten key groups. Initially, one backend is responsible for all ten key groups.
+	 * Then we snapshot, split up the state and restore in to backends where each is responsible
+	 * for five key groups. Then we make sure that the state is only available in the correct
+	 * backend.
+	 * @throws Exception
+	 */
+	@Test
+	public void testKeyGroupSnapshotRestore() throws Exception {
+		final int MAX_PARALLELISM = 10;
+
+		CheckpointStreamFactory streamFactory = createStreamFactory();
+
+		HashKeyGroupAssigner<Integer> keyGroupAssigner = new HashKeyGroupAssigner<>(10);
+
+		KeyedStateBackend<Integer> backend = createKeyedBackend(
+				IntSerializer.INSTANCE,
+				keyGroupAssigner,
+				new KeyGroupRange(0, MAX_PARALLELISM - 1),
+				new DummyEnvironment("test", 1, 0));
+
+		ValueStateDescriptor<String> kvId = new ValueStateDescriptor<>("id", String.class, null);
+		kvId.initializeSerializerUnlessSet(new ExecutionConfig());
+
+		ValueState<String> state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
+
+		// keys that fall into the first half/second half of the key groups, respectively
+		int keyInFirstHalf = 17;
+		int keyInSecondHalf = 42;
+		Random rand = new Random(0);
+
+		// for each key, determine into which half of the key-group space they fall
+		int firstKeyHalf = keyGroupAssigner.getKeyGroupIndex(keyInFirstHalf) * 2 / MAX_PARALLELISM;
+		int secondKeyHalf = keyGroupAssigner.getKeyGroupIndex(keyInSecondHalf) * 2 / MAX_PARALLELISM;
+
+		while (firstKeyHalf == secondKeyHalf) {
+			keyInSecondHalf = rand.nextInt();
+			secondKeyHalf = keyGroupAssigner.getKeyGroupIndex(keyInSecondHalf) * 2 / MAX_PARALLELISM;
+		}
+
+		backend.setCurrentKey(keyInFirstHalf);
+		state.update("ShouldBeInFirstHalf");
+
+		backend.setCurrentKey(keyInSecondHalf);
+		state.update("ShouldBeInSecondHalf");
+
+
+		KeyGroupsStateHandle snapshot = runSnapshot(backend.snapshot(0, 0, streamFactory));
+
+		List<KeyGroupsStateHandle> firstHalfKeyGroupStates = CheckpointCoordinator.getKeyGroupsStateHandles(
+				Collections.singletonList(snapshot),
+				new KeyGroupRange(0, 4));
+
+		List<KeyGroupsStateHandle> secondHalfKeyGroupStates = CheckpointCoordinator.getKeyGroupsStateHandles(
+				Collections.singletonList(snapshot),
+				new KeyGroupRange(5, 9));
+
+		backend.close();
+
+		// backend for the first half of the key group range
+		KeyedStateBackend<Integer> firstHalfBackend = restoreKeyedBackend(
+				IntSerializer.INSTANCE,
+				keyGroupAssigner,
+				new KeyGroupRange(0, 4),
+				firstHalfKeyGroupStates,
+				new DummyEnvironment("test", 1, 0));
+
+		// backend for the second half of the key group range
+		KeyedStateBackend<Integer> secondHalfBackend = restoreKeyedBackend(
+				IntSerializer.INSTANCE,
+				keyGroupAssigner,
+				new KeyGroupRange(5, 9),
+				secondHalfKeyGroupStates,
+				new DummyEnvironment("test", 1, 0));
+
+
+		ValueState<String> firstHalfState = firstHalfBackend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
+
+		firstHalfBackend.setCurrentKey(keyInFirstHalf);
+		assertTrue(firstHalfState.value().equals("ShouldBeInFirstHalf"));
+
+		firstHalfBackend.setCurrentKey(keyInSecondHalf);
+		assertTrue(firstHalfState.value() == null);
+
+		ValueState<String> secondHalfState = secondHalfBackend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
+
+		secondHalfBackend.setCurrentKey(keyInFirstHalf);
+		assertTrue(secondHalfState.value() == null);
+
+		secondHalfBackend.setCurrentKey(keyInSecondHalf);
+		assertTrue(secondHalfState.value().equals("ShouldBeInSecondHalf"));
+
+		firstHalfBackend.close();
+		secondHalfBackend.close();
+	}
+
 	@Test
 	@SuppressWarnings("unchecked")
 	public void testValueStateRestoreWithWrongSerializers() {
 		try {
-			backend.initializeForJob(new DummyEnvironment("test", 1, 0),
-				"test_op",
-				IntSerializer.INSTANCE);
+			CheckpointStreamFactory streamFactory = createStreamFactory();
+			KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 			ValueStateDescriptor<String> kvId = new ValueStateDescriptor<>("id", String.class, null);
 			kvId.initializeSerializerUnlessSet(new ExecutionConfig());
-			
+
 			ValueState<String> state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
 
 			backend.setCurrentKey(1);
@@ -645,23 +777,12 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			state.update("2");
 
 			// draw a snapshot
-			HashMap<String, KvStateSnapshot<?, ?, ?, ?, ?>> snapshot1 = backend.snapshotPartitionedState(682375462378L, 2);
-
-			for (String key: snapshot1.keySet()) {
-				if (snapshot1.get(key) instanceof AsynchronousKvStateSnapshot) {
-					snapshot1.put(key, ((AsynchronousKvStateSnapshot<?, ?, ?, ?, ?>) snapshot1.get(key)).materialize());
-				}
-			}
-
-			backend.discardState();
+			KeyGroupsStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory));
 
+			backend.close();
 			// restore the first snapshot and validate it
-			backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE);
-			backend.injectKeyValueStateSnapshots((HashMap) snapshot1);
-
-			for (String key: snapshot1.keySet()) {
-				snapshot1.get(key).discardState();
-			}
+			backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot1);
+			snapshot1.discardState();
 
 			@SuppressWarnings("unchecked")
 			TypeSerializer<String> fakeStringSerializer =
@@ -683,6 +804,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			} catch (Exception e) {
 				fail("wrong exception " + e);
 			}
+			backend.close();
 		}
 		catch (Exception e) {
 			e.printStackTrace();
@@ -694,7 +816,8 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 	@SuppressWarnings("unchecked")
 	public void testListStateRestoreWithWrongSerializers() {
 		try {
-			backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE);
+			CheckpointStreamFactory streamFactory = createStreamFactory();
+			KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 			ListStateDescriptor<String> kvId = new ListStateDescriptor<>("id", String.class);
 			ListState<String> state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
@@ -705,23 +828,12 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			state.add("2");
 
 			// draw a snapshot
-			HashMap<String, KvStateSnapshot<?, ?, ?, ?, ?>> snapshot1 = backend.snapshotPartitionedState(682375462378L, 2);
-
-			for (String key: snapshot1.keySet()) {
-				if (snapshot1.get(key) instanceof AsynchronousKvStateSnapshot) {
-					snapshot1.put(key, ((AsynchronousKvStateSnapshot<?, ?, ?, ?, ?>) snapshot1.get(key)).materialize());
-				}
-			}
-
-			backend.discardState();
+			KeyGroupsStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory));
 
+			backend.close();
 			// restore the first snapshot and validate it
-			backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE);
-			backend.injectKeyValueStateSnapshots((HashMap) snapshot1);
-
-			for (String key: snapshot1.keySet()) {
-				snapshot1.get(key).discardState();
-			}
+			backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot1);
+			snapshot1.discardState();
 
 			@SuppressWarnings("unchecked")
 			TypeSerializer<String> fakeStringSerializer =
@@ -743,6 +855,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			} catch (Exception e) {
 				fail("wrong exception " + e);
 			}
+			backend.close();
 		}
 		catch (Exception e) {
 			e.printStackTrace();
@@ -754,7 +867,8 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 	@SuppressWarnings("unchecked")
 	public void testReducingStateRestoreWithWrongSerializers() {
 		try {
-			backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE);
+			CheckpointStreamFactory streamFactory = createStreamFactory();
+			KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 			ReducingStateDescriptor<String> kvId = new ReducingStateDescriptor<>("id",
 					new AppendingReduce(),
@@ -767,23 +881,12 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			state.add("2");
 
 			// draw a snapshot
-			HashMap<String, KvStateSnapshot<?, ?, ?, ?, ?>> snapshot1 = backend.snapshotPartitionedState(682375462378L, 2);
-
-			for (String key: snapshot1.keySet()) {
-				if (snapshot1.get(key) instanceof AsynchronousKvStateSnapshot) {
-					snapshot1.put(key, ((AsynchronousKvStateSnapshot<?, ?, ?, ?, ?>) snapshot1.get(key)).materialize());
-				}
-			}
-
-			backend.discardState();
+			KeyGroupsStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory));
 
+			backend.close();
 			// restore the first snapshot and validate it
-			backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE);
-			backend.injectKeyValueStateSnapshots((HashMap) snapshot1);
-
-			for (String key: snapshot1.keySet()) {
-				snapshot1.get(key).discardState();
-			}
+			backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot1);
+			snapshot1.discardState();
 
 			@SuppressWarnings("unchecked")
 			TypeSerializer<String> fakeStringSerializer =
@@ -805,6 +908,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			} catch (Exception e) {
 				fail("wrong exception " + e);
 			}
+			backend.close();
 		}
 		catch (Exception e) {
 			e.printStackTrace();
@@ -814,7 +918,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 
 	@Test
 	public void testCopyDefaultValue() throws Exception {
-		backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE);
+		KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 		ValueStateDescriptor<IntValue> kvId = new ValueStateDescriptor<>("id", IntValue.class, new IntValue(-1));
 		kvId.initializeSerializerUnlessSet(new ExecutionConfig());
@@ -831,6 +935,8 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		assertNotNull(default2);
 		assertEquals(default1, default2);
 		assertFalse(default1 == default2);
+
+		backend.close();
 	}
 
 	/**
@@ -840,7 +946,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 	 */
 	@Test
 	public void testRequireNonNullNamespace() throws Exception {
-		backend.initializeForJob(new DummyEnvironment("test", 1, 0), "test_op", IntSerializer.INSTANCE);
+		KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 		ValueStateDescriptor<IntValue> kvId = new ValueStateDescriptor<>("id", IntValue.class, new IntValue(-1));
 		kvId.initializeSerializerUnlessSet(new ExecutionConfig());
@@ -862,6 +968,8 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			fail("Did not throw expected NullPointerException");
 		} catch (NullPointerException ignored) {
 		}
+
+		backend.close();
 	}
 
 	/**
@@ -869,7 +977,13 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 	 * flag and create concurrent variants for internal state structures.
 	 */
 	@SuppressWarnings("unchecked")
-	protected static <B extends AbstractStateBackend> void testConcurrentMapIfQueryable(B backend) throws Exception {
+	protected void testConcurrentMapIfQueryable() throws Exception {
+		KeyedStateBackend<Integer> backend = createKeyedBackend(
+				IntSerializer.INSTANCE,
+				new HashKeyGroupAssigner<Integer>(1),
+				new KeyGroupRange(0, 0),
+				new DummyEnvironment("test_op", 1, 0));
+
 		{
 			// ValueState
 			ValueStateDescriptor<Integer> desc = new ValueStateDescriptor<>(
@@ -884,20 +998,19 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 					VoidNamespaceSerializer.INSTANCE,
 					desc);
 
-			KvState<Integer, VoidNamespace, ?, ?, ?> kvState = (KvState<Integer, VoidNamespace, ?, ?, ?>) state;
+			KvState<VoidNamespace> kvState = (KvState<VoidNamespace>) state;
 			assertTrue(kvState instanceof AbstractHeapState);
 
-			Map<VoidNamespace, Map<Integer, ?>> stateMap = ((AbstractHeapState) kvState).getStateMap();
-			assertTrue(stateMap instanceof ConcurrentHashMap);
-
 			kvState.setCurrentNamespace(VoidNamespace.INSTANCE);
-			kvState.setCurrentKey(1);
+			backend.setCurrentKey(1);
 			state.update(121818273);
 
-			Map<Integer, ?> namespaceMap = stateMap.get(VoidNamespace.INSTANCE);
+			int keyGroupIndex = new HashKeyGroupAssigner<>(1).getKeyGroupIndex(1);
+			StateTable stateTable = ((AbstractHeapState) kvState).getStateTable();
+			assertNotNull("State not set", stateTable.get(keyGroupIndex));
+			assertTrue(stateTable.get(keyGroupIndex) instanceof ConcurrentHashMap);
+			assertTrue(stateTable.get(keyGroupIndex).get(VoidNamespace.INSTANCE) instanceof ConcurrentHashMap);
 
-			assertNotNull("Value not set", namespaceMap);
-			assertTrue(namespaceMap instanceof ConcurrentHashMap);
 		}
 
 		{
@@ -911,20 +1024,18 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 					VoidNamespaceSerializer.INSTANCE,
 					desc);
 
-			KvState<Integer, VoidNamespace, ?, ?, ?> kvState = (KvState<Integer, VoidNamespace, ?, ?, ?>) state;
+			KvState<VoidNamespace> kvState = (KvState<VoidNamespace>) state;
 			assertTrue(kvState instanceof AbstractHeapState);
 
-			Map<VoidNamespace, Map<Integer, ?>> stateMap = ((AbstractHeapState) kvState).getStateMap();
-			assertTrue(stateMap instanceof ConcurrentHashMap);
-
 			kvState.setCurrentNamespace(VoidNamespace.INSTANCE);
-			kvState.setCurrentKey(1);
+			backend.setCurrentKey(1);
 			state.add(121818273);
 
-			Map<Integer, ?> namespaceMap = stateMap.get(VoidNamespace.INSTANCE);
-
-			assertNotNull("List not set", namespaceMap);
-			assertTrue(namespaceMap instanceof ConcurrentHashMap);
+			int keyGroupIndex = new HashKeyGroupAssigner<>(1).getKeyGroupIndex(1);
+			StateTable stateTable = ((AbstractHeapState) kvState).getStateTable();
+			assertNotNull("State not set", stateTable.get(keyGroupIndex));
+			assertTrue(stateTable.get(keyGroupIndex) instanceof ConcurrentHashMap);
+			assertTrue(stateTable.get(keyGroupIndex).get(VoidNamespace.INSTANCE) instanceof ConcurrentHashMap);
 		}
 
 		{
@@ -944,20 +1055,18 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 					VoidNamespaceSerializer.INSTANCE,
 					desc);
 
-			KvState<Integer, VoidNamespace, ?, ?, ?> kvState = (KvState<Integer, VoidNamespace, ?, ?, ?>) state;
+			KvState<VoidNamespace> kvState = (KvState<VoidNamespace>) state;
 			assertTrue(kvState instanceof AbstractHeapState);
 
-			Map<VoidNamespace, Map<Integer, ?>> stateMap = ((AbstractHeapState) kvState).getStateMap();
-			assertTrue(stateMap instanceof ConcurrentHashMap);
-
 			kvState.setCurrentNamespace(VoidNamespace.INSTANCE);
-			kvState.setCurrentKey(1);
+			backend.setCurrentKey(1);
 			state.add(121818273);
 
-			Map<Integer, ?> namespaceMap = stateMap.get(VoidNamespace.INSTANCE);
-
-			assertNotNull("List not set", namespaceMap);
-			assertTrue(namespaceMap instanceof ConcurrentHashMap);
+			int keyGroupIndex = new HashKeyGroupAssigner<>(1).getKeyGroupIndex(1);
+			StateTable stateTable = ((AbstractHeapState) kvState).getStateTable();
+			assertNotNull("State not set", stateTable.get(keyGroupIndex));
+			assertTrue(stateTable.get(keyGroupIndex) instanceof ConcurrentHashMap);
+			assertTrue(stateTable.get(keyGroupIndex).get(VoidNamespace.INSTANCE) instanceof ConcurrentHashMap);
 		}
 
 		{
@@ -977,21 +1086,21 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 					VoidNamespaceSerializer.INSTANCE,
 					desc);
 
-			KvState<Integer, VoidNamespace, ?, ?, ?> kvState = (KvState<Integer, VoidNamespace, ?, ?, ?>) state;
+			KvState<VoidNamespace> kvState = (KvState<VoidNamespace>) state;
 			assertTrue(kvState instanceof AbstractHeapState);
 
-			Map<VoidNamespace, Map<Integer, ?>> stateMap = ((AbstractHeapState) kvState).getStateMap();
-			assertTrue(stateMap instanceof ConcurrentHashMap);
-
 			kvState.setCurrentNamespace(VoidNamespace.INSTANCE);
-			kvState.setCurrentKey(1);
+			backend.setCurrentKey(1);
 			state.add(121818273);
 
-			Map<Integer, ?> namespaceMap = stateMap.get(VoidNamespace.INSTANCE);
-
-			assertNotNull("List not set", namespaceMap);
-			assertTrue(namespaceMap instanceof ConcurrentHashMap);
+			int keyGroupIndex = new HashKeyGroupAssigner<>(1).getKeyGroupIndex(1);
+			StateTable stateTable = ((AbstractHeapState) kvState).getStateTable();
+			assertNotNull("State not set", stateTable.get(keyGroupIndex));
+			assertTrue(stateTable.get(keyGroupIndex) instanceof ConcurrentHashMap);
+			assertTrue(stateTable.get(keyGroupIndex).get(VoidNamespace.INSTANCE) instanceof ConcurrentHashMap);
 		}
+
+		backend.close();
 	}
 
 	/**
@@ -1002,11 +1111,12 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		DummyEnvironment env = new DummyEnvironment("test", 1, 0);
 		KvStateRegistry registry = env.getKvStateRegistry();
 
+		CheckpointStreamFactory streamFactory = createStreamFactory();
+		KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE, env);
+
 		KvStateRegistryListener listener = mock(KvStateRegistryListener.class);
 		registry.registerListener(listener);
 
-		backend.initializeForJob(env, "test_op", IntSerializer.INSTANCE);
-
 		ValueStateDescriptor<Integer> desc = new ValueStateDescriptor<>(
 				"test",
 				IntSerializer.INSTANCE,
@@ -1020,25 +1130,16 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 				eq(env.getJobID()), eq(env.getJobVertexId()), eq(0), eq("banana"), any(KvStateID.class));
 
 
-		HashMap<String, KvStateSnapshot<?, ?, ?, ?, ?>> snapshot = backend
-				.snapshotPartitionedState(682375462379L, 4);
-
-		for (String key: snapshot.keySet()) {
-			if (snapshot.get(key) instanceof AsynchronousKvStateSnapshot) {
-				snapshot.put(key, ((AsynchronousKvStateSnapshot<?, ?, ?, ?, ?>) snapshot.get(key)).materialize());
-			}
-		}
+		KeyGroupsStateHandle snapshot = runSnapshot(backend.snapshot(682375462379L, 4, streamFactory));
 
-		// Verify unregistered
-		backend.discardState();
+		backend.close();
 
 		verify(listener, times(1)).notifyKvStateUnregistered(
 				eq(env.getJobID()), eq(env.getJobVertexId()), eq(0), eq("banana"));
-
+		backend.close();
 		// Initialize again
-		backend.initializeForJob(env, "test_op", IntSerializer.INSTANCE);
-
-		backend.injectKeyValueStateSnapshots((HashMap) snapshot);
+		backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot, env);
+		snapshot.discardState();
 
 		backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, desc);
 
@@ -1046,6 +1147,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		verify(listener, times(2)).notifyKvStateRegistered(
 				eq(env.getJobID()), eq(env.getJobVertexId()), eq(0), eq("banana"), any(KvStateID.class));
 
+		backend.close();
 
 	}
 	
@@ -1093,7 +1195,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 	 * if it is not null.
 	 */
 	private static <V, K, N> V getSerializedValue(
-			KvState<K, N, ?, ?, ?> kvState,
+			KvState<N> kvState,
 			K key,
 			TypeSerializer<K> keySerializer,
 			N namespace,
@@ -1117,7 +1219,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 	 * if it is not null.
 	 */
 	private static <V, K, N> List<V> getSerializedList(
-			KvState<K, N, ?, ?, ?> kvState,
+			KvState<N> kvState,
 			K key,
 			TypeSerializer<K> keySerializer,
 			N namespace,
@@ -1135,4 +1237,12 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			return KvStateRequestSerializer.deserializeList(serializedValue, valueSerializer);
 		}
 	}
+
+	private KeyGroupsStateHandle runSnapshot(RunnableFuture<KeyGroupsStateHandle> snapshotRunnableFuture) throws Exception {
+		if(!snapshotRunnableFuture.isDone()) {
+			Thread runner = new Thread(snapshotRunnableFuture);
+			runner.start();
+		}
+		return snapshotRunnableFuture.get();
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/4809f536/flink-runtime/src/test/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStateOutputStreamTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStateOutputStreamTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStateOutputStreamTest.java
index 1d45115..a6a555d 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStateOutputStreamTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStateOutputStreamTest.java
@@ -34,29 +34,26 @@ import java.util.Random;
 import static org.junit.Assert.*;
 
 public class FsCheckpointStateOutputStreamTest {
-	
+
 	/** The temp dir, obtained in a platform neutral way */
 	private static final Path TEMP_DIR_PATH = new Path(new File(System.getProperty("java.io.tmpdir")).toURI());
-	
-	
+
+
 	@Test(expected = IllegalArgumentException.class)
 	public void testWrongParameters() {
 		// this should fail
-		new FsStateBackend.FsCheckpointStateOutputStream(
+		new FsCheckpointStreamFactory.FsCheckpointStateOutputStream(
 			TEMP_DIR_PATH, FileSystem.getLocalFileSystem(), 4000, 5000);
 	}
 
 
 	@Test
 	public void testEmptyState() throws Exception {
-		AbstractStateBackend.CheckpointStateOutputStream stream = new FsStateBackend.FsCheckpointStateOutputStream(
-			TEMP_DIR_PATH, FileSystem.getLocalFileSystem(), 1024, 512);
+		FsCheckpointStreamFactory.CheckpointStateOutputStream stream =
+				new FsCheckpointStreamFactory.FsCheckpointStateOutputStream(TEMP_DIR_PATH, FileSystem.getLocalFileSystem(), 1024, 512);
 
 		StreamStateHandle handle = stream.closeAndGetHandle();
-		assertTrue(handle instanceof ByteStreamStateHandle);
-
-		InputStream inStream = handle.openInputStream();
-		assertEquals(-1, inStream.read());
+		assertTrue(handle == null);
 	}
 
 	@Test
@@ -73,17 +70,17 @@ public class FsCheckpointStateOutputStreamTest {
 	public void testStateAboveMemThreshold() throws Exception {
 		runTest(576446, 259, 17, true);
 	}
-	
+
 	@Test
 	public void testZeroThreshold() throws Exception {
 		runTest(16678, 4096, 0, true);
 	}
-	
+
 	private void runTest(int numBytes, int bufferSize, int threshold, boolean expectFile) throws Exception {
-		AbstractStateBackend.CheckpointStateOutputStream stream =
-			new FsStateBackend.FsCheckpointStateOutputStream(
+		FsCheckpointStreamFactory.CheckpointStateOutputStream stream =
+			new FsCheckpointStreamFactory.FsCheckpointStateOutputStream(
 				TEMP_DIR_PATH, FileSystem.getLocalFileSystem(), bufferSize, threshold);
-		
+
 		Random rnd = new Random();
 		byte[] original = new byte[numBytes];
 		byte[] bytes = new byte[original.length];

http://git-wip-us.apache.org/repos/asf/flink/blob/4809f536/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
index 429fc6b..ab4ca3b 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
@@ -159,7 +159,7 @@ public class TaskAsyncCallTest {
 		TaskDeploymentDescriptor tdd = new TaskDeploymentDescriptor(
 				new JobID(), "Job Name", new JobVertexID(), new ExecutionAttemptID(),
 				new SerializedValue<>(new ExecutionConfig()),
-				"Test Task", 0, 1, 0,
+				"Test Task", 1, 0, 1, 0,
 				new Configuration(), new Configuration(),
 				CheckpointsInOrderInvokable.class.getName(),
 				Collections.<ResultPartitionDeploymentDescriptor>emptyList(),

http://git-wip-us.apache.org/repos/asf/flink/blob/4809f536/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskManagerTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskManagerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskManagerTest.java
index ce88c09..54cd7c6 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskManagerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskManagerTest.java
@@ -162,7 +162,7 @@ public class TaskManagerTest extends TestLogger {
 				final SerializedValue<ExecutionConfig> executionConfig = new SerializedValue<>(new ExecutionConfig());
 
 				final TaskDeploymentDescriptor tdd = new TaskDeploymentDescriptor(jid, "TestJob", vid, eid, executionConfig,
-						"TestTask", 2, 7, 0, new Configuration(), new Configuration(),
+						"TestTask", 7, 2, 7, 0, new Configuration(), new Configuration(),
 						TestInvokableCorrect.class.getName(),
 						Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
 						Collections.<InputGateDeploymentDescriptor>emptyList(),
@@ -265,7 +265,7 @@ public class TaskManagerTest extends TestLogger {
 				final TaskDeploymentDescriptor tdd1 = new TaskDeploymentDescriptor(
 						jid1, "TestJob1", vid1, eid1,
 						new SerializedValue<>(new ExecutionConfig()),
-						"TestTask1", 1, 5, 0,
+						"TestTask1", 5, 1, 5, 0,
 						new Configuration(), new Configuration(), TestInvokableBlockingCancelable.class.getName(),
 						Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
 						Collections.<InputGateDeploymentDescriptor>emptyList(),
@@ -274,7 +274,7 @@ public class TaskManagerTest extends TestLogger {
 				final TaskDeploymentDescriptor tdd2 = new TaskDeploymentDescriptor(
 						jid2, "TestJob2", vid2, eid2,
 						new SerializedValue<>(new ExecutionConfig()),
-						"TestTask2", 2, 7, 0,
+						"TestTask2", 7, 2, 7, 0,
 						new Configuration(), new Configuration(), TestInvokableBlockingCancelable.class.getName(),
 						Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
 						Collections.<InputGateDeploymentDescriptor>emptyList(),
@@ -403,13 +403,13 @@ public class TaskManagerTest extends TestLogger {
 				final SerializedValue<ExecutionConfig> executionConfig = new SerializedValue<>(new ExecutionConfig());
 
 				final TaskDeploymentDescriptor tdd1 = new TaskDeploymentDescriptor(jid1, "TestJob", vid1, eid1, executionConfig,
-						"TestTask1", 1, 5, 0, new Configuration(), new Configuration(), StoppableInvokable.class.getName(),
+						"TestTask1", 5, 1, 5, 0, new Configuration(), new Configuration(), StoppableInvokable.class.getName(),
 						Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
 						Collections.<InputGateDeploymentDescriptor>emptyList(),
 						new ArrayList<BlobKey>(), Collections.<URL>emptyList(), 0);
 
 				final TaskDeploymentDescriptor tdd2 = new TaskDeploymentDescriptor(jid2, "TestJob", vid2, eid2, executionConfig,
-						"TestTask2", 2, 7, 0, new Configuration(), new Configuration(), TestInvokableBlockingCancelable.class.getName(),
+						"TestTask2", 7, 2, 7, 0, new Configuration(), new Configuration(), TestInvokableBlockingCancelable.class.getName(),
 						Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
 						Collections.<InputGateDeploymentDescriptor>emptyList(),
 						new ArrayList<BlobKey>(), Collections.<URL>emptyList(), 0);
@@ -531,7 +531,7 @@ public class TaskManagerTest extends TestLogger {
 				final TaskDeploymentDescriptor tdd1 = new TaskDeploymentDescriptor(
 						jid, "TestJob", vid1, eid1,
 						new SerializedValue<>(new ExecutionConfig()),
-						"Sender", 0, 1, 0,
+						"Sender", 1, 0, 1, 0,
 						new Configuration(), new Configuration(), Tasks.Sender.class.getName(),
 						Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
 						Collections.<InputGateDeploymentDescriptor>emptyList(),
@@ -540,7 +540,7 @@ public class TaskManagerTest extends TestLogger {
 				final TaskDeploymentDescriptor tdd2 = new TaskDeploymentDescriptor(
 						jid, "TestJob", vid2, eid2,
 						new SerializedValue<>(new ExecutionConfig()),
-						"Receiver", 2, 7, 0,
+						"Receiver", 7, 2, 7, 0,
 						new Configuration(), new Configuration(), Tasks.Receiver.class.getName(),
 						Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
 						Collections.<InputGateDeploymentDescriptor>emptyList(),
@@ -636,7 +636,7 @@ public class TaskManagerTest extends TestLogger {
 				final TaskDeploymentDescriptor tdd1 = new TaskDeploymentDescriptor(
 						jid, "TestJob", vid1, eid1,
 						new SerializedValue<>(new ExecutionConfig()),
-						"Sender", 0, 1, 0,
+						"Sender", 1, 0, 1, 0,
 						new Configuration(), new Configuration(), Tasks.Sender.class.getName(),
 						irpdd, Collections.<InputGateDeploymentDescriptor>emptyList(), new ArrayList<BlobKey>(),
 						Collections.<URL>emptyList(), 0);
@@ -644,7 +644,7 @@ public class TaskManagerTest extends TestLogger {
 				final TaskDeploymentDescriptor tdd2 = new TaskDeploymentDescriptor(
 						jid, "TestJob", vid2, eid2,
 						new SerializedValue<>(new ExecutionConfig()),
-						"Receiver", 2, 7, 0,
+						"Receiver", 7, 2, 7, 0,
 						new Configuration(), new Configuration(), Tasks.Receiver.class.getName(),
 						Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
 						Collections.singletonList(ircdd),
@@ -781,7 +781,7 @@ public class TaskManagerTest extends TestLogger {
 				final TaskDeploymentDescriptor tdd1 = new TaskDeploymentDescriptor(
 						jid, "TestJob", vid1, eid1,
 						new SerializedValue<>(new ExecutionConfig()),
-						"Sender", 0, 1, 0,
+						"Sender", 1, 0, 1, 0,
 						new Configuration(), new Configuration(), Tasks.Sender.class.getName(),
 						irpdd, Collections.<InputGateDeploymentDescriptor>emptyList(),
 						new ArrayList<BlobKey>(), Collections.<URL>emptyList(), 0);
@@ -789,7 +789,7 @@ public class TaskManagerTest extends TestLogger {
 				final TaskDeploymentDescriptor tdd2 = new TaskDeploymentDescriptor(
 						jid, "TestJob", vid2, eid2,
 						new SerializedValue<>(new ExecutionConfig()),
-						"Receiver", 2, 7, 0,
+						"Receiver", 7, 2, 7, 0,
 						new Configuration(), new Configuration(), Tasks.BlockingReceiver.class.getName(),
 						Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
 						Collections.singletonList(ircdd),
@@ -929,7 +929,7 @@ public class TaskManagerTest extends TestLogger {
 				final TaskDeploymentDescriptor tdd = new TaskDeploymentDescriptor(
 						jid, "TestJob", vid, eid,
 						new SerializedValue<>(new ExecutionConfig()),
-						"Receiver", 0, 1, 0,
+						"Receiver", 1, 0, 1, 0,
 						new Configuration(), new Configuration(),
 						Tasks.AgnosticReceiver.class.getName(),
 						Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
@@ -1025,7 +1025,7 @@ public class TaskManagerTest extends TestLogger {
 				final TaskDeploymentDescriptor tdd = new TaskDeploymentDescriptor(
 						jid, "TestJob", vid, eid,
 						new SerializedValue<>(new ExecutionConfig()),
-						"Receiver", 0, 1, 0,
+						"Receiver", 1, 0, 1, 0,
 						new Configuration(), new Configuration(),
 						Tasks.AgnosticReceiver.class.getName(),
 						Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
@@ -1104,6 +1104,7 @@ public class TaskManagerTest extends TestLogger {
 						new ExecutionAttemptID(),
 						new SerializedValue<>(new ExecutionConfig()),
 						"Task",
+						1,
 						0,
 						1,
 						0,

http://git-wip-us.apache.org/repos/asf/flink/blob/4809f536/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java
index 2f8e3db..f145b48 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java
@@ -639,7 +639,7 @@ public class TaskTest {
 		return new TaskDeploymentDescriptor(
 				new JobID(), "Test Job", new JobVertexID(), new ExecutionAttemptID(),
 				execConfig,
-				"Test Task", 0, 1, 0,
+				"Test Task", 1, 0, 1, 0,
 				new Configuration(), new Configuration(),
 				invokable.getName(),
 				Collections.<ResultPartitionDeploymentDescriptor>emptyList(),

http://git-wip-us.apache.org/repos/asf/flink/blob/4809f536/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileReaderOperator.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileReaderOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileReaderOperator.java
index 6d67560..0c0b81a 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileReaderOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileReaderOperator.java
@@ -125,7 +125,7 @@ public class ContinuousFileReaderOperator<OUT, S extends Serializable> extends A
 	}
 
 	@Override
-	public void dispose() {
+	public void dispose() throws Exception {
 		super.dispose();
 
 		// first try to cancel it properly and

http://git-wip-us.apache.org/repos/asf/flink/blob/4809f536/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java
index 81c5c48..f9f26e9 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java
@@ -152,7 +152,6 @@ public class StreamGraphGenerator {
 
 			if (maxParallelism <= 0) {
 				maxParallelism = transform.getParallelism();
-
 				/**
 				 * TODO: Remove once the parallelism settings works properly in Flink (FLINK-3885)
 				 * Currently, the parallelism will be set to 1 on the JobManager iff it encounters