You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tz...@apache.org on 2017/07/28 13:53:32 UTC

[2/7] flink git commit: [FLINK-7143] [kafka] Add test for Kafka Consumer rescaling

[FLINK-7143] [kafka] Add test for Kafka Consumer rescaling

This verifies that the consumer always correctly knows whether it is
restored or not and is not affected by changes in the partitions as
reported by Kafka.

Previously, operator state reshuffling could lead to partitions being
subscribed to multiple times.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/e111d773
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/e111d773
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/e111d773

Branch: refs/heads/master
Commit: e111d7730ec6032dc14579bd274e7822f7176e39
Parents: 888fabe
Author: Aljoscha Krettek <al...@gmail.com>
Authored: Tue Jul 18 11:57:46 2017 +0200
Committer: Tzu-Li (Gordon) Tai <tz...@apache.org>
Committed: Fri Jul 28 21:52:29 2017 +0800

----------------------------------------------------------------------
 .../kafka/FlinkKafkaConsumerBaseTest.java       | 169 ++++++++++++++++++-
 .../AbstractPartitionDiscovererTest.java        | 115 ++-----------
 .../testutils/TestPartitionDiscoverer.java      | 125 ++++++++++++++
 3 files changed, 307 insertions(+), 102 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/e111d773/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java b/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
index e0508ce..fef2820 100644
--- a/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
+++ b/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
@@ -25,15 +25,22 @@ import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.state.StateInitializationContext;
 import org.apache.flink.runtime.state.StateSnapshotContextSynchronousImpl;
+import org.apache.flink.streaming.api.TimeCharacteristic;
 import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks;
 import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.operators.StreamSource;
 import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
 import org.apache.flink.streaming.connectors.kafka.config.OffsetCommitMode;
 import org.apache.flink.streaming.connectors.kafka.internals.AbstractFetcher;
 import org.apache.flink.streaming.connectors.kafka.internals.AbstractPartitionDiscoverer;
 import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
 import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicsDescriptor;
+import org.apache.flink.streaming.connectors.kafka.testutils.TestPartitionDiscoverer;
+import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
+import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
 import org.apache.flink.streaming.util.serialization.KeyedDeserializationSchema;
+import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.SerializedValue;
 
 import org.apache.commons.collections.map.LinkedMap;
@@ -47,14 +54,21 @@ import java.io.Serializable;
 import java.lang.reflect.Field;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
 
