You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by se...@apache.org on 2015/10/16 18:08:39 UTC

[12/24] flink git commit: [FLINK-2550] [streaming] Make fast-path processing time windows fault tolerant

http://git-wip-us.apache.org/repos/asf/flink/blob/c24dca50/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java
index 671544e..dd76a67 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AccumulatingAlignedProcessingTimeWindowOperatorTest.java
@@ -20,23 +20,30 @@ package org.apache.flink.streaming.runtime.operators.windowing;
 
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.accumulators.Accumulator;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.StringSerializer;
 import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.streaming.api.functions.windowing.WindowFunction;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.Output;
+import org.apache.flink.streaming.api.state.StateBackend;
+import org.apache.flink.streaming.api.state.memory.MemoryStateBackend;
 import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
 import org.apache.flink.streaming.runtime.operators.Triggerable;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
+import org.apache.flink.streaming.runtime.tasks.StreamTaskState;
 import org.apache.flink.util.Collector;
 
 import org.junit.After;
 import org.junit.Test;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
+import org.mockito.stubbing.OngoingStubbing;
 
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
@@ -119,27 +126,31 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest {
 	@Test
 	public void testWindowSizeAndSlide() {
 		try {
-			AbstractAlignedProcessingTimeWindowOperator<String, String, String, ?> op;
+			AccumulatingProcessingTimeWindowOperator<String, String, String> op;
 			
-			op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 5000, 1000);
+			op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector,
+					StringSerializer.INSTANCE, StringSerializer.INSTANCE, 5000, 1000);
 			assertEquals(5000, op.getWindowSize());
 			assertEquals(1000, op.getWindowSlide());
 			assertEquals(1000, op.getPaneSize());
 			assertEquals(5, op.getNumPanesPerWindow());
 
-			op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 1000, 1000);
+			op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector,
+					StringSerializer.INSTANCE, StringSerializer.INSTANCE, 1000, 1000);
 			assertEquals(1000, op.getWindowSize());
 			assertEquals(1000, op.getWindowSlide());
 			assertEquals(1000, op.getPaneSize());
 			assertEquals(1, op.getNumPanesPerWindow());
 
-			op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 1500, 1000);
+			op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector,
+					StringSerializer.INSTANCE, StringSerializer.INSTANCE, 1500, 1000);
 			assertEquals(1500, op.getWindowSize());
 			assertEquals(1000, op.getWindowSlide());
 			assertEquals(500, op.getPaneSize());
 			assertEquals(3, op.getNumPanesPerWindow());
 
-			op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 1200, 1100);
+			op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector,
+					StringSerializer.INSTANCE, StringSerializer.INSTANCE, 1200, 1100);
 			assertEquals(1200, op.getWindowSize());
 			assertEquals(1100, op.getWindowSlide());
 			assertEquals(100, op.getPaneSize());
@@ -157,31 +168,35 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest {
 			@SuppressWarnings("unchecked")
 			final Output<StreamRecord<String>> mockOut = mock(Output.class);
 			final StreamTask<?, ?> mockTask = createMockTask();
-			
-			AbstractAlignedProcessingTimeWindowOperator<String, String, String, ?> op;
 
-			op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 5000, 1000);
+			AccumulatingProcessingTimeWindowOperator<String, String, String> op;
+
+			op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector,
+					StringSerializer.INSTANCE, StringSerializer.INSTANCE, 5000, 1000);
 			op.setup(mockTask, new StreamConfig(new Configuration()), mockOut);
 			op.open();
 			assertTrue(op.getNextSlideTime() % 1000 == 0);
 			assertTrue(op.getNextEvaluationTime() % 1000 == 0);
 			op.dispose();
 
-			op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 1000, 1000);
+			op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector,
+					StringSerializer.INSTANCE, StringSerializer.INSTANCE, 1000, 1000);
 			op.setup(mockTask, new StreamConfig(new Configuration()), mockOut);
 			op.open();
 			assertTrue(op.getNextSlideTime() % 1000 == 0);
 			assertTrue(op.getNextEvaluationTime() % 1000 == 0);
 			op.dispose();
 
-			op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 1500, 1000);
+			op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector,
+					StringSerializer.INSTANCE, StringSerializer.INSTANCE, 1500, 1000);
 			op.setup(mockTask, new StreamConfig(new Configuration()), mockOut);
 			op.open();
 			assertTrue(op.getNextSlideTime() % 500 == 0);
 			assertTrue(op.getNextEvaluationTime() % 1000 == 0);
 			op.dispose();
 
-			op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 1200, 1100);
+			op = new AccumulatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector,
+					StringSerializer.INSTANCE, StringSerializer.INSTANCE, 1200, 1100);
 			op.setup(mockTask, new StreamConfig(new Configuration()), mockOut);
 			op.open();
 			assertTrue(op.getNextSlideTime() % 100 == 0);
@@ -204,9 +219,11 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest {
 			final StreamTask<?, ?> mockTask = createMockTaskWithTimer(timerService, lock);
 
 			// tumbling window that triggers every 20 milliseconds
-			AbstractAlignedProcessingTimeWindowOperator<Integer, Integer, Integer, ?> op =
+			AccumulatingProcessingTimeWindowOperator<Integer, Integer, Integer> op =
 					new AccumulatingProcessingTimeWindowOperator<>(
-							validatingIdentityFunction, identitySelector, windowSize, windowSize);
+							validatingIdentityFunction, identitySelector,
+							IntSerializer.INSTANCE, IntSerializer.INSTANCE,
+							windowSize, windowSize);
 
 			op.setup(mockTask, new StreamConfig(new Configuration()), out);
 			op.open();
@@ -220,7 +237,9 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest {
 				Thread.sleep(1);
 			}
 
