You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by kk...@apache.org on 2017/02/28 18:06:03 UTC

flink git commit: [FLINK-5420] [cep] Make the CEP operators rescalable

Repository: flink
Updated Branches:
  refs/heads/master 788b83921 -> daf0ccda4


[FLINK-5420] [cep] Make the CEP operators rescalable

Introduces the KeyRegistry in the TimeServiceHandler
which allows to specify a callback and register keys
for which we want this callback to be invoked on each
watermark.

Given this service, now the CEP operator has only
keyed state, and the non-keyed one (keys) are
handled by the KeyRegistry.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/daf0ccda
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/daf0ccda
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/daf0ccda

Branch: refs/heads/master
Commit: daf0ccda4dc60a267be7b8074d40e48d22ccb13f
Parents: 788b839
Author: kl0u <kk...@gmail.com>
Authored: Fri Feb 24 10:34:43 2017 +0100
Committer: kl0u <kk...@gmail.com>
Committed: Tue Feb 28 19:04:27 2017 +0100

----------------------------------------------------------------------
 .../AbstractKeyedCEPPatternOperator.java        | 132 +++---
 .../flink/cep/operator/CEPOperatorUtils.java    |   1 -
 .../flink/cep/operator/CEPOperatorTest.java     |  26 +-
 .../flink/cep/operator/CEPRescalingTest.java    | 417 +++++++++++++++++++
 .../api/operators/AbstractStreamOperator.java   | 133 +++---
 .../operators/InternalTimeServiceManager.java   | 191 +++++++++
 .../InternalWatermarkCallbackService.java       | 239 +++++++++++
 .../api/operators/OnWatermarkCallback.java      |  41 ++
 .../operators/AbstractStreamOperatorTest.java   | 234 ++++++++++-
 9 files changed, 1245 insertions(+), 169 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/daf0ccda/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/AbstractKeyedCEPPatternOperator.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/AbstractKeyedCEPPatternOperator.java b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/AbstractKeyedCEPPatternOperator.java
index 90ee846..f534bec 100644
--- a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/AbstractKeyedCEPPatternOperator.java
+++ b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/AbstractKeyedCEPPatternOperator.java
@@ -24,14 +24,12 @@ import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.cep.nfa.NFA;
 import org.apache.flink.cep.nfa.compiler.NFACompiler;
-import org.apache.flink.core.fs.FSDataInputStream;
-import org.apache.flink.core.fs.FSDataOutputStream;
 import org.apache.flink.core.memory.DataInputView;
-import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputView;
-import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
-import org.apache.flink.streaming.api.operators.StreamCheckpointedOperator;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.operators.InternalWatermarkCallbackService;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OnWatermarkCallback;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer;
@@ -40,10 +38,8 @@ import org.apache.flink.util.Preconditions;
 
 import java.io.IOException;
 import java.io.Serializable;
-import java.util.HashSet;
 import java.util.Objects;
 import java.util.PriorityQueue;
-import java.util.Set;
 
 /**
  * Abstract CEP pattern operator for a keyed input stream. For each key, the operator creates
@@ -58,7 +54,7 @@ import java.util.Set;
  */
 public abstract class AbstractKeyedCEPPatternOperator<IN, KEY, OUT>
 	extends AbstractStreamOperator<OUT>
