You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by gy...@apache.org on 2015/06/25 19:21:42 UTC

[09/12] flink git commit: [streaming] Allow using both partitioned and non-partitioned state in an operator + refactor

[streaming] Allow using both partitioned and non-partitioned state in an operator + refactor


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/474ff4df
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/474ff4df
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/474ff4df

Branch: refs/heads/master
Commit: 474ff4dfd93a57730969ffdd221b7ecec06ed479
Parents: 0a4144e
Author: Gyula Fora <gy...@apache.org>
Authored: Sun Jun 21 16:22:21 2015 +0200
Committer: Gyula Fora <gy...@apache.org>
Committed: Thu Jun 25 16:38:07 2015 +0200

----------------------------------------------------------------------
 docs/apis/streaming_guide.md                    |  9 +-
 .../api/common/functions/RuntimeContext.java    | 16 +++-
 .../util/AbstractRuntimeUDFContext.java         | 10 ++-
 .../api/persistent/PersistentKafkaSource.java   |  2 +-
 .../source/StatefulSequenceSource.java          |  2 +-
 .../operators/AbstractUdfStreamOperator.java    | 16 ++--
 .../api/state/StreamOperatorState.java          |  2 +-
 .../streaming/runtime/tasks/OutputHandler.java  | 10 ++-
 .../streaming/runtime/tasks/StreamTask.java     | 93 +++++++-------------
 .../runtime/tasks/StreamingRuntimeContext.java  | 49 +++++++----
 .../api/state/StatefulOperatorTest.java         | 61 ++++---------
 .../runtime/tasks/SourceStreamTaskTest.java     |  2 +-
 .../StreamCheckpointingITCase.java              |  8 +-
 .../ProcessFailureStreamingRecoveryITCase.java  |  5 +-
 14 files changed, 128 insertions(+), 157 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/474ff4df/docs/apis/streaming_guide.md
----------------------------------------------------------------------
diff --git a/docs/apis/streaming_guide.md b/docs/apis/streaming_guide.md
index 997a245..f252f1e 100644
--- a/docs/apis/streaming_guide.md
+++ b/docs/apis/streaming_guide.md
@@ -1191,17 +1191,14 @@ Stateful computation
 Flink supports the checkpointing and persistence of user defined operator state, so in case of a failure this state can be restored to the latest checkpoint and the processing will continue from there. This gives exactly once processing semantics with respect to the operator states when the sources follow this stateful pattern as well. In practice this usually means that sources keep track of their current offset as their OperatorState. The `PersistentKafkaSource` provides this stateful functionality for reading streams from Kafka. 
 
 Flink supports two ways of accessing operator states: partitioned and non-partitioned state access.
-In case of non-partitioned state access, an operator state is maintained for each parallel instance of a given operator. When `OperatorState.getState()` is called a separate state is returned in each parallel instance. In practice this means if we keep a counter for the received inputs in a mapper, `getState()` will return number of inputs processed by each parallel mapper.
 
-In case of partitioned state access the user needs to define a `KeyExtractor` which will assign a key to each input of the stateful operator:
+In case of non-partitioned state access, an operator state is maintained for each parallel instance of a given operator. When `OperatorState.getState()` is called, a separate state is returned in each parallel instance. In practice this means if we keep a counter for the received inputs in a mapper, `getState()` will return number of inputs processed by each parallel mapper.
 
-`stream.map(counter).setStatePartitioner(…)`
-
-A separate `OperatorState` is maintained for each received key which can be used for instance to count received inputs by different keys, or store and update summary statistics of different sub-streams.
+In case of of partitioned `OperatorState` a separate state is maintained for each received key. This can be used for instance to count received inputs by different keys, or store and update summary statistics of different sub-streams.
 
 Checkpointing of the states needs to be enabled from the `StreamExecutionEnvironment` using the `enableCheckpointing(…)` where additional parameters can be passed to modify the default 5 second checkpoint interval.
 
-Operators can be accessed from the `RuntimeContext` using the `getOperatorState(“name”, defaultValue)` method so it is only accessible in `RichFunction`s. A recommended usage pattern is to retrieve the operator state in the `open(…)` method of the operator and set it as a field in the operator instance for runtime usage. Multiple `OperatorState`s can be used simultaneously by the same operator by using different names to identify them.
+Operator states can be accessed from the `RuntimeContext` using the `getOperatorState(“name”, defaultValue, partitioned)` method so it is only accessible in `RichFunction`s. A recommended usage pattern is to retrieve the operator state in the `open(…)` method of the operator and set it as a field in the operator instance for runtime usage. Multiple `OperatorState`s can be used simultaneously by the same operator by using different names to identify them.
 
 By default operator states are checkpointed using default java serialization thus they need to be `Serializable`. The user can gain more control over the state checkpoint mechanism by passing a `StateCheckpointer` instance when retrieving the `OperatorState` from the `RuntimeContext`. The `StateCheckpointer` allows custom implementations for the checkpointing logic for increased efficiency and to store arbitrary non-serializable states.
 

