You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by al...@apache.org on 2016/12/13 12:44:18 UTC

[1/4] flink git commit: [FLINK-5163] Port the StatefulSequenceSource to the new state abstractions.

Repository: flink
Updated Branches:
  refs/heads/master 369837971 -> 81d8fe16a


[FLINK-5163] Port the StatefulSequenceSource to the new state abstractions.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/81d8fe16
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/81d8fe16
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/81d8fe16

Branch: refs/heads/master
Commit: 81d8fe16a04a7826ba72c89fdc98fa50e3f86f5e
Parents: 956ffa6
Author: kl0u <kk...@gmail.com>
Authored: Mon Nov 21 18:50:30 2016 +0100
Committer: Aljoscha Krettek <al...@gmail.com>
Committed: Tue Dec 13 13:38:18 2016 +0100

----------------------------------------------------------------------
 .../source/StatefulSequenceSource.java          |  98 ++++++--
 .../flink/streaming/api/SourceFunctionTest.java |   8 -
 .../functions/StatefulSequenceSourceTest.java   | 242 +++++++++++++++++++
 3 files changed, 315 insertions(+), 33 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/81d8fe16/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/StatefulSequenceSource.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/StatefulSequenceSource.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/StatefulSequenceSource.java
index 563f6ef..bdb12f3 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/StatefulSequenceSource.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/StatefulSequenceSource.java
@@ -1,4 +1,4 @@
-/**
+/*
  * Licensed to the Apache Software Foundation (ASF) under one
  * or more contributor license agreements.  See the NOTICE file
  * distributed with this work for additional information
@@ -18,25 +18,42 @@
 package org.apache.flink.streaming.api.functions.source;
 
 import org.apache.flink.annotation.PublicEvolving;
-import org.apache.flink.api.common.functions.RuntimeContext;
-import org.apache.flink.streaming.api.checkpoint.Checkpointed;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayDeque;
+import java.util.Deque;
 
 /**
  * A stateful streaming source that emits each number from a given interval exactly once,
  * possibly in parallel.
+ *
+ * <p>For the source to be re-scalable, the first time the job is run, we precompute all the elements
+ * that each of the tasks should emit and upon checkpointing, each element constitutes its own
+ * partition. When rescaling, these partitions will be randomly re-assigned to the new tasks.
+ *
+ * <p>This strategy guarantees that each element will be emitted exactly-once, but elements will not
+ * necessarily be emitted in ascending order, even for the same tasks.
  */
 @PublicEvolving