+import static org.hamcrest.Matchers.everyItem;
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.collection.IsIn.isIn;
+import static org.hamcrest.collection.IsMapContaining.hasKey;
+import static org.hamcrest.core.IsNot.not;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertThat;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 import static org.mockito.Matchers.anyMap;
@@ -181,6 +195,10 @@ public class FlinkKafkaConsumerBaseTest {
 	@Test
 	public void checkRestoredNullCheckpointWhenFetcherNotReady() throws Exception {
 		FlinkKafkaConsumerBase<String> consumer = getConsumer(null, new LinkedMap(), true);
+		StreamingRuntimeContext runtimeContext = mock(StreamingRuntimeContext.class);
+		when(runtimeContext.getIndexOfThisSubtask()).thenReturn(0);
+		when(runtimeContext.getNumberOfParallelSubtasks()).thenReturn(1);
+		consumer.setRuntimeContext(runtimeContext);
 
 		OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
 		TestingListState<Serializable> listState = new TestingListState<>();
@@ -299,6 +317,8 @@ public class FlinkKafkaConsumerBaseTest {
 		FlinkKafkaConsumerBase<String> consumer = getConsumer(fetcher, pendingOffsetsToCommit, true);
 		StreamingRuntimeContext mockRuntimeContext = mock(StreamingRuntimeContext.class);
 		when(mockRuntimeContext.isCheckpointingEnabled()).thenReturn(true); // enable checkpointing
+		when(mockRuntimeContext.getIndexOfThisSubtask()).thenReturn(0);
+		when(mockRuntimeContext.getNumberOfParallelSubtasks()).thenReturn(1);
 		consumer.setRuntimeContext(mockRuntimeContext);
 
 		assertEquals(0, pendingOffsetsToCommit.size());
@@ -426,6 +446,8 @@ public class FlinkKafkaConsumerBaseTest {
 		FlinkKafkaConsumerBase<String> consumer = getConsumer(fetcher, pendingOffsetsToCommit, true);
 		StreamingRuntimeContext mockRuntimeContext = mock(StreamingRuntimeContext.class);
 		when(mockRuntimeContext.isCheckpointingEnabled()).thenReturn(true); // enable checkpointing
+		when(mockRuntimeContext.getIndexOfThisSubtask()).thenReturn(0);
+		when(mockRuntimeContext.getNumberOfParallelSubtasks()).thenReturn(1);
 		consumer.setRuntimeContext(mockRuntimeContext);
 
 		consumer.setCommitOffsetsOnCheckpoints(false); // disable offset committing
@@ -523,6 +545,120 @@ public class FlinkKafkaConsumerBaseTest {
 		verify(fetcher, never()).commitInternalOffsetsToKafka(anyMap()); // not offsets should be committed
 	}
 
+	@Test
+	public void testScaleUp() throws Exception {
+		testRescaling(5, 2, 15, 1000);
+	}
+
+	@Test
+	public void testScaleDown() throws Exception {
+		testRescaling(5, 10, 2, 100);
+	}
+
+	/**
+	 * Tests whether the Kafka consumer behaves correctly when scaling the parallelism up/down,
+	 * which means that operator state is being reshuffled.
+	 *
+	 * <p>This also verifies that a restoring source is always impervious to changes in the list
+	 * of topics fetched from Kafka.
+	 */
+	@SuppressWarnings("unchecked")
+	void testRescaling(
+		final int initialParallelism,
+		final int numPartitions,
+		final int restoredParallelism,
+		final int restoredNumPartitions) throws Exception {
+
+		Preconditions.checkArgument(
+			restoredNumPartitions >= numPartitions,
+			"invalid test case for Kafka repartitioning; Kafka only allows increasing partitions.");
+
+		List<KafkaTopicPartition> mockFetchedPartitionsOnStartup = new ArrayList<>();
+		for (int i = 0; i < numPartitions; i++) {
+			mockFetchedPartitionsOnStartup.add(new KafkaTopicPartition("test-topic", i));
+		}
+
+		DummyFlinkKafkaConsumer<String>[] consumers =
+			new DummyFlinkKafkaConsumer[initialParallelism];
+
+		AbstractStreamOperatorTestHarness<String>[] testHarnesses =
+			new AbstractStreamOperatorTestHarness[initialParallelism];
+
+		for (int i = 0; i < initialParallelism; i++) {
+			consumers[i] = new DummyFlinkKafkaConsumer<>(
+				Collections.singletonList("test-topic"), mockFetchedPartitionsOnStartup);
+			testHarnesses[i] = createTestHarness(consumers[i], initialParallelism, i);
+
+			// initializeState() is always called, null signals that we didn't restore
+			testHarnesses[i].initializeState(null);
+			testHarnesses[i].open();
+		}
+
+		Map<KafkaTopicPartition, Long> globalSubscribedPartitions = new HashMap<>();
+
+		for (int i = 0; i < initialParallelism; i++) {
+			Map<KafkaTopicPartition, Long> subscribedPartitions =
+				consumers[i].getSubscribedPartitionsToStartOffsets();
+
+			// make sure that no one else is subscribed to these partitions
+			for (KafkaTopicPartition partition : subscribedPartitions.keySet()) {
+				assertThat(globalSubscribedPartitions, not(hasKey(partition)));
+			}
+			globalSubscribedPartitions.putAll(subscribedPartitions);
+		}
+
+		assertThat(globalSubscribedPartitions.values(), hasSize(numPartitions));
+		assertThat(mockFetchedPartitionsOnStartup, everyItem(isIn(globalSubscribedPartitions.keySet())));
+
+		OperatorStateHandles[] state = new OperatorStateHandles[initialParallelism];
+
+		for (int i = 0; i < initialParallelism; i++) {
+			state[i] = testHarnesses[i].snapshot(0, 0);
+		}
+
+		OperatorStateHandles mergedState = AbstractStreamOperatorTestHarness.repackageState(state);
+
+		// -----------------------------------------------------------------------------------------
+		// restore
+
+		List<KafkaTopicPartition> mockFetchedPartitionsAfterRestore = new ArrayList<>();
+		for (int i = 0; i < restoredNumPartitions; i++) {
+			mockFetchedPartitionsAfterRestore.add(new KafkaTopicPartition("test-topic", i));
+		}
+
+		DummyFlinkKafkaConsumer<String>[] restoredConsumers =
+			new DummyFlinkKafkaConsumer[restoredParallelism];
+
+		AbstractStreamOperatorTestHarness<String>[] restoredTestHarnesses =
+			new AbstractStreamOperatorTestHarness[restoredParallelism];
+
+		for (int i = 0; i < restoredParallelism; i++) {
+			restoredConsumers[i] = new DummyFlinkKafkaConsumer<>(
+				Collections.singletonList("test-topic"), mockFetchedPartitionsAfterRestore);
+			restoredTestHarnesses[i] = createTestHarness(restoredConsumers[i], restoredParallelism, i);
+
+			// initializeState() is always called, null signals that we didn't restore
+			restoredTestHarnesses[i].initializeState(mergedState);
+			restoredTestHarnesses[i].open();
+		}
+
+		Map<KafkaTopicPartition, Long> restoredGlobalSubscribedPartitions = new HashMap<>();
+
+		for (int i = 0; i < restoredParallelism; i++) {
+			Map<KafkaTopicPartition, Long> subscribedPartitions =
+				restoredConsumers[i].getSubscribedPartitionsToStartOffsets();
+
+			// make sure that no one else is subscribed to these partitions
+			for (KafkaTopicPartition partition : subscribedPartitions.keySet()) {
+				assertThat(restoredGlobalSubscribedPartitions, not(hasKey(partition)));
+			}
+			restoredGlobalSubscribedPartitions.putAll(subscribedPartitions);
+		}
+
+		assertThat(restoredGlobalSubscribedPartitions.values(), hasSize(restoredNumPartitions));
+		assertThat(mockFetchedPartitionsOnStartup, everyItem(isIn(restoredGlobalSubscribedPartitions.keySet())));
+	}
+
 	// ------------------------------------------------------------------------
 
 	private static <T> FlinkKafkaConsumerBase<T> getConsumer(
@@ -547,6 +683,19 @@ public class FlinkKafkaConsumerBaseTest {
 		return consumer;
 	}
 
+	private static <T> AbstractStreamOperatorTestHarness<T> createTestHarness(
+		SourceFunction<T> source, int numSubtasks, int subtaskIndex) throws Exception {
+
+		AbstractStreamOperatorTestHarness<T> testHarness =
+			new AbstractStreamOperatorTestHarness<>(
+				new StreamSource<>(source), Short.MAX_VALUE / 2, numSubtasks, subtaskIndex);
+
+		testHarness.setTimeCharacteristic(TimeCharacteristic.EventTime);
+
+		return testHarness;
+	}
+
+
 	// ------------------------------------------------------------------------
 
 	private static class DummyFlinkKafkaConsumer<T> extends FlinkKafkaConsumerBase<T> {
@@ -554,9 +703,20 @@ public class FlinkKafkaConsumerBaseTest {
 
 		boolean isAutoCommitEnabled = false;
 
-		@SuppressWarnings("unchecked")
+		private List<String> fixedMockGetAllTopicsReturnSequence;
+		private List<KafkaTopicPartition> fixedMockGetAllPartitionsForTopicsReturnSequence;
+
 		public DummyFlinkKafkaConsumer() {
+			this(Collections.singletonList("dummy-topic"), Collections.singletonList(new KafkaTopicPartition("dummy-topic", 0)));
+		}
+
+		@SuppressWarnings("unchecked")
+		public DummyFlinkKafkaConsumer(
+				List<String> fixedMockGetAllTopicsReturnSequence,
+				List<KafkaTopicPartition> fixedMockGetAllPartitionsForTopicsReturnSequence) {
 			super(Arrays.asList("dummy-topic"), null, (KeyedDeserializationSchema < T >) mock(KeyedDeserializationSchema.class), 0);
+			this.fixedMockGetAllTopicsReturnSequence = Preconditions.checkNotNull(fixedMockGetAllTopicsReturnSequence);
+			this.fixedMockGetAllPartitionsForTopicsReturnSequence = Preconditions.checkNotNull(fixedMockGetAllPartitionsForTopicsReturnSequence);
 		}
 
 		@Override
@@ -576,7 +736,12 @@ public class FlinkKafkaConsumerBaseTest {
 				KafkaTopicsDescriptor topicsDescriptor,
 				int indexOfThisSubtask,
 				int numParallelSubtasks) {
-			return mock(AbstractPartitionDiscoverer.class);
+			return new TestPartitionDiscoverer(
+				topicsDescriptor,
+				indexOfThisSubtask,
+				numParallelSubtasks,
+				TestPartitionDiscoverer.createMockGetAllTopicsSequenceFromFixedReturn(fixedMockGetAllTopicsReturnSequence),
+				TestPartitionDiscoverer.createMockGetAllPartitionsFromTopicsSequenceFromFixedReturn(fixedMockGetAllPartitionsForTopicsReturnSequence));
 		}
 
 		@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/e111d773/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/internals/AbstractPartitionDiscovererTest.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/internals/AbstractPartitionDiscovererTest.java b/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/internals/AbstractPartitionDiscovererTest.java
index 4d3e542..2633b95 100644
--- a/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/internals/AbstractPartitionDiscovererTest.java
+++ b/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/internals/AbstractPartitionDiscovererTest.java
@@ -18,11 +18,11 @@
 
 package org.apache.flink.streaming.connectors.kafka.internals;
 
+import org.apache.flink.streaming.connectors.kafka.testutils.TestPartitionDiscoverer;
+
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
-import org.mockito.invocation.InvocationOnMock;
-import org.mockito.stubbing.Answer;
 
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -36,9 +36,6 @@ import java.util.regex.Pattern;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
-import static org.mockito.Matchers.anyInt;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
 
 /**
  * Tests that the partition assignment in the partition discoverer is
@@ -82,8 +79,8 @@ public class AbstractPartitionDiscovererTest {
 					topicsDescriptor,
 					subtaskIndex,
 					mockGetAllPartitionsForTopicsReturn.size(),
-					createMockGetAllTopicsSequenceFromFixedReturn(Collections.singletonList(TEST_TOPIC)),
-					createMockGetAllPartitionsFromTopicsSequenceFromFixedReturn(mockGetAllPartitionsForTopicsReturn));
+					TestPartitionDiscoverer.createMockGetAllTopicsSequenceFromFixedReturn(Collections.singletonList(TEST_TOPIC)),
+					TestPartitionDiscoverer.createMockGetAllPartitionsFromTopicsSequenceFromFixedReturn(mockGetAllPartitionsForTopicsReturn));
 			partitionDiscoverer.open();
 
 			List<KafkaTopicPartition> initialDiscovery = partitionDiscoverer.discoverPartitions();
@@ -127,8 +124,8 @@ public class AbstractPartitionDiscovererTest {
 						topicsDescriptor,
 						subtaskIndex,
 						numConsumers,
-						createMockGetAllTopicsSequenceFromFixedReturn(Collections.singletonList(TEST_TOPIC)),
-						createMockGetAllPartitionsFromTopicsSequenceFromFixedReturn(mockGetAllPartitionsForTopicsReturn));
+						TestPartitionDiscoverer.createMockGetAllTopicsSequenceFromFixedReturn(Collections.singletonList(TEST_TOPIC)),
+						TestPartitionDiscoverer.createMockGetAllPartitionsFromTopicsSequenceFromFixedReturn(mockGetAllPartitionsForTopicsReturn));
 				partitionDiscoverer.open();
 
 				List<KafkaTopicPartition> initialDiscovery = partitionDiscoverer.discoverPartitions();
@@ -179,8 +176,8 @@ public class AbstractPartitionDiscovererTest {
 						topicsDescriptor,
 						subtaskIndex,
 						numConsumers,
-						createMockGetAllTopicsSequenceFromFixedReturn(Collections.singletonList(TEST_TOPIC)),
-						createMockGetAllPartitionsFromTopicsSequenceFromFixedReturn(mockGetAllPartitionsForTopicsReturn));
+						TestPartitionDiscoverer.createMockGetAllTopicsSequenceFromFixedReturn(Collections.singletonList(TEST_TOPIC)),
+						TestPartitionDiscoverer.createMockGetAllPartitionsFromTopicsSequenceFromFixedReturn(mockGetAllPartitionsForTopicsReturn));
 				partitionDiscoverer.open();
 
 				List<KafkaTopicPartition> initialDiscovery = partitionDiscoverer.discoverPartitions();
@@ -240,7 +237,7 @@ public class AbstractPartitionDiscovererTest {
 					topicsDescriptor,
 					0,
 					numConsumers,
-					createMockGetAllTopicsSequenceFromFixedReturn(Collections.singletonList(TEST_TOPIC)),
+					TestPartitionDiscoverer.createMockGetAllTopicsSequenceFromFixedReturn(Collections.singletonList(TEST_TOPIC)),
 					deepClone(mockGetAllPartitionsForTopicsReturnSequence));
 			partitionDiscovererSubtask0.open();
 
@@ -248,7 +245,7 @@ public class AbstractPartitionDiscovererTest {
 					topicsDescriptor,
 					1,
 					numConsumers,
-					createMockGetAllTopicsSequenceFromFixedReturn(Collections.singletonList(TEST_TOPIC)),
+					TestPartitionDiscoverer.createMockGetAllTopicsSequenceFromFixedReturn(Collections.singletonList(TEST_TOPIC)),
 					deepClone(mockGetAllPartitionsForTopicsReturnSequence));
 			partitionDiscovererSubtask1.open();
 
@@ -256,7 +253,7 @@ public class AbstractPartitionDiscovererTest {
 					topicsDescriptor,
 					2,
 					numConsumers,
-					createMockGetAllTopicsSequenceFromFixedReturn(Collections.singletonList(TEST_TOPIC)),
+					TestPartitionDiscoverer.createMockGetAllTopicsSequenceFromFixedReturn(Collections.singletonList(TEST_TOPIC)),
 					deepClone(mockGetAllPartitionsForTopicsReturnSequence));
 			partitionDiscovererSubtask2.open();
 
@@ -375,16 +372,16 @@ public class AbstractPartitionDiscovererTest {
 					topicsDescriptor,
 					subtaskIndex,
 					numSubtasks,
-					createMockGetAllTopicsSequenceFromFixedReturn(Arrays.asList("test-topic", "test-topic2")),
-					createMockGetAllPartitionsFromTopicsSequenceFromFixedReturn(mockGetAllPartitionsForTopicsReturn));
+					TestPartitionDiscoverer.createMockGetAllTopicsSequenceFromFixedReturn(Arrays.asList("test-topic", "test-topic2")),
+					TestPartitionDiscoverer.createMockGetAllPartitionsFromTopicsSequenceFromFixedReturn(mockGetAllPartitionsForTopicsReturn));
 			partitionDiscoverer.open();
 
 			TestPartitionDiscoverer partitionDiscovererOutOfOrder = new TestPartitionDiscoverer(
 					topicsDescriptor,
 					subtaskIndex,
 					numSubtasks,
-					createMockGetAllTopicsSequenceFromFixedReturn(Arrays.asList("test-topic", "test-topic2")),
-					createMockGetAllPartitionsFromTopicsSequenceFromFixedReturn(mockGetAllPartitionsForTopicsReturnOutOfOrder));
+					TestPartitionDiscoverer.createMockGetAllTopicsSequenceFromFixedReturn(Arrays.asList("test-topic", "test-topic2")),
+					TestPartitionDiscoverer.createMockGetAllPartitionsFromTopicsSequenceFromFixedReturn(mockGetAllPartitionsForTopicsReturnOutOfOrder));
 			partitionDiscovererOutOfOrder.open();
 
 			List<KafkaTopicPartition> discoveredPartitions = partitionDiscoverer.discoverPartitions();
@@ -397,88 +394,6 @@ public class AbstractPartitionDiscovererTest {
 		}
 	}
 
-	private static class TestPartitionDiscoverer extends AbstractPartitionDiscoverer {
-
-		private final KafkaTopicsDescriptor topicsDescriptor;
-
-		private final List<List<String>> mockGetAllTopicsReturnSequence;
-		private final List<List<KafkaTopicPartition>> mockGetAllPartitionsForTopicsReturnSequence;
-
-		private int getAllTopicsInvokeCount = 0;
-		private int getAllPartitionsForTopicsInvokeCount = 0;
-
-		public TestPartitionDiscoverer(
-				KafkaTopicsDescriptor topicsDescriptor,
-				int indexOfThisSubtask,
-				int numParallelSubtasks,
-				List<List<String>> mockGetAllTopicsReturnSequence,
-				List<List<KafkaTopicPartition>> mockGetAllPartitionsForTopicsReturnSequence) {
-
-			super(topicsDescriptor, indexOfThisSubtask, numParallelSubtasks);
-
-			this.topicsDescriptor = topicsDescriptor;
-			this.mockGetAllTopicsReturnSequence = mockGetAllTopicsReturnSequence;
-			this.mockGetAllPartitionsForTopicsReturnSequence = mockGetAllPartitionsForTopicsReturnSequence;
-		}
-
-		@Override
-		protected List<String> getAllTopics() {
-			assertTrue(topicsDescriptor.isTopicPattern());
-			return mockGetAllTopicsReturnSequence.get(getAllTopicsInvokeCount++);
-		}
-
-		@Override
-		protected List<KafkaTopicPartition> getAllPartitionsForTopics(List<String> topics) {
-			if (topicsDescriptor.isFixedTopics()) {
-				assertEquals(topicsDescriptor.getFixedTopics(), topics);
-			} else {
-				assertEquals(mockGetAllTopicsReturnSequence.get(getAllPartitionsForTopicsInvokeCount - 1), topics);
-			}
-			return mockGetAllPartitionsForTopicsReturnSequence.get(getAllPartitionsForTopicsInvokeCount++);
-		}
-
-		@Override
-		protected void initializeConnections() {
-			// nothing to do
-		}
-
-		@Override
-		protected void wakeupConnections() {
-			// nothing to do
-		}
-
-		@Override
-		protected void closeConnections() {
-			// nothing to do
-		}
-	}
-
-	private static List<List<String>> createMockGetAllTopicsSequenceFromFixedReturn(final List<String> fixed) {
-		@SuppressWarnings("unchecked")
-		List<List<String>> mockSequence = mock(List.class);
-		when(mockSequence.get(anyInt())).thenAnswer(new Answer<List<String>>() {
-			@Override
-			public List<String> answer(InvocationOnMock invocationOnMock) throws Throwable {
-				return new ArrayList<>(fixed);
-			}
-		});
-
-		return mockSequence;
-	}
-
-	private static List<List<KafkaTopicPartition>> createMockGetAllPartitionsFromTopicsSequenceFromFixedReturn(final List<KafkaTopicPartition> fixed) {
-		@SuppressWarnings("unchecked")
-		List<List<KafkaTopicPartition>> mockSequence = mock(List.class);
-		when(mockSequence.get(anyInt())).thenAnswer(new Answer<List<KafkaTopicPartition>>() {
-			@Override
-			public List<KafkaTopicPartition> answer(InvocationOnMock invocationOnMock) throws Throwable {
-				return new ArrayList<>(fixed);
-			}
-		});
-
-		return mockSequence;
-	}
-
 	private boolean contains(List<KafkaTopicPartition> partitions, int partition) {
 		for (KafkaTopicPartition ktp : partitions) {
 			if (ktp.getPartition() == partition) {

http://git-wip-us.apache.org/repos/asf/flink/blob/e111d773/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/testutils/TestPartitionDiscoverer.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/testutils/TestPartitionDiscoverer.java b/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/testutils/TestPartitionDiscoverer.java
new file mode 100644
index 0000000..1f7c031
--- /dev/null
+++ b/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/testutils/TestPartitionDiscoverer.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka.testutils;
+
+import org.apache.flink.streaming.connectors.kafka.internals.AbstractPartitionDiscoverer;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicPartition;
+import org.apache.flink.streaming.connectors.kafka.internals.KafkaTopicsDescriptor;
+
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Matchers.anyInt;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * Utility {@link AbstractPartitionDiscoverer} for tests that allows
+ * mocking the sequence of consecutive metadata fetch calls to Kafka.
+ */
+public class TestPartitionDiscoverer extends AbstractPartitionDiscoverer {
+
+	private final KafkaTopicsDescriptor topicsDescriptor;
+
+	private final List<List<String>> mockGetAllTopicsReturnSequence;
+	private final List<List<KafkaTopicPartition>> mockGetAllPartitionsForTopicsReturnSequence;
+
+	private int getAllTopicsInvokeCount = 0;
+	private int getAllPartitionsForTopicsInvokeCount = 0;
+
+	public TestPartitionDiscoverer(
+			KafkaTopicsDescriptor topicsDescriptor,
+			int indexOfThisSubtask,
+			int numParallelSubtasks,
+			List<List<String>> mockGetAllTopicsReturnSequence,
+			List<List<KafkaTopicPartition>> mockGetAllPartitionsForTopicsReturnSequence) {
+
+		super(topicsDescriptor, indexOfThisSubtask, numParallelSubtasks);
+
+		this.topicsDescriptor = topicsDescriptor;
+		this.mockGetAllTopicsReturnSequence = mockGetAllTopicsReturnSequence;
+		this.mockGetAllPartitionsForTopicsReturnSequence = mockGetAllPartitionsForTopicsReturnSequence;
+	}
+
+	@Override
+	protected List<String> getAllTopics() {
+		assertTrue(topicsDescriptor.isTopicPattern());
+		return mockGetAllTopicsReturnSequence.get(getAllTopicsInvokeCount++);
+	}
+
+	@Override
+	protected List<KafkaTopicPartition> getAllPartitionsForTopics(List<String> topics) {
+		if (topicsDescriptor.isFixedTopics()) {
+			assertEquals(topicsDescriptor.getFixedTopics(), topics);
+		} else {
+			assertEquals(mockGetAllTopicsReturnSequence.get(getAllPartitionsForTopicsInvokeCount - 1), topics);
+		}
+		return mockGetAllPartitionsForTopicsReturnSequence.get(getAllPartitionsForTopicsInvokeCount++);
+	}
+
+	@Override
+	protected void initializeConnections() {
+		// nothing to do
+	}
+
+	@Override
+	protected void wakeupConnections() {
+		// nothing to do
+	}
+
+	@Override
+	protected void closeConnections() {
+		// nothing to do
+	}
+
+	// ---------------------------------------------------------------------------------
+	//  Utilities to create mocked, fixed results for a sequences of metadata fetches
+	// ---------------------------------------------------------------------------------
+
+	public static List<List<String>> createMockGetAllTopicsSequenceFromFixedReturn(final List<String> fixed) {
+		@SuppressWarnings("unchecked")
+		List<List<String>> mockSequence = mock(List.class);
+		when(mockSequence.get(anyInt())).thenAnswer(new Answer<List<String>>() {
+			@Override
+			public List<String> answer(InvocationOnMock invocationOnMock) throws Throwable {
+				return new ArrayList<>(fixed);
+			}
+		});
+
+		return mockSequence;
+	}
+
+	public static List<List<KafkaTopicPartition>> createMockGetAllPartitionsFromTopicsSequenceFromFixedReturn(final List<KafkaTopicPartition> fixed) {
+		@SuppressWarnings("unchecked")
+		List<List<KafkaTopicPartition>> mockSequence = mock(List.class);
+		when(mockSequence.get(anyInt())).thenAnswer(new Answer<List<KafkaTopicPartition>>() {
+			@Override
+			public List<KafkaTopicPartition> answer(InvocationOnMock invocationOnMock) throws Throwable {
+				return new ArrayList<>(fixed);
+			}
+		});
+
+		return mockSequence;
+	}
+}