http://git-wip-us.apache.org/repos/asf/flink/blob/474ff4df/flink-core/src/main/java/org/apache/flink/api/common/functions/RuntimeContext.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/functions/RuntimeContext.java b/flink-core/src/main/java/org/apache/flink/api/common/functions/RuntimeContext.java
index 2c2ec53..8046327 100644
--- a/flink-core/src/main/java/org/apache/flink/api/common/functions/RuntimeContext.java
+++ b/flink-core/src/main/java/org/apache/flink/api/common/functions/RuntimeContext.java
@@ -186,12 +186,16 @@ public interface RuntimeContext {
 	 *            the first time {@link OperatorState#getState()} (for every
 	 *            state partition) is called before
 	 *            {@link OperatorState#updateState(Object)}.
+	 * @param partitioned
+	 *            Sets whether partitioning should be applied for the given
+	 *            state. If true a partitioner key must be used.
 	 * @param checkpointer
 	 *            The {@link StateCheckpointer} that will be used to draw
 	 *            snapshots from the user state.
 	 * @return The {@link OperatorState} for the underlying operator.
 	 */
-	<S,C extends Serializable> OperatorState<S> getOperatorState(String name, S defaultState, StateCheckpointer<S,C> checkpointer);
+	<S, C extends Serializable> OperatorState<S> getOperatorState(String name, S defaultState,
+			boolean partitioned, StateCheckpointer<S, C> checkpointer);
 
 	/**
 	 * Returns the {@link OperatorState} with the given name of the underlying
@@ -205,14 +209,18 @@ public interface RuntimeContext {
 	 * </p>
 	 * 
 	 * @param name
-	 *            Identifier for the state allowing that more operator states can be
-	 *            used by the same operator.
+	 *            Identifier for the state allowing that more operator states
+	 *            can be used by the same operator.
 	 * @param defaultState
 	 *            Default value for the operator state. This will be returned
 	 *            the first time {@link OperatorState#getState()} (for every
 	 *            state partition) is called before
 	 *            {@link OperatorState#updateState(Object)}.
+	 * @param partitioned
+	 *            Sets whether partitioning should be applied for the given
+	 *            state. If true a partitioner key must be used.
 	 * @return The {@link OperatorState} for the underlying operator.
 	 */
-	<S extends Serializable> OperatorState<S> getOperatorState(String name, S defaultState);
+	<S extends Serializable> OperatorState<S> getOperatorState(String name, S defaultState,
+			boolean partitioned);
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/474ff4df/flink-core/src/main/java/org/apache/flink/api/common/functions/util/AbstractRuntimeUDFContext.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/functions/util/AbstractRuntimeUDFContext.java b/flink-core/src/main/java/org/apache/flink/api/common/functions/util/AbstractRuntimeUDFContext.java
index 9f17bd1..33a7abc 100644
--- a/flink-core/src/main/java/org/apache/flink/api/common/functions/util/AbstractRuntimeUDFContext.java
+++ b/flink-core/src/main/java/org/apache/flink/api/common/functions/util/AbstractRuntimeUDFContext.java
@@ -174,12 +174,14 @@ public abstract class AbstractRuntimeUDFContext implements RuntimeContext {
 	}
 	
 	@Override
-	public <S, C extends Serializable> OperatorState<S> getOperatorState(String name, S defaultState, StateCheckpointer<S, C> checkpointer) {
-		throw new UnsupportedOperationException("Operator state is only accessible for streaming operators.");
+	public <S, C extends Serializable> OperatorState<S> getOperatorState(String name,
+			S defaultState, boolean partitioned, StateCheckpointer<S, C> checkpointer) {
+	throw new UnsupportedOperationException("Operator state is only accessible for streaming operators.");
 	}
 
 	@Override
-	public <S extends Serializable> OperatorState<S> getOperatorState(String name, S defaultState) {
-		throw new UnsupportedOperationException("Operator state is only accessible for streaming operators.");
+	public <S extends Serializable> OperatorState<S> getOperatorState(String name, S defaultState,
+			boolean partitioned) {
+	throw new UnsupportedOperationException("Operator state is only accessible for streaming operators.");
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/474ff4df/flink-staging/flink-streaming/flink-streaming-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/api/persistent/PersistentKafkaSource.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/api/persistent/PersistentKafkaSource.java b/flink-staging/flink-streaming/flink-streaming-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/api/persistent/PersistentKafkaSource.java
index d3c651d..befbef6 100644
--- a/flink-staging/flink-streaming/flink-streaming-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/api/persistent/PersistentKafkaSource.java
+++ b/flink-staging/flink-streaming/flink-streaming-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/api/persistent/PersistentKafkaSource.java
@@ -145,7 +145,7 @@ public class PersistentKafkaSource<OUT> extends RichParallelSourceFunction<OUT>
 		// most likely the number of offsets we're going to store here will be lower than the number of partitions.
 		int numPartitions = getNumberOfPartitions();
 		LOG.debug("The topic {} has {} partitions", topicName, numPartitions);
-		this.lastOffsets = getRuntimeContext().getOperatorState("offset", new long[numPartitions]);
+		this.lastOffsets = getRuntimeContext().getOperatorState("offset", new long[numPartitions], false);
 		this.commitedOffsets = new long[numPartitions];
 		// check if there are offsets to restore
 		if (!Arrays.equals(lastOffsets.getState(), new long[numPartitions])) {

http://git-wip-us.apache.org/repos/asf/flink/blob/474ff4df/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/functions/source/StatefulSequenceSource.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/functions/source/StatefulSequenceSource.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/functions/source/StatefulSequenceSource.java
index 13d0b6c..3cb8b90 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/functions/source/StatefulSequenceSource.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/functions/source/StatefulSequenceSource.java
@@ -74,7 +74,7 @@ public class StatefulSequenceSource extends RichParallelSourceFunction<Long> {
 	
 	@Override
 	public void open(Configuration conf){
-		collected = getRuntimeContext().getOperatorState("collected", 0L);
+		collected = getRuntimeContext().getOperatorState("collected", 0L, false);
 	}
 
 	@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/474ff4df/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
index b37c426..9faf0c0 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
@@ -32,6 +32,8 @@ import org.apache.flink.streaming.api.checkpoint.CheckpointCommitter;
 import org.apache.flink.streaming.api.state.StreamOperatorState;
 import org.apache.flink.streaming.runtime.tasks.StreamingRuntimeContext;
 
+import com.google.common.collect.ImmutableMap;
+
 /**
  * This is used as the base class for operators that have a user-defined
  * function.
@@ -72,25 +74,25 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function & Serial
 
 	@SuppressWarnings({ "unchecked", "rawtypes" })
 	public void restoreInitialState(Map<String, PartitionedStateHandle> snapshots) throws Exception {
-
-		Map<String, StreamOperatorState> operatorStates = runtimeContext.getOperatorStates();
-		
+		// We iterate over the states registered for this operator, initialize and restore it
 		for (Entry<String, PartitionedStateHandle> snapshot : snapshots.entrySet()) {
-			StreamOperatorState restoredState = runtimeContext.createRawState();
+			Map<Serializable, StateHandle<Serializable>> handles = snapshot.getValue().getState();
+			StreamOperatorState restoredState = runtimeContext.getState(snapshot.getKey(),
+					!(handles instanceof ImmutableMap));
 			restoredState.restoreState(snapshot.getValue().getState());
-			operatorStates.put(snapshot.getKey(), restoredState);
 		}
-
 	}
 
 	@SuppressWarnings({ "rawtypes", "unchecked" })
 	public Map<String, PartitionedStateHandle> getStateSnapshotFromFunction(long checkpointId, long timestamp)
 			throws Exception {
-
+		// Get all the states for the operator
 		Map<String, StreamOperatorState> operatorStates = runtimeContext.getOperatorStates();
 		if (operatorStates.isEmpty()) {
+			// We return null to signal that there is nothing to checkpoint
 			return null;
 		} else {
+			// Checkpoint the states and store the handles in a map
 			Map<String, PartitionedStateHandle> snapshots = new HashMap<String, PartitionedStateHandle>();
 
 			for (Entry<String, StreamOperatorState> state : operatorStates.entrySet()) {

http://git-wip-us.apache.org/repos/asf/flink/blob/474ff4df/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/StreamOperatorState.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/StreamOperatorState.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/StreamOperatorState.java
index c19b485..59c624e 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/StreamOperatorState.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/StreamOperatorState.java
@@ -41,7 +41,7 @@ import com.google.common.collect.ImmutableMap;
  */
 public class StreamOperatorState<S, C extends Serializable> implements OperatorState<S> {
 
-	protected static final Serializable DEFAULTKEY = -1;
+	public static final Serializable DEFAULTKEY = -1;
 
 	private S state;
 	private StateCheckpointer<S, C> checkpointer;

http://git-wip-us.apache.org/repos/asf/flink/blob/474ff4df/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/OutputHandler.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/OutputHandler.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/OutputHandler.java
index 6e88610..2d2f29b 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/OutputHandler.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/OutputHandler.java
@@ -35,6 +35,7 @@ import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.graph.StreamEdge;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.api.operators.Output;
+import org.apache.flink.streaming.api.operators.StreamOperator;
 import org.apache.flink.streaming.runtime.io.RecordWriterFactory;
 import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
@@ -51,7 +52,7 @@ public class OutputHandler<OUT> {
 	private ClassLoader cl;
 	private Output<OUT> outerOutput;
 
-	public List<OneInputStreamOperator<?, ?>> chainedOperators;
+	public List<StreamOperator<?>> chainedOperators;
 
 	private Map<StreamEdge, StreamOutput<?>> outputMap;
 
@@ -63,7 +64,7 @@ public class OutputHandler<OUT> {
 		// Initialize some fields
 		this.vertex = vertex;
 		this.configuration = new StreamConfig(vertex.getTaskConfiguration());
-		this.chainedOperators = new ArrayList<OneInputStreamOperator<?, ?>>();
+		this.chainedOperators = new ArrayList<StreamOperator<?>>();
 		this.outputMap = new HashMap<StreamEdge, StreamOutput<?>>();
 		this.cl = vertex.getUserCodeClassLoader();
 
@@ -88,6 +89,9 @@ public class OutputHandler<OUT> {
 		// We create the outer output that will be passed to the first task
 		// in the chain
 		this.outerOutput = createChainedCollector(configuration);
+		
+		// Add the head operator to the end of the list
+		this.chainedOperators.add(vertex.streamOperator);
 	}
 
 	public void broadcastBarrier(long id, long timestamp) throws IOException, InterruptedException {
@@ -101,7 +105,7 @@ public class OutputHandler<OUT> {
 		return outputMap.values();
 	}
 	
-	public List<OneInputStreamOperator<?, ?>> getChainedOperators(){
+	public List<StreamOperator<?>> getChainedOperators(){
 		return chainedOperators;
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/474ff4df/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index 2b8964d..8f6bc43 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -43,7 +43,6 @@ import org.apache.flink.runtime.state.StateHandleProvider;
 import org.apache.flink.runtime.util.SerializedValue;
 import org.apache.flink.runtime.util.event.EventListener;
 import org.apache.flink.streaming.api.graph.StreamConfig;
-import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.api.operators.StatefulStreamOperator;
 import org.apache.flink.streaming.api.operators.StreamOperator;
 import org.apache.flink.streaming.api.state.WrapperStateHandle;
@@ -74,8 +73,6 @@ public abstract class StreamTask<OUT, O extends StreamOperator<OUT>> extends Abs
 
 	protected ClassLoader userClassLoader;
 	
-	private StateHandleProvider<Serializable> stateHandleProvider;
-
 	private EventListener<TaskEvent> superstepListener;
 
 	public StreamTask() {
@@ -88,11 +85,10 @@ public abstract class StreamTask<OUT, O extends StreamOperator<OUT>> extends Abs
 	public void registerInputOutput() {
 		this.userClassLoader = getUserCodeClassLoader();
 		this.configuration = new StreamConfig(getTaskConfiguration());
-		this.stateHandleProvider = getStateHandleProvider();
+		
+		streamOperator = configuration.getStreamOperator(userClassLoader);
 
 		outputHandler = new OutputHandler<OUT>(this);
-
-		streamOperator = configuration.getStreamOperator(userClassLoader);
 		
 		if (streamOperator != null) {
 			// IterationHead and IterationTail don't have an Operator...
@@ -103,7 +99,7 @@ public abstract class StreamTask<OUT, O extends StreamOperator<OUT>> extends Abs
 			streamOperator.setup(outputHandler.getOutput(), headContext);
 		}
 
-		hasChainedOperators = !outputHandler.getChainedOperators().isEmpty();
+		hasChainedOperators = !(outputHandler.getChainedOperators().size() == 1);
 	}
 
 	public String getName() {
@@ -167,20 +163,16 @@ public abstract class StreamTask<OUT, O extends StreamOperator<OUT>> extends Abs
 	}
 
 	protected void openOperator() throws Exception {
-		streamOperator.open(getTaskConfiguration());
-
-		for (OneInputStreamOperator<?, ?> operator : outputHandler.chainedOperators) {
+		for (StreamOperator<?> operator : outputHandler.getChainedOperators()) {
 			operator.open(getTaskConfiguration());
 		}
 	}
 
 	protected void closeOperator() throws Exception {
-		streamOperator.close();
-
 		// We need to close them first to last, since upstream operators in the chain might emit
 		// elements in their close methods.
-		for (int i = outputHandler.chainedOperators.size()-1; i >= 0; i--) {
-			outputHandler.chainedOperators.get(i).close();
+		for (int i = outputHandler.getChainedOperators().size()-1; i >= 0; i--) {
+			outputHandler.getChainedOperators().get(i).close();
 		}
 	}
 
@@ -206,28 +198,19 @@ public abstract class StreamTask<OUT, O extends StreamOperator<OUT>> extends Abs
 	@SuppressWarnings("unchecked")
 	@Override
 	public void setInitialState(StateHandle<Serializable> stateHandle) throws Exception {
-		// here, we later resolve the state handle into the actual state by
-		// loading the state described by the handle from the backup store
-		Serializable state = stateHandle.getState();
+		
+		// We retrieve end restore the states for the chained oeprators.
+		List<Serializable> chainedStates = (List<Serializable>) stateHandle.getState();
 
-		List<Serializable> chainedStates = (List<Serializable>) state;
+		// We restore all stateful chained operators
+		for (int i = 0; i < chainedStates.size(); i++) {
+			Serializable state = chainedStates.get(i);
+			// If state is not null we need to restore it
+			if (state != null) {
+				StreamOperator<?> chainedOperator = outputHandler.getChainedOperators().get(i);
 
-		Serializable headState = chainedStates.get(0);
-		if (headState != null) {
-			if (streamOperator instanceof StatefulStreamOperator) {
-				((StatefulStreamOperator<?>) streamOperator)
-						.restoreInitialState((Map<String, PartitionedStateHandle>) headState);
-			}
-		}
-
-		for (int i = 1; i < chainedStates.size(); i++) {
-			Serializable chainedState = chainedStates.get(i);
-			if (chainedState != null) {
-				StreamOperator<?> chainedOperator = outputHandler.getChainedOperators().get(i - 1);
-				if (chainedOperator instanceof StatefulStreamOperator) {
-					((StatefulStreamOperator<?>) chainedOperator)
-							.restoreInitialState((Map<String, PartitionedStateHandle>) chainedState);
-				}
+				((StatefulStreamOperator<?>) chainedOperator)
+						.restoreInitialState((Map<String, PartitionedStateHandle>) state);
 
 			}
 		}
@@ -242,22 +225,21 @@ public abstract class StreamTask<OUT, O extends StreamOperator<OUT>> extends Abs
 				try {
 					LOG.debug("Starting checkpoint {} on task {}", checkpointId, getName());
 					
-					// first draw the state that should go into checkpoint
+					// We wrap the states of the chained operators in a list, marking non-stateful oeprators with null
 					List<Map<String, PartitionedStateHandle>> chainedStates = new ArrayList<Map<String, PartitionedStateHandle>>();
-					StateHandle<Serializable> stateHandle;
+					
+					// A wrapper handle is created for the List of statehandles
+					WrapperStateHandle stateHandle;
 					try {
 
-						if (streamOperator instanceof StatefulStreamOperator) {
-							chainedStates.add(((StatefulStreamOperator<?>) streamOperator).getStateSnapshotFromFunction(checkpointId, timestamp));
-						}
-
-
-						if (hasChainedOperators) {
-							// We construct a list of states for chained tasks
-							for (OneInputStreamOperator<?, ?> chainedOperator : outputHandler.getChainedOperators()) {
-								if (chainedOperator instanceof StatefulStreamOperator) {
-									chainedStates.add(((StatefulStreamOperator<?>) chainedOperator).getStateSnapshotFromFunction(checkpointId, timestamp));
-								}
+						// We construct a list of states for chained tasks
+						for (StreamOperator<?> chainedOperator : outputHandler
+								.getChainedOperators()) {
+							if (chainedOperator instanceof StatefulStreamOperator) {
+								chainedStates.add(((StatefulStreamOperator<?>) chainedOperator)
+										.getStateSnapshotFromFunction(checkpointId, timestamp));
+							}else{
+								chainedStates.add(null);
 							}
 						}
 						
@@ -296,23 +278,10 @@ public abstract class StreamTask<OUT, O extends StreamOperator<OUT>> extends Abs
 			
 			List<Map<String, PartitionedStateHandle>> chainedStates = (List<Map<String, PartitionedStateHandle>>) stateHandle.deserializeValue(getUserCodeClassLoader()).getState();
 
-			Map<String, PartitionedStateHandle> headState = chainedStates.get(0);
-			if (headState != null) {
-				if (streamOperator instanceof StatefulStreamOperator) {
-					for (Entry<String, PartitionedStateHandle> stateEntry : headState
-							.entrySet()) {
-						for (StateHandle<Serializable> handle : stateEntry.getValue().getState().values()) {
-							((StatefulStreamOperator) streamOperator).confirmCheckpointCompleted(
-									checkpointId, stateEntry.getKey(), handle);
-						}
-					}
-				}
-			}
-
-			for (int i = 1; i < chainedStates.size(); i++) {
+			for (int i = 0; i < chainedStates.size(); i++) {
 				Map<String, PartitionedStateHandle> chainedState = chainedStates.get(i);
 				if (chainedState != null) {
-					StreamOperator<?> chainedOperator = outputHandler.getChainedOperators().get(i - 1);
+					StreamOperator<?> chainedOperator = outputHandler.getChainedOperators().get(i);
 					if (chainedOperator instanceof StatefulStreamOperator) {
 						for (Entry<String, PartitionedStateHandle> stateEntry : chainedState
 								.entrySet()) {

http://git-wip-us.apache.org/repos/asf/flink/blob/474ff4df/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamingRuntimeContext.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamingRuntimeContext.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamingRuntimeContext.java
index 14ea5ea..de5ef06 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamingRuntimeContext.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamingRuntimeContext.java
@@ -20,6 +20,8 @@ package org.apache.flink.streaming.runtime.tasks;
 
 import java.io.Serializable;
 import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.List;
 import java.util.Map;
 
 import org.apache.flink.api.common.ExecutionConfig;
@@ -45,6 +47,7 @@ public class StreamingRuntimeContext extends RuntimeUDFContext {
 
 	private final Environment env;
 	private final Map<String, StreamOperatorState> states;
+	private final List<PartitionedStreamOperatorState> partitionedStates;
 	private final KeySelector<?, ?> statePartitioner;
 	private final StateHandleProvider<?> provider;
 
@@ -56,6 +59,7 @@ public class StreamingRuntimeContext extends RuntimeUDFContext {
 		this.env = env;
 		this.statePartitioner = statePartitioner;
 		this.states = new HashMap<String, StreamOperatorState>();
+		this.partitionedStates = new LinkedList<PartitionedStreamOperatorState>();
 		this.provider = provider;
 	}
 
@@ -81,11 +85,11 @@ public class StreamingRuntimeContext extends RuntimeUDFContext {
 	@SuppressWarnings("unchecked")
 	@Override
 	public <S, C extends Serializable> OperatorState<S> getOperatorState(String name,
-			S defaultState, StateCheckpointer<S, C> checkpointer) {
+			S defaultState, boolean partitioned, StateCheckpointer<S, C> checkpointer) {
 		if (defaultState == null) {
 			throw new RuntimeException("Cannot set default state to null.");
 		}
-		StreamOperatorState<S, C> state = (StreamOperatorState<S, C>) getState(name);
+		StreamOperatorState<S, C> state = (StreamOperatorState<S, C>) getState(name, partitioned);
 		state.setDefaultState(defaultState);
 		state.setCheckpointer(checkpointer);
 
@@ -94,21 +98,28 @@ public class StreamingRuntimeContext extends RuntimeUDFContext {
 
 	@SuppressWarnings("unchecked")
 	@Override
-	public <S extends Serializable> OperatorState<S> getOperatorState(String name, S defaultState) {
+	public <S extends Serializable> OperatorState<S> getOperatorState(String name, S defaultState,
+			boolean partitioned) {
 		if (defaultState == null) {
 			throw new RuntimeException("Cannot set default state to null.");
 		}
-		StreamOperatorState<S, S> state = (StreamOperatorState<S, S>) getState(name);
+		StreamOperatorState<S, S> state = (StreamOperatorState<S, S>) getState(name, partitioned);
 		state.setDefaultState(defaultState);
 
 		return (OperatorState<S>) state;
 	}
 
-	private StreamOperatorState<?, ?> getState(String name) {
+	public StreamOperatorState<?, ?> getState(String name, boolean partitioned) {
+		// Try fetching state from the map
 		StreamOperatorState state = states.get(name);
 		if (state == null) {
-			state = createRawState();
+			// If not found, create empty state and add to the map
+			state = createRawState(partitioned);
 			states.put(name, state);
+			// We keep a reference to all partitioned states for registering input
+			if (state instanceof PartitionedStreamOperatorState) {
+				partitionedStates.add((PartitionedStreamOperatorState) state);
+			}
 		}
 		return state;
 	}
@@ -119,14 +130,19 @@ public class StreamingRuntimeContext extends RuntimeUDFContext {
 	 * @return An empty operator state.
 	 */
 	@SuppressWarnings("unchecked")
-	public StreamOperatorState createRawState() {
-		if (statePartitioner == null) {
-			return new StreamOperatorState(provider);
+	public StreamOperatorState createRawState(boolean partitioned) {
+		if (partitioned) {
+			if (statePartitioner != null) {
+				return new PartitionedStreamOperatorState(provider, statePartitioner);
+			} else {
+				throw new RuntimeException(
+						"A partitioning key must be provided for pastitioned state.");
+			}
 		} else {
-			return new PartitionedStreamOperatorState(provider, statePartitioner);
+			return new StreamOperatorState(provider);
 		}
 	}
-	
+
 	/**
 	 * Provides access to the all the states contained in the context
 	 * 
@@ -137,14 +153,17 @@ public class StreamingRuntimeContext extends RuntimeUDFContext {
 	}
 
 	/**
-	 * Sets the next input of the underlying operators, used to access partitioned states.
-	 * @param nextRecord Next input of the operator.
+	 * Sets the next input of the underlying operators, used to access
+	 * partitioned states.
+	 * 
+	 * @param nextRecord
+	 *            Next input of the operator.
 	 */
 	@SuppressWarnings("unchecked")
 	public void setNextInput(Object nextRecord) {
 		if (statePartitioner != null) {
-			for (StreamOperatorState state : states.values()) {
-				((PartitionedStreamOperatorState) state).setCurrentInput(nextRecord);
+			for (PartitionedStreamOperatorState state : partitionedStates) {
+				state.setCurrentInput(nextRecord);
 			}
 		}
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/474ff4df/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/state/StatefulOperatorTest.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/state/StatefulOperatorTest.java b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/state/StatefulOperatorTest.java
index 3b4e0e2..32a638a 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/state/StatefulOperatorTest.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/state/StatefulOperatorTest.java
@@ -56,65 +56,33 @@ public class StatefulOperatorTest {
 
 		List<String> out = new ArrayList<String>();
 
-		StreamMap<Integer, String> map = createOperatorWithContext(out, null, null);
-		StreamingRuntimeContext context = map.getRuntimeContext();
-
-		processInputs(map, Arrays.asList(1, 2, 3, 4, 5));
-
-		assertEquals(Arrays.asList("1", "2", "3", "4", "5"), out);
-		assertEquals((Integer) 5, context.getOperatorState("counter", 0).getState());
-		assertEquals("12345", context.getOperatorState("concat", "").getState());
-
-		byte[] serializedState = InstantiationUtil.serializeObject(map.getStateSnapshotFromFunction(1, 1));
-
-		StreamMap<Integer, String> restoredMap = createOperatorWithContext(out, null, serializedState);
-		StreamingRuntimeContext restoredContext = restoredMap.getRuntimeContext();
-
-		assertEquals((Integer) 5, restoredContext.getOperatorState("counter", 0).getState());
-		assertEquals("12345", restoredContext.getOperatorState("concat", "").getState());
-		out.clear();
-
-		processInputs(restoredMap, Arrays.asList(7, 8));
-
-		assertEquals(Arrays.asList("7", "8"), out);
-		assertEquals((Integer) 7, restoredContext.getOperatorState("counter", 0).getState());
-		assertEquals("1234578", restoredContext.getOperatorState("concat", "").getState());
-
-	}
-
-	@Test
-	public void partitionedStateTest() throws Exception {
-		List<String> out = new ArrayList<String>();
-
 		StreamMap<Integer, String> map = createOperatorWithContext(out, new ModKey(2), null);
 		StreamingRuntimeContext context = map.getRuntimeContext();
 
 		processInputs(map, Arrays.asList(1, 2, 3, 4, 5));
 
 		assertEquals(Arrays.asList("1", "2", "3", "4", "5"), out);
-		assertEquals(ImmutableMap.of(0, 2, 1, 3), context.getOperatorStates().get("counter").getPartitionedState());
-		assertEquals(ImmutableMap.of(0, "24", 1, "135"), context.getOperatorStates().get("concat")
-				.getPartitionedState());
+		assertEquals((Integer) 5, context.getOperatorState("counter", 0, false).getState());
+		assertEquals(ImmutableMap.of(0, 2, 1, 3), context.getOperatorStates().get("groupCounter").getPartitionedState());
+		assertEquals("12345", context.getOperatorState("concat", "", false).getState());
 
 		byte[] serializedState = InstantiationUtil.serializeObject(map.getStateSnapshotFromFunction(1, 1));
 
 		StreamMap<Integer, String> restoredMap = createOperatorWithContext(out, new ModKey(2), serializedState);
 		StreamingRuntimeContext restoredContext = restoredMap.getRuntimeContext();
 
-		assertEquals(ImmutableMap.of(0, 2, 1, 3), restoredContext.getOperatorStates().get("counter")
-				.getPartitionedState());
-		assertEquals(ImmutableMap.of(0, "24", 1, "135"), restoredContext.getOperatorStates().get("concat")
-				.getPartitionedState());
+		assertEquals((Integer) 5, restoredContext.getOperatorState("counter", 0, false).getState());
+		assertEquals(ImmutableMap.of(0, 2, 1, 3), context.getOperatorStates().get("groupCounter").getPartitionedState());
+		assertEquals("12345", restoredContext.getOperatorState("concat", "", false).getState());
 		out.clear();
 
 		processInputs(restoredMap, Arrays.asList(7, 8));
 
 		assertEquals(Arrays.asList("7", "8"), out);
-		assertEquals(ImmutableMap.of(0, 3, 1, 4), restoredContext.getOperatorStates().get("counter")
-				.getPartitionedState());
-		assertEquals(ImmutableMap.of(0, "248", 1, "1357"), restoredContext.getOperatorStates().get("concat")
+		assertEquals((Integer) 7, restoredContext.getOperatorState("counter", 0, false).getState());
+		assertEquals(ImmutableMap.of(0, 3, 1, 4), restoredContext.getOperatorStates().get("groupCounter")
 				.getPartitionedState());
-
+		assertEquals("1234578", restoredContext.getOperatorState("concat", "", false).getState());
 	}
 
 	private void processInputs(StreamMap<Integer, ?> map, List<Integer> input) throws Exception {
@@ -161,11 +129,13 @@ public class StatefulOperatorTest {
 
 		private static final long serialVersionUID = -9007873655253339356L;
 		OperatorState<Integer> counter;
+		OperatorState<Integer> groupCounter;
 		OperatorState<String> concat;
 
 		@Override
 		public String map(Integer value) throws Exception {
 			counter.updateState(counter.getState() + 1);
+			groupCounter.updateState(groupCounter.getState() + 1);
 			concat.updateState(concat.getState() + value.toString());
 			try {
 				counter.updateState(null);
@@ -177,15 +147,16 @@ public class StatefulOperatorTest {
 
 		@Override
 		public void open(Configuration conf) {
-			counter = getRuntimeContext().getOperatorState("counter", 0);
-			concat = getRuntimeContext().getOperatorState("concat", "");
+			counter = getRuntimeContext().getOperatorState("counter", 0, false);
+			groupCounter = getRuntimeContext().getOperatorState("groupCounter", 0, true);
+			concat = getRuntimeContext().getOperatorState("concat", "", false);
 			try {
-				getRuntimeContext().getOperatorState("test", null);
+				getRuntimeContext().getOperatorState("test", null, true);
 				fail();
 			} catch (RuntimeException e){
 			}
 			try {
-				getRuntimeContext().getOperatorState("test", null, null);
+				getRuntimeContext().getOperatorState("test", null, true, null);
 				fail();
 			} catch (RuntimeException e){
 			}

http://git-wip-us.apache.org/repos/asf/flink/blob/474ff4df/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskTest.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskTest.java b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskTest.java
index 4e0523e..5cd34cc 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskTest.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskTest.java
@@ -193,7 +193,7 @@ public class SourceStreamTaskTest extends StreamTaskTestBase {
 		
 		@Override
 		public void open(Configuration conf){
-			state = getRuntimeContext().getOperatorState("state", 1, this);
+			state = getRuntimeContext().getOperatorState("state", 1, false, this);
 		}
 
 

http://git-wip-us.apache.org/repos/asf/flink/blob/474ff4df/flink-tests/src/test/java/org/apache/flink/test/checkpointing/StreamCheckpointingITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/StreamCheckpointingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/StreamCheckpointingITCase.java
index d58e332..7e046af 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/StreamCheckpointingITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/StreamCheckpointingITCase.java
@@ -211,7 +211,7 @@ public class StreamCheckpointingITCase {
 			step = getRuntimeContext().getNumberOfParallelSubtasks();
 			
 			
-			index = getRuntimeContext().getOperatorState("index", getRuntimeContext().getIndexOfThisSubtask());
+			index = getRuntimeContext().getOperatorState("index", getRuntimeContext().getIndexOfThisSubtask(), false);
 			
 			isRunning = true;
 		}
@@ -265,7 +265,7 @@ public class StreamCheckpointingITCase {
 
 		@Override
 		public void open(Configuration conf) {
-			count = getRuntimeContext().getOperatorState("count", 0L);
+			count = getRuntimeContext().getOperatorState("count", 0L, false);
 		}
 
 		@Override
@@ -347,7 +347,7 @@ public class StreamCheckpointingITCase {
 		
 		@Override
 		public void open(Configuration conf) {
-			this.count = getRuntimeContext().getOperatorState("count", 0L);
+			this.count = getRuntimeContext().getOperatorState("count", 0L, false);
 		}
 
 		@Override
@@ -369,7 +369,7 @@ public class StreamCheckpointingITCase {
 		
 		@Override
 		public void open(Configuration conf) {
-			this.count = getRuntimeContext().getOperatorState("count", 0L);
+			this.count = getRuntimeContext().getOperatorState("count", 0L, false);
 		}
 
 		@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/474ff4df/flink-tests/src/test/java/org/apache/flink/test/recovery/ProcessFailureStreamingRecoveryITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/recovery/ProcessFailureStreamingRecoveryITCase.java b/flink-tests/src/test/java/org/apache/flink/test/recovery/ProcessFailureStreamingRecoveryITCase.java
index 809db1b..688a212 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/recovery/ProcessFailureStreamingRecoveryITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/recovery/ProcessFailureStreamingRecoveryITCase.java
@@ -37,7 +37,6 @@ import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
 import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
-import org.apache.flink.streaming.runtime.tasks.StreamingRuntimeContext;
 import org.junit.Assert;
 
 /**
@@ -154,7 +153,7 @@ public class ProcessFailureStreamingRecoveryITCase extends AbstractProcessFailur
 		
 		@Override
 		public void open(Configuration conf) {
-			collected = getRuntimeContext().getOperatorState("count", 0L);
+			collected = getRuntimeContext().getOperatorState("count", 0L, false);
 		}
 
 		@Override
@@ -199,7 +198,7 @@ public class ProcessFailureStreamingRecoveryITCase extends AbstractProcessFailur
 			stepSize = getRuntimeContext().getNumberOfParallelSubtasks();
 			congruence = getRuntimeContext().getIndexOfThisSubtask();
 			toCollect = (end % stepSize > congruence) ? (end / stepSize + 1) : (end / stepSize);
-			collected = getRuntimeContext().getOperatorState("count", 0L);
+			collected = getRuntimeContext().getOperatorState("count", 0L, false);
 		}
 
 		@Override