You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tz...@apache.org on 2017/02/22 17:22:38 UTC

flink git commit: [FLINK-5701] [kafka] FlinkKafkaProducer should check asyncException on checkpoints

Repository: flink
Updated Branches:
  refs/heads/release-1.2 15bb4246f -> 576cc895c


[FLINK-5701] [kafka] FlinkKafkaProducer should check asyncException on checkpoints

This closes #3278.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/576cc895
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/576cc895
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/576cc895

Branch: refs/heads/release-1.2
Commit: 576cc895c836be5aea62e3d4a6b3ceb1b257c612
Parents: 15bb424
Author: Tzu-Li (Gordon) Tai <tz...@apache.org>
Authored: Tue Feb 7 00:37:13 2017 +0800
Committer: Tzu-Li (Gordon) Tai <tz...@apache.org>
Committed: Thu Feb 23 01:21:58 2017 +0800

----------------------------------------------------------------------
 .../kafka/FlinkKafkaProducerBase.java           |  15 +-
 .../kafka/FlinkKafkaProducerBaseTest.java       | 391 ++++++++++++-------
 2 files changed, 272 insertions(+), 134 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/576cc895/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java b/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java
index 679b731..6a7b17f 100644
--- a/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java
+++ b/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java
@@ -17,6 +17,7 @@
 
 package org.apache.flink.streaming.connectors.kafka;
 
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.functions.RuntimeContext;
 import org.apache.flink.api.common.state.OperatorStateStore;
 import org.apache.flink.api.java.ClosureCleaner;
@@ -348,6 +349,9 @@ public abstract class FlinkKafkaProducerBase<IN> extends RichSinkFunction<IN> im
 
 	@Override
 	public void snapshotState(FunctionSnapshotContext ctx) throws Exception {
+		// check for asynchronous errors and fail the checkpoint if necessary
+		checkErroneous();
+
 		if (flushOnCheckpoint) {
 			// flushing is activated: We need to wait until pendingRecords is 0
 			flush();
@@ -355,7 +359,9 @@ public abstract class FlinkKafkaProducerBase<IN> extends RichSinkFunction<IN> im
 				if (pendingRecords != 0) {
 					throw new IllegalStateException("Pending record count must be zero at this point: " + pendingRecords);
 				}
-				// pending records count is 0. We can now confirm the checkpoint
+
+				// if the flushed requests has errors, we should propagate it also and fail the checkpoint
+				checkErroneous();
 			}
 		}
 	}
@@ -383,4 +389,11 @@ public abstract class FlinkKafkaProducerBase<IN> extends RichSinkFunction<IN> im
 		props.setProperty(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList);
 		return props;
 	}
