You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by se...@apache.org on 2016/09/30 12:47:51 UTC

[01/10] flink git commit: [FLINK-4573] [web dashboard] Fix potential resource leak due to unclosed RandomAccessFile in TaskManagerLogHandler

Repository: flink
Updated Branches:
  refs/heads/master 477d1c5d4 -> 92f4539af


[FLINK-4573] [web dashboard] Fix potential resource leak due to unclosed RandomAccessFile in TaskManagerLogHandler

This closes #2556


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/2afc0924
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/2afc0924
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/2afc0924

Branch: refs/heads/master
Commit: 2afc092461cf68cf0f3c26a3ab4c58a7bd68cf71
Parents: 477d1c5
Author: Liwei Lin <lw...@gmail.com>
Authored: Tue Sep 27 20:49:52 2016 +0800
Committer: Stephan Ewen <se...@apache.org>
Committed: Fri Sep 30 11:32:39 2016 +0200

----------------------------------------------------------------------
 .../webmonitor/handlers/TaskManagerLogHandler.java        | 10 +++++++++-
 1 file changed, 9 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/2afc0924/flink-runtime-web/src/main/java/org/apache/flink/runtime/webmonitor/handlers/TaskManagerLogHandler.java
----------------------------------------------------------------------
diff --git a/flink-runtime-web/src/main/java/org/apache/flink/runtime/webmonitor/handlers/TaskManagerLogHandler.java b/flink-runtime-web/src/main/java/org/apache/flink/runtime/webmonitor/handlers/TaskManagerLogHandler.java
index 5343049..2f0d438 100644
--- a/flink-runtime-web/src/main/java/org/apache/flink/runtime/webmonitor/handlers/TaskManagerLogHandler.java
+++ b/flink-runtime-web/src/main/java/org/apache/flink/runtime/webmonitor/handlers/TaskManagerLogHandler.java
@@ -210,7 +210,15 @@ public class TaskManagerLogHandler extends RuntimeMonitorHandlerBase {
 							LOG.error("Displaying TaskManager log failed.", e);
 							return;
 						}
-						long fileLength = raf.length();
+						long fileLength;
+						try {
+							fileLength = raf.length();
+						} catch (IOException ioe) {
+							display(ctx, request, "Displaying TaskManager log failed.");
+							LOG.error("Displaying TaskManager log failed.", ioe);
+							raf.close();
+							throw ioe;
+						}
 						final FileChannel fc = raf.getChannel();
 
 						HttpResponse response = new DefaultHttpResponse(HTTP_1_1, OK);


[08/10] flink git commit: [FLINK-4379] [checkpoints] Introduce rescalable operator state

Posted by se...@apache.org.
[FLINK-4379] [checkpoints] Introduce rescalable operator state

This introduces the Operator State Backend, which stores state that is not partitioned
by a key. It replaces the 'Checkpointed' interface.

Additionally, this introduces CheckpointStateHandles as container for all checkpoint related state handles

This closes #2512


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/53ed6ada
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/53ed6ada
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/53ed6ada

Branch: refs/heads/master
Commit: 53ed6adac8cbe6b5dcb692dc9b94970f3ec5887c
Parents: 2afc092
Author: Stefan Richter <s....@data-artisans.com>
Authored: Wed Aug 31 23:59:27 2016 +0200
Committer: Stephan Ewen <se...@apache.org>
Committed: Fri Sep 30 12:38:46 2016 +0200

----------------------------------------------------------------------
 .../streaming/state/AbstractRocksDBState.java   |   6 +-
 .../state/RocksDBKeyedStateBackend.java         |  75 ++--
 .../streaming/state/RocksDBStateBackend.java    |   8 +-
 .../state/RocksDBAsyncSnapshotTest.java         |  12 +-
 .../state/RocksDBStateBackendConfigTest.java    |  48 ++-
 .../api/common/functions/RuntimeContext.java    | 125 +-----
 .../util/AbstractRuntimeUDFContext.java         |  28 +-
 .../flink/api/common/state/OperatorState.java   |  70 ---
 .../flink/api/common/state/ValueState.java      |   2 +-
 .../java/typeutils/runtime/JavaSerializer.java  | 116 +++++
 .../flink/hdfstests/FileStateBackendTest.java   |  26 +-
 .../AbstractCEPBasePatternOperator.java         |   3 +-
 .../operator/AbstractCEPPatternOperator.java    |   2 -
 .../AbstractKeyedCEPPatternOperator.java        |   2 -
 .../checkpoint/CheckpointCoordinator.java       | 127 +++++-
 .../runtime/checkpoint/CompletedCheckpoint.java |   5 -
 .../checkpoint/OperatorStateRepartitioner.java  |  42 ++
 .../runtime/checkpoint/PendingCheckpoint.java   |  95 +++--
 .../RoundRobinOperatorStateRepartitioner.java   | 190 +++++++++
 .../flink/runtime/checkpoint/SubtaskState.java  |   9 -
 .../flink/runtime/checkpoint/TaskState.java     |  79 +++-
 .../savepoint/SavepointV1Serializer.java        |  97 ++++-
 .../deployment/TaskDeploymentDescriptor.java    |  50 ++-
 .../flink/runtime/execution/Environment.java    |  16 +-
 .../flink/runtime/executiongraph/Execution.java |  25 +-
 .../runtime/executiongraph/ExecutionGraph.java  |   2 -
 .../runtime/executiongraph/ExecutionVertex.java |   6 +-
 .../runtime/jobgraph/tasks/StatefulTask.java    |  11 +-
 .../checkpoint/AcknowledgeCheckpoint.java       |  67 ++-
 .../runtime/state/AbstractCloseableHandle.java  | 126 ------
 .../state/AbstractKeyedStateBackend.java        | 342 +++++++++++++++
 .../runtime/state/AbstractStateBackend.java     |  43 +-
 .../flink/runtime/state/ChainedStateHandle.java |   7 +-
 .../runtime/state/CheckpointStateHandles.java   | 103 +++++
 .../flink/runtime/state/ClosableRegistry.java   |  84 ++++
 .../state/DefaultOperatorStateBackend.java      | 215 ++++++++++
 .../runtime/state/KeyGroupRangeOffsets.java     |   2 +
 .../runtime/state/KeyGroupsStateHandle.java     |   6 -
 .../flink/runtime/state/KeyedStateBackend.java  | 301 ++-----------
 .../runtime/state/OperatorStateBackend.java     |  35 ++
 .../runtime/state/OperatorStateHandle.java      | 109 +++++
 .../flink/runtime/state/OperatorStateStore.java |  47 +++
 ...artitionableCheckpointStateOutputStream.java |  96 +++++
 .../state/RetrievableStreamStateHandle.java     |   2 +-
 .../flink/runtime/state/SnapshotProvider.java   |  45 ++
 .../apache/flink/runtime/state/StateObject.java |   6 +-
 .../apache/flink/runtime/state/StateUtil.java   |  37 --
 .../state/filesystem/FileStateHandle.java       |   8 +-
 .../state/filesystem/FsStateBackend.java        |   6 +-
 .../state/heap/HeapKeyedStateBackend.java       | 210 ++++-----
 .../state/memory/ByteStreamStateHandle.java     |  13 +-
 .../state/memory/MemoryStateBackend.java        |   9 +-
 .../ActorGatewayCheckpointResponder.java        |  11 +-
 .../taskmanager/CheckpointResponder.java        |  15 +-
 .../runtime/taskmanager/RuntimeEnvironment.java |  12 +-
 .../apache/flink/runtime/taskmanager/Task.java  |  11 +-
 .../checkpoint/CheckpointCoordinatorTest.java   | 421 +++++++++++++++----
 .../checkpoint/CheckpointStateRestoreTest.java  |  46 +-
 .../CompletedCheckpointStoreTest.java           |   2 +-
 .../checkpoint/PendingCheckpointTest.java       |   2 +-
 .../checkpoint/PendingSavepointTest.java        |   2 +-
 ...ZooKeeperCompletedCheckpointStoreITCase.java |   5 -
 .../checkpoint/savepoint/SavepointV1Test.java   |  20 +-
 .../stats/SimpleCheckpointStatsTrackerTest.java |   2 +-
 .../jobmanager/JobManagerHARecoveryTest.java    |  20 +-
 .../messages/CheckpointMessagesTest.java        |  17 +-
 .../operators/testutils/DummyEnvironment.java   |   3 +-
 .../operators/testutils/MockEnvironment.java    |   3 +-
 .../runtime/query/QueryableStateClientTest.java |   4 +-
 .../runtime/query/netty/KvStateClientTest.java  |   5 +-
 .../query/netty/KvStateServerHandlerTest.java   |   7 +-
 .../runtime/query/netty/KvStateServerTest.java  |   4 +-
 .../state/AbstractCloseableHandleTest.java      |  97 -----
 .../runtime/state/FileStateBackendTest.java     |  35 +-
 .../runtime/state/MemoryStateBackendTest.java   |  15 +-
 .../runtime/state/OperatorStateBackendTest.java | 155 +++++++
 .../runtime/state/StateBackendTestBase.java     | 115 ++---
 .../FsCheckpointStateOutputStreamTest.java      |  16 +-
 .../runtime/taskmanager/TaskAsyncCallTest.java  |   5 +-
 .../ZooKeeperStateHandleStoreITCase.java        |   4 -
 .../connectors/kafka/FlinkKafkaConsumer08.java  |  37 +-
 .../connectors/kafka/KafkaConsumer08Test.java   |   4 +-
 .../connectors/kafka/FlinkKafkaConsumer09.java  |  43 +-
 .../kafka/FlinkKafkaConsumerBase.java           | 182 +++++---
 .../kafka/FlinkKafkaProducerBase.java           |  27 +-
 .../kafka/internals/AbstractFetcher.java        |   4 +-
 .../kafka/AtLeastOnceProducerTest.java          |  13 +-
 .../kafka/FlinkKafkaConsumerBaseTest.java       | 149 ++++++-
 .../connectors/kafka/KafkaConsumerTestBase.java |  30 +-
 .../kafka/testutils/MockRuntimeContext.java     |  10 -
 .../streaming/api/checkpoint/Checkpointed.java  |   1 +
 .../api/checkpoint/CheckpointedFunction.java    |  65 +++
 .../api/checkpoint/ListCheckpointed.java        |  65 +++
 .../source/ContinuousFileReaderOperator.java    |   5 +-
 .../api/operators/AbstractStreamOperator.java   |  96 ++++-
 .../operators/AbstractUdfStreamOperator.java    |  65 ++-
 .../operators/StreamCheckpointedOperator.java   |  58 +++
 .../streaming/api/operators/StreamOperator.java |  43 +-
 .../api/operators/StreamingRuntimeContext.java  |  32 --
 .../operators/GenericWriteAheadSink.java        |  25 +-
 .../windowing/EvictingWindowOperator.java       |   4 +-
 .../operators/windowing/WindowOperator.java     |  14 +-
 .../streaming/runtime/tasks/OperatorChain.java  |  50 ++-
 .../streaming/runtime/tasks/StreamTask.java     | 314 ++++++++------
 .../operators/StreamingRuntimeContextTest.java  |   8 +-
 .../streaming/runtime/io/BarrierBufferTest.java |   8 +-
 .../runtime/io/BarrierTrackerTest.java          |  13 +-
 .../operators/StreamOperatorChainingTest.java   |  15 +-
 .../tasks/InterruptSensitiveRestoreTest.java    |  68 +--
 .../runtime/tasks/OneInputStreamTaskTest.java   |  82 +++-
 .../runtime/tasks/StreamMockEnvironment.java    |   3 +-
 .../KeyedOneInputStreamOperatorTestHarness.java |  37 +-
 .../util/OneInputStreamOperatorTestHarness.java |  24 +-
 .../UdfStreamOperatorCheckpointingITCase.java   |  16 +-
 .../streaming/runtime/StateBackendITCase.java   |  11 +-
 115 files changed, 3981 insertions(+), 1890 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/AbstractRocksDBState.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/AbstractRocksDBState.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/AbstractRocksDBState.java
index e878ad5..9da33ef 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/AbstractRocksDBState.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/AbstractRocksDBState.java
@@ -156,7 +156,7 @@ public abstract class AbstractRocksDBState<K, N, S extends State, SD extends Sta
 
 	private void writeKey(K key) throws IOException {
 		//write key
-		int beforeWrite = (int) keySerializationStream.getPosition();
+		int beforeWrite = keySerializationStream.getPosition();
 		backend.getKeySerializer().serialize(key, keySerializationDateDataOutputView);
 
 		if (ambiguousKeyPossible) {
@@ -166,7 +166,7 @@ public abstract class AbstractRocksDBState<K, N, S extends State, SD extends Sta
 	}
 
 	private void writeNameSpace(N namespace) throws IOException {
-		int beforeWrite = (int) keySerializationStream.getPosition();
+		int beforeWrite = keySerializationStream.getPosition();
 		namespaceSerializer.serialize(namespace, keySerializationDateDataOutputView);
 
 		if (ambiguousKeyPossible) {
@@ -176,7 +176,7 @@ public abstract class AbstractRocksDBState<K, N, S extends State, SD extends Sta
 	}
 
 	private void writeLengthFrom(int fromPosition) throws IOException {
-		int length = (int) (keySerializationStream.getPosition() - fromPosition);
+		int length = keySerializationStream.getPosition() - fromPosition;
 		writeVariableIntBytes(length);
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index d5a96af..126ebd2 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -39,12 +39,12 @@ import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.runtime.io.async.AbstractAsyncIOCallable;
 import org.apache.flink.runtime.io.async.AsyncStoppableTaskWithCallback;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.DoneFuture;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.util.SerializableObject;
 import org.apache.flink.util.InstantiationUtil;
@@ -73,12 +73,12 @@ import java.util.PriorityQueue;
 import java.util.concurrent.RunnableFuture;
 
 /**
- * A {@link KeyedStateBackend} that stores its state in {@code RocksDB} and will serialize state to
+ * A {@link AbstractKeyedStateBackend} that stores its state in {@code RocksDB} and will serialize state to
  * streams provided by a {@link org.apache.flink.runtime.state.CheckpointStreamFactory} upon
  * checkpointing. This state backend can store very large state that exceeds memory and spills
- * to disk.
+ * to disk. Except for the snapshotting, this class should be accessed as if it is not threadsafe.
  */
-public class RocksDBKeyedStateBackend<K> extends KeyedStateBackend<K> {
+public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 	private static final Logger LOG = LoggerFactory.getLogger(RocksDBKeyedStateBackend.class);
 
@@ -98,9 +98,9 @@ public class RocksDBKeyedStateBackend<K> extends KeyedStateBackend<K> {
 	private final File instanceRocksDBPath;
 
 	/**
-	 * Lock for protecting cleanup of the RocksDB db. We acquire this when doing asynchronous
-	 * checkpoints and when disposing the db. Otherwise, the asynchronous snapshot might try
-	 * iterating over a disposed db.
+	 * Lock for protecting cleanup of the RocksDB against the checkpointing thread. We acquire this when doing
+	 * asynchronous checkpoints and when disposing the DB. Otherwise, the asynchronous snapshot might try
+	 * iterating over a disposed DB. After aquriring the lock, always first check if (db == null).
 	 */
 	private final SerializableObject dbDisposeLock = new SerializableObject();
 
@@ -110,13 +110,13 @@ public class RocksDBKeyedStateBackend<K> extends KeyedStateBackend<K> {
 	 * instance. They all write to this instance but to their own column family.
 	 */
 	@GuardedBy("dbDisposeLock")
-	protected volatile RocksDB db;
+	protected RocksDB db;
 
 	/**
 	 * Information about the k/v states as we create them. This is used to retrieve the
 	 * column family that is used for a state and also for sanity checks when restoring.
 	 */
-	private Map<String, Tuple2<ColumnFamilyHandle, StateDescriptor>> kvStateInformation;
+	private Map<String, Tuple2<ColumnFamilyHandle, StateDescriptor<?, ?>>> kvStateInformation;
 
 	/** Number of bytes required to prefix the key groups. */
 	private final int keyGroupPrefixBytes;
@@ -187,8 +187,8 @@ public class RocksDBKeyedStateBackend<K> extends KeyedStateBackend<K> {
 			KeyGroupRange keyGroupRange,
 			List<KeyGroupsStateHandle> restoreState
 	) throws Exception {
-		this(
-			jobId,
+
+		this(jobId,
 			operatorIdentifier,
 			userCodeClassLoader,
 			instanceBasePath,
@@ -210,15 +210,11 @@ public class RocksDBKeyedStateBackend<K> extends KeyedStateBackend<K> {
 	}
 
 	/**
-	 * @see java.io.Closeable
-	 *
-	 * Should only be called by one thread.
-	 *
-	 * @throws Exception
+	 * Should only be called by one thread, and only after all accesses to the DB happened.
 	 */
 	@Override
-	public void close() throws Exception {
-		super.close();
+	public void dispose() {
+		super.dispose();
 
 		final RocksDB cleanupRockDBReference;
 
@@ -233,13 +229,17 @@ public class RocksDBKeyedStateBackend<K> extends KeyedStateBackend<K> {
 
 		// Dispose decoupled db
 		if (cleanupRockDBReference != null) {
-			for (Tuple2<ColumnFamilyHandle, StateDescriptor> column : kvStateInformation.values()) {
+			for (Tuple2<ColumnFamilyHandle, StateDescriptor<?, ?>> column : kvStateInformation.values()) {
 				column.f0.dispose();
 			}
 			cleanupRockDBReference.dispose();
 		}
 
-		FileUtils.deleteDirectory(instanceBasePath);
+		try {
+			FileUtils.deleteDirectory(instanceBasePath);
+		} catch (IOException ioex) {
+			LOG.info("Could not delete instace base path for RocksDB: " + instanceBasePath);
+		}
 	}
 
 	public int getKeyGroupPrefixBytes() {
@@ -248,7 +248,7 @@ public class RocksDBKeyedStateBackend<K> extends KeyedStateBackend<K> {
 
 	/**
 	 * Triggers an asynchronous snapshot of the keyed state backend from RocksDB. This snapshot can be canceled and
-	 * is also stopped when the backend is closed through {@link #close()}. For each backend, this method must always
+	 * is also stopped when the backend is closed through {@link #dispose()}. For each backend, this method must always
 	 * be called by the same thread.
 	 *
 	 * @param checkpointId The Id of the checkpoint.
@@ -386,13 +386,13 @@ public class RocksDBKeyedStateBackend<K> extends KeyedStateBackend<K> {
 			Preconditions.checkArgument(outStream == null, "Output stream for snapshot is already set.");
 			outStream = checkpointStreamFactory.
 					createCheckpointStateOutputStream(checkpointId, checkpointTimeStamp);
+			stateBackend.cancelStreamRegistry.registerClosable(outStream);
 			outputView = new DataOutputViewStreamWrapper(outStream);
 		}
 
 		/**
 		 * 3) Write the actual data from RocksDB from the time we took the snapshot object in (1).
 		 *
-		 * @return
 		 * @throws IOException
 		 */
 		public void writeDBSnapshot() throws IOException, InterruptedException {
@@ -408,7 +408,8 @@ public class RocksDBKeyedStateBackend<K> extends KeyedStateBackend<K> {
 		 * @throws IOException
 		 */
 		public void closeCheckpointStream() throws IOException {
-			if(outStream != null) {
+			if (outStream != null) {
+				stateBackend.cancelStreamRegistry.unregisterClosable(outStream);
 				snapshotResultStateHandle = closeSnapshotStreamAndGetHandle();
 			}
 		}
@@ -451,7 +452,7 @@ public class RocksDBKeyedStateBackend<K> extends KeyedStateBackend<K> {
 
 			int kvStateId = 0;
 			//iterate all column families, where each column family holds one k/v state, to write the metadata
-			for (Map.Entry<String, Tuple2<ColumnFamilyHandle, StateDescriptor>> column : stateBackend.kvStateInformation.entrySet()) {
+			for (Map.Entry<String, Tuple2<ColumnFamilyHandle, StateDescriptor<?, ?>>> column : stateBackend.kvStateInformation.entrySet()) {
 
 				//be cooperative and check for interruption from time to time in the hot loop
 				checkInterrupted();
@@ -463,7 +464,7 @@ public class RocksDBKeyedStateBackend<K> extends KeyedStateBackend<K> {
 				ReadOptions readOptions = new ReadOptions();
 				readOptions.setSnapshot(snapshot);
 				RocksIterator iterator = stateBackend.db.newIterator(column.getValue().f0, readOptions);
-				kvStateIterators.add(new Tuple2<RocksIterator, Integer>(iterator, kvStateId));
+				kvStateIterators.add(new Tuple2<>(iterator, kvStateId));
 				++kvStateId;
 			}
 		}
@@ -624,15 +625,16 @@ public class RocksDBKeyedStateBackend<K> extends KeyedStateBackend<K> {
 				throws IOException, RocksDBException, ClassNotFoundException {
 			try {
 				currentStateHandleInStream = currentKeyGroupsStateHandle.getStateHandle().openInputStream();
+				rocksDBKeyedStateBackend.cancelStreamRegistry.registerClosable(currentStateHandleInStream);
 				currentStateHandleInView = new DataInputViewStreamWrapper(currentStateHandleInStream);
 				restoreKVStateMetaData();
 				restoreKVStateData();
 			} finally {
-				if(currentStateHandleInStream != null) {
+				if (currentStateHandleInStream != null) {
+					rocksDBKeyedStateBackend.cancelStreamRegistry.unregisterClosable(currentStateHandleInStream);
 					currentStateHandleInStream.close();
 				}
 			}
-
 		}
 
 		/**
@@ -652,19 +654,20 @@ public class RocksDBKeyedStateBackend<K> extends KeyedStateBackend<K> {
 			//restore the empty columns for the k/v states through the metadata
 			for (int i = 0; i < numColumns; i++) {
 
-				StateDescriptor stateDescriptor = InstantiationUtil.deserializeObject(
+				StateDescriptor<?, ?> stateDescriptor = (StateDescriptor<?, ?>) InstantiationUtil.deserializeObject(
 						currentStateHandleInStream,
 						rocksDBKeyedStateBackend.userCodeClassLoader);
 
-				Tuple2<ColumnFamilyHandle, StateDescriptor> columnFamily = rocksDBKeyedStateBackend.
+				Tuple2<ColumnFamilyHandle, StateDescriptor<?, ?>> columnFamily = rocksDBKeyedStateBackend.
 						kvStateInformation.get(stateDescriptor.getName());
 
-				if(null == columnFamily) {
+				if (null == columnFamily) {
 					ColumnFamilyDescriptor columnFamilyDescriptor = new ColumnFamilyDescriptor(
 							stateDescriptor.getName().getBytes(), rocksDBKeyedStateBackend.columnOptions);
 
-					columnFamily = new Tuple2<>(rocksDBKeyedStateBackend.db.
-							createColumnFamily(columnFamilyDescriptor), stateDescriptor);
+					columnFamily = new Tuple2<ColumnFamilyHandle, StateDescriptor<?, ?>>(
+							rocksDBKeyedStateBackend.db.createColumnFamily(columnFamilyDescriptor), stateDescriptor);
+
 					rocksDBKeyedStateBackend.kvStateInformation.put(stateDescriptor.getName(), columnFamily);
 				}
 
@@ -727,9 +730,9 @@ public class RocksDBKeyedStateBackend<K> extends KeyedStateBackend<K> {
 	 * <p>This also checks whether the {@link StateDescriptor} for a state matches the one
 	 * that we checkpointed, i.e. is already in the map of column families.
 	 */
-	protected ColumnFamilyHandle getColumnFamily(StateDescriptor descriptor) {
+	protected ColumnFamilyHandle getColumnFamily(StateDescriptor<?, ?> descriptor) {
 
-		Tuple2<ColumnFamilyHandle, StateDescriptor> stateInfo = kvStateInformation.get(descriptor.getName());
+		Tuple2<ColumnFamilyHandle, StateDescriptor<?, ?>> stateInfo = kvStateInformation.get(descriptor.getName());
 
 		if (stateInfo != null) {
 			if (!stateInfo.f1.equals(descriptor)) {
@@ -744,7 +747,9 @@ public class RocksDBKeyedStateBackend<K> extends KeyedStateBackend<K> {
 
 		try {
 			ColumnFamilyHandle columnFamily = db.createColumnFamily(columnDescriptor);
-			kvStateInformation.put(descriptor.getName(), new Tuple2<>(columnFamily, descriptor));
+			Tuple2<ColumnFamilyHandle, StateDescriptor<?, ?>> tuple =
+					new Tuple2<ColumnFamilyHandle, StateDescriptor<?, ?>>(columnFamily, descriptor);
+			kvStateInformation.put(descriptor.getName(), tuple);
 			return columnFamily;
 		} catch (RocksDBException e) {
 			throw new RuntimeException("Error creating ColumnFamilyHandle.", e);

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
index b6ce224..a0c980b 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
@@ -23,11 +23,11 @@ import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.core.fs.Path;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.filesystem.FsStateBackend;
 import org.rocksdb.ColumnFamilyOptions;
 import org.rocksdb.DBOptions;
@@ -224,7 +224,7 @@ public class RocksDBStateBackend extends AbstractStateBackend {
 	}
 
 	@Override
-	public <K> KeyedStateBackend<K> createKeyedStateBackend(
+	public <K> AbstractKeyedStateBackend<K> createKeyedStateBackend(
 			Environment env,
 			JobID jobID,
 			String operatorIdentifier,
@@ -251,7 +251,9 @@ public class RocksDBStateBackend extends AbstractStateBackend {
 	}
 
 	@Override
-	public <K> KeyedStateBackend<K> restoreKeyedStateBackend(Environment env, JobID jobID,
+	public <K> AbstractKeyedStateBackend<K> restoreKeyedStateBackend(
+			Environment env,
+			JobID jobID,
 			String operatorIdentifier,
 			TypeSerializer<K> keySerializer,
 			int numberOfKeyGroups,

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
index c0c9ca1..bccbabc 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBAsyncSnapshotTest.java
@@ -29,10 +29,8 @@ import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.core.testutils.OneShotLatch;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
-import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory;
@@ -66,7 +64,6 @@ import java.io.IOException;
 import java.lang.reflect.Field;
 import java.net.URI;
 import java.util.Arrays;
-import java.util.List;
 import java.util.UUID;
 import java.util.concurrent.CancellationException;
 import java.util.concurrent.ExecutorService;
@@ -138,12 +135,11 @@ public class RocksDBAsyncSnapshotTest {
 			@Override
 			public void acknowledgeCheckpoint(
 					long checkpointId,
-					ChainedStateHandle<StreamStateHandle> chainedStateHandle, 
-					List<KeyGroupsStateHandle> keyGroupStateHandles,
+					CheckpointStateHandles checkpointStateHandles,
 					long synchronousDurationMillis, long asynchronousDurationMillis,
 					long bytesBufferedInAlignment, long alignmentDurationNanos) {
 
-				super.acknowledgeCheckpoint(checkpointId, chainedStateHandle, keyGroupStateHandles,
+				super.acknowledgeCheckpoint(checkpointId, checkpointStateHandles,
 						synchronousDurationMillis, asynchronousDurationMillis,
 						bytesBufferedInAlignment, alignmentDurationNanos);
 
@@ -156,7 +152,7 @@ public class RocksDBAsyncSnapshotTest {
 				}
 
 				// should be only one k/v state
-				assertEquals(1, keyGroupStateHandles.size());
+				assertEquals(1, checkpointStateHandles.getKeyGroupsStateHandle().size());
 
 				// we now know that the checkpoint went through
 				ensureCheckpointLatch.trigger();

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendConfigTest.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendConfigTest.java b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendConfigTest.java
index 3b851be..07fc27c 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendConfigTest.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendConfigTest.java
@@ -26,7 +26,6 @@ import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.io.disk.iomanager.IOManager;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.query.KvStateRegistry;
-
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.util.OperatingSystem;
@@ -34,7 +33,6 @@ import org.junit.Assume;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
-
 import org.junit.rules.TemporaryFolder;
 import org.rocksdb.ColumnFamilyOptions;
 import org.rocksdb.CompactionStyle;
@@ -45,8 +43,18 @@ import java.io.File;
 import static org.hamcrest.CoreMatchers.anyOf;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.startsWith;
-import static org.junit.Assert.*;
-import static org.mockito.Mockito.*;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 
 
 /**
@@ -88,13 +96,15 @@ public class RocksDBStateBackendConfigTest {
 		assertArrayEquals(new String[] { testDir1.getAbsolutePath(), testDir2.getAbsolutePath() }, rocksDbBackend.getDbStoragePaths());
 
 		Environment env = getMockEnvironment(new File[] {});
-		RocksDBKeyedStateBackend<Integer> keyedBackend = (RocksDBKeyedStateBackend<Integer>) rocksDbBackend.createKeyedStateBackend(env,
-				env.getJobID(),
-				"test_op",
-				IntSerializer.INSTANCE,
-				1,
-				new KeyGroupRange(0, 0),
-				env.getTaskKvStateRegistry());
+		RocksDBKeyedStateBackend<Integer> keyedBackend = (RocksDBKeyedStateBackend<Integer>) rocksDbBackend.
+				createKeyedStateBackend(
+						env,
+						env.getJobID(),
+						"test_op",
+						IntSerializer.INSTANCE,
+						1,
+						new KeyGroupRange(0, 0),
+						env.getTaskKvStateRegistry());
 
 
 		File instanceBasePath = keyedBackend.getInstanceBasePath();
@@ -142,13 +152,15 @@ public class RocksDBStateBackendConfigTest {
 		assertNull(rocksDbBackend.getDbStoragePaths());
 
 		Environment env = getMockEnvironment(tempDirs);
-		RocksDBKeyedStateBackend<Integer> keyedBackend = (RocksDBKeyedStateBackend<Integer>) rocksDbBackend.createKeyedStateBackend(env,
-				env.getJobID(),
-				"test_op",
-				IntSerializer.INSTANCE,
-				1,
-				new KeyGroupRange(0, 0),
-				env.getTaskKvStateRegistry());
+		RocksDBKeyedStateBackend<Integer> keyedBackend = (RocksDBKeyedStateBackend<Integer>) rocksDbBackend.
+				createKeyedStateBackend(
+						env,
+						env.getJobID(),
+						"test_op",
+						IntSerializer.INSTANCE,
+						1,
+						new KeyGroupRange(0, 0),
+						env.getTaskKvStateRegistry());
 
 
 		File instanceBasePath = keyedBackend.getInstanceBasePath();

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-core/src/main/java/org/apache/flink/api/common/functions/RuntimeContext.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/functions/RuntimeContext.java b/flink-core/src/main/java/org/apache/flink/api/common/functions/RuntimeContext.java
index a9e8da9..ce513cb 100644
--- a/flink-core/src/main/java/org/apache/flink/api/common/functions/RuntimeContext.java
+++ b/flink-core/src/main/java/org/apache/flink/api/common/functions/RuntimeContext.java
@@ -18,12 +18,8 @@
 
 package org.apache.flink.api.common.functions;
 
-import java.io.Serializable;
-import java.util.List;
-import java.util.Map;
-
-import org.apache.flink.annotation.PublicEvolving;
 import org.apache.flink.annotation.Public;
+import org.apache.flink.annotation.PublicEvolving;
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.accumulators.Accumulator;
 import org.apache.flink.api.common.accumulators.DoubleCounter;
@@ -33,14 +29,16 @@ import org.apache.flink.api.common.accumulators.LongCounter;
 import org.apache.flink.api.common.cache.DistributedCache;
 import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ListStateDescriptor;
-import org.apache.flink.api.common.state.OperatorState;
 import org.apache.flink.api.common.state.ReducingState;
 import org.apache.flink.api.common.state.ReducingStateDescriptor;
 import org.apache.flink.api.common.state.ValueState;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
-import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.metrics.MetricGroup;
 
+import java.io.Serializable;
+import java.util.List;
+import java.util.Map;
+
 /**
  * A RuntimeContext contains information about the context in which functions are executed. Each parallel instance
  * of the function will have a context through which it can access static contextual information (such as 
@@ -347,117 +345,4 @@ public interface RuntimeContext {
 	 */
 	@PublicEvolving
 	<T> ReducingState<T> getReducingState(ReducingStateDescriptor<T> stateProperties);
-	
-	/**
-	 * Gets the key/value state, which is only accessible if the function is executed on
-	 * a KeyedStream. Upon calling {@link ValueState#value()}, the key/value state will
-	 * return the value bound to the key of the element currently processed by the function.
-	 * Each operator may maintain multiple key/value states, addressed with different names.
-	 *
-	 * <p>Because the scope of each value is the key of the currently processed element,
-	 * and the elements are distributed by the Flink runtime, the system can transparently
-	 * scale out and redistribute the state and KeyedStream.
-	 *
-	 * <p>The following code example shows how to implement a continuous counter that counts
-	 * how many times elements of a certain key occur, and emits an updated count for that
-	 * element on each occurrence. 
-	 *
-	 * <pre>{@code
-	 * DataStream<MyType> stream = ...;
-	 * KeyedStream<MyType> keyedStream = stream.keyBy("id");     
-	 *
-	 * keyedStream.map(new RichMapFunction<MyType, Tuple2<MyType, Long>>() {
-	 *
-	 *     private State<Long> state;
-	 *
-	 *     public void open(Configuration cfg) {
-	 *         state = getRuntimeContext().getKeyValueState(Long.class, 0L);
-	 *     }
-	 *
-	 *     public Tuple2<MyType, Long> map(MyType value) {
-	 *         long count = state.value();
-	 *         state.update(value + 1);
-	 *         return new Tuple2<>(value, count);
-	 *     }
-	 * });
-	 *
-	 * }</pre>
-	 *
-	 * <p>This method attempts to deduce the type information from the given type class. If the
-	 * full type cannot be determined from the class (for example because of generic parameters),
-	 * the TypeInformation object must be manually passed via 
-	 * {@link #getKeyValueState(String, TypeInformation, Object)}. 
-	 * 
-	 *
-	 * @param name The name of the key/value state.
-	 * @param stateType The class of the type that is stored in the state. Used to generate
-	 *                  serializers for managed memory and checkpointing.
-	 * @param defaultState The default state value, returned when the state is accessed and
-	 *                     no value has yet been set for the key. May be null.
-	 *
-	 * @param <S> The type of the state.
-	 *
-	 * @return The key/value state access.
-	 *
-	 * @throws UnsupportedOperationException Thrown, if no key/value state is available for the
-	 *                                       function (function is not part os a KeyedStream).
-	 * 
-	 * @deprecated Use the more expressive {@link #getState(ValueStateDescriptor)} instead.
-	 */
-	@Deprecated
-	@PublicEvolving
-	<S> OperatorState<S> getKeyValueState(String name, Class<S> stateType, S defaultState);
-
-	/**
-	 * Gets the key/value state, which is only accessible if the function is executed on
-	 * a KeyedStream. Upon calling {@link ValueState#value()}, the key/value state will
-	 * return the value bound to the key of the element currently processed by the function.
-	 * Each operator may maintain multiple key/value states, addressed with different names.
-	 * 
-	 * <p>Because the scope of each value is the key of the currently processed element,
-	 * and the elements are distributed by the Flink runtime, the system can transparently
-	 * scale out and redistribute the state and KeyedStream.
-	 * 
-	 * <p>The following code example shows how to implement a continuous counter that counts
-	 * how many times elements of a certain key occur, and emits an updated count for that
-	 * element on each occurrence. 
-	 * 
-	 * <pre>{@code
-	 * DataStream<MyType> stream = ...;
-	 * KeyedStream<MyType> keyedStream = stream.keyBy("id");     
-	 * 
-	 * keyedStream.map(new RichMapFunction<MyType, Tuple2<MyType, Long>>() {
-	 * 
-	 *     private State<Long> state;
-	 *     
-	 *     public void open(Configuration cfg) {
-	 *         state = getRuntimeContext().getKeyValueState(Long.class, 0L);
-	 *     }
-	 *     
-	 *     public Tuple2<MyType, Long> map(MyType value) {
-	 *         long count = state.value();
-	 *         state.update(value + 1);
-	 *         return new Tuple2<>(value, count);
-	 *     }
-	 * });
-	 *     
-	 * }</pre>
-	 * 
-	 * @param name The name of the key/value state.
-	 * @param stateType The type information for the type that is stored in the state.
-	 *                  Used to create serializers for managed memory and checkpoints.
-	 * @param defaultState The default state value, returned when the state is accessed and
-	 *                     no value has yet been set for the key. May be null.
-	 * @param <S> The type of the state.
-	 *
-	 * @return The key/value state access.
-	 * 
-	 * @throws UnsupportedOperationException Thrown, if no key/value state is available for the
-	 *                                       function (function is not part os a KeyedStream).
-	 * 
-	 * @deprecated Use the more expressive {@link #getState(ValueStateDescriptor)} instead.
-	 */
-	@Deprecated
-	@PublicEvolving
-	<S> OperatorState<S> getKeyValueState(String name, TypeInformation<S> stateType, S defaultState);
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-core/src/main/java/org/apache/flink/api/common/functions/util/AbstractRuntimeUDFContext.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/functions/util/AbstractRuntimeUDFContext.java b/flink-core/src/main/java/org/apache/flink/api/common/functions/util/AbstractRuntimeUDFContext.java
index 6645964..4f559bf 100644
--- a/flink-core/src/main/java/org/apache/flink/api/common/functions/util/AbstractRuntimeUDFContext.java
+++ b/flink-core/src/main/java/org/apache/flink/api/common/functions/util/AbstractRuntimeUDFContext.java
@@ -18,11 +18,6 @@
 
 package org.apache.flink.api.common.functions.util;
 
-import java.io.Serializable;
-import java.util.Collections;
-import java.util.Map;
-import java.util.concurrent.Future;
-
 import org.apache.flink.annotation.PublicEvolving;
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.TaskInfo;
@@ -36,15 +31,18 @@ import org.apache.flink.api.common.cache.DistributedCache;
 import org.apache.flink.api.common.functions.RuntimeContext;
 import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ListStateDescriptor;
-import org.apache.flink.api.common.state.OperatorState;
 import org.apache.flink.api.common.state.ReducingState;
 import org.apache.flink.api.common.state.ReducingStateDescriptor;
 import org.apache.flink.api.common.state.ValueState;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
-import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.core.fs.Path;
 import org.apache.flink.metrics.MetricGroup;
 
+import java.io.Serializable;
+import java.util.Collections;
+import java.util.Map;
+import java.util.concurrent.Future;
+
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
 /**
@@ -207,20 +205,4 @@ public abstract class AbstractRuntimeUDFContext implements RuntimeContext {
 		throw new UnsupportedOperationException(
 				"This state is only accessible by functions executed on a KeyedStream");
 	}
-
-	@Override
-	@Deprecated
-	@PublicEvolving
-	public <S> OperatorState<S> getKeyValueState(String name, Class<S> stateType, S defaultState) {
-		throw new UnsupportedOperationException(
-				"This state is only accessible by functions executed on a KeyedStream");
-	}
-
-	@Override
-	@Deprecated
-	@PublicEvolving
-	public <S> OperatorState<S> getKeyValueState(String name, TypeInformation<S> stateType, S defaultState) {
-		throw new UnsupportedOperationException(
-				"This state is only accessible by functions executed on a KeyedStream");
-	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorState.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorState.java b/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorState.java
deleted file mode 100644
index ac4ed07..0000000
--- a/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorState.java
+++ /dev/null
@@ -1,70 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.api.common.state;
-
-import org.apache.flink.annotation.PublicEvolving;
-
-import java.io.IOException;
-
-/**
- * This state interface abstracts persistent key/value state in streaming programs.
- * The state is accessed and modified by user functions, and checkpointed consistently
- * by the system as part of the distributed snapshots.
- *
- * <p>The state is only accessible by functions applied on a KeyedDataStream. The key is
- * automatically supplied by the system, so the function always sees the value mapped to the
- * key of the current element. That way, the system can handle stream and state partitioning
- * consistently together.
- *
- * @param <T> Type of the value in the operator state
- * 
- * @deprecated OperatorState has been replaced by {@link ValueState}.
- */
-@Deprecated
-@PublicEvolving
-public interface OperatorState<T> extends State {
-
-	/**
-	 * Returns the current value for the state. When the state is not
-	 * partitioned the returned value is the same for all inputs in a given
-	 * operator instance. If state partitioning is applied, the value returned
-	 * depends on the current operator input, as the operator maintains an
-	 * independent state for each partition.
-	 *
-	 * @return The operator state value corresponding to the current input.
-	 *
-	 * @throws IOException Thrown if the system cannot access the state.
-	 */
-	T value() throws IOException;
-
-	/**
-	 * Updates the operator state accessible by {@link #value()} to the given
-	 * value. The next time {@link #value()} is called (for the same state
-	 * partition) the returned state will represent the updated value. When a
-	 * partitioned state is updated with null, the state for the current key 
-	 * will be removed and the default value is returned on the next access.
-	 *
-	 * @param value
-	 *            The new value for the state.
-	 *
-	 * @throws IOException Thrown if the system cannot access the state.
-	 */
-	void update(T value) throws IOException;
-
-}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-core/src/main/java/org/apache/flink/api/common/state/ValueState.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/state/ValueState.java b/flink-core/src/main/java/org/apache/flink/api/common/state/ValueState.java
index 607cb32..de3250a 100644
--- a/flink-core/src/main/java/org/apache/flink/api/common/state/ValueState.java
+++ b/flink-core/src/main/java/org/apache/flink/api/common/state/ValueState.java
@@ -37,7 +37,7 @@ import java.io.IOException;
  * @param <T> Type of the value in the state.
  */
 @PublicEvolving
-public interface ValueState<T> extends State, OperatorState<T> {
+public interface ValueState<T> extends State {
 
 	/**
 	 * Returns the current value for the state. When the state is not

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-core/src/main/java/org/apache/flink/api/java/typeutils/runtime/JavaSerializer.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/runtime/JavaSerializer.java b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/runtime/JavaSerializer.java
new file mode 100644
index 0000000..4ae00d1
--- /dev/null
+++ b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/runtime/JavaSerializer.java
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.java.typeutils.runtime;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.util.InstantiationUtil;
+
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import java.io.Serializable;
+
+public class JavaSerializer<T extends Serializable> extends TypeSerializer<T> {
+
+	private static final long serialVersionUID = 1L;
+
+	@Override
+	public boolean isImmutableType() {
+		return false;
+	}
+
+	@Override
+	public TypeSerializer<T> duplicate() {
+		return this;
+	}
+
+	@Override
+	public T createInstance() {
+		return null;
+	}
+
+	@Override
+	public T copy(T from) {
+
+		try {
+			return InstantiationUtil.clone(from);
+		} catch (IOException | ClassNotFoundException e) {
+			throw new RuntimeException("Could not copy instance of " + from + '.', e);
+		}
+	}
+
+	@Override
+	public T copy(T from, T reuse) {
+		return copy(from);
+	}
+
+	@Override
+	public int getLength() {
+		return 0;
+	}
+
+	@Override
+	public void serialize(T record, DataOutputView target) throws IOException {
+		ObjectOutputStream oos = new ObjectOutputStream(new DataOutputViewStream(target));
+		oos.writeObject(record);
+		oos.flush();
+	}
+
+	@Override
+	public T deserialize(DataInputView source) throws IOException {
+		ObjectInputStream ois = new ObjectInputStream(new DataInputViewStream(source));
+
+		try {
+			@SuppressWarnings("unchecked")
+			T nfa = (T) ois.readObject();
+			return nfa;
+		} catch (ClassNotFoundException e) {
+			throw new RuntimeException("Could not deserialize NFA.", e);
+		}
+	}
+
+	@Override
+	public T deserialize(T reuse, DataInputView source) throws IOException {
+		return deserialize(source);
+	}
+
+	@Override
+	public void copy(DataInputView source, DataOutputView target) throws IOException {
+		int size = source.readInt();
+		target.writeInt(size);
+		target.write(source, size);
+	}
+
+	@Override
+	public boolean equals(Object obj) {
+		return obj instanceof JavaSerializer && ((JavaSerializer<T>) obj).canEqual(this);
+	}
+
+	@Override
+	public boolean canEqual(Object obj) {
+		return obj instanceof JavaSerializer;
+	}
+
+	@Override
+	public int hashCode() {
+		return getClass().hashCode();
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/FileStateBackendTest.java
----------------------------------------------------------------------
diff --git a/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/FileStateBackendTest.java b/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/FileStateBackendTest.java
index df40998..080485e 100644
--- a/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/FileStateBackendTest.java
+++ b/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/FileStateBackendTest.java
@@ -19,7 +19,6 @@
 package org.apache.flink.hdfstests;
 
 import org.apache.commons.io.FileUtils;
-
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.core.fs.FileStatus;
@@ -31,10 +30,8 @@ import org.apache.flink.runtime.state.StateBackendTestBase;
 import org.apache.flink.runtime.state.filesystem.FileStateHandle;
 import org.apache.flink.runtime.state.filesystem.FsStateBackend;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
-
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.hdfs.MiniDFSCluster;
-
 import org.junit.AfterClass;
 import org.junit.BeforeClass;
 import org.junit.Test;
@@ -243,16 +240,21 @@ public class FileStateBackendTest extends StateBackendTestBase<FsStateBackend> {
 	}
 
 	private static void validateBytesInStream(InputStream is, byte[] data) throws IOException {
-		byte[] holder = new byte[data.length];
 
-		int pos = 0;
-		int read;
-		while (pos < holder.length && (read = is.read(holder, pos, holder.length - pos)) != -1) {
-			pos += read;
-		}
+		try {
+			byte[] holder = new byte[data.length];
+
+			int pos = 0;
+			int read;
+			while (pos < holder.length && (read = is.read(holder, pos, holder.length - pos)) != -1) {
+				pos += read;
+			}
 
-		assertEquals("not enough data", holder.length, pos);
-		assertEquals("too much data", -1, is.read());
-		assertArrayEquals("wrong data", data, holder);
+			assertEquals("not enough data", holder.length, pos);
+			assertEquals("too much data", -1, is.read());
+			assertArrayEquals("wrong data", data, holder);
+		} finally {
+			is.close();
+		}
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/AbstractCEPBasePatternOperator.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/AbstractCEPBasePatternOperator.java b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/AbstractCEPBasePatternOperator.java
index aad408c..2f21346 100644
--- a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/AbstractCEPBasePatternOperator.java
+++ b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/AbstractCEPBasePatternOperator.java
@@ -20,6 +20,7 @@ package org.apache.flink.cep.operator;
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.cep.nfa.NFA;
+import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
@@ -36,7 +37,7 @@ import java.util.PriorityQueue;
  */
 public abstract class AbstractCEPBasePatternOperator<IN, OUT>
 	extends AbstractStreamOperator<OUT>
-	implements OneInputStreamOperator<IN, OUT> {
+	implements OneInputStreamOperator<IN, OUT>, StreamCheckpointedOperator {
 
 	private static final long serialVersionUID = -4166778210774160757L;
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/AbstractCEPPatternOperator.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/AbstractCEPPatternOperator.java b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/AbstractCEPPatternOperator.java
index 64ffa2a..10bb6ff 100644
--- a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/AbstractCEPPatternOperator.java
+++ b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/AbstractCEPPatternOperator.java
@@ -104,7 +104,6 @@ abstract public class AbstractCEPPatternOperator<IN, OUT> extends AbstractCEPBas
 
 	@Override
 	public void snapshotState(FSDataOutputStream out, long checkpointId, long timestamp) throws Exception {
-		super.snapshotState(out, checkpointId, timestamp);
 		final ObjectOutputStream oos = new ObjectOutputStream(out);
 
 		oos.writeObject(nfa);
@@ -119,7 +118,6 @@ abstract public class AbstractCEPPatternOperator<IN, OUT> extends AbstractCEPBas
 	@Override
 	@SuppressWarnings("unchecked")
 	public void restoreState(FSDataInputStream state) throws Exception {
-		super.restoreState(state);
 		final ObjectInputStream ois = new ObjectInputStream(state);
 
 		nfa = (NFA<IN>)ois.readObject();

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/AbstractKeyedCEPPatternOperator.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/AbstractKeyedCEPPatternOperator.java b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/AbstractKeyedCEPPatternOperator.java
index 09773a2..07e2662 100644
--- a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/AbstractKeyedCEPPatternOperator.java
+++ b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/AbstractKeyedCEPPatternOperator.java
@@ -187,7 +187,6 @@ abstract public class AbstractKeyedCEPPatternOperator<IN, KEY, OUT> extends Abst
 
 	@Override
 	public void snapshotState(FSDataOutputStream out, long checkpointId, long timestamp) throws Exception {
-		super.snapshotState(out, checkpointId, timestamp);
 
 		DataOutputView ov = new DataOutputViewStreamWrapper(out);
 		ov.writeInt(keys.size());
@@ -199,7 +198,6 @@ abstract public class AbstractKeyedCEPPatternOperator<IN, KEY, OUT> extends Abst
 
 	@Override
 	public void restoreState(FSDataInputStream state) throws Exception {
-		super.restoreState(state);
 
 		DataInputView inputView = new DataInputViewStreamWrapper(state);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
index 6a43ddf..4428427 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
@@ -34,9 +34,11 @@ import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.NotifyCheckpointComplete;
 import org.apache.flink.runtime.messages.checkpoint.TriggerCheckpoint;
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.util.Preconditions;
 import org.slf4j.Logger;
@@ -45,7 +47,9 @@ import scala.concurrent.Future;
 
 import java.util.ArrayDeque;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.Iterator;
 import java.util.LinkedHashMap;
@@ -398,9 +402,9 @@ public class CheckpointCoordinator {
 				return new CheckpointTriggerResult(CheckpointDeclineReason.EXCEPTION);
 			}
 
-		final PendingCheckpoint checkpoint = props.isSavepoint() ?
-			new PendingSavepoint(job, checkpointID, timestamp, ackTasks, savepointStore) :
-			new PendingCheckpoint(job, checkpointID, timestamp, ackTasks);
+			final PendingCheckpoint checkpoint = props.isSavepoint() ?
+					new PendingSavepoint(job, checkpointID, timestamp, ackTasks, savepointStore) :
+					new PendingCheckpoint(job, checkpointID, timestamp, ackTasks);
 
 			// schedule the timer that will clean up the expired checkpoints
 			TimerTask canceller = new TimerTask() {
@@ -627,9 +631,8 @@ public class CheckpointCoordinator {
 				isPendingCheckpoint = true;
 
 				if (checkpoint.acknowledgeTask(
-					message.getTaskExecutionId(),
-					message.getStateHandle(),
-					message.getKeyGroupsStateHandle())) {
+						message.getTaskExecutionId(),
+						message.getCheckpointStateHandles())) {
 					if (checkpoint.isFullyAcknowledged()) {
 						completed = checkpoint.finalizeCheckpoint();
 
@@ -640,7 +643,7 @@ public class CheckpointCoordinator {
 
 						if (LOG.isDebugEnabled()) {
 							StringBuilder builder = new StringBuilder();
-							for (Map.Entry<JobVertexID, TaskState> entry: completed.getTaskStates().entrySet()) {
+							for (Map.Entry<JobVertexID, TaskState> entry : completed.getTaskStates().entrySet()) {
 								builder.append("JobVertexID: ").append(entry.getKey()).append(" {").append(entry.getValue()).append("}");
 							}
 
@@ -654,8 +657,7 @@ public class CheckpointCoordinator {
 
 						triggerQueuedRequests();
 					}
-				}
-				else {
+				} else {
 					// checkpoint did not accept message
 					LOG.error("Received duplicate or invalid acknowledge message for checkpoint " + checkpointId
 							+ " , task " + message.getTaskExecutionId());
@@ -790,22 +792,80 @@ public class CheckpointCoordinator {
 					}
 
 
+					int oldParallelism = taskState.getParallelism();
+					int newParallelism = executionJobVertex.getParallelism();
+					boolean parallelismChanged = oldParallelism != newParallelism;
 					boolean hasNonPartitionedState = taskState.hasNonPartitionedState();
 
-					if (hasNonPartitionedState && taskState.getParallelism() != executionJobVertex.getParallelism()) {
+					if (hasNonPartitionedState && parallelismChanged) {
 						throw new IllegalStateException("Cannot restore the latest checkpoint because " +
 							"the operator " + executionJobVertex.getJobVertexId() + " has non-partitioned " +
 							"state and its parallelism changed. The operator" + executionJobVertex.getJobVertexId() +
-							" has parallelism " + executionJobVertex.getParallelism() + " whereas the corresponding" +
-							"state object has a parallelism of " + taskState.getParallelism());
+							" has parallelism " + newParallelism + " whereas the corresponding" +
+							"state object has a parallelism of " + oldParallelism);
 					}
 
 					List<KeyGroupRange> keyGroupPartitions = createKeyGroupPartitions(
-						executionJobVertex.getMaxParallelism(),
-						executionJobVertex.getParallelism());
+							executionJobVertex.getMaxParallelism(),
+							newParallelism);
+					
+					// operator chain index -> list of the stored partitionables states from all parallel instances
+					@SuppressWarnings("unchecked")
+					List<OperatorStateHandle>[] chainParallelStates =
+							new List[taskState.getChainLength()];
+
+					for (int i = 0; i < oldParallelism; ++i) {
+
+						ChainedStateHandle<OperatorStateHandle> partitionableState =
+								taskState.getPartitionableState(i);
+
+						if (partitionableState != null) {
+							for (int j = 0; j < partitionableState.getLength(); ++j) {
+								OperatorStateHandle opParalleState = partitionableState.get(j);
+								if (opParalleState != null) {
+									List<OperatorStateHandle> opParallelStates =
+											chainParallelStates[j];
+									if (opParallelStates == null) {
+										opParallelStates = new ArrayList<>();
+										chainParallelStates[j] = opParallelStates;
+									}
+									opParallelStates.add(opParalleState);
+								}
+							}
+						}
+					}
+
+					// operator chain index -> lists with collected states (one collection for each parallel subtasks)
+					@SuppressWarnings("unchecked")
+					List<Collection<OperatorStateHandle>>[] redistributedParallelStates =
+							new List[taskState.getChainLength()];
+
+					//TODO here we can employ different redistribution strategies for state, e.g. union state. For now we only offer round robin as the default.
+					OperatorStateRepartitioner repartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE;
+
+					for (int i = 0; i < chainParallelStates.length; ++i) {
+						List<OperatorStateHandle> chainOpParallelStates = chainParallelStates[i];
+						if (chainOpParallelStates != null) {
+							//We only redistribute if the parallelism of the operator changed from previous executions
+							if (parallelismChanged) {
+								redistributedParallelStates[i] = repartitioner.repartitionState(
+										chainOpParallelStates,
+										newParallelism);
+							} else {
+								List<Collection<OperatorStateHandle>> repacking = new ArrayList<>(newParallelism);
+								for (OperatorStateHandle operatorStateHandle : chainOpParallelStates) {
+									repacking.add(Collections.singletonList(operatorStateHandle));
+								}
+								redistributedParallelStates[i] = repacking;
+							}
+						}
+					}
 
 					int counter = 0;
-					for (int i = 0; i < executionJobVertex.getParallelism(); i++) {
+
+					for (int i = 0; i < newParallelism; ++i) {
+
+						// non-partitioned state
 						ChainedStateHandle<StreamStateHandle> state = null;
 
 						if (hasNonPartitionedState) {
@@ -813,25 +873,46 @@ public class CheckpointCoordinator {
 
 							if (subtaskState != null) {
 								// count the number of executions for which we set a state
-								counter++;
+								++counter;
 								state = subtaskState.getChainedStateHandle();
 							}
 						}
 
+						// partitionable state
+						@SuppressWarnings("unchecked")
+						Collection<OperatorStateHandle>[] ia = new Collection[taskState.getChainLength()];
+						List<Collection<OperatorStateHandle>> subTaskPartitionableState = Arrays.asList(ia);
+
+						for (int j = 0; j < redistributedParallelStates.length; ++j) {
+							List<Collection<OperatorStateHandle>> redistributedParallelState =
+									redistributedParallelStates[j];
+
+							if (redistributedParallelState != null) {
+								subTaskPartitionableState.set(j, redistributedParallelState.get(i));
+							}
+						}
+
+						// key-partitioned state
 						KeyGroupRange subtaskKeyGroupIds = keyGroupPartitions.get(i);
 
-						List<KeyGroupsStateHandle> subtaskKeyGroupStates = getKeyGroupsStateHandles(
-								taskState.getKeyGroupStates(),
-								subtaskKeyGroupIds);
+						// Again, we only repartition if the parallelism changed
+						List<KeyGroupsStateHandle> subtaskKeyGroupStates = parallelismChanged ?
+								getKeyGroupsStateHandles(taskState.getKeyGroupStates(), subtaskKeyGroupIds)
+								: Collections.singletonList(taskState.getKeyGroupState(i));
 
 						Execution currentExecutionAttempt = executionJobVertex
 							.getTaskVertices()[i]
 							.getCurrentExecutionAttempt();
 
-						currentExecutionAttempt.setInitialState(state, subtaskKeyGroupStates);
+						CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(
+								state,
+								null/*subTaskPartionableState*/, //TODO chose right structure and put redistributed states here
+								subtaskKeyGroupStates);
+
+						currentExecutionAttempt.setInitialState(checkpointStateHandles, subTaskPartitionableState);
 					}
 
-					if (allOrNothingState && counter > 0 && counter < executionJobVertex.getParallelism()) {
+					if (allOrNothingState && counter > 0 && counter < newParallelism) {
 						throw new IllegalStateException("The checkpoint contained state only for " +
 							"a subset of tasks for vertex " + executionJobVertex);
 					}
@@ -859,7 +940,7 @@ public class CheckpointCoordinator {
 
 		for (KeyGroupsStateHandle storedKeyGroup : allKeyGroupsHandles) {
 			KeyGroupsStateHandle intersection = storedKeyGroup.getKeyGroupIntersection(subtaskKeyGroupIds);
-			if(intersection.getNumberOfKeyGroups() > 0) {
+			if (intersection.getNumberOfKeyGroups() > 0) {
 				subtaskKeyGroupStates.add(intersection);
 			}
 		}
@@ -881,7 +962,7 @@ public class CheckpointCoordinator {
 	public static List<KeyGroupRange> createKeyGroupPartitions(int numberKeyGroups, int parallelism) {
 		Preconditions.checkArgument(numberKeyGroups >= parallelism);
 		List<KeyGroupRange> result = new ArrayList<>(parallelism);
-		int start = 0;
+
 		for (int i = 0; i < parallelism; ++i) {
 			result.add(KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(numberKeyGroups, parallelism, i));
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
index 7cb3916..0d279f1 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
@@ -153,9 +153,4 @@ public class CompletedCheckpoint implements StateObject {
 	public String toString() {
 		return String.format("Checkpoint %d @ %d for %s", checkpointID, timestamp, job);
 	}
-
-	@Override
-	public void close() throws IOException {
-		StateUtil.bestEffortCloseAllStateObjects(taskStates.values());
-	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorStateRepartitioner.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorStateRepartitioner.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorStateRepartitioner.java
new file mode 100644
index 0000000..98810f1
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorStateRepartitioner.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.runtime.state.OperatorStateHandle;
+
+import java.util.Collection;
+import java.util.List;
+
+/**
+ * Interface that allows to implement different strategies for repartitioning of operator state as parallelism changes.
+ */
+public interface OperatorStateRepartitioner {
+
+	/**
+	 * @param previousParallelSubtaskStates List of state handles to the parallel subtask states of an operator, as they
+	 *                                      have been checkpointed.
+	 * @param parallelism                   The parallelism that we consider for the state redistribution. Determines the size of the
+	 *                                      returned list.
+	 * @return List with one entry per parallel subtask. Each subtask receives now one collection of states that build
+	 * of the new total state for this subtask.
+	 */
+	List<Collection<OperatorStateHandle>> repartitionState(
+			List<OperatorStateHandle> previousParallelSubtaskStates,
+			int parallelism);
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
index d499a5a..2ca9d69 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
@@ -23,7 +23,9 @@ import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateUtil;
 import org.apache.flink.runtime.state.StreamStateHandle;
 
@@ -167,49 +169,80 @@ public class PendingCheckpoint {
 	}
 	
 	public boolean acknowledgeTask(
-		ExecutionAttemptID attemptID,
-		ChainedStateHandle<StreamStateHandle> state,
-		List<KeyGroupsStateHandle> keyGroupsState) {
+			ExecutionAttemptID attemptID,
+			CheckpointStateHandles checkpointStateHandles) {
 
 		synchronized (lock) {
 			if (discarded) {
 				return false;
 			}
-			
-			ExecutionVertex vertex = notYetAcknowledgedTasks.remove(attemptID);
-			if (vertex != null) {
-				if (state != null || keyGroupsState != null) {
 
-					JobVertexID jobVertexID = vertex.getJobvertexId();
+			ExecutionVertex vertex = notYetAcknowledgedTasks.remove(attemptID);
 
-					TaskState taskState;
+			if (vertex != null) {
 
-					if (taskStates.containsKey(jobVertexID)) {
-						taskState = taskStates.get(jobVertexID);
-					} else {
-						taskState = new TaskState(jobVertexID, vertex.getTotalNumberOfParallelSubtasks(), vertex.getMaxParallelism());
-						taskStates.put(jobVertexID, taskState);
+				if (checkpointStateHandles != null) {
+					List<KeyGroupsStateHandle> keyGroupsState = checkpointStateHandles.getKeyGroupsStateHandle();
+					ChainedStateHandle<StreamStateHandle> nonPartitionedState =
+							checkpointStateHandles.getNonPartitionedStateHandles();
+					ChainedStateHandle<OperatorStateHandle> partitioneableState =
+							checkpointStateHandles.getPartitioneableStateHandles();
+
+					if (nonPartitionedState != null || partitioneableState != null || keyGroupsState != null) {
+
+						JobVertexID jobVertexID = vertex.getJobvertexId();
+
+						int subtaskIndex = vertex.getParallelSubtaskIndex();
+
+						TaskState taskState;
+
+						if (taskStates.containsKey(jobVertexID)) {
+							taskState = taskStates.get(jobVertexID);
+						} else {
+							//TODO this should go away when we remove chained state, assigning state to operators directly instead
+							int chainLength;
+							if (nonPartitionedState != null) {
+								chainLength = nonPartitionedState.getLength();
+							} else if (partitioneableState != null) {
+								chainLength = partitioneableState.getLength();
+							} else {
+								chainLength = 1;
+							}
+
+							taskState = new TaskState(
+								jobVertexID,
+								vertex.getTotalNumberOfParallelSubtasks(),
+								vertex.getMaxParallelism(),
+								chainLength);
+
+							taskStates.put(jobVertexID, taskState);
+						}
+
+						long duration = System.currentTimeMillis() - checkpointTimestamp;
+
+						if (nonPartitionedState != null) {
+							taskState.putState(
+									subtaskIndex,
+									new SubtaskState(nonPartitionedState, duration));
+						}
+
+						if(partitioneableState != null && !partitioneableState.isEmpty()) {
+							taskState.putPartitionableState(subtaskIndex, partitioneableState);
+						}
+
+						// currently a checkpoint can only contain keyed state
+						// for the head operator
+						if (keyGroupsState != null && !keyGroupsState.isEmpty()) {
+							KeyGroupsStateHandle keyGroupsStateHandle = keyGroupsState.get(0);
+							taskState.putKeyedState(subtaskIndex, keyGroupsStateHandle);
+						}
 					}
+				}
 
-					long duration = System.currentTimeMillis() - checkpointTimestamp;
+				++numAcknowledgedTasks;
 
-					if (state != null) {
-						taskState.putState(
-							vertex.getParallelSubtaskIndex(),
-							new SubtaskState(state, duration));
-					}
-
-					// currently a checkpoint can only contain keyed state
-					// for the head operator
-					if (keyGroupsState != null && !keyGroupsState.isEmpty()) {
-						KeyGroupsStateHandle keyGroupsStateHandle = keyGroupsState.get(0);
-						taskState.putKeyedState(vertex.getParallelSubtaskIndex(), keyGroupsStateHandle);
-					}
-				}
-				numAcknowledgedTasks++;
 				return true;
-			}
-			else {
+			} else {
 				return false;
 			}
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java
new file mode 100644
index 0000000..09a35f6
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Current default implementation of {@link OperatorStateRepartitioner} that redistributes state in round robin fashion.
+ */
+public class RoundRobinOperatorStateRepartitioner implements OperatorStateRepartitioner {
+
+	public static final OperatorStateRepartitioner INSTANCE = new RoundRobinOperatorStateRepartitioner();
+	private static final boolean OPTIMIZE_MEMORY_USE = false;
+
+	@Override
+	public List<Collection<OperatorStateHandle>> repartitionState(
+			List<OperatorStateHandle> previousParallelSubtaskStates,
+			int parallelism) {
+
+		Preconditions.checkNotNull(previousParallelSubtaskStates);
+		Preconditions.checkArgument(parallelism > 0);
+
+		// Reorganize: group by (State Name -> StreamStateHandle + Offsets)
+		Map<String, List<Tuple2<StreamStateHandle, long[]>>> nameToState =
+				groupByStateName(previousParallelSubtaskStates);
+
+		if (OPTIMIZE_MEMORY_USE) {
+			previousParallelSubtaskStates.clear(); // free for GC at to cost that old handles are no longer available
+		}
+
+		// Assemble result from all merge maps
+		List<Collection<OperatorStateHandle>> result = new ArrayList<>(parallelism);
+
+		// Do the actual repartitioning for all named states
+		List<Map<StreamStateHandle, OperatorStateHandle>> mergeMapList =
+				repartition(nameToState, parallelism);
+
+		for (int i = 0; i < mergeMapList.size(); ++i) {
+			result.add(i, new ArrayList<>(mergeMapList.get(i).values()));
+		}
+
+		return result;
+	}
+
+	/**
+	 * Group by the different named states.
+	 */
+	private Map<String, List<Tuple2<StreamStateHandle, long[]>>> groupByStateName(
+			List<OperatorStateHandle> previousParallelSubtaskStates) {
+
+		//Reorganize: group by (State Name -> StreamStateHandle + Offsets)
+		Map<String, List<Tuple2<StreamStateHandle, long[]>>> nameToState = new HashMap<>();
+		for (OperatorStateHandle psh : previousParallelSubtaskStates) {
+
+			for (Map.Entry<String, long[]> e : psh.getStateNameToPartitionOffsets().entrySet()) {
+
+				List<Tuple2<StreamStateHandle, long[]>> stateLocations = nameToState.get(e.getKey());
+
+				if (stateLocations == null) {
+					stateLocations = new ArrayList<>();
+					nameToState.put(e.getKey(), stateLocations);
+				}
+
+				stateLocations.add(new Tuple2<>(psh.getDelegateStateHandle(), e.getValue()));
+			}
+		}
+		return nameToState;
+	}
+
+	/**
+	 * Repartition all named states.
+	 */
+	private List<Map<StreamStateHandle, OperatorStateHandle>> repartition(
+			Map<String, List<Tuple2<StreamStateHandle, long[]>>> nameToState, int parallelism) {
+
+		// We will use this to merge w.r.t. StreamStateHandles for each parallel subtask inside the maps
+		List<Map<StreamStateHandle, OperatorStateHandle>> mergeMapList = new ArrayList<>(parallelism);
+		// Initialize
+		for (int i = 0; i < parallelism; ++i) {
+			mergeMapList.add(new HashMap<StreamStateHandle, OperatorStateHandle>());
+		}
+
+		int startParallelOP = 0;
+		// Iterate all named states and repartition one named state at a time per iteration
+		for (Map.Entry<String, List<Tuple2<StreamStateHandle, long[]>>> e : nameToState.entrySet()) {
+
+			List<Tuple2<StreamStateHandle, long[]>> current = e.getValue();
+
+			// Determine actual number of partitions for this named state
+			int totalPartitions = 0;
+			for (Tuple2<StreamStateHandle, long[]> offsets : current) {
+				totalPartitions += offsets.f1.length;
+			}
+
+			// Repartition the state across the parallel operator instances
+			int lstIdx = 0;
+			int offsetIdx = 0;
+			int baseFraction = totalPartitions / parallelism;
+			int remainder = totalPartitions % parallelism;
+
+			int newStartParallelOp = startParallelOP;
+
+			for (int i = 0; i < parallelism; ++i) {
+
+				// Preparation: calculate the actual index considering wrap around
+				int parallelOpIdx = (i + startParallelOP) % parallelism;
+
+				// Now calculate the number of partitions we will assign to the parallel instance in this round ...
+				int numberOfPartitionsToAssign = baseFraction;
+
+				// ... and distribute odd partitions while we still have some, one at a time
+				if (remainder > 0) {
+					++numberOfPartitionsToAssign;
+					--remainder;
+				} else if (remainder == 0) {
+					// We are out of odd partitions now and begin our next redistribution round with the current
+					// parallel operator to ensure fair load balance
+					newStartParallelOp = parallelOpIdx;
+					--remainder;
+				}
+
+				// Now start collection the partitions for the parallel instance into this list
+				List<Tuple2<StreamStateHandle, long[]>> parallelOperatorState = new ArrayList<>();
+
+				while (numberOfPartitionsToAssign > 0) {
+					Tuple2<StreamStateHandle, long[]> handleWithOffsets = current.get(lstIdx);
+					long[] offsets = handleWithOffsets.f1;
+					int remaining = offsets.length - offsetIdx;
+					// Repartition offsets
+					long[] offs;
+					if (remaining > numberOfPartitionsToAssign) {
+						offs = Arrays.copyOfRange(offsets, offsetIdx, offsetIdx + numberOfPartitionsToAssign);
+						offsetIdx += numberOfPartitionsToAssign;
+					} else {
+						if (OPTIMIZE_MEMORY_USE) {
+							handleWithOffsets.f1 = null; // GC
+						}
+						offs = Arrays.copyOfRange(offsets, offsetIdx, offsets.length);
+						offsetIdx = 0;
+						++lstIdx;
+					}
+
+					parallelOperatorState.add(
+							new Tuple2<>(handleWithOffsets.f0, offs));
+
+					numberOfPartitionsToAssign -= remaining;
+
+					// As a last step we merge partitions that use the same StreamStateHandle in a single
+					// OperatorStateHandle
+					Map<StreamStateHandle, OperatorStateHandle> mergeMap = mergeMapList.get(parallelOpIdx);
+					OperatorStateHandle psh = mergeMap.get(handleWithOffsets.f0);
+					if (psh == null) {
+						psh = new OperatorStateHandle(handleWithOffsets.f0, new HashMap<String, long[]>());
+						mergeMap.put(handleWithOffsets.f0, psh);
+					}
+					psh.getStateNameToPartitionOffsets().put(e.getKey(), offs);
+				}
+			}
+			startParallelOP = newStartParallelOp;
+			e.setValue(null);
+		}
+		return mergeMapList;
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
index 9beb233..2aa0491 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
@@ -24,8 +24,6 @@ import org.apache.flink.runtime.state.StreamStateHandle;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.io.IOException;
-
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
 /**
@@ -112,11 +110,4 @@ public class SubtaskState implements StateObject {
 	public String toString() {
 		return String.format("SubtaskState(Size: %d, Duration: %d, State: %s)", stateSize, duration, chainedStateHandle);
 	}
-
-	@Override
-	public void close() throws IOException {
-		chainedStateHandle.close();
-	}
-
-
 }


[09/10] flink git commit: [FLINK-4702] [kafka connector] Commit offsets to Kafka asynchronously and don't block on polls

Posted by se...@apache.org.
[FLINK-4702] [kafka connector] Commit offsets to Kafka asynchronously and don't block on polls

Letting the Kafka commit block on polls means that 'notifyCheckpointComplete()' may take
very long. This is mostly relevant for low-throughput Kafka topics.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/92f4539a
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/92f4539a
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/92f4539a

Branch: refs/heads/master
Commit: 92f4539afc714f7dbd293c3ad677b3b5807c6911
Parents: 6f8f5eb
Author: Stephan Ewen <se...@apache.org>
Authored: Thu Sep 29 18:09:51 2016 +0200
Committer: Stephan Ewen <se...@apache.org>
Committed: Fri Sep 30 12:38:46 2016 +0200

----------------------------------------------------------------------
 .../kafka/internal/Kafka09Fetcher.java          |  73 +++--
 .../connectors/kafka/Kafka09FetcherTest.java    | 304 +++++++++++++++++++
 2 files changed, 355 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/92f4539a/flink-streaming-connectors/flink-connector-kafka-0.9/src/main/java/org/apache/flink/streaming/connectors/kafka/internal/Kafka09Fetcher.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-kafka-0.9/src/main/java/org/apache/flink/streaming/connectors/kafka/internal/Kafka09Fetcher.java b/flink-streaming-connectors/flink-connector-kafka-0.9/src/main/java/org/apache/flink/streaming/connectors/kafka/internal/Kafka09Fetcher.java
index 9c861c9..1da2259 100644
--- a/flink-streaming-connectors/flink-connector-kafka-0.9/src/main/java/org/apache/flink/streaming/connectors/kafka/internal/Kafka09Fetcher.java
+++ b/flink-streaming-connectors/flink-connector-kafka-0.9/src/main/java/org/apache/flink/streaming/connectors/kafka/internal/Kafka09Fetcher.java
@@ -37,6 +37,7 @@ import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.clients.consumer.ConsumerRecords;
 import org.apache.kafka.clients.consumer.KafkaConsumer;
 import org.apache.kafka.clients.consumer.OffsetAndMetadata;
+import org.apache.kafka.clients.consumer.OffsetCommitCallback;
 import org.apache.kafka.common.Metric;
 import org.apache.kafka.common.MetricName;
 import org.apache.kafka.common.TopicPartition;
@@ -50,6 +51,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Properties;
+import java.util.concurrent.atomic.AtomicReference;
 
 /**
  * A fetcher that fetches data from Kafka brokers via the Kafka 0.9 consumer API.
@@ -74,18 +76,24 @@ public class Kafka09Fetcher<T> extends AbstractFetcher<T, TopicPartition> implem
 	/** The maximum number of milliseconds to wait for a fetch batch */
 	private final long pollTimeout;
 
-	/** Mutex to guard against concurrent access to the non-threadsafe Kafka consumer */
-	private final Object consumerLock = new Object();
+	/** The next offsets that the main thread should commit */
+	private final AtomicReference<Map<TopicPartition, OffsetAndMetadata>> nextOffsetsToCommit;
+	
+	/** The callback invoked by Kafka once an offset commit is complete */
+	private final OffsetCommitCallback offsetCommitCallback;
 
 	/** Reference to the Kafka consumer, once it is created */
 	private volatile KafkaConsumer<byte[], byte[]> consumer;
-
+	
 	/** Reference to the proxy, forwarding exceptions from the fetch thread to the main thread */
 	private volatile ExceptionProxy errorHandler;
 
 	/** Flag to mark the main work loop as alive */
 	private volatile boolean running = true;
 
+	/** Flag tracking whether the latest commit request has completed */
+	private volatile boolean commitInProgress;
+
 	// ------------------------------------------------------------------------
 
 	public Kafka09Fetcher(
@@ -105,6 +113,8 @@ public class Kafka09Fetcher<T> extends AbstractFetcher<T, TopicPartition> implem
 		this.runtimeContext = runtimeContext;
 		this.kafkaProperties = kafkaProperties;
 		this.pollTimeout = pollTimeout;
+		this.nextOffsetsToCommit = new AtomicReference<>();
+		this.offsetCommitCallback = new CommitCallback();
 
 		// if checkpointing is enabled, we are not automatically committing to Kafka.
 		kafkaProperties.setProperty(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG,
@@ -203,19 +213,23 @@ public class Kafka09Fetcher<T> extends AbstractFetcher<T, TopicPartition> implem
 
 			// main fetch loop
 			while (running) {
+
+				// check if there is something to commit
+				final Map<TopicPartition, OffsetAndMetadata> toCommit = nextOffsetsToCommit.getAndSet(null);
+				if (toCommit != null && !commitInProgress) {
+					// reset the work-to-be committed, so we don't repeatedly commit the same
+					// also record that a commit is already in progress
+					commitInProgress = true;
+					consumer.commitAsync(toCommit, offsetCommitCallback);
+				}
+
 				// get the next batch of records
 				final ConsumerRecords<byte[], byte[]> records;
-				synchronized (consumerLock) {
-					try {
-						records = consumer.poll(pollTimeout);
-					}
-					catch (WakeupException we) {
-						if (running) {
-							throw we;
-						} else {
-							continue;
-						}
-					}
+				try {
+					records = consumer.poll(pollTimeout);
+				}
+				catch (WakeupException we) {
+					continue;
 				}
 
 				// get the records for each topic partition
@@ -252,10 +266,9 @@ public class Kafka09Fetcher<T> extends AbstractFetcher<T, TopicPartition> implem
 		}
 		finally {
 			try {
-				synchronized (consumerLock) {
-					consumer.close();
-				}
-			} catch (Throwable t) {
+				consumer.close();
+			}
+			catch (Throwable t) {
 				LOG.warn("Error while closing Kafka 0.9 consumer", t);
 			}
 		}
@@ -283,10 +296,14 @@ public class Kafka09Fetcher<T> extends AbstractFetcher<T, TopicPartition> implem
 			}
 		}
 
-		if (this.consumer != null) {
-			synchronized (consumerLock) {
-				this.consumer.commitSync(offsetsToCommit);
-			}
+		// record the work to be committed by the main consumer thread and make sure the consumer notices that
+		if (nextOffsetsToCommit.getAndSet(offsetsToCommit) != null) {
+			LOG.warn("Committing offsets to Kafka takes longer than the checkpoint interval. " +
+					"Skipping commit of previous offsets because newer complete checkpoint offsets are available. " +
+					"This does not compromise Flink's checkpoint integrity.");
+		}
+		if (consumer != null) {
+			consumer.wakeup();
 		}
 	}
 
@@ -301,4 +318,16 @@ public class Kafka09Fetcher<T> extends AbstractFetcher<T, TopicPartition> implem
 		}
 		return result;
 	}
+
+	private class CommitCallback implements OffsetCommitCallback {
+
+		@Override
+		public void onComplete(Map<TopicPartition, OffsetAndMetadata> offsets, Exception ex) {
+			commitInProgress = false;
+
+			if (ex != null) {
+				LOG.warn("Committing offsets to Kafka failed. This does not compromise Flink's checkpoints.", ex);
+			}
+		}
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/92f4539a/flink-streaming-connectors/flink-connector-kafka-0.9/src/test/java/org/apache/flink/streaming/connectors/kafka/Kafka09FetcherTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-kafka-0.9/src/test/java/org/apache/flink/streaming/connectors/kafka/Kafka09FetcherTest.java b/flink-streaming-connectors/flink-connector-kafka-0.9/src/test/java/org/apache/flink/streaming/connectors/kafka/Kafka09FetcherTest.java
new file mode 100644
index 0000000..4fd6c9f
--- /dev/null
+++ b/flink-streaming-connectors/flink-connector-kafka-0.9/src/test/java/org/apache/flink/streaming/connectors/kafka/Kafka09FetcherTest.java
@@ -0,0 +1,304 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka;
+
+import org.apache.flink.core.testutils.MultiShotLatch;
+import org.apache.flink.core.testutils.OneShotLatch;
+import org.apache.flink.streaming.api.functions.source.SourceFunction.SourceContext;
+import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
+import org.apache.flink.streaming.connectors.kafka.internal.Kafka09Fetcher;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import org.apache.flink.streaming.util.serialization.KeyedDeserializationSchema;
+import org.apache.flink.streaming.util.serialization.KeyedDeserializationSchemaWrapper;
+import org.apache.flink.streaming.util.serialization.SimpleStringSchema;
+
+import org.apache.kafka.clients.consumer.ConsumerRecords;
+import org.apache.kafka.clients.consumer.KafkaConsumer;
+import org.apache.kafka.clients.consumer.OffsetAndMetadata;
+import org.apache.kafka.clients.consumer.OffsetCommitCallback;
+import org.apache.kafka.common.TopicPartition;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+import org.mockito.Mockito;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import org.powermock.core.classloader.annotations.PrepareForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Properties;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.atomic.AtomicReference;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.anyLong;
+import static org.powermock.api.mockito.PowerMockito.doAnswer;
+import static org.powermock.api.mockito.PowerMockito.mock;
+import static org.powermock.api.mockito.PowerMockito.when;
+import static org.powermock.api.mockito.PowerMockito.whenNew;
+
+/**
+ * Unit tests for the {@link Kafka09Fetcher}.
+ */
+@RunWith(PowerMockRunner.class)
+@PrepareForTest(Kafka09Fetcher.class)
+public class Kafka09FetcherTest {
+
+	@Test
+	public void testCommitDoesNotBlock() throws Exception {
+
+		// test data
+		final KafkaTopicPartition testPartition = new KafkaTopicPartition("test", 42);
+		final Map<KafkaTopicPartition, Long> testCommitData = new HashMap<>();
+		testCommitData.put(testPartition, 11L);
+
+		// to synchronize when the consumer is in its blocking method
+		final OneShotLatch sync = new OneShotLatch();
+
+		// ----- the mock consumer with blocking poll calls ----
+		final MultiShotLatch blockerLatch = new MultiShotLatch();
+		
+		KafkaConsumer<?, ?> mockConsumer = mock(KafkaConsumer.class);
+		when(mockConsumer.poll(anyLong())).thenAnswer(new Answer<ConsumerRecords<?, ?>>() {
+			
+			@Override
+			public ConsumerRecords<?, ?> answer(InvocationOnMock invocation) throws InterruptedException {
+				sync.trigger();
+				blockerLatch.await();
+				return ConsumerRecords.empty();
+			}
+		});
+
+		doAnswer(new Answer<Void>() {
+			@Override
+			public Void answer(InvocationOnMock invocation) {
+				blockerLatch.trigger();
+				return null;
+			}
+		}).when(mockConsumer).wakeup();
+
+		// make sure the fetcher creates the mock consumer
+		whenNew(KafkaConsumer.class).withAnyArguments().thenReturn(mockConsumer);
+
+		// ----- create the test fetcher -----
+
+		@SuppressWarnings("unchecked")
+		SourceContext<String> sourceContext = mock(SourceContext.class);
+		List<KafkaTopicPartition> topics = Collections.singletonList(new KafkaTopicPartition("test", 42));
+		KeyedDeserializationSchema<String> schema = new KeyedDeserializationSchemaWrapper<>(new SimpleStringSchema());
+		StreamingRuntimeContext context = mock(StreamingRuntimeContext.class);
+		
+		final Kafka09Fetcher<String> fetcher = new Kafka09Fetcher<>(
+				sourceContext, topics, null, null, context, schema, new Properties(), 0L, false);
+
+		// ----- run the fetcher -----
+
+		final AtomicReference<Throwable> error = new AtomicReference<>();
+		final Thread fetcherRunner = new Thread("fetcher runner") {
+
+			@Override
+			public void run() {
+				try {
+					fetcher.runFetchLoop();
+				} catch (Throwable t) {
+					error.set(t);
+				}
+			}
+		};
+		fetcherRunner.start();
+
+		// wait until the fetcher has reached the method of interest
+		sync.await();
+
+		// ----- trigger the offset commit -----
+		
+		final AtomicReference<Throwable> commitError = new AtomicReference<>();
+		final Thread committer = new Thread("committer runner") {
+			@Override
+			public void run() {
+				try {
+					fetcher.commitSpecificOffsetsToKafka(testCommitData);
+				} catch (Throwable t) {
+					commitError.set(t);
+				}
+			}
+		};
+		committer.start();
+
+		// ----- ensure that the committer finishes in time  -----
+		committer.join(30000);
+		assertFalse("The committer did not finish in time", committer.isAlive());
+
+		// ----- test done, wait till the fetcher is done for a clean shutdown -----
+		fetcher.cancel();
+		fetcherRunner.join();
+
+		// check that there were no errors in the fetcher
+		final Throwable fetcherError = error.get();
+		if (fetcherError != null) {
+			throw new Exception("Exception in the fetcher", fetcherError);
+		}
+		final Throwable committerError = commitError.get();
+		if (committerError != null) {
+			throw new Exception("Exception in the committer", committerError);
+		}
+	}
+
+	@Test
+	public void ensureOffsetsGetCommitted() throws Exception {
+		
+		// test data
+		final KafkaTopicPartition testPartition1 = new KafkaTopicPartition("test", 42);
+		final KafkaTopicPartition testPartition2 = new KafkaTopicPartition("another", 99);
+		
+		final Map<KafkaTopicPartition, Long> testCommitData1 = new HashMap<>();
+		testCommitData1.put(testPartition1, 11L);
+		testCommitData1.put(testPartition2, 18L);
+
+		final Map<KafkaTopicPartition, Long> testCommitData2 = new HashMap<>();
+		testCommitData2.put(testPartition1, 19L);
+		testCommitData2.put(testPartition2, 28L);
+
+		final BlockingQueue<Map<TopicPartition, OffsetAndMetadata>> commitStore = new LinkedBlockingQueue<>();
+
+
+		// ----- the mock consumer with poll(), wakeup(), and commit(A)sync calls ----
+
+		final MultiShotLatch blockerLatch = new MultiShotLatch();
+
+		KafkaConsumer<?, ?> mockConsumer = mock(KafkaConsumer.class);
+
+		when(mockConsumer.poll(anyLong())).thenAnswer(new Answer<ConsumerRecords<?, ?>>() {
+			@Override
+			public ConsumerRecords<?, ?> answer(InvocationOnMock invocation) throws InterruptedException {
+				blockerLatch.await();
+				return ConsumerRecords.empty();
+			}
+		});
+
+		doAnswer(new Answer<Void>() {
+			@Override
+			public Void answer(InvocationOnMock invocation) {
+				blockerLatch.trigger();
+				return null;
+			}
+		}).when(mockConsumer).wakeup();
+
+		doAnswer(new Answer<Void>() {
+			@Override
+			public Void answer(InvocationOnMock invocation) {
+				@SuppressWarnings("unchecked")
+				Map<TopicPartition, OffsetAndMetadata> offsets = 
+						(Map<TopicPartition, OffsetAndMetadata>) invocation.getArguments()[0];
+
+				OffsetCommitCallback callback = (OffsetCommitCallback) invocation.getArguments()[1];
+
+				commitStore.add(offsets);
+				callback.onComplete(offsets, null);
+
+				return null; 
+			}
+		}).when(mockConsumer).commitAsync(
+				Mockito.<Map<TopicPartition, OffsetAndMetadata>>any(), any(OffsetCommitCallback.class));
+
+		// make sure the fetcher creates the mock consumer
+		whenNew(KafkaConsumer.class).withAnyArguments().thenReturn(mockConsumer);
+
+		// ----- create the test fetcher -----
+
+		@SuppressWarnings("unchecked")
+		SourceContext<String> sourceContext = mock(SourceContext.class);
+		List<KafkaTopicPartition> topics = Collections.singletonList(new KafkaTopicPartition("test", 42));
+		KeyedDeserializationSchema<String> schema = new KeyedDeserializationSchemaWrapper<>(new SimpleStringSchema());
+		StreamingRuntimeContext context = mock(StreamingRuntimeContext.class);
+
+		final Kafka09Fetcher<String> fetcher = new Kafka09Fetcher<>(
+				sourceContext, topics, null, null, context, schema, new Properties(), 0L, false);
+
+		// ----- run the fetcher -----
+
+		final AtomicReference<Throwable> error = new AtomicReference<>();
+		final Thread fetcherRunner = new Thread("fetcher runner") {
+
+			@Override
+			public void run() {
+				try {
+					fetcher.runFetchLoop();
+				} catch (Throwable t) {
+					error.set(t);
+				}
+			}
+		};
+		fetcherRunner.start();
+
+		// ----- trigger the first offset commit -----
+
+		fetcher.commitSpecificOffsetsToKafka(testCommitData1);
+		Map<TopicPartition, OffsetAndMetadata> result1 = commitStore.take();
+
+		for (Entry<TopicPartition, OffsetAndMetadata> entry : result1.entrySet()) {
+			TopicPartition partition = entry.getKey();
+			if (partition.topic().equals("test")) {
+				assertEquals(42, partition.partition());
+				assertEquals(11L, entry.getValue().offset());
+			}
+			else if (partition.topic().equals("another")) {
+				assertEquals(99, partition.partition());
+				assertEquals(18L, entry.getValue().offset());
+			}
+		}
+
+		// ----- trigger the second offset commit -----
+
+		fetcher.commitSpecificOffsetsToKafka(testCommitData2);
+		Map<TopicPartition, OffsetAndMetadata> result2 = commitStore.take();
+
+		for (Entry<TopicPartition, OffsetAndMetadata> entry : result2.entrySet()) {
+			TopicPartition partition = entry.getKey();
+			if (partition.topic().equals("test")) {
+				assertEquals(42, partition.partition());
+				assertEquals(19L, entry.getValue().offset());
+			}
+			else if (partition.topic().equals("another")) {
+				assertEquals(99, partition.partition());
+				assertEquals(28L, entry.getValue().offset());
+			}
+		}
+		
+		// ----- test done, wait till the fetcher is done for a clean shutdown -----
+		fetcher.cancel();
+		fetcherRunner.join();
+
+		// check that there were no errors in the fetcher
+		final Throwable caughtError = error.get();
+		if (caughtError != null) {
+			throw new Exception("Exception in the fetcher", caughtError);
+		}
+	}
+}


[06/10] flink git commit: [FLINK-4379] [checkpoints] Introduce rescalable operator state

Posted by se...@apache.org.
http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
index 5612f73..7293a84 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
@@ -7,7 +7,7 @@
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
  *
- *     http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
  *
  * Unless required by applicable law or agreed to in writing, software
  * distributed under the License is distributed on an "AS IS" BASIS,
@@ -18,178 +18,55 @@
 
 package org.apache.flink.runtime.state;
 
-import org.apache.flink.api.common.ExecutionConfig;
-import org.apache.flink.api.common.functions.ReduceFunction;
-import org.apache.flink.api.common.state.FoldingState;
-import org.apache.flink.api.common.state.FoldingStateDescriptor;
-import org.apache.flink.api.common.state.ListState;
-import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.state.MergingState;
-import org.apache.flink.api.common.state.ReducingState;
-import org.apache.flink.api.common.state.ReducingStateDescriptor;
 import org.apache.flink.api.common.state.State;
-import org.apache.flink.api.common.state.StateBackend;
 import org.apache.flink.api.common.state.StateDescriptor;
-import org.apache.flink.api.common.state.ValueState;
-import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.runtime.query.TaskKvStateRegistry;
-import org.apache.flink.util.Preconditions;
 
-import java.util.ArrayList;
 import java.util.Collection;
-import java.util.HashMap;
-import java.util.List;
-import java.util.concurrent.RunnableFuture;
 
 /**
- * A keyed state backend is responsible for managing keyed state. The state can be checkpointed
- * to streams using {@link #snapshot(long, long, CheckpointStreamFactory)}.
+ * A keyed state backend provides methods for managing keyed state.
  *
  * @param <K> The key by which state is keyed.
  */
-public abstract class KeyedStateBackend<K> {
-
-	/** {@link TypeSerializer} for our key. */
-	protected final TypeSerializer<K> keySerializer;
-
-	/** The currently active key. */
-	protected K currentKey;
-
-	/** The key group of the currently active key */
-	private int currentKeyGroup;
-
-	/** So that we can give out state when the user uses the same key. */
-	protected HashMap<String, KvState<?>> keyValueStatesByName;
-
-	/** For caching the last accessed partitioned state */
-	private String lastName;
-
-	@SuppressWarnings("rawtypes")
-	private KvState lastState;
-
-	/** The number of key-groups aka max parallelism */
-	protected final int numberOfKeyGroups;
-
-	/** Range of key-groups for which this backend is responsible */
-	protected final KeyGroupRange keyGroupRange;
-
-	/** KvStateRegistry helper for this task */
-	protected final TaskKvStateRegistry kvStateRegistry;
-
-	protected final ClassLoader userCodeClassLoader;
-
-	public KeyedStateBackend(
-			TaskKvStateRegistry kvStateRegistry,
-			TypeSerializer<K> keySerializer,
-			ClassLoader userCodeClassLoader,
-			int numberOfKeyGroups,
-			KeyGroupRange keyGroupRange) {
-
-		this.kvStateRegistry = Preconditions.checkNotNull(kvStateRegistry);
-		this.keySerializer = Preconditions.checkNotNull(keySerializer);
-		this.userCodeClassLoader = Preconditions.checkNotNull(userCodeClassLoader);
-		this.numberOfKeyGroups = Preconditions.checkNotNull(numberOfKeyGroups);
-		this.keyGroupRange = Preconditions.checkNotNull(keyGroupRange);
-	}
+public interface KeyedStateBackend<K> {
 
 	/**
-	 * Closes the state backend, releasing all internal resources, but does not delete any persistent
-	 * checkpoint data.
-	 *
-	 * @throws Exception Exceptions can be forwarded and will be logged by the system
+	 * Sets the current key that is used for partitioned state.
+	 * @param newKey The new current key.
 	 */
-	public void close() throws Exception {
-		if (kvStateRegistry != null) {
-			kvStateRegistry.unregisterAll();
-		}
-
-		lastName = null;
-		lastState = null;
-		keyValueStatesByName = null;
-	}
+	void setCurrentKey(K newKey);
 
 	/**
-	 * Creates and returns a new {@link ValueState}.
-	 *
-	 * @param namespaceSerializer TypeSerializer for the state namespace.
-	 * @param stateDesc The {@code StateDescriptor} that contains the name of the state.
-	 *
-	 * @param <N> The type of the namespace.
-	 * @param <T> The type of the value that the {@code ValueState} can store.
+	 * Used by states to access the current key.
 	 */
-	protected abstract <N, T> ValueState<T> createValueState(TypeSerializer<N> namespaceSerializer, ValueStateDescriptor<T> stateDesc) throws Exception;
+	K getCurrentKey();
 
 	/**
-	 * Creates and returns a new {@link ListState}.
-	 *
-	 * @param namespaceSerializer TypeSerializer for the state namespace.
-	 * @param stateDesc The {@code StateDescriptor} that contains the name of the state.
-	 *
-	 * @param <N> The type of the namespace.
-	 * @param <T> The type of the values that the {@code ListState} can store.
+	 * Returns the key-group to which the current key belongs.
 	 */
-	protected abstract <N, T> ListState<T> createListState(TypeSerializer<N> namespaceSerializer, ListStateDescriptor<T> stateDesc) throws Exception;
+	int getCurrentKeyGroupIndex();
 
 	/**
-	 * Creates and returns a new {@link ReducingState}.
-	 *
-	 * @param namespaceSerializer TypeSerializer for the state namespace.
-	 * @param stateDesc The {@code StateDescriptor} that contains the name of the state.
-	 *
-	 * @param <N> The type of the namespace.
-	 * @param <T> The type of the values that the {@code ListState} can store.
+	 * Returns the number of key-groups aka max parallelism.
 	 */
-	protected abstract <N, T> ReducingState<T> createReducingState(TypeSerializer<N> namespaceSerializer, ReducingStateDescriptor<T> stateDesc) throws Exception;
+	int getNumberOfKeyGroups();
 
 	/**
-	 * Creates and returns a new {@link FoldingState}.
-	 *
-	 * @param namespaceSerializer TypeSerializer for the state namespace.
-	 * @param stateDesc The {@code StateDescriptor} that contains the name of the state.
-	 *
-	 * @param <N> The type of the namespace.
-	 * @param <T> Type of the values folded into the state
-	 * @param <ACC> Type of the value in the state	 *
+	 * Returns the key group range for this backend.
 	 */
-	protected abstract <N, T, ACC> FoldingState<T, ACC> createFoldingState(TypeSerializer<N> namespaceSerializer, FoldingStateDescriptor<T, ACC> stateDesc) throws Exception;
-
-	/**
-	 * Sets the current key that is used for partitioned state.
-	 * @param newKey The new current key.
-	 */
-	public void setCurrentKey(K newKey) {
-		this.currentKey = newKey;
-		this.currentKeyGroup = KeyGroupRangeAssignment.assignToKeyGroup(newKey, numberOfKeyGroups);
-	}
+	KeyGroupRange getKeyGroupRange();
 
 	/**
 	 * {@link TypeSerializer} for the state backend key type.
 	 */
-	public TypeSerializer<K> getKeySerializer() {
-		return keySerializer;
-	}
-
-	/**
-	 * Used by states to access the current key.
-	 */
-	public K getCurrentKey() {
-		return currentKey;
-	}
-
-	public int getCurrentKeyGroupIndex() {
-		return currentKeyGroup;
-	}
-
-	public int getNumberOfKeyGroups() {
-		return numberOfKeyGroups;
-	}
+	TypeSerializer<K> getKeySerializer();
 
 	/**
 	 * Creates or retrieves a partitioned state backed by this state backend.
 	 *
-	 * @param stateDescriptor The state identifier for the state. This contains name
-	 *                           and can create a default state value.
+	 * @param stateDescriptor The identifier for the state. This contains name and can create a default state value.
 
 	 * @param <N> The type of the namespace.
 	 * @param <S> The type of the state.
@@ -199,145 +76,21 @@ public abstract class KeyedStateBackend<K> {
 	 * @throws Exception Exceptions may occur during initialization of the state and should be forwarded.
 	 */
 	@SuppressWarnings({"rawtypes", "unchecked"})
-	public <N, S extends State> S getPartitionedState(final N namespace, final TypeSerializer<N> namespaceSerializer, final StateDescriptor<S, ?> stateDescriptor) throws Exception {
-		Preconditions.checkNotNull(namespace, "Namespace");
-		Preconditions.checkNotNull(namespaceSerializer, "Namespace serializer");
-
-		if (keySerializer == null) {
-			throw new RuntimeException("State key serializer has not been configured in the config. " +
-					"This operation cannot use partitioned state.");
-		}
-		
-		if (!stateDescriptor.isSerializerInitialized()) {
-			stateDescriptor.initializeSerializerUnlessSet(new ExecutionConfig());
-		}
-
-		if (keyValueStatesByName == null) {
-			keyValueStatesByName = new HashMap<>();
-		}
-
-		if (lastName != null && lastName.equals(stateDescriptor.getName())) {
-			lastState.setCurrentNamespace(namespace);
-			return (S) lastState;
-		}
-
-		KvState<?> previous = keyValueStatesByName.get(stateDescriptor.getName());
-		if (previous != null) {
-			lastState = previous;
-			lastState.setCurrentNamespace(namespace);
-			lastName = stateDescriptor.getName();
-			return (S) previous;
-		}
-
-		// create a new blank key/value state
-		S state = stateDescriptor.bind(new StateBackend() {
-			@Override
-			public <T> ValueState<T> createValueState(ValueStateDescriptor<T> stateDesc) throws Exception {
-				return KeyedStateBackend.this.createValueState(namespaceSerializer, stateDesc);
-			}
+	<N, S extends State> S getPartitionedState(
+			N namespace,
+			TypeSerializer<N> namespaceSerializer,
+			StateDescriptor<S, ?> stateDescriptor) throws Exception;
 
-			@Override
-			public <T> ListState<T> createListState(ListStateDescriptor<T> stateDesc) throws Exception {
-				return KeyedStateBackend.this.createListState(namespaceSerializer, stateDesc);
-			}
-
-			@Override
-			public <T> ReducingState<T> createReducingState(ReducingStateDescriptor<T> stateDesc) throws Exception {
-				return KeyedStateBackend.this.createReducingState(namespaceSerializer, stateDesc);
-			}
-
-			@Override
-			public <T, ACC> FoldingState<T, ACC> createFoldingState(FoldingStateDescriptor<T, ACC> stateDesc) throws Exception {
-				return KeyedStateBackend.this.createFoldingState(namespaceSerializer, stateDesc);
-			}
-
-		});
-
-		KvState kvState = (KvState) state;
-
-		keyValueStatesByName.put(stateDescriptor.getName(), kvState);
-
-		lastName = stateDescriptor.getName();
-		lastState = kvState;
-
-		kvState.setCurrentNamespace(namespace);
-
-		// Publish queryable state
-		if (stateDescriptor.isQueryable()) {
-			if (kvStateRegistry == null) {
-				throw new IllegalStateException("State backend has not been initialized for job.");
-			}
-
-			String name = stateDescriptor.getQueryableStateName();
-			kvStateRegistry.registerKvState(keyGroupRange, name, kvState);
-		}
-
-		return state;
-	}
 
 	@SuppressWarnings("unchecked,rawtypes")
-	public <N, S extends MergingState<?, ?>> void mergePartitionedStates(final N target, Collection<N> sources, final TypeSerializer<N> namespaceSerializer, final StateDescriptor<S, ?> stateDescriptor) throws Exception {
-		if (stateDescriptor instanceof ReducingStateDescriptor) {
-			ReducingStateDescriptor reducingStateDescriptor = (ReducingStateDescriptor) stateDescriptor;
-			ReduceFunction reduceFn = reducingStateDescriptor.getReduceFunction();
-			ReducingState state = (ReducingState) getPartitionedState(target, namespaceSerializer, stateDescriptor);
-			KvState kvState = (KvState) state;
-			Object result = null;
-			for (N source: sources) {
-				kvState.setCurrentNamespace(source);
-				Object sourceValue = state.get();
-				if (result == null) {
-					result = state.get();
-				} else if (sourceValue != null) {
-					result = reduceFn.reduce(result, sourceValue);
-				}
-				state.clear();
-			}
-			kvState.setCurrentNamespace(target);
-			if (result != null) {
-				state.add(result);
-			}
-		} else if (stateDescriptor instanceof ListStateDescriptor) {
-			ListState<Object> state = (ListState) getPartitionedState(target, namespaceSerializer, stateDescriptor);
-			KvState kvState = (KvState) state;
-			List<Object> result = new ArrayList<>();
-			for (N source: sources) {
-				kvState.setCurrentNamespace(source);
-				Iterable<Object> sourceValue = state.get();
-				if (sourceValue != null) {
-					for (Object o : sourceValue) {
-						result.add(o);
-					}
-				}
-				state.clear();
-			}
-			kvState.setCurrentNamespace(target);
-			for (Object o : result) {
-				state.add(o);
-			}
-		} else {
-			throw new RuntimeException("Cannot merge states for " + stateDescriptor);
-		}
-	}
+	<N, S extends MergingState<?, ?>> void mergePartitionedStates(
+			N target,
+			Collection<N> sources,
+			TypeSerializer<N> namespaceSerializer,
+			StateDescriptor<S, ?> stateDescriptor) throws Exception;
 
 	/**
-	 * Snapshots the keyed state by writing it to streams that are provided by a
-	 * {@link CheckpointStreamFactory}.
-	 *
-	 * @param checkpointId The ID of the checkpoint.
-	 * @param timestamp The timestamp of the checkpoint.
-	 * @param streamFactory The factory that we can use for writing our state to streams.
-	 *
-	 * @return A future that will yield a {@link KeyGroupsStateHandle} with the index and
-	 * written key group state stream.
+	 * Closes the backend and releases all resources.
 	 */
-	public abstract RunnableFuture<KeyGroupsStateHandle> snapshot(
-			long checkpointId,
-			long timestamp,
-			CheckpointStreamFactory streamFactory) throws Exception;
-
-
-	public KeyGroupRange getKeyGroupRange() {
-		return keyGroupRange;
-	}
+	void dispose();
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateBackend.java
new file mode 100644
index 0000000..4e980b7
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateBackend.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import java.io.Closeable;
+
+/**
+ * Interface that combines both, the user facing {@link OperatorStateStore} interface and the system interface
+ * {@link SnapshotProvider}
+ *
+ */
+public interface OperatorStateBackend extends OperatorStateStore, SnapshotProvider<OperatorStateHandle>, Closeable {
+
+	/**
+	 * Disposes the backend and releases all resources.
+	 */
+	void dispose();
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java
new file mode 100644
index 0000000..3e2d713
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Map;
+
+/**
+ * State handle for partitionable operator state. Besides being a {@link StreamStateHandle}, this also provides a
+ * map that contains the offsets to the partitions of named states in the stream.
+ */
+public class OperatorStateHandle implements StreamStateHandle {
+
+	private static final long serialVersionUID = 35876522969227335L;
+
+	/** unique state name -> offsets for available partitions in the handle stream */
+	private final Map<String, long[]> stateNameToPartitionOffsets;
+	private final StreamStateHandle delegateStateHandle;
+
+	public OperatorStateHandle(
+			StreamStateHandle delegateStateHandle,
+			Map<String, long[]> stateNameToPartitionOffsets) {
+
+		this.delegateStateHandle = Preconditions.checkNotNull(delegateStateHandle);
+		this.stateNameToPartitionOffsets = Preconditions.checkNotNull(stateNameToPartitionOffsets);
+	}
+
+	public Map<String, long[]> getStateNameToPartitionOffsets() {
+		return stateNameToPartitionOffsets;
+	}
+
+	@Override
+	public void discardState() throws Exception {
+		delegateStateHandle.discardState();
+	}
+
+	@Override
+	public long getStateSize() throws IOException {
+		return delegateStateHandle.getStateSize();
+	}
+
+	@Override
+	public FSDataInputStream openInputStream() throws IOException {
+		return delegateStateHandle.openInputStream();
+	}
+
+	public StreamStateHandle getDelegateStateHandle() {
+		return delegateStateHandle;
+	}
+
+	@Override
+	public boolean equals(Object o) {
+		if (this == o) {
+			return true;
+		}
+
+		if (!(o instanceof OperatorStateHandle)) {
+			return false;
+		}
+
+		OperatorStateHandle that = (OperatorStateHandle) o;
+
+		if(stateNameToPartitionOffsets.size() != that.stateNameToPartitionOffsets.size()) {
+			return false;
+		}
+
+		for (Map.Entry<String, long[]> entry : stateNameToPartitionOffsets.entrySet()) {
+			if (!Arrays.equals(entry.getValue(), that.stateNameToPartitionOffsets.get(entry.getKey()))) {
+				return false;
+			}
+		}
+
+		return delegateStateHandle.equals(that.delegateStateHandle);
+	}
+
+	@Override
+	public int hashCode() {
+		int result = delegateStateHandle.hashCode();
+		for (Map.Entry<String, long[]> entry : stateNameToPartitionOffsets.entrySet()) {
+
+			int entryHash = entry.getKey().hashCode();
+			if (entry.getValue() != null) {
+				entryHash += Arrays.hashCode(entry.getValue());
+			}
+			result = 31 * result + entryHash;
+		}
+		return result;
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateStore.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateStore.java
new file mode 100644
index 0000000..6914a7c
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateStore.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+
+import java.util.Set;
+
+/**
+ * Interface for a backend that manages partitionable operator state.
+ */
+public interface OperatorStateStore {
+
+	/**
+	 * Creates (or restores) the partitionable state in this backend. Each state is registered under a unique name.
+	 * The provided serializer is used to de/serialize the state in case of checkpointing (snapshot/restore).
+	 *
+	 * @param stateDescriptor The descriptr for this state, providing a name and serializer
+	 * @param <S> The generic type of the state
+	 * @return A list for all state partitions.
+	 * @throws Exception
+	 */
+	<S> ListState<S> getPartitionableState(ListStateDescriptor<S> stateDescriptor) throws Exception;
+
+	/**
+	 * Returns a set with the names of all currently registered states.
+	 * @return set of names for all registered states.
+	 */
+	Set<String> getRegisteredStateNames();
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/PartitionableCheckpointStateOutputStream.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/PartitionableCheckpointStateOutputStream.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/PartitionableCheckpointStateOutputStream.java
new file mode 100644
index 0000000..065f9c2
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/PartitionableCheckpointStateOutputStream.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+
+public class PartitionableCheckpointStateOutputStream extends FSDataOutputStream {
+
+	private final Map<String, long[]> stateNameToPartitionOffsets;
+	private final CheckpointStreamFactory.CheckpointStateOutputStream delegate;
+
+	public PartitionableCheckpointStateOutputStream(CheckpointStreamFactory.CheckpointStateOutputStream delegate) {
+		this.delegate = Preconditions.checkNotNull(delegate);
+		this.stateNameToPartitionOffsets = new HashMap<>();
+	}
+
+	@Override
+	public long getPos() throws IOException {
+		return delegate.getPos();
+	}
+
+	@Override
+	public void flush() throws IOException {
+		delegate.flush();
+	}
+
+	@Override
+	public void sync() throws IOException {
+		delegate.sync();
+	}
+
+	@Override
+	public void write(int b) throws IOException {
+		delegate.write(b);
+	}
+
+	@Override
+	public void write(byte[] b) throws IOException {
+		delegate.write(b);
+	}
+
+	@Override
+	public void write(byte[] b, int off, int len) throws IOException {
+		delegate.write(b, off, len);
+	}
+
+	@Override
+	public void close() throws IOException {
+		delegate.close();
+	}
+
+	public OperatorStateHandle closeAndGetHandle() throws IOException {
+		StreamStateHandle streamStateHandle = delegate.closeAndGetHandle();
+		return new OperatorStateHandle(streamStateHandle, stateNameToPartitionOffsets);
+	}
+
+	public void startNewPartition(String stateName) throws IOException {
+		long[] offs = stateNameToPartitionOffsets.get(stateName);
+		if (offs == null) {
+			offs = new long[1];
+		} else {
+			//TODO maybe we can use some primitive array list here instead of an array to avoid resize on each call.
+			offs = Arrays.copyOf(offs, offs.length + 1);
+		}
+
+		offs[offs.length - 1] = getPos();
+		stateNameToPartitionOffsets.put(stateName, offs);
+	}
+
+	public static PartitionableCheckpointStateOutputStream wrap(
+			CheckpointStreamFactory.CheckpointStateOutputStream stream) {
+		return new PartitionableCheckpointStateOutputStream(stream);
+	}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/RetrievableStreamStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/RetrievableStreamStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/RetrievableStreamStateHandle.java
index 9ecc4c9..9934382 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/RetrievableStreamStateHandle.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/RetrievableStreamStateHandle.java
@@ -76,6 +76,6 @@ public class RetrievableStreamStateHandle<T extends Serializable> implements
 
 	@Override
 	public void close() throws IOException {
-		wrappedStreamStateHandle.close();
+//		wrappedStreamStateHandle.close();
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/SnapshotProvider.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SnapshotProvider.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SnapshotProvider.java
new file mode 100644
index 0000000..c47fedd
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SnapshotProvider.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import java.util.concurrent.RunnableFuture;
+
+/**
+ * Interface for operations that can perform snapshots of their state.
+ *
+ * @param <S> Generic type of the state object that is created as handle to snapshots.
+ */
+public interface SnapshotProvider<S extends StateObject> {
+
+	/**
+	 * Operation that writes a snapshot into a stream that is provided by the given {@link CheckpointStreamFactory} and
+	 * returns a @{@link RunnableFuture} that gives a state handle to the snapshot. It is up to the implementation if
+	 * the operation is performed synchronous or asynchronous. In the later case, the returned Runnable must be executed
+	 * first before obtaining the handle.
+	 *
+	 * @param checkpointId  The ID of the checkpoint.
+	 * @param timestamp     The timestamp of the checkpoint.
+	 * @param streamFactory The factory that we can use for writing our state to streams.
+	 * @return A runnable future that will yield a {@link StateObject}.
+	 */
+	RunnableFuture<S> snapshot(
+			long checkpointId,
+			long timestamp,
+			CheckpointStreamFactory streamFactory) throws Exception;
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateObject.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateObject.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateObject.java
index 4c65318..a502b9d 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateObject.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateObject.java
@@ -28,13 +28,9 @@ import java.io.IOException;
  * <ul>
  *     <li><b>Discard State</b>: The {@link #discardState()} method defines how state is permanently
  *         disposed/deleted. After that method call, state may not be recoverable any more.</li>
- 
- *     <li><b>Close the current state access</b>: The {@link #close()} method defines how to
- *         stop the current access or recovery to the state. Called for example when an operation is
- *         canceled during recovery.</li>
  * </ul>
  */
-public interface StateObject extends java.io.Closeable, java.io.Serializable {
+public interface StateObject extends java.io.Serializable {
 
 	/**
 	 * Discards the state referred to by this handle, to free up resources in

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtil.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtil.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtil.java
index aa28404..a4799bf 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtil.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtil.java
@@ -18,8 +18,6 @@
 
 package org.apache.flink.runtime.state;
 
-import java.io.IOException;
-
 /**
  * Helpers for {@link StateObject} related code.
  */
@@ -63,39 +61,4 @@ public class StateUtil {
 			}
 		}
 	}
-
-	/**
-	 * Iterates through the passed state handles and calls discardState() on each handle that is not null. All
-	 * occurring exceptions are suppressed and collected until the iteration is over and emitted as a single exception.
-	 *
-	 * @param handlesToDiscard State handles to discard. Passed iterable is allowed to deliver null values.
-	 * @throws IOException exception that is a collection of all suppressed exceptions that were caught during iteration
-	 */
-	public static void bestEffortCloseAllStateObjects(
-			Iterable<? extends StateObject> handlesToDiscard) throws IOException {
-
-		if (handlesToDiscard != null) {
-
-			IOException suppressedExceptions = null;
-
-			for (StateObject state : handlesToDiscard) {
-
-				if (state != null) {
-					try {
-						state.close();
-					} catch (Exception ex) {
-						//best effort to still cleanup other states and deliver exceptions in the end
-						if (suppressedExceptions == null) {
-							suppressedExceptions = new IOException(ex);
-						}
-						suppressedExceptions.addSuppressed(ex);
-					}
-				}
-			}
-
-			if (suppressedExceptions != null) {
-				throw suppressedExceptions;
-			}
-		}
-	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FileStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FileStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FileStateHandle.java
index f361263..29e905c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FileStateHandle.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FileStateHandle.java
@@ -21,7 +21,6 @@ package org.apache.flink.runtime.state.filesystem;
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
-import org.apache.flink.runtime.state.AbstractCloseableHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 
 import java.io.IOException;
@@ -34,7 +33,7 @@ import static org.apache.flink.util.Preconditions.checkNotNull;
  * {@link StreamStateHandle} for state that was written to a file stream. The written data is
  * identifier by the file path. The state can be read again by calling {@link #openInputStream()}.
  */
-public class FileStateHandle extends AbstractCloseableHandle implements StreamStateHandle {
+public class FileStateHandle implements StreamStateHandle {
 
 	private static final long serialVersionUID = 350284443258002355L;
 
@@ -69,10 +68,7 @@ public class FileStateHandle extends AbstractCloseableHandle implements StreamSt
 
 	@Override
 	public FSDataInputStream openInputStream() throws IOException {
-		ensureNotClosed();
-		FSDataInputStream inputStream = getFileSystem().open(filePath);
-		registerCloseable(inputStream);
-		return inputStream;
+		return getFileSystem().open(filePath);
 	}
 
 	/**

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
index 99e3684..e027632 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
@@ -24,11 +24,11 @@ import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -175,7 +175,7 @@ public class FsStateBackend extends AbstractStateBackend {
 	}
 
 	@Override
-	public <K> KeyedStateBackend<K> createKeyedStateBackend(
+	public <K> AbstractKeyedStateBackend<K> createKeyedStateBackend(
 			Environment env,
 			JobID jobID,
 			String operatorIdentifier,
@@ -192,7 +192,7 @@ public class FsStateBackend extends AbstractStateBackend {
 	}
 
 	@Override
-	public <K> KeyedStateBackend<K> restoreKeyedStateBackend(
+	public <K> AbstractKeyedStateBackend<K> restoreKeyedStateBackend(
 			Environment env,
 			JobID jobID,
 			String operatorIdentifier,

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
index c13be70..a766373 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.state.heap;
 
+import org.apache.commons.io.IOUtils;
 import org.apache.flink.api.common.state.FoldingState;
 import org.apache.flink.api.common.state.FoldingStateDescriptor;
 import org.apache.flink.api.common.state.ListState;
@@ -27,17 +28,18 @@ import org.apache.flink.api.common.state.ReducingStateDescriptor;
 import org.apache.flink.api.common.state.ValueState;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.ArrayListSerializer;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.DoneFuture;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.util.InstantiationUtil;
 import org.apache.flink.util.Preconditions;
@@ -51,13 +53,13 @@ import java.util.Map;
 import java.util.concurrent.RunnableFuture;
 
 /**
- * A {@link KeyedStateBackend} that keeps state on the Java Heap and will serialize state to
+ * A {@link AbstractKeyedStateBackend} that keeps state on the Java Heap and will serialize state to
  * streams provided by a {@link org.apache.flink.runtime.state.CheckpointStreamFactory} upon
  * checkpointing.
  *
  * @param <K> The key by which state is keyed.
  */
-public class HeapKeyedStateBackend<K> extends KeyedStateBackend<K> {
+public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 	private static final Logger LOG = LoggerFactory.getLogger(HeapKeyedStateBackend.class);
 
@@ -165,85 +167,83 @@ public class HeapKeyedStateBackend<K> extends KeyedStateBackend<K> {
 			long timestamp,
 			CheckpointStreamFactory streamFactory) throws Exception {
 
-		CheckpointStreamFactory.CheckpointStateOutputStream stream =
-				streamFactory.createCheckpointStateOutputStream(
-						checkpointId,
-						timestamp);
-
 		if (stateTables.isEmpty()) {
 			return new DoneFuture<>(null);
 		}
 
-		DataOutputViewStreamWrapper outView = new DataOutputViewStreamWrapper(stream);
+		try (CheckpointStreamFactory.CheckpointStateOutputStream stream = streamFactory.
+				createCheckpointStateOutputStream(checkpointId, timestamp)) {
 
-		Preconditions.checkState(stateTables.size() <= Short.MAX_VALUE,
-				"Too many KV-States: " + stateTables.size() +
-						". Currently at most " + Short.MAX_VALUE + " states are supported");
+			DataOutputViewStreamWrapper outView = new DataOutputViewStreamWrapper(stream);
 
-		outView.writeShort(stateTables.size());
+			Preconditions.checkState(stateTables.size() <= Short.MAX_VALUE,
+					"Too many KV-States: " + stateTables.size() +
+							". Currently at most " + Short.MAX_VALUE + " states are supported");
 
-		Map<String, Integer> kVStateToId = new HashMap<>(stateTables.size());
+			outView.writeShort(stateTables.size());
 
-		for (Map.Entry<String, StateTable<K, ?, ?>> kvState : stateTables.entrySet()) {
+			Map<String, Integer> kVStateToId = new HashMap<>(stateTables.size());
 
-			outView.writeUTF(kvState.getKey());
+			for (Map.Entry<String, StateTable<K, ?, ?>> kvState : stateTables.entrySet()) {
 
-			TypeSerializer namespaceSerializer = kvState.getValue().getNamespaceSerializer();
-			TypeSerializer stateSerializer = kvState.getValue().getStateSerializer();
+				outView.writeUTF(kvState.getKey());
 
-			InstantiationUtil.serializeObject(stream, namespaceSerializer);
-			InstantiationUtil.serializeObject(stream, stateSerializer);
+				TypeSerializer namespaceSerializer = kvState.getValue().getNamespaceSerializer();
+				TypeSerializer stateSerializer = kvState.getValue().getStateSerializer();
 
-			kVStateToId.put(kvState.getKey(), kVStateToId.size());
-		}
+				InstantiationUtil.serializeObject(stream, namespaceSerializer);
+				InstantiationUtil.serializeObject(stream, stateSerializer);
 
-		int offsetCounter = 0;
-		long[] keyGroupRangeOffsets = new long[keyGroupRange.getNumberOfKeyGroups()];
+				kVStateToId.put(kvState.getKey(), kVStateToId.size());
+			}
 
-		for (int keyGroupIndex = keyGroupRange.getStartKeyGroup(); keyGroupIndex <= keyGroupRange.getEndKeyGroup(); keyGroupIndex++) {
-			keyGroupRangeOffsets[offsetCounter++] = stream.getPos();
-			outView.writeInt(keyGroupIndex);
+			int offsetCounter = 0;
+			long[] keyGroupRangeOffsets = new long[keyGroupRange.getNumberOfKeyGroups()];
 
-			for (Map.Entry<String, StateTable<K, ?, ?>> kvState : stateTables.entrySet()) {
+			for (int keyGroupIndex = keyGroupRange.getStartKeyGroup(); keyGroupIndex <= keyGroupRange.getEndKeyGroup(); keyGroupIndex++) {
+				keyGroupRangeOffsets[offsetCounter++] = stream.getPos();
+				outView.writeInt(keyGroupIndex);
 
-				outView.writeShort(kVStateToId.get(kvState.getKey()));
+				for (Map.Entry<String, StateTable<K, ?, ?>> kvState : stateTables.entrySet()) {
 
-				TypeSerializer namespaceSerializer = kvState.getValue().getNamespaceSerializer();
-				TypeSerializer stateSerializer = kvState.getValue().getStateSerializer();
+					outView.writeShort(kVStateToId.get(kvState.getKey()));
 
-				// Map<NamespaceT, Map<KeyT, StateT>>
-				Map<?, ? extends Map<K, ?>> namespaceMap = kvState.getValue().get(keyGroupIndex);
-				if (namespaceMap == null) {
-					outView.writeByte(0);
-					continue;
-				}
+					TypeSerializer namespaceSerializer = kvState.getValue().getNamespaceSerializer();
+					TypeSerializer stateSerializer = kvState.getValue().getStateSerializer();
+
+					// Map<NamespaceT, Map<KeyT, StateT>>
+					Map<?, ? extends Map<K, ?>> namespaceMap = kvState.getValue().get(keyGroupIndex);
+					if (namespaceMap == null) {
+						outView.writeByte(0);
+						continue;
+					}
 
-				outView.writeByte(1);
+					outView.writeByte(1);
 
-				// number of namespaces
-				outView.writeInt(namespaceMap.size());
-				for (Map.Entry<?, ? extends Map<K, ?>> namespace : namespaceMap.entrySet()) {
-					namespaceSerializer.serialize(namespace.getKey(), outView);
+					// number of namespaces
+					outView.writeInt(namespaceMap.size());
+					for (Map.Entry<?, ? extends Map<K, ?>> namespace : namespaceMap.entrySet()) {
+						namespaceSerializer.serialize(namespace.getKey(), outView);
 
-					Map<K, ?> entryMap = namespace.getValue();
+						Map<K, ?> entryMap = namespace.getValue();
 
-					// number of entries
-					outView.writeInt(entryMap.size());
-					for (Map.Entry<K, ?> entry : entryMap.entrySet()) {
-						keySerializer.serialize(entry.getKey(), outView);
-						stateSerializer.serialize(entry.getValue(), outView);
+						// number of entries
+						outView.writeInt(entryMap.size());
+						for (Map.Entry<K, ?> entry : entryMap.entrySet()) {
+							keySerializer.serialize(entry.getKey(), outView);
+							stateSerializer.serialize(entry.getValue(), outView);
+						}
 					}
 				}
+				outView.flush();
 			}
-			outView.flush();
-		}
-
-		StreamStateHandle streamStateHandle = stream.closeAndGetHandle();
 
-		KeyGroupRangeOffsets offsets = new KeyGroupRangeOffsets(keyGroupRange, keyGroupRangeOffsets);
-		final KeyGroupsStateHandle keyGroupsStateHandle = new KeyGroupsStateHandle(offsets, streamStateHandle);
+			StreamStateHandle streamStateHandle = stream.closeAndGetHandle();
 
-		return new DoneFuture(keyGroupsStateHandle);
+			KeyGroupRangeOffsets offsets = new KeyGroupRangeOffsets(keyGroupRange, keyGroupRangeOffsets);
+			final KeyGroupsStateHandle keyGroupsStateHandle = new KeyGroupsStateHandle(offsets, streamStateHandle);
+			return new DoneFuture<>(keyGroupsStateHandle);
+		}
 	}
 
 	@SuppressWarnings({"unchecked", "rawtypes"})
@@ -251,71 +251,81 @@ public class HeapKeyedStateBackend<K> extends KeyedStateBackend<K> {
 
 		for (KeyGroupsStateHandle keyGroupsHandle : state) {
 
-			if(keyGroupsHandle == null) {
+			if (keyGroupsHandle == null) {
 				continue;
 			}
 
-			FSDataInputStream fsDataInputStream = keyGroupsHandle.getStateHandle().openInputStream();
-			DataInputViewStreamWrapper inView = new DataInputViewStreamWrapper(fsDataInputStream);
+			FSDataInputStream fsDataInputStream = null;
 
-			int numKvStates = inView.readShort();
+			try {
 
-			Map<Integer, String> kvStatesById = new HashMap<>(numKvStates);
+				fsDataInputStream = keyGroupsHandle.getStateHandle().openInputStream();
+				cancelStreamRegistry.registerClosable(fsDataInputStream);
 
-			for (int i = 0; i < numKvStates; ++i) {
-				String stateName = inView.readUTF();
+				DataInputViewStreamWrapper inView = new DataInputViewStreamWrapper(fsDataInputStream);
 
-				TypeSerializer namespaceSerializer =
-						InstantiationUtil.deserializeObject(fsDataInputStream, userCodeClassLoader);
-				TypeSerializer stateSerializer =
-						InstantiationUtil.deserializeObject(fsDataInputStream, userCodeClassLoader);
+				int numKvStates = inView.readShort();
 
-				StateTable<K, ?, ?> stateTable = new StateTable(
-						stateSerializer,
-						namespaceSerializer,
-						keyGroupRange);
-				stateTables.put(stateName, stateTable);
-				kvStatesById.put(i, stateName);
-			}
+				Map<Integer, String> kvStatesById = new HashMap<>(numKvStates);
+
+				for (int i = 0; i < numKvStates; ++i) {
+					String stateName = inView.readUTF();
 
-			for (int keyGroupIndex = keyGroupRange.getStartKeyGroup();  keyGroupIndex <= keyGroupRange.getEndKeyGroup(); ++keyGroupIndex) {
-				long offset = keyGroupsHandle.getOffsetForKeyGroup(keyGroupIndex);
-				fsDataInputStream.seek(offset);
+					TypeSerializer namespaceSerializer =
+							InstantiationUtil.deserializeObject(fsDataInputStream, userCodeClassLoader);
+					TypeSerializer stateSerializer =
+							InstantiationUtil.deserializeObject(fsDataInputStream, userCodeClassLoader);
 
-				int writtenKeyGroupIndex = inView.readInt();
-				assert writtenKeyGroupIndex == keyGroupIndex;
+					StateTable<K, ?, ?> stateTable = new StateTable(stateSerializer,
+							namespaceSerializer,
+							keyGroupRange);
+					stateTables.put(stateName, stateTable);
+					kvStatesById.put(i, stateName);
+				}
 
-				for (int i = 0; i < numKvStates; i++) {
-					int kvStateId = inView.readShort();
+				for (Tuple2<Integer, Long> groupOffset : keyGroupsHandle.getGroupRangeOffsets()) {
+					int keyGroupIndex = groupOffset.f0;
+					long offset = groupOffset.f1;
+					fsDataInputStream.seek(offset);
 
-					byte isPresent = inView.readByte();
-					if (isPresent == 0) {
-						continue;
-					}
+					int writtenKeyGroupIndex = inView.readInt();
+					assert writtenKeyGroupIndex == keyGroupIndex;
+
+					for (int i = 0; i < numKvStates; i++) {
+						int kvStateId = inView.readShort();
+
+						byte isPresent = inView.readByte();
+						if (isPresent == 0) {
+							continue;
+						}
 
-					StateTable<K, ?, ?> stateTable = stateTables.get(kvStatesById.get(kvStateId));
-					Preconditions.checkNotNull(stateTable);
+						StateTable<K, ?, ?> stateTable = stateTables.get(kvStatesById.get(kvStateId));
+						Preconditions.checkNotNull(stateTable);
 
-					TypeSerializer namespaceSerializer = stateTable.getNamespaceSerializer();
-					TypeSerializer stateSerializer = stateTable.getStateSerializer();
+						TypeSerializer namespaceSerializer = stateTable.getNamespaceSerializer();
+						TypeSerializer stateSerializer = stateTable.getStateSerializer();
 
-					Map namespaceMap = new HashMap<>();
-					stateTable.set(keyGroupIndex, namespaceMap);
+						Map namespaceMap = new HashMap<>();
+						stateTable.set(keyGroupIndex, namespaceMap);
 
-					int numNamespaces = inView.readInt();
-					for (int k = 0; k < numNamespaces; k++) {
-						Object namespace = namespaceSerializer.deserialize(inView);
-						Map entryMap = new HashMap<>();
-						namespaceMap.put(namespace, entryMap);
+						int numNamespaces = inView.readInt();
+						for (int k = 0; k < numNamespaces; k++) {
+							Object namespace = namespaceSerializer.deserialize(inView);
+							Map entryMap = new HashMap<>();
+							namespaceMap.put(namespace, entryMap);
 
-						int numEntries = inView.readInt();
-						for (int l = 0; l < numEntries; l++) {
-							Object key = keySerializer.deserialize(inView);
-							Object value = stateSerializer.deserialize(inView);
-							entryMap.put(key, value);
+							int numEntries = inView.readInt();
+							for (int l = 0; l < numEntries; l++) {
+								Object key = keySerializer.deserialize(inView);
+								Object value = stateSerializer.deserialize(inView);
+								entryMap.put(key, value);
+							}
 						}
 					}
 				}
+			} finally {
+				cancelStreamRegistry.unregisterClosable(fsDataInputStream);
+				IOUtils.closeQuietly(fsDataInputStream);
 			}
 		}
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java
index b9ff255..7d8b6ce 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java
@@ -19,10 +19,8 @@
 package org.apache.flink.runtime.state.memory;
 
 import org.apache.flink.core.fs.FSDataInputStream;
-import org.apache.flink.runtime.state.AbstractCloseableHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.util.InstantiationUtil;
-
 import org.apache.flink.util.Preconditions;
 
 import java.io.IOException;
@@ -32,7 +30,7 @@ import java.util.Arrays;
 /**
  * A state handle that contains stream state in a byte array.
  */
-public class ByteStreamStateHandle extends AbstractCloseableHandle implements StreamStateHandle {
+public class ByteStreamStateHandle implements StreamStateHandle {
 
 	private static final long serialVersionUID = -5280226231200217594L;
 
@@ -52,9 +50,8 @@ public class ByteStreamStateHandle extends AbstractCloseableHandle implements St
 
 	@Override
 	public FSDataInputStream openInputStream() throws IOException {
-		ensureNotClosed();
 
-		FSDataInputStream inputStream = new FSDataInputStream() {
+		return new FSDataInputStream() {
 			int index = 0;
 
 			@Override
@@ -73,8 +70,6 @@ public class ByteStreamStateHandle extends AbstractCloseableHandle implements St
 				return index < data.length ? data[index++] & 0xFF : -1;
 			}
 		};
-		registerCloseable(inputStream);
-		return inputStream;
 	}
 
 	public byte[] getData() {
@@ -106,9 +101,7 @@ public class ByteStreamStateHandle extends AbstractCloseableHandle implements St
 
 	@Override
 	public int hashCode() {
-		int result = super.hashCode();
-		result = 31 * result + Arrays.hashCode(data);
-		return result;
+		return Arrays.hashCode(data);
 	}
 
 	public static StreamStateHandle fromSerializable(Serializable value) throws IOException {

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
index cc145ff..1772dbe 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
@@ -22,11 +22,11 @@ import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend;
 
 import java.io.IOException;
@@ -71,12 +71,13 @@ public class MemoryStateBackend extends AbstractStateBackend {
 	}
 
 	@Override
-	public CheckpointStreamFactory createStreamFactory(JobID jobId, String operatorIdentifier) throws IOException {
+	public CheckpointStreamFactory createStreamFactory(
+			JobID jobId, String operatorIdentifier) throws IOException {
 		return new MemCheckpointStreamFactory(maxStateSize);
 	}
 
 	@Override
-	public <K> KeyedStateBackend<K> createKeyedStateBackend(
+	public <K> AbstractKeyedStateBackend<K> createKeyedStateBackend(
 			Environment env, JobID jobID,
 			String operatorIdentifier,
 			TypeSerializer<K> keySerializer,
@@ -93,7 +94,7 @@ public class MemoryStateBackend extends AbstractStateBackend {
 	}
 
 	@Override
-	public <K> KeyedStateBackend<K> restoreKeyedStateBackend(
+	public <K> AbstractKeyedStateBackend<K> restoreKeyedStateBackend(
 			Environment env, JobID jobID,
 			String operatorIdentifier,
 			TypeSerializer<K> keySerializer,

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java
index c317bed..8bf1127 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java
@@ -23,13 +23,9 @@ import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.instance.ActorGateway;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint;
-import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.util.Preconditions;
 
-import java.util.List;
-
 /**
  * Implementation using {@link ActorGateway} to forward the messages.
  */
@@ -46,8 +42,7 @@ public class ActorGatewayCheckpointResponder implements CheckpointResponder {
 			JobID jobID,
 			ExecutionAttemptID executionAttemptID,
 			long checkpointID,
-			ChainedStateHandle<StreamStateHandle> chainedStateHandle,
-			List<KeyGroupsStateHandle> keyGroupStateHandles,
+			CheckpointStateHandles checkpointStateHandles,
 			long synchronousDurationMillis,
 			long asynchronousDurationMillis,
 			long bytesBufferedInAlignment,
@@ -55,7 +50,7 @@ public class ActorGatewayCheckpointResponder implements CheckpointResponder {
 
 		AcknowledgeCheckpoint message = new AcknowledgeCheckpoint(
 				jobID, executionAttemptID, checkpointID,
-				chainedStateHandle, keyGroupStateHandles,
+				checkpointStateHandles,
 				synchronousDurationMillis, asynchronousDurationMillis,
 				bytesBufferedInAlignment, alignmentDurationNanos);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java
index b3f9827..698a7f4 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java
@@ -20,11 +20,7 @@ package org.apache.flink.runtime.taskmanager;
 
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
-import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.StreamStateHandle;
-
-import java.util.List;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 
 /**
  * Responder for checkpoint acknowledge and decline messages in the {@link Task}.
@@ -40,10 +36,8 @@ public interface CheckpointResponder {
 	 *             Execution attempt ID of the running task
 	 * @param checkpointID
 	 *             Checkpoint ID of the checkpoint
-	 * @param chainedStateHandle
-	 *             Chained state handle
-	 * @param keyGroupStateHandles
-	 *             State handles for key groups
+	 * @param checkpointStateHandles 
+	 *             State handles for the checkpoint
 	 * @param synchronousDurationMillis
 	 *             The duration (in milliseconds) of the synchronous part of the operator checkpoint
 	 * @param asynchronousDurationMillis
@@ -57,8 +51,7 @@ public interface CheckpointResponder {
 		JobID jobID,
 		ExecutionAttemptID executionAttemptID,
 		long checkpointID,
-		ChainedStateHandle<StreamStateHandle> chainedStateHandle,
-		List<KeyGroupsStateHandle> keyGroupStateHandles,
+		CheckpointStateHandles checkpointStateHandles,
 		long synchronousDurationMillis,
 		long asynchronousDurationMillis,
 		long bytesBufferedInAlignment,

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java
index 23b6f82..c2ba7ef 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java
@@ -35,11 +35,8 @@ import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
 import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.metrics.groups.TaskMetricGroup;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
-import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 
-import java.util.List;
 import java.util.Map;
 import java.util.concurrent.Future;
 
@@ -246,7 +243,7 @@ public class RuntimeEnvironment implements Environment {
 			long bytesBufferedInAlignment,
 			long alignmentDurationNanos) {
 
-		acknowledgeCheckpoint(checkpointId, null, null,
+		acknowledgeCheckpoint(checkpointId, null,
 				synchronousDurationMillis, asynchronousDurationMillis,
 				bytesBufferedInAlignment, alignmentDurationNanos);
 	}
@@ -254,8 +251,7 @@ public class RuntimeEnvironment implements Environment {
 	@Override
 	public void acknowledgeCheckpoint(
 			long checkpointId,
-			ChainedStateHandle<StreamStateHandle> chainedStateHandle,
-			List<KeyGroupsStateHandle> keyGroupStateHandles,
+			CheckpointStateHandles checkpointStateHandles,
 			long synchronousDurationMillis,
 			long asynchronousDurationMillis,
 			long bytesBufferedInAlignment,
@@ -264,7 +260,7 @@ public class RuntimeEnvironment implements Environment {
 
 		checkpointResponder.acknowledgeCheckpoint(
 				jobId, executionId, checkpointId,
-				chainedStateHandle, keyGroupStateHandles,
+				checkpointStateHandles,
 				synchronousDurationMillis, asynchronousDurationMillis,
 				bytesBufferedInAlignment, alignmentDurationNanos);
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
index 62dc8b7..8463fa0 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
@@ -59,6 +59,7 @@ import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 
 import org.apache.flink.util.Preconditions;
@@ -68,6 +69,7 @@ import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
 import java.net.URL;
+import java.util.Collection;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -241,6 +243,8 @@ public class Task implements Runnable, TaskActions {
 	 */
 	private volatile List<KeyGroupsStateHandle> keyGroupStates;
 
+	private volatile List<Collection<OperatorStateHandle>> partitionableOperatorState;
+
 	/** Initialized from the Flink configuration. May also be set at the ExecutionConfig */
 	private long taskCancellationInterval;
 
@@ -278,6 +282,7 @@ public class Task implements Runnable, TaskActions {
 		this.chainedOperatorState = tdd.getOperatorState();
 		this.serializedExecutionConfig = checkNotNull(tdd.getSerializedExecutionConfig());
 		this.keyGroupStates = tdd.getKeyGroupState();
+		this.partitionableOperatorState = tdd.getPartitionableOperatorState();
 
 		this.taskCancellationInterval = jobConfiguration.getLong(
 			ConfigConstants.TASK_CANCELLATION_INTERVAL_MILLIS,
@@ -488,7 +493,7 @@ public class Task implements Runnable, TaskActions {
 		Map<String, Future<Path>> distributedCacheEntries = new HashMap<String, Future<Path>>();
 		AbstractInvokable invokable = null;
 
-		ClassLoader userCodeClassLoader = null;
+		ClassLoader userCodeClassLoader;
 		try {
 			// ----------------------------
 			//  Task Bootstrap - We periodically
@@ -564,10 +569,10 @@ public class Task implements Runnable, TaskActions {
 			// the state into the task. the state is non-empty if this is an execution
 			// of a task that failed but had backuped state from a checkpoint
 
-			if (chainedOperatorState != null || keyGroupStates != null) {
+			if (chainedOperatorState != null || keyGroupStates != null || partitionableOperatorState != null) {
 				if (invokable instanceof StatefulTask) {
 					StatefulTask op = (StatefulTask) invokable;
-					op.setInitialState(chainedOperatorState, keyGroupStates);
+					op.setInitialState(chainedOperatorState, keyGroupStates, partitionableOperatorState);
 				} else {
 					throw new IllegalStateException("Found operator state for a non-stateful task invokable");
 				}


[10/10] flink git commit: [FLINK-4379] [checkpoints] Fix minor bug and improve debug logging

Posted by se...@apache.org.
[FLINK-4379] [checkpoints] Fix minor bug and improve debug logging


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/6f8f5eb3
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/6f8f5eb3
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/6f8f5eb3

Branch: refs/heads/master
Commit: 6f8f5eb3b9ba07cd3bb4d9f7edd43d4b8862acbe
Parents: 53ed6ad
Author: Stephan Ewen <se...@apache.org>
Authored: Thu Sep 29 21:12:38 2016 +0200
Committer: Stephan Ewen <se...@apache.org>
Committed: Fri Sep 30 12:38:46 2016 +0200

----------------------------------------------------------------------
 .../streaming/runtime/tasks/StreamTask.java      | 19 +++++++++++++------
 1 file changed, 13 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/6f8f5eb3/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index 1725eca..88c3ba4 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -717,6 +717,13 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 
 				cancelables.registerClosable(asyncCheckpointRunnable);
 				asyncOperationsThreadPool.submit(asyncCheckpointRunnable);
+
+				if (LOG.isDebugEnabled()) {
+					LOG.debug("{} - finished synchronous part of checkpoint {}." +
+							"Alignment duration: {} ms, snapshot duration {} ms",
+							getName(), checkpointId, alignmentDurationNanos / 1_000_000, syncDurationMillis);
+				}
+
 				return true;
 			} else {
 				return false;
@@ -998,12 +1005,12 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 				final long asyncEndNanos = System.nanoTime();
 				final long asyncDurationMillis = (asyncEndNanos - asyncStartNanos) / 1_000_000;
 
-				if (nonPartitionedStateHandles.isEmpty() && keyedStates.isEmpty()) {
-					owner.getEnvironment().acknowledgeCheckpoint(checkpointId,
+				if (nonPartitionedStateHandles.isEmpty() && partitioneableStateHandles.isEmpty() && keyedStates.isEmpty()) {
+					owner.getEnvironment().acknowledgeCheckpoint(
+							checkpointId,
 							syncDurationMillies, asyncDurationMillis,
 							bytesBufferedInAlignment, alignmentDurationNanos);
-				} else  {
-
+				} else {
 					CheckpointStateHandles allStateHandles = new CheckpointStateHandles(
 							nonPartitionedStateHandles,
 							partitioneableStateHandles,
@@ -1016,8 +1023,8 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 				}
 
 				if (LOG.isDebugEnabled()) {
-					LOG.debug("Finished asynchronous checkpoints for checkpoint {} on task {}. Returning handles on " +
-							"keyed states {}.", checkpointId, name, keyedStates);
+					LOG.debug("{} - finished asynchronous part of checkpoint {}. Asynchronous duration: {} ms", 
+							owner.getName(), checkpointId, asyncDurationMillis);
 				}
 			}
 			catch (Exception e) {


[02/10] flink git commit: [FLINK-4379] [checkpoints] Introduce rescalable operator state

Posted by se...@apache.org.
http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
index 2036f69..f638ddd 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
@@ -50,6 +50,7 @@ import org.apache.flink.runtime.plugable.NonReusingDeserializationDelegate;
 import org.apache.flink.runtime.query.KvStateRegistry;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
@@ -317,7 +318,7 @@ public class StreamMockEnvironment implements Environment {
 	@Override
 	public void acknowledgeCheckpoint(
 			long checkpointId,
-			ChainedStateHandle<StreamStateHandle> chainedStateHandle, List<KeyGroupsStateHandle> keyGroupStateHandles,
+			CheckpointStateHandles checkpointStateHandles,
 			long synchronousDurationMillis, long asynchronousDurationMillis,
 			long bytesBufferedInAlignment, long alignmentDurationNanos) {
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
index 430c6de..247edd6 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
@@ -24,12 +24,14 @@ import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.ClosureCleaner;
 import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
+import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.runtime.tasks.TimeServiceProvider;
 import org.mockito.invocation.InvocationOnMock;
@@ -41,11 +43,12 @@ import java.util.Collections;
 import java.util.concurrent.RunnableFuture;
 
 import static org.mockito.Matchers.any;
-import static org.mockito.Mockito.*;
+import static org.mockito.Mockito.anyInt;
+import static org.mockito.Mockito.doAnswer;
 
 /**
  * Extension of {@link OneInputStreamOperatorTestHarness} that allows the operator to get
- * a {@link KeyedStateBackend}.
+ * a {@link AbstractKeyedStateBackend}.
  *
  */
 public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
@@ -53,7 +56,7 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 
 	// in case the operator creates one we store it here so that we
 	// can snapshot its state
-	private KeyedStateBackend<?> keyedStateBackend = null;
+	private AbstractKeyedStateBackend<?> keyedStateBackend = null;
 
 	// when we restore we keep the state here so that we can call restore
 	// when the operator requests the keyed state backend
@@ -114,7 +117,7 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 					final KeyGroupRange keyGroupRange = (KeyGroupRange) invocationOnMock.getArguments()[2];
 
 					if(keyedStateBackend != null) {
-						keyedStateBackend.close();
+						keyedStateBackend.dispose();
 					}
 
 					if (restoredKeyedState == null) {
@@ -148,7 +151,7 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 	}
 
 	/**
-	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#snapshotState(org.apache.flink.core.fs.FSDataOutputStream, long, long)} ()}
+	 *
 	 */
 	@Override
 	public StreamStateHandle snapshot(long checkpointId, long timestamp) throws Exception {
@@ -159,7 +162,9 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 		CheckpointStreamFactory.CheckpointStateOutputStream outStream =
 				streamFactory.createCheckpointStateOutputStream(checkpointId, timestamp);
 
-		operator.snapshotState(outStream, checkpointId, timestamp);
+		if (operator instanceof StreamCheckpointedOperator) {
+			((StreamCheckpointedOperator) operator).snapshotState(outStream, checkpointId, timestamp);
+		}
 
 		if (keyedStateBackend != null) {
 			RunnableFuture<KeyGroupsStateHandle> keyedSnapshotRunnable = keyedStateBackend.snapshot(checkpointId,
@@ -180,17 +185,21 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 	}
 
 	/**
-	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#restoreState(org.apache.flink.core.fs.FSDataInputStream)} ()}
+	 * 
 	 */
 	@Override
 	public void restore(StreamStateHandle snapshot) throws Exception {
-		FSDataInputStream inStream = snapshot.openInputStream();
-		operator.restoreState(inStream);
+		try (FSDataInputStream inStream = snapshot.openInputStream()) {
+
+			if (operator instanceof StreamCheckpointedOperator) {
+				((StreamCheckpointedOperator) operator).restoreState(inStream);
+			}
 
-		byte keyedStatePresent = (byte) inStream.read();
-		if (keyedStatePresent == 1) {
-			ObjectInputStream ois = new ObjectInputStream(inStream);
-			this.restoredKeyedState = (KeyGroupsStateHandle) ois.readObject();
+			byte keyedStatePresent = (byte) inStream.read();
+			if (keyedStatePresent == 1) {
+				ObjectInputStream ois = new ObjectInputStream(inStream);
+				this.restoredKeyedState = (KeyGroupsStateHandle) ois.readObject();
+			}
 		}
 	}
 
@@ -200,7 +209,7 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 	public void close() throws Exception {
 		super.close();
 		if(keyedStateBackend != null) {
-			keyedStateBackend.close();
+			keyedStateBackend.dispose();
 		}
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java
index acf046a..d6f46fd 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java
@@ -22,7 +22,7 @@ import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.typeutils.TypeExtractor;
 import org.apache.flink.configuration.Configuration;
-import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.operators.testutils.MockEnvironment;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
@@ -30,6 +30,7 @@ import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
+import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.api.operators.Output;
@@ -39,7 +40,6 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.AsynchronousException;
 import org.apache.flink.streaming.runtime.tasks.DefaultTimeServiceProvider;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
-import org.apache.flink.streaming.runtime.tasks.TestTimeServiceProvider;
 import org.apache.flink.streaming.runtime.tasks.TimeServiceProvider;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
@@ -204,14 +204,18 @@ public class OneInputStreamOperatorTestHarness<IN, OUT> {
 	}
 
 	/**
-	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#snapshotState(FSDataOutputStream, long, long)} ()}
+	 *
 	 */
 	public StreamStateHandle snapshot(long checkpointId, long timestamp) throws Exception {
 		CheckpointStreamFactory.CheckpointStateOutputStream outStream = stateBackend.createStreamFactory(
 				new JobID(),
 				"test_op").createCheckpointStateOutputStream(checkpointId, timestamp);
-		operator.snapshotState(outStream, checkpointId, timestamp);
-		return outStream.closeAndGetHandle();
+		if(operator instanceof StreamCheckpointedOperator) {
+			((StreamCheckpointedOperator) operator).snapshotState(outStream, checkpointId, timestamp);
+			return outStream.closeAndGetHandle();
+		} else {
+			throw new RuntimeException("Operator is not StreamCheckpointedOperator");
+		}
 	}
 
 	/**
@@ -222,10 +226,16 @@ public class OneInputStreamOperatorTestHarness<IN, OUT> {
 	}
 
 	/**
-	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#restoreState(org.apache.flink.core.fs.FSDataInputStream)} ()}
+	 *
 	 */
 	public void restore(StreamStateHandle snapshot) throws Exception {
-		operator.restoreState(snapshot.openInputStream());
+		if(operator instanceof StreamCheckpointedOperator) {
+			try (FSDataInputStream in = snapshot.openInputStream()) {
+				((StreamCheckpointedOperator) operator).restoreState(in);
+			}
+		} else {
+			throw new RuntimeException("Operator is not StreamCheckpointedOperator");
+		}
 	}
 
 	/**

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UdfStreamOperatorCheckpointingITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UdfStreamOperatorCheckpointingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UdfStreamOperatorCheckpointingITCase.java
index c12bcb9..5874f56 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UdfStreamOperatorCheckpointingITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UdfStreamOperatorCheckpointingITCase.java
@@ -26,6 +26,7 @@ import org.apache.flink.api.java.tuple.Tuple;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.streaming.api.checkpoint.Checkpointed;
+import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
 import org.apache.flink.streaming.api.datastream.KeyedStream;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.api.functions.sink.SinkFunction;
@@ -35,6 +36,8 @@ import org.apache.flink.streaming.api.operators.StreamGroupedFold;
 import org.apache.flink.streaming.api.operators.StreamGroupedReduce;
 import org.junit.Assert;
 
+import java.util.Collections;
+import java.util.List;
 import java.util.Queue;
 import java.util.Random;
 
@@ -180,7 +183,7 @@ public class UdfStreamOperatorCheckpointingITCase extends StreamFaultToleranceTe
 	 */
 	private static class OnceFailingIdentityMapFunction
 			extends RichMapFunction<Tuple2<Integer, Long>, Tuple2<Integer, Long>> 
-			implements Checkpointed<Long> {
+			implements ListCheckpointed<Long> {
 
 		private static volatile boolean hasFailed = false;
 
@@ -211,15 +214,16 @@ public class UdfStreamOperatorCheckpointingITCase extends StreamFaultToleranceTe
 			return value;
 		}
 
-
 		@Override
-		public Long snapshotState(long checkpointId, long checkpointTimestamp) {
-			return count;
+		public List<Long> snapshotState(long checkpointId, long timestamp) throws Exception {
+			return Collections.singletonList(count);
 		}
 
 		@Override
-		public void restoreState(Long state) {
-			count = state;
+		public void restoreState(List<Long> state) throws Exception {
+			if(!state.isEmpty()) {
+				count = state.get(0);
+			}
 		}
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
index 694f006..2a635ab 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
@@ -21,17 +21,18 @@ package org.apache.flink.test.streaming.runtime;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.functions.RichMapFunction;
 import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.client.JobExecutionException;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase;
 import org.junit.Test;
@@ -66,7 +67,7 @@ public class StateBackendITCase extends StreamingMultipleProgramsTestBase {
 				@Override
 				public void open(Configuration parameters) throws Exception {
 					super.open(parameters);
-					getRuntimeContext().getKeyValueState("test", String.class, "");
+					getRuntimeContext().getState(new ValueStateDescriptor<Integer>("Test", Integer.class, 0));
 				}
 
 				@Override
@@ -99,7 +100,8 @@ public class StateBackendITCase extends StreamingMultipleProgramsTestBase {
 		}
 
 		@Override
-		public <K> KeyedStateBackend<K> createKeyedStateBackend(Environment env,
+		public <K> AbstractKeyedStateBackend<K> createKeyedStateBackend(
+				Environment env,
 				JobID jobID,
 				String operatorIdentifier,
 				TypeSerializer<K> keySerializer,
@@ -110,7 +112,8 @@ public class StateBackendITCase extends StreamingMultipleProgramsTestBase {
 		}
 
 		@Override
-		public <K> KeyedStateBackend<K> restoreKeyedStateBackend(Environment env,
+		public <K> AbstractKeyedStateBackend<K> restoreKeyedStateBackend(
+				Environment env,
 				JobID jobID,
 				String operatorIdentifier,
 				TypeSerializer<K> keySerializer,


[07/10] flink git commit: [FLINK-4379] [checkpoints] Introduce rescalable operator state

Posted by se...@apache.org.
http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
index f5e3618..7e4eded 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
@@ -20,7 +20,9 @@ package org.apache.flink.runtime.checkpoint;
 
 import com.google.common.collect.Iterables;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateObject;
 import org.apache.flink.runtime.state.StateUtil;
 import org.apache.flink.util.Preconditions;
@@ -47,27 +49,36 @@ public class TaskState implements StateObject {
 	/** handles to non-partitioned states, subtaskindex -> subtaskstate */
 	private final Map<Integer, SubtaskState> subtaskStates;
 
-	/** handles to partitioned states, subtaskindex -> keyed state */
+	/** handles to partitionable states, subtaskindex -> partitionable state */
+	private final Map<Integer, ChainedStateHandle<OperatorStateHandle>> partitionableStates;
+
+	/** handles to key-partitioned states, subtaskindex -> keyed state */
 	private final Map<Integer, KeyGroupsStateHandle> keyGroupsStateHandles;
 
+
 	/** parallelism of the operator when it was checkpointed */
 	private final int parallelism;
 
 	/** maximum parallelism of the operator when the job was first created */
 	private final int maxParallelism;
 
-	public TaskState(JobVertexID jobVertexID, int parallelism, int maxParallelism) {
+	private final int chainLength;
+
+	public TaskState(JobVertexID jobVertexID, int parallelism, int maxParallelism, int chainLength) {
 		Preconditions.checkArgument(
 				parallelism <= maxParallelism,
 				"Parallelism " + parallelism + " is not smaller or equal to max parallelism " + maxParallelism + ".");
+		Preconditions.checkArgument(chainLength > 0, "There has to be at least one operator in the operator chain.");
 
 		this.jobVertexID = jobVertexID;
 
 		this.subtaskStates = new HashMap<>(parallelism);
+		this.partitionableStates = new HashMap<>(parallelism);
 		this.keyGroupsStateHandles = new HashMap<>(parallelism);
 
 		this.parallelism = parallelism;
 		this.maxParallelism = maxParallelism;
+		this.chainLength = chainLength;
 	}
 
 	public JobVertexID getJobVertexID() {
@@ -85,6 +96,20 @@ public class TaskState implements StateObject {
 		}
 	}
 
+	public void putPartitionableState(
+			int subtaskIndex,
+			ChainedStateHandle<OperatorStateHandle> partitionableState) {
+
+		Preconditions.checkNotNull(partitionableState);
+
+		if (subtaskIndex < 0 || subtaskIndex >= parallelism) {
+			throw new IndexOutOfBoundsException("The given sub task index " + subtaskIndex +
+					" exceeds the maximum number of sub tasks " + subtaskStates.size());
+		} else {
+			partitionableStates.put(subtaskIndex, partitionableState);
+		}
+	}
+
 	public void putKeyedState(int subtaskIndex, KeyGroupsStateHandle keyGroupsStateHandle) {
 		Preconditions.checkNotNull(keyGroupsStateHandle);
 
@@ -106,6 +131,15 @@ public class TaskState implements StateObject {
 		}
 	}
 
+	public ChainedStateHandle<OperatorStateHandle> getPartitionableState(int subtaskIndex) {
+		if (subtaskIndex < 0 || subtaskIndex >= parallelism) {
+			throw new IndexOutOfBoundsException("The given sub task index " + subtaskIndex +
+					" exceeds the maximum number of sub tasks " + subtaskStates.size());
+		} else {
+			return partitionableStates.get(subtaskIndex);
+		}
+	}
+
 	public KeyGroupsStateHandle getKeyGroupState(int subtaskIndex) {
 		if (subtaskIndex < 0 || subtaskIndex >= parallelism) {
 			throw new IndexOutOfBoundsException("The given sub task index " + subtaskIndex +
@@ -131,6 +165,10 @@ public class TaskState implements StateObject {
 		return maxParallelism;
 	}
 
+	public int getChainLength() {
+		return chainLength;
+	}
+
 	public Collection<KeyGroupsStateHandle> getKeyGroupStates() {
 		return keyGroupsStateHandles.values();
 	}
@@ -147,7 +185,7 @@ public class TaskState implements StateObject {
 	@Override
 	public void discardState() throws Exception {
 		StateUtil.bestEffortDiscardAllStateObjects(
-				Iterables.concat(subtaskStates.values(), keyGroupsStateHandles.values()));
+				Iterables.concat(subtaskStates.values(), partitionableStates.values(), keyGroupsStateHandles.values()));
 	}
 
 
@@ -156,11 +194,19 @@ public class TaskState implements StateObject {
 		long result = 0L;
 
 		for (int i = 0; i < parallelism; i++) {
-			if (subtaskStates.get(i) != null) {
-				result += subtaskStates.get(i).getStateSize();
+			SubtaskState subtaskState = subtaskStates.get(i);
+			if (subtaskState != null) {
+				result += subtaskState.getStateSize();
+			}
+
+			ChainedStateHandle<OperatorStateHandle> partitionableState = partitionableStates.get(i);
+			if (partitionableState != null) {
+				result += partitionableState.getStateSize();
 			}
-			if (keyGroupsStateHandles.get(i) != null) {
-				result += keyGroupsStateHandles.get(i).getStateSize();
+
+			KeyGroupsStateHandle keyGroupsState = keyGroupsStateHandles.get(i);
+			if (keyGroupsState != null) {
+				result += keyGroupsState.getStateSize();
 			}
 		}
 
@@ -172,8 +218,11 @@ public class TaskState implements StateObject {
 		if (obj instanceof TaskState) {
 			TaskState other = (TaskState) obj;
 
-			return jobVertexID.equals(other.jobVertexID) && parallelism == other.parallelism &&
-				subtaskStates.equals(other.subtaskStates) && keyGroupsStateHandles.equals(other.keyGroupsStateHandles);
+			return jobVertexID.equals(other.jobVertexID)
+					&& parallelism == other.parallelism
+					&& subtaskStates.equals(other.subtaskStates)
+					&& partitionableStates.equals(other.partitionableStates)
+					&& keyGroupsStateHandles.equals(other.keyGroupsStateHandles);
 		} else {
 			return false;
 		}
@@ -181,13 +230,7 @@ public class TaskState implements StateObject {
 
 	@Override
 	public int hashCode() {
-		return parallelism + 31 * Objects.hash(jobVertexID, subtaskStates, keyGroupsStateHandles);
-	}
-
-	@Override
-	public void close() throws IOException {
-		StateUtil.bestEffortCloseAllStateObjects(
-				Iterables.concat(subtaskStates.values(), keyGroupsStateHandles.values()));
+		return parallelism + 31 * Objects.hash(jobVertexID, subtaskStates, partitionableStates, keyGroupsStateHandles);
 	}
 
 	public Map<Integer, SubtaskState> getSubtaskStates() {
@@ -197,4 +240,8 @@ public class TaskState implements StateObject {
 	public Map<Integer, KeyGroupsStateHandle> getKeyGroupsStateHandles() {
 		return Collections.unmodifiableMap(keyGroupsStateHandles);
 	}
+
+	public Map<Integer, ChainedStateHandle<OperatorStateHandle>> getPartitionableStates() {
+		return partitionableStates;
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
index f07f44f..536062a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
@@ -26,6 +26,7 @@ import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.filesystem.FileStateHandle;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
@@ -35,6 +36,7 @@ import java.io.DataOutputStream;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collection;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 
@@ -51,6 +53,7 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV1> {
 	private static final byte BYTE_STREAM_STATE_HANDLE = 1;
 	private static final byte FILE_STREAM_STATE_HANDLE = 2;
 	private static final byte KEY_GROUPS_HANDLE = 3;
+	private static final byte PARTITIONABLE_OPERATOR_STATE_HANDLE = 4;
 
 
 	public static final SavepointV1Serializer INSTANCE = new SavepointV1Serializer();
@@ -75,8 +78,9 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV1> {
 				int parallelism = taskState.getParallelism();
 				dos.writeInt(parallelism);
 				dos.writeInt(taskState.getMaxParallelism());
+				dos.writeInt(taskState.getChainLength());
 
-				// Sub task states
+				// Sub task non-partitionable states
 				Map<Integer, SubtaskState> subtaskStateMap = taskState.getSubtaskStates();
 				dos.writeInt(subtaskStateMap.size());
 				for (Map.Entry<Integer, SubtaskState> entry : subtaskStateMap.entrySet()) {
@@ -93,7 +97,22 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV1> {
 					dos.writeLong(subtaskState.getDuration());
 				}
 
+				// Sub task partitionable states
+				Map<Integer, ChainedStateHandle<OperatorStateHandle>> partitionableStatesMap = taskState.getPartitionableStates();
+				dos.writeInt(partitionableStatesMap.size());
 
+				for (Map.Entry<Integer, ChainedStateHandle<OperatorStateHandle>> entry : partitionableStatesMap.entrySet()) {
+					dos.writeInt(entry.getKey());
+
+					ChainedStateHandle<OperatorStateHandle> chainedStateHandle = entry.getValue();
+					dos.writeInt(chainedStateHandle.getLength());
+					for (int j = 0; j < chainedStateHandle.getLength(); ++j) {
+						OperatorStateHandle stateHandle = chainedStateHandle.get(j);
+						serializePartitionableStateHandle(stateHandle, dos);
+					}
+				}
+
+				// Keyed state
 				Map<Integer, KeyGroupsStateHandle> keyGroupsStateHandles = taskState.getKeyGroupsStateHandles();
 				dos.writeInt(keyGroupsStateHandles.size());
 				for (Map.Entry<Integer, KeyGroupsStateHandle> entry : keyGroupsStateHandles.entrySet()) {
@@ -119,9 +138,10 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV1> {
 			JobVertexID jobVertexId = new JobVertexID(dis.readLong(), dis.readLong());
 			int parallelism = dis.readInt();
 			int maxParallelism = dis.readInt();
+			int chainLength = dis.readInt();
 
 			// Add task state
-			TaskState taskState = new TaskState(jobVertexId, parallelism, maxParallelism);
+			TaskState taskState = new TaskState(jobVertexId, parallelism, maxParallelism, chainLength);
 			taskStates.add(taskState);
 
 			// Sub task states
@@ -142,6 +162,24 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV1> {
 				taskState.putState(subtaskIndex, subtaskState);
 			}
 
+			int numPartitionableOpStates = dis.readInt();
+
+			for (int j = 0; j < numPartitionableOpStates; j++) {
+				int subtaskIndex = dis.readInt();
+				int chainedStateHandleSize = dis.readInt();
+				List<OperatorStateHandle> streamStateHandleList = new ArrayList<>(chainedStateHandleSize);
+
+				for (int k = 0; k < chainedStateHandleSize; ++k) {
+					OperatorStateHandle streamStateHandle = deserializePartitionableStateHandle(dis);
+					streamStateHandleList.add(streamStateHandle);
+				}
+
+				ChainedStateHandle<OperatorStateHandle> chainedStateHandle =
+						new ChainedStateHandle<>(streamStateHandleList);
+
+				taskState.putPartitionableState(subtaskIndex, chainedStateHandle);
+			}
+
 			// Key group states
 			int numKeyGroupStates = dis.readInt();
 			for (int j = 0; j < numKeyGroupStates; j++) {
@@ -157,7 +195,9 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV1> {
 		return new SavepointV1(checkpointId, taskStates);
 	}
 
-	public static void serializeKeyGroupStateHandle(KeyGroupsStateHandle stateHandle, DataOutputStream dos) throws IOException {
+	public static void serializeKeyGroupStateHandle(
+			KeyGroupsStateHandle stateHandle, DataOutputStream dos) throws IOException {
+
 		if (stateHandle != null) {
 			dos.writeByte(KEY_GROUPS_HANDLE);
 			dos.writeInt(stateHandle.getGroupRangeOffsets().getKeyGroupRange().getStartKeyGroup());
@@ -172,10 +212,10 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV1> {
 	}
 
 	public static KeyGroupsStateHandle deserializeKeyGroupStateHandle(DataInputStream dis) throws IOException {
-		int type = dis.readByte();
+		final int type = dis.readByte();
 		if (NULL_HANDLE == type) {
 			return null;
-		} else {
+		} else if (KEY_GROUPS_HANDLE == type) {
 			int startKeyGroup = dis.readInt();
 			int numKeyGroups = dis.readInt();
 			KeyGroupRange keyGroupRange = KeyGroupRange.of(startKeyGroup, startKeyGroup + numKeyGroups - 1);
@@ -186,6 +226,53 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV1> {
 			KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(keyGroupRange, offsets);
 			StreamStateHandle stateHandle = deserializeStreamStateHandle(dis);
 			return new KeyGroupsStateHandle(keyGroupRangeOffsets, stateHandle);
+		} else {
+			throw new IllegalStateException("Reading invalid KeyGroupsStateHandle, type: " + type);
+		}
+	}
+
+	public static void serializePartitionableStateHandle(
+			OperatorStateHandle stateHandle, DataOutputStream dos) throws IOException {
+
+		if (stateHandle != null) {
+			dos.writeByte(PARTITIONABLE_OPERATOR_STATE_HANDLE);
+			Map<String, long[]> partitionOffsetsMap = stateHandle.getStateNameToPartitionOffsets();
+			dos.writeInt(partitionOffsetsMap.size());
+			for (Map.Entry<String, long[]> entry : partitionOffsetsMap.entrySet()) {
+				dos.writeUTF(entry.getKey());
+				long[] offsets = entry.getValue();
+				dos.writeInt(offsets.length);
+				for (int i = 0; i < offsets.length; ++i) {
+					dos.writeLong(offsets[i]);
+				}
+			}
+			serializeStreamStateHandle(stateHandle.getDelegateStateHandle(), dos);
+		} else {
+			dos.writeByte(NULL_HANDLE);
+		}
+	}
+
+	public static OperatorStateHandle deserializePartitionableStateHandle(
+			DataInputStream dis) throws IOException {
+
+		final int type = dis.readByte();
+		if (NULL_HANDLE == type) {
+			return null;
+		} else if (PARTITIONABLE_OPERATOR_STATE_HANDLE == type) {
+			int mapSize = dis.readInt();
+			Map<String, long[]> offsetsMap = new HashMap<>(mapSize);
+			for (int i = 0; i < mapSize; ++i) {
+				String key = dis.readUTF();
+				long[] offsets = new long[dis.readInt()];
+				for (int j = 0; j < offsets.length; ++j) {
+					offsets[j] = dis.readLong();
+				}
+				offsetsMap.put(key, offsets);
+			}
+			StreamStateHandle stateHandle = deserializeStreamStateHandle(dis);
+			return new OperatorStateHandle(stateHandle, offsetsMap);
+		} else {
+			throw new IllegalStateException("Reading invalid OperatorStateHandle, type: " + type);
 		}
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java
index ca976e4..7bbdb2a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java
@@ -27,6 +27,7 @@ import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.util.SerializedValue;
 
@@ -100,6 +101,8 @@ public final class TaskDeploymentDescriptor implements Serializable {
 	/** Handle to the key-grouped state of the head operator in the chain */
 	private final List<KeyGroupsStateHandle> keyGroupState;
 
+	private final List<Collection<OperatorStateHandle>> partitionableOperatorState;
+
 	/** The execution configuration (see {@link ExecutionConfig}) related to the specific job. */
 	private final SerializedValue<ExecutionConfig> serializedExecutionConfig;
 
@@ -107,26 +110,27 @@ public final class TaskDeploymentDescriptor implements Serializable {
 	 * Constructs a task deployment descriptor.
 	 */
 	public TaskDeploymentDescriptor(
-			JobID jobID,
-			String jobName,
-			JobVertexID vertexID,
-			ExecutionAttemptID executionId,
-			SerializedValue<ExecutionConfig> serializedExecutionConfig,
-			String taskName,
-			int numberOfKeyGroups,
-			int indexInSubtaskGroup,
-			int numberOfSubtasks,
-			int attemptNumber,
-			Configuration jobConfiguration,
-			Configuration taskConfiguration,
-			String invokableClassName,
-			List<ResultPartitionDeploymentDescriptor> producedPartitions,
-			List<InputGateDeploymentDescriptor> inputGates,
-			List<BlobKey> requiredJarFiles,
-			List<URL> requiredClasspaths,
-			int targetSlotNumber,
-			ChainedStateHandle<StreamStateHandle> operatorState,
-			List<KeyGroupsStateHandle> keyGroupState) {
+		JobID jobID,
+		String jobName,
+		JobVertexID vertexID,
+		ExecutionAttemptID executionId,
+		SerializedValue<ExecutionConfig> serializedExecutionConfig,
+		String taskName,
+		int numberOfKeyGroups,
+		int indexInSubtaskGroup,
+		int numberOfSubtasks,
+		int attemptNumber,
+		Configuration jobConfiguration,
+		Configuration taskConfiguration,
+		String invokableClassName,
+		List<ResultPartitionDeploymentDescriptor> producedPartitions,
+		List<InputGateDeploymentDescriptor> inputGates,
+		List<BlobKey> requiredJarFiles,
+		List<URL> requiredClasspaths,
+		int targetSlotNumber,
+		ChainedStateHandle<StreamStateHandle> operatorState,
+		List<KeyGroupsStateHandle> keyGroupState,
+		List<Collection<OperatorStateHandle>> partitionableOperatorStateHandles) {
 
 		checkArgument(indexInSubtaskGroup >= 0);
 		checkArgument(numberOfSubtasks > indexInSubtaskGroup);
@@ -153,6 +157,7 @@ public final class TaskDeploymentDescriptor implements Serializable {
 		this.targetSlotNumber = targetSlotNumber;
 		this.operatorState = operatorState;
 		this.keyGroupState = keyGroupState;
+		this.partitionableOperatorState = partitionableOperatorStateHandles;
 	}
 
 	public TaskDeploymentDescriptor(
@@ -195,6 +200,7 @@ public final class TaskDeploymentDescriptor implements Serializable {
 			requiredClasspaths,
 			targetSlotNumber,
 			null,
+			null,
 			null);
 	}
 
@@ -347,4 +353,8 @@ public final class TaskDeploymentDescriptor implements Serializable {
 	public List<KeyGroupsStateHandle> getKeyGroupState() {
 		return keyGroupState;
 	}
+
+	public List<Collection<OperatorStateHandle>> getPartitionableOperatorState() {
+		return partitionableOperatorState;
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java
index 273c0d9..f6cde95 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java
@@ -34,13 +34,10 @@ import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
 import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.KvState;
-import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
 
-import java.util.List;
 import java.util.Map;
 import java.util.concurrent.Future;
 
@@ -187,12 +184,8 @@ public interface Environment {
 	 * the checkpoint with the give checkpoint-ID. This method does include
 	 * the given state in the checkpoint.
 	 *
-	 * @param checkpointId
-	 *             The ID of the checkpoint.
-	 * @param chainedStateHandle
-	 *             Handle for the chained operator state
-	 * @param keyGroupStateHandles
-	 *             Handles for key group state
+	 * @param checkpointId The ID of the checkpoint.
+	 * @param checkpointStateHandles All state handles for the checkpointed state
 	 * @param synchronousDurationMillis
 	 *             The duration (in milliseconds) of the synchronous part of the operator checkpoint
 	 * @param asynchronousDurationMillis
@@ -204,8 +197,7 @@ public interface Environment {
 	 */
 	void acknowledgeCheckpoint(
 			long checkpointId,
-			ChainedStateHandle<StreamStateHandle> chainedStateHandle,
-			List<KeyGroupsStateHandle> keyGroupStateHandles,
+			CheckpointStateHandles checkpointStateHandles,
 			long synchronousDurationMillis,
 			long asynchronousDurationMillis,
 			long bytesBufferedInAlignment,

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
index 912ff10..b92e3af 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
@@ -33,6 +33,7 @@ import org.apache.flink.runtime.deployment.PartialInputChannelDeploymentDescript
 import org.apache.flink.runtime.deployment.ResultPartitionLocation;
 import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor;
 import org.apache.flink.runtime.execution.ExecutionState;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.instance.ActorGateway;
 import org.apache.flink.runtime.instance.SimpleSlot;
 import org.apache.flink.runtime.instance.SlotProvider;
@@ -46,6 +47,7 @@ import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup;
 import org.apache.flink.runtime.messages.Messages;
 import org.apache.flink.runtime.messages.TaskMessages.TaskOperationResult;
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.taskmanager.TaskManagerLocation;
@@ -56,6 +58,7 @@ import scala.concurrent.ExecutionContext;
 import scala.concurrent.duration.FiniteDuration;
 
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.Callable;
@@ -134,6 +137,8 @@ public class Execution {
 
 	private ChainedStateHandle<StreamStateHandle> chainedStateHandle;
 
+	private List<Collection<OperatorStateHandle>> chainedPartitionableStateHandle;
+
 	private List<KeyGroupsStateHandle> keyGroupsStateHandles;
 	
 
@@ -223,6 +228,10 @@ public class Execution {
 		return keyGroupsStateHandles;
 	}
 
+	public List<Collection<OperatorStateHandle>> getChainedPartitionableStateHandle() {
+		return chainedPartitionableStateHandle;
+	}
+
 	public boolean isFinished() {
 		return state.isTerminal();
 	}
@@ -246,18 +255,19 @@ public class Execution {
 	 * Sets the initial state for the execution. The serialized state is then shipped via the
 	 * {@link TaskDeploymentDescriptor} to the TaskManagers.
 	 *
-	 * @param chainedStateHandle Chained operator state
-	 * @param keyGroupsStateHandles Key-group state (= partitioned state)
+	 * @param checkpointStateHandles all checkpointed operator state
 	 */
-	public void setInitialState(
-		ChainedStateHandle<StreamStateHandle> chainedStateHandle,
-			List<KeyGroupsStateHandle> keyGroupsStateHandles) {
+	public void setInitialState(CheckpointStateHandles checkpointStateHandles, List<Collection<OperatorStateHandle>> chainedPartitionableStateHandle) {
 
 		if (state != ExecutionState.CREATED) {
 			throw new IllegalArgumentException("Can only assign operator state when execution attempt is in CREATED");
 		}
-		this.chainedStateHandle = chainedStateHandle;
-		this.keyGroupsStateHandles = keyGroupsStateHandles;
+
+		if(checkpointStateHandles != null) {
+			this.chainedStateHandle = checkpointStateHandles.getNonPartitionedStateHandles();
+			this.chainedPartitionableStateHandle = chainedPartitionableStateHandle;
+			this.keyGroupsStateHandles = checkpointStateHandles.getKeyGroupsStateHandle();
+		}
 	}
 
 	// --------------------------------------------------------------------------------------------
@@ -385,6 +395,7 @@ public class Execution {
 				slot,
 				chainedStateHandle,
 				keyGroupsStateHandles,
+				chainedPartitionableStateHandle,
 				attemptNumber);
 
 			// register this execution at the execution graph, to receive call backs

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
index 7c3fa0b..6023205 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
@@ -56,10 +56,8 @@ import org.apache.flink.runtime.util.SerializableObject;
 import org.apache.flink.runtime.util.SerializedThrowable;
 import org.apache.flink.util.ExceptionUtils;
 import org.apache.flink.util.SerializedValue;
-
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
-
 import scala.concurrent.ExecutionContext;
 import scala.concurrent.duration.FiniteDuration;
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
index a8d5ee4..4837803 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
@@ -28,6 +28,7 @@ import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor;
 import org.apache.flink.runtime.execution.ExecutionState;
 import org.apache.flink.runtime.instance.SlotProvider;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.taskmanager.TaskManagerLocation;
 import org.apache.flink.runtime.instance.ActorGateway;
 import org.apache.flink.runtime.instance.SimpleSlot;
@@ -53,6 +54,7 @@ import scala.concurrent.duration.FiniteDuration;
 import java.io.Serializable;
 import java.net.URL;
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.LinkedHashMap;
@@ -629,6 +631,7 @@ public class ExecutionVertex {
 			SimpleSlot targetSlot,
 			ChainedStateHandle<StreamStateHandle> operatorState,
 			List<KeyGroupsStateHandle> keyGroupStates,
+			List<Collection<OperatorStateHandle>> partitionableOperatorStateHandle,
 			int attemptNumber) {
 
 		// Produced intermediate results
@@ -681,7 +684,8 @@ public class ExecutionVertex {
 			classpaths,
 			targetSlot.getRoot().getSlotNumber(),
 			operatorState,
-			keyGroupStates);
+			keyGroupStates,
+			partitionableOperatorStateHandle);
 	}
 
 	// --------------------------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java
index 9ddfdf7..55e3e09 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/StatefulTask.java
@@ -20,8 +20,10 @@ package org.apache.flink.runtime.jobgraph.tasks;
 
 import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 
+import java.util.Collection;
 import java.util.List;
 
 /**
@@ -33,11 +35,16 @@ public interface StatefulTask {
 	/**
 	 * Sets the initial state of the operator, upon recovery. The initial state is typically
 	 * a snapshot of the state from a previous execution.
-	 * 
+	 *
+	 * TODO this should use @{@link org.apache.flink.runtime.state.CheckpointStateHandles} after redoing chained state.
+	 *
 	 * @param chainedState Handle for the chained operator states.
 	 * @param keyGroupsState Handle for key group states.
 	 */
-	void setInitialState(ChainedStateHandle<StreamStateHandle> chainedState, List<KeyGroupsStateHandle> keyGroupsState) throws Exception;
+	void setInitialState(
+		ChainedStateHandle<StreamStateHandle> chainedState,
+		List<KeyGroupsStateHandle> keyGroupsState,
+		List<Collection<OperatorStateHandle>> partitionableOperatorState) throws Exception;
 
 	/**
 	 * This method is called to trigger a checkpoint, asynchronously by the checkpoint

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/messages/checkpoint/AcknowledgeCheckpoint.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/messages/checkpoint/AcknowledgeCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/messages/checkpoint/AcknowledgeCheckpoint.java
index 72396eb..e95e7b3 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/messages/checkpoint/AcknowledgeCheckpoint.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/messages/checkpoint/AcknowledgeCheckpoint.java
@@ -20,11 +20,7 @@ package org.apache.flink.runtime.messages.checkpoint;
 
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
-import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.StreamStateHandle;
-
-import java.util.List;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 
 import static org.apache.flink.util.Preconditions.checkArgument;
 
@@ -32,7 +28,7 @@ import static org.apache.flink.util.Preconditions.checkArgument;
  * This message is sent from the {@link org.apache.flink.runtime.taskmanager.TaskManager} to the
  * {@link org.apache.flink.runtime.jobmanager.JobManager} to signal that the checkpoint of an
  * individual task is completed.
- * 
+ * <p>
  * <p>This message may carry the handle to the task's chained operator state and the key group
  * state.
  */
@@ -40,9 +36,8 @@ public class AcknowledgeCheckpoint extends AbstractCheckpointMessage implements
 
 	private static final long serialVersionUID = -7606214777192401493L;
 
-	private final ChainedStateHandle<StreamStateHandle> stateHandle;
 
-	private final List<KeyGroupsStateHandle> keyGroupsStateHandle;
+	private final CheckpointStateHandles checkpointStateHandles;
 
 	/** The duration (in milliseconds) that the synchronous part of the checkpoint took */
 	private final long synchronousDurationMillis;
@@ -62,24 +57,22 @@ public class AcknowledgeCheckpoint extends AbstractCheckpointMessage implements
 			JobID job,
 			ExecutionAttemptID taskExecutionId,
 			long checkpointId) {
-		this(job, taskExecutionId, checkpointId, null, null);
+		this(job, taskExecutionId, checkpointId, null);
 	}
 
 	public AcknowledgeCheckpoint(
 			JobID job,
 			ExecutionAttemptID taskExecutionId,
 			long checkpointId,
-			ChainedStateHandle<StreamStateHandle> state,
-			List<KeyGroupsStateHandle> keyGroupStateAndSizes) {
-		this(job, taskExecutionId, checkpointId, state, keyGroupStateAndSizes, -1L, -1L, -1L, -1L);
+			CheckpointStateHandles checkpointStateHandles) {
+		this(job, taskExecutionId, checkpointId, checkpointStateHandles, -1L, -1L, -1L, -1L);
 	}
 
 	public AcknowledgeCheckpoint(
 			JobID job,
 			ExecutionAttemptID taskExecutionId,
 			long checkpointId,
-			ChainedStateHandle<StreamStateHandle> state,
-			List<KeyGroupsStateHandle> keyGroupStateAndSizes,
+			CheckpointStateHandles checkpointStateHandles,
 			long synchronousDurationMillis,
 			long asynchronousDurationMillis,
 			long bytesBufferedInAlignment,
@@ -87,9 +80,7 @@ public class AcknowledgeCheckpoint extends AbstractCheckpointMessage implements
 
 		super(job, taskExecutionId, checkpointId);
 
-		// these may be null in cases where the operator has no state
-		this.stateHandle = state;
-		this.keyGroupsStateHandle = keyGroupStateAndSizes;
+		this.checkpointStateHandles = checkpointStateHandles;
 
 		// these may be "-1", in case the values are unknown or not set
 		checkArgument(synchronousDurationMillis >= -1);
@@ -107,12 +98,8 @@ public class AcknowledgeCheckpoint extends AbstractCheckpointMessage implements
 	//  properties
 	// ------------------------------------------------------------------------
 
-	public ChainedStateHandle<StreamStateHandle> getStateHandle() {
-		return stateHandle;
-	}
-
-	public List<KeyGroupsStateHandle> getKeyGroupsStateHandle() {
-		return keyGroupsStateHandle;
+	public CheckpointStateHandles getCheckpointStateHandles() {
+		return checkpointStateHandles;
 	}
 
 	public long getSynchronousDurationMillis() {
@@ -134,31 +121,33 @@ public class AcknowledgeCheckpoint extends AbstractCheckpointMessage implements
 	// --------------------------------------------------------------------------------------------
 
 	@Override
-	public int hashCode() {
-		return super.hashCode();
-	}
-
-	@Override
 	public boolean equals(Object o) {
 		if (this == o) {
-			return true ;
+			return true;
 		}
-		else if (o instanceof AcknowledgeCheckpoint) {
-			AcknowledgeCheckpoint that = (AcknowledgeCheckpoint) o;
-			return super.equals(o) &&
-					(this.stateHandle == null ? that.stateHandle == null : 
-							(that.stateHandle != null && this.stateHandle.equals(that.stateHandle))) &&
-					(this.keyGroupsStateHandle == null ? that.keyGroupsStateHandle == null : 
-							(that.keyGroupsStateHandle != null && this.keyGroupsStateHandle.equals(that.keyGroupsStateHandle)));
+		if (!(o instanceof AcknowledgeCheckpoint)) {
+			return false;
 		}
-		else {
+		if (!super.equals(o)) {
 			return false;
 		}
+
+		AcknowledgeCheckpoint that = (AcknowledgeCheckpoint) o;
+		return checkpointStateHandles != null ?
+				checkpointStateHandles.equals(that.checkpointStateHandles) : that.checkpointStateHandles == null;
+
+	}
+
+	@Override
+	public int hashCode() {
+		int result = super.hashCode();
+		result = 31 * result + (checkpointStateHandles != null ? checkpointStateHandles.hashCode() : 0);
+		return result;
 	}
 
 	@Override
 	public String toString() {
-		return String.format("Confirm Task Checkpoint %d for (%s/%s) - state=%s keyGroupState=%s",
-				getCheckpointId(), getJob(), getTaskExecutionId(), stateHandle, keyGroupsStateHandle);
+		return String.format("Confirm Task Checkpoint %d for (%s/%s) - state=%s",
+				getCheckpointId(), getJob(), getTaskExecutionId(), checkpointStateHandles);
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractCloseableHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractCloseableHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractCloseableHandle.java
deleted file mode 100644
index 5966c95..0000000
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractCloseableHandle.java
+++ /dev/null
@@ -1,126 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.runtime.state;
-
-import java.io.Closeable;
-import java.io.IOException;
-import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
-
-/**
- * A simple base for closable handles.
- *
- * Offers to register a stream (or other closable object) that close calls are delegated to if
- * the handle is closed or was already closed.
- */
-public abstract class AbstractCloseableHandle implements Closeable, StateObject {
-
-	/** Serial Version UID must be constant to maintain format compatibility */
-	private static final long serialVersionUID = 1L;
-
-	/** To atomically update the "closable" field without needing to add a member class like "AtomicBoolean */
-	private static final AtomicIntegerFieldUpdater<AbstractCloseableHandle> CLOSER =
-			AtomicIntegerFieldUpdater.newUpdater(AbstractCloseableHandle.class, "isClosed");
-
-	// ------------------------------------------------------------------------
-
-	/** The closeable to close if this handle is closed late */
-	private transient volatile Closeable toClose;
-
-	/** Flag to remember if this handle was already closed */
-	@SuppressWarnings("unused") // this field is actually updated, but via the "CLOSER" updater
-	private transient volatile int isClosed;
-
-	// ------------------------------------------------------------------------
-
-	protected final void registerCloseable(Closeable toClose) throws IOException {
-		if (toClose == null) {
-			return;
-		}
-
-		// NOTE: The order of operations matters here:
-		// (1) first setting the closeable
-		// (2) checking the flag.
-		// Because the order in the {@link #close()} method is the opposite, and
-		// both variables are volatile (reordering barriers), we can be sure that
-		// one of the methods always notices the effect of a concurrent call to the
-		// other method.
-
-		this.toClose = toClose;
-
-		// check if we were closed early
-		if (this.isClosed != 0) {
-			toClose.close();
-			throw new IOException("handle is closed");
-		}
-	}
-
-	/**
-	 * Closes the handle.
-	 *
-	 * <p>If a "Closeable" has been registered via {@link #registerCloseable(Closeable)},
-	 * then this will be closes.
-	 *
-	 * <p>If any "Closeable" will be registered via {@link #registerCloseable(Closeable)} in the future,
-	 * it will immediately be closed and that method will throw an exception.
-	 *
-	 * @throws IOException Exceptions occurring while closing an already registered {@code Closeable}
-	 *                     are forwarded.
-	 *
-	 * @see #registerCloseable(Closeable)
-	 */
-	@Override
-	public final void close() throws IOException {
-		// NOTE: The order of operations matters here:
-		// (1) first setting the closed flag
-		// (2) checking whether there is already a closeable
-		// Because the order in the {@link #registerCloseable(Closeable)} method is the opposite, and
-		// both variables are volatile (reordering barriers), we can be sure that
-		// one of the methods always notices the effect of a concurrent call to the
-		// other method.
-
-		if (CLOSER.compareAndSet(this, 0, 1)) {
-			final Closeable toClose = this.toClose;
-			if (toClose != null) {
-				this.toClose = null;
-				toClose.close();
-			}
-		}
-	}
-
-	/**
-	 * Checks whether this handle has been closed.
-	 *
-	 * @return True is the handle is closed, false otherwise.
-	 */
-	public boolean isClosed() {
-		return isClosed != 0;
-	}
-
-	/**
-	 * This method checks whether the handle is closed and throws an exception if it is closed.
-	 * If the handle is not closed, this method does nothing.
-	 *
-	 * @throws IOException Thrown, if the handle has been closed.
-	 */
-	public void ensureNotClosed() throws IOException {
-		if (isClosed != 0) {
-			throw new IOException("handle is closed");
-		}
-	}
-}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
new file mode 100644
index 0000000..7ca3b38
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
@@ -0,0 +1,342 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.FoldingState;
+import org.apache.flink.api.common.state.FoldingStateDescriptor;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.MergingState;
+import org.apache.flink.api.common.state.ReducingState;
+import org.apache.flink.api.common.state.ReducingStateDescriptor;
+import org.apache.flink.api.common.state.State;
+import org.apache.flink.api.common.state.StateBackend;
+import org.apache.flink.api.common.state.StateDescriptor;
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.query.TaskKvStateRegistry;
+import org.apache.flink.util.Preconditions;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+
+/**
+ * Base implementation of KeyedStateBackend. The state can be checkpointed
+ * to streams using {@link #snapshot(long, long, CheckpointStreamFactory)}.
+ *
+ * @param <K> Type of the key by which state is keyed.
+ */
+public abstract class AbstractKeyedStateBackend<K>
+		implements KeyedStateBackend<K>, SnapshotProvider<KeyGroupsStateHandle>, Closeable {
+
+	/** {@link TypeSerializer} for our key. */
+	protected final TypeSerializer<K> keySerializer;
+
+	/** The currently active key. */
+	protected K currentKey;
+
+	/** The key group of the currently active key */
+	private int currentKeyGroup;
+
+	/** So that we can give out state when the user uses the same key. */
+	protected HashMap<String, KvState<?>> keyValueStatesByName;
+
+	/** For caching the last accessed partitioned state */
+	private String lastName;
+
+	@SuppressWarnings("rawtypes")
+	private KvState lastState;
+
+	/** The number of key-groups aka max parallelism */
+	protected final int numberOfKeyGroups;
+
+	/** Range of key-groups for which this backend is responsible */
+	protected final KeyGroupRange keyGroupRange;
+
+	/** KvStateRegistry helper for this task */
+	protected final TaskKvStateRegistry kvStateRegistry;
+
+	/** Registry for all opened streams, so they can be closed if the task using this backend is closed */
+	protected ClosableRegistry cancelStreamRegistry;
+
+	protected final ClassLoader userCodeClassLoader;
+
+	public AbstractKeyedStateBackend(
+			TaskKvStateRegistry kvStateRegistry,
+			TypeSerializer<K> keySerializer,
+			ClassLoader userCodeClassLoader,
+			int numberOfKeyGroups,
+			KeyGroupRange keyGroupRange) {
+
+		this.kvStateRegistry = Preconditions.checkNotNull(kvStateRegistry);
+		this.keySerializer = Preconditions.checkNotNull(keySerializer);
+		this.numberOfKeyGroups = Preconditions.checkNotNull(numberOfKeyGroups);
+		this.userCodeClassLoader = Preconditions.checkNotNull(userCodeClassLoader);
+		this.keyGroupRange = Preconditions.checkNotNull(keyGroupRange);
+		this.cancelStreamRegistry = new ClosableRegistry();
+	}
+
+	/**
+	 * Closes the state backend, releasing all internal resources, but does not delete any persistent
+	 * checkpoint data.
+	 *
+	 */
+	@Override
+	public void dispose() {
+		if (kvStateRegistry != null) {
+			kvStateRegistry.unregisterAll();
+		}
+
+		lastName = null;
+		lastState = null;
+		keyValueStatesByName = null;
+	}
+
+	/**
+	 * Creates and returns a new {@link ValueState}.
+	 *
+	 * @param namespaceSerializer TypeSerializer for the state namespace.
+	 * @param stateDesc The {@code StateDescriptor} that contains the name of the state.
+	 *
+	 * @param <N> The type of the namespace.
+	 * @param <T> The type of the value that the {@code ValueState} can store.
+	 */
+	protected abstract <N, T> ValueState<T> createValueState(TypeSerializer<N> namespaceSerializer, ValueStateDescriptor<T> stateDesc) throws Exception;
+
+	/**
+	 * Creates and returns a new {@link ListState}.
+	 *
+	 * @param namespaceSerializer TypeSerializer for the state namespace.
+	 * @param stateDesc The {@code StateDescriptor} that contains the name of the state.
+	 *
+	 * @param <N> The type of the namespace.
+	 * @param <T> The type of the values that the {@code ListState} can store.
+	 */
+	protected abstract <N, T> ListState<T> createListState(TypeSerializer<N> namespaceSerializer, ListStateDescriptor<T> stateDesc) throws Exception;
+
+	/**
+	 * Creates and returns a new {@link ReducingState}.
+	 *
+	 * @param namespaceSerializer TypeSerializer for the state namespace.
+	 * @param stateDesc The {@code StateDescriptor} that contains the name of the state.
+	 *
+	 * @param <N> The type of the namespace.
+	 * @param <T> The type of the values that the {@code ListState} can store.
+	 */
+	protected abstract <N, T> ReducingState<T> createReducingState(TypeSerializer<N> namespaceSerializer, ReducingStateDescriptor<T> stateDesc) throws Exception;
+
+	/**
+	 * Creates and returns a new {@link FoldingState}.
+	 *
+	 * @param namespaceSerializer TypeSerializer for the state namespace.
+	 * @param stateDesc The {@code StateDescriptor} that contains the name of the state.
+	 *
+	 * @param <N> The type of the namespace.
+	 * @param <T> Type of the values folded into the state
+	 * @param <ACC> Type of the value in the state	 *
+	 */
+	protected abstract <N, T, ACC> FoldingState<T, ACC> createFoldingState(TypeSerializer<N> namespaceSerializer, FoldingStateDescriptor<T, ACC> stateDesc) throws Exception;
+
+	/**
+	 * @see KeyedStateBackend
+	 */
+	@Override
+	public void setCurrentKey(K newKey) {
+		this.currentKey = newKey;
+		this.currentKeyGroup = KeyGroupRangeAssignment.assignToKeyGroup(newKey, numberOfKeyGroups);
+	}
+
+	/**
+	 * @see KeyedStateBackend
+	 */
+	@Override
+	public TypeSerializer<K> getKeySerializer() {
+		return keySerializer;
+	}
+
+	/**
+	 * @see KeyedStateBackend
+	 */
+	@Override
+	public K getCurrentKey() {
+		return currentKey;
+	}
+
+	/**
+	 * @see KeyedStateBackend
+	 */
+	@Override
+	public int getCurrentKeyGroupIndex() {
+		return currentKeyGroup;
+	}
+
+	/**
+	 * @see KeyedStateBackend
+	 */
+	@Override
+	public int getNumberOfKeyGroups() {
+		return numberOfKeyGroups;
+	}
+
+	/**
+	 * @see KeyedStateBackend
+	 */
+	public KeyGroupRange getKeyGroupRange() {
+		return keyGroupRange;
+	}
+
+	/**
+	 * @see KeyedStateBackend
+	 */
+	@Override
+	@SuppressWarnings({"rawtypes", "unchecked"})
+	public <N, S extends State> S getPartitionedState(final N namespace, final TypeSerializer<N> namespaceSerializer, final StateDescriptor<S, ?> stateDescriptor) throws Exception {
+		Preconditions.checkNotNull(namespace, "Namespace");
+		Preconditions.checkNotNull(namespaceSerializer, "Namespace serializer");
+
+		if (keySerializer == null) {
+			throw new RuntimeException("State key serializer has not been configured in the config. " +
+					"This operation cannot use partitioned state.");
+		}
+		
+		if (!stateDescriptor.isSerializerInitialized()) {
+			stateDescriptor.initializeSerializerUnlessSet(new ExecutionConfig());
+		}
+
+		if (keyValueStatesByName == null) {
+			keyValueStatesByName = new HashMap<>();
+		}
+
+		if (lastName != null && lastName.equals(stateDescriptor.getName())) {
+			lastState.setCurrentNamespace(namespace);
+			return (S) lastState;
+		}
+
+		KvState<?> previous = keyValueStatesByName.get(stateDescriptor.getName());
+		if (previous != null) {
+			lastState = previous;
+			lastState.setCurrentNamespace(namespace);
+			lastName = stateDescriptor.getName();
+			return (S) previous;
+		}
+
+		// create a new blank key/value state
+		S state = stateDescriptor.bind(new StateBackend() {
+			@Override
+			public <T> ValueState<T> createValueState(ValueStateDescriptor<T> stateDesc) throws Exception {
+				return AbstractKeyedStateBackend.this.createValueState(namespaceSerializer, stateDesc);
+			}
+
+			@Override
+			public <T> ListState<T> createListState(ListStateDescriptor<T> stateDesc) throws Exception {
+				return AbstractKeyedStateBackend.this.createListState(namespaceSerializer, stateDesc);
+			}
+
+			@Override
+			public <T> ReducingState<T> createReducingState(ReducingStateDescriptor<T> stateDesc) throws Exception {
+				return AbstractKeyedStateBackend.this.createReducingState(namespaceSerializer, stateDesc);
+			}
+
+			@Override
+			public <T, ACC> FoldingState<T, ACC> createFoldingState(FoldingStateDescriptor<T, ACC> stateDesc) throws Exception {
+				return AbstractKeyedStateBackend.this.createFoldingState(namespaceSerializer, stateDesc);
+			}
+
+		});
+
+		KvState kvState = (KvState) state;
+
+		keyValueStatesByName.put(stateDescriptor.getName(), kvState);
+
+		lastName = stateDescriptor.getName();
+		lastState = kvState;
+
+		kvState.setCurrentNamespace(namespace);
+
+		// Publish queryable state
+		if (stateDescriptor.isQueryable()) {
+			if (kvStateRegistry == null) {
+				throw new IllegalStateException("State backend has not been initialized for job.");
+			}
+
+			String name = stateDescriptor.getQueryableStateName();
+			kvStateRegistry.registerKvState(keyGroupRange, name, kvState);
+		}
+
+		return state;
+	}
+
+	@Override
+	@SuppressWarnings("unchecked,rawtypes")
+	public <N, S extends MergingState<?, ?>> void mergePartitionedStates(final N target, Collection<N> sources, final TypeSerializer<N> namespaceSerializer, final StateDescriptor<S, ?> stateDescriptor) throws Exception {
+		if (stateDescriptor instanceof ReducingStateDescriptor) {
+			ReducingStateDescriptor reducingStateDescriptor = (ReducingStateDescriptor) stateDescriptor;
+			ReduceFunction reduceFn = reducingStateDescriptor.getReduceFunction();
+			ReducingState state = (ReducingState) getPartitionedState(target, namespaceSerializer, stateDescriptor);
+			KvState kvState = (KvState) state;
+			Object result = null;
+			for (N source: sources) {
+				kvState.setCurrentNamespace(source);
+				Object sourceValue = state.get();
+				if (result == null) {
+					result = state.get();
+				} else if (sourceValue != null) {
+					result = reduceFn.reduce(result, sourceValue);
+				}
+				state.clear();
+			}
+			kvState.setCurrentNamespace(target);
+			if (result != null) {
+				state.add(result);
+			}
+		} else if (stateDescriptor instanceof ListStateDescriptor) {
+			ListState<Object> state = (ListState) getPartitionedState(target, namespaceSerializer, stateDescriptor);
+			KvState kvState = (KvState) state;
+			List<Object> result = new ArrayList<>();
+			for (N source: sources) {
+				kvState.setCurrentNamespace(source);
+				Iterable<Object> sourceValue = state.get();
+				if (sourceValue != null) {
+					for (Object o : sourceValue) {
+						result.add(o);
+					}
+				}
+				state.clear();
+			}
+			kvState.setCurrentNamespace(target);
+			for (Object o : result) {
+				state.add(o);
+			}
+		} else {
+			throw new RuntimeException("Cannot merge states for " + stateDescriptor);
+		}
+	}
+
+	@Override
+	public void close() throws IOException {
+		cancelStreamRegistry.close();
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java
index 0d2bf45..c2e665b 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java
@@ -24,6 +24,7 @@ import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 
 import java.io.IOException;
+import java.util.Collection;
 import java.util.List;
 
 /**
@@ -36,31 +37,33 @@ public abstract class AbstractStateBackend implements java.io.Serializable {
 	 * Creates a {@link CheckpointStreamFactory} that can be used to create streams
 	 * that should end up in a checkpoint.
 	 *
-	 * @param jobId The {@link JobID} of the job for which we are creating checkpoint streams.
+	 * @param jobId              The {@link JobID} of the job for which we are creating checkpoint streams.
 	 * @param operatorIdentifier An identifier of the operator for which we create streams.
 	 */
 	public abstract CheckpointStreamFactory createStreamFactory(
 			JobID jobId,
-			String operatorIdentifier) throws IOException;
+			String operatorIdentifier
+	) throws IOException;
 
 	/**
-	 * Creates a new {@link KeyedStateBackend} that is responsible for keeping keyed state
+	 * Creates a new {@link AbstractKeyedStateBackend} that is responsible for keeping keyed state
 	 * and can be checkpointed to checkpoint streams.
 	 */
-	public abstract <K> KeyedStateBackend<K> createKeyedStateBackend(
+	public abstract <K> AbstractKeyedStateBackend<K> createKeyedStateBackend(
 			Environment env,
 			JobID jobID,
 			String operatorIdentifier,
 			TypeSerializer<K> keySerializer,
 			int numberOfKeyGroups,
 			KeyGroupRange keyGroupRange,
-			TaskKvStateRegistry kvStateRegistry) throws Exception;
+			TaskKvStateRegistry kvStateRegistry
+	) throws Exception;
 
 	/**
-	 * Creates a new {@link KeyedStateBackend} that restores its state from the given list
+	 * Creates a new {@link AbstractKeyedStateBackend} that restores its state from the given list
 	 * {@link KeyGroupsStateHandle KeyGroupStateHandles}.
 	 */
-	public abstract <K> KeyedStateBackend<K> restoreKeyedStateBackend(
+	public abstract <K> AbstractKeyedStateBackend<K> restoreKeyedStateBackend(
 			Environment env,
 			JobID jobID,
 			String operatorIdentifier,
@@ -68,6 +71,30 @@ public abstract class AbstractStateBackend implements java.io.Serializable {
 			int numberOfKeyGroups,
 			KeyGroupRange keyGroupRange,
 			List<KeyGroupsStateHandle> restoredState,
-			TaskKvStateRegistry kvStateRegistry) throws Exception;
+			TaskKvStateRegistry kvStateRegistry
+	) throws Exception;
 
+
+	/**
+	 * Creates a new {@link OperatorStateBackend} that can be used for storing partitionable operator
+	 * state in checkpoint streams.
+	 */
+	public OperatorStateBackend createOperatorStateBackend(
+			Environment env,
+			String operatorIdentifier
+	) throws Exception {
+		return new DefaultOperatorStateBackend();
+	}
+
+	/**
+	 * Creates a new {@link OperatorStateBackend} that restores its state from the given collection of
+	 * {@link OperatorStateHandle}.
+	 */
+	public OperatorStateBackend restoreOperatorStateBackend(
+			Environment env,
+			String operatorIdentifier,
+			Collection<OperatorStateHandle> restoreSnapshots
+	) throws Exception {
+		return new DefaultOperatorStateBackend(restoreSnapshots);
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/ChainedStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ChainedStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ChainedStateHandle.java
index 74057ee..c6904c0 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ChainedStateHandle.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ChainedStateHandle.java
@@ -26,7 +26,7 @@ import java.util.Collections;
 import java.util.List;
 
 /**
- * Handle to the non-partitioned states for the operators in an operator chain.
+ * Handle to state handles for the operators in an operator chain.
  */
 public class ChainedStateHandle<T extends StateObject> implements StateObject {
 
@@ -123,9 +123,4 @@ public class ChainedStateHandle<T extends StateObject> implements StateObject {
 	public static <T extends StateObject> ChainedStateHandle<T> wrapSingleHandle(T stateHandleToWrap) {
 		return new ChainedStateHandle<T>(Collections.singletonList(stateHandleToWrap));
 	}
-
-	@Override
-	public void close() throws IOException {
-		StateUtil.bestEffortCloseAllStateObjects(operatorStateHandles);
-	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/CheckpointStateHandles.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/CheckpointStateHandles.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CheckpointStateHandles.java
new file mode 100644
index 0000000..9daf963
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CheckpointStateHandles.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import java.io.Serializable;
+import java.util.List;
+
+/**
+ * Container state handles that contains all state handles from the different state types of a checkpointed state.
+ * TODO This will be changed in the future if we get rid of chained state and instead connect state directly to individual operators in a chain.
+ */
+public class CheckpointStateHandles implements Serializable {
+
+	private static final long serialVersionUID = 3252351989995L;
+
+	private final ChainedStateHandle<StreamStateHandle> nonPartitionedStateHandles;
+
+	private final ChainedStateHandle<OperatorStateHandle> partitioneableStateHandles;
+
+	private final List<KeyGroupsStateHandle> keyGroupsStateHandle;
+
+	public CheckpointStateHandles(
+			ChainedStateHandle<StreamStateHandle> nonPartitionedStateHandles,
+			ChainedStateHandle<OperatorStateHandle> partitioneableStateHandles,
+			List<KeyGroupsStateHandle> keyGroupsStateHandle) {
+
+		this.nonPartitionedStateHandles = nonPartitionedStateHandles;
+		this.partitioneableStateHandles = partitioneableStateHandles;
+		this.keyGroupsStateHandle = keyGroupsStateHandle;
+	}
+
+	public ChainedStateHandle<StreamStateHandle> getNonPartitionedStateHandles() {
+		return nonPartitionedStateHandles;
+	}
+
+	public ChainedStateHandle<OperatorStateHandle> getPartitioneableStateHandles() {
+		return partitioneableStateHandles;
+	}
+
+	public List<KeyGroupsStateHandle> getKeyGroupsStateHandle() {
+		return keyGroupsStateHandle;
+	}
+
+	@Override
+	public boolean equals(Object o) {
+		if (this == o) {
+			return true;
+		}
+		if (!(o instanceof CheckpointStateHandles)) {
+			return false;
+		}
+
+		CheckpointStateHandles that = (CheckpointStateHandles) o;
+
+		if (nonPartitionedStateHandles != null ?
+				!nonPartitionedStateHandles.equals(that.nonPartitionedStateHandles)
+				: that.nonPartitionedStateHandles != null) {
+			return false;
+		}
+
+		if (partitioneableStateHandles != null ?
+				!partitioneableStateHandles.equals(that.partitioneableStateHandles)
+				: that.partitioneableStateHandles != null) {
+			return false;
+		}
+		return keyGroupsStateHandle != null ?
+				keyGroupsStateHandle.equals(that.keyGroupsStateHandle) : that.keyGroupsStateHandle == null;
+
+	}
+
+	@Override
+	public int hashCode() {
+		int result = nonPartitionedStateHandles != null ? nonPartitionedStateHandles.hashCode() : 0;
+		result = 31 * result + (partitioneableStateHandles != null ? partitioneableStateHandles.hashCode() : 0);
+		result = 31 * result + (keyGroupsStateHandle != null ? keyGroupsStateHandle.hashCode() : 0);
+		return result;
+	}
+
+	@Override
+	public String toString() {
+		return "CheckpointStateHandles{" +
+				"nonPartitionedStateHandles=" + nonPartitionedStateHandles +
+				", partitioneableStateHandles=" + partitioneableStateHandles +
+				", keyGroupsStateHandle=" + keyGroupsStateHandle +
+				'}';
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/ClosableRegistry.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ClosableRegistry.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ClosableRegistry.java
new file mode 100644
index 0000000..26d6192
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ClosableRegistry.java
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.commons.io.IOUtils;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.HashSet;
+import java.util.Set;
+
+public class ClosableRegistry implements Closeable {
+
+	private final Set<Closeable> registeredCloseables;
+	private boolean closed;
+
+	public ClosableRegistry() {
+		this.registeredCloseables = new HashSet<>();
+		this.closed = false;
+	}
+
+	public boolean registerClosable(Closeable closeable) {
+
+		if (null == closeable) {
+			return false;
+		}
+
+		synchronized (getSynchronizationLock()) {
+			if (closed) {
+				throw new IllegalStateException("Cannot register Closable, registry is already closed.");
+			}
+
+			return registeredCloseables.add(closeable);
+		}
+	}
+
+	public boolean unregisterClosable(Closeable closeable) {
+
+		if (null == closeable) {
+			return false;
+		}
+
+		synchronized (getSynchronizationLock()) {
+			return registeredCloseables.remove(closeable);
+		}
+	}
+
+	@Override
+	public void close() throws IOException {
+
+		if (!registeredCloseables.isEmpty()) {
+
+			synchronized (getSynchronizationLock()) {
+
+				for (Closeable closeable : registeredCloseables) {
+					IOUtils.closeQuietly(closeable);
+				}
+
+				registeredCloseables.clear();
+				closed = true;
+			}
+		}
+	}
+
+	private Object getSynchronizationLock() {
+		return registeredCloseables;
+	}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
new file mode 100644
index 0000000..0bd5eeb
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.RunnableFuture;
+
+/**
+ * Default implementation of OperatorStateStore that provides the ability to make snapshots.
+ */
+public class DefaultOperatorStateBackend implements OperatorStateBackend {
+
+	private final Map<String, PartitionableListState<?>> registeredStates;
+	private final Collection<OperatorStateHandle> restoreSnapshots;
+	private final ClosableRegistry closeStreamOnCancelRegistry;
+
+	/**
+	 * Restores a OperatorStateStore (lazily) using the provided snapshots.
+	 *
+	 * @param restoreSnapshots snapshots that are available to restore partitionable states on request.
+	 */
+	public DefaultOperatorStateBackend(
+			Collection<OperatorStateHandle> restoreSnapshots) {
+		this.restoreSnapshots = restoreSnapshots;
+		this.registeredStates = new HashMap<>();
+		this.closeStreamOnCancelRegistry = new ClosableRegistry();
+	}
+
+	/**
+	 * Creates an empty OperatorStateStore.
+	 */
+	public DefaultOperatorStateBackend() {
+		this(null);
+	}
+
+	/**
+	 * @see OperatorStateStore
+	 */
+	@Override
+	public <S> ListState<S> getPartitionableState(
+			ListStateDescriptor<S> stateDescriptor) throws IOException {
+
+		Preconditions.checkNotNull(stateDescriptor);
+
+		String name = Preconditions.checkNotNull(stateDescriptor.getName());
+		TypeSerializer<S> partitionStateSerializer = Preconditions.checkNotNull(stateDescriptor.getSerializer());
+
+		@SuppressWarnings("unchecked")
+		PartitionableListState<S> partitionableListState = (PartitionableListState<S>) registeredStates.get(name);
+
+		if (null == partitionableListState) {
+
+			partitionableListState = new PartitionableListState<>(partitionStateSerializer);
+
+			registeredStates.put(name, partitionableListState);
+
+			// Try to restore previous state if state handles to snapshots are provided
+			if (restoreSnapshots != null) {
+				for (OperatorStateHandle stateHandle : restoreSnapshots) {
+
+					long[] offsets = stateHandle.getStateNameToPartitionOffsets().get(name);
+
+					if (offsets != null) {
+
+						FSDataInputStream in = stateHandle.openInputStream();
+						try {
+							closeStreamOnCancelRegistry.registerClosable(in);
+
+							DataInputView div = new DataInputViewStreamWrapper(in);
+
+							for (int i = 0; i < offsets.length; ++i) {
+
+								in.seek(offsets[i]);
+								S partitionState = partitionStateSerializer.deserialize(div);
+								partitionableListState.add(partitionState);
+							}
+						} finally {
+							closeStreamOnCancelRegistry.unregisterClosable(in);
+							in.close();
+						}
+					}
+				}
+			}
+		}
+
+		return partitionableListState;
+	}
+
+	/**
+	 * @see SnapshotProvider
+	 */
+	@Override
+	public RunnableFuture<OperatorStateHandle> snapshot(
+			long checkpointId, long timestamp, CheckpointStreamFactory streamFactory) throws Exception {
+
+		if (registeredStates.isEmpty()) {
+			return new DoneFuture<>(null);
+		}
+
+		Map<String, long[]> writtenStatesMetaData = new HashMap<>(registeredStates.size());
+
+		CheckpointStreamFactory.CheckpointStateOutputStream out = streamFactory.
+				createCheckpointStateOutputStream(checkpointId, timestamp);
+
+		try {
+			closeStreamOnCancelRegistry.registerClosable(out);
+
+			DataOutputView dov = new DataOutputViewStreamWrapper(out);
+
+			dov.writeInt(registeredStates.size());
+			for (Map.Entry<String, PartitionableListState<?>> entry : registeredStates.entrySet()) {
+
+				long[] partitionOffsets = entry.getValue().write(out);
+				writtenStatesMetaData.put(entry.getKey(), partitionOffsets);
+			}
+
+			OperatorStateHandle handle = new OperatorStateHandle(out.closeAndGetHandle(), writtenStatesMetaData);
+
+			return new DoneFuture<>(handle);
+		} finally {
+			closeStreamOnCancelRegistry.unregisterClosable(out);
+			out.close();
+		}
+	}
+
+	@Override
+	public void dispose() {
+
+	}
+
+	static final class PartitionableListState<S> implements ListState<S> {
+
+		private final List<S> listState;
+		private final TypeSerializer<S> partitionStateSerializer;
+
+		public PartitionableListState(TypeSerializer<S> partitionStateSerializer) {
+			this.listState = new ArrayList<>();
+			this.partitionStateSerializer = Preconditions.checkNotNull(partitionStateSerializer);
+		}
+
+		@Override
+		public void clear() {
+			listState.clear();
+		}
+
+		@Override
+		public Iterable<S> get() {
+			return listState;
+		}
+
+		@Override
+		public void add(S value) {
+			listState.add(value);
+		}
+
+		public long[] write(FSDataOutputStream out) throws IOException {
+
+			long[] partitionOffsets = new long[listState.size()];
+
+			DataOutputView dov = new DataOutputViewStreamWrapper(out);
+
+			for (int i = 0; i < listState.size(); ++i) {
+				S element = listState.get(i);
+				partitionOffsets[i] = out.getPos();
+				partitionStateSerializer.serialize(element, dov);
+			}
+
+			return partitionOffsets;
+		}
+	}
+
+	@Override
+	public Set<String> getRegisteredStateNames() {
+		return registeredStates.keySet();
+	}
+
+	@Override
+	public void close() throws IOException {
+		closeStreamOnCancelRegistry.close();
+	}
+}
+

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRangeOffsets.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRangeOffsets.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRangeOffsets.java
index 4f0a82b..8e7207e 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRangeOffsets.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupRangeOffsets.java
@@ -31,6 +31,8 @@ import java.util.Iterator;
  */
 public class KeyGroupRangeOffsets implements Iterable<Tuple2<Integer, Long>> , Serializable {
 
+	private static final long serialVersionUID = 6595415219136429696L;
+
 	/** the range of key-groups */
 	private final KeyGroupRange keyGroupRange;
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java
index 7f87e86..ea12808 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupsStateHandle.java
@@ -138,7 +138,6 @@ public class KeyGroupsStateHandle implements StateObject {
 			return false;
 		}
 		return stateHandle.equals(that.stateHandle);
-
 	}
 
 	@Override
@@ -155,9 +154,4 @@ public class KeyGroupsStateHandle implements StateObject {
 				", data=" + stateHandle +
 				'}';
 	}
-
-	@Override
-	public void close() throws IOException {
-		stateHandle.close();
-	}
 }


[04/10] flink git commit: [FLINK-4379] [checkpoints] Introduce rescalable operator state

Posted by se...@apache.org.
http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
index 73e2808..2f21574 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
@@ -80,11 +80,11 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		return getStateBackend().createStreamFactory(new JobID(), "test_op");
 	}
 
-	protected <K> KeyedStateBackend<K> createKeyedBackend(TypeSerializer<K> keySerializer) throws Exception {
+	protected <K> AbstractKeyedStateBackend<K> createKeyedBackend(TypeSerializer<K> keySerializer) throws Exception {
 		return createKeyedBackend(keySerializer, new DummyEnvironment("test", 1, 0));
 	}
 
-	protected <K> KeyedStateBackend<K> createKeyedBackend(TypeSerializer<K> keySerializer, Environment env) throws Exception {
+	protected <K> AbstractKeyedStateBackend<K> createKeyedBackend(TypeSerializer<K> keySerializer, Environment env) throws Exception {
 		return createKeyedBackend(
 				keySerializer,
 				10,
@@ -92,7 +92,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 				env);
 	}
 
-	protected <K> KeyedStateBackend<K> createKeyedBackend(
+	protected <K> AbstractKeyedStateBackend<K> createKeyedBackend(
 			TypeSerializer<K> keySerializer,
 			int numberOfKeyGroups,
 			KeyGroupRange keyGroupRange,
@@ -104,14 +104,15 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 				keySerializer,
 				numberOfKeyGroups,
 				keyGroupRange,
-				env.getTaskKvStateRegistry());
+				env.getTaskKvStateRegistry())
+;
 	}
 
-	protected <K> KeyedStateBackend<K> restoreKeyedBackend(TypeSerializer<K> keySerializer, KeyGroupsStateHandle state) throws Exception {
+	protected <K> AbstractKeyedStateBackend<K> restoreKeyedBackend(TypeSerializer<K> keySerializer, KeyGroupsStateHandle state) throws Exception {
 		return restoreKeyedBackend(keySerializer, state, new DummyEnvironment("test", 1, 0));
 	}
 
-	protected <K> KeyedStateBackend<K> restoreKeyedBackend(
+	protected <K> AbstractKeyedStateBackend<K> restoreKeyedBackend(
 			TypeSerializer<K> keySerializer,
 			KeyGroupsStateHandle state,
 			Environment env) throws Exception {
@@ -123,7 +124,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 				env);
 	}
 
-	protected <K> KeyedStateBackend<K> restoreKeyedBackend(
+	protected <K> AbstractKeyedStateBackend<K> restoreKeyedBackend(
 			TypeSerializer<K> keySerializer,
 			int numberOfKeyGroups,
 			KeyGroupRange keyGroupRange,
@@ -144,7 +145,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 	@SuppressWarnings("unchecked")
 	public void testValueState() throws Exception {
 		CheckpointStreamFactory streamFactory = createStreamFactory();
-		KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
+		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 		ValueStateDescriptor<String> kvId = new ValueStateDescriptor<>("id", String.class, null);
 		kvId.initializeSerializerUnlessSet(new ExecutionConfig());
@@ -195,7 +196,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		assertEquals("u3", state.value());
 		assertEquals("u3", getSerializedValue(kvState, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
-		backend.close();
+		backend.dispose();
 		backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot1);
 
 		snapshot1.discardState();
@@ -211,7 +212,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		assertEquals("2", restored1.value());
 		assertEquals("2", getSerializedValue(restoredKvState1, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
-		backend.close();
+		backend.dispose();
 		backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot2);
 
 		snapshot2.discardState();
@@ -230,7 +231,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		assertEquals("u3", restored2.value());
 		assertEquals("u3", getSerializedValue(restoredKvState2, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
-		backend.close();
+		backend.dispose();
 	}
 
 	@Test
@@ -238,7 +239,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 	public void testMultipleValueStates() throws Exception {
 		CheckpointStreamFactory streamFactory = createStreamFactory();
 
-		KeyedStateBackend<Integer> backend = createKeyedBackend(
+		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(
 				IntSerializer.INSTANCE,
 				1,
 				new KeyGroupRange(0, 0),
@@ -271,7 +272,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		// draw a snapshot
 		KeyGroupsStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory));
 
-		backend.close();
+		backend.dispose();
 		backend = restoreKeyedBackend(
 				IntSerializer.INSTANCE,
 				1,
@@ -290,7 +291,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		assertEquals("1", state1.value());
 		assertEquals(13, (int) state2.value());
 
-		backend.close();
+		backend.dispose();
 	}
 
 	/**
@@ -313,7 +314,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		}
 
 		CheckpointStreamFactory streamFactory = createStreamFactory();
-		KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
+		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 		ValueStateDescriptor<Long> kvId = new ValueStateDescriptor<>("id", LongSerializer.INSTANCE, 42L);
 		kvId.initializeSerializerUnlessSet(new ExecutionConfig());
@@ -344,14 +345,14 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		// draw a snapshot
 		KeyGroupsStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory));
 
-		backend.close();
+		backend.dispose();
 		backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot1);
 
 		snapshot1.discardState();
 
 		backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
 
-		backend.close();
+		backend.dispose();
 	}
 
 	@Test
@@ -359,7 +360,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 	public void testListState() {
 		try {
 			CheckpointStreamFactory streamFactory = createStreamFactory();
-			KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
+			AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 			ListStateDescriptor<String> kvId = new ListStateDescriptor<>("id", String.class);
 			kvId.initializeSerializerUnlessSet(new ExecutionConfig());
@@ -411,7 +412,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			assertEquals("u3", joiner.join(state.get()));
 			assertEquals("u3", joiner.join(getSerializedList(kvState, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)));
 
-			backend.close();
+			backend.dispose();
 			// restore the first snapshot and validate it
 			backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot1);
 			snapshot1.discardState();
@@ -427,7 +428,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			assertEquals("2", joiner.join(restored1.get()));
 			assertEquals("2", joiner.join(getSerializedList(restoredKvState1, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)));
 
-			backend.close();
+			backend.dispose();
 			// restore the second snapshot and validate it
 			backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot2);
 			snapshot2.discardState();
@@ -446,7 +447,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			assertEquals("u3", joiner.join(restored2.get()));
 			assertEquals("u3", joiner.join(getSerializedList(restoredKvState2, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)));
 
-			backend.close();
+			backend.dispose();
 		}
 		catch (Exception e) {
 			e.printStackTrace();
@@ -459,7 +460,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 	public void testReducingState() {
 		try {
 			CheckpointStreamFactory streamFactory = createStreamFactory();
-			KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
+			AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 			ReducingStateDescriptor<String> kvId = new ReducingStateDescriptor<>("id", new AppendingReduce(), String.class);
 			kvId.initializeSerializerUnlessSet(new ExecutionConfig());
@@ -510,7 +511,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			assertEquals("u3", state.get());
 			assertEquals("u3", getSerializedValue(kvState, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
-			backend.close();
+			backend.dispose();
 			// restore the first snapshot and validate it
 			backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot1);
 			snapshot1.discardState();
@@ -526,7 +527,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			assertEquals("2", restored1.get());
 			assertEquals("2", getSerializedValue(restoredKvState1, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
-			backend.close();
+			backend.dispose();
 			// restore the second snapshot and validate it
 			backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot2);
 			snapshot2.discardState();
@@ -545,7 +546,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			assertEquals("u3", restored2.get());
 			assertEquals("u3", getSerializedValue(restoredKvState2, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
-			backend.close();
+			backend.dispose();
 		}
 		catch (Exception e) {
 			e.printStackTrace();
@@ -558,7 +559,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 	public void testFoldingState() {
 		try {
 			CheckpointStreamFactory streamFactory = createStreamFactory();
-			KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
+			AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 			FoldingStateDescriptor<Integer, String> kvId = new FoldingStateDescriptor<>("id",
 					"Fold-Initial:",
@@ -613,7 +614,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			assertEquals("Fold-Initial:,103", state.get());
 			assertEquals("Fold-Initial:,103", getSerializedValue(kvState, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
-			backend.close();
+			backend.dispose();
 			// restore the first snapshot and validate it
 			backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot1);
 			snapshot1.discardState();
@@ -629,7 +630,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			assertEquals("Fold-Initial:,2", restored1.get());
 			assertEquals("Fold-Initial:,2", getSerializedValue(restoredKvState1, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
-			backend.close();
+			backend.dispose();
 			// restore the second snapshot and validate it
 			backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot2);
 			snapshot1.discardState();
@@ -649,7 +650,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			assertEquals("Fold-Initial:,103", restored2.get());
 			assertEquals("Fold-Initial:,103", getSerializedValue(restoredKvState2, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
-			backend.close();
+			backend.dispose();
 		}
 		catch (Exception e) {
 			e.printStackTrace();
@@ -672,7 +673,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		final int MAX_PARALLELISM = 10;
 
 		CheckpointStreamFactory streamFactory = createStreamFactory();
-		KeyedStateBackend<Integer> backend = createKeyedBackend(
+		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(
 				IntSerializer.INSTANCE,
 				MAX_PARALLELISM,
 				new KeyGroupRange(0, MAX_PARALLELISM - 1),
@@ -714,10 +715,10 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 				Collections.singletonList(snapshot),
 				KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(MAX_PARALLELISM, 2, 1));
 
-		backend.close();
+		backend.dispose();
 
 		// backend for the first half of the key group range
-		KeyedStateBackend<Integer> firstHalfBackend = restoreKeyedBackend(
+		AbstractKeyedStateBackend<Integer> firstHalfBackend = restoreKeyedBackend(
 				IntSerializer.INSTANCE,
 				MAX_PARALLELISM,
 				new KeyGroupRange(0, 4),
@@ -725,7 +726,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 				new DummyEnvironment("test", 1, 0));
 
 		// backend for the second half of the key group range
-		KeyedStateBackend<Integer> secondHalfBackend = restoreKeyedBackend(
+		AbstractKeyedStateBackend<Integer> secondHalfBackend = restoreKeyedBackend(
 				IntSerializer.INSTANCE,
 				MAX_PARALLELISM,
 				new KeyGroupRange(5, 9),
@@ -749,8 +750,8 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		secondHalfBackend.setCurrentKey(keyInSecondHalf);
 		assertTrue(secondHalfState.value().equals("ShouldBeInSecondHalf"));
 
-		firstHalfBackend.close();
-		secondHalfBackend.close();
+		firstHalfBackend.dispose();
+		secondHalfBackend.dispose();
 	}
 
 	@Test
@@ -758,7 +759,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 	public void testValueStateRestoreWithWrongSerializers() {
 		try {
 			CheckpointStreamFactory streamFactory = createStreamFactory();
-			KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
+			AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 			ValueStateDescriptor<String> kvId = new ValueStateDescriptor<>("id", String.class, null);
 			kvId.initializeSerializerUnlessSet(new ExecutionConfig());
@@ -773,7 +774,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			// draw a snapshot
 			KeyGroupsStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory));
 
-			backend.close();
+			backend.dispose();
 			// restore the first snapshot and validate it
 			backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot1);
 			snapshot1.discardState();
@@ -798,7 +799,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			} catch (Exception e) {
 				fail("wrong exception " + e);
 			}
-			backend.close();
+			backend.dispose();
 		}
 		catch (Exception e) {
 			e.printStackTrace();
@@ -811,7 +812,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 	public void testListStateRestoreWithWrongSerializers() {
 		try {
 			CheckpointStreamFactory streamFactory = createStreamFactory();
-			KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
+			AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 			ListStateDescriptor<String> kvId = new ListStateDescriptor<>("id", String.class);
 			ListState<String> state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
@@ -824,7 +825,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			// draw a snapshot
 			KeyGroupsStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory));
 
-			backend.close();
+			backend.dispose();
 			// restore the first snapshot and validate it
 			backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot1);
 			snapshot1.discardState();
@@ -849,7 +850,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			} catch (Exception e) {
 				fail("wrong exception " + e);
 			}
-			backend.close();
+			backend.dispose();
 		}
 		catch (Exception e) {
 			e.printStackTrace();
@@ -862,7 +863,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 	public void testReducingStateRestoreWithWrongSerializers() {
 		try {
 			CheckpointStreamFactory streamFactory = createStreamFactory();
-			KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
+			AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 			ReducingStateDescriptor<String> kvId = new ReducingStateDescriptor<>("id",
 					new AppendingReduce(),
@@ -877,7 +878,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			// draw a snapshot
 			KeyGroupsStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory));
 
-			backend.close();
+			backend.dispose();
 			// restore the first snapshot and validate it
 			backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot1);
 			snapshot1.discardState();
@@ -902,7 +903,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			} catch (Exception e) {
 				fail("wrong exception " + e);
 			}
-			backend.close();
+			backend.dispose();
 		}
 		catch (Exception e) {
 			e.printStackTrace();
@@ -912,7 +913,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 
 	@Test
 	public void testCopyDefaultValue() throws Exception {
-		KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
+		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 		ValueStateDescriptor<IntValue> kvId = new ValueStateDescriptor<>("id", IntValue.class, new IntValue(-1));
 		kvId.initializeSerializerUnlessSet(new ExecutionConfig());
@@ -930,7 +931,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		assertEquals(default1, default2);
 		assertFalse(default1 == default2);
 
-		backend.close();
+		backend.dispose();
 	}
 
 	/**
@@ -940,7 +941,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 	 */
 	@Test
 	public void testRequireNonNullNamespace() throws Exception {
-		KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
+		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 		ValueStateDescriptor<IntValue> kvId = new ValueStateDescriptor<>("id", IntValue.class, new IntValue(-1));
 		kvId.initializeSerializerUnlessSet(new ExecutionConfig());
@@ -963,7 +964,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		} catch (NullPointerException ignored) {
 		}
 
-		backend.close();
+		backend.dispose();
 	}
 
 	/**
@@ -973,7 +974,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 	@SuppressWarnings("unchecked")
 	protected void testConcurrentMapIfQueryable() throws Exception {
 		final int numberOfKeyGroups = 1;
-		KeyedStateBackend<Integer> backend = createKeyedBackend(
+		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(
 				IntSerializer.INSTANCE,
 				numberOfKeyGroups,
 				new KeyGroupRange(0, 0),
@@ -1095,7 +1096,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 			assertTrue(stateTable.get(keyGroupIndex).get(VoidNamespace.INSTANCE) instanceof ConcurrentHashMap);
 		}
 
-		backend.close();
+		backend.dispose();
 	}
 
 	/**
@@ -1107,7 +1108,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		KvStateRegistry registry = env.getKvStateRegistry();
 
 		CheckpointStreamFactory streamFactory = createStreamFactory();
-		KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE, env);
+		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE, env);
 		KeyGroupRange expectedKeyGroupRange = backend.getKeyGroupRange();
 
 		KvStateRegistryListener listener = mock(KvStateRegistryListener.class);
@@ -1128,11 +1129,11 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 
 		KeyGroupsStateHandle snapshot = runSnapshot(backend.snapshot(682375462379L, 4, streamFactory));
 
-		backend.close();
+		backend.dispose();
 
 		verify(listener, times(1)).notifyKvStateUnregistered(
 				eq(env.getJobID()), eq(env.getJobVertexId()), eq(expectedKeyGroupRange), eq("banana"));
-		backend.close();
+		backend.dispose();
 		// Initialize again
 		backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot, env);
 		snapshot.discardState();
@@ -1143,7 +1144,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 		verify(listener, times(2)).notifyKvStateRegistered(
 				eq(env.getJobID()), eq(env.getJobVertexId()), eq(expectedKeyGroupRange), eq("banana"), any(KvStateID.class));
 
-		backend.close();
+		backend.dispose();
 
 	}
 
@@ -1152,17 +1153,17 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> {
 
 		try {
 			CheckpointStreamFactory streamFactory = createStreamFactory();
-			KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
+			AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 			ListStateDescriptor<String> kvId = new ListStateDescriptor<>("id", String.class);
 
 			// draw a snapshot
 			KeyGroupsStateHandle snapshot = runSnapshot(backend.snapshot(682375462379L, 1, streamFactory));
 			assertNull(snapshot);
-			backend.close();
+			backend.dispose();
 
 			backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot);
-			backend.close();
+			backend.dispose();
 		}
 		catch (Exception e) {
 			e.printStackTrace();

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStateOutputStreamTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStateOutputStreamTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStateOutputStreamTest.java
index a6a555d..d484f2e 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStateOutputStreamTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStateOutputStreamTest.java
@@ -20,8 +20,6 @@ package org.apache.flink.runtime.state.filesystem;
 
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
-import org.apache.flink.runtime.state.AbstractStateBackend;
-
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.junit.Test;
@@ -31,7 +29,8 @@ import java.io.File;
 import java.io.InputStream;
 import java.util.Random;
 
-import static org.junit.Assert.*;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertTrue;
 
 public class FsCheckpointStateOutputStreamTest {
 
@@ -112,13 +111,14 @@ public class FsCheckpointStateOutputStreamTest {
 		// make sure the writing process did not alter the original byte array
 		assertArrayEquals(original, bytes);
 
-		InputStream inStream = handle.openInputStream();
-		byte[] validation = new byte[bytes.length];
+		try (InputStream inStream = handle.openInputStream()) {
+			byte[] validation = new byte[bytes.length];
 
-		DataInputStream dataInputStream = new DataInputStream(inStream);
-		dataInputStream.readFully(validation);
+			DataInputStream dataInputStream = new DataInputStream(inStream);
+			dataInputStream.readFully(validation);
 
-		assertArrayEquals(bytes, validation);
+			assertArrayEquals(bytes, validation);
+		}
 
 		handle.discardState();
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
index 454196f..7bc2c29 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
@@ -46,6 +46,7 @@ import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 
 import org.apache.flink.util.SerializedValue;
@@ -53,6 +54,7 @@ import org.junit.Before;
 import org.junit.Test;
 
 import java.net.URL;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.List;
 import java.util.concurrent.Executor;
@@ -209,7 +211,8 @@ public class TaskAsyncCallTest {
 
 		@Override
 		public void setInitialState(ChainedStateHandle<StreamStateHandle> chainedState,
-				List<KeyGroupsStateHandle> keyGroupsState) throws Exception {
+									List<KeyGroupsStateHandle> keyGroupsState,
+									List<Collection<OperatorStateHandle>> partitionableOperatorState) throws Exception {
 
 		}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/zookeeper/ZooKeeperStateHandleStoreITCase.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/zookeeper/ZooKeeperStateHandleStoreITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/zookeeper/ZooKeeperStateHandleStoreITCase.java
index 7e8868c..8f9c932 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/zookeeper/ZooKeeperStateHandleStoreITCase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/zookeeper/ZooKeeperStateHandleStoreITCase.java
@@ -33,7 +33,6 @@ import org.junit.Test;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
 
-import java.io.IOException;
 import java.util.ArrayList;
 import java.util.HashSet;
 import java.util.List;
@@ -587,8 +586,5 @@ public class ZooKeeperStateHandleStoreITCase extends TestLogger {
 		public int getNumberOfDiscardCalls() {
 			return numberOfDiscardCalls;
 		}
-
-		@Override
-		public void close() throws IOException {}
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-connectors/flink-connector-kafka-0.8/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumer08.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-kafka-0.8/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumer08.java b/flink-streaming-connectors/flink-connector-kafka-0.8/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumer08.java
index c16629d..d7a6364 100644
--- a/flink-streaming-connectors/flink-connector-kafka-0.8/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumer08.java
+++ b/flink-streaming-connectors/flink-connector-kafka-0.8/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumer08.java
@@ -177,7 +177,7 @@ public class FlinkKafkaConsumer08<T> extends FlinkKafkaConsumerBase<T> {
 	 *           The properties that are used to configure both the fetcher and the offset handler.
 	 */
 	public FlinkKafkaConsumer08(List<String> topics, KeyedDeserializationSchema<T> deserializer, Properties props) {
-		super(deserializer);
+		super(topics, deserializer);
 
 		checkNotNull(topics, "topics");
 		this.kafkaProperties = checkNotNull(props, "props");
@@ -187,22 +187,6 @@ public class FlinkKafkaConsumer08<T> extends FlinkKafkaConsumerBase<T> {
 
 		this.invalidOffsetBehavior = getInvalidOffsetBehavior(props);
 		this.autoCommitInterval = PropertiesUtil.getLong(props, "auto.commit.interval.ms", 60000);
-
-		// Connect to a broker to get the partitions for all topics
-		List<KafkaTopicPartition> partitionInfos = 
-				KafkaTopicPartition.dropLeaderData(getPartitionsForTopic(topics, props));
-
-		if (partitionInfos.size() == 0) {
-			throw new RuntimeException(
-					"Unable to retrieve any partitions for the requested topics " + topics + 
-							". Please check previous log entries");
-		}
-
-		if (LOG.isInfoEnabled()) {
-			logPartitionInfo(LOG, partitionInfos);
-		}
-
-		setSubscribedPartitions(partitionInfos);
 	}
 
 	@Override
@@ -221,6 +205,25 @@ public class FlinkKafkaConsumer08<T> extends FlinkKafkaConsumerBase<T> {
 				invalidOffsetBehavior, autoCommitInterval, useMetrics);
 	}
 
+	@Override
+	protected List<KafkaTopicPartition> getKafkaPartitions(List<String> topics) {
+		// Connect to a broker to get the partitions for all topics
+		List<KafkaTopicPartition> partitionInfos =
+			KafkaTopicPartition.dropLeaderData(getPartitionsForTopic(topics, kafkaProperties));
+
+		if (partitionInfos.size() == 0) {
+			throw new RuntimeException(
+				"Unable to retrieve any partitions for the requested topics " + topics +
+					". Please check previous log entries");
+		}
+
+		if (LOG.isInfoEnabled()) {
+			logPartitionInfo(LOG, partitionInfos);
+		}
+
+		return partitionInfos;
+	}
+
 	// ------------------------------------------------------------------------
 	//  Kafka / ZooKeeper communication utilities
 	// ------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-connectors/flink-connector-kafka-0.8/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumer08Test.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-kafka-0.8/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumer08Test.java b/flink-streaming-connectors/flink-connector-kafka-0.8/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumer08Test.java
index 36fb7e6..f0b58cf 100644
--- a/flink-streaming-connectors/flink-connector-kafka-0.8/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumer08Test.java
+++ b/flink-streaming-connectors/flink-connector-kafka-0.8/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumer08Test.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.streaming.connectors.kafka;
 
+import org.apache.flink.configuration.Configuration;
 import org.apache.flink.streaming.util.serialization.SimpleStringSchema;
 
 import org.apache.kafka.clients.consumer.ConsumerConfig;
@@ -80,7 +81,8 @@ public class KafkaConsumer08Test {
 			props.setProperty("bootstrap.servers", "localhost:11111, localhost:22222");
 			props.setProperty("group.id", "non-existent-group");
 
-			new FlinkKafkaConsumer08<>(Collections.singletonList("no op topic"), new SimpleStringSchema(), props);
+			FlinkKafkaConsumer08<String> consumer = new FlinkKafkaConsumer08<>(Collections.singletonList("no op topic"), new SimpleStringSchema(), props);
+			consumer.open(new Configuration());
 			fail();
 		}
 		catch (Exception e) {

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-connectors/flink-connector-kafka-0.9/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumer09.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-kafka-0.9/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumer09.java b/flink-streaming-connectors/flink-connector-kafka-0.9/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumer09.java
index 8c3eaf8..9708777 100644
--- a/flink-streaming-connectors/flink-connector-kafka-0.9/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumer09.java
+++ b/flink-streaming-connectors/flink-connector-kafka-0.9/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumer09.java
@@ -149,9 +149,8 @@ public class FlinkKafkaConsumer09<T> extends FlinkKafkaConsumerBase<T> {
 	 *           The properties that are used to configure both the fetcher and the offset handler.
 	 */
 	public FlinkKafkaConsumer09(List<String> topics, KeyedDeserializationSchema<T> deserializer, Properties props) {
-		super(deserializer);
+		super(topics, deserializer);
 
-		checkNotNull(topics, "topics");
 		this.properties = checkNotNull(props, "props");
 		setDeserializer(this.properties);
 
@@ -166,7 +165,27 @@ public class FlinkKafkaConsumer09<T> extends FlinkKafkaConsumerBase<T> {
 		catch (Exception e) {
 			throw new IllegalArgumentException("Cannot parse poll timeout for '" + KEY_POLL_TIMEOUT + '\'', e);
 		}
+	}
+
+	@Override
+	protected AbstractFetcher<T, ?> createFetcher(
+			SourceContext<T> sourceContext,
+			List<KafkaTopicPartition> thisSubtaskPartitions,
+			SerializedValue<AssignerWithPeriodicWatermarks<T>> watermarksPeriodic,
+			SerializedValue<AssignerWithPunctuatedWatermarks<T>> watermarksPunctuated,
+			StreamingRuntimeContext runtimeContext) throws Exception {
+
+		boolean useMetrics = !Boolean.valueOf(properties.getProperty(KEY_DISABLE_METRICS, "false"));
+
+		return new Kafka09Fetcher<>(sourceContext, thisSubtaskPartitions,
+				watermarksPeriodic, watermarksPunctuated,
+				runtimeContext, deserializer,
+				properties, pollTimeout, useMetrics);
+		
+	}
 
+	@Override
+	protected List<KafkaTopicPartition> getKafkaPartitions(List<String> topics) {
 		// read the partitions that belong to the listed topics
 		final List<KafkaTopicPartition> partitions = new ArrayList<>();
 
@@ -192,25 +211,7 @@ public class FlinkKafkaConsumer09<T> extends FlinkKafkaConsumerBase<T> {
 			logPartitionInfo(LOG, partitions);
 		}
 
-		// register these partitions
-		setSubscribedPartitions(partitions);
-	}
-
-	@Override
-	protected AbstractFetcher<T, ?> createFetcher(
-			SourceContext<T> sourceContext,
-			List<KafkaTopicPartition> thisSubtaskPartitions,
-			SerializedValue<AssignerWithPeriodicWatermarks<T>> watermarksPeriodic,
-			SerializedValue<AssignerWithPunctuatedWatermarks<T>> watermarksPunctuated,
-			StreamingRuntimeContext runtimeContext) throws Exception {
-
-		boolean useMetrics = !Boolean.valueOf(properties.getProperty(KEY_DISABLE_METRICS, "false"));
-
-		return new Kafka09Fetcher<>(sourceContext, thisSubtaskPartitions,
-				watermarksPeriodic, watermarksPunctuated,
-				runtimeContext, deserializer,
-				properties, pollTimeout, useMetrics);
-		
+		return partitions;
 	}
 
 	// ------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
index 2b2c527..939b77b 100644
--- a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
+++ b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
@@ -18,11 +18,16 @@
 package org.apache.flink.streaming.connectors.kafka;
 
 import org.apache.commons.collections.map.LinkedMap;
-
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.typeinfo.TypeHint;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
+import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.state.CheckpointListener;
-import org.apache.flink.streaming.api.checkpoint.CheckpointedAsynchronously;
+import org.apache.flink.runtime.state.OperatorStateStore;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
 import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks;
 import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
 import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
@@ -30,18 +35,21 @@ import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.connectors.kafka.internals.AbstractFetcher;
 import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartitionState;
 import org.apache.flink.streaming.util.serialization.KeyedDeserializationSchema;
 import org.apache.flink.util.SerializedValue;
-
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.Comparator;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 
+import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
 /**
@@ -55,11 +63,12 @@ import static org.apache.flink.util.Preconditions.checkNotNull;
  */
 public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFunction<T> implements 
 		CheckpointListener,
-		CheckpointedAsynchronously<HashMap<KafkaTopicPartition, Long>>,
-		ResultTypeQueryable<T>
-{
+		ResultTypeQueryable<T>,
+		CheckpointedFunction {
 	private static final long serialVersionUID = -6272159445203409112L;
 
+	private static final String KAFKA_OFFSETS = "kafka_offsets";
+
 	protected static final Logger LOG = LoggerFactory.getLogger(FlinkKafkaConsumerBase.class);
 	
 	/** The maximum number of pending non-committed checkpoints to track, to avoid memory leaks */
@@ -71,12 +80,14 @@ public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFuncti
 	// ------------------------------------------------------------------------
 	//  configuration state, set on the client relevant for all subtasks
 	// ------------------------------------------------------------------------
+
+	private final List<String> topics;
 	
 	/** The schema to convert between Kafka's byte messages, and Flink's objects */
 	protected final KeyedDeserializationSchema<T> deserializer;
 
 	/** The set of topic partitions that the source will read */
-	protected List<KafkaTopicPartition> allSubscribedPartitions;
+	protected List<KafkaTopicPartition> subscribedPartitions;
 	
 	/** Optional timestamp extractor / watermark generator that will be run per Kafka partition,
 	 * to exploit per-partition timestamp characteristics.
@@ -88,6 +99,8 @@ public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFuncti
 	 * The assigner is kept in serialized form, to deserialize it into multiple copies */
 	private SerializedValue<AssignerWithPunctuatedWatermarks<T>> punctuatedWatermarkAssigner;
 
+	private transient OperatorStateStore stateStore;
+
 	// ------------------------------------------------------------------------
 	//  runtime state (used individually by each parallel subtask) 
 	// ------------------------------------------------------------------------
@@ -112,8 +125,14 @@ public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFuncti
 	 * @param deserializer
 	 *           The deserializer to turn raw byte messages into Java/Scala objects.
 	 */
-	public FlinkKafkaConsumerBase(KeyedDeserializationSchema<T> deserializer) {
+	public FlinkKafkaConsumerBase(List<String> topics, KeyedDeserializationSchema<T> deserializer) {
+		this.topics = checkNotNull(topics);
+		checkArgument(topics.size() > 0, "You have to define at least one topic.");
+
 		this.deserializer = checkNotNull(deserializer, "valueDeserializer");
+
+		TypeInformation<Tuple2<KafkaTopicPartition, Long>> typeInfo =
+				TypeInformation.of(new TypeHint<Tuple2<KafkaTopicPartition, Long>>(){});
 	}
 
 	/**
@@ -124,7 +143,7 @@ public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFuncti
 	 */
 	protected void setSubscribedPartitions(List<KafkaTopicPartition> allSubscribedPartitions) {
 		checkNotNull(allSubscribedPartitions);
-		this.allSubscribedPartitions = Collections.unmodifiableList(allSubscribedPartitions);
+		this.subscribedPartitions = Collections.unmodifiableList(allSubscribedPartitions);
 	}
 
 	// ------------------------------------------------------------------------
@@ -205,20 +224,16 @@ public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFuncti
 
 	@Override
 	public void run(SourceContext<T> sourceContext) throws Exception {
-		if (allSubscribedPartitions == null) {
+		if (subscribedPartitions == null) {
 			throw new Exception("The partitions were not set for the consumer");
 		}
-		
-		// figure out which partitions this subtask should process
-		final List<KafkaTopicPartition> thisSubtaskPartitions = assignPartitions(allSubscribedPartitions,
-				getRuntimeContext().getNumberOfParallelSubtasks(), getRuntimeContext().getIndexOfThisSubtask());
-		
+
 		// we need only do work, if we actually have partitions assigned
-		if (!thisSubtaskPartitions.isEmpty()) {
+		if (!subscribedPartitions.isEmpty()) {
 
 			// (1) create the fetcher that will communicate with the Kafka brokers
 			final AbstractFetcher<T, ?> fetcher = createFetcher(
-					sourceContext, thisSubtaskPartitions, 
+					sourceContext, subscribedPartitions,
 					periodicWatermarkAssigner, punctuatedWatermarkAssigner,
 					(StreamingRuntimeContext) getRuntimeContext());
 
@@ -277,6 +292,15 @@ public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFuncti
 	}
 
 	@Override
+	public void open(Configuration configuration) {
+		List<KafkaTopicPartition> kafkaTopicPartitions = getKafkaPartitions(topics);
+
+		if (kafkaTopicPartitions != null) {
+			assignTopicPartitions(kafkaTopicPartitions);
+		}
+	}
+
+	@Override
 	public void close() throws Exception {
 		// pretty much the same logic as cancelling
 		try {
@@ -289,44 +313,76 @@ public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFuncti
 	// ------------------------------------------------------------------------
 	//  Checkpoint and restore
 	// ------------------------------------------------------------------------
-	
+
+
 	@Override
-	public HashMap<KafkaTopicPartition, Long> snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
-		if (!running) {
-			LOG.debug("snapshotState() called on closed source");
-			return null;
-		}
-		
-		final AbstractFetcher<?, ?> fetcher = this.kafkaFetcher;
-		if (fetcher == null) {
-			// the fetcher has not yet been initialized, which means we need to return the
-			// originally restored offsets
-			return restoreToOffset;
-		}
+	public void initializeState(OperatorStateStore stateStore) throws Exception {
 
-		HashMap<KafkaTopicPartition, Long> currentOffsets = fetcher.snapshotCurrentState();
+		this.stateStore = stateStore;
 
-		if (LOG.isDebugEnabled()) {
-			LOG.debug("Snapshotting state. Offsets: {}, checkpoint id: {}, timestamp: {}",
-					KafkaTopicPartition.toString(currentOffsets), checkpointId, checkpointTimestamp);
-		}
+		ListState<Serializable> offsets = stateStore.getPartitionableState(ListCheckpointed.DEFAULT_LIST_DESCRIPTOR);
 
-		// the map cannot be asynchronously updated, because only one checkpoint call can happen
-		// on this function at a time: either snapshotState() or notifyCheckpointComplete()
-		pendingCheckpoints.put(checkpointId, currentOffsets);
-		
-		// truncate the map, to prevent infinite growth
-		while (pendingCheckpoints.size() > MAX_NUM_PENDING_CHECKPOINTS) {
-			pendingCheckpoints.remove(0);
+		restoreToOffset = new HashMap<>();
+
+		for (Serializable serializable : offsets.get()) {
+			@SuppressWarnings("unchecked")
+			Tuple2<KafkaTopicPartition, Long> kafkaOffset = (Tuple2<KafkaTopicPartition, Long>) serializable;
+			restoreToOffset.put(kafkaOffset.f0, kafkaOffset.f1);
 		}
 
-		return currentOffsets;
+		LOG.info("Setting restore state in the FlinkKafkaConsumer: {}", restoreToOffset);
 	}
 
 	@Override
-	public void restoreState(HashMap<KafkaTopicPartition, Long> restoredOffsets) {
-		LOG.info("Setting restore state in the FlinkKafkaConsumer: {}", restoredOffsets);
-		restoreToOffset = restoredOffsets;
+	public void prepareSnapshot(long checkpointId, long timestamp) throws Exception {
+		if (!running) {
+			LOG.debug("storeOperatorState() called on closed source");
+		} else {
+
+			ListState<Serializable> listState = stateStore.getPartitionableState(ListCheckpointed.DEFAULT_LIST_DESCRIPTOR);
+
+			listState.clear();
+
+			final AbstractFetcher<?, ?> fetcher = this.kafkaFetcher;
+			if (fetcher == null) {
+				// the fetcher has not yet been initialized, which means we need to return the
+				// originally restored offsets or the assigned partitions
+
+				if (restoreToOffset != null) {
+					// the map cannot be asynchronously updated, because only one checkpoint call can happen
+					// on this function at a time: either snapshotState() or notifyCheckpointComplete()
+					pendingCheckpoints.put(checkpointId, restoreToOffset);
+
+					// truncate the map, to prevent infinite growth
+					while (pendingCheckpoints.size() > MAX_NUM_PENDING_CHECKPOINTS) {
+						pendingCheckpoints.remove(0);
+					}
+
+					for (Map.Entry<KafkaTopicPartition, Long> kafkaTopicPartitionLongEntry : restoreToOffset.entrySet()) {
+						listState.add(Tuple2.of(kafkaTopicPartitionLongEntry.getKey(), kafkaTopicPartitionLongEntry.getValue()));
+					}
+				} else if (subscribedPartitions != null) {
+					for (KafkaTopicPartition subscribedPartition : subscribedPartitions) {
+						listState.add(Tuple2.of(subscribedPartition, KafkaTopicPartitionState.OFFSET_NOT_SET));
+					}
+				}
+			} else {
+				HashMap<KafkaTopicPartition, Long> currentOffsets = fetcher.snapshotCurrentState();
+
+				// the map cannot be asynchronously updated, because only one checkpoint call can happen
+				// on this function at a time: either snapshotState() or notifyCheckpointComplete()
+				pendingCheckpoints.put(checkpointId, currentOffsets);
+
+				// truncate the map, to prevent infinite growth
+				while (pendingCheckpoints.size() > MAX_NUM_PENDING_CHECKPOINTS) {
+					pendingCheckpoints.remove(0);
+				}
+
+				for (Map.Entry<KafkaTopicPartition, Long> kafkaTopicPartitionLongEntry : currentOffsets.entrySet()) {
+					listState.add(Tuple2.of(kafkaTopicPartitionLongEntry.getKey(), kafkaTopicPartitionLongEntry.getValue()));
+				}
+			}
+		}
 	}
 
 	@Override
@@ -401,6 +457,8 @@ public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFuncti
 			SerializedValue<AssignerWithPeriodicWatermarks<T>> watermarksPeriodic,
 			SerializedValue<AssignerWithPunctuatedWatermarks<T>> watermarksPunctuated,
 			StreamingRuntimeContext runtimeContext) throws Exception;
+
+	protected abstract List<KafkaTopicPartition> getKafkaPartitions(List<String> topics);
 	
 	// ------------------------------------------------------------------------
 	//  ResultTypeQueryable methods 
@@ -415,6 +473,35 @@ public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFuncti
 	//  Utilities
 	// ------------------------------------------------------------------------
 
+	private void assignTopicPartitions(List<KafkaTopicPartition> kafkaTopicPartitions) {
+		subscribedPartitions = new ArrayList<>();
+
+		if (restoreToOffset != null) {
+			for (KafkaTopicPartition kafkaTopicPartition : kafkaTopicPartitions) {
+				if (restoreToOffset.containsKey(kafkaTopicPartition)) {
+					subscribedPartitions.add(kafkaTopicPartition);
+				}
+			}
+		} else {
+			Collections.sort(kafkaTopicPartitions, new Comparator<KafkaTopicPartition>() {
+				@Override
+				public int compare(KafkaTopicPartition o1, KafkaTopicPartition o2) {
+					int topicComparison = o1.getTopic().compareTo(o2.getTopic());
+
+					if (topicComparison == 0) {
+						return o1.getPartition() - o2.getPartition();
+					} else {
+						return topicComparison;
+					}
+				}
+			});
+
+			for (int i = getRuntimeContext().getIndexOfThisSubtask(); i < kafkaTopicPartitions.size(); i += getRuntimeContext().getNumberOfParallelSubtasks()) {
+				subscribedPartitions.add(kafkaTopicPartitions.get(i));
+			}
+		}
+	}
+
 	/**
 	 * Selects which of the given partitions should be handled by a specific consumer,
 	 * given a certain number of consumers.
@@ -427,8 +514,7 @@ public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFuncti
 	 */
 	protected static List<KafkaTopicPartition> assignPartitions(
 			List<KafkaTopicPartition> allPartitions,
-			int numConsumers, int consumerIndex)
-	{
+			int numConsumers, int consumerIndex) {
 		final List<KafkaTopicPartition> thisSubtaskPartitions = new ArrayList<>(
 				allPartitions.size() / numConsumers + 1);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java
index e63f033..8b87004 100644
--- a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java
+++ b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java
@@ -20,16 +20,16 @@ package org.apache.flink.streaming.connectors.kafka;
 import org.apache.flink.api.common.functions.RuntimeContext;
 import org.apache.flink.api.java.ClosureCleaner;
 import org.apache.flink.configuration.Configuration;
-import org.apache.flink.runtime.util.SerializableObject;
 import org.apache.flink.metrics.MetricGroup;
-import org.apache.flink.streaming.api.checkpoint.Checkpointed;
+import org.apache.flink.runtime.state.OperatorStateStore;
+import org.apache.flink.runtime.util.SerializableObject;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
 import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
 import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
 import org.apache.flink.streaming.connectors.kafka.internals.metrics.KafkaMetricWrapper;
 import org.apache.flink.streaming.connectors.kafka.partitioner.KafkaPartitioner;
 import org.apache.flink.streaming.util.serialization.KeyedSerializationSchema;
 import org.apache.flink.util.NetUtils;
-
 import org.apache.kafka.clients.producer.Callback;
 import org.apache.kafka.clients.producer.KafkaProducer;
 import org.apache.kafka.clients.producer.Producer;
@@ -40,11 +40,9 @@ import org.apache.kafka.common.Metric;
 import org.apache.kafka.common.MetricName;
 import org.apache.kafka.common.PartitionInfo;
 import org.apache.kafka.common.serialization.ByteArraySerializer;
-
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.io.Serializable;
 import java.util.List;
 import java.util.Map;
 import java.util.Properties;
@@ -61,7 +59,7 @@ import static java.util.Objects.requireNonNull;
  *
  * @param <IN> Type of the messages to write into Kafka.
  */
-public abstract class FlinkKafkaProducerBase<IN> extends RichSinkFunction<IN> implements Checkpointed<Serializable> {
+public abstract class FlinkKafkaProducerBase<IN> extends RichSinkFunction<IN> implements CheckpointedFunction {
 
 	private static final Logger LOG = LoggerFactory.getLogger(FlinkKafkaProducerBase.class);
 
@@ -126,6 +124,8 @@ public abstract class FlinkKafkaProducerBase<IN> extends RichSinkFunction<IN> im
 	/** Number of unacknowledged records. */
 	protected long pendingRecords;
 
+	protected OperatorStateStore stateStore;
+
 
 	/**
 	 * The main constructor for creating a FlinkKafkaProducer.
@@ -330,7 +330,12 @@ public abstract class FlinkKafkaProducerBase<IN> extends RichSinkFunction<IN> im
 	protected abstract void flush();
 
 	@Override
-	public Serializable snapshotState(long checkpointId, long checkpointTimestamp) {
+	public void initializeState(OperatorStateStore stateStore) throws Exception {
+		this.stateStore = stateStore;
+	}
+
+	@Override
+	public void prepareSnapshot(long checkpointId, long timestamp) throws Exception {
 		if (flushOnCheckpoint) {
 			// flushing is activated: We need to wait until pendingRecords is 0
 			flush();
@@ -341,16 +346,8 @@ public abstract class FlinkKafkaProducerBase<IN> extends RichSinkFunction<IN> im
 				// pending records count is 0. We can now confirm the checkpoint
 			}
 		}
-		// return empty state
-		return null;
-	}
-
-	@Override
-	public void restoreState(Serializable state) {
-		// nothing to do here
 	}
 
-
 	// ----------------------------------- Utilities --------------------------
 
 	protected void checkErroneous() throws Exception {

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/internals/AbstractFetcher.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/internals/AbstractFetcher.java b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/internals/AbstractFetcher.java
index 9255445..7ce3a9d 100644
--- a/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/internals/AbstractFetcher.java
+++ b/flink-streaming-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/internals/AbstractFetcher.java
@@ -183,9 +183,7 @@ public abstract class AbstractFetcher<T, KPH> {
 
 		HashMap<KafkaTopicPartition, Long> state = new HashMap<>(allPartitions.length);
 		for (KafkaTopicPartitionState<?> partition : subscribedPartitions()) {
-			if (partition.isOffsetDefined()) {
-				state.put(partition.getKafkaTopicPartition(), partition.getOffset());
-			}
+			state.put(partition.getKafkaTopicPartition(), partition.getOffset());
 		}
 		return state;
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/AtLeastOnceProducerTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/AtLeastOnceProducerTest.java b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/AtLeastOnceProducerTest.java
index b02593c..766a107 100644
--- a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/AtLeastOnceProducerTest.java
+++ b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/AtLeastOnceProducerTest.java
@@ -20,6 +20,7 @@ package org.apache.flink.streaming.connectors.kafka;
 
 import org.apache.flink.api.java.tuple.Tuple1;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.state.OperatorStateStore;
 import org.apache.flink.streaming.connectors.kafka.testutils.MockRuntimeContext;
 import org.apache.flink.streaming.util.serialization.KeyedSerializationSchema;
 import org.apache.flink.streaming.util.serialization.KeyedSerializationSchemaWrapper;
@@ -37,7 +38,6 @@ import org.junit.Test;
 import scala.concurrent.duration.Deadline;
 import scala.concurrent.duration.FiniteDuration;
 
-import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
@@ -45,6 +45,8 @@ import java.util.Properties;
 import java.util.concurrent.Future;
 import java.util.concurrent.atomic.AtomicBoolean;
 
+import static org.mockito.Mockito.mock;
+
 /**
  * Test ensuring that the producer is not dropping buffered records
  */
@@ -111,7 +113,7 @@ public class AtLeastOnceProducerTest {
 		Thread threadB = new Thread(confirmer);
 		threadB.start();
 		// this should block:
-		producer.snapshotState(0, 0);
+		producer.prepareSnapshot(0, 0);
 		synchronized (threadA) {
 			threadA.notifyAll(); // just in case, to let the test fail faster
 		}
@@ -130,6 +132,8 @@ public class AtLeastOnceProducerTest {
 
 
 	private static class TestingKafkaProducer<T> extends FlinkKafkaProducerBase<T> {
+		private static final long serialVersionUID = -1759403646061180067L;
+
 		private MockProducer prod;
 		private AtomicBoolean snapshottingFinished;
 
@@ -145,12 +149,11 @@ public class AtLeastOnceProducerTest {
 		}
 
 		@Override
-		public Serializable snapshotState(long checkpointId, long checkpointTimestamp) {
+		public void prepareSnapshot(long checkpointId, long timestamp) throws Exception {
 			// call the actual snapshot state
-			Serializable ret = super.snapshotState(checkpointId, checkpointTimestamp);
+			super.prepareSnapshot(checkpointId, timestamp);
 			// notify test that snapshotting has been done
 			snapshottingFinished.set(true);
-			return ret;
 		}
 
 		@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
index 9b517df..fc8b7e9 100644
--- a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
+++ b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
@@ -19,6 +19,11 @@
 package org.apache.flink.streaming.connectors.kafka;
 
 import org.apache.commons.collections.map.LinkedMap;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.runtime.state.OperatorStateStore;
 import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks;
 import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
 import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
@@ -26,15 +31,26 @@ import org.apache.flink.streaming.connectors.kafka.internals.AbstractFetcher;
 import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
 import org.apache.flink.streaming.util.serialization.KeyedDeserializationSchema;
 import org.apache.flink.util.SerializedValue;
-
 import org.junit.Test;
+import org.mockito.Matchers;
 
 import java.lang.reflect.Field;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
+import java.util.Set;
 
-import static org.junit.Assert.*;
-import static org.mockito.Mockito.*;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 public class FlinkKafkaConsumerBaseTest {
 
@@ -82,7 +98,13 @@ public class FlinkKafkaConsumerBaseTest {
 		final AbstractFetcher<String, ?> fetcher = mock(AbstractFetcher.class);
 
 		FlinkKafkaConsumerBase<String> consumer = getConsumer(fetcher, new LinkedMap(), false);
-		assertNull(consumer.snapshotState(17L, 23L));
+		OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
+		TestingListState<Tuple2<KafkaTopicPartition, Long>> listState = new TestingListState<>();
+		when(operatorStateStore.getPartitionableState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+
+		consumer.prepareSnapshot(17L, 17L);
+
+		assertFalse(listState.get().iterator().hasNext());
 		consumer.notifyCheckpointComplete(66L);
 	}
 
@@ -91,14 +113,37 @@ public class FlinkKafkaConsumerBaseTest {
 	 */
 	@Test
 	public void checkRestoredCheckpointWhenFetcherNotReady() throws Exception {
-		HashMap<KafkaTopicPartition, Long> restoreState = new HashMap<>();
-		restoreState.put(new KafkaTopicPartition("abc", 13), 16768L);
-		restoreState.put(new KafkaTopicPartition("def", 7), 987654321L);
+		OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
+
+		TestingListState<Tuple2<KafkaTopicPartition, Long>> expectedState = new TestingListState<>();
+		expectedState.add(Tuple2.of(new KafkaTopicPartition("abc", 13), 16768L));
+		expectedState.add(Tuple2.of(new KafkaTopicPartition("def", 7), 987654321L));
+
+		TestingListState<Tuple2<KafkaTopicPartition, Long>> listState = new TestingListState<>();
 
 		FlinkKafkaConsumerBase<String> consumer = getConsumer(null, new LinkedMap(), true);
-		consumer.restoreState(restoreState);
-		
-		assertEquals(restoreState, consumer.snapshotState(17L, 23L));
+
+		when(operatorStateStore.getPartitionableState(Matchers.any(ListStateDescriptor.class))).thenReturn(expectedState);
+		consumer.initializeState(operatorStateStore);
+
+		when(operatorStateStore.getPartitionableState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+
+		consumer.prepareSnapshot(17L, 17L);
+
+		Set<Tuple2<KafkaTopicPartition, Long>> expected = new HashSet<Tuple2<KafkaTopicPartition, Long>>();
+
+		for (Tuple2<KafkaTopicPartition, Long> kafkaTopicPartitionLongTuple2 : expectedState.get()) {
+			expected.add(kafkaTopicPartitionLongTuple2);
+		}
+
+		int counter = 0;
+
+		for (Tuple2<KafkaTopicPartition, Long> kafkaTopicPartitionLongTuple2 : listState.get()) {
+			assertTrue(expected.contains(kafkaTopicPartitionLongTuple2));
+			counter++;
+		}
+
+		assertEquals(expected.size(), counter);
 	}
 
 	/**
@@ -107,7 +152,15 @@ public class FlinkKafkaConsumerBaseTest {
 	@Test
 	public void checkRestoredNullCheckpointWhenFetcherNotReady() throws Exception {
 		FlinkKafkaConsumerBase<String> consumer = getConsumer(null, new LinkedMap(), true);
-		assertNull(consumer.snapshotState(17L, 23L));
+
+		OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
+		TestingListState<Tuple2<KafkaTopicPartition, Long>> listState = new TestingListState<>();
+		when(operatorStateStore.getPartitionableState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+
+		consumer.initializeState(operatorStateStore);
+		consumer.prepareSnapshot(17L, 17L);
+
+		assertFalse(listState.get().iterator().hasNext());
 	}
 	
 	@Test
@@ -132,15 +185,40 @@ public class FlinkKafkaConsumerBaseTest {
 	
 		FlinkKafkaConsumerBase<String> consumer = getConsumer(fetcher, pendingCheckpoints, true);
 		assertEquals(0, pendingCheckpoints.size());
-		
+
+		OperatorStateStore backend = mock(OperatorStateStore.class);
+
+		TestingListState<Tuple2<KafkaTopicPartition, Long>> listState1 = new TestingListState<>();
+		TestingListState<Tuple2<KafkaTopicPartition, Long>> listState2 = new TestingListState<>();
+		TestingListState<Tuple2<KafkaTopicPartition, Long>> listState3 = new TestingListState<>();
+
+		when(backend.getPartitionableState(Matchers.any(ListStateDescriptor.class))).
+				thenReturn(listState1, listState1, listState2, listState2, listState3, listState3);
+
+		consumer.initializeState(backend);
+
 		// checkpoint 1
-		HashMap<KafkaTopicPartition, Long> snapshot1 = consumer.snapshotState(138L, 19L);
+		consumer.prepareSnapshot(138L, 138L);
+
+		HashMap<KafkaTopicPartition, Long> snapshot1 = new HashMap<>();
+
+		for (Tuple2<KafkaTopicPartition, Long> kafkaTopicPartitionLongTuple2 : listState1.get()) {
+			snapshot1.put(kafkaTopicPartitionLongTuple2.f0, kafkaTopicPartitionLongTuple2.f1);
+		}
+
 		assertEquals(state1, snapshot1);
 		assertEquals(1, pendingCheckpoints.size());
 		assertEquals(state1, pendingCheckpoints.get(138L));
 
 		// checkpoint 2
-		HashMap<KafkaTopicPartition, Long> snapshot2 = consumer.snapshotState(140L, 1578L);
+		consumer.prepareSnapshot(140L, 140L);
+
+		HashMap<KafkaTopicPartition, Long> snapshot2 = new HashMap<>();
+
+		for (Tuple2<KafkaTopicPartition, Long> kafkaTopicPartitionLongTuple2 : listState2.get()) {
+			snapshot2.put(kafkaTopicPartitionLongTuple2.f0, kafkaTopicPartitionLongTuple2.f1);
+		}
+
 		assertEquals(state2, snapshot2);
 		assertEquals(2, pendingCheckpoints.size());
 		assertEquals(state2, pendingCheckpoints.get(140L));
@@ -151,7 +229,14 @@ public class FlinkKafkaConsumerBaseTest {
 		assertTrue(pendingCheckpoints.containsKey(140L));
 
 		// checkpoint 3
-		HashMap<KafkaTopicPartition, Long> snapshot3 = consumer.snapshotState(141L, 1578L);
+		consumer.prepareSnapshot(141L, 141L);
+
+		HashMap<KafkaTopicPartition, Long> snapshot3 = new HashMap<>();
+
+		for (Tuple2<KafkaTopicPartition, Long> kafkaTopicPartitionLongTuple2 : listState1.get()) {
+			snapshot1.put(kafkaTopicPartitionLongTuple2.f0, kafkaTopicPartitionLongTuple2.f1);
+		}
+
 		assertEquals(state3, snapshot3);
 		assertEquals(2, pendingCheckpoints.size());
 		assertEquals(state3, pendingCheckpoints.get(141L));
@@ -164,9 +249,14 @@ public class FlinkKafkaConsumerBaseTest {
 		consumer.notifyCheckpointComplete(666); // invalid checkpoint
 		assertEquals(0, pendingCheckpoints.size());
 
+		OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
+		TestingListState<Tuple2<KafkaTopicPartition, Long>> listState = new TestingListState<>();
+		when(operatorStateStore.getPartitionableState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+
 		// create 500 snapshots
 		for (int i = 100; i < 600; i++) {
-			consumer.snapshotState(i, 15 * i);
+			consumer.prepareSnapshot(i, i);
+			listState.clear();
 		}
 		assertEquals(FlinkKafkaConsumerBase.MAX_NUM_PENDING_CHECKPOINTS, pendingCheckpoints.size());
 
@@ -211,12 +301,37 @@ public class FlinkKafkaConsumerBaseTest {
 
 		@SuppressWarnings("unchecked")
 		public DummyFlinkKafkaConsumer() {
-			super((KeyedDeserializationSchema<T>) mock(KeyedDeserializationSchema.class));
+			super(Arrays.asList("abc", "def"), (KeyedDeserializationSchema < T >) mock(KeyedDeserializationSchema.class));
 		}
 
 		@Override
 		protected AbstractFetcher<T, ?> createFetcher(SourceContext<T> sourceContext, List<KafkaTopicPartition> thisSubtaskPartitions, SerializedValue<AssignerWithPeriodicWatermarks<T>> watermarksPeriodic, SerializedValue<AssignerWithPunctuatedWatermarks<T>> watermarksPunctuated, StreamingRuntimeContext runtimeContext) throws Exception {
 			return null;
 		}
+
+		@Override
+		protected List<KafkaTopicPartition> getKafkaPartitions(List<String> topics) {
+			return Collections.emptyList();
+		}
+	}
+
+	private static final class TestingListState<T> implements ListState<T> {
+
+		private final List<T> list = new ArrayList<>();
+
+		@Override
+		public void clear() {
+			list.clear();
+		}
+
+		@Override
+		public Iterable<T> get() throws Exception {
+			return list;
+		}
+
+		@Override
+		public void add(T value) throws Exception {
+			list.add(value);
+		}
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumerTestBase.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumerTestBase.java b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumerTestBase.java
index a87ff8a..9c36b43 100644
--- a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumerTestBase.java
+++ b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaConsumerTestBase.java
@@ -68,7 +68,6 @@ import org.apache.flink.streaming.api.functions.sink.SinkFunction;
 import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
 import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
-import org.apache.flink.streaming.api.functions.source.StatefulSequenceSource;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.api.watermark.Watermark;
@@ -92,7 +91,6 @@ import org.apache.flink.test.util.SuccessException;
 import org.apache.flink.testutils.junit.RetryOnException;
 import org.apache.flink.testutils.junit.RetryRule;
 import org.apache.flink.util.Collector;
-import org.apache.flink.util.StringUtils;
 import org.apache.kafka.clients.producer.ProducerConfig;
 import org.apache.kafka.common.errors.TimeoutException;
 import org.junit.Assert;
@@ -186,15 +184,27 @@ public abstract class KafkaConsumerTestBase extends KafkaTestBase {
 			DataStream<String> stream = see.addSource(source);
 			stream.print();
 			see.execute("No broker test");
-		} catch(RuntimeException re) {
+		} catch(ProgramInvocationException pie) {
 			if(kafkaServer.getVersion().equals("0.9")) {
-				Assert.assertTrue("Wrong RuntimeException thrown: " + StringUtils.stringifyException(re),
-						re.getClass().equals(TimeoutException.class) &&
-								re.getMessage().contains("Timeout expired while fetching topic metadata"));
+				assertTrue(pie.getCause() instanceof JobExecutionException);
+
+				JobExecutionException jee = (JobExecutionException) pie.getCause();
+
+				assertTrue(jee.getCause() instanceof TimeoutException);
+
+				TimeoutException te = (TimeoutException) jee.getCause();
+
+				assertEquals("Timeout expired while fetching topic metadata", te.getMessage());
 			} else {
-				Assert.assertTrue("Wrong RuntimeException thrown: " + StringUtils.stringifyException(re),
-						re.getClass().equals(RuntimeException.class) &&
-								re.getMessage().contains("Unable to retrieve any partitions for the requested topics [doesntexist]"));
+				assertTrue(pie.getCause() instanceof JobExecutionException);
+
+				JobExecutionException jee = (JobExecutionException) pie.getCause();
+
+				assertTrue(jee.getCause() instanceof RuntimeException);
+
+				RuntimeException re = (RuntimeException) jee.getCause();
+
+				assertTrue(re.getMessage().contains("Unable to retrieve any partitions for the requested topics [doesntexist]"));
 			}
 		}
 	}
@@ -413,7 +423,7 @@ public abstract class KafkaConsumerTestBase extends KafkaTestBase {
 		DataGenerators.generateRandomizedIntegerSequence(
 				StreamExecutionEnvironment.createRemoteEnvironment("localhost", flinkPort),
 				kafkaServer,
-				topic, numPartitions, numElementsPerPartition, true);
+				topic, numPartitions, numElementsPerPartition, false);
 
 		// run the topology that fails and recovers
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/testutils/MockRuntimeContext.java
----------------------------------------------------------------------
diff --git a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/testutils/MockRuntimeContext.java b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/testutils/MockRuntimeContext.java
index da2c652..5be4195 100644
--- a/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/testutils/MockRuntimeContext.java
+++ b/flink-streaming-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/testutils/MockRuntimeContext.java
@@ -173,16 +173,6 @@ public class MockRuntimeContext extends StreamingRuntimeContext {
 	}
 
 	@Override
-	public <S> org.apache.flink.api.common.state.OperatorState<S> getKeyValueState(String name, Class<S> stateType, S defaultState) {
-		throw new UnsupportedOperationException();
-	}
-
-	@Override
-	public <S> org.apache.flink.api.common.state.OperatorState<S> getKeyValueState(String name, TypeInformation<S> stateType, S defaultState) {
-		throw new UnsupportedOperationException();
-	}
-
-	@Override
 	public <T> ValueState<T> getState(ValueStateDescriptor<T> stateProperties) {
 		throw new UnsupportedOperationException();
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java
index 6e2850c..4a0fd60 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/Checkpointed.java
@@ -36,6 +36,7 @@ import java.io.Serializable;
  * 
  * @param <T> The type of the operator state.
  */
+@Deprecated
 @PublicEvolving
 public interface Checkpointed<T extends Serializable> {
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedFunction.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedFunction.java
new file mode 100644
index 0000000..2227201
--- /dev/null
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointedFunction.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.checkpoint;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.runtime.state.OperatorStateStore;
+
+/**
+ *
+ * Similar to @{@link Checkpointed}, this interface must be implemented by functions that have potentially
+ * repartitionable state that needs to be checkpointed. Methods from this interface are called upon checkpointing and
+ * restoring of state.
+ *
+ * On #initializeState the implementing class receives the {@link org.apache.flink.runtime.state.OperatorStateStore}
+ * to store it's state. At least before each snapshot, all state persistent state must be stored in the state store.
+ *
+ * When the backend is received for initialization, the user registers states with the backend via
+ * {@link org.apache.flink.api.common.state.StateDescriptor}. Then, all previously stored state is found in the
+ * received {@link org.apache.flink.api.common.state.State} (currently only
+ * {@link org.apache.flink.api.common.state.ListState} is supported.
+ *
+ * In #prepareSnapshot, the implementing class must ensure that all operator state is passed to the operator backend,
+ * i.e. that the state was stored in the relevant {@link org.apache.flink.api.common.state.State} instances that
+ * are requested on restore. Notice that users might want to clear and reinsert the complete state first if incremental
+ * updates of the states are not possible.
+ */
+@PublicEvolving
+public interface CheckpointedFunction {
+
+	/**
+	 *
+	 * This method is called when state should be stored for a checkpoint. The state can be registered and written to
+	 * the provided backend.
+	 *
+	 * @param checkpointId Id of the checkpoint to perform
+	 * @param timestamp Timestamp of the checkpoint
+	 * @throws Exception
+	 */
+	void prepareSnapshot(long checkpointId, long timestamp) throws Exception;
+
+	/**
+	 * This method is called when an operator is opened, so that the function can set the state backend to which it
+	 * hands it's state on snapshot.
+	 *
+	 * @param stateStore the state store to which this function stores it's state
+	 * @throws Exception
+	 */
+	void initializeState(OperatorStateStore stateStore) throws Exception;
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/ListCheckpointed.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/ListCheckpointed.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/ListCheckpointed.java
new file mode 100644
index 0000000..430b2b9
--- /dev/null
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/ListCheckpointed.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.checkpoint;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.java.typeutils.runtime.JavaSerializer;
+
+import java.io.Serializable;
+import java.util.List;
+
+/**
+ * This method must be implemented by functions that have state that needs to be
+ * checkpointed. The functions get a call whenever a checkpoint should take place
+ * and return a snapshot of their state as a list of redistributable sub-states,
+ * which will be checkpointed.
+ *
+ * @param <T> The type of the operator state.
+ */
+@PublicEvolving
+public interface ListCheckpointed<T extends Serializable> {
+
+	ListStateDescriptor<Serializable> DEFAULT_LIST_DESCRIPTOR =
+			new ListStateDescriptor<>("", new JavaSerializer<>());
+
+	/**
+	 * Gets the current state of the function of operator. The state must reflect the result of all
+	 * prior invocations to this function.
+	 *
+	 * @param checkpointId The ID of the checkpoint.
+	 * @param timestamp Timestamp of the checkpoint.
+	 * @return The operator state in a list of redistributable, atomic sub-states.
+	 * @throws Exception Thrown if the creation of the state object failed. This causes the
+	 *                   checkpoint to fail. The system may decide to fail the operation (and trigger
+	 *                   recovery), or to discard this checkpoint attempt and to continue running
+	 *                   and to try again with the next checkpoint attempt.
+	 */
+	List<T> snapshotState(long checkpointId, long timestamp) throws Exception;
+
+	/**
+	 * Restores the state of the function or operator to that of a previous checkpoint.
+	 * This method is invoked when a function is executed as part of a recovery run.
+	 * <p>
+	 * Note that restoreState() is called before open().
+	 *
+	 * @param state The state to be restored as a list of atomic sub-states.
+	 */
+	void restoreState(List<T> state) throws Exception;
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileReaderOperator.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileReaderOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileReaderOperator.java
index 0c0b81a..838bee6 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileReaderOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileReaderOperator.java
@@ -28,6 +28,7 @@ import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.core.fs.FileInputSplit;
 import org.apache.flink.metrics.Counter;
+import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.api.operators.OutputTypeConfigurable;
@@ -60,7 +61,7 @@ import static org.apache.flink.util.Preconditions.checkNotNull;
  */
 @Internal
 public class ContinuousFileReaderOperator<OUT, S extends Serializable> extends AbstractStreamOperator<OUT>
-	implements OneInputStreamOperator<FileInputSplit, OUT>, OutputTypeConfigurable<OUT> {
+	implements OneInputStreamOperator<FileInputSplit, OUT>, OutputTypeConfigurable<OUT>, StreamCheckpointedOperator {
 
 	private static final long serialVersionUID = 1L;
 
@@ -374,7 +375,6 @@ public class ContinuousFileReaderOperator<OUT, S extends Serializable> extends A
 
 	@Override
 	public void snapshotState(FSDataOutputStream os, long checkpointId, long timestamp) throws Exception {
-		super.snapshotState(os, checkpointId, timestamp);
 
 		final ObjectOutputStream oos = new ObjectOutputStream(os);
 
@@ -397,7 +397,6 @@ public class ContinuousFileReaderOperator<OUT, S extends Serializable> extends A
 
 	@Override
 	public void restoreState(FSDataInputStream is) throws Exception {
-		super.restoreState(is);
 
 		final ObjectInputStream ois = new ObjectInputStream(is);
 


[05/10] flink git commit: [FLINK-4379] [checkpoints] Introduce rescalable operator state

Posted by se...@apache.org.
http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index 9adaa86..c39e436 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -20,7 +20,9 @@ package org.apache.flink.runtime.checkpoint;
 
 import com.google.common.collect.Iterables;
 import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.fs.Path;
 import org.apache.flink.runtime.checkpoint.savepoint.HeapSavepointStore;
 import org.apache.flink.runtime.checkpoint.stats.DisabledCheckpointStatsTracker;
 import org.apache.flink.runtime.execution.ExecutionState;
@@ -34,21 +36,21 @@ import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.NotifyCheckpointComplete;
 import org.apache.flink.runtime.messages.checkpoint.TriggerCheckpoint;
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.filesystem.FileStateHandle;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.apache.flink.util.InstantiationUtil;
 import org.apache.flink.util.Preconditions;
-
 import org.junit.Assert;
 import org.junit.Test;
-
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
-
 import scala.concurrent.ExecutionContext;
 import scala.concurrent.Future;
 
@@ -56,6 +58,8 @@ import java.io.IOException;
 import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
@@ -1459,7 +1463,7 @@ public class CheckpointCoordinatorTest {
 					maxConcurrentAttempts,
 					new ExecutionVertex[] { triggerVertex },
 					new ExecutionVertex[] { ackVertex },
-					new ExecutionVertex[] { commitVertex }, 
+					new ExecutionVertex[] { commitVertex },
 					new StandaloneCheckpointIDCounter(),
 					new StandaloneCompletedCheckpointStore(2, cl),
 					new HeapSavepointStore(),
@@ -1531,7 +1535,7 @@ public class CheckpointCoordinatorTest {
 					maxConcurrentAttempts, // max two concurrent checkpoints
 					new ExecutionVertex[] { triggerVertex },
 					new ExecutionVertex[] { ackVertex },
-					new ExecutionVertex[] { commitVertex }, 
+					new ExecutionVertex[] { commitVertex },
 					new StandaloneCheckpointIDCounter(),
 					new StandaloneCompletedCheckpointStore(2, cl),
 					new HeapSavepointStore(),
@@ -1791,29 +1795,29 @@ public class CheckpointCoordinatorTest {
 
 		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
 			ChainedStateHandle<StreamStateHandle> nonPartitionedState = generateStateForVertex(jobVertexID1, index);
+			ChainedStateHandle<OperatorStateHandle> partitionableState = generateChainedPartitionableStateHandle(jobVertexID1, index, 2, 8);
 			List<KeyGroupsStateHandle> partitionedKeyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index));
 
+			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(nonPartitionedState, partitionableState, partitionedKeyGroupState);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
-				jid,
-				jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
-				checkpointId,
-				nonPartitionedState,
-				partitionedKeyGroupState);
+					jid,
+					jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+					checkpointId,
+					checkpointStateHandles);
 
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
 
-
 		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
 			ChainedStateHandle<StreamStateHandle> nonPartitionedState = generateStateForVertex(jobVertexID2, index);
+			ChainedStateHandle<OperatorStateHandle> partitionableState = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8);
 			List<KeyGroupsStateHandle> partitionedKeyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index));
-
+			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(nonPartitionedState, partitionableState, partitionedKeyGroupState);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
-				jid,
-				jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
-				checkpointId,
-				nonPartitionedState,
-				partitionedKeyGroupState);
+					jid,
+					jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+					checkpointId,
+					checkpointStateHandles);
 
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
@@ -1895,13 +1899,12 @@ public class CheckpointCoordinatorTest {
 		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
 			ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID1, index);
 			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index));
-
+			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(valueSizeTuple, null, keyGroupState);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
-				jid,
-				jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
-				checkpointId,
-				valueSizeTuple,
-				keyGroupState);
+					jid,
+					jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+					checkpointId,
+					checkpointStateHandles);
 
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
@@ -1910,13 +1913,12 @@ public class CheckpointCoordinatorTest {
 		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
 			ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID2, index);
 			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index));
-
+			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(valueSizeTuple, null, keyGroupState);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
-				jid,
-				jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
-				checkpointId,
-				valueSizeTuple,
-				keyGroupState);
+					jid,
+					jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+					checkpointId,
+					checkpointStateHandles);
 
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
@@ -2014,12 +2016,12 @@ public class CheckpointCoordinatorTest {
 			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(
 					jobVertexID1, keyGroupPartitions1.get(index));
 
+			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(valueSizeTuple, null, keyGroupState);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
-				jid,
-				jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
-				checkpointId,
-				valueSizeTuple,
-				keyGroupState);
+					jid,
+					jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+					checkpointId,
+					checkpointStateHandles);
 
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
@@ -2031,12 +2033,12 @@ public class CheckpointCoordinatorTest {
 			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(
 					jobVertexID2, keyGroupPartitions2.get(index));
 
+			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(state, null, keyGroupState);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
 					checkpointId,
-					state,
-					keyGroupState);
+					checkpointStateHandles);
 
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
@@ -2132,28 +2134,32 @@ public class CheckpointCoordinatorTest {
 
 		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
 			ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID1, index);
+			ChainedStateHandle<OperatorStateHandle> partitionableState = generateChainedPartitionableStateHandle(jobVertexID1, index, 2, 8);
 			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index));
 
+
+			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(valueSizeTuple, partitionableState, keyGroupState);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
 					checkpointId,
-					valueSizeTuple,
-					keyGroupState);
+					checkpointStateHandles);
 
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
 
 
+		final List<ChainedStateHandle<OperatorStateHandle>> originalPartitionableStates = new ArrayList<>(jobVertex2.getParallelism());
 		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
 			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index));
-
+			ChainedStateHandle<OperatorStateHandle> partitionableState = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8);
+			originalPartitionableStates.add(partitionableState);
+			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(null, partitionableState, keyGroupState);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
 					checkpointId,
-					null,
-					keyGroupState);
+					checkpointStateHandles);
 
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
@@ -2185,22 +2191,49 @@ public class CheckpointCoordinatorTest {
 
 		// verify the restored state
 		verifiyStateRestore(jobVertexID1, newJobVertex1, keyGroupPartitions1);
-
+		List<List<Collection<OperatorStateHandle>>> actualPartitionableStates = new ArrayList<>(newJobVertex2.getParallelism());
 		for (int i = 0; i < newJobVertex2.getParallelism(); i++) {
 			List<KeyGroupsStateHandle> originalKeyGroupState = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i));
 
 			ChainedStateHandle<StreamStateHandle> operatorState = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getChainedStateHandle();
+			List<Collection<OperatorStateHandle>> partitionableState = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getChainedPartitionableStateHandle();
 			List<KeyGroupsStateHandle> keyGroupState = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getKeyGroupsStateHandles();
 
+			actualPartitionableStates.add(partitionableState);
 			assertNull(operatorState);
-			comparePartitionedState(originalKeyGroupState, keyGroupState);
+			compareKeyPartitionedState(originalKeyGroupState, keyGroupState);
 		}
+		comparePartitionableState(originalPartitionableStates, actualPartitionableStates);
 	}
 
 	// ------------------------------------------------------------------------
 	//  Utilities
 	// ------------------------------------------------------------------------
 
+	static void sendAckMessageToCoordinator(
+			CheckpointCoordinator coord,
+			long checkpointId, JobID jid,
+			ExecutionJobVertex jobVertex,
+			JobVertexID jobVertexID,
+			List<KeyGroupRange> keyGroupPartitions) throws Exception {
+
+		for (int index = 0; index < jobVertex.getParallelism(); index++) {
+			ChainedStateHandle<StreamStateHandle> state = generateStateForVertex(jobVertexID, index);
+			List<KeyGroupsStateHandle> keyGroupState = generateKeyGroupState(
+					jobVertexID,
+					keyGroupPartitions.get(index));
+
+			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(state, null, keyGroupState);
+			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
+					jid,
+					jobVertex.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+					checkpointId,
+					checkpointStateHandles);
+
+			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
+		}
+	}
+
 	public static List<KeyGroupsStateHandle> generateKeyGroupState(
 			JobVertexID jobVertexID,
 			KeyGroupRange keyGroupPartition) throws IOException {
@@ -2217,23 +2250,45 @@ public class CheckpointCoordinatorTest {
 		return generateKeyGroupState(keyGroupPartition, testStatesLists);
 	}
 
-	public static List<KeyGroupsStateHandle> generateKeyGroupState(KeyGroupRange keyGroupRange, List< ? extends Serializable> states) throws IOException {
+	public static List<KeyGroupsStateHandle> generateKeyGroupState(
+			KeyGroupRange keyGroupRange,
+			List<? extends Serializable> states) throws IOException {
+
 		Preconditions.checkArgument(keyGroupRange.getNumberOfKeyGroups() == states.size());
 
-		long[] offsets = new long[keyGroupRange.getNumberOfKeyGroups()];
-		List<byte[]> serializedGroupValues = new ArrayList<>(offsets.length);
+		Tuple2<byte[], List<long[]>> serializedDataWithOffsets =
+				serializeTogetherAndTrackOffsets(Collections.<List<? extends Serializable>>singletonList(states));
+
+		KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(keyGroupRange, serializedDataWithOffsets.f1.get(0));
+
+		ByteStreamStateHandle allSerializedStatesHandle = new ByteStreamStateHandle(
+				serializedDataWithOffsets.f0);
+		KeyGroupsStateHandle keyGroupsStateHandle = new KeyGroupsStateHandle(
+				keyGroupRangeOffsets,
+				allSerializedStatesHandle);
+		List<KeyGroupsStateHandle> keyGroupsStateHandleList = new ArrayList<>();
+		keyGroupsStateHandleList.add(keyGroupsStateHandle);
+		return keyGroupsStateHandleList;
+	}
+
+	public static Tuple2<byte[], List<long[]>> serializeTogetherAndTrackOffsets(
+			List<List<? extends Serializable>> serializables) throws IOException {
 
-		KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(keyGroupRange, offsets);
+		List<long[]> offsets = new ArrayList<>(serializables.size());
+		List<byte[]> serializedGroupValues = new ArrayList<>();
 
 		int runningGroupsOffset = 0;
-		// generate test state for all keygroups
-		int idx = 0;
-		for (int keyGroup : keyGroupRange) {
-			keyGroupRangeOffsets.setKeyGroupOffset(keyGroup,runningGroupsOffset);
-			byte[] serializedValue = InstantiationUtil.serializeObject(states.get(idx));
-			runningGroupsOffset += serializedValue.length;
-			serializedGroupValues.add(serializedValue);
-			++idx;
+		for(List<? extends Serializable> list : serializables) {
+
+			long[] currentOffsets = new long[list.size()];
+			offsets.add(currentOffsets);
+
+			for (int i = 0; i < list.size(); ++i) {
+				currentOffsets[i] = runningGroupsOffset;
+				byte[] serializedValue = InstantiationUtil.serializeObject(list.get(i));
+				serializedGroupValues.add(serializedValue);
+				runningGroupsOffset += serializedValue.length;
+			}
 		}
 
 		//write all generated values in a single byte array, which is index by groupOffsetsInFinalByteArray
@@ -2248,15 +2303,7 @@ public class CheckpointCoordinatorTest {
 					serializedGroupValue.length);
 			runningGroupsOffset += serializedGroupValue.length;
 		}
-
-		ByteStreamStateHandle allSerializedStatesHandle = new ByteStreamStateHandle(
-				allSerializedValuesConcatenated);
-		KeyGroupsStateHandle keyGroupsStateHandle = new KeyGroupsStateHandle(
-				keyGroupRangeOffsets,
-				allSerializedStatesHandle);
-		List<KeyGroupsStateHandle> keyGroupsStateHandleList = new ArrayList<>();
-		keyGroupsStateHandleList.add(keyGroupsStateHandle);
-		return keyGroupsStateHandleList;
+		return new Tuple2<>(allSerializedValuesConcatenated, offsets);
 	}
 
 	public static ChainedStateHandle<StreamStateHandle> generateStateForVertex(
@@ -2273,6 +2320,55 @@ public class CheckpointCoordinatorTest {
 		return ChainedStateHandle.wrapSingleHandle(ByteStreamStateHandle.fromSerializable(value));
 	}
 
+	public static ChainedStateHandle<OperatorStateHandle> generateChainedPartitionableStateHandle(
+			JobVertexID jobVertexID,
+			int index,
+			int namedStates,
+			int partitionsPerState) throws IOException {
+
+		Map<String, List<? extends Serializable>> statesListsMap = new HashMap<>(namedStates);
+
+		for (int i = 0; i < namedStates; ++i) {
+			List<Integer> testStatesLists = new ArrayList<>(partitionsPerState);
+			// generate state
+			Random random = new Random(jobVertexID.hashCode() * index + i * namedStates);
+			for (int j = 0; j < partitionsPerState; ++j) {
+				int simulatedStateValue = random.nextInt();
+				testStatesLists.add(simulatedStateValue);
+			}
+			statesListsMap.put("state-" + i, testStatesLists);
+		}
+
+		return generateChainedPartitionableStateHandle(statesListsMap);
+	}
+
+	public static ChainedStateHandle<OperatorStateHandle> generateChainedPartitionableStateHandle(
+			Map<String, List<? extends Serializable>> states) throws IOException {
+
+		List<List<? extends Serializable>> namedStateSerializables = new ArrayList<>(states.size());
+
+		for (Map.Entry<String, List<? extends Serializable>> entry : states.entrySet()) {
+			namedStateSerializables.add(entry.getValue());
+		}
+
+		Tuple2<byte[], List<long[]>> serializationWithOffsets = serializeTogetherAndTrackOffsets(namedStateSerializables);
+
+		Map<String, long[]> offsetsMap = new HashMap<>(states.size());
+
+		int idx = 0;
+		for (Map.Entry<String, List<? extends Serializable>> entry : states.entrySet()) {
+			offsetsMap.put(entry.getKey(), serializationWithOffsets.f1.get(idx));
+			++idx;
+		}
+
+		ByteStreamStateHandle streamStateHandle = new ByteStreamStateHandle(
+				serializationWithOffsets.f0);
+
+		OperatorStateHandle operatorStateHandle =
+				new OperatorStateHandle(streamStateHandle, offsetsMap);
+		return ChainedStateHandle.wrapSingleHandle(operatorStateHandle);
+	}
+
 	public static ExecutionJobVertex mockExecutionJobVertex(
 		JobVertexID jobVertexID,
 		int parallelism,
@@ -2348,16 +2444,24 @@ public class CheckpointCoordinatorTest {
 					getTaskVertices()[i].getCurrentExecutionAttempt().getChainedStateHandle();
 			assertEquals(expectNonPartitionedState.get(0), actualNonPartitionedState.get(0));
 
+			ChainedStateHandle<OperatorStateHandle> expectedPartitionableState =
+					generateChainedPartitionableStateHandle(jobVertexID, i, 2, 8);
+
+			List<Collection<OperatorStateHandle>> actualPartitionableState = executionJobVertex.
+					getTaskVertices()[i].getCurrentExecutionAttempt().getChainedPartitionableStateHandle();
+
+			assertEquals(expectedPartitionableState.get(0), actualPartitionableState.get(0).iterator().next());
+
 			List<KeyGroupsStateHandle> expectPartitionedKeyGroupState = generateKeyGroupState(
 					jobVertexID,
 					keyGroupPartitions.get(i));
 			List<KeyGroupsStateHandle> actualPartitionedKeyGroupState = executionJobVertex.
 					getTaskVertices()[i].getCurrentExecutionAttempt().getKeyGroupsStateHandles();
-			comparePartitionedState(expectPartitionedKeyGroupState, actualPartitionedKeyGroupState);
+			compareKeyPartitionedState(expectPartitionedKeyGroupState, actualPartitionedKeyGroupState);
 		}
 	}
 
-	public static void comparePartitionedState(
+	public static void compareKeyPartitionedState(
 			List<KeyGroupsStateHandle> expectPartitionedKeyGroupState,
 			List<KeyGroupsStateHandle> actualPartitionedKeyGroupState) throws Exception {
 
@@ -2370,22 +2474,68 @@ public class CheckpointCoordinatorTest {
 
 		assertEquals(expectedTotalKeyGroups, actualTotalKeyGroups);
 
-		FSDataInputStream inputStream = expectedHeadOpKeyGroupStateHandle.getStateHandle().openInputStream();
-		for(int groupId : expectedHeadOpKeyGroupStateHandle.keyGroups()) {
-			long offset = expectedHeadOpKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
-			inputStream.seek(offset);
-			int expectedKeyGroupState = InstantiationUtil.deserializeObject(inputStream);
-			for(KeyGroupsStateHandle oneActualKeyGroupStateHandle : actualPartitionedKeyGroupState) {
-				if (oneActualKeyGroupStateHandle.containsKeyGroup(groupId)) {
-					long actualOffset = oneActualKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
-					FSDataInputStream actualInputStream = oneActualKeyGroupStateHandle.getStateHandle().openInputStream();
-					actualInputStream.seek(actualOffset);
-					int actualGroupState = InstantiationUtil.deserializeObject(actualInputStream);
-
-					assertEquals(expectedKeyGroupState, actualGroupState);
+		try (FSDataInputStream inputStream = expectedHeadOpKeyGroupStateHandle.getStateHandle().openInputStream()) {
+			for (int groupId : expectedHeadOpKeyGroupStateHandle.keyGroups()) {
+				long offset = expectedHeadOpKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
+				inputStream.seek(offset);
+				int expectedKeyGroupState = InstantiationUtil.deserializeObject(inputStream);
+				for (KeyGroupsStateHandle oneActualKeyGroupStateHandle : actualPartitionedKeyGroupState) {
+					if (oneActualKeyGroupStateHandle.containsKeyGroup(groupId)) {
+						long actualOffset = oneActualKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
+						try (FSDataInputStream actualInputStream =
+								     oneActualKeyGroupStateHandle.getStateHandle().openInputStream()) {
+							actualInputStream.seek(actualOffset);
+							int actualGroupState = InstantiationUtil.deserializeObject(actualInputStream);
+							assertEquals(expectedKeyGroupState, actualGroupState);
+						}
+					}
+				}
+			}
+		}
+	}
+
+	public static void comparePartitionableState(
+			List<ChainedStateHandle<OperatorStateHandle>> expected,
+			List<List<Collection<OperatorStateHandle>>> actual) throws Exception {
+
+		List<String> expectedResult = new ArrayList<>();
+		for (ChainedStateHandle<OperatorStateHandle> chainedStateHandle : expected) {
+			for (int i = 0; i < chainedStateHandle.getLength(); ++i) {
+				OperatorStateHandle operatorStateHandle = chainedStateHandle.get(i);
+				try (FSDataInputStream in = operatorStateHandle.openInputStream()) {
+					for (Map.Entry<String, long[]> entry : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) {
+						for (long offset : entry.getValue()) {
+							in.seek(offset);
+							Integer state = InstantiationUtil.deserializeObject(in);
+							expectedResult.add(i + " : " + entry.getKey() + " : " + state);
+						}
+					}
 				}
 			}
 		}
+		Collections.sort(expectedResult);
+
+		List<String> actualResult = new ArrayList<>();
+		for (List<Collection<OperatorStateHandle>> collectionList : actual) {
+			if (collectionList != null) {
+				for (int i = 0; i < collectionList.size(); ++i) {
+					Collection<OperatorStateHandle> stateHandles = collectionList.get(i);
+					for (OperatorStateHandle operatorStateHandle : stateHandles) {
+						try (FSDataInputStream in = operatorStateHandle.openInputStream()) {
+							for (Map.Entry<String, long[]> entry : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) {
+								for (long offset : entry.getValue()) {
+									in.seek(offset);
+									Integer state = InstantiationUtil.deserializeObject(in);
+									actualResult.add(i + " : " + entry.getKey() + " : " + state);
+								}
+							}
+						}
+					}
+				}
+			}
+		}
+		Collections.sort(actualResult);
+		Assert.assertEquals(expectedResult, actualResult);
 	}
 
 	@Test
@@ -2415,4 +2565,117 @@ public class CheckpointCoordinatorTest {
 		}
 	}
 
+
+	@Test
+	public void testPartitionableStateRepartitioning() {
+		Random r = new Random(42);
+
+		for (int run = 0; run < 10000; ++run) {
+			int oldParallelism = 1 + r.nextInt(9);
+			int newParallelism = 1 + r.nextInt(9);
+
+			int numNamedStates = 1 + r.nextInt(9);
+			int maxPartitionsPerState = 1 + r.nextInt(9);
+
+			doTestPartitionableStateRepartitioning(
+					r, oldParallelism, newParallelism, numNamedStates, maxPartitionsPerState);
+		}
+	}
+
+	private void doTestPartitionableStateRepartitioning(
+			Random r, int oldParallelism, int newParallelism, int numNamedStates, int maxPartitionsPerState) {
+
+		List<OperatorStateHandle> previousParallelOpInstanceStates = new ArrayList<>(oldParallelism);
+
+		for (int i = 0; i < oldParallelism; ++i) {
+			Path fakePath = new Path("/fake-" + i);
+			Map<String, long[]> namedStatesToOffsets = new HashMap<>();
+			int off = 0;
+			for (int s = 0; s < numNamedStates; ++s) {
+				long[] offs = new long[1 + r.nextInt(maxPartitionsPerState)];
+				if (offs.length > 0) {
+					for (int o = 0; o < offs.length; ++o) {
+						offs[o] = off;
+						++off;
+					}
+					namedStatesToOffsets.put("State-" + s, offs);
+				}
+			}
+
+			previousParallelOpInstanceStates.add(
+					new OperatorStateHandle(new FileStateHandle(fakePath, -1), namedStatesToOffsets));
+		}
+
+		Map<StreamStateHandle, Map<String, List<Long>>> expected = new HashMap<>();
+
+		int expectedTotalPartitions = 0;
+		for (OperatorStateHandle psh : previousParallelOpInstanceStates) {
+			Map<String, long[]> offsMap = psh.getStateNameToPartitionOffsets();
+			Map<String, List<Long>> offsMapWithList = new HashMap<>(offsMap.size());
+			for (Map.Entry<String, long[]> e : offsMap.entrySet()) {
+				long[] offs = e.getValue();
+				expectedTotalPartitions += offs.length;
+				List<Long> offsList = new ArrayList<>(offs.length);
+				for (int i = 0; i < offs.length; ++i) {
+					offsList.add(i, offs[i]);
+				}
+				offsMapWithList.put(e.getKey(), offsList);
+			}
+			expected.put(psh.getDelegateStateHandle(), offsMapWithList);
+		}
+
+		OperatorStateRepartitioner repartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE;
+
+		List<Collection<OperatorStateHandle>> pshs =
+				repartitioner.repartitionState(previousParallelOpInstanceStates, newParallelism);
+
+		Map<StreamStateHandle, Map<String, List<Long>>> actual = new HashMap<>();
+
+		int minCount = Integer.MAX_VALUE;
+		int maxCount = 0;
+		int actualTotalPartitions = 0;
+		for (int p = 0; p < newParallelism; ++p) {
+			int partitionCount = 0;
+
+			Collection<OperatorStateHandle> pshc = pshs.get(p);
+			for (OperatorStateHandle sh : pshc) {
+				for (Map.Entry<String, long[]> namedState : sh.getStateNameToPartitionOffsets().entrySet()) {
+
+					Map<String, List<Long>> x = actual.get(sh.getDelegateStateHandle());
+					if (x == null) {
+						x = new HashMap<>();
+						actual.put(sh.getDelegateStateHandle(), x);
+					}
+
+					List<Long> actualOffs = x.get(namedState.getKey());
+					if (actualOffs == null) {
+						actualOffs = new ArrayList<>();
+						x.put(namedState.getKey(), actualOffs);
+					}
+					long[] add = namedState.getValue();
+					for (int i = 0; i < add.length; ++i) {
+						actualOffs.add(add[i]);
+					}
+
+					partitionCount += namedState.getValue().length;
+				}
+			}
+
+			minCount = Math.min(minCount, partitionCount);
+			maxCount = Math.max(maxCount, partitionCount);
+			actualTotalPartitions += partitionCount;
+		}
+
+		for (Map<String, List<Long>> v : actual.values()) {
+			for (List<Long> l : v.values()) {
+				Collections.sort(l);
+			}
+		}
+
+		int maxLoadDiff = maxCount - minCount;
+		Assert.assertTrue("Difference in partition load is > 1 : " + maxLoadDiff, maxLoadDiff <= 1);
+		Assert.assertEquals(expectedTotalPartitions, actualTotalPartitions);
+		Assert.assertEquals(expected, actual);
+	}
+
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
index a4896aa..bb78b6a 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
@@ -29,14 +29,18 @@ import org.apache.flink.runtime.executiongraph.ExecutionVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.util.SerializableObject;
-
+import org.hamcrest.BaseMatcher;
+import org.hamcrest.Description;
 import org.junit.Test;
 import org.mockito.Mockito;
 
+import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
@@ -112,9 +116,11 @@ public class CheckpointStateRestoreTest {
 			PendingCheckpoint pending = coord.getPendingCheckpoints().values().iterator().next();
 			final long checkpointId = pending.getCheckpointId();
 
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, serializedState, serializedKeyGroupStates));
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointId, serializedState, serializedKeyGroupStates));
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, serializedState, serializedKeyGroupStates));
+			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(serializedState, null, serializedKeyGroupStates);
+
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, checkpointStateHandles));
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointId, checkpointStateHandles));
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, checkpointStateHandles));
 			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec1.getAttemptId(), checkpointId));
 			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec2.getAttemptId(), checkpointId));
 
@@ -125,11 +131,27 @@ public class CheckpointStateRestoreTest {
 			coord.restoreLatestCheckpointedState(map, true, false);
 
 			// verify that each stateful vertex got the state
-			verify(statefulExec1, times(1)).setInitialState(Mockito.eq(serializedState), Mockito.<List<KeyGroupsStateHandle>>any());
-			verify(statefulExec2, times(1)).setInitialState(Mockito.eq(serializedState), Mockito.<List<KeyGroupsStateHandle>>any());
-			verify(statefulExec3, times(1)).setInitialState(Mockito.eq(serializedState), Mockito.<List<KeyGroupsStateHandle>>any());
-			verify(statelessExec1, times(0)).setInitialState(Mockito.<ChainedStateHandle<StreamStateHandle>>any(), Mockito.<List<KeyGroupsStateHandle>>any());
-			verify(statelessExec2, times(0)).setInitialState(Mockito.<ChainedStateHandle<StreamStateHandle>>any(), Mockito.<List<KeyGroupsStateHandle>>any());
+
+			BaseMatcher<CheckpointStateHandles> matcher = new BaseMatcher<CheckpointStateHandles>() {
+				@Override
+				public boolean matches(Object o) {
+					if (o instanceof CheckpointStateHandles) {
+						return ((CheckpointStateHandles) o).getNonPartitionedStateHandles().equals(serializedState);
+					}
+					return false;
+				}
+
+				@Override
+				public void describeTo(Description description) {
+					description.appendValue(serializedState);
+				}
+			};
+
+			verify(statefulExec1, times(1)).setInitialState(Mockito.argThat(matcher), Mockito.<List<Collection<OperatorStateHandle>>>any());
+			verify(statefulExec2, times(1)).setInitialState(Mockito.argThat(matcher), Mockito.<List<Collection<OperatorStateHandle>>>any());
+			verify(statefulExec3, times(1)).setInitialState(Mockito.argThat(matcher), Mockito.<List<Collection<OperatorStateHandle>>>any());
+			verify(statelessExec1, times(0)).setInitialState(Mockito.<CheckpointStateHandles>any(), Mockito.<List<Collection<OperatorStateHandle>>>any());
+			verify(statelessExec2, times(0)).setInitialState(Mockito.<CheckpointStateHandles>any(), Mockito.<List<Collection<OperatorStateHandle>>>any());
 		}
 		catch (Exception e) {
 			e.printStackTrace();
@@ -193,9 +215,11 @@ public class CheckpointStateRestoreTest {
 			final long checkpointId = pending.getCheckpointId();
 
 			// the difference to the test "testSetState" is that one stateful subtask does not report state
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, serializedState, serializedKeyGroupStates));
+			CheckpointStateHandles checkpointStateHandles = new CheckpointStateHandles(serializedState, null, serializedKeyGroupStates);
+
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, checkpointStateHandles));
 			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointId));
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, serializedState, serializedKeyGroupStates));
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, checkpointStateHandles));
 			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec1.getAttemptId(), checkpointId));
 			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec2.getAttemptId(), checkpointId));
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
index 6182ffd..289f5c3 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
@@ -197,7 +197,7 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 		JobVertexID jvid = new JobVertexID();
 
 		Map<JobVertexID, TaskState> taskGroupStates = new HashMap<>();
-		TaskState taskState = new TaskState(jvid, numberOfStates, numberOfStates);
+		TaskState taskState = new TaskState(jvid, numberOfStates, numberOfStates, 1);
 		taskGroupStates.put(jvid, taskState);
 
 		for (int i = 0; i < numberOfStates; i++) {

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
index fd4e02d..b8126e9 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
@@ -106,7 +106,7 @@ public class PendingCheckpointTest {
 		PendingCheckpoint pending = createPendingCheckpoint();
 		PendingCheckpointTest.setTaskState(pending, state);
 
-		pending.acknowledgeTask(ATTEMPT_ID, null, null);
+		pending.acknowledgeTask(ATTEMPT_ID, null);
 
 		CompletedCheckpoint checkpoint = pending.finalizeCheckpoint();
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingSavepointTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingSavepointTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingSavepointTest.java
index 7258545..3701359 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingSavepointTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingSavepointTest.java
@@ -117,7 +117,7 @@ public class PendingSavepointTest {
 
 		Future<String> future = pending.getCompletionFuture();
 
-		pending.acknowledgeTask(ATTEMPT_ID, null, null);
+		pending.acknowledgeTask(ATTEMPT_ID, null);
 
 		CompletedCheckpoint checkpoint = pending.finalizeCheckpoint();
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
index 6a8d072..9fbe574 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
@@ -186,10 +186,5 @@ public class ZooKeeperCompletedCheckpointStoreITCase extends CompletedCheckpoint
 		public long getStateSize() throws IOException {
 			return 0;
 		}
-
-		@Override
-		public void close() throws IOException {
-			
-		}
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
index ef10032..c82be18 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
@@ -24,6 +24,7 @@ import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.junit.Test;
@@ -32,7 +33,9 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.concurrent.ThreadLocalRandom;
 
 import static org.junit.Assert.assertEquals;
@@ -67,17 +70,30 @@ public class SavepointV1Test {
 		List<TaskState> taskStates = new ArrayList<>(numTaskStates);
 
 		for (int i = 0; i < numTaskStates; i++) {
-			TaskState taskState = new TaskState(new JobVertexID(), numSubtaskStates, numSubtaskStates);
+			TaskState taskState = new TaskState(new JobVertexID(), numSubtaskStates, numSubtaskStates, 1);
 			for (int j = 0; j < numSubtaskStates; j++) {
 				StreamStateHandle stateHandle = new ByteStreamStateHandle("Hello".getBytes());
 				taskState.putState(i, new SubtaskState(
 						new ChainedStateHandle<>(Collections.singletonList(stateHandle)), 0));
+
+				stateHandle = new ByteStreamStateHandle("Beautiful".getBytes());
+				Map<String, long[]> offsetsMap = new HashMap<>();
+				offsetsMap.put("A", new long[]{0, 10, 20});
+				offsetsMap.put("B", new long[]{30, 40, 50});
+
+				OperatorStateHandle operatorStateHandle =
+						new OperatorStateHandle(stateHandle, offsetsMap);
+
+				taskState.putPartitionableState(
+						i,
+						new ChainedStateHandle<OperatorStateHandle>(
+								Collections.singletonList(operatorStateHandle)));
 			}
 
 			taskState.putKeyedState(
 					0,
 					new KeyGroupsStateHandle(
-							new KeyGroupRangeOffsets(1,1, new long[] {42}), new ByteStreamStateHandle("Hello".getBytes())));
+							new KeyGroupRangeOffsets(1,1, new long[] {42}), new ByteStreamStateHandle("World".getBytes())));
 
 			taskStates.add(taskState);
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java
index 504143b..1e95732 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/stats/SimpleCheckpointStatsTrackerTest.java
@@ -319,7 +319,7 @@ public class SimpleCheckpointStatsTrackerTest {
 				JobVertexID operatorId = operatorIds[operatorIndex];
 				int parallelism = operatorParallelism[operatorIndex];
 
-				TaskState taskState = new TaskState(operatorId, parallelism, maxParallelism);
+				TaskState taskState = new TaskState(operatorId, parallelism, maxParallelism, 1);
 
 				taskGroupStates.put(operatorId, taskState);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
index ef8e3bd..9b12cac 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
@@ -26,6 +26,7 @@ import akka.testkit.JavaTestKit;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.runtime.akka.AkkaUtils;
 import org.apache.flink.runtime.akka.ListeningBehaviour;
 import org.apache.flink.runtime.blob.BlobServer;
@@ -54,9 +55,11 @@ import org.apache.flink.runtime.leaderelection.TestingLeaderRetrievalService;
 import org.apache.flink.runtime.leaderretrieval.LeaderRetrievalService;
 import org.apache.flink.runtime.messages.JobManagerMessages;
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.RetrievableStreamStateHandle;
+import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.apache.flink.runtime.taskmanager.TaskManager;
 import org.apache.flink.runtime.testingUtils.TestingJobManager;
@@ -80,6 +83,7 @@ import scala.concurrent.duration.FiniteDuration;
 
 import java.util.ArrayDeque;
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
@@ -441,10 +445,15 @@ public class JobManagerHARecoveryTest {
 		private int completedCheckpoints = 0;
 
 		@Override
-		public void setInitialState(ChainedStateHandle<StreamStateHandle> chainedState, List<KeyGroupsStateHandle> keyGroupsState) throws Exception {
+		public void setInitialState(
+			ChainedStateHandle<StreamStateHandle> chainedState,
+			List<KeyGroupsStateHandle> keyGroupsState,
+			List<Collection<OperatorStateHandle>> partitionableOperatorState) throws Exception {
 			int subtaskIndex = getIndexInSubtaskGroup();
 			if (subtaskIndex < recoveredStates.length) {
-				recoveredStates[subtaskIndex] = InstantiationUtil.deserializeObject(chainedState.get(0).openInputStream());
+				try (FSDataInputStream in = chainedState.get(0).openInputStream()) {
+					recoveredStates[subtaskIndex] = InstantiationUtil.deserializeObject(in);
+				}
 			}
 		}
 
@@ -456,11 +465,12 @@ public class JobManagerHARecoveryTest {
 
 				RetrievableStreamStateHandle<Long> state = new RetrievableStreamStateHandle<Long>(byteStreamStateHandle);
 				ChainedStateHandle<StreamStateHandle> chainedStateHandle = new ChainedStateHandle<StreamStateHandle>(Collections.singletonList(state));
+				CheckpointStateHandles checkpointStateHandles =
+						new CheckpointStateHandles(chainedStateHandle, null, Collections.<KeyGroupsStateHandle>emptyList());
 
 				getEnvironment().acknowledgeCheckpoint(
 						checkpointId,
-						chainedStateHandle,
-						Collections.<KeyGroupsStateHandle>emptyList(),
+						checkpointStateHandles,
 						0L, 0L, 0L, 0L);
 				return true;
 			} catch (Exception ex) {

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
index 6a6ac64..4873335 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
@@ -23,11 +23,12 @@ import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.testutils.CommonTestUtils;
 import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTest;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.NotifyCheckpointComplete;
 import org.apache.flink.runtime.messages.checkpoint.TriggerCheckpoint;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.StateObject;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.junit.Test;
 
@@ -65,13 +66,17 @@ public class CheckpointMessagesTest {
 
 			KeyGroupRange keyGroupRange = KeyGroupRange.of(42,42);
 
+			CheckpointStateHandles checkpointStateHandles =
+					new CheckpointStateHandles(
+							CheckpointCoordinatorTest.generateChainedStateHandle(new MyHandle()),
+							CheckpointCoordinatorTest.generateChainedPartitionableStateHandle(new JobVertexID(), 0, 2, 8),
+							CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, Collections.singletonList(new MyHandle())));
+
 			AcknowledgeCheckpoint withState = new AcknowledgeCheckpoint(
 					new JobID(),
 					new ExecutionAttemptID(),
 					87658976143L,
-					CheckpointCoordinatorTest.generateChainedStateHandle(new MyHandle()),
-					CheckpointCoordinatorTest.generateKeyGroupState(
-							keyGroupRange, Collections.singletonList(new MyHandle())));
+					checkpointStateHandles);
 
 			testSerializabilityEqualsHashCode(noState);
 			testSerializabilityEqualsHashCode(withState);
@@ -83,7 +88,6 @@ public class CheckpointMessagesTest {
 
 	private static void testSerializabilityEqualsHashCode(Serializable o) throws IOException {
 		Object copy = CommonTestUtils.createCopySerializable(o);
-		System.out.println(o.getClass() +" "+copy.getClass());
 		assertEquals(o, copy);
 		assertEquals(o.hashCode(), copy.hashCode());
 		assertNotNull(o.toString());
@@ -117,9 +121,6 @@ public class CheckpointMessagesTest {
 		}
 
 		@Override
-		public void close() throws IOException {}
-
-		@Override
 		public FSDataInputStream openInputStream() throws IOException {
 			return null;
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
index a857d1b..c855230 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
@@ -37,6 +37,7 @@ import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.query.KvStateRegistry;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
@@ -162,7 +163,7 @@ public class DummyEnvironment implements Environment {
 	@Override
 	public void acknowledgeCheckpoint(
 			long checkpointId,
-			ChainedStateHandle<StreamStateHandle> chainedStateHandle, List<KeyGroupsStateHandle> keyGroupStateHandles,
+			CheckpointStateHandles checkpointStateHandles,
 			long synchronousDurationMillis, long asynchronousDurationMillis,
 			long bytesBufferedInAlignment, long alignmentDurationNanos) {
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
index 75e88eb..c3ed6c0 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
@@ -46,6 +46,7 @@ import org.apache.flink.runtime.metrics.groups.TaskMetricGroup;
 import org.apache.flink.runtime.query.KvStateRegistry;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 
@@ -323,7 +324,7 @@ public class MockEnvironment implements Environment {
 	@Override
 	public void acknowledgeCheckpoint(
 			long checkpointId,
-			ChainedStateHandle<StreamStateHandle> chainedStateHandle, List<KeyGroupsStateHandle> keyGroupStateHandles,
+			CheckpointStateHandles checkpointStateHandles,
 			long synchronousDurationMillis, long asynchronousDurationMillis,
 			long bytesBufferedInAlignment, long alignmentDurationNanos) {
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/query/QueryableStateClientTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/query/QueryableStateClientTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/query/QueryableStateClientTest.java
index 1039568..4279635 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/query/QueryableStateClientTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/query/QueryableStateClientTest.java
@@ -32,8 +32,8 @@ import org.apache.flink.runtime.query.netty.KvStateClient;
 import org.apache.flink.runtime.query.netty.KvStateServer;
 import org.apache.flink.runtime.query.netty.UnknownKvStateID;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.runtime.state.heap.HeapValueState;
@@ -246,7 +246,7 @@ public class QueryableStateClientTest {
 		MemoryStateBackend backend = new MemoryStateBackend();
 		DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0);
 
-		KeyedStateBackend<Integer> keyedStateBackend = backend.createKeyedStateBackend(dummyEnv,
+		AbstractKeyedStateBackend<Integer> keyedStateBackend = backend.createKeyedStateBackend(dummyEnv,
 				new JobID(),
 				"test_op",
 				IntSerializer.INSTANCE,

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientTest.java
index c8fb4bb..0db8b31 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientTest.java
@@ -41,9 +41,9 @@ import org.apache.flink.runtime.query.KvStateServerAddress;
 import org.apache.flink.runtime.query.netty.message.KvStateRequest;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestType;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.KvState;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
@@ -538,7 +538,8 @@ public class KvStateClientTest {
 		KvStateRegistry dummyRegistry = new KvStateRegistry();
 		DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0);
 		dummyEnv.setKvStateRegistry(dummyRegistry);
-		KeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend(
+
+		AbstractKeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend(
 				dummyEnv,
 				new JobID(),
 				"test_op",

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerHandlerTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerHandlerTest.java
index 7e6d713..ed4a822 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerHandlerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerHandlerTest.java
@@ -38,6 +38,7 @@ import org.apache.flink.runtime.query.netty.message.KvStateRequestFailure;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestResult;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestType;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyedStateBackend;
@@ -92,7 +93,7 @@ public class KvStateServerHandlerTest {
 		AbstractStateBackend abstractBackend = new MemoryStateBackend();
 		DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0);
 		dummyEnv.setKvStateRegistry(registry);
-		KeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend(
+		AbstractKeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend(
 				dummyEnv,
 				new JobID(),
 				"test_op",
@@ -490,7 +491,7 @@ public class KvStateServerHandlerTest {
 		AbstractStateBackend abstractBackend = new MemoryStateBackend();
 		DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0);
 		dummyEnv.setKvStateRegistry(registry);
-		KeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend(
+		AbstractKeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend(
 				dummyEnv,
 				new JobID(),
 				"test_op",
@@ -586,7 +587,7 @@ public class KvStateServerHandlerTest {
 		AbstractStateBackend abstractBackend = new MemoryStateBackend();
 		DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0);
 		dummyEnv.setKvStateRegistry(registry);
-		KeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend(
+		AbstractKeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend(
 				dummyEnv,
 				new JobID(),
 				"test_op",

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java
index e92fb10..b1c4a9f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java
@@ -41,9 +41,9 @@ import org.apache.flink.runtime.query.KvStateServerAddress;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestResult;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestType;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
@@ -91,7 +91,7 @@ public class KvStateServerTest {
 			AbstractStateBackend abstractBackend = new MemoryStateBackend();
 			DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0);
 			dummyEnv.setKvStateRegistry(registry);
-			KeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend(
+			AbstractKeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend(
 					dummyEnv,
 					new JobID(),
 					"test_op",

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/state/AbstractCloseableHandleTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/AbstractCloseableHandleTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/AbstractCloseableHandleTest.java
deleted file mode 100644
index e613105..0000000
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/AbstractCloseableHandleTest.java
+++ /dev/null
@@ -1,97 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.runtime.state;
-
-import org.junit.Test;
-
-import java.io.Closeable;
-import java.io.IOException;
-
-import static org.junit.Assert.*;
-import static org.mockito.Mockito.*;
-
-public class AbstractCloseableHandleTest {
-
-	@Test
-	public void testRegisterThenClose() throws Exception {
-		Closeable closeable = mock(Closeable.class);
-
-		AbstractCloseableHandle handle = new CloseableHandle();
-		assertFalse(handle.isClosed());
-
-		// no immediate closing
-		handle.registerCloseable(closeable);
-		verify(closeable, times(0)).close();
-		assertFalse(handle.isClosed());
-
-		// close forwarded once
-		handle.close();
-		verify(closeable, times(1)).close();
-		assertTrue(handle.isClosed());
-
-		// no repeated closing
-		handle.close();
-		verify(closeable, times(1)).close();
-		assertTrue(handle.isClosed());
-	}
-
-	@Test
-	public void testCloseThenRegister() throws Exception {
-		Closeable closeable = mock(Closeable.class);
-
-		AbstractCloseableHandle handle = new CloseableHandle();
-		assertFalse(handle.isClosed());
-
-		// close the handle before setting the closeable
-		handle.close();
-		assertTrue(handle.isClosed());
-
-		// immediate closing
-		try {
-			handle.registerCloseable(closeable);
-			fail("this should throw an excepion");
-		} catch (IOException e) {
-			// expected
-			assertTrue(e.getMessage().contains("closed"));
-		}
-
-		// should still have called "close" on the Closeable
-		verify(closeable, times(1)).close();
-		assertTrue(handle.isClosed());
-
-		// no repeated closing
-		handle.close();
-		verify(closeable, times(1)).close();
-		assertTrue(handle.isClosed());
-	}
-
-	// ------------------------------------------------------------------------
-
-	private static final class CloseableHandle extends AbstractCloseableHandle {
-		private static final long serialVersionUID = 1L;
-
-		@Override
-		public void discardState() {}
-
-		@Override
-		public long getStateSize() {
-			return 0;
-		}
-	}
-}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java
index bc0b9c3..0b04ebc 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java
@@ -20,16 +20,11 @@ package org.apache.flink.runtime.state;
 
 import org.apache.commons.io.FileUtils;
 import org.apache.flink.api.common.JobID;
-import org.apache.flink.api.common.typeutils.base.IntSerializer;
-import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.core.fs.Path;
 import org.apache.flink.core.testutils.CommonTestUtils;
-import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
-
 import org.apache.flink.runtime.state.filesystem.FileStateHandle;
 import org.apache.flink.runtime.state.filesystem.FsStateBackend;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
-
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.TemporaryFolder;
@@ -39,9 +34,12 @@ import java.io.IOException;
 import java.io.InputStream;
 import java.net.URI;
 import java.util.Random;
-import java.util.UUID;
 
-import static org.junit.Assert.*;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 
 public class FileStateBackendTest extends StateBackendTestBase<FsStateBackend> {
 
@@ -188,18 +186,21 @@ public class FileStateBackendTest extends StateBackendTestBase<FsStateBackend> {
 	}
 
 	private static void validateBytesInStream(InputStream is, byte[] data) throws IOException {
-		byte[] holder = new byte[data.length];
+		try {
+			byte[] holder = new byte[data.length];
 
-		int pos = 0;
-		int read;
-		while (pos < holder.length && (read = is.read(holder, pos, holder.length - pos)) != -1) {
-			pos += read;
-		}
+			int pos = 0;
+			int read;
+			while (pos < holder.length && (read = is.read(holder, pos, holder.length - pos)) != -1) {
+				pos += read;
+			}
 
-		assertEquals("not enough data", holder.length, pos);
-		assertEquals("too much data", -1, is.read());
-		assertArrayEquals("wrong data", data, holder);
-		is.close();
+			assertEquals("not enough data", holder.length, pos);
+			assertEquals("too much data", -1, is.read());
+			assertArrayEquals("wrong data", data, holder);
+		} finally {
+			is.close();
+		}
 	}
 
 	@Test

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
index 944938b..ac6adff 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
@@ -19,8 +19,6 @@
 package org.apache.flink.runtime.state;
 
 import org.apache.flink.api.common.JobID;
-import org.apache.flink.api.common.typeutils.base.IntSerializer;
-import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.junit.Test;
 
@@ -29,7 +27,10 @@ import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
 import java.util.HashMap;
 
-import static org.junit.Assert.*;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 
 /**
  * Tests for the {@link org.apache.flink.runtime.state.memory.MemoryStateBackend}.
@@ -105,10 +106,10 @@ public class MemoryStateBackendTest extends StateBackendTestBase<MemoryStateBack
 
 			assertNotNull(handle);
 
-			ObjectInputStream ois = new ObjectInputStream(handle.openInputStream());
-			assertEquals(state, ois.readObject());
-			assertTrue(ois.available() <= 0);
-			ois.close();
+			try (ObjectInputStream ois = new ObjectInputStream(handle.openInputStream())) {
+				assertEquals(state, ois.readObject());
+				assertTrue(ois.available() <= 0);
+			}
 		}
 		catch (Exception e) {
 			e.printStackTrace();

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
new file mode 100644
index 0000000..56c8987
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.java.typeutils.runtime.JavaSerializer;
+import org.apache.flink.runtime.state.memory.MemoryStateBackend;
+import org.junit.Test;
+
+import java.io.Serializable;
+import java.util.Collections;
+import java.util.Iterator;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+public class OperatorStateBackendTest {
+
+	AbstractStateBackend abstractStateBackend = new MemoryStateBackend(1024);
+
+	private OperatorStateBackend createNewOperatorStateBackend() throws Exception {
+		return abstractStateBackend.createOperatorStateBackend(null, "test-operator");
+	}
+
+	@Test
+	public void testCreateNew() throws Exception {
+		OperatorStateBackend operatorStateBackend = createNewOperatorStateBackend();
+		assertNotNull(operatorStateBackend);
+		assertTrue(operatorStateBackend.getRegisteredStateNames().isEmpty());
+	}
+
+	@Test
+	public void testRegisterStates() throws Exception {
+		OperatorStateBackend operatorStateBackend = createNewOperatorStateBackend();
+		ListStateDescriptor<Serializable> stateDescriptor1 = new ListStateDescriptor<>("test1", new JavaSerializer<>());
+		ListStateDescriptor<Serializable> stateDescriptor2 = new ListStateDescriptor<>("test2", new JavaSerializer<>());
+		ListState<Serializable> listState1 = operatorStateBackend.getPartitionableState(stateDescriptor1);
+		assertNotNull(listState1);
+		assertEquals(1, operatorStateBackend.getRegisteredStateNames().size());
+		Iterator<Serializable> it = listState1.get().iterator();
+		assertTrue(!it.hasNext());
+		listState1.add(42);
+		listState1.add(4711);
+
+		it = listState1.get().iterator();
+		assertEquals(42, it.next());
+		assertEquals(4711, it.next());
+		assertTrue(!it.hasNext());
+
+		ListState<Serializable> listState2 = operatorStateBackend.getPartitionableState(stateDescriptor2);
+		assertNotNull(listState2);
+		assertEquals(2, operatorStateBackend.getRegisteredStateNames().size());
+		assertTrue(!it.hasNext());
+		listState2.add(7);
+		listState2.add(13);
+		listState2.add(23);
+
+		it = listState2.get().iterator();
+		assertEquals(7, it.next());
+		assertEquals(13, it.next());
+		assertEquals(23, it.next());
+		assertTrue(!it.hasNext());
+
+		ListState<Serializable> listState1b = operatorStateBackend.getPartitionableState(stateDescriptor1);
+		assertNotNull(listState1b);
+		listState1b.add(123);
+		it = listState1b.get().iterator();
+		assertEquals(42, it.next());
+		assertEquals(4711, it.next());
+		assertEquals(123, it.next());
+		assertTrue(!it.hasNext());
+
+		it = listState1.get().iterator();
+		assertEquals(42, it.next());
+		assertEquals(4711, it.next());
+		assertEquals(123, it.next());
+		assertTrue(!it.hasNext());
+
+		it = listState1b.get().iterator();
+		assertEquals(42, it.next());
+		assertEquals(4711, it.next());
+		assertEquals(123, it.next());
+		assertTrue(!it.hasNext());
+	}
+
+	@Test
+	public void testSnapshotRestore() throws Exception {
+		OperatorStateBackend operatorStateBackend = createNewOperatorStateBackend();
+		ListStateDescriptor<Serializable> stateDescriptor1 = new ListStateDescriptor<>("test1", new JavaSerializer<>());
+		ListStateDescriptor<Serializable> stateDescriptor2 = new ListStateDescriptor<>("test2", new JavaSerializer<>());
+		ListState<Serializable> listState1 = operatorStateBackend.getPartitionableState(stateDescriptor1);
+		ListState<Serializable> listState2 = operatorStateBackend.getPartitionableState(stateDescriptor2);
+
+		listState1.add(42);
+		listState1.add(4711);
+
+		listState2.add(7);
+		listState2.add(13);
+		listState2.add(23);
+
+		CheckpointStreamFactory streamFactory = abstractStateBackend.createStreamFactory(new JobID(), "testOperator");
+		OperatorStateHandle stateHandle = operatorStateBackend.snapshot(1, 1, streamFactory).get();
+
+		try {
+
+			operatorStateBackend.dispose();
+
+			operatorStateBackend = abstractStateBackend.
+					restoreOperatorStateBackend(null, "testOperator", Collections.singletonList(stateHandle));
+
+			assertEquals(0, operatorStateBackend.getRegisteredStateNames().size());
+
+			listState1 = operatorStateBackend.getPartitionableState(stateDescriptor1);
+			listState2 = operatorStateBackend.getPartitionableState(stateDescriptor2);
+
+			assertEquals(2, operatorStateBackend.getRegisteredStateNames().size());
+
+
+			Iterator<Serializable> it = listState1.get().iterator();
+			assertEquals(42, it.next());
+			assertEquals(4711, it.next());
+			assertTrue(!it.hasNext());
+
+			it = listState2.get().iterator();
+			assertEquals(7, it.next());
+			assertEquals(13, it.next());
+			assertEquals(23, it.next());
+			assertTrue(!it.hasNext());
+
+			operatorStateBackend.dispose();
+		} finally {
+
+			stateHandle.discardState();
+		}
+	}
+
+}
\ No newline at end of file


[03/10] flink git commit: [FLINK-4379] [checkpoints] Introduce rescalable operator state

Posted by se...@apache.org.
http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
index a73f3b2..0ca89ef 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
@@ -18,29 +18,35 @@
 
 package org.apache.flink.streaming.api.operators;
 
+import org.apache.commons.io.IOUtils;
 import org.apache.flink.annotation.PublicEvolving;
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.state.State;
 import org.apache.flink.api.common.state.StateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.functions.KeySelector;
-import org.apache.flink.core.fs.FSDataInputStream;
-import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.metrics.Counter;
 import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
+import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.KeyedStateBackend;
+import org.apache.flink.runtime.state.OperatorStateBackend;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.watermark.Watermark;
-import org.apache.flink.streaming.runtime.tasks.TimeServiceProvider;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
+import org.apache.flink.streaming.runtime.tasks.TimeServiceProvider;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.util.Collection;
+import java.util.concurrent.RunnableFuture;
+
 /**
  * Base class for all stream operators. Operators that contain a user function should extend the class 
  * {@link AbstractUdfStreamOperator} instead (which is a specialized subclass of this class). 
@@ -90,7 +96,12 @@ public abstract class AbstractStreamOperator<OUT>
 	private transient KeySelector<?, ?> stateKeySelector2;
 
 	/** Backend for keyed state. This might be empty if we're not on a keyed stream. */
-	private transient KeyedStateBackend<?> keyedStateBackend;
+	private transient AbstractKeyedStateBackend<?> keyedStateBackend;
+
+	/** Operator state backend */
+	private transient OperatorStateBackend operatorStateBackend;
+
+	private transient Collection<OperatorStateHandle> lazyRestoreStateHandles;
 
 	protected transient MetricGroup metrics;
 
@@ -116,9 +127,14 @@ public abstract class AbstractStreamOperator<OUT>
 		return metrics;
 	}
 
+	@Override
+	public void restoreState(Collection<OperatorStateHandle> stateHandles) {
+		this.lazyRestoreStateHandles = stateHandles;
+	}
+
 	/**
 	 * This method is called immediately before any elements are processed, it should contain the
-	 * operator's initialization logic.
+	 * operator's initialization logic, e.g. state initialization.
 	 *
 	 * <p>The default implementation does nothing.
 	 * 
@@ -126,24 +142,39 @@ public abstract class AbstractStreamOperator<OUT>
 	 */
 	@Override
 	public void open() throws Exception {
+		initOperatorState();
+		initKeyedState();
+	}
+
+	private void initKeyedState() {
 		try {
 			TypeSerializer<Object> keySerializer = config.getStateKeySerializer(getUserCodeClassloader());
 			// create a keyed state backend if there is keyed state, as indicated by the presence of a key serializer
 			if (null != keySerializer) {
-				ExecutionConfig execConf = container.getEnvironment().getExecutionConfig();;
 
 				KeyGroupRange subTaskKeyGroupRange = KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(
 						container.getEnvironment().getTaskInfo().getNumberOfKeyGroups(),
 						container.getEnvironment().getTaskInfo().getNumberOfParallelSubtasks(),
 						container.getIndexInSubtaskGroup());
 
-				keyedStateBackend = container.createKeyedStateBackend(
+				this.keyedStateBackend = container.createKeyedStateBackend(
 						keySerializer,
 						container.getConfiguration().getNumberOfKeyGroups(getUserCodeClassloader()),
 						subTaskKeyGroupRange);
+
 			}
+
+		} catch (Exception e) {
+			throw new IllegalStateException("Could not initialize keyed state backend.", e);
+		}
+	}
+
+	private void initOperatorState() {
+		try {
+			// create an operator state backend
+			this.operatorStateBackend = container.createOperatorStateBackend(this, lazyRestoreStateHandles);
 		} catch (Exception e) {
-			throw new RuntimeException("Could not initialize keyed state backend.", e);
+			throw new IllegalStateException("Could not initialize operator state backend.", e);
 		}
 	}
 
@@ -171,18 +202,25 @@ public abstract class AbstractStreamOperator<OUT>
 	 */
 	@Override
 	public void dispose() throws Exception {
+
+		if (operatorStateBackend != null) {
+			IOUtils.closeQuietly(operatorStateBackend);
+			operatorStateBackend.dispose();
+		}
+
 		if (keyedStateBackend != null) {
-			keyedStateBackend.close();
+			IOUtils.closeQuietly(keyedStateBackend);
+			keyedStateBackend.dispose();
 		}
 	}
 
 	@Override
-	public void snapshotState(FSDataOutputStream out,
-			long checkpointId,
-			long timestamp) throws Exception {}
+	public RunnableFuture<OperatorStateHandle> snapshotState(
+			long checkpointId, long timestamp, CheckpointStreamFactory streamFactory) throws Exception {
 
-	@Override
-	public void restoreState(FSDataInputStream in) throws Exception {}
+		return operatorStateBackend != null ?
+				operatorStateBackend.snapshot(checkpointId, timestamp, streamFactory) : null;
+	}
 
 	@Override
 	public void notifyOfCompletedCheckpoint(long checkpointId) throws Exception {}
@@ -223,10 +261,24 @@ public abstract class AbstractStreamOperator<OUT>
 	}
 
 	@SuppressWarnings("rawtypes, unchecked")
-	public <K> KeyedStateBackend<K> getStateBackend() {
+	public <K> KeyedStateBackend<K> getKeyedStateBackend() {
+
+		if (null == keyedStateBackend) {
+			initKeyedState();
+		}
+
 		return (KeyedStateBackend<K>) keyedStateBackend;
 	}
 
+	public OperatorStateBackend getOperatorStateBackend() {
+
+		if (null == operatorStateBackend) {
+			initOperatorState();
+		}
+
+		return operatorStateBackend;
+	}
+
 	/**
 	 * Returns the {@link TimeServiceProvider} responsible for getting  the current
 	 * processing time and registering timers.
@@ -268,18 +320,18 @@ public abstract class AbstractStreamOperator<OUT>
 	@Override
 	@SuppressWarnings({"unchecked", "rawtypes"})
 	public void setKeyContextElement1(StreamRecord record) throws Exception {
-		if (stateKeySelector1 != null) {
-			Object key = ((KeySelector) stateKeySelector1).getKey(record.getValue());
-			getStateBackend().setCurrentKey(key);
-		}
+		setRawKeyContextElement(record, stateKeySelector1);
 	}
 
 	@Override
 	@SuppressWarnings({"unchecked", "rawtypes"})
 	public void setKeyContextElement2(StreamRecord record) throws Exception {
-		if (stateKeySelector2 != null) {
-			Object key = ((KeySelector) stateKeySelector2).getKey(record.getValue());
+		setRawKeyContextElement(record, stateKeySelector2);
+	}
 
+	private void setRawKeyContextElement(StreamRecord record, KeySelector<?, ?> selector) throws Exception {
+		if (selector != null) {
+			Object key = ((KeySelector) selector).getKey(record.getValue());
 			setKeyContext(key);
 		}
 	}
@@ -290,7 +342,7 @@ public abstract class AbstractStreamOperator<OUT>
 			try {
 				// need to work around type restrictions
 				@SuppressWarnings("unchecked,rawtypes")
-				KeyedStateBackend rawBackend = (KeyedStateBackend) keyedStateBackend;
+				AbstractKeyedStateBackend rawBackend = (AbstractKeyedStateBackend) keyedStateBackend;
 
 				rawBackend.setCurrentKey(key);
 			} catch (Exception e) {

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
index 6ac73e7..f683d9a 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
@@ -18,23 +18,31 @@
 
 package org.apache.flink.streaming.api.operators;
 
-import java.io.Serializable;
-
 import org.apache.flink.annotation.PublicEvolving;
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.functions.Function;
 import org.apache.flink.api.common.functions.util.FunctionUtils;
+import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.runtime.state.CheckpointListener;
+import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.streaming.api.checkpoint.Checkpointed;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
 import org.apache.flink.util.InstantiationUtil;
 
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.RunnableFuture;
+
 import static java.util.Objects.requireNonNull;
 
 /**
@@ -50,7 +58,8 @@ import static java.util.Objects.requireNonNull;
 @PublicEvolving
 public abstract class AbstractUdfStreamOperator<OUT, F extends Function>
 		extends AbstractStreamOperator<OUT>
-		implements OutputTypeConfigurable<OUT> {
+		implements OutputTypeConfigurable<OUT>,
+		StreamCheckpointedOperator {
 
 	private static final long serialVersionUID = 1L;
 	
@@ -91,6 +100,28 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function>
 		super.open();
 		
 		FunctionUtils.openFunction(userFunction, new Configuration());
+
+		if (userFunction instanceof CheckpointedFunction) {
+			((CheckpointedFunction) userFunction).initializeState(getOperatorStateBackend());
+		} else if (userFunction instanceof ListCheckpointed) {
+			@SuppressWarnings("unchecked")
+			ListCheckpointed<Serializable> listCheckpointedFun = (ListCheckpointed<Serializable>) userFunction;
+
+			ListState<Serializable> listState =
+					getOperatorStateBackend().getPartitionableState(ListCheckpointed.DEFAULT_LIST_DESCRIPTOR);
+
+			List<Serializable> list = new ArrayList<>();
+
+			for (Serializable serializable : listState.get()) {
+				list.add(serializable);
+			}
+
+			try {
+				listCheckpointedFun.restoreState(list);
+			} catch (Exception e) {
+				throw new Exception("Failed to restore state to function: " + e.getMessage(), e);
+			}
+		}
 	}
 
 	@Override
@@ -115,7 +146,6 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function>
 	
 	@Override
 	public void snapshotState(FSDataOutputStream out, long checkpointId, long timestamp) throws Exception {
-		super.snapshotState(out, checkpointId, timestamp);
 
 		if (userFunction instanceof Checkpointed) {
 			@SuppressWarnings("unchecked")
@@ -138,7 +168,6 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function>
 
 	@Override
 	public void restoreState(FSDataInputStream in) throws Exception {
-		super.restoreState(in);
 
 		if (userFunction instanceof Checkpointed) {
 			@SuppressWarnings("unchecked")
@@ -160,6 +189,32 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function>
 	}
 
 	@Override
+	public RunnableFuture<OperatorStateHandle> snapshotState(
+			long checkpointId, long timestamp, CheckpointStreamFactory streamFactory) throws Exception {
+
+		if (userFunction instanceof CheckpointedFunction) {
+			((CheckpointedFunction) userFunction).prepareSnapshot(checkpointId, timestamp);
+		}
+
+		if (userFunction instanceof ListCheckpointed) {
+			@SuppressWarnings("unchecked")
+			List<Serializable> partitionableState =
+					((ListCheckpointed<Serializable>) userFunction).snapshotState(checkpointId, timestamp);
+
+			ListState<Serializable> listState =
+					getOperatorStateBackend().getPartitionableState(ListCheckpointed.DEFAULT_LIST_DESCRIPTOR);
+
+			listState.clear();
+
+			for (Serializable statePartition : partitionableState) {
+				listState.add(statePartition);
+			}
+		}
+
+		return super.snapshotState(checkpointId, timestamp, streamFactory);
+	}
+
+	@Override
 	public void notifyOfCompletedCheckpoint(long checkpointId) throws Exception {
 		super.notifyOfCompletedCheckpoint(checkpointId);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamCheckpointedOperator.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamCheckpointedOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamCheckpointedOperator.java
new file mode 100644
index 0000000..50cdc02
--- /dev/null
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamCheckpointedOperator.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators;
+
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.streaming.api.graph.StreamConfig;
+import org.apache.flink.streaming.api.operators.Output;
+import org.apache.flink.streaming.runtime.tasks.StreamTask;
+
+@Deprecated
+public interface StreamCheckpointedOperator {
+
+	/**
+	 * Called to draw a state snapshot from the operator. This method snapshots the operator state
+	 * (if the operator is stateful).
+	 *
+	 * @param out The stream to which we have to write our state.
+	 * @param checkpointId The ID of the checkpoint.
+	 * @param timestamp The timestamp of the checkpoint.
+	 *
+	 * @throws Exception Forwards exceptions that occur while drawing snapshots from the operator
+	 *                   and the key/value state.
+	 */
+	void snapshotState(FSDataOutputStream out, long checkpointId, long timestamp) throws Exception;
+
+	/**
+	 * Restores the operator state, if this operator's execution is recovering from a checkpoint.
+	 * This method restores the operator state (if the operator is stateful) and the key/value state
+	 * (if it had been used and was initialized when the snapshot occurred).
+	 *
+	 * <p>This method is called after {@link #setup(StreamTask, StreamConfig, Output)}
+	 * and before {@link #open()}.
+	 *
+	 * @param in The stream from which we have to restore our state.
+	 *
+	 * @throws Exception Exceptions during state restore should be forwarded, so that the system can
+	 *                   properly react to failed state restore and fail the execution attempt.
+	 */
+	void restoreState(FSDataInputStream in) throws Exception;
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java
index f1e8160..fae5fd0 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java
@@ -17,16 +17,18 @@
 
 package org.apache.flink.streaming.api.operators;
 
-import java.io.Serializable;
-
 import org.apache.flink.annotation.PublicEvolving;
-import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.metrics.MetricGroup;
-import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
 
+import java.io.Serializable;
+import java.util.Collection;
+import java.util.concurrent.RunnableFuture;
+
 /**
  * Basic interface for stream operators. Implementers would implement one of
  * {@link org.apache.flink.streaming.api.operators.OneInputStreamOperator} or
@@ -91,32 +93,27 @@ public interface StreamOperator<OUT> extends Serializable {
 	// ------------------------------------------------------------------------
 
 	/**
-	 * Called to draw a state snapshot from the operator. This method snapshots the operator state
-	 * (if the operator is stateful).
-	 *
-	 * @param out The stream to which we have to write our state.
-	 * @param checkpointId The ID of the checkpoint.
-	 * @param timestamp The timestamp of the checkpoint.
+	 * Called to draw a state snapshot from the operator.
 	 *
-	 * @throws Exception Forwards exceptions that occur while drawing snapshots from the operator
-	 *                   and the key/value state.
+	 * @throws Exception Forwards exceptions that occur while preparing for the snapshot
 	 */
-	void snapshotState(FSDataOutputStream out, long checkpointId, long timestamp) throws Exception;
 
 	/**
-	 * Restores the operator state, if this operator's execution is recovering from a checkpoint.
-	 * This method restores the operator state (if the operator is stateful) and the key/value state
-	 * (if it had been used and was initialized when the snapshot occurred).
+	 * Called to draw a state snapshot from the operator.
 	 *
-	 * <p>This method is called after {@link #setup(StreamTask, StreamConfig, Output)}
-	 * and before {@link #open()}.
-	 *
-	 * @param in The stream from which we have to restore our state.
+	 * @return a runnable future to the state handle that points to the snapshotted state. For synchronous implementations,
+	 * the runnable might already be finished.
+	 * @throws Exception exception that happened during snapshotting.
+	 */
+	RunnableFuture<OperatorStateHandle> snapshotState(
+			long checkpointId, long timestamp, CheckpointStreamFactory streamFactory) throws Exception;
+
+	/**
+	 * Provides state handles to restore the operator state.
 	 *
-	 * @throws Exception Exceptions during state restore should be forwarded, so that the system can
-	 *                   properly react to failed state restore and fail the execution attempt.
+	 * @param stateHandles state handles to the operator state.
 	 */
-	void restoreState(FSDataInputStream in) throws Exception;
+	void restoreState(Collection<OperatorStateHandle> stateHandles);
 
 	/**
 	 * Called when the checkpoint with the given ID is completed and acknowledged on the JobManager.

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java
index 4f85e3a..cc2e54b 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java
@@ -24,13 +24,10 @@ import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
 import org.apache.flink.api.common.functions.util.AbstractRuntimeUDFContext;
 import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ListStateDescriptor;
-import org.apache.flink.api.common.state.OperatorState;
 import org.apache.flink.api.common.state.ReducingState;
 import org.apache.flink.api.common.state.ReducingStateDescriptor;
 import org.apache.flink.api.common.state.ValueState;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
-import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.api.java.typeutils.TypeExtractor;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
 import org.apache.flink.streaming.api.CheckpointingMode;
@@ -143,35 +140,6 @@ public class StreamingRuntimeContext extends AbstractRuntimeUDFContext {
 		}
 	}
 
-	@Override
-	@Deprecated
-	public <S> OperatorState<S> getKeyValueState(String name, Class<S> stateType, S defaultState) {
-		requireNonNull(stateType, "The state type class must not be null");
-
-		TypeInformation<S> typeInfo;
-		try {
-			typeInfo = TypeExtractor.getForClass(stateType);
-		}
-		catch (Exception e) {
-			throw new RuntimeException("Cannot analyze type '" + stateType.getName() +
-					"' from the class alone, due to generic type parameters. " +
-					"Please specify the TypeInformation directly.", e);
-		}
-
-		return getKeyValueState(name, typeInfo, defaultState);
-	}
-
-	@Override
-	@Deprecated
-	public <S> OperatorState<S> getKeyValueState(String name, TypeInformation<S> stateType, S defaultState) {
-		requireNonNull(name, "The name of the state must not be null");
-		requireNonNull(stateType, "The state type information must not be null");
-
-		ValueStateDescriptor<S> stateProps = 
-				new ValueStateDescriptor<>(name, stateType, defaultState);
-		return getState(stateProps);
-	}
-
 	// ------------------ expose (read only) relevant information from the stream config -------- //
 
 	/**

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java
index 35d1108..b5500b7 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java
@@ -27,6 +27,7 @@ import org.apache.flink.runtime.io.disk.InputViewIterator;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.util.ReusingMutableToRegularIteratorWrapper;
+import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.api.watermark.Watermark;
@@ -51,7 +52,9 @@ import java.util.UUID;
  *
  * @param <IN> Type of the elements emitted by this sink
  */
-public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<IN> implements OneInputStreamOperator<IN, IN> {
+public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<IN>
+		implements OneInputStreamOperator<IN, IN>, StreamCheckpointedOperator {
+
 	private static final long serialVersionUID = 1L;
 
 	protected static final Logger LOG = LoggerFactory.getLogger(GenericWriteAheadSink.class);
@@ -110,7 +113,6 @@ public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<I
 	public void snapshotState(FSDataOutputStream out,
 			long checkpointId,
 			long timestamp) throws Exception {
-		super.snapshotState(out, checkpointId, timestamp);
 
 		saveHandleInState(checkpointId, timestamp);
 
@@ -119,7 +121,6 @@ public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<I
 
 	@Override
 	public void restoreState(FSDataInputStream in) throws Exception {
-		super.restoreState(in);
 
 		this.state = InstantiationUtil.deserializeObject(in, getUserCodeClassloader());
 	}
@@ -151,11 +152,19 @@ public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<I
 					try {
 						if (!committer.isCheckpointCommitted(pastCheckpointId)) {
 							Tuple2<Long, StreamStateHandle> handle = state.pendingHandles.get(pastCheckpointId);
-							FSDataInputStream in = handle.f1.openInputStream();
-							boolean success = sendValues(new ReusingMutableToRegularIteratorWrapper<>(new InputViewIterator<>(new DataInputViewStreamWrapper(in), serializer), serializer), handle.f0);
-							if (success) { //if the sending has failed we will retry on the next notify
-								committer.commitCheckpoint(pastCheckpointId);
-								checkpointsToRemove.add(pastCheckpointId);
+							try (FSDataInputStream in = handle.f1.openInputStream()) {
+								boolean success = sendValues(
+										new ReusingMutableToRegularIteratorWrapper<>(
+												new InputViewIterator<>(
+														new DataInputViewStreamWrapper(
+																in),
+														serializer),
+												serializer),
+										handle.f0);
+								if (success) { //if the sending has failed we will retry on the next notify
+									committer.commitCheckpoint(pastCheckpointId);
+									checkpointsToRemove.add(pastCheckpointId);
+								}
 							}
 						} else {
 							checkpointsToRemove.add(pastCheckpointId);

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/EvictingWindowOperator.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/EvictingWindowOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/EvictingWindowOperator.java
index 4de7729..a838faa 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/EvictingWindowOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/EvictingWindowOperator.java
@@ -88,7 +88,7 @@ public class EvictingWindowOperator<K, IN, OUT, W extends Window> extends Window
 				element.getTimestamp(),
 				windowAssignerContext);
 
-		final K key = (K) getStateBackend().getCurrentKey();
+		final K key = (K) getKeyedStateBackend().getCurrentKey();
 
 		if (windowAssigner instanceof MergingWindowAssigner) {
 
@@ -122,7 +122,7 @@ public class EvictingWindowOperator<K, IN, OUT, W extends Window> extends Window
 								}
 
 								// merge the merged state windows into the newly resulting state window
-								getStateBackend().mergePartitionedStates(
+								getKeyedStateBackend().mergePartitionedStates(
 									stateWindowResult,
 									mergedStateWindows,
 									windowSerializer,

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java
index e4939db..ffdf334 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java
@@ -298,7 +298,7 @@ public class WindowOperator<K, IN, ACC, OUT, W extends Window>
 		Collection<W> elementWindows = windowAssigner.assignWindows(
 			element.getValue(), element.getTimestamp(), windowAssignerContext);
 
-		final K key = (K) getStateBackend().getCurrentKey();
+		final K key = (K) getKeyedStateBackend().getCurrentKey();
 
 		if (windowAssigner instanceof MergingWindowAssigner) {
 			MergingWindowSet<W> mergingWindows = getMergingWindowSet();
@@ -329,7 +329,7 @@ public class WindowOperator<K, IN, ACC, OUT, W extends Window>
 						}
 
 						// merge the merged state windows into the newly resulting state window
-						getStateBackend().mergePartitionedStates(
+						getKeyedStateBackend().mergePartitionedStates(
 							stateWindowResult,
 							mergedStateWindows,
 							windowSerializer,
@@ -554,18 +554,18 @@ public class WindowOperator<K, IN, ACC, OUT, W extends Window>
 	 */
 	@SuppressWarnings("unchecked")
 	protected MergingWindowSet<W> getMergingWindowSet() throws Exception {
-		MergingWindowSet<W> mergingWindows = mergingWindowsByKey.get((K) getStateBackend().getCurrentKey());
+		MergingWindowSet<W> mergingWindows = mergingWindowsByKey.get((K) getKeyedStateBackend().getCurrentKey());
 		if (mergingWindows == null) {
 			// try to retrieve from state
 
 			TupleSerializer<Tuple2<W, W>> tupleSerializer = new TupleSerializer<>((Class) Tuple2.class, new TypeSerializer[] {windowSerializer, windowSerializer} );
 			ListStateDescriptor<Tuple2<W, W>> mergeStateDescriptor = new ListStateDescriptor<>("merging-window-set", tupleSerializer);
-			ListState<Tuple2<W, W>> mergeState = getStateBackend().getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, mergeStateDescriptor);
+			ListState<Tuple2<W, W>> mergeState = getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, mergeStateDescriptor);
 
 			mergingWindows = new MergingWindowSet<>((MergingWindowAssigner<? super IN, W>) windowAssigner, mergeState);
 			mergeState.clear();
 
-			mergingWindowsByKey.put((K) getStateBackend().getCurrentKey(), mergingWindows);
+			mergingWindowsByKey.put((K) getKeyedStateBackend().getCurrentKey(), mergingWindows);
 		}
 		return mergingWindows;
 	}
@@ -709,7 +709,7 @@ public class WindowOperator<K, IN, ACC, OUT, W extends Window>
 		public <S extends MergingState<?, ?>> void mergePartitionedState(StateDescriptor<S, ?> stateDescriptor) {
 			if (mergedWindows != null && mergedWindows.size() > 0) {
 				try {
-					WindowOperator.this.getStateBackend().mergePartitionedStates(window,
+					WindowOperator.this.getKeyedStateBackend().mergePartitionedStates(window,
 							mergedWindows,
 							windowSerializer,
 							stateDescriptor);
@@ -869,7 +869,7 @@ public class WindowOperator<K, IN, ACC, OUT, W extends Window>
 			ListStateDescriptor<Tuple2<W, W>> mergeStateDescriptor = new ListStateDescriptor<>("merging-window-set", tupleSerializer);
 			for (Map.Entry<K, MergingWindowSet<W>> key: mergingWindowsByKey.entrySet()) {
 				setKeyContext(key.getKey());
-				ListState<Tuple2<W, W>> mergeState = getStateBackend().getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, mergeStateDescriptor);
+				ListState<Tuple2<W, W>> mergeState = getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, mergeStateDescriptor);
 				mergeState.clear();
 				key.getValue().persist(mergeState);
 			}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java
index 0e24516..9e96f5d 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java
@@ -17,12 +17,6 @@
 
 package org.apache.flink.streaming.runtime.tasks;
 
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-
 import org.apache.flink.annotation.Internal;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
@@ -35,20 +29,25 @@ import org.apache.flink.runtime.plugable.SerializationDelegate;
 import org.apache.flink.streaming.api.collector.selector.CopyingDirectedOutput;
 import org.apache.flink.streaming.api.collector.selector.DirectedOutput;
 import org.apache.flink.streaming.api.collector.selector.OutputSelector;
-import org.apache.flink.streaming.api.watermark.Watermark;
-import org.apache.flink.streaming.runtime.io.RecordWriterOutput;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.graph.StreamEdge;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.api.operators.Output;
 import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.io.RecordWriterOutput;
 import org.apache.flink.streaming.runtime.io.StreamRecordWriter;
 import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
-
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
 /**
  * The {@code OperatorChain} contains all operators that are executed as one chain within a single
  * {@link StreamTask}.
@@ -57,7 +56,7 @@ import org.slf4j.LoggerFactory;
  *              head operator.
  */
 @Internal
-public class OperatorChain<OUT> {
+public class OperatorChain<OUT, OP extends StreamOperator<OUT>> {
 	
 	private static final Logger LOG = LoggerFactory.getLogger(OperatorChain.class);
 	
@@ -66,16 +65,17 @@ public class OperatorChain<OUT> {
 	private final RecordWriterOutput<?>[] streamOutputs;
 	
 	private final Output<StreamRecord<OUT>> chainEntryPoint;
-	
 
-	public OperatorChain(StreamTask<OUT, ?> containingTask,
-							StreamOperator<OUT> headOperator,
-							AccumulatorRegistry.Reporter reporter) {
+	private final OP headOperator;
+
+	public OperatorChain(StreamTask<OUT, OP> containingTask, AccumulatorRegistry.Reporter reporter) {
 		
 		final ClassLoader userCodeClassloader = containingTask.getUserCodeClassLoader();
 		final StreamConfig configuration = containingTask.getConfiguration();
 		final boolean enableTimestamps = containingTask.isSerializingTimestamps();
 
+		headOperator = configuration.getStreamOperator(userCodeClassloader);
+
 		// we read the chained configs, and the order of record writer registrations by output name
 		Map<Integer, StreamConfig> chainedConfigs = configuration.getTransitiveChainedTaskConfigs(userCodeClassloader);
 		chainedConfigs.put(configuration.getVertexID(), configuration);
@@ -104,11 +104,15 @@ public class OperatorChain<OUT> {
 			List<StreamOperator<?>> allOps = new ArrayList<>(chainedConfigs.size());
 			this.chainEntryPoint = createOutputCollector(containingTask, configuration,
 					chainedConfigs, userCodeClassloader, streamOutputMap, allOps);
+
+			if (headOperator != null) {
+				headOperator.setup(containingTask, configuration, getChainEntryPoint());
+			}
+
+			// add head operator to end of chain
+			allOps.add(headOperator);
 			
-			this.allOperators = allOps.toArray(new StreamOperator<?>[allOps.size() + 1]);
-			
-			// add the head operator to the end of the list
-			this.allOperators[this.allOperators.length - 1] = headOperator;
+			this.allOperators = allOps.toArray(new StreamOperator<?>[allOps.size()]);
 			
 			success = true;
 		}
@@ -181,7 +185,15 @@ public class OperatorChain<OUT> {
 			}
 		}
 	}
-	
+
+	public OP getHeadOperator() {
+		return headOperator;
+	}
+
+	public int getChainLength() {
+		return allOperators == null ? 0 : allOperators.length;
+	}
+
 	// ------------------------------------------------------------------------
 	//  initialization utilities
 	// ------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index 7976f01..1725eca 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -26,14 +26,19 @@ import org.apache.flink.configuration.IllegalConfigurationException;
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.metrics.Gauge;
 import org.apache.flink.runtime.execution.CancelTaskException;
+import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.runtime.jobgraph.tasks.StatefulTask;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.runtime.state.ClosableRegistry;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.KeyedStateBackend;
+import org.apache.flink.runtime.state.OperatorStateBackend;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateBackendFactory;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.filesystem.FsStateBackend;
@@ -41,24 +46,23 @@ import org.apache.flink.runtime.state.filesystem.FsStateBackendFactory;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.runtime.taskmanager.DispatcherThreadFactory;
 import org.apache.flink.streaming.api.TimeCharacteristic;
+import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.Output;
 import org.apache.flink.streaming.api.operators.StreamOperator;
 import org.apache.flink.streaming.runtime.io.RecordWriterOutput;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
-
+import org.apache.flink.util.Preconditions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.io.Closeable;
 import java.io.IOException;
-import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
-import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
-import java.util.Set;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.RunnableFuture;
@@ -70,19 +74,19 @@ import java.util.concurrent.ScheduledThreadPoolExecutor;
  * the Task's operator chain. Operators that are chained together execute synchronously in the
  * same thread and hence on the same stream partition. A common case for these chains
  * are successive map/flatmap/filter tasks.
- * 
- * <p>The task chain contains one "head" operator and multiple chained operators. 
+ *
+ * <p>The task chain contains one "head" operator and multiple chained operators.
  * The StreamTask is specialized for the type of the head operator: one-input and two-input tasks,
  * as well as for sources, iteration heads and iteration tails.
- * 
- * <p>The Task class deals with the setup of the streams read by the head operator, and the streams 
+ *
+ * <p>The Task class deals with the setup of the streams read by the head operator, and the streams
  * produced by the operators at the ends of the operator chain. Note that the chain may fork and
  * thus have multiple ends.
  *
- * The life cycle of the task is set up as follows: 
+ * The life cycle of the task is set up as follows:
  * <pre>{@code
- *  -- restoreState() -> restores state of all operators in the chain
- *  
+ *  -- getPartitionableState() -> restores state of all operators in the chain
+ *
  *  -- invoke()
  *        |
  *        +----> Create basic utils (config, etc) and load the chain of operators
@@ -99,35 +103,35 @@ import java.util.concurrent.ScheduledThreadPoolExecutor;
  * <p> The {@code StreamTask} has a lock object called {@code lock}. All calls to methods on a
  * {@code StreamOperator} must be synchronized on this lock object to ensure that no methods
  * are called concurrently.
- * 
+ *
  * @param <OUT>
- * @param <Operator>
+ * @param <OP>
  */
 @Internal
-public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
+public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 		extends AbstractInvokable
 		implements StatefulTask, AsyncExceptionHandler {
 
 	/** The thread group that holds all trigger timer threads */
 	public static final ThreadGroup TRIGGER_THREAD_GROUP = new ThreadGroup("Triggers");
-	
+
 	/** The logger used by the StreamTask and its subclasses */
 	private static final Logger LOG = LoggerFactory.getLogger(StreamTask.class);
-	
+
 	// ------------------------------------------------------------------------
-	
+
 	/**
 	 * All interaction with the {@code StreamOperator} must be synchronized on this lock object to ensure that
 	 * we don't have concurrent method calls that void consistent checkpoints.
 	 */
 	private final Object lock = new Object();
-	
+
 	/** the head operator that consumes the input streams of this task */
-	protected Operator headOperator;
+	protected OP headOperator;
 
 	/** The chain of operators executed by this task */
-	private OperatorChain<OUT> operatorChain;
-	
+	private OperatorChain<OUT, OP> operatorChain;
+
 	/** The configuration of this streaming task */
 	private StreamConfig configuration;
 
@@ -135,7 +139,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 	private AbstractStateBackend stateBackend;
 
 	/** Keyed state backend for the head operator, if it is keyed. There can only ever be one. */
-	private KeyedStateBackend<?> keyedStateBackend;
+	private AbstractKeyedStateBackend<?> keyedStateBackend;
 
 	/**
 	 * The internal {@link TimeServiceProvider} used to define the current
@@ -146,12 +150,14 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 
 	/** The map of user-defined accumulators of this task */
 	private Map<String, Accumulator<?, ?>> accumulatorMap;
-	
+
 	/** The chained operator state to be restored once the initialization is done */
 	private ChainedStateHandle<StreamStateHandle> lazyRestoreChainedOperatorState;
 
 	private List<KeyGroupsStateHandle> lazyRestoreKeyGroupStates;
 
+	private List<Collection<OperatorStateHandle>> lazyRestoreOperatorState;
+
 	/**
 	 * This field is used to forward an exception that is caught in the timer thread or other
 	 * asynchronous Threads. Subclasses must ensure that exceptions stored here get thrown on the
@@ -159,12 +165,12 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 	private volatile AsynchronousException asyncException;
 
 	/** The currently active background materialization threads */
-	private final Set<Closeable> cancelables = new HashSet<>();
-	
+	private final ClosableRegistry cancelables = new ClosableRegistry();
+
 	/** Flag to mark the task "in operation", in which case check
 	 * needs to be initialized to true, so that early cancel() before invoke() behaves correctly */
 	private volatile boolean isRunning;
-	
+
 	/** Flag to mark this task as canceled */
 	private volatile boolean canceled;
 
@@ -178,11 +184,11 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 	// ------------------------------------------------------------------------
 
 	protected abstract void init() throws Exception;
-	
+
 	protected abstract void run() throws Exception;
-	
+
 	protected abstract void cleanup() throws Exception;
-	
+
 	protected abstract void cancelTask() throws Exception;
 
 	// ------------------------------------------------------------------------
@@ -232,13 +238,8 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 				timerService = DefaultTimeServiceProvider.create(this, executor, getCheckpointLock());
 			}
 
-			headOperator = configuration.getStreamOperator(getUserCodeClassLoader());
-			operatorChain = new OperatorChain<>(this, headOperator, 
-						getEnvironment().getAccumulatorRegistry().getReadWriteReporter());
-
-			if (headOperator != null) {
-				headOperator.setup(this, configuration, operatorChain.getChainEntryPoint());
-			}
+			operatorChain = new OperatorChain<>(this, getEnvironment().getAccumulatorRegistry().getReadWriteReporter());
+			headOperator = operatorChain.getHeadOperator();
 
 			getEnvironment().getMetricGroup().gauge("lastCheckpointSize", new Gauge<Long>() {
 				@Override
@@ -249,12 +250,12 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 
 			// task specific initialization
 			init();
-			
+
 			// save the work of reloadig state, etc, if the task is already canceled
 			if (canceled) {
 				throw new CancelTaskException();
 			}
-			
+
 			// -------- Invoke --------
 			LOG.debug("Invoking {}", getName());
 
@@ -278,7 +279,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 			run();
 
 			LOG.debug("Finished task {}", getName());
-			
+
 			// make sure no further checkpoint and notification actions happen.
 			// we make sure that no other thread is currently in the locked scope before
 			// we close the operators by trying to acquire the checkpoint scope lock
@@ -286,13 +287,13 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 			// at the same time, this makes sure that during any "regular" exit where still
 			synchronized (lock) {
 				isRunning = false;
-				
+
 				// this is part of the main logic, so if this fails, the task is considered failed
 				closeAllOperators();
 			}
 
 			LOG.debug("Closed operators for task {}", getName());
-			
+
 			// make sure all buffered data is flushed
 			operatorChain.flushOutputs();
 
@@ -324,7 +325,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 
 			// stop all asynchronous checkpoint threads
 			try {
-				closeAllClosables();
+				cancelables.close();
 				shutdownAsyncThreads();
 			}
 			catch (Throwable t) {
@@ -371,13 +372,13 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 		isRunning = false;
 		canceled = true;
 		cancelTask();
-		closeAllClosables();
+		cancelables.close();
 	}
 
 	public final boolean isRunning() {
 		return isRunning;
 	}
-	
+
 	public final boolean isCanceled() {
 		return canceled;
 	}
@@ -476,36 +477,14 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 			}
 		}
 
-		closeAllClosables();
-	}
-
-	private void closeAllClosables() {
-		// first, create a copy of the cancelables to prevent concurrent modifications
-		// and to not hold the lock for too long. the copy can be a cheap list
-		List<Closeable> localCancelables = null;
-		synchronized (cancelables) {
-			if (cancelables.size() > 0) {
-				localCancelables = new ArrayList<>(cancelables);
-				cancelables.clear();
-			}
-		}
-
-		if (localCancelables != null) {
-			for (Closeable cancelable : localCancelables) {
-				try {
-					cancelable.close();
-				} catch (Throwable t) {
-					LOG.error("Error on canceling operation", t);
-				}
-			}
-		}
+		cancelables.close();
 	}
 
 	boolean isSerializingTimestamps() {
 		TimeCharacteristic tc = configuration.getTimeCharacteristic();
 		return tc == TimeCharacteristic.EventTime | tc == TimeCharacteristic.IngestionTime;
 	}
-	
+
 	// ------------------------------------------------------------------------
 	//  Access to properties and utilities
 	// ------------------------------------------------------------------------
@@ -525,7 +504,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 	public Object getCheckpointLock() {
 		return lock;
 	}
-	
+
 	public StreamConfig getConfiguration() {
 		return configuration;
 	}
@@ -533,11 +512,11 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 	public Map<String, Accumulator<?, ?>> getAccumulatorMap() {
 		return accumulatorMap;
 	}
-	
+
 	Output<StreamRecord<OUT>> getHeadOutput() {
 		return operatorChain.getChainEntryPoint();
 	}
-	
+
 	RecordWriterOutput<?>[] getStreamOutputs() {
 		return operatorChain.getStreamOutputs();
 	}
@@ -547,40 +526,59 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 	// ------------------------------------------------------------------------
 
 	@Override
-	public void setInitialState(ChainedStateHandle<StreamStateHandle> chainedState, List<KeyGroupsStateHandle> keyGroupsState) {
+	public void setInitialState(
+		ChainedStateHandle<StreamStateHandle> chainedState,
+		List<KeyGroupsStateHandle> keyGroupsState,
+		List<Collection<OperatorStateHandle>> partitionableOperatorState) {
+
 		lazyRestoreChainedOperatorState = chainedState;
 		lazyRestoreKeyGroupStates = keyGroupsState;
+		lazyRestoreOperatorState = partitionableOperatorState;
 	}
 
 	private void restoreState() throws Exception {
 		final StreamOperator<?>[] allOperators = operatorChain.getAllOperators();
 
-		try {
-			if (lazyRestoreChainedOperatorState != null) {
+		if (lazyRestoreChainedOperatorState != null) {
+			Preconditions.checkState(lazyRestoreChainedOperatorState.getLength() == allOperators.length,
+					"Invalid Invalid number of operator states. Found :" + lazyRestoreChainedOperatorState.getLength() +
+							". Expected: " + allOperators.length);
+		}
 
-				synchronized (cancelables) {
-					cancelables.add(lazyRestoreChainedOperatorState);
-				}
+		if (lazyRestoreOperatorState != null) {
+			Preconditions.checkArgument(lazyRestoreOperatorState.isEmpty()
+							|| lazyRestoreOperatorState.size() == allOperators.length,
+					"Invalid number of operator states. Found :" + lazyRestoreOperatorState.size() +
+							". Expected: " + allOperators.length);
+		}
 
-				for (int i = 0; i < lazyRestoreChainedOperatorState.getLength(); i++) {
+		for (int i = 0; i < allOperators.length; i++) {
+			StreamOperator<?> operator = allOperators[i];
+
+			if (null != lazyRestoreOperatorState && !lazyRestoreOperatorState.isEmpty()) {
+				operator.restoreState(lazyRestoreOperatorState.get(i));
+			}
+
+			// TODO deprecated code path
+			if (operator instanceof StreamCheckpointedOperator) {
+
+				if (lazyRestoreChainedOperatorState != null) {
 					StreamStateHandle state = lazyRestoreChainedOperatorState.get(i);
-					if (state == null) {
-						continue;
-					}
-					StreamOperator<?> operator = allOperators[i];
 
-					if (operator != null) {
+					if (state != null) {
 						LOG.debug("Restore state of task {} in chain ({}).", i, getName());
-						try (FSDataInputStream inputStream = state.openInputStream()) {
-							operator.restoreState(inputStream);
+
+						FSDataInputStream is = state.openInputStream();
+						try {
+							cancelables.registerClosable(is);
+							((StreamCheckpointedOperator) operator).restoreState(is);
+						} finally {
+							cancelables.unregisterClosable(is);
+							is.close();
 						}
 					}
 				}
 			}
-		} finally {
-			synchronized (cancelables) {
-				cancelables.remove(lazyRestoreChainedOperatorState);
-			}
 		}
 	}
 
@@ -629,29 +627,58 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 				// Given this, we immediately emit the checkpoint barriers, so the downstream operators
 				// can start their checkpoint work as soon as possible
 				operatorChain.broadcastCheckpointBarrier(checkpointId, timestamp);
-				
+
 				// now draw the state snapshot
 				final StreamOperator<?>[] allOperators = operatorChain.getAllOperators();
-				final List<StreamStateHandle> nonPartitionedStates = Arrays.asList(new StreamStateHandle[allOperators.length]);
+
+				final List<StreamStateHandle> nonPartitionedStates =
+						Arrays.asList(new StreamStateHandle[allOperators.length]);
+
+				final List<OperatorStateHandle> operatorStates =
+						Arrays.asList(new OperatorStateHandle[allOperators.length]);
 
 				for (int i = 0; i < allOperators.length; i++) {
 					StreamOperator<?> operator = allOperators[i];
 
 					if (operator != null) {
+
+						final String operatorId = createOperatorIdentifier(operator, configuration.getVertexID());
+
 						CheckpointStreamFactory streamFactory =
-								stateBackend.createStreamFactory(
-										getEnvironment().getJobID(),
-										createOperatorIdentifier(
-												operator,
-												configuration.getVertexID()));
+								stateBackend.createStreamFactory(getEnvironment().getJobID(), operatorId);
+
+						//TODO deprecated code path
+						if (operator instanceof StreamCheckpointedOperator) {
+
+							CheckpointStreamFactory.CheckpointStateOutputStream outStream =
+									streamFactory.createCheckpointStateOutputStream(checkpointId, timestamp);
+
+
+							cancelables.registerClosable(outStream);
+
+							try {
+								((StreamCheckpointedOperator) operator).
+										snapshotState(outStream, checkpointId, timestamp);
+
+								nonPartitionedStates.set(i, outStream.closeAndGetHandle());
+							} finally {
+								cancelables.unregisterClosable(outStream);
+							}
+						}
 
-						CheckpointStreamFactory.CheckpointStateOutputStream outStream =
-								streamFactory.createCheckpointStateOutputStream(checkpointId, timestamp);
+						RunnableFuture<OperatorStateHandle> handleFuture =
+								operator.snapshotState(checkpointId, timestamp, streamFactory);
 
-						operator.snapshotState(outStream, checkpointId, timestamp);
+						if (null != handleFuture) {
+							//TODO for now we assume there are only synchrous snapshots, no need to start the runnable.
+							if (!handleFuture.isDone()) {
+								throw new IllegalStateException("Currently only supports synchronous snapshots!");
+							}
 
-						nonPartitionedStates.set(i, outStream.closeAndGetHandle());
+							operatorStates.set(i, handleFuture.get());
+						}
 					}
+
 				}
 
 				RunnableFuture<KeyGroupsStateHandle> keyGroupsStateHandleFuture = null;
@@ -659,16 +686,16 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 				if (keyedStateBackend != null) {
 					CheckpointStreamFactory streamFactory = stateBackend.createStreamFactory(
 							getEnvironment().getJobID(),
-							createOperatorIdentifier(
-									headOperator,
-									configuration.getVertexID()));
-					keyGroupsStateHandleFuture = keyedStateBackend.snapshot(
-							checkpointId,
-							timestamp,
-							streamFactory);
+							createOperatorIdentifier(headOperator, configuration.getVertexID()));
+
+					keyGroupsStateHandleFuture = keyedStateBackend.snapshot(checkpointId, timestamp, streamFactory);
 				}
 
-				ChainedStateHandle<StreamStateHandle> chainedStateHandles = new ChainedStateHandle<>(nonPartitionedStates);
+				ChainedStateHandle<StreamStateHandle> chainedNonPartitionedStateHandles =
+						new ChainedStateHandle<>(nonPartitionedStates);
+
+				ChainedStateHandle<OperatorStateHandle> chainedPartitionedStateHandles =
+						new ChainedStateHandle<>(operatorStates);
 
 				LOG.debug("Finished synchronous checkpoints for checkpoint {} on task {}", checkpointId, getName());
 
@@ -679,7 +706,8 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 						"checkpoint-" + checkpointId + "-" + timestamp,
 						this,
 						cancelables,
-						chainedStateHandles,
+						chainedNonPartitionedStateHandles,
+						chainedPartitionedStateHandles,
 						keyGroupsStateHandleFuture,
 						checkpointId,
 						bytesBufferedAlignment,
@@ -687,9 +715,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 						syncDurationMillis,
 						endOfSyncPart);
 
-				synchronized (cancelables) {
-					cancelables.add(asyncCheckpointRunnable);
-				}
+				cancelables.registerClosable(asyncCheckpointRunnable);
 				asyncOperationsThreadPool.submit(asyncCheckpointRunnable);
 				return true;
 			} else {
@@ -707,7 +733,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 		synchronized (lock) {
 			if (isRunning) {
 				LOG.debug("Notification of complete checkpoint for task {}", getName());
-				
+
 				for (StreamOperator<?> operator : operatorChain.getAllOperators()) {
 					if (operator != null) {
 						operator.notifyOfCompletedCheckpoint(checkpointId);
@@ -760,7 +786,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 						Class<? extends StateBackendFactory> clazz =
 								Class.forName(backendName, false, getUserCodeClassLoader()).asSubclass(StateBackendFactory.class);
 
-						stateBackend = ((StateBackendFactory<?>) clazz.newInstance()).createFromConfig(flinkConfig);
+						stateBackend = clazz.newInstance().createFromConfig(flinkConfig);
 					} catch (ClassNotFoundException e) {
 						throw new IllegalConfigurationException("Cannot find configured state backend: " + backendName);
 					} catch (ClassCastException e) {
@@ -772,10 +798,26 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 					}
 			}
 		}
+
 		return stateBackend;
 	}
 
-	public <K> KeyedStateBackend<K> createKeyedStateBackend(
+	public OperatorStateBackend createOperatorStateBackend(
+			StreamOperator<?> op, Collection<OperatorStateHandle> restoreStateHandles) throws Exception {
+
+		Environment env = getEnvironment();
+		String opId = createOperatorIdentifier(op, configuration.getVertexID());
+
+		OperatorStateBackend newBackend = restoreStateHandles == null ?
+				stateBackend.createOperatorStateBackend(env, opId)
+				: stateBackend.restoreOperatorStateBackend(env, opId, restoreStateHandles);
+
+		cancelables.registerClosable(newBackend);
+
+		return newBackend;
+	}
+
+	public <K> AbstractKeyedStateBackend<K> createKeyedStateBackend(
 			TypeSerializer<K> keySerializer,
 			int numberOfKeyGroups,
 			KeyGroupRange keyGroupRange) throws Exception {
@@ -811,8 +853,10 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 					getEnvironment().getTaskKvStateRegistry());
 		}
 
+		cancelables.registerClosable(keyedStateBackend);
+
 		@SuppressWarnings("unchecked")
-		KeyedStateBackend<K> typedBackend = (KeyedStateBackend<K>) keyedStateBackend;
+		AbstractKeyedStateBackend<K> typedBackend = (AbstractKeyedStateBackend<K>) keyedStateBackend;
 		return typedBackend;
 	}
 
@@ -825,9 +869,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 	public CheckpointStreamFactory createCheckpointStreamFactory(StreamOperator<?> operator) throws IOException {
 		return stateBackend.createStreamFactory(
 				getEnvironment().getJobID(),
-				createOperatorIdentifier(
-						operator,
-						configuration.getVertexID()));
+				createOperatorIdentifier(operator, configuration.getVertexID()));
 
 	}
 
@@ -867,7 +909,6 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 		if (isRunning) {
 			LOG.error("Asynchronous exception registered.", exception);
 		}
-
 		if (this.asyncException == null) {
 			this.asyncException = exception;
 		}
@@ -877,20 +918,23 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 	//  Utilities
 	// ------------------------------------------------------------------------
 
+
 	@Override
 	public String toString() {
 		return getName();
 	}
 
 	// ------------------------------------------------------------------------
-	
+
 	private static class AsyncCheckpointRunnable implements Runnable, Closeable {
 
 		private final StreamTask<?, ?> owner;
 
-		private final Set<Closeable> cancelables;
+		private final ClosableRegistry cancelables;
+
+		private final ChainedStateHandle<StreamStateHandle> nonPartitionedStateHandles;
 
-		private final ChainedStateHandle<StreamStateHandle> chainedStateHandles;
+		private final ChainedStateHandle<OperatorStateHandle> partitioneableStateHandles;
 
 		private final RunnableFuture<KeyGroupsStateHandle> keyGroupsStateHandleFuture;
 
@@ -909,8 +953,9 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 		AsyncCheckpointRunnable(
 				String name,
 				StreamTask<?, ?> owner,
-				Set<Closeable> cancelables,
-				ChainedStateHandle<StreamStateHandle> chainedStateHandles,
+				ClosableRegistry cancelables,
+				ChainedStateHandle<StreamStateHandle> nonPartitionedStateHandles,
+				ChainedStateHandle<OperatorStateHandle> partitioneableStateHandles,
 				RunnableFuture<KeyGroupsStateHandle> keyGroupsStateHandleFuture,
 				long checkpointId,
 				long bytesBufferedInAlignment,
@@ -921,7 +966,8 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 			this.name = name;
 			this.owner = owner;
 			this.cancelables = cancelables;
-			this.chainedStateHandles = chainedStateHandles;
+			this.nonPartitionedStateHandles = nonPartitionedStateHandles;
+			this.partitioneableStateHandles = partitioneableStateHandles;
 			this.keyGroupsStateHandleFuture = keyGroupsStateHandleFuture;
 			this.checkpointId = checkpointId;
 			this.bytesBufferedInAlignment = bytesBufferedInAlignment;
@@ -952,13 +998,19 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 				final long asyncEndNanos = System.nanoTime();
 				final long asyncDurationMillis = (asyncEndNanos - asyncStartNanos) / 1_000_000;
 
-				if (chainedStateHandles.isEmpty() && keyedStates.isEmpty()) {
+				if (nonPartitionedStateHandles.isEmpty() && keyedStates.isEmpty()) {
 					owner.getEnvironment().acknowledgeCheckpoint(checkpointId,
 							syncDurationMillies, asyncDurationMillis,
 							bytesBufferedInAlignment, alignmentDurationNanos);
 				} else  {
+
+					CheckpointStateHandles allStateHandles = new CheckpointStateHandles(
+							nonPartitionedStateHandles,
+							partitioneableStateHandles,
+							keyedStates);
+
 					owner.getEnvironment().acknowledgeCheckpoint(checkpointId,
-							chainedStateHandles, keyedStates,
+							allStateHandles,
 							syncDurationMillies, asyncDurationMillis,
 							bytesBufferedInAlignment, alignmentDurationNanos);
 				}
@@ -974,9 +1026,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 				owner.registerAsyncException(asyncException);
 			}
 			finally {
-				synchronized (cancelables) {
-					cancelables.remove(this);
-				}
+				cancelables.unregisterClosable(this);
 			}
 		}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java
index fe09788..02409a3 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java
@@ -36,8 +36,8 @@ import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
 import org.apache.flink.runtime.query.KvStateRegistry;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
@@ -188,15 +188,15 @@ public class StreamingRuntimeContextTest {
 					public ListState<String> answer(InvocationOnMock invocationOnMock) throws Throwable {
 						ListStateDescriptor<String> descr =
 								(ListStateDescriptor<String>) invocationOnMock.getArguments()[0];
-						KeyedStateBackend<Integer> backend = new MemoryStateBackend().createKeyedStateBackend(
+
+						AbstractKeyedStateBackend<Integer> backend = new MemoryStateBackend().createKeyedStateBackend(
 								new DummyEnvironment("test_task", 1, 0),
 								new JobID(),
 								"test_op",
 								IntSerializer.INSTANCE,
 								1,
 								new KeyGroupRange(0, 0),
-								new KvStateRegistry().createTaskRegistry(new JobID(),
-										new JobVertexID()));
+								new KvStateRegistry().createTaskRegistry(new JobID(), new JobVertexID()));
 						backend.setCurrentKey(0);
 						return backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, descr);
 					}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java
index b549ef8..5d68841 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java
@@ -21,6 +21,7 @@ package org.apache.flink.streaming.runtime.io;
 import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.runtime.io.disk.iomanager.IOManager;
 import org.apache.flink.runtime.io.disk.iomanager.IOManagerAsync;
+import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
@@ -28,15 +29,15 @@ import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
 import org.apache.flink.runtime.jobgraph.tasks.StatefulTask;
 import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
-import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
-
 import org.junit.AfterClass;
 import org.junit.BeforeClass;
 import org.junit.Test;
 
 import java.io.File;
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.List;
 
 import static org.junit.Assert.assertEquals;
@@ -974,7 +975,8 @@ public class BarrierBufferTest {
 		@Override
 		public void setInitialState(
 				ChainedStateHandle<StreamStateHandle> chainedState,
-				List<KeyGroupsStateHandle> keyGroupsState) throws Exception {
+				List<KeyGroupsStateHandle> keyGroupsState,
+				List<Collection<OperatorStateHandle>> partitionableOperatorState) throws Exception {
 			throw new UnsupportedOperationException("should never be called");
 		}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java
index 314dcc4..f2f9092 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java
@@ -19,21 +19,25 @@
 package org.apache.flink.streaming.runtime.io;
 
 import org.apache.flink.core.memory.MemorySegmentFactory;
+import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
 import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
 import org.apache.flink.runtime.jobgraph.tasks.StatefulTask;
 import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
-import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
-
 import org.junit.Test;
 
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.List;
 
-import static org.junit.Assert.*;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 
 /**
  * Tests for the behavior of the barrier tracker.
@@ -363,7 +367,8 @@ public class BarrierTrackerTest {
 		@Override
 		public void setInitialState(
 				ChainedStateHandle<StreamStateHandle> chainedState,
-				List<KeyGroupsStateHandle> keyGroupsState) throws Exception {
+				List<KeyGroupsStateHandle> keyGroupsState,
+				List<Collection<OperatorStateHandle>> partitionableOperatorState) throws Exception {
 
 			throw new UnsupportedOperationException("should never be called");
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java
index f4ac5b2..32e8ea9 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java
@@ -19,7 +19,6 @@ package org.apache.flink.streaming.runtime.operators;
 
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.functions.MapFunction;
-import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.accumulators.AccumulatorRegistry;
 import org.apache.flink.runtime.execution.Environment;
@@ -27,8 +26,6 @@ import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.operators.testutils.MockEnvironment;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
-import org.apache.flink.runtime.state.AbstractStateBackend;
-import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.streaming.api.collector.selector.OutputSelector;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.datastream.SplitStream;
@@ -42,19 +39,15 @@ import org.apache.flink.streaming.runtime.tasks.OperatorChain;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
 import org.junit.Assert;
 import org.junit.Test;
-import org.mockito.invocation.InvocationOnMock;
-import org.mockito.stubbing.Answer;
 
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
 
+import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.contains;
-import static org.mockito.Matchers.any;
-import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
-import static org.hamcrest.MatcherAssert.assertThat;
 
 /**
  * Tests for stream operator chaining behaviour.
@@ -156,9 +149,8 @@ public class StreamOperatorChainingTest {
 		StreamTask<Integer, StreamMap<Integer, Integer>> mockTask =
 				createMockTask(streamConfig, chainedVertex.getName());
 
-		OperatorChain<Integer> operatorChain = new OperatorChain<>(
+		OperatorChain<Integer, StreamMap<Integer, Integer>> operatorChain = new OperatorChain<>(
 				mockTask,
-				headOperator,
 				mock(AccumulatorRegistry.Reporter.class));
 
 		headOperator.setup(mockTask, streamConfig, operatorChain.getChainEntryPoint());
@@ -299,9 +291,8 @@ public class StreamOperatorChainingTest {
 		StreamTask<Integer, StreamMap<Integer, Integer>> mockTask =
 				createMockTask(streamConfig, chainedVertex.getName());
 
-		OperatorChain<Integer> operatorChain = new OperatorChain<>(
+		OperatorChain<Integer, StreamMap<Integer, Integer>> operatorChain = new OperatorChain<>(
 				mockTask,
-				headOperator,
 				mock(AccumulatorRegistry.Reporter.class));
 
 		headOperator.setup(mockTask, streamConfig, operatorChain.getChainEntryPoint());

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
index 6a7b024..b5b6582 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
@@ -41,9 +41,9 @@ import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
 import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
-import org.apache.flink.runtime.state.AbstractCloseableHandle;
 import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.taskmanager.CheckpointResponder;
 import org.apache.flink.runtime.taskmanager.Task;
@@ -56,19 +56,23 @@ import org.apache.flink.streaming.api.functions.source.SourceFunction;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.StreamSource;
 import org.apache.flink.util.SerializedValue;
-
 import org.junit.Test;
 
 import java.io.EOFException;
 import java.io.IOException;
 import java.io.Serializable;
 import java.net.URL;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.List;
 import java.util.concurrent.Executor;
 
-import static org.junit.Assert.*;
-import static org.mockito.Mockito.*;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 /**
  * This test checks that task restores that get stuck in the presence of interrupts
@@ -121,6 +125,7 @@ public class InterruptSensitiveRestoreTest {
 
 		ChainedStateHandle<StreamStateHandle> operatorState = new ChainedStateHandle<>(Collections.singletonList(state));
 		List<KeyGroupsStateHandle> keyGroupState = Collections.emptyList();
+		List<Collection<OperatorStateHandle>> partitionableOperatorState = Collections.emptyList();
 
 		return new TaskDeploymentDescriptor(
 				new JobID(),
@@ -139,42 +144,47 @@ public class InterruptSensitiveRestoreTest {
 				Collections.<URL>emptyList(),
 				0,
 				operatorState,
-				keyGroupState);
+				keyGroupState,
+				partitionableOperatorState);
 	}
-	
+
 	private static Task createTask(TaskDeploymentDescriptor tdd) throws IOException {
 		NetworkEnvironment networkEnvironment = mock(NetworkEnvironment.class);
 		when(networkEnvironment.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class)))
 				.thenReturn(mock(TaskKvStateRegistry.class));
 
 		return new Task(
-			tdd,
-			mock(MemoryManager.class),
-			mock(IOManager.class),
-			networkEnvironment,
-			mock(BroadcastVariableManager.class),
+				tdd,
+				mock(MemoryManager.class),
+				mock(IOManager.class),
+				networkEnvironment,
+				mock(BroadcastVariableManager.class),
 				mock(TaskManagerConnection.class),
 				mock(InputSplitProvider.class),
 				mock(CheckpointResponder.class),
-			new FallbackLibraryCacheManager(),
-			new FileCache(new Configuration()),
-			new TaskManagerRuntimeInfo(
-					"localhost", new Configuration(), EnvironmentInformation.getTemporaryFileDirectory()),
-			new UnregisteredTaskMetricsGroup(),
-			mock(ResultPartitionConsumableNotifier.class),
-			mock(PartitionStateChecker.class),
-			mock(Executor.class));
-		
+				new FallbackLibraryCacheManager(),
+				new FileCache(new Configuration()),
+				new TaskManagerRuntimeInfo(
+						"localhost", new Configuration(), EnvironmentInformation.getTemporaryFileDirectory()),
+				new UnregisteredTaskMetricsGroup(),
+				mock(ResultPartitionConsumableNotifier.class),
+				mock(PartitionStateChecker.class),
+				mock(Executor.class));
+
 	}
 
 	// ------------------------------------------------------------------------
 
 	@SuppressWarnings("serial")
-	private static class InterruptLockingStateHandle extends AbstractCloseableHandle implements StreamStateHandle {
+	private static class InterruptLockingStateHandle implements StreamStateHandle {
+
+		private volatile boolean closed;
 
 		@Override
 		public FSDataInputStream openInputStream() throws IOException {
-			ensureNotClosed();
+
+			closed = false;
+
 			FSDataInputStream is = new FSDataInputStream() {
 
 				@Override
@@ -191,8 +201,14 @@ public class InterruptSensitiveRestoreTest {
 					block();
 					throw new EOFException();
 				}
+
+				@Override
+				public void close() throws IOException {
+					super.close();
+					closed = true;
+				}
 			};
-			registerCloseable(is);
+
 			return is;
 		}
 
@@ -207,7 +223,7 @@ public class InterruptSensitiveRestoreTest {
 				}
 			}
 			catch (InterruptedException e) {
-				while (!isClosed()) {
+				while (!closed) {
 					try {
 						synchronized (this) {
 							wait();
@@ -227,7 +243,7 @@ public class InterruptSensitiveRestoreTest {
 	}
 
 	// ------------------------------------------------------------------------
-	
+
 	private static class TestSource implements SourceFunction<Object>, Checkpointed<Serializable> {
 		private static final long serialVersionUID = 1L;
 
@@ -250,4 +266,4 @@ public class InterruptSensitiveRestoreTest {
 			fail("should never be called");
 		}
 	}
-}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
index 88fb383..4003e59 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
@@ -21,7 +21,10 @@ package org.apache.flink.streaming.runtime.tasks;
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
 import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.core.fs.FSDataInputStream;
@@ -31,8 +34,12 @@ import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
+import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.graph.StreamEdge;
 import org.apache.flink.streaming.api.graph.StreamNode;
@@ -56,17 +63,23 @@ import scala.concurrent.duration.FiniteDuration;
 
 import java.io.Serializable;
 import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Random;
+import java.util.Set;
 import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.RunnableFuture;
 import java.util.concurrent.TimeUnit;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 
 /**
  * Tests for {@link OneInputStreamTask}.
@@ -82,6 +95,9 @@ import static org.junit.Assert.assertTrue;
 @PowerMockIgnore({"javax.management.*", "com.sun.jndi.*"})
 public class OneInputStreamTaskTest extends TestLogger {
 
+	private static final ListStateDescriptor<Integer> TEST_DESCRIPTOR =
+			new ListStateDescriptor<>("test", new IntSerializer());
+
 	/**
 	 * This test verifies that open() and close() are correctly called. This test also verifies
 	 * that timestamps of emitted elements are correct. {@link StreamMap} assigns the input
@@ -358,7 +374,7 @@ public class OneInputStreamTaskTest extends TestLogger {
 		testHarness.invoke(env);
 		testHarness.waitForTaskRunning(deadline.timeLeft().toMillis());
 
-		streamTask.triggerCheckpoint(checkpointId, checkpointTimestamp);
+		while(!streamTask.triggerCheckpoint(checkpointId, checkpointTimestamp));
 
 		// since no state was set, there shouldn't be restore calls
 		assertEquals(0, TestingStreamOperator.numberRestoreCalls);
@@ -371,7 +387,7 @@ public class OneInputStreamTaskTest extends TestLogger {
 		testHarness.waitForTaskCompletion(deadline.timeLeft().toMillis());
 
 		final OneInputStreamTask<String, String> restoredTask = new OneInputStreamTask<String, String>();
-		restoredTask.setInitialState(env.getState(), env.getKeyGroupStates());
+		restoredTask.setInitialState(env.getState(), env.getKeyGroupStates(), env.getPartitionableOperatorState());
 
 		final OneInputStreamTaskTestHarness<String, String> restoredTaskHarness = new OneInputStreamTaskTestHarness<String, String>(restoredTask, BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO);
 		restoredTaskHarness.configureForKeyedStream(keySelector, BasicTypeInfo.STRING_TYPE_INFO);
@@ -465,6 +481,7 @@ public class OneInputStreamTaskTest extends TestLogger {
 		private volatile long checkpointId;
 		private volatile ChainedStateHandle<StreamStateHandle> state;
 		private volatile List<KeyGroupsStateHandle> keyGroupStates;
+		private volatile List<Collection<OperatorStateHandle>> partitionableOperatorState;
 
 		private final OneShotLatch checkpointLatch = new OneShotLatch();
 
@@ -486,6 +503,10 @@ public class OneInputStreamTaskTest extends TestLogger {
 			return result;
 		}
 
+		List<Collection<OperatorStateHandle>> getPartitionableOperatorState() {
+			return partitionableOperatorState;
+		}
+
 		AcknowledgeStreamMockEnvironment(
 				Configuration jobConfig, Configuration taskConfig,
 				ExecutionConfig executionConfig, long memorySize,
@@ -497,13 +518,21 @@ public class OneInputStreamTaskTest extends TestLogger {
 		@Override
 		public void acknowledgeCheckpoint(
 				long checkpointId,
-				ChainedStateHandle<StreamStateHandle> state, List<KeyGroupsStateHandle> keyGroupStates,
+				CheckpointStateHandles checkpointStateHandles,
 				long syncDuration, long asymcDuration, long alignmentByte, long alignmentDuration) {
 
 			this.checkpointId = checkpointId;
-			this.state = state;
-			this.keyGroupStates = keyGroupStates;
-
+			if(checkpointStateHandles != null) {
+				this.state = checkpointStateHandles.getNonPartitionedStateHandles();
+				this.keyGroupStates = checkpointStateHandles.getKeyGroupsStateHandle();
+				ChainedStateHandle<OperatorStateHandle> chainedStateHandle = checkpointStateHandles.getPartitioneableStateHandles();
+				Collection<OperatorStateHandle>[] ia = new Collection[chainedStateHandle.getLength()];
+				this.partitionableOperatorState = Arrays.asList(ia);
+
+				for (int i = 0; i < chainedStateHandle.getLength(); ++i) {
+					partitionableOperatorState.set(i, Collections.singletonList(chainedStateHandle.get(i)));
+				}
+			}
 			checkpointLatch.trigger();
 		}
 
@@ -513,17 +542,56 @@ public class OneInputStreamTaskTest extends TestLogger {
 	}
 
 	private static class TestingStreamOperator<IN, OUT>
-			extends AbstractStreamOperator<OUT> implements OneInputStreamOperator<IN, OUT> {
+			extends AbstractStreamOperator<OUT>
+			implements OneInputStreamOperator<IN, OUT>, StreamCheckpointedOperator {
 
 		private static final long serialVersionUID = 774614855940397174L;
 
 		public static int numberRestoreCalls = 0;
+		public static int numberSnapshotCalls = 0;
 
 		private final long seed;
 		private final long recoveryTimestamp;
 
 		private transient Random random;
 
+		@Override
+		public void open() throws Exception {
+			super.open();
+
+			ListState<Integer> partitionableState = getOperatorStateBackend().getPartitionableState(TEST_DESCRIPTOR);
+
+			if (numberSnapshotCalls == 0) {
+				for (Integer v : partitionableState.get()) {
+					fail();
+				}
+			} else {
+				Set<Integer> result = new HashSet<>();
+				for (Integer v : partitionableState.get()) {
+					result.add(v);
+				}
+
+				assertEquals(2, result.size());
+				assertTrue(result.contains(42));
+				assertTrue(result.contains(4711));
+			}
+		}
+
+		@Override
+		public RunnableFuture<OperatorStateHandle> snapshotState(
+				long checkpointId, long timestamp, CheckpointStreamFactory streamFactory) throws Exception {
+
+			ListState<Integer> partitionableState =
+					getOperatorStateBackend().getPartitionableState(TEST_DESCRIPTOR);
+			partitionableState.clear();
+
+			partitionableState.add(42);
+			partitionableState.add(4711);
+
+			++numberSnapshotCalls;
+			return super.snapshotState(checkpointId, timestamp, streamFactory);
+		}
+
 		TestingStreamOperator(long seed, long recoveryTimestamp) {
 			this.seed = seed;
 			this.recoveryTimestamp = recoveryTimestamp;