You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by gy...@apache.org on 2015/06/25 19:21:43 UTC

[10/12] flink git commit: [streaming] Fix shallow discard + proper checkpoint commit

[streaming] Fix shallow discard + proper checkpoint commit


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/56ae08e2
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/56ae08e2
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/56ae08e2

Branch: refs/heads/master
Commit: 56ae08e2b3ef6b1183d8f88b36cb42749b5c221e
Parents: 5ddd232
Author: Gyula Fora <gy...@apache.org>
Authored: Fri Jun 19 08:13:41 2015 +0200
Committer: Gyula Fora <gy...@apache.org>
Committed: Thu Jun 25 16:38:07 2015 +0200

----------------------------------------------------------------------
 .../checkpoint/SuccessfulCheckpoint.java        |  9 +-
 .../runtime/state/PartitionedStateHandle.java   |  5 +
 .../apache/flink/runtime/state/StateUtils.java  |  7 --
 .../api/persistent/PersistentKafkaSource.java   |  4 +-
 .../streaming/connectors/kafka/KafkaITCase.java |  3 +-
 .../api/checkpoint/CheckpointCommitter.java     |  3 +-
 .../source/StatefulSequenceSource.java          | 34 +++----
 .../operators/AbstractUdfStreamOperator.java    | 21 ++---
 .../api/operators/StatefulStreamOperator.java   |  8 +-
 .../state/PartitionedStreamOperatorState.java   |  3 +
 .../api/state/StreamOperatorState.java          |  7 +-
 .../streaming/api/state/WrapperStateHandle.java | 57 ++++++++++++
 .../streaming/runtime/tasks/OutputHandler.java  |  5 +-
 .../streaming/runtime/tasks/StreamTask.java     | 97 +++++++++++---------
 .../runtime/tasks/StreamingRuntimeContext.java  | 27 +++---
 .../api/state/StatefulOperatorTest.java         | 24 ++++-
 .../streaming/util/SourceFunctionUtil.java      |  5 +-
 17 files changed, 202 insertions(+), 117 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/56ae08e2/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SuccessfulCheckpoint.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SuccessfulCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SuccessfulCheckpoint.java
