You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by se...@apache.org on 2016/09/30 12:47:54 UTC

[04/10] flink git commit: [FLINK-4379] [checkpoints] Introduce rescalable operator state

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
index 73e2808..2f21574 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
@@ -80,11 +80,11 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		return getStateBackend().createStreamFactory(new JobID(), "test_op");
 	}
 
-	protected <K> KeyedStateBackend<K> createKeyedBackend(TypeSerializer<K> keySerializer) throws Exception {
+	protected <K> AbstractKeyedStateBackend<K> createKeyedBackend(TypeSerializer<K> keySerializer) throws Exception {
 		return createKeyedBackend(keySerializer, new DummyEnvironment("test", 1, 0));
 	}
 
-	protected <K> KeyedStateBackend<K> createKeyedBackend(TypeSerializer<K> keySerializer, Environment env) throws Exception {
+	protected <K> AbstractKeyedStateBackend<K> createKeyedBackend(TypeSerializer<K> keySerializer, Environment env) throws Exception {
 		return createKeyedBackend(
 				keySerializer,
 				10,
@@ -92,7 +92,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 				env);
 	}
 
-	protected <K> KeyedStateBackend<K> createKeyedBackend(
+	protected <K> AbstractKeyedStateBackend<K> createKeyedBackend(
 			TypeSerializer<K> keySerializer,
 			int numberOfKeyGroups,
 			KeyGroupRange keyGroupRange,
@@ -104,14 +104,15 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 				keySerializer,
 				numberOfKeyGroups,
 				keyGroupRange,
-				env.getTaskKvStateRegistry());
+				env.getTaskKvStateRegistry())
+;
 	}
 
-	protected <K> KeyedStateBackend<K> restoreKeyedBackend(TypeSerializer<K> keySerializer, KeyGroupsStateHandle state) throws Exception {
+	protected <K> AbstractKeyedStateBackend<K> restoreKeyedBackend(TypeSerializer<K> keySerializer, KeyGroupsStateHandle state) throws Exception {
 		return restoreKeyedBackend(keySerializer, state, new DummyEnvironment("test", 1, 0));
 	}
 
-	protected <K> KeyedStateBackend<K> restoreKeyedBackend(
+	protected <K> AbstractKeyedStateBackend<K> restoreKeyedBackend(
 			TypeSerializer<K> keySerializer,
 			KeyGroupsStateHandle state,
 			Environment env) throws Exception {
@@ -123,7 +124,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 				env);
 	}
 
-	protected <K> KeyedStateBackend<K> restoreKeyedBackend(
+	protected <K> AbstractKeyedStateBackend<K> restoreKeyedBackend(
 			TypeSerializer<K> keySerializer,
 			int numberOfKeyGroups,
 			KeyGroupRange keyGroupRange,
@@ -144,7 +145,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 	@SuppressWarnings("unchecked")
 	public void testValueState() throws Exception {
 		CheckpointStreamFactory streamFactory = createStreamFactory();
-		KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
+		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 		ValueStateDescriptor<String> kvId = new ValueStateDescriptor<>("id", String.class, null);
 		kvId.initializeSerializerUnlessSet(new ExecutionConfig());
@@ -195,7 +196,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		assertEquals("u3", state.value());
 		assertEquals("u3", getSerializedValue(kvState, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
-		backend.close();
+		backend.dispose();
 		backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot1);
 
 		snapshot1.discardState();
@@ -211,7 +212,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		assertEquals("2", restored1.value());
 		assertEquals("2", getSerializedValue(restoredKvState1, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
-		backend.close();
+		backend.dispose();
 		backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot2);
 
 		snapshot2.discardState();
@@ -230,7 +231,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		assertEquals("u3", restored2.value());
 		assertEquals("u3", getSerializedValue(restoredKvState2, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
-		backend.close();
+		backend.dispose();
 	}
 
 	@Test
@@ -238,7 +239,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 	public void testMultipleValueStates() throws Exception {
 		CheckpointStreamFactory streamFactory = createStreamFactory();
 
-		KeyedStateBackend<Integer> backend = createKeyedBackend(
+		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(
 				IntSerializer.INSTANCE,
 				1,
 				new KeyGroupRange(0, 0),
@@ -271,7 +272,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		// draw a snapshot
 		KeyGroupsStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory));
 
-		backend.close();
+		backend.dispose();
 		backend = restoreKeyedBackend(
 				IntSerializer.INSTANCE,
 				1,
@@ -290,7 +291,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		assertEquals("1", state1.value());
 		assertEquals(13, (int) state2.value());
 
-		backend.close();
+		backend.dispose();
 	}
 
 	/**
@@ -313,7 +314,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		}
 
 		CheckpointStreamFactory streamFactory = createStreamFactory();
-		KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
+		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 		ValueStateDescriptor<Long> kvId = new ValueStateDescriptor<>("id", LongSerializer.INSTANCE, 42L);
 		kvId.initializeSerializerUnlessSet(new ExecutionConfig());
@@ -344,14 +345,14 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		// draw a snapshot
 		KeyGroupsStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory));
 
-		backend.close();
+		backend.dispose();
 		backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot1);
 
 		snapshot1.discardState();
 
 		backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
 
-		backend.close();
+		backend.dispose();
 	}
 
 	@Test
@@ -359,7 +360,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 	public void testListState() {
 		try {
 			CheckpointStreamFactory streamFactory = createStreamFactory();
-			KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
+			AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 			ListStateDescriptor<String> kvId = new ListStateDescriptor<>("id", String.class);
 			kvId.initializeSerializerUnlessSet(new ExecutionConfig());
@@ -411,7 +412,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			assertEquals("u3", joiner.join(state.get()));
 			assertEquals("u3", joiner.join(getSerializedList(kvState, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)));
 
-			backend.close();
+			backend.dispose();
 			// restore the first snapshot and validate it
 			backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot1);
 			snapshot1.discardState();
@@ -427,7 +428,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			assertEquals("2", joiner.join(restored1.get()));
 			assertEquals("2", joiner.join(getSerializedList(restoredKvState1, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)));
 
-			backend.close();
+			backend.dispose();
 			// restore the second snapshot and validate it
 			backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot2);
 			snapshot2.discardState();
@@ -446,7 +447,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			assertEquals("u3", joiner.join(restored2.get()));
 			assertEquals("u3", joiner.join(getSerializedList(restoredKvState2, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)));
 
-			backend.close();
+			backend.dispose();
 		}
 		catch (Exception e) {
 			e.printStackTrace();
@@ -459,7 +460,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 	public void testReducingState() {
 		try {
 			CheckpointStreamFactory streamFactory = createStreamFactory();
-			KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
+			AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 			ReducingStateDescriptor<String> kvId = new ReducingStateDescriptor<>("id", new AppendingReduce(), String.class);
 			kvId.initializeSerializerUnlessSet(new ExecutionConfig());
@@ -510,7 +511,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			assertEquals("u3", state.get());
 			assertEquals("u3", getSerializedValue(kvState, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
-			backend.close();
+			backend.dispose();
 			// restore the first snapshot and validate it
 			backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot1);
 			snapshot1.discardState();
@@ -526,7 +527,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			assertEquals("2", restored1.get());
 			assertEquals("2", getSerializedValue(restoredKvState1, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
-			backend.close();
+			backend.dispose();
 			// restore the second snapshot and validate it
 			backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot2);
 			snapshot2.discardState();
@@ -545,7 +546,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			assertEquals("u3", restored2.get());
 			assertEquals("u3", getSerializedValue(restoredKvState2, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
-			backend.close();
+			backend.dispose();
 		}
 		catch (Exception e) {
 			e.printStackTrace();
@@ -558,7 +559,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 	public void testFoldingState() {
 		try {
 			CheckpointStreamFactory streamFactory = createStreamFactory();
-			KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
+			AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 			FoldingStateDescriptor<Integer, String> kvId = new FoldingStateDescriptor<>("id",
 					"Fold-Initial:",
@@ -613,7 +614,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			assertEquals("Fold-Initial:,103", state.get());
 			assertEquals("Fold-Initial:,103", getSerializedValue(kvState, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
-			backend.close();
+			backend.dispose();
 			// restore the first snapshot and validate it
 			backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot1);
 			snapshot1.discardState();
@@ -629,7 +630,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			assertEquals("Fold-Initial:,2", restored1.get());
 			assertEquals("Fold-Initial:,2", getSerializedValue(restoredKvState1, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
-			backend.close();
+			backend.dispose();
 			// restore the second snapshot and validate it
 			backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot2);
 			snapshot1.discardState();
@@ -649,7 +650,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			assertEquals("Fold-Initial:,103", restored2.get());
 			assertEquals("Fold-Initial:,103", getSerializedValue(restoredKvState2, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
-			backend.close();
+			backend.dispose();
 		}
 		catch (Exception e) {
 			e.printStackTrace();
@@ -672,7 +673,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		final int MAX_PARALLELISM = 10;
 
 		CheckpointStreamFactory streamFactory = createStreamFactory();
-		KeyedStateBackend<Integer> backend = createKeyedBackend(
+		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(
 				IntSerializer.INSTANCE,
 				MAX_PARALLELISM,
 				new KeyGroupRange(0, MAX_PARALLELISM - 1),
@@ -714,10 +715,10 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 				Collections.singletonList(snapshot),
 				KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(MAX_PARALLELISM, 2, 1));
 
-		backend.close();
+		backend.dispose();
 
 		// backend for the first half of the key group range
-		KeyedStateBackend<Integer> firstHalfBackend = restoreKeyedBackend(
+		AbstractKeyedStateBackend<Integer> firstHalfBackend = restoreKeyedBackend(
 				IntSerializer.INSTANCE,
 				MAX_PARALLELISM,
 				new KeyGroupRange(0, 4),
@@ -725,7 +726,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 				new DummyEnvironment("test", 1, 0));
 
 		// backend for the second half of the key group range
-		KeyedStateBackend<Integer> secondHalfBackend = restoreKeyedBackend(
+		AbstractKeyedStateBackend<Integer> secondHalfBackend = restoreKeyedBackend(
 				IntSerializer.INSTANCE,
 				MAX_PARALLELISM,
 				new KeyGroupRange(5, 9),
@@ -749,8 +750,8 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		secondHalfBackend.setCurrentKey(keyInSecondHalf);
 		assertTrue(secondHalfState.value().equals("ShouldBeInSecondHalf"));
 
-		firstHalfBackend.close();
-		secondHalfBackend.close();
+		firstHalfBackend.dispose();
+		secondHalfBackend.dispose();
 	}
 
 	@Test
@@ -758,7 +759,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 	public void testValueStateRestoreWithWrongSerializers() {
 		try {
 			CheckpointStreamFactory streamFactory = createStreamFactory();
-			KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
+			AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 			ValueStateDescriptor<String> kvId = new ValueStateDescriptor<>("id", String.class, null);
 			kvId.initializeSerializerUnlessSet(new ExecutionConfig());
@@ -773,7 +774,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			// draw a snapshot
 			KeyGroupsStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory));
 
-			backend.close();
+			backend.dispose();
 			// restore the first snapshot and validate it
 			backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot1);
 			snapshot1.discardState();
@@ -798,7 +799,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			} catch (Exception e) {
 				fail("wrong exception " + e);
 			}
-			backend.close();
+			backend.dispose();
 		}
 		catch (Exception e) {
 			e.printStackTrace();
@@ -811,7 +812,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 	public void testListStateRestoreWithWrongSerializers() {
 		try {
 			CheckpointStreamFactory streamFactory = createStreamFactory();
-			KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
+			AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 			ListStateDescriptor<String> kvId = new ListStateDescriptor<>("id", String.class);
 			ListState<String> state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
@@ -824,7 +825,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			// draw a snapshot
 			KeyGroupsStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory));
 
-			backend.close();
+			backend.dispose();
 			// restore the first snapshot and validate it
 			backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot1);
 			snapshot1.discardState();
@@ -849,7 +850,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			} catch (Exception e) {
 				fail("wrong exception " + e);
 			}
-			backend.close();
+			backend.dispose();
 		}
 		catch (Exception e) {
 			e.printStackTrace();
@@ -862,7 +863,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 	public void testReducingStateRestoreWithWrongSerializers() {
 		try {
 			CheckpointStreamFactory streamFactory = createStreamFactory();
-			KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
+			AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 			ReducingStateDescriptor<String> kvId = new ReducingStateDescriptor<>("id",
 					new AppendingReduce(),
@@ -877,7 +878,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			// draw a snapshot
 			KeyGroupsStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory));
 
-			backend.close();
+			backend.dispose();
 			// restore the first snapshot and validate it
 			backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot1);
 			snapshot1.discardState();
@@ -902,7 +903,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			} catch (Exception e) {
 				fail("wrong exception " + e);
 			}
-			backend.close();
+			backend.dispose();
 		}
 		catch (Exception e) {
 			e.printStackTrace();
@@ -912,7 +913,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 
 	@Test
 	public void testCopyDefaultValue() throws Exception {
-		KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
+		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 		ValueStateDescriptor<IntValue> kvId = new ValueStateDescriptor<>("id", IntValue.class, new IntValue(-1));
 		kvId.initializeSerializerUnlessSet(new ExecutionConfig());
@@ -930,7 +931,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		assertEquals(default1, default2);
 		assertFalse(default1 == default2);
 
-		backend.close();
+		backend.dispose();
 	}
 
 	/**
@@ -940,7 +941,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 	 */
 	@Test
 	public void testRequireNonNullNamespace() throws Exception {
-		KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
+		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 		ValueStateDescriptor<IntValue> kvId = new ValueStateDescriptor<>("id", IntValue.class, new IntValue(-1));
 		kvId.initializeSerializerUnlessSet(new ExecutionConfig());
@@ -963,7 +964,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		} catch (NullPointerException ignored) {
 		}
 
-		backend.close();
+		backend.dispose();
 	}
 
 	/**
@@ -973,7 +974,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 	@SuppressWarnings("unchecked")
 	protected void testConcurrentMapIfQueryable() throws Exception {
 		final int numberOfKeyGroups = 1;
-		KeyedStateBackend<Integer> backend = createKeyedBackend(
+		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(
 				IntSerializer.INSTANCE,
 				numberOfKeyGroups,
 				new KeyGroupRange(0, 0),
@@ -1095,7 +1096,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			assertTrue(stateTable.get(keyGroupIndex).get(VoidNamespace.INSTANCE) instanceof ConcurrentHashMap);
 		}
 
-		backend.close();
+		backend.dispose();
 	}
 
 	/**
@@ -1107,7 +1108,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		KvStateRegistry registry = env.getKvStateRegistry();
 
 		CheckpointStreamFactory streamFactory = createStreamFactory();
-		KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE, env);
+		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE, env);
 		KeyGroupRange expectedKeyGroupRange = backend.getKeyGroupRange();
 
 		KvStateRegistryListener listener = mock(KvStateRegistryListener.class);
@@ -1128,11 +1129,11 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 
 		KeyGroupsStateHandle snapshot = runSnapshot(backend.snapshot(682375462379L, 4, streamFactory));
 
-		backend.close();
+		backend.dispose();
 
 		verify(listener, times(1)).notifyKvStateUnregistered(
 				eq(env.getJobID()), eq(env.getJobVertexId()), eq(expectedKeyGroupRange), eq("banana"));
-		backend.close();
+		backend.dispose();
 		// Initialize again
 		backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot, env);
 		snapshot.discardState();
@@ -1143,7 +1144,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		verify(listener, times(2)).notifyKvStateRegistered(
 				eq(env.getJobID()), eq(env.getJobVertexId()), eq(expectedKeyGroupRange), eq("banana"), any(KvStateID.class));
 
-		backend.close();
+		backend.dispose();
 
 	}
 
@@ -1152,17 +1153,17 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 
 		try {
 			CheckpointStreamFactory streamFactory = createStreamFactory();
-			KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
+			AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 			ListStateDescriptor<String> kvId = new ListStateDescriptor<>("id", String.class);
 
 			// draw a snapshot
 			KeyGroupsStateHandle snapshot = runSnapshot(backend.snapshot(682375462379L, 1, streamFactory));
 			assertNull(snapshot);
-			backend.close();
+			backend.dispose();
 
 			backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot);
-			backend.close();
+			backend.dispose();
 		}
 		catch (Exception e) {
 			e.printStackTrace();

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStateOutputStreamTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStateOutputStreamTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStateOutputStreamTest.java
index a6a555d..d484f2e 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStateOutputStreamTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStateOutputStreamTest.java
@@ -20,8 +20,6 @@ package org.apache.flink.runtime.state.filesystem;
 
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
-import org.apache.flink.runtime.state.AbstractStateBackend;
-
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.junit.Test;
@@ -31,7 +29,8 @@ import java.io.File;
 import java.io.InputStream;
 import java.util.Random;
 
-import static org.junit.Assert.*;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertTrue;
 
 public class FsCheckpointStateOutputStreamTest {
 
@@ -112,13 +111,14 @@ public class FsCheckpointStateOutputStreamTest {
 		// make sure the writing process did not alter the original byte array
 		assertArrayEquals(original, bytes);
 
-		InputStream inStream = handle.openInputStream();
-		byte[] validation = new byte[bytes.length];
+		try (InputStream inStream = handle.openInputStream()) {
+			byte[] validation = new byte[bytes.length];
 
-		DataInputStream dataInputStream = new DataInputStream(inStream);
-		dataInputStream.readFully(validation);
+			DataInputStream dataInputStream = new DataInputStream(inStream);
+			dataInputStream.readFully(validation);
 
-		assertArrayEquals(bytes, validation);
+			assertArrayEquals(bytes, validation);
+		}
 
 		handle.discardState();
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
index 454196f..7bc2c29 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
@@ -46,6 +46,7 @@ import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 
 import org.apache.flink.util.SerializedValue;
@@ -53,6 +54,7 @@ import org.junit.Before;
 import org.junit.Test;
 
 import java.net.URL;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.List;
 import java.util.concurrent.Executor;
@@ -209,7 +211,8 @@ public class TaskAsyncCallTest {
 
 		@Override
 		public void setInitialState(ChainedStateHandle<StreamStateHandle> chainedState,
-				List<KeyGroupsStateHandle> keyGroupsState) throws Exception {
+									List<KeyGroupsStateHandle> keyGroupsState,
+									List<Collection<OperatorStateHandle>> partitionableOperatorState) throws Exception {
 
 		}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/zookeeper/ZooKeeperStateHandleStoreITCase.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/zookeeper/ZooKeeperStateHandleStoreITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/zookeeper/ZooKeeperStateHandleStoreITCase.java
index 7e8868c..8f9c932 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/zookeeper/ZooKeeperStateHandleStoreITCase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/zookeeper/ZooKeeperStateHandleStoreITCase.java
@@ -33,7 +33,6 @@ import org.junit.Test;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
 
-import java.io.IOException;
 import java.util.ArrayList;
 import java.util.HashSet;
 import java.util.List;
@@ -587,8 +586,5 @@ public class ZooKeeperStateHandleStoreITCase extends TestLogger {
 		public int getNumberOfDiscardCalls() {
 			return numberOfDiscardCalls;
 		}
-
-		@Override
-		public void close() throws IOException {}
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-connectors/flink-connector-kafka-0.8/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumer08.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-kafka-0.8/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumer08.java b/flink-streaming-connectors/flink-connector-kafka-0.8/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumer08.java
index c16629d..d7a6364 100644
--- a/flink-streaming-connectors/flink-connector-kafka-0.8/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumer08.java
+++ b/flink-streaming-connectors/flink-connector-kafka-0.8/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumer08.java
@@ -177,7 +177,7 @@ public class FlinkKafkaConsumer08<T> extends FlinkKafkaConsumerBase<T> {
 	 *           The properties that are used to configure both the fetcher and the offset handler.
 	 */
 	public FlinkKafkaConsumer08(List<String> topics, KeyedDeserializationSchema<T> deserializer, Properties props) {
-		super(deserializer);
+		super(topics, deserializer);
 
 		checkNotNull(topics, "topics");
 		this.kafkaProperties = checkNotNull(props, "props");
@@ -187,22 +187,6 @@ public class FlinkKafkaConsumer08<T> extends FlinkKafkaConsumerBase<T> {
 
 		this.invalidOffsetBehavior = getInvalidOffsetBehavior(props);
 		this.autoCommitInterval = PropertiesUtil.getLong(props, "auto.commit.interval.ms", 60000);
-
-		// Connect to a broker to get the partitions for all topics
-		List<KafkaTopicPartition> partitionInfos = 
-				KafkaTopicPartition.dropLeaderData(getPartitionsForTopic(topics, props));
-
-		if (partitionInfos.size() == 0) {
-			throw new RuntimeException(
-					"Unable to retrieve any partitions for the requested topics " + topics + 
-							". Please check previous log entries");
-		}
-
-		if (LOG.isInfoEnabled()) {
-			logPartitionInfo(LOG, partitionInfos);
-		}
-
-		setSubscribedPartitions(partitionInfos);
 	}
 
 	@Override
@@ -221,6 +205,25 @@ public class FlinkKafkaConsumer08<T> extends FlinkKafkaConsumerBase<T> {
 				invalidOffsetBehavior, autoCommitInterval, useMetrics);
 	}
 
+	@Override
+	protected List<KafkaTopicPartition> getKafkaPartitions(List<String> topics) {
+		// Connect to a broker to get the partitions for all topics
+		List<KafkaTopicPartition> partitionInfos =
+			KafkaTopicPartition.dropLeaderData(getPartitionsForTopic(topics, kafkaProperties));
+
+		if (partitionInfos.size() == 0) {
+			throw new RuntimeException(
+				"Unable to retrieve any partitions for the requested topics " + topics +
+					". Please check previous log entries");
+		}
+
+		if (LOG.isInfoEnabled()) {
+			logPartitionInfo(LOG, partitionInfos);
+		}
+
+		return partitionInfos;
+	}
+
 	// ------------------------------------------------------------------------
 	//  Kafka / ZooKeeper communication utilities
 	// ------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-connectors/flink-connector-kafka-0.8/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumer08Test.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-kafka-0.8/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumer08Test.java b/flink-streaming-connectors/flink-connector-kafka-0.8/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumer08Test.java
index 36fb7e6..f0b58cf 100644
--- a/flink-streaming-connectors/flink-connector-kafka-0.8/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumer08Test.java
+++ b/flink-streaming-connectors/flink-connector-kafka-0.8/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumer08Test.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.streaming.connectors.kafka;
 
+import org.apache.flink.configuration.Configuration;
 import org.apache.flink.streaming.util.serialization.SimpleStringSchema;
 
 import org.apache.kafka.clients.consumer.ConsumerConfig;
@@ -80,7 +81,8 @@ public class KafkaConsumer08Test {
 			props.setProperty("bootstrap.servers", "localhost:11111, localhost:22222");
 			props.setProperty("group.id", "non-existent-group");
 
-			new FlinkKafkaConsumer08<>(Collections.singletonList("no op topic"), new SimpleStringSchema(), props);
+			FlinkKafkaConsumer08<String> consumer = new FlinkKafkaConsumer08<>(Collections.singletonList("no op topic"), new SimpleStringSchema(), props);
+			consumer.open(new Configuration());
 			fail();
 		}
 		catch (Exception e) {

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-connectors/flink-connector-kafka-0.9/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumer09.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-kafka-0.9/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumer09.java b/flink-streaming-connectors/flink-connector-kafka-0.9/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumer09.java
index 8c3eaf8..9708777 100644
--- a/flink-streaming-connectors/flink-connector-kafka-0.9/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumer09.java
+++ b/flink-streaming-connectors/flink-connector-kafka-0.9/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumer09.java
@@ -149,9 +149,8 @@ public class FlinkKafkaConsumer09<T> extends FlinkKafkaConsumerBase<T> {
 	 *           The properties that are used to configure both the fetcher and the offset handler.
 	 */
 	public FlinkKafkaConsumer09(List<String> topics, KeyedDeserializationSchema<T> deserializer, Properties props) {
-		super(deserializer);
+		super(topics, deserializer);
 
-		checkNotNull(topics, "topics");
 		this.properties = checkNotNull(props, "props");
 		setDeserializer(this.properties);
 
@@ -166,7 +165,27 @@ public class FlinkKafkaConsumer09<T> extends FlinkKafkaConsumerBase<T> {
 		catch (Exception e) {
 			throw new IllegalArgumentException("Cannot parse poll timeout for '" + KEY_POLL_TIMEOUT + '\'', e);
 		}
+	}
+
+	@Override
+	protected AbstractFetcher<T, ?> createFetcher(
+			SourceContext<T> sourceContext,
+			List<KafkaTopicPartition> thisSubtaskPartitions,
+			SerializedValue<AssignerWithPeriodicWatermarks<T>> watermarksPeriodic,
+			SerializedValue<AssignerWithPunctuatedWatermarks<T>> watermarksPunctuated,
+			StreamingRuntimeContext runtimeContext) throws Exception {
+
+		boolean useMetrics = !Boolean.valueOf(properties.getProperty(KEY_DISABLE_METRICS, "false"));
+
+		return new Kafka09Fetcher<>(sourceContext, thisSubtaskPartitions,
+				watermarksPeriodic, watermarksPunctuated,
+				runtimeContext, deserializer,
+				properties, pollTimeout, useMetrics);
+		
+	}
 
+	@Override
+	protected List<KafkaTopicPartition> getKafkaPartitions(List<String> topics) {
 		// read the partitions that belong to the listed topics
 		final List<KafkaTopicPartition> partitions = new ArrayList<>();
 
@@ -192,25 +211,7 @@ public class FlinkKafkaConsumer09<T> extends FlinkKafkaConsumerBase<T> {
 			logPartitionInfo(LOG, partitions);
 		}
 
-		// register these partitions
-		setSubscribedPartitions(partitions);
-	}
-
-	@Override
-	protected AbstractFetcher<T, ?> createFetcher(
-			SourceContext<T> sourceContext,
-			List<KafkaTopicPartition> thisSubtaskPartitions,
-			SerializedValue<AssignerWithPeriodicWatermarks<T>> watermarksPeriodic,
-			SerializedValue<AssignerWithPunctuatedWatermarks<T>> watermarksPunctuated,
-			StreamingRuntimeContext runtimeContext) throws Exception {
-
-		boolean useMetrics = !Boolean.valueOf(properties.getProperty(KEY_DISABLE_METRICS, "false"));
-
-		return new Kafka09Fetcher<>(sourceContext, thisSubtaskPartitions,
-				watermarksPeriodic, watermarksPunctuated,
-				runtimeContext, deserializer,
-				properties, pollTimeout, useMetrics);
-		
+		return partitions;
 	}
 
 	// ------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
index 2b2c527..939b77b 100644
--- a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
+++ b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
@@ -18,11 +18,16 @@
 package org.apache.flink.streaming.connectors.kafka;
 
 import org.apache.commons.collections.map.LinkedMap;
-
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.typeinfo.TypeHint;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
+import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.state.CheckpointListener;
-import org.apache.flink.streaming.api.checkpoint.CheckpointedAsynchronously;
+import org.apache.flink.runtime.state.OperatorStateStore;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
 import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks;
 import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
 import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
@@ -30,18 +35,21 @@ import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.connectors.kafka.internals.AbstractFetcher;
 import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionState;
 import org.apache.flink.streaming.util.serialization.KeyedDeserializationSchema;
 import org.apache.flink.util.SerializedValue;
-
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.Comparator;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 
+import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
 /**
@@ -55,11 +63,12 @@ import static org.apache.flink.util.Preconditions.checkNotNull;
  */
 public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFunction<T> implements 
 		CheckpointListener,
-		CheckpointedAsynchronously<HashMap<KafkaTopicPartition, Long>>,
-		ResultTypeQueryable<T>
-{
+		ResultTypeQueryable<T>,
+		CheckpointedFunction {
 	private static final long serialVersionUID = -6272159445203409112L;
 
+	private static final String KAFKA_OFFSETS = "kafka_offsets";
+
 	protected static final Logger LOG = LoggerFactory.getLogger(FlinkKafkaConsumerBase.class);
 	
 	/** The maximum number of pending non-committed checkpoints to track, to avoid memory leaks */
@@ -71,12 +80,14 @@ public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFuncti
 	// ------------------------------------------------------------------------
 	//  configuration state, set on the client relevant for all subtasks
 	// ------------------------------------------------------------------------
+
+	private final List<String> topics;
 	
 	/** The schema to convert between Kafka's byte messages, and Flink's objects */
 	protected final KeyedDeserializationSchema<T> deserializer;
 
 	/** The set of topic partitions that the source will read */
-	protected List<KafkaTopicPartition> allSubscribedPartitions;
+	protected List<KafkaTopicPartition> subscribedPartitions;
 	
 	/** Optional timestamp extractor / watermark generator that will be run per Kafka partition,
 	 * to exploit per-partition timestamp characteristics.
@@ -88,6 +99,8 @@ public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFuncti
 	 * The assigner is kept in serialized form, to deserialize it into multiple copies */
 	private SerializedValue<AssignerWithPunctuatedWatermarks<T>> punctuatedWatermarkAssigner;
 
+	private transient OperatorStateStore stateStore;
+
 	// ------------------------------------------------------------------------
 	//  runtime state (used individually by each parallel subtask) 
 	// ------------------------------------------------------------------------
@@ -112,8 +125,14 @@ public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFuncti
 	 * @param deserializer
 	 *           The deserializer to turn raw byte messages into Java/Scala objects.
 	 */
-	public FlinkKafkaConsumerBase(KeyedDeserializationSchema<T> deserializer) {
+	public FlinkKafkaConsumerBase(List<String> topics, KeyedDeserializationSchema<T> deserializer) {
+		this.topics = checkNotNull(topics);
+		checkArgument(topics.size() > 0, "You have to define at least one topic.");
+
 		this.deserializer = checkNotNull(deserializer, "valueDeserializer");
+
+		TypeInformation<Tuple2<KafkaTopicPartition, Long>> typeInfo =
+				TypeInformation.of(new TypeHint<Tuple2<KafkaTopicPartition, Long>>(){});
 	}
 
 	/**
@@ -124,7 +143,7 @@ public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFuncti
 	 */
 	protected void setSubscribedPartitions(List<KafkaTopicPartition> allSubscribedPartitions) {
 		checkNotNull(allSubscribedPartitions);
-		this.allSubscribedPartitions = Collections.unmodifiableList(allSubscribedPartitions);
+		this.subscribedPartitions = Collections.unmodifiableList(allSubscribedPartitions);
 	}
 
 	// ------------------------------------------------------------------------
@@ -205,20 +224,16 @@ public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFuncti
 
 	@Override
 	public void run(SourceContext<T> sourceContext) throws Exception {
-		if (allSubscribedPartitions == null) {
+		if (subscribedPartitions == null) {
 			throw new Exception("The partitions were not set for the consumer");
 		}
-		
-		// figure out which partitions this subtask should process
-		final List<KafkaTopicPartition> thisSubtaskPartitions = assignPartitions(allSubscribedPartitions,
-				getRuntimeContext().getNumberOfParallelSubtasks(), getRuntimeContext().getIndexOfThisSubtask());
-		
+
 		// we need only do work, if we actually have partitions assigned
-		if (!thisSubtaskPartitions.isEmpty()) {
+		if (!subscribedPartitions.isEmpty()) {
 
 			// (1) create the fetcher that will communicate with the Kafka brokers
 			final AbstractFetcher<T, ?> fetcher = createFetcher(
-					sourceContext, thisSubtaskPartitions, 
+					sourceContext, subscribedPartitions,
 					periodicWatermarkAssigner, punctuatedWatermarkAssigner,
 					(StreamingRuntimeContext) getRuntimeContext());
 
@@ -277,6 +292,15 @@ public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFuncti
 	}
 
 	@Override
+	public void open(Configuration configuration) {
+		List<KafkaTopicPartition> kafkaTopicPartitions = getKafkaPartitions(topics);
+
+		if (kafkaTopicPartitions != null) {
+			assignTopicPartitions(kafkaTopicPartitions);
+		}
+	}
+
+	@Override
 	public void close() throws Exception {
 		// pretty much the same logic as cancelling
 		try {
@@ -289,44 +313,76 @@ public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFuncti
 	// ------------------------------------------------------------------------
 	//  Checkpoint and restore
 	// ------------------------------------------------------------------------
-	
+
+
 	@Override
-	public HashMap<KafkaTopicPartition, Long> snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
-		if (!running) {
-			LOG.debug("snapshotState() called on closed source");
-			return null;
-		}
-		
-		final AbstractFetcher<?, ?> fetcher = this.kafkaFetcher;
-		if (fetcher == null) {
-			// the fetcher has not yet been initialized, which means we need to return the
-			// originally restored offsets
-			return restoreToOffset;
-		}
+	public void initializeState(OperatorStateStore stateStore) throws Exception {
 
-		HashMap<KafkaTopicPartition, Long> currentOffsets = fetcher.snapshotCurrentState();
+		this.stateStore = stateStore;
 
-		if (LOG.isDebugEnabled()) {
-			LOG.debug("Snapshotting state. Offsets: {}, checkpoint id: {}, timestamp: {}",
-					KafkaTopicPartition.toString(currentOffsets), checkpointId, checkpointTimestamp);
-		}
+		ListState<Serializable> offsets = stateStore.getPartitionableState(ListCheckpointed.DEFAULT_LIST_DESCRIPTOR);
 
-		// the map cannot be asynchronously updated, because only one checkpoint call can happen
-		// on this function at a time: either snapshotState() or notifyCheckpointComplete()
-		pendingCheckpoints.put(checkpointId, currentOffsets);
-		
-		// truncate the map, to prevent infinite growth
-		while (pendingCheckpoints.size() > MAX_NUM_PENDING_CHECKPOINTS) {
-			pendingCheckpoints.remove(0);
+		restoreToOffset = new HashMap<>();
+
+		for (Serializable serializable : offsets.get()) {
+			@SuppressWarnings("unchecked")
+			Tuple2<KafkaTopicPartition, Long> kafkaOffset = (Tuple2<KafkaTopicPartition, Long>) serializable;
+			restoreToOffset.put(kafkaOffset.f0, kafkaOffset.f1);
 		}
 
-		return currentOffsets;
+		LOG.info("Setting restore state in the FlinkKafkaConsumer: {}", restoreToOffset);
 	}
 
 	@Override
-	public void restoreState(HashMap<KafkaTopicPartition, Long> restoredOffsets) {
-		LOG.info("Setting restore state in the FlinkKafkaConsumer: {}", restoredOffsets);
-		restoreToOffset = restoredOffsets;
+	public void prepareSnapshot(long checkpointId, long timestamp) throws Exception {
+		if (!running) {
+			LOG.debug("storeOperatorState() called on closed source");
+		} else {
+
+			ListState<Serializable> listState = stateStore.getPartitionableState(ListCheckpointed.DEFAULT_LIST_DESCRIPTOR);
+
+			listState.clear();
+
+			final AbstractFetcher<?, ?> fetcher = this.kafkaFetcher;
+			if (fetcher == null) {
+				// the fetcher has not yet been initialized, which means we need to return the
+				// originally restored offsets or the assigned partitions
+
+				if (restoreToOffset != null) {
+					// the map cannot be asynchronously updated, because only one checkpoint call can happen
+					// on this function at a time: either snapshotState() or notifyCheckpointComplete()
+					pendingCheckpoints.put(checkpointId, restoreToOffset);
+
+					// truncate the map, to prevent infinite growth
+					while (pendingCheckpoints.size() > MAX_NUM_PENDING_CHECKPOINTS) {
+						pendingCheckpoints.remove(0);
+					}
+
+					for (Map.Entry<KafkaTopicPartition, Long> kafkaTopicPartitionLongEntry : restoreToOffset.entrySet()) {
+						listState.add(Tuple2.of(kafkaTopicPartitionLongEntry.getKey(), kafkaTopicPartitionLongEntry.getValue()));
+					}
+				} else if (subscribedPartitions != null) {
+					for (KafkaTopicPartition subscribedPartition : subscribedPartitions) {
+						listState.add(Tuple2.of(subscribedPartition, KafkaTopicPartitionState.OFFSET_NOT_SET));
+					}
+				}
+			} else {
+				HashMap<KafkaTopicPartition, Long> currentOffsets = fetcher.snapshotCurrentState();
+
+				// the map cannot be asynchronously updated, because only one checkpoint call can happen
+				// on this function at a time: either snapshotState() or notifyCheckpointComplete()
+				pendingCheckpoints.put(checkpointId, currentOffsets);
+
+				// truncate the map, to prevent infinite growth
+				while (pendingCheckpoints.size() > MAX_NUM_PENDING_CHECKPOINTS) {
+					pendingCheckpoints.remove(0);
+				}
+
+				for (Map.Entry<KafkaTopicPartition, Long> kafkaTopicPartitionLongEntry : currentOffsets.entrySet()) {
+					listState.add(Tuple2.of(kafkaTopicPartitionLongEntry.getKey(), kafkaTopicPartitionLongEntry.getValue()));
+				}
+			}
+		}
 	}
 
 	@Override
@@ -401,6 +457,8 @@ public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFuncti
 			SerializedValue<AssignerWithPeriodicWatermarks<T>> watermarksPeriodic,
 			SerializedValue<AssignerWithPunctuatedWatermarks<T>> watermarksPunctuated,
 			StreamingRuntimeContext runtimeContext) throws Exception;
+
+	protected abstract List<KafkaTopicPartition> getKafkaPartitions(List<String> topics);
 	
 	// ------------------------------------------------------------------------
 	//  ResultTypeQueryable methods 
@@ -415,6 +473,35 @@ public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFuncti
 	//  Utilities
 	// ------------------------------------------------------------------------
 
+	private void assignTopicPartitions(List<KafkaTopicPartition> kafkaTopicPartitions) {
+		subscribedPartitions = new ArrayList<>();
+
+		if (restoreToOffset != null) {
+			for (KafkaTopicPartition kafkaTopicPartition : kafkaTopicPartitions) {
+				if (restoreToOffset.containsKey(kafkaTopicPartition)) {
+					subscribedPartitions.add(kafkaTopicPartition);
+				}
+			}
+		} else {
+			Collections.sort(kafkaTopicPartitions, new Comparator<KafkaTopicPartition>() {
+				@Override
+				public int compare(KafkaTopicPartition o1, KafkaTopicPartition o2) {
+					int topicComparison = o1.getTopic().compareTo(o2.getTopic());
+
+					if (topicComparison == 0) {
+						return o1.getPartition() - o2.getPartition();
+					} else {
+						return topicComparison;
+					}
+				}
+			});
+
+			for (int i = getRuntimeContext().getIndexOfThisSubtask(); i < kafkaTopicPartitions.size(); i += getRuntimeContext().getNumberOfParallelSubtasks()) {
+				subscribedPartitions.add(kafkaTopicPartitions.get(i));
+			}
+		}
+	}
+
 	/**
 	 * Selects which of the given partitions should be handled by a specific consumer,
 	 * given a certain number of consumers.
@@ -427,8 +514,7 @@ public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFuncti
 	 */
 	protected static List<KafkaTopicPartition> assignPartitions(
 			List<KafkaTopicPartition> allPartitions,
-			int numConsumers, int consumerIndex)
-	{
+			int numConsumers, int consumerIndex) {
 		final List<KafkaTopicPartition> thisSubtaskPartitions = new ArrayList<>(
 				allPartitions.size() / numConsumers + 1);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java
index e63f033..8b87004 100644
--- a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java
+++ b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java
@@ -20,16 +20,16 @@ package org.apache.flink.streaming.connectors.kafka;
 import org.apache.flink.api.common.functions.RuntimeContext;
 import org.apache.flink.api.java.ClosureCleaner;
 import org.apache.flink.configuration.Configuration;
-import org.apache.flink.runtime.util.SerializableObject;
 import org.apache.flink.metrics.MetricGroup;
-import org.apache.flink.streaming.api.checkpoint.Checkpointed;
+import org.apache.flink.runtime.state.OperatorStateStore;
+import org.apache.flink.runtime.util.SerializableObject;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
 import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
 import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
 import org.apache.flink.streaming.connectors.kafka.internals.metrics.KafkaMetricWrapper;
 import org.apache.flink.streaming.connectors.kafka.partitioner.KafkaPartitioner;
 import org.apache.flink.streaming.util.serialization.KeyedSerializationSchema;
 import org.apache.flink.util.NetUtils;
-
 import org.apache.kafka.clients.producer.Callback;
 import org.apache.kafka.clients.producer.KafkaProducer;
 import org.apache.kafka.clients.producer.Producer;
@@ -40,11 +40,9 @@ import org.apache.kafka.common.Metric;
 import org.apache.kafka.common.MetricName;
 import org.apache.kafka.common.PartitionInfo;
 import org.apache.kafka.common.serialization.ByteArraySerializer;
-
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.io.Serializable;
 import java.util.List;
 import java.util.Map;
 import java.util.Properties;
@@ -61,7 +59,7 @@ import static java.util.Objects.requireNonNull;
  *
  * @param <IN> Type of the messages to write into Kafka.
  */
-public abstract class FlinkKafkaProducerBase<IN> extends RichSinkFunction<IN> implements Checkpointed<Serializable> {
+public abstract class FlinkKafkaProducerBase<IN> extends RichSinkFunction<IN> implements CheckpointedFunction {
 
 	private static final Logger LOG = LoggerFactory.getLogger(FlinkKafkaProducerBase.class);
 
@@ -126,6 +124,8 @@ public abstract class FlinkKafkaProducerBase<IN> extends RichSinkFunction<IN> im
 	/** Number of unacknowledged records. */
 	protected long pendingRecords;
 
+	protected OperatorStateStore stateStore;
+
 
 	/**
 	 * The main constructor for creating a FlinkKafkaProducer.
@@ -330,7 +330,12 @@ public abstract class FlinkKafkaProducerBase<IN> extends RichSinkFunction<IN> im
 	protected abstract void flush();
 
 	@Override
-	public Serializable snapshotState(long checkpointId, long checkpointTimestamp) {
+	public void initializeState(OperatorStateStore stateStore) throws Exception {
+		this.stateStore = stateStore;
+	}
+
+	@Override
+	public void prepareSnapshot(long checkpointId, long timestamp) throws Exception {
 		if (flushOnCheckpoint) {
 			// flushing is activated: We need to wait until pendingRecords is 0
 			flush();
@@ -341,16 +346,8 @@ public abstract class FlinkKafkaProducerBase<IN> extends RichSinkFunction<IN> im
 				// pending records count is 0. We can now confirm the checkpoint
 			}
 		}
-		// return empty state
-		return null;
-	}
-
-	@Override
-	public void restoreState(Serializable state) {
-		// nothing to do here
 	}
 
-
 	// ----------------------------------- Utilities --------------------------
 
 	protected void checkErroneous() throws Exception {

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/internals/AbstractFetcher.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/internals/AbstractFetcher.java b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/internals/AbstractFetcher.java
index 9255445..7ce3a9d 100644
--- a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/internals/AbstractFetcher.java
+++ b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/internals/AbstractFetcher.java
@@ -183,9 +183,7 @@ public abstract class AbstractFetcher<T, KPH> {
 
 		HashMap<KafkaTopicPartition, Long> state = new HashMap<>(allPartitions.length);
 		for (KafkaTopicPartitionState<?> partition : subscribedPartitions()) {
-			if (partition.isOffsetDefined()) {
-				state.put(partition.getKafkaTopicPartition(), partition.getOffset());
-			}
+			state.put(partition.getKafkaTopicPartition(), partition.getOffset());
 		}
 		return state;
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/AtLeastOnceProducerTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/AtLeastOnceProducerTest.java b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/AtLeastOnceProducerTest.java
index b02593c..766a107 100644
--- a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/AtLeastOnceProducerTest.java
+++ b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/AtLeastOnceProducerTest.java
@@ -20,6 +20,7 @@ package org.apache.flink.streaming.connectors.kafka;
 
 import org.apache.flink.api.java.tuple.Tuple1;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.state.OperatorStateStore;
 import org.apache.flink.streaming.connectors.kafka.testutils.MockRuntimeContext;
 import org.apache.flink.streaming.util.serialization.KeyedSerializationSchema;
 import org.apache.flink.streaming.util.serialization.KeyedSerializationSchemaWrapper;
@@ -37,7 +38,6 @@ import org.junit.Test;
 import scala.concurrent.duration.Deadline;
 import scala.concurrent.duration.FiniteDuration;
 
-import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
@@ -45,6 +45,8 @@ import java.util.Properties;
 import java.util.concurrent.Future;
 import java.util.concurrent.atomic.AtomicBoolean;
 
+import static org.mockito.Mockito.mock;
+
 /**
  * Test ensuring that the producer is not dropping buffered records
  */
@@ -111,7 +113,7 @@ public class AtLeastOnceProducerTest {
 		Thread threadB = new Thread(confirmer);
 		threadB.start();
 		// this should block:
-		producer.snapshotState(0, 0);
+		producer.prepareSnapshot(0, 0);
 		synchronized (threadA) {
 			threadA.notifyAll(); // just in case, to let the test fail faster
 		}
@@ -130,6 +132,8 @@ public class AtLeastOnceProducerTest {
 
 
 	private static class TestingKafkaProducer<T> extends FlinkKafkaProducerBase<T> {
+		private static final long serialVersionUID = -1759403646061180067L;
+
 		private MockProducer prod;
 		private AtomicBoolean snapshottingFinished;
 
@@ -145,12 +149,11 @@ public class AtLeastOnceProducerTest {
 		}
 
 		@Override
-		public Serializable snapshotState(long checkpointId, long checkpointTimestamp) {
+		public void prepareSnapshot(long checkpointId, long timestamp) throws Exception {
 			// call the actual snapshot state
-			Serializable ret = super.snapshotState(checkpointId, checkpointTimestamp);
+			super.prepareSnapshot(checkpointId, timestamp);
 			// notify test that snapshotting has been done
 			snapshottingFinished.set(true);
-			return ret;
 		}
 
 		@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
index 9b517df..fc8b7e9 100644
--- a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
+++ b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
@@ -19,6 +19,11 @@
 package org.apache.flink.streaming.connectors.kafka;
 
 import org.apache.commons.collections.map.LinkedMap;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.runtime.state.OperatorStateStore;
 import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks;
 import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
 import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
@@ -26,15 +31,26 @@ import org.apache.flink.streaming.connectors.kafka.internals.AbstractFetcher;
 import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
 import org.apache.flink.streaming.util.serialization.KeyedDeserializationSchema;
 import org.apache.flink.util.SerializedValue;
-
 import org.junit.Test;
+import org.mockito.Matchers;
 
 import java.lang.reflect.Field;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
+import java.util.Set;
 
-import static org.junit.Assert.*;
-import static org.mockito.Mockito.*;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 public class FlinkKafkaConsumerBaseTest {
 
@@ -82,7 +98,13 @@ public class FlinkKafkaConsumerBaseTest {
 		final AbstractFetcher<String, ?> fetcher = mock(AbstractFetcher.class);
 
 		FlinkKafkaConsumerBase<String> consumer = getConsumer(fetcher, new LinkedMap(), false);
-		assertNull(consumer.snapshotState(17L, 23L));
+		OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
+		TestingListState<Tuple2<KafkaTopicPartition, Long>> listState = new TestingListState<>();
+		when(operatorStateStore.getPartitionableState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+
+		consumer.prepareSnapshot(17L, 17L);
+
+		assertFalse(listState.get().iterator().hasNext());
 		consumer.notifyCheckpointComplete(66L);
 	}
 
@@ -91,14 +113,37 @@ public class FlinkKafkaConsumerBaseTest {
 	 */
 	@Test
 	public void checkRestoredCheckpointWhenFetcherNotReady() throws Exception {
-		HashMap<KafkaTopicPartition, Long> restoreState = new HashMap<>();
-		restoreState.put(new KafkaTopicPartition("abc", 13), 16768L);
-		restoreState.put(new KafkaTopicPartition("def", 7), 987654321L);
+		OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
+
+		TestingListState<Tuple2<KafkaTopicPartition, Long>> expectedState = new TestingListState<>();
+		expectedState.add(Tuple2.of(new KafkaTopicPartition("abc", 13), 16768L));
+		expectedState.add(Tuple2.of(new KafkaTopicPartition("def", 7), 987654321L));
+
+		TestingListState<Tuple2<KafkaTopicPartition, Long>> listState = new TestingListState<>();
 
 		FlinkKafkaConsumerBase<String> consumer = getConsumer(null, new LinkedMap(), true);
-		consumer.restoreState(restoreState);
-		
-		assertEquals(restoreState, consumer.snapshotState(17L, 23L));
+
+		when(operatorStateStore.getPartitionableState(Matchers.any(ListStateDescriptor.class))).thenReturn(expectedState);
+		consumer.initializeState(operatorStateStore);
+
+		when(operatorStateStore.getPartitionableState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+
+		consumer.prepareSnapshot(17L, 17L);
+
+		Set<Tuple2<KafkaTopicPartition, Long>> expected = new HashSet<Tuple2<KafkaTopicPartition, Long>>();
+
+		for (Tuple2<KafkaTopicPartition, Long> kafkaTopicPartitionLongTuple2 : expectedState.get()) {
+			expected.add(kafkaTopicPartitionLongTuple2);
+		}
+
+		int counter = 0;
+
+		for (Tuple2<KafkaTopicPartition, Long> kafkaTopicPartitionLongTuple2 : listState.get()) {
+			assertTrue(expected.contains(kafkaTopicPartitionLongTuple2));
+			counter++;
+		}
+
+		assertEquals(expected.size(), counter);
 	}
 
 	/**
@@ -107,7 +152,15 @@ public class FlinkKafkaConsumerBaseTest {
 	@Test
 	public void checkRestoredNullCheckpointWhenFetcherNotReady() throws Exception {
 		FlinkKafkaConsumerBase<String> consumer = getConsumer(null, new LinkedMap(), true);
-		assertNull(consumer.snapshotState(17L, 23L));
+
+		OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
+		TestingListState<Tuple2<KafkaTopicPartition, Long>> listState = new TestingListState<>();
+		when(operatorStateStore.getPartitionableState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+
+		consumer.initializeState(operatorStateStore);
+		consumer.prepareSnapshot(17L, 17L);
+
+		assertFalse(listState.get().iterator().hasNext());
 	}
 	
 	@Test
@@ -132,15 +185,40 @@ public class FlinkKafkaConsumerBaseTest {
 	
 		FlinkKafkaConsumerBase<String> consumer = getConsumer(fetcher, pendingCheckpoints, true);
 		assertEquals(0, pendingCheckpoints.size());
-		
+
+		OperatorStateStore backend = mock(OperatorStateStore.class);
+
+		TestingListState<Tuple2<KafkaTopicPartition, Long>> listState1 = new TestingListState<>();
+		TestingListState<Tuple2<KafkaTopicPartition, Long>> listState2 = new TestingListState<>();
+		TestingListState<Tuple2<KafkaTopicPartition, Long>> listState3 = new TestingListState<>();
+
+		when(backend.getPartitionableState(Matchers.any(ListStateDescriptor.class))).
+				thenReturn(listState1, listState1, listState2, listState2, listState3, listState3);
+
+		consumer.initializeState(backend);
+
 		// checkpoint 1
-		HashMap<KafkaTopicPartition, Long> snapshot1 = consumer.snapshotState(138L, 19L);
+		consumer.prepareSnapshot(138L, 138L);
+
+		HashMap<KafkaTopicPartition, Long> snapshot1 = new HashMap<>();
+
+		for (Tuple2<KafkaTopicPartition, Long> kafkaTopicPartitionLongTuple2 : listState1.get()) {
+			snapshot1.put(kafkaTopicPartitionLongTuple2.f0, kafkaTopicPartitionLongTuple2.f1);
+		}
+
 		assertEquals(state1, snapshot1);
 		assertEquals(1, pendingCheckpoints.size());
 		assertEquals(state1, pendingCheckpoints.get(138L));
 
 		// checkpoint 2
-		HashMap<KafkaTopicPartition, Long> snapshot2 = consumer.snapshotState(140L, 1578L);
+		consumer.prepareSnapshot(140L, 140L);
+
+		HashMap<KafkaTopicPartition, Long> snapshot2 = new HashMap<>();
+
+		for (Tuple2<KafkaTopicPartition, Long> kafkaTopicPartitionLongTuple2 : listState2.get()) {
+			snapshot2.put(kafkaTopicPartitionLongTuple2.f0, kafkaTopicPartitionLongTuple2.f1);
+		}
+
 		assertEquals(state2, snapshot2);
 		assertEquals(2, pendingCheckpoints.size());
 		assertEquals(state2, pendingCheckpoints.get(140L));
@@ -151,7 +229,14 @@ public class FlinkKafkaConsumerBaseTest {
 		assertTrue(pendingCheckpoints.containsKey(140L));
 
 		// checkpoint 3
-		HashMap<KafkaTopicPartition, Long> snapshot3 = consumer.snapshotState(141L, 1578L);
+		consumer.prepareSnapshot(141L, 141L);
+
+		HashMap<KafkaTopicPartition, Long> snapshot3 = new HashMap<>();
+
+		for (Tuple2<KafkaTopicPartition, Long> kafkaTopicPartitionLongTuple2 : listState1.get()) {
+			snapshot1.put(kafkaTopicPartitionLongTuple2.f0, kafkaTopicPartitionLongTuple2.f1);
+		}
+
 		assertEquals(state3, snapshot3);
 		assertEquals(2, pendingCheckpoints.size());
 		assertEquals(state3, pendingCheckpoints.get(141L));
@@ -164,9 +249,14 @@ public class FlinkKafkaConsumerBaseTest {
 		consumer.notifyCheckpointComplete(666); // invalid checkpoint
 		assertEquals(0, pendingCheckpoints.size());
 
+		OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
+		TestingListState<Tuple2<KafkaTopicPartition, Long>> listState = new TestingListState<>();
+		when(operatorStateStore.getPartitionableState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+
 		// create 500 snapshots
 		for (int i = 100; i < 600; i++) {
-			consumer.snapshotState(i, 15 * i);
+			consumer.prepareSnapshot(i, i);
+			listState.clear();
 		}
 		assertEquals(FlinkKafkaConsumerBase.MAX_NUM_PENDING_CHECKPOINTS, pendingCheckpoints.size());
 
@@ -211,12 +301,37 @@ public class FlinkKafkaConsumerBaseTest {
 
 		@SuppressWarnings("unchecked")
 		public DummyFlinkKafkaConsumer() {
-			super((KeyedDeserializationSchema<T>) mock(KeyedDeserializationSchema.class));
+			super(Arrays.asList("abc", "def"), (KeyedDeserializationSchema < T >) mock(KeyedDeserializationSchema.class));
 		}
 
 		@Override
 		protected AbstractFetcher<T, ?> createFetcher(SourceContext<T> sourceContext, List<KafkaTopicPartition> thisSubtaskPartitions, SerializedValue<AssignerWithPeriodicWatermarks<T>> watermarksPeriodic, SerializedValue<AssignerWithPunctuatedWatermarks<T>> watermarksPunctuated, StreamingRuntimeContext runtimeContext) throws Exception {
 			return null;
 		}
+
+		@Override
+		protected List<KafkaTopicPartition> getKafkaPartitions(List<String> topics) {
+			return Collections.emptyList();
+		}
+	}
+
+	private static final class TestingListState<T> implements ListState<T> {
+
+		private final List<T> list = new ArrayList<>();
+
+		@Override
+		public void clear() {
+			list.clear();
+		}
+
+		@Override
+		public Iterable<T> get() throws Exception {
+			return list;
+		}
+
+		@Override
+		public void add(T value) throws Exception {
+			list.add(value);
+		}
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumerTestBase.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumerTestBase.java b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumerTestBase.java
index a87ff8a..9c36b43 100644
--- a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumerTestBase.java
+++ b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumerTestBase.java
@@ -68,7 +68,6 @@ import org.apache.flink.streaming.api.functions.sink.SinkFunction;
 import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
 import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
-import org.apache.flink.streaming.api.functions.source.StatefulSequenceSource;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.api.watermark.Watermark;
@@ -92,7 +91,6 @@ import org.apache.flink.test.util.SuccessException;
 import org.apache.flink.testutils.junit.RetryOnException;
 import org.apache.flink.testutils.junit.RetryRule;
 import org.apache.flink.util.Collector;
-import org.apache.flink.util.StringUtils;
 import org.apache.kafka.clients.producer.ProducerConfig;
 import org.apache.kafka.common.errors.TimeoutException;
 import org.junit.Assert;
@@ -186,15 +184,27 @@ public abstract class KafkaConsumerTestBase extends KafkaTestBase {
 			DataStream<String> stream = see.addSource(source);
 			stream.print();
 			see.execute("No broker test");
-		} catch(RuntimeException re) {
+		} catch(ProgramInvocationException pie) {
 			if(kafkaServer.getVersion().equals("0.9")) {
-				Assert.assertTrue("Wrong RuntimeException thrown: " + StringUtils.stringifyException(re),
-						re.getClass().equals(TimeoutException.class) &&
-								re.getMessage().contains("Timeout expired while fetching topic metadata"));
+				assertTrue(pie.getCause() instanceof JobExecutionException);
+
+				JobExecutionException jee = (JobExecutionException) pie.getCause();
+
+				assertTrue(jee.getCause() instanceof TimeoutException);
+
+				TimeoutException te = (TimeoutException) jee.getCause();
+
+				assertEquals("Timeout expired while fetching topic metadata", te.getMessage());
 			} else {
-				Assert.assertTrue("Wrong RuntimeException thrown: " + StringUtils.stringifyException(re),
-						re.getClass().equals(RuntimeException.class) &&
-								re.getMessage().contains("Unable to retrieve any partitions for the requested topics [doesntexist]"));
+				assertTrue(pie.getCause() instanceof JobExecutionException);
+
+				JobExecutionException jee = (JobExecutionException) pie.getCause();
+
+				assertTrue(jee.getCause() instanceof RuntimeException);
+
+				RuntimeException re = (RuntimeException) jee.getCause();
+
+				assertTrue(re.getMessage().contains("Unable to retrieve any partitions for the requested topics [doesntexist]"));
 			}
 		}
 	}
@@ -413,7 +423,7 @@ public abstract class KafkaConsumerTestBase extends KafkaTestBase {
 		DataGenerators.generateRandomizedIntegerSequence(
 				StreamExecutionEnvironment.createRemoteEnvironment("localhost", flinkPort),
 				kafkaServer,
-				topic, numPartitions, numElementsPerPartition, true);
+				topic, numPartitions, numElementsPerPartition, false);
 
 		// run the topology that fails and recovers
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/testutils/MockRuntimeContext.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/testutils/MockRuntimeContext.java b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/testutils/MockRuntimeContext.java
index da2c652..5be4195 100644
--- a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/testutils/MockRuntimeContext.java
+++ b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/testutils/MockRuntimeContext.java
@@ -173,16 +173,6 @@ public class MockRuntimeContext extends StreamingRuntimeContext {
 	}
 
 	@Override
-	public <S> org.apache.flink.api.common.state.OperatorState<S> getKeyValueState(String name, Class<S> stateType, S defaultState) {
-		throw new UnsupportedOperationException();
-	}
-
-	@Override
-	public <S> org.apache.flink.api.common.state.OperatorState<S> getKeyValueState(String name, TypeInformation<S> stateType, S defaultState) {
-		throw new UnsupportedOperationException();
-	}
-
-	@Override
 	public <T> ValueState<T> getState(ValueStateDescriptor<T> stateProperties) {
 		throw new UnsupportedOperationException();
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java
index 6e2850c..4a0fd60 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java
@@ -36,6 +36,7 @@ import java.io.Serializable;
  * 
  * @param <T> The type of the operator state.
  */
+@Deprecated
 @PublicEvolving
 public interface Checkpointed<T extends Serializable> {
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedFunction.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedFunction.java
new file mode 100644
index 0000000..2227201
--- /dev/null
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedFunction.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.checkpoint;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.runtime.state.OperatorStateStore;
+
+/**
+ *
+ * Similar to @{@link Checkpointed}, this interface must be implemented by functions that have potentially
+ * repartitionable state that needs to be checkpointed. Methods from this interface are called upon checkpointing and
+ * restoring of state.
+ *
+ * On #initializeState the implementing class receives the {@link org.apache.flink.runtime.state.OperatorStateStore}
+ * to store it's state. At least before each snapshot, all state persistent state must be stored in the state store.
+ *
+ * When the backend is received for initialization, the user registers states with the backend via
+ * {@link org.apache.flink.api.common.state.StateDescriptor}. Then, all previously stored state is found in the
+ * received {@link org.apache.flink.api.common.state.State} (currently only
+ * {@link org.apache.flink.api.common.state.ListState} is supported.
+ *
+ * In #prepareSnapshot, the implementing class must ensure that all operator state is passed to the operator backend,
+ * i.e. that the state was stored in the relevant {@link org.apache.flink.api.common.state.State} instances that
+ * are requested on restore. Notice that users might want to clear and reinsert the complete state first if incremental
+ * updates of the states are not possible.
+ */
+@PublicEvolving
+public interface CheckpointedFunction {
+
+	/**
+	 *
+	 * This method is called when state should be stored for a checkpoint. The state can be registered and written to
+	 * the provided backend.
+	 *
+	 * @param checkpointId Id of the checkpoint to perform
+	 * @param timestamp Timestamp of the checkpoint
+	 * @throws Exception
+	 */
+	void prepareSnapshot(long checkpointId, long timestamp) throws Exception;
+
+	/**
+	 * This method is called when an operator is opened, so that the function can set the state backend to which it
+	 * hands it's state on snapshot.
+	 *
+	 * @param stateStore the state store to which this function stores it's state
+	 * @throws Exception
+	 */
+	void initializeState(OperatorStateStore stateStore) throws Exception;
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/ListCheckpointed.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/ListCheckpointed.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/ListCheckpointed.java
new file mode 100644
index 0000000..430b2b9
--- /dev/null
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/ListCheckpointed.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.checkpoint;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.java.typeutils.runtime.JavaSerializer;
+
+import java.io.Serializable;
+import java.util.List;
+
+/**
+ * This method must be implemented by functions that have state that needs to be
+ * checkpointed. The functions get a call whenever a checkpoint should take place
+ * and return a snapshot of their state as a list of redistributable sub-states,
+ * which will be checkpointed.
+ *
+ * @param <T> The type of the operator state.
+ */
+@PublicEvolving
+public interface ListCheckpointed<T extends Serializable> {
+
+	ListStateDescriptor<Serializable> DEFAULT_LIST_DESCRIPTOR =
+			new ListStateDescriptor<>("", new JavaSerializer<>());
+
+	/**
+	 * Gets the current state of the function of operator. The state must reflect the result of all
+	 * prior invocations to this function.
+	 *
+	 * @param checkpointId The ID of the checkpoint.
+	 * @param timestamp Timestamp of the checkpoint.
+	 * @return The operator state in a list of redistributable, atomic sub-states.
+	 * @throws Exception Thrown if the creation of the state object failed. This causes the
+	 *                   checkpoint to fail. The system may decide to fail the operation (and trigger
+	 *                   recovery), or to discard this checkpoint attempt and to continue running
+	 *                   and to try again with the next checkpoint attempt.
+	 */
+	List<T> snapshotState(long checkpointId, long timestamp) throws Exception;
+
+	/**
+	 * Restores the state of the function or operator to that of a previous checkpoint.
+	 * This method is invoked when a function is executed as part of a recovery run.
+	 * <p>
+	 * Note that restoreState() is called before open().
+	 *
+	 * @param state The state to be restored as a list of atomic sub-states.
+	 */
+	void restoreState(List<T> state) throws Exception;
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileReaderOperator.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileReaderOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileReaderOperator.java
index 0c0b81a..838bee6 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileReaderOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileReaderOperator.java
@@ -28,6 +28,7 @@ import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.core.fs.FileInputSplit;
 import org.apache.flink.metrics.Counter;
+import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.api.operators.OutputTypeConfigurable;
@@ -60,7 +61,7 @@ import static org.apache.flink.util.Preconditions.checkNotNull;
  */
 @Internal
 public class ContinuousFileReaderOperator<OUT, S extends Serializable> extends AbstractStreamOperator<OUT>
-	implements OneInputStreamOperator<FileInputSplit, OUT>, OutputTypeConfigurable<OUT> {
+	implements OneInputStreamOperator<FileInputSplit, OUT>, OutputTypeConfigurable<OUT>, StreamCheckpointedOperator {
 
 	private static final long serialVersionUID = 1L;
 
@@ -374,7 +375,6 @@ public class ContinuousFileReaderOperator<OUT, S extends Serializable> extends A
 
 	@Override
 	public void snapshotState(FSDataOutputStream os, long checkpointId, long timestamp) throws Exception {
-		super.snapshotState(os, checkpointId, timestamp);
 
 		final ObjectOutputStream oos = new ObjectOutputStream(os);
 
@@ -397,7 +397,6 @@ public class ContinuousFileReaderOperator<OUT, S extends Serializable> extends A
 
 	@Override
 	public void restoreState(FSDataInputStream is) throws Exception {
-		super.restoreState(is);
 
 		final ObjectInputStream ois = new ObjectInputStream(is);