-			op.close();
+			synchronized (lock) {
+				op.close();
+			}
 			op.dispose();
 
 			// get and verify the result
@@ -250,8 +269,10 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest {
 			final StreamTask<?, ?> mockTask = createMockTaskWithTimer(timerService, lock);
 			
 			// tumbling window that triggers every 20 milliseconds
-			AbstractAlignedProcessingTimeWindowOperator<Integer, Integer, Integer, ?> op =
-					new AccumulatingProcessingTimeWindowOperator<>(validatingIdentityFunction, identitySelector, 150, 50);
+			AccumulatingProcessingTimeWindowOperator<Integer, Integer, Integer> op =
+					new AccumulatingProcessingTimeWindowOperator<>(
+							validatingIdentityFunction, identitySelector,
+							IntSerializer.INSTANCE, IntSerializer.INSTANCE, 150, 50);
 
 			op.setup(mockTask, new StreamConfig(new Configuration()), out);
 			op.open();
@@ -265,7 +286,9 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest {
 				Thread.sleep(1);
 			}
 
-			op.close();
+			synchronized (lock) {
+				op.close();
+			}
 			op.dispose();
 
 			// get and verify the result
@@ -312,8 +335,10 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest {
 			final StreamTask<?, ?> mockTask = createMockTaskWithTimer(timerService, lock);
 
 			// tumbling window that triggers every 20 milliseconds
-			AbstractAlignedProcessingTimeWindowOperator<Integer, Integer, Integer, ?> op =
-					new AccumulatingProcessingTimeWindowOperator<>(validatingIdentityFunction, identitySelector, 50, 50);
+			AccumulatingProcessingTimeWindowOperator<Integer, Integer, Integer> op =
+					new AccumulatingProcessingTimeWindowOperator<>(
+							validatingIdentityFunction, identitySelector,
+							IntSerializer.INSTANCE, IntSerializer.INSTANCE, 50, 50);
 
 			op.setup(mockTask, new StreamConfig(new Configuration()), out);
 			op.open();
@@ -342,7 +367,9 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest {
 			Collections.sort(result);
 			assertEquals(Arrays.asList(1, 2, 3, 4, 5, 6), result);
 
-			op.close();
+			synchronized (lock) {
+				op.close();
+			}
 			op.dispose();
 		}
 		catch (Exception e) {
@@ -364,8 +391,10 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest {
 			final StreamTask<?, ?> mockTask = createMockTaskWithTimer(timerService, lock);
 
 			// tumbling window that triggers every 20 milliseconds
-			AbstractAlignedProcessingTimeWindowOperator<Integer, Integer, Integer, ?> op =
-					new AccumulatingProcessingTimeWindowOperator<>(validatingIdentityFunction, identitySelector, 150, 50);
+			AccumulatingProcessingTimeWindowOperator<Integer, Integer, Integer> op =
+					new AccumulatingProcessingTimeWindowOperator<>(
+							validatingIdentityFunction, identitySelector,
+							IntSerializer.INSTANCE, IntSerializer.INSTANCE, 150, 50);
 
 			op.setup(mockTask, new StreamConfig(new Configuration()), out);
 			op.open();
@@ -384,8 +413,10 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest {
 			
 			Collections.sort(result);
 			assertEquals(Arrays.asList(1, 1, 1, 2, 2, 2), result);
-			
-			op.close();
+
+			synchronized (lock) {
+				op.close();
+			}
 			op.dispose();
 		}
 		catch (Exception e) {
@@ -407,8 +438,10 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest {
 			
 			// the operator has a window time that is so long that it will not fire in this test
 			final long oneYear = 365L * 24 * 60 * 60 * 1000;
-			AbstractAlignedProcessingTimeWindowOperator<Integer, Integer, Integer, ?> op = 
-					new AccumulatingProcessingTimeWindowOperator<>(validatingIdentityFunction, identitySelector,
+			AccumulatingProcessingTimeWindowOperator<Integer, Integer, Integer> op =
+					new AccumulatingProcessingTimeWindowOperator<>(
+							validatingIdentityFunction, identitySelector,
+							IntSerializer.INSTANCE, IntSerializer.INSTANCE,
 							oneYear, oneYear);
 
 			op.setup(mockTask, new StreamConfig(new Configuration()), out);
@@ -420,8 +453,10 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest {
 					op.processElement(new StreamRecord<Integer>(i));
 				}
 			}
-			
-			op.close();
+
+			synchronized (lock) {
+				op.close();
+			}
 			op.dispose();
 			
 			// get and verify the result
@@ -450,9 +485,11 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest {
 
 			// the operator has a window time that is so long that it will not fire in this test
 			final long hundredYears = 100L * 365 * 24 * 60 * 60 * 1000;
-			AbstractAlignedProcessingTimeWindowOperator<Integer, Integer, Integer, ?> op =
+			AccumulatingProcessingTimeWindowOperator<Integer, Integer, Integer> op =
 					new AccumulatingProcessingTimeWindowOperator<>(
-							failingFunction, identitySelector, hundredYears, hundredYears);
+							failingFunction, identitySelector,
+							IntSerializer.INSTANCE, IntSerializer.INSTANCE,
+							hundredYears, hundredYears);
 
 			op.setup(mockTask, new StreamConfig(new Configuration()), out);
 			op.open();
@@ -464,7 +501,9 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest {
 			}
 			
 			try {
-				op.close();
+				synchronized (lock) {
+					op.close();
+				}
 				fail("This should fail with an exception");
 			}
 			catch (Exception e) {
@@ -484,12 +523,216 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest {
 		}
 	}
 	
+	@Test
+	public void checkpointRestoreWithPendingWindowTumbling() {
+		final ScheduledExecutorService timerService = Executors.newSingleThreadScheduledExecutor();
+		try {
+			final int windowSize = 200;
+			final CollectingOutput<Integer> out = new CollectingOutput<>(windowSize);
+			final Object lock = new Object();
+			final StreamTask<?, ?> mockTask = createMockTaskWithTimer(timerService, lock);
+
+			// tumbling window that triggers every 50 milliseconds
+			AccumulatingProcessingTimeWindowOperator<Integer, Integer, Integer> op =
+					new AccumulatingProcessingTimeWindowOperator<>(
+							validatingIdentityFunction, identitySelector,
+							IntSerializer.INSTANCE, IntSerializer.INSTANCE,
+							windowSize, windowSize);
+
+			op.setup(mockTask, new StreamConfig(new Configuration()), out);
+			op.open();
+
+			// inject some elements
+			final int numElementsFirst = 700;
+			for (int i = 0; i < numElementsFirst; i++) {
+				synchronized (lock) {
+					op.processElement(new StreamRecord<Integer>(i));
+				}
+				Thread.sleep(1);
+			}
+
+			// draw a snapshot and dispose the window
+			StreamTaskState state;
+			List<Integer> resultAtSnapshot;
+			synchronized (lock) {
+				int beforeSnapShot = out.getElements().size(); 
+				state = op.snapshotOperatorState(1L, System.currentTimeMillis());
+				resultAtSnapshot = new ArrayList<>(out.getElements());
+				int afterSnapShot = out.getElements().size();
+				assertEquals("operator performed computation during snapshot", beforeSnapShot, afterSnapShot);
+			}
+
+			// inject some random elements, which should not show up in the state
+			for (int i = 0; i < 300; i++) {
+				synchronized (lock) {
+					op.processElement(new StreamRecord<Integer>(i + numElementsFirst));
+				}
+				Thread.sleep(1);
+			}
+			
+			op.dispose();
+			
+			// re-create the operator and restore the state
+			final CollectingOutput<Integer> out2 = new CollectingOutput<>(windowSize);
+			op = new AccumulatingProcessingTimeWindowOperator<>(
+							validatingIdentityFunction, identitySelector,
+							IntSerializer.INSTANCE, IntSerializer.INSTANCE,
+							windowSize, windowSize);
+
+			op.setup(mockTask, new StreamConfig(new Configuration()), out2);
+			op.restoreState(state);
+			op.open();
+
+			// inject some more elements
+			final int numElements = 1000;
+			for (int i = numElementsFirst; i < numElements; i++) {
+				synchronized (lock) {
+					op.processElement(new StreamRecord<Integer>(i));
+				}
+				Thread.sleep(1);
+			}
+
+			synchronized (lock) {
+				op.close();
+			}
+			op.dispose();
+
+			// get and verify the result
+			List<Integer> finalResult = new ArrayList<>(resultAtSnapshot);
+			finalResult.addAll(out2.getElements());
+			assertEquals(numElements, finalResult.size());
+
+			Collections.sort(finalResult);
+			for (int i = 0; i < numElements; i++) {
+				assertEquals(i, finalResult.get(i).intValue());
+			}
+		}
+		catch (Exception e) {
+			e.printStackTrace();
+			fail(e.getMessage());
+		}
+		finally {
+			timerService.shutdown();
+		}
+	}
+
+	@Test
+	public void checkpointRestoreWithPendingWindowSliding() {
+		final ScheduledExecutorService timerService = Executors.newSingleThreadScheduledExecutor();
+		try {
+			final int factor = 4;
+			final int windowSlide = 50;
+			final int windowSize = factor * windowSlide;
+			
+			final CollectingOutput<Integer> out = new CollectingOutput<>(windowSlide);
+			final Object lock = new Object();
+			final StreamTask<?, ?> mockTask = createMockTaskWithTimer(timerService, lock);
+
+			// sliding window (200 msecs) every 50 msecs
+			AccumulatingProcessingTimeWindowOperator<Integer, Integer, Integer> op =
+					new AccumulatingProcessingTimeWindowOperator<>(
+							validatingIdentityFunction, identitySelector,
+							IntSerializer.INSTANCE, IntSerializer.INSTANCE,
+							windowSize, windowSlide);
+
+			op.setup(mockTask, new StreamConfig(new Configuration()), out);
+			op.open();
+
+			// inject some elements
+			final int numElements = 1000;
+			final int numElementsFirst = 700;
+			
+			for (int i = 0; i < numElementsFirst; i++) {
+				synchronized (lock) {
+					op.processElement(new StreamRecord<Integer>(i));
+				}
+				Thread.sleep(1);
+			}
+
+			// draw a snapshot
+			StreamTaskState state;
+			List<Integer> resultAtSnapshot;
+			synchronized (lock) {
+				int beforeSnapShot = out.getElements().size();
+				state = op.snapshotOperatorState(1L, System.currentTimeMillis());
+				resultAtSnapshot = new ArrayList<>(out.getElements());
+				int afterSnapShot = out.getElements().size();
+				assertEquals("operator performed computation during snapshot", beforeSnapShot, afterSnapShot);
+			}
+			
+			assertTrue(resultAtSnapshot.size() <= factor * numElementsFirst);
+
+			// inject the remaining elements - these should not influence the snapshot
+			for (int i = numElementsFirst; i < numElements; i++) {
+				synchronized (lock) {
+					op.processElement(new StreamRecord<Integer>(i));
+				}
+				Thread.sleep(1);
+			}
+			
+			op.dispose();
+			
+			// re-create the operator and restore the state
+			final CollectingOutput<Integer> out2 = new CollectingOutput<>(windowSlide);
+			op = new AccumulatingProcessingTimeWindowOperator<>(
+					validatingIdentityFunction, identitySelector,
+					IntSerializer.INSTANCE, IntSerializer.INSTANCE,
+					windowSize, windowSlide);
+
+			op.setup(mockTask, new StreamConfig(new Configuration()), out2);
+			op.restoreState(state);
+			op.open();
+			
+
+			// inject again the remaining elements
+			for (int i = numElementsFirst; i < numElements; i++) {
+				synchronized (lock) {
+					op.processElement(new StreamRecord<Integer>(i));
+				}
+				Thread.sleep(1);
+			}
+
+			// for a deterministic result, we need to wait until all pending triggers
+			// have fired and emitted their results
+			long deadline = System.currentTimeMillis() + 120000;
+			do {
+				Thread.sleep(20);
+			}
+			while (resultAtSnapshot.size() + out2.getElements().size() < factor * numElements
+					&& System.currentTimeMillis() < deadline);
+
+			synchronized (lock) {
+				op.close();
+			}
+			op.dispose();
+
+			// get and verify the result
+			List<Integer> finalResult = new ArrayList<>(resultAtSnapshot);
+			finalResult.addAll(out2.getElements());
+			assertEquals(factor * numElements, finalResult.size());
+
+			Collections.sort(finalResult);
+			for (int i = 0; i < factor * numElements; i++) {
+				assertEquals(i / factor, finalResult.get(i).intValue());
+			}
+		}
+		catch (Exception e) {
+			e.printStackTrace();
+			fail(e.getMessage());
+		}
+		finally {
+			timerService.shutdown();
+		}
+	}
+	
 	// ------------------------------------------------------------------------
 	
 	private void assertInvalidParameter(long windowSize, long windowSlide) {
 		try {
 			new AccumulatingProcessingTimeWindowOperator<String, String, String>(
-					mockFunction, mockKeySelector, windowSize, windowSlide);
+					mockFunction, mockKeySelector, 
+					StringSerializer.INSTANCE, StringSerializer.INSTANCE,
+					windowSize, windowSlide);
 			fail("This should fail with an IllegalArgumentException");
 		}
 		catch (IllegalArgumentException e) {
@@ -541,6 +784,12 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest {
 
 		when(task.getEnvironment()).thenReturn(env);
 
+		// ugly java generic hacks to get the state backend into the mock
+		@SuppressWarnings("unchecked")
+		OngoingStubbing<StateBackend<?>> stubbing =
+				(OngoingStubbing<StateBackend<?>>) (OngoingStubbing<?>) when(task.getStateBackend());
+		stubbing.thenReturn(MemoryStateBackend.defaultInstance());
+		
 		return task;
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/c24dca50/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java
index 106e833..ab8e551 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AggregatingAlignedProcessingTimeWindowOperatorTest.java
@@ -21,20 +21,27 @@ package org.apache.flink.streaming.runtime.operators.windowing;
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.accumulators.Accumulator;
 import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.StringSerializer;
 import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.Output;
+import org.apache.flink.streaming.api.state.StateBackend;
+import org.apache.flink.streaming.api.state.memory.MemoryStateBackend;
 import org.apache.flink.streaming.runtime.operators.Triggerable;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
 
+import org.apache.flink.streaming.runtime.tasks.StreamTaskState;
 import org.junit.After;
 import org.junit.Test;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
+import org.mockito.stubbing.OngoingStubbing;
 
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
@@ -118,25 +125,29 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest {
 		try {
 			AggregatingProcessingTimeWindowOperator<String, String> op;
 			
-			op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 5000, 1000);
+			op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector,
+					StringSerializer.INSTANCE, StringSerializer.INSTANCE, 5000, 1000);
 			assertEquals(5000, op.getWindowSize());
 			assertEquals(1000, op.getWindowSlide());
 			assertEquals(1000, op.getPaneSize());
 			assertEquals(5, op.getNumPanesPerWindow());
 
-			op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 1000, 1000);
+			op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector,
+					StringSerializer.INSTANCE, StringSerializer.INSTANCE, 1000, 1000);
 			assertEquals(1000, op.getWindowSize());
 			assertEquals(1000, op.getWindowSlide());
 			assertEquals(1000, op.getPaneSize());
 			assertEquals(1, op.getNumPanesPerWindow());
 
-			op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 1500, 1000);
+			op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector,
+					StringSerializer.INSTANCE, StringSerializer.INSTANCE, 1500, 1000);
 			assertEquals(1500, op.getWindowSize());
 			assertEquals(1000, op.getWindowSlide());
 			assertEquals(500, op.getPaneSize());
 			assertEquals(3, op.getNumPanesPerWindow());
 
-			op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 1200, 1100);
+			op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector,
+					StringSerializer.INSTANCE, StringSerializer.INSTANCE, 1200, 1100);
 			assertEquals(1200, op.getWindowSize());
 			assertEquals(1100, op.getWindowSlide());
 			assertEquals(100, op.getPaneSize());
@@ -157,28 +168,32 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest {
 			
 			AggregatingProcessingTimeWindowOperator<String, String> op;
 
-			op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 5000, 1000);
+			op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector,
+					StringSerializer.INSTANCE, StringSerializer.INSTANCE, 5000, 1000);
 			op.setup(mockTask, new StreamConfig(new Configuration()), mockOut);
 			op.open();
 			assertTrue(op.getNextSlideTime() % 1000 == 0);
 			assertTrue(op.getNextEvaluationTime() % 1000 == 0);
 			op.dispose();
 
-			op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 1000, 1000);
+			op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector,
+					StringSerializer.INSTANCE, StringSerializer.INSTANCE, 1000, 1000);
 			op.setup(mockTask, new StreamConfig(new Configuration()), mockOut);
 			op.open();
 			assertTrue(op.getNextSlideTime() % 1000 == 0);
 			assertTrue(op.getNextEvaluationTime() % 1000 == 0);
 			op.dispose();
 