+
+	@VisibleForTesting
+	protected long numPendingRecords() {
+		synchronized (pendingRecordsLock) {
+			return pendingRecords;
+		}
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/576cc895/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBaseTest.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBaseTest.java b/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBaseTest.java
index 2e06160..1f16d8e 100644
--- a/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBaseTest.java
+++ b/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBaseTest.java
@@ -18,38 +18,36 @@
 package org.apache.flink.streaming.connectors.kafka;
 
 import org.apache.flink.api.common.functions.RuntimeContext;
-import org.apache.flink.api.java.tuple.Tuple1;
+import org.apache.flink.core.testutils.CheckedThread;
+import org.apache.flink.core.testutils.MultiShotLatch;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
 import org.apache.flink.streaming.api.operators.StreamSink;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
 import org.apache.flink.configuration.Configuration;
-import org.apache.flink.runtime.state.FunctionSnapshotContext;
 import org.apache.flink.streaming.connectors.kafka.partitioner.KafkaPartitioner;
 import org.apache.flink.streaming.connectors.kafka.testutils.FakeStandardProducerConfig;
 import org.apache.flink.streaming.util.serialization.KeyedSerializationSchema;
 import org.apache.kafka.clients.producer.Callback;
 import org.apache.kafka.clients.producer.KafkaProducer;
 import org.apache.kafka.clients.producer.ProducerConfig;
-import org.apache.kafka.clients.producer.RecordMetadata;
 import org.apache.kafka.clients.producer.ProducerRecord;
-import org.apache.kafka.common.Metric;
-import org.apache.kafka.common.MetricName;
 import org.apache.kafka.common.PartitionInfo;
 import org.apache.kafka.common.serialization.ByteArraySerializer;
 import org.junit.Assert;
 import org.junit.Test;
-import scala.concurrent.duration.Deadline;
-import scala.concurrent.duration.FiniteDuration;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
 
 import java.util.ArrayList;
 import java.util.List;
-import java.util.Map;
 import java.util.Properties;
-import java.util.concurrent.Future;
-import java.util.concurrent.atomic.AtomicBoolean;
 
 import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.anyString;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
@@ -88,201 +86,328 @@ public class FlinkKafkaProducerBaseTest {
 	@Test
 	public void testPartitionerOpenedWithDeterminatePartitionList() throws Exception {
 		KafkaPartitioner mockPartitioner = mock(KafkaPartitioner.class);
+
 		RuntimeContext mockRuntimeContext = mock(RuntimeContext.class);
 		when(mockRuntimeContext.getIndexOfThisSubtask()).thenReturn(0);
 		when(mockRuntimeContext.getNumberOfParallelSubtasks()).thenReturn(1);
-
-		DummyFlinkKafkaProducer producer = new DummyFlinkKafkaProducer(
+		
+		// out-of-order list of 4 partitions
+		List<PartitionInfo> mockPartitionsList = new ArrayList<>(4);
+		mockPartitionsList.add(new PartitionInfo(DummyFlinkKafkaProducer.DUMMY_TOPIC, 3, null, null, null));
+		mockPartitionsList.add(new PartitionInfo(DummyFlinkKafkaProducer.DUMMY_TOPIC, 1, null, null, null));
+		mockPartitionsList.add(new PartitionInfo(DummyFlinkKafkaProducer.DUMMY_TOPIC, 0, null, null, null));
+		mockPartitionsList.add(new PartitionInfo(DummyFlinkKafkaProducer.DUMMY_TOPIC, 2, null, null, null));
+
+		final DummyFlinkKafkaProducer producer = new DummyFlinkKafkaProducer(
 			FakeStandardProducerConfig.get(), mockPartitioner);
 		producer.setRuntimeContext(mockRuntimeContext);
 
+		final KafkaProducer mockProducer = producer.getMockKafkaProducer();
+		when(mockProducer.partitionsFor(anyString())).thenReturn(mockPartitionsList);
+		when(mockProducer.metrics()).thenReturn(null);
+
 		producer.open(new Configuration());
 
-		// the internal mock KafkaProducer will return an out-of-order list of 4 partitions,
-		// which should be sorted before provided to the custom partitioner's open() method
+		// the out-of-order partitions list should be sorted before provided to the custom partitioner's open() method
 		int[] correctPartitionList = {0, 1, 2, 3};
 		verify(mockPartitioner).open(0, 1, correctPartitionList);
 	}
 
 	/**
-	 * Test ensuring that the producer is not dropping buffered records.;
-	 * we set a timeout because the test will not finish if the logic is broken
+	 * Test ensuring that if an invoke call happens right after an async exception is caught, it should be rethrown
 	 */
-	@Test(timeout=5000)
-	public void testAtLeastOnceProducer() throws Throwable {
-		runAtLeastOnceTest(true);
+	@Test
+	public void testAsyncErrorRethrownOnInvoke() throws Throwable {
+		final DummyFlinkKafkaProducer<String> producer = new DummyFlinkKafkaProducer<>(
+			FakeStandardProducerConfig.get(), null);
+
+		OneInputStreamOperatorTestHarness<String, Object> testHarness =
+			new OneInputStreamOperatorTestHarness<>(new StreamSink<>(producer));
+
+		testHarness.open();
+
+		testHarness.processElement(new StreamRecord<>("msg-1"));
+
+		// let the message request return an async exception
+		producer.getPendingCallbacks().get(0).onCompletion(null, new Exception("artificial async exception"));
+
+		try {
+			testHarness.processElement(new StreamRecord<>("msg-2"));
+		} catch (Exception e) {
+			// the next invoke should rethrow the async exception
+			Assert.assertTrue(e.getCause().getMessage().contains("artificial async exception"));
+
+			// test succeeded
+			return;
+		}
+
+		Assert.fail();
 	}
 
 	/**
-	 * Ensures that the at least once producing test fails if the flushing is disabled
+	 * Test ensuring that if a snapshot call happens right after an async exception is caught, it should be rethrown
 	 */
-	@Test(expected = AssertionError.class, timeout=5000)
-	public void testAtLeastOnceProducerFailsIfFlushingDisabled() throws Throwable {
-		runAtLeastOnceTest(false);
-	}
-
-	private void runAtLeastOnceTest(boolean flushOnCheckpoint) throws Throwable {
-		final AtomicBoolean snapshottingFinished = new AtomicBoolean(false);
+	@Test
+	public void testAsyncErrorRethrownOnCheckpoint() throws Throwable {
 		final DummyFlinkKafkaProducer<String> producer = new DummyFlinkKafkaProducer<>(
-			FakeStandardProducerConfig.get(), null, snapshottingFinished);
-		producer.setFlushOnCheckpoint(flushOnCheckpoint);
+			FakeStandardProducerConfig.get(), null);
 
 		OneInputStreamOperatorTestHarness<String, Object> testHarness =
-				new OneInputStreamOperatorTestHarness<>(new StreamSink(producer));
+			new OneInputStreamOperatorTestHarness<>(new StreamSink<>(producer));
 
 		testHarness.open();
 
-		for (int i = 0; i < 100; i++) {
-			testHarness.processElement(new StreamRecord<>("msg-" + i));
+		testHarness.processElement(new StreamRecord<>("msg-1"));
+
+		// let the message request return an async exception
+		producer.getPendingCallbacks().get(0).onCompletion(null, new Exception("artificial async exception"));
+
+		try {
+			testHarness.snapshot(123L, 123L);
+		} catch (Exception e) {
+			// the next invoke should rethrow the async exception
+			Assert.assertTrue(e.getCause().getMessage().contains("artificial async exception"));
+
+			// test succeeded
+			return;
 		}
 
-		// start a thread confirming all pending records
-		final Tuple1<Throwable> runnableError = new Tuple1<>(null);
-		final Thread threadA = Thread.currentThread();
+		Assert.fail();
+	}
+
+	/**
+	 * Test ensuring that if an async exception is caught for one of the flushed requests on checkpoint,
+	 * it should be rethrown; we set a timeout because the test will not finish if the logic is broken.
+	 *
+	 * Note that this test does not test the snapshot method is blocked correctly when there are pending recorrds.
+	 * The test for that is covered in testAtLeastOnceProducer.
+	 */
+	@SuppressWarnings("unchecked")
+	@Test(timeout=5000)
+	public void testAsyncErrorRethrownOnCheckpointAfterFlush() throws Throwable {
+		final DummyFlinkKafkaProducer<String> producer = new DummyFlinkKafkaProducer<>(
+			FakeStandardProducerConfig.get(), null);
+		producer.setFlushOnCheckpoint(true);
+
+		final KafkaProducer<?, ?> mockProducer = producer.getMockKafkaProducer();
+
+		final OneInputStreamOperatorTestHarness<String, Object> testHarness =
+			new OneInputStreamOperatorTestHarness<>(new StreamSink<>(producer));
+
+		testHarness.open();
+
+		testHarness.processElement(new StreamRecord<>("msg-1"));
+		testHarness.processElement(new StreamRecord<>("msg-2"));
+		testHarness.processElement(new StreamRecord<>("msg-3"));
+
+		verify(mockProducer, times(3)).send(any(ProducerRecord.class), any(Callback.class));
+
+		// only let the first callback succeed for now
+		producer.getPendingCallbacks().get(0).onCompletion(null, null);
 
-		Runnable confirmer = new Runnable() {
+		CheckedThread snapshotThread = new CheckedThread() {
 			@Override
-			public void run() {
-				try {
-					MockProducer mp = producer.getProducerInstance();
-					List<Callback> pending = mp.getPending();
-
-					// we need to find out if the snapshot() method blocks forever
-					// this is not possible. If snapshot() is running, it will
-					// start removing elements from the pending list.
-					synchronized (threadA) {
-						threadA.wait(500L);
-					}
-					// we now check that no records have been confirmed yet
-					Assert.assertEquals(100, pending.size());
-					Assert.assertFalse("Snapshot method returned before all records were confirmed",
-						snapshottingFinished.get());
-
-					// now confirm all checkpoints
-					for (Callback c: pending) {
-						c.onCompletion(null, null);
-					}
-					pending.clear();
-				} catch(Throwable t) {
-					runnableError.f0 = t;
-				}
+			public void go() throws Exception {
+				// this should block at first, since there are still two pending records that needs to be flushed
+				testHarness.snapshot(123L, 123L);
 			}
 		};
-		Thread threadB = new Thread(confirmer);
-		threadB.start();
+		snapshotThread.start();
 
-		// this should block:
-		testHarness.snapshot(0, 0);
+		// let the 2nd message fail with an async exception
+		producer.getPendingCallbacks().get(1).onCompletion(null, new Exception("artificial async failure for 2nd message"));
+		producer.getPendingCallbacks().get(2).onCompletion(null, null);
 
-		synchronized (threadA) {
-			threadA.notifyAll(); // just in case, to let the test fail faster
-		}
-		Assert.assertEquals(0, producer.getProducerInstance().getPending().size());
-		Deadline deadline = FiniteDuration.apply(5, "s").fromNow();
-		while (deadline.hasTimeLeft() && threadB.isAlive()) {
-			threadB.join(500);
-		}
-		Assert.assertFalse("Thread A is expected to be finished at this point. If not, the test is prone to fail", threadB.isAlive());
-		if (runnableError.f0 != null) {
-			throw runnableError.f0;
+		try {
+			snapshotThread.sync();
+		} catch (Exception e) {
+			// the snapshot should have failed with the async exception
+			Assert.assertTrue(e.getCause().getMessage().contains("artificial async failure for 2nd message"));
+
+			// test succeeded
+			return;
 		}
 
+		Assert.fail();
+	}
+
+	/**
+	 * Test ensuring that the producer is not dropping buffered records;
+	 * we set a timeout because the test will not finish if the logic is broken
+	 */
+	@SuppressWarnings("unchecked")
+	@Test(timeout=10000)
+	public void testAtLeastOnceProducer() throws Throwable {
+		final DummyFlinkKafkaProducer<String> producer = new DummyFlinkKafkaProducer<>(
+			FakeStandardProducerConfig.get(), null);
+		producer.setFlushOnCheckpoint(true);
+
+		final KafkaProducer<?, ?> mockProducer = producer.getMockKafkaProducer();
+
+		final OneInputStreamOperatorTestHarness<String, Object> testHarness =
+			new OneInputStreamOperatorTestHarness<>(new StreamSink<>(producer));
+
+		testHarness.open();
+
+		testHarness.processElement(new StreamRecord<>("msg-1"));
+		testHarness.processElement(new StreamRecord<>("msg-2"));
+		testHarness.processElement(new StreamRecord<>("msg-3"));
+
+		verify(mockProducer, times(3)).send(any(ProducerRecord.class), any(Callback.class));
+		Assert.assertEquals(3, producer.getPendingSize());
+
+		// start a thread to perform checkpointing
+		CheckedThread snapshotThread = new CheckedThread() {
+			@Override
+			public void go() throws Exception {
+				// this should block until all records are flushed;
+				// if the snapshot implementation returns before pending records are flushed,
+				testHarness.snapshot(123L, 123L);
+			}
+		};
+		snapshotThread.start();
+
+		// before proceeding, make sure that flushing has started and that the snapshot is still blocked;
+		// this would block forever if the snapshot didn't perform a flush
+		producer.waitUntilFlushStarted();
+		Assert.assertTrue("Snapshot returned before all records were flushed", snapshotThread.isAlive());
+
+		// now, complete the callbacks
+		producer.getPendingCallbacks().get(0).onCompletion(null, null);
+		Assert.assertTrue("Snapshot returned before all records were flushed", snapshotThread.isAlive());
+		Assert.assertEquals(2, producer.getPendingSize());
+
+		producer.getPendingCallbacks().get(1).onCompletion(null, null);
+		Assert.assertTrue("Snapshot returned before all records were flushed", snapshotThread.isAlive());
+		Assert.assertEquals(1, producer.getPendingSize());
+
+		producer.getPendingCallbacks().get(2).onCompletion(null, null);
+		Assert.assertEquals(0, producer.getPendingSize());
+
+		// this would fail with an exception if flushing wasn't completed before the snapshot method returned
+		snapshotThread.sync();
+
 		testHarness.close();
 	}
 
+	/**
+	 * This test is meant to assure that testAtLeastOnceProducer is valid by testing that if flushing is disabled,
+	 * the snapshot method does indeed finishes without waiting for pending records;
+	 * we set a timeout because the test will not finish if the logic is broken
+	 */
+	@SuppressWarnings("unchecked")
+	@Test(timeout=5000)
+	public void testDoesNotWaitForPendingRecordsIfFlushingDisabled() throws Throwable {
+		final DummyFlinkKafkaProducer<String> producer = new DummyFlinkKafkaProducer<>(
+			FakeStandardProducerConfig.get(), null);
+		producer.setFlushOnCheckpoint(false);
+
+		final KafkaProducer<?, ?> mockProducer = producer.getMockKafkaProducer();
+
+		final OneInputStreamOperatorTestHarness<String, Object> testHarness =
+			new OneInputStreamOperatorTestHarness<>(new StreamSink<>(producer));
+
+		testHarness.open();
+
+		testHarness.processElement(new StreamRecord<>("msg"));
+
+		// make sure that all callbacks have not been completed
+		verify(mockProducer, times(1)).send(any(ProducerRecord.class), any(Callback.class));
+
+		// should return even if there are pending records
+		testHarness.snapshot(123L, 123L);
+
+		testHarness.close();
+	}
 
 	// ------------------------------------------------------------------------
 
 	private static class DummyFlinkKafkaProducer<T> extends FlinkKafkaProducerBase<T> {
 		private static final long serialVersionUID = 1L;
+		
+		private final static String DUMMY_TOPIC = "dummy-topic";
 
-		private transient MockProducer prod;
-		private AtomicBoolean snapshottingFinished;
+		private transient KafkaProducer<?, ?> mockProducer;
+		private transient List<Callback> pendingCallbacks;
+		private transient MultiShotLatch flushLatch;
+		private boolean isFlushed;
 
 		@SuppressWarnings("unchecked")
-		public DummyFlinkKafkaProducer(Properties producerConfig, KafkaPartitioner partitioner, AtomicBoolean snapshottingFinished) {
-			super("dummy-topic", (KeyedSerializationSchema< T >) mock(KeyedSerializationSchema.class), producerConfig, partitioner);
-			this.snapshottingFinished = snapshottingFinished;
-		}
+		DummyFlinkKafkaProducer(Properties producerConfig, KafkaPartitioner partitioner) {
 
-		// constructor variant for test irrelated to snapshotting
-		@SuppressWarnings("unchecked")
-		public DummyFlinkKafkaProducer(Properties producerConfig, KafkaPartitioner partitioner) {
-			super("dummy-topic", (KeyedSerializationSchema< T >) mock(KeyedSerializationSchema.class), producerConfig, partitioner);
-			this.snapshottingFinished = new AtomicBoolean(true);
-		}
+			super(DUMMY_TOPIC, (KeyedSerializationSchema<T>) mock(KeyedSerializationSchema.class), producerConfig, partitioner);
 
-		@Override
-		protected <K, V> KafkaProducer<K, V> getKafkaProducer(Properties props) {
-			this.prod = new MockProducer();
-			return this.prod;
-		}
+			this.mockProducer = mock(KafkaProducer.class);
+			when(mockProducer.send(any(ProducerRecord.class), any(Callback.class))).thenAnswer(new Answer<Object>() {
+				@Override
+				public Object answer(InvocationOnMock invocationOnMock) throws Throwable {
+					pendingCallbacks.add(invocationOnMock.getArgumentAt(1, Callback.class));
+					return null;
+				}
+			});
 
-		@Override
-		public void snapshotState(FunctionSnapshotContext ctx) throws Exception {
-			// call the actual snapshot state
-			super.snapshotState(ctx);
-			// notify test that snapshotting has been done
-			snapshottingFinished.set(true);
+			this.pendingCallbacks = new ArrayList<>();
+			this.flushLatch = new MultiShotLatch();
 		}
 
-		@Override
-		protected void flush() {
-			this.prod.flush();
+		long getPendingSize() {
+			if (flushOnCheckpoint) {
+				return numPendingRecords();
+			} else {
+				// when flushing is disabled, the implementation does not
+				// maintain the current number of pending records to reduce
+				// the extra locking overhead required to do so
+				throw new UnsupportedOperationException("getPendingSize not supported when flushing is disabled");
+			}
 		}
 
-		public MockProducer getProducerInstance() {
-			return this.prod;
+		List<Callback> getPendingCallbacks() {
+			return pendingCallbacks;
 		}
-	}
-
-	private static class MockProducer<K, V> extends KafkaProducer<K, V> {
-		List<Callback> pendingCallbacks = new ArrayList<>();
 
-		public MockProducer() {
-			super(FakeStandardProducerConfig.get());
+		KafkaProducer<?, ?> getMockKafkaProducer() {
+			return mockProducer;
 		}
 
 		@Override
-		public Future<RecordMetadata> send(ProducerRecord<K, V> record) {
-			throw new UnsupportedOperationException("Unexpected");
-		}
+		public void snapshotState(FunctionSnapshotContext ctx) throws Exception {
+			isFlushed = false;
 
-		@Override
-		public Future<RecordMetadata> send(ProducerRecord<K, V> record, Callback callback) {
-			pendingCallbacks.add(callback);
-			return null;
+			super.snapshotState(ctx);
+
+			// if the snapshot implementation doesn't wait until all pending records are flushed, we should fail the test
+			if (flushOnCheckpoint && !isFlushed) {
+				throw new RuntimeException("Flushing is enabled; snapshots should be blocked until all pending records are flushed");
+			}
 		}
 
-		@Override
-		public List<PartitionInfo> partitionsFor(String topic) {
-			List<PartitionInfo> list = new ArrayList<>();
-			// deliberately return an out-of-order partition list
-			list.add(new PartitionInfo(topic, 3, null, null, null));
-			list.add(new PartitionInfo(topic, 1, null, null, null));
-			list.add(new PartitionInfo(topic, 0, null, null, null));
-			list.add(new PartitionInfo(topic, 2, null, null, null));
-			return list;
+		public void waitUntilFlushStarted() throws Exception {
+			flushLatch.await();
 		}
 
+		@SuppressWarnings("unchecked")
 		@Override
-		public Map<MetricName, ? extends Metric> metrics() {
-			return null;
+		protected <K, V> KafkaProducer<K, V> getKafkaProducer(Properties props) {
+			return (KafkaProducer<K, V>) mockProducer;
 		}
 
+		@Override
+		protected void flush() {
+			flushLatch.trigger();
 
-		public List<Callback> getPending() {
-			return this.pendingCallbacks;
-		}
-
-		public void flush() {
-			while (pendingCallbacks.size() > 0) {
+			// simply wait until the producer's pending records become zero.
+			// This relies on the fact that the producer's Callback implementation
+			// and pending records tracking logic is implemented correctly, otherwise
+			// we will loop forever.
+			while (numPendingRecords() > 0) {
 				try {
 					Thread.sleep(10);
 				} catch (InterruptedException e) {
 					throw new RuntimeException("Unable to flush producer, task was interrupted");
 				}
 			}
+
+			isFlushed = true;
 		}
 	}
 }