index c277ea3..5432d33 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SuccessfulCheckpoint.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SuccessfulCheckpoint.java
@@ -80,13 +80,8 @@ public class SuccessfulCheckpoint {
 	 * @param jobVertexID
 	 * @return
 	 */
-	public StateForTask getState(JobVertexID jobVertexID)
-	{
-		if(vertexToState.containsKey(jobVertexID)) {
-			return vertexToState.get(jobVertexID);
-		}
-		
-		return null;
+	public StateForTask getState(JobVertexID jobVertexID) {
+		return vertexToState.get(jobVertexID);
 	}
 
 	// --------------------------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/56ae08e2/flink-runtime/src/main/java/org/apache/flink/runtime/state/PartitionedStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/PartitionedStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/PartitionedStateHandle.java
index 4119df1..b6981c3 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/PartitionedStateHandle.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/PartitionedStateHandle.java
@@ -21,6 +21,11 @@ package org.apache.flink.runtime.state;
 import java.io.Serializable;
 import java.util.Map;
 
+/**
+ * Wrapper for storing the handles for each state in a partitioned form. It can
+ * be used to repartition the state before re-injecting to the tasks.
+ * 
+ */
 public class PartitionedStateHandle implements
 		StateHandle<Map<Serializable, StateHandle<Serializable>>> {
 

http://git-wip-us.apache.org/repos/asf/flink/blob/56ae08e2/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtils.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtils.java
index 7977e09..30daec1 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtils.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateUtils.java
@@ -18,8 +18,6 @@
 
 package org.apache.flink.runtime.state;
 
-import java.util.List;
-
 import org.apache.flink.runtime.jobgraph.tasks.OperatorStateCarrier;
 
 /**
@@ -53,11 +51,6 @@ public class StateUtils {
 		typedOp.setInitialState(typedHandle);
 	}
 
-	public static List<PartitionedStateHandle> rePartitionHandles(
-			List<PartitionedStateHandle> handles, int numPartitions) {
-		return null;
-	}
-
 	// ------------------------------------------------------------------------
 
 	/** Do not instantiate */

http://git-wip-us.apache.org/repos/asf/flink/blob/56ae08e2/flink-staging/flink-streaming/flink-streaming-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/api/persistent/PersistentKafkaSource.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/api/persistent/PersistentKafkaSource.java b/flink-staging/flink-streaming/flink-streaming-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/api/persistent/PersistentKafkaSource.java
index 90427ee..d3c651d 100644
--- a/flink-staging/flink-streaming/flink-streaming-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/api/persistent/PersistentKafkaSource.java
+++ b/flink-staging/flink-streaming/flink-streaming-connectors/flink-connector-kafka/src/main/java/org/apache/flink/streaming/connectors/kafka/api/persistent/PersistentKafkaSource.java
@@ -148,7 +148,7 @@ public class PersistentKafkaSource<OUT> extends RichParallelSourceFunction<OUT>
 		this.lastOffsets = getRuntimeContext().getOperatorState("offset", new long[numPartitions]);
 		this.commitedOffsets = new long[numPartitions];
 		// check if there are offsets to restore
-		if (Arrays.equals(lastOffsets.getState(), new long[numPartitions])) {
+		if (!Arrays.equals(lastOffsets.getState(), new long[numPartitions])) {
 			if (lastOffsets.getState().length != numPartitions) {
 				throw new IllegalStateException("There are "+lastOffsets.getState().length+" offsets to restore for topic "+topicName+" but " +
 						"there are only "+numPartitions+" in the topic");
@@ -217,7 +217,7 @@ public class PersistentKafkaSource<OUT> extends RichParallelSourceFunction<OUT>
 	 * @throws Exception 
 	 */
 	@Override
-	public void commitCheckpoint(long checkpointId, StateHandle<Serializable> state) throws Exception {
+	public void commitCheckpoint(long checkpointId, String stateName, StateHandle<Serializable> state) throws Exception {
 		LOG.info("Commit checkpoint {}", checkpointId);
 
 		long[] checkpointOffsets;

http://git-wip-us.apache.org/repos/asf/flink/blob/56ae08e2/flink-staging/flink-streaming/flink-streaming-connectors/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaITCase.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-connectors/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaITCase.java b/flink-staging/flink-streaming/flink-streaming-connectors/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaITCase.java
index 8fd8973..562dc47 100644
--- a/flink-staging/flink-streaming/flink-streaming-connectors/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaITCase.java
+++ b/flink-staging/flink-streaming/flink-streaming-connectors/flink-connector-kafka/src/test/java/org/apache/flink/streaming/connectors/kafka/KafkaITCase.java
@@ -227,7 +227,6 @@ public class KafkaITCase {
 	 *
 	 */
 	@Test
-	@Ignore
 	public void testPersistentSourceWithOffsetUpdates() throws Exception {
 		LOG.info("Starting testPersistentSourceWithOffsetUpdates()");
 
@@ -314,7 +313,7 @@ public class KafkaITCase {
 
 				LOG.info("Reader " + getRuntimeContext().getIndexOfThisSubtask() + " got " + value + " count=" + count + "/" + finalCount);
 				// verify if we've seen everything
-
+				
 				if (count == finalCount) {
 					LOG.info("Received all values");
 					for (int i = 0; i < values.length; i++) {

http://git-wip-us.apache.org/repos/asf/flink/blob/56ae08e2/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointCommitter.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointCommitter.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointCommitter.java
index 758c19f..306df71 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointCommitter.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/checkpoint/CheckpointCommitter.java
@@ -36,8 +36,9 @@ public interface CheckpointCommitter {
 	 * fail any more.
 	 * 
 	 * @param checkpointId The ID of the checkpoint that has been completed.
+	 * @param stateName The name of the committed state
 	 * @param checkPointedState Handle to the state that was checkpointed with this checkpoint id.
 	 * @throws Exception 
 	 */
-	void commitCheckpoint(long checkpointId, StateHandle<Serializable> checkPointedState) throws Exception;
+	void commitCheckpoint(long checkpointId, String stateName, StateHandle<Serializable> checkPointedState) throws Exception;
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/56ae08e2/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/functions/source/StatefulSequenceSource.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/functions/source/StatefulSequenceSource.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/functions/source/StatefulSequenceSource.java
index e213363..13d0b6c 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/functions/source/StatefulSequenceSource.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/functions/source/StatefulSequenceSource.java
@@ -19,19 +19,20 @@ package org.apache.flink.streaming.api.functions.source;
 
 
 import org.apache.flink.api.common.functions.RuntimeContext;
-import org.apache.flink.streaming.api.checkpoint.Checkpointed;
+import org.apache.flink.api.common.state.OperatorState;
+import org.apache.flink.configuration.Configuration;
 
 /**
  * A stateful streaming source that emits each number from a given interval exactly once,
  * possibly in parallel.
  */
-public class StatefulSequenceSource extends RichParallelSourceFunction<Long> implements Checkpointed<Long> {
+public class StatefulSequenceSource extends RichParallelSourceFunction<Long> {
 	private static final long serialVersionUID = 1L;
 
 	private final long start;
 	private final long end;
 
-	private long collected;
+	private OperatorState<Long> collected;
 
 	private volatile boolean isRunning = true;
 
@@ -44,7 +45,6 @@ public class StatefulSequenceSource extends RichParallelSourceFunction<Long> imp
 	public StatefulSequenceSource(long start, long end) {
 		this.start = start;
 		this.end = end;
-		this.collected = 0;
 	}
 
 	@Override
@@ -60,29 +60,25 @@ public class StatefulSequenceSource extends RichParallelSourceFunction<Long> imp
 				((end - start + 1) % stepSize > (congruence - start)) ?
 					((end - start + 1) / stepSize + 1) :
 					((end - start + 1) / stepSize);
+					
+		Long currentCollected = collected.getState();
 
-		while (isRunning && collected < toCollect) {
-
+		while (isRunning && currentCollected < toCollect) {
 			synchronized (checkpointLock) {
-				ctx.collect(collected * stepSize + congruence);
-				collected++;
+				ctx.collect(currentCollected * stepSize + congruence);
+				collected.updateState(currentCollected + 1);
 			}
+			currentCollected = collected.getState();
 		}
 	}
-
-	@Override
-	public void cancel() {
-		isRunning = false;
-	}
-
+	
 	@Override
-	public Long snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
-		return collected;
+	public void open(Configuration conf){
+		collected = getRuntimeContext().getOperatorState("collected", 0L);
 	}
 
 	@Override
-	public void restoreState(Long state) {
-		collected = state;
+	public void cancel() {
+		isRunning = false;
 	}
-
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/56ae08e2/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
index 9d9896d..b37c426 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
@@ -26,6 +26,7 @@ import java.util.Map.Entry;
 import org.apache.flink.api.common.functions.Function;
 import org.apache.flink.api.common.functions.util.FunctionUtils;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.state.PartitionedStateHandle;
 import org.apache.flink.runtime.state.StateHandle;
 import org.apache.flink.streaming.api.checkpoint.CheckpointCommitter;
 import org.apache.flink.streaming.api.state.StreamOperatorState;
@@ -70,45 +71,43 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function & Serial
 	}
 
 	@SuppressWarnings({ "unchecked", "rawtypes" })
-	public void restoreInitialState(Serializable state) throws Exception {
-
-		Map<String, Map<Serializable, StateHandle<Serializable>>> snapshots = (Map<String, Map<Serializable, StateHandle<Serializable>>>) state;
+	public void restoreInitialState(Map<String, PartitionedStateHandle> snapshots) throws Exception {
 
 		Map<String, StreamOperatorState> operatorStates = runtimeContext.getOperatorStates();
 		
-		for (Entry<String, Map<Serializable, StateHandle<Serializable>>> snapshot : snapshots.entrySet()) {
+		for (Entry<String, PartitionedStateHandle> snapshot : snapshots.entrySet()) {
 			StreamOperatorState restoredState = runtimeContext.createRawState();
-			restoredState.restoreState(snapshot.getValue());
+			restoredState.restoreState(snapshot.getValue().getState());
 			operatorStates.put(snapshot.getKey(), restoredState);
 		}
 
 	}
 
 	@SuppressWarnings({ "rawtypes", "unchecked" })
-	public Serializable getStateSnapshotFromFunction(long checkpointId, long timestamp)
+	public Map<String, PartitionedStateHandle> getStateSnapshotFromFunction(long checkpointId, long timestamp)
 			throws Exception {
 
 		Map<String, StreamOperatorState> operatorStates = runtimeContext.getOperatorStates();
 		if (operatorStates.isEmpty()) {
 			return null;
 		} else {
-			Map<String, Map<Serializable, StateHandle<Serializable>>> snapshots = new HashMap<String, Map<Serializable, StateHandle<Serializable>>>();
+			Map<String, PartitionedStateHandle> snapshots = new HashMap<String, PartitionedStateHandle>();
 
 			for (Entry<String, StreamOperatorState> state : operatorStates.entrySet()) {
 				snapshots.put(state.getKey(),
-						state.getValue().snapshotState(checkpointId, timestamp));
+						new PartitionedStateHandle(state.getValue().snapshotState(checkpointId, timestamp)));
 			}
 
-			return (Serializable) snapshots;
+			return snapshots;
 		}
 
 	}
 
-	public void confirmCheckpointCompleted(long checkpointId,
+	public void confirmCheckpointCompleted(long checkpointId, String stateName,
 			StateHandle<Serializable> checkpointedState) throws Exception {
 		if (userFunction instanceof CheckpointCommitter) {
 			try {
-				((CheckpointCommitter) userFunction).commitCheckpoint(checkpointId, checkpointedState);
+				((CheckpointCommitter) userFunction).commitCheckpoint(checkpointId, stateName, checkpointedState);
 			} catch (Exception e) {
 				throw new Exception("Error while confirming checkpoint " + checkpointId + " to the stream function", e);
 			}

http://git-wip-us.apache.org/repos/asf/flink/blob/56ae08e2/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/operators/StatefulStreamOperator.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/operators/StatefulStreamOperator.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/operators/StatefulStreamOperator.java
index 13012d3..0fac2f8 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/operators/StatefulStreamOperator.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/operators/StatefulStreamOperator.java
@@ -18,7 +18,9 @@
 package org.apache.flink.streaming.api.operators;
 
 import java.io.Serializable;
+import java.util.Map;
 
+import org.apache.flink.runtime.state.PartitionedStateHandle;
 import org.apache.flink.runtime.state.StateHandle;
 
 /**
@@ -29,9 +31,9 @@ import org.apache.flink.runtime.state.StateHandle;
  */
 public interface StatefulStreamOperator<OUT> extends StreamOperator<OUT> {
 
-	void restoreInitialState(Serializable state) throws Exception;
+	void restoreInitialState(Map<String, PartitionedStateHandle> state) throws Exception;
 
-	Serializable getStateSnapshotFromFunction(long checkpointId, long timestamp) throws Exception;
+	Map<String, PartitionedStateHandle> getStateSnapshotFromFunction(long checkpointId, long timestamp) throws Exception;
 
-	void confirmCheckpointCompleted(long checkpointId, StateHandle<Serializable> checkpointedState) throws Exception;
+	void confirmCheckpointCompleted(long checkpointId, String stateName, StateHandle<Serializable> checkpointedState) throws Exception;
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/56ae08e2/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/PartitionedStreamOperatorState.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/PartitionedStreamOperatorState.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/PartitionedStreamOperatorState.java
index 4f0200f..abe32da 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/PartitionedStreamOperatorState.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/PartitionedStreamOperatorState.java
@@ -87,6 +87,9 @@ public class PartitionedStreamOperatorState<IN, S, C extends Serializable> exten
 
 	@Override
 	public void updateState(S state) {
+		if (state == null) {
+			throw new RuntimeException("Cannot set state to null.");
+		}
 		if (currentInput == null) {
 			throw new RuntimeException("Need a valid input for updating a state.");
 		} else {

http://git-wip-us.apache.org/repos/asf/flink/blob/56ae08e2/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/StreamOperatorState.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/StreamOperatorState.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/StreamOperatorState.java
index bab3421..c19b485 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/StreamOperatorState.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/StreamOperatorState.java
@@ -25,7 +25,8 @@ import org.apache.flink.api.common.state.OperatorState;
 import org.apache.flink.api.common.state.StateCheckpointer;
 import org.apache.flink.runtime.state.StateHandle;
 import org.apache.flink.runtime.state.StateHandleProvider;
-import org.apache.flink.shaded.com.google.common.collect.ImmutableMap;
+
+import com.google.common.collect.ImmutableMap;
 
 /**
  * Implementation of the {@link OperatorState} interface for non-partitioned
@@ -63,11 +64,13 @@ public class StreamOperatorState<S, C extends Serializable> implements OperatorS
 
 	@Override
 	public void updateState(S state) {
+		if (state == null) {
+			throw new RuntimeException("Cannot set state to null.");
+		}
 		this.state = state;
 	}
 	
 	public void setDefaultState(S defaultState) {
-		// reconsider this as it might cause issues when setting the state to null
 		if (getState() == null) {
 			updateState(defaultState);
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/56ae08e2/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/WrapperStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/WrapperStateHandle.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/WrapperStateHandle.java
new file mode 100644
index 0000000..276f9e9
--- /dev/null
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/WrapperStateHandle.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.state;
+
+import java.io.Serializable;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.flink.runtime.state.LocalStateHandle;
+import org.apache.flink.runtime.state.PartitionedStateHandle;
+import org.apache.flink.runtime.state.StateHandle;
+
+/**
+ * StateHandle that wraps the StateHandles for the operator states of chained
+ * tasks. This is needed so the wrapped handles are properly discarded.
+ * 
+ */
+public class WrapperStateHandle extends LocalStateHandle<Serializable> {
+
+	private static final long serialVersionUID = 1L;
+
+	public WrapperStateHandle(List<Map<String, PartitionedStateHandle>> state) {
+		super((Serializable) state);
+	}
+
+	@Override
+	public void discardState() throws Exception {
+		@SuppressWarnings("unchecked")
+		List<Map<String, PartitionedStateHandle>> chainedStates = (List<Map<String, PartitionedStateHandle>>) getState();
+		for (Map<String, PartitionedStateHandle> stateMap : chainedStates) {
+			if(stateMap != null) {
+				for (PartitionedStateHandle statePartitions : stateMap.values()) {
+					for (StateHandle<Serializable> handle : statePartitions.getState().values()) {
+						handle.discardState();
+					}
+				}
+			}
+		}
+	}
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/56ae08e2/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/OutputHandler.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/OutputHandler.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/OutputHandler.java
index a7755f6..6e88610 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/OutputHandler.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/OutputHandler.java
@@ -234,7 +234,6 @@ public class OutputHandler<OUT> {
 		}
 
 		@Override
-		@SuppressWarnings("unchecked")
 		public void collect(T record) {
 			try {
 				operator.getRuntimeContext().setNextInput(record);
@@ -262,13 +261,13 @@ public class OutputHandler<OUT> {
 	private static class CopyingOperatorCollector<T> extends OperatorCollector<T> {
 		private final TypeSerializer<T> serializer;
 
-		public CopyingOperatorCollector(OneInputStreamOperator<?, T> operator, TypeSerializer<T> serializer) {
+		@SuppressWarnings({ "rawtypes", "unchecked" })
+		public CopyingOperatorCollector(OneInputStreamOperator operator, TypeSerializer<T> serializer) {
 			super(operator);
 			this.serializer = serializer;
 		}
 
 		@Override
-		@SuppressWarnings("unchecked")
 		public void collect(T record) {
 			try {
 				operator.processElement(serializer.copy(record));

http://git-wip-us.apache.org/repos/asf/flink/blob/56ae08e2/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index d4ec51c..2b8964d 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -21,6 +21,8 @@ import java.io.IOException;
 import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
 
 import org.apache.commons.collections.CollectionUtils;
 import org.apache.commons.collections.functors.NotNullPredicate;
@@ -35,6 +37,7 @@ import org.apache.flink.runtime.jobgraph.tasks.CheckpointedOperator;
 import org.apache.flink.runtime.jobgraph.tasks.OperatorStateCarrier;
 import org.apache.flink.runtime.state.FileStateHandle;
 import org.apache.flink.runtime.state.LocalStateHandle;
+import org.apache.flink.runtime.state.PartitionedStateHandle;
 import org.apache.flink.runtime.state.StateHandle;
 import org.apache.flink.runtime.state.StateHandleProvider;
 import org.apache.flink.runtime.util.SerializedValue;
@@ -43,6 +46,7 @@ import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.api.operators.StatefulStreamOperator;
 import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.state.WrapperStateHandle;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -199,40 +203,35 @@ public abstract class StreamTask<OUT, O extends StreamOperator<OUT>> extends Abs
 	//  Checkpoint and Restore
 	// ------------------------------------------------------------------------
 
+	@SuppressWarnings("unchecked")
 	@Override
 	public void setInitialState(StateHandle<Serializable> stateHandle) throws Exception {
 		// here, we later resolve the state handle into the actual state by
 		// loading the state described by the handle from the backup store
 		Serializable state = stateHandle.getState();
 
-		if (hasChainedOperators) {
-			@SuppressWarnings("unchecked")
-			List<Serializable> chainedStates = (List<Serializable>) state;
+		List<Serializable> chainedStates = (List<Serializable>) state;
 
-			Serializable headState = chainedStates.get(0);
-			if (headState != null) {
-				if (streamOperator instanceof StatefulStreamOperator) {
-					((StatefulStreamOperator<?>) streamOperator).restoreInitialState(headState);
-				}
+		Serializable headState = chainedStates.get(0);
+		if (headState != null) {
+			if (streamOperator instanceof StatefulStreamOperator) {
+				((StatefulStreamOperator<?>) streamOperator)
+						.restoreInitialState((Map<String, PartitionedStateHandle>) headState);
 			}
+		}
 
-			for (int i = 1; i < chainedStates.size(); i++) {
-				Serializable chainedState = chainedStates.get(i);
-				if (chainedState != null) {
-					StreamOperator<?> chainedOperator = outputHandler.getChainedOperators().get(i - 1);
-					if (chainedOperator instanceof StatefulStreamOperator) {
-						((StatefulStreamOperator<?>) chainedOperator).restoreInitialState(chainedState);
-					}
-
+		for (int i = 1; i < chainedStates.size(); i++) {
+			Serializable chainedState = chainedStates.get(i);
+			if (chainedState != null) {
+				StreamOperator<?> chainedOperator = outputHandler.getChainedOperators().get(i - 1);
+				if (chainedOperator instanceof StatefulStreamOperator) {
+					((StatefulStreamOperator<?>) chainedOperator)
+							.restoreInitialState((Map<String, PartitionedStateHandle>) chainedState);
 				}
-			}
 
-		} else {
-			if (streamOperator instanceof StatefulStreamOperator) {
-				((StatefulStreamOperator<?>) streamOperator).restoreInitialState(state);
 			}
-
 		}
+
 	}
 
 	@Override
@@ -244,34 +243,26 @@ public abstract class StreamTask<OUT, O extends StreamOperator<OUT>> extends Abs
 					LOG.debug("Starting checkpoint {} on task {}", checkpointId, getName());
 					
 					// first draw the state that should go into checkpoint
-					StateHandle<Serializable> state;
+					List<Map<String, PartitionedStateHandle>> chainedStates = new ArrayList<Map<String, PartitionedStateHandle>>();
+					StateHandle<Serializable> stateHandle;
 					try {
 
-						Serializable userState = null;
-
 						if (streamOperator instanceof StatefulStreamOperator) {
-							userState = ((StatefulStreamOperator<?>) streamOperator).getStateSnapshotFromFunction(checkpointId, timestamp);
+							chainedStates.add(((StatefulStreamOperator<?>) streamOperator).getStateSnapshotFromFunction(checkpointId, timestamp));
 						}
 
 
 						if (hasChainedOperators) {
 							// We construct a list of states for chained tasks
-							List<Serializable> chainedStates = new ArrayList<Serializable>();
-
-							chainedStates.add(userState);
-
 							for (OneInputStreamOperator<?, ?> chainedOperator : outputHandler.getChainedOperators()) {
 								if (chainedOperator instanceof StatefulStreamOperator) {
 									chainedStates.add(((StatefulStreamOperator<?>) chainedOperator).getStateSnapshotFromFunction(checkpointId, timestamp));
 								}
 							}
-
-							userState = CollectionUtils.exists(chainedStates,
-									NotNullPredicate.INSTANCE) ? (Serializable) chainedStates
-									: null;
 						}
 						
-						state = userState == null ? null : stateHandleProvider.createStateHandle(userState);
+						stateHandle = CollectionUtils.exists(chainedStates,
+								NotNullPredicate.INSTANCE) ? new WrapperStateHandle(chainedStates) : null;
 					}
 					catch (Exception e) {
 						throw new Exception("Error while drawing snapshot of the user state.", e);
@@ -281,10 +272,10 @@ public abstract class StreamTask<OUT, O extends StreamOperator<OUT>> extends Abs
 					outputHandler.broadcastBarrier(checkpointId, timestamp);
 					
 					// now confirm the checkpoint
-					if (state == null) {
+					if (stateHandle == null) {
 						getEnvironment().acknowledgeCheckpoint(checkpointId);
 					} else {
-						getEnvironment().acknowledgeCheckpoint(checkpointId, state);
+						getEnvironment().acknowledgeCheckpoint(checkpointId, stateHandle);
 					}
 				}
 				catch (Exception e) {
@@ -299,19 +290,37 @@ public abstract class StreamTask<OUT, O extends StreamOperator<OUT>> extends Abs
 
 	@SuppressWarnings({ "unchecked", "rawtypes" })
 	@Override
-	public void confirmCheckpoint(long checkpointId, SerializedValue<StateHandle<?>> state) throws Exception {
+	public void confirmCheckpoint(long checkpointId, SerializedValue<StateHandle<?>> stateHandle) throws Exception {
 		// we do nothing here so far. this should call commit on the source function, for example
 		synchronized (checkpointLock) {
-			if (streamOperator instanceof StatefulStreamOperator) {
-				((StatefulStreamOperator) streamOperator).confirmCheckpointCompleted(checkpointId,
-						state.deserializeValue(getUserCodeClassLoader()));
+			
+			List<Map<String, PartitionedStateHandle>> chainedStates = (List<Map<String, PartitionedStateHandle>>) stateHandle.deserializeValue(getUserCodeClassLoader()).getState();
+
+			Map<String, PartitionedStateHandle> headState = chainedStates.get(0);
+			if (headState != null) {
+				if (streamOperator instanceof StatefulStreamOperator) {
+					for (Entry<String, PartitionedStateHandle> stateEntry : headState
+							.entrySet()) {
+						for (StateHandle<Serializable> handle : stateEntry.getValue().getState().values()) {
+							((StatefulStreamOperator) streamOperator).confirmCheckpointCompleted(
+									checkpointId, stateEntry.getKey(), handle);
+						}
+					}
+				}
 			}
 
-			if (hasChainedOperators) {
-				for (OneInputStreamOperator<?, ?> chainedOperator : outputHandler.getChainedOperators()) {
+			for (int i = 1; i < chainedStates.size(); i++) {
+				Map<String, PartitionedStateHandle> chainedState = chainedStates.get(i);
+				if (chainedState != null) {
+					StreamOperator<?> chainedOperator = outputHandler.getChainedOperators().get(i - 1);
 					if (chainedOperator instanceof StatefulStreamOperator) {
-						((StatefulStreamOperator) chainedOperator).confirmCheckpointCompleted(checkpointId,
-								state.deserializeValue(getUserCodeClassLoader()));
+						for (Entry<String, PartitionedStateHandle> stateEntry : chainedState
+								.entrySet()) {
+							for (StateHandle<Serializable> handle : stateEntry.getValue().getState().values()) {
+								((StatefulStreamOperator) chainedOperator).confirmCheckpointCompleted(
+										checkpointId, stateEntry.getKey(), handle);
+							}
+						}
 					}
 				}
 			}

http://git-wip-us.apache.org/repos/asf/flink/blob/56ae08e2/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamingRuntimeContext.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamingRuntimeContext.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamingRuntimeContext.java
index 77639a0..15e5f1a 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamingRuntimeContext.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamingRuntimeContext.java
@@ -82,13 +82,10 @@ public class StreamingRuntimeContext extends RuntimeUDFContext {
 	@Override
 	public <S, C extends Serializable> OperatorState<S> getOperatorState(String name,
 			S defaultState, StateCheckpointer<S, C> checkpointer) {
-		StreamOperatorState state;
-		if (states.containsKey(name)) {
-			state = states.get(name);
-		} else {
-			state = createRawState();
-			states.put(name, state);
+		if (defaultState == null) {
+			throw new RuntimeException("Cannot set default state to null.");
 		}
+		StreamOperatorState<S, C> state = (StreamOperatorState<S, C>) getState(name);
 		state.setDefaultState(defaultState);
 		state.setCheckpointer(checkpointer);
 
@@ -98,18 +95,24 @@ public class StreamingRuntimeContext extends RuntimeUDFContext {
 	@SuppressWarnings("unchecked")
 	@Override
 	public <S extends Serializable> OperatorState<S> getOperatorState(String name, S defaultState) {
-		StreamOperatorState state;
-		if (states.containsKey(name)) {
-			state = states.get(name);
-		} else {
-			state = createRawState();
-			states.put(name, state);
+		if (defaultState == null) {
+			throw new RuntimeException("Cannot set default state to null.");
 		}
+		StreamOperatorState<S, S> state = (StreamOperatorState<S, S>) getState(name);
 		state.setDefaultState(defaultState);
 
 		return (OperatorState<S>) state;
 	}
 
+	private StreamOperatorState<?, ?> getState(String name) {
+		StreamOperatorState state = states.get(name);
+		if (state == null) {
+			state = createRawState();
+			states.put(name, state);
+		}
+		return state;
+	}
+
 	/**
 	 * Creates an empty state depending on the partitioning state.
 	 * 

http://git-wip-us.apache.org/repos/asf/flink/blob/56ae08e2/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/state/StatefulOperatorTest.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/state/StatefulOperatorTest.java b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/state/StatefulOperatorTest.java
index 99015c9..3b4e0e2 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/state/StatefulOperatorTest.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/state/StatefulOperatorTest.java
@@ -19,11 +19,13 @@
 package org.apache.flink.streaming.api.state;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
 
 import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
+import java.util.Map;
 
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.functions.RichMapFunction;
@@ -33,13 +35,15 @@ import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.operators.testutils.MockEnvironment;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
 import org.apache.flink.runtime.state.LocalStateHandle.LocalStateHandleProvider;
-import org.apache.flink.shaded.com.google.common.collect.ImmutableMap;
+import org.apache.flink.runtime.state.PartitionedStateHandle;
 import org.apache.flink.streaming.api.operators.Output;
 import org.apache.flink.streaming.api.operators.StreamMap;
 import org.apache.flink.streaming.runtime.tasks.StreamingRuntimeContext;
 import org.apache.flink.util.InstantiationUtil;
 import org.junit.Test;
 
+import com.google.common.collect.ImmutableMap;
+
 /**
  * Test the functionality supported by stateful user functions for both
  * partitioned and non-partitioned user states. This test mimics the runtime
@@ -120,6 +124,7 @@ public class StatefulOperatorTest {
 		}
 	}
 
+	@SuppressWarnings("unchecked")
 	private StreamMap<Integer, String> createOperatorWithContext(List<String> output,
 			KeySelector<Integer, Serializable> partitioner, byte[] serializedState) throws Exception {
 		final List<String> outputList = output;
@@ -143,7 +148,7 @@ public class StatefulOperatorTest {
 		}, context);
 
 		if (serializedState != null) {
-			op.restoreInitialState((Serializable) InstantiationUtil.deserializeObject(serializedState, Thread
+			op.restoreInitialState((Map<String, PartitionedStateHandle>) InstantiationUtil.deserializeObject(serializedState, Thread
 					.currentThread().getContextClassLoader()));
 		}
 
@@ -162,6 +167,11 @@ public class StatefulOperatorTest {
 		public String map(Integer value) throws Exception {
 			counter.updateState(counter.getState() + 1);
 			concat.updateState(concat.getState() + value.toString());
+			try {
+				counter.updateState(null);
+				fail();
+			} catch (RuntimeException e){
+			}
 			return value.toString();
 		}
 
@@ -169,6 +179,16 @@ public class StatefulOperatorTest {
 		public void open(Configuration conf) {
 			counter = getRuntimeContext().getOperatorState("counter", 0);
 			concat = getRuntimeContext().getOperatorState("concat", "");
+			try {
+				getRuntimeContext().getOperatorState("test", null);
+				fail();
+			} catch (RuntimeException e){
+			}
+			try {
+				getRuntimeContext().getOperatorState("test", null, null);
+				fail();
+			} catch (RuntimeException e){
+			}
 		}
 	}
 	

http://git-wip-us.apache.org/repos/asf/flink/blob/56ae08e2/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/util/SourceFunctionUtil.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/util/SourceFunctionUtil.java b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/util/SourceFunctionUtil.java
index 5cc5346..dafd9a3 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/util/SourceFunctionUtil.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/util/SourceFunctionUtil.java
@@ -17,6 +17,7 @@
 
 package org.apache.flink.streaming.util;
 
+import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.List;
 
@@ -26,7 +27,7 @@ import org.apache.flink.api.common.functions.RuntimeContext;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.operators.testutils.MockEnvironment;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
-import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
+import org.apache.flink.runtime.state.LocalStateHandle.LocalStateHandleProvider;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
 import org.apache.flink.streaming.runtime.tasks.StreamingRuntimeContext;
 import org.apache.flink.util.Collector;
@@ -37,7 +38,7 @@ public class SourceFunctionUtil<T> {
 		List<T> outputs = new ArrayList<T>();
 		if (sourceFunction instanceof RichFunction) {
 			RuntimeContext runtimeContext =  new StreamingRuntimeContext("MockTask", new MockEnvironment(3 * 1024 * 1024, new MockInputSplitProvider(), 1024), null,
-					new ExecutionConfig());
+					new ExecutionConfig(), null, new LocalStateHandleProvider<Serializable>());
 			((RichFunction) sourceFunction).setRuntimeContext(runtimeContext);
 
 			((RichFunction) sourceFunction).open(new Configuration());