You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tz...@apache.org on 2018/01/12 13:31:08 UTC

[16/19] flink git commit: [FLINK-8296] [kafka] Rework FlinkKafkaConsumerBaseTest to not rely on Java reflection

[FLINK-8296] [kafka] Rework FlinkKafkaConsumerBaseTest to not rely on Java reflection

Reflection was mainly used to inject mocks into private fields of the
FlinkKafkaConsumerBase, without the need to fully execute all operator
life cycle methods. This, however, caused the unit tests to be too
implementation-specific.

This commit reworks the FlinkKafkaConsumerBaseTest to remove test
consumer instantiation methods that rely on reflection for dependency
injection. All tests now instantiate dummy test consumers normally, and
let all tests properly execute all operator life cycle methods
regardless of the tested logic.

This closes #5188.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/37cdaf97
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/37cdaf97
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/37cdaf97

Branch: refs/heads/master
Commit: 37cdaf976ff198a6e5c1d0e6e38a50de185cec1e
Parents: faaa135
Author: Tzu-Li (Gordon) Tai <tz...@apache.org>
Authored: Tue Dec 19 16:10:44 2017 -0800
Committer: Tzu-Li (Gordon) Tai <tz...@apache.org>
Committed: Fri Jan 12 19:43:28 2018 +0800

----------------------------------------------------------------------
 .../kafka/FlinkKafkaConsumerBase.java           |   5 +
 .../kafka/FlinkKafkaConsumerBaseTest.java       | 532 ++++++++++++-------
 2 files changed, 338 insertions(+), 199 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/37cdaf97/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java b/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
index c350442..7a87f4d 100644
--- a/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
+++ b/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBase.java
@@ -884,4 +884,9 @@ public abstract class FlinkKafkaConsumerBase<T> extends RichParallelSourceFuncti
 	OffsetCommitMode getOffsetCommitMode() {
 		return offsetCommitMode;
 	}
