You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by se...@apache.org on 2016/09/30 12:47:53 UTC

[03/10] flink git commit: [FLINK-4379] [checkpoints] Introduce rescalable operator state

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
index a73f3b2..0ca89ef 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
@@ -18,29 +18,35 @@
 
 package org.apache.flink.streaming.api.operators;
 
+import org.apache.commons.io.IOUtils;
 import org.apache.flink.annotation.PublicEvolving;
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.state.State;
 import org.apache.flink.api.common.state.StateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.functions.KeySelector;
-import org.apache.flink.core.fs.FSDataInputStream;
-import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.metrics.Counter;
 import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
+import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.KeyedStateBackend;
+import org.apache.flink.runtime.state.OperatorStateBackend;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.watermark.Watermark;
-import org.apache.flink.streaming.runtime.tasks.TimeServiceProvider;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
+import org.apache.flink.streaming.runtime.tasks.TimeServiceProvider;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.util.Collection;
+import java.util.concurrent.RunnableFuture;
+
 /**
  * Base class for all stream operators. Operators that contain a user function should extend the class 
  * {@link AbstractUdfStreamOperator} instead (which is a specialized subclass of this class). 
@@ -90,7 +96,12 @@ public abstract class AbstractStreamOperator<OUT>
 	private transient KeySelector<?, ?> stateKeySelector2;
 
 	/** Backend for keyed state. This might be empty if we're not on a keyed stream. */
-	private transient KeyedStateBackend<?> keyedStateBackend;
+	private transient AbstractKeyedStateBackend<?> keyedStateBackend;
+
+	/** Operator state backend */
+	private transient OperatorStateBackend operatorStateBackend;
+
+	private transient Collection<OperatorStateHandle> lazyRestoreStateHandles;
 
 	protected transient MetricGroup metrics;
 
@@ -116,9 +127,14 @@ public abstract class AbstractStreamOperator<OUT>
 		return metrics;
 	}
 
+	@Override
+	public void restoreState(Collection<OperatorStateHandle> stateHandles) {
+		this.lazyRestoreStateHandles = stateHandles;
+	}
+
 	/**
 	 * This method is called immediately before any elements are processed, it should contain the
-	 * operator's initialization logic.
+	 * operator's initialization logic, e.g. state initialization.
 	 *
 	 * <p>The default implementation does nothing.
 	 * 
@@ -126,24 +142,39 @@ public abstract class AbstractStreamOperator<OUT>
 	 */
 	@Override
 	public void open() throws Exception {
+		initOperatorState();
+		initKeyedState();
+	}
+
+	private void initKeyedState() {
 		try {
 			TypeSerializer<Object> keySerializer = config.getStateKeySerializer(getUserCodeClassloader());
 			// create a keyed state backend if there is keyed state, as indicated by the presence of a key serializer
 			if (null != keySerializer) {
-				ExecutionConfig execConf = container.getEnvironment().getExecutionConfig();;
 
 				KeyGroupRange subTaskKeyGroupRange = KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(
 						container.getEnvironment().getTaskInfo().getNumberOfKeyGroups(),
 						container.getEnvironment().getTaskInfo().getNumberOfParallelSubtasks(),
 						container.getIndexInSubtaskGroup());
 
-				keyedStateBackend = container.createKeyedStateBackend(
+				this.keyedStateBackend = container.createKeyedStateBackend(
 						keySerializer,
 						container.getConfiguration().getNumberOfKeyGroups(getUserCodeClassloader()),
 						subTaskKeyGroupRange);
+
 			}
+
+		} catch (Exception e) {
+			throw new IllegalStateException("Could not initialize keyed state backend.", e);
+		}
+	}
+
+	private void initOperatorState() {
+		try {
+			// create an operator state backend
+			this.operatorStateBackend = container.createOperatorStateBackend(this, lazyRestoreStateHandles);
 		} catch (Exception e) {
-			throw new RuntimeException("Could not initialize keyed state backend.", e);
+			throw new IllegalStateException("Could not initialize operator state backend.", e);
 		}
 	}
 
@@ -171,18 +202,25 @@ public abstract class AbstractStreamOperator<OUT>
 	 */
 	@Override
 	public void dispose() throws Exception {
+
+		if (operatorStateBackend != null) {
+			IOUtils.closeQuietly(operatorStateBackend);
+			operatorStateBackend.dispose();
+		}
+
 		if (keyedStateBackend != null) {
-			keyedStateBackend.close();
+			IOUtils.closeQuietly(keyedStateBackend);
+			keyedStateBackend.dispose();
 		}
 	}
 
 	@Override
-	public void snapshotState(FSDataOutputStream out,
-			long checkpointId,
-			long timestamp) throws Exception {}
+	public RunnableFuture<OperatorStateHandle> snapshotState(
+			long checkpointId, long timestamp, CheckpointStreamFactory streamFactory) throws Exception {
 
-	@Override
-	public void restoreState(FSDataInputStream in) throws Exception {}
+		return operatorStateBackend != null ?
+				operatorStateBackend.snapshot(checkpointId, timestamp, streamFactory) : null;
+	}
 
 	@Override
 	public void notifyOfCompletedCheckpoint(long checkpointId) throws Exception {}
@@ -223,10 +261,24 @@ public abstract class AbstractStreamOperator<OUT>
 	}
 
 	@SuppressWarnings("rawtypes, unchecked")
