You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by se...@apache.org on 2016/09/30 12:47:56 UTC

[06/10] flink git commit: [FLINK-4379] [checkpoints] Introduce rescalable operator state

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
index 5612f73..7293a84 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
@@ -7,7 +7,7 @@
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
  *
- *     http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
  *
  * Unless required by applicable law or agreed to in writing, software
  * distributed under the License is distributed on an "AS IS" BASIS,
@@ -18,178 +18,55 @@
 
 package org.apache.flink.runtime.state;
 
-import org.apache.flink.api.common.ExecutionConfig;
-import org.apache.flink.api.common.functions.ReduceFunction;
-import org.apache.flink.api.common.state.FoldingState;
-import org.apache.flink.api.common.state.FoldingStateDescriptor;
-import org.apache.flink.api.common.state.ListState;
-import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.state.MergingState;
-import org.apache.flink.api.common.state.ReducingState;
-import org.apache.flink.api.common.state.ReducingStateDescriptor;
 import org.apache.flink.api.common.state.State;
-import org.apache.flink.api.common.state.StateBackend;
 import org.apache.flink.api.common.state.StateDescriptor;
-import org.apache.flink.api.common.state.ValueState;
-import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.runtime.query.TaskKvStateRegistry;
-import org.apache.flink.util.Preconditions;
 
-import java.util.ArrayList;
 import java.util.Collection;
-import java.util.HashMap;
-import java.util.List;
-import java.util.concurrent.RunnableFuture;
 
 /**
- * A keyed state backend is responsible for managing keyed state. The state can be checkpointed
- * to streams using {@link #snapshot(long, long, CheckpointStreamFactory)}.
+ * A keyed state backend provides methods for managing keyed state.
  *
  * @param <K> The key by which state is keyed.
  */