-public class StatefulSequenceSource extends RichParallelSourceFunction<Long> implements Checkpointed<Long> {
+public class StatefulSequenceSource extends RichParallelSourceFunction<Long> implements CheckpointedFunction {
 	
 	private static final long serialVersionUID = 1L;
 
 	private final long start;
 	private final long end;
 
-	private long collected;
-
 	private volatile boolean isRunning = true;
 
+	private transient Deque<Long> valuesToEmit;
+
+	private transient ListState<Long> checkpointedState;
+
 	/**
 	 * Creates a source that emits all numbers from the given interval exactly once.
 	 *
@@ -49,24 +66,47 @@ public class StatefulSequenceSource extends RichParallelSourceFunction<Long> imp
 	}
 
 	@Override
-	public void run(SourceContext<Long> ctx) throws Exception {
-		final Object checkpointLock = ctx.getCheckpointLock();
+	public void initializeState(FunctionInitializationContext context) throws Exception {
 
-		RuntimeContext context = getRuntimeContext();
+		Preconditions.checkState(this.checkpointedState == null,
+			"The " + getClass().getSimpleName() + " has already been initialized.");
 
-		final long stepSize = context.getNumberOfParallelSubtasks();
-		final long congruence = start + context.getIndexOfThisSubtask();
+		this.checkpointedState = context.getOperatorStateStore().getOperatorState(
+			new ListStateDescriptor<>(
+				"stateful-sequence-source-state",
+				LongSerializer.INSTANCE
+			)
+		);
 
-		final long toCollect =
-				((end - start + 1) % stepSize > (congruence - start)) ?
-					((end - start + 1) / stepSize + 1) :
-					((end - start + 1) / stepSize);
-		
+		this.valuesToEmit = new ArrayDeque<>();
+		if (context.isRestored()) {
+			// upon restoring
 
-		while (isRunning && collected < toCollect) {
-			synchronized (checkpointLock) {
-				ctx.collect(collected * stepSize + congruence);
-				collected++;
+			for (Long v : this.checkpointedState.get()) {
+				this.valuesToEmit.add(v);
+			}
+		} else {
+			// the first time the job is executed
+
+			final int stepSize = getRuntimeContext().getNumberOfParallelSubtasks();
+			final int taskIdx = getRuntimeContext().getIndexOfThisSubtask();
+			final long congruence = start + taskIdx;
+
+			long totalNoOfElements = Math.abs(end - start + 1);
+			final int baseSize = safeDivide(totalNoOfElements, stepSize);
+			final int toCollect = (totalNoOfElements % stepSize > taskIdx) ? baseSize + 1 : baseSize;
+
+			for (long collected = 0; collected < toCollect; collected++) {
+				this.valuesToEmit.add(collected * stepSize + congruence);
+			}
+		}
+	}
+
+	@Override
+	public void run(SourceContext<Long> ctx) throws Exception {
+		while (isRunning && !this.valuesToEmit.isEmpty()) {
+			synchronized (ctx.getCheckpointLock()) {
+				ctx.collect(this.valuesToEmit.poll());
 			}
 		}
 	}
@@ -77,12 +117,20 @@ public class StatefulSequenceSource extends RichParallelSourceFunction<Long> imp
 	}
 
 	@Override
-	public Long snapshotState(long checkpointId, long checkpointTimestamp) {
-		return collected;
+	public void snapshotState(FunctionSnapshotContext context) throws Exception {
+		Preconditions.checkState(this.checkpointedState != null,
+			"The " + getClass().getSimpleName() + " state has not been properly initialized.");
+
+		this.checkpointedState.clear();
+		for (Long v : this.valuesToEmit) {
+			this.checkpointedState.add(v);
+		}
 	}
 
-	@Override
-	public void restoreState(Long state) {
-		collected = state;
+	private static int safeDivide(long left, long right) {
+		Preconditions.checkArgument(right > 0);
+		Preconditions.checkArgument(left >= 0);
+		Preconditions.checkArgument(left <= Integer.MAX_VALUE * right);
+		return (int) (left / right);
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/81d8fe16/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/SourceFunctionTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/SourceFunctionTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/SourceFunctionTest.java
index 946b474..dd4ff33 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/SourceFunctionTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/SourceFunctionTest.java
@@ -52,12 +52,4 @@ public class SourceFunctionTest {
 						Arrays.asList(1, 2, 3))));
 		assertEquals(expectedList, actualList);
 	}
-
-	@Test
-	public void generateSequenceTest() throws Exception {
-		List<Long> expectedList = Arrays.asList(1L, 2L, 3L, 4L, 5L, 6L, 7L);
-		List<Long> actualList = SourceFunctionUtil.runSourceFunction(new StatefulSequenceSource(1,
-				7));
-		assertEquals(expectedList, actualList);
-	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/81d8fe16/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/StatefulSequenceSourceTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/StatefulSequenceSourceTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/StatefulSequenceSourceTest.java
new file mode 100644
index 0000000..8332cb3
--- /dev/null
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/StatefulSequenceSourceTest.java
@@ -0,0 +1,242 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.functions;
+
+import org.apache.flink.core.testutils.OneShotLatch;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.functions.source.StatefulSequenceSource;
+import org.apache.flink.streaming.api.operators.StreamSource;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
+import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+
+public class StatefulSequenceSourceTest {
+
+	@Test
+	public void testCheckpointRestore() throws Exception {
+		final int initElement = 0;
+		final int maxElement = 100;
+
+		final Set<Long> expectedOutput = new HashSet<>();
+		for (long i = initElement; i <= maxElement; i++) {
+			expectedOutput.add(i);
+		}
+
+		final ConcurrentHashMap<String, List<Long>> outputCollector = new ConcurrentHashMap<>();
+		final OneShotLatch latchToTrigger1 = new OneShotLatch();
+		final OneShotLatch latchToWait1 = new OneShotLatch();
+		final OneShotLatch latchToTrigger2 = new OneShotLatch();
+		final OneShotLatch latchToWait2 = new OneShotLatch();
+
+		final StatefulSequenceSource source1 = new StatefulSequenceSource(initElement, maxElement);
+		StreamSource<Long, StatefulSequenceSource> src1 = new StreamSource<>(source1);
+
+		final AbstractStreamOperatorTestHarness<Long> testHarness1 =
+			new AbstractStreamOperatorTestHarness<>(src1, 2, 2, 0);
+		testHarness1.open();
+
+		final StatefulSequenceSource source2 = new StatefulSequenceSource(initElement, maxElement);
+		StreamSource<Long, StatefulSequenceSource> src2 = new StreamSource<>(source2);
+
+		final AbstractStreamOperatorTestHarness<Long> testHarness2 =
+			new AbstractStreamOperatorTestHarness<>(src2, 2, 2, 1);
+		testHarness2.open();
+
+		final Throwable[] error = new Throwable[3];
+
+		// run the source asynchronously
+		Thread runner1 = new Thread() {
+			@Override
+			public void run() {
+				try {
+					source1.run(new BlockingSourceContext("1", latchToTrigger1, latchToWait1, outputCollector, 21));
+				}
+				catch (Throwable t) {
+					t.printStackTrace();
+					error[0] = t;
+				}
+			}
+		};
+
+		// run the source asynchronously
+		Thread runner2 = new Thread() {
+			@Override
+			public void run() {
+				try {
+					source2.run(new BlockingSourceContext("2", latchToTrigger2, latchToWait2, outputCollector, 32));
+				}
+				catch (Throwable t) {
+					t.printStackTrace();
+					error[1] = t;
+				}
+			}
+		};
+
+		runner1.start();
+		runner2.start();
+
+		if (!latchToTrigger1.isTriggered()) {
+			latchToTrigger1.await();
+		}
+
+		if (!latchToTrigger2.isTriggered()) {
+			latchToTrigger2.await();
+		}
+
+		OperatorStateHandles snapshot = AbstractStreamOperatorTestHarness.repackageState(
+			testHarness1.snapshot(0L, 0L),
+			testHarness2.snapshot(0L, 0L)
+		);
+
+		final StatefulSequenceSource source3 = new StatefulSequenceSource(initElement, maxElement);
+		StreamSource<Long, StatefulSequenceSource> src3 = new StreamSource<>(source3);
+
+		final AbstractStreamOperatorTestHarness<Long> testHarness3 =
+			new AbstractStreamOperatorTestHarness<>(src3, 2, 1, 0);
+		testHarness3.setup();
+		testHarness3.initializeState(snapshot);
+		testHarness3.open();
+
+		final OneShotLatch latchToTrigger3 = new OneShotLatch();
+		final OneShotLatch latchToWait3 = new OneShotLatch();
+		latchToWait3.trigger();
+
+		// run the source asynchronously
+		Thread runner3 = new Thread() {
+			@Override
+			public void run() {
+				try {
+					source3.run(new BlockingSourceContext("3", latchToTrigger3, latchToWait3, outputCollector, 3));
+				}
+				catch (Throwable t) {
+					t.printStackTrace();
+					error[2] = t;
+				}
+			}
+		};
+		runner3.start();
+		runner3.join();
+
+		Assert.assertEquals(3, outputCollector.size()); // we have 3 tasks.
+
+		// test for at-most-once
+		Set<Long> dedupRes = new HashSet<>(Math.abs(maxElement - initElement) + 1);
+		for (Map.Entry<String, List<Long>> elementsPerTask: outputCollector.entrySet()) {
+			String key = elementsPerTask.getKey();
+			List<Long> elements = outputCollector.get(key);
+
+			// this tests the correctness of the latches in the test
+			Assert.assertTrue(elements.size() > 0);
+
+			for (Long elem : elements) {
+				if (!dedupRes.add(elem)) {
+					Assert.fail("Duplicate entry: " + elem);
+				}
+
+				if (!expectedOutput.contains(elem)) {
+					Assert.fail("Unexpected element: " + elem);
+				}
+			}
+		}
+
+		// test for exactly-once
+		Assert.assertEquals(Math.abs(initElement - maxElement) + 1, dedupRes.size());
+
+		latchToWait1.trigger();
+		latchToWait2.trigger();
+
+		// wait for everybody ot finish.
+		runner1.join();
+		runner2.join();
+	}
+
+	private static class BlockingSourceContext implements SourceFunction.SourceContext<Long> {
+
+		private final String name;
+
+		private final Object lock;
+		private final OneShotLatch latchToTrigger;
+		private final OneShotLatch latchToWait;
+		private final ConcurrentHashMap<String, List<Long>> collector;
+
+		private final int threshold;
+		private int counter = 0;
+
+		private final List<Long> localOutput;
+
+		public BlockingSourceContext(String name, OneShotLatch latchToTrigger, OneShotLatch latchToWait,
+									 ConcurrentHashMap<String, List<Long>> output, int elemToFire) {
+			this.name = name;
+			this.lock = new Object();
+			this.latchToTrigger = latchToTrigger;
+			this.latchToWait = latchToWait;
+			this.collector = output;
+			this.threshold = elemToFire;
+
+			this.localOutput = new ArrayList<>();
+			List<Long> prev = collector.put(name, localOutput);
+			if (prev != null) {
+				Assert.fail();
+			}
+		}
+
+		@Override
+		public void collectWithTimestamp(Long element, long timestamp) {
+			collect(element);
+		}
+
+		@Override
+		public void collect(Long element) {
+			localOutput.add(element);
+			if (++counter == threshold) {
+				latchToTrigger.trigger();
+				try {
+					if (!latchToWait.isTriggered()) {
+						latchToWait.await();
+					}
+				} catch (InterruptedException e) {
+					e.printStackTrace();
+				}
+			}
+		}
+
+
+		@Override
+		public void emitWatermark(Watermark mark) {
+		}
+
+		@Override
+		public Object getCheckpointLock() {
+			return lock;
+		}
+
+		@Override
+		public void close() {
+		}
+	}
+}


[2/4] flink git commit: [FLINK-5163] Port the FromElementsFunction to the new state abstractions.

Posted by al...@apache.org.
[FLINK-5163] Port the FromElementsFunction to the new state abstractions.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/d24833db
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/d24833db
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/d24833db

Branch: refs/heads/master
Commit: d24833dbb641979008ba18be8b2ddb694a5d43e1
Parents: 685c4f8
Author: kl0u <kk...@gmail.com>
Authored: Thu Nov 17 16:52:50 2016 +0100
Committer: Aljoscha Krettek <al...@gmail.com>
Committed: Tue Dec 13 13:38:18 2016 +0100

----------------------------------------------------------------------
 .../functions/source/FromElementsFunction.java  | 61 ++++++++++++++++----
 .../api/functions/FromElementsFunctionTest.java | 35 +++++++----
 2 files changed, 72 insertions(+), 24 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/d24833db/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/FromElementsFunction.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/FromElementsFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/FromElementsFunction.java
index 98bc10e..e3a9d54 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/FromElementsFunction.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/FromElementsFunction.java
@@ -18,17 +18,25 @@
 package org.apache.flink.streaming.api.functions.source;
 
 import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
 import org.apache.flink.core.memory.DataInputView;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
-import org.apache.flink.streaming.api.checkpoint.CheckpointedAsynchronously;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.util.Preconditions;
 
 import java.io.ByteArrayInputStream;
 import java.io.ByteArrayOutputStream;
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
+import java.util.List;
 
 /**
  * A stream source function that returns a sequence of elements.
@@ -36,11 +44,14 @@ import java.util.Collection;
  * <p>Upon construction, this source function serializes the elements using Flink's type information.
  * That way, any object transport using Java serialization will not be affected by the serializability
  * of the elements.</p>
- * 
+ *
+ * <p>
+ * <b>NOTE:</b> This source has a parallelism of 1.
+ *
  * @param <T> The type of elements returned by this function.
  */
 @PublicEvolving
-public class FromElementsFunction<T> implements SourceFunction<T>, CheckpointedAsynchronously<Integer> {
+public class FromElementsFunction<T> implements SourceFunction<T>, CheckpointedFunction {
 	
 	private static final long serialVersionUID = 1L;
 
@@ -62,7 +73,8 @@ public class FromElementsFunction<T> implements SourceFunction<T>, CheckpointedA
 	/** Flag to make the source cancelable */
 	private volatile boolean isRunning = true;
 
-	
+	private transient ListState<Integer> checkpointedState;
+
 	public FromElementsFunction(TypeSerializer<T> serializer, T... elements) throws IOException {
 		this(serializer, Arrays.asList(elements));
 	}
@@ -88,6 +100,32 @@ public class FromElementsFunction<T> implements SourceFunction<T>, CheckpointedA
 	}
 
 	@Override
+	public void initializeState(FunctionInitializationContext context) throws Exception {
+		Preconditions.checkState(this.checkpointedState == null,
+			"The " + getClass().getSimpleName() + " has already been initialized.");
+
+		this.checkpointedState = context.getOperatorStateStore().getOperatorState(
+			new ListStateDescriptor<>(
+				"from-elements-state",
+				IntSerializer.INSTANCE
+			)
+		);
+
+		if (context.isRestored()) {
+			List<Integer> retrievedStates = new ArrayList<>();
+			for (Integer entry : this.checkpointedState.get()) {
+				retrievedStates.add(entry);
+			}
+
+			// given that the parallelism of the function is 1, we can only have 1 state
+			Preconditions.checkArgument(retrievedStates.size() == 1,
+				getClass().getSimpleName() + " retrieved invalid state.");
+
+			this.numElementsToSkip = retrievedStates.get(0);
+		}
+	}
+
+	@Override
 	public void run(SourceContext<T> ctx) throws Exception {
 		ByteArrayInputStream bais = new ByteArrayInputStream(elementsSerialized);
 		final DataInputView input = new DataInputViewStreamWrapper(bais);
@@ -157,17 +195,16 @@ public class FromElementsFunction<T> implements SourceFunction<T>, CheckpointedA
 	// ------------------------------------------------------------------------
 	//  Checkpointing
 	// ------------------------------------------------------------------------
-	
-	@Override
-	public Integer snapshotState(long checkpointId, long checkpointTimestamp) {
-		return this.numElementsEmitted;
-	}
 
 	@Override
-	public void restoreState(Integer state) {
-		this.numElementsToSkip = state;
+	public void snapshotState(FunctionSnapshotContext context) throws Exception {
+		Preconditions.checkState(this.checkpointedState != null,
+			"The " + getClass().getSimpleName() + " has not been properly initialized.");
+
+		this.checkpointedState.clear();
+		this.checkpointedState.add(this.numElementsEmitted);
 	}
-	
+
 	// ------------------------------------------------------------------------
 	//  Utilities
 	// ------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/d24833db/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/FromElementsFunctionTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/FromElementsFunctionTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/FromElementsFunctionTest.java
index 41bd381..01540da 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/FromElementsFunctionTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/FromElementsFunctionTest.java
@@ -26,9 +26,11 @@ import org.apache.flink.api.java.typeutils.TypeExtractor;
 import org.apache.flink.api.java.typeutils.ValueTypeInfo;
 import org.apache.flink.core.memory.DataInputView;
 import org.apache.flink.core.memory.DataOutputView;
-import org.apache.flink.core.testutils.CommonTestUtils;
 import org.apache.flink.streaming.api.functions.source.FromElementsFunction;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.operators.StreamSource;
+import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
+import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
 import org.apache.flink.types.Value;
 import org.apache.flink.util.ExceptionUtils;
 
@@ -141,9 +143,12 @@ public class FromElementsFunctionTest {
 				data.add(i);
 			}
 			
-			final FromElementsFunction<Integer> source = new FromElementsFunction<Integer>(IntSerializer.INSTANCE, data);
-			final FromElementsFunction<Integer> sourceCopy = CommonTestUtils.createCopySerializable(source);
-			
+			final FromElementsFunction<Integer> source = new FromElementsFunction<>(IntSerializer.INSTANCE, data);
+			StreamSource<Integer, FromElementsFunction<Integer>> src = new StreamSource<>(source);
+			AbstractStreamOperatorTestHarness<Integer> testHarness =
+				new AbstractStreamOperatorTestHarness<>(src, 1, 1, 0);
+			testHarness.open();
+
 			final SourceFunction.SourceContext<Integer> ctx = new ListSourceContext<Integer>(result, 2L);
 			
 			final Throwable[] error = new Throwable[1];
@@ -166,11 +171,10 @@ public class FromElementsFunctionTest {
 			Thread.sleep(1000);
 			
 			// make a checkpoint
-			int count;
-			List<Integer> checkpointData = new ArrayList<Integer>(NUM_ELEMENTS);
-			
+			List<Integer> checkpointData = new ArrayList<>(NUM_ELEMENTS);
+			OperatorStateHandles handles = null;
 			synchronized (ctx.getCheckpointLock()) {
-				count = source.snapshotState(566, System.currentTimeMillis());
+				handles = testHarness.snapshot(566, System.currentTimeMillis());
 				checkpointData.addAll(result);
 			}
 			
@@ -184,11 +188,18 @@ public class FromElementsFunctionTest {
 				error[0].printStackTrace();
 				fail("Error in asynchronous source runner");
 			}
-			
+
+			final FromElementsFunction<Integer> sourceCopy = new FromElementsFunction<>(IntSerializer.INSTANCE, data);
+			StreamSource<Integer, FromElementsFunction<Integer>> srcCopy = new StreamSource<>(sourceCopy);
+			AbstractStreamOperatorTestHarness<Integer> testHarnessCopy =
+				new AbstractStreamOperatorTestHarness<>(srcCopy, 1, 1, 0);
+			testHarnessCopy.setup();
+			testHarnessCopy.initializeState(handles);
+			testHarnessCopy.open();
+
 			// recovery run
-			SourceFunction.SourceContext<Integer> newCtx = new ListSourceContext<Integer>(checkpointData);
-			sourceCopy.restoreState(count);
-			
+			SourceFunction.SourceContext<Integer> newCtx = new ListSourceContext<>(checkpointData);
+
 			sourceCopy.run(newCtx);
 			
 			assertEquals(data, checkpointData);


[4/4] flink git commit: [FLINK-5163] Port the ContinuousFileMonitoringFunction to the new state abstractions.

Posted by al...@apache.org.
[FLINK-5163] Port the ContinuousFileMonitoringFunction to the new state abstractions.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/685c4f83
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/685c4f83
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/685c4f83

Branch: refs/heads/master
Commit: 685c4f836bdb79181fd1f62642736606eb81d847
Parents: 3698379
Author: kl0u <kk...@gmail.com>
Authored: Thu Nov 17 14:54:08 2016 +0100
Committer: Aljoscha Krettek <al...@gmail.com>
Committed: Tue Dec 13 13:38:18 2016 +0100

----------------------------------------------------------------------
 .../ContinuousFileProcessingITCase.java         |  2 +-
 .../hdfstests/ContinuousFileProcessingTest.java | 95 ++++++++++++++++++--
 .../environment/StreamExecutionEnvironment.java |  4 +-
 .../ContinuousFileMonitoringFunction.java       | 79 +++++++++++++---
 4 files changed, 157 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/685c4f83/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/ContinuousFileProcessingITCase.java
----------------------------------------------------------------------
diff --git a/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/ContinuousFileProcessingITCase.java b/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/ContinuousFileProcessingITCase.java
index 3211a20..df68a76 100644
--- a/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/ContinuousFileProcessingITCase.java
+++ b/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/ContinuousFileProcessingITCase.java
@@ -124,7 +124,7 @@ public class ContinuousFileProcessingITCase extends StreamingProgramTestBase {
 		env.setParallelism(PARALLELISM);
 
 		ContinuousFileMonitoringFunction<String> monitoringFunction =
-			new ContinuousFileMonitoringFunction<>(format, hdfsURI,
+			new ContinuousFileMonitoringFunction<>(format,
 				FileProcessingMode.PROCESS_CONTINUOUSLY,
 				env.getParallelism(), INTERVAL);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/685c4f83/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/ContinuousFileProcessingTest.java
----------------------------------------------------------------------
diff --git a/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/ContinuousFileProcessingTest.java b/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/ContinuousFileProcessingTest.java
index 6454c11..0cb1bad 100644
--- a/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/ContinuousFileProcessingTest.java
+++ b/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/ContinuousFileProcessingTest.java
@@ -35,9 +35,11 @@ import org.apache.flink.streaming.api.functions.source.ContinuousFileReaderOpera
 import org.apache.flink.streaming.api.functions.source.TimestampedFileInputSplit;
 import org.apache.flink.streaming.api.functions.source.FileProcessingMode;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.operators.StreamSource;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
+import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
 import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
 import org.apache.flink.util.Preconditions;
 import org.apache.hadoop.fs.FSDataOutputStream;
@@ -117,10 +119,10 @@ public class ContinuousFileProcessingTest {
 	public void testInvalidPathSpecification() throws Exception {
 
 		String invalidPath = "hdfs://" + hdfsCluster.getURI().getHost() + ":" + hdfsCluster.getNameNodePort() +"/invalid/";
-		TextInputFormat format = new TextInputFormat(new Path(hdfsURI));
+		TextInputFormat format = new TextInputFormat(new Path(invalidPath));
 
 		ContinuousFileMonitoringFunction<String> monitoringFunction =
-			new ContinuousFileMonitoringFunction<>(format, invalidPath,
+			new ContinuousFileMonitoringFunction<>(format,
 				FileProcessingMode.PROCESS_ONCE, 1, INTERVAL);
 		try {
 			monitoringFunction.run(new DummySourceContext() {
@@ -135,7 +137,7 @@ public class ContinuousFileProcessingTest {
 			Assert.fail("Test passed with an invalid path.");
 
 		} catch (FileNotFoundException e) {
-			Assert.assertEquals("The provided file path " + invalidPath + " does not exist.", e.getMessage());
+			Assert.assertEquals("The provided file path " + format.getFilePath().toString() + " does not exist.", e.getMessage());
 		}
 	}
 
@@ -491,6 +493,8 @@ public class ContinuousFileProcessingTest {
 
 	private static class BlockingFileInputFormat extends FileInputFormat<FileInputSplit> {
 
+		private static final long serialVersionUID = -6727603565381560267L;
+
 		private final OneShotLatch latch;
 
 		private FileInputSplit split;
@@ -556,6 +560,9 @@ public class ContinuousFileProcessingTest {
 
 		TextInputFormat format = new TextInputFormat(new Path(hdfsURI));
 		format.setFilesFilter(new FilePathFilter() {
+
+			private static final long serialVersionUID = 2611449927338589804L;
+
 			@Override
 			public boolean filterPath(Path filePath) {
 				return filePath.getName().startsWith("**");
@@ -563,7 +570,7 @@ public class ContinuousFileProcessingTest {
 		});
 
 		ContinuousFileMonitoringFunction<String> monitoringFunction =
-			new ContinuousFileMonitoringFunction<>(format, hdfsURI,
+			new ContinuousFileMonitoringFunction<>(format,
 				FileProcessingMode.PROCESS_ONCE, 1, INTERVAL);
 
 		final FileVerifyingSourceContext context =
@@ -601,7 +608,7 @@ public class ContinuousFileProcessingTest {
 		FileInputSplit[] splits = format.createInputSplits(1);
 
 		ContinuousFileMonitoringFunction<String> monitoringFunction =
-			new ContinuousFileMonitoringFunction<>(format, hdfsURI,
+			new ContinuousFileMonitoringFunction<>(format,
 				FileProcessingMode.PROCESS_ONCE, 1, INTERVAL);
 
 		ModTimeVerifyingSourceContext context = new ModTimeVerifyingSourceContext(modTimes);
@@ -633,7 +640,7 @@ public class ContinuousFileProcessingTest {
 		format.setFilesFilter(FilePathFilter.createDefaultFilter());
 
 		final ContinuousFileMonitoringFunction<String> monitoringFunction =
-			new ContinuousFileMonitoringFunction<>(format, hdfsURI,
+			new ContinuousFileMonitoringFunction<>(format,
 				FileProcessingMode.PROCESS_ONCE, 1, INTERVAL);
 
 		final FileVerifyingSourceContext context = new FileVerifyingSourceContext(latch, monitoringFunction);
@@ -683,6 +690,80 @@ public class ContinuousFileProcessingTest {
 	}
 
 	@Test
+	public void testFunctionRestore() throws Exception {
+
+		org.apache.hadoop.fs.Path path = null;
+		long fileModTime = Long.MIN_VALUE;
+		for (int i = 0; i < 1; i++) {
+			Tuple2<org.apache.hadoop.fs.Path, String> file = createFileAndFillWithData(hdfsURI, "file", i, "This is test line.");
+			path = file.f0;
+			fileModTime = hdfs.getFileStatus(file.f0).getModificationTime();
+		}
+
+		TextInputFormat format = new TextInputFormat(new Path(hdfsURI));
+
+		final ContinuousFileMonitoringFunction<String> monitoringFunction =
+			new ContinuousFileMonitoringFunction<>(format, FileProcessingMode.PROCESS_CONTINUOUSLY, 1, INTERVAL);
+
+		StreamSource<TimestampedFileInputSplit, ContinuousFileMonitoringFunction<String>> src =
+			new StreamSource<>(monitoringFunction);
+
+		final AbstractStreamOperatorTestHarness<TimestampedFileInputSplit> testHarness =
+			new AbstractStreamOperatorTestHarness<>(src, 1, 1, 0);
+		testHarness.open();
+
+		final Throwable[] error = new Throwable[1];
+
+		final OneShotLatch latch = new OneShotLatch();
+
+		// run the source asynchronously
+		Thread runner = new Thread() {
+			@Override
+			public void run() {
+				try {
+					monitoringFunction.run(new DummySourceContext() {
+						@Override
+						public void collect(TimestampedFileInputSplit element) {
+							latch.trigger();
+						}
+					});
+				}
+				catch (Throwable t) {
+					t.printStackTrace();
+					error[0] = t;
+				}
+			}
+		};
+		runner.start();
+
+		if (!latch.isTriggered()) {
+			latch.await();
+		}
+
+		OperatorStateHandles snapshot = testHarness.snapshot(0, 0);
+		monitoringFunction.cancel();
+		runner.join();
+
+		testHarness.close();
+
+		final ContinuousFileMonitoringFunction<String> monitoringFunctionCopy =
+			new ContinuousFileMonitoringFunction<>(format, FileProcessingMode.PROCESS_CONTINUOUSLY, 1, INTERVAL);
+
+		StreamSource<TimestampedFileInputSplit, ContinuousFileMonitoringFunction<String>> srcCopy =
+			new StreamSource<>(monitoringFunctionCopy);
+
+		AbstractStreamOperatorTestHarness<TimestampedFileInputSplit> testHarnessCopy =
+			new AbstractStreamOperatorTestHarness<>(srcCopy, 1, 1, 0);
+		testHarnessCopy.initializeState(snapshot);
+		testHarnessCopy.open();
+
+		Assert.assertNull(error[0]);
+		Assert.assertEquals(fileModTime, monitoringFunctionCopy.getGlobalModificationTime());
+
+		hdfs.delete(path, false);
+	}
+
+	@Test
 	public void testProcessContinuously() throws Exception {
 		final OneShotLatch latch = new OneShotLatch();
 
@@ -698,7 +779,7 @@ public class ContinuousFileProcessingTest {
 		format.setFilesFilter(FilePathFilter.createDefaultFilter());
 
 		final ContinuousFileMonitoringFunction<String> monitoringFunction =
-			new ContinuousFileMonitoringFunction<>(format, hdfsURI,
+			new ContinuousFileMonitoringFunction<>(format,
 				FileProcessingMode.PROCESS_CONTINUOUSLY, 1, INTERVAL);
 
 		final int totalNoOfFilesToBeRead = NO_OF_FILES + 1; // 1 for the bootstrap + NO_OF_FILES

http://git-wip-us.apache.org/repos/asf/flink/blob/685c4f83/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/StreamExecutionEnvironment.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/StreamExecutionEnvironment.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/StreamExecutionEnvironment.java
index 08e17a1..99784e9 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/StreamExecutionEnvironment.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/StreamExecutionEnvironment.java
@@ -1351,9 +1351,7 @@ public abstract class StreamExecutionEnvironment {
 					ContinuousFileMonitoringFunction.MIN_MONITORING_INTERVAL + " ms.");
 
 		ContinuousFileMonitoringFunction<OUT> monitoringFunction =
-			new ContinuousFileMonitoringFunction<>(
-				inputFormat, inputFormat.getFilePath().toString(),
-				monitoringMode, getParallelism(), interval);
+			new ContinuousFileMonitoringFunction<>(inputFormat, monitoringMode, getParallelism(), interval);
 
 		ContinuousFileReaderOperator<OUT> reader =
 			new ContinuousFileReaderOperator<>(inputFormat);

http://git-wip-us.apache.org/repos/asf/flink/blob/685c4f83/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileMonitoringFunction.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileMonitoringFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileMonitoringFunction.java
index 54ab0ab..8723853 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileMonitoringFunction.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/ContinuousFileMonitoringFunction.java
@@ -17,14 +17,20 @@
 package org.apache.flink.streaming.api.functions.source;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.io.FileInputFormat;
 import org.apache.flink.api.common.io.FilePathFilter;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.core.fs.FileInputSplit;
 import org.apache.flink.core.fs.FileStatus;
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
-import org.apache.flink.streaming.api.checkpoint.Checkpointed;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
 import org.apache.flink.util.Preconditions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -58,7 +64,7 @@ import java.util.TreeMap;
  */
 @Internal
 public class ContinuousFileMonitoringFunction<OUT>
-	extends RichSourceFunction<TimestampedFileInputSplit> implements Checkpointed<Long> {
+	extends RichSourceFunction<TimestampedFileInputSplit> implements CheckpointedFunction {
 
 	private static final long serialVersionUID = 1L;
 
@@ -92,10 +98,13 @@ public class ContinuousFileMonitoringFunction<OUT>
 
 	private volatile boolean isRunning = true;
 
+	private transient ListState<Long> checkpointedState;
+
 	public ContinuousFileMonitoringFunction(
-		FileInputFormat<OUT> format, String path,
+		FileInputFormat<OUT> format,
 		FileProcessingMode watchType,
-		int readerParallelism, long interval) {
+		int readerParallelism,
+		long interval) {
 
 		Preconditions.checkArgument(
 			watchType == FileProcessingMode.PROCESS_ONCE || interval >= MIN_MONITORING_INTERVAL,
@@ -104,7 +113,7 @@ public class ContinuousFileMonitoringFunction<OUT>
 		);
 
 		this.format = Preconditions.checkNotNull(format, "Unspecified File Input Format.");
-		this.path = Preconditions.checkNotNull(path, "Unspecified Path.");
+		this.path = Preconditions.checkNotNull(format.getFilePath().toString(), "Unspecified Path.");
 
 		this.interval = interval;
 		this.watchType = watchType;
@@ -112,13 +121,56 @@ public class ContinuousFileMonitoringFunction<OUT>
 		this.globalModificationTime = Long.MIN_VALUE;
 	}
 
+	@VisibleForTesting
+	public long getGlobalModificationTime() {
+		return this.globalModificationTime;
+	}
+
+	@Override
+	public void initializeState(FunctionInitializationContext context) throws Exception {
+
+		Preconditions.checkState(this.checkpointedState == null,
+			"The " + getClass().getSimpleName() + " has already been initialized.");
+
+		this.checkpointedState = context.getOperatorStateStore().getOperatorState(
+			new ListStateDescriptor<>(
+				"file-monitoring-state",
+				LongSerializer.INSTANCE
+			)
+		);
+
+		if (context.isRestored()) {
+			LOG.info("Restoring state for the {}.", getClass().getSimpleName());
+
+			List<Long> retrievedStates = new ArrayList<>();
+			for (Long entry : this.checkpointedState.get()) {
+				retrievedStates.add(entry);
+			}
+
+			// given that the parallelism of the function is 1, we can only have 1 state
+			Preconditions.checkArgument(retrievedStates.size() == 1,
+				getClass().getSimpleName() + " retrieved invalid state.");
+
+			this.globalModificationTime = retrievedStates.get(0);
+
+			if (LOG.isDebugEnabled()) {
+				LOG.debug("{} retrieved a global mod time of {}.",
+					getClass().getSimpleName(), globalModificationTime);
+			}
+
+		} else {
+			LOG.info("No state to restore for the {}.", getClass().getSimpleName());
+		}
+	}
+
 	@Override
 	public void open(Configuration parameters) throws Exception {
 		super.open(parameters);
 		format.configure(parameters);
 
 		if (LOG.isDebugEnabled()) {
-			LOG.debug("Opened File Monitoring Source for path: " + path + ".");
+			LOG.debug("Opened {} (taskIdx= {}) for path: {}",
+				getClass().getSimpleName(), getRuntimeContext().getIndexOfThisSubtask(), path);
 		}
 	}
 
@@ -294,12 +346,15 @@ public class ContinuousFileMonitoringFunction<OUT>
 	//	---------------------			Checkpointing			--------------------------
 
 	@Override
-	public Long snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
-		return this.globalModificationTime;
-	}
+	public void snapshotState(FunctionSnapshotContext context) throws Exception {
+		Preconditions.checkState(this.checkpointedState != null,
+			"The " + getClass().getSimpleName() + " state has not been properly initialized.");
 
-	@Override
-	public void restoreState(Long state) throws Exception {
-		this.globalModificationTime = state;
+		this.checkpointedState.clear();
+		this.checkpointedState.add(this.globalModificationTime);
+
+		if (LOG.isDebugEnabled()) {
+			LOG.debug("{} checkpointed {}.", getClass().getSimpleName(), globalModificationTime);
+		}
 	}
 }


[3/4] flink git commit: [FLINK-5163] Port the MessageAcknowledgingSourceBase to the new state abstractions.

Posted by al...@apache.org.
[FLINK-5163] Port the MessageAcknowledgingSourceBase to the new state abstractions.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/956ffa69
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/956ffa69
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/956ffa69

Branch: refs/heads/master
Commit: 956ffa69a96ee74990c032c647a887feef4b2508
Parents: d24833d
Author: kl0u <kk...@gmail.com>
Authored: Fri Nov 18 16:07:45 2016 +0100
Committer: Aljoscha Krettek <al...@gmail.com>
Committed: Tue Dec 13 13:38:18 2016 +0100

----------------------------------------------------------------------
 .../flink-connector-rabbitmq/pom.xml            | 16 ++++
 .../connectors/rabbitmq/RMQSourceTest.java      | 54 +++++++++++--
 .../source/MessageAcknowledgingSourceBase.java  | 82 +++++++++++++-------
 ...ltipleIdsMessageAcknowledgingSourceBase.java |  8 +-
 4 files changed, 120 insertions(+), 40 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/956ffa69/flink-connectors/flink-connector-rabbitmq/pom.xml
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-rabbitmq/pom.xml b/flink-connectors/flink-connector-rabbitmq/pom.xml
index 0b69d66..2710e53 100644
--- a/flink-connectors/flink-connector-rabbitmq/pom.xml
+++ b/flink-connectors/flink-connector-rabbitmq/pom.xml
@@ -55,6 +55,22 @@ under the License.
 			<version>${rabbitmq.version}</version>
 		</dependency>
 
+		<dependency>
+			<groupId>org.apache.flink</groupId>
+			<artifactId>flink-streaming-java_2.10</artifactId>
+			<version>${project.version}</version>
+			<scope>test</scope>
+			<type>test-jar</type>
+		</dependency>
+
+		<dependency>
+			<groupId>org.apache.flink</groupId>
+			<artifactId>flink-runtime_2.10</artifactId>
+			<version>${project.version}</version>
+			<type>test-jar</type>
+			<scope>test</scope>
+		</dependency>
+
 	</dependencies>
 
 </project>

http://git-wip-us.apache.org/repos/asf/flink/blob/956ffa69/flink-connectors/flink-connector-rabbitmq/src/test/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSourceTest.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-rabbitmq/src/test/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSourceTest.java b/flink-connectors/flink-connector-rabbitmq/src/test/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSourceTest.java
index b63c835..8474f8a 100644
--- a/flink-connectors/flink-connector-rabbitmq/src/test/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSourceTest.java
+++ b/flink-connectors/flink-connector-rabbitmq/src/test/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSourceTest.java
@@ -23,16 +23,19 @@ import com.rabbitmq.client.ConnectionFactory;
 import com.rabbitmq.client.Envelope;
 import com.rabbitmq.client.QueueingConsumer;
 import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.api.common.state.OperatorStateStore;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.api.common.typeutils.base.StringSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.typeutils.TypeExtractor;
 import org.apache.flink.configuration.Configuration;
-import org.apache.flink.runtime.state.SerializedCheckpointData;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.operators.StreamSource;
 import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.connectors.rabbitmq.common.RMQConnectionConfig;
+import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
+import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
 import org.apache.flink.streaming.util.serialization.DeserializationSchema;
 import org.junit.After;
 import org.junit.Before;
@@ -53,6 +56,7 @@ import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
+import static org.mockito.Matchers.any;
 
 
 /**
@@ -83,7 +87,13 @@ public class RMQSourceTest {
 	@Before
 	public void beforeTest() throws Exception {
 
+		OperatorStateStore mockStore = Mockito.mock(OperatorStateStore.class);
+		FunctionInitializationContext mockContext = Mockito.mock(FunctionInitializationContext.class);
+		Mockito.when(mockContext.getOperatorStateStore()).thenReturn(mockStore);
+		Mockito.when(mockStore.getSerializableListState(any(String.class))).thenReturn(null);
+
 		source = new RMQTestSource();
+		source.initializeState(mockContext);
 		source.open(config);
 
 		messageId = 0;
@@ -128,6 +138,12 @@ public class RMQSourceTest {
 	@Test
 	public void testCheckpointing() throws Exception {
 		source.autoAck = false;
+
+		StreamSource<String, RMQSource<String>> src = new StreamSource<>(source);
+		AbstractStreamOperatorTestHarness<String> testHarness =
+			new AbstractStreamOperatorTestHarness<>(src, 1, 1, 0);
+		testHarness.open();
+
 		sourceThread.start();
 
 		Thread.sleep(5);
@@ -141,10 +157,10 @@ public class RMQSourceTest {
 
 		for (int i=0; i < numSnapshots; i++) {
 			long snapshotId = random.nextLong();
-			SerializedCheckpointData[] data;
+			OperatorStateHandles data;
 
 			synchronized (DummySourceContext.lock) {
-				data = source.snapshotState(snapshotId, System.currentTimeMillis());
+				data = testHarness.snapshot(snapshotId, System.currentTimeMillis());
 				previousSnapshotId = lastSnapshotId;
 				lastSnapshotId = messageId;
 			}
@@ -153,15 +169,25 @@ public class RMQSourceTest {
 
 			// check if the correct number of messages have been snapshotted
 			final long numIds = lastSnapshotId - previousSnapshotId;
-			assertEquals(numIds, data[0].getNumIds());
-			// deserialize and check if the last id equals the last snapshotted id
-			ArrayDeque<Tuple2<Long, List<String>>> deque = SerializedCheckpointData.toDeque(data, new StringSerializer());
+
+			RMQTestSource sourceCopy = new RMQTestSource();
+			StreamSource<String, RMQTestSource> srcCopy = new StreamSource<>(sourceCopy);
+			AbstractStreamOperatorTestHarness<String> testHarnessCopy =
+				new AbstractStreamOperatorTestHarness<>(srcCopy, 1, 1, 0);
+
+			testHarnessCopy.setup();
+			testHarnessCopy.initializeState(data);
+			testHarnessCopy.open();
+
+			ArrayDeque<Tuple2<Long, List<String>>> deque = sourceCopy.getRestoredState();
 			List<String> messageIds = deque.getLast().f1;
+
+			assertEquals(numIds, messageIds.size());
 			if (messageIds.size() > 0) {
 				assertEquals(lastSnapshotId, (long) Long.valueOf(messageIds.get(messageIds.size() - 1)));
 			}
 
-			// check if the messages are being acknowledged and the transaction comitted
+			// check if the messages are being acknowledged and the transaction committed
 			synchronized (DummySourceContext.lock) {
 				source.notifyCheckpointComplete(snapshotId);
 			}
@@ -313,6 +339,8 @@ public class RMQSourceTest {
 
 	private class RMQTestSource extends RMQSource<String> {
 
+		private ArrayDeque<Tuple2<Long, List<String>>> restoredState;
+
 		public RMQTestSource() {
 			super(new RMQConnectionConfig.Builder().setHost("hostTest")
 					.setPort(999).setUserName("userTest").setPassword("passTest").setVirtualHost("/").build()
@@ -320,6 +348,16 @@ public class RMQSourceTest {
 		}
 
 		@Override
+		public void initializeState(FunctionInitializationContext context) throws Exception {
+			super.initializeState(context);
+			this.restoredState = this.pendingCheckpoints;
+		}
+
+		public ArrayDeque<Tuple2<Long, List<String>>> getRestoredState() {
+			return this.restoredState;
+		}
+
+		@Override
 		public void open(Configuration config) throws Exception {
 			super.open(config);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/956ffa69/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/MessageAcknowledgingSourceBase.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/MessageAcknowledgingSourceBase.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/MessageAcknowledgingSourceBase.java
index 5c1b94e..035b7dd 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/MessageAcknowledgingSourceBase.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/MessageAcknowledgingSourceBase.java
@@ -27,14 +27,17 @@ import java.util.Set;
 
 import org.apache.flink.annotation.PublicEvolving;
 import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.typeutils.TypeExtractor;
-import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.state.CheckpointListener;
-import org.apache.flink.streaming.api.checkpoint.Checkpointed;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
 import org.apache.flink.runtime.state.SerializedCheckpointData;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.util.Preconditions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -73,14 +76,16 @@ import org.slf4j.LoggerFactory;
  *     }
  * }
  * }</pre>
- * 
+ *
+ * <b>NOTE:</b> This source has a parallelism of {@code 1}.
+ *
  * @param <Type> The type of the messages created by the source.
  * @param <UId> The type of unique IDs which may be used to acknowledge elements.
  */
 @PublicEvolving
 public abstract class MessageAcknowledgingSourceBase<Type, UId>
 	extends RichSourceFunction<Type>
-	implements Checkpointed<SerializedCheckpointData[]>, CheckpointListener {
+	implements CheckpointedFunction, CheckpointListener {
 
 	private static final long serialVersionUID = -8689291992192955579L;
 
@@ -93,7 +98,7 @@ public abstract class MessageAcknowledgingSourceBase<Type, UId>
 	private transient List<UId> idsForCurrentCheckpoint;
 
 	/** The list with IDs from checkpoints that were triggered, but not yet completed or notified of completion */
-	private transient ArrayDeque<Tuple2<Long, List<UId>>> pendingCheckpoints;
+	protected transient ArrayDeque<Tuple2<Long, List<UId>>> pendingCheckpoints;
 
 	/**
 	 * Set which contain all processed ids. Ids are acknowledged after checkpoints. When restoring
@@ -102,6 +107,8 @@ public abstract class MessageAcknowledgingSourceBase<Type, UId>
 	 */
 	private transient Set<UId> idsProcessedButNotAcknowledged;
 
+	private transient ListState<SerializedCheckpointData[]> checkpointedState;
+
 	// ------------------------------------------------------------------------
 
 	/**
@@ -123,13 +130,38 @@ public abstract class MessageAcknowledgingSourceBase<Type, UId>
 	}
 
 	@Override
-	public void open(Configuration parameters) throws Exception {
-		idsForCurrentCheckpoint = new ArrayList<>(64);
-		if (pendingCheckpoints == null) {
-			pendingCheckpoints = new ArrayDeque<>();
-		}
-		if (idsProcessedButNotAcknowledged == null) {
-			idsProcessedButNotAcknowledged = new HashSet<>();
+	public void initializeState(FunctionInitializationContext context) throws Exception {
+		Preconditions.checkState(this.checkpointedState == null,
+			"The " + getClass().getSimpleName() + " has already been initialized.");
+
+		this.checkpointedState = context
+			.getOperatorStateStore()
+			.getSerializableListState("message-acknowledging-source-state");
+
+		this.idsForCurrentCheckpoint = new ArrayList<>(64);
+		this.pendingCheckpoints = new ArrayDeque<>();
+		this.idsProcessedButNotAcknowledged = new HashSet<>();
+
+		if (context.isRestored()) {
+			LOG.info("Restoring state for the {}.", getClass().getSimpleName());
+
+			List<SerializedCheckpointData[]> retrievedStates = new ArrayList<>();
+			for (SerializedCheckpointData[] entry : this.checkpointedState.get()) {
+				retrievedStates.add(entry);
+			}
+
+			// given that the parallelism of the function is 1, we can only have at most 1 state
+			Preconditions.checkArgument(retrievedStates.size() == 1,
+				getClass().getSimpleName() + " retrieved invalid state.");
+
+			pendingCheckpoints = SerializedCheckpointData.toDeque(retrievedStates.get(0), idSerializer);
+			// build a set which contains all processed ids. It may be used to check if we have
+			// already processed an incoming message.
+			for (Tuple2<Long, List<UId>> checkpoint : pendingCheckpoints) {
+				idsProcessedButNotAcknowledged.addAll(checkpoint.f1);
+			}
+		} else {
+			LOG.info("No state to restore for the {}.", getClass().getSimpleName());
 		}
 	}
 
@@ -166,26 +198,20 @@ public abstract class MessageAcknowledgingSourceBase<Type, UId>
 	// ------------------------------------------------------------------------
 
 	@Override
-	public SerializedCheckpointData[] snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
-		LOG.debug("Snapshotting state. Messages: {}, checkpoint id: {}, timestamp: {}",
-					idsForCurrentCheckpoint, checkpointId, checkpointTimestamp);
+	public void snapshotState(FunctionSnapshotContext context) throws Exception {
+		Preconditions.checkState(this.checkpointedState != null,
+			"The " + getClass().getSimpleName() + " has not been properly initialized.");
 
-		pendingCheckpoints.addLast(new Tuple2<>(checkpointId, idsForCurrentCheckpoint));
+		if (LOG.isDebugEnabled()) {
+			LOG.debug("{} checkpointing: Messages: {}, checkpoint id: {}, timestamp: {}",
+				idsForCurrentCheckpoint, context.getCheckpointId(), context.getCheckpointTimestamp());
+		}
 
+		pendingCheckpoints.addLast(new Tuple2<>(context.getCheckpointId(), idsForCurrentCheckpoint));
 		idsForCurrentCheckpoint = new ArrayList<>(64);
 
-		return SerializedCheckpointData.fromDeque(pendingCheckpoints, idSerializer);
-	}
-
-	@Override
-	public void restoreState(SerializedCheckpointData[] state) throws Exception {
-		idsProcessedButNotAcknowledged = new HashSet<>();
-		pendingCheckpoints = SerializedCheckpointData.toDeque(state, idSerializer);
-		// build a set which contains all processed ids. It may be used to check if we have
-		// already processed an incoming message.
-		for (Tuple2<Long, List<UId>> checkpoint : pendingCheckpoints) {
-			idsProcessedButNotAcknowledged.addAll(checkpoint.f1);
-		}
+		this.checkpointedState.clear();
+		this.checkpointedState.add(SerializedCheckpointData.fromDeque(pendingCheckpoints, idSerializer));
 	}
 
 	@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/956ffa69/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/MultipleIdsMessageAcknowledgingSourceBase.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/MultipleIdsMessageAcknowledgingSourceBase.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/MultipleIdsMessageAcknowledgingSourceBase.java
index 2237c1c..188fdce 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/MultipleIdsMessageAcknowledgingSourceBase.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/source/MultipleIdsMessageAcknowledgingSourceBase.java
@@ -22,7 +22,7 @@ import org.apache.flink.annotation.PublicEvolving;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.configuration.Configuration;
-import org.apache.flink.runtime.state.SerializedCheckpointData;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -133,9 +133,9 @@ public abstract class MultipleIdsMessageAcknowledgingSourceBase<Type, UId, Sessi
 
 
 	@Override
-	public SerializedCheckpointData[] snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
-		sessionIdsPerSnapshot.add(new Tuple2<>(checkpointId, sessionIds));
+	public void snapshotState(FunctionSnapshotContext context) throws Exception {
+		sessionIdsPerSnapshot.add(new Tuple2<>(context.getCheckpointId(), sessionIds));
 		sessionIds = new ArrayList<>(64);
-		return super.snapshotState(checkpointId, checkpointTimestamp);
+		super.snapshotState(context);
 	}
 }