-			op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 1500, 1000);
+			op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector,
+					StringSerializer.INSTANCE, StringSerializer.INSTANCE, 1500, 1000);
 			op.setup(mockTask, new StreamConfig(new Configuration()), mockOut);
 			op.open();
 			assertTrue(op.getNextSlideTime() % 500 == 0);
 			assertTrue(op.getNextEvaluationTime() % 1000 == 0);
 			op.dispose();
 
-			op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector, 1200, 1100);
+			op = new AggregatingProcessingTimeWindowOperator<>(mockFunction, mockKeySelector,
+					StringSerializer.INSTANCE, StringSerializer.INSTANCE, 1200, 1100);
 			op.setup(mockTask, new StreamConfig(new Configuration()), mockOut);
 			op.open();
 			assertTrue(op.getNextSlideTime() % 100 == 0);
@@ -200,7 +215,9 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest {
 			
 			AggregatingProcessingTimeWindowOperator<Integer, Integer> op =
 					new AggregatingProcessingTimeWindowOperator<>(
-							sumFunction, identitySelector, windowSize, windowSize);
+							sumFunction, identitySelector,
+							IntSerializer.INSTANCE, IntSerializer.INSTANCE,
+							windowSize, windowSize);
 			
 			final Object lock = new Object();
 			final StreamTask<?, ?> mockTask = createMockTaskWithTimer(timerService, lock);