-	public <K> KeyedStateBackend<K> getStateBackend() {
+	public <K> KeyedStateBackend<K> getKeyedStateBackend() {
+
+		if (null == keyedStateBackend) {
+			initKeyedState();
+		}
+
 		return (KeyedStateBackend<K>) keyedStateBackend;
 	}
 
+	public OperatorStateBackend getOperatorStateBackend() {
+
+		if (null == operatorStateBackend) {
+			initOperatorState();
+		}
+
+		return operatorStateBackend;
+	}
+
 	/**
 	 * Returns the {@link TimeServiceProvider} responsible for getting  the current
 	 * processing time and registering timers.
@@ -268,18 +320,18 @@ public abstract class AbstractStreamOperator<OUT>
 	@Override
 	@SuppressWarnings({"unchecked", "rawtypes"})
 	public void setKeyContextElement1(StreamRecord record) throws Exception {
-		if (stateKeySelector1 != null) {
-			Object key = ((KeySelector) stateKeySelector1).getKey(record.getValue());
-			getStateBackend().setCurrentKey(key);
-		}
+		setRawKeyContextElement(record, stateKeySelector1);
 	}
 
 	@Override
 	@SuppressWarnings({"unchecked", "rawtypes"})
 	public void setKeyContextElement2(StreamRecord record) throws Exception {
-		if (stateKeySelector2 != null) {
-			Object key = ((KeySelector) stateKeySelector2).getKey(record.getValue());
+		setRawKeyContextElement(record, stateKeySelector2);
+	}
 
+	private void setRawKeyContextElement(StreamRecord record, KeySelector<?, ?> selector) throws Exception {
+		if (selector != null) {
+			Object key = ((KeySelector) selector).getKey(record.getValue());
 			setKeyContext(key);
 		}
 	}
@@ -290,7 +342,7 @@ public abstract class AbstractStreamOperator<OUT>
 			try {
 				// need to work around type restrictions
 				@SuppressWarnings("unchecked,rawtypes")
-				KeyedStateBackend rawBackend = (KeyedStateBackend) keyedStateBackend;
+				AbstractKeyedStateBackend rawBackend = (AbstractKeyedStateBackend) keyedStateBackend;
 
 				rawBackend.setCurrentKey(key);
 			} catch (Exception e) {

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
index 6ac73e7..f683d9a 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
@@ -18,23 +18,31 @@
 
 package org.apache.flink.streaming.api.operators;
 
-import java.io.Serializable;
-
 import org.apache.flink.annotation.PublicEvolving;
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.functions.Function;
 import org.apache.flink.api.common.functions.util.FunctionUtils;
+import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.runtime.state.CheckpointListener;
+import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.streaming.api.checkpoint.Checkpointed;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
 import org.apache.flink.util.InstantiationUtil;
 
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.RunnableFuture;
+
 import static java.util.Objects.requireNonNull;
 
 /**
@@ -50,7 +58,8 @@ import static java.util.Objects.requireNonNull;
 @PublicEvolving
 public abstract class AbstractUdfStreamOperator<OUT, F extends Function>
 		extends AbstractStreamOperator<OUT>
-		implements OutputTypeConfigurable<OUT> {
+		implements OutputTypeConfigurable<OUT>,
+		StreamCheckpointedOperator {
 
 	private static final long serialVersionUID = 1L;
 	
@@ -91,6 +100,28 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function>
 		super.open();
 		
 		FunctionUtils.openFunction(userFunction, new Configuration());
+
+		if (userFunction instanceof CheckpointedFunction) {
+			((CheckpointedFunction) userFunction).initializeState(getOperatorStateBackend());
+		} else if (userFunction instanceof ListCheckpointed) {
+			@SuppressWarnings("unchecked")
+			ListCheckpointed<Serializable> listCheckpointedFun = (ListCheckpointed<Serializable>) userFunction;
+
+			ListState<Serializable> listState =
+					getOperatorStateBackend().getPartitionableState(ListCheckpointed.DEFAULT_LIST_DESCRIPTOR);
+
+			List<Serializable> list = new ArrayList<>();
+
+			for (Serializable serializable : listState.get()) {
+				list.add(serializable);
+			}
+
+			try {
+				listCheckpointedFun.restoreState(list);
+			} catch (Exception e) {
+				throw new Exception("Failed to restore state to function: " + e.getMessage(), e);
+			}
+		}
 	}
 
 	@Override
@@ -115,7 +146,6 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function>
 	
 	@Override
 	public void snapshotState(FSDataOutputStream out, long checkpointId, long timestamp) throws Exception {
-		super.snapshotState(out, checkpointId, timestamp);
 
 		if (userFunction instanceof Checkpointed) {
 			@SuppressWarnings("unchecked")
@@ -138,7 +168,6 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function>
 
 	@Override
 	public void restoreState(FSDataInputStream in) throws Exception {
-		super.restoreState(in);
 
 		if (userFunction instanceof Checkpointed) {
 			@SuppressWarnings("unchecked")
@@ -160,6 +189,32 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function>
 	}
 
 	@Override
+	public RunnableFuture<OperatorStateHandle> snapshotState(
+			long checkpointId, long timestamp, CheckpointStreamFactory streamFactory) throws Exception {
+
+		if (userFunction instanceof CheckpointedFunction) {
+			((CheckpointedFunction) userFunction).prepareSnapshot(checkpointId, timestamp);
+		}
+
+		if (userFunction instanceof ListCheckpointed) {
+			@SuppressWarnings("unchecked")
+			List<Serializable> partitionableState =
+					((ListCheckpointed<Serializable>) userFunction).snapshotState(checkpointId, timestamp);
+
+			ListState<Serializable> listState =
+					getOperatorStateBackend().getPartitionableState(ListCheckpointed.DEFAULT_LIST_DESCRIPTOR);
+
+			listState.clear();
+
+			for (Serializable statePartition : partitionableState) {
+				listState.add(statePartition);
+			}
+		}
+
+		return super.snapshotState(checkpointId, timestamp, streamFactory);
+	}
+
+	@Override
 	public void notifyOfCompletedCheckpoint(long checkpointId) throws Exception {
 		super.notifyOfCompletedCheckpoint(checkpointId);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamCheckpointedOperator.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamCheckpointedOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamCheckpointedOperator.java
new file mode 100644
index 0000000..50cdc02
--- /dev/null
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamCheckpointedOperator.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators;
+
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.streaming.api.graph.StreamConfig;
+import org.apache.flink.streaming.api.operators.Output;
+import org.apache.flink.streaming.runtime.tasks.StreamTask;
+
+@Deprecated
+public interface StreamCheckpointedOperator {
+
+	/**
+	 * Called to draw a state snapshot from the operator. This method snapshots the operator state
+	 * (if the operator is stateful).
+	 *
+	 * @param out The stream to which we have to write our state.
+	 * @param checkpointId The ID of the checkpoint.
+	 * @param timestamp The timestamp of the checkpoint.
+	 *
+	 * @throws Exception Forwards exceptions that occur while drawing snapshots from the operator
+	 *                   and the key/value state.
+	 */
+	void snapshotState(FSDataOutputStream out, long checkpointId, long timestamp) throws Exception;
+
+	/**
+	 * Restores the operator state, if this operator's execution is recovering from a checkpoint.
+	 * This method restores the operator state (if the operator is stateful) and the key/value state
+	 * (if it had been used and was initialized when the snapshot occurred).
+	 *
+	 * <p>This method is called after {@link #setup(StreamTask, StreamConfig, Output)}
+	 * and before {@link #open()}.
+	 *
+	 * @param in The stream from which we have to restore our state.
+	 *
+	 * @throws Exception Exceptions during state restore should be forwarded, so that the system can
+	 *                   properly react to failed state restore and fail the execution attempt.
+	 */
+	void restoreState(FSDataInputStream in) throws Exception;
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java
index f1e8160..fae5fd0 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamOperator.java
@@ -17,16 +17,18 @@
 
 package org.apache.flink.streaming.api.operators;
 
-import java.io.Serializable;
-
 import org.apache.flink.annotation.PublicEvolving;
-import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.metrics.MetricGroup;
-import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
 
+import java.io.Serializable;
+import java.util.Collection;
+import java.util.concurrent.RunnableFuture;
+
 /**
  * Basic interface for stream operators. Implementers would implement one of
  * {@link org.apache.flink.streaming.api.operators.OneInputStreamOperator} or
@@ -91,32 +93,27 @@ public interface StreamOperator<OUT> extends Serializable {
 	// ------------------------------------------------------------------------
 
 	/**
-	 * Called to draw a state snapshot from the operator. This method snapshots the operator state
-	 * (if the operator is stateful).
-	 *
-	 * @param out The stream to which we have to write our state.
-	 * @param checkpointId The ID of the checkpoint.
-	 * @param timestamp The timestamp of the checkpoint.
+	 * Called to draw a state snapshot from the operator.
 	 *
-	 * @throws Exception Forwards exceptions that occur while drawing snapshots from the operator
-	 *                   and the key/value state.
+	 * @throws Exception Forwards exceptions that occur while preparing for the snapshot
 	 */
-	void snapshotState(FSDataOutputStream out, long checkpointId, long timestamp) throws Exception;
 
 	/**
-	 * Restores the operator state, if this operator's execution is recovering from a checkpoint.
-	 * This method restores the operator state (if the operator is stateful) and the key/value state
-	 * (if it had been used and was initialized when the snapshot occurred).
+	 * Called to draw a state snapshot from the operator.
 	 *
-	 * <p>This method is called after {@link #setup(StreamTask, StreamConfig, Output)}
-	 * and before {@link #open()}.
-	 *
-	 * @param in The stream from which we have to restore our state.
+	 * @return a runnable future to the state handle that points to the snapshotted state. For synchronous implementations,
+	 * the runnable might already be finished.
+	 * @throws Exception exception that happened during snapshotting.
+	 */
+	RunnableFuture<OperatorStateHandle> snapshotState(
+			long checkpointId, long timestamp, CheckpointStreamFactory streamFactory) throws Exception;
+
+	/**
+	 * Provides state handles to restore the operator state.
 	 *
-	 * @throws Exception Exceptions during state restore should be forwarded, so that the system can
-	 *                   properly react to failed state restore and fail the execution attempt.
+	 * @param stateHandles state handles to the operator state.
 	 */
-	void restoreState(FSDataInputStream in) throws Exception;
+	void restoreState(Collection<OperatorStateHandle> stateHandles);
 
 	/**
 	 * Called when the checkpoint with the given ID is completed and acknowledged on the JobManager.

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java
index 4f85e3a..cc2e54b 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java
@@ -24,13 +24,10 @@ import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
 import org.apache.flink.api.common.functions.util.AbstractRuntimeUDFContext;
 import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ListStateDescriptor;
-import org.apache.flink.api.common.state.OperatorState;
 import org.apache.flink.api.common.state.ReducingState;
 import org.apache.flink.api.common.state.ReducingStateDescriptor;
 import org.apache.flink.api.common.state.ValueState;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
-import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.api.java.typeutils.TypeExtractor;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
 import org.apache.flink.streaming.api.CheckpointingMode;
@@ -143,35 +140,6 @@ public class StreamingRuntimeContext extends AbstractRuntimeUDFContext {
 		}
 	}
 
-	@Override
-	@Deprecated
-	public <S> OperatorState<S> getKeyValueState(String name, Class<S> stateType, S defaultState) {
-		requireNonNull(stateType, "The state type class must not be null");
-
-		TypeInformation<S> typeInfo;
-		try {
-			typeInfo = TypeExtractor.getForClass(stateType);
-		}
-		catch (Exception e) {
-			throw new RuntimeException("Cannot analyze type '" + stateType.getName() +
-					"' from the class alone, due to generic type parameters. " +
-					"Please specify the TypeInformation directly.", e);
-		}
-
-		return getKeyValueState(name, typeInfo, defaultState);
-	}
-
-	@Override
-	@Deprecated
-	public <S> OperatorState<S> getKeyValueState(String name, TypeInformation<S> stateType, S defaultState) {
-		requireNonNull(name, "The name of the state must not be null");
-		requireNonNull(stateType, "The state type information must not be null");
-
-		ValueStateDescriptor<S> stateProps = 
-				new ValueStateDescriptor<>(name, stateType, defaultState);
-		return getState(stateProps);
-	}
-
 	// ------------------ expose (read only) relevant information from the stream config -------- //
 
 	/**

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java
index 35d1108..b5500b7 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/GenericWriteAheadSink.java
@@ -27,6 +27,7 @@ import org.apache.flink.runtime.io.disk.InputViewIterator;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.util.ReusingMutableToRegularIteratorWrapper;
+import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.api.watermark.Watermark;
@@ -51,7 +52,9 @@ import java.util.UUID;
  *
  * @param <IN> Type of the elements emitted by this sink
  */
-public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<IN> implements OneInputStreamOperator<IN, IN> {
+public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<IN>
+		implements OneInputStreamOperator<IN, IN>, StreamCheckpointedOperator {
+
 	private static final long serialVersionUID = 1L;
 
 	protected static final Logger LOG = LoggerFactory.getLogger(GenericWriteAheadSink.class);
@@ -110,7 +113,6 @@ public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<I
 	public void snapshotState(FSDataOutputStream out,
 			long checkpointId,
 			long timestamp) throws Exception {
-		super.snapshotState(out, checkpointId, timestamp);
 
 		saveHandleInState(checkpointId, timestamp);
 
@@ -119,7 +121,6 @@ public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<I
 
 	@Override
 	public void restoreState(FSDataInputStream in) throws Exception {
-		super.restoreState(in);
 
 		this.state = InstantiationUtil.deserializeObject(in, getUserCodeClassloader());
 	}
@@ -151,11 +152,19 @@ public abstract class GenericWriteAheadSink<IN> extends AbstractStreamOperator<I
 					try {
 						if (!committer.isCheckpointCommitted(pastCheckpointId)) {
 							Tuple2<Long, StreamStateHandle> handle = state.pendingHandles.get(pastCheckpointId);
-							FSDataInputStream in = handle.f1.openInputStream();
-							boolean success = sendValues(new ReusingMutableToRegularIteratorWrapper<>(new InputViewIterator<>(new DataInputViewStreamWrapper(in), serializer), serializer), handle.f0);
-							if (success) { //if the sending has failed we will retry on the next notify
-								committer.commitCheckpoint(pastCheckpointId);
-								checkpointsToRemove.add(pastCheckpointId);
+							try (FSDataInputStream in = handle.f1.openInputStream()) {
+								boolean success = sendValues(
+										new ReusingMutableToRegularIteratorWrapper<>(
+												new InputViewIterator<>(
+														new DataInputViewStreamWrapper(
+																in),
+														serializer),
+												serializer),
+										handle.f0);
+								if (success) { //if the sending has failed we will retry on the next notify
+									committer.commitCheckpoint(pastCheckpointId);
+									checkpointsToRemove.add(pastCheckpointId);
+								}
 							}
 						} else {
 							checkpointsToRemove.add(pastCheckpointId);

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/EvictingWindowOperator.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/EvictingWindowOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/EvictingWindowOperator.java
index 4de7729..a838faa 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/EvictingWindowOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/EvictingWindowOperator.java
@@ -88,7 +88,7 @@ public class EvictingWindowOperator<K, IN, OUT, W extends Window> extends Window
 				element.getTimestamp(),
 				windowAssignerContext);
 
-		final K key = (K) getStateBackend().getCurrentKey();
+		final K key = (K) getKeyedStateBackend().getCurrentKey();
 
 		if (windowAssigner instanceof MergingWindowAssigner) {
 
@@ -122,7 +122,7 @@ public class EvictingWindowOperator<K, IN, OUT, W extends Window> extends Window
 								}
 
 								// merge the merged state windows into the newly resulting state window
-								getStateBackend().mergePartitionedStates(
+								getKeyedStateBackend().mergePartitionedStates(
 									stateWindowResult,
 									mergedStateWindows,
 									windowSerializer,

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java
index e4939db..ffdf334 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/operators/windowing/WindowOperator.java
@@ -298,7 +298,7 @@ public class WindowOperator<K, IN, ACC, OUT, W extends Window>
 		Collection<W> elementWindows = windowAssigner.assignWindows(
 			element.getValue(), element.getTimestamp(), windowAssignerContext);
 
-		final K key = (K) getStateBackend().getCurrentKey();
+		final K key = (K) getKeyedStateBackend().getCurrentKey();
 
 		if (windowAssigner instanceof MergingWindowAssigner) {
 			MergingWindowSet<W> mergingWindows = getMergingWindowSet();
@@ -329,7 +329,7 @@ public class WindowOperator<K, IN, ACC, OUT, W extends Window>
 						}
 
 						// merge the merged state windows into the newly resulting state window
-						getStateBackend().mergePartitionedStates(
+						getKeyedStateBackend().mergePartitionedStates(
 							stateWindowResult,
 							mergedStateWindows,
 							windowSerializer,
@@ -554,18 +554,18 @@ public class WindowOperator<K, IN, ACC, OUT, W extends Window>
 	 */
 	@SuppressWarnings("unchecked")
 	protected MergingWindowSet<W> getMergingWindowSet() throws Exception {
-		MergingWindowSet<W> mergingWindows = mergingWindowsByKey.get((K) getStateBackend().getCurrentKey());
+		MergingWindowSet<W> mergingWindows = mergingWindowsByKey.get((K) getKeyedStateBackend().getCurrentKey());
 		if (mergingWindows == null) {
 			// try to retrieve from state
 
 			TupleSerializer<Tuple2<W, W>> tupleSerializer = new TupleSerializer<>((Class) Tuple2.class, new TypeSerializer[] {windowSerializer, windowSerializer} );
 			ListStateDescriptor<Tuple2<W, W>> mergeStateDescriptor = new ListStateDescriptor<>("merging-window-set", tupleSerializer);
-			ListState<Tuple2<W, W>> mergeState = getStateBackend().getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, mergeStateDescriptor);
+			ListState<Tuple2<W, W>> mergeState = getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, mergeStateDescriptor);
 
 			mergingWindows = new MergingWindowSet<>((MergingWindowAssigner<? super IN, W>) windowAssigner, mergeState);
 			mergeState.clear();
 
-			mergingWindowsByKey.put((K) getStateBackend().getCurrentKey(), mergingWindows);
+			mergingWindowsByKey.put((K) getKeyedStateBackend().getCurrentKey(), mergingWindows);
 		}
 		return mergingWindows;
 	}
@@ -709,7 +709,7 @@ public class WindowOperator<K, IN, ACC, OUT, W extends Window>
 		public <S extends MergingState<?, ?>> void mergePartitionedState(StateDescriptor<S, ?> stateDescriptor) {
 			if (mergedWindows != null && mergedWindows.size() > 0) {
 				try {
-					WindowOperator.this.getStateBackend().mergePartitionedStates(window,
+					WindowOperator.this.getKeyedStateBackend().mergePartitionedStates(window,
 							mergedWindows,
 							windowSerializer,
 							stateDescriptor);
@@ -869,7 +869,7 @@ public class WindowOperator<K, IN, ACC, OUT, W extends Window>
 			ListStateDescriptor<Tuple2<W, W>> mergeStateDescriptor = new ListStateDescriptor<>("merging-window-set", tupleSerializer);
 			for (Map.Entry<K, MergingWindowSet<W>> key: mergingWindowsByKey.entrySet()) {
 				setKeyContext(key.getKey());
-				ListState<Tuple2<W, W>> mergeState = getStateBackend().getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, mergeStateDescriptor);
+				ListState<Tuple2<W, W>> mergeState = getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, mergeStateDescriptor);
 				mergeState.clear();
 				key.getValue().persist(mergeState);
 			}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java
index 0e24516..9e96f5d 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java
@@ -17,12 +17,6 @@
 
 package org.apache.flink.streaming.runtime.tasks;
 
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-
 import org.apache.flink.annotation.Internal;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
@@ -35,20 +29,25 @@ import org.apache.flink.runtime.plugable.SerializationDelegate;
 import org.apache.flink.streaming.api.collector.selector.CopyingDirectedOutput;
 import org.apache.flink.streaming.api.collector.selector.DirectedOutput;
 import org.apache.flink.streaming.api.collector.selector.OutputSelector;
-import org.apache.flink.streaming.api.watermark.Watermark;
-import org.apache.flink.streaming.runtime.io.RecordWriterOutput;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.graph.StreamEdge;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.api.operators.Output;
 import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.io.RecordWriterOutput;
 import org.apache.flink.streaming.runtime.io.StreamRecordWriter;
 import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
-
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
 /**
  * The {@code OperatorChain} contains all operators that are executed as one chain within a single
  * {@link StreamTask}.
@@ -57,7 +56,7 @@ import org.slf4j.LoggerFactory;
  *              head operator.
  */
 @Internal
-public class OperatorChain<OUT> {
+public class OperatorChain<OUT, OP extends StreamOperator<OUT>> {
 	
 	private static final Logger LOG = LoggerFactory.getLogger(OperatorChain.class);
 	
@@ -66,16 +65,17 @@ public class OperatorChain<OUT> {
 	private final RecordWriterOutput<?>[] streamOutputs;
 	
 	private final Output<StreamRecord<OUT>> chainEntryPoint;
-	
 
-	public OperatorChain(StreamTask<OUT, ?> containingTask,
-							StreamOperator<OUT> headOperator,
-							AccumulatorRegistry.Reporter reporter) {
+	private final OP headOperator;
+
+	public OperatorChain(StreamTask<OUT, OP> containingTask, AccumulatorRegistry.Reporter reporter) {
 		
 		final ClassLoader userCodeClassloader = containingTask.getUserCodeClassLoader();
 		final StreamConfig configuration = containingTask.getConfiguration();
 		final boolean enableTimestamps = containingTask.isSerializingTimestamps();
 
+		headOperator = configuration.getStreamOperator(userCodeClassloader);
+
 		// we read the chained configs, and the order of record writer registrations by output name
 		Map<Integer, StreamConfig> chainedConfigs = configuration.getTransitiveChainedTaskConfigs(userCodeClassloader);
 		chainedConfigs.put(configuration.getVertexID(), configuration);
@@ -104,11 +104,15 @@ public class OperatorChain<OUT> {
 			List<StreamOperator<?>> allOps = new ArrayList<>(chainedConfigs.size());
 			this.chainEntryPoint = createOutputCollector(containingTask, configuration,
 					chainedConfigs, userCodeClassloader, streamOutputMap, allOps);
+
+			if (headOperator != null) {
+				headOperator.setup(containingTask, configuration, getChainEntryPoint());
+			}
+
+			// add head operator to end of chain
+			allOps.add(headOperator);
 			
-			this.allOperators = allOps.toArray(new StreamOperator<?>[allOps.size() + 1]);
-			
-			// add the head operator to the end of the list
-			this.allOperators[this.allOperators.length - 1] = headOperator;
+			this.allOperators = allOps.toArray(new StreamOperator<?>[allOps.size()]);
 			
 			success = true;
 		}
@@ -181,7 +185,15 @@ public class OperatorChain<OUT> {
 			}
 		}
 	}
-	
+
+	public OP getHeadOperator() {
+		return headOperator;
+	}
+
+	public int getChainLength() {
+		return allOperators == null ? 0 : allOperators.length;
+	}
+
 	// ------------------------------------------------------------------------
 	//  initialization utilities
 	// ------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index 7976f01..1725eca 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -26,14 +26,19 @@ import org.apache.flink.configuration.IllegalConfigurationException;
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.metrics.Gauge;
 import org.apache.flink.runtime.execution.CancelTaskException;
+import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.runtime.jobgraph.tasks.StatefulTask;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.runtime.state.ClosableRegistry;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
-import org.apache.flink.runtime.state.KeyedStateBackend;
+import org.apache.flink.runtime.state.OperatorStateBackend;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateBackendFactory;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.filesystem.FsStateBackend;
@@ -41,24 +46,23 @@ import org.apache.flink.runtime.state.filesystem.FsStateBackendFactory;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.runtime.taskmanager.DispatcherThreadFactory;
 import org.apache.flink.streaming.api.TimeCharacteristic;
+import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.Output;
 import org.apache.flink.streaming.api.operators.StreamOperator;
 import org.apache.flink.streaming.runtime.io.RecordWriterOutput;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
-
+import org.apache.flink.util.Preconditions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.io.Closeable;
 import java.io.IOException;
-import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
-import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
-import java.util.Set;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.RunnableFuture;
@@ -70,19 +74,19 @@ import java.util.concurrent.ScheduledThreadPoolExecutor;
  * the Task's operator chain. Operators that are chained together execute synchronously in the
  * same thread and hence on the same stream partition. A common case for these chains
  * are successive map/flatmap/filter tasks.
- * 
- * <p>The task chain contains one "head" operator and multiple chained operators. 
+ *
+ * <p>The task chain contains one "head" operator and multiple chained operators.
  * The StreamTask is specialized for the type of the head operator: one-input and two-input tasks,
  * as well as for sources, iteration heads and iteration tails.
- * 
- * <p>The Task class deals with the setup of the streams read by the head operator, and the streams 
+ *
+ * <p>The Task class deals with the setup of the streams read by the head operator, and the streams
  * produced by the operators at the ends of the operator chain. Note that the chain may fork and
  * thus have multiple ends.
  *
- * The life cycle of the task is set up as follows: 
+ * The life cycle of the task is set up as follows:
  * <pre>{@code
- *  -- restoreState() -> restores state of all operators in the chain
- *  
+ *  -- getPartitionableState() -> restores state of all operators in the chain
+ *
  *  -- invoke()
  *        |
  *        +----> Create basic utils (config, etc) and load the chain of operators
@@ -99,35 +103,35 @@ import java.util.concurrent.ScheduledThreadPoolExecutor;
  * <p> The {@code StreamTask} has a lock object called {@code lock}. All calls to methods on a
  * {@code StreamOperator} must be synchronized on this lock object to ensure that no methods
  * are called concurrently.
- * 
+ *
  * @param <OUT>
- * @param <Operator>
+ * @param <OP>
  */
 @Internal
-public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
+public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 		extends AbstractInvokable
 		implements StatefulTask, AsyncExceptionHandler {
 
 	/** The thread group that holds all trigger timer threads */
 	public static final ThreadGroup TRIGGER_THREAD_GROUP = new ThreadGroup("Triggers");
-	
+
 	/** The logger used by the StreamTask and its subclasses */
 	private static final Logger LOG = LoggerFactory.getLogger(StreamTask.class);
-	
+
 	// ------------------------------------------------------------------------
-	
+
 	/**
 	 * All interaction with the {@code StreamOperator} must be synchronized on this lock object to ensure that
 	 * we don't have concurrent method calls that void consistent checkpoints.
 	 */
 	private final Object lock = new Object();
-	
+
 	/** the head operator that consumes the input streams of this task */
-	protected Operator headOperator;
+	protected OP headOperator;
 
 	/** The chain of operators executed by this task */
-	private OperatorChain<OUT> operatorChain;
-	
+	private OperatorChain<OUT, OP> operatorChain;
+
 	/** The configuration of this streaming task */
 	private StreamConfig configuration;
 
@@ -135,7 +139,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 	private AbstractStateBackend stateBackend;
 
 	/** Keyed state backend for the head operator, if it is keyed. There can only ever be one. */
-	private KeyedStateBackend<?> keyedStateBackend;
+	private AbstractKeyedStateBackend<?> keyedStateBackend;
 
 	/**
 	 * The internal {@link TimeServiceProvider} used to define the current
@@ -146,12 +150,14 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 
 	/** The map of user-defined accumulators of this task */
 	private Map<String, Accumulator<?, ?>> accumulatorMap;
-	
+
 	/** The chained operator state to be restored once the initialization is done */
 	private ChainedStateHandle<StreamStateHandle> lazyRestoreChainedOperatorState;
 
 	private List<KeyGroupsStateHandle> lazyRestoreKeyGroupStates;
 
+	private List<Collection<OperatorStateHandle>> lazyRestoreOperatorState;
+
 	/**
 	 * This field is used to forward an exception that is caught in the timer thread or other
 	 * asynchronous Threads. Subclasses must ensure that exceptions stored here get thrown on the
@@ -159,12 +165,12 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 	private volatile AsynchronousException asyncException;
 
 	/** The currently active background materialization threads */
-	private final Set<Closeable> cancelables = new HashSet<>();
-	
+	private final ClosableRegistry cancelables = new ClosableRegistry();
+
 	/** Flag to mark the task "in operation", in which case check
 	 * needs to be initialized to true, so that early cancel() before invoke() behaves correctly */
 	private volatile boolean isRunning;
-	
+
 	/** Flag to mark this task as canceled */
 	private volatile boolean canceled;
 
@@ -178,11 +184,11 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 	// ------------------------------------------------------------------------
 
 	protected abstract void init() throws Exception;
-	
+
 	protected abstract void run() throws Exception;
-	
+
 	protected abstract void cleanup() throws Exception;
-	
+
 	protected abstract void cancelTask() throws Exception;
 
 	// ------------------------------------------------------------------------
@@ -232,13 +238,8 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 				timerService = DefaultTimeServiceProvider.create(this, executor, getCheckpointLock());
 			}
 
-			headOperator = configuration.getStreamOperator(getUserCodeClassLoader());
-			operatorChain = new OperatorChain<>(this, headOperator, 
-						getEnvironment().getAccumulatorRegistry().getReadWriteReporter());
-
-			if (headOperator != null) {
-				headOperator.setup(this, configuration, operatorChain.getChainEntryPoint());
-			}
+			operatorChain = new OperatorChain<>(this, getEnvironment().getAccumulatorRegistry().getReadWriteReporter());
+			headOperator = operatorChain.getHeadOperator();
 
 			getEnvironment().getMetricGroup().gauge("lastCheckpointSize", new Gauge<Long>() {
 				@Override
@@ -249,12 +250,12 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 
 			// task specific initialization
 			init();
-			
+
 			// save the work of reloadig state, etc, if the task is already canceled
 			if (canceled) {
 				throw new CancelTaskException();
 			}
-			
+
 			// -------- Invoke --------
 			LOG.debug("Invoking {}", getName());
 
@@ -278,7 +279,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 			run();
 
 			LOG.debug("Finished task {}", getName());
-			
+
 			// make sure no further checkpoint and notification actions happen.
 			// we make sure that no other thread is currently in the locked scope before
 			// we close the operators by trying to acquire the checkpoint scope lock
@@ -286,13 +287,13 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 			// at the same time, this makes sure that during any "regular" exit where still
 			synchronized (lock) {
 				isRunning = false;
-				
+
 				// this is part of the main logic, so if this fails, the task is considered failed
 				closeAllOperators();
 			}
 
 			LOG.debug("Closed operators for task {}", getName());
-			
+
 			// make sure all buffered data is flushed
 			operatorChain.flushOutputs();
 
@@ -324,7 +325,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 
 			// stop all asynchronous checkpoint threads
 			try {
-				closeAllClosables();
+				cancelables.close();
 				shutdownAsyncThreads();
 			}
 			catch (Throwable t) {
@@ -371,13 +372,13 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 		isRunning = false;
 		canceled = true;
 		cancelTask();
-		closeAllClosables();
+		cancelables.close();
 	}
 
 	public final boolean isRunning() {
 		return isRunning;
 	}
-	
+
 	public final boolean isCanceled() {
 		return canceled;
 	}
@@ -476,36 +477,14 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 			}
 		}
 
-		closeAllClosables();
-	}
-
-	private void closeAllClosables() {
-		// first, create a copy of the cancelables to prevent concurrent modifications
-		// and to not hold the lock for too long. the copy can be a cheap list
-		List<Closeable> localCancelables = null;
-		synchronized (cancelables) {
-			if (cancelables.size() > 0) {
-				localCancelables = new ArrayList<>(cancelables);
-				cancelables.clear();
-			}
-		}
-
-		if (localCancelables != null) {
-			for (Closeable cancelable : localCancelables) {
-				try {
-					cancelable.close();
-				} catch (Throwable t) {
-					LOG.error("Error on canceling operation", t);
-				}
-			}
-		}
+		cancelables.close();
 	}
 
 	boolean isSerializingTimestamps() {
 		TimeCharacteristic tc = configuration.getTimeCharacteristic();
 		return tc == TimeCharacteristic.EventTime | tc == TimeCharacteristic.IngestionTime;
 	}
-	
+
 	// ------------------------------------------------------------------------
 	//  Access to properties and utilities
 	// ------------------------------------------------------------------------
@@ -525,7 +504,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 	public Object getCheckpointLock() {
 		return lock;
 	}
-	
+
 	public StreamConfig getConfiguration() {
 		return configuration;
 	}
@@ -533,11 +512,11 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 	public Map<String, Accumulator<?, ?>> getAccumulatorMap() {
 		return accumulatorMap;
 	}
-	
+
 	Output<StreamRecord<OUT>> getHeadOutput() {
 		return operatorChain.getChainEntryPoint();
 	}
-	
+
 	RecordWriterOutput<?>[] getStreamOutputs() {
 		return operatorChain.getStreamOutputs();
 	}
@@ -547,40 +526,59 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 	// ------------------------------------------------------------------------
 
 	@Override
-	public void setInitialState(ChainedStateHandle<StreamStateHandle> chainedState, List<KeyGroupsStateHandle> keyGroupsState) {
+	public void setInitialState(
+		ChainedStateHandle<StreamStateHandle> chainedState,
+		List<KeyGroupsStateHandle> keyGroupsState,
+		List<Collection<OperatorStateHandle>> partitionableOperatorState) {
+
 		lazyRestoreChainedOperatorState = chainedState;
 		lazyRestoreKeyGroupStates = keyGroupsState;
+		lazyRestoreOperatorState = partitionableOperatorState;
 	}
 
 	private void restoreState() throws Exception {
 		final StreamOperator<?>[] allOperators = operatorChain.getAllOperators();
 
-		try {
-			if (lazyRestoreChainedOperatorState != null) {
+		if (lazyRestoreChainedOperatorState != null) {
+			Preconditions.checkState(lazyRestoreChainedOperatorState.getLength() == allOperators.length,
+					"Invalid Invalid number of operator states. Found :" + lazyRestoreChainedOperatorState.getLength() +
+							". Expected: " + allOperators.length);
+		}
 
-				synchronized (cancelables) {
-					cancelables.add(lazyRestoreChainedOperatorState);
-				}
+		if (lazyRestoreOperatorState != null) {
+			Preconditions.checkArgument(lazyRestoreOperatorState.isEmpty()
+							|| lazyRestoreOperatorState.size() == allOperators.length,
+					"Invalid number of operator states. Found :" + lazyRestoreOperatorState.size() +
+							". Expected: " + allOperators.length);
+		}
 
-				for (int i = 0; i < lazyRestoreChainedOperatorState.getLength(); i++) {
+		for (int i = 0; i < allOperators.length; i++) {
+			StreamOperator<?> operator = allOperators[i];
+
+			if (null != lazyRestoreOperatorState && !lazyRestoreOperatorState.isEmpty()) {
+				operator.restoreState(lazyRestoreOperatorState.get(i));
+			}
+
+			// TODO deprecated code path
+			if (operator instanceof StreamCheckpointedOperator) {
+
+				if (lazyRestoreChainedOperatorState != null) {
 					StreamStateHandle state = lazyRestoreChainedOperatorState.get(i);
-					if (state == null) {
-						continue;
-					}
-					StreamOperator<?> operator = allOperators[i];
 
-					if (operator != null) {
+					if (state != null) {
 						LOG.debug("Restore state of task {} in chain ({}).", i, getName());
-						try (FSDataInputStream inputStream = state.openInputStream()) {
-							operator.restoreState(inputStream);
+
+						FSDataInputStream is = state.openInputStream();
+						try {
+							cancelables.registerClosable(is);
+							((StreamCheckpointedOperator) operator).restoreState(is);
+						} finally {
+							cancelables.unregisterClosable(is);
+							is.close();
 						}
 					}
 				}
 			}
-		} finally {
-			synchronized (cancelables) {
-				cancelables.remove(lazyRestoreChainedOperatorState);
-			}
 		}
 	}
 
@@ -629,29 +627,58 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 				// Given this, we immediately emit the checkpoint barriers, so the downstream operators
 				// can start their checkpoint work as soon as possible
 				operatorChain.broadcastCheckpointBarrier(checkpointId, timestamp);
-				
+
 				// now draw the state snapshot
 				final StreamOperator<?>[] allOperators = operatorChain.getAllOperators();
-				final List<StreamStateHandle> nonPartitionedStates = Arrays.asList(new StreamStateHandle[allOperators.length]);
+
+				final List<StreamStateHandle> nonPartitionedStates =
+						Arrays.asList(new StreamStateHandle[allOperators.length]);
+
+				final List<OperatorStateHandle> operatorStates =
+						Arrays.asList(new OperatorStateHandle[allOperators.length]);
 
 				for (int i = 0; i < allOperators.length; i++) {
 					StreamOperator<?> operator = allOperators[i];
 
 					if (operator != null) {
+
+						final String operatorId = createOperatorIdentifier(operator, configuration.getVertexID());
+
 						CheckpointStreamFactory streamFactory =
-								stateBackend.createStreamFactory(
-										getEnvironment().getJobID(),
-										createOperatorIdentifier(
-												operator,
-												configuration.getVertexID()));
+								stateBackend.createStreamFactory(getEnvironment().getJobID(), operatorId);
+
+						//TODO deprecated code path
+						if (operator instanceof StreamCheckpointedOperator) {
+
+							CheckpointStreamFactory.CheckpointStateOutputStream outStream =
+									streamFactory.createCheckpointStateOutputStream(checkpointId, timestamp);
+
+
+							cancelables.registerClosable(outStream);
+
+							try {
+								((StreamCheckpointedOperator) operator).
+										snapshotState(outStream, checkpointId, timestamp);
+
+								nonPartitionedStates.set(i, outStream.closeAndGetHandle());
+							} finally {
+								cancelables.unregisterClosable(outStream);
+							}
+						}
 
-						CheckpointStreamFactory.CheckpointStateOutputStream outStream =
-								streamFactory.createCheckpointStateOutputStream(checkpointId, timestamp);
+						RunnableFuture<OperatorStateHandle> handleFuture =
+								operator.snapshotState(checkpointId, timestamp, streamFactory);
 
-						operator.snapshotState(outStream, checkpointId, timestamp);
+						if (null != handleFuture) {
+							//TODO for now we assume there are only synchrous snapshots, no need to start the runnable.
+							if (!handleFuture.isDone()) {
+								throw new IllegalStateException("Currently only supports synchronous snapshots!");
+							}
 
-						nonPartitionedStates.set(i, outStream.closeAndGetHandle());
+							operatorStates.set(i, handleFuture.get());
+						}
 					}
+
 				}
 
 				RunnableFuture<KeyGroupsStateHandle> keyGroupsStateHandleFuture = null;
@@ -659,16 +686,16 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 				if (keyedStateBackend != null) {
 					CheckpointStreamFactory streamFactory = stateBackend.createStreamFactory(
 							getEnvironment().getJobID(),
-							createOperatorIdentifier(
-									headOperator,
-									configuration.getVertexID()));
-					keyGroupsStateHandleFuture = keyedStateBackend.snapshot(
-							checkpointId,
-							timestamp,
-							streamFactory);
+							createOperatorIdentifier(headOperator, configuration.getVertexID()));
+
+					keyGroupsStateHandleFuture = keyedStateBackend.snapshot(checkpointId, timestamp, streamFactory);
 				}
 
-				ChainedStateHandle<StreamStateHandle> chainedStateHandles = new ChainedStateHandle<>(nonPartitionedStates);
+				ChainedStateHandle<StreamStateHandle> chainedNonPartitionedStateHandles =
+						new ChainedStateHandle<>(nonPartitionedStates);
+
+				ChainedStateHandle<OperatorStateHandle> chainedPartitionedStateHandles =
+						new ChainedStateHandle<>(operatorStates);
 
 				LOG.debug("Finished synchronous checkpoints for checkpoint {} on task {}", checkpointId, getName());
 
@@ -679,7 +706,8 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 						"checkpoint-" + checkpointId + "-" + timestamp,
 						this,
 						cancelables,
-						chainedStateHandles,
+						chainedNonPartitionedStateHandles,
+						chainedPartitionedStateHandles,
 						keyGroupsStateHandleFuture,
 						checkpointId,
 						bytesBufferedAlignment,
@@ -687,9 +715,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 						syncDurationMillis,
 						endOfSyncPart);
 
-				synchronized (cancelables) {
-					cancelables.add(asyncCheckpointRunnable);
-				}
+				cancelables.registerClosable(asyncCheckpointRunnable);
 				asyncOperationsThreadPool.submit(asyncCheckpointRunnable);
 				return true;
 			} else {
@@ -707,7 +733,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 		synchronized (lock) {
 			if (isRunning) {
 				LOG.debug("Notification of complete checkpoint for task {}", getName());
-				
+
 				for (StreamOperator<?> operator : operatorChain.getAllOperators()) {
 					if (operator != null) {
 						operator.notifyOfCompletedCheckpoint(checkpointId);
@@ -760,7 +786,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 						Class<? extends StateBackendFactory> clazz =
 								Class.forName(backendName, false, getUserCodeClassLoader()).asSubclass(StateBackendFactory.class);
 
-						stateBackend = ((StateBackendFactory<?>) clazz.newInstance()).createFromConfig(flinkConfig);
+						stateBackend = clazz.newInstance().createFromConfig(flinkConfig);
 					} catch (ClassNotFoundException e) {
 						throw new IllegalConfigurationException("Cannot find configured state backend: " + backendName);
 					} catch (ClassCastException e) {
@@ -772,10 +798,26 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 					}
 			}
 		}
+
 		return stateBackend;
 	}
 
-	public <K> KeyedStateBackend<K> createKeyedStateBackend(
+	public OperatorStateBackend createOperatorStateBackend(
+			StreamOperator<?> op, Collection<OperatorStateHandle> restoreStateHandles) throws Exception {
+
+		Environment env = getEnvironment();
+		String opId = createOperatorIdentifier(op, configuration.getVertexID());
+
+		OperatorStateBackend newBackend = restoreStateHandles == null ?
+				stateBackend.createOperatorStateBackend(env, opId)
+				: stateBackend.restoreOperatorStateBackend(env, opId, restoreStateHandles);
+
+		cancelables.registerClosable(newBackend);
+
+		return newBackend;
+	}
+
+	public <K> AbstractKeyedStateBackend<K> createKeyedStateBackend(
 			TypeSerializer<K> keySerializer,
 			int numberOfKeyGroups,
 			KeyGroupRange keyGroupRange) throws Exception {
@@ -811,8 +853,10 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 					getEnvironment().getTaskKvStateRegistry());
 		}
 
+		cancelables.registerClosable(keyedStateBackend);
+
 		@SuppressWarnings("unchecked")
-		KeyedStateBackend<K> typedBackend = (KeyedStateBackend<K>) keyedStateBackend;
+		AbstractKeyedStateBackend<K> typedBackend = (AbstractKeyedStateBackend<K>) keyedStateBackend;
 		return typedBackend;
 	}
 
@@ -825,9 +869,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 	public CheckpointStreamFactory createCheckpointStreamFactory(StreamOperator<?> operator) throws IOException {
 		return stateBackend.createStreamFactory(
 				getEnvironment().getJobID(),
-				createOperatorIdentifier(
-						operator,
-						configuration.getVertexID()));
+				createOperatorIdentifier(operator, configuration.getVertexID()));
 
 	}
 
@@ -867,7 +909,6 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 		if (isRunning) {
 			LOG.error("Asynchronous exception registered.", exception);
 		}
-
 		if (this.asyncException == null) {
 			this.asyncException = exception;
 		}
@@ -877,20 +918,23 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 	//  Utilities
 	// ------------------------------------------------------------------------
 
+
 	@Override
 	public String toString() {
 		return getName();
 	}
 
 	// ------------------------------------------------------------------------
-	
+
 	private static class AsyncCheckpointRunnable implements Runnable, Closeable {
 
 		private final StreamTask<?, ?> owner;
 
-		private final Set<Closeable> cancelables;
+		private final ClosableRegistry cancelables;
+
+		private final ChainedStateHandle<StreamStateHandle> nonPartitionedStateHandles;
 
-		private final ChainedStateHandle<StreamStateHandle> chainedStateHandles;
+		private final ChainedStateHandle<OperatorStateHandle> partitioneableStateHandles;
 
 		private final RunnableFuture<KeyGroupsStateHandle> keyGroupsStateHandleFuture;
 
@@ -909,8 +953,9 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 		AsyncCheckpointRunnable(
 				String name,
 				StreamTask<?, ?> owner,
-				Set<Closeable> cancelables,
-				ChainedStateHandle<StreamStateHandle> chainedStateHandles,
+				ClosableRegistry cancelables,
+				ChainedStateHandle<StreamStateHandle> nonPartitionedStateHandles,
+				ChainedStateHandle<OperatorStateHandle> partitioneableStateHandles,
 				RunnableFuture<KeyGroupsStateHandle> keyGroupsStateHandleFuture,
 				long checkpointId,
 				long bytesBufferedInAlignment,
@@ -921,7 +966,8 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 			this.name = name;
 			this.owner = owner;
 			this.cancelables = cancelables;
-			this.chainedStateHandles = chainedStateHandles;
+			this.nonPartitionedStateHandles = nonPartitionedStateHandles;
+			this.partitioneableStateHandles = partitioneableStateHandles;
 			this.keyGroupsStateHandleFuture = keyGroupsStateHandleFuture;
 			this.checkpointId = checkpointId;
 			this.bytesBufferedInAlignment = bytesBufferedInAlignment;
@@ -952,13 +998,19 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 				final long asyncEndNanos = System.nanoTime();
 				final long asyncDurationMillis = (asyncEndNanos - asyncStartNanos) / 1_000_000;
 
-				if (chainedStateHandles.isEmpty() && keyedStates.isEmpty()) {
+				if (nonPartitionedStateHandles.isEmpty() && keyedStates.isEmpty()) {
 					owner.getEnvironment().acknowledgeCheckpoint(checkpointId,
 							syncDurationMillies, asyncDurationMillis,
 							bytesBufferedInAlignment, alignmentDurationNanos);
 				} else  {
+
+					CheckpointStateHandles allStateHandles = new CheckpointStateHandles(
+							nonPartitionedStateHandles,
+							partitioneableStateHandles,
+							keyedStates);
+
 					owner.getEnvironment().acknowledgeCheckpoint(checkpointId,
-							chainedStateHandles, keyedStates,
+							allStateHandles,
 							syncDurationMillies, asyncDurationMillis,
 							bytesBufferedInAlignment, alignmentDurationNanos);
 				}
@@ -974,9 +1026,7 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 				owner.registerAsyncException(asyncException);
 			}
 			finally {
-				synchronized (cancelables) {
-					cancelables.remove(this);
-				}
+				cancelables.unregisterClosable(this);
 			}
 		}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java
index fe09788..02409a3 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java
@@ -36,8 +36,8 @@ import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
 import org.apache.flink.runtime.query.KvStateRegistry;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
@@ -188,15 +188,15 @@ public class StreamingRuntimeContextTest {
 					public ListState<String> answer(InvocationOnMock invocationOnMock) throws Throwable {
 						ListStateDescriptor<String> descr =
 								(ListStateDescriptor<String>) invocationOnMock.getArguments()[0];
-						KeyedStateBackend<Integer> backend = new MemoryStateBackend().createKeyedStateBackend(
+
+						AbstractKeyedStateBackend<Integer> backend = new MemoryStateBackend().createKeyedStateBackend(
 								new DummyEnvironment("test_task", 1, 0),
 								new JobID(),
 								"test_op",
 								IntSerializer.INSTANCE,
 								1,
 								new KeyGroupRange(0, 0),
-								new KvStateRegistry().createTaskRegistry(new JobID(),
-										new JobVertexID()));
+								new KvStateRegistry().createTaskRegistry(new JobID(), new JobVertexID()));
 						backend.setCurrentKey(0);
 						return backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, descr);
 					}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java
index b549ef8..5d68841 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferTest.java
@@ -21,6 +21,7 @@ package org.apache.flink.streaming.runtime.io;
 import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.runtime.io.disk.iomanager.IOManager;
 import org.apache.flink.runtime.io.disk.iomanager.IOManagerAsync;
+import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
@@ -28,15 +29,15 @@ import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
 import org.apache.flink.runtime.jobgraph.tasks.StatefulTask;
 import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
-import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
-
 import org.junit.AfterClass;
 import org.junit.BeforeClass;
 import org.junit.Test;
 
 import java.io.File;
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.List;
 
 import static org.junit.Assert.assertEquals;
@@ -974,7 +975,8 @@ public class BarrierBufferTest {
 		@Override
 		public void setInitialState(
 				ChainedStateHandle<StreamStateHandle> chainedState,
-				List<KeyGroupsStateHandle> keyGroupsState) throws Exception {
+				List<KeyGroupsStateHandle> keyGroupsState,
+				List<Collection<OperatorStateHandle>> partitionableOperatorState) throws Exception {
 			throw new UnsupportedOperationException("should never be called");
 		}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java
index 314dcc4..f2f9092 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierTrackerTest.java
@@ -19,21 +19,25 @@
 package org.apache.flink.streaming.runtime.io;
 
 import org.apache.flink.core.memory.MemorySegmentFactory;
+import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
 import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
 import org.apache.flink.runtime.jobgraph.tasks.StatefulTask;
 import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
-import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
-
 import org.junit.Test;
 
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.List;
 
-import static org.junit.Assert.*;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 
 /**
  * Tests for the behavior of the barrier tracker.
@@ -363,7 +367,8 @@ public class BarrierTrackerTest {
 		@Override
 		public void setInitialState(
 				ChainedStateHandle<StreamStateHandle> chainedState,
-				List<KeyGroupsStateHandle> keyGroupsState) throws Exception {
+				List<KeyGroupsStateHandle> keyGroupsState,
+				List<Collection<OperatorStateHandle>> partitionableOperatorState) throws Exception {
 
 			throw new UnsupportedOperationException("should never be called");
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java
index f4ac5b2..32e8ea9 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/StreamOperatorChainingTest.java
@@ -19,7 +19,6 @@ package org.apache.flink.streaming.runtime.operators;
 
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.functions.MapFunction;
-import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.accumulators.AccumulatorRegistry;
 import org.apache.flink.runtime.execution.Environment;
@@ -27,8 +26,6 @@ import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.operators.testutils.MockEnvironment;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
-import org.apache.flink.runtime.state.AbstractStateBackend;
-import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.streaming.api.collector.selector.OutputSelector;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.datastream.SplitStream;
@@ -42,19 +39,15 @@ import org.apache.flink.streaming.runtime.tasks.OperatorChain;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
 import org.junit.Assert;
 import org.junit.Test;
-import org.mockito.invocation.InvocationOnMock;
-import org.mockito.stubbing.Answer;
 
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
 
+import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.contains;
-import static org.mockito.Matchers.any;
-import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
-import static org.hamcrest.MatcherAssert.assertThat;
 
 /**
  * Tests for stream operator chaining behaviour.
@@ -156,9 +149,8 @@ public class StreamOperatorChainingTest {
 		StreamTask<Integer, StreamMap<Integer, Integer>> mockTask =
 				createMockTask(streamConfig, chainedVertex.getName());
 
-		OperatorChain<Integer> operatorChain = new OperatorChain<>(
+		OperatorChain<Integer, StreamMap<Integer, Integer>> operatorChain = new OperatorChain<>(
 				mockTask,
-				headOperator,
 				mock(AccumulatorRegistry.Reporter.class));
 
 		headOperator.setup(mockTask, streamConfig, operatorChain.getChainEntryPoint());
@@ -299,9 +291,8 @@ public class StreamOperatorChainingTest {
 		StreamTask<Integer, StreamMap<Integer, Integer>> mockTask =
 				createMockTask(streamConfig, chainedVertex.getName());
 
-		OperatorChain<Integer> operatorChain = new OperatorChain<>(
+		OperatorChain<Integer, StreamMap<Integer, Integer>> operatorChain = new OperatorChain<>(
 				mockTask,
-				headOperator,
 				mock(AccumulatorRegistry.Reporter.class));
 
 		headOperator.setup(mockTask, streamConfig, operatorChain.getChainEntryPoint());

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
index 6a7b024..b5b6582 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
@@ -41,9 +41,9 @@ import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
 import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
-import org.apache.flink.runtime.state.AbstractCloseableHandle;
 import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.taskmanager.CheckpointResponder;
 import org.apache.flink.runtime.taskmanager.Task;
@@ -56,19 +56,23 @@ import org.apache.flink.streaming.api.functions.source.SourceFunction;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.StreamSource;
 import org.apache.flink.util.SerializedValue;
-
 import org.junit.Test;
 
 import java.io.EOFException;
 import java.io.IOException;
 import java.io.Serializable;
 import java.net.URL;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.List;
 import java.util.concurrent.Executor;
 
-import static org.junit.Assert.*;
-import static org.mockito.Mockito.*;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 /**
  * This test checks that task restores that get stuck in the presence of interrupts
@@ -121,6 +125,7 @@ public class InterruptSensitiveRestoreTest {
 
 		ChainedStateHandle<StreamStateHandle> operatorState = new ChainedStateHandle<>(Collections.singletonList(state));
 		List<KeyGroupsStateHandle> keyGroupState = Collections.emptyList();
+		List<Collection<OperatorStateHandle>> partitionableOperatorState = Collections.emptyList();
 
 		return new TaskDeploymentDescriptor(
 				new JobID(),
@@ -139,42 +144,47 @@ public class InterruptSensitiveRestoreTest {
 				Collections.<URL>emptyList(),
 				0,
 				operatorState,
-				keyGroupState);
+				keyGroupState,
+				partitionableOperatorState);
 	}
-	
+
 	private static Task createTask(TaskDeploymentDescriptor tdd) throws IOException {
 		NetworkEnvironment networkEnvironment = mock(NetworkEnvironment.class);
 		when(networkEnvironment.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class)))
 				.thenReturn(mock(TaskKvStateRegistry.class));
 
 		return new Task(
-			tdd,
-			mock(MemoryManager.class),
-			mock(IOManager.class),
-			networkEnvironment,
-			mock(BroadcastVariableManager.class),
+				tdd,
+				mock(MemoryManager.class),
+				mock(IOManager.class),
+				networkEnvironment,
+				mock(BroadcastVariableManager.class),
 				mock(TaskManagerConnection.class),
 				mock(InputSplitProvider.class),
 				mock(CheckpointResponder.class),
-			new FallbackLibraryCacheManager(),
-			new FileCache(new Configuration()),
-			new TaskManagerRuntimeInfo(
-					"localhost", new Configuration(), EnvironmentInformation.getTemporaryFileDirectory()),
-			new UnregisteredTaskMetricsGroup(),
-			mock(ResultPartitionConsumableNotifier.class),
-			mock(PartitionStateChecker.class),
-			mock(Executor.class));
-		
+				new FallbackLibraryCacheManager(),
+				new FileCache(new Configuration()),
+				new TaskManagerRuntimeInfo(
+						"localhost", new Configuration(), EnvironmentInformation.getTemporaryFileDirectory()),
+				new UnregisteredTaskMetricsGroup(),
+				mock(ResultPartitionConsumableNotifier.class),
+				mock(PartitionStateChecker.class),
+				mock(Executor.class));
+
 	}
 
 	// ------------------------------------------------------------------------
 
 	@SuppressWarnings("serial")
-	private static class InterruptLockingStateHandle extends AbstractCloseableHandle implements StreamStateHandle {
+	private static class InterruptLockingStateHandle implements StreamStateHandle {
+
+		private volatile boolean closed;
 
 		@Override
 		public FSDataInputStream openInputStream() throws IOException {
-			ensureNotClosed();
+
+			closed = false;
+
 			FSDataInputStream is = new FSDataInputStream() {
 
 				@Override
@@ -191,8 +201,14 @@ public class InterruptSensitiveRestoreTest {
 					block();
 					throw new EOFException();
 				}
+
+				@Override
+				public void close() throws IOException {
+					super.close();
+					closed = true;
+				}
 			};
-			registerCloseable(is);
+
 			return is;
 		}
 
@@ -207,7 +223,7 @@ public class InterruptSensitiveRestoreTest {
 				}
 			}
 			catch (InterruptedException e) {
-				while (!isClosed()) {
+				while (!closed) {
 					try {
 						synchronized (this) {
 							wait();
@@ -227,7 +243,7 @@ public class InterruptSensitiveRestoreTest {
 	}
 
 	// ------------------------------------------------------------------------
-	
+
 	private static class TestSource implements SourceFunction<Object>, Checkpointed<Serializable> {
 		private static final long serialVersionUID = 1L;
 
@@ -250,4 +266,4 @@ public class InterruptSensitiveRestoreTest {
 			fail("should never be called");
 		}
 	}
-}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/53ed6ada/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
index 88fb383..4003e59 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java
@@ -21,7 +21,10 @@ package org.apache.flink.streaming.runtime.tasks;
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
 import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.core.fs.FSDataInputStream;
@@ -31,8 +34,12 @@ import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.CheckpointStateHandles;
+import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.graph.StreamEdge;
 import org.apache.flink.streaming.api.graph.StreamNode;
@@ -56,17 +63,23 @@ import scala.concurrent.duration.FiniteDuration;
 
 import java.io.Serializable;
 import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Random;
+import java.util.Set;
 import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.RunnableFuture;
 import java.util.concurrent.TimeUnit;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 
 /**
  * Tests for {@link OneInputStreamTask}.
@@ -82,6 +95,9 @@ import static org.junit.Assert.assertTrue;
 @PowerMockIgnore({"javax.management.*", "com.sun.jndi.*"})
 public class OneInputStreamTaskTest extends TestLogger {
 
+	private static final ListStateDescriptor<Integer> TEST_DESCRIPTOR =
+			new ListStateDescriptor<>("test", new IntSerializer());
+
 	/**
 	 * This test verifies that open() and close() are correctly called. This test also verifies
 	 * that timestamps of emitted elements are correct. {@link StreamMap} assigns the input
@@ -358,7 +374,7 @@ public class OneInputStreamTaskTest extends TestLogger {
 		testHarness.invoke(env);
 		testHarness.waitForTaskRunning(deadline.timeLeft().toMillis());
 
-		streamTask.triggerCheckpoint(checkpointId, checkpointTimestamp);
+		while(!streamTask.triggerCheckpoint(checkpointId, checkpointTimestamp));
 
 		// since no state was set, there shouldn't be restore calls
 		assertEquals(0, TestingStreamOperator.numberRestoreCalls);
@@ -371,7 +387,7 @@ public class OneInputStreamTaskTest extends TestLogger {
 		testHarness.waitForTaskCompletion(deadline.timeLeft().toMillis());
 
 		final OneInputStreamTask<String, String> restoredTask = new OneInputStreamTask<String, String>();
-		restoredTask.setInitialState(env.getState(), env.getKeyGroupStates());
+		restoredTask.setInitialState(env.getState(), env.getKeyGroupStates(), env.getPartitionableOperatorState());
 
 		final OneInputStreamTaskTestHarness<String, String> restoredTaskHarness = new OneInputStreamTaskTestHarness<String, String>(restoredTask, BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO);
 		restoredTaskHarness.configureForKeyedStream(keySelector, BasicTypeInfo.STRING_TYPE_INFO);
@@ -465,6 +481,7 @@ public class OneInputStreamTaskTest extends TestLogger {
 		private volatile long checkpointId;
 		private volatile ChainedStateHandle<StreamStateHandle> state;
 		private volatile List<KeyGroupsStateHandle> keyGroupStates;
+		private volatile List<Collection<OperatorStateHandle>> partitionableOperatorState;
 
 		private final OneShotLatch checkpointLatch = new OneShotLatch();
 
@@ -486,6 +503,10 @@ public class OneInputStreamTaskTest extends TestLogger {
 			return result;
 		}
 
+		List<Collection<OperatorStateHandle>> getPartitionableOperatorState() {
+			return partitionableOperatorState;
+		}
+
 		AcknowledgeStreamMockEnvironment(
 				Configuration jobConfig, Configuration taskConfig,
 				ExecutionConfig executionConfig, long memorySize,
@@ -497,13 +518,21 @@ public class OneInputStreamTaskTest extends TestLogger {
 		@Override
 		public void acknowledgeCheckpoint(
 				long checkpointId,
-				ChainedStateHandle<StreamStateHandle> state, List<KeyGroupsStateHandle> keyGroupStates,
+				CheckpointStateHandles checkpointStateHandles,
 				long syncDuration, long asymcDuration, long alignmentByte, long alignmentDuration) {
 
 			this.checkpointId = checkpointId;
-			this.state = state;
-			this.keyGroupStates = keyGroupStates;
-
+			if(checkpointStateHandles != null) {
+				this.state = checkpointStateHandles.getNonPartitionedStateHandles();
+				this.keyGroupStates = checkpointStateHandles.getKeyGroupsStateHandle();
+				ChainedStateHandle<OperatorStateHandle> chainedStateHandle = checkpointStateHandles.getPartitioneableStateHandles();
+				Collection<OperatorStateHandle>[] ia = new Collection[chainedStateHandle.getLength()];
+				this.partitionableOperatorState = Arrays.asList(ia);
+
+				for (int i = 0; i < chainedStateHandle.getLength(); ++i) {
+					partitionableOperatorState.set(i, Collections.singletonList(chainedStateHandle.get(i)));
+				}
+			}
 			checkpointLatch.trigger();
 		}
 
@@ -513,17 +542,56 @@ public class OneInputStreamTaskTest extends TestLogger {
 	}
 
 	private static class TestingStreamOperator<IN, OUT>
-			extends AbstractStreamOperator<OUT> implements OneInputStreamOperator<IN, OUT> {
+			extends AbstractStreamOperator<OUT>
+			implements OneInputStreamOperator<IN, OUT>, StreamCheckpointedOperator {
 
 		private static final long serialVersionUID = 774614855940397174L;
 
 		public static int numberRestoreCalls = 0;
+		public static int numberSnapshotCalls = 0;
 
 		private final long seed;
 		private final long recoveryTimestamp;
 
 		private transient Random random;
 
+		@Override
+		public void open() throws Exception {
+			super.open();
+
+			ListState<Integer> partitionableState = getOperatorStateBackend().getPartitionableState(TEST_DESCRIPTOR);
+
+			if (numberSnapshotCalls == 0) {
+				for (Integer v : partitionableState.get()) {
+					fail();
+				}
+			} else {
+				Set<Integer> result = new HashSet<>();
+				for (Integer v : partitionableState.get()) {
+					result.add(v);
+				}
+
+				assertEquals(2, result.size());
+				assertTrue(result.contains(42));
+				assertTrue(result.contains(4711));
+			}
+		}
+
+		@Override
+		public RunnableFuture<OperatorStateHandle> snapshotState(
+				long checkpointId, long timestamp, CheckpointStreamFactory streamFactory) throws Exception {
+
+			ListState<Integer> partitionableState =
+					getOperatorStateBackend().getPartitionableState(TEST_DESCRIPTOR);
+			partitionableState.clear();
+
+			partitionableState.add(42);
+			partitionableState.add(4711);
+
+			++numberSnapshotCalls;
+			return super.snapshotState(checkpointId, timestamp, streamFactory);
+		}
+
 		TestingStreamOperator(long seed, long recoveryTimestamp) {
 			this.seed = seed;
 			this.recoveryTimestamp = recoveryTimestamp;