+
+	@VisibleForTesting
+	LinkedMap getPendingOffsetsToCommit() {
+		return pendingOffsetsToCommit;
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/37cdaf97/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java b/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
index 180b12a..f8aeea2 100644
--- a/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
+++ b/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java
@@ -18,17 +18,27 @@
 
 package org.apache.flink.streaming.connectors.kafka;
 
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.accumulators.Accumulator;
+import org.apache.flink.api.common.state.KeyedStateStore;
 import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.state.OperatorStateStore;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.configuration.Configuration;
-import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.core.testutils.CheckedThread;
+import org.apache.flink.core.testutils.OneShotLatch;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.metrics.groups.UnregisteredMetricsGroup;
+import org.apache.flink.runtime.memory.MemoryManager;
+import org.apache.flink.runtime.operators.testutils.MockEnvironment;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
 import org.apache.flink.runtime.state.StateSnapshotContextSynchronousImpl;
 import org.apache.flink.streaming.api.TimeCharacteristic;
 import org.apache.flink.streaming.api.functions.AssignerWithPeriodicWatermarks;
 import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
 import org.apache.flink.streaming.api.operators.StreamSource;
 import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
 import org.apache.flink.streaming.connectors.kafka.config.OffsetCommitMode;
@@ -44,16 +54,13 @@ import org.apache.flink.streaming.util.serialization.KeyedDeserializationSchema;
 import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.SerializedValue;
 
-import org.apache.commons.collections.map.LinkedMap;
 import org.junit.Assert;
 import org.junit.Test;
 import org.mockito.Matchers;
 import org.mockito.Mockito;
 
 import java.io.Serializable;
-import java.lang.reflect.Field;
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -86,20 +93,19 @@ public class FlinkKafkaConsumerBaseTest {
 	 * Tests that not both types of timestamp extractors / watermark generators can be used.
 	 */
 	@Test
+	@SuppressWarnings("unchecked")
 	public void testEitherWatermarkExtractor() {
 		try {
-			new DummyFlinkKafkaConsumer<>().assignTimestampsAndWatermarks((AssignerWithPeriodicWatermarks<Object>) null);
+			new DummyFlinkKafkaConsumer<String>().assignTimestampsAndWatermarks((AssignerWithPeriodicWatermarks<String>) null);
 			fail();
 		} catch (NullPointerException ignored) {}
 
 		try {
-			new DummyFlinkKafkaConsumer<>().assignTimestampsAndWatermarks((AssignerWithPunctuatedWatermarks<Object>) null);
+			new DummyFlinkKafkaConsumer<String>().assignTimestampsAndWatermarks((AssignerWithPunctuatedWatermarks<String>) null);
 			fail();
 		} catch (NullPointerException ignored) {}
 
-		@SuppressWarnings("unchecked")
 		final AssignerWithPeriodicWatermarks<String> periodicAssigner = mock(AssignerWithPeriodicWatermarks.class);
-		@SuppressWarnings("unchecked")
 		final AssignerWithPunctuatedWatermarks<String> punctuatedAssigner = mock(AssignerWithPunctuatedWatermarks.class);
 
 		DummyFlinkKafkaConsumer<String> c1 = new DummyFlinkKafkaConsumer<>();
@@ -123,17 +129,16 @@ public class FlinkKafkaConsumerBaseTest {
 	@Test
 	public void ignoreCheckpointWhenNotRunning() throws Exception {
 		@SuppressWarnings("unchecked")
-		final AbstractFetcher<String, ?> fetcher = mock(AbstractFetcher.class);
+		final FlinkKafkaConsumerBase<String> consumer = new DummyFlinkKafkaConsumer<>();
 
-		FlinkKafkaConsumerBase<String> consumer = getConsumer(fetcher, new LinkedMap(), false);
-		OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
-		TestingListState<Tuple2<KafkaTopicPartition, Long>> listState = new TestingListState<>();
-		when(operatorStateStore.getListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+		final TestingListState<Tuple2<KafkaTopicPartition, Long>> listState = new TestingListState<>();
+		setupConsumer(consumer, false, listState, true, 0, 1);
 
+		// snapshot before the fetcher starts running
 		consumer.snapshotState(new StateSnapshotContextSynchronousImpl(1, 1));
 
+		// no state should have been checkpointed
 		assertFalse(listState.get().iterator().hasNext());
-		consumer.notifyCheckpointComplete(66L);
 	}
 
 	/**
@@ -142,32 +147,13 @@ public class FlinkKafkaConsumerBaseTest {
 	 */
 	@Test
 	public void checkRestoredCheckpointWhenFetcherNotReady() throws Exception {
-		OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
-
-		TestingListState<Serializable> restoredListState = new TestingListState<>();
-		restoredListState.add(Tuple2.of(new KafkaTopicPartition("abc", 13), 16768L));
-		restoredListState.add(Tuple2.of(new KafkaTopicPartition("def", 7), 987654321L));
-
-		FlinkKafkaConsumerBase<String> consumer = getConsumer(null, new LinkedMap(), true);
-		StreamingRuntimeContext context = mock(StreamingRuntimeContext.class);
-		when(context.getNumberOfParallelSubtasks()).thenReturn(1);
-		when(context.getIndexOfThisSubtask()).thenReturn(0);
-		consumer.setRuntimeContext(context);
-
-		// mock old 1.2 state (empty)
-		when(operatorStateStore.getSerializableListState(Matchers.any(String.class))).thenReturn(new TestingListState<Serializable>());
-		// mock 1.3 state
-		when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(restoredListState);
-
-		StateInitializationContext initializationContext = mock(StateInitializationContext.class);
-
-		when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore);
-		when(initializationContext.isRestored()).thenReturn(true);
-
-		consumer.initializeState(initializationContext);
+		@SuppressWarnings("unchecked")
+		final FlinkKafkaConsumerBase<String> consumer = new DummyFlinkKafkaConsumer<>();
 
-		consumer.open(new Configuration());
+		final TestingListState<Tuple2<KafkaTopicPartition, Long>> restoredListState = new TestingListState<>();
+		setupConsumer(consumer, true, restoredListState, true, 0, 1);
 
+		// snapshot before the fetcher starts running
 		consumer.snapshotState(new StateSnapshotContextSynchronousImpl(17, 17));
 
 		// ensure that the list was cleared and refilled. while this is an implementation detail, we use it here
@@ -192,67 +178,68 @@ public class FlinkKafkaConsumerBaseTest {
 
 	@Test
 	public void testConfigureOnCheckpointsCommitMode() throws Exception {
+		@SuppressWarnings("unchecked")
+		// auto-commit enabled; this should be ignored in this case
+		final DummyFlinkKafkaConsumer<String> consumer = new DummyFlinkKafkaConsumer<>(true);
 
-		DummyFlinkKafkaConsumer consumer = new DummyFlinkKafkaConsumer();
-		consumer.setIsAutoCommitEnabled(true); // this should be ignored
-
-		StreamingRuntimeContext context = mock(StreamingRuntimeContext.class);
-		when(context.getIndexOfThisSubtask()).thenReturn(0);
-		when(context.getNumberOfParallelSubtasks()).thenReturn(1);
-		when(context.isCheckpointingEnabled()).thenReturn(true); // enable checkpointing, auto commit should be ignored
-		consumer.setRuntimeContext(context);
+		setupConsumer(
+			consumer,
+			false,
+			null,
+			true, // enable checkpointing; auto commit should be ignored
+			0,
+			1);
 
-		consumer.open(new Configuration());
 		assertEquals(OffsetCommitMode.ON_CHECKPOINTS, consumer.getOffsetCommitMode());
 	}
 
 	@Test
 	public void testConfigureAutoCommitMode() throws Exception {
+		@SuppressWarnings("unchecked")
+		final DummyFlinkKafkaConsumer<String> consumer = new DummyFlinkKafkaConsumer<>(true);
 
-		DummyFlinkKafkaConsumer consumer = new DummyFlinkKafkaConsumer();
-		consumer.setIsAutoCommitEnabled(true);
-
-		StreamingRuntimeContext context = mock(StreamingRuntimeContext.class);
-		when(context.getIndexOfThisSubtask()).thenReturn(0);
-		when(context.getNumberOfParallelSubtasks()).thenReturn(1);
-		when(context.isCheckpointingEnabled()).thenReturn(false); // disable checkpointing, auto commit should be respected
-		consumer.setRuntimeContext(context);
+		setupConsumer(
+			consumer,
+			false,
+			null,
+			false, // disable checkpointing; auto commit should be respected
+			0,
+			1);
 
-		consumer.open(new Configuration());
 		assertEquals(OffsetCommitMode.KAFKA_PERIODIC, consumer.getOffsetCommitMode());
 	}
 
 	@Test
 	public void testConfigureDisableOffsetCommitWithCheckpointing() throws Exception {
-
-		DummyFlinkKafkaConsumer consumer = new DummyFlinkKafkaConsumer();
-		consumer.setIsAutoCommitEnabled(true); // this should be ignored
-
-		StreamingRuntimeContext context = mock(StreamingRuntimeContext.class);
-		when(context.getIndexOfThisSubtask()).thenReturn(0);
-		when(context.getNumberOfParallelSubtasks()).thenReturn(1);
-		when(context.isCheckpointingEnabled()).thenReturn(true); // enable checkpointing, auto commit should be ignored
-		consumer.setRuntimeContext(context);
-
+		@SuppressWarnings("unchecked")
+		// auto-commit enabled; this should be ignored in this case
+		final DummyFlinkKafkaConsumer<String> consumer = new DummyFlinkKafkaConsumer<>(true);
 		consumer.setCommitOffsetsOnCheckpoints(false); // disabling offset committing should override everything
 
-		consumer.open(new Configuration());
+		setupConsumer(
+			consumer,
+			false,
+			null,
+			true, // enable checkpointing; auto commit should be ignored
+			0,
+			1);
+
 		assertEquals(OffsetCommitMode.DISABLED, consumer.getOffsetCommitMode());
 	}
 
 	@Test
 	public void testConfigureDisableOffsetCommitWithoutCheckpointing() throws Exception {
+		@SuppressWarnings("unchecked")
+		final DummyFlinkKafkaConsumer<String> consumer = new DummyFlinkKafkaConsumer<>(false);
 
-		DummyFlinkKafkaConsumer consumer = new DummyFlinkKafkaConsumer();
-		consumer.setIsAutoCommitEnabled(false);
-
-		StreamingRuntimeContext context = mock(StreamingRuntimeContext.class);
-		when(context.getIndexOfThisSubtask()).thenReturn(0);
-		when(context.getNumberOfParallelSubtasks()).thenReturn(1);
-		when(context.isCheckpointingEnabled()).thenReturn(false); // disable checkpointing, auto commit should be respected
-		consumer.setRuntimeContext(context);
+		setupConsumer(
+			consumer,
+			false,
+			null,
+			false, // disable checkpointing; auto commit should be respected
+			0,
+			1);
 
-		consumer.open(new Configuration());
 		assertEquals(OffsetCommitMode.DISABLED, consumer.getOffsetCommitMode());
 	}
 
@@ -278,36 +265,37 @@ public class FlinkKafkaConsumerBaseTest {
 
 		// --------------------------------------------------------------------
 
-		final AbstractFetcher<String, ?> fetcher = mock(AbstractFetcher.class);
+		final OneShotLatch runLatch = new OneShotLatch();
+		final OneShotLatch stopLatch = new OneShotLatch();
+		final AbstractFetcher<String, ?> fetcher = getRunnableMockFetcher(runLatch, stopLatch);
 		when(fetcher.snapshotCurrentState()).thenReturn(state1, state2, state3);
 
-		final LinkedMap pendingOffsetsToCommit = new LinkedMap();
-
-		FlinkKafkaConsumerBase<String> consumer = getConsumer(fetcher, pendingOffsetsToCommit, true);
-		StreamingRuntimeContext mockRuntimeContext = mock(StreamingRuntimeContext.class);
-		when(mockRuntimeContext.isCheckpointingEnabled()).thenReturn(true); // enable checkpointing
-		when(mockRuntimeContext.getIndexOfThisSubtask()).thenReturn(0);
-		when(mockRuntimeContext.getNumberOfParallelSubtasks()).thenReturn(1);
-		consumer.setRuntimeContext(mockRuntimeContext);
-
-		assertEquals(0, pendingOffsetsToCommit.size());
-
-		OperatorStateStore backend = mock(OperatorStateStore.class);
+		final FlinkKafkaConsumerBase<String> consumer = new DummyFlinkKafkaConsumer<>(
+				fetcher,
+				mock(AbstractPartitionDiscoverer.class),
+				false);
 
-		TestingListState<Serializable> listState = new TestingListState<>();
-		// mock old 1.2 state (empty)
-		when(backend.getSerializableListState(Matchers.any(String.class))).thenReturn(new TestingListState<Serializable>());
-		// mock 1.3 state
-		when(backend.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+		final TestingListState<Serializable> listState = new TestingListState<>();
 
-		StateInitializationContext initializationContext = mock(StateInitializationContext.class);
+		// setup and run the consumer; wait until the consumer reaches the main fetch loop before continuing test
+		setupConsumer(consumer, false, listState, true, 0, 1);
 
-		when(initializationContext.getOperatorStateStore()).thenReturn(backend);
-		when(initializationContext.isRestored()).thenReturn(false, true, true, true);
+		final CheckedThread runThread = new CheckedThread() {
+			@Override
+			public void go() throws Exception {
+				consumer.run(mock(SourceFunction.SourceContext.class));
+			}
 
-		consumer.initializeState(initializationContext);
+			@Override
+			public void sync() throws Exception {
+				stopLatch.trigger();
+				super.sync();
+			}
+		};
+		runThread.start();
+		runLatch.await();
 
-		consumer.open(new Configuration());
+		assertEquals(0, consumer.getPendingOffsetsToCommit().size());
 
 		// checkpoint 1
 		consumer.snapshotState(new StateSnapshotContextSynchronousImpl(138, 138));
@@ -320,8 +308,8 @@ public class FlinkKafkaConsumerBaseTest {
 		}
 
 		assertEquals(state1, snapshot1);
-		assertEquals(1, pendingOffsetsToCommit.size());
-		assertEquals(state1, pendingOffsetsToCommit.get(138L));
+		assertEquals(1, consumer.getPendingOffsetsToCommit().size());
+		assertEquals(state1, consumer.getPendingOffsetsToCommit().get(138L));
 
 		// checkpoint 2
 		consumer.snapshotState(new StateSnapshotContextSynchronousImpl(140, 140));
@@ -334,13 +322,13 @@ public class FlinkKafkaConsumerBaseTest {
 		}
 
 		assertEquals(state2, snapshot2);
-		assertEquals(2, pendingOffsetsToCommit.size());
-		assertEquals(state2, pendingOffsetsToCommit.get(140L));
+		assertEquals(2, consumer.getPendingOffsetsToCommit().size());
+		assertEquals(state2, consumer.getPendingOffsetsToCommit().get(140L));
 
 		// ack checkpoint 1
 		consumer.notifyCheckpointComplete(138L);
-		assertEquals(1, pendingOffsetsToCommit.size());
-		assertTrue(pendingOffsetsToCommit.containsKey(140L));
+		assertEquals(1, consumer.getPendingOffsetsToCommit().size());
+		assertTrue(consumer.getPendingOffsetsToCommit().containsKey(140L));
 
 		// checkpoint 3
 		consumer.snapshotState(new StateSnapshotContextSynchronousImpl(141, 141));
@@ -353,37 +341,35 @@ public class FlinkKafkaConsumerBaseTest {
 		}
 
 		assertEquals(state3, snapshot3);
-		assertEquals(2, pendingOffsetsToCommit.size());
-		assertEquals(state3, pendingOffsetsToCommit.get(141L));
+		assertEquals(2, consumer.getPendingOffsetsToCommit().size());
+		assertEquals(state3, consumer.getPendingOffsetsToCommit().get(141L));
 
 		// ack checkpoint 3, subsumes number 2
 		consumer.notifyCheckpointComplete(141L);
-		assertEquals(0, pendingOffsetsToCommit.size());
+		assertEquals(0, consumer.getPendingOffsetsToCommit().size());
 
 		consumer.notifyCheckpointComplete(666); // invalid checkpoint
-		assertEquals(0, pendingOffsetsToCommit.size());
-
-		OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
-		listState = new TestingListState<>();
-		when(operatorStateStore.getListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+		assertEquals(0, consumer.getPendingOffsetsToCommit().size());
 
 		// create 500 snapshots
 		for (int i = 100; i < 600; i++) {
 			consumer.snapshotState(new StateSnapshotContextSynchronousImpl(i, i));
 			listState.clear();
 		}
-		assertEquals(FlinkKafkaConsumerBase.MAX_NUM_PENDING_CHECKPOINTS, pendingOffsetsToCommit.size());
+		assertEquals(FlinkKafkaConsumerBase.MAX_NUM_PENDING_CHECKPOINTS, consumer.getPendingOffsetsToCommit().size());
 
 		// commit only the second last
 		consumer.notifyCheckpointComplete(598);
-		assertEquals(1, pendingOffsetsToCommit.size());
+		assertEquals(1, consumer.getPendingOffsetsToCommit().size());
 
 		// access invalid checkpoint
 		consumer.notifyCheckpointComplete(590);
 
 		// and the last
 		consumer.notifyCheckpointComplete(599);
-		assertEquals(0, pendingOffsetsToCommit.size());
+		assertEquals(0, consumer.getPendingOffsetsToCommit().size());
+
+		runThread.sync();
 	}
 
 	@Test
@@ -407,38 +393,38 @@ public class FlinkKafkaConsumerBaseTest {
 
 		// --------------------------------------------------------------------
 
-		final AbstractFetcher<String, ?> fetcher = mock(AbstractFetcher.class);
+		final OneShotLatch runLatch = new OneShotLatch();
+		final OneShotLatch stopLatch = new OneShotLatch();
+		final AbstractFetcher<String, ?> fetcher = getRunnableMockFetcher(runLatch, stopLatch);
 		when(fetcher.snapshotCurrentState()).thenReturn(state1, state2, state3);
 
-		final LinkedMap pendingOffsetsToCommit = new LinkedMap();
-
-		FlinkKafkaConsumerBase<String> consumer = getConsumer(fetcher, pendingOffsetsToCommit, true);
-		StreamingRuntimeContext mockRuntimeContext = mock(StreamingRuntimeContext.class);
-		when(mockRuntimeContext.isCheckpointingEnabled()).thenReturn(true); // enable checkpointing
-		when(mockRuntimeContext.getIndexOfThisSubtask()).thenReturn(0);
-		when(mockRuntimeContext.getNumberOfParallelSubtasks()).thenReturn(1);
-		consumer.setRuntimeContext(mockRuntimeContext);
-
+		final FlinkKafkaConsumerBase<String> consumer = new DummyFlinkKafkaConsumer<>(
+				fetcher,
+				mock(AbstractPartitionDiscoverer.class),
+				false);
 		consumer.setCommitOffsetsOnCheckpoints(false); // disable offset committing
 
-		assertEquals(0, pendingOffsetsToCommit.size());
-
-		OperatorStateStore backend = mock(OperatorStateStore.class);
+		final TestingListState<Serializable> listState = new TestingListState<>();
 
-		TestingListState<Serializable> listState = new TestingListState<>();
-		// mock old 1.2 state (empty)
-		when(backend.getSerializableListState(Matchers.any(String.class))).thenReturn(new TestingListState<Serializable>());
-		// mock 1.3 state
-		when(backend.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
+		// setup and run the consumer; wait until the consumer reaches the main fetch loop before continuing test
+		setupConsumer(consumer, false, listState, true, 0, 1);
 
-		StateInitializationContext initializationContext = mock(StateInitializationContext.class);
-
-		when(initializationContext.getOperatorStateStore()).thenReturn(backend);
-		when(initializationContext.isRestored()).thenReturn(false, true, true, true);
+		final CheckedThread runThread = new CheckedThread() {
+			@Override
+			public void go() throws Exception {
+				consumer.run(mock(SourceFunction.SourceContext.class));
+			}
 
-		consumer.initializeState(initializationContext);
+			@Override
+			public void sync() throws Exception {
+				stopLatch.trigger();
+				super.sync();
+			}
+		};
+		runThread.start();
+		runLatch.await();
 
-		consumer.open(new Configuration());
+		assertEquals(0, consumer.getPendingOffsetsToCommit().size());
 
 		// checkpoint 1
 		consumer.snapshotState(new StateSnapshotContextSynchronousImpl(138, 138));
@@ -451,7 +437,7 @@ public class FlinkKafkaConsumerBaseTest {
 		}
 
 		assertEquals(state1, snapshot1);
-		assertEquals(0, pendingOffsetsToCommit.size()); // pending offsets to commit should not be updated
+		assertEquals(0, consumer.getPendingOffsetsToCommit().size()); // pending offsets to commit should not be updated
 
 		// checkpoint 2
 		consumer.snapshotState(new StateSnapshotContextSynchronousImpl(140, 140));
@@ -464,7 +450,7 @@ public class FlinkKafkaConsumerBaseTest {
 		}
 
 		assertEquals(state2, snapshot2);
-		assertEquals(0, pendingOffsetsToCommit.size()); // pending offsets to commit should not be updated
+		assertEquals(0, consumer.getPendingOffsetsToCommit().size()); // pending offsets to commit should not be updated
 
 		// ack checkpoint 1
 		consumer.notifyCheckpointComplete(138L);
@@ -481,7 +467,7 @@ public class FlinkKafkaConsumerBaseTest {
 		}
 
 		assertEquals(state3, snapshot3);
-		assertEquals(0, pendingOffsetsToCommit.size()); // pending offsets to commit should not be updated
+		assertEquals(0, consumer.getPendingOffsetsToCommit().size()); // pending offsets to commit should not be updated
 
 		// ack checkpoint 3, subsumes number 2
 		consumer.notifyCheckpointComplete(141L);
@@ -490,16 +476,12 @@ public class FlinkKafkaConsumerBaseTest {
 		consumer.notifyCheckpointComplete(666); // invalid checkpoint
 		verify(fetcher, never()).commitInternalOffsetsToKafka(anyMap(), Matchers.any(KafkaCommitCallback.class)); // no offsets should be committed
 
-		OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
-		listState = new TestingListState<>();
-		when(operatorStateStore.getListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
-
 		// create 500 snapshots
 		for (int i = 100; i < 600; i++) {
 			consumer.snapshotState(new StateSnapshotContextSynchronousImpl(i, i));
 			listState.clear();
 		}
-		assertEquals(0, pendingOffsetsToCommit.size()); // pending offsets to commit should not be updated
+		assertEquals(0, consumer.getPendingOffsetsToCommit().size()); // pending offsets to commit should not be updated
 
 		// commit only the second last
 		consumer.notifyCheckpointComplete(598);
@@ -532,7 +514,7 @@ public class FlinkKafkaConsumerBaseTest {
 	 * of topics fetched from Kafka.
 	 */
 	@SuppressWarnings("unchecked")
-	void testRescaling(
+	private void testRescaling(
 		final int initialParallelism,
 		final int numPartitions,
 		final int restoredParallelism,
@@ -554,8 +536,14 @@ public class FlinkKafkaConsumerBaseTest {
 			new AbstractStreamOperatorTestHarness[initialParallelism];
 
 		for (int i = 0; i < initialParallelism; i++) {
-			consumers[i] = new DummyFlinkKafkaConsumer<>(
-				Collections.singletonList("test-topic"), mockFetchedPartitionsOnStartup);
+			TestPartitionDiscoverer partitionDiscoverer = new TestPartitionDiscoverer(
+				new KafkaTopicsDescriptor(Collections.singletonList("test-topic"), null),
+				i,
+				initialParallelism,
+				TestPartitionDiscoverer.createMockGetAllTopicsSequenceFromFixedReturn(Collections.singletonList("test-topic")),
+				TestPartitionDiscoverer.createMockGetAllPartitionsFromTopicsSequenceFromFixedReturn(mockFetchedPartitionsOnStartup));
+
+			consumers[i] = new DummyFlinkKafkaConsumer<>(mock(AbstractFetcher.class), partitionDiscoverer, false);
 			testHarnesses[i] = createTestHarness(consumers[i], initialParallelism, i);
 
 			// initializeState() is always called, null signals that we didn't restore
@@ -602,8 +590,14 @@ public class FlinkKafkaConsumerBaseTest {
 			new AbstractStreamOperatorTestHarness[restoredParallelism];
 
 		for (int i = 0; i < restoredParallelism; i++) {
-			restoredConsumers[i] = new DummyFlinkKafkaConsumer<>(
-				Collections.singletonList("test-topic"), mockFetchedPartitionsAfterRestore);
+			TestPartitionDiscoverer partitionDiscoverer = new TestPartitionDiscoverer(
+				new KafkaTopicsDescriptor(Collections.singletonList("test-topic"), null),
+				i,
+				restoredParallelism,
+				TestPartitionDiscoverer.createMockGetAllTopicsSequenceFromFixedReturn(Collections.singletonList("test-topic")),
+				TestPartitionDiscoverer.createMockGetAllPartitionsFromTopicsSequenceFromFixedReturn(mockFetchedPartitionsAfterRestore));
+
+			restoredConsumers[i] = new DummyFlinkKafkaConsumer<>(mock(AbstractFetcher.class), partitionDiscoverer, false);
 			restoredTestHarnesses[i] = createTestHarness(restoredConsumers[i], restoredParallelism, i);
 
 			// initializeState() is always called, null signals that we didn't restore
@@ -630,28 +624,6 @@ public class FlinkKafkaConsumerBaseTest {
 
 	// ------------------------------------------------------------------------
 
-	private static <T> FlinkKafkaConsumerBase<T> getConsumer(
-			AbstractFetcher<T, ?> fetcher, LinkedMap pendingOffsetsToCommit, boolean running) throws Exception {
-		FlinkKafkaConsumerBase<T> consumer = new DummyFlinkKafkaConsumer<>();
-		StreamingRuntimeContext mockRuntimeContext = mock(StreamingRuntimeContext.class);
-		Mockito.when(mockRuntimeContext.isCheckpointingEnabled()).thenReturn(true);
-		consumer.setRuntimeContext(mockRuntimeContext);
-
-		Field fetcherField = FlinkKafkaConsumerBase.class.getDeclaredField("kafkaFetcher");
-		fetcherField.setAccessible(true);
-		fetcherField.set(consumer, fetcher);
-
-		Field mapField = FlinkKafkaConsumerBase.class.getDeclaredField("pendingOffsetsToCommit");
-		mapField.setAccessible(true);
-		mapField.set(consumer, pendingOffsetsToCommit);
-
-		Field runningField = FlinkKafkaConsumerBase.class.getDeclaredField("running");
-		runningField.setAccessible(true);
-		runningField.set(consumer, running);
-
-		return consumer;
-	}
-
 	private static <T> AbstractStreamOperatorTestHarness<T> createTestHarness(
 		SourceFunction<T> source, int numSubtasks, int subtaskIndex) throws Exception {
 
@@ -667,25 +639,43 @@ public class FlinkKafkaConsumerBaseTest {
 
 	// ------------------------------------------------------------------------
 
+	/**
+	 * An instantiable dummy {@link FlinkKafkaConsumerBase} that supports injecting
+	 * mocks for {@link FlinkKafkaConsumerBase#kafkaFetcher}, {@link FlinkKafkaConsumerBase#partitionDiscoverer},
+	 * and {@link FlinkKafkaConsumerBase#getIsAutoCommitEnabled()}.
+	 */
 	private static class DummyFlinkKafkaConsumer<T> extends FlinkKafkaConsumerBase<T> {
 		private static final long serialVersionUID = 1L;
 
-		boolean isAutoCommitEnabled = false;
+		private AbstractFetcher<T, ?> testFetcher;
+		private AbstractPartitionDiscoverer testPartitionDiscoverer;
+		private boolean isAutoCommitEnabled;
 
-		private List<String> fixedMockGetAllTopicsReturnSequence;
-		private List<KafkaTopicPartition> fixedMockGetAllPartitionsForTopicsReturnSequence;
+		@SuppressWarnings("unchecked")
+		DummyFlinkKafkaConsumer() {
+			this(false);
+		}
 
-		public DummyFlinkKafkaConsumer() {
-			this(Collections.singletonList("dummy-topic"), Collections.singletonList(new KafkaTopicPartition("dummy-topic", 0)));
+		@SuppressWarnings("unchecked")
+		DummyFlinkKafkaConsumer(boolean isAutoCommitEnabled) {
+			this(mock(AbstractFetcher.class), mock(AbstractPartitionDiscoverer.class), isAutoCommitEnabled);
 		}
 
 		@SuppressWarnings("unchecked")
-		public DummyFlinkKafkaConsumer(
-				List<String> fixedMockGetAllTopicsReturnSequence,
-				List<KafkaTopicPartition> fixedMockGetAllPartitionsForTopicsReturnSequence) {
-			super(Arrays.asList("dummy-topic"), null, (KeyedDeserializationSchema < T >) mock(KeyedDeserializationSchema.class), 0);
-			this.fixedMockGetAllTopicsReturnSequence = Preconditions.checkNotNull(fixedMockGetAllTopicsReturnSequence);
-			this.fixedMockGetAllPartitionsForTopicsReturnSequence = Preconditions.checkNotNull(fixedMockGetAllPartitionsForTopicsReturnSequence);
+		DummyFlinkKafkaConsumer(
+				AbstractFetcher<T, ?> testFetcher,
+				AbstractPartitionDiscoverer testPartitionDiscoverer,
+				boolean isAutoCommitEnabled) {
+
+			super(
+					Collections.singletonList("dummy-topic"),
+					null,
+					(KeyedDeserializationSchema < T >) mock(KeyedDeserializationSchema.class),
+					PARTITION_DISCOVERY_DISABLED);
+
+			this.testFetcher = testFetcher;
+			this.testPartitionDiscoverer = testPartitionDiscoverer;
+			this.isAutoCommitEnabled = isAutoCommitEnabled;
 		}
 
 		@Override
@@ -697,7 +687,7 @@ public class FlinkKafkaConsumerBaseTest {
 				SerializedValue<AssignerWithPunctuatedWatermarks<T>> watermarksPunctuated,
 				StreamingRuntimeContext runtimeContext,
 				OffsetCommitMode offsetCommitMode) throws Exception {
-			return mock(AbstractFetcher.class);
+			return this.testFetcher;
 		}
 
 		@Override
@@ -705,21 +695,12 @@ public class FlinkKafkaConsumerBaseTest {
 				KafkaTopicsDescriptor topicsDescriptor,
 				int indexOfThisSubtask,
 				int numParallelSubtasks) {
-			return new TestPartitionDiscoverer(
-				topicsDescriptor,
-				indexOfThisSubtask,
-				numParallelSubtasks,
-				TestPartitionDiscoverer.createMockGetAllTopicsSequenceFromFixedReturn(fixedMockGetAllTopicsReturnSequence),
-				TestPartitionDiscoverer.createMockGetAllPartitionsFromTopicsSequenceFromFixedReturn(fixedMockGetAllPartitionsForTopicsReturnSequence));
+			return this.testPartitionDiscoverer;
 		}
 
 		@Override
 		protected boolean getIsAutoCommitEnabled() {
-			return isAutoCommitEnabled;
-		}
-
-		public void setIsAutoCommitEnabled(boolean isAutoCommitEnabled) {
-			this.isAutoCommitEnabled = isAutoCommitEnabled;
+			return this.isAutoCommitEnabled;
 		}
 	}
 
@@ -748,7 +729,7 @@ public class FlinkKafkaConsumerBaseTest {
 			return list;
 		}
 
-		public boolean isClearCalled() {
+		boolean isClearCalled() {
 			return clearCalled;
 		}
 
@@ -761,4 +742,157 @@ public class FlinkKafkaConsumerBaseTest {
 			}
 		}
 	}
+
+	/**
+	 * Returns a mock {@link AbstractFetcher}, with run / stop latches injected in
+	 * the {@link AbstractFetcher#runFetchLoop()} method.
+	 */
+	private static <T> AbstractFetcher<T, ?> getRunnableMockFetcher(
+			OneShotLatch runLatch,
+			OneShotLatch stopLatch) throws Exception {
+
+		@SuppressWarnings("unchecked")
+		final AbstractFetcher<T, ?> fetcher = mock(AbstractFetcher.class);
+
+		Mockito.doAnswer(invocationOnMock -> {
+			runLatch.trigger();
+			stopLatch.await();
+			return null;
+		}).when(fetcher).runFetchLoop();
+
+		return fetcher;
+	}
+
+	@SuppressWarnings("unchecked")
+	private static <T, S> void setupConsumer(
+			FlinkKafkaConsumerBase<T> consumer,
+			boolean isRestored,
+			ListState<S> restoredListState,
+			boolean isCheckpointingEnabled,
+			int subtaskIndex,
+			int totalNumSubtasks) throws Exception {
+
+		// run setup procedure in operator life cycle
+		consumer.setRuntimeContext(new MockRuntimeContext(isCheckpointingEnabled, totalNumSubtasks, subtaskIndex));
+		consumer.initializeState(new MockFunctionInitializationContext(isRestored, new MockOperatorStateStore(restoredListState)));
+		consumer.open(new Configuration());
+	}
+
+	private static class MockRuntimeContext extends StreamingRuntimeContext {
+
+		private final boolean isCheckpointingEnabled;
+
+		private final int numParallelSubtasks;
+		private final int subtaskIndex;
+
+		private MockRuntimeContext(
+				boolean isCheckpointingEnabled,
+				int numParallelSubtasks,
+				int subtaskIndex) {
+
+			super(
+				new MockStreamOperator(),
+				new MockEnvironment("mockTask", 4 * MemoryManager.DEFAULT_PAGE_SIZE, null, 16),
+				Collections.<String, Accumulator<?, ?>>emptyMap());
+
+			this.isCheckpointingEnabled = isCheckpointingEnabled;
+			this.numParallelSubtasks = numParallelSubtasks;
+			this.subtaskIndex = subtaskIndex;
+		}
+
+		@Override
+		public MetricGroup getMetricGroup() {
+			return new UnregisteredMetricsGroup();
+		}
+
+		@Override
+		public boolean isCheckpointingEnabled() {
+			return isCheckpointingEnabled;
+		}
+
+		@Override
+		public int getIndexOfThisSubtask() {
+			return subtaskIndex;
+		}
+
+		@Override
+		public int getNumberOfParallelSubtasks() {
+			return numParallelSubtasks;
+		}
+
+		// ------------------------------------------------------------------------
+
+		private static class MockStreamOperator extends AbstractStreamOperator<Integer> {
+			private static final long serialVersionUID = -1153976702711944427L;
+
+			@Override
+			public ExecutionConfig getExecutionConfig() {
+				return new ExecutionConfig();
+			}
+		}
+	}
+
+	private static class MockOperatorStateStore implements OperatorStateStore {
+
+		private final ListState<?> mockRestoredUnionListState;
+
+		private MockOperatorStateStore(ListState<?> restoredUnionListState) {
+			this.mockRestoredUnionListState = restoredUnionListState;
+		}
+
+		@Override
+		@SuppressWarnings("unchecked")
+		public <S> ListState<S> getUnionListState(ListStateDescriptor<S> stateDescriptor) throws Exception {
+			return (ListState<S>) mockRestoredUnionListState;
+		}
+
+		@Override
+		public <T extends Serializable> ListState<T> getSerializableListState(String stateName) throws Exception {
+			// return empty state for the legacy 1.2 Kafka consumer state
+			return new TestingListState<>();
+		}
+
+		// ------------------------------------------------------------------------
+
+		@Override
+		public <S> ListState<S> getOperatorState(ListStateDescriptor<S> stateDescriptor) throws Exception {
+			throw new UnsupportedOperationException();
+		}
+
+		@Override
+		public <S> ListState<S> getListState(ListStateDescriptor<S> stateDescriptor) throws Exception {
+			throw new UnsupportedOperationException();
+		}
+
+		@Override
+		public Set<String> getRegisteredStateNames() {
+			throw new UnsupportedOperationException();
+		}
+	}
+
+	private static class MockFunctionInitializationContext implements FunctionInitializationContext {
+
+		private final boolean isRestored;
+		private final OperatorStateStore operatorStateStore;
+
+		private MockFunctionInitializationContext(boolean isRestored, OperatorStateStore operatorStateStore) {
+			this.isRestored = isRestored;
+			this.operatorStateStore = operatorStateStore;
+		}
+
+		@Override
+		public boolean isRestored() {
+			return isRestored;
+		}
+
+		@Override
+		public OperatorStateStore getOperatorStateStore() {
+			return operatorStateStore;
+		}
+
+		@Override
+		public KeyedStateStore getKeyedStateStore() {
+			throw new UnsupportedOperationException();
+		}
+	}
 }