-	implements OneInputStreamOperator<IN, OUT>, StreamCheckpointedOperator {
+	implements OneInputStreamOperator<IN, OUT> {
 
 	private static final long serialVersionUID = -4166778210774160757L;
 
@@ -76,11 +72,6 @@ public abstract class AbstractKeyedCEPPatternOperator<IN, KEY, OUT>
 
 	///////////////			State			//////////////
 
-	// stores the keys we've already seen to trigger execution upon receiving a watermark
-	// this can be problematic, since it is never cleared
-	// TODO: fix once the state refactoring is completed
-	private transient Set<KEY> keys;
-
 	private static final String NFA_OPERATOR_STATE_NAME = "nfaOperatorState";
 	private static final String PRIORITY_QUEUE_STATE_NAME = "priorityQueueStateName";
 
@@ -109,25 +100,31 @@ public abstract class AbstractKeyedCEPPatternOperator<IN, KEY, OUT>
 	}
 
 	@Override
-	@SuppressWarnings("unchecked")
-	public void open() throws Exception {
-		super.open();
+	public void initializeState(StateInitializationContext context) throws Exception {
+		super.initializeState(context);
 
-		if (keys == null) {
-			keys = new HashSet<>();
-		}
+		// we have to call initializeState here and in the migration restore()
+		// method because the restore() (from legacy) is called before the
+		// initializeState().
+
+		initializeState();
+	}
+
+	private void initializeState() {
 
 		if (nfaOperatorState == null) {
-			nfaOperatorState = getPartitionedState(
-				new ValueStateDescriptor<>(NFA_OPERATOR_STATE_NAME, new NFA.Serializer<IN>()));
+			nfaOperatorState = getRuntimeContext().getState(
+				new ValueStateDescriptor<>(
+					NFA_OPERATOR_STATE_NAME,
+					new NFA.Serializer<IN>()));
 		}
 
 		@SuppressWarnings("unchecked,rawtypes")
 		TypeSerializer<StreamRecord<IN>> streamRecordSerializer =
-			(TypeSerializer) new StreamElementSerializer<>(getInputSerializer());
+			(TypeSerializer) new StreamElementSerializer<>(inputSerializer);
 
 		if (priorityQueueOperatorState == null) {
-			priorityQueueOperatorState = getPartitionedState(
+			priorityQueueOperatorState = getRuntimeContext().getState(
 				new ValueStateDescriptor<>(
 					PRIORITY_QUEUE_STATE_NAME,
 					new PriorityQueueSerializer<>(
@@ -139,6 +136,39 @@ public abstract class AbstractKeyedCEPPatternOperator<IN, KEY, OUT>
 		}
 	}
 
+	@Override
+	public void open() throws Exception {
+		super.open();
+
+		InternalWatermarkCallbackService<KEY> watermarkCallbackService = getInternalWatermarkCallbackService();
+
+		watermarkCallbackService.setWatermarkCallback(
+			new OnWatermarkCallback<KEY>() {
+
+				@Override
+				public void onWatermark(KEY key, Watermark watermark) throws IOException {
+					setCurrentKey(key);
+
+					PriorityQueue<StreamRecord<IN>> priorityQueue = getPriorityQueue();
+					NFA<IN> nfa = getNFA();
+
+					if (priorityQueue.isEmpty()) {
+						advanceTime(nfa, watermark.getTimestamp());
+					} else {
+						while (!priorityQueue.isEmpty() && priorityQueue.peek().getTimestamp() <= watermark.getTimestamp()) {
+							StreamRecord<IN> streamRecord = priorityQueue.poll();
+							processEvent(nfa, streamRecord.getValue(), streamRecord.getTimestamp());
+						}
+					}
+
+					updateNFA(nfa);
+					updatePriorityQueue(priorityQueue);
+				}
+			},
+			keySerializer
+		);
+	}
+
 	private NFA<IN> getNFA() throws IOException {
 		NFA<IN> nfa = nfaOperatorState.value();
 
@@ -173,7 +203,7 @@ public abstract class AbstractKeyedCEPPatternOperator<IN, KEY, OUT>
 
 	@Override
 	public void processElement(StreamRecord<IN> element) throws Exception {
-		keys.add(keySelector.getKey(element.getValue()));
+		getInternalWatermarkCallbackService().registerKeyForWatermarkCallback(keySelector.getKey(element.getValue()));
 
 		if (isProcessingTime) {
 			// there can be no out of order elements in processing time
@@ -195,35 +225,6 @@ public abstract class AbstractKeyedCEPPatternOperator<IN, KEY, OUT>
 		}
 	}
 
-	@Override
-	public void processWatermark(Watermark mark) throws Exception {
-		// we do our own watermark handling, no super call. we will never be able to use
-		// the timer service like this, however.
-
-		// iterate over all keys to trigger the execution of the buffered elements
-		for (KEY key: keys) {
-			setCurrentKey(key);
-
-			PriorityQueue<StreamRecord<IN>> priorityQueue = getPriorityQueue();
-			NFA<IN> nfa = getNFA();
-
-			if (priorityQueue.isEmpty()) {
-				advanceTime(nfa, mark.getTimestamp());
-			} else {
-				while (!priorityQueue.isEmpty() && priorityQueue.peek().getTimestamp() <= mark.getTimestamp()) {
-					StreamRecord<IN> streamRecord = priorityQueue.poll();
-
-					processEvent(nfa, streamRecord.getValue(), streamRecord.getTimestamp());
-				}
-			}
-
-			updateNFA(nfa);
-			updatePriorityQueue(priorityQueue);
-		}
-
-		output.emitWatermark(mark);
-	}
-
 	/**
 	 * Process the given event by giving it to the NFA and outputting the produced set of matched
 	 * event sequences.
@@ -243,31 +244,6 @@ public abstract class AbstractKeyedCEPPatternOperator<IN, KEY, OUT>
 	 */
 	protected abstract void advanceTime(NFA<IN> nfa, long timestamp);
 
-	@Override
-	public void snapshotState(FSDataOutputStream out, long checkpointId, long timestamp) throws Exception {
-		DataOutputView ov = new DataOutputViewStreamWrapper(out);
-		ov.writeInt(keys.size());
-
-		for (KEY key: keys) {
-			keySerializer.serialize(key, ov);
-		}
-	}
-
-	@Override
-	public void restoreState(FSDataInputStream state) throws Exception {
-		DataInputView inputView = new DataInputViewStreamWrapper(state);
-
-		if (keys == null) {
-			keys = new HashSet<>();
-		}
-
-		int numberEntries = inputView.readInt();
-
-		for (int i = 0; i <numberEntries; i++) {
-			keys.add(keySerializer.deserialize(inputView));
-		}
-	}
-
 	//////////////////////			Utility Classes			//////////////////////
 
 	/**

http://git-wip-us.apache.org/repos/asf/flink/blob/daf0ccda/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/CEPOperatorUtils.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/CEPOperatorUtils.java b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/CEPOperatorUtils.java
index 36f2e7a..56ecb17 100644
--- a/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/CEPOperatorUtils.java
+++ b/flink-libraries/flink-cep/src/main/java/org/apache/flink/cep/operator/CEPOperatorUtils.java
@@ -137,7 +137,6 @@ public class CEPOperatorUtils {
 		} else {
 
 			KeySelector<T, Byte> keySelector = new NullByteKeySelector<>();
-
 			TypeSerializer<Byte> keySerializer = ByteSerializer.INSTANCE;
 
 			patternStream = inputStream.keyBy(keySelector).transform(

http://git-wip-us.apache.org/repos/asf/flink/blob/daf0ccda/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPOperatorTest.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPOperatorTest.java b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPOperatorTest.java
index f90b670..1899cb4 100644
--- a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPOperatorTest.java
+++ b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPOperatorTest.java
@@ -31,11 +31,11 @@ import org.apache.flink.cep.nfa.NFA;
 import org.apache.flink.cep.nfa.compiler.NFACompiler;
 import org.apache.flink.cep.pattern.Pattern;
 import org.apache.flink.contrib.streaming.state.RocksDBStateBackend;
-import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.api.windowing.time.Time;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
 import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness;
 import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
 import org.apache.flink.types.Either;
@@ -123,7 +123,7 @@ public class CEPOperatorTest extends TestLogger {
 		harness.processElement(new StreamRecord<Event>(new Event(42, "foobar", 1.0), 2));
 
 		// simulate snapshot/restore with some elements in internal sorting queue
-		StreamStateHandle snapshot = harness.snapshotLegacy(0, 0);
+		OperatorStateHandles snapshot = harness.snapshot(0, 0);
 		harness.close();
 
 		harness = new KeyedOneInputStreamOperatorTestHarness<>(
@@ -137,7 +137,7 @@ public class CEPOperatorTest extends TestLogger {
 				BasicTypeInfo.INT_TYPE_INFO);
 
 		harness.setup();
-		harness.restore(snapshot);
+		harness.initializeState(snapshot);
 		harness.open();
 
 		harness.processWatermark(new Watermark(Long.MIN_VALUE));
@@ -148,8 +148,12 @@ public class CEPOperatorTest extends TestLogger {
 		// a pruning time underflow exception in NFA
 		harness.processWatermark(new Watermark(2));
 
+		harness.processElement(new StreamRecord<Event>(middleEvent, 3));
+		harness.processElement(new StreamRecord<Event>(new Event(42, "start", 1.0), 4));
+		harness.processElement(new StreamRecord<Event>(endEvent, 5));
+
 		// simulate snapshot/restore with empty element queue but NFA state
-		StreamStateHandle snapshot2 = harness.snapshotLegacy(1, 1);
+		OperatorStateHandles snapshot2 = harness.snapshot(1, 1);
 		harness.close();
 
 		harness = new KeyedOneInputStreamOperatorTestHarness<>(
@@ -163,13 +167,9 @@ public class CEPOperatorTest extends TestLogger {
 				BasicTypeInfo.INT_TYPE_INFO);
 
 		harness.setup();
-		harness.restore(snapshot2);
+		harness.initializeState(snapshot2);
 		harness.open();
 
-		harness.processElement(new StreamRecord<Event>(middleEvent, 3));
-		harness.processElement(new StreamRecord<Event>(new Event(42, "start", 1.0), 4));
-		harness.processElement(new StreamRecord<Event>(endEvent, 5));
-
 		harness.processWatermark(new Watermark(Long.MAX_VALUE));
 
 		ConcurrentLinkedQueue<Object> result = harness.getOutput();
@@ -230,7 +230,7 @@ public class CEPOperatorTest extends TestLogger {
 		harness.processElement(new StreamRecord<Event>(new Event(42, "foobar", 1.0), 2));
 
 		// simulate snapshot/restore with some elements in internal sorting queue
-		StreamStateHandle snapshot = harness.snapshotLegacy(0, 0);
+		OperatorStateHandles snapshot = harness.snapshot(0, 0);
 		harness.close();
 
 		harness = new KeyedOneInputStreamOperatorTestHarness<>(
@@ -248,7 +248,7 @@ public class CEPOperatorTest extends TestLogger {
 		harness.setStateBackend(rocksDBStateBackend);
 
 		harness.setup();
-		harness.restore(snapshot);
+		harness.initializeState(snapshot);
 		harness.open();
 
 		harness.processWatermark(new Watermark(Long.MIN_VALUE));
@@ -260,7 +260,7 @@ public class CEPOperatorTest extends TestLogger {
 		harness.processWatermark(new Watermark(2));
 
 		// simulate snapshot/restore with empty element queue but NFA state
-		StreamStateHandle snapshot2 = harness.snapshotLegacy(1, 1);
+		OperatorStateHandles snapshot2 = harness.snapshot(1, 1);
 		harness.close();
 
 		harness = new KeyedOneInputStreamOperatorTestHarness<>(
@@ -277,7 +277,7 @@ public class CEPOperatorTest extends TestLogger {
 		rocksDBStateBackend.setDbStoragePath(rocksDbPath);
 		harness.setStateBackend(rocksDBStateBackend);
 		harness.setup();
-		harness.restore(snapshot2);
+		harness.initializeState(snapshot2);
 		harness.open();
 
 		harness.processElement(new StreamRecord<Event>(middleEvent, 3));

http://git-wip-us.apache.org/repos/asf/flink/blob/daf0ccda/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPRescalingTest.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPRescalingTest.java b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPRescalingTest.java
new file mode 100644
index 0000000..78765c0
--- /dev/null
+++ b/flink-libraries/flink-cep/src/test/java/org/apache/flink/cep/operator/CEPRescalingTest.java
@@ -0,0 +1,417 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.cep.operator;
+
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.functions.FilterFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.cep.Event;
+import org.apache.flink.cep.SubEvent;
+import org.apache.flink.cep.nfa.NFA;
+import org.apache.flink.cep.nfa.compiler.NFACompiler;
+import org.apache.flink.cep.pattern.Pattern;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.api.windowing.time.Time;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
+import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
+import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness;
+import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
+import org.junit.Test;
+
+import java.util.Map;
+import java.util.Queue;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+public class CEPRescalingTest {
+
+	@Test
+	public void testCEPFunctionScalingUp() throws Exception {
+		int maxParallelism = 10;
+
+		KeySelector<Event, Integer> keySelector = new KeySelector<Event, Integer>() {
+			private static final long serialVersionUID = -4873366487571254798L;
+
+			@Override
+			public Integer getKey(Event value) throws Exception {
+				return value.getId();
+			}
+		};
+
+		// valid pattern events belong to different keygroups
+		// that will be shipped to different tasks when changing parallelism.
+
+		Event startEvent1 = new Event(7, "start", 1.0);
+		SubEvent middleEvent1 = new SubEvent(7, "foo", 1.0, 10.0);
+		Event endEvent1=  new Event(7, "end", 1.0);
+
+		int keygroup = KeyGroupRangeAssignment.assignToKeyGroup(keySelector.getKey(startEvent1), maxParallelism);
+		assertEquals(1, keygroup);
+		assertEquals(0, KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(maxParallelism, 2, keygroup));
+
+		Event startEvent2 = new Event(10, "start", 1.0);				// this will go to task index 2
+		SubEvent middleEvent2 = new SubEvent(10, "foo", 1.0, 10.0);
+		Event endEvent2 = new Event(10, "end", 1.0);
+
+		keygroup = KeyGroupRangeAssignment.assignToKeyGroup(keySelector.getKey(startEvent2), maxParallelism);
+		assertEquals(9, keygroup);
+		assertEquals(1, KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(maxParallelism, 2, keygroup));
+
+		// now we start the test, we go from parallelism 1 to 2.
+
+		OneInputStreamOperatorTestHarness<Event, Map<String, Event>> harness =
+			getTestHarness(maxParallelism, 1, 0);
+		harness.open();
+
+		harness.processElement(new StreamRecord<>(startEvent1, 1));						// valid element
+		harness.processElement(new StreamRecord<>(new Event(7, "foobar", 1.0), 2));
+
+		harness.processElement(new StreamRecord<>(startEvent2, 3));						// valid element
+		harness.processElement(new StreamRecord<Event>(middleEvent2, 4));				// valid element
+
+		// take a snapshot with some elements in internal sorting queue
+		OperatorStateHandles snapshot = harness.snapshot(0, 0);
+		harness.close();
+
+		// initialize two sub-tasks with the previously snapshotted state to simulate scaling up
+
+		// we know that the valid element will go to index 0,
+		// so we initialize the two tasks and we put the rest of
+		// the valid elements for the pattern on task 0.
+
+		OneInputStreamOperatorTestHarness<Event, Map<String, Event>> harness1 =
+			getTestHarness(maxParallelism, 2, 0);
+
+		harness1.setup();
+		harness1.initializeState(snapshot);
+		harness1.open();
+
+		// if element timestamps are not correctly checkpointed/restored this will lead to
+		// a pruning time underflow exception in NFA
+		harness1.processWatermark(new Watermark(2));
+
+		harness1.processElement(new StreamRecord<Event>(middleEvent1, 3));				// valid element
+		harness1.processElement(new StreamRecord<>(endEvent1, 5));						// valid element
+
+		harness1.processWatermark(new Watermark(Long.MAX_VALUE));
+
+		// watermarks and the result
+		assertEquals(3, harness1.getOutput().size());
+		verifyWatermark(harness1.getOutput().poll(), 2);
+		verifyPattern(harness1.getOutput().poll(), startEvent1, middleEvent1, endEvent1);
+
+		OneInputStreamOperatorTestHarness<Event, Map<String, Event>> harness2 =
+			getTestHarness(maxParallelism, 2, 1);
+
+		harness2.setup();
+		harness2.initializeState(snapshot);
+		harness2.open();
+
+		// now we move to the second parallel task
+		harness2.processWatermark(new Watermark(2));
+
+		harness2.processElement(new StreamRecord<>(endEvent2, 5));
+		harness2.processElement(new StreamRecord<>(new Event(42, "start", 1.0), 4));
+
+		harness2.processWatermark(new Watermark(Long.MAX_VALUE));
+
+		assertEquals(3, harness2.getOutput().size());
+		verifyWatermark(harness2.getOutput().poll(), 2);
+		verifyPattern(harness2.getOutput().poll(), startEvent2, middleEvent2, endEvent2);
+
+		harness.close();
+		harness1.close();
+		harness2.close();
+	}
+
+	@Test
+	public void testCEPFunctionScalingDown() throws Exception {
+		int maxParallelism = 10;
+
+		KeySelector<Event, Integer> keySelector = new KeySelector<Event, Integer>() {
+			private static final long serialVersionUID = -4873366487571254798L;
+
+			@Override
+			public Integer getKey(Event value) throws Exception {
+				return value.getId();
+			}
+		};
+
+		// create some valid pattern events on predetermined key groups and task indices
+
+		Event startEvent1 = new Event(7, "start", 1.0);					// this will go to task index 0
+		SubEvent middleEvent1 = new SubEvent(7, "foo", 1.0, 10.0);
+		Event endEvent1 = new Event(7, "end", 1.0);
+
+		// verification of the key choice
+		int keygroup = KeyGroupRangeAssignment.assignToKeyGroup(keySelector.getKey(startEvent1), maxParallelism);
+		assertEquals(1, keygroup);
+		assertEquals(0, KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(maxParallelism, 3, keygroup));
+		assertEquals(0, KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(maxParallelism, 2, keygroup));
+
+		Event startEvent2 = new Event(45, "start", 1.0);				// this will go to task index 1
+		SubEvent middleEvent2 = new SubEvent(45, "foo", 1.0, 10.0);
+		Event endEvent2 = new Event(45, "end", 1.0);
+
+		keygroup = KeyGroupRangeAssignment.assignToKeyGroup(keySelector.getKey(startEvent2), maxParallelism);
+		assertEquals(6, keygroup);
+		assertEquals(1, KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(maxParallelism, 3, keygroup));
+		assertEquals(1, KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(maxParallelism, 2, keygroup));
+
+		Event startEvent3 = new Event(90, "start", 1.0);				// this will go to task index 0
+		SubEvent middleEvent3 = new SubEvent(90, "foo", 1.0, 10.0);
+		Event endEvent3 = new Event(90, "end", 1.0);
+
+		keygroup = KeyGroupRangeAssignment.assignToKeyGroup(keySelector.getKey(startEvent3), maxParallelism);
+		assertEquals(2, keygroup);
+		assertEquals(0, KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(maxParallelism, 3, keygroup));
+		assertEquals(0, KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(maxParallelism, 2, keygroup));
+
+		Event startEvent4 = new Event(10, "start", 1.0);				// this will go to task index 2
+		SubEvent middleEvent4 = new SubEvent(10, "foo", 1.0, 10.0);
+		Event endEvent4 = new Event(10, "end", 1.0);
+
+		keygroup = KeyGroupRangeAssignment.assignToKeyGroup(keySelector.getKey(startEvent4), maxParallelism);
+		assertEquals(9, keygroup);
+		assertEquals(2, KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(maxParallelism, 3, keygroup));
+		assertEquals(1, KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(maxParallelism, 2, keygroup));
+
+		// starting the test, we will go from parallelism of 3 to parallelism of 2
+
+		OneInputStreamOperatorTestHarness<Event, Map<String, Event>> harness1 =
+			getTestHarness(maxParallelism, 3, 0);
+		harness1.open();
+
+		OneInputStreamOperatorTestHarness<Event, Map<String, Event>> harness2 =
+			getTestHarness(maxParallelism, 3, 1);
+		harness2.open();
+
+		OneInputStreamOperatorTestHarness<Event, Map<String, Event>> harness3 =
+			getTestHarness(maxParallelism, 3, 2);
+		harness3.open();
+
+		harness1.processWatermark(Long.MIN_VALUE);
+		harness2.processWatermark(Long.MIN_VALUE);
+		harness3.processWatermark(Long.MIN_VALUE);
+
+		harness1.processElement(new StreamRecord<>(startEvent1, 1));						// valid element
+		harness1.processElement(new StreamRecord<>(new Event(7, "foobar", 1.0), 2));
+		harness1.processElement(new StreamRecord<Event>(middleEvent1, 3));					// valid element
+		harness1.processElement(new StreamRecord<>(endEvent1, 5));							// valid element
+
+		// till here we have a valid sequence, so after creating the
+		// new instance and sending it a watermark, we expect it to fire,
+		// even with no new elements.
+
+		harness1.processElement(new StreamRecord<>(startEvent3, 10));
+		harness1.processElement(new StreamRecord<>(startEvent1, 10));
+
+		harness2.processElement(new StreamRecord<>(startEvent2, 7));
+		harness2.processElement(new StreamRecord<Event>(middleEvent2, 8));
+
+		harness3.processElement(new StreamRecord<>(startEvent4, 15));
+		harness3.processElement(new StreamRecord<Event>(middleEvent4, 16));
+		harness3.processElement(new StreamRecord<>(endEvent4, 17));
+
+		// so far we only have the initial watermark
+		assertEquals(1, harness1.getOutput().size());
+		verifyWatermark(harness1.getOutput().poll(), Long.MIN_VALUE);
+
+		assertEquals(1, harness2.getOutput().size());
+		verifyWatermark(harness2.getOutput().poll(), Long.MIN_VALUE);
+
+		assertEquals(1, harness3.getOutput().size());
+		verifyWatermark(harness3.getOutput().poll(), Long.MIN_VALUE);
+
+		// we take a snapshot and make it look as a single operator
+		// this will be the initial state of all downstream tasks.
+		OperatorStateHandles snapshot = AbstractStreamOperatorTestHarness.repackageState(
+			harness2.snapshot(0, 0),
+			harness1.snapshot(0, 0),
+			harness3.snapshot(0, 0)
+		);
+
+		OneInputStreamOperatorTestHarness<Event, Map<String, Event>> harness4 =
+			getTestHarness(maxParallelism, 2, 0);
+		harness4.setup();
+		harness4.initializeState(snapshot);
+		harness4.open();
+
+		OneInputStreamOperatorTestHarness<Event, Map<String, Event>> harness5 =
+			getTestHarness(maxParallelism, 2, 1);
+		harness5.setup();
+		harness5.initializeState(snapshot);
+		harness5.open();
+
+		harness5.processElement(new StreamRecord<>(endEvent2, 11));
+		harness5.processWatermark(new Watermark(12));
+
+		verifyPattern(harness5.getOutput().poll(), startEvent2, middleEvent2, endEvent2);
+		verifyWatermark(harness5.getOutput().poll(), 12);
+
+		// if element timestamps are not correctly checkpointed/restored this will lead to
+		// a pruning time underflow exception in NFA
+		harness4.processWatermark(new Watermark(12));
+
+		assertEquals(2, harness4.getOutput().size());
+		verifyPattern(harness4.getOutput().poll(), startEvent1, middleEvent1, endEvent1);
+		verifyWatermark(harness4.getOutput().poll(), 12);
+
+		harness4.processElement(new StreamRecord<Event>(middleEvent3, 15));			// valid element
+		harness4.processElement(new StreamRecord<>(endEvent3, 16));					// valid element
+
+		harness4.processElement(new StreamRecord<Event>(middleEvent1, 15));			// valid element
+		harness4.processElement(new StreamRecord<>(endEvent1, 16));					// valid element
+
+		harness4.processWatermark(new Watermark(Long.MAX_VALUE));
+		harness5.processWatermark(new Watermark(Long.MAX_VALUE));
+
+		// verify result
+		assertEquals(3, harness4.getOutput().size());
+
+		// check the order of the events in the output
+		Queue<Object> output = harness4.getOutput();
+		StreamRecord<?> resultRecord = (StreamRecord<?>) output.peek();
+		assertTrue(resultRecord.getValue() instanceof Map);
+
+		@SuppressWarnings("unchecked")
+		Map<String, Event> patternMap = (Map<String, Event>) resultRecord.getValue();
+		if (patternMap.get("start").getId() == 7) {
+			verifyPattern(harness4.getOutput().poll(), startEvent1, middleEvent1, endEvent1);
+			verifyPattern(harness4.getOutput().poll(), startEvent3, middleEvent3, endEvent3);
+		} else {
+			verifyPattern(harness4.getOutput().poll(), startEvent3, middleEvent3, endEvent3);
+			verifyPattern(harness4.getOutput().poll(), startEvent1, middleEvent1, endEvent1);
+		}
+
+		// after scaling down this should end up here
+		assertEquals(2, harness5.getOutput().size());
+		verifyPattern(harness5.getOutput().poll(), startEvent4, middleEvent4, endEvent4);
+
+		harness1.close();
+		harness2.close();
+		harness3.close();
+		harness4.close();
+		harness5.close();
+	}
+
+	private void verifyWatermark(Object outputObject, long timestamp) {
+		assertTrue(outputObject instanceof Watermark);
+		assertEquals(timestamp, ((Watermark) outputObject).getTimestamp());
+	}
+
+	private void verifyPattern(Object outputObject, Event start, SubEvent middle, Event end) {
+		assertTrue(outputObject instanceof StreamRecord);
+
+		StreamRecord<?> resultRecord = (StreamRecord<?>) outputObject;
+		assertTrue(resultRecord.getValue() instanceof Map);
+
+		@SuppressWarnings("unchecked")
+		Map<String, Event> patternMap = (Map<String, Event>) resultRecord.getValue();
+		assertEquals(start, patternMap.get("start"));
+		assertEquals(middle, patternMap.get("middle"));
+		assertEquals(end, patternMap.get("end"));
+	}
+
+	private KeyedOneInputStreamOperatorTestHarness<Integer, Event, Map<String, Event>> getTestHarness(
+		int maxParallelism,
+		int taskParallelism,
+		int subtaskIdx) throws Exception {
+
+		KeySelector<Event, Integer> keySelector = new TestKeySelector();
+		return new KeyedOneInputStreamOperatorTestHarness<>(
+			new KeyedCEPPatternOperator<>(
+				Event.createTypeSerializer(),
+				false,
+				keySelector,
+				BasicTypeInfo.INT_TYPE_INFO.createSerializer(new ExecutionConfig()),
+				new NFAFactory()),
+			keySelector,
+			BasicTypeInfo.INT_TYPE_INFO,
+			maxParallelism,
+			taskParallelism,
+			subtaskIdx);
+	}
+
+	private static class NFAFactory implements NFACompiler.NFAFactory<Event> {
+
+		private static final long serialVersionUID = 1173020762472766713L;
+
+		private final boolean handleTimeout;
+
+		private NFAFactory() {
+			this(false);
+		}
+
+		private NFAFactory(boolean handleTimeout) {
+			this.handleTimeout = handleTimeout;
+		}
+
+		@Override
+		public NFA<Event> createNFA() {
+
+			Pattern<Event, ?> pattern = Pattern.<Event>begin("start").where(new FilterFunction<Event>() {
+				private static final long serialVersionUID = 5726188262756267490L;
+
+				@Override
+				public boolean filter(Event value) throws Exception {
+					return value.getName().equals("start");
+				}
+			})
+				.followedBy("middle").subtype(SubEvent.class).where(new FilterFunction<SubEvent>() {
+					private static final long serialVersionUID = 6215754202506583964L;
+
+					@Override
+					public boolean filter(SubEvent value) throws Exception {
+						return value.getVolume() > 5.0;
+					}
+				})
+				.followedBy("end").where(new FilterFunction<Event>() {
+					private static final long serialVersionUID = 7056763917392056548L;
+
+					@Override
+					public boolean filter(Event value) throws Exception {
+						return value.getName().equals("end");
+					}
+				})
+				// add a window timeout to test whether timestamps of elements in the
+				// priority queue in CEP operator are correctly checkpointed/restored
+				.within(Time.milliseconds(10L));
+
+			return NFACompiler.compile(pattern, Event.createTypeSerializer(), handleTimeout);
+		}
+	}
+
+	/**
+	 * A simple {@link KeySelector} that returns as key the id of the {@link Event}
+	 * provided as argument in the {@link #getKey(Event)}.
+	 * */
+	private static class TestKeySelector implements KeySelector<Event, Integer> {
+		private static final long serialVersionUID = -4873366487571254798L;
+
+		@Override
+		public Integer getKey(Event value) throws Exception {
+			return value.getId();
+		}
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/daf0ccda/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
index 05fda28..a81056f 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
@@ -69,6 +69,7 @@ import org.apache.flink.streaming.runtime.tasks.StreamTask;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.io.Serializable;
 import java.util.Collection;
 import java.util.ConcurrentModificationException;
 import java.util.HashMap;
@@ -82,8 +83,7 @@ import static org.apache.flink.util.Preconditions.checkArgument;
  * 
  * <p>For concrete implementations, one of the following two interfaces must also be implemented, to
  * mark the operator as unary or binary:
- * {@link org.apache.flink.streaming.api.operators.OneInputStreamOperator} or
- * {@link org.apache.flink.streaming.api.operators.TwoInputStreamOperator}.
+ * {@link OneInputStreamOperator} or {@link TwoInputStreamOperator}.
  *
  * <p>Methods of {@code StreamOperator} are guaranteed not to be called concurrently. Also, if using
  * the timer service, timer callbacks are also guaranteed not to be called concurrently with
@@ -93,7 +93,7 @@ import static org.apache.flink.util.Preconditions.checkArgument;
  */
 @PublicEvolving
 public abstract class AbstractStreamOperator<OUT>
-		implements StreamOperator<OUT>, java.io.Serializable, KeyContext {
+		implements StreamOperator<OUT>, Serializable, KeyContext {
 
 	private static final long serialVersionUID = 1L;
 	
@@ -146,10 +146,9 @@ public abstract class AbstractStreamOperator<OUT>
 
 	protected transient LatencyGauge latencyGauge;
 
-	// ---------------- timers ------------------
-
-	private transient Map<String, HeapInternalTimerService<?, ?>> timerServices;
+	// ---------------- time handler ------------------
 
+	private transient InternalTimeServiceManager<?, ?> timeServiceManager;
 
 	// ---------------- two-input operator watermarks ------------------
 
@@ -206,6 +205,14 @@ public abstract class AbstractStreamOperator<OUT>
 
 		initKeyedState(); //TODO we should move the actual initialization of this from StreamTask to this class
 
+		if (getKeyedStateBackend() != null && timeServiceManager == null) {
+			timeServiceManager = new InternalTimeServiceManager<>(
+				getKeyedStateBackend().getNumberOfKeyGroups(),
+				getKeyedStateBackend().getKeyGroupRange(),
+				this,
+				getRuntimeContext().getProcessingTimeService());
+		}
+
 		if (restoring) {
 
 			restoreStreamCheckpointed(stateHandles);
@@ -268,11 +275,7 @@ public abstract class AbstractStreamOperator<OUT>
 	 * @throws Exception An exception in this method causes the operator to fail.
 	 */
 	@Override
-	public void open() throws Exception {
-		if (timerServices == null) {
-			timerServices = new HashMap<>();
-		}
-	}
+	public void open() throws Exception {}
 
 	private void initKeyedState() {
 		try {
@@ -308,9 +311,9 @@ public abstract class AbstractStreamOperator<OUT>
 
 	/**
 	 * This method is called after all records have been added to the operators via the methods
-	 * {@link org.apache.flink.streaming.api.operators.OneInputStreamOperator#processElement(StreamRecord)}, or
-	 * {@link org.apache.flink.streaming.api.operators.TwoInputStreamOperator#processElement1(StreamRecord)} and
-	 * {@link org.apache.flink.streaming.api.operators.TwoInputStreamOperator#processElement2(StreamRecord)}.
+	 * {@link OneInputStreamOperator#processElement(StreamRecord)}, or
+	 * {@link TwoInputStreamOperator#processElement1(StreamRecord)} and
+	 * {@link TwoInputStreamOperator#processElement2(StreamRecord)}.
 
 	 * <p>The method is expected to flush all remaining buffered data. Exceptions during this flushing
 	 * of buffered should be propagated, in order to cause the operation to be recognized asa failed,
@@ -408,16 +411,8 @@ public abstract class AbstractStreamOperator<OUT>
 				for (int keyGroupIdx : allKeyGroups) {
 					out.startNewKeyGroup(keyGroupIdx);
 
-					DataOutputViewStreamWrapper dov = new DataOutputViewStreamWrapper(out);
-					dov.writeInt(timerServices.size());
-
-					for (Map.Entry<String, HeapInternalTimerService<?, ?>> entry : timerServices.entrySet()) {
-						String serviceName = entry.getKey();
-						HeapInternalTimerService<?, ?> timerService = entry.getValue();
-
-						dov.writeUTF(serviceName);
-						timerService.snapshotTimersForKeyGroup(dov, keyGroupIdx);
-					}
+					timeServiceManager.snapshotStateForKeyGroup(
+						new DataOutputViewStreamWrapper(out), keyGroupIdx);
 				}
 			} catch (Exception exception) {
 				throw new Exception("Could not write timer service of " + getOperatorName() +
@@ -465,35 +460,17 @@ public abstract class AbstractStreamOperator<OUT>
 	 */
 	public void initializeState(StateInitializationContext context) throws Exception {
 		if (getKeyedStateBackend() != null) {
-			int totalKeyGroups = getKeyedStateBackend().getNumberOfKeyGroups();
 			KeyGroupsList localKeyGroupRange = getKeyedStateBackend().getKeyGroupRange();
 
-			// initialize the map with the timer services
-			this.timerServices = new HashMap<>();
-
 			// and then initialize the timer services
 			for (KeyGroupStatePartitionStreamProvider streamProvider : context.getRawKeyedStateInputs()) {
-				DataInputViewStreamWrapper div = new DataInputViewStreamWrapper(streamProvider.getStream());
-
 				int keyGroupIdx = streamProvider.getKeyGroupId();
 				checkArgument(localKeyGroupRange.contains(keyGroupIdx),
 					"Key Group " + keyGroupIdx + " does not belong to the local range.");
 
-				int noOfTimerServices = div.readInt();
-				for (int i = 0; i < noOfTimerServices; i++) {
-					String serviceName = div.readUTF();
-
-					HeapInternalTimerService<?, ?> timerService = this.timerServices.get(serviceName);
-					if (timerService == null) {
-						timerService = new HeapInternalTimerService<>(
-							totalKeyGroups,
-							localKeyGroupRange,
-							this,
-							getRuntimeContext().getProcessingTimeService());
-						this.timerServices.put(serviceName, timerService);
-					}
-					timerService.restoreTimersForKeyGroup(div, keyGroupIdx, getUserCodeClassloader());
-				}
+				timeServiceManager.restoreStateForKeyGroup(
+					new DataInputViewStreamWrapper(streamProvider.getStream()),
+					keyGroupIdx, getUserCodeClassloader());
 			}
 		}
 	}
@@ -888,6 +865,20 @@ public abstract class AbstractStreamOperator<OUT>
 	// ------------------------------------------------------------------------
 
 	/**
+	 * Returns an {@link InternalWatermarkCallbackService} which  allows to register a
+	 * {@link OnWatermarkCallback} and multiple keys, for which
+	 * the callback will be invoked every time a new {@link Watermark} is received.
+	 * <p>
+	 * <b>NOTE: </b> This service is only available to <b>keyed</b> operators.
+	 */
+	public <K> InternalWatermarkCallbackService<K> getInternalWatermarkCallbackService() {
+		checkTimerServiceInitialization();
+
+		InternalTimeServiceManager<K, ?> keyedTimeServiceHandler = (InternalTimeServiceManager<K, ?>) timeServiceManager;
+		return keyedTimeServiceHandler.getWatermarkCallbackService();
+	}
+
+	/**
 	 * Returns a {@link InternalTimerService} that can be used to query current processing time
 	 * and event time and to set timers. An operator can have several timer services, where
 	 * each has its own namespace serializer. Timer services are differentiated by the string
@@ -908,38 +899,34 @@ public abstract class AbstractStreamOperator<OUT>
 	 *
 	 * @param <N> The type of the timer namespace.
 	 */
-	public <N> InternalTimerService<N> getInternalTimerService(
+	public <K, N> InternalTimerService<N> getInternalTimerService(
 			String name,
 			TypeSerializer<N> namespaceSerializer,
-			Triggerable<?, N> triggerable) {
-		if (getKeyedStateBackend() == null) {
-			throw new UnsupportedOperationException("Timers can only be used on keyed operators.");
-		}
+			Triggerable<K, N> triggerable) {
 
-		@SuppressWarnings("unchecked")
-		HeapInternalTimerService<Object, N> timerService = (HeapInternalTimerService<Object, N>) timerServices.get(name);
+		checkTimerServiceInitialization();
 
-		if (timerService == null) {
-			timerService = new HeapInternalTimerService<>(
-				getKeyedStateBackend().getNumberOfKeyGroups(),
-				getKeyedStateBackend().getKeyGroupRange(),
-				this,
-				getRuntimeContext().getProcessingTimeService());
-			timerServices.put(name, timerService);
-		}
-		@SuppressWarnings({"unchecked", "rawtypes"})
-		Triggerable rawTriggerable = (Triggerable) triggerable;
-		timerService.startTimerService(getKeyedStateBackend().getKeySerializer(), namespaceSerializer, rawTriggerable);
-		return timerService;
+		// the following casting is to overcome type restrictions.
+		TypeSerializer<K> keySerializer = (TypeSerializer<K>) getKeyedStateBackend().getKeySerializer();
+		InternalTimeServiceManager<K, N> keyedTimeServiceHandler = (InternalTimeServiceManager<K, N>) timeServiceManager;
+		return keyedTimeServiceHandler.getInternalTimerService(name, keySerializer, namespaceSerializer, triggerable);
 	}
 
 	public void processWatermark(Watermark mark) throws Exception {
-		for (HeapInternalTimerService<?, ?> service : timerServices.values()) {
-			service.advanceWatermark(mark.getTimestamp());
+		if (timeServiceManager != null) {
+			timeServiceManager.advanceWatermark(mark);
 		}
 		output.emitWatermark(mark);
 	}
 
+	private void checkTimerServiceInitialization() {
+		if (getKeyedStateBackend() == null) {
+			throw new UnsupportedOperationException("Timers can only be used on keyed operators.");
+		} else if (timeServiceManager == null) {
+			throw new RuntimeException("The timer service has not been initialized.");
+		}
+	}
+
 	public void processWatermark1(Watermark mark) throws Exception {
 		input1Watermark = mark.getTimestamp();
 		long newMin = Math.min(input1Watermark, input2Watermark);
@@ -960,19 +947,13 @@ public abstract class AbstractStreamOperator<OUT>
 
 	@VisibleForTesting
 	public int numProcessingTimeTimers() {
-		int count = 0;
-		for (HeapInternalTimerService<?, ?> timerService : timerServices.values()) {
-			count += timerService.numProcessingTimeTimers();
-		}
-		return count;
+		return timeServiceManager == null ? 0 :
+			timeServiceManager.numProcessingTimeTimers();
 	}
 
 	@VisibleForTesting
 	public int numEventTimeTimers() {
-		int count = 0;
-		for (HeapInternalTimerService<?, ?> timerService : timerServices.values()) {
-			count += timerService.numEventTimeTimers();
-		}
-		return count;
+		return timeServiceManager == null ? 0 :
+			timeServiceManager.numEventTimeTimers();
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/daf0ccda/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java
new file mode 100644
index 0000000..71ffbd2
--- /dev/null
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java
@@ -0,0 +1,191 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.state.KeyGroupsList;
+import org.apache.flink.runtime.state.VoidNamespaceSerializer;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * An entity keeping all the time-related services available to all operators extending the
+ * {@link AbstractStreamOperator}. These are the different {@link HeapInternalTimerService timer services}
+ * and the {@link InternalWatermarkCallbackService}.
+ *
+ * <b>NOTE:</b> These services are only available to keyed operators.
+ *
+ * @param <K> The type of keys used for the timers and the registry.
+ * @param <N> The type of namespace used for the timers.
+ */
+@Internal
+class InternalTimeServiceManager<K, N> {
+
+	private final int totalKeyGroups;
+	private final KeyGroupsList localKeyGroupRange;
+	private final KeyContext keyContext;
+
+	private final ProcessingTimeService processingTimeService;
+
+	private final Map<String, HeapInternalTimerService<K, N>> timerServices;
+	private final InternalWatermarkCallbackService<K> watermarkCallbackService;
+
+	InternalTimeServiceManager(
+			int totalKeyGroups,
+			KeyGroupsList localKeyGroupRange,
+			KeyContext keyContext,
+			ProcessingTimeService processingTimeService) {
+
+		Preconditions.checkArgument(totalKeyGroups > 0);
+		this.totalKeyGroups = totalKeyGroups;
+		this.localKeyGroupRange = Preconditions.checkNotNull(localKeyGroupRange);
+
+		this.keyContext = Preconditions.checkNotNull(keyContext);
+		this.processingTimeService = Preconditions.checkNotNull(processingTimeService);
+
+		this.timerServices = new HashMap<>();
+		this.watermarkCallbackService = new InternalWatermarkCallbackService<>(totalKeyGroups, localKeyGroupRange, keyContext);
+	}
+
+	/**
+	 * Returns an {@link InternalWatermarkCallbackService} which  allows to register a
+	 * {@link OnWatermarkCallback} and multiple keys, for which
+	 * the callback will be invoked every time a new {@link Watermark} is received.
+	 */
+	public InternalWatermarkCallbackService<K> getWatermarkCallbackService() {
+		return watermarkCallbackService;
+	}
+
+	/**
+	 * Returns a {@link InternalTimerService} that can be used to query current processing time
+	 * and event time and to set timers. An operator can have several timer services, where
+	 * each has its own namespace serializer. Timer services are differentiated by the string
+	 * key that is given when requesting them, if you call this method with the same key
+	 * multiple times you will get the same timer service instance in subsequent requests.
+	 *
+	 * <p>Timers are always scoped to a key, the currently active key of a keyed stream operation.
+	 * When a timer fires, this key will also be set as the currently active key.
+	 *
+	 * <p>Each timer has attached metadata, the namespace. Different timer services
+	 * can have a different namespace type. If you don't need namespace differentiation you
+	 * can use {@link VoidNamespaceSerializer} as the namespace serializer.
+	 *
+	 * @param name The name of the requested timer service. If no service exists under the given
+	 *             name a new one will be created and returned.
+	 * @param keySerializer {@code TypeSerializer} for the timer keys.
+	 * @param namespaceSerializer {@code TypeSerializer} for the timer namespace.
+	 * @param triggerable The {@link Triggerable} that should be invoked when timers fire
+	 */
+	public InternalTimerService<N> getInternalTimerService(String name, TypeSerializer<K> keySerializer,
+														TypeSerializer<N> namespaceSerializer, Triggerable<K, N> triggerable) {
+
+		HeapInternalTimerService<K, N> timerService = timerServices.get(name);
+		if (timerService == null) {
+			timerService = new HeapInternalTimerService<>(totalKeyGroups,
+				localKeyGroupRange, keyContext, processingTimeService);
+			timerServices.put(name, timerService);
+		}
+		timerService.startTimerService(keySerializer, namespaceSerializer, triggerable);
+		return timerService;
+	}
+
+	public void advanceWatermark(Watermark watermark) throws Exception {
+		for (HeapInternalTimerService<?, ?> service : timerServices.values()) {
+			service.advanceWatermark(watermark.getTimestamp());
+		}
+		watermarkCallbackService.invokeOnWatermarkCallback(watermark);
+	}
+
+	//////////////////				Fault Tolerance Methods				///////////////////
+
+	public void snapshotStateForKeyGroup(DataOutputViewStreamWrapper stream, int keyGroupIdx) throws Exception {
+		stream.writeInt(timerServices.size());
+
+		for (Map.Entry<String, HeapInternalTimerService<K, N>> entry : timerServices.entrySet()) {
+			String serviceName = entry.getKey();
+			HeapInternalTimerService<?, ?> timerService = entry.getValue();
+
+			stream.writeUTF(serviceName);
+			timerService.snapshotTimersForKeyGroup(stream, keyGroupIdx);
+		}
+
+		// write a byte indicating if there was a key
+		// registry service instantiated (1) or not (0).
+		if (watermarkCallbackService != null) {
+			stream.writeByte(1);
+			watermarkCallbackService.snapshotKeysForKeyGroup(stream, keyGroupIdx);
+		} else {
+			stream.writeByte(0);
+		}
+	}
+
+	public void restoreStateForKeyGroup(DataInputViewStreamWrapper stream, int keyGroupIdx,
+										ClassLoader userCodeClassLoader) throws IOException, ClassNotFoundException {
+
+		int noOfTimerServices = stream.readInt();
+		for (int i = 0; i < noOfTimerServices; i++) {
+			String serviceName = stream.readUTF();
+
+			HeapInternalTimerService<K, N> timerService = timerServices.get(serviceName);
+			if (timerService == null) {
+				timerService = new HeapInternalTimerService<>(
+					totalKeyGroups,
+					localKeyGroupRange,
+					keyContext,
+					processingTimeService);
+				timerServices.put(serviceName, timerService);
+			}
+			timerService.restoreTimersForKeyGroup(stream, keyGroupIdx, userCodeClassLoader);
+		}
+
+		byte hadKeyRegistry = stream.readByte();
+		if (hadKeyRegistry == 1) {
+			watermarkCallbackService.restoreKeysForKeyGroup(stream, keyGroupIdx, userCodeClassLoader);
+		}
+	}
+
+	////////////////////			Methods used ONLY IN TESTS				////////////////////
+
+	@VisibleForTesting
+	public int numProcessingTimeTimers() {
+		int count = 0;
+		for (HeapInternalTimerService<?, ?> timerService : timerServices.values()) {
+			count += timerService.numProcessingTimeTimers();
+		}
+		return count;
+	}
+
+	@VisibleForTesting
+	public int numEventTimeTimers() {
+		int count = 0;
+		for (HeapInternalTimerService<?, ?> timerService : timerServices.values()) {
+			count += timerService.numEventTimeTimers();
+		}
+		return count;
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/daf0ccda/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalWatermarkCallbackService.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalWatermarkCallbackService.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalWatermarkCallbackService.java
new file mode 100644
index 0000000..a4263e4
--- /dev/null
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalWatermarkCallbackService.java
@@ -0,0 +1,239 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.runtime.state.KeyGroupsList;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.util.InstantiationUtil;
+
+import java.io.IOException;
+import java.util.HashSet;
+import java.util.Set;
+
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/**
+ * The watermark callback service allows to register a {@link OnWatermarkCallback OnWatermarkCallback}
+ * and multiple keys, for which the callback will be invoked every time a new {@link Watermark} is received
+ * (after the registration of the key).
+ * <p>
+ * <b>NOTE: </b> This service is only available to <b>keyed</b> operators.
+ *
+ *  @param <K> The type of key returned by the {@code KeySelector}.
+ */
+@Internal
+public class InternalWatermarkCallbackService<K> {
+
+	////////////			Information about the keyed state				//////////
+
+	private final KeyGroupsList localKeyGroupRange;
+	private final int totalKeyGroups;
+	private final int localKeyGroupRangeStartIdx;
+
+	private final KeyContext keyContext;
+
+	/**
+	 * An array of sets of keys keeping the registered keys split
+	 * by the key-group they belong to. Each key-group has one set.
+	 */
+	private final Set<K>[] keysByKeygroup;
+
+	/** A serializer for the registered keys. */
+	private TypeSerializer<K> keySerializer;
+
+	/**
+	 * The {@link OnWatermarkCallback} to be invoked for each
+	 * registered key upon reception of the watermark.
+	 */
+	private OnWatermarkCallback<K> callback;
+
+	public InternalWatermarkCallbackService(int totalKeyGroups, KeyGroupsList localKeyGroupRange, KeyContext keyContext) {
+
+		this.totalKeyGroups = totalKeyGroups;
+		this.localKeyGroupRange = checkNotNull(localKeyGroupRange);
+		this.keyContext = checkNotNull(keyContext);
+
+		// find the starting index of the local key-group range
+		int startIdx = Integer.MAX_VALUE;
+		for (Integer keyGroupIdx : localKeyGroupRange) {
+			startIdx = Math.min(keyGroupIdx, startIdx);
+		}
+		this.localKeyGroupRangeStartIdx = startIdx;
+
+		// the list of ids of the key-groups this task is responsible for
+		int localKeyGroups = this.localKeyGroupRange.getNumberOfKeyGroups();
+		this.keysByKeygroup = new Set[localKeyGroups];
+	}
+
+	/**
+	 * Registers a {@link OnWatermarkCallback} with the current {@link InternalWatermarkCallbackService} service.
+	 * Before this method is called and the callback is set, the service is unusable.
+	 *
+	 * @param watermarkCallback The callback to be registered.
+	 * @param keySerializer A serializer for the registered keys.
+	 */
+	public void setWatermarkCallback(OnWatermarkCallback<K> watermarkCallback, TypeSerializer<K> keySerializer) {
+		if (callback == null) {
+			this.keySerializer = keySerializer;
+			this.callback = watermarkCallback;
+		} else {
+			throw new RuntimeException("The watermark callback has already been initialized.");
+		}
+	}
+
+	/**
+	 * Registers a key with the service. This will lead to the {@link OnWatermarkCallback}
+	 * being invoked for this key upon reception of each subsequent watermark.
+	 *
+	 * @param key The key to be registered.
+	 */
+	public boolean registerKeyForWatermarkCallback(K key) {
+		return getKeySetForKeyGroup(key).add(key);
+	}
+
+	/**
+	 * Unregisters the provided key from the service.
+	 *
+	 * @param key The key to be unregistered.
+	 */
+	public boolean unregisterKeyFromWatermarkCallback(K key) {
+		Set<K> keys = getKeySetForKeyGroup(key);
+		boolean res = keys.remove(key);
+
+		if (keys.isEmpty()) {
+			removeKeySetForKey(key);
+		}
+		return res;
+	}
+
+	/**
+	 * Invokes the registered callback for all the registered keys.
+	 *
+	 * @param watermark The watermark that triggered the invocation.
+	 */
+	public void invokeOnWatermarkCallback(Watermark watermark) throws IOException {
+		if (callback != null) {
+			for (Set<K> keySet : keysByKeygroup) {
+				if (keySet != null) {
+					for (K key : keySet) {
+						keyContext.setCurrentKey(key);
+						callback.onWatermark(key, watermark);
+					}
+				}
+			}
+		}
+	}
+
+	/**
+	 * Retrieve the set of keys for the key-group this key belongs to.
+	 *
+	 * @param key the key whose key-group we are searching.
+	 * @return the set of registered keys for the key-group.
+	 */
+	private Set<K> getKeySetForKeyGroup(K key) {
+		checkArgument(localKeyGroupRange != null, "The operator has not been initialized.");
+		int keyGroupIdx = KeyGroupRangeAssignment.assignToKeyGroup(key, totalKeyGroups);
+		return getKeySetForKeyGroup(keyGroupIdx);
+	}
+
+	/**
+	 * Retrieve the set of keys for the requested key-group.
+	 *
+	 * @param keyGroupIdx the index of the key group we are interested in.
+	 * @return the set of keys for the key-group.
+	 */
+	private Set<K> getKeySetForKeyGroup(int keyGroupIdx) {
+		int localIdx = getIndexForKeyGroup(keyGroupIdx);
+		Set<K> keys = keysByKeygroup[localIdx];
+		if (keys == null) {
+			keys = new HashSet<>();
+			keysByKeygroup[localIdx] = keys;
+		}
+		return keys;
+	}
+
+	private void removeKeySetForKey(K key) {
+		checkArgument(localKeyGroupRange != null, "The operator has not been initialized.");
+		int keyGroupIdx = KeyGroupRangeAssignment.assignToKeyGroup(key, totalKeyGroups);
+		int localKeyGroupIdx = getIndexForKeyGroup(keyGroupIdx);
+		keysByKeygroup[localKeyGroupIdx] = null;
+	}
+
+	/**
+	 * Computes the index of the requested key-group in the local datastructures.
+	 * <li/>
+	 * Currently we assume that each task is assigned a continuous range of key-groups,
+	 * e.g. 1,2,3,4, and not 1,3,5. We leverage this to keep the different states
+	 * key-grouped in arrays instead of maps, where the offset for each key-group is
+	 * the key-group id (an int) minus the id of the first key-group in the local range.
+	 * This is for performance reasons.
+	 */
+	private int getIndexForKeyGroup(int keyGroupIdx) {
+		checkArgument(localKeyGroupRange.contains(keyGroupIdx),
+			"Key Group " + keyGroupIdx + " does not belong to the local range.");
+		return keyGroupIdx - localKeyGroupRangeStartIdx;
+	}
+
+	//////////////////				Fault Tolerance Methods				///////////////////
+
+	public void snapshotKeysForKeyGroup(DataOutputViewStreamWrapper stream, int keyGroupIdx) throws Exception {
+		Set<K> keySet = getKeySetForKeyGroup(keyGroupIdx);
+		if (keySet != null) {
+			stream.writeInt(keySet.size());
+
+			InstantiationUtil.serializeObject(stream, keySerializer);
+			for (K key : keySet) {
+				keySerializer.serialize(key, stream);
+			}
+		} else {
+			stream.writeInt(0);
+		}
+	}
+
+	public void restoreKeysForKeyGroup(DataInputViewStreamWrapper stream, int keyGroupIdx,
+									ClassLoader userCodeClassLoader) throws IOException, ClassNotFoundException {
+
+		checkArgument(localKeyGroupRange.contains(keyGroupIdx),
+			"Key Group " + keyGroupIdx + " does not belong to the local range.");
+
+		int numKeys = stream.readInt();
+		if (numKeys > 0) {
+
+			TypeSerializer<K> tmpKeyDeserializer = InstantiationUtil.deserializeObject(stream, userCodeClassLoader);
+
+			if (keySerializer != null && !keySerializer.equals(tmpKeyDeserializer)) {
+				throw new IllegalArgumentException("Tried to restore timers " +
+					"for the same service with different serializers.");
+			}
+
+			this.keySerializer = tmpKeyDeserializer;
+
+			Set<K> keys = getKeySetForKeyGroup(keyGroupIdx);
+			for (int i = 0; i < numKeys; i++) {
+				keys.add(keySerializer.deserialize(stream));
+			}
+		}
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/daf0ccda/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/OnWatermarkCallback.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/OnWatermarkCallback.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/OnWatermarkCallback.java
new file mode 100644
index 0000000..bc317a9
--- /dev/null
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/OnWatermarkCallback.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.streaming.api.watermark.Watermark;
+
+import java.io.IOException;
+
+/**
+ * A callback registered with the {@link InternalWatermarkCallbackService} service. This callback will
+ * be invoked for all keys registered with the service, upon reception of a watermark.
+ */
+@Internal
+public interface OnWatermarkCallback<KEY> {
+
+	/**
+	 * The action to be triggered upon reception of a watermark.
+	 *
+	 * @param key The current key.
+	 * @param watermark The current watermark.
+	 */
+	void onWatermark(KEY key, Watermark watermark) throws IOException;
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/daf0ccda/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractStreamOperatorTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractStreamOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractStreamOperatorTest.java
index 8507200..33def9e 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractStreamOperatorTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/AbstractStreamOperatorTest.java
@@ -34,12 +34,14 @@ import static org.powermock.api.mockito.PowerMockito.spy;
 import static org.powermock.api.mockito.PowerMockito.when;
 import static org.powermock.api.mockito.PowerMockito.whenNew;
 
+import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Random;
 import java.util.concurrent.RunnableFuture;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
 import org.apache.flink.api.common.typeutils.base.StringSerializer;
 import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.api.java.tuple.Tuple2;
@@ -55,12 +57,14 @@ import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateSnapshotContextSynchronousImpl;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
+import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
 import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
 import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness;
 import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
+import org.junit.Assert;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.internal.util.reflection.Whitebox;
@@ -610,6 +614,208 @@ public class AbstractStreamOperatorTest {
 		verify(futureKeyGroupStateHandle).cancel(anyBoolean());
 	}
 
+	@Test
+	public void testWatermarkCallbackServiceScalingUp() throws Exception {
+		final int MAX_PARALLELISM = 10;
+
+		KeySelector<Tuple2<Integer, String>, Integer> keySelector = new TestKeySelector();
+
+		Tuple2<Integer, String> element1 = new Tuple2<>(7, "first");
+		Tuple2<Integer, String> element2 = new Tuple2<>(10, "start");
+
+		int keygroup = KeyGroupRangeAssignment.assignToKeyGroup(keySelector.getKey(element1), MAX_PARALLELISM);
+		assertEquals(1, keygroup);
+		assertEquals(0, KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(MAX_PARALLELISM, 2, keygroup));
+
+		keygroup = KeyGroupRangeAssignment.assignToKeyGroup(keySelector.getKey(element2), MAX_PARALLELISM);
+		assertEquals(9, keygroup);
+		assertEquals(1, KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(MAX_PARALLELISM, 2, keygroup));
+
+		// now we start the test, we go from parallelism 1 to 2.
+
+		KeyedOneInputStreamOperatorTestHarness<Integer, Tuple2<Integer, String>, Integer> testHarness1 =
+			getTestHarness(MAX_PARALLELISM, 1, 0);
+		testHarness1.open();
+
+		testHarness1.processElement(new StreamRecord<>(element1));
+		testHarness1.processElement(new StreamRecord<>(element2));
+
+		assertEquals(0, testHarness1.getOutput().size());
+
+		// take a snapshot with some elements in internal sorting queue
+		OperatorStateHandles snapshot = testHarness1.snapshot(0, 0);
+		testHarness1.close();
+
+		// initialize two sub-tasks with the previously snapshotted state to simulate scaling up
+
+		KeyedOneInputStreamOperatorTestHarness<Integer, Tuple2<Integer, String>, Integer> testHarness2 =
+			getTestHarness(MAX_PARALLELISM, 2, 0);
+
+		testHarness2.setup();
+		testHarness2.initializeState(snapshot);
+		testHarness2.open();
+
+		KeyedOneInputStreamOperatorTestHarness<Integer, Tuple2<Integer, String>, Integer> testHarness3 =
+			getTestHarness(MAX_PARALLELISM, 2, 1);
+
+		testHarness3.setup();
+		testHarness3.initializeState(snapshot);
+		testHarness3.open();
+
+		testHarness2.processWatermark(new Watermark(10));
+		testHarness3.processWatermark(new Watermark(10));
+
+		assertEquals(2, testHarness2.getOutput().size());
+		verifyElement(testHarness2.getOutput().poll(), 7);
+		verifyWatermark(testHarness2.getOutput().poll(), 10);
+
+		assertEquals(2, testHarness3.getOutput().size());
+		verifyElement(testHarness3.getOutput().poll(), 10);
+		verifyWatermark(testHarness3.getOutput().poll(), 10);
+
+		testHarness1.close();
+		testHarness2.close();
+		testHarness3.close();
+	}
+
+	@Test
+	public void testWatermarkCallbackServiceScalingDown() throws Exception {
+		final int MAX_PARALLELISM = 10;
+
+		KeySelector<Tuple2<Integer, String>, Integer> keySelector = new TestKeySelector();
+
+		Tuple2<Integer, String> element1 = new Tuple2<>(7, "first");
+		Tuple2<Integer, String> element2 = new Tuple2<>(45, "start");
+		Tuple2<Integer, String> element3 = new Tuple2<>(90, "start");
+		Tuple2<Integer, String> element4 = new Tuple2<>(10, "start");
+
+		int keygroup = KeyGroupRangeAssignment.assignToKeyGroup(keySelector.getKey(element1), MAX_PARALLELISM);
+		assertEquals(1, keygroup);
+		assertEquals(0, KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(MAX_PARALLELISM, 3, keygroup));
+		assertEquals(0, KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(MAX_PARALLELISM, 2, keygroup));
+
+		keygroup = KeyGroupRangeAssignment.assignToKeyGroup(keySelector.getKey(element2), MAX_PARALLELISM);
+		assertEquals(6, keygroup);
+		assertEquals(1, KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(MAX_PARALLELISM, 3, keygroup));
+		assertEquals(1, KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(MAX_PARALLELISM, 2, keygroup));
+
+		keygroup = KeyGroupRangeAssignment.assignToKeyGroup(keySelector.getKey(element3), MAX_PARALLELISM);
+		assertEquals(2, keygroup);
+		assertEquals(0, KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(MAX_PARALLELISM, 3, keygroup));
+		assertEquals(0, KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(MAX_PARALLELISM, 2, keygroup));
+
+		keygroup = KeyGroupRangeAssignment.assignToKeyGroup(keySelector.getKey(element4), MAX_PARALLELISM);
+		assertEquals(9, keygroup);
+		assertEquals(2, KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(MAX_PARALLELISM, 3, keygroup));
+		assertEquals(1, KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(MAX_PARALLELISM, 2, keygroup));
+
+		// starting the test, we will go from parallelism of 3 to parallelism of 2
+
+		// first operator
+		KeyedOneInputStreamOperatorTestHarness<Integer, Tuple2<Integer, String>, Integer> testHarness1 =
+			getTestHarness(MAX_PARALLELISM, 3, 0);
+		testHarness1.open();
+
+		// second operator
+		KeyedOneInputStreamOperatorTestHarness<Integer, Tuple2<Integer, String>, Integer> testHarness2 =
+			getTestHarness(MAX_PARALLELISM, 3, 1);
+		testHarness2.open();
+
+		// third operator
+		KeyedOneInputStreamOperatorTestHarness<Integer, Tuple2<Integer, String>, Integer> testHarness3 =
+			getTestHarness(MAX_PARALLELISM, 3, 2);
+		testHarness3.open();
+
+		testHarness1.processWatermark(Long.MIN_VALUE);
+		testHarness2.processWatermark(Long.MIN_VALUE);
+		testHarness3.processWatermark(Long.MIN_VALUE);
+
+		testHarness1.processElement(new StreamRecord<>(element1));
+		testHarness1.processElement(new StreamRecord<>(element3));
+
+		testHarness2.processElement(new StreamRecord<>(element2));
+		testHarness3.processElement(new StreamRecord<>(element4));
+
+		// so far we only have the initial watermark
+		assertEquals(1, testHarness1.getOutput().size());
+		verifyWatermark(testHarness1.getOutput().poll(), Long.MIN_VALUE);
+
+		assertEquals(1, testHarness2.getOutput().size());
+		verifyWatermark(testHarness2.getOutput().poll(), Long.MIN_VALUE);
+
+		assertEquals(1, testHarness3.getOutput().size());
+		verifyWatermark(testHarness3.getOutput().poll(), Long.MIN_VALUE);
+
+		// we take a snapshot and make it look as a single operator
+		// this will be the initial state of all downstream tasks.
+		OperatorStateHandles snapshot = AbstractStreamOperatorTestHarness.repackageState(
+			testHarness2.snapshot(0, 0),
+			testHarness1.snapshot(0, 0),
+			testHarness3.snapshot(0, 0)
+		);
+
+		// first new operator
+		KeyedOneInputStreamOperatorTestHarness<Integer, Tuple2<Integer, String>, Integer> testHarness4 =
+			getTestHarness(MAX_PARALLELISM, 2, 0);
+		testHarness4.setup();
+		testHarness4.initializeState(snapshot);
+		testHarness4.open();
+
+		// second new operator
+		KeyedOneInputStreamOperatorTestHarness<Integer, Tuple2<Integer, String>, Integer> testHarness5 =
+			getTestHarness(MAX_PARALLELISM, 2, 1);
+		testHarness5.setup();
+		testHarness5.initializeState(snapshot);
+		testHarness5.open();
+
+		testHarness4.processWatermark(10);
+		testHarness5.processWatermark(10);
+
+		assertEquals(3, testHarness4.getOutput().size());
+		verifyElement(testHarness4.getOutput().poll(), 7);
+		verifyElement(testHarness4.getOutput().poll(), 90);
+		verifyWatermark(testHarness4.getOutput().poll(), 10);
+
+		assertEquals(3, testHarness5.getOutput().size());
+		verifyElement(testHarness5.getOutput().poll(), 45);
+		verifyElement(testHarness5.getOutput().poll(), 10);
+		verifyWatermark(testHarness5.getOutput().poll(), 10);
+
+		testHarness1.close();
+		testHarness2.close();
+		testHarness3.close();
+		testHarness4.close();
+		testHarness5.close();
+	}
+
+	private KeyedOneInputStreamOperatorTestHarness<Integer, Tuple2<Integer, String>, Integer> getTestHarness(
+			int maxParallelism, int noOfTasks, int taskIdx) throws Exception {
+
+		return new KeyedOneInputStreamOperatorTestHarness<>(
+			new TestOperatorWithCallback(),
+			new TestKeySelector(),
+			BasicTypeInfo.INT_TYPE_INFO,
+			maxParallelism,
+			noOfTasks, /* num subtasks */
+			taskIdx /* subtask index */);
+	}
+
+	private void verifyWatermark(Object outputObject, long timestamp) {
+		Assert.assertTrue(outputObject instanceof Watermark);
+		assertEquals(timestamp, ((Watermark) outputObject).getTimestamp());
+	}
+
+	private void verifyElement(Object outputObject, int expected) {
+		Assert.assertTrue(outputObject instanceof StreamRecord);
+
+		StreamRecord<?> resultRecord = (StreamRecord<?>) outputObject;
+		Assert.assertTrue(resultRecord.getValue() instanceof Integer);
+
+		@SuppressWarnings("unchecked")
+		int actual = (Integer) resultRecord.getValue();
+		assertEquals(expected, actual);
+	}
+
 	/**
 	 * Extracts the result values form the test harness and clear the output queue.
 	 */
@@ -626,7 +832,6 @@ public class AbstractStreamOperatorTest {
 		return result;
 	}
 
-
 	private static class TestKeySelector implements KeySelector<Tuple2<Integer, String>, Integer> {
 		private static final long serialVersionUID = 1L;
 
@@ -636,6 +841,33 @@ public class AbstractStreamOperatorTest {
 		}
 	}
 
+	private static class TestOperatorWithCallback
+			extends AbstractStreamOperator<Integer>
+			implements OneInputStreamOperator<Tuple2<Integer, String>, Integer> {
+
+		private static final long serialVersionUID = 9215057823264582305L;
+
+		@Override
+		public void open() throws Exception {
+			super.open();
+
+			InternalWatermarkCallbackService<Integer> callbackService = getInternalWatermarkCallbackService();
+
+			callbackService.setWatermarkCallback(new OnWatermarkCallback<Integer>() {
+
+				@Override
+				public void onWatermark(Integer integer, Watermark watermark) throws IOException {
+					output.collect(new StreamRecord<>(integer));
+				}
+			}, IntSerializer.INSTANCE);
+		}
+
+		@Override
+		public void processElement(StreamRecord<Tuple2<Integer, String>> element) throws Exception {
+			getInternalWatermarkCallbackService().registerKeyForWatermarkCallback(element.getValue().f0);
+		}
+	}
+
 	/**
 	 * Testing operator that can respond to commands by either setting/deleting state, emitting
 	 * state or setting timers.