-public abstract class KeyedStateBackend<K> {
-
-	/** {@link TypeSerializer} for our key. */
-	protected final TypeSerializer<K> keySerializer;
-
-	/** The currently active key. */
-	protected K currentKey;
-
-	/** The key group of the currently active key */
-	private int currentKeyGroup;
-
-	/** So that we can give out state when the user uses the same key. */
-	protected HashMap<String, KvState<?>> keyValueStatesByName;
-
-	/** For caching the last accessed partitioned state */
-	private String lastName;
-
-	@SuppressWarnings("rawtypes")
-	private KvState lastState;
-
-	/** The number of key-groups aka max parallelism */
-	protected final int numberOfKeyGroups;
-
-	/** Range of key-groups for which this backend is responsible */
-	protected final KeyGroupRange keyGroupRange;
-
-	/** KvStateRegistry helper for this task */
-	protected final TaskKvStateRegistry kvStateRegistry;
-
-	protected final ClassLoader userCodeClassLoader;
-
-	public KeyedStateBackend(
-			TaskKvStateRegistry kvStateRegistry,
-			TypeSerializer<K> keySerializer,
-			ClassLoader userCodeClassLoader,
-			int numberOfKeyGroups,
-			KeyGroupRange keyGroupRange) {
-
-		this.kvStateRegistry = Preconditions.checkNotNull(kvStateRegistry);
-		this.keySerializer = Preconditions.checkNotNull(keySerializer);
-		this.userCodeClassLoader = Preconditions.checkNotNull(userCodeClassLoader);
-		this.numberOfKeyGroups = Preconditions.checkNotNull(numberOfKeyGroups);
-		this.keyGroupRange = Preconditions.checkNotNull(keyGroupRange);
-	}
+public interface KeyedStateBackend<K> {
 
 	/**
-	 * Closes the state backend, releasing all internal resources, but does not delete any persistent
-	 * checkpoint data.
-	 *
-	 * @throws Exception Exceptions can be forwarded and will be logged by the system
+	 * Sets the current key that is used for partitioned state.
+	 * @param newKey The new current key.
 	 */
-	public void close() throws Exception {
-		if (kvStateRegistry != null) {
-			kvStateRegistry.unregisterAll();
-		}
-
-		lastName = null;
-		lastState = null;
-		keyValueStatesByName = null;
-	}
+	void setCurrentKey(K newKey);
 
 	/**
-	 * Creates and returns a new {@link ValueState}.
-	 *
-	 * @param namespaceSerializer TypeSerializer for the state namespace.
-	 * @param stateDesc The {@code StateDescriptor} that contains the name of the state.
-	 *
-	 * @param <N> The type of the namespace.
-	 * @param <T> The type of the value that the {@code ValueState} can store.
+	 * Used by states to access the current key.
 	 */
-	protected abstract <N, T> ValueState<T> createValueState(TypeSerializer<N> namespaceSerializer, ValueStateDescriptor<T> stateDesc) throws Exception;
+	K getCurrentKey();
 
 	/**
-	 * Creates and returns a new {@link ListState}.
-	 *
-	 * @param namespaceSerializer TypeSerializer for the state namespace.
-	 * @param stateDesc The {@code StateDescriptor} that contains the name of the state.
-	 *
-	 * @param <N> The type of the namespace.
-	 * @param <T> The type of the values that the {@code ListState} can store.
+	 * Returns the key-group to which the current key belongs.
 	 */
-	protected abstract <N, T> ListState<T> createListState(TypeSerializer<N> namespaceSerializer, ListStateDescriptor<T> stateDesc) throws Exception;
+	int getCurrentKeyGroupIndex();
 
 	/**
-	 * Creates and returns a new {@link ReducingState}.
-	 *
-	 * @param namespaceSerializer TypeSerializer for the state namespace.
-	 * @param stateDesc The {@code StateDescriptor} that contains the name of the state.
-	 *
-	 * @param <N> The type of the namespace.
-	 * @param <T> The type of the values that the {@code ListState} can store.
+	 * Returns the number of key-groups aka max parallelism.
 	 */
-	protected abstract <N, T> ReducingState<T> createReducingState(TypeSerializer<N> namespaceSerializer, ReducingStateDescriptor<T> stateDesc) throws Exception;
+	int getNumberOfKeyGroups();
 
 	/**
-	 * Creates and returns a new {@link FoldingState}.
-	 *
-	 * @param namespaceSerializer TypeSerializer for the state namespace.
-	 * @param stateDesc The {@code StateDescriptor} that contains the name of the state.
-	 *
-	 * @param <N> The type of the namespace.
-	 * @param <T> Type of the values folded into the state
-	 * @param <ACC> Type of the value in the state	 *
+	 * Returns the key group range for this backend.
 	 */
-	protected abstract <N, T, ACC> FoldingState<T, ACC> createFoldingState(TypeSerializer<N> namespaceSerializer, FoldingStateDescriptor<T, ACC> stateDesc) throws Exception;
-
-	/**
-	 * Sets the current key that is used for partitioned state.
-	 * @param newKey The new current key.
-	 */
-	public void setCurrentKey(K newKey) {
-		this.currentKey = newKey;
-		this.currentKeyGroup = KeyGroupRangeAssignment.assignToKeyGroup(newKey, numberOfKeyGroups);
-	}
+	KeyGroupRange getKeyGroupRange();
 
 	/**
 	 * {@link TypeSerializer} for the state backend key type.
 	 */
-	public TypeSerializer<K> getKeySerializer() {
-		return keySerializer;
-	}
-
-	/**
-	 * Used by states to access the current key.
-	 */
-	public K getCurrentKey() {
-		return currentKey;
-	}
-
-	public int getCurrentKeyGroupIndex() {
-		return currentKeyGroup;
-	}
-
-	public int getNumberOfKeyGroups() {
-		return numberOfKeyGroups;
-	}
+	TypeSerializer<K> getKeySerializer();
 
 	/**
 	 * Creates or retrieves a partitioned state backed by this state backend.
 	 *
-	 * @param stateDescriptor The state identifier for the state. This contains name
-	 *                           and can create a default state value.
+	 * @param stateDescriptor The identifier for the state. This contains name and can create a default state value.
 
 	 * @param <N> The type of the namespace.
 	 * @param <S> The type of the state.
@@ -199,145 +76,21 @@ public abstract class KeyedStateBackend<K> {
 	 * @throws Exception Exceptions may occur during initialization of the state and should be forwarded.
 	 */
 	@SuppressWarnings({"rawtypes", "unchecked"})
-	public <N, S extends State> S getPartitionedState(final N namespace, final TypeSerializer<N> namespaceSerializer, final StateDescriptor<S, ?> stateDescriptor) throws Exception {
-		Preconditions.checkNotNull(namespace, "Namespace");
-		Preconditions.checkNotNull(namespaceSerializer, "Namespace serializer");
-
-		if (keySerializer == null) {
-			throw new RuntimeException("State key serializer has not been configured in the config. " +
-					"This operation cannot use partitioned state.");
-		}
-		
-		if (!stateDescriptor.isSerializerInitialized()) {
-			stateDescriptor.initializeSerializerUnlessSet(new ExecutionConfig());
-		}
-
-		if (keyValueStatesByName == null) {
-			keyValueStatesByName = new HashMap<>();
-		}
-
-		if (lastName != null && lastName.equals(stateDescriptor.getName())) {
-			lastState.setCurrentNamespace(namespace);
-			return (S) lastState;
-		}
-
-		KvState<?> previous = keyValueStatesByName.get(stateDescriptor.getName());
-		if (previous != null) {
-			lastState = previous;
-			lastState.setCurrentNamespace(namespace);
-			lastName = stateDescriptor.getName();
-			return (S) previous;
-		}
-
-		// create a new blank key/value state
-		S state = stateDescriptor.bind(new StateBackend() {
-			@Override
-			public <T> ValueState<T> createValueState(ValueStateDescriptor<T> stateDesc) throws Exception {
-				return KeyedStateBackend.this.createValueState(namespaceSerializer, stateDesc);
-			}
+	<N, S extends State> S getPartitionedState(
+			N namespace,
+			TypeSerializer<N> namespaceSerializer,
+			StateDescriptor<S, ?> stateDescriptor) throws Exception;
 
-			@Override
-			public <T> ListState<T> createListState(ListStateDescriptor<T> stateDesc) throws Exception {
-				return KeyedStateBackend.this.createListState(namespaceSerializer, stateDesc);
-			}
-
-			@Override
-			public <T> ReducingState<T> createReducingState(ReducingStateDescriptor<T> stateDesc) throws Exception {
-				return KeyedStateBackend.this.createReducingState(namespaceSerializer, stateDesc);
-			}
-
-			@Override
-			public <T, ACC> FoldingState<T, ACC> createFoldingState(FoldingStateDescriptor<T, ACC> stateDesc) throws Exception {
-				return KeyedStateBackend.this.createFoldingState(namespaceSerializer, stateDesc);
-			}
-
-		});
-
-		KvState kvState = (KvState) state;
-
-		keyValueStatesByName.put(stateDescriptor.getName(), kvState);
-
-		lastName = stateDescriptor.getName();
-		lastState = kvState;
-
-		kvState.setCurrentNamespace(namespace);
-
-		// Publish queryable state
-		if (stateDescriptor.isQueryable()) {
-			if (kvStateRegistry == null) {
-				throw new IllegalStateException("State backend has not been initialized for job.");
-			}
-
-			String name = stateDescriptor.getQueryableStateName();
-			kvStateRegistry.registerKvState(keyGroupRange, name, kvState);
-		}
-
-		return state;
-	}
 
 	@SuppressWarnings("unchecked,rawtypes")
-	public <N, S extends MergingState<?, ?>> void mergePartitionedStates(final N target, Collection<N> sources, final TypeSerializer<N> namespaceSerializer, final StateDescriptor<S, ?> stateDescriptor) throws Exception {
-		if (stateDescriptor instanceof ReducingStateDescriptor) {
-			ReducingStateDescriptor reducingStateDescriptor = (ReducingStateDescriptor) stateDescriptor;
-			ReduceFunction reduceFn = reducingStateDescriptor.getReduceFunction();
-			ReducingState state = (ReducingState) getPartitionedState(target, namespaceSerializer, stateDescriptor);
-			KvState kvState = (KvState) state;
-			Object result = null;
-			for (N source: sources) {
-				kvState.setCurrentNamespace(source);
-				Object sourceValue = state.get();
-				if (result == null) {
-					result = state.get();
-				} else if (sourceValue != null) {
-					result = reduceFn.reduce(result, sourceValue);
-				}
-				state.clear();
-			}
-			kvState.setCurrentNamespace(target);
-			if (result != null) {
-				state.add(result);
-			}
-		} else if (stateDescriptor instanceof ListStateDescriptor) {
-			ListState<Object> state = (ListState) getPartitionedState(target, namespaceSerializer, stateDescriptor);
-			KvState kvState = (KvState) state;
-			List<Object> result = new ArrayList<>();
-			for (N source: sources) {
-				kvState.setCurrentNamespace(source);
-				Iterable<Object> sourceValue = state.get();
-				if (sourceValue != null) {
-					for (Object o : sourceValue) {
-						result.add(o);
-					}
-				}
-				state.clear();
-			}
-			kvState.setCurrentNamespace(target);
-			for (Object o : result) {
-				state.add(o);
-			}
-		} else {
-			throw new RuntimeException("Cannot merge states for " + stateDescriptor);
-		}
-	}
+	<N, S extends MergingState<?, ?>> void mergePartitionedStates(
+			N target,
+			Collection<N> sources,
+			TypeSerializer<N> namespaceSerializer,
+			StateDescriptor<S, ?> stateDescriptor) throws Exception;
 
 	/**
-	 * Snapshots the keyed state by writing it to streams that are provided by a
-	 * {@link CheckpointStreamFactory}.
-	 *
-	 * @param checkpointId The ID of the checkpoint.
-	 * @param timestamp The timestamp of the checkpoint.
-	 * @param streamFactory The factory that we can use for writing our state to streams.
-	 *
-	 * @return A future that will yield a {@link KeyGroupsStateHandle} with the index and
-	 * written key group state stream.
+	 * Closes the backend and releases all resources.
 	 */
-	public abstract RunnableFuture<KeyGroupsStateHandle> snapshot(
-			long checkpointId,
-			long timestamp,
-			CheckpointStreamFactory streamFactory) throws Exception;
-
-
-	public KeyGroupRange getKeyGroupRange() {
-		return keyGroupRange;
-	}
+	void dispose();
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateBackend.java
new file mode 100644
index 0000000..4e980b7
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateBackend.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import java.io.Closeable;
+
+/**
+ * Interface that combines both, the user facing {@link OperatorStateStore} interface and the system interface
+ * {@link SnapshotProvider}
+ *
+ */
+public interface OperatorStateBackend extends OperatorStateStore, SnapshotProvider<OperatorStateHandle>, Closeable {
+
+	/**
+	 * Disposes the backend and releases all resources.
+	 */
+	void dispose();
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java
new file mode 100644
index 0000000..3e2d713
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Map;
+
+/**
+ * State handle for partitionable operator state. Besides being a {@link StreamStateHandle}, this also provides a
+ * map that contains the offsets to the partitions of named states in the stream.
+ */
+public class OperatorStateHandle implements StreamStateHandle {
+
+	private static final long serialVersionUID = 35876522969227335L;
+
+	/** unique state name -> offsets for available partitions in the handle stream */
+	private final Map<String, long[]> stateNameToPartitionOffsets;
+	private final StreamStateHandle delegateStateHandle;
+
+	public OperatorStateHandle(
+			StreamStateHandle delegateStateHandle,
+			Map<String, long[]> stateNameToPartitionOffsets) {
+
+		this.delegateStateHandle = Preconditions.checkNotNull(delegateStateHandle);
+		this.stateNameToPartitionOffsets = Preconditions.checkNotNull(stateNameToPartitionOffsets);
+	}
+
+	public Map<String, long[]> getStateNameToPartitionOffsets() {
+		return stateNameToPartitionOffsets;
+	}
+
+	@Override
+	public void discardState() throws Exception {
+		delegateStateHandle.discardState();
+	}
+
+	@Override
+	public long getStateSize() throws IOException {
+		return delegateStateHandle.getStateSize();
+	}
+
+	@Override
+	public FSDataInputStream openInputStream() throws IOException {
+		return delegateStateHandle.openInputStream();
+	}
+
+	public StreamStateHandle getDelegateStateHandle() {
+		return delegateStateHandle;
+	}
+
+	@Override
+	public boolean equals(Object o) {
+		if (this == o) {
+			return true;
+		}
+
+		if (!(o instanceof OperatorStateHandle)) {
+			return false;
+		}
+
+		OperatorStateHandle that = (OperatorStateHandle) o;
+
+		if(stateNameToPartitionOffsets.size() != that.stateNameToPartitionOffsets.size()) {
+			return false;
+		}
+
+		for (Map.Entry<String, long[]> entry : stateNameToPartitionOffsets.entrySet()) {
+			if (!Arrays.equals(entry.getValue(), that.stateNameToPartitionOffsets.get(entry.getKey()))) {
+				return false;
+			}
+		}
+
+		return delegateStateHandle.equals(that.delegateStateHandle);
+	}
+
+	@Override
+	public int hashCode() {
+		int result = delegateStateHandle.hashCode();
+		for (Map.Entry<String, long[]> entry : stateNameToPartitionOffsets.entrySet()) {
+
+			int entryHash = entry.getKey().hashCode();
+			if (entry.getValue() != null) {
+				entryHash += Arrays.hashCode(entry.getValue());
+			}
+			result = 31 * result + entryHash;
+		}
+		return result;
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateStore.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateStore.java
new file mode 100644
index 0000000..6914a7c
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateStore.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+
+import java.util.Set;
+
+/**
+ * Interface for a backend that manages partitionable operator state.
+ */
+public interface OperatorStateStore {
+
+	/**
+	 * Creates (or restores) the partitionable state in this backend. Each state is registered under a unique name.
+	 * The provided serializer is used to de/serialize the state in case of checkpointing (snapshot/restore).
+	 *
+	 * @param stateDescriptor The descriptr for this state, providing a name and serializer
+	 * @param <S> The generic type of the state
+	 * @return A list for all state partitions.
+	 * @throws Exception
+	 */
+	<S> ListState<S> getPartitionableState(ListStateDescriptor<S> stateDescriptor) throws Exception;
+
+	/**
+	 * Returns a set with the names of all currently registered states.
+	 * @return set of names for all registered states.
+	 */
+	Set<String> getRegisteredStateNames();
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/PartitionableCheckpointStateOutputStream.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/PartitionableCheckpointStateOutputStream.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/PartitionableCheckpointStateOutputStream.java
new file mode 100644
index 0000000..065f9c2
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/PartitionableCheckpointStateOutputStream.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+
+public class PartitionableCheckpointStateOutputStream extends FSDataOutputStream {
+
+	private final Map<String, long[]> stateNameToPartitionOffsets;
+	private final CheckpointStreamFactory.CheckpointStateOutputStream delegate;
+
+	public PartitionableCheckpointStateOutputStream(CheckpointStreamFactory.CheckpointStateOutputStream delegate) {
+		this.delegate = Preconditions.checkNotNull(delegate);
+		this.stateNameToPartitionOffsets = new HashMap<>();
+	}
+
+	@Override
+	public long getPos() throws IOException {
+		return delegate.getPos();
+	}
+
+	@Override
+	public void flush() throws IOException {
+		delegate.flush();
+	}
+
+	@Override
+	public void sync() throws IOException {
+		delegate.sync();
+	}
+
+	@Override
+	public void write(int b) throws IOException {
+		delegate.write(b);
+	}
+
+	@Override
+	public void write(byte[] b) throws IOException {
+		delegate.write(b);
+	}
+
+	@Override
+	public void write(byte[] b, int off, int len) throws IOException {
+		delegate.write(b, off, len);
+	}
+
+	@Override
+	public void close() throws IOException {
+		delegate.close();
+	}
+
+	public OperatorStateHandle closeAndGetHandle() throws IOException {
+		StreamStateHandle streamStateHandle = delegate.closeAndGetHandle();
+		return new OperatorStateHandle(streamStateHandle, stateNameToPartitionOffsets);
+	}
+
+	public void startNewPartition(String stateName) throws IOException {
+		long[] offs = stateNameToPartitionOffsets.get(stateName);
+		if (offs == null) {
+			offs = new long[1];
+		} else {
+			//TODO maybe we can use some primitive array list here instead of an array to avoid resize on each call.
+			offs = Arrays.copyOf(offs, offs.length + 1);
+		}
+
+		offs[offs.length - 1] = getPos();
+		stateNameToPartitionOffsets.put(stateName, offs);
+	}
+
+	public static PartitionableCheckpointStateOutputStream wrap(
+			CheckpointStreamFactory.CheckpointStateOutputStream stream) {
+		return new PartitionableCheckpointStateOutputStream(stream);
+	}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/RetrievableStreamStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/RetrievableStreamStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/RetrievableStreamStateHandle.java
index 9ecc4c9..9934382 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/RetrievableStreamStateHandle.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/RetrievableStreamStateHandle.java
@@ -76,6 +76,6 @@ public class RetrievableStreamStateHandle<T extends Serializable> implements
 
 	@Override
 	public void close() throws IOException {
-		wrappedStreamStateHandle.close();
+//		wrappedStreamStateHandle.close();
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/SnapshotProvider.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SnapshotProvider.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SnapshotProvider.java
new file mode 100644
index 0000000..c47fedd
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SnapshotProvider.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import java.util.concurrent.RunnableFuture;
+
+/**
+ * Interface for operations that can perform snapshots of their state.
+ *
+ * @param <S> Generic type of the state object that is created as handle to snapshots.
+ */
+public interface SnapshotProvider<S extends StateObject> {
+
+	/**
+	 * Operation that writes a snapshot into a stream that is provided by the given {@link CheckpointStreamFactory} and
+	 * returns a @{@link RunnableFuture} that gives a state handle to the snapshot. It is up to the implementation if
+	 * the operation is performed synchronous or asynchronous. In the later case, the returned Runnable must be executed
+	 * first before obtaining the handle.
+	 *
+	 * @param checkpointId  The ID of the checkpoint.
+	 * @param timestamp     The timestamp of the checkpoint.
+	 * @param streamFactory The factory that we can use for writing our state to streams.
+	 * @return A runnable future that will yield a {@link StateObject}.
+	 */
+	RunnableFuture<S> snapshot(
+			long checkpointId,
+			long timestamp,
+			CheckpointStreamFactory streamFactory) throws Exception;
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateObject.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateObject.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateObject.java
index 4c65318..a502b9d 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateObject.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateObject.java
@@ -28,13 +28,9 @@ import java.io.IOException;
  * <ul>
  *     <li><b>Discard State</b>: The {@link #discardState()} method defines how state is permanently
  *         disposed/deleted. After that method call, state may not be recoverable any more.</li>
- 
- *     <li><b>Close the current state access</b>: The {@link #close()} method defines how to
- *         stop the current access or recovery to the state. Called for example when an operation is
- *         canceled during recovery.</li>
  * </ul>
  */
-public interface StateObject extends java.io.Closeable, java.io.Serializable {
+public interface StateObject extends java.io.Serializable {
 
 	/**
 	 * Discards the state referred to by this handle, to free up resources in

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtil.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtil.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtil.java
index aa28404..a4799bf 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtil.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtil.java
@@ -18,8 +18,6 @@
 
 package org.apache.flink.runtime.state;
 
-import java.io.IOException;
-
 /**
  * Helpers for {@link StateObject} related code.
  */
@@ -63,39 +61,4 @@ public class StateUtil {
 			}
 		}
 	}
-
-	/**
-	 * Iterates through the passed state handles and calls discardState() on each handle that is not null. All
-	 * occurring exceptions are suppressed and collected until the iteration is over and emitted as a single exception.
-	 *
-	 * @param handlesToDiscard State handles to discard. Passed iterable is allowed to deliver null values.
-	 * @throws IOException exception that is a collection of all suppressed exceptions that were caught during iteration
-	 */
-	public static void bestEffortCloseAllStateObjects(
-			Iterable<? extends StateObject> handlesToDiscard) throws IOException {
-
-		if (handlesToDiscard != null) {
-
-			IOException suppressedExceptions = null;
-
-			for (StateObject state : handlesToDiscard) {
-
-				if (state != null) {
-					try {
-						state.close();
-					} catch (Exception ex) {
-						//best effort to still cleanup other states and deliver exceptions in the end
-						if (suppressedExceptions == null) {
-							suppressedExceptions = new IOException(ex);
-						}
-						suppressedExceptions.addSuppressed(ex);
-					}
-				}
-			}
-
-			if (suppressedExceptions != null) {
-				throw suppressedExceptions;
-			}
-		}
-	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FileStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FileStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FileStateHandle.java
index f361263..29e905c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FileStateHandle.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FileStateHandle.java
@@ -21,7 +21,6 @@ package org.apache.flink.runtime.state.filesystem;
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
-import org.apache.flink.runtime.state.AbstractCloseableHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 
 import java.io.IOException;
@@ -34,7 +33,7 @@ import static org.apache.flink.util.Preconditions.checkNotNull;
  * {@link StreamStateHandle} for state that was written to a file stream. The written data is
  * identifier by the file path. The state can be read again by calling {@link #openInputStream()}.
  */
-public class FileStateHandle extends AbstractCloseableHandle implements StreamStateHandle {
+public class FileStateHandle implements StreamStateHandle {
 
 	private static final long serialVersionUID = 350284443258002355L;
 
@@ -69,10 +68,7 @@ public class FileStateHandle extends AbstractCloseableHandle implements StreamSt
 
 	@Override
 	public FSDataInputStream openInputStream() throws IOException {
-		ensureNotClosed();
-		FSDataInputStream inputStream = getFileSystem().open(filePath);
-		registerCloseable(inputStream);
-		return inputStream;
+		return getFileSystem().open(filePath);
 	}
 
 	/**

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
index 99e3684..e027632 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
@@ -24,11 +24,11 @@ import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -175,7 +175,7 @@ public class FsStateBackend extends AbstractStateBackend {
 	}
 
 	@Override
-	public <K> KeyedStateBackend<K> createKeyedStateBackend(
+	public <K> AbstractKeyedStateBackend<K> createKeyedStateBackend(
 			Environment env,
 			JobID jobID,
 			String operatorIdentifier,
@@ -192,7 +192,7 @@ public class FsStateBackend extends AbstractStateBackend {
 	}
 
 	@Override
-	public <K> KeyedStateBackend<K> restoreKeyedStateBackend(
+	public <K> AbstractKeyedStateBackend<K> restoreKeyedStateBackend(
 			Environment env,
 			JobID jobID,
 			String operatorIdentifier,

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
index c13be70..a766373 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.state.heap;
 
+import org.apache.commons.io.IOUtils;
 import org.apache.flink.api.common.state.FoldingState;
 import org.apache.flink.api.common.state.FoldingStateDescriptor;
 import org.apache.flink.api.common.state.ListState;
@@ -27,17 +28,18 @@ import org.apache.flink.api.common.state.ReducingStateDescriptor;
 import org.apache.flink.api.common.state.ValueState;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.ArrayListSerializer;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.DoneFuture;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.util.InstantiationUtil;
 import org.apache.flink.util.Preconditions;
@@ -51,13 +53,13 @@ import java.util.Map;
 import java.util.concurrent.RunnableFuture;
 
 /**
- * A {@link KeyedStateBackend} that keeps state on the Java Heap and will serialize state to
+ * A {@link AbstractKeyedStateBackend} that keeps state on the Java Heap and will serialize state to
  * streams provided by a {@link org.apache.flink.runtime.state.CheckpointStreamFactory} upon
  * checkpointing.
  *
  * @param <K> The key by which state is keyed.
  */
-public class HeapKeyedStateBackend<K> extends KeyedStateBackend<K> {
+public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 	private static final Logger LOG = LoggerFactory.getLogger(HeapKeyedStateBackend.class);
 
@@ -165,85 +167,83 @@ public class HeapKeyedStateBackend<K> extends KeyedStateBackend<K> {
 			long timestamp,
 			CheckpointStreamFactory streamFactory) throws Exception {
 
-		CheckpointStreamFactory.CheckpointStateOutputStream stream =
-				streamFactory.createCheckpointStateOutputStream(
-						checkpointId,
-						timestamp);
-
 		if (stateTables.isEmpty()) {
 			return new DoneFuture<>(null);
 		}
 
-		DataOutputViewStreamWrapper outView = new DataOutputViewStreamWrapper(stream);
+		try (CheckpointStreamFactory.CheckpointStateOutputStream stream = streamFactory.
+				createCheckpointStateOutputStream(checkpointId, timestamp)) {
 
-		Preconditions.checkState(stateTables.size() <= Short.MAX_VALUE,
-				"Too many KV-States: " + stateTables.size() +
-						". Currently at most " + Short.MAX_VALUE + " states are supported");
+			DataOutputViewStreamWrapper outView = new DataOutputViewStreamWrapper(stream);
 
-		outView.writeShort(stateTables.size());
+			Preconditions.checkState(stateTables.size() <= Short.MAX_VALUE,
+					"Too many KV-States: " + stateTables.size() +
+							". Currently at most " + Short.MAX_VALUE + " states are supported");
 
-		Map<String, Integer> kVStateToId = new HashMap<>(stateTables.size());
+			outView.writeShort(stateTables.size());
 
-		for (Map.Entry<String, StateTable<K, ?, ?>> kvState : stateTables.entrySet()) {
+			Map<String, Integer> kVStateToId = new HashMap<>(stateTables.size());
 
-			outView.writeUTF(kvState.getKey());
+			for (Map.Entry<String, StateTable<K, ?, ?>> kvState : stateTables.entrySet()) {
 
-			TypeSerializer namespaceSerializer = kvState.getValue().getNamespaceSerializer();
-			TypeSerializer stateSerializer = kvState.getValue().getStateSerializer();
+				outView.writeUTF(kvState.getKey());
 
-			InstantiationUtil.serializeObject(stream, namespaceSerializer);
-			InstantiationUtil.serializeObject(stream, stateSerializer);
+				TypeSerializer namespaceSerializer = kvState.getValue().getNamespaceSerializer();
+				TypeSerializer stateSerializer = kvState.getValue().getStateSerializer();
 
-			kVStateToId.put(kvState.getKey(), kVStateToId.size());
-		}
+				InstantiationUtil.serializeObject(stream, namespaceSerializer);
+				InstantiationUtil.serializeObject(stream, stateSerializer);
 
-		int offsetCounter = 0;
-		long[] keyGroupRangeOffsets = new long[keyGroupRange.getNumberOfKeyGroups()];
+				kVStateToId.put(kvState.getKey(), kVStateToId.size());
+			}
 
-		for (int keyGroupIndex = keyGroupRange.getStartKeyGroup(); keyGroupIndex <= keyGroupRange.getEndKeyGroup(); keyGroupIndex++) {
-			keyGroupRangeOffsets[offsetCounter++] = stream.getPos();
-			outView.writeInt(keyGroupIndex);
+			int offsetCounter = 0;
+			long[] keyGroupRangeOffsets = new long[keyGroupRange.getNumberOfKeyGroups()];
 
-			for (Map.Entry<String, StateTable<K, ?, ?>> kvState : stateTables.entrySet()) {
+			for (int keyGroupIndex = keyGroupRange.getStartKeyGroup(); keyGroupIndex <= keyGroupRange.getEndKeyGroup(); keyGroupIndex++) {
+				keyGroupRangeOffsets[offsetCounter++] = stream.getPos();
+				outView.writeInt(keyGroupIndex);
 
-				outView.writeShort(kVStateToId.get(kvState.getKey()));
+				for (Map.Entry<String, StateTable<K, ?, ?>> kvState : stateTables.entrySet()) {
 
-				TypeSerializer namespaceSerializer = kvState.getValue().getNamespaceSerializer();
-				TypeSerializer stateSerializer = kvState.getValue().getStateSerializer();
+					outView.writeShort(kVStateToId.get(kvState.getKey()));
 
-				// Map<NamespaceT, Map<KeyT, StateT>>
-				Map<?, ? extends Map<K, ?>> namespaceMap = kvState.getValue().get(keyGroupIndex);
-				if (namespaceMap == null) {
-					outView.writeByte(0);
-					continue;
-				}
+					TypeSerializer namespaceSerializer = kvState.getValue().getNamespaceSerializer();
+					TypeSerializer stateSerializer = kvState.getValue().getStateSerializer();
+
+					// Map<NamespaceT, Map<KeyT, StateT>>
+					Map<?, ? extends Map<K, ?>> namespaceMap = kvState.getValue().get(keyGroupIndex);
+					if (namespaceMap == null) {
+						outView.writeByte(0);
+						continue;
+					}
 
-				outView.writeByte(1);
+					outView.writeByte(1);
 
-				// number of namespaces
-				outView.writeInt(namespaceMap.size());
-				for (Map.Entry<?, ? extends Map<K, ?>> namespace : namespaceMap.entrySet()) {
-					namespaceSerializer.serialize(namespace.getKey(), outView);
+					// number of namespaces
+					outView.writeInt(namespaceMap.size());
+					for (Map.Entry<?, ? extends Map<K, ?>> namespace : namespaceMap.entrySet()) {
+						namespaceSerializer.serialize(namespace.getKey(), outView);
 
-					Map<K, ?> entryMap = namespace.getValue();
+						Map<K, ?> entryMap = namespace.getValue();
 
-					// number of entries
-					outView.writeInt(entryMap.size());
-					for (Map.Entry<K, ?> entry : entryMap.entrySet()) {
-						keySerializer.serialize(entry.getKey(), outView);
-						stateSerializer.serialize(entry.getValue(), outView);
+						// number of entries
+						outView.writeInt(entryMap.size());
+						for (Map.Entry<K, ?> entry : entryMap.entrySet()) {
+							keySerializer.serialize(entry.getKey(), outView);
+							stateSerializer.serialize(entry.getValue(), outView);
+						}
 					}
 				}
+				outView.flush();
 			}
-			outView.flush();
-		}
-
-		StreamStateHandle streamStateHandle = stream.closeAndGetHandle();
 
-		KeyGroupRangeOffsets offsets = new KeyGroupRangeOffsets(keyGroupRange, keyGroupRangeOffsets);
-		final KeyGroupsStateHandle keyGroupsStateHandle = new KeyGroupsStateHandle(offsets, streamStateHandle);
+			StreamStateHandle streamStateHandle = stream.closeAndGetHandle();
 
-		return new DoneFuture(keyGroupsStateHandle);
+			KeyGroupRangeOffsets offsets = new KeyGroupRangeOffsets(keyGroupRange, keyGroupRangeOffsets);
+			final KeyGroupsStateHandle keyGroupsStateHandle = new KeyGroupsStateHandle(offsets, streamStateHandle);
+			return new DoneFuture<>(keyGroupsStateHandle);
+		}
 	}
 
 	@SuppressWarnings({"unchecked", "rawtypes"})
@@ -251,71 +251,81 @@ public class HeapKeyedStateBackend<K> extends KeyedStateBackend<K> {
 
 		for (KeyGroupsStateHandle keyGroupsHandle : state) {
 
-			if(keyGroupsHandle == null) {
+			if (keyGroupsHandle == null) {
 				continue;
 			}
 
-			FSDataInputStream fsDataInputStream = keyGroupsHandle.getStateHandle().openInputStream();
-			DataInputViewStreamWrapper inView = new DataInputViewStreamWrapper(fsDataInputStream);
+			FSDataInputStream fsDataInputStream = null;
 
-			int numKvStates = inView.readShort();
+			try {
 
-			Map<Integer, String> kvStatesById = new HashMap<>(numKvStates);
+				fsDataInputStream = keyGroupsHandle.getStateHandle().openInputStream();
+				cancelStreamRegistry.registerClosable(fsDataInputStream);
 
-			for (int i = 0; i < numKvStates; ++i) {
-				String stateName = inView.readUTF();
+				DataInputViewStreamWrapper inView = new DataInputViewStreamWrapper(fsDataInputStream);
 
-				TypeSerializer namespaceSerializer =
-						InstantiationUtil.deserializeObject(fsDataInputStream, userCodeClassLoader);
-				TypeSerializer stateSerializer =
-						InstantiationUtil.deserializeObject(fsDataInputStream, userCodeClassLoader);
+				int numKvStates = inView.readShort();
 
-				StateTable<K, ?, ?> stateTable = new StateTable(
-						stateSerializer,
-						namespaceSerializer,
-						keyGroupRange);
-				stateTables.put(stateName, stateTable);
-				kvStatesById.put(i, stateName);
-			}
+				Map<Integer, String> kvStatesById = new HashMap<>(numKvStates);
+
+				for (int i = 0; i < numKvStates; ++i) {
+					String stateName = inView.readUTF();
 
-			for (int keyGroupIndex = keyGroupRange.getStartKeyGroup();  keyGroupIndex <= keyGroupRange.getEndKeyGroup(); ++keyGroupIndex) {
-				long offset = keyGroupsHandle.getOffsetForKeyGroup(keyGroupIndex);
-				fsDataInputStream.seek(offset);
+					TypeSerializer namespaceSerializer =
+							InstantiationUtil.deserializeObject(fsDataInputStream, userCodeClassLoader);
+					TypeSerializer stateSerializer =
+							InstantiationUtil.deserializeObject(fsDataInputStream, userCodeClassLoader);
 
-				int writtenKeyGroupIndex = inView.readInt();
-				assert writtenKeyGroupIndex == keyGroupIndex;
+					StateTable<K, ?, ?> stateTable = new StateTable(stateSerializer,
+							namespaceSerializer,
+							keyGroupRange);
+					stateTables.put(stateName, stateTable);
+					kvStatesById.put(i, stateName);
+				}
 
-				for (int i = 0; i < numKvStates; i++) {
-					int kvStateId = inView.readShort();
+				for (Tuple2<Integer, Long> groupOffset : keyGroupsHandle.getGroupRangeOffsets()) {
+					int keyGroupIndex = groupOffset.f0;
+					long offset = groupOffset.f1;
+					fsDataInputStream.seek(offset);
 
-					byte isPresent = inView.readByte();
-					if (isPresent == 0) {
-						continue;
-					}
+					int writtenKeyGroupIndex = inView.readInt();
+					assert writtenKeyGroupIndex == keyGroupIndex;
+
+					for (int i = 0; i < numKvStates; i++) {
+						int kvStateId = inView.readShort();
+
+						byte isPresent = inView.readByte();
+						if (isPresent == 0) {
+							continue;
+						}
 
-					StateTable<K, ?, ?> stateTable = stateTables.get(kvStatesById.get(kvStateId));
-					Preconditions.checkNotNull(stateTable);
+						StateTable<K, ?, ?> stateTable = stateTables.get(kvStatesById.get(kvStateId));
+						Preconditions.checkNotNull(stateTable);
 
-					TypeSerializer namespaceSerializer = stateTable.getNamespaceSerializer();
-					TypeSerializer stateSerializer = stateTable.getStateSerializer();
+						TypeSerializer namespaceSerializer = stateTable.getNamespaceSerializer();
+						TypeSerializer stateSerializer = stateTable.getStateSerializer();
 
-					Map namespaceMap = new HashMap<>();
-					stateTable.set(keyGroupIndex, namespaceMap);
+						Map namespaceMap = new HashMap<>();
+						stateTable.set(keyGroupIndex, namespaceMap);
 
-					int numNamespaces = inView.readInt();
-					for (int k = 0; k < numNamespaces; k++) {
-						Object namespace = namespaceSerializer.deserialize(inView);
-						Map entryMap = new HashMap<>();
-						namespaceMap.put(namespace, entryMap);
+						int numNamespaces = inView.readInt();
+						for (int k = 0; k < numNamespaces; k++) {
+							Object namespace = namespaceSerializer.deserialize(inView);
+							Map entryMap = new HashMap<>();
+							namespaceMap.put(namespace, entryMap);
 
-						int numEntries = inView.readInt();
-						for (int l = 0; l < numEntries; l++) {
-							Object key = keySerializer.deserialize(inView);
-							Object value = stateSerializer.deserialize(inView);
-							entryMap.put(key, value);
+							int numEntries = inView.readInt();
+							for (int l = 0; l < numEntries; l++) {
+								Object key = keySerializer.deserialize(inView);
+								Object value = stateSerializer.deserialize(inView);
+								entryMap.put(key, value);
+							}
 						}
 					}
 				}
+			} finally {
+				cancelStreamRegistry.unregisterClosable(fsDataInputStream);
+				IOUtils.closeQuietly(fsDataInputStream);
 			}
 		}
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java
index b9ff255..7d8b6ce 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java
@@ -19,10 +19,8 @@
 package org.apache.flink.runtime.state.memory;
 
 import org.apache.flink.core.fs.FSDataInputStream;
-import org.apache.flink.runtime.state.AbstractCloseableHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.util.InstantiationUtil;
-
 import org.apache.flink.util.Preconditions;
 
 import java.io.IOException;
@@ -32,7 +30,7 @@ import java.util.Arrays;
 /**
  * A state handle that contains stream state in a byte array.
  */
-public class ByteStreamStateHandle extends AbstractCloseableHandle implements StreamStateHandle {
+public class ByteStreamStateHandle implements StreamStateHandle {
 
 	private static final long serialVersionUID = -5280226231200217594L;
 
@@ -52,9 +50,8 @@ public class ByteStreamStateHandle extends AbstractCloseableHandle implements St
 
 	@Override
 	public FSDataInputStream openInputStream() throws IOException {
-		ensureNotClosed();
 
-		FSDataInputStream inputStream = new FSDataInputStream() {
+		return new FSDataInputStream() {
 			int index = 0;
 
 			@Override
@@ -73,8 +70,6 @@ public class ByteStreamStateHandle extends AbstractCloseableHandle implements St
 				return index < data.length ? data[index++] & 0xFF : -1;
 			}
 		};
-		registerCloseable(inputStream);
-		return inputStream;
 	}
 
 	public byte[] getData() {
@@ -106,9 +101,7 @@ public class ByteStreamStateHandle extends AbstractCloseableHandle implements St
 
 	@Override
 	public int hashCode() {
-		int result = super.hashCode();
-		result = 31 * result + Arrays.hashCode(data);
-		return result;
+		return Arrays.hashCode(data);
 	}
 
 	public static StreamStateHandle fromSerializable(Serializable value) throws IOException {

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
index cc145ff..1772dbe 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
@@ -22,11 +22,11 @@ import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend;
 
 import java.io.IOException;
@@ -71,12 +71,13 @@ public class MemoryStateBackend extends AbstractStateBackend {
 	}
 
 	@Override
-	public CheckpointStreamFactory createStreamFactory(JobID jobId, String operatorIdentifier) throws IOException {
+	public CheckpointStreamFactory createStreamFactory(
+			JobID jobId, String operatorIdentifier) throws IOException {
 		return new MemCheckpointStreamFactory(maxStateSize);
 	}
 
 	@Override
-	public <K> KeyedStateBackend<K> createKeyedStateBackend(
+	public <K> AbstractKeyedStateBackend<K> createKeyedStateBackend(
 			Environment env, JobID jobID,
 			String operatorIdentifier,
 			TypeSerializer<K> keySerializer,
@@ -93,7 +94,7 @@ public class MemoryStateBackend extends AbstractStateBackend {
 	}
 
 	@Override
-	public <K> KeyedStateBackend<K> restoreKeyedStateBackend(
+	public <K> AbstractKeyedStateBackend<K> restoreKeyedStateBackend(
 			Environment env, JobID jobID,
 			String operatorIdentifier,
 			TypeSerializer<K> keySerializer,

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java
index c317bed..8bf1127 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ActorGatewayCheckpointResponder.java
@@ -23,13 +23,9 @@ import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.instance.ActorGateway;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint;
-import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.util.Preconditions;
 
-import java.util.List;
-
 /**
  * Implementation using {@link ActorGateway} to forward the messages.
  */
@@ -46,8 +42,7 @@ public class ActorGatewayCheckpointResponder implements CheckpointResponder {
 			JobID jobID,
 			ExecutionAttemptID executionAttemptID,
 			long checkpointID,
-			ChainedStateHandle<StreamStateHandle> chainedStateHandle,
-			List<KeyGroupsStateHandle> keyGroupStateHandles,
+			CheckpointStateHandles checkpointStateHandles,
 			long synchronousDurationMillis,
 			long asynchronousDurationMillis,
 			long bytesBufferedInAlignment,
@@ -55,7 +50,7 @@ public class ActorGatewayCheckpointResponder implements CheckpointResponder {
 
 		AcknowledgeCheckpoint message = new AcknowledgeCheckpoint(
 				jobID, executionAttemptID, checkpointID,
-				chainedStateHandle, keyGroupStateHandles,
+				checkpointStateHandles,
 				synchronousDurationMillis, asynchronousDurationMillis,
 				bytesBufferedInAlignment, alignmentDurationNanos);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java
index b3f9827..698a7f4 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/CheckpointResponder.java
@@ -20,11 +20,7 @@ package org.apache.flink.runtime.taskmanager;
 
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
-import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.StreamStateHandle;
-
-import java.util.List;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 
 /**
  * Responder for checkpoint acknowledge and decline messages in the {@link Task}.
@@ -40,10 +36,8 @@ public interface CheckpointResponder {
 	 *             Execution attempt ID of the running task
 	 * @param checkpointID
 	 *             Checkpoint ID of the checkpoint
-	 * @param chainedStateHandle
-	 *             Chained state handle
-	 * @param keyGroupStateHandles
-	 *             State handles for key groups
+	 * @param checkpointStateHandles 
+	 *             State handles for the checkpoint
 	 * @param synchronousDurationMillis
 	 *             The duration (in milliseconds) of the synchronous part of the operator checkpoint
 	 * @param asynchronousDurationMillis
@@ -57,8 +51,7 @@ public interface CheckpointResponder {
 		JobID jobID,
 		ExecutionAttemptID executionAttemptID,
 		long checkpointID,
-		ChainedStateHandle<StreamStateHandle> chainedStateHandle,
-		List<KeyGroupsStateHandle> keyGroupStateHandles,
+		CheckpointStateHandles checkpointStateHandles,
 		long synchronousDurationMillis,
 		long asynchronousDurationMillis,
 		long bytesBufferedInAlignment,

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java
index 23b6f82..c2ba7ef 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java
@@ -35,11 +35,8 @@ import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
 import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.metrics.groups.TaskMetricGroup;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
-import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 
-import java.util.List;
 import java.util.Map;
 import java.util.concurrent.Future;
 
@@ -246,7 +243,7 @@ public class RuntimeEnvironment implements Environment {
 			long bytesBufferedInAlignment,
 			long alignmentDurationNanos) {
 
-		acknowledgeCheckpoint(checkpointId, null, null,
+		acknowledgeCheckpoint(checkpointId, null,
 				synchronousDurationMillis, asynchronousDurationMillis,
 				bytesBufferedInAlignment, alignmentDurationNanos);
 	}
@@ -254,8 +251,7 @@ public class RuntimeEnvironment implements Environment {
 	@Override
 	public void acknowledgeCheckpoint(
 			long checkpointId,
-			ChainedStateHandle<StreamStateHandle> chainedStateHandle,
-			List<KeyGroupsStateHandle> keyGroupStateHandles,
+			CheckpointStateHandles checkpointStateHandles,
 			long synchronousDurationMillis,
 			long asynchronousDurationMillis,
 			long bytesBufferedInAlignment,
@@ -264,7 +260,7 @@ public class RuntimeEnvironment implements Environment {
 
 		checkpointResponder.acknowledgeCheckpoint(
 				jobId, executionId, checkpointId,
-				chainedStateHandle, keyGroupStateHandles,
+				checkpointStateHandles,
 				synchronousDurationMillis, asynchronousDurationMillis,
 				bytesBufferedInAlignment, alignmentDurationNanos);
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
index 62dc8b7..8463fa0 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
@@ -59,6 +59,7 @@ import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 
 import org.apache.flink.util.Preconditions;
@@ -68,6 +69,7 @@ import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
 import java.net.URL;
+import java.util.Collection;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -241,6 +243,8 @@ public class Task implements Runnable, TaskActions {
 	 */
 	private volatile List<KeyGroupsStateHandle> keyGroupStates;
 
+	private volatile List<Collection<OperatorStateHandle>> partitionableOperatorState;
+
 	/** Initialized from the Flink configuration. May also be set at the ExecutionConfig */
 	private long taskCancellationInterval;
 
@@ -278,6 +282,7 @@ public class Task implements Runnable, TaskActions {
 		this.chainedOperatorState = tdd.getOperatorState();
 		this.serializedExecutionConfig = checkNotNull(tdd.getSerializedExecutionConfig());
 		this.keyGroupStates = tdd.getKeyGroupState();
+		this.partitionableOperatorState = tdd.getPartitionableOperatorState();
 
 		this.taskCancellationInterval = jobConfiguration.getLong(
 			ConfigConstants.TASK_CANCELLATION_INTERVAL_MILLIS,
@@ -488,7 +493,7 @@ public class Task implements Runnable, TaskActions {
 		Map<String, Future<Path>> distributedCacheEntries = new HashMap<String, Future<Path>>();
 		AbstractInvokable invokable = null;
 
-		ClassLoader userCodeClassLoader = null;
+		ClassLoader userCodeClassLoader;
 		try {
 			// ----------------------------
 			//  Task Bootstrap - We periodically
@@ -564,10 +569,10 @@ public class Task implements Runnable, TaskActions {
 			// the state into the task. the state is non-empty if this is an execution
 			// of a task that failed but had backuped state from a checkpoint
 
-			if (chainedOperatorState != null || keyGroupStates != null) {
+			if (chainedOperatorState != null || keyGroupStates != null || partitionableOperatorState != null) {
 				if (invokable instanceof StatefulTask) {
 					StatefulTask op = (StatefulTask) invokable;
-					op.setInitialState(chainedOperatorState, keyGroupStates);
+					op.setInitialState(chainedOperatorState, keyGroupStates, partitionableOperatorState);
 				} else {
 					throw new IllegalStateException("Found operator state for a non-stateful task invokable");
 				}