You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tz...@apache.org on 2017/05/07 09:35:28 UTC

[1/2] flink git commit: [FLINK-4821] Implement rescalable non-partitioned state for Kinesis Connector

Repository: flink
Updated Branches:
  refs/heads/master fde4f9097 -> e5b65a7fc


[FLINK-4821] Implement rescalable non-partitioned state for Kinesis Connector


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/a05b574c
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/a05b574c
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/a05b574c

Branch: refs/heads/master
Commit: a05b574cc68d3f652d11fece46c23bbc24f35430
Parents: fde4f90
Author: Tony Wei <to...@gmail.com>
Authored: Wed Dec 14 10:18:25 2016 +0800
Committer: Tzu-Li (Gordon) Tai <tz...@apache.org>
Committed: Sun May 7 16:28:52 2017 +0800

----------------------------------------------------------------------
 .../flink-connector-kinesis/pom.xml             |  16 +
 .../kinesis/FlinkKinesisConsumer.java           | 150 +++++--
 .../kinesis/internals/KinesisDataFetcher.java   |   2 +-
 .../FlinkKinesisConsumerMigrationTest.java      | 149 +++++++
 .../kinesis/FlinkKinesisConsumerTest.java       | 440 +++++++++++++++++--
 ...is-consumer-migration-test-flink1.1-snapshot | Bin 0 -> 1140 bytes
 ...sumer-migration-test-flink1.1-snapshot-empty | Bin 0 -> 468 bytes
 7 files changed, 690 insertions(+), 67 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/a05b574c/flink-connectors/flink-connector-kinesis/pom.xml
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/pom.xml b/flink-connectors/flink-connector-kinesis/pom.xml
index d457199..080626f 100644
--- a/flink-connectors/flink-connector-kinesis/pom.xml
+++ b/flink-connectors/flink-connector-kinesis/pom.xml
@@ -57,6 +57,14 @@ under the License.
 
 		<dependency>
 			<groupId>org.apache.flink</groupId>
+			<artifactId>flink-runtime_2.10</artifactId>
+			<version>${project.version}</version>
+			<type>test-jar</type>
+			<scope>test</scope>
+		</dependency>
+
+		<dependency>
+			<groupId>org.apache.flink</groupId>
 			<artifactId>flink-tests_2.10</artifactId>
 			<version>${project.version}</version>
 			<scope>test</scope>
@@ -65,6 +73,14 @@ under the License.
 
 		<dependency>
 			<groupId>org.apache.flink</groupId>
+			<artifactId>flink-streaming-java_2.10</artifactId>
+			<version>${project.version}</version>
+			<type>test-jar</type>
+			<scope>test</scope>
+		</dependency>
+
+		<dependency>
+			<groupId>org.apache.flink</groupId>
 			<artifactId>flink-test-utils_2.10</artifactId>
 			<version>${project.version}</version>
 			<scope>test</scope>

http://git-wip-us.apache.org/repos/asf/flink/blob/a05b574c/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java
index a62dc10..dfcd552 100644
--- a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java
+++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java
@@ -17,15 +17,26 @@
 
 package org.apache.flink.streaming.connectors.kinesis;
 
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
 import org.apache.flink.configuration.Configuration;
-import org.apache.flink.streaming.api.checkpoint.CheckpointedAsynchronously;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedRestoring;
 import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
 import org.apache.flink.streaming.connectors.kinesis.internals.KinesisDataFetcher;
 import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard;
-import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber;
 import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState;
+import org.apache.flink.streaming.connectors.kinesis.model.SentinelSequenceNumber;
+import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber;
 import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchema;
 import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchemaWrapper;
 import org.apache.flink.streaming.connectors.kinesis.util.KinesisConfigUtil;
@@ -55,8 +66,10 @@ import static org.apache.flink.util.Preconditions.checkNotNull;
  *
  * @param <T> the type of data emitted
  */
