You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by gy...@apache.org on 2015/07/30 16:45:38 UTC

[1/3] flink git commit: [FLINK-2324] [streaming] Added test for different StateHandle wrappers

Repository: flink
Updated Branches:
  refs/heads/master 1b3bdce5c -> 83e14cb15


[FLINK-2324] [streaming] Added test for different StateHandle wrappers

Closes #937


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/83e14cb1
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/83e14cb1
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/83e14cb1

Branch: refs/heads/master
Commit: 83e14cb15a41c417fd7d024c383a4f4d50ec5a19
Parents: 58cd4ea
Author: Gyula Fora <gy...@apache.org>
Authored: Thu Jul 30 07:26:40 2015 +0200
Committer: Gyula Fora <gy...@apache.org>
Committed: Thu Jul 30 16:44:53 2015 +0200

----------------------------------------------------------------------
 .../operators/AbstractUdfStreamOperator.java    |   8 +-
 .../api/state/OperatorStateHandle.java          |   4 +
 .../streaming/api/state/StateHandleTest.java    | 134 +++++++++++++++++++
 .../StreamCheckpointingITCase.java              |  14 +-
 4 files changed, 156 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/83e14cb1/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
index f21aacc..585b4ce 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
@@ -79,7 +79,7 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function & Serial
 	@SuppressWarnings({ "unchecked", "rawtypes" })
 	public void restoreInitialState(Tuple2<StateHandle<Serializable>, Map<String, OperatorStateHandle>> snapshots) throws Exception {
 		// Restore state using the Checkpointed interface
-		if (userFunction instanceof Checkpointed) {
+		if (userFunction instanceof Checkpointed && snapshots.f0 != null) {
 			((Checkpointed) userFunction).restoreState(snapshots.f0.getState());
 		}
 		
@@ -122,8 +122,10 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function & Serial
 		// if the UDF implements the Checkpointed interface we draw a snapshot
 		if (userFunction instanceof Checkpointed) {
 			StateHandleProvider<Serializable> provider = runtimeContext.getStateHandleProvider();
-			checkpointedSnapshot = provider.createStateHandle(((Checkpointed) userFunction)
-					.snapshotState(checkpointId, timestamp));
+			Serializable state = ((Checkpointed) userFunction).snapshotState(checkpointId, timestamp);
+			if (state != null) {
+				checkpointedSnapshot = provider.createStateHandle(state);
+			}
 		}
 		
 		// if we have either operator or checkpointed state we store it in a

http://git-wip-us.apache.org/repos/asf/flink/blob/83e14cb1/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/OperatorStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/OperatorStateHandle.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/OperatorStateHandle.java
index 87536ed..f308ba8 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/OperatorStateHandle.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/OperatorStateHandle.java
@@ -46,5 +46,9 @@ public class OperatorStateHandle implements StateHandle<Serializable> {
 	public void discardState() throws Exception {
 		handle.discardState();
 	}
+	
+	public StateHandle<Serializable> getHandle() {
+		return handle;
+	}
 
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/83e14cb1/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/state/StateHandleTest.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/state/StateHandleTest.java b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/state/StateHandleTest.java
new file mode 100644
index 0000000..38117e8
--- /dev/null
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/state/StateHandleTest.java
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.state;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.runtime.state.StateHandle;
+import org.apache.flink.util.InstantiationUtil;
+import org.junit.Test;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+
+public class StateHandleTest {
+
+	@Test
+	public void operatorStateHandleTest() throws Exception {
+
+		MockHandle<Serializable> h1 = new MockHandle<Serializable>(1);
+
+		OperatorStateHandle opHandle = new OperatorStateHandle(h1, true);
+		assertEquals(1, opHandle.getState());
+
+		OperatorStateHandle dsHandle = serializeDeserialize(opHandle);
+		MockHandle<Serializable> h2 = (MockHandle<Serializable>) dsHandle.getHandle();
+		assertFalse(h2.discarded);
+		assertNotNull(h1.state);
+		assertNull(h2.state);
+
+		dsHandle.discardState();
+
+		assertTrue(h2.discarded);
+	}
+
+	@Test
+	public void wrapperStateHandleTest() throws Exception {
+
+		MockHandle<Serializable> h1 = new MockHandle<Serializable>(1);
+		MockHandle<Serializable> h2 = new MockHandle<Serializable>(2);
+		StateHandle<Serializable> h3 = new MockHandle<Serializable>(3);
+
+		OperatorStateHandle opH1 = new OperatorStateHandle(h1, true);
+		OperatorStateHandle opH2 = new OperatorStateHandle(h2, false);
+
+		Map<String, OperatorStateHandle> opHandles = ImmutableMap.of("h1", opH1, "h2", opH2);
+
+		Tuple2<StateHandle<Serializable>, Map<String, OperatorStateHandle>> fullState = Tuple2.of(h3,
+				opHandles);
+
+		List<Tuple2<StateHandle<Serializable>, Map<String, OperatorStateHandle>>> chainedStates = ImmutableList
+				.of(fullState);
+
+		WrapperStateHandle wrapperHandle = new WrapperStateHandle(chainedStates);
+
+		WrapperStateHandle dsWrapper = serializeDeserialize(wrapperHandle);
+
+		@SuppressWarnings("unchecked")
+		Tuple2<StateHandle<Serializable>, Map<String, OperatorStateHandle>> dsFullState = ((List<Tuple2<StateHandle<Serializable>, Map<String, OperatorStateHandle>>>) dsWrapper
+				.getState()).get(0);
+
+		Map<String, OperatorStateHandle> dsOpHandles = dsFullState.f1;
+
+		assertNull(dsFullState.f0.getState());
+		assertFalse(((MockHandle<?>) dsFullState.f0).discarded);
+		assertFalse(((MockHandle<?>) dsOpHandles.get("h1").getHandle()).discarded);
+		assertNull(dsOpHandles.get("h1").getState());
+		assertFalse(((MockHandle<?>) dsOpHandles.get("h2").getHandle()).discarded);
+		assertNull(dsOpHandles.get("h2").getState());
+
+		dsWrapper.discardState();
+
+		assertTrue(((MockHandle<?>) dsFullState.f0).discarded);
+		assertTrue(((MockHandle<?>) dsOpHandles.get("h1").getHandle()).discarded);
+		assertTrue(((MockHandle<?>) dsOpHandles.get("h2").getHandle()).discarded);
+
+	}
+
+	@SuppressWarnings("unchecked")
+	private <X extends StateHandle<?>> X serializeDeserialize(X handle) throws IOException,
+			ClassNotFoundException {
+		byte[] serialized = InstantiationUtil.serializeObject(handle);
+		return (X) InstantiationUtil.deserializeObject(serialized, Thread.currentThread()
+				.getContextClassLoader());
+	}
+
+	@SuppressWarnings("serial")
+	private static class MockHandle<T> implements StateHandle<T> {
+
+		boolean discarded = false;
+		transient T state;
+
+		public MockHandle(T state) {
+			this.state = state;
+		}
+
+		@Override
+		public void discardState() {
+			state = null;
+			discarded = true;
+		}
+
+		@Override
+		public T getState() {
+			return state;
+		}
+	}
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/83e14cb1/flink-tests/src/test/java/org/apache/flink/test/checkpointing/StreamCheckpointingITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/StreamCheckpointingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/StreamCheckpointingITCase.java
index 3f99fa0..93dda5f 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/StreamCheckpointingITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/StreamCheckpointingITCase.java
@@ -372,7 +372,8 @@ public class StreamCheckpointingITCase {
 		}
 	}
 
-	private static class StringPrefixCountRichMapFunction extends RichMapFunction<String, PrefixCount> {
+	private static class StringPrefixCountRichMapFunction extends RichMapFunction<String, PrefixCount>
+			implements Checkpointed<Integer> {
 
 		OperatorState<Long> count;
 		static final long[] counts = new long[PARALLELISM];
@@ -392,5 +393,16 @@ public class StreamCheckpointingITCase {
 		public void close() throws IOException {
 			counts[getRuntimeContext().getIndexOfThisSubtask()] = count.value();
 		}
+
+		@Override
+		public Integer snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
+			return null;
+		}
+
+		@Override
+		public void restoreState(Integer state) {
+			// verify that we never store/restore null state
+			fail();
+		}
 	}
 }


[2/3] flink git commit: [FLINK-2324] [streaming] ITCase added for partitioned states

Posted by gy...@apache.org.
[FLINK-2324] [streaming] ITCase added for partitioned states


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/58cd4ea6
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/58cd4ea6
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/58cd4ea6

Branch: refs/heads/master
Commit: 58cd4ea615bbbbec682ec82f6380bc30c12c5899
Parents: 0558644
Author: Gyula Fora <gy...@apache.org>
Authored: Tue Jul 28 15:46:13 2015 +0200
Committer: Gyula Fora <gy...@apache.org>
Committed: Thu Jul 30 16:44:53 2015 +0200

----------------------------------------------------------------------
 .../PartitionedStateCheckpointingITCase.java    | 257 +++++++++++++++++++
 1 file changed, 257 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/58cd4ea6/flink-tests/src/test/java/org/apache/flink/test/checkpointing/PartitionedStateCheckpointingITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/PartitionedStateCheckpointingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/PartitionedStateCheckpointingITCase.java
new file mode 100644
index 0000000..88361e2
--- /dev/null
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/PartitionedStateCheckpointingITCase.java
@@ -0,0 +1,257 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.test.checkpointing;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Random;
+import java.util.concurrent.ConcurrentHashMap;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.OperatorState;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.ConfigConstants;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
+import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.test.util.ForkableFlinkMiniCluster;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+/**
+ * A simple test that runs a streaming topology with checkpointing enabled.
+ * 
+ * The test triggers a failure after a while and verifies that, after
+ * completion, the state reflects the "exactly once" semantics.
+ * 
+ * It is designed to check partitioned states.
+ */
+@SuppressWarnings("serial")
+public class PartitionedStateCheckpointingITCase {
+
+	private static final int NUM_TASK_MANAGERS = 2;
+	private static final int NUM_TASK_SLOTS = 3;
+	private static final int PARALLELISM = NUM_TASK_MANAGERS * NUM_TASK_SLOTS;
+
+	private static ForkableFlinkMiniCluster cluster;
+
+	@BeforeClass
+	public static void startCluster() {
+		try {
+			Configuration config = new Configuration();
+			config.setInteger(ConfigConstants.LOCAL_INSTANCE_MANAGER_NUMBER_TASK_MANAGER, NUM_TASK_MANAGERS);
+			config.setInteger(ConfigConstants.TASK_MANAGER_NUM_TASK_SLOTS, NUM_TASK_SLOTS);
+			config.setString(ConfigConstants.DEFAULT_EXECUTION_RETRY_DELAY_KEY, "0 ms");
+			config.setInteger(ConfigConstants.TASK_MANAGER_MEMORY_SIZE_KEY, 12);
+
+			cluster = new ForkableFlinkMiniCluster(config, false);
+		} catch (Exception e) {
+			e.printStackTrace();
+			fail("Failed to start test cluster: " + e.getMessage());
+		}
+	}
+
+	@AfterClass
+	public static void shutdownCluster() {
+		try {
+			cluster.shutdown();
+			cluster = null;
+		} catch (Exception e) {
+			e.printStackTrace();
+			fail("Failed to stop test cluster: " + e.getMessage());
+		}
+	}
+
+	@SuppressWarnings("unchecked")
+	@Test
+	public void runCheckpointedProgram() {
+
+		final long NUM_STRINGS = 10000000L;
+		assertTrue("Broken test setup", (NUM_STRINGS/2) % 40 == 0);
+
+		try {
+			StreamExecutionEnvironment env = StreamExecutionEnvironment.createRemoteEnvironment("localhost",
+					cluster.getJobManagerRPCPort());
+			env.setParallelism(PARALLELISM);
+			env.enableCheckpointing(500);
+			env.getConfig().disableSysoutLogging();
+
+			DataStream<Integer> stream1 = env.addSource(new IntGeneratingSourceFunction(NUM_STRINGS / 2));
+			DataStream<Integer> stream2 = env.addSource(new IntGeneratingSourceFunction(NUM_STRINGS / 2));
+
+			stream1.union(stream2)
+					.groupBy(new IdentityKeySelector<Integer>())
+					.map(new OnceFailingPartitionedSum(NUM_STRINGS))
+					.keyBy(0)
+					.addSink(new CounterSink());
+
+			env.execute();
+
+			// verify that we counted exactly right
+			for (Entry<Integer, Long> sum : OnceFailingPartitionedSum.allSums.entrySet()) {
+				assertEquals(new Long(sum.getKey() * NUM_STRINGS / 40), sum.getValue());
+			}
+			System.out.println("new");
+			for (Long count : CounterSink.allCounts.values()) {
+				assertEquals(new Long(NUM_STRINGS / 40), count);
+			}
+			
+			assertEquals(40, CounterSink.allCounts.size());
+			assertEquals(40, OnceFailingPartitionedSum.allSums.size());
+			
+		} catch (Exception e) {
+			e.printStackTrace();
+			fail(e.getMessage());
+		}
+	}
+
+	// --------------------------------------------------------------------------------------------
+	// Custom Functions
+	// --------------------------------------------------------------------------------------------
+
+	private static class IntGeneratingSourceFunction extends RichParallelSourceFunction<Integer> {
+
+		private final long numElements;
+
+		private OperatorState<Integer> index;
+		private int step;
+
+		private volatile boolean isRunning;
+
+		static final long[] counts = new long[PARALLELISM];
+
+		@Override
+		public void close() throws IOException {
+			counts[getRuntimeContext().getIndexOfThisSubtask()] = index.value();
+		}
+
+		IntGeneratingSourceFunction(long numElements) {
+			this.numElements = numElements;
+		}
+
+		@Override
+		public void open(Configuration parameters) throws IOException {
+			step = getRuntimeContext().getNumberOfParallelSubtasks();
+
+			index = getRuntimeContext().getOperatorState("index",
+					getRuntimeContext().getIndexOfThisSubtask(), false);
+
+			isRunning = true;
+		}
+
+		@Override
+		public void run(SourceContext<Integer> ctx) throws Exception {
+			final Object lockingObject = ctx.getCheckpointLock();
+
+			while (isRunning && index.value() < numElements) {
+
+				synchronized (lockingObject) {
+					index.update(index.value() + step);
+					ctx.collect(index.value() % 40);
+				}
+			}
+		}
+
+		@Override
+		public void cancel() {
+			isRunning = false;
+		}
+	}
+
+	private static class OnceFailingPartitionedSum extends RichMapFunction<Integer, Tuple2<Integer, Long>> {
+
+		private static Map<Integer, Long> allSums = new ConcurrentHashMap<Integer, Long>();
+		private static volatile boolean hasFailed = false;
+
+		private final long numElements;
+
+		private long failurePos;
+		private long count;
+
+		private OperatorState<Long> sum;
+
+		OnceFailingPartitionedSum(long numElements) {
+			this.numElements = numElements;
+		}
+
+		@Override
+		public void open(Configuration parameters) throws IOException {
+			long failurePosMin = (long) (0.4 * numElements / getRuntimeContext()
+					.getNumberOfParallelSubtasks());
+			long failurePosMax = (long) (0.7 * numElements / getRuntimeContext()
+					.getNumberOfParallelSubtasks());
+
+			failurePos = (new Random().nextLong() % (failurePosMax - failurePosMin)) + failurePosMin;
+			count = 0;
+			sum = getRuntimeContext().getOperatorState("sum", 0L, true);
+		}
+
+		@Override
+		public Tuple2<Integer, Long> map(Integer value) throws Exception {
+			count++;
+			if (!hasFailed && count >= failurePos) {
+				hasFailed = true;
+				throw new Exception("Test Failure");
+			}
+
+			long currentSum = sum.value() + value;
+			sum.update(currentSum);
+			allSums.put(value, currentSum);
+			return new Tuple2<Integer, Long>(value, currentSum);
+		}
+	}
+
+	private static class CounterSink extends RichSinkFunction<Tuple2<Integer, Long>> {
+
+		private static Map<Integer, Long> allCounts = new ConcurrentHashMap<Integer, Long>();
+
+		private OperatorState<Long> counts;
+
+		@Override
+		public void open(Configuration parameters) throws IOException {
+			counts = getRuntimeContext().getOperatorState("count", 0L, true);
+		}
+
+		@Override
+		public void invoke(Tuple2<Integer, Long> value) throws Exception {
+			long currentCount = counts.value() + 1;
+			counts.update(currentCount);
+			allCounts.put(value.f0, currentCount);
+
+		}
+	}
+	
+	private static class IdentityKeySelector<T> implements KeySelector<T, T> {
+
+		@Override
+		public T getKey(T value) throws Exception {
+			return value;
+		}
+
+	}
+}


[3/3] flink git commit: [FLINK-2324] [streaming] Partitioned state checkpointing rework + test update

Posted by gy...@apache.org.
[FLINK-2324] [streaming] Partitioned state checkpointing rework + test update


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/0558644a
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/0558644a
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/0558644a

Branch: refs/heads/master
Commit: 0558644ae1a1c8e0f21867ce1963aaa625170690
Parents: 1b3bdce
Author: Gyula Fora <gy...@apache.org>
Authored: Mon Jul 27 08:56:13 2015 +0200
Committer: Gyula Fora <gy...@apache.org>
Committed: Thu Jul 30 16:44:53 2015 +0200

----------------------------------------------------------------------
 .../runtime/state/PartitionedStateStore.java    |  51 --------
 .../operators/AbstractUdfStreamOperator.java    |  35 +++---
 .../api/operators/StatefulStreamOperator.java   |   6 +-
 .../streaming/api/state/EagerStateStore.java    |  34 +++---
 .../streaming/api/state/LazyStateStore.java     | 122 -------------------
 .../api/state/OperatorStateHandle.java          |  50 ++++++++
 .../api/state/PartitionedStateStore.java        |  52 ++++++++
 .../state/PartitionedStreamOperatorState.java   |   7 +-
 .../api/state/StreamOperatorState.java          |  24 ++--
 .../streaming/api/state/WrapperStateHandle.java |  13 +-
 .../streaming/runtime/tasks/StreamTask.java     |  16 ++-
 .../api/state/StatefulOperatorTest.java         |   3 +-
 .../StreamCheckpointingITCase.java              |  99 ++++++++-------
 13 files changed, 223 insertions(+), 289 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/0558644a/flink-runtime/src/main/java/org/apache/flink/runtime/state/PartitionedStateStore.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/PartitionedStateStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/PartitionedStateStore.java
deleted file mode 100644
index 6353eda..0000000
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/PartitionedStateStore.java
+++ /dev/null
@@ -1,51 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.runtime.state;
-
-import java.io.Serializable;
-import java.util.Map;
-
-import org.apache.flink.api.common.state.StateCheckpointer;
-
-/**
- * Interface for storing and accessing partitioned state. The interface is
- * designed in a way that allows implementations for lazily state access.
- * 
- * @param <S>
- *            Type of the state.
- * @param <C>
- *            Type of the state snapshot.
- */
-public interface PartitionedStateStore<S, C extends Serializable> {
-
-	S getStateForKey(Serializable key) throws Exception;
-
-	void setStateForKey(Serializable key, S state);
-
-	Map<Serializable, S> getPartitionedState() throws Exception;
-
-	Map<Serializable, StateHandle<C>> snapshotStates(long checkpointId, long checkpointTimestamp) throws Exception;
-
-	void restoreStates(Map<Serializable, StateHandle<C>> snapshots) throws Exception;
-
-	boolean containsKey(Serializable key);
-	
-	void setCheckPointer(StateCheckpointer<S, C> checkpointer);
-
-}

http://git-wip-us.apache.org/repos/asf/flink/blob/0558644a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
index 23c4ab8..f21aacc 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/operators/AbstractUdfStreamOperator.java
@@ -27,17 +27,16 @@ import org.apache.flink.api.common.functions.Function;
 import org.apache.flink.api.common.functions.util.FunctionUtils;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.configuration.Configuration;
-import org.apache.flink.runtime.state.PartitionedStateHandle;
 import org.apache.flink.runtime.state.StateHandle;
 import org.apache.flink.runtime.state.StateHandleProvider;
 import org.apache.flink.streaming.api.checkpoint.CheckpointNotifier;
 import org.apache.flink.streaming.api.checkpoint.Checkpointed;
-import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.api.state.OperatorStateHandle;
+import org.apache.flink.streaming.api.state.PartitionedStreamOperatorState;
 import org.apache.flink.streaming.api.state.StreamOperatorState;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.StreamingRuntimeContext;
 
-import com.google.common.collect.ImmutableMap;
-
 /**
  * This is used as the base class for operators that have a user-defined
  * function.
@@ -78,7 +77,7 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function & Serial
 
 	@Override
 	@SuppressWarnings({ "unchecked", "rawtypes" })
-	public void restoreInitialState(Tuple2<StateHandle<Serializable>, Map<String, PartitionedStateHandle>> snapshots) throws Exception {
+	public void restoreInitialState(Tuple2<StateHandle<Serializable>, Map<String, OperatorStateHandle>> snapshots) throws Exception {
 		// Restore state using the Checkpointed interface
 		if (userFunction instanceof Checkpointed) {
 			((Checkpointed) userFunction).restoreState(snapshots.f0.getState());
@@ -86,49 +85,51 @@ public abstract class AbstractUdfStreamOperator<OUT, F extends Function & Serial
 		
 		if (snapshots.f1 != null) {
 			// We iterate over the states registered for this operator, initialize and restore it
-			for (Entry<String, PartitionedStateHandle> snapshot : snapshots.f1.entrySet()) {
-				Map<Serializable, StateHandle<Serializable>> handles = snapshot.getValue().getState();
-				StreamOperatorState restoredState = runtimeContext.getState(snapshot.getKey(),
-						!(handles instanceof ImmutableMap));
-				restoredState.restoreState(snapshot.getValue().getState());
+			for (Entry<String, OperatorStateHandle> snapshot : snapshots.f1.entrySet()) {
+				StreamOperatorState restoredOpState = runtimeContext.getState(snapshot.getKey(), snapshot.getValue().isPartitioned());
+				StateHandle<Serializable> checkpointHandle = snapshot.getValue();
+				restoredOpState.restoreState(checkpointHandle);
 			}
 		}
 		
 	}
 
 	@SuppressWarnings({ "rawtypes", "unchecked" })
-	public Tuple2<StateHandle<Serializable>, Map<String, PartitionedStateHandle>> getStateSnapshotFromFunction(long checkpointId, long timestamp)
+	public Tuple2<StateHandle<Serializable>, Map<String, OperatorStateHandle>> getStateSnapshotFromFunction(long checkpointId, long timestamp)
 			throws Exception {
 		// Get all the states for the operator
 		Map<String, StreamOperatorState> operatorStates = runtimeContext.getOperatorStates();
 		
-		Map<String, PartitionedStateHandle> operatorStateSnapshots;
+		Map<String, OperatorStateHandle> operatorStateSnapshots;
 		if (operatorStates.isEmpty()) {
 			// We return null to signal that there is nothing to checkpoint
 			operatorStateSnapshots = null;
 		} else {
 			// Checkpoint the states and store the handles in a map
-			Map<String, PartitionedStateHandle> snapshots = new HashMap<String, PartitionedStateHandle>();
+			Map<String, OperatorStateHandle> snapshots = new HashMap<String, OperatorStateHandle>();
 
 			for (Entry<String, StreamOperatorState> state : operatorStates.entrySet()) {
+				boolean isPartitioned = state.getValue() instanceof PartitionedStreamOperatorState;
 				snapshots.put(state.getKey(),
-						new PartitionedStateHandle(state.getValue().snapshotState(checkpointId, timestamp)));
+						new OperatorStateHandle(state.getValue().snapshotState(checkpointId, timestamp),
+								isPartitioned));
 			}
 
 			operatorStateSnapshots = snapshots;
 		}
 		
 		StateHandle<Serializable> checkpointedSnapshot = null;
-
+		// if the UDF implements the Checkpointed interface we draw a snapshot
 		if (userFunction instanceof Checkpointed) {
 			StateHandleProvider<Serializable> provider = runtimeContext.getStateHandleProvider();
 			checkpointedSnapshot = provider.createStateHandle(((Checkpointed) userFunction)
 					.snapshotState(checkpointId, timestamp));
 		}
 		
+		// if we have either operator or checkpointed state we store it in a
+		// tuple2 otherwise return null
 		if (operatorStateSnapshots != null || checkpointedSnapshot != null) {
-			return new Tuple2<StateHandle<Serializable>, Map<String, PartitionedStateHandle>>(
-					checkpointedSnapshot, operatorStateSnapshots);
+			return Tuple2.of(checkpointedSnapshot, operatorStateSnapshots);
 		} else {
 			return null;
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/0558644a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/operators/StatefulStreamOperator.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/operators/StatefulStreamOperator.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/operators/StatefulStreamOperator.java
index afc36e0..d400fc4 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/operators/StatefulStreamOperator.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/operators/StatefulStreamOperator.java
@@ -21,8 +21,8 @@ import java.io.Serializable;
 import java.util.Map;
 
 import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.runtime.state.PartitionedStateHandle;
 import org.apache.flink.runtime.state.StateHandle;
+import org.apache.flink.streaming.api.state.OperatorStateHandle;
 
 /**
  * Interface for Stream operators that can have state. This interface is used for checkpointing
@@ -32,9 +32,9 @@ import org.apache.flink.runtime.state.StateHandle;
  */
 public interface StatefulStreamOperator<OUT> extends StreamOperator<OUT> {
 
-	void restoreInitialState(Tuple2<StateHandle<Serializable>, Map<String, PartitionedStateHandle>> state) throws Exception;
+	void restoreInitialState(Tuple2<StateHandle<Serializable>, Map<String, OperatorStateHandle>> state) throws Exception;
 
-	Tuple2<StateHandle<Serializable>, Map<String, PartitionedStateHandle>> getStateSnapshotFromFunction(long checkpointId, long timestamp) throws Exception;
+	Tuple2<StateHandle<Serializable>, Map<String, OperatorStateHandle>> getStateSnapshotFromFunction(long checkpointId, long timestamp) throws Exception;
 
 	void notifyCheckpointComplete(long checkpointId) throws Exception;
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/0558644a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/EagerStateStore.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/EagerStateStore.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/EagerStateStore.java
index 6d3bad6..213303a 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/EagerStateStore.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/EagerStateStore.java
@@ -24,20 +24,20 @@ import java.util.Map;
 import java.util.Map.Entry;
 
 import org.apache.flink.api.common.state.StateCheckpointer;
-import org.apache.flink.runtime.state.PartitionedStateStore;
 import org.apache.flink.runtime.state.StateHandle;
 import org.apache.flink.runtime.state.StateHandleProvider;
 
 public class EagerStateStore<S, C extends Serializable> implements PartitionedStateStore<S, C> {
 
-	private StateCheckpointer<S, C> checkpointer;
-	private final StateHandleProvider<C> provider;
+	private StateCheckpointer<S,C> checkpointer;
+	private final StateHandleProvider<Serializable> provider;
 
 	private Map<Serializable, S> fetchedState;
 
+	@SuppressWarnings("unchecked")
 	public EagerStateStore(StateCheckpointer<S, C> checkpointer, StateHandleProvider<C> provider) {
 		this.checkpointer = checkpointer;
-		this.provider = provider;
+		this.provider = (StateHandleProvider<Serializable>) provider;
 
 		fetchedState = new HashMap<Serializable, S>();
 	}
@@ -58,23 +58,25 @@ public class EagerStateStore<S, C extends Serializable> implements PartitionedSt
 	}
 
 	@Override
-	public Map<Serializable, StateHandle<C>> snapshotStates(long checkpointId,
-			long checkpointTimestamp) {
-
-		Map<Serializable, StateHandle<C>> handles = new HashMap<Serializable, StateHandle<C>>();
-
+	public StateHandle<Serializable> snapshotStates(long checkpointId, long checkpointTimestamp) {
+		// we map the values in the state-map using the state-checkpointer and store it as a checkpoint
+		Map<Serializable, C> checkpoints = new HashMap<Serializable, C>();
 		for (Entry<Serializable, S> stateEntry : fetchedState.entrySet()) {
-			handles.put(stateEntry.getKey(), provider.createStateHandle(checkpointer.snapshotState(
-					stateEntry.getValue(), checkpointId, checkpointTimestamp)));
+			checkpoints.put(stateEntry.getKey(),
+					checkpointer.snapshotState(stateEntry.getValue(), checkpointId, checkpointTimestamp));
 		}
-		return handles;
+		return provider.createStateHandle((Serializable) checkpoints);
 	}
 
 	@Override
-	public void restoreStates(Map<Serializable, StateHandle<C>> snapshots) throws Exception {
-		for (Entry<Serializable, StateHandle<C>> snapshotEntry : snapshots.entrySet()) {
-			fetchedState.put(snapshotEntry.getKey(),
-					checkpointer.restoreState(snapshotEntry.getValue().getState()));
+	public void restoreStates(StateHandle<Serializable> snapshot) throws Exception {
+		
+		@SuppressWarnings("unchecked")
+		Map<Serializable, C> checkpoints = (Map<Serializable, C>) snapshot.getState();
+		
+		// we map the values back to the state from the checkpoints
+		for (Entry<Serializable, C> snapshotEntry : checkpoints.entrySet()) {
+			fetchedState.put(snapshotEntry.getKey(), (S) checkpointer.restoreState(snapshotEntry.getValue()));
 		}
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/0558644a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/LazyStateStore.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/LazyStateStore.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/LazyStateStore.java
deleted file mode 100644
index 14484ea..0000000
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/LazyStateStore.java
+++ /dev/null
@@ -1,122 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.streaming.api.state;
-
-import java.io.Serializable;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Map.Entry;
-
-import org.apache.flink.api.common.state.StateCheckpointer;
-import org.apache.flink.runtime.state.PartitionedStateStore;
-import org.apache.flink.runtime.state.StateHandle;
-import org.apache.flink.runtime.state.StateHandleProvider;
-
-/**
- * Implementation of the {@link PartitionedStateStore} interface for lazy
- * retrieval and snapshotting of the partitioned operator states. Lazy state
- * access considerably speeds up recovery and makes resource access smoother by
- * avoiding request congestion in the persistent storage layer.
- * 
- * <p>
- * The logic implemented here can also be used later to push unused state to the
- * persistent layer and also avoids re-snapshotting the unmodified states.
- * </p>
- * 
- * @param <S>
- *            Type of the operator states.
- * @param <C>
- *            Type of the state checkpoints.
- */
-public class LazyStateStore<S, C extends Serializable> implements PartitionedStateStore<S, C> {
-
-	protected StateCheckpointer<S, C> checkpointer;
-	protected final StateHandleProvider<C> provider;
-
-	private final Map<Serializable, StateHandle<C>> unfetchedState;
-	private final Map<Serializable, S> fetchedState;
-
-	public LazyStateStore(StateCheckpointer<S, C> checkpointer, StateHandleProvider<C> provider) {
-		this.checkpointer = checkpointer;
-		this.provider = provider;
-
-		unfetchedState = new HashMap<Serializable, StateHandle<C>>();
-		fetchedState = new HashMap<Serializable, S>();
-	}
-
-	@Override
-	public S getStateForKey(Serializable key) throws Exception {
-		S state = fetchedState.get(key);
-		if (state != null) {
-			return state;
-		} else {
-			StateHandle<C> handle = unfetchedState.get(key);
-			if (handle != null) {
-				state = checkpointer.restoreState(handle.getState());
-				fetchedState.put(key, state);
-				unfetchedState.remove(key);
-				return state;
-			} else {
-				return null;
-			}
-		}
-	}
-
-	@Override
-	public void setStateForKey(Serializable key, S state) {
-		fetchedState.put(key, state);
-		unfetchedState.remove(key);
-	}
-
-	@Override
-	public Map<Serializable, S> getPartitionedState() throws Exception {
-		for (Entry<Serializable, StateHandle<C>> handleEntry : unfetchedState.entrySet()) {
-			fetchedState.put(handleEntry.getKey(),
-					checkpointer.restoreState(handleEntry.getValue().getState()));
-		}
-		unfetchedState.clear();
-		return fetchedState;
-	}
-
-	@Override
-	public Map<Serializable, StateHandle<C>> snapshotStates(long checkpointId,
-			long checkpointTimestamp) {
-		for (Entry<Serializable, S> stateEntry : fetchedState.entrySet()) {
-			unfetchedState.put(stateEntry.getKey(), provider.createStateHandle(checkpointer
-					.snapshotState(stateEntry.getValue(), checkpointId, checkpointTimestamp)));
-		}
-		return unfetchedState;
-	}
-
-	@Override
-	public void restoreStates(Map<Serializable, StateHandle<C>> snapshots) {
-		unfetchedState.putAll(snapshots);
-	}
-
-	@Override
-	public boolean containsKey(Serializable key) {
-		return fetchedState.containsKey(key) || unfetchedState.containsKey(key);
-	}
-
-	@Override
-	public void setCheckPointer(StateCheckpointer<S, C> checkpointer) {
-		this.checkpointer = checkpointer;
-	}
-
-}

http://git-wip-us.apache.org/repos/asf/flink/blob/0558644a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/OperatorStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/OperatorStateHandle.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/OperatorStateHandle.java
new file mode 100644
index 0000000..87536ed
--- /dev/null
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/OperatorStateHandle.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.state;
+
+import java.io.Serializable;
+
+import org.apache.flink.runtime.state.StateHandle;
+
+public class OperatorStateHandle implements StateHandle<Serializable> {
+	
+	private static final long serialVersionUID = 1L;
+	
+	private final StateHandle<Serializable> handle;
+	private final boolean isPartitioned;
+	
+	public OperatorStateHandle(StateHandle<Serializable> handle, boolean isPartitioned){
+		this.handle = handle;
+		this.isPartitioned = isPartitioned;
+	}
+	
+	public boolean isPartitioned(){
+		return isPartitioned;
+	}
+
+	@Override
+	public Serializable getState() throws Exception {
+		return handle.getState();
+	}
+
+	@Override
+	public void discardState() throws Exception {
+		handle.discardState();
+	}
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/0558644a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/PartitionedStateStore.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/PartitionedStateStore.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/PartitionedStateStore.java
new file mode 100644
index 0000000..5201058
--- /dev/null
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/PartitionedStateStore.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.state;
+
+import java.io.Serializable;
+import java.util.Map;
+
+import org.apache.flink.api.common.state.StateCheckpointer;
+import org.apache.flink.runtime.state.StateHandle;
+
+/**
+ * Interface for storing and accessing partitioned state. The interface is
+ * designed in a way that allows implementations for lazily state access.
+ * 
+ * @param <S>
+ *            Type of the state.
+ * @param <C>
+ *            Type of the state snapshot.
+ */
+public interface PartitionedStateStore<S, C extends Serializable> {
+
+	S getStateForKey(Serializable key) throws Exception;
+
+	void setStateForKey(Serializable key, S state);
+
+	Map<Serializable, S> getPartitionedState() throws Exception;
+
+	StateHandle<Serializable> snapshotStates(long checkpointId, long checkpointTimestamp) throws Exception;
+
+	void restoreStates(StateHandle<Serializable> snapshot) throws Exception;
+
+	boolean containsKey(Serializable key);
+	
+	void setCheckPointer(StateCheckpointer<S, C> checkpointer);
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/0558644a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/PartitionedStreamOperatorState.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/PartitionedStreamOperatorState.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/PartitionedStreamOperatorState.java
index b22aed4..b165a94 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/PartitionedStreamOperatorState.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/PartitionedStreamOperatorState.java
@@ -25,7 +25,6 @@ import java.util.Map;
 import org.apache.flink.api.common.state.OperatorState;
 import org.apache.flink.api.common.state.StateCheckpointer;
 import org.apache.flink.api.java.functions.KeySelector;
-import org.apache.flink.runtime.state.PartitionedStateStore;
 import org.apache.flink.runtime.state.StateHandle;
 import org.apache.flink.runtime.state.StateHandleProvider;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
@@ -83,7 +82,7 @@ public class PartitionedStreamOperatorState<IN, S, C extends Serializable> exten
 				if (stateStore.containsKey(key)) {
 					return stateStore.getStateForKey(key);
 				} else {
-					return checkpointer.restoreState((C) InstantiationUtil.deserializeObject(
+					return (S) checkpointer.restoreState((C) InstantiationUtil.deserializeObject(
 							defaultState, cl));
 				}
 			} catch (Exception e) {
@@ -123,13 +122,13 @@ public class PartitionedStreamOperatorState<IN, S, C extends Serializable> exten
 	}
 
 	@Override
-	public Map<Serializable, StateHandle<C>> snapshotState(long checkpointId,
+	public StateHandle<Serializable> snapshotState(long checkpointId,
 			long checkpointTimestamp) throws Exception {
 		return stateStore.snapshotStates(checkpointId, checkpointTimestamp);
 	}
 
 	@Override
-	public void restoreState(Map<Serializable, StateHandle<C>> snapshots) throws Exception {
+	public void restoreState(StateHandle<Serializable> snapshots) throws Exception {
 		stateStore.restoreStates(snapshots);
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/0558644a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/StreamOperatorState.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/StreamOperatorState.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/StreamOperatorState.java
index 2724efb..6e0a3ea 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/StreamOperatorState.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/StreamOperatorState.java
@@ -42,15 +42,14 @@ import com.google.common.collect.ImmutableMap;
  */
 public class StreamOperatorState<S, C extends Serializable> implements OperatorState<S> {
 
-	public static final Serializable DEFAULTKEY = -1;
-
 	private S state;
 	protected StateCheckpointer<S, C> checkpointer;
-	private final StateHandleProvider<C> provider;
+	private final StateHandleProvider<Serializable> provider;
 
+	@SuppressWarnings("unchecked")
 	public StreamOperatorState(StateCheckpointer<S, C> checkpointer, StateHandleProvider<C> provider) {
 		this.checkpointer = checkpointer;
-		this.provider = provider;
+		this.provider = (StateHandleProvider<Serializable>) provider;
 	}
 	
 	@SuppressWarnings("unchecked")
@@ -85,23 +84,24 @@ public class StreamOperatorState<S, C extends Serializable> implements OperatorS
 		this.checkpointer = checkpointer;
 	}
 
-	protected StateHandleProvider<C> getStateHandleProvider() {
+	protected StateHandleProvider<Serializable> getStateHandleProvider() {
 		return provider;
 	}
 
-	public Map<Serializable, StateHandle<C>> snapshotState(long checkpointId,
-			long checkpointTimestamp) throws Exception {
-		return ImmutableMap.of(DEFAULTKEY, provider.createStateHandle(checkpointer.snapshotState(
-				value(), checkpointId, checkpointTimestamp)));
+	public StateHandle<Serializable> snapshotState(long checkpointId, long checkpointTimestamp)
+			throws Exception {
+		return provider.createStateHandle(checkpointer.snapshotState(value(), checkpointId,
+				checkpointTimestamp));
 
 	}
 
-	public void restoreState(Map<Serializable, StateHandle<C>> snapshots) throws Exception {
-		update(checkpointer.restoreState(snapshots.get(DEFAULTKEY).getState()));
+	@SuppressWarnings("unchecked")
+	public void restoreState(StateHandle<Serializable> snapshot) throws Exception {
+		update((S) checkpointer.restoreState((C) snapshot.getState()));
 	}
 
 	public Map<Serializable, S> getPartitionedState() throws Exception {
-		return ImmutableMap.of(DEFAULTKEY, state);
+		return ImmutableMap.of((Serializable) 0, state);
 	}
 	
 	@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/0558644a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/WrapperStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/WrapperStateHandle.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/WrapperStateHandle.java
index 1adef48..27c697a 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/WrapperStateHandle.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/WrapperStateHandle.java
@@ -24,7 +24,6 @@ import java.util.Map;
 
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.runtime.state.LocalStateHandle;
-import org.apache.flink.runtime.state.PartitionedStateHandle;
 import org.apache.flink.runtime.state.StateHandle;
 
 /**
@@ -36,24 +35,22 @@ public class WrapperStateHandle extends LocalStateHandle<Serializable> {
 
 	private static final long serialVersionUID = 1L;
 
-	public WrapperStateHandle(List<Tuple2<StateHandle<Serializable>, Map<String, PartitionedStateHandle>>> state) {
+	public WrapperStateHandle(List<Tuple2<StateHandle<Serializable>, Map<String, OperatorStateHandle>>> state) {
 		super((Serializable) state);
 	}
 
 	@Override
 	public void discardState() throws Exception {
 		@SuppressWarnings("unchecked")
-		List<Tuple2<StateHandle<Serializable>, Map<String, PartitionedStateHandle>>> chainedStates = (List<Tuple2<StateHandle<Serializable>, Map<String, PartitionedStateHandle>>>) getState();
-		for (Tuple2<StateHandle<Serializable>, Map<String, PartitionedStateHandle>> state : chainedStates) {
+		List<Tuple2<StateHandle<Serializable>, Map<String, OperatorStateHandle>>> chainedStates = (List<Tuple2<StateHandle<Serializable>, Map<String, OperatorStateHandle>>>) getState();
+		for (Tuple2<StateHandle<Serializable>, Map<String, OperatorStateHandle>> state : chainedStates) {
 			if (state != null) {
 				if (state.f0 != null) {
 					state.f0.discardState();
 				}
 				if (state.f1 != null) {
-					for (PartitionedStateHandle statePartitions : state.f1.values()) {
-						for (StateHandle<Serializable> handle : statePartitions.getState().values()) {
-							handle.discardState();
-						}
+					for (StateHandle<Serializable> opState : state.f1.values()) {
+						opState.discardState();
 					}
 				}
 			}

http://git-wip-us.apache.org/repos/asf/flink/blob/0558644a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index d829833..2098da8 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -38,13 +38,13 @@ import org.apache.flink.runtime.jobgraph.tasks.CheckpointedOperator;
 import org.apache.flink.runtime.jobgraph.tasks.OperatorStateCarrier;
 import org.apache.flink.runtime.state.FileStateHandle;
 import org.apache.flink.runtime.state.LocalStateHandle;
-import org.apache.flink.runtime.state.PartitionedStateHandle;
 import org.apache.flink.runtime.state.StateHandle;
 import org.apache.flink.runtime.state.StateHandleProvider;
 import org.apache.flink.runtime.util.event.EventListener;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.StatefulStreamOperator;
 import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.state.OperatorStateHandle;
 import org.apache.flink.streaming.api.state.WrapperStateHandle;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -210,12 +210,12 @@ public abstract class StreamTask<OUT, O extends StreamOperator<OUT>> extends Abs
 	@Override
 	public void setInitialState(StateHandle<Serializable> stateHandle) throws Exception {
 
-		// We retrieve end restore the states for the chained operators.
-		List<Tuple2<StateHandle<Serializable>, Map<String, PartitionedStateHandle>>> chainedStates = (List<Tuple2<StateHandle<Serializable>, Map<String, PartitionedStateHandle>>>) stateHandle.getState();
+		// We retrieve end restore the states for the chained oeprators.
+		List<Tuple2<StateHandle<Serializable>, Map<String, OperatorStateHandle>>> chainedStates = (List<Tuple2<StateHandle<Serializable>, Map<String, OperatorStateHandle>>>) stateHandle.getState();
 
-		// We restore all stateful chained operators
+		// We restore all stateful operators
 		for (int i = 0; i < chainedStates.size(); i++) {
-			Tuple2<StateHandle<Serializable>, Map<String, PartitionedStateHandle>> state = chainedStates.get(i);
+			Tuple2<StateHandle<Serializable>, Map<String, OperatorStateHandle>> state = chainedStates.get(i);
 			// If state is not null we need to restore it
 			if (state != null) {
 				StreamOperator<?> chainedOperator = outputHandler.getChainedOperators().get(i);
@@ -234,15 +234,14 @@ public abstract class StreamTask<OUT, O extends StreamOperator<OUT>> extends Abs
 					LOG.debug("Starting checkpoint {} on task {}", checkpointId, getName());
 
 					// We wrap the states of the chained operators in a list, marking non-stateful oeprators with null
-					List<Tuple2<StateHandle<Serializable>, Map<String, PartitionedStateHandle>>> chainedStates = new ArrayList<Tuple2<StateHandle<Serializable>, Map<String, PartitionedStateHandle>>>();
+					List<Tuple2<StateHandle<Serializable>, Map<String, OperatorStateHandle>>> chainedStates = new ArrayList<Tuple2<StateHandle<Serializable>, Map<String, OperatorStateHandle>>>();
 
 					// A wrapper handle is created for the List of statehandles
 					WrapperStateHandle stateHandle;
 					try {
 
 						// We construct a list of states for chained tasks
-						for (StreamOperator<?> chainedOperator : outputHandler
-								.getChainedOperators()) {
+						for (StreamOperator<?> chainedOperator : outputHandler.getChainedOperators()) {
 							if (chainedOperator instanceof StatefulStreamOperator) {
 								chainedStates.add(((StatefulStreamOperator<?>) chainedOperator)
 										.getStateSnapshotFromFunction(checkpointId, timestamp));
@@ -281,7 +280,6 @@ public abstract class StreamTask<OUT, O extends StreamOperator<OUT>> extends Abs
 	@SuppressWarnings("rawtypes")
 	@Override
 	public void notifyCheckpointComplete(long checkpointId) throws Exception {
-		// we do nothing here so far. this should call commit on the source function, for example
 		synchronized (checkpointLock) {
 
 			for (StreamOperator<?> chainedOperator : outputHandler.getChainedOperators()) {

http://git-wip-us.apache.org/repos/asf/flink/blob/0558644a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/state/StatefulOperatorTest.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/state/StatefulOperatorTest.java b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/state/StatefulOperatorTest.java
index a7a8a09..6ca38b7 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/state/StatefulOperatorTest.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/state/StatefulOperatorTest.java
@@ -41,7 +41,6 @@ import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.operators.testutils.MockEnvironment;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
 import org.apache.flink.runtime.state.LocalStateHandle.LocalStateHandleProvider;
-import org.apache.flink.runtime.state.PartitionedStateHandle;
 import org.apache.flink.runtime.state.StateHandle;
 import org.apache.flink.streaming.api.checkpoint.Checkpointed;
 import org.apache.flink.streaming.api.datastream.KeyedDataStream;
@@ -166,7 +165,7 @@ public class StatefulOperatorTest {
 		}, context);
 
 		if (serializedState != null) {
-			op.restoreInitialState((Tuple2<StateHandle<Serializable>, Map<String, PartitionedStateHandle>>) InstantiationUtil
+			op.restoreInitialState((Tuple2<StateHandle<Serializable>, Map<String, OperatorStateHandle>>) InstantiationUtil
 					.deserializeObject(serializedState, Thread.currentThread()
 							.getContextClassLoader()));
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/0558644a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/StreamCheckpointingITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/StreamCheckpointingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/StreamCheckpointingITCase.java
index 438e980..3f99fa0 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/StreamCheckpointingITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/StreamCheckpointingITCase.java
@@ -18,32 +18,33 @@
 
 package org.apache.flink.test.checkpointing;
 
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.Map;
+import java.util.Random;
+import java.util.concurrent.ConcurrentHashMap;
+
 import org.apache.flink.api.common.functions.RichFilterFunction;
 import org.apache.flink.api.common.functions.RichMapFunction;
-import org.apache.flink.api.common.functions.RichReduceFunction;
 import org.apache.flink.api.common.state.OperatorState;
 import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.streaming.api.checkpoint.Checkpointed;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
-import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
+import org.apache.flink.streaming.api.functions.sink.SinkFunction;
 import org.apache.flink.streaming.api.functions.source.ParallelSourceFunction;
 import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
 import org.apache.flink.test.util.ForkableFlinkMiniCluster;
-
 import org.junit.AfterClass;
 import org.junit.BeforeClass;
 import org.junit.Test;
 
-import java.io.IOException;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Random;
-
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
 
 /**
  * A simple test that runs a streaming topology with checkpointing enabled.
@@ -94,7 +95,7 @@ public class StreamCheckpointingITCase {
 	 * Runs the following program:
 	 *
 	 * <pre>
-	 *     [ (source)->(filter)->(map) ] -> [ (map) ] -> [ (groupBy/reduce)->(sink) ]
+	 *     [ (source)->(filter) ]-s->[ (map) ] -> [ (map) ] -> [ (groupBy/count)->(sink) ]
 	 * </pre>
 	 */
 	@Test
@@ -114,37 +115,22 @@ public class StreamCheckpointingITCase {
 			
 			stream
 					// -------------- first vertex, chained to the source ----------------
-					.filter(new StringRichFilterFunction())
+					.filter(new StringRichFilterFunction()).shuffle()
 
 					// -------------- seconds vertex - the stateful one that also fails ----------------
 					.map(new StringPrefixCountRichMapFunction())
 					.startNewChain()
 					.map(new StatefulCounterFunction())
 
-					// -------------- third vertex - reducer and the sink ----------------
+					// -------------- third vertex - counter and the sink ----------------
 					.groupBy("prefix")
-					.reduce(new OnceFailingReducer(NUM_STRINGS))
-					.addSink(new RichSinkFunction<PrefixCount>() {
-
-						private Map<Character, Long> counts = new HashMap<Character, Long>();
+					.map(new OnceFailingPrefixCounter(NUM_STRINGS))
+					.addSink(new SinkFunction<PrefixCount>() {
 
 						@Override
-						public void invoke(PrefixCount value) {
-							Character first = value.prefix.charAt(0);
-							Long previous = counts.get(first);
-							if (previous == null) {
-								counts.put(first, value.count);
-							} else {
-								counts.put(first, Math.max(previous, value.count));
-							}
+						public void invoke(PrefixCount value) throws Exception {
+							// Do nothing here
 						}
-
-//						@Override
-//						public void close() {
-//							for (Long count : counts.values()) {
-//								assertEquals(NUM_STRINGS / 40, count.longValue());
-//							}
-//						}
 					});
 
 			env.execute();
@@ -163,14 +149,20 @@ public class StreamCheckpointingITCase {
 			for (long l : StatefulCounterFunction.counts) {
 				countSum += l;
 			}
-
-			// verify that we counted exactly right
 			
-			// this line should be uncommented once the "exactly one off by one" is fixed
-			// if this fails we see at which point the count is off
+			long reduceInputCount = 0;
+			for(long l: OnceFailingPrefixCounter.counts){
+				reduceInputCount += l;
+			}
+			
 			assertEquals(NUM_STRINGS, filterSum);
 			assertEquals(NUM_STRINGS, mapSum);
 			assertEquals(NUM_STRINGS, countSum);
+			assertEquals(NUM_STRINGS, reduceInputCount);
+			// verify that we counted exactly right
+			for (Long count : OnceFailingPrefixCounter.prefixCounts.values()) {
+				assertEquals(new Long(NUM_STRINGS / 40), count);
+			}
 		}
 		catch (Exception e) {
 			e.printStackTrace();
@@ -277,7 +269,10 @@ public class StreamCheckpointingITCase {
 		
 	}
 	
-	private static class OnceFailingReducer extends RichReduceFunction<PrefixCount> {
+	private static class OnceFailingPrefixCounter extends RichMapFunction<PrefixCount, PrefixCount> {
+		
+		private static Map<String, Long> prefixCounts = new ConcurrentHashMap<String, Long>();
+		static final long[] counts = new long[PARALLELISM];
 
 		private static volatile boolean hasFailed = false;
 
@@ -285,30 +280,44 @@ public class StreamCheckpointingITCase {
 		
 		private long failurePos;
 		private long count;
+		
+		private OperatorState<Long> pCount;
+		private OperatorState<Long> inputCount;
 
-		OnceFailingReducer(long numElements) {
+		OnceFailingPrefixCounter(long numElements) {
 			this.numElements = numElements;
 		}
 		
 		@Override
-		public void open(Configuration parameters) {
+		public void open(Configuration parameters) throws IOException {
 			long failurePosMin = (long) (0.4 * numElements / getRuntimeContext().getNumberOfParallelSubtasks());
 			long failurePosMax = (long) (0.7 * numElements / getRuntimeContext().getNumberOfParallelSubtasks());
 
 			failurePos = (new Random().nextLong() % (failurePosMax - failurePosMin)) + failurePosMin;
 			count = 0;
+			pCount = getRuntimeContext().getOperatorState("prefix-count", 0L, true);
+			inputCount = getRuntimeContext().getOperatorState("input-count", 0L, false);
 		}
 		
 		@Override
-		public PrefixCount reduce(PrefixCount value1, PrefixCount value2) throws Exception {
+		public void close() throws IOException {
+			counts[getRuntimeContext().getIndexOfThisSubtask()] = inputCount.value();
+		}
+
+		@Override
+		public PrefixCount map(PrefixCount value) throws Exception {
 			count++;
 			if (!hasFailed && count >= failurePos) {
 				hasFailed = true;
 				throw new Exception("Test Failure");
 			}
-			
-			value1.count += value2.count;
-			return value1;
+			inputCount.update(inputCount.value() + 1);
+		
+			long currentPrefixCount = pCount.value() + value.count;
+			pCount.update(currentPrefixCount);
+			prefixCounts.put(value.prefix, currentPrefixCount);
+			value.count = currentPrefixCount;
+			return value;
 		}
 	}
 	
@@ -316,7 +325,7 @@ public class StreamCheckpointingITCase {
 	//  Custom Type Classes
 	// --------------------------------------------------------------------------------------------
 
-	public static class PrefixCount {
+	public static class PrefixCount implements Serializable {
 
 		public String prefix;
 		public String value;