You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tz...@apache.org on 2017/05/07 09:35:28 UTC
[1/2] flink git commit: [FLINK-4821] Implement rescalable
non-partitioned state for Kinesis Connector
Repository: flink
Updated Branches:
refs/heads/master fde4f9097 -> e5b65a7fc
[FLINK-4821] Implement rescalable non-partitioned state for Kinesis Connector
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/a05b574c
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/a05b574c
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/a05b574c
Branch: refs/heads/master
Commit: a05b574cc68d3f652d11fece46c23bbc24f35430
Parents: fde4f90
Author: Tony Wei <to...@gmail.com>
Authored: Wed Dec 14 10:18:25 2016 +0800
Committer: Tzu-Li (Gordon) Tai <tz...@apache.org>
Committed: Sun May 7 16:28:52 2017 +0800
----------------------------------------------------------------------
.../flink-connector-kinesis/pom.xml | 16 +
.../kinesis/FlinkKinesisConsumer.java | 150 +++++--
.../kinesis/internals/KinesisDataFetcher.java | 2 +-
.../FlinkKinesisConsumerMigrationTest.java | 149 +++++++
.../kinesis/FlinkKinesisConsumerTest.java | 440 +++++++++++++++++--
...is-consumer-migration-test-flink1.1-snapshot | Bin 0 -> 1140 bytes
...sumer-migration-test-flink1.1-snapshot-empty | Bin 0 -> 468 bytes
7 files changed, 690 insertions(+), 67 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/a05b574c/flink-connectors/flink-connector-kinesis/pom.xml
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/pom.xml b/flink-connectors/flink-connector-kinesis/pom.xml
index d457199..080626f 100644
--- a/flink-connectors/flink-connector-kinesis/pom.xml
+++ b/flink-connectors/flink-connector-kinesis/pom.xml
@@ -57,6 +57,14 @@ under the License.
<dependency>
<groupId>org.apache.flink</groupId>
+ <artifactId>flink-runtime_2.10</artifactId>
+ <version>${project.version}</version>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.flink</groupId>
<artifactId>flink-tests_2.10</artifactId>
<version>${project.version}</version>
<scope>test</scope>
@@ -65,6 +73,14 @@ under the License.
<dependency>
<groupId>org.apache.flink</groupId>
+ <artifactId>flink-streaming-java_2.10</artifactId>
+ <version>${project.version}</version>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.flink</groupId>
<artifactId>flink-test-utils_2.10</artifactId>
<version>${project.version}</version>
<scope>test</scope>
http://git-wip-us.apache.org/repos/asf/flink/blob/a05b574c/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java
index a62dc10..dfcd552 100644
--- a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java
+++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java
@@ -17,15 +17,26 @@
package org.apache.flink.streaming.connectors.kinesis;
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.configuration.Configuration;
-import org.apache.flink.streaming.api.checkpoint.CheckpointedAsynchronously;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedRestoring;
import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.connectors.kinesis.internals.KinesisDataFetcher;
import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard;
-import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber;
import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState;
+import org.apache.flink.streaming.connectors.kinesis.model.SentinelSequenceNumber;
+import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber;
import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchema;
import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchemaWrapper;
import org.apache.flink.streaming.connectors.kinesis.util.KinesisConfigUtil;
@@ -55,8 +66,10 @@ import static org.apache.flink.util.Preconditions.checkNotNull;
*
* @param <T> the type of data emitted
*/
-public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T>
- implements CheckpointedAsynchronously<HashMap<KinesisStreamShard, SequenceNumber>>, ResultTypeQueryable<T> {
+public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> implements
+ ResultTypeQueryable<T>,
+ CheckpointedFunction,
+ CheckpointedRestoring<HashMap<KinesisStreamShard, SequenceNumber>> {
private static final long serialVersionUID = 4724006128720664870L;
@@ -91,6 +104,14 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T>
private volatile boolean running = true;
+ // ------------------------------------------------------------------------
+ // State for Checkpoint
+ // ------------------------------------------------------------------------
+
+ /** The name is the key for sequence numbers state, and cannot be changed. */
+ private static final String sequenceNumsStateStoreName = "Kinesis-Stream-Shard-State";
+
+ private transient ListState<Tuple2<KinesisStreamShard, SequenceNumber>> sequenceNumsStateForCheckpoint;
// ------------------------------------------------------------------------
// Constructors
@@ -194,8 +215,7 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T>
// all subtasks will run a fetcher, regardless of whether or not the subtask will initially have
// shards to subscribe to; fetchers will continuously poll for changes in the shard list, so all subtasks
// can potentially have new shards to subscribe to later on
- fetcher = new KinesisDataFetcher<>(
- streams, sourceContext, getRuntimeContext(), configProps, deserializer);
+ fetcher = createFetcher(streams, sourceContext, getRuntimeContext(), configProps, deserializer);
boolean isRestoringFromFailure = (sequenceNumsToRestore != null);
fetcher.setIsRestoringFromFailure(isRestoringFromFailure);
@@ -203,17 +223,35 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T>
// if we are restoring from a checkpoint, we iterate over the restored
// state and accordingly seed the fetcher with subscribed shards states
if (isRestoringFromFailure) {
- for (Map.Entry<KinesisStreamShard, SequenceNumber> restored : lastStateSnapshot.entrySet()) {
- fetcher.advanceLastDiscoveredShardOfStream(
- restored.getKey().getStreamName(), restored.getKey().getShard().getShardId());
-
- if (LOG.isInfoEnabled()) {
- LOG.info("Subtask {} is seeding the fetcher with restored shard {}," +
- " starting state set to the restored sequence number {}",
- getRuntimeContext().getIndexOfThisSubtask(), restored.getKey().toString(), restored.getValue());
+ // Since there may have a situation that some subtasks did not finish discovering before rescale,
+ // and KinesisDataFetcher will always discover the shard from the largest shard id. To prevent from
+ // missing some shards which didn't be discovered and whose id is not the largest one, we force the
+ // consumer to discover once from the smallest id and make sure each shard have its initial sequence
+ // number from restored state or SENTINEL_EARLIEST_SEQUENCE_NUM.
+ List<KinesisStreamShard> newShardsCreatedWhileNotRunning = fetcher.discoverNewShardsToSubscribe();
+ for (KinesisStreamShard shard : newShardsCreatedWhileNotRunning) {
+ SequenceNumber startingStateForNewShard;
+
+ if (lastStateSnapshot.containsKey(shard)) {
+ startingStateForNewShard = lastStateSnapshot.get(shard);
+
+ if (LOG.isInfoEnabled()) {
+ LOG.info("Subtask {} is seeding the fetcher with restored shard {}," +
+ " starting state set to the restored sequence number {}",
+ getRuntimeContext().getIndexOfThisSubtask(), shard.toString(), startingStateForNewShard);
+ }
+ } else {
+ startingStateForNewShard = SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get();
+
+ if (LOG.isInfoEnabled()) {
+ LOG.info("Subtask {} is seeding the fetcher with new discovered shard {}," +
+ " starting state set to the SENTINEL_EARLIEST_SEQUENCE_NUM",
+ getRuntimeContext().getIndexOfThisSubtask(), shard.toString());
+ }
}
+
fetcher.registerNewSubscribedShardState(
- new KinesisStreamShardState(restored.getKey(), restored.getValue()));
+ new KinesisStreamShardState(shard, startingStateForNewShard));
}
}
@@ -267,38 +305,78 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T>
// ------------------------------------------------------------------------
@Override
- public HashMap<KinesisStreamShard, SequenceNumber> snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
- if (lastStateSnapshot == null) {
- LOG.debug("snapshotState() requested on not yet opened source; returning null.");
- return null;
- }
+ public void initializeState(FunctionInitializationContext context) throws Exception {
+ TypeInformation<Tuple2<KinesisStreamShard, SequenceNumber>> tuple = new TupleTypeInfo<>(
+ TypeInformation.of(KinesisStreamShard.class),
+ TypeInformation.of(SequenceNumber.class)
+ );
+
+ sequenceNumsStateForCheckpoint = context.getOperatorStateStore().getUnionListState(
+ new ListStateDescriptor<>(sequenceNumsStateStoreName, tuple));
+
+ if (context.isRestored()) {
+ if (sequenceNumsToRestore == null) {
+ sequenceNumsToRestore = new HashMap<>();
+ for (Tuple2<KinesisStreamShard, SequenceNumber> kinesisSequenceNumber : sequenceNumsStateForCheckpoint.get()) {
+ sequenceNumsToRestore.put(kinesisSequenceNumber.f0, kinesisSequenceNumber.f1);
+ }
- if (fetcher == null) {
- LOG.debug("snapshotState() requested on not yet running source; returning null.");
- return null;
+ LOG.info("Setting restore state in the FlinkKinesisConsumer. Using the following offsets: {}",
+ sequenceNumsToRestore);
+ } else if (sequenceNumsToRestore.isEmpty()) {
+ sequenceNumsToRestore = null;
+ }
+ } else {
+ LOG.info("No restore state for FlinkKinesisConsumer.");
}
+ }
- if (!running) {
+ @Override
+ public void snapshotState(FunctionSnapshotContext context) throws Exception {
+ if (lastStateSnapshot == null) {
+ LOG.debug("snapshotState() requested on not yet opened source; returning null.");
+ } else if (fetcher == null) {
+ LOG.debug("snapshotState() requested on not yet running source; returning null.");
+ } else if (!running) {
LOG.debug("snapshotState() called on closed source; returning null.");
- return null;
- }
+ } else {
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Snapshotting state ...");
+ }
- if (LOG.isDebugEnabled()) {
- LOG.debug("Snapshotting state ...");
- }
+ sequenceNumsStateForCheckpoint.clear();
+ lastStateSnapshot = fetcher.snapshotState();
- lastStateSnapshot = fetcher.snapshotState();
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Snapshotted state, last processed sequence numbers: {}, checkpoint id: {}, timestamp: {}",
+ lastStateSnapshot.toString(), context.getCheckpointId(), context.getCheckpointTimestamp());
+ }
- if (LOG.isDebugEnabled()) {
- LOG.debug("Snapshotted state, last processed sequence numbers: {}, checkpoint id: {}, timestamp: {}",
- lastStateSnapshot.toString(), checkpointId, checkpointTimestamp);
+ for (Map.Entry<KinesisStreamShard, SequenceNumber> entry : lastStateSnapshot.entrySet()) {
+ sequenceNumsStateForCheckpoint.add(Tuple2.of(entry.getKey(), entry.getValue()));
+ }
}
-
- return lastStateSnapshot;
}
@Override
public void restoreState(HashMap<KinesisStreamShard, SequenceNumber> restoredState) throws Exception {
- sequenceNumsToRestore = restoredState;
+ LOG.info("Subtask {} restoring offsets from an older Flink version: {}",
+ getRuntimeContext().getIndexOfThisSubtask(), sequenceNumsToRestore);
+
+ sequenceNumsToRestore = restoredState.isEmpty() ? null : restoredState;
+ }
+
+ /** This method is created for tests that can mock the KinesisDataFetcher in the consumer. */
+ protected KinesisDataFetcher<T> createFetcher(List<String> streams,
+ SourceFunction.SourceContext<T> sourceContext,
+ RuntimeContext runtimeContext,
+ Properties configProps,
+ KinesisDeserializationSchema<T> deserializationSchema) {
+ return new KinesisDataFetcher<>(streams, sourceContext, runtimeContext, configProps, deserializationSchema);
+ }
+
+ @VisibleForTesting
+ HashMap<KinesisStreamShard, SequenceNumber> getRestoredState() {
+ return sequenceNumsToRestore;
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/a05b574c/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
index 8f7ca6c..c5b4b04 100644
--- a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
+++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
@@ -461,7 +461,7 @@ public class KinesisDataFetcher<T> {
* 3. Update the subscribedStreamsToLastDiscoveredShardIds state so that we won't get shards
* that we have already seen before the next time this function is called
*/
- private List<KinesisStreamShard> discoverNewShardsToSubscribe() throws InterruptedException {
+ public List<KinesisStreamShard> discoverNewShardsToSubscribe() throws InterruptedException {
List<KinesisStreamShard> newShardsToSubscribe = new LinkedList<>();
http://git-wip-us.apache.org/repos/asf/flink/blob/a05b574c/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java
new file mode 100644
index 0000000..2f46e09
--- /dev/null
+++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kinesis;
+
+import com.amazonaws.services.kinesis.model.Shard;
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.streaming.api.TimeCharacteristic;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.operators.StreamSource;
+import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants;
+import org.apache.flink.streaming.connectors.kinesis.internals.KinesisDataFetcher;
+import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard;
+import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber;
+import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchema;
+import org.apache.flink.streaming.connectors.kinesis.testutils.KinesisShardIdGenerator;
+import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
+import org.junit.Test;
+
+import java.net.URL;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Properties;
+
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Mockito.mock;
+
+/**
+ * Tests for checking whether {@link FlinkKinesisConsumer} can restore from snapshots that were
+ * done using the Flink 1.1 {@link FlinkKinesisConsumer}.
+ *
+ * <p>For regenerating the binary snapshot file you have to run the commented out portion
+ * of each test on a checkout of the Flink 1.1 branch.
+ */
+public class FlinkKinesisConsumerMigrationTest {
+
+ @Test
+ public void testRestoreFromFlink11WithEmptyState() throws Exception {
+ Properties testConfig = new Properties();
+ testConfig.setProperty(ConsumerConfigConstants.AWS_REGION, "us-east-1");
+ testConfig.setProperty(ConsumerConfigConstants.AWS_CREDENTIALS_PROVIDER, "BASIC");
+ testConfig.setProperty(ConsumerConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId");
+ testConfig.setProperty(ConsumerConfigConstants.AWS_SECRET_ACCESS_KEY, "secretKey");
+
+ final DummyFlinkKafkaConsumer<String> consumerFunction = new DummyFlinkKafkaConsumer<>(testConfig);
+
+ StreamSource<String, DummyFlinkKafkaConsumer<String>> consumerOperator = new StreamSource<>(consumerFunction);
+
+ final AbstractStreamOperatorTestHarness<String> testHarness =
+ new AbstractStreamOperatorTestHarness<>(consumerOperator, 1, 1, 0);
+
+ testHarness.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
+
+ testHarness.setup();
+ // restore state from binary snapshot file using legacy method
+ testHarness.initializeStateFromLegacyCheckpoint(
+ getResourceFilename("kinesis-consumer-migration-test-flink1.1-snapshot-empty"));
+ testHarness.open();
+
+ // assert that no state was restored
+ assertEquals(null, consumerFunction.getRestoredState());
+
+ consumerOperator.close();
+ consumerOperator.cancel();
+ }
+
+ @Test
+ public void testRestoreFromFlink11() throws Exception {
+ Properties testConfig = new Properties();
+ testConfig.setProperty(ConsumerConfigConstants.AWS_REGION, "us-east-1");
+ testConfig.setProperty(ConsumerConfigConstants.AWS_CREDENTIALS_PROVIDER, "BASIC");
+ testConfig.setProperty(ConsumerConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId");
+ testConfig.setProperty(ConsumerConfigConstants.AWS_SECRET_ACCESS_KEY, "secretKey");
+
+ final DummyFlinkKafkaConsumer<String> consumerFunction = new DummyFlinkKafkaConsumer<>(testConfig);
+
+ StreamSource<String, DummyFlinkKafkaConsumer<String>> consumerOperator =
+ new StreamSource<>(consumerFunction);
+
+ final AbstractStreamOperatorTestHarness<String> testHarness =
+ new AbstractStreamOperatorTestHarness<>(consumerOperator, 1, 1, 0);
+
+ testHarness.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
+
+ testHarness.setup();
+ // restore state from binary snapshot file using legacy method
+ testHarness.initializeStateFromLegacyCheckpoint(
+ getResourceFilename("kinesis-consumer-migration-test-flink1.1-snapshot"));
+ testHarness.open();
+
+ // the expected state in "kafka-consumer-migration-test-flink1.1-snapshot"
+ final HashMap<KinesisStreamShard, SequenceNumber> expectedState = new HashMap<>();
+ expectedState.put(new KinesisStreamShard("fakeStream1",
+ new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
+ new SequenceNumber("987654321"));
+
+ // assert that state is correctly restored from legacy checkpoint
+ assertNotEquals(null, consumerFunction.getRestoredState());
+ assertEquals(1, consumerFunction.getRestoredState().size());
+ assertEquals(expectedState, consumerFunction.getRestoredState());
+
+ consumerOperator.close();
+ consumerOperator.cancel();
+ }
+
+ // ------------------------------------------------------------------------
+
+ private static String getResourceFilename(String filename) {
+ ClassLoader cl = FlinkKinesisConsumerMigrationTest.class.getClassLoader();
+ URL resource = cl.getResource(filename);
+ if (resource == null) {
+ throw new NullPointerException("Missing snapshot resource.");
+ }
+ return resource.getFile();
+ }
+
+ private static class DummyFlinkKafkaConsumer<T> extends FlinkKinesisConsumer<T> {
+ private static final long serialVersionUID = 1L;
+
+ @SuppressWarnings("unchecked")
+ DummyFlinkKafkaConsumer(Properties properties) {
+ super("test", mock(KinesisDeserializationSchema.class), properties);
+ }
+
+ @Override
+ protected KinesisDataFetcher<T> createFetcher(List<String> streams,
+ SourceFunction.SourceContext<T> sourceContext,
+ RuntimeContext runtimeContext,
+ Properties configProps,
+ KinesisDeserializationSchema<T> deserializationSchema) {
+ return mock(KinesisDataFetcher.class);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/a05b574c/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java
index 741f0ca..bf8e44f 100644
--- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java
+++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java
@@ -18,13 +18,22 @@
package org.apache.flink.streaming.connectors.kinesis;
import com.amazonaws.services.kinesis.model.Shard;
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.OperatorStateStore;
+import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContextSynchronousImpl;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.connectors.kinesis.config.AWSConfigConstants;
import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants;
import org.apache.flink.streaming.connectors.kinesis.config.ProducerConfigConstants;
import org.apache.flink.streaming.connectors.kinesis.internals.KinesisDataFetcher;
import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard;
+import org.apache.flink.streaming.connectors.kinesis.model.SentinelSequenceNumber;
import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber;
import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState;
import org.apache.flink.streaming.connectors.kinesis.testutils.KinesisShardIdGenerator;
@@ -35,19 +44,29 @@ import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
+import org.mockito.Matchers;
import org.mockito.Mockito;
+import org.mockito.internal.util.reflection.Whitebox;
import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
-import java.text.SimpleDateFormat;
import java.util.HashMap;
import java.util.Map;
+import java.util.ArrayList;
+import java.util.List;
import java.util.Properties;
import java.util.UUID;
+import java.io.Serializable;
-import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotEquals;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.when;
+import static org.mockito.Mockito.never;
/**
* Suite of FlinkKinesisConsumer tests for the methods called throughout the source life cycle.
@@ -511,28 +530,149 @@ public class FlinkKinesisConsumerTest {
// ----------------------------------------------------------------------
@Test
- public void testSnapshotStateShouldBeNullIfSourceNotOpened() throws Exception {
+ public void testSnapshotStateShouldNotClearListStateIfSourceNotOpened() throws Exception {
Properties config = new Properties();
config.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1");
config.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId");
config.setProperty(AWSConfigConstants.AWS_SECRET_ACCESS_KEY, "secretKey");
+ OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
+
+ TestingListState<Serializable> listState = new TestingListState<>();
+
FlinkKinesisConsumer<String> consumer = new FlinkKinesisConsumer<>("fakeStream", new SimpleStringSchema(), config);
- assertTrue(consumer.snapshotState(123, 123) == null); //arbitrary checkpoint id and timestamp
+ when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+
+ StateInitializationContext initializationContext = mock(StateInitializationContext.class);
+
+ when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore);
+ when(initializationContext.isRestored()).thenReturn(false);
+
+ consumer.initializeState(initializationContext);
+
+ consumer.snapshotState(new StateSnapshotContextSynchronousImpl(123, 123)); //arbitrary checkpoint id and timestamp
+
+ assertFalse(listState.isClearCalled());
}
@Test
- public void testSnapshotStateShouldBeNullIfSourceNotRun() throws Exception {
+ public void testSnapshotStateShouldNotClearListStateIfSourceNotRun() throws Exception {
Properties config = new Properties();
config.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1");
config.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId");
config.setProperty(AWSConfigConstants.AWS_SECRET_ACCESS_KEY, "secretKey");
+ OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
+
+ TestingListState<Serializable> listState = new TestingListState<>();
+
FlinkKinesisConsumer<String> consumer = new FlinkKinesisConsumer<>("fakeStream", new SimpleStringSchema(), config);
+
+ when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+
+ StateInitializationContext initializationContext = mock(StateInitializationContext.class);
+
+ when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore);
+ when(initializationContext.isRestored()).thenReturn(false);
+
+ consumer.initializeState(initializationContext);
+
consumer.open(new Configuration()); // only opened, not run
- assertTrue(consumer.snapshotState(123, 123) == null); //arbitrary checkpoint id and timestamp
+ consumer.snapshotState(new StateSnapshotContextSynchronousImpl(123, 123)); //arbitrary checkpoint id and timestamp
+
+ assertFalse(listState.isClearCalled());
+ }
+
+ @Test
+ public void testListStateChangedAfterSnapshotState() throws Exception {
+ // ----------------------------------------------------------------------
+ // setting config, initial state and state after snapshot
+ // ----------------------------------------------------------------------
+ Properties config = new Properties();
+ config.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1");
+ config.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId");
+ config.setProperty(AWSConfigConstants.AWS_SECRET_ACCESS_KEY, "secretKey");
+
+ ArrayList<Tuple2<KinesisStreamShard, SequenceNumber>> initialState = new ArrayList<>(1);
+ initialState.add(Tuple2.of(
+ new KinesisStreamShard("fakeStream1",
+ new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
+ new SequenceNumber("1")));
+
+ ArrayList<Tuple2<KinesisStreamShard, SequenceNumber>> snapShotState = new ArrayList<>(3);
+ snapShotState.add(Tuple2.of(
+ new KinesisStreamShard("fakeStream1",
+ new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
+ new SequenceNumber("12")));
+ snapShotState.add(Tuple2.of(
+ new KinesisStreamShard("fakeStream1",
+ new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))),
+ new SequenceNumber("11")));
+ snapShotState.add(Tuple2.of(
+ new KinesisStreamShard("fakeStream1",
+ new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))),
+ new SequenceNumber("31")));
+
+ // ----------------------------------------------------------------------
+ // mock operator state backend and initial state for initializeState()
+ // ----------------------------------------------------------------------
+ TestingListState<Serializable> listState = new TestingListState<>();
+ for (Serializable state: initialState) {
+ listState.add(state);
+ }
+
+ OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
+ when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+
+ StateInitializationContext initializationContext = mock(StateInitializationContext.class);
+ when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore);
+ when(initializationContext.isRestored()).thenReturn(true);
+
+ // ----------------------------------------------------------------------
+ // mock a running fetcher and its state for snapshot
+ // ----------------------------------------------------------------------
+ HashMap<KinesisStreamShard, SequenceNumber> stateSnapshot = new HashMap<>();
+ for (Tuple2<KinesisStreamShard, SequenceNumber> tuple: snapShotState) {
+ stateSnapshot.put(tuple.f0, tuple.f1);
+ }
+
+ KinesisDataFetcher mockedFetcher = mock(KinesisDataFetcher.class);
+ when(mockedFetcher.snapshotState()).thenReturn(stateSnapshot);
+
+ // ----------------------------------------------------------------------
+ // create a consumer and test the snapshotState()
+ // ----------------------------------------------------------------------
+ FlinkKinesisConsumer<String> consumer = new FlinkKinesisConsumer<>("fakeStream", new SimpleStringSchema(), config);
+ FlinkKinesisConsumer<?> mockedConsumer = spy(consumer);
+
+ RuntimeContext context = mock(RuntimeContext.class);
+ when(context.getIndexOfThisSubtask()).thenReturn(1);
+
+ mockedConsumer.setRuntimeContext(context);
+ mockedConsumer.initializeState(initializationContext);
+ mockedConsumer.open(new Configuration());
+ Whitebox.setInternalState(mockedConsumer, "fetcher", mockedFetcher); // mock as consumer is running.
+
+ mockedConsumer.snapshotState(mock(FunctionSnapshotContext.class));
+
+ assertEquals(true, listState.clearCalled);
+ assertEquals(3, listState.getList().size());
+
+ for (Serializable state: initialState) {
+ for (Serializable currentState: listState.getList()) {
+ assertNotEquals(state, currentState);
+ }
+ }
+
+ for (Serializable state: snapShotState) {
+ boolean hasOneIsSame = false;
+ for (Serializable currentState: listState.getList()) {
+ hasOneIsSame = hasOneIsSame || state.equals(currentState);
+ }
+ assertEquals(true, hasOneIsSame);
+ }
}
// ----------------------------------------------------------------------
@@ -559,48 +699,288 @@ public class FlinkKinesisConsumerTest {
@Test
@SuppressWarnings("unchecked")
+ public void testFetcherShouldBeCorrectlySeededIfRestoringFromLegacyCheckpoint() throws Exception {
+ HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = getFakeRestoredStore("all");
+
+ KinesisDataFetcher mockedFetcher = Mockito.mock(KinesisDataFetcher.class);
+ List<KinesisStreamShard> shards = new ArrayList<>();
+ shards.addAll(fakeRestoredState.keySet());
+ when(mockedFetcher.discoverNewShardsToSubscribe()).thenReturn(shards);
+ PowerMockito.whenNew(KinesisDataFetcher.class).withAnyArguments().thenReturn(mockedFetcher);
+
+ // assume the given config is correct
+ PowerMockito.mockStatic(KinesisConfigUtil.class);
+ PowerMockito.doNothing().when(KinesisConfigUtil.class);
+
+ TestableFlinkKinesisConsumer consumer = new TestableFlinkKinesisConsumer(
+ "fakeStream", new Properties(), 10, 2);
+ consumer.restoreState(fakeRestoredState);
+ consumer.open(new Configuration());
+ consumer.run(Mockito.mock(SourceFunction.SourceContext.class));
+
+ Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true);
+ for (Map.Entry<KinesisStreamShard, SequenceNumber> restoredShard : fakeRestoredState.entrySet()) {
+ Mockito.verify(mockedFetcher).registerNewSubscribedShardState(
+ new KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue()));
+ }
+ }
+
+ @Test
+ @SuppressWarnings("unchecked")
public void testFetcherShouldBeCorrectlySeededIfRestoringFromCheckpoint() throws Exception {
+ // ----------------------------------------------------------------------
+ // setting initial state
+ // ----------------------------------------------------------------------
+ HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = getFakeRestoredStore("all");
+
+ // ----------------------------------------------------------------------
+ // mock operator state backend and initial state for initializeState()
+ // ----------------------------------------------------------------------
+ TestingListState<Serializable> listState = new TestingListState<>();
+ for (Map.Entry<KinesisStreamShard, SequenceNumber> state: fakeRestoredState.entrySet()) {
+ listState.add(Tuple2.of(state.getKey(), state.getValue()));
+ }
+
+ OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
+ when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+
+ StateInitializationContext initializationContext = mock(StateInitializationContext.class);
+ when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore);
+ when(initializationContext.isRestored()).thenReturn(true);
+
+ // ----------------------------------------------------------------------
+ // mock fetcher
+ // ----------------------------------------------------------------------
KinesisDataFetcher mockedFetcher = Mockito.mock(KinesisDataFetcher.class);
+ List<KinesisStreamShard> shards = new ArrayList<>();
+ shards.addAll(fakeRestoredState.keySet());
+ when(mockedFetcher.discoverNewShardsToSubscribe()).thenReturn(shards);
PowerMockito.whenNew(KinesisDataFetcher.class).withAnyArguments().thenReturn(mockedFetcher);
// assume the given config is correct
PowerMockito.mockStatic(KinesisConfigUtil.class);
PowerMockito.doNothing().when(KinesisConfigUtil.class);
- HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = new HashMap<>();
- fakeRestoredState.put(
- new KinesisStreamShard("fakeStream1",
- new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
- new SequenceNumber(UUID.randomUUID().toString()));
- fakeRestoredState.put(
- new KinesisStreamShard("fakeStream1",
- new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))),
- new SequenceNumber(UUID.randomUUID().toString()));
- fakeRestoredState.put(
- new KinesisStreamShard("fakeStream1",
- new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))),
- new SequenceNumber(UUID.randomUUID().toString()));
- fakeRestoredState.put(
- new KinesisStreamShard("fakeStream2",
- new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
- new SequenceNumber(UUID.randomUUID().toString()));
- fakeRestoredState.put(
- new KinesisStreamShard("fakeStream2",
- new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))),
- new SequenceNumber(UUID.randomUUID().toString()));
+ // ----------------------------------------------------------------------
+ // start to test seed initial state to fetcher
+ // ----------------------------------------------------------------------
+ TestableFlinkKinesisConsumer consumer = new TestableFlinkKinesisConsumer(
+ "fakeStream", new Properties(), 10, 2);
+ consumer.initializeState(initializationContext);
+ consumer.open(new Configuration());
+ consumer.run(Mockito.mock(SourceFunction.SourceContext.class));
+
+ Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true);
+ for (Map.Entry<KinesisStreamShard, SequenceNumber> restoredShard : fakeRestoredState.entrySet()) {
+ Mockito.verify(mockedFetcher).registerNewSubscribedShardState(
+ new KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue()));
+ }
+ }
+
+ @Test
+ @SuppressWarnings("unchecked")
+ public void testFetcherShouldBeCorrectlySeededOnlyItsOwnStates() throws Exception {
+ // ----------------------------------------------------------------------
+ // setting initial state
+ // ----------------------------------------------------------------------
+ HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = getFakeRestoredStore("fakeStream1");
+
+ HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredStateForOthers = getFakeRestoredStore("fakeStream2");
+
+ // ----------------------------------------------------------------------
+ // mock operator state backend and initial state for initializeState()
+ // ----------------------------------------------------------------------
+ TestingListState<Serializable> listState = new TestingListState<>();
+ for (Map.Entry<KinesisStreamShard, SequenceNumber> state: fakeRestoredState.entrySet()) {
+ listState.add(Tuple2.of(state.getKey(), state.getValue()));
+ }
+ for (Map.Entry<KinesisStreamShard, SequenceNumber> state: fakeRestoredStateForOthers.entrySet()) {
+ listState.add(Tuple2.of(state.getKey(), state.getValue()));
+ }
+ OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
+ when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+
+ StateInitializationContext initializationContext = mock(StateInitializationContext.class);
+ when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore);
+ when(initializationContext.isRestored()).thenReturn(true);
+
+ // ----------------------------------------------------------------------
+ // mock fetcher
+ // ----------------------------------------------------------------------
+ KinesisDataFetcher mockedFetcher = Mockito.mock(KinesisDataFetcher.class);
+ List<KinesisStreamShard> shards = new ArrayList<>();
+ shards.addAll(fakeRestoredState.keySet());
+ when(mockedFetcher.discoverNewShardsToSubscribe()).thenReturn(shards);
+ PowerMockito.whenNew(KinesisDataFetcher.class).withAnyArguments().thenReturn(mockedFetcher);
+
+ // assume the given config is correct
+ PowerMockito.mockStatic(KinesisConfigUtil.class);
+ PowerMockito.doNothing().when(KinesisConfigUtil.class);
+
+ // ----------------------------------------------------------------------
+ // start to test seed initial state to fetcher
+ // ----------------------------------------------------------------------
TestableFlinkKinesisConsumer consumer = new TestableFlinkKinesisConsumer(
"fakeStream", new Properties(), 10, 2);
- consumer.restoreState(fakeRestoredState);
+ consumer.initializeState(initializationContext);
consumer.open(new Configuration());
consumer.run(Mockito.mock(SourceFunction.SourceContext.class));
Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true);
+ for (Map.Entry<KinesisStreamShard, SequenceNumber> restoredShard : fakeRestoredStateForOthers.entrySet()) {
+ // should never get restored state not belonging to itself
+ Mockito.verify(mockedFetcher, never()).registerNewSubscribedShardState(
+ new KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue()));
+ }
for (Map.Entry<KinesisStreamShard, SequenceNumber> restoredShard : fakeRestoredState.entrySet()) {
- Mockito.verify(mockedFetcher).advanceLastDiscoveredShardOfStream(
- restoredShard.getKey().getStreamName(), restoredShard.getKey().getShard().getShardId());
+ // should get restored state belonging to itself
Mockito.verify(mockedFetcher).registerNewSubscribedShardState(
new KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue()));
}
}
+
+ /*
+ * If the original parallelism is 2 and states is:
+ * Consumer subtask 1:
+ * stream1, shard1, SequentialNumber(xxx)
+ * Consumer subtask 2:
+ * stream1, shard2, SequentialNumber(yyy)
+ * After discoverNewShardsToSubscribe() if there are two shards (shard3, shard4) been created:
+ * Consumer subtask 1 (late for discoverNewShardsToSubscribe()):
+ * stream1, shard1, SequentialNumber(xxx)
+ * Consumer subtask 2:
+ * stream1, shard2, SequentialNumber(yyy)
+ * stream1, shard4, SequentialNumber(zzz)
+ * If snapshotState() occur and parallelism is changed to 1:
+ * Union state will be:
+ * stream1, shard1, SequentialNumber(xxx)
+ * stream1, shard2, SequentialNumber(yyy)
+ * stream1, shard4, SequentialNumber(zzz)
+ * Fetcher should be seeded with:
+ * stream1, shard1, SequentialNumber(xxx)
+ * stream1, shard2, SequentialNumber(yyy)
+ * stream1, share3, SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM
+ * stream1, shard4, SequentialNumber(zzz)
+ *
+ * This test is to guarantee the fetcher will be seeded correctly for such situation.
+ */
+ @Test
+ @SuppressWarnings("unchecked")
+ public void testFetcherShouldBeCorrectlySeededWithNewDiscoveredKinesisStreamShard() throws Exception {
+ // ----------------------------------------------------------------------
+ // setting initial state
+ // ----------------------------------------------------------------------
+ HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = getFakeRestoredStore("all");
+
+ // ----------------------------------------------------------------------
+ // mock operator state backend and initial state for initializeState()
+ // ----------------------------------------------------------------------
+ TestingListState<Serializable> listState = new TestingListState<>();
+ for (Map.Entry<KinesisStreamShard, SequenceNumber> state: fakeRestoredState.entrySet()) {
+ listState.add(Tuple2.of(state.getKey(), state.getValue()));
+ }
+
+ OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
+ when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+
+ StateInitializationContext initializationContext = mock(StateInitializationContext.class);
+ when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore);
+ when(initializationContext.isRestored()).thenReturn(true);
+
+ // ----------------------------------------------------------------------
+ // mock fetcher
+ // ----------------------------------------------------------------------
+ KinesisDataFetcher mockedFetcher = Mockito.mock(KinesisDataFetcher.class);
+ List<KinesisStreamShard> shards = new ArrayList<>();
+ shards.addAll(fakeRestoredState.keySet());
+ shards.add(new KinesisStreamShard("fakeStream2",
+ new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))));
+ when(mockedFetcher.discoverNewShardsToSubscribe()).thenReturn(shards);
+ PowerMockito.whenNew(KinesisDataFetcher.class).withAnyArguments().thenReturn(mockedFetcher);
+
+ // assume the given config is correct
+ PowerMockito.mockStatic(KinesisConfigUtil.class);
+ PowerMockito.doNothing().when(KinesisConfigUtil.class);
+
+ // ----------------------------------------------------------------------
+ // start to test seed initial state to fetcher
+ // ----------------------------------------------------------------------
+ TestableFlinkKinesisConsumer consumer = new TestableFlinkKinesisConsumer(
+ "fakeStream", new Properties(), 10, 2);
+ consumer.initializeState(initializationContext);
+ consumer.open(new Configuration());
+ consumer.run(Mockito.mock(SourceFunction.SourceContext.class));
+
+ fakeRestoredState.put(new KinesisStreamShard("fakeStream2",
+ new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))),
+ SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get());
+ Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true);
+ for (Map.Entry<KinesisStreamShard, SequenceNumber> restoredShard : fakeRestoredState.entrySet()) {
+ Mockito.verify(mockedFetcher).registerNewSubscribedShardState(
+ new KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue()));
+ }
+ }
+
+ private static final class TestingListState<T> implements ListState<T> {
+
+ private final List<T> list = new ArrayList<>();
+ private boolean clearCalled = false;
+
+ @Override
+ public void clear() {
+ list.clear();
+ clearCalled = true;
+ }
+
+ @Override
+ public Iterable<T> get() throws Exception {
+ return list;
+ }
+
+ @Override
+ public void add(T value) throws Exception {
+ list.add(value);
+ }
+
+ public List<T> getList() {
+ return list;
+ }
+
+ public boolean isClearCalled() {
+ return clearCalled;
+ }
+ }
+
+ private HashMap<KinesisStreamShard, SequenceNumber> getFakeRestoredStore(String streamName) {
+ HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = new HashMap<>();
+
+ if (streamName.equals("fakeStream1") || streamName.equals("all")) {
+ fakeRestoredState.put(
+ new KinesisStreamShard("fakeStream1",
+ new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
+ new SequenceNumber(UUID.randomUUID().toString()));
+ fakeRestoredState.put(
+ new KinesisStreamShard("fakeStream1",
+ new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))),
+ new SequenceNumber(UUID.randomUUID().toString()));
+ fakeRestoredState.put(
+ new KinesisStreamShard("fakeStream1",
+ new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))),
+ new SequenceNumber(UUID.randomUUID().toString()));
+ }
+
+ if (streamName.equals("fakeStream2") || streamName.equals("all")) {
+ fakeRestoredState.put(
+ new KinesisStreamShard("fakeStream2",
+ new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
+ new SequenceNumber(UUID.randomUUID().toString()));
+ fakeRestoredState.put(
+ new KinesisStreamShard("fakeStream2",
+ new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))),
+ new SequenceNumber(UUID.randomUUID().toString()));
+ }
+
+ return fakeRestoredState;
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/a05b574c/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.1-snapshot
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.1-snapshot b/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.1-snapshot
new file mode 100644
index 0000000..b60402e
Binary files /dev/null and b/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.1-snapshot differ
http://git-wip-us.apache.org/repos/asf/flink/blob/a05b574c/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.1-snapshot-empty
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.1-snapshot-empty b/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.1-snapshot-empty
new file mode 100644
index 0000000..f4dd96d
Binary files /dev/null and b/flink-connectors/flink-connector-kinesis/src/test/resources/kinesis-consumer-migration-test-flink1.1-snapshot-empty differ
[2/2] flink git commit: [FLINK-4821] [kinesis] General improvements
to rescalable FlinkKinesisConsumer
Posted by tz...@apache.org.
[FLINK-4821] [kinesis] General improvements to rescalable FlinkKinesisConsumer
This commit adds some general improvements to the rescalable
implementation of FlinkKinesisConsumer, including:
- Refactor setup procedures in KinesisDataFetcher so that duplicate work
isn't done on a restored run
- Strengthen corner cases where fetcher was not fully seeded with
initial state when snapshot is taken
This closes #3001.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/e5b65a7f
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/e5b65a7f
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/e5b65a7f
Branch: refs/heads/master
Commit: e5b65a7fc2b4a7532ca40748f81bcbf8ace46815
Parents: a05b574
Author: Tzu-Li (Gordon) Tai <tz...@apache.org>
Authored: Sun May 7 16:29:32 2017 +0800
Committer: Tzu-Li (Gordon) Tai <tz...@apache.org>
Committed: Sun May 7 17:33:04 2017 +0800
----------------------------------------------------------------------
.../kinesis/FlinkKinesisConsumer.java | 150 ++++++++---------
.../kinesis/internals/KinesisDataFetcher.java | 52 +-----
.../FlinkKinesisConsumerMigrationTest.java | 5 +-
.../kinesis/FlinkKinesisConsumerTest.java | 159 +++++++++++--------
.../internals/KinesisDataFetcherTest.java | 65 ++++++--
.../testutils/TestableKinesisDataFetcher.java | 14 ++
6 files changed, 233 insertions(+), 212 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/e5b65a7f/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java
index dfcd552..4982f7f 100644
--- a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java
+++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java
@@ -25,13 +25,14 @@ import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
-import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.checkpoint.CheckpointedRestoring;
import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants;
+import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants.InitialPosition;
import org.apache.flink.streaming.connectors.kinesis.internals.KinesisDataFetcher;
import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard;
import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState;
@@ -67,9 +68,9 @@ import static org.apache.flink.util.Preconditions.checkNotNull;
* @param <T> the type of data emitted
*/
public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> implements
- ResultTypeQueryable<T>,
- CheckpointedFunction,
- CheckpointedRestoring<HashMap<KinesisStreamShard, SequenceNumber>> {
+ ResultTypeQueryable<T>,
+ CheckpointedFunction,
+ CheckpointedRestoring<HashMap<KinesisStreamShard, SequenceNumber>> {
private static final long serialVersionUID = 4724006128720664870L;
@@ -86,7 +87,7 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
* shard list retrieval behaviours, etc */
private final Properties configProps;
- /** User supplied deseriliazation schema to convert Kinesis byte messages to Flink objects */
+ /** User supplied deserialization schema to convert Kinesis byte messages to Flink objects */
private final KinesisDeserializationSchema<T> deserializer;
// ------------------------------------------------------------------------
@@ -96,9 +97,6 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
/** Per-task fetcher for Kinesis data records, where each fetcher pulls data from one or more Kinesis shards */
private transient KinesisDataFetcher<T> fetcher;
- /** The sequence numbers in the last state snapshot of this subtask */
- private transient HashMap<KinesisStreamShard, SequenceNumber> lastStateSnapshot;
-
/** The sequence numbers to restore to upon restore from failure */
private transient HashMap<KinesisStreamShard, SequenceNumber> sequenceNumsToRestore;
@@ -108,7 +106,7 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
// State for Checkpoint
// ------------------------------------------------------------------------
- /** The name is the key for sequence numbers state, and cannot be changed. */
+ /** State name to access shard sequence number states; cannot be changed */
private static final String sequenceNumsStateStoreName = "Kinesis-Stream-Shard-State";
private transient ListState<Tuple2<KinesisStreamShard, SequenceNumber>> sequenceNumsStateForCheckpoint;
@@ -191,57 +189,33 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
// ------------------------------------------------------------------------
@Override
- public void open(Configuration parameters) throws Exception {
- super.open(parameters);
-
- // restore to the last known sequence numbers from the latest complete snapshot
- if (sequenceNumsToRestore != null) {
- if (LOG.isInfoEnabled()) {
- LOG.info("Subtask {} is restoring sequence numbers {} from previous checkpointed state",
- getRuntimeContext().getIndexOfThisSubtask(), sequenceNumsToRestore.toString());
- }
-
- // initialize sequence numbers with restored state
- lastStateSnapshot = sequenceNumsToRestore;
- } else {
- // start fresh with empty sequence numbers if there are no snapshots to restore from.
- lastStateSnapshot = new HashMap<>();
- }
- }
-
- @Override
public void run(SourceContext<T> sourceContext) throws Exception {
// all subtasks will run a fetcher, regardless of whether or not the subtask will initially have
// shards to subscribe to; fetchers will continuously poll for changes in the shard list, so all subtasks
// can potentially have new shards to subscribe to later on
- fetcher = createFetcher(streams, sourceContext, getRuntimeContext(), configProps, deserializer);
-
- boolean isRestoringFromFailure = (sequenceNumsToRestore != null);
- fetcher.setIsRestoringFromFailure(isRestoringFromFailure);
-
- // if we are restoring from a checkpoint, we iterate over the restored
- // state and accordingly seed the fetcher with subscribed shards states
- if (isRestoringFromFailure) {
- // Since there may have a situation that some subtasks did not finish discovering before rescale,
- // and KinesisDataFetcher will always discover the shard from the largest shard id. To prevent from
- // missing some shards which didn't be discovered and whose id is not the largest one, we force the
- // consumer to discover once from the smallest id and make sure each shard have its initial sequence
- // number from restored state or SENTINEL_EARLIEST_SEQUENCE_NUM.
- List<KinesisStreamShard> newShardsCreatedWhileNotRunning = fetcher.discoverNewShardsToSubscribe();
- for (KinesisStreamShard shard : newShardsCreatedWhileNotRunning) {
- SequenceNumber startingStateForNewShard;
-
- if (lastStateSnapshot.containsKey(shard)) {
- startingStateForNewShard = lastStateSnapshot.get(shard);
+ KinesisDataFetcher<T> fetcher = createFetcher(streams, sourceContext, getRuntimeContext(), configProps, deserializer);
+
+ // initial discovery
+ List<KinesisStreamShard> allShards = fetcher.discoverNewShardsToSubscribe();
+
+ for (KinesisStreamShard shard : allShards) {
+ if (sequenceNumsToRestore != null) {
+ if (sequenceNumsToRestore.containsKey(shard)) {
+ // if the shard was already seen and is contained in the state,
+ // just use the sequence number stored in the state
+ fetcher.registerNewSubscribedShardState(
+ new KinesisStreamShardState(shard, sequenceNumsToRestore.get(shard)));
if (LOG.isInfoEnabled()) {
LOG.info("Subtask {} is seeding the fetcher with restored shard {}," +
" starting state set to the restored sequence number {}",
- getRuntimeContext().getIndexOfThisSubtask(), shard.toString(), startingStateForNewShard);
+ getRuntimeContext().getIndexOfThisSubtask(), shard.toString(), sequenceNumsToRestore.get(shard));
}
} else {
- startingStateForNewShard = SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get();
+ // the shard wasn't discovered in the previous run, therefore should be consumed from the beginning
+ fetcher.registerNewSubscribedShardState(
+ new KinesisStreamShardState(shard, SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get()));
if (LOG.isInfoEnabled()) {
LOG.info("Subtask {} is seeding the fetcher with new discovered shard {}," +
@@ -249,9 +223,20 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
getRuntimeContext().getIndexOfThisSubtask(), shard.toString());
}
}
+ } else {
+ // we're starting fresh; use the configured start position as initial state
+ SentinelSequenceNumber startingSeqNum =
+ InitialPosition.valueOf(configProps.getProperty(
+ ConsumerConfigConstants.STREAM_INITIAL_POSITION,
+ ConsumerConfigConstants.DEFAULT_STREAM_INITIAL_POSITION)).toSentinelSequenceNumber();
fetcher.registerNewSubscribedShardState(
- new KinesisStreamShardState(shard, startingStateForNewShard));
+ new KinesisStreamShardState(shard, startingSeqNum.get()));
+
+ if (LOG.isInfoEnabled()) {
+ LOG.info("Subtask {} will be seeded with initial shard {}, starting state set as sequence number {}",
+ getRuntimeContext().getIndexOfThisSubtask(), shard.toString(), startingSeqNum.get());
+ }
}
}
@@ -260,6 +245,10 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
return;
}
+ // expose the fetcher from this point, so that state
+ // snapshots can be taken from the fetcher's state holders
+ this.fetcher = fetcher;
+
// start the fetcher loop. The fetcher will stop running only when cancel() or
// close() is called, or an error is thrown by threads created by the fetcher
fetcher.runFetcher();
@@ -306,13 +295,12 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
@Override
public void initializeState(FunctionInitializationContext context) throws Exception {
- TypeInformation<Tuple2<KinesisStreamShard, SequenceNumber>> tuple = new TupleTypeInfo<>(
+ TypeInformation<Tuple2<KinesisStreamShard, SequenceNumber>> shardsStateTypeInfo = new TupleTypeInfo<>(
TypeInformation.of(KinesisStreamShard.class),
- TypeInformation.of(SequenceNumber.class)
- );
+ TypeInformation.of(SequenceNumber.class));
sequenceNumsStateForCheckpoint = context.getOperatorStateStore().getUnionListState(
- new ListStateDescriptor<>(sequenceNumsStateStoreName, tuple));
+ new ListStateDescriptor<>(sequenceNumsStateStoreName, shardsStateTypeInfo));
if (context.isRestored()) {
if (sequenceNumsToRestore == null) {
@@ -323,8 +311,6 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
LOG.info("Setting restore state in the FlinkKinesisConsumer. Using the following offsets: {}",
sequenceNumsToRestore);
- } else if (sequenceNumsToRestore.isEmpty()) {
- sequenceNumsToRestore = null;
}
} else {
LOG.info("No restore state for FlinkKinesisConsumer.");
@@ -333,11 +319,7 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
@Override
public void snapshotState(FunctionSnapshotContext context) throws Exception {
- if (lastStateSnapshot == null) {
- LOG.debug("snapshotState() requested on not yet opened source; returning null.");
- } else if (fetcher == null) {
- LOG.debug("snapshotState() requested on not yet running source; returning null.");
- } else if (!running) {
+ if (!running) {
LOG.debug("snapshotState() called on closed source; returning null.");
} else {
if (LOG.isDebugEnabled()) {
@@ -345,15 +327,33 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
}
sequenceNumsStateForCheckpoint.clear();
- lastStateSnapshot = fetcher.snapshotState();
- if (LOG.isDebugEnabled()) {
- LOG.debug("Snapshotted state, last processed sequence numbers: {}, checkpoint id: {}, timestamp: {}",
- lastStateSnapshot.toString(), context.getCheckpointId(), context.getCheckpointTimestamp());
- }
+ if (fetcher == null) {
+ if (sequenceNumsToRestore != null) {
+ for (Map.Entry<KinesisStreamShard, SequenceNumber> entry : sequenceNumsToRestore.entrySet()) {
+ // sequenceNumsToRestore is the restored global union state;
+ // should only snapshot shards that actually belong to us
+
+ if (KinesisDataFetcher.isThisSubtaskShouldSubscribeTo(
+ entry.getKey(),
+ getRuntimeContext().getNumberOfParallelSubtasks(),
+ getRuntimeContext().getIndexOfThisSubtask())) {
+
+ sequenceNumsStateForCheckpoint.add(Tuple2.of(entry.getKey(), entry.getValue()));
+ }
+ }
+ }
+ } else {
+ HashMap<KinesisStreamShard, SequenceNumber> lastStateSnapshot = fetcher.snapshotState();
+
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Snapshotted state, last processed sequence numbers: {}, checkpoint id: {}, timestamp: {}",
+ lastStateSnapshot.toString(), context.getCheckpointId(), context.getCheckpointTimestamp());
+ }
- for (Map.Entry<KinesisStreamShard, SequenceNumber> entry : lastStateSnapshot.entrySet()) {
- sequenceNumsStateForCheckpoint.add(Tuple2.of(entry.getKey(), entry.getValue()));
+ for (Map.Entry<KinesisStreamShard, SequenceNumber> entry : lastStateSnapshot.entrySet()) {
+ sequenceNumsStateForCheckpoint.add(Tuple2.of(entry.getKey(), entry.getValue()));
+ }
}
}
}
@@ -366,12 +366,14 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
sequenceNumsToRestore = restoredState.isEmpty() ? null : restoredState;
}
- /** This method is created for tests that can mock the KinesisDataFetcher in the consumer. */
- protected KinesisDataFetcher<T> createFetcher(List<String> streams,
- SourceFunction.SourceContext<T> sourceContext,
- RuntimeContext runtimeContext,
- Properties configProps,
- KinesisDeserializationSchema<T> deserializationSchema) {
+ /** This method is exposed for tests that need to mock the KinesisDataFetcher in the consumer. */
+ protected KinesisDataFetcher<T> createFetcher(
+ List<String> streams,
+ SourceFunction.SourceContext<T> sourceContext,
+ RuntimeContext runtimeContext,
+ Properties configProps,
+ KinesisDeserializationSchema<T> deserializationSchema) {
+
return new KinesisDataFetcher<>(streams, sourceContext, runtimeContext, configProps, deserializationSchema);
}
http://git-wip-us.apache.org/repos/asf/flink/blob/e5b65a7f/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
index c5b4b04..99305cb 100644
--- a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
+++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
@@ -19,9 +19,7 @@ package org.apache.flink.streaming.connectors.kinesis.internals;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
-import org.apache.flink.streaming.connectors.kinesis.FlinkKinesisConsumer;
import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants;
-import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants.InitialPosition;
import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard;
import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState;
import org.apache.flink.streaming.connectors.kinesis.model.SentinelSequenceNumber;
@@ -99,12 +97,6 @@ public class KinesisDataFetcher<T> {
private final int indexOfThisConsumerSubtask;
- /**
- * This flag should be set by {@link FlinkKinesisConsumer} using
- * {@link KinesisDataFetcher#setIsRestoringFromFailure(boolean)}
- */
- private boolean isRestoredFromFailure;
-
// ------------------------------------------------------------------------
// Executor services to run created threads
// ------------------------------------------------------------------------
@@ -235,41 +227,7 @@ public class KinesisDataFetcher<T> {
// Procedures before starting the infinite while loop:
// ------------------------------------------------------------------------
- // 1. query for any new shards that may have been created while the Kinesis consumer was not running,
- // and register them to the subscribedShardState list.
- if (LOG.isDebugEnabled()) {
- String logFormat = (!isRestoredFromFailure)
- ? "Subtask {} is trying to discover initial shards ..."
- : "Subtask {} is trying to discover any new shards that were created while the consumer wasn't " +
- "running due to failure ...";
-
- LOG.debug(logFormat, indexOfThisConsumerSubtask);
- }
- List<KinesisStreamShard> newShardsCreatedWhileNotRunning = discoverNewShardsToSubscribe();
- for (KinesisStreamShard shard : newShardsCreatedWhileNotRunning) {
- // the starting state for new shards created while the consumer wasn't running depends on whether or not
- // we are starting fresh (not restoring from a checkpoint); when we are starting fresh, this simply means
- // all existing shards of streams we are subscribing to are new shards; when we are restoring from checkpoint,
- // any new shards due to Kinesis resharding from the time of the checkpoint will be considered new shards.
- InitialPosition initialPosition = InitialPosition.valueOf(configProps.getProperty(
- ConsumerConfigConstants.STREAM_INITIAL_POSITION, ConsumerConfigConstants.DEFAULT_STREAM_INITIAL_POSITION));
-
- SentinelSequenceNumber startingStateForNewShard = (isRestoredFromFailure)
- ? SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM
- : initialPosition.toSentinelSequenceNumber();
-
- if (LOG.isInfoEnabled()) {
- String logFormat = (!isRestoredFromFailure)
- ? "Subtask {} will be seeded with initial shard {}, starting state set as sequence number {}"
- : "Subtask {} will be seeded with new shard {} that was created while the consumer wasn't " +
- "running due to failure, starting state set as sequence number {}";
-
- LOG.info(logFormat, indexOfThisConsumerSubtask, shard.toString(), startingStateForNewShard.get());
- }
- registerNewSubscribedShardState(new KinesisStreamShardState(shard, startingStateForNewShard.get()));
- }
-
- // 2. check that there is at least one shard in the subscribed streams to consume from (can be done by
+ // 1. check that there is at least one shard in the subscribed streams to consume from (can be done by
// checking if at least one value in subscribedStreamsToLastDiscoveredShardIds is not null)
boolean hasShards = false;
StringBuilder streamsWithNoShardsFound = new StringBuilder();
@@ -290,7 +248,7 @@ public class KinesisDataFetcher<T> {
throw new RuntimeException("No shards can be found for all subscribed streams: " + streams);
}
- // 3. start consuming any shard state we already have in the subscribedShardState up to this point; the
+ // 2. start consuming any shard state we already have in the subscribedShardState up to this point; the
// subscribedShardState may already be seeded with values due to step 1., or explicitly added by the
// consumer using a restored state checkpoint
for (int seededStateIndex = 0; seededStateIndex < subscribedShardsState.size(); seededStateIndex++) {
@@ -489,10 +447,6 @@ public class KinesisDataFetcher<T> {
// Functions to get / set information about the consumer
// ------------------------------------------------------------------------
- public void setIsRestoringFromFailure(boolean bool) {
- this.isRestoredFromFailure = bool;
- }
-
protected Properties getConsumerConfiguration() {
return configProps;
}
@@ -595,7 +549,7 @@ public class KinesisDataFetcher<T> {
* @param totalNumberOfConsumerSubtasks total number of consumer subtasks
* @param indexOfThisConsumerSubtask index of this consumer subtask
*/
- private static boolean isThisSubtaskShouldSubscribeTo(KinesisStreamShard shard,
+ public static boolean isThisSubtaskShouldSubscribeTo(KinesisStreamShard shard,
int totalNumberOfConsumerSubtasks,
int indexOfThisConsumerSubtask) {
return (Math.abs(shard.hashCode() % totalNumberOfConsumerSubtasks)) == indexOfThisConsumerSubtask;
http://git-wip-us.apache.org/repos/asf/flink/blob/e5b65a7f/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java
index 2f46e09..7629f9d 100644
--- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java
+++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java
@@ -42,10 +42,7 @@ import static org.mockito.Mockito.mock;
/**
* Tests for checking whether {@link FlinkKinesisConsumer} can restore from snapshots that were
- * done using the Flink 1.1 {@link FlinkKinesisConsumer}.
- *
- * <p>For regenerating the binary snapshot file you have to run the commented out portion
- * of each test on a checkout of the Flink 1.1 branch.
+ * done using the Flink 1.1 {@code FlinkKinesisConsumer}.
*/
public class FlinkKinesisConsumerMigrationTest {
http://git-wip-us.apache.org/repos/asf/flink/blob/e5b65a7f/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java
index bf8e44f..4b178c7 100644
--- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java
+++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java
@@ -40,6 +40,7 @@ import org.apache.flink.streaming.connectors.kinesis.testutils.KinesisShardIdGen
import org.apache.flink.streaming.connectors.kinesis.testutils.TestableFlinkKinesisConsumer;
import org.apache.flink.streaming.connectors.kinesis.util.KinesisConfigUtil;
import org.apache.flink.streaming.util.serialization.SimpleStringSchema;
+import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
@@ -57,10 +58,8 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
import java.util.UUID;
-import java.io.Serializable;
import static org.junit.Assert.fail;
-import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.mockito.Mockito.mock;
@@ -530,7 +529,7 @@ public class FlinkKinesisConsumerTest {
// ----------------------------------------------------------------------
@Test
- public void testSnapshotStateShouldNotClearListStateIfSourceNotOpened() throws Exception {
+ public void testUseRestoredStateForSnapshotIfFetcherNotInitialized() throws Exception {
Properties config = new Properties();
config.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1");
config.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId");
@@ -538,57 +537,63 @@ public class FlinkKinesisConsumerTest {
OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
- TestingListState<Serializable> listState = new TestingListState<>();
-
- FlinkKinesisConsumer<String> consumer = new FlinkKinesisConsumer<>("fakeStream", new SimpleStringSchema(), config);
-
- when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
-
- StateInitializationContext initializationContext = mock(StateInitializationContext.class);
-
- when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore);
- when(initializationContext.isRestored()).thenReturn(false);
-
- consumer.initializeState(initializationContext);
-
- consumer.snapshotState(new StateSnapshotContextSynchronousImpl(123, 123)); //arbitrary checkpoint id and timestamp
-
- assertFalse(listState.isClearCalled());
- }
-
- @Test
- public void testSnapshotStateShouldNotClearListStateIfSourceNotRun() throws Exception {
- Properties config = new Properties();
- config.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1");
- config.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId");
- config.setProperty(AWSConfigConstants.AWS_SECRET_ACCESS_KEY, "secretKey");
-
- OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
+ List<Tuple2<KinesisStreamShard, SequenceNumber>> globalUnionState = new ArrayList<>(4);
+ globalUnionState.add(Tuple2.of(
+ new KinesisStreamShard("fakeStream",
+ new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
+ new SequenceNumber("1")));
+ globalUnionState.add(Tuple2.of(
+ new KinesisStreamShard("fakeStream",
+ new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))),
+ new SequenceNumber("1")));
+ globalUnionState.add(Tuple2.of(
+ new KinesisStreamShard("fakeStream",
+ new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))),
+ new SequenceNumber("1")));
+ globalUnionState.add(Tuple2.of(
+ new KinesisStreamShard("fakeStream",
+ new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(3))),
+ new SequenceNumber("1")));
- TestingListState<Serializable> listState = new TestingListState<>();
+ TestingListState<Tuple2<KinesisStreamShard, SequenceNumber>> listState = new TestingListState<>();
+ for (Tuple2<KinesisStreamShard, SequenceNumber> state : globalUnionState) {
+ listState.add(state);
+ }
FlinkKinesisConsumer<String> consumer = new FlinkKinesisConsumer<>("fakeStream", new SimpleStringSchema(), config);
+ RuntimeContext context = mock(RuntimeContext.class);
+ when(context.getIndexOfThisSubtask()).thenReturn(0);
+ when(context.getNumberOfParallelSubtasks()).thenReturn(2);
+ consumer.setRuntimeContext(context);
when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
StateInitializationContext initializationContext = mock(StateInitializationContext.class);
when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore);
- when(initializationContext.isRestored()).thenReturn(false);
+ when(initializationContext.isRestored()).thenReturn(true);
consumer.initializeState(initializationContext);
- consumer.open(new Configuration()); // only opened, not run
+ // only opened, not run
+ consumer.open(new Configuration());
+
+ // arbitrary checkpoint id and timestamp
+ consumer.snapshotState(new StateSnapshotContextSynchronousImpl(123, 123));
- consumer.snapshotState(new StateSnapshotContextSynchronousImpl(123, 123)); //arbitrary checkpoint id and timestamp
+ Assert.assertTrue(listState.isClearCalled());
- assertFalse(listState.isClearCalled());
+ // the checkpointed list state should contain only the shards that it should subscribe to
+ Assert.assertEquals(globalUnionState.size() / 2, listState.getList().size());
+ Assert.assertTrue(listState.getList().contains(globalUnionState.get(0)));
+ Assert.assertTrue(listState.getList().contains(globalUnionState.get(2)));
}
@Test
public void testListStateChangedAfterSnapshotState() throws Exception {
+
// ----------------------------------------------------------------------
- // setting config, initial state and state after snapshot
+ // setup config, initial state and expected state snapshot
// ----------------------------------------------------------------------
Properties config = new Properties();
config.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1");
@@ -601,16 +606,16 @@ public class FlinkKinesisConsumerTest {
new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
new SequenceNumber("1")));
- ArrayList<Tuple2<KinesisStreamShard, SequenceNumber>> snapShotState = new ArrayList<>(3);
- snapShotState.add(Tuple2.of(
+ ArrayList<Tuple2<KinesisStreamShard, SequenceNumber>> expectedStateSnapshot = new ArrayList<>(3);
+ expectedStateSnapshot.add(Tuple2.of(
new KinesisStreamShard("fakeStream1",
new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
new SequenceNumber("12")));
- snapShotState.add(Tuple2.of(
+ expectedStateSnapshot.add(Tuple2.of(
new KinesisStreamShard("fakeStream1",
new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))),
new SequenceNumber("11")));
- snapShotState.add(Tuple2.of(
+ expectedStateSnapshot.add(Tuple2.of(
new KinesisStreamShard("fakeStream1",
new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))),
new SequenceNumber("31")));
@@ -618,8 +623,9 @@ public class FlinkKinesisConsumerTest {
// ----------------------------------------------------------------------
// mock operator state backend and initial state for initializeState()
// ----------------------------------------------------------------------
- TestingListState<Serializable> listState = new TestingListState<>();
- for (Serializable state: initialState) {
+
+ TestingListState<Tuple2<KinesisStreamShard, SequenceNumber>> listState = new TestingListState<>();
+ for (Tuple2<KinesisStreamShard, SequenceNumber> state: initialState) {
listState.add(state);
}
@@ -633,8 +639,9 @@ public class FlinkKinesisConsumerTest {
// ----------------------------------------------------------------------
// mock a running fetcher and its state for snapshot
// ----------------------------------------------------------------------
+
HashMap<KinesisStreamShard, SequenceNumber> stateSnapshot = new HashMap<>();
- for (Tuple2<KinesisStreamShard, SequenceNumber> tuple: snapShotState) {
+ for (Tuple2<KinesisStreamShard, SequenceNumber> tuple: expectedStateSnapshot) {
stateSnapshot.put(tuple.f0, tuple.f1);
}
@@ -644,6 +651,7 @@ public class FlinkKinesisConsumerTest {
// ----------------------------------------------------------------------
// create a consumer and test the snapshotState()
// ----------------------------------------------------------------------
+
FlinkKinesisConsumer<String> consumer = new FlinkKinesisConsumer<>("fakeStream", new SimpleStringSchema(), config);
FlinkKinesisConsumer<?> mockedConsumer = spy(consumer);
@@ -653,22 +661,22 @@ public class FlinkKinesisConsumerTest {
mockedConsumer.setRuntimeContext(context);
mockedConsumer.initializeState(initializationContext);
mockedConsumer.open(new Configuration());
- Whitebox.setInternalState(mockedConsumer, "fetcher", mockedFetcher); // mock as consumer is running.
+ Whitebox.setInternalState(mockedConsumer, "fetcher", mockedFetcher); // mock consumer as running.
mockedConsumer.snapshotState(mock(FunctionSnapshotContext.class));
assertEquals(true, listState.clearCalled);
assertEquals(3, listState.getList().size());
- for (Serializable state: initialState) {
- for (Serializable currentState: listState.getList()) {
+ for (Tuple2<KinesisStreamShard, SequenceNumber> state: initialState) {
+ for (Tuple2<KinesisStreamShard, SequenceNumber> currentState: listState.getList()) {
assertNotEquals(state, currentState);
}
}
- for (Serializable state: snapShotState) {
+ for (Tuple2<KinesisStreamShard, SequenceNumber> state: expectedStateSnapshot) {
boolean hasOneIsSame = false;
- for (Serializable currentState: listState.getList()) {
+ for (Tuple2<KinesisStreamShard, SequenceNumber> currentState: listState.getList()) {
hasOneIsSame = hasOneIsSame || state.equals(currentState);
}
assertEquals(true, hasOneIsSame);
@@ -693,8 +701,6 @@ public class FlinkKinesisConsumerTest {
"fakeStream", new Properties(), 10, 2);
consumer.open(new Configuration());
consumer.run(Mockito.mock(SourceFunction.SourceContext.class));
-
- Mockito.verify(mockedFetcher).setIsRestoringFromFailure(false);
}
@Test
@@ -718,7 +724,6 @@ public class FlinkKinesisConsumerTest {
consumer.open(new Configuration());
consumer.run(Mockito.mock(SourceFunction.SourceContext.class));
- Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true);
for (Map.Entry<KinesisStreamShard, SequenceNumber> restoredShard : fakeRestoredState.entrySet()) {
Mockito.verify(mockedFetcher).registerNewSubscribedShardState(
new KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue()));
@@ -728,15 +733,18 @@ public class FlinkKinesisConsumerTest {
@Test
@SuppressWarnings("unchecked")
public void testFetcherShouldBeCorrectlySeededIfRestoringFromCheckpoint() throws Exception {
+
// ----------------------------------------------------------------------
- // setting initial state
+ // setup initial state
// ----------------------------------------------------------------------
+
HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = getFakeRestoredStore("all");
// ----------------------------------------------------------------------
// mock operator state backend and initial state for initializeState()
// ----------------------------------------------------------------------
- TestingListState<Serializable> listState = new TestingListState<>();
+
+ TestingListState<Tuple2<KinesisStreamShard, SequenceNumber>> listState = new TestingListState<>();
for (Map.Entry<KinesisStreamShard, SequenceNumber> state: fakeRestoredState.entrySet()) {
listState.add(Tuple2.of(state.getKey(), state.getValue()));
}
@@ -751,6 +759,7 @@ public class FlinkKinesisConsumerTest {
// ----------------------------------------------------------------------
// mock fetcher
// ----------------------------------------------------------------------
+
KinesisDataFetcher mockedFetcher = Mockito.mock(KinesisDataFetcher.class);
List<KinesisStreamShard> shards = new ArrayList<>();
shards.addAll(fakeRestoredState.keySet());
@@ -762,15 +771,15 @@ public class FlinkKinesisConsumerTest {
PowerMockito.doNothing().when(KinesisConfigUtil.class);
// ----------------------------------------------------------------------
- // start to test seed initial state to fetcher
+ // start to test fetcher's initial state seeding
// ----------------------------------------------------------------------
+
TestableFlinkKinesisConsumer consumer = new TestableFlinkKinesisConsumer(
"fakeStream", new Properties(), 10, 2);
consumer.initializeState(initializationContext);
consumer.open(new Configuration());
consumer.run(Mockito.mock(SourceFunction.SourceContext.class));
- Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true);
for (Map.Entry<KinesisStreamShard, SequenceNumber> restoredShard : fakeRestoredState.entrySet()) {
Mockito.verify(mockedFetcher).registerNewSubscribedShardState(
new KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue()));
@@ -780,9 +789,11 @@ public class FlinkKinesisConsumerTest {
@Test
@SuppressWarnings("unchecked")
public void testFetcherShouldBeCorrectlySeededOnlyItsOwnStates() throws Exception {
+
// ----------------------------------------------------------------------
- // setting initial state
+ // setup initial state
// ----------------------------------------------------------------------
+
HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = getFakeRestoredStore("fakeStream1");
HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredStateForOthers = getFakeRestoredStore("fakeStream2");
@@ -790,7 +801,8 @@ public class FlinkKinesisConsumerTest {
// ----------------------------------------------------------------------
// mock operator state backend and initial state for initializeState()
// ----------------------------------------------------------------------
- TestingListState<Serializable> listState = new TestingListState<>();
+
+ TestingListState<Tuple2<KinesisStreamShard, SequenceNumber>> listState = new TestingListState<>();
for (Map.Entry<KinesisStreamShard, SequenceNumber> state: fakeRestoredState.entrySet()) {
listState.add(Tuple2.of(state.getKey(), state.getValue()));
}
@@ -808,6 +820,7 @@ public class FlinkKinesisConsumerTest {
// ----------------------------------------------------------------------
// mock fetcher
// ----------------------------------------------------------------------
+
KinesisDataFetcher mockedFetcher = Mockito.mock(KinesisDataFetcher.class);
List<KinesisStreamShard> shards = new ArrayList<>();
shards.addAll(fakeRestoredState.keySet());
@@ -819,15 +832,15 @@ public class FlinkKinesisConsumerTest {
PowerMockito.doNothing().when(KinesisConfigUtil.class);
// ----------------------------------------------------------------------
- // start to test seed initial state to fetcher
+ // start to test fetcher's initial state seeding
// ----------------------------------------------------------------------
+
TestableFlinkKinesisConsumer consumer = new TestableFlinkKinesisConsumer(
"fakeStream", new Properties(), 10, 2);
consumer.initializeState(initializationContext);
consumer.open(new Configuration());
consumer.run(Mockito.mock(SourceFunction.SourceContext.class));
- Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true);
for (Map.Entry<KinesisStreamShard, SequenceNumber> restoredShard : fakeRestoredStateForOthers.entrySet()) {
// should never get restored state not belonging to itself
Mockito.verify(mockedFetcher, never()).registerNewSubscribedShardState(
@@ -841,42 +854,49 @@ public class FlinkKinesisConsumerTest {
}
/*
- * If the original parallelism is 2 and states is:
+ * This tests that the consumer correctly picks up shards that were not discovered on the previous run.
+ *
+ * Case under test:
+ *
+ * If the original parallelism is 2 and states are:
* Consumer subtask 1:
* stream1, shard1, SequentialNumber(xxx)
* Consumer subtask 2:
* stream1, shard2, SequentialNumber(yyy)
- * After discoverNewShardsToSubscribe() if there are two shards (shard3, shard4) been created:
+ *
+ * After discoverNewShardsToSubscribe() if there were two shards (shard3, shard4) created:
* Consumer subtask 1 (late for discoverNewShardsToSubscribe()):
* stream1, shard1, SequentialNumber(xxx)
* Consumer subtask 2:
* stream1, shard2, SequentialNumber(yyy)
* stream1, shard4, SequentialNumber(zzz)
- * If snapshotState() occur and parallelism is changed to 1:
- * Union state will be:
+ *
+ * If snapshotState() occurs and parallelism is changed to 1:
+ * Union state will be:
* stream1, shard1, SequentialNumber(xxx)
* stream1, shard2, SequentialNumber(yyy)
* stream1, shard4, SequentialNumber(zzz)
- * Fetcher should be seeded with:
+ * Fetcher should be seeded with:
* stream1, shard1, SequentialNumber(xxx)
* stream1, shard2, SequentialNumber(yyy)
* stream1, share3, SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM
* stream1, shard4, SequentialNumber(zzz)
- *
- * This test is to guarantee the fetcher will be seeded correctly for such situation.
*/
@Test
@SuppressWarnings("unchecked")
public void testFetcherShouldBeCorrectlySeededWithNewDiscoveredKinesisStreamShard() throws Exception {
+
// ----------------------------------------------------------------------
- // setting initial state
+ // setup initial state
// ----------------------------------------------------------------------
+
HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = getFakeRestoredStore("all");
// ----------------------------------------------------------------------
// mock operator state backend and initial state for initializeState()
// ----------------------------------------------------------------------
- TestingListState<Serializable> listState = new TestingListState<>();
+
+ TestingListState<Tuple2<KinesisStreamShard, SequenceNumber>> listState = new TestingListState<>();
for (Map.Entry<KinesisStreamShard, SequenceNumber> state: fakeRestoredState.entrySet()) {
listState.add(Tuple2.of(state.getKey(), state.getValue()));
}
@@ -891,6 +911,7 @@ public class FlinkKinesisConsumerTest {
// ----------------------------------------------------------------------
// mock fetcher
// ----------------------------------------------------------------------
+
KinesisDataFetcher mockedFetcher = Mockito.mock(KinesisDataFetcher.class);
List<KinesisStreamShard> shards = new ArrayList<>();
shards.addAll(fakeRestoredState.keySet());
@@ -904,8 +925,9 @@ public class FlinkKinesisConsumerTest {
PowerMockito.doNothing().when(KinesisConfigUtil.class);
// ----------------------------------------------------------------------
- // start to test seed initial state to fetcher
+ // start to test fetcher's initial state seeding
// ----------------------------------------------------------------------
+
TestableFlinkKinesisConsumer consumer = new TestableFlinkKinesisConsumer(
"fakeStream", new Properties(), 10, 2);
consumer.initializeState(initializationContext);
@@ -915,7 +937,6 @@ public class FlinkKinesisConsumerTest {
fakeRestoredState.put(new KinesisStreamShard("fakeStream2",
new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))),
SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get());
- Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true);
for (Map.Entry<KinesisStreamShard, SequenceNumber> restoredShard : fakeRestoredState.entrySet()) {
Mockito.verify(mockedFetcher).registerNewSubscribedShardState(
new KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue()));
http://git-wip-us.apache.org/repos/asf/flink/blob/e5b65a7f/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java
index e79f9b1..800fde5 100644
--- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java
+++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java
@@ -18,9 +18,14 @@
package org.apache.flink.streaming.connectors.kinesis.internals;
import com.amazonaws.services.kinesis.model.Shard;
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.connectors.kinesis.FlinkKinesisConsumer;
+import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants;
import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard;
import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber;
import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState;
+import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchema;
import org.apache.flink.streaming.connectors.kinesis.testutils.FakeKinesisBehavioursFactory;
import org.apache.flink.streaming.connectors.kinesis.testutils.KinesisShardIdGenerator;
import org.apache.flink.streaming.connectors.kinesis.testutils.TestableKinesisDataFetcher;
@@ -42,6 +47,8 @@ import java.util.UUID;
import java.util.concurrent.atomic.AtomicReference;
import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
@RunWith(PowerMockRunner.class)
@PrepareForTest(TestableKinesisDataFetcher.class)
@@ -67,8 +74,6 @@ public class KinesisDataFetcherTest {
subscribedStreamsToLastSeenShardIdsUnderTest,
FakeKinesisBehavioursFactory.noShardsFoundForRequestedStreamsBehaviour());
- fetcher.setIsRestoringFromFailure(false); // not restoring
-
fetcher.runFetcher(); // this should throw RuntimeException
}
@@ -100,23 +105,30 @@ public class KinesisDataFetcherTest {
subscribedStreamsToLastSeenShardIdsUnderTest,
FakeKinesisBehavioursFactory.nonReshardedStreamsBehaviour(streamToShardCount));
- fetcher.setIsRestoringFromFailure(false);
+ Properties testConfig = new Properties();
+ testConfig.setProperty(ConsumerConfigConstants.AWS_REGION, "us-east-1");
+ testConfig.setProperty(ConsumerConfigConstants.AWS_CREDENTIALS_PROVIDER, "BASIC");
+ testConfig.setProperty(ConsumerConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId");
+ testConfig.setProperty(ConsumerConfigConstants.AWS_SECRET_ACCESS_KEY, "secretKey");
+
+ final DummyFlinkKafkaConsumer<String> consumer = new DummyFlinkKafkaConsumer<>(testConfig, fetcher);
PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class));
- Thread runFetcherThread = new Thread(new Runnable() {
+ Thread consumerThread = new Thread(new Runnable() {
@Override
public void run() {
try {
- fetcher.runFetcher();
+ consumer.run(mock(SourceFunction.SourceContext.class));
} catch (Exception e) {
//
}
}
});
- runFetcherThread.start();
- Thread.sleep(1000); // sleep a while before closing
- fetcher.shutdownFetcher();
+ consumerThread.start();
+ fetcher.waitUntilRun();
+ consumer.cancel();
+ consumerThread.join();
// assert that the streams tracked in the state are identical to the subscribed streams
Set<String> streamsInState = subscribedStreamsToLastSeenShardIdsUnderTest.keySet();
@@ -192,8 +204,6 @@ public class KinesisDataFetcherTest {
new KinesisStreamShardState(restoredState.getKey(), new SequenceNumber(restoredState.getValue())));
}
- fetcher.setIsRestoringFromFailure(true);
-
PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class));
Thread runFetcherThread = new Thread(new Runnable() {
@Override
@@ -284,8 +294,6 @@ public class KinesisDataFetcherTest {
new KinesisStreamShardState(restoredState.getKey(), new SequenceNumber(restoredState.getValue())));
}
- fetcher.setIsRestoringFromFailure(true);
-
PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class));
Thread runFetcherThread = new Thread(new Runnable() {
@Override
@@ -380,8 +388,6 @@ public class KinesisDataFetcherTest {
new KinesisStreamShardState(restoredState.getKey(), new SequenceNumber(restoredState.getValue())));
}
- fetcher.setIsRestoringFromFailure(true);
-
PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class));
Thread runFetcherThread = new Thread(new Runnable() {
@Override
@@ -477,8 +483,6 @@ public class KinesisDataFetcherTest {
new KinesisStreamShardState(restoredState.getKey(), new SequenceNumber(restoredState.getValue())));
}
- fetcher.setIsRestoringFromFailure(true);
-
PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class));
Thread runFetcherThread = new Thread(new Runnable() {
@Override
@@ -507,4 +511,33 @@ public class KinesisDataFetcherTest {
assertTrue(subscribedStreamsToLastSeenShardIdsUnderTest.get("fakeStream3") == null);
assertTrue(subscribedStreamsToLastSeenShardIdsUnderTest.get("fakeStream4") == null);
}
+
+ private static class DummyFlinkKafkaConsumer<T> extends FlinkKinesisConsumer<T> {
+ private static final long serialVersionUID = 1L;
+
+ private KinesisDataFetcher<T> fetcher;
+
+ @SuppressWarnings("unchecked")
+ DummyFlinkKafkaConsumer(Properties properties, KinesisDataFetcher<T> fetcher) {
+ super("test", mock(KinesisDeserializationSchema.class), properties);
+ this.fetcher = fetcher;
+ }
+
+ @Override
+ protected KinesisDataFetcher<T> createFetcher(List<String> streams,
+ SourceFunction.SourceContext<T> sourceContext,
+ RuntimeContext runtimeContext,
+ Properties configProps,
+ KinesisDeserializationSchema<T> deserializationSchema) {
+ return fetcher;
+ }
+
+ @Override
+ public RuntimeContext getRuntimeContext() {
+ RuntimeContext context = mock(RuntimeContext.class);
+ when(context.getIndexOfThisSubtask()).thenReturn(0);
+ when(context.getNumberOfParallelSubtasks()).thenReturn(1);
+ return context;
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/e5b65a7f/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java
index 57886fe..bb644ba 100644
--- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java
+++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java
@@ -18,6 +18,7 @@
package org.apache.flink.streaming.connectors.kinesis.testutils;
import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.core.testutils.OneShotLatch;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.connectors.kinesis.internals.KinesisDataFetcher;
import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState;
@@ -42,6 +43,8 @@ public class TestableKinesisDataFetcher extends KinesisDataFetcher<String> {
private long numElementsCollected;
+ private OneShotLatch runWaiter;
+
public TestableKinesisDataFetcher(List<String> fakeStreams,
Properties fakeConfiguration,
int fakeTotalCountOfSubtasks,
@@ -62,6 +65,7 @@ public class TestableKinesisDataFetcher extends KinesisDataFetcher<String> {
fakeKinesis);
this.numElementsCollected = 0;
+ this.runWaiter = new OneShotLatch();
}
public long getNumOfElementsCollected() {
@@ -81,6 +85,16 @@ public class TestableKinesisDataFetcher extends KinesisDataFetcher<String> {
}
}
+ @Override
+ public void runFetcher() throws Exception {
+ runWaiter.trigger();
+ super.runFetcher();
+ }
+
+ public void waitUntilRun() throws Exception {
+ runWaiter.await();
+ }
+
@SuppressWarnings("unchecked")
private static SourceFunction.SourceContext<String> getMockedSourceContext() {
return Mockito.mock(SourceFunction.SourceContext.class);