-public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T>
-	implements CheckpointedAsynchronously<HashMap<KinesisStreamShard, SequenceNumber>>, ResultTypeQueryable<T> {
+public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> implements
+	ResultTypeQueryable<T>,
+	CheckpointedFunction,
+	CheckpointedRestoring<HashMap<KinesisStreamShard, SequenceNumber>> {
 
 	private static final long serialVersionUID = 4724006128720664870L;
 
@@ -91,6 +104,14 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T>
 
 	private volatile boolean running = true;
 
+	// ------------------------------------------------------------------------
+	//  State for Checkpoint
+	// ------------------------------------------------------------------------
+
+	/** The name is the key for sequence numbers state, and cannot be changed. */
+	private static final String sequenceNumsStateStoreName = "Kinesis-Stream-Shard-State";
+
+	private transient ListState<Tuple2<KinesisStreamShard, SequenceNumber>> sequenceNumsStateForCheckpoint;
 
 	// ------------------------------------------------------------------------
 	//  Constructors
@@ -194,8 +215,7 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T>
 		// all subtasks will run a fetcher, regardless of whether or not the subtask will initially have
 		// shards to subscribe to; fetchers will continuously poll for changes in the shard list, so all subtasks
 		// can potentially have new shards to subscribe to later on
-		fetcher = new KinesisDataFetcher<>(
-			streams, sourceContext, getRuntimeContext(), configProps, deserializer);
+		fetcher = createFetcher(streams, sourceContext, getRuntimeContext(), configProps, deserializer);
 
 		boolean isRestoringFromFailure = (sequenceNumsToRestore != null);
 		fetcher.setIsRestoringFromFailure(isRestoringFromFailure);
@@ -203,17 +223,35 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T>
 		// if we are restoring from a checkpoint, we iterate over the restored
 		// state and accordingly seed the fetcher with subscribed shards states
 		if (isRestoringFromFailure) {
-			for (Map.Entry<KinesisStreamShard, SequenceNumber> restored : lastStateSnapshot.entrySet()) {
-				fetcher.advanceLastDiscoveredShardOfStream(
-					restored.getKey().getStreamName(), restored.getKey().getShard().getShardId());
-
-				if (LOG.isInfoEnabled()) {
-					LOG.info("Subtask {} is seeding the fetcher with restored shard {}," +
-							" starting state set to the restored sequence number {}",
-						getRuntimeContext().getIndexOfThisSubtask(), restored.getKey().toString(), restored.getValue());
+			// Since there may have a situation that some subtasks did not finish discovering before rescale,
+			// and KinesisDataFetcher will always discover the shard from the largest shard id. To prevent from
+			// missing some shards which didn't be discovered and whose id is not the largest one, we force the
+			// consumer to discover once from the smallest id and make sure each shard have its initial sequence
+			// number from restored state or SENTINEL_EARLIEST_SEQUENCE_NUM.
+			List<KinesisStreamShard> newShardsCreatedWhileNotRunning = fetcher.discoverNewShardsToSubscribe();
+			for (KinesisStreamShard shard : newShardsCreatedWhileNotRunning) {
+				SequenceNumber startingStateForNewShard;
+
+				if (lastStateSnapshot.containsKey(shard)) {
+					startingStateForNewShard = lastStateSnapshot.get(shard);
+
+					if (LOG.isInfoEnabled()) {
+						LOG.info("Subtask {} is seeding the fetcher with restored shard {}," +
+								" starting state set to the restored sequence number {}",
+							getRuntimeContext().getIndexOfThisSubtask(), shard.toString(), startingStateForNewShard);
+					}
+				} else {
+					startingStateForNewShard = SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get();
+
+					if (LOG.isInfoEnabled()) {
+						LOG.info("Subtask {} is seeding the fetcher with new discovered shard {}," +
+								" starting state set to the SENTINEL_EARLIEST_SEQUENCE_NUM",
+							getRuntimeContext().getIndexOfThisSubtask(), shard.toString());
+					}
 				}
+
 				fetcher.registerNewSubscribedShardState(
-					new KinesisStreamShardState(restored.getKey(), restored.getValue()));
+					new KinesisStreamShardState(shard, startingStateForNewShard));
 			}
 		}
 
@@ -267,38 +305,78 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T>
 	// ------------------------------------------------------------------------
 
 	@Override
-	public HashMap<KinesisStreamShard, SequenceNumber> snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
-		if (lastStateSnapshot == null) {
-			LOG.debug("snapshotState() requested on not yet opened source; returning null.");
-			return null;
-		}
+	public void initializeState(FunctionInitializationContext context) throws Exception {
+		TypeInformation<Tuple2<KinesisStreamShard, SequenceNumber>> tuple = new TupleTypeInfo<>(
+			TypeInformation.of(KinesisStreamShard.class),
+			TypeInformation.of(SequenceNumber.class)
+		);
+
+		sequenceNumsStateForCheckpoint = context.getOperatorStateStore().getUnionListState(
+			new ListStateDescriptor<>(sequenceNumsStateStoreName, tuple));
+
+		if (context.isRestored()) {
+			if (sequenceNumsToRestore == null) {
+				sequenceNumsToRestore = new HashMap<>();
+				for (Tuple2<KinesisStreamShard, SequenceNumber> kinesisSequenceNumber : sequenceNumsStateForCheckpoint.get()) {
+					sequenceNumsToRestore.put(kinesisSequenceNumber.f0, kinesisSequenceNumber.f1);
+				}
 
-		if (fetcher == null) {
-			LOG.debug("snapshotState() requested on not yet running source; returning null.");
-			return null;
+				LOG.info("Setting restore state in the FlinkKinesisConsumer. Using the following offsets: {}",
+					sequenceNumsToRestore);
+			} else if (sequenceNumsToRestore.isEmpty()) {
+				sequenceNumsToRestore = null;
+			}
+		} else {
+			LOG.info("No restore state for FlinkKinesisConsumer.");
 		}
+	}
 
-		if (!running) {
+	@Override
+	public void snapshotState(FunctionSnapshotContext context) throws Exception {
+		if (lastStateSnapshot == null) {
+			LOG.debug("snapshotState() requested on not yet opened source; returning null.");
+		} else if (fetcher == null) {
+			LOG.debug("snapshotState() requested on not yet running source; returning null.");
+		} else if (!running) {
 			LOG.debug("snapshotState() called on closed source; returning null.");
-			return null;
-		}
+		} else {
+			if (LOG.isDebugEnabled()) {
+				LOG.debug("Snapshotting state ...");
+			}
 
-		if (LOG.isDebugEnabled()) {
-			LOG.debug("Snapshotting state ...");
-		}
+			sequenceNumsStateForCheckpoint.clear();
+			lastStateSnapshot = fetcher.snapshotState();
 
-		lastStateSnapshot = fetcher.snapshotState();
+			if (LOG.isDebugEnabled()) {
+				LOG.debug("Snapshotted state, last processed sequence numbers: {}, checkpoint id: {}, timestamp: {}",
+					lastStateSnapshot.toString(), context.getCheckpointId(), context.getCheckpointTimestamp());
+			}
 
-		if (LOG.isDebugEnabled()) {
-			LOG.debug("Snapshotted state, last processed sequence numbers: {}, checkpoint id: {}, timestamp: {}",
-				lastStateSnapshot.toString(), checkpointId, checkpointTimestamp);
+			for (Map.Entry<KinesisStreamShard, SequenceNumber> entry : lastStateSnapshot.entrySet()) {
+				sequenceNumsStateForCheckpoint.add(Tuple2.of(entry.getKey(), entry.getValue()));
+			}
 		}
-
-		return lastStateSnapshot;
 	}
 
 	@Override
 	public void restoreState(HashMap<KinesisStreamShard, SequenceNumber> restoredState) throws Exception {
-		sequenceNumsToRestore = restoredState;
+		LOG.info("Subtask {} restoring offsets from an older Flink version: {}",
+			getRuntimeContext().getIndexOfThisSubtask(), sequenceNumsToRestore);
+
+		sequenceNumsToRestore = restoredState.isEmpty() ? null : restoredState;
+	}
+
+	/** This method is created for tests that can mock the KinesisDataFetcher in the consumer. */
+	protected KinesisDataFetcher<T> createFetcher(List<String> streams,
+													SourceFunction.SourceContext<T> sourceContext,
+													RuntimeContext runtimeContext,
+													Properties configProps,
+													KinesisDeserializationSchema<T> deserializationSchema) {
+		return new KinesisDataFetcher<>(streams, sourceContext, runtimeContext, configProps, deserializationSchema);
+	}
+
+	@VisibleForTesting
+	HashMap<KinesisStreamShard, SequenceNumber> getRestoredState() {
+		return sequenceNumsToRestore;
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/a05b574c/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
index 8f7ca6c..c5b4b04 100644
--- a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
+++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
@@ -461,7 +461,7 @@ public class KinesisDataFetcher<T> {
 	 * 3. Update the subscribedStreamsToLastDiscoveredShardIds state so that we won't get shards
 	 *    that we have already seen before the next time this function is called
 	 */
-	private List<KinesisStreamShard> discoverNewShardsToSubscribe() throws InterruptedException {
+	public List<KinesisStreamShard> discoverNewShardsToSubscribe() throws InterruptedException {
 
 		List<KinesisStreamShard> newShardsToSubscribe = new LinkedList<>();
 

http://git-wip-us.apache.org/repos/asf/flink/blob/a05b574c/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java
new file mode 100644
index 0000000..2f46e09
--- /dev/null
+++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kinesis;
+
+import com.amazonaws.services.kinesis.model.Shard;
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.streaming.api.TimeCharacteristic;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.operators.StreamSource;
+import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants;
+import org.apache.flink.streaming.connectors.kinesis.internals.KinesisDataFetcher;
+import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard;
+import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber;
+import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchema;
+import org.apache.flink.streaming.connectors.kinesis.testutils.KinesisShardIdGenerator;
+import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
+import org.junit.Test;
+
+import java.net.URL;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Properties;
+
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Mockito.mock;
+
+/**
+ * Tests for checking whether {@link FlinkKinesisConsumer} can restore from snapshots that were
+ * done using the Flink 1.1 {@link FlinkKinesisConsumer}.
+ *
+ * <p>For regenerating the binary snapshot file you have to run the commented out portion
+ * of each test on a checkout of the Flink 1.1 branch.
+ */
+public class FlinkKinesisConsumerMigrationTest {
+
+	@Test
+	public void testRestoreFromFlink11WithEmptyState() throws Exception {
+		Properties testConfig = new Properties();
+		testConfig.setProperty(ConsumerConfigConstants.AWS_REGION, "us-east-1");
+		testConfig.setProperty(ConsumerConfigConstants.AWS_CREDENTIALS_PROVIDER, "BASIC");
+		testConfig.setProperty(ConsumerConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId");
+		testConfig.setProperty(ConsumerConfigConstants.AWS_SECRET_ACCESS_KEY, "secretKey");
+
+		final DummyFlinkKafkaConsumer<String> consumerFunction = new DummyFlinkKafkaConsumer<>(testConfig);
+
+		StreamSource<String, DummyFlinkKafkaConsumer<String>> consumerOperator = new StreamSource<>(consumerFunction);
+
+		final AbstractStreamOperatorTestHarness<String> testHarness =
+			new AbstractStreamOperatorTestHarness<>(consumerOperator, 1, 1, 0);
+
+		testHarness.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
+
+		testHarness.setup();
+		// restore state from binary snapshot file using legacy method
+		testHarness.initializeStateFromLegacyCheckpoint(
+			getResourceFilename("kinesis-consumer-migration-test-flink1.1-snapshot-empty"));
+		testHarness.open();
+
+		// assert that no state was restored
+		assertEquals(null, consumerFunction.getRestoredState());
+
+		consumerOperator.close();
+		consumerOperator.cancel();
+	}
+
+	@Test
+	public void testRestoreFromFlink11() throws Exception {
+		Properties testConfig = new Properties();
+		testConfig.setProperty(ConsumerConfigConstants.AWS_REGION, "us-east-1");
+		testConfig.setProperty(ConsumerConfigConstants.AWS_CREDENTIALS_PROVIDER, "BASIC");
+		testConfig.setProperty(ConsumerConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId");
+		testConfig.setProperty(ConsumerConfigConstants.AWS_SECRET_ACCESS_KEY, "secretKey");
+
+		final DummyFlinkKafkaConsumer<String> consumerFunction = new DummyFlinkKafkaConsumer<>(testConfig);
+
+		StreamSource<String, DummyFlinkKafkaConsumer<String>> consumerOperator =
+			new StreamSource<>(consumerFunction);
+
+		final AbstractStreamOperatorTestHarness<String> testHarness =
+			new AbstractStreamOperatorTestHarness<>(consumerOperator, 1, 1, 0);
+
+		testHarness.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
+
+		testHarness.setup();
+		// restore state from binary snapshot file using legacy method
+		testHarness.initializeStateFromLegacyCheckpoint(
+			getResourceFilename("kinesis-consumer-migration-test-flink1.1-snapshot"));
+		testHarness.open();
+
+		// the expected state in "kafka-consumer-migration-test-flink1.1-snapshot"
+		final HashMap<KinesisStreamShard, SequenceNumber> expectedState = new HashMap<>();
+		expectedState.put(new KinesisStreamShard("fakeStream1",
+				new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
+			new SequenceNumber("987654321"));
+
+		// assert that state is correctly restored from legacy checkpoint
+		assertNotEquals(null, consumerFunction.getRestoredState());
+		assertEquals(1, consumerFunction.getRestoredState().size());
+		assertEquals(expectedState, consumerFunction.getRestoredState());
+
+		consumerOperator.close();
+		consumerOperator.cancel();
+	}
+
+	// ------------------------------------------------------------------------
+
+	private static String getResourceFilename(String filename) {
+		ClassLoader cl = FlinkKinesisConsumerMigrationTest.class.getClassLoader();
+		URL resource = cl.getResource(filename);
+		if (resource == null) {
+			throw new NullPointerException("Missing snapshot resource.");
+		}
+		return resource.getFile();
+	}
+
+	private static class DummyFlinkKafkaConsumer<T> extends FlinkKinesisConsumer<T> {
+		private static final long serialVersionUID = 1L;
+
+		@SuppressWarnings("unchecked")
+		DummyFlinkKafkaConsumer(Properties properties) {
+			super("test", mock(KinesisDeserializationSchema.class), properties);
+		}
+
+		@Override
+		protected KinesisDataFetcher<T> createFetcher(List<String> streams,
+													  	SourceFunction.SourceContext<T> sourceContext,
+													  	RuntimeContext runtimeContext,
+													  	Properties configProps,
+													  	KinesisDeserializationSchema<T> deserializationSchema) {
+			return mock(KinesisDataFetcher.class);
+		}
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/a05b574c/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java
index 741f0ca..bf8e44f 100644
--- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java
+++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java
@@ -18,13 +18,22 @@
 package org.apache.flink.streaming.connectors.kinesis;
 
 import com.amazonaws.services.kinesis.model.Shard;
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.OperatorStateStore;
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContextSynchronousImpl;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
 import org.apache.flink.streaming.connectors.kinesis.config.AWSConfigConstants;
 import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants;
 import org.apache.flink.streaming.connectors.kinesis.config.ProducerConfigConstants;
 import org.apache.flink.streaming.connectors.kinesis.internals.KinesisDataFetcher;
 import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard;
+import org.apache.flink.streaming.connectors.kinesis.model.SentinelSequenceNumber;
 import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber;
 import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState;
 import org.apache.flink.streaming.connectors.kinesis.testutils.KinesisShardIdGenerator;
@@ -35,19 +44,29 @@ import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
 import org.junit.runner.RunWith;
+import org.mockito.Matchers;
 import org.mockito.Mockito;
+import org.mockito.internal.util.reflection.Whitebox;
 import org.powermock.api.mockito.PowerMockito;
 import org.powermock.core.classloader.annotations.PrepareForTest;
 import org.powermock.modules.junit4.PowerMockRunner;
 
-import java.text.SimpleDateFormat;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.ArrayList;
+import java.util.List;
 import java.util.Properties;
 import java.util.UUID;
+import java.io.Serializable;
 
-import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotEquals;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.when;
+import static org.mockito.Mockito.never;
 
 /**
  * Suite of FlinkKinesisConsumer tests for the methods called throughout the source life cycle.
@@ -511,28 +530,149 @@ public class FlinkKinesisConsumerTest {
 	// ----------------------------------------------------------------------
 
 	@Test
-	public void testSnapshotStateShouldBeNullIfSourceNotOpened() throws Exception {
+	public void testSnapshotStateShouldNotClearListStateIfSourceNotOpened() throws Exception {
 		Properties config = new Properties();
 		config.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1");
 		config.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId");
 		config.setProperty(AWSConfigConstants.AWS_SECRET_ACCESS_KEY, "secretKey");
 
+		OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
+
+		TestingListState<Serializable> listState = new TestingListState<>();
+
 		FlinkKinesisConsumer<String> consumer = new FlinkKinesisConsumer<>("fakeStream", new SimpleStringSchema(), config);
 
-		assertTrue(consumer.snapshotState(123, 123) == null); //arbitrary checkpoint id and timestamp
+		when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+
+		StateInitializationContext initializationContext = mock(StateInitializationContext.class);
+
+		when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore);
+		when(initializationContext.isRestored()).thenReturn(false);
+
+		consumer.initializeState(initializationContext);
+
+		consumer.snapshotState(new StateSnapshotContextSynchronousImpl(123, 123)); //arbitrary checkpoint id and timestamp
+
+		assertFalse(listState.isClearCalled());
 	}
 
 	@Test
-	public void testSnapshotStateShouldBeNullIfSourceNotRun() throws Exception {
+	public void testSnapshotStateShouldNotClearListStateIfSourceNotRun() throws Exception {
 		Properties config = new Properties();
 		config.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1");
 		config.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId");
 		config.setProperty(AWSConfigConstants.AWS_SECRET_ACCESS_KEY, "secretKey");
 
+		OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
+
+		TestingListState<Serializable> listState = new TestingListState<>();
+
 		FlinkKinesisConsumer<String> consumer = new FlinkKinesisConsumer<>("fakeStream", new SimpleStringSchema(), config);
+
+		when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+
+		StateInitializationContext initializationContext = mock(StateInitializationContext.class);
+
+		when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore);
+		when(initializationContext.isRestored()).thenReturn(false);
+
+		consumer.initializeState(initializationContext);
+
 		consumer.open(new Configuration()); // only opened, not run
 
-		assertTrue(consumer.snapshotState(123, 123) == null); //arbitrary checkpoint id and timestamp
+		consumer.snapshotState(new StateSnapshotContextSynchronousImpl(123, 123)); //arbitrary checkpoint id and timestamp
+
+		assertFalse(listState.isClearCalled());
+	}
+
+	@Test
+	public void testListStateChangedAfterSnapshotState() throws Exception {
+		// ----------------------------------------------------------------------
+		// setting config, initial state and state after snapshot
+		// ----------------------------------------------------------------------
+		Properties config = new Properties();
+		config.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1");
+		config.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId");
+		config.setProperty(AWSConfigConstants.AWS_SECRET_ACCESS_KEY, "secretKey");
+
+		ArrayList<Tuple2<KinesisStreamShard, SequenceNumber>> initialState = new ArrayList<>(1);
+		initialState.add(Tuple2.of(
+			new KinesisStreamShard("fakeStream1",
+				new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
+			new SequenceNumber("1")));
+
+		ArrayList<Tuple2<KinesisStreamShard, SequenceNumber>> snapShotState = new ArrayList<>(3);
+		snapShotState.add(Tuple2.of(
+			new KinesisStreamShard("fakeStream1",
+				new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
+			new SequenceNumber("12")));
+		snapShotState.add(Tuple2.of(
+			new KinesisStreamShard("fakeStream1",
+				new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))),
+			new SequenceNumber("11")));
+		snapShotState.add(Tuple2.of(
+			new KinesisStreamShard("fakeStream1",
+				new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))),
+			new SequenceNumber("31")));
+
+		// ----------------------------------------------------------------------
+		// mock operator state backend and initial state for initializeState()
+		// ----------------------------------------------------------------------
+		TestingListState<Serializable> listState = new TestingListState<>();
+		for (Serializable state: initialState) {
+			listState.add(state);
+		}
+
+		OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
+		when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+
+		StateInitializationContext initializationContext = mock(StateInitializationContext.class);
+		when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore);
+		when(initializationContext.isRestored()).thenReturn(true);
+
+		// ----------------------------------------------------------------------
+		// mock a running fetcher and its state for snapshot
+		// ----------------------------------------------------------------------
+		HashMap<KinesisStreamShard, SequenceNumber> stateSnapshot = new HashMap<>();
+		for (Tuple2<KinesisStreamShard, SequenceNumber> tuple: snapShotState) {
+			stateSnapshot.put(tuple.f0, tuple.f1);
+		}
+
+		KinesisDataFetcher mockedFetcher = mock(KinesisDataFetcher.class);
+		when(mockedFetcher.snapshotState()).thenReturn(stateSnapshot);
+
+		// ----------------------------------------------------------------------
+		// create a consumer and test the snapshotState()
+		// ----------------------------------------------------------------------
+		FlinkKinesisConsumer<String> consumer = new FlinkKinesisConsumer<>("fakeStream", new SimpleStringSchema(), config);
+		FlinkKinesisConsumer<?> mockedConsumer = spy(consumer);
+
+		RuntimeContext context = mock(RuntimeContext.class);
+		when(context.getIndexOfThisSubtask()).thenReturn(1);
+
+		mockedConsumer.setRuntimeContext(context);
+		mockedConsumer.initializeState(initializationContext);
+		mockedConsumer.open(new Configuration());
+		Whitebox.setInternalState(mockedConsumer, "fetcher", mockedFetcher); // mock as consumer is running.
+
+		mockedConsumer.snapshotState(mock(FunctionSnapshotContext.class));
+
+		assertEquals(true, listState.clearCalled);
+		assertEquals(3, listState.getList().size());
+
+		for (Serializable state: initialState) {
+			for (Serializable currentState: listState.getList()) {
+				assertNotEquals(state, currentState);
+			}
+		}
+
+		for (Serializable state: snapShotState) {
+			boolean hasOneIsSame = false;
+			for (Serializable currentState: listState.getList()) {
+				hasOneIsSame = hasOneIsSame || state.equals(currentState);
+			}
+			assertEquals(true, hasOneIsSame);
+		}
 	}
 
 	// ----------------------------------------------------------------------
@@ -559,48 +699,288 @@ public class FlinkKinesisConsumerTest {
 
 	@Test
 	@SuppressWarnings("unchecked")
+	public void testFetcherShouldBeCorrectlySeededIfRestoringFromLegacyCheckpoint() throws Exception {
+		HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = getFakeRestoredStore("all");
+
+		KinesisDataFetcher mockedFetcher = Mockito.mock(KinesisDataFetcher.class);
+		List<KinesisStreamShard> shards = new ArrayList<>();
+		shards.addAll(fakeRestoredState.keySet());
+		when(mockedFetcher.discoverNewShardsToSubscribe()).thenReturn(shards);
+		PowerMockito.whenNew(KinesisDataFetcher.class).withAnyArguments().thenReturn(mockedFetcher);
+
+		// assume the given config is correct
+		PowerMockito.mockStatic(KinesisConfigUtil.class);
+		PowerMockito.doNothing().when(KinesisConfigUtil.class);
+
+		TestableFlinkKinesisConsumer consumer = new TestableFlinkKinesisConsumer(
+			"fakeStream", new Properties(), 10, 2);
+		consumer.restoreState(fakeRestoredState);
+		consumer.open(new Configuration());
+		consumer.run(Mockito.mock(SourceFunction.SourceContext.class));
+
+		Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true);
+		for (Map.Entry<KinesisStreamShard, SequenceNumber> restoredShard : fakeRestoredState.entrySet()) {
+			Mockito.verify(mockedFetcher).registerNewSubscribedShardState(
+				new KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue()));
+		}
+	}
+
+	@Test
+	@SuppressWarnings("unchecked")
 	public void testFetcherShouldBeCorrectlySeededIfRestoringFromCheckpoint() throws Exception {
+		// ----------------------------------------------------------------------
+		// setting initial state
+		// ----------------------------------------------------------------------
+		HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = getFakeRestoredStore("all");
+
+		// ----------------------------------------------------------------------
+		// mock operator state backend and initial state for initializeState()
+		// ----------------------------------------------------------------------
+		TestingListState<Serializable> listState = new TestingListState<>();
+		for (Map.Entry<KinesisStreamShard, SequenceNumber> state: fakeRestoredState.entrySet()) {
+			listState.add(Tuple2.of(state.getKey(), state.getValue()));
+		}
+
+		OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
+		when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+
+		StateInitializationContext initializationContext = mock(StateInitializationContext.class);
+		when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore);
+		when(initializationContext.isRestored()).thenReturn(true);
+
+		// ----------------------------------------------------------------------
+		// mock fetcher
+		// ----------------------------------------------------------------------
 		KinesisDataFetcher mockedFetcher = Mockito.mock(KinesisDataFetcher.class);
+		List<KinesisStreamShard> shards = new ArrayList<>();
+		shards.addAll(fakeRestoredState.keySet());
+		when(mockedFetcher.discoverNewShardsToSubscribe()).thenReturn(shards);
 		PowerMockito.whenNew(KinesisDataFetcher.class).withAnyArguments().thenReturn(mockedFetcher);
 
 		// assume the given config is correct
 		PowerMockito.mockStatic(KinesisConfigUtil.class);
 		PowerMockito.doNothing().when(KinesisConfigUtil.class);
 
-		HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = new HashMap<>();
-		fakeRestoredState.put(
-			new KinesisStreamShard("fakeStream1",
-				new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
-			new SequenceNumber(UUID.randomUUID().toString()));
-		fakeRestoredState.put(
-			new KinesisStreamShard("fakeStream1",
-				new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))),
-			new SequenceNumber(UUID.randomUUID().toString()));
-		fakeRestoredState.put(
-			new KinesisStreamShard("fakeStream1",
-				new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))),
-			new SequenceNumber(UUID.randomUUID().toString()));
-		fakeRestoredState.put(
-			new KinesisStreamShard("fakeStream2",
-				new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
-			new SequenceNumber(UUID.randomUUID().toString()));
-		fakeRestoredState.put(
-			new KinesisStreamShard("fakeStream2",
-				new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))),
-			new SequenceNumber(UUID.randomUUID().toString()));
+		// ----------------------------------------------------------------------
+		// start to test seed initial state to fetcher
+		// ----------------------------------------------------------------------
+		TestableFlinkKinesisConsumer consumer = new TestableFlinkKinesisConsumer(
+			"fakeStream", new Properties(), 10, 2);
+		consumer.initializeState(initializationContext);
+		consumer.open(new Configuration());
+		consumer.run(Mockito.mock(SourceFunction.SourceContext.class));
+
+		Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true);
+		for (Map.Entry<KinesisStreamShard, SequenceNumber> restoredShard : fakeRestoredState.entrySet()) {
+			Mockito.verify(mockedFetcher).registerNewSubscribedShardState(
+				new KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue()));
+		}
+	}
+
+	@Test
+	@SuppressWarnings("unchecked")
+	public void testFetcherShouldBeCorrectlySeededOnlyItsOwnStates() throws Exception {
+		// ----------------------------------------------------------------------
+		// setting initial state
+		// ----------------------------------------------------------------------
+		HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = getFakeRestoredStore("fakeStream1");
+
+		HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredStateForOthers = getFakeRestoredStore("fakeStream2");
+
+		// ----------------------------------------------------------------------
+		// mock operator state backend and initial state for initializeState()
+		// ----------------------------------------------------------------------
+		TestingListState<Serializable> listState = new TestingListState<>();
+		for (Map.Entry<KinesisStreamShard, SequenceNumber> state: fakeRestoredState.entrySet()) {
+			listState.add(Tuple2.of(state.getKey(), state.getValue()));
+		}
+		for (Map.Entry<KinesisStreamShard, SequenceNumber> state: fakeRestoredStateForOthers.entrySet()) {
+			listState.add(Tuple2.of(state.getKey(), state.getValue()));
+		}
 
+		OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
+		when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+
+		StateInitializationContext initializationContext = mock(StateInitializationContext.class);
+		when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore);
+		when(initializationContext.isRestored()).thenReturn(true);
+
+		// ----------------------------------------------------------------------
+		// mock fetcher
+		// ----------------------------------------------------------------------
+		KinesisDataFetcher mockedFetcher = Mockito.mock(KinesisDataFetcher.class);
+		List<KinesisStreamShard> shards = new ArrayList<>();
+		shards.addAll(fakeRestoredState.keySet());
+		when(mockedFetcher.discoverNewShardsToSubscribe()).thenReturn(shards);
+		PowerMockito.whenNew(KinesisDataFetcher.class).withAnyArguments().thenReturn(mockedFetcher);
+
+		// assume the given config is correct
+		PowerMockito.mockStatic(KinesisConfigUtil.class);
+		PowerMockito.doNothing().when(KinesisConfigUtil.class);
+
+		// ----------------------------------------------------------------------
+		// start to test seed initial state to fetcher
+		// ----------------------------------------------------------------------
 		TestableFlinkKinesisConsumer consumer = new TestableFlinkKinesisConsumer(
 			"fakeStream", new Properties(), 10, 2);
-		consumer.restoreState(fakeRestoredState);
+		consumer.initializeState(initializationContext);
 		consumer.open(new Configuration());
 		consumer.run(Mockito.mock(SourceFunction.SourceContext.class));
 
 		Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true);
+		for (Map.Entry<KinesisStreamShard, SequenceNumber> restoredShard : fakeRestoredStateForOthers.entrySet()) {
+			// should never get restored state not belonging to itself
+			Mockito.verify(mockedFetcher, never()).registerNewSubscribedShardState(
+				new KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue()));
+		}
 		for (Map.Entry<KinesisStreamShard, SequenceNumber> restoredShard : fakeRestoredState.entrySet()) {
-			Mockito.verify(mockedFetcher).advanceLastDiscoveredShardOfStream(
-				restoredShard.getKey().getStreamName(), restoredShard.getKey().getShard().getShardId());
+			// should get restored state belonging to itself
 			Mockito.verify(mockedFetcher).registerNewSubscribedShardState(
 				new KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue()));
 		}
 	}
+
+	/*
+	 * If the original parallelism is 2 and states is:
+	 *   Consumer subtask 1:
+	 *     stream1, shard1, SequentialNumber(xxx)
+	 *   Consumer subtask 2:
+	 *     stream1, shard2, SequentialNumber(yyy)
+	 * After discoverNewShardsToSubscribe() if there are two shards (shard3, shard4) been created:
+	 *   Consumer subtask 1 (late for discoverNewShardsToSubscribe()):
+	 *     stream1, shard1, SequentialNumber(xxx)
+	 *   Consumer subtask 2:
+	 *     stream1, shard2, SequentialNumber(yyy)
+	 *     stream1, shard4, SequentialNumber(zzz)
+	 *  If snapshotState() occur and parallelism is changed to 1:
+	 *    Union state will be:
+	 *     stream1, shard1, SequentialNumber(xxx)
+	 *     stream1, shard2, SequentialNumber(yyy)
+	 *     stream1, shard4, SequentialNumber(zzz)
+	 *    Fetcher should be seeded with:
+	 *     stream1, shard1, SequentialNumber(xxx)
+	 *     stream1, shard2, SequentialNumber(yyy)
+	 *     stream1, share3, SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM
+	 *     stream1, shard4, SequentialNumber(zzz)
+	 *
+	 *  This test is to guarantee the fetcher will be seeded correctly for such situation.
+	 */
+	@Test
+	@SuppressWarnings("unchecked")
+	public void testFetcherShouldBeCorrectlySeededWithNewDiscoveredKinesisStreamShard() throws Exception {
+		// ----------------------------------------------------------------------
+		// setting initial state
+		// ----------------------------------------------------------------------
+		HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = getFakeRestoredStore("all");
+
+		// ----------------------------------------------------------------------
+		// mock operator state backend and initial state for initializeState()
+		// ----------------------------------------------------------------------
+		TestingListState<Serializable> listState = new TestingListState<>();
+		for (Map.Entry<KinesisStreamShard, SequenceNumber> state: fakeRestoredState.entrySet()) {
+			listState.add(Tuple2.of(state.getKey(), state.getValue()));
+		}
+
+		OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
+		when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+
+		StateInitializationContext initializationContext = mock(StateInitializationContext.class);
+		when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore);
+		when(initializationContext.isRestored()).thenReturn(true);
+
+		// ----------------------------------------------------------------------
+		// mock fetcher
+		// ----------------------------------------------------------------------
+		KinesisDataFetcher mockedFetcher = Mockito.mock(KinesisDataFetcher.class);
+		List<KinesisStreamShard> shards = new ArrayList<>();
+		shards.addAll(fakeRestoredState.keySet());
+		shards.add(new KinesisStreamShard("fakeStream2",
+			new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))));
+		when(mockedFetcher.discoverNewShardsToSubscribe()).thenReturn(shards);
+		PowerMockito.whenNew(KinesisDataFetcher.class).withAnyArguments().thenReturn(mockedFetcher);
+
+		// assume the given config is correct
+		PowerMockito.mockStatic(KinesisConfigUtil.class);
+		PowerMockito.doNothing().when(KinesisConfigUtil.class);
+
+		// ----------------------------------------------------------------------
+		// start to test seed initial state to fetcher
+		// ----------------------------------------------------------------------
+		TestableFlinkKinesisConsumer consumer = new TestableFlinkKinesisConsumer(
+			"fakeStream", new Properties(), 10, 2);
+		consumer.initializeState(initializationContext);
+		consumer.open(new Configuration());
+		consumer.run(Mockito.mock(SourceFunction.SourceContext.class));
+
+		fakeRestoredState.put(new KinesisStreamShard("fakeStream2",
+				new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))),
+			SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get());
+		Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true);
+		for (Map.Entry<KinesisStreamShard, SequenceNumber> restoredShard : fakeRestoredState.entrySet()) {
+			Mockito.verify(mockedFetcher).registerNewSubscribedShardState(
+				new KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue()));
+		}
+	}
+
+	private static final class TestingListState<T> implements ListState<T> {
+
+		private final List<T> list = new ArrayList<>();
+		private boolean clearCalled = false;
+
+		@Override
+		public void clear() {
+			list.clear();
+			clearCalled = true;
+		}
+
+		@Override
+		public Iterable<T> get() throws Exception {
+			return list;
+		}
+
+		@Override
+		public void add(T value) throws Exception {
+			list.add(value);
+		}
+
+		public List<T> getList() {
+			return list;
+		}
+
+		public boolean isClearCalled() {
+			return clearCalled;
+		}
+	}
+
+	private HashMap<KinesisStreamShard, SequenceNumber> getFakeRestoredStore(String streamName) {
+		HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = new HashMap<>();
+
+		if (streamName.equals("fakeStream1") || streamName.equals("all")) {
+			fakeRestoredState.put(
+				new KinesisStreamShard("fakeStream1",
+					new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
+				new SequenceNumber(UUID.randomUUID().toString()));
+			fakeRestoredState.put(
+				new KinesisStreamShard("fakeStream1",
+					new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))),
+				new SequenceNumber(UUID.randomUUID().toString()));
+			fakeRestoredState.put(
+				new KinesisStreamShard("fakeStream1",
+					new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))),
+				new SequenceNumber(UUID.randomUUID().toString()));
+		}
+
+		if (streamName.equals("fakeStream2") || streamName.equals("all")) {
+			fakeRestoredState.put(
+				new KinesisStreamShard("fakeStream2",
+					new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
+				new SequenceNumber(UUID.randomUUID().toString()));
+			fakeRestoredState.put(
+				new KinesisStreamShard("fakeStream2",
+					new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))),
+				new SequenceNumber(UUID.randomUUID().toString()));
+		}
+
+		return fakeRestoredState;
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/a05b574c/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.1-snapshot
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.1-snapshot b/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.1-snapshot
new file mode 100644
index 0000000..b60402e
Binary files /dev/null and b/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.1-snapshot differ

http://git-wip-us.apache.org/repos/asf/flink/blob/a05b574c/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.1-snapshot-empty
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.1-snapshot-empty b/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.1-snapshot-empty
new file mode 100644
index 0000000..f4dd96d
Binary files /dev/null and b/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.1-snapshot-empty differ


[2/2] flink git commit: [FLINK-4821] [kinesis] General improvements to rescalable FlinkKinesisConsumer

Posted by tz...@apache.org.
[FLINK-4821] [kinesis] General improvements to rescalable FlinkKinesisConsumer

This commit adds some general improvements to the rescalable
implementation of FlinkKinesisConsumer, including:
- Refactor setup procedures in KinesisDataFetcher so that duplicate work
  isn't done on a restored run
- Strengthen corner cases where fetcher was not fully seeded with
  initial state when snapshot is taken

This closes #3001.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/e5b65a7f
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/e5b65a7f
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/e5b65a7f

Branch: refs/heads/master
Commit: e5b65a7fc2b4a7532ca40748f81bcbf8ace46815
Parents: a05b574
Author: Tzu-Li (Gordon) Tai <tz...@apache.org>
Authored: Sun May 7 16:29:32 2017 +0800
Committer: Tzu-Li (Gordon) Tai <tz...@apache.org>
Committed: Sun May 7 17:33:04 2017 +0800

----------------------------------------------------------------------
 .../kinesis/FlinkKinesisConsumer.java           | 150 ++++++++---------
 .../kinesis/internals/KinesisDataFetcher.java   |  52 +-----
 .../FlinkKinesisConsumerMigrationTest.java      |   5 +-
 .../kinesis/FlinkKinesisConsumerTest.java       | 159 +++++++++++--------
 .../internals/KinesisDataFetcherTest.java       |  65 ++++++--
 .../testutils/TestableKinesisDataFetcher.java   |  14 ++
 6 files changed, 233 insertions(+), 212 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/e5b65a7f/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java
index dfcd552..4982f7f 100644
--- a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java
+++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java
@@ -25,13 +25,14 @@ import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
 import org.apache.flink.api.java.typeutils.TupleTypeInfo;
-import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.state.FunctionInitializationContext;
 import org.apache.flink.runtime.state.FunctionSnapshotContext;
 import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
 import org.apache.flink.streaming.api.checkpoint.CheckpointedRestoring;
 import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants;
+import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants.InitialPosition;
 import org.apache.flink.streaming.connectors.kinesis.internals.KinesisDataFetcher;
 import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard;
 import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState;
@@ -67,9 +68,9 @@ import static org.apache.flink.util.Preconditions.checkNotNull;
  * @param <T> the type of data emitted
  */
 public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> implements
-	ResultTypeQueryable<T>,
-	CheckpointedFunction,
-	CheckpointedRestoring<HashMap<KinesisStreamShard, SequenceNumber>> {
+		ResultTypeQueryable<T>,
+		CheckpointedFunction,
+		CheckpointedRestoring<HashMap<KinesisStreamShard, SequenceNumber>> {
 
 	private static final long serialVersionUID = 4724006128720664870L;
 
@@ -86,7 +87,7 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
 	 * shard list retrieval behaviours, etc */
 	private final Properties configProps;
 
-	/** User supplied deseriliazation schema to convert Kinesis byte messages to Flink objects */
+	/** User supplied deserialization schema to convert Kinesis byte messages to Flink objects */
 	private final KinesisDeserializationSchema<T> deserializer;
 
 	// ------------------------------------------------------------------------
@@ -96,9 +97,6 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
 	/** Per-task fetcher for Kinesis data records, where each fetcher pulls data from one or more Kinesis shards */
 	private transient KinesisDataFetcher<T> fetcher;
 
-	/** The sequence numbers in the last state snapshot of this subtask */
-	private transient HashMap<KinesisStreamShard, SequenceNumber> lastStateSnapshot;
-
 	/** The sequence numbers to restore to upon restore from failure */
 	private transient HashMap<KinesisStreamShard, SequenceNumber> sequenceNumsToRestore;
 
@@ -108,7 +106,7 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
 	//  State for Checkpoint
 	// ------------------------------------------------------------------------
 
-	/** The name is the key for sequence numbers state, and cannot be changed. */
+	/** State name to access shard sequence number states; cannot be changed */
 	private static final String sequenceNumsStateStoreName = "Kinesis-Stream-Shard-State";
 
 	private transient ListState<Tuple2<KinesisStreamShard, SequenceNumber>> sequenceNumsStateForCheckpoint;
@@ -191,57 +189,33 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
 	// ------------------------------------------------------------------------
 
 	@Override
-	public void open(Configuration parameters) throws Exception {
-		super.open(parameters);
-
-		// restore to the last known sequence numbers from the latest complete snapshot
-		if (sequenceNumsToRestore != null) {
-			if (LOG.isInfoEnabled()) {
-				LOG.info("Subtask {} is restoring sequence numbers {} from previous checkpointed state",
-					getRuntimeContext().getIndexOfThisSubtask(), sequenceNumsToRestore.toString());
-			}
-
-			// initialize sequence numbers with restored state
-			lastStateSnapshot = sequenceNumsToRestore;
-		} else {
-			// start fresh with empty sequence numbers if there are no snapshots to restore from.
-			lastStateSnapshot = new HashMap<>();
-		}
-	}
-
-	@Override
 	public void run(SourceContext<T> sourceContext) throws Exception {
 
 		// all subtasks will run a fetcher, regardless of whether or not the subtask will initially have
 		// shards to subscribe to; fetchers will continuously poll for changes in the shard list, so all subtasks
 		// can potentially have new shards to subscribe to later on
-		fetcher = createFetcher(streams, sourceContext, getRuntimeContext(), configProps, deserializer);
-
-		boolean isRestoringFromFailure = (sequenceNumsToRestore != null);
-		fetcher.setIsRestoringFromFailure(isRestoringFromFailure);
-
-		// if we are restoring from a checkpoint, we iterate over the restored
-		// state and accordingly seed the fetcher with subscribed shards states
-		if (isRestoringFromFailure) {
-			// Since there may have a situation that some subtasks did not finish discovering before rescale,
-			// and KinesisDataFetcher will always discover the shard from the largest shard id. To prevent from
-			// missing some shards which didn't be discovered and whose id is not the largest one, we force the
-			// consumer to discover once from the smallest id and make sure each shard have its initial sequence
-			// number from restored state or SENTINEL_EARLIEST_SEQUENCE_NUM.
-			List<KinesisStreamShard> newShardsCreatedWhileNotRunning = fetcher.discoverNewShardsToSubscribe();
-			for (KinesisStreamShard shard : newShardsCreatedWhileNotRunning) {
-				SequenceNumber startingStateForNewShard;
-
-				if (lastStateSnapshot.containsKey(shard)) {
-					startingStateForNewShard = lastStateSnapshot.get(shard);
+		KinesisDataFetcher<T> fetcher = createFetcher(streams, sourceContext, getRuntimeContext(), configProps, deserializer);
+
+		// initial discovery
+		List<KinesisStreamShard> allShards = fetcher.discoverNewShardsToSubscribe();
+
+		for (KinesisStreamShard shard : allShards) {
+			if (sequenceNumsToRestore != null) {
+				if (sequenceNumsToRestore.containsKey(shard)) {
+					// if the shard was already seen and is contained in the state,
+					// just use the sequence number stored in the state
+					fetcher.registerNewSubscribedShardState(
+						new KinesisStreamShardState(shard, sequenceNumsToRestore.get(shard)));
 
 					if (LOG.isInfoEnabled()) {
 						LOG.info("Subtask {} is seeding the fetcher with restored shard {}," +
 								" starting state set to the restored sequence number {}",
-							getRuntimeContext().getIndexOfThisSubtask(), shard.toString(), startingStateForNewShard);
+							getRuntimeContext().getIndexOfThisSubtask(), shard.toString(), sequenceNumsToRestore.get(shard));
 					}
 				} else {
-					startingStateForNewShard = SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get();
+					// the shard wasn't discovered in the previous run, therefore should be consumed from the beginning
+					fetcher.registerNewSubscribedShardState(
+						new KinesisStreamShardState(shard, SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get()));
 
 					if (LOG.isInfoEnabled()) {
 						LOG.info("Subtask {} is seeding the fetcher with new discovered shard {}," +
@@ -249,9 +223,20 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
 							getRuntimeContext().getIndexOfThisSubtask(), shard.toString());
 					}
 				}
+			} else {
+				// we're starting fresh; use the configured start position as initial state
+				SentinelSequenceNumber startingSeqNum =
+					InitialPosition.valueOf(configProps.getProperty(
+						ConsumerConfigConstants.STREAM_INITIAL_POSITION,
+						ConsumerConfigConstants.DEFAULT_STREAM_INITIAL_POSITION)).toSentinelSequenceNumber();
 
 				fetcher.registerNewSubscribedShardState(
-					new KinesisStreamShardState(shard, startingStateForNewShard));
+					new KinesisStreamShardState(shard, startingSeqNum.get()));
+
+				if (LOG.isInfoEnabled()) {
+					LOG.info("Subtask {} will be seeded with initial shard {}, starting state set as sequence number {}",
+						getRuntimeContext().getIndexOfThisSubtask(), shard.toString(), startingSeqNum.get());
+				}
 			}
 		}
 
@@ -260,6 +245,10 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
 			return;
 		}
 
+		// expose the fetcher from this point, so that state
+		// snapshots can be taken from the fetcher's state holders
+		this.fetcher = fetcher;
+
 		// start the fetcher loop. The fetcher will stop running only when cancel() or
 		// close() is called, or an error is thrown by threads created by the fetcher
 		fetcher.runFetcher();
@@ -306,13 +295,12 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
 
 	@Override
 	public void initializeState(FunctionInitializationContext context) throws Exception {
-		TypeInformation<Tuple2<KinesisStreamShard, SequenceNumber>> tuple = new TupleTypeInfo<>(
+		TypeInformation<Tuple2<KinesisStreamShard, SequenceNumber>> shardsStateTypeInfo = new TupleTypeInfo<>(
 			TypeInformation.of(KinesisStreamShard.class),
-			TypeInformation.of(SequenceNumber.class)
-		);
+			TypeInformation.of(SequenceNumber.class));
 
 		sequenceNumsStateForCheckpoint = context.getOperatorStateStore().getUnionListState(
-			new ListStateDescriptor<>(sequenceNumsStateStoreName, tuple));
+			new ListStateDescriptor<>(sequenceNumsStateStoreName, shardsStateTypeInfo));
 
 		if (context.isRestored()) {
 			if (sequenceNumsToRestore == null) {
@@ -323,8 +311,6 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
 
 				LOG.info("Setting restore state in the FlinkKinesisConsumer. Using the following offsets: {}",
 					sequenceNumsToRestore);
-			} else if (sequenceNumsToRestore.isEmpty()) {
-				sequenceNumsToRestore = null;
 			}
 		} else {
 			LOG.info("No restore state for FlinkKinesisConsumer.");
@@ -333,11 +319,7 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
 
 	@Override
 	public void snapshotState(FunctionSnapshotContext context) throws Exception {
-		if (lastStateSnapshot == null) {
-			LOG.debug("snapshotState() requested on not yet opened source; returning null.");
-		} else if (fetcher == null) {
-			LOG.debug("snapshotState() requested on not yet running source; returning null.");
-		} else if (!running) {
+		if (!running) {
 			LOG.debug("snapshotState() called on closed source; returning null.");
 		} else {
 			if (LOG.isDebugEnabled()) {
@@ -345,15 +327,33 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
 			}
 
 			sequenceNumsStateForCheckpoint.clear();
-			lastStateSnapshot = fetcher.snapshotState();
 
-			if (LOG.isDebugEnabled()) {
-				LOG.debug("Snapshotted state, last processed sequence numbers: {}, checkpoint id: {}, timestamp: {}",
-					lastStateSnapshot.toString(), context.getCheckpointId(), context.getCheckpointTimestamp());
-			}
+			if (fetcher == null) {
+				if (sequenceNumsToRestore != null) {
+					for (Map.Entry<KinesisStreamShard, SequenceNumber> entry : sequenceNumsToRestore.entrySet()) {
+						// sequenceNumsToRestore is the restored global union state;
+						// should only snapshot shards that actually belong to us
+
+						if (KinesisDataFetcher.isThisSubtaskShouldSubscribeTo(
+								entry.getKey(),
+								getRuntimeContext().getNumberOfParallelSubtasks(),
+								getRuntimeContext().getIndexOfThisSubtask())) {
+
+							sequenceNumsStateForCheckpoint.add(Tuple2.of(entry.getKey(), entry.getValue()));
+						}
+					}
+				}
+			} else {
+				HashMap<KinesisStreamShard, SequenceNumber> lastStateSnapshot = fetcher.snapshotState();
+
+				if (LOG.isDebugEnabled()) {
+					LOG.debug("Snapshotted state, last processed sequence numbers: {}, checkpoint id: {}, timestamp: {}",
+						lastStateSnapshot.toString(), context.getCheckpointId(), context.getCheckpointTimestamp());
+				}
 
-			for (Map.Entry<KinesisStreamShard, SequenceNumber> entry : lastStateSnapshot.entrySet()) {
-				sequenceNumsStateForCheckpoint.add(Tuple2.of(entry.getKey(), entry.getValue()));
+				for (Map.Entry<KinesisStreamShard, SequenceNumber> entry : lastStateSnapshot.entrySet()) {
+					sequenceNumsStateForCheckpoint.add(Tuple2.of(entry.getKey(), entry.getValue()));
+				}
 			}
 		}
 	}
@@ -366,12 +366,14 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
 		sequenceNumsToRestore = restoredState.isEmpty() ? null : restoredState;
 	}
 
-	/** This method is created for tests that can mock the KinesisDataFetcher in the consumer. */
-	protected KinesisDataFetcher<T> createFetcher(List<String> streams,
-													SourceFunction.SourceContext<T> sourceContext,
-													RuntimeContext runtimeContext,
-													Properties configProps,
-													KinesisDeserializationSchema<T> deserializationSchema) {
+	/** This method is exposed for tests that need to mock the KinesisDataFetcher in the consumer. */
+	protected KinesisDataFetcher<T> createFetcher(
+			List<String> streams,
+			SourceFunction.SourceContext<T> sourceContext,
+			RuntimeContext runtimeContext,
+			Properties configProps,
+			KinesisDeserializationSchema<T> deserializationSchema) {
+
 		return new KinesisDataFetcher<>(streams, sourceContext, runtimeContext, configProps, deserializationSchema);
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/e5b65a7f/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
index c5b4b04..99305cb 100644
--- a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
+++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
@@ -19,9 +19,7 @@ package org.apache.flink.streaming.connectors.kinesis.internals;
 
 import org.apache.flink.api.common.functions.RuntimeContext;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
-import org.apache.flink.streaming.connectors.kinesis.FlinkKinesisConsumer;
 import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants;
-import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants.InitialPosition;
 import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard;
 import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState;
 import org.apache.flink.streaming.connectors.kinesis.model.SentinelSequenceNumber;
@@ -99,12 +97,6 @@ public class KinesisDataFetcher<T> {
 
 	private final int indexOfThisConsumerSubtask;
 
-	/**
-	 * This flag should be set by {@link FlinkKinesisConsumer} using
-	 * {@link KinesisDataFetcher#setIsRestoringFromFailure(boolean)}
-	 */
-	private boolean isRestoredFromFailure;
-
 	// ------------------------------------------------------------------------
 	//  Executor services to run created threads
 	// ------------------------------------------------------------------------
@@ -235,41 +227,7 @@ public class KinesisDataFetcher<T> {
 		//  Procedures before starting the infinite while loop:
 		// ------------------------------------------------------------------------
 
-		//  1. query for any new shards that may have been created while the Kinesis consumer was not running,
-		//     and register them to the subscribedShardState list.
-		if (LOG.isDebugEnabled()) {
-			String logFormat = (!isRestoredFromFailure)
-				? "Subtask {} is trying to discover initial shards ..."
-				: "Subtask {} is trying to discover any new shards that were created while the consumer wasn't " +
-				"running due to failure ...";
-
-			LOG.debug(logFormat, indexOfThisConsumerSubtask);
-		}
-		List<KinesisStreamShard> newShardsCreatedWhileNotRunning = discoverNewShardsToSubscribe();
-		for (KinesisStreamShard shard : newShardsCreatedWhileNotRunning) {
-			// the starting state for new shards created while the consumer wasn't running depends on whether or not
-			// we are starting fresh (not restoring from a checkpoint); when we are starting fresh, this simply means
-			// all existing shards of streams we are subscribing to are new shards; when we are restoring from checkpoint,
-			// any new shards due to Kinesis resharding from the time of the checkpoint will be considered new shards.
-			InitialPosition initialPosition = InitialPosition.valueOf(configProps.getProperty(
-				ConsumerConfigConstants.STREAM_INITIAL_POSITION, ConsumerConfigConstants.DEFAULT_STREAM_INITIAL_POSITION));
-
-			SentinelSequenceNumber startingStateForNewShard = (isRestoredFromFailure)
-				? SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM
-				: initialPosition.toSentinelSequenceNumber();
-
-			if (LOG.isInfoEnabled()) {
-				String logFormat = (!isRestoredFromFailure)
-					? "Subtask {} will be seeded with initial shard {}, starting state set as sequence number {}"
-					: "Subtask {} will be seeded with new shard {} that was created while the consumer wasn't " +
-					"running due to failure, starting state set as sequence number {}";
-
-				LOG.info(logFormat, indexOfThisConsumerSubtask, shard.toString(), startingStateForNewShard.get());
-			}
-			registerNewSubscribedShardState(new KinesisStreamShardState(shard, startingStateForNewShard.get()));
-		}
-
-		//  2. check that there is at least one shard in the subscribed streams to consume from (can be done by
+		//  1. check that there is at least one shard in the subscribed streams to consume from (can be done by
 		//     checking if at least one value in subscribedStreamsToLastDiscoveredShardIds is not null)
 		boolean hasShards = false;
 		StringBuilder streamsWithNoShardsFound = new StringBuilder();
@@ -290,7 +248,7 @@ public class KinesisDataFetcher<T> {
 			throw new RuntimeException("No shards can be found for all subscribed streams: " + streams);
 		}
 
-		//  3. start consuming any shard state we already have in the subscribedShardState up to this point; the
+		//  2. start consuming any shard state we already have in the subscribedShardState up to this point; the
 		//     subscribedShardState may already be seeded with values due to step 1., or explicitly added by the
 		//     consumer using a restored state checkpoint
 		for (int seededStateIndex = 0; seededStateIndex < subscribedShardsState.size(); seededStateIndex++) {
@@ -489,10 +447,6 @@ public class KinesisDataFetcher<T> {
 	//  Functions to get / set information about the consumer
 	// ------------------------------------------------------------------------
 
-	public void setIsRestoringFromFailure(boolean bool) {
-		this.isRestoredFromFailure = bool;
-	}
-
 	protected Properties getConsumerConfiguration() {
 		return configProps;
 	}
@@ -595,7 +549,7 @@ public class KinesisDataFetcher<T> {
 	 * @param totalNumberOfConsumerSubtasks total number of consumer subtasks
 	 * @param indexOfThisConsumerSubtask index of this consumer subtask
 	 */
-	private static boolean isThisSubtaskShouldSubscribeTo(KinesisStreamShard shard,
+	public static boolean isThisSubtaskShouldSubscribeTo(KinesisStreamShard shard,
 														int totalNumberOfConsumerSubtasks,
 														int indexOfThisConsumerSubtask) {
 		return (Math.abs(shard.hashCode() % totalNumberOfConsumerSubtasks)) == indexOfThisConsumerSubtask;

http://git-wip-us.apache.org/repos/asf/flink/blob/e5b65a7f/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java
index 2f46e09..7629f9d 100644
--- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java
+++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java
@@ -42,10 +42,7 @@ import static org.mockito.Mockito.mock;
 
 /**
  * Tests for checking whether {@link FlinkKinesisConsumer} can restore from snapshots that were
- * done using the Flink 1.1 {@link FlinkKinesisConsumer}.
- *
- * <p>For regenerating the binary snapshot file you have to run the commented out portion
- * of each test on a checkout of the Flink 1.1 branch.
+ * done using the Flink 1.1 {@code FlinkKinesisConsumer}.
  */
 public class FlinkKinesisConsumerMigrationTest {
 

http://git-wip-us.apache.org/repos/asf/flink/blob/e5b65a7f/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java
index bf8e44f..4b178c7 100644
--- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java
+++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java
@@ -40,6 +40,7 @@ import org.apache.flink.streaming.connectors.kinesis.testutils.KinesisShardIdGen
 import org.apache.flink.streaming.connectors.kinesis.testutils.TestableFlinkKinesisConsumer;
 import org.apache.flink.streaming.connectors.kinesis.util.KinesisConfigUtil;
 import org.apache.flink.streaming.util.serialization.SimpleStringSchema;
+import org.junit.Assert;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
@@ -57,10 +58,8 @@ import java.util.ArrayList;
 import java.util.List;
 import java.util.Properties;
 import java.util.UUID;
-import java.io.Serializable;
 
 import static org.junit.Assert.fail;
-import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotEquals;
 import static org.mockito.Mockito.mock;
@@ -530,7 +529,7 @@ public class FlinkKinesisConsumerTest {
 	// ----------------------------------------------------------------------
 
 	@Test
-	public void testSnapshotStateShouldNotClearListStateIfSourceNotOpened() throws Exception {
+	public void testUseRestoredStateForSnapshotIfFetcherNotInitialized() throws Exception {
 		Properties config = new Properties();
 		config.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1");
 		config.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId");
@@ -538,57 +537,63 @@ public class FlinkKinesisConsumerTest {
 
 		OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
 
-		TestingListState<Serializable> listState = new TestingListState<>();
-
-		FlinkKinesisConsumer<String> consumer = new FlinkKinesisConsumer<>("fakeStream", new SimpleStringSchema(), config);
-
-		when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
-
-		StateInitializationContext initializationContext = mock(StateInitializationContext.class);
-
-		when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore);
-		when(initializationContext.isRestored()).thenReturn(false);
-
-		consumer.initializeState(initializationContext);
-
-		consumer.snapshotState(new StateSnapshotContextSynchronousImpl(123, 123)); //arbitrary checkpoint id and timestamp
-
-		assertFalse(listState.isClearCalled());
-	}
-
-	@Test
-	public void testSnapshotStateShouldNotClearListStateIfSourceNotRun() throws Exception {
-		Properties config = new Properties();
-		config.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1");
-		config.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId");
-		config.setProperty(AWSConfigConstants.AWS_SECRET_ACCESS_KEY, "secretKey");
-
-		OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
+		List<Tuple2<KinesisStreamShard, SequenceNumber>> globalUnionState = new ArrayList<>(4);
+		globalUnionState.add(Tuple2.of(
+			new KinesisStreamShard("fakeStream",
+				new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
+			new SequenceNumber("1")));
+		globalUnionState.add(Tuple2.of(
+			new KinesisStreamShard("fakeStream",
+				new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))),
+			new SequenceNumber("1")));
+		globalUnionState.add(Tuple2.of(
+			new KinesisStreamShard("fakeStream",
+				new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))),
+			new SequenceNumber("1")));
+		globalUnionState.add(Tuple2.of(
+			new KinesisStreamShard("fakeStream",
+				new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(3))),
+			new SequenceNumber("1")));
 
-		TestingListState<Serializable> listState = new TestingListState<>();
+		TestingListState<Tuple2<KinesisStreamShard, SequenceNumber>> listState = new TestingListState<>();
+		for (Tuple2<KinesisStreamShard, SequenceNumber> state : globalUnionState) {
+			listState.add(state);
+		}
 
 		FlinkKinesisConsumer<String> consumer = new FlinkKinesisConsumer<>("fakeStream", new SimpleStringSchema(), config);
+		RuntimeContext context = mock(RuntimeContext.class);
+		when(context.getIndexOfThisSubtask()).thenReturn(0);
+		when(context.getNumberOfParallelSubtasks()).thenReturn(2);
+		consumer.setRuntimeContext(context);
 
 		when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
 
 		StateInitializationContext initializationContext = mock(StateInitializationContext.class);
 
 		when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore);
-		when(initializationContext.isRestored()).thenReturn(false);
+		when(initializationContext.isRestored()).thenReturn(true);
 
 		consumer.initializeState(initializationContext);
 
-		consumer.open(new Configuration()); // only opened, not run
+		// only opened, not run
+		consumer.open(new Configuration());
+
+		// arbitrary checkpoint id and timestamp
+		consumer.snapshotState(new StateSnapshotContextSynchronousImpl(123, 123));
 
-		consumer.snapshotState(new StateSnapshotContextSynchronousImpl(123, 123)); //arbitrary checkpoint id and timestamp
+		Assert.assertTrue(listState.isClearCalled());
 
-		assertFalse(listState.isClearCalled());
+		// the checkpointed list state should contain only the shards that it should subscribe to
+		Assert.assertEquals(globalUnionState.size() / 2, listState.getList().size());
+		Assert.assertTrue(listState.getList().contains(globalUnionState.get(0)));
+		Assert.assertTrue(listState.getList().contains(globalUnionState.get(2)));
 	}
 
 	@Test
 	public void testListStateChangedAfterSnapshotState() throws Exception {
+
 		// ----------------------------------------------------------------------
-		// setting config, initial state and state after snapshot
+		// setup config, initial state and expected state snapshot
 		// ----------------------------------------------------------------------
 		Properties config = new Properties();
 		config.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1");
@@ -601,16 +606,16 @@ public class FlinkKinesisConsumerTest {
 				new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
 			new SequenceNumber("1")));
 
-		ArrayList<Tuple2<KinesisStreamShard, SequenceNumber>> snapShotState = new ArrayList<>(3);
-		snapShotState.add(Tuple2.of(
+		ArrayList<Tuple2<KinesisStreamShard, SequenceNumber>> expectedStateSnapshot = new ArrayList<>(3);
+		expectedStateSnapshot.add(Tuple2.of(
 			new KinesisStreamShard("fakeStream1",
 				new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
 			new SequenceNumber("12")));
-		snapShotState.add(Tuple2.of(
+		expectedStateSnapshot.add(Tuple2.of(
 			new KinesisStreamShard("fakeStream1",
 				new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))),
 			new SequenceNumber("11")));
-		snapShotState.add(Tuple2.of(
+		expectedStateSnapshot.add(Tuple2.of(
 			new KinesisStreamShard("fakeStream1",
 				new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))),
 			new SequenceNumber("31")));
@@ -618,8 +623,9 @@ public class FlinkKinesisConsumerTest {
 		// ----------------------------------------------------------------------
 		// mock operator state backend and initial state for initializeState()
 		// ----------------------------------------------------------------------
-		TestingListState<Serializable> listState = new TestingListState<>();
-		for (Serializable state: initialState) {
+
+		TestingListState<Tuple2<KinesisStreamShard, SequenceNumber>> listState = new TestingListState<>();
+		for (Tuple2<KinesisStreamShard, SequenceNumber> state: initialState) {
 			listState.add(state);
 		}
 
@@ -633,8 +639,9 @@ public class FlinkKinesisConsumerTest {
 		// ----------------------------------------------------------------------
 		// mock a running fetcher and its state for snapshot
 		// ----------------------------------------------------------------------
+
 		HashMap<KinesisStreamShard, SequenceNumber> stateSnapshot = new HashMap<>();
-		for (Tuple2<KinesisStreamShard, SequenceNumber> tuple: snapShotState) {
+		for (Tuple2<KinesisStreamShard, SequenceNumber> tuple: expectedStateSnapshot) {
 			stateSnapshot.put(tuple.f0, tuple.f1);
 		}
 
@@ -644,6 +651,7 @@ public class FlinkKinesisConsumerTest {
 		// ----------------------------------------------------------------------
 		// create a consumer and test the snapshotState()
 		// ----------------------------------------------------------------------
+
 		FlinkKinesisConsumer<String> consumer = new FlinkKinesisConsumer<>("fakeStream", new SimpleStringSchema(), config);
 		FlinkKinesisConsumer<?> mockedConsumer = spy(consumer);
 
@@ -653,22 +661,22 @@ public class FlinkKinesisConsumerTest {
 		mockedConsumer.setRuntimeContext(context);
 		mockedConsumer.initializeState(initializationContext);
 		mockedConsumer.open(new Configuration());
-		Whitebox.setInternalState(mockedConsumer, "fetcher", mockedFetcher); // mock as consumer is running.
+		Whitebox.setInternalState(mockedConsumer, "fetcher", mockedFetcher); // mock consumer as running.
 
 		mockedConsumer.snapshotState(mock(FunctionSnapshotContext.class));
 
 		assertEquals(true, listState.clearCalled);
 		assertEquals(3, listState.getList().size());
 
-		for (Serializable state: initialState) {
-			for (Serializable currentState: listState.getList()) {
+		for (Tuple2<KinesisStreamShard, SequenceNumber> state: initialState) {
+			for (Tuple2<KinesisStreamShard, SequenceNumber> currentState: listState.getList()) {
 				assertNotEquals(state, currentState);
 			}
 		}
 
-		for (Serializable state: snapShotState) {
+		for (Tuple2<KinesisStreamShard, SequenceNumber> state: expectedStateSnapshot) {
 			boolean hasOneIsSame = false;
-			for (Serializable currentState: listState.getList()) {
+			for (Tuple2<KinesisStreamShard, SequenceNumber> currentState: listState.getList()) {
 				hasOneIsSame = hasOneIsSame || state.equals(currentState);
 			}
 			assertEquals(true, hasOneIsSame);
@@ -693,8 +701,6 @@ public class FlinkKinesisConsumerTest {
 			"fakeStream", new Properties(), 10, 2);
 		consumer.open(new Configuration());
 		consumer.run(Mockito.mock(SourceFunction.SourceContext.class));
-
-		Mockito.verify(mockedFetcher).setIsRestoringFromFailure(false);
 	}
 
 	@Test
@@ -718,7 +724,6 @@ public class FlinkKinesisConsumerTest {
 		consumer.open(new Configuration());
 		consumer.run(Mockito.mock(SourceFunction.SourceContext.class));
 
-		Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true);
 		for (Map.Entry<KinesisStreamShard, SequenceNumber> restoredShard : fakeRestoredState.entrySet()) {
 			Mockito.verify(mockedFetcher).registerNewSubscribedShardState(
 				new KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue()));
@@ -728,15 +733,18 @@ public class FlinkKinesisConsumerTest {
 	@Test
 	@SuppressWarnings("unchecked")
 	public void testFetcherShouldBeCorrectlySeededIfRestoringFromCheckpoint() throws Exception {
+
 		// ----------------------------------------------------------------------
-		// setting initial state
+		// setup initial state
 		// ----------------------------------------------------------------------
+
 		HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = getFakeRestoredStore("all");
 
 		// ----------------------------------------------------------------------
 		// mock operator state backend and initial state for initializeState()
 		// ----------------------------------------------------------------------
-		TestingListState<Serializable> listState = new TestingListState<>();
+
+		TestingListState<Tuple2<KinesisStreamShard, SequenceNumber>> listState = new TestingListState<>();
 		for (Map.Entry<KinesisStreamShard, SequenceNumber> state: fakeRestoredState.entrySet()) {
 			listState.add(Tuple2.of(state.getKey(), state.getValue()));
 		}
@@ -751,6 +759,7 @@ public class FlinkKinesisConsumerTest {
 		// ----------------------------------------------------------------------
 		// mock fetcher
 		// ----------------------------------------------------------------------
+
 		KinesisDataFetcher mockedFetcher = Mockito.mock(KinesisDataFetcher.class);
 		List<KinesisStreamShard> shards = new ArrayList<>();
 		shards.addAll(fakeRestoredState.keySet());
@@ -762,15 +771,15 @@ public class FlinkKinesisConsumerTest {
 		PowerMockito.doNothing().when(KinesisConfigUtil.class);
 
 		// ----------------------------------------------------------------------
-		// start to test seed initial state to fetcher
+		// start to test fetcher's initial state seeding
 		// ----------------------------------------------------------------------
+
 		TestableFlinkKinesisConsumer consumer = new TestableFlinkKinesisConsumer(
 			"fakeStream", new Properties(), 10, 2);
 		consumer.initializeState(initializationContext);
 		consumer.open(new Configuration());
 		consumer.run(Mockito.mock(SourceFunction.SourceContext.class));
 
-		Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true);
 		for (Map.Entry<KinesisStreamShard, SequenceNumber> restoredShard : fakeRestoredState.entrySet()) {
 			Mockito.verify(mockedFetcher).registerNewSubscribedShardState(
 				new KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue()));
@@ -780,9 +789,11 @@ public class FlinkKinesisConsumerTest {
 	@Test
 	@SuppressWarnings("unchecked")
 	public void testFetcherShouldBeCorrectlySeededOnlyItsOwnStates() throws Exception {
+
 		// ----------------------------------------------------------------------
-		// setting initial state
+		// setup initial state
 		// ----------------------------------------------------------------------
+
 		HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = getFakeRestoredStore("fakeStream1");
 
 		HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredStateForOthers = getFakeRestoredStore("fakeStream2");
@@ -790,7 +801,8 @@ public class FlinkKinesisConsumerTest {
 		// ----------------------------------------------------------------------
 		// mock operator state backend and initial state for initializeState()
 		// ----------------------------------------------------------------------
-		TestingListState<Serializable> listState = new TestingListState<>();
+
+		TestingListState<Tuple2<KinesisStreamShard, SequenceNumber>> listState = new TestingListState<>();
 		for (Map.Entry<KinesisStreamShard, SequenceNumber> state: fakeRestoredState.entrySet()) {
 			listState.add(Tuple2.of(state.getKey(), state.getValue()));
 		}
@@ -808,6 +820,7 @@ public class FlinkKinesisConsumerTest {
 		// ----------------------------------------------------------------------
 		// mock fetcher
 		// ----------------------------------------------------------------------
+
 		KinesisDataFetcher mockedFetcher = Mockito.mock(KinesisDataFetcher.class);
 		List<KinesisStreamShard> shards = new ArrayList<>();
 		shards.addAll(fakeRestoredState.keySet());
@@ -819,15 +832,15 @@ public class FlinkKinesisConsumerTest {
 		PowerMockito.doNothing().when(KinesisConfigUtil.class);
 
 		// ----------------------------------------------------------------------
-		// start to test seed initial state to fetcher
+		// start to test fetcher's initial state seeding
 		// ----------------------------------------------------------------------
+
 		TestableFlinkKinesisConsumer consumer = new TestableFlinkKinesisConsumer(
 			"fakeStream", new Properties(), 10, 2);
 		consumer.initializeState(initializationContext);
 		consumer.open(new Configuration());
 		consumer.run(Mockito.mock(SourceFunction.SourceContext.class));
 
-		Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true);
 		for (Map.Entry<KinesisStreamShard, SequenceNumber> restoredShard : fakeRestoredStateForOthers.entrySet()) {
 			// should never get restored state not belonging to itself
 			Mockito.verify(mockedFetcher, never()).registerNewSubscribedShardState(
@@ -841,42 +854,49 @@ public class FlinkKinesisConsumerTest {
 	}
 
 	/*
-	 * If the original parallelism is 2 and states is:
+	 * This tests that the consumer correctly picks up shards that were not discovered on the previous run.
+	 *
+	 * Case under test:
+	 *
+	 * If the original parallelism is 2 and states are:
 	 *   Consumer subtask 1:
 	 *     stream1, shard1, SequentialNumber(xxx)
 	 *   Consumer subtask 2:
 	 *     stream1, shard2, SequentialNumber(yyy)
-	 * After discoverNewShardsToSubscribe() if there are two shards (shard3, shard4) been created:
+	 *
+	 * After discoverNewShardsToSubscribe() if there were two shards (shard3, shard4) created:
 	 *   Consumer subtask 1 (late for discoverNewShardsToSubscribe()):
 	 *     stream1, shard1, SequentialNumber(xxx)
 	 *   Consumer subtask 2:
 	 *     stream1, shard2, SequentialNumber(yyy)
 	 *     stream1, shard4, SequentialNumber(zzz)
-	 *  If snapshotState() occur and parallelism is changed to 1:
-	 *    Union state will be:
+	 *
+	 * If snapshotState() occurs and parallelism is changed to 1:
+	 *   Union state will be:
 	 *     stream1, shard1, SequentialNumber(xxx)
 	 *     stream1, shard2, SequentialNumber(yyy)
 	 *     stream1, shard4, SequentialNumber(zzz)
-	 *    Fetcher should be seeded with:
+	 *   Fetcher should be seeded with:
 	 *     stream1, shard1, SequentialNumber(xxx)
 	 *     stream1, shard2, SequentialNumber(yyy)
 	 *     stream1, share3, SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM
 	 *     stream1, shard4, SequentialNumber(zzz)
-	 *
-	 *  This test is to guarantee the fetcher will be seeded correctly for such situation.
 	 */
 	@Test
 	@SuppressWarnings("unchecked")
 	public void testFetcherShouldBeCorrectlySeededWithNewDiscoveredKinesisStreamShard() throws Exception {
+
 		// ----------------------------------------------------------------------
-		// setting initial state
+		// setup initial state
 		// ----------------------------------------------------------------------
+
 		HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = getFakeRestoredStore("all");
 
 		// ----------------------------------------------------------------------
 		// mock operator state backend and initial state for initializeState()
 		// ----------------------------------------------------------------------
-		TestingListState<Serializable> listState = new TestingListState<>();
+
+		TestingListState<Tuple2<KinesisStreamShard, SequenceNumber>> listState = new TestingListState<>();
 		for (Map.Entry<KinesisStreamShard, SequenceNumber> state: fakeRestoredState.entrySet()) {
 			listState.add(Tuple2.of(state.getKey(), state.getValue()));
 		}
@@ -891,6 +911,7 @@ public class FlinkKinesisConsumerTest {
 		// ----------------------------------------------------------------------
 		// mock fetcher
 		// ----------------------------------------------------------------------
+
 		KinesisDataFetcher mockedFetcher = Mockito.mock(KinesisDataFetcher.class);
 		List<KinesisStreamShard> shards = new ArrayList<>();
 		shards.addAll(fakeRestoredState.keySet());
@@ -904,8 +925,9 @@ public class FlinkKinesisConsumerTest {
 		PowerMockito.doNothing().when(KinesisConfigUtil.class);
 
 		// ----------------------------------------------------------------------
-		// start to test seed initial state to fetcher
+		// start to test fetcher's initial state seeding
 		// ----------------------------------------------------------------------
+
 		TestableFlinkKinesisConsumer consumer = new TestableFlinkKinesisConsumer(
 			"fakeStream", new Properties(), 10, 2);
 		consumer.initializeState(initializationContext);
@@ -915,7 +937,6 @@ public class FlinkKinesisConsumerTest {
 		fakeRestoredState.put(new KinesisStreamShard("fakeStream2",
 				new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))),
 			SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get());
-		Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true);
 		for (Map.Entry<KinesisStreamShard, SequenceNumber> restoredShard : fakeRestoredState.entrySet()) {
 			Mockito.verify(mockedFetcher).registerNewSubscribedShardState(
 				new KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue()));

http://git-wip-us.apache.org/repos/asf/flink/blob/e5b65a7f/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java
index e79f9b1..800fde5 100644
--- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java
+++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java
@@ -18,9 +18,14 @@
 package org.apache.flink.streaming.connectors.kinesis.internals;
 
 import com.amazonaws.services.kinesis.model.Shard;
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.connectors.kinesis.FlinkKinesisConsumer;
+import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants;
 import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard;
 import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber;
 import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState;
+import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchema;
 import org.apache.flink.streaming.connectors.kinesis.testutils.FakeKinesisBehavioursFactory;
 import org.apache.flink.streaming.connectors.kinesis.testutils.KinesisShardIdGenerator;
 import org.apache.flink.streaming.connectors.kinesis.testutils.TestableKinesisDataFetcher;
@@ -42,6 +47,8 @@ import java.util.UUID;
 import java.util.concurrent.atomic.AtomicReference;
 
 import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 @RunWith(PowerMockRunner.class)
 @PrepareForTest(TestableKinesisDataFetcher.class)
@@ -67,8 +74,6 @@ public class KinesisDataFetcherTest {
 				subscribedStreamsToLastSeenShardIdsUnderTest,
 				FakeKinesisBehavioursFactory.noShardsFoundForRequestedStreamsBehaviour());
 
-		fetcher.setIsRestoringFromFailure(false); // not restoring
-
 		fetcher.runFetcher(); // this should throw RuntimeException
 	}
 
@@ -100,23 +105,30 @@ public class KinesisDataFetcherTest {
 				subscribedStreamsToLastSeenShardIdsUnderTest,
 				FakeKinesisBehavioursFactory.nonReshardedStreamsBehaviour(streamToShardCount));
 
-		fetcher.setIsRestoringFromFailure(false);
+		Properties testConfig = new Properties();
+		testConfig.setProperty(ConsumerConfigConstants.AWS_REGION, "us-east-1");
+		testConfig.setProperty(ConsumerConfigConstants.AWS_CREDENTIALS_PROVIDER, "BASIC");
+		testConfig.setProperty(ConsumerConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId");
+		testConfig.setProperty(ConsumerConfigConstants.AWS_SECRET_ACCESS_KEY, "secretKey");
+
+		final DummyFlinkKafkaConsumer<String> consumer = new DummyFlinkKafkaConsumer<>(testConfig, fetcher);
 
 		PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class));
-		Thread runFetcherThread = new Thread(new Runnable() {
+		Thread consumerThread = new Thread(new Runnable() {
 			@Override
 			public void run() {
 				try {
-					fetcher.runFetcher();
+					consumer.run(mock(SourceFunction.SourceContext.class));
 				} catch (Exception e) {
 					//
 				}
 			}
 		});
-		runFetcherThread.start();
-		Thread.sleep(1000); // sleep a while before closing
-		fetcher.shutdownFetcher();
+		consumerThread.start();
 
+		fetcher.waitUntilRun();
+		consumer.cancel();
+		consumerThread.join();
 
 		// assert that the streams tracked in the state are identical to the subscribed streams
 		Set<String> streamsInState = subscribedStreamsToLastSeenShardIdsUnderTest.keySet();
@@ -192,8 +204,6 @@ public class KinesisDataFetcherTest {
 				new KinesisStreamShardState(restoredState.getKey(), new SequenceNumber(restoredState.getValue())));
 		}
 
-		fetcher.setIsRestoringFromFailure(true);
-
 		PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class));
 		Thread runFetcherThread = new Thread(new Runnable() {
 			@Override
@@ -284,8 +294,6 @@ public class KinesisDataFetcherTest {
 				new KinesisStreamShardState(restoredState.getKey(), new SequenceNumber(restoredState.getValue())));
 		}
 
-		fetcher.setIsRestoringFromFailure(true);
-
 		PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class));
 		Thread runFetcherThread = new Thread(new Runnable() {
 			@Override
@@ -380,8 +388,6 @@ public class KinesisDataFetcherTest {
 				new KinesisStreamShardState(restoredState.getKey(), new SequenceNumber(restoredState.getValue())));
 		}
 
-		fetcher.setIsRestoringFromFailure(true);
-
 		PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class));
 		Thread runFetcherThread = new Thread(new Runnable() {
 			@Override
@@ -477,8 +483,6 @@ public class KinesisDataFetcherTest {
 				new KinesisStreamShardState(restoredState.getKey(), new SequenceNumber(restoredState.getValue())));
 		}
 
-		fetcher.setIsRestoringFromFailure(true);
-
 		PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class));
 		Thread runFetcherThread = new Thread(new Runnable() {
 			@Override
@@ -507,4 +511,33 @@ public class KinesisDataFetcherTest {
 		assertTrue(subscribedStreamsToLastSeenShardIdsUnderTest.get("fakeStream3") == null);
 		assertTrue(subscribedStreamsToLastSeenShardIdsUnderTest.get("fakeStream4") == null);
 	}
+
+	private static class DummyFlinkKafkaConsumer<T> extends FlinkKinesisConsumer<T> {
+		private static final long serialVersionUID = 1L;
+
+		private KinesisDataFetcher<T> fetcher;
+
+		@SuppressWarnings("unchecked")
+		DummyFlinkKafkaConsumer(Properties properties, KinesisDataFetcher<T> fetcher) {
+			super("test", mock(KinesisDeserializationSchema.class), properties);
+			this.fetcher = fetcher;
+		}
+
+		@Override
+		protected KinesisDataFetcher<T> createFetcher(List<String> streams,
+													  SourceFunction.SourceContext<T> sourceContext,
+													  RuntimeContext runtimeContext,
+													  Properties configProps,
+													  KinesisDeserializationSchema<T> deserializationSchema) {
+			return fetcher;
+		}
+
+		@Override
+		public RuntimeContext getRuntimeContext() {
+			RuntimeContext context = mock(RuntimeContext.class);
+			when(context.getIndexOfThisSubtask()).thenReturn(0);
+			when(context.getNumberOfParallelSubtasks()).thenReturn(1);
+			return context;
+		}
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/e5b65a7f/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java
index 57886fe..bb644ba 100644
--- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java
+++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java
@@ -18,6 +18,7 @@
 package org.apache.flink.streaming.connectors.kinesis.testutils;
 
 import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.core.testutils.OneShotLatch;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
 import org.apache.flink.streaming.connectors.kinesis.internals.KinesisDataFetcher;
 import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState;
@@ -42,6 +43,8 @@ public class TestableKinesisDataFetcher extends KinesisDataFetcher<String> {
 
 	private long numElementsCollected;
 
+	private OneShotLatch runWaiter;
+
 	public TestableKinesisDataFetcher(List<String> fakeStreams,
 									  Properties fakeConfiguration,
 									  int fakeTotalCountOfSubtasks,
@@ -62,6 +65,7 @@ public class TestableKinesisDataFetcher extends KinesisDataFetcher<String> {
 			fakeKinesis);
 
 		this.numElementsCollected = 0;
+		this.runWaiter = new OneShotLatch();
 	}
 
 	public long getNumOfElementsCollected() {
@@ -81,6 +85,16 @@ public class TestableKinesisDataFetcher extends KinesisDataFetcher<String> {
 		}
 	}
 
+	@Override
+	public void runFetcher() throws Exception {
+		runWaiter.trigger();
+		super.runFetcher();
+	}
+
+	public void waitUntilRun() throws Exception {
+		runWaiter.await();
+	}
+
 	@SuppressWarnings("unchecked")
 	private static SourceFunction.SourceContext<String> getMockedSourceContext() {
 		return Mockito.mock(SourceFunction.SourceContext.class);