You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jg...@apache.org on 2018/10/31 22:51:13 UTC

[kafka] branch trunk updated: MINOR: Fix a few blocking calls in PlaintextConsumerTest (#5859)

This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 8065a0b  MINOR: Fix a few blocking calls in PlaintextConsumerTest (#5859)
8065a0b is described below

commit 8065a0bef41b6f3d0f8b0fcc01a008d7916ac5c6
Author: Jason Gustafson <ja...@confluent.io>
AuthorDate: Wed Oct 31 15:51:02 2018 -0700

    MINOR: Fix a few blocking calls in PlaintextConsumerTest (#5859)
    
    We've been seeing some hanging builds recently (see KAFKA-7553). Consistently the culprit seems to be a test case in PlaintextConsumerTest. This patch doesn't fix the underlying issue, but it eliminates a few places where these test cases could block:
    
    1. It replaces several calls to the deprecated `poll(long)` which can block indefinitely in the worst case in order to join the group with `poll(Duration)` which respects the timeout.
    2. It also fixes a consume utility in `TestUtils` which can block for a long time depending on the number of records that are expected to be consumed.
    
    Reviewers: Ismael Juma <is...@juma.me.uk>, Colin Patrick McCabe <co...@cmccabe.xyz>
---
 .../integration/kafka/api/BaseConsumerTest.scala   |  51 ++--
 .../kafka/api/PlaintextConsumerTest.scala          | 270 +++++++--------------
 .../integration/kafka/api/TransactionsTest.scala   |   2 +-
 .../test/scala/unit/kafka/utils/TestUtils.scala    |  74 +++---
 4 files changed, 169 insertions(+), 228 deletions(-)

diff --git a/core/src/test/scala/integration/kafka/api/BaseConsumerTest.scala b/core/src/test/scala/integration/kafka/api/BaseConsumerTest.scala
index 488874d..3e67b18 100644
--- a/core/src/test/scala/integration/kafka/api/BaseConsumerTest.scala
+++ b/core/src/test/scala/integration/kafka/api/BaseConsumerTest.scala
@@ -12,13 +12,14 @@
  */
 package kafka.api
 
+import java.time.Duration
 import java.util
 
 import org.apache.kafka.clients.consumer._
 import org.apache.kafka.clients.producer.{ProducerConfig, ProducerRecord}
 import org.apache.kafka.common.record.TimestampType
 import org.apache.kafka.common.{PartitionInfo, TopicPartition}
-import kafka.utils.ShutdownableThread
+import kafka.utils.{ShutdownableThread, TestUtils}
 import kafka.server.KafkaConfig
 import org.junit.Assert._
 import org.junit.{Before, Test}
@@ -98,8 +99,7 @@ abstract class BaseConsumerTest extends IntegrationTestHarness {
     consumer.subscribe(List(topic).asJava, listener)
 
     // the initial subscription should cause a callback execution
-    consumer.poll(2000)
-
+    awaitRebalance(consumer, listener)
     assertEquals(1, listener.callsToAssigned)
 
     // get metadata for the topic
@@ -113,11 +113,8 @@ abstract class BaseConsumerTest extends IntegrationTestHarness {
     val coordinator = parts.head.leader().id()
     this.servers(coordinator).shutdown()
 
-    consumer.poll(5000)
-
     // the failover should not cause a rebalance
-    assertEquals(1, listener.callsToAssigned)
-    assertEquals(1, listener.callsToRevoked)
+    ensureNoRebalance(consumer, listener)
   }
 
   protected class TestConsumerReassignmentListener extends ConsumerRebalanceListener {
@@ -183,29 +180,41 @@ abstract class BaseConsumerTest extends IntegrationTestHarness {
                                      numRecords: Int,
                                      maxPollRecords: Int = Int.MaxValue): ArrayBuffer[ConsumerRecord[K, V]] = {
     val records = new ArrayBuffer[ConsumerRecord[K, V]]
-    val maxIters = numRecords * 300
-    var iters = 0
-    while (records.size < numRecords) {
-      val polledRecords = consumer.poll(50).asScala
-      assertTrue(polledRecords.size <= maxPollRecords)
-      for (record <- polledRecords)
-        records += record
-      if (iters > maxIters)
-        throw new IllegalStateException("Failed to consume the expected records after " + iters + " iterations.")
-      iters += 1
+    def pollAction(polledRecords: ConsumerRecords[K, V]): Boolean = {
+      assertTrue(polledRecords.asScala.size <= maxPollRecords)
+      records ++= polledRecords.asScala
+      records.size >= numRecords
     }
+    TestUtils.pollRecordsUntilTrue(consumer, pollAction, waitTimeMs = 60000,
+      msg = s"Timed out before consuming expected $numRecords records. " +
+        s"The number consumed was ${records.size}.")
     records
   }
 
   protected def awaitCommitCallback[K, V](consumer: Consumer[K, V],
                                           commitCallback: CountConsumerCommitCallback,
                                           count: Int = 1): Unit = {
-    val started = System.currentTimeMillis()
-    while (commitCallback.successCount < count && System.currentTimeMillis() - started < 10000)
-      consumer.poll(50)
+    TestUtils.pollUntilTrue(consumer, () => commitCallback.successCount >= count,
+      "Failed to observe commit callback before timeout", waitTimeMs = 10000)
     assertEquals(count, commitCallback.successCount)
   }
 
+  protected def awaitRebalance(consumer: Consumer[_, _], rebalanceListener: TestConsumerReassignmentListener): Unit = {
+    val numReassignments = rebalanceListener.callsToAssigned
+    TestUtils.pollUntilTrue(consumer, () => rebalanceListener.callsToAssigned > numReassignments,
+      "Timed out before expected rebalance completed")
+  }
+
+  protected def ensureNoRebalance(consumer: Consumer[_, _], rebalanceListener: TestConsumerReassignmentListener): Unit = {
+    // The best way to verify that the current membership is still active is to commit offsets.
+    // This would fail if the group had rebalanced.
+    val initialRevokeCalls = rebalanceListener.callsToRevoked
+    val commitCallback = new CountConsumerCommitCallback
+    consumer.commitAsync(commitCallback)
+    awaitCommitCallback(consumer, commitCallback)
+    assertEquals(initialRevokeCalls, rebalanceListener.callsToRevoked)
+  }
+
   protected class CountConsumerCommitCallback extends OffsetCommitCallback {
     var successCount = 0
     var failCount = 0
@@ -274,7 +283,7 @@ abstract class BaseConsumerTest extends IntegrationTestHarness {
         subscriptionChanged = false
       }
       try {
-        consumer.poll(50)
+        consumer.poll(Duration.ofMillis(50))
       } catch {
         case _: WakeupException => // ignore for shutdown
       }
diff --git a/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala b/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala
index a23513f..0c63775 100644
--- a/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala
+++ b/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala
@@ -12,6 +12,7 @@
   */
 package kafka.api
 
+import java.time.Duration
 import java.util
 import java.util.regex.Pattern
 import java.util.{Collections, Locale, Optional, Properties}
@@ -126,6 +127,14 @@ class PlaintextConsumerTest extends BaseConsumerTest {
   }
 
   @Test
+  def testDeprecatedPollBlocksForAssignment(): Unit = {
+    val consumer = createConsumer()
+    consumer.subscribe(Set(topic).asJava)
+    consumer.poll(0)
+    assertEquals(Set(tp, tp2), consumer.assignment().asScala)
+  }
+
+  @Test
   def testHeadersExtendedSerializerDeserializer(): Unit = {
     val extendedSerializer = new ExtendedSerializer[Array[Byte]] with SerializerImpl
 
@@ -168,15 +177,15 @@ class PlaintextConsumerTest extends BaseConsumerTest {
     val listener = new TestConsumerReassignmentListener()
     consumer.subscribe(List(topic).asJava, listener)
 
-    // poll once to get the initial assignment
-    consumer.poll(0)
+    // rebalance to get the initial assignment
+    awaitRebalance(consumer, listener)
     assertEquals(1, listener.callsToAssigned)
     assertEquals(1, listener.callsToRevoked)
 
     Thread.sleep(3500)
 
     // we should fall out of the group and need to rebalance
-    consumer.poll(0)
+    awaitRebalance(consumer, listener)
     assertEquals(2, listener.callsToAssigned)
     assertEquals(2, listener.callsToRevoked)
   }
@@ -209,12 +218,12 @@ class PlaintextConsumerTest extends BaseConsumerTest {
 
     consumer.subscribe(List(topic).asJava, listener)
 
-    // poll once to join the group and get the initial assignment
-    consumer.poll(0)
+    // rebalance to get the initial assignment
+    awaitRebalance(consumer, listener)
 
     // force a rebalance to trigger an invocation of the revocation callback while in the group
     consumer.subscribe(List("otherTopic").asJava, listener)
-    consumer.poll(0)
+    awaitRebalance(consumer, listener)
 
     assertEquals(0, committedPosition)
     assertTrue(commitCompleted)
@@ -237,14 +246,11 @@ class PlaintextConsumerTest extends BaseConsumerTest {
     }
     consumer.subscribe(List(topic).asJava, listener)
 
-    // poll once to join the group and get the initial assignment
-    consumer.poll(0)
+    // rebalance to get the initial assignment
+    awaitRebalance(consumer, listener)
 
-    // we should still be in the group after this invocation
-    consumer.poll(0)
-
-    assertEquals(1, listener.callsToAssigned)
-    assertEquals(1, listener.callsToRevoked)
+    // We should still be in the group after this invocation
+    ensureNoRebalance(consumer, listener)
   }
 
   @Test
@@ -257,12 +263,7 @@ class PlaintextConsumerTest extends BaseConsumerTest {
     sendRecords(producer, numRecords, tp)
 
     consumer.subscribe(List(topic).asJava)
-
-    val assignment = Set(tp, tp2)
-    TestUtils.waitUntilTrue(() => {
-      consumer.poll(50)
-      consumer.assignment() == assignment.asJava
-    }, s"Expected partitions ${assignment.asJava} but actually got ${consumer.assignment()}")
+    awaitAssignment(consumer, Set(tp, tp2))
 
     // should auto-commit seeked positions before closing
     consumer.seek(tp, 300)
@@ -285,12 +286,7 @@ class PlaintextConsumerTest extends BaseConsumerTest {
     sendRecords(producer, numRecords, tp)
 
     consumer.subscribe(List(topic).asJava)
-
-    val assignment = Set(tp, tp2)
-    TestUtils.waitUntilTrue(() => {
-      consumer.poll(50)
-      consumer.assignment() == assignment.asJava
-    }, s"Expected partitions ${assignment.asJava} but actually got ${consumer.assignment()}")
+    awaitAssignment(consumer, Set(tp, tp2))
 
     // should auto-commit seeked positions before closing
     consumer.seek(tp, 300)
@@ -362,32 +358,23 @@ class PlaintextConsumerTest extends BaseConsumerTest {
 
     val pattern = Pattern.compile("t.*c")
     consumer.subscribe(pattern, new TestConsumerReassignmentListener)
-    consumer.poll(50)
 
-    var subscriptions = Set(
+    var assignment = Set(
       new TopicPartition(topic, 0),
       new TopicPartition(topic, 1),
       new TopicPartition(topic1, 0),
       new TopicPartition(topic1, 1))
-
-    TestUtils.waitUntilTrue(() => {
-      consumer.poll(50)
-      consumer.assignment() == subscriptions.asJava
-    }, s"Expected partitions ${subscriptions.asJava} but actually got ${consumer.assignment()}")
+    awaitAssignment(consumer, assignment)
 
     val topic4 = "tsomec" // matches subscribed pattern
     createTopic(topic4, 2, serverCount)
     sendRecords(producer, numRecords = 1000, new TopicPartition(topic4, 0))
     sendRecords(producer, numRecords = 1000, new TopicPartition(topic4, 1))
 
-    subscriptions ++= Set(
+    assignment ++= Set(
       new TopicPartition(topic4, 0),
       new TopicPartition(topic4, 1))
-
-    TestUtils.waitUntilTrue(() => {
-      consumer.poll(50)
-      consumer.assignment() == subscriptions.asJava
-    }, s"Expected partitions ${subscriptions.asJava} but actually got ${consumer.assignment()}")
+    awaitAssignment(consumer, assignment)
 
     consumer.unsubscribe()
     assertEquals(0, consumer.assignment().size)
@@ -421,17 +408,12 @@ class PlaintextConsumerTest extends BaseConsumerTest {
 
     val pattern1 = Pattern.compile(".*o.*") // only 'topic' and 'foo' match this
     consumer.subscribe(pattern1, new TestConsumerReassignmentListener)
-    consumer.poll(50)
 
-    var subscriptions = Set(
+    var assignment = Set(
       new TopicPartition(topic, 0),
       new TopicPartition(topic, 1),
       new TopicPartition(fooTopic, 0))
-
-    TestUtils.waitUntilTrue(() => {
-      consumer.poll(50)
-      consumer.assignment() == subscriptions.asJava
-    }, s"Expected partitions ${subscriptions.asJava} but actually got ${consumer.assignment()}")
+    awaitAssignment(consumer, assignment)
 
     val barTopic = "bar" // matches the next subscription pattern
     createTopic(barTopic, 1, serverCount)
@@ -439,19 +421,12 @@ class PlaintextConsumerTest extends BaseConsumerTest {
 
     val pattern2 = Pattern.compile("...") // only 'foo' and 'bar' match this
     consumer.subscribe(pattern2, new TestConsumerReassignmentListener)
-    consumer.poll(50)
-
-    subscriptions --= Set(
+    assignment --= Set(
       new TopicPartition(topic, 0),
       new TopicPartition(topic, 1))
-
-    subscriptions ++= Set(
+    assignment ++= Set(
       new TopicPartition(barTopic, 0))
-
-    TestUtils.waitUntilTrue(() => {
-      consumer.poll(50)
-      consumer.assignment() == subscriptions.asJava
-    }, s"Expected partitions ${subscriptions.asJava} but actually got ${consumer.assignment()}")
+    awaitAssignment(consumer, assignment)
 
     consumer.unsubscribe()
     assertEquals(0, consumer.assignment().size)
@@ -480,18 +455,12 @@ class PlaintextConsumerTest extends BaseConsumerTest {
     assertEquals(0, consumer.assignment().size)
 
     consumer.subscribe(Pattern.compile("t.*c"), new TestConsumerReassignmentListener)
-    consumer.poll(50)
-
-    val subscriptions = Set(
+    val assignment = Set(
       new TopicPartition(topic, 0),
       new TopicPartition(topic, 1),
       new TopicPartition(topic1, 0),
       new TopicPartition(topic1, 1))
-
-    TestUtils.waitUntilTrue(() => {
-      consumer.poll(50)
-      consumer.assignment() == subscriptions.asJava
-    }, s"Expected partitions ${subscriptions.asJava} but actually got ${consumer.assignment()}")
+    awaitAssignment(consumer, assignment)
 
     consumer.unsubscribe()
     assertEquals(0, consumer.assignment().size)
@@ -524,7 +493,6 @@ class PlaintextConsumerTest extends BaseConsumerTest {
   def testAsyncCommit() {
     val consumer = createConsumer()
     consumer.assign(List(tp).asJava)
-    consumer.poll(0)
 
     val callback = new CountConsumerCommitCallback
     val count = 5
@@ -538,41 +506,29 @@ class PlaintextConsumerTest extends BaseConsumerTest {
   @Test
   def testExpandingTopicSubscriptions() {
     val otherTopic = "other"
-    val subscriptions = Set(new TopicPartition(topic, 0), new TopicPartition(topic, 1))
-    val expandedSubscriptions = subscriptions ++ Set(new TopicPartition(otherTopic, 0), new TopicPartition(otherTopic, 1))
+    val initialAssignment = Set(new TopicPartition(topic, 0), new TopicPartition(topic, 1))
     val consumer = createConsumer()
     consumer.subscribe(List(topic).asJava)
-    TestUtils.waitUntilTrue(() => {
-      consumer.poll(50)
-      consumer.assignment == subscriptions.asJava
-    }, s"Expected partitions ${subscriptions.asJava} but actually got ${consumer.assignment}")
+    awaitAssignment(consumer, initialAssignment)
 
     createTopic(otherTopic, 2, serverCount)
+    val expandedAssignment = initialAssignment ++ Set(new TopicPartition(otherTopic, 0), new TopicPartition(otherTopic, 1))
     consumer.subscribe(List(topic, otherTopic).asJava)
-    TestUtils.waitUntilTrue(() => {
-      consumer.poll(50)
-      consumer.assignment == expandedSubscriptions.asJava
-    }, s"Expected partitions ${expandedSubscriptions.asJava} but actually got ${consumer.assignment}")
+    awaitAssignment(consumer, expandedAssignment)
   }
 
   @Test
   def testShrinkingTopicSubscriptions() {
     val otherTopic = "other"
     createTopic(otherTopic, 2, serverCount)
-    val subscriptions = Set(new TopicPartition(topic, 0), new TopicPartition(topic, 1), new TopicPartition(otherTopic, 0), new TopicPartition(otherTopic, 1))
-    val shrunkenSubscriptions = Set(new TopicPartition(topic, 0), new TopicPartition(topic, 1))
+    val initialAssignment = Set(new TopicPartition(topic, 0), new TopicPartition(topic, 1), new TopicPartition(otherTopic, 0), new TopicPartition(otherTopic, 1))
     val consumer = createConsumer()
     consumer.subscribe(List(topic, otherTopic).asJava)
-    TestUtils.waitUntilTrue(() => {
-      consumer.poll(50)
-      consumer.assignment == subscriptions.asJava
-    }, s"Expected partitions ${subscriptions.asJava} but actually got ${consumer.assignment}")
+    awaitAssignment(consumer, initialAssignment)
 
+    val shrunkenAssignment = Set(new TopicPartition(topic, 0), new TopicPartition(topic, 1))
     consumer.subscribe(List(topic).asJava)
-    TestUtils.waitUntilTrue(() => {
-      consumer.poll(50)
-      consumer.assignment == shrunkenSubscriptions.asJava
-    }, s"Expected partitions ${shrunkenSubscriptions.asJava} but actually got ${consumer.assignment}")
+    awaitAssignment(consumer, shrunkenAssignment)
   }
 
   @Test
@@ -611,7 +567,7 @@ class PlaintextConsumerTest extends BaseConsumerTest {
 
     consumer.seekToEnd(List(tp).asJava)
     assertEquals(totalRecords, consumer.position(tp))
-    assertFalse(consumer.poll(totalRecords).iterator().hasNext)
+    assertTrue(consumer.poll(Duration.ofMillis(50)).isEmpty)
 
     consumer.seekToBeginning(List(tp).asJava)
     assertEquals(0, consumer.position(tp), 0)
@@ -629,7 +585,7 @@ class PlaintextConsumerTest extends BaseConsumerTest {
 
     consumer.seekToEnd(List(tp2).asJava)
     assertEquals(totalRecords, consumer.position(tp2))
-    assertFalse(consumer.poll(totalRecords).iterator().hasNext)
+    assertTrue(consumer.poll(Duration.ofMillis(50)).isEmpty)
 
     consumer.seekToBeginning(List(tp2).asJava)
     assertEquals(0, consumer.position(tp2), 0)
@@ -695,7 +651,7 @@ class PlaintextConsumerTest extends BaseConsumerTest {
     consumeAndVerifyRecords(consumer = consumer, numRecords = 5, startingOffset = 0)
     consumer.pause(partitions)
     sendRecords(producer, numRecords = 5, tp)
-    assertTrue(consumer.poll(0).isEmpty)
+    assertTrue(consumer.poll(Duration.ofMillis(100)).isEmpty)
     consumer.resume(partitions)
     consumeAndVerifyRecords(consumer = consumer, numRecords = 5, startingOffset = 5)
   }
@@ -713,14 +669,14 @@ class PlaintextConsumerTest extends BaseConsumerTest {
 
     // poll should fail because there is no offset reset strategy set
     intercept[NoOffsetForPartitionException] {
-      consumer.poll(50)
+      consumer.poll(Duration.ofMillis(50))
     }
 
     // seek to out of range position
     val outOfRangePos = totalRecords + 1
     consumer.seek(tp, outOfRangePos)
     val e = intercept[OffsetOutOfRangeException] {
-      consumer.poll(20000)
+      consumer.poll(Duration.ofMillis(20000))
     }
     val outOfRangePartitions = e.offsetOutOfRangePartitions()
     assertNotNull(outOfRangePartitions)
@@ -746,7 +702,7 @@ class PlaintextConsumerTest extends BaseConsumerTest {
 
     // consuming a record that is too large should succeed since KIP-74
     consumer.assign(List(tp).asJava)
-    val records = consumer.poll(20000)
+    val records = consumer.poll(Duration.ofMillis(20000))
     assertEquals(1, records.count)
     val consumerRecord = records.iterator().next()
     assertEquals(0L, consumerRecord.offset)
@@ -778,7 +734,7 @@ class PlaintextConsumerTest extends BaseConsumerTest {
 
     // we should only get the small record in the first `poll`
     consumer.assign(List(tp).asJava)
-    val records = consumer.poll(20000)
+    val records = consumer.poll(Duration.ofMillis(20000))
     assertEquals(1, records.count)
     val consumerRecord = records.iterator().next()
     assertEquals(0L, consumerRecord.offset)
@@ -830,10 +786,7 @@ class PlaintextConsumerTest extends BaseConsumerTest {
 
     consumer.subscribe(List(topic1, topic2, topic3).asJava)
 
-    TestUtils.waitUntilTrue(() => {
-      consumer.poll(50)
-      consumer.assignment() == partitions.toSet.asJava
-    }, s"Expected partitions ${partitions.asJava} but actually got ${consumer.assignment}")
+    awaitAssignment(consumer, partitions.toSet)
 
     val producer = createProducer()
     val producerRecords = partitions.flatMap(sendRecords(producer, partitionCount, _))
@@ -868,10 +821,7 @@ class PlaintextConsumerTest extends BaseConsumerTest {
 
     // subscribe to two topics
     consumer.subscribe(List(topic1, topic2).asJava)
-    TestUtils.waitUntilTrue(() => {
-      consumer.poll(50)
-      consumer.assignment() == expectedAssignment.asJava
-    }, s"Expected partitions ${expectedAssignment.asJava} but actually got ${consumer.assignment()}")
+    awaitAssignment(consumer, expectedAssignment)
 
     // add one more topic with 2 partitions
     val topic3 = "topic3"
@@ -879,17 +829,11 @@ class PlaintextConsumerTest extends BaseConsumerTest {
 
     val newExpectedAssignment = expectedAssignment ++ Set(new TopicPartition(topic3, 0), new TopicPartition(topic3, 1))
     consumer.subscribe(List(topic1, topic2, topic3).asJava)
-    TestUtils.waitUntilTrue(() => {
-      consumer.poll(50)
-      consumer.assignment() == newExpectedAssignment.asJava
-    }, s"Expected partitions ${newExpectedAssignment.asJava} but actually got ${consumer.assignment()}")
+    awaitAssignment(consumer, newExpectedAssignment)
 
     // remove the topic we just added
     consumer.subscribe(List(topic1, topic2).asJava)
-    TestUtils.waitUntilTrue(() => {
-      consumer.poll(50)
-      consumer.assignment() == expectedAssignment.asJava
-    }, s"Expected partitions ${expectedAssignment.asJava} but actually got ${consumer.assignment()}")
+    awaitAssignment(consumer, expectedAssignment)
 
     consumer.unsubscribe()
     assertEquals(0, consumer.assignment().size)
@@ -1321,8 +1265,7 @@ class PlaintextConsumerTest extends BaseConsumerTest {
     consumer.subscribe(List(topic).asJava, listener)
 
     // the initial subscription should cause a callback execution
-    while (listener.callsToAssigned == 0)
-      consumer.poll(50)
+    awaitRebalance(consumer, listener)
 
     consumer.subscribe(List[String]().asJava)
     assertEquals(0, consumer.assignment.size())
@@ -1357,8 +1300,6 @@ class PlaintextConsumerTest extends BaseConsumerTest {
     val consumer = createConsumer()
     consumer.assign(List(tp, tp2).asJava)
 
-    // Need to poll to join the group
-    consumer.poll(50)
     val pos1 = consumer.position(tp)
     val pos2 = consumer.position(tp2)
     consumer.commitSync(Map[TopicPartition, OffsetAndMetadata]((tp, new OffsetAndMetadata(3L))).asJava)
@@ -1402,11 +1343,7 @@ class PlaintextConsumerTest extends BaseConsumerTest {
 
     consumer.subscribe(List(topic).asJava, rebalanceListener)
 
-    val assignment = Set(tp, tp2)
-    TestUtils.waitUntilTrue(() => {
-      consumer.poll(50)
-      consumer.assignment() == assignment.asJava
-    }, s"Expected partitions ${assignment.asJava} but actually got ${consumer.assignment()}")
+    awaitAssignment(consumer, Set(tp, tp2))
 
     consumer.seek(tp, 300)
     consumer.seek(tp2, 500)
@@ -1415,10 +1352,7 @@ class PlaintextConsumerTest extends BaseConsumerTest {
     consumer.subscribe(List(topic, topic2).asJava, rebalanceListener)
 
     val newAssignment = Set(tp, tp2, new TopicPartition(topic2, 0), new TopicPartition(topic2, 1))
-    TestUtils.waitUntilTrue(() => {
-      consumer.poll(50)
-      consumer.assignment() == newAssignment.asJava
-    }, s"Expected partitions ${newAssignment.asJava} but actually got ${consumer.assignment()}")
+    awaitAssignment(consumer, newAssignment)
 
     // after rebalancing, we should have reset to the committed positions
     assertEquals(300, consumer.committed(tp).offset)
@@ -1438,14 +1372,10 @@ class PlaintextConsumerTest extends BaseConsumerTest {
     consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "testPerPartitionLeadMetricsCleanUpWithSubscribe")
     consumerConfig.setProperty(ConsumerConfig.CLIENT_ID_CONFIG, "testPerPartitionLeadMetricsCleanUpWithSubscribe")
     val consumer = createConsumer()
-    val listener0 = new TestConsumerReassignmentListener
-    consumer.subscribe(List(topic, topic2).asJava, listener0)
-    var records: ConsumerRecords[Array[Byte], Array[Byte]] = ConsumerRecords.empty()
-    TestUtils.waitUntilTrue(() => {
-      records = consumer.poll(100)
-      !records.records(tp).isEmpty
-    }, "Consumer did not consume any message before timeout.")
-    assertEquals("should be assigned once", 1, listener0.callsToAssigned)
+    val listener = new TestConsumerReassignmentListener
+    consumer.subscribe(List(topic, topic2).asJava, listener)
+    val records = awaitNonEmptyRecords(consumer, tp)
+    assertEquals("should be assigned once", 1, listener.callsToAssigned)
     // Verify the metric exist.
     val tags1 = new util.HashMap[String, String]()
     tags1.put("client-id", "testPerPartitionLeadMetricsCleanUpWithSubscribe")
@@ -1458,14 +1388,11 @@ class PlaintextConsumerTest extends BaseConsumerTest {
     tags2.put("partition", String.valueOf(tp2.partition()))
     val fetchLead0 = consumer.metrics.get(new MetricName("records-lead", "consumer-fetch-manager-metrics", "", tags1))
     assertNotNull(fetchLead0)
-    assertTrue(s"The lead should be ${records.count}", fetchLead0.metricValue() == records.count)
+    assertEquals(s"The lead should be ${records.count}", records.count.toDouble, fetchLead0.metricValue())
 
     // Remove topic from subscription
-    consumer.subscribe(List(topic2).asJava, listener0)
-    TestUtils.waitUntilTrue(() => {
-      consumer.poll(100)
-      listener0.callsToAssigned >= 2
-    }, "Expected rebalance did not occur.")
+    consumer.subscribe(List(topic2).asJava, listener)
+    awaitRebalance(consumer, listener)
     // Verify the metric has gone
     assertNull(consumer.metrics.get(new MetricName("records-lead", "consumer-fetch-manager-metrics", "", tags1)))
     assertNull(consumer.metrics.get(new MetricName("records-lead", "consumer-fetch-manager-metrics", "", tags2)))
@@ -1484,14 +1411,10 @@ class PlaintextConsumerTest extends BaseConsumerTest {
     consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "testPerPartitionLagMetricsCleanUpWithSubscribe")
     consumerConfig.setProperty(ConsumerConfig.CLIENT_ID_CONFIG, "testPerPartitionLagMetricsCleanUpWithSubscribe")
     val consumer = createConsumer()
-    val listener0 = new TestConsumerReassignmentListener
-    consumer.subscribe(List(topic, topic2).asJava, listener0)
-    var records: ConsumerRecords[Array[Byte], Array[Byte]] = ConsumerRecords.empty()
-    TestUtils.waitUntilTrue(() => {
-      records = consumer.poll(100)
-      !records.records(tp).isEmpty
-    }, "Consumer did not consume any message before timeout.")
-    assertEquals("should be assigned once", 1, listener0.callsToAssigned)
+    val listener = new TestConsumerReassignmentListener
+    consumer.subscribe(List(topic, topic2).asJava, listener)
+    val records = awaitNonEmptyRecords(consumer, tp)
+    assertEquals("should be assigned once", 1, listener.callsToAssigned)
     // Verify the metric exist.
     val tags1 = new util.HashMap[String, String]()
     tags1.put("client-id", "testPerPartitionLagMetricsCleanUpWithSubscribe")
@@ -1508,11 +1431,8 @@ class PlaintextConsumerTest extends BaseConsumerTest {
     assertEquals(s"The lag should be $expectedLag", expectedLag, fetchLag0.metricValue.asInstanceOf[Double], epsilon)
 
     // Remove topic from subscription
-    consumer.subscribe(List(topic2).asJava, listener0)
-    TestUtils.waitUntilTrue(() => {
-      consumer.poll(100)
-      listener0.callsToAssigned >= 2
-    }, "Expected rebalance did not occur.")
+    consumer.subscribe(List(topic2).asJava, listener)
+    awaitRebalance(consumer, listener)
     // Verify the metric has gone
     assertNull(consumer.metrics.get(new MetricName("records-lag", "consumer-fetch-manager-metrics", "", tags1)))
     assertNull(consumer.metrics.get(new MetricName("records-lag", "consumer-fetch-manager-metrics", "", tags2)))
@@ -1531,11 +1451,7 @@ class PlaintextConsumerTest extends BaseConsumerTest {
     consumerConfig.setProperty(ConsumerConfig.CLIENT_ID_CONFIG, "testPerPartitionLeadMetricsCleanUpWithAssign")
     val consumer = createConsumer()
     consumer.assign(List(tp).asJava)
-    var records: ConsumerRecords[Array[Byte], Array[Byte]] = ConsumerRecords.empty()
-    TestUtils.waitUntilTrue(() => {
-      records = consumer.poll(100)
-      !records.records(tp).isEmpty
-    }, "Consumer did not consume any message before timeout.")
+    val records = awaitNonEmptyRecords(consumer, tp)
     // Verify the metric exist.
     val tags = new util.HashMap[String, String]()
     tags.put("client-id", "testPerPartitionLeadMetricsCleanUpWithAssign")
@@ -1547,7 +1463,7 @@ class PlaintextConsumerTest extends BaseConsumerTest {
     assertTrue(s"The lead should be ${records.count}", records.count == fetchLead.metricValue())
 
     consumer.assign(List(tp2).asJava)
-    TestUtils.waitUntilTrue(() => !consumer.poll(100).isEmpty, "Consumer did not consume any message before timeout.")
+    awaitNonEmptyRecords(consumer ,tp2)
     assertNull(consumer.metrics.get(new MetricName("records-lead", "consumer-fetch-manager-metrics", "", tags)))
   }
 
@@ -1564,11 +1480,7 @@ class PlaintextConsumerTest extends BaseConsumerTest {
     consumerConfig.setProperty(ConsumerConfig.CLIENT_ID_CONFIG, "testPerPartitionLagMetricsCleanUpWithAssign")
     val consumer = createConsumer()
     consumer.assign(List(tp).asJava)
-    var records: ConsumerRecords[Array[Byte], Array[Byte]] = ConsumerRecords.empty()
-    TestUtils.waitUntilTrue(() => {
-      records = consumer.poll(100)
-      !records.records(tp).isEmpty
-    }, "Consumer did not consume any message before timeout.")
+    val records = awaitNonEmptyRecords(consumer, tp)
     // Verify the metric exist.
     val tags = new util.HashMap[String, String]()
     tags.put("client-id", "testPerPartitionLagMetricsCleanUpWithAssign")
@@ -1581,7 +1493,7 @@ class PlaintextConsumerTest extends BaseConsumerTest {
     assertEquals(s"The lag should be $expectedLag", expectedLag, fetchLag.metricValue.asInstanceOf[Double], epsilon)
 
     consumer.assign(List(tp2).asJava)
-    TestUtils.waitUntilTrue(() => !consumer.poll(100).isEmpty, "Consumer did not consume any message before timeout.")
+    awaitNonEmptyRecords(consumer, tp2)
     assertNull(consumer.metrics.get(new MetricName(tp + ".records-lag", "consumer-fetch-manager-metrics", "", tags)))
     assertNull(consumer.metrics.get(new MetricName("records-lag", "consumer-fetch-manager-metrics", "", tags)))
   }
@@ -1599,11 +1511,7 @@ class PlaintextConsumerTest extends BaseConsumerTest {
     consumerConfig.setProperty(ConsumerConfig.CLIENT_ID_CONFIG, "testPerPartitionLagMetricsCleanUpWithAssign")
     val consumer = createConsumer()
     consumer.assign(List(tp).asJava)
-    var records: ConsumerRecords[Array[Byte], Array[Byte]] = ConsumerRecords.empty()
-    TestUtils.waitUntilTrue(() => {
-      records = consumer.poll(100)
-      !records.records(tp).isEmpty
-    }, "Consumer did not consume any message before timeout.")
+    val records = awaitNonEmptyRecords(consumer, tp)
     // Verify the metric exist.
     val tags = new util.HashMap[String, String]()
     tags.put("client-id", "testPerPartitionLagMetricsCleanUpWithAssign")
@@ -1625,11 +1533,7 @@ class PlaintextConsumerTest extends BaseConsumerTest {
     consumerConfig.setProperty(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, maxPollRecords.toString)
     val consumer = createConsumer()
     consumer.assign(List(tp).asJava)
-    var records: ConsumerRecords[Array[Byte], Array[Byte]] = ConsumerRecords.empty()
-    TestUtils.waitUntilTrue(() => {
-      records = consumer.poll(100)
-      !records.isEmpty
-    }, "Consumer did not consume any message before timeout.")
+    awaitNonEmptyRecords(consumer, tp)
 
     val tags = new util.HashMap[String, String]()
     tags.put("client-id", "testPerPartitionLeadWithMaxPollRecords")
@@ -1651,11 +1555,7 @@ class PlaintextConsumerTest extends BaseConsumerTest {
     consumerConfig.setProperty(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, maxPollRecords.toString)
     val consumer = createConsumer()
     consumer.assign(List(tp).asJava)
-    var records: ConsumerRecords[Array[Byte], Array[Byte]] = ConsumerRecords.empty()
-    TestUtils.waitUntilTrue(() => {
-      records = consumer.poll(100)
-      !records.isEmpty
-    }, "Consumer did not consume any message before timeout.")
+    val records = awaitNonEmptyRecords(consumer, tp)
 
     val tags = new util.HashMap[String, String]()
     tags.put("client-id", "testPerPartitionLagWithMaxPollRecords")
@@ -1881,13 +1781,25 @@ class PlaintextConsumerTest extends BaseConsumerTest {
 
   def changeConsumerSubscriptionAndValidateAssignment[K, V](consumer: Consumer[K, V],
                                                             topicsToSubscribe: List[String],
-                                                            subscriptions: Set[TopicPartition],
+                                                            expectedAssignment: Set[TopicPartition],
                                                             rebalanceListener: ConsumerRebalanceListener): Unit = {
     consumer.subscribe(topicsToSubscribe.asJava, rebalanceListener)
-    TestUtils.waitUntilTrue(() => {
-      consumer.poll(50)
-      consumer.assignment() == subscriptions.asJava
-    }, s"Expected partitions ${subscriptions.asJava} but actually got ${consumer.assignment()}")
+    awaitAssignment(consumer, expectedAssignment)
+  }
+
+  private def awaitNonEmptyRecords[K, V](consumer: Consumer[K, V], partition: TopicPartition): ConsumerRecords[K, V] = {
+    TestUtils.pollRecordsUntilTrue(consumer, (polledRecords: ConsumerRecords[K, V]) => {
+      if (polledRecords.records(partition).asScala.nonEmpty)
+        return polledRecords
+      false
+    }, s"Consumer did not consume any messages for partition $partition before timeout.")
+    throw new IllegalStateException("Should have timed out before reaching here")
+  }
+
+  private def awaitAssignment(consumer: Consumer[_, _], expectedAssignment: Set[TopicPartition]): Unit = {
+    TestUtils.pollUntilTrue(consumer, () => consumer.assignment() == expectedAssignment.asJava,
+      s"Timed out while awaiting expected assignment $expectedAssignment. " +
+        s"The current assignment is ${consumer.assignment()}")
   }
 
 }
diff --git a/core/src/test/scala/integration/kafka/api/TransactionsTest.scala b/core/src/test/scala/integration/kafka/api/TransactionsTest.scala
index ab14db4..e3b447e 100644
--- a/core/src/test/scala/integration/kafka/api/TransactionsTest.scala
+++ b/core/src/test/scala/integration/kafka/api/TransactionsTest.scala
@@ -209,7 +209,7 @@ class TransactionsTest extends KafkaServerTestHarness {
     val readCommittedConsumer = createReadCommittedConsumer(props = consumerProps)
 
     readCommittedConsumer.assign(Set(new TopicPartition(topic1, 0)).asJava)
-    val records = consumeRecords(readCommittedConsumer, numMessages = 2)
+    val records = consumeRecords(readCommittedConsumer, numRecords = 2)
     assertEquals(2, records.size)
 
     val first = records.head
diff --git a/core/src/test/scala/unit/kafka/utils/TestUtils.scala b/core/src/test/scala/unit/kafka/utils/TestUtils.scala
index 3949543..47d45ad 100755
--- a/core/src/test/scala/unit/kafka/utils/TestUtils.scala
+++ b/core/src/test/scala/unit/kafka/utils/TestUtils.scala
@@ -720,32 +720,51 @@ object TestUtils extends Logging {
     }
   }
 
+  def pollUntilTrue(consumer: Consumer[_, _],
+                    action: () => Boolean,
+                    msg: => String,
+                    waitTimeMs: Long = JTestUtils.DEFAULT_MAX_WAIT_MS): Unit = {
+    waitUntilTrue(() => {
+      consumer.poll(Duration.ofMillis(50))
+      action()
+    }, msg = msg, pause = 0L, waitTimeMs = waitTimeMs)
+  }
+
+  def pollRecordsUntilTrue[K, V](consumer: Consumer[K, V],
+                                 action: ConsumerRecords[K, V] => Boolean,
+                                 msg: => String,
+                                 waitTimeMs: Long = JTestUtils.DEFAULT_MAX_WAIT_MS): Unit = {
+    waitUntilTrue(() => {
+      val records = consumer.poll(Duration.ofMillis(50))
+      action(records)
+    }, msg = msg, pause = 0L, waitTimeMs = waitTimeMs)
+  }
+
   /**
     *  Wait until the given condition is true or throw an exception if the given wait time elapses.
     *
     * @param condition condition to check
     * @param msg error message
-    * @param waitTime maximum time to wait and retest the condition before failing the test
+    * @param waitTimeMs maximum time to wait and retest the condition before failing the test
     * @param pause delay between condition checks
     * @param maxRetries maximum number of retries to check the given condition if a retriable exception is thrown
     */
   def waitUntilTrue(condition: () => Boolean, msg: => String,
-                    waitTime: Long = JTestUtils.DEFAULT_MAX_WAIT_MS, pause: Long = 100L, maxRetries: Int = 0): Unit = {
+                    waitTimeMs: Long = JTestUtils.DEFAULT_MAX_WAIT_MS, pause: Long = 100L, maxRetries: Int = 0): Unit = {
     val startTime = System.currentTimeMillis()
     var retry = 0
     while (true) {
       try {
         if (condition())
           return
-        if (System.currentTimeMillis() > startTime + waitTime)
+        if (System.currentTimeMillis() > startTime + waitTimeMs)
           fail(msg)
-        Thread.sleep(waitTime.min(pause))
+        Thread.sleep(waitTimeMs.min(pause))
       }
       catch {
-        case e: RetriableException if retry < maxRetries => {
+        case e: RetriableException if retry < maxRetries =>
           debug("Retrying after error", e)
           retry += 1
-        }
         case e : Throwable => throw e
       }
     }
@@ -840,7 +859,7 @@ object TestUtils extends Logging {
           }
       },
       "Partition [%s,%d] metadata not propagated after %d ms".format(topic, partition, timeout),
-      waitTime = timeout)
+      waitTimeMs = timeout)
 
     leader
   }
@@ -862,7 +881,7 @@ object TestUtils extends Logging {
     }
 
     TestUtils.waitUntilTrue(() => newLeaderExists.isDefined,
-      s"Did not observe leader change for partition $tp after $timeout ms", waitTime = timeout)
+      s"Did not observe leader change for partition $tp after $timeout ms", waitTimeMs = timeout)
 
     newLeaderExists.get
   }
@@ -877,7 +896,7 @@ object TestUtils extends Logging {
     }
 
     TestUtils.waitUntilTrue(() => leaderIfExists.isDefined,
-      s"Partition $tp leaders not made yet after $timeout ms", waitTime = timeout)
+      s"Partition $tp leaders not made yet after $timeout ms", waitTimeMs = timeout)
 
     leaderIfExists.get
   }
@@ -1086,7 +1105,7 @@ object TestUtils extends Logging {
 
     TestUtils.waitUntilTrue(() => authorizer.getAcls(resource) == expected,
       s"expected acls:${expected.mkString(newLine + "\t", newLine + "\t", newLine)}" +
-        s"but got:${authorizer.getAcls(resource).mkString(newLine + "\t", newLine + "\t", newLine)}", waitTime = JTestUtils.DEFAULT_MAX_WAIT_MS)
+        s"but got:${authorizer.getAcls(resource).mkString(newLine + "\t", newLine + "\t", newLine)}", waitTimeMs = JTestUtils.DEFAULT_MAX_WAIT_MS)
   }
 
   /**
@@ -1192,7 +1211,6 @@ object TestUtils extends Logging {
       threadPool.shutdownNow()
     }
     assertTrue(s"$message failed with exception(s) $exceptions", exceptions.isEmpty)
-
   }
 
   def consumeTopicRecords[K, V](servers: Seq[KafkaServer],
@@ -1212,14 +1230,25 @@ object TestUtils extends Logging {
     } finally consumer.close()
   }
 
-  def consumeRecords[K, V](consumer: KafkaConsumer[K, V], numMessages: Int,
-                           waitTime: Long = JTestUtils.DEFAULT_MAX_WAIT_MS): Seq[ConsumerRecord[K, V]] = {
+  def pollUntilAtLeastNumRecords[K, V](consumer: KafkaConsumer[K, V],
+                                       numRecords: Int,
+                                       waitTimeMs: Long = JTestUtils.DEFAULT_MAX_WAIT_MS): Seq[ConsumerRecord[K, V]] = {
     val records = new ArrayBuffer[ConsumerRecord[K, V]]()
-    waitUntilTrue(() => {
-      records ++= consumer.poll(Duration.ofMillis(50)).asScala
-      records.size >= numMessages
-    }, s"Consumed ${records.size} records until timeout instead of the expected $numMessages records", waitTime)
-    assertEquals("Consumed more records than expected", numMessages, records.size)
+    def pollAction(polledRecords: ConsumerRecords[K, V]): Boolean = {
+      records ++= polledRecords.asScala
+      records.size >= numRecords
+    }
+    pollRecordsUntilTrue(consumer, pollAction,
+      waitTimeMs = waitTimeMs,
+      msg = s"Consumed ${records.size} records before timeout instead of the expected $numRecords records")
+    records
+  }
+
+  def consumeRecords[K, V](consumer: KafkaConsumer[K, V],
+                           numRecords: Int,
+                           waitTimeMs: Long = JTestUtils.DEFAULT_MAX_WAIT_MS): Seq[ConsumerRecord[K, V]] = {
+    val records = pollUntilAtLeastNumRecords(consumer, numRecords, waitTimeMs)
+    assertEquals("Consumed more records than expected", numRecords, records.size)
     records
   }
 
@@ -1318,15 +1347,6 @@ object TestUtils extends Logging {
     offsetsToCommit.toMap
   }
 
-  def pollUntilAtLeastNumRecords(consumer: KafkaConsumer[Array[Byte], Array[Byte]], numRecords: Int): Seq[ConsumerRecord[Array[Byte], Array[Byte]]] = {
-    val records = new ArrayBuffer[ConsumerRecord[Array[Byte], Array[Byte]]]()
-    TestUtils.waitUntilTrue(() => {
-      records ++= consumer.poll(Duration.ofMillis(50)).asScala
-      records.size >= numRecords
-    }, s"Consumed ${records.size} records until timeout, but expected $numRecords records.")
-    records
-  }
-
   def resetToCommittedPositions(consumer: KafkaConsumer[Array[Byte], Array[Byte]]) = {
     consumer.assignment.asScala.foreach { topicPartition =>
       val offset = consumer.committed(topicPartition)