@@ -211,11 +228,15 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest {
 			final int numElements = 1000;
 
 			for (int i = 0; i < numElements; i++) {
-				op.processElement(new StreamRecord<Integer>(i));
+				synchronized (lock) {
+					op.processElement(new StreamRecord<Integer>(i));
+				}
 				Thread.sleep(1);
 			}
 
-			op.close();
+			synchronized (lock) {
+				op.close();
+			}
 			op.dispose();
 
 			// get and verify the result
@@ -238,7 +259,6 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest {
 
 	@Test
 	public void  testTumblingWindowDuplicateElements() {
-
 		final ScheduledExecutorService timerService = Executors.newSingleThreadScheduledExecutor();
 
 		try {
@@ -250,7 +270,9 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest {
 			
 			AggregatingProcessingTimeWindowOperator<Integer, Integer> op =
 					new AggregatingProcessingTimeWindowOperator<>(
-							sumFunction, identitySelector, windowSize, windowSize);
+							sumFunction, identitySelector,
+							IntSerializer.INSTANCE, IntSerializer.INSTANCE,
+							windowSize, windowSize);
 			
 			op.setup(mockTask, new StreamConfig(new Configuration()), out);
 			op.open();
@@ -261,22 +283,23 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest {
 			int window = 1;
 			
 			while (window <= numWindows) {
-				long nextTime = op.getNextEvaluationTime();
-				int val = ((int) nextTime) ^ ((int) (nextTime >>> 32));
-
 				synchronized (lock) {
+					long nextTime = op.getNextEvaluationTime();
+					int val = ((int) nextTime) ^ ((int) (nextTime >>> 32));
+					
 					op.processElement(new StreamRecord<Integer>(val));
+
+					if (nextTime != previousNextTime) {
+						window++;
+						previousNextTime = nextTime;
+					}
 				}
-				
-				if (nextTime != previousNextTime) {
-					window++;
-					previousNextTime = nextTime;
-				}
-				
 				Thread.sleep(1);
 			}
 
-			op.close();
+			synchronized (lock) {
+				op.close();
+			}
 			op.dispose();
 			
 			List<Integer> result = out.getElements();
@@ -287,12 +310,13 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest {
 
 			// deduplicate for more accurate checks
 			HashSet<Integer> set = new HashSet<>(result);
-			assertTrue(set.size() == 10 || set.size() == 11);
+			assertTrue(set.size() == 10);
 		}
 		catch (Exception e) {
 			e.printStackTrace();
 			fail(e.getMessage());
-		} finally {
+		}
+		finally {
 			timerService.shutdown();
 		}
 	}
@@ -308,7 +332,10 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest {
 
 			// tumbling window that triggers every 20 milliseconds
 			AggregatingProcessingTimeWindowOperator<Integer, Integer> op =
-					new AggregatingProcessingTimeWindowOperator<>(sumFunction, identitySelector, 150, 50);
+					new AggregatingProcessingTimeWindowOperator<>(
+							sumFunction, identitySelector,
+							IntSerializer.INSTANCE, IntSerializer.INSTANCE,
+							150, 50);
 
 			op.setup(mockTask, new StreamConfig(new Configuration()), out);
 			op.open();
@@ -322,7 +349,9 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest {
 				Thread.sleep(1);
 			}
 
-			op.close();
+			synchronized (lock) {
+				op.close();
+			}
 			op.dispose();
 
 			// get and verify the result
@@ -369,7 +398,9 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest {
 
 			// tumbling window that triggers every 20 milliseconds
 			AggregatingProcessingTimeWindowOperator<Integer, Integer> op =
-					new AggregatingProcessingTimeWindowOperator<>(sumFunction, identitySelector, 150, 50);
+					new AggregatingProcessingTimeWindowOperator<>(
+							sumFunction, identitySelector,
+							IntSerializer.INSTANCE, IntSerializer.INSTANCE, 150, 50);
 
 			op.setup(mockTask, new StreamConfig(new Configuration()), out);
 			op.open();
@@ -388,8 +419,10 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest {
 			
 			Collections.sort(result);
 			assertEquals(Arrays.asList(1, 1, 1, 2, 2, 2), result);
-			
-			op.close();
+
+			synchronized (lock) {
+				op.close();
+			}
 			op.dispose();
 		}
 		catch (Exception e) {
@@ -412,7 +445,8 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest {
 			// the operator has a window time that is so long that it will not fire in this test
 			final long oneYear = 365L * 24 * 60 * 60 * 1000;
 			AggregatingProcessingTimeWindowOperator<Integer, Integer> op = 
-					new AggregatingProcessingTimeWindowOperator<>(sumFunction, identitySelector, oneYear, oneYear);
+					new AggregatingProcessingTimeWindowOperator<>(sumFunction, identitySelector,
+							IntSerializer.INSTANCE, IntSerializer.INSTANCE, oneYear, oneYear);
 
 			op.setup(mockTask, new StreamConfig(new Configuration()), out);
 			op.open();
@@ -423,8 +457,10 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest {
 					op.processElement(new StreamRecord<Integer>(i));
 				}
 			}
-			
-			op.close();
+
+			synchronized (lock) {
+				op.close();
+			}
 			op.dispose();
 			
 			// get and verify the result
@@ -455,7 +491,9 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest {
 			final long hundredYears = 100L * 365 * 24 * 60 * 60 * 1000;
 			AggregatingProcessingTimeWindowOperator<Integer, Integer> op =
 					new AggregatingProcessingTimeWindowOperator<>(
-							failingFunction, identitySelector, hundredYears, hundredYears);
+							failingFunction, identitySelector,
+							IntSerializer.INSTANCE, IntSerializer.INSTANCE,
+							hundredYears, hundredYears);
 
 			op.setup(mockTask, new StreamConfig(new Configuration()), out);
 			op.open();
@@ -484,13 +522,220 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest {
 			timerService.shutdown();
 		}
 	}
+
+	@Test
+	public void checkpointRestoreWithPendingWindowTumbling() {
+		final ScheduledExecutorService timerService = Executors.newSingleThreadScheduledExecutor();
+		try {
+			final int windowSize = 200;
+			final CollectingOutput<Integer> out = new CollectingOutput<>(windowSize);
+			final Object lock = new Object();
+			final StreamTask<?, ?> mockTask = createMockTaskWithTimer(timerService, lock);
+
+			// tumbling window that triggers every 50 milliseconds
+			AggregatingProcessingTimeWindowOperator<Integer, Integer> op =
+					new AggregatingProcessingTimeWindowOperator<>(
+							sumFunction, identitySelector,
+							IntSerializer.INSTANCE, IntSerializer.INSTANCE,
+							windowSize, windowSize);
+
+			op.setup(mockTask, new StreamConfig(new Configuration()), out);
+			op.open();
+
+			// inject some elements
+			final int numElementsFirst = 700;
+			final int numElements = 1000;
+			
+			for (int i = 0; i < numElementsFirst; i++) {
+				synchronized (lock) {
+					op.processElement(new StreamRecord<Integer>(i));
+				}
+				Thread.sleep(1);
+			}
+
+			// draw a snapshot and dispose the window
+			StreamTaskState state;
+			List<Integer> resultAtSnapshot;
+			synchronized (lock) {
+				int beforeSnapShot = out.getElements().size();
+				state = op.snapshotOperatorState(1L, System.currentTimeMillis());
+				resultAtSnapshot = new ArrayList<>(out.getElements());
+				int afterSnapShot = out.getElements().size();
+				assertEquals("operator performed computation during snapshot", beforeSnapShot, afterSnapShot);
+			}
+			
+			assertTrue(resultAtSnapshot.size() <= numElementsFirst);
+
+			// inject some random elements, which should not show up in the state
+			for (int i = numElementsFirst; i < numElements; i++) {
+				synchronized (lock) {
+					op.processElement(new StreamRecord<Integer>(i));
+				}
+				Thread.sleep(1);
+			}
+
+			op.dispose();
+
+			// re-create the operator and restore the state
+			final CollectingOutput<Integer> out2 = new CollectingOutput<>(windowSize);
+			op = new AggregatingProcessingTimeWindowOperator<>(
+					sumFunction, identitySelector,
+					IntSerializer.INSTANCE, IntSerializer.INSTANCE,
+					windowSize, windowSize);
+
+			op.setup(mockTask, new StreamConfig(new Configuration()), out2);
+			op.restoreState(state);
+			op.open();
+
+			// inject the remaining elements
+			for (int i = numElementsFirst; i < numElements; i++) {
+				synchronized (lock) {
+					op.processElement(new StreamRecord<Integer>(i));
+				}
+				Thread.sleep(1);
+			}
+
+			synchronized (lock) {
+				op.close();
+			}
+			op.dispose();
+
+			// get and verify the result
+			List<Integer> finalResult = new ArrayList<>(resultAtSnapshot);
+			finalResult.addAll(out2.getElements());
+			assertEquals(numElements, finalResult.size());
+
+			Collections.sort(finalResult);
+			for (int i = 0; i < numElements; i++) {
+				assertEquals(i, finalResult.get(i).intValue());
+			}
+		}
+		catch (Exception e) {
+			e.printStackTrace();
+			fail(e.getMessage());
+		}
+		finally {
+			timerService.shutdown();
+		}
+	}
+
+	@Test
+	public void checkpointRestoreWithPendingWindowSliding() {
+		final ScheduledExecutorService timerService = Executors.newSingleThreadScheduledExecutor();
+		try {
+			final int factor = 4;
+			final int windowSlide = 50;
+			final int windowSize = factor * windowSlide;
+
+			final CollectingOutput<Integer> out = new CollectingOutput<>(windowSlide);
+			final Object lock = new Object();
+			final StreamTask<?, ?> mockTask = createMockTaskWithTimer(timerService, lock);
+
+			// sliding window (200 msecs) every 50 msecs
+			AggregatingProcessingTimeWindowOperator<Integer, Integer> op =
+					new AggregatingProcessingTimeWindowOperator<>(
+							sumFunction, identitySelector,
+							IntSerializer.INSTANCE, IntSerializer.INSTANCE,
+							windowSize, windowSlide);
+
+			op.setup(mockTask, new StreamConfig(new Configuration()), out);
+			op.open();
+
+			// inject some elements
+			final int numElements = 1000;
+			final int numElementsFirst = 700;
+
+			for (int i = 0; i < numElementsFirst; i++) {
+				synchronized (lock) {
+					op.processElement(new StreamRecord<Integer>(i));
+				}
+				Thread.sleep(1);
+			}
+
+			// draw a snapshot
+			StreamTaskState state;
+			List<Integer> resultAtSnapshot;
+			synchronized (lock) {
+				int beforeSnapShot = out.getElements().size();
+				state = op.snapshotOperatorState(1L, System.currentTimeMillis());
+				resultAtSnapshot = new ArrayList<>(out.getElements());
+				int afterSnapShot = out.getElements().size();
+				assertEquals("operator performed computation during snapshot", beforeSnapShot, afterSnapShot);
+			}
+
+			assertTrue(resultAtSnapshot.size() <= factor * numElementsFirst);
+
+			// inject the remaining elements - these should not influence the snapshot
+			for (int i = numElementsFirst; i < numElements; i++) {
+				synchronized (lock) {
+					op.processElement(new StreamRecord<Integer>(i));
+				}
+				Thread.sleep(1);
+			}
+
+			op.dispose();
+
+			// re-create the operator and restore the state
+			final CollectingOutput<Integer> out2 = new CollectingOutput<>(windowSlide);
+			op = new AggregatingProcessingTimeWindowOperator<>(
+					sumFunction, identitySelector,
+					IntSerializer.INSTANCE, IntSerializer.INSTANCE,
+					windowSize, windowSlide);
+
+			op.setup(mockTask, new StreamConfig(new Configuration()), out2);
+			op.restoreState(state);
+			op.open();
+
+
+			// inject again the remaining elements
+			for (int i = numElementsFirst; i < numElements; i++) {
+				synchronized (lock) {
+					op.processElement(new StreamRecord<Integer>(i));
+				}
+				Thread.sleep(1);
+			}
+
+			// for a deterministic result, we need to wait until all pending triggers
+			// have fired and emitted their results
+			long deadline = System.currentTimeMillis() + 120000;
+			do {
+				Thread.sleep(20);
+			}
+			while (resultAtSnapshot.size() + out2.getElements().size() < factor * numElements
+					&& System.currentTimeMillis() < deadline);
+
+			synchronized (lock) {
+				op.close();
+			}
+			op.dispose();
+
+			// get and verify the result
+			List<Integer> finalResult = new ArrayList<>(resultAtSnapshot);
+			finalResult.addAll(out2.getElements());
+			assertEquals(factor * numElements, finalResult.size());
+
+			Collections.sort(finalResult);
+			for (int i = 0; i < factor * numElements; i++) {
+				assertEquals(i / factor, finalResult.get(i).intValue());
+			}
+		}
+		catch (Exception e) {
+			e.printStackTrace();
+			fail(e.getMessage());
+		}
+		finally {
+			timerService.shutdown();
+		}
+	}
 	
 	// ------------------------------------------------------------------------
 	
 	private void assertInvalidParameter(long windowSize, long windowSlide) {
 		try {
 			new AggregatingProcessingTimeWindowOperator<String, String>(
-					mockFunction, mockKeySelector, windowSize, windowSlide);
+					mockFunction, mockKeySelector,
+					StringSerializer.INSTANCE, StringSerializer.INSTANCE,
+					windowSize, windowSlide);
 			fail("This should fail with an IllegalArgumentException");
 		}
 		catch (IllegalArgumentException e) {
@@ -537,6 +782,12 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest {
 		when(env.getUserClassLoader()).thenReturn(AggregatingAlignedProcessingTimeWindowOperatorTest.class.getClassLoader());
 		
 		when(task.getEnvironment()).thenReturn(env);
+
+		// ugly java generic hacks to get the state backend into the mock
+		@SuppressWarnings("unchecked")
+		OngoingStubbing<StateBackend<?>> stubbing =
+				(OngoingStubbing<StateBackend<?>>) (OngoingStubbing<?>) when(task.getStateBackend());
+		stubbing.thenReturn(MemoryStateBackend.defaultInstance());
 		
 		return task;
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/c24dca50/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/util/MockContext.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/util/MockContext.java b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/util/MockContext.java
index 000a1a2..81d3a69 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/util/MockContext.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/util/MockContext.java
@@ -28,17 +28,22 @@ import java.util.concurrent.TimeUnit;
 
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.accumulators.Accumulator;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.operators.testutils.MockEnvironment;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.api.operators.Output;
+import org.apache.flink.streaming.api.state.StateBackend;
+import org.apache.flink.streaming.api.state.memory.MemoryStateBackend;
 import org.apache.flink.streaming.runtime.operators.Triggerable;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
+import org.mockito.stubbing.OngoingStubbing;
 
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyLong;
@@ -48,13 +53,11 @@ import static org.mockito.Mockito.when;
 
 public class MockContext<IN, OUT> {
 	
-	private Collection<IN> inputs;
 	private List<OUT> outputs;
 
 	private MockOutput<OUT> output;
 
 	public MockContext(Collection<IN> inputs) {
-		this.inputs = inputs;
 		if (inputs.isEmpty()) {
 			throw new RuntimeException("Inputs must not be empty");
 		}
@@ -72,20 +75,35 @@ public class MockContext<IN, OUT> {
 	}
 
 	public static <IN, OUT> List<OUT> createAndExecute(OneInputStreamOperator<IN, OUT> operator, List<IN> inputs) {
+		return createAndExecuteForKeyedStream(operator, inputs, null, null);
+	}
+	
+	public static <IN, OUT, KEY> List<OUT> createAndExecuteForKeyedStream(
+				OneInputStreamOperator<IN, OUT> operator, List<IN> inputs,
+				KeySelector<IN, KEY> keySelector, TypeInformation<KEY> keyType) {
+		
 		MockContext<IN, OUT> mockContext = new MockContext<IN, OUT>(inputs);
 
+		StreamConfig config = new StreamConfig(new Configuration());
+		if (keySelector != null && keyType != null) {
+			config.setStateKeySerializer(keyType.createSerializer(new ExecutionConfig()));
+			config.setStatePartitioner(keySelector);
+		}
+		
 		final ScheduledExecutorService timerService = Executors.newSingleThreadScheduledExecutor();
 		final Object lock = new Object();
 		final StreamTask<?, ?> mockTask = createMockTaskWithTimer(timerService, lock);
 				
-		operator.setup(mockTask, new StreamConfig(new Configuration()), mockContext.output);
+		operator.setup(mockTask, config, mockContext.output);
 		try {
 			operator.open();
 
-			StreamRecord<IN> nextRecord;
+			StreamRecord<IN> record = new StreamRecord<IN>(null);
 			for (IN in: inputs) {
+				record = record.replace(in);
 				synchronized (lock) {
-					operator.processElement(new StreamRecord<IN>(in));
+					operator.setKeyContextElement(record);
+					operator.processElement(record);
 				}
 			}
 
@@ -130,6 +148,12 @@ public class MockContext<IN, OUT> {
 			}
 		}).when(task).registerTimer(anyLong(), any(Triggerable.class));
 
+		// ugly Java generic hacks to get the generic state backend into the mock
+		@SuppressWarnings("unchecked")
+		OngoingStubbing<StateBackend<?>> stubbing =
+				(OngoingStubbing<StateBackend<?>>) (OngoingStubbing<?>) when(task.getStateBackend());
+		stubbing.thenReturn(MemoryStateBackend.defaultInstance());
+		
 		return task;
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/c24dca50/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java
index edf3a09..b83feca 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java
@@ -18,7 +18,10 @@
 package org.apache.flink.streaming.util;
 
 import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.ClosureCleaner;
+import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.api.java.typeutils.TypeExtractor;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.execution.Environment;
@@ -54,17 +57,17 @@ public class OneInputStreamOperatorTestHarness<IN, OUT> {
 
 	final ConcurrentLinkedQueue<Object> outputList;
 
+	final StreamConfig config;
+	
 	final ExecutionConfig executionConfig;
 	
 	final Object checkpointLock;
-
-	public OneInputStreamOperatorTestHarness(OneInputStreamOperator<IN, OUT> operator) {
-		this(operator, new StreamConfig(new Configuration()));
-	}
 	
-	public OneInputStreamOperatorTestHarness(OneInputStreamOperator<IN, OUT> operator, StreamConfig config) {
+	
+	public OneInputStreamOperatorTestHarness(OneInputStreamOperator<IN, OUT> operator) {
 		this.operator = operator;
 		this.outputList = new ConcurrentLinkedQueue<Object>();
+		this.config = new StreamConfig(new Configuration());
 		this.executionConfig = new ExecutionConfig();
 		this.checkpointLock = new Object();
 
@@ -82,9 +85,15 @@ public class OneInputStreamOperatorTestHarness<IN, OUT> {
 				(OngoingStubbing<StateBackend<?>>) (OngoingStubbing<?>) when(mockTask.getStateBackend());
 		stubbing.thenReturn(MemoryStateBackend.defaultInstance());
 
-		operator.setup(mockTask, new StreamConfig(new Configuration()), new MockOutput());
+		operator.setup(mockTask, config, new MockOutput());
 	}
 
+	public <K> void configureForKeyedStream(KeySelector<IN, K> keySelector, TypeInformation<K> keyType) {
+		ClosureCleaner.clean(keySelector, false);
+		config.setStatePartitioner(keySelector);
+		config.setStateKeySerializer(keyType.createSerializer(executionConfig));
+	}
+	
 	/**
 	 * Get all the output from the task. This contains StreamRecords and Events interleaved. Use
 	 * {@link org.apache.flink.streaming.util.TestHarnessUtil#getStreamRecordsFromOutput(java.util.List)}
@@ -109,11 +118,13 @@ public class OneInputStreamOperatorTestHarness<IN, OUT> {
 	}
 
 	public void processElement(StreamRecord<IN> element) throws Exception {
+		operator.setKeyContextElement(element);
 		operator.processElement(element);
 	}
 
 	public void processElements(Collection<StreamRecord<IN>> elements) throws Exception {
 		for (StreamRecord<IN> element: elements) {
+			operator.setKeyContextElement(element);
 			operator.processElement(element);
 		}
 	}
@@ -127,13 +138,11 @@ public class OneInputStreamOperatorTestHarness<IN, OUT> {
 		private TypeSerializer<OUT> outputSerializer;
 
 		@Override
-		@SuppressWarnings("unchecked")
 		public void emitWatermark(Watermark mark) {
 			outputList.add(mark);
 		}
 
 		@Override
-		@SuppressWarnings("unchecked")
 		public void collect(StreamRecord<OUT> element) {
 			if (outputSerializer == null) {
 				outputSerializer = TypeExtractor.getForObject(element.getValue()).createSerializer(executionConfig);