You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by wc...@apache.org on 2023/10/05 14:12:02 UTC

[kafka] branch trunk updated: KAFKA-15415: On producer-batch retry, skip-backoff on a new leader (#14384)

This is an automated email from the ASF dual-hosted git repository.

wcarlson pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new d817b1b5900 KAFKA-15415: On producer-batch retry, skip-backoff on a new leader (#14384)
d817b1b5900 is described below

commit d817b1b5900e16d76ceae570d6e93d4d57783b73
Author: Mayank Shekhar Narula <42...@users.noreply.github.com>
AuthorDate: Thu Oct 5 15:11:47 2023 +0100

    KAFKA-15415: On producer-batch retry, skip-backoff on a new leader (#14384)
    
    When producer-batch is being retried, new-leader is known for the partition Vs the leader used in last attempt, then it is worthwhile to retry immediately to this new leader. A partition-leader is considered to be newer, if the epoch has advanced.
    
    Reviewers: Walker Carlson <wc...@apache.org>, Kirk True <ki...@kirktrue.pro>, Andrew Schofield <andrew_schofield@uk.ibm.com
---
 checkstyle/suppressions.xml                        |   2 +-
 .../clients/producer/internals/ProducerBatch.java  |  49 +++
 .../producer/internals/RecordAccumulator.java      |  55 ++-
 .../kafka/clients/producer/internals/Sender.java   |   6 +-
 .../producer/internals/ProducerBatchTest.java      |  51 +++
 .../producer/internals/RecordAccumulatorTest.java  | 434 +++++++++++++++++----
 .../clients/producer/internals/SenderTest.java     |  89 ++++-
 .../producer/internals/TransactionManagerTest.java |  24 +-
 8 files changed, 603 insertions(+), 107 deletions(-)

diff --git a/checkstyle/suppressions.xml b/checkstyle/suppressions.xml
index c6dfae8cc44..c9601753c1e 100644
--- a/checkstyle/suppressions.xml
+++ b/checkstyle/suppressions.xml
@@ -93,7 +93,7 @@
               files="(AbstractRequest|AbstractResponse|KerberosLogin|WorkerSinkTaskTest|TransactionManagerTest|SenderTest|KafkaAdminClient|ConsumerCoordinatorTest|KafkaAdminClientTest|KafkaRaftClientTest).java"/>
 
     <suppress checks="NPathComplexity"
-              files="(ConsumerCoordinator|BufferPool|MetricName|Node|ConfigDef|RecordBatch|SslFactory|SslTransportLayer|MetadataResponse|KerberosLogin|Selector|Sender|Serdes|TokenInformation|Agent|Values|PluginUtils|MiniTrogdorCluster|TasksRequest|KafkaProducer|AbstractStickyAssignor|KafkaRaftClient|Authorizer|FetchSessionHandler).java"/>
+              files="(ConsumerCoordinator|BufferPool|MetricName|Node|ConfigDef|RecordBatch|SslFactory|SslTransportLayer|MetadataResponse|KerberosLogin|Selector|Sender|Serdes|TokenInformation|Agent|Values|PluginUtils|MiniTrogdorCluster|TasksRequest|KafkaProducer|AbstractStickyAssignor|KafkaRaftClient|Authorizer|FetchSessionHandler|RecordAccumulator).java"/>
 
     <suppress checks="(JavaNCSS|CyclomaticComplexity|MethodLength)"
               files="CoordinatorClient.java"/>
diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProducerBatch.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProducerBatch.java
index 4da03627be1..408b8316eb8 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProducerBatch.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProducerBatch.java
@@ -16,6 +16,7 @@
  */
 package org.apache.kafka.clients.producer.internals;
 
+import java.util.Optional;
 import org.apache.kafka.clients.producer.Callback;
 import org.apache.kafka.clients.producer.RecordMetadata;
 import org.apache.kafka.common.TopicPartition;
@@ -79,6 +80,11 @@ public final class ProducerBatch {
     private boolean retry;
     private boolean reopened;
 
+    // Tracks the current-leader's epoch to which this batch would be sent, in the current to produce the batch.
+    private Optional<Integer> currentLeaderEpoch;
+    // Tracks the attempt in which leader was changed to currentLeaderEpoch for the 1st time.
+    private int attemptsWhenLeaderLastChanged;
+
     public ProducerBatch(TopicPartition tp, MemoryRecordsBuilder recordsBuilder, long createdMs) {
         this(tp, recordsBuilder, createdMs, false);
     }
@@ -94,9 +100,42 @@ public final class ProducerBatch {
         this.isSplitBatch = isSplitBatch;
         float compressionRatioEstimation = CompressionRatioEstimator.estimation(topicPartition.topic(),
                                                                                 recordsBuilder.compressionType());
+        this.currentLeaderEpoch = Optional.empty();
+        this.attemptsWhenLeaderLastChanged = 0;
         recordsBuilder.setEstimatedCompressionRatio(compressionRatioEstimation);
     }
 
+    /**
+     * It will update the leader to which this batch will be produced for the ongoing attempt, if a newer leader is known.
+     * @param latestLeaderEpoch latest leader's epoch.
+     */
+    void maybeUpdateLeaderEpoch(Optional<Integer> latestLeaderEpoch) {
+        if (!currentLeaderEpoch.equals(latestLeaderEpoch)) {
+            log.trace("For {}, leader will be updated, currentLeaderEpoch: {}, attemptsWhenLeaderLastChanged:{}, latestLeaderEpoch: {}, current attempt: {}",
+                this, currentLeaderEpoch, attemptsWhenLeaderLastChanged, latestLeaderEpoch, attempts);
+            attemptsWhenLeaderLastChanged = attempts();
+            currentLeaderEpoch = latestLeaderEpoch;
+        } else {
+            log.trace("For {}, leader wasn't updated, currentLeaderEpoch: {}, attemptsWhenLeaderLastChanged:{}, latestLeaderEpoch: {}, current attempt: {}",
+                this, currentLeaderEpoch, attemptsWhenLeaderLastChanged, latestLeaderEpoch, attempts);
+        }
+    }
+
+    /**
+     * It will return true, for a when batch is being retried, it will be retried to a newer leader.
+     */
+
+    boolean hasLeaderChangedForTheOngoingRetry() {
+        int attempts = attempts();
+        boolean isRetry = attempts >= 1;
+        if (!isRetry)
+            return false;
+        if (attempts == attemptsWhenLeaderLastChanged)
+            return true;
+        return false;
+    }
+
+
     /**
      * Append the record to the current record set and return the relative offset within that record set
      *
@@ -517,4 +556,14 @@ public final class ProducerBatch {
     public boolean sequenceHasBeenReset() {
         return reopened;
     }
+
+    // VisibleForTesting
+    Optional<Integer> currentLeaderEpoch() {
+        return currentLeaderEpoch;
+    }
+
+    // VisibleForTesting
+    int attemptsWhenLeaderLastChanged() {
+        return attemptsWhenLeaderLastChanged;
+    }
 }
diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/RecordAccumulator.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/RecordAccumulator.java
index a3f27bf59d9..5e1795cb2a1 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/internals/RecordAccumulator.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/RecordAccumulator.java
@@ -32,6 +32,7 @@ import java.util.concurrent.atomic.AtomicInteger;
 
 import org.apache.kafka.clients.ApiVersions;
 import org.apache.kafka.clients.CommonClientConfigs;
+import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.producer.Callback;
 import org.apache.kafka.clients.producer.RecordMetadata;
 import org.apache.kafka.common.utils.ExponentialBackoff;
@@ -650,7 +651,7 @@ public class RecordAccumulator {
      * into the set of ready nodes.  If partition has no leader, add the topic to the set of topics with
      * no leader.  This function also calculates stats for adaptive partitioning.
      *
-     * @param cluster The cluster metadata
+     * @param metadata The cluster metadata
      * @param nowMs The current time
      * @param topic The topic
      * @param topicInfo The topic info
@@ -659,14 +660,14 @@ public class RecordAccumulator {
      * @param unknownLeaderTopics The set of topics with no leader (to be filled in)
      * @return The delay for next check
      */
-    private long partitionReady(Cluster cluster, long nowMs, String topic,
+    private long partitionReady(Metadata metadata, long nowMs, String topic,
                                 TopicInfo topicInfo,
                                 long nextReadyCheckDelayMs, Set<Node> readyNodes, Set<String> unknownLeaderTopics) {
         ConcurrentMap<Integer, Deque<ProducerBatch>> batches = topicInfo.batches;
         // Collect the queue sizes for available partitions to be used in adaptive partitioning.
         int[] queueSizes = null;
         int[] partitionIds = null;
-        if (enableAdaptivePartitioning && batches.size() >= cluster.partitionsForTopic(topic).size()) {
+        if (enableAdaptivePartitioning && batches.size() >= metadata.fetch().partitionsForTopic(topic).size()) {
             // We don't do adaptive partitioning until we scheduled at least a batch for all
             // partitions (i.e. we have the corresponding entries in the batches map), we just
             // do uniform.  The reason is that we build queue sizes from the batches map,
@@ -682,7 +683,9 @@ public class RecordAccumulator {
             TopicPartition part = new TopicPartition(topic, entry.getKey());
             // Advance queueSizesIndex so that we properly index available
             // partitions.  Do it here so that it's done for all code paths.
-            Node leader = cluster.leaderFor(part);
+
+            Metadata.LeaderAndEpoch leaderAndEpoch = metadata.currentLeader(part);
+            Node leader = leaderAndEpoch.leader.orElse(null);
             if (leader != null && queueSizes != null) {
                 ++queueSizesIndex;
                 assert queueSizesIndex < queueSizes.length;
@@ -712,7 +715,8 @@ public class RecordAccumulator {
                 }
 
                 waitedTimeMs = batch.waitedTimeMs(nowMs);
-                backingOff = shouldBackoff(batch, waitedTimeMs);
+                batch.maybeUpdateLeaderEpoch(leaderAndEpoch.epoch);
+                backingOff = shouldBackoff(batch.hasLeaderChangedForTheOngoingRetry(), batch, waitedTimeMs);
                 backoffAttempts = batch.attempts();
                 dequeSize = deque.size();
                 full = dequeSize > 1 || batch.isFull();
@@ -772,7 +776,7 @@ public class RecordAccumulator {
      * </ul>
      * </ol>
      */
-    public ReadyCheckResult ready(Cluster cluster, long nowMs) {
+    public ReadyCheckResult ready(Metadata metadata, long nowMs) {
         Set<Node> readyNodes = new HashSet<>();
         long nextReadyCheckDelayMs = Long.MAX_VALUE;
         Set<String> unknownLeaderTopics = new HashSet<>();
@@ -780,7 +784,7 @@ public class RecordAccumulator {
         // cumulative frequency table (used in partitioner).
         for (Map.Entry<String, TopicInfo> topicInfoEntry : this.topicInfoMap.entrySet()) {
             final String topic = topicInfoEntry.getKey();
-            nextReadyCheckDelayMs = partitionReady(cluster, nowMs, topic, topicInfoEntry.getValue(), nextReadyCheckDelayMs, readyNodes, unknownLeaderTopics);
+            nextReadyCheckDelayMs = partitionReady(metadata, nowMs, topic, topicInfoEntry.getValue(), nextReadyCheckDelayMs, readyNodes, unknownLeaderTopics);
         }
         return new ReadyCheckResult(readyNodes, nextReadyCheckDelayMs, unknownLeaderTopics);
     }
@@ -800,8 +804,17 @@ public class RecordAccumulator {
         return false;
     }
 
-    private boolean shouldBackoff(final ProducerBatch batch, final long waitedTimeMs) {
-        return batch.attempts() > 0 && waitedTimeMs < retryBackoff.backoff(batch.attempts() - 1);
+    private boolean shouldBackoff(boolean hasLeaderChanged, final ProducerBatch batch, final long waitedTimeMs) {
+        boolean shouldWaitMore = batch.attempts() > 0 && waitedTimeMs < retryBackoff.backoff(batch.attempts() - 1);
+        boolean shouldBackoff = !hasLeaderChanged && shouldWaitMore;
+        if (shouldBackoff) {
+            log.trace(
+                "For {}, will backoff", batch);
+        } else {
+            log.trace(
+                "For {}, will not backoff, shouldWaitMore {}, hasLeaderChanged {}", batch, shouldWaitMore, hasLeaderChanged);
+        }
+        return shouldBackoff;
     }
 
     private boolean shouldStopDrainBatchesForPartition(ProducerBatch first, TopicPartition tp) {
@@ -842,22 +855,31 @@ public class RecordAccumulator {
         return false;
     }
 
-    private List<ProducerBatch> drainBatchesForOneNode(Cluster cluster, Node node, int maxSize, long now) {
+    private List<ProducerBatch> drainBatchesForOneNode(Metadata metadata, Node node, int maxSize, long now) {
         int size = 0;
-        List<PartitionInfo> parts = cluster.partitionsForNode(node.id());
+        List<PartitionInfo> parts = metadata.fetch().partitionsForNode(node.id());
         List<ProducerBatch> ready = new ArrayList<>();
+        if (parts.isEmpty())
+            return ready;
         /* to make starvation less likely each node has it's own drainIndex */
         int drainIndex = getDrainIndex(node.idString());
         int start = drainIndex = drainIndex % parts.size();
         do {
             PartitionInfo part = parts.get(drainIndex);
+
             TopicPartition tp = new TopicPartition(part.topic(), part.partition());
             updateDrainIndex(node.idString(), drainIndex);
             drainIndex = (drainIndex + 1) % parts.size();
             // Only proceed if the partition has no in-flight batches.
             if (isMuted(tp))
                 continue;
-
+            Metadata.LeaderAndEpoch leaderAndEpoch = metadata.currentLeader(tp);
+            // Although a small chance, but skip this partition if leader has changed since the partition -> node assignment obtained from outside the loop.
+            // In this case, skip sending it to the old leader, as it would return aa NO_LEADER_OR_FOLLOWER error.
+            if (!leaderAndEpoch.leader.isPresent())
+                continue;
+            if (!node.equals(leaderAndEpoch.leader.get()))
+                continue;
             Deque<ProducerBatch> deque = getDeque(tp);
             if (deque == null)
                 continue;
@@ -871,7 +893,8 @@ public class RecordAccumulator {
 
                 // first != null
                 // Only drain the batch if it is not during backoff period.
-                if (shouldBackoff(first, first.waitedTimeMs(now)))
+                first.maybeUpdateLeaderEpoch(leaderAndEpoch.epoch);
+                if (shouldBackoff(first.hasLeaderChangedForTheOngoingRetry(), first, first.waitedTimeMs(now)))
                     continue;
 
                 if (size + first.estimatedSizeInBytes() > maxSize && !ready.isEmpty()) {
@@ -937,19 +960,19 @@ public class RecordAccumulator {
      * Drain all the data for the given nodes and collate them into a list of batches that will fit within the specified
      * size on a per-node basis. This method attempts to avoid choosing the same topic-node over and over.
      *
-     * @param cluster The current cluster metadata
+     * @param metadata The current cluster metadata
      * @param nodes The list of node to drain
      * @param maxSize The maximum number of bytes to drain
      * @param now The current unix time in milliseconds
      * @return A list of {@link ProducerBatch} for each node specified with total size less than the requested maxSize.
      */
-    public Map<Integer, List<ProducerBatch>> drain(Cluster cluster, Set<Node> nodes, int maxSize, long now) {
+    public Map<Integer, List<ProducerBatch>> drain(Metadata metadata, Set<Node> nodes, int maxSize, long now) {
         if (nodes.isEmpty())
             return Collections.emptyMap();
 
         Map<Integer, List<ProducerBatch>> batches = new HashMap<>();
         for (Node node : nodes) {
-            List<ProducerBatch> ready = drainBatchesForOneNode(cluster, node, maxSize, now);
+            List<ProducerBatch> ready = drainBatchesForOneNode(metadata, node, maxSize, now);
             batches.put(node.id(), ready);
         }
         return batches;
diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/Sender.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/Sender.java
index c18a6e44701..1214df496f5 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/internals/Sender.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/Sender.java
@@ -23,7 +23,6 @@ import org.apache.kafka.clients.KafkaClient;
 import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.NetworkClientUtils;
 import org.apache.kafka.clients.RequestCompletionHandler;
-import org.apache.kafka.common.Cluster;
 import org.apache.kafka.common.InvalidRecordException;
 import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.MetricName;
@@ -359,9 +358,8 @@ public class Sender implements Runnable {
     }
 
     private long sendProducerData(long now) {
-        Cluster cluster = metadata.fetch();
         // get the list of partitions with data ready to send
-        RecordAccumulator.ReadyCheckResult result = this.accumulator.ready(cluster, now);
+        RecordAccumulator.ReadyCheckResult result = this.accumulator.ready(metadata, now);
 
         // if there are any partitions whose leaders are not known yet, force metadata update
         if (!result.unknownLeaderTopics.isEmpty()) {
@@ -396,7 +394,7 @@ public class Sender implements Runnable {
         }
 
         // create produce requests
-        Map<Integer, List<ProducerBatch>> batches = this.accumulator.drain(cluster, result.readyNodes, this.maxRequestSize, now);
+        Map<Integer, List<ProducerBatch>> batches = this.accumulator.drain(metadata, result.readyNodes, this.maxRequestSize, now);
         addToInflightBatches(batches);
         if (guaranteeMessageOrder) {
             // Mute all the partitions drained
diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/ProducerBatchTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/ProducerBatchTest.java
index 03f5e0b6aa8..24629b612b2 100644
--- a/clients/src/test/java/org/apache/kafka/clients/producer/internals/ProducerBatchTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/ProducerBatchTest.java
@@ -16,6 +16,7 @@
  */
 package org.apache.kafka.clients.producer.internals;
 
+import java.util.Optional;
 import org.apache.kafka.clients.producer.Callback;
 import org.apache.kafka.clients.producer.RecordMetadata;
 import org.apache.kafka.common.KafkaException;
@@ -263,6 +264,56 @@ public class ProducerBatchTest {
             testCompleteExceptionally(recordCount, topLevelException, null));
     }
 
+    /**
+     * This tests that leader is correctly maintained & leader-change is correctly detected across retries
+     * of the batch. It does so by testing primarily testing methods
+     * 1. maybeUpdateLeaderEpoch
+     * 2. hasLeaderChangedForTheOngoingRetry
+     */
+
+    @Test
+    public void testWithLeaderChangesAcrossRetries() {
+        ProducerBatch batch = new ProducerBatch(new TopicPartition("topic", 1), memoryRecordsBuilder, now);
+
+        // Starting state for the batch, no attempt made to send it yet.
+        assertEquals(Optional.empty(), batch.currentLeaderEpoch());
+        assertEquals(0, batch.attemptsWhenLeaderLastChanged()); // default value
+        batch.maybeUpdateLeaderEpoch(Optional.empty());
+        assertFalse(batch.hasLeaderChangedForTheOngoingRetry());
+
+        // 1st attempt[Not a retry] to send the batch.
+        // Check leader isn't flagged as a new leader.
+        int batchLeaderEpoch = 100;
+        batch.maybeUpdateLeaderEpoch(Optional.of(batchLeaderEpoch));
+        assertFalse(batch.hasLeaderChangedForTheOngoingRetry(), "batch leader is assigned for 1st time");
+        assertEquals(batchLeaderEpoch, batch.currentLeaderEpoch().get());
+        assertEquals(0, batch.attemptsWhenLeaderLastChanged());
+
+        // 2nd attempt[1st retry] to send the batch to a new leader.
+        // Check leader change is detected.
+        batchLeaderEpoch = 101;
+        batch.reenqueued(0);
+        batch.maybeUpdateLeaderEpoch(Optional.of(batchLeaderEpoch));
+        assertTrue(batch.hasLeaderChangedForTheOngoingRetry(), "batch leader has changed");
+        assertEquals(batchLeaderEpoch, batch.currentLeaderEpoch().get());
+        assertEquals(1, batch.attemptsWhenLeaderLastChanged());
+
+        // 2nd attempt[1st retry] still ongoing, yet to be made.
+        // Check same leaderEpoch(101) is still considered as a leader-change.
+        batch.maybeUpdateLeaderEpoch(Optional.of(batchLeaderEpoch));
+        assertTrue(batch.hasLeaderChangedForTheOngoingRetry(), "batch leader has changed");
+        assertEquals(batchLeaderEpoch, batch.currentLeaderEpoch().get());
+        assertEquals(1, batch.attemptsWhenLeaderLastChanged());
+
+        // 3rd attempt[2nd retry] to the same leader-epoch(101).
+        // Check same leaderEpoch(101) as not detected as a leader-change.
+        batch.reenqueued(0);
+        batch.maybeUpdateLeaderEpoch(Optional.of(batchLeaderEpoch));
+        assertFalse(batch.hasLeaderChangedForTheOngoingRetry(), "batch leader has not changed");
+        assertEquals(batchLeaderEpoch, batch.currentLeaderEpoch().get());
+        assertEquals(1, batch.attemptsWhenLeaderLastChanged());
+    }
+
     private void testCompleteExceptionally(
         int recordCount,
         RuntimeException topLevelException,
diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/RecordAccumulatorTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/RecordAccumulatorTest.java
index b7710cdfff1..a046efe2cb2 100644
--- a/clients/src/test/java/org/apache/kafka/clients/producer/internals/RecordAccumulatorTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/RecordAccumulatorTest.java
@@ -16,8 +16,11 @@
  */
 package org.apache.kafka.clients.producer.internals;
 
+import java.util.Optional;
+import java.util.function.Function;
 import org.apache.kafka.clients.ApiVersions;
 import org.apache.kafka.clients.CommonClientConfigs;
+import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.NodeApiVersions;
 import org.apache.kafka.clients.producer.Callback;
 import org.apache.kafka.clients.producer.Partitioner;
@@ -45,6 +48,7 @@ import org.apache.kafka.common.utils.ProducerIdAndEpoch;
 import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.test.TestUtils;
 import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 import org.mockito.Mockito;
 
@@ -64,6 +68,7 @@ import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Future;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.stream.Collectors;
+
 import static java.util.Arrays.asList;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
@@ -80,6 +85,7 @@ public class RecordAccumulatorTest {
     private int partition3 = 2;
     private Node node1 = new Node(0, "localhost", 1111);
     private Node node2 = new Node(1, "localhost", 1112);
+
     private TopicPartition tp1 = new TopicPartition(topic, partition1);
     private TopicPartition tp2 = new TopicPartition(topic, partition2);
     private TopicPartition tp3 = new TopicPartition(topic, partition3);
@@ -90,15 +96,22 @@ public class RecordAccumulatorTest {
     private byte[] key = "key".getBytes();
     private byte[] value = "value".getBytes();
     private int msgSize = DefaultRecord.sizeInBytes(0, 0, key.length, value.length, Record.EMPTY_HEADERS);
+    Metadata metadataMock;
     private Cluster cluster = new Cluster(null, Arrays.asList(node1, node2), Arrays.asList(part1, part2, part3),
         Collections.emptySet(), Collections.emptySet());
     private Metrics metrics = new Metrics(time);
     private final long maxBlockTimeMs = 1000;
     private final LogContext logContext = new LogContext();
 
+    @BeforeEach
+    public void setup() {
+        metadataMock = setupMetadata(cluster);
+    }
+
     @AfterEach
     public void teardown() {
         this.metrics.close();
+        Mockito.reset(metadataMock);
     }
 
     @Test
@@ -113,6 +126,7 @@ public class RecordAccumulatorTest {
         RecordAccumulator accum = createTestRecordAccumulator((int) batchSize, Integer.MAX_VALUE, CompressionType.NONE, 10);
         Cluster cluster = new Cluster(null, Arrays.asList(node1, node2), Arrays.asList(part1, part2, part3, part4),
                 Collections.emptySet(), Collections.emptySet());
+        metadataMock = setupMetadata(cluster);
 
         //  initial data
         accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds(), cluster);
@@ -121,7 +135,7 @@ public class RecordAccumulatorTest {
         accum.append(topic, partition4, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds(), cluster);
 
         // drain batches from 2 nodes: node1 => tp1, node2 => tp3, because the max request size is full after the first batch drained
-        Map<Integer, List<ProducerBatch>> batches1 = accum.drain(cluster, new HashSet<>(Arrays.asList(node1, node2)), (int) batchSize, 0);
+        Map<Integer, List<ProducerBatch>> batches1 = accum.drain(metadataMock, new HashSet<>(Arrays.asList(node1, node2)), (int) batchSize, 0);
         verifyTopicPartitionInBatches(batches1, tp1, tp3);
 
         // add record for tp1, tp3
@@ -130,11 +144,11 @@ public class RecordAccumulatorTest {
 
         // drain batches from 2 nodes: node1 => tp2, node2 => tp4, because the max request size is full after the first batch drained
         // The drain index should start from next topic partition, that is, node1 => tp2, node2 => tp4
-        Map<Integer, List<ProducerBatch>> batches2 = accum.drain(cluster, new HashSet<>(Arrays.asList(node1, node2)), (int) batchSize, 0);
+        Map<Integer, List<ProducerBatch>> batches2 = accum.drain(metadataMock, new HashSet<>(Arrays.asList(node1, node2)), (int) batchSize, 0);
         verifyTopicPartitionInBatches(batches2, tp2, tp4);
 
         // make sure in next run, the drain index will start from the beginning
-        Map<Integer, List<ProducerBatch>> batches3 = accum.drain(cluster, new HashSet<>(Arrays.asList(node1, node2)), (int) batchSize, 0);
+        Map<Integer, List<ProducerBatch>> batches3 = accum.drain(metadataMock, new HashSet<>(Arrays.asList(node1, node2)), (int) batchSize, 0);
         verifyTopicPartitionInBatches(batches3, tp1, tp3);
 
         // add record for tp2, tp3, tp4 and mute the tp4
@@ -143,7 +157,7 @@ public class RecordAccumulatorTest {
         accum.append(topic, partition4, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds(), cluster);
         accum.mutePartition(tp4);
         // drain batches from 2 nodes: node1 => tp2, node2 => tp3 (because tp4 is muted)
-        Map<Integer, List<ProducerBatch>> batches4 = accum.drain(cluster, new HashSet<>(Arrays.asList(node1, node2)), (int) batchSize, 0);
+        Map<Integer, List<ProducerBatch>> batches4 = accum.drain(metadataMock, new HashSet<>(Arrays.asList(node1, node2)), (int) batchSize, 0);
         verifyTopicPartitionInBatches(batches4, tp2, tp3);
 
         // add record for tp1, tp2, tp3, and unmute tp4
@@ -152,7 +166,7 @@ public class RecordAccumulatorTest {
         accum.append(topic, partition3, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds(), cluster);
         accum.unmutePartition(tp4);
         // set maxSize as a max value, so that the all partitions in 2 nodes should be drained: node1 => [tp1, tp2], node2 => [tp3, tp4]
-        Map<Integer, List<ProducerBatch>> batches5 = accum.drain(cluster, new HashSet<>(Arrays.asList(node1, node2)), Integer.MAX_VALUE, 0);
+        Map<Integer, List<ProducerBatch>> batches5 = accum.drain(metadataMock, new HashSet<>(Arrays.asList(node1, node2)), Integer.MAX_VALUE, 0);
         verifyTopicPartitionInBatches(batches5, tp1, tp2, tp3, tp4);
     }
 
@@ -189,7 +203,7 @@ public class RecordAccumulatorTest {
 
             ProducerBatch batch = partitionBatches.peekFirst();
             assertTrue(batch.isWritable());
-            assertEquals(0, accum.ready(cluster, now).readyNodes.size(), "No partitions should be ready.");
+            assertEquals(0, accum.ready(metadataMock, now).readyNodes.size(), "No partitions should be ready.");
         }
 
         // this append doesn't fit in the first batch, so a new batch is created and the first batch is closed
@@ -199,9 +213,9 @@ public class RecordAccumulatorTest {
         assertEquals(2, partitionBatches.size());
         Iterator<ProducerBatch> partitionBatchesIterator = partitionBatches.iterator();
         assertTrue(partitionBatchesIterator.next().isWritable());
-        assertEquals(Collections.singleton(node1), accum.ready(cluster, time.milliseconds()).readyNodes, "Our partition's leader should be ready");
+        assertEquals(Collections.singleton(node1), accum.ready(metadataMock, time.milliseconds()).readyNodes, "Our partition's leader should be ready");
 
-        List<ProducerBatch> batches = accum.drain(cluster, Collections.singleton(node1), Integer.MAX_VALUE, 0).get(node1.id());
+        List<ProducerBatch> batches = accum.drain(metadataMock, Collections.singleton(node1), Integer.MAX_VALUE, 0).get(node1.id());
         assertEquals(1, batches.size());
         ProducerBatch batch = batches.get(0);
 
@@ -230,7 +244,7 @@ public class RecordAccumulatorTest {
         RecordAccumulator accum = createTestRecordAccumulator(
                 batchSize + DefaultRecordBatch.RECORD_BATCH_OVERHEAD, 10 * 1024, compressionType, 0);
         accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds(), cluster);
-        assertEquals(Collections.singleton(node1), accum.ready(cluster, time.milliseconds()).readyNodes, "Our partition's leader should be ready");
+        assertEquals(Collections.singleton(node1), accum.ready(metadataMock, time.milliseconds()).readyNodes, "Our partition's leader should be ready");
 
         Deque<ProducerBatch> batches = accum.getDeque(tp1);
         assertEquals(1, batches.size());
@@ -268,7 +282,7 @@ public class RecordAccumulatorTest {
         RecordAccumulator accum = createTestRecordAccumulator(
                 batchSize + DefaultRecordBatch.RECORD_BATCH_OVERHEAD, 10 * 1024, compressionType, 0);
         accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds(), cluster);
-        assertEquals(Collections.singleton(node1), accum.ready(cluster, time.milliseconds()).readyNodes, "Our partition's leader should be ready");
+        assertEquals(Collections.singleton(node1), accum.ready(metadataMock, time.milliseconds()).readyNodes, "Our partition's leader should be ready");
 
         Deque<ProducerBatch> batches = accum.getDeque(tp1);
         assertEquals(1, batches.size());
@@ -292,10 +306,10 @@ public class RecordAccumulatorTest {
         RecordAccumulator accum = createTestRecordAccumulator(
                 1024 + DefaultRecordBatch.RECORD_BATCH_OVERHEAD, 10 * 1024, CompressionType.NONE, lingerMs);
         accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds(), cluster);
-        assertEquals(0, accum.ready(cluster, time.milliseconds()).readyNodes.size(), "No partitions should be ready");
+        assertEquals(0, accum.ready(metadataMock, time.milliseconds()).readyNodes.size(), "No partitions should be ready");
         time.sleep(10);
-        assertEquals(Collections.singleton(node1), accum.ready(cluster, time.milliseconds()).readyNodes, "Our partition's leader should be ready");
-        List<ProducerBatch> batches = accum.drain(cluster, Collections.singleton(node1), Integer.MAX_VALUE, 0).get(node1.id());
+        assertEquals(Collections.singleton(node1), accum.ready(metadataMock, time.milliseconds()).readyNodes, "Our partition's leader should be ready");
+        List<ProducerBatch> batches = accum.drain(metadataMock, Collections.singleton(node1), Integer.MAX_VALUE, 0).get(node1.id());
         assertEquals(1, batches.size());
         ProducerBatch batch = batches.get(0);
 
@@ -316,9 +330,9 @@ public class RecordAccumulatorTest {
             for (int i = 0; i < appends; i++)
                 accum.append(tp.topic(), tp.partition(), 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds(), cluster);
         }
-        assertEquals(Collections.singleton(node1), accum.ready(cluster, time.milliseconds()).readyNodes, "Partition's leader should be ready");
+        assertEquals(Collections.singleton(node1), accum.ready(metadataMock, time.milliseconds()).readyNodes, "Partition's leader should be ready");
 
-        List<ProducerBatch> batches = accum.drain(cluster, Collections.singleton(node1), 1024, 0).get(node1.id());
+        List<ProducerBatch> batches = accum.drain(metadataMock, Collections.singleton(node1), 1024, 0).get(node1.id());
         assertEquals(1, batches.size(), "But due to size bound only one partition should have been retrieved");
     }
 
@@ -347,8 +361,8 @@ public class RecordAccumulatorTest {
         int read = 0;
         long now = time.milliseconds();
         while (read < numThreads * msgs) {
-            Set<Node> nodes = accum.ready(cluster, now).readyNodes;
-            List<ProducerBatch> batches = accum.drain(cluster, nodes, 5 * 1024, 0).get(node1.id());
+            Set<Node> nodes = accum.ready(metadataMock, now).readyNodes;
+            List<ProducerBatch> batches = accum.drain(metadataMock, nodes, 5 * 1024, 0).get(node1.id());
             if (batches != null) {
                 for (ProducerBatch batch : batches) {
                     for (Record record : batch.records().records())
@@ -379,7 +393,7 @@ public class RecordAccumulatorTest {
         // Partition on node1 only
         for (int i = 0; i < appends; i++)
             accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds(), cluster);
-        RecordAccumulator.ReadyCheckResult result = accum.ready(cluster, time.milliseconds());
+        RecordAccumulator.ReadyCheckResult result = accum.ready(metadataMock, time.milliseconds());
         assertEquals(0, result.readyNodes.size(), "No nodes should be ready.");
         assertEquals(lingerMs, result.nextReadyCheckDelayMs, "Next check time should be the linger time");
 
@@ -388,14 +402,14 @@ public class RecordAccumulatorTest {
         // Add partition on node2 only
         for (int i = 0; i < appends; i++)
             accum.append(topic, partition3, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds(), cluster);
-        result = accum.ready(cluster, time.milliseconds());
+        result = accum.ready(metadataMock, time.milliseconds());
         assertEquals(0, result.readyNodes.size(), "No nodes should be ready.");
         assertEquals(lingerMs / 2, result.nextReadyCheckDelayMs, "Next check time should be defined by node1, half remaining linger time");
 
         // Add data for another partition on node1, enough to make data sendable immediately
         for (int i = 0; i < appends + 1; i++)
             accum.append(topic, partition2, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds(), cluster);
-        result = accum.ready(cluster, time.milliseconds());
+        result = accum.ready(metadataMock, time.milliseconds());
         assertEquals(Collections.singleton(node1), result.readyNodes, "Node1 should be ready");
         // Note this can actually be < linger time because it may use delays from partitions that aren't sendable
         // but have leaders with other sendable data.
@@ -419,9 +433,9 @@ public class RecordAccumulatorTest {
 
         long now = time.milliseconds();
         accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds(), cluster);
-        RecordAccumulator.ReadyCheckResult result = accum.ready(cluster, now + lingerMs + 1);
+        RecordAccumulator.ReadyCheckResult result = accum.ready(metadataMock, now + lingerMs + 1);
         assertEquals(Collections.singleton(node1), result.readyNodes, "Node1 should be ready");
-        Map<Integer, List<ProducerBatch>> batches = accum.drain(cluster, result.readyNodes, Integer.MAX_VALUE, now + lingerMs + 1);
+        Map<Integer, List<ProducerBatch>> batches = accum.drain(metadataMock, result.readyNodes, Integer.MAX_VALUE, now + lingerMs + 1);
         assertEquals(1, batches.size(), "Node1 should be the only ready node.");
         assertEquals(1, batches.get(0).size(), "Partition 0 should only have one batch drained.");
 
@@ -431,36 +445,37 @@ public class RecordAccumulatorTest {
 
         // Put message for partition 1 into accumulator
         accum.append(topic, partition2, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds(), cluster);
-        result = accum.ready(cluster, now + lingerMs + 1);
+        result = accum.ready(metadataMock, now + lingerMs + 1);
         assertEquals(Collections.singleton(node1), result.readyNodes, "Node1 should be ready");
 
         // tp1 should backoff while tp2 should not
-        batches = accum.drain(cluster, result.readyNodes, Integer.MAX_VALUE, now + lingerMs + 1);
+        batches = accum.drain(metadataMock, result.readyNodes, Integer.MAX_VALUE, now + lingerMs + 1);
         assertEquals(1, batches.size(), "Node1 should be the only ready node.");
         assertEquals(1, batches.get(0).size(), "Node1 should only have one batch drained.");
         assertEquals(tp2, batches.get(0).get(0).topicPartition, "Node1 should only have one batch for partition 1.");
 
         // Partition 0 can be drained after retry backoff
         long upperBoundBackoffMs = (long) (retryBackoffMs * (1 + CommonClientConfigs.RETRY_BACKOFF_JITTER));
-        result = accum.ready(cluster, now + upperBoundBackoffMs + 1);
+        result = accum.ready(metadataMock, now + upperBoundBackoffMs + 1);
         assertEquals(Collections.singleton(node1), result.readyNodes, "Node1 should be ready");
-        batches = accum.drain(cluster, result.readyNodes, Integer.MAX_VALUE, now + upperBoundBackoffMs + 1);
+        batches = accum.drain(metadataMock, result.readyNodes, Integer.MAX_VALUE, now + upperBoundBackoffMs + 1);
         assertEquals(1, batches.size(), "Node1 should be the only ready node.");
         assertEquals(1, batches.get(0).size(), "Node1 should only have one batch drained.");
         assertEquals(tp1, batches.get(0).get(0).topicPartition, "Node1 should only have one batch for partition 0.");
     }
 
     private Map<Integer, List<ProducerBatch>> drainAndCheckBatchAmount(Cluster cluster, Node leader, RecordAccumulator accum, long now, int expected) {
-        RecordAccumulator.ReadyCheckResult result = accum.ready(cluster, now);
+        metadataMock = setupMetadata(cluster);
+        RecordAccumulator.ReadyCheckResult result = accum.ready(metadataMock, now);
         if (expected > 0) {
             assertEquals(Collections.singleton(leader), result.readyNodes, "Leader should be ready");
-            Map<Integer, List<ProducerBatch>> batches = accum.drain(cluster, result.readyNodes, Integer.MAX_VALUE, now);
+            Map<Integer, List<ProducerBatch>> batches = accum.drain(metadataMock, result.readyNodes, Integer.MAX_VALUE, now);
             assertEquals(expected, batches.size(), "Leader should be the only ready node.");
             assertEquals(expected, batches.get(leader.id()).size(), "Partition should only have " + expected + " batch drained.");
             return batches;
         } else {
             assertEquals(0, result.readyNodes.size(), "Leader should not be ready");
-            Map<Integer, List<ProducerBatch>> batches = accum.drain(cluster, result.readyNodes, Integer.MAX_VALUE, now);
+            Map<Integer, List<ProducerBatch>> batches = accum.drain(metadataMock, result.readyNodes, Integer.MAX_VALUE, now);
             assertEquals(0, batches.size(), "Leader should not be drained.");
             return null;
         }
@@ -590,14 +605,14 @@ public class RecordAccumulatorTest {
             accum.append(topic, i % 3, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds(), cluster);
             assertTrue(accum.hasIncomplete());
         }
-        RecordAccumulator.ReadyCheckResult result = accum.ready(cluster, time.milliseconds());
+        RecordAccumulator.ReadyCheckResult result = accum.ready(metadataMock, time.milliseconds());
         assertEquals(0, result.readyNodes.size(), "No nodes should be ready.");
 
         accum.beginFlush();
-        result = accum.ready(cluster, time.milliseconds());
+        result = accum.ready(metadataMock, time.milliseconds());
 
         // drain and deallocate all batches
-        Map<Integer, List<ProducerBatch>> results = accum.drain(cluster, result.readyNodes, Integer.MAX_VALUE, time.milliseconds());
+        Map<Integer, List<ProducerBatch>> results = accum.drain(metadataMock, result.readyNodes, Integer.MAX_VALUE, time.milliseconds());
         assertTrue(accum.hasIncomplete());
 
         for (List<ProducerBatch> batches: results.values())
@@ -657,9 +672,9 @@ public class RecordAccumulatorTest {
         }
         for (int i = 0; i < numRecords; i++)
             accum.append(topic, i % 3, 0L, key, value, null, new TestCallback(), maxBlockTimeMs, false, time.milliseconds(), cluster);
-        RecordAccumulator.ReadyCheckResult result = accum.ready(cluster, time.milliseconds());
+        RecordAccumulator.ReadyCheckResult result = accum.ready(metadataMock, time.milliseconds());
         assertFalse(result.readyNodes.isEmpty());
-        Map<Integer, List<ProducerBatch>> drained = accum.drain(cluster, result.readyNodes, Integer.MAX_VALUE, time.milliseconds());
+        Map<Integer, List<ProducerBatch>> drained = accum.drain(metadataMock, result.readyNodes, Integer.MAX_VALUE, time.milliseconds());
         assertTrue(accum.hasUndrained());
         assertTrue(accum.hasIncomplete());
 
@@ -702,9 +717,9 @@ public class RecordAccumulatorTest {
         }
         for (int i = 0; i < numRecords; i++)
             accum.append(topic, i % 3, 0L, key, value, null, new TestCallback(), maxBlockTimeMs, false, time.milliseconds(), cluster);
-        RecordAccumulator.ReadyCheckResult result = accum.ready(cluster, time.milliseconds());
+        RecordAccumulator.ReadyCheckResult result = accum.ready(metadataMock, time.milliseconds());
         assertFalse(result.readyNodes.isEmpty());
-        Map<Integer, List<ProducerBatch>> drained = accum.drain(cluster, result.readyNodes, Integer.MAX_VALUE,
+        Map<Integer, List<ProducerBatch>> drained = accum.drain(metadataMock, result.readyNodes, Integer.MAX_VALUE,
                 time.milliseconds());
         assertTrue(accum.hasUndrained());
         assertTrue(accum.hasIncomplete());
@@ -741,10 +756,10 @@ public class RecordAccumulatorTest {
             if (time.milliseconds() < System.currentTimeMillis())
                 time.setCurrentTimeMs(System.currentTimeMillis());
             accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds(), cluster);
-            assertEquals(0, accum.ready(cluster, time.milliseconds()).readyNodes.size(), "No partition should be ready.");
+            assertEquals(0, accum.ready(metadataMock, time.milliseconds()).readyNodes.size(), "No partition should be ready.");
 
             time.sleep(lingerMs);
-            readyNodes = accum.ready(cluster, time.milliseconds()).readyNodes;
+            readyNodes = accum.ready(metadataMock, time.milliseconds()).readyNodes;
             assertEquals(Collections.singleton(node1), readyNodes, "Our partition's leader should be ready");
 
             expiredBatches = accum.expiredBatches(time.milliseconds());
@@ -759,7 +774,7 @@ public class RecordAccumulatorTest {
             time.sleep(deliveryTimeoutMs - lingerMs);
             expiredBatches = accum.expiredBatches(time.milliseconds());
             assertEquals(1, expiredBatches.size(), "The batch may expire when the partition is muted");
-            assertEquals(0, accum.ready(cluster, time.milliseconds()).readyNodes.size(), "No partitions should be ready.");
+            assertEquals(0, accum.ready(metadataMock, time.milliseconds()).readyNodes.size(), "No partitions should be ready.");
         }
     }
 
@@ -790,11 +805,11 @@ public class RecordAccumulatorTest {
         // Test batches not in retry
         for (int i = 0; i < appends; i++) {
             accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds(), cluster);
-            assertEquals(0, accum.ready(cluster, time.milliseconds()).readyNodes.size(), "No partitions should be ready.");
+            assertEquals(0, accum.ready(metadataMock, time.milliseconds()).readyNodes.size(), "No partitions should be ready.");
         }
         // Make the batches ready due to batch full
         accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, null, 0, false, time.milliseconds(), cluster);
-        Set<Node> readyNodes = accum.ready(cluster, time.milliseconds()).readyNodes;
+        Set<Node> readyNodes = accum.ready(metadataMock, time.milliseconds()).readyNodes;
         assertEquals(Collections.singleton(node1), readyNodes, "Our partition's leader should be ready");
         // Advance the clock to expire the batch.
         time.sleep(deliveryTimeoutMs + 1);
@@ -805,7 +820,7 @@ public class RecordAccumulatorTest {
         accum.unmutePartition(tp1);
         expiredBatches = accum.expiredBatches(time.milliseconds());
         assertEquals(0, expiredBatches.size(), "All batches should have been expired earlier");
-        assertEquals(0, accum.ready(cluster, time.milliseconds()).readyNodes.size(), "No partitions should be ready.");
+        assertEquals(0, accum.ready(metadataMock, time.milliseconds()).readyNodes.size(), "No partitions should be ready.");
 
         // Advance the clock to make the next batch ready due to linger.ms
         time.sleep(lingerMs);
@@ -819,15 +834,15 @@ public class RecordAccumulatorTest {
         accum.unmutePartition(tp1);
         expiredBatches = accum.expiredBatches(time.milliseconds());
         assertEquals(0, expiredBatches.size(), "All batches should have been expired");
-        assertEquals(0, accum.ready(cluster, time.milliseconds()).readyNodes.size(), "No partitions should be ready.");
+        assertEquals(0, accum.ready(metadataMock, time.milliseconds()).readyNodes.size(), "No partitions should be ready.");
 
         // Test batches in retry.
         // Create a retried batch
         accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, null, 0, false, time.milliseconds(), cluster);
         time.sleep(lingerMs);
-        readyNodes = accum.ready(cluster, time.milliseconds()).readyNodes;
+        readyNodes = accum.ready(metadataMock, time.milliseconds()).readyNodes;
         assertEquals(Collections.singleton(node1), readyNodes, "Our partition's leader should be ready");
-        Map<Integer, List<ProducerBatch>> drained = accum.drain(cluster, readyNodes, Integer.MAX_VALUE, time.milliseconds());
+        Map<Integer, List<ProducerBatch>> drained = accum.drain(metadataMock, readyNodes, Integer.MAX_VALUE, time.milliseconds());
         assertEquals(drained.get(node1.id()).size(), 1, "There should be only one batch.");
         time.sleep(1000L);
         accum.reenqueue(drained.get(node1.id()).get(0), time.milliseconds());
@@ -849,7 +864,7 @@ public class RecordAccumulatorTest {
         // Test that when being throttled muted batches are expired before the throttle time is over.
         accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, null, 0, false, time.milliseconds(), cluster);
         time.sleep(lingerMs);
-        readyNodes = accum.ready(cluster, time.milliseconds()).readyNodes;
+        readyNodes = accum.ready(metadataMock, time.milliseconds()).readyNodes;
         assertEquals(Collections.singleton(node1), readyNodes, "Our partition's leader should be ready");
         // Advance the clock to expire the batch.
         time.sleep(requestTimeout + 1);
@@ -867,7 +882,7 @@ public class RecordAccumulatorTest {
         time.sleep(throttleTimeMs);
         expiredBatches = accum.expiredBatches(time.milliseconds());
         assertEquals(0, expiredBatches.size(), "All batches should have been expired earlier");
-        assertEquals(1, accum.ready(cluster, time.milliseconds()).readyNodes.size(), "No partitions should be ready.");
+        assertEquals(1, accum.ready(metadataMock, time.milliseconds()).readyNodes.size(), "No partitions should be ready.");
     }
 
     @Test
@@ -881,28 +896,28 @@ public class RecordAccumulatorTest {
         int appends = expectedNumAppends(batchSize);
         for (int i = 0; i < appends; i++) {
             accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds(), cluster);
-            assertEquals(0, accum.ready(cluster, now).readyNodes.size(), "No partitions should be ready.");
+            assertEquals(0, accum.ready(metadataMock, now).readyNodes.size(), "No partitions should be ready.");
         }
         time.sleep(2000);
 
         // Test ready with muted partition
         accum.mutePartition(tp1);
-        RecordAccumulator.ReadyCheckResult result = accum.ready(cluster, time.milliseconds());
+        RecordAccumulator.ReadyCheckResult result = accum.ready(metadataMock, time.milliseconds());
         assertEquals(0, result.readyNodes.size(), "No node should be ready");
 
         // Test ready without muted partition
         accum.unmutePartition(tp1);
-        result = accum.ready(cluster, time.milliseconds());
+        result = accum.ready(metadataMock, time.milliseconds());
         assertTrue(result.readyNodes.size() > 0, "The batch should be ready");
 
         // Test drain with muted partition
         accum.mutePartition(tp1);
-        Map<Integer, List<ProducerBatch>> drained = accum.drain(cluster, result.readyNodes, Integer.MAX_VALUE, time.milliseconds());
+        Map<Integer, List<ProducerBatch>> drained = accum.drain(metadataMock, result.readyNodes, Integer.MAX_VALUE, time.milliseconds());
         assertEquals(0, drained.get(node1.id()).size(), "No batch should have been drained");
 
         // Test drain without muted partition.
         accum.unmutePartition(tp1);
-        drained = accum.drain(cluster, result.readyNodes, Integer.MAX_VALUE, time.milliseconds());
+        drained = accum.drain(metadataMock, result.readyNodes, Integer.MAX_VALUE, time.milliseconds());
         assertTrue(drained.get(node1.id()).size() > 0, "The batch should have been drained.");
     }
 
@@ -952,20 +967,20 @@ public class RecordAccumulatorTest {
             false, time.milliseconds(), cluster);
         assertTrue(accumulator.hasUndrained());
 
-        RecordAccumulator.ReadyCheckResult firstResult = accumulator.ready(cluster, time.milliseconds());
+        RecordAccumulator.ReadyCheckResult firstResult = accumulator.ready(metadataMock, time.milliseconds());
         assertEquals(0, firstResult.readyNodes.size());
-        Map<Integer, List<ProducerBatch>> firstDrained = accumulator.drain(cluster, firstResult.readyNodes,
+        Map<Integer, List<ProducerBatch>> firstDrained = accumulator.drain(metadataMock, firstResult.readyNodes,
             Integer.MAX_VALUE, time.milliseconds());
         assertEquals(0, firstDrained.size());
 
         // Once the transaction begins completion, then the batch should be drained immediately.
         Mockito.when(transactionManager.isCompleting()).thenReturn(true);
 
-        RecordAccumulator.ReadyCheckResult secondResult = accumulator.ready(cluster, time.milliseconds());
+        RecordAccumulator.ReadyCheckResult secondResult = accumulator.ready(metadataMock, time.milliseconds());
         assertEquals(1, secondResult.readyNodes.size());
         Node readyNode = secondResult.readyNodes.iterator().next();
 
-        Map<Integer, List<ProducerBatch>> secondDrained = accumulator.drain(cluster, secondResult.readyNodes,
+        Map<Integer, List<ProducerBatch>> secondDrained = accumulator.drain(metadataMock, secondResult.readyNodes,
             Integer.MAX_VALUE, time.milliseconds());
         assertEquals(Collections.singleton(readyNode.id()), secondDrained.keySet());
         List<ProducerBatch> batches = secondDrained.get(readyNode.id());
@@ -996,16 +1011,16 @@ public class RecordAccumulatorTest {
         // Re-enqueuing counts as a second attempt, so the delay with jitter is 100 * (1 + 0.2) + 1
         time.sleep(121L);
         // Drain the batch.
-        RecordAccumulator.ReadyCheckResult result = accum.ready(cluster, time.milliseconds());
+        RecordAccumulator.ReadyCheckResult result = accum.ready(metadataMock, time.milliseconds());
         assertTrue(result.readyNodes.size() > 0, "The batch should be ready");
-        Map<Integer, List<ProducerBatch>> drained = accum.drain(cluster, result.readyNodes, Integer.MAX_VALUE, time.milliseconds());
+        Map<Integer, List<ProducerBatch>> drained = accum.drain(metadataMock, result.readyNodes, Integer.MAX_VALUE, time.milliseconds());
         assertEquals(1, drained.size(), "Only node1 should be drained");
         assertEquals(1, drained.get(node1.id()).size(), "Only one batch should be drained");
         // Split and reenqueue the batch.
         accum.splitAndReenqueue(drained.get(node1.id()).get(0));
         time.sleep(101L);
 
-        drained = accum.drain(cluster, result.readyNodes, Integer.MAX_VALUE, time.milliseconds());
+        drained = accum.drain(metadataMock, result.readyNodes, Integer.MAX_VALUE, time.milliseconds());
         assertFalse(drained.isEmpty());
         assertFalse(drained.get(node1.id()).isEmpty());
         drained.get(node1.id()).get(0).complete(acked.get(), 100L);
@@ -1013,7 +1028,7 @@ public class RecordAccumulatorTest {
         assertTrue(future1.isDone());
         assertEquals(0, future1.get().offset());
 
-        drained = accum.drain(cluster, result.readyNodes, Integer.MAX_VALUE, time.milliseconds());
+        drained = accum.drain(metadataMock, result.readyNodes, Integer.MAX_VALUE, time.milliseconds());
         assertFalse(drained.isEmpty());
         assertFalse(drained.get(node1.id()).isEmpty());
         drained.get(node1.id()).get(0).complete(acked.get(), 100L);
@@ -1034,14 +1049,14 @@ public class RecordAccumulatorTest {
         int numSplitBatches = prepareSplitBatches(accum, seed, 100, 20);
         assertTrue(numSplitBatches > 0, "There should be some split batches");
         // Drain all the split batches.
-        RecordAccumulator.ReadyCheckResult result = accum.ready(cluster, time.milliseconds());
+        RecordAccumulator.ReadyCheckResult result = accum.ready(metadataMock, time.milliseconds());
         for (int i = 0; i < numSplitBatches; i++) {
             Map<Integer, List<ProducerBatch>> drained =
-                accum.drain(cluster, result.readyNodes, Integer.MAX_VALUE, time.milliseconds());
+                accum.drain(metadataMock, result.readyNodes, Integer.MAX_VALUE, time.milliseconds());
             assertFalse(drained.isEmpty());
             assertFalse(drained.get(node1.id()).isEmpty());
         }
-        assertTrue(accum.ready(cluster, time.milliseconds()).readyNodes.isEmpty(), "All the batches should have been drained.");
+        assertTrue(accum.ready(metadataMock, time.milliseconds()).readyNodes.isEmpty(), "All the batches should have been drained.");
         assertEquals(bufferCapacity, accum.bufferPoolAvailableMemory(),
             "The split batches should be allocated off the accumulator");
     }
@@ -1088,16 +1103,16 @@ public class RecordAccumulatorTest {
             batchSize + DefaultRecordBatch.RECORD_BATCH_OVERHEAD, 10 * batchSize, CompressionType.NONE, lingerMs);
 
         accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds(), cluster);
-        Set<Node> readyNodes = accum.ready(cluster, time.milliseconds()).readyNodes;
-        Map<Integer, List<ProducerBatch>> drained = accum.drain(cluster, readyNodes, Integer.MAX_VALUE, time.milliseconds());
+        Set<Node> readyNodes = accum.ready(metadataMock, time.milliseconds()).readyNodes;
+        Map<Integer, List<ProducerBatch>> drained = accum.drain(metadataMock, readyNodes, Integer.MAX_VALUE, time.milliseconds());
         assertTrue(drained.isEmpty());
         //assertTrue(accum.soonToExpireInFlightBatches().isEmpty());
 
         // advanced clock and send one batch out but it should not be included in soon to expire inflight
         // batches because batch's expiry is quite far.
         time.sleep(lingerMs + 1);
-        readyNodes = accum.ready(cluster, time.milliseconds()).readyNodes;
-        drained = accum.drain(cluster, readyNodes, Integer.MAX_VALUE, time.milliseconds());
+        readyNodes = accum.ready(metadataMock, time.milliseconds()).readyNodes;
+        drained = accum.drain(metadataMock, readyNodes, Integer.MAX_VALUE, time.milliseconds());
         assertEquals(1, drained.size(), "A batch did not drain after linger");
         //assertTrue(accum.soonToExpireInFlightBatches().isEmpty());
 
@@ -1106,8 +1121,8 @@ public class RecordAccumulatorTest {
         time.sleep(lingerMs * 4);
 
         // Now drain and check that accumulator picked up the drained batch because its expiry is soon.
-        readyNodes = accum.ready(cluster, time.milliseconds()).readyNodes;
-        drained = accum.drain(cluster, readyNodes, Integer.MAX_VALUE, time.milliseconds());
+        readyNodes = accum.ready(metadataMock, time.milliseconds()).readyNodes;
+        drained = accum.drain(metadataMock, readyNodes, Integer.MAX_VALUE, time.milliseconds());
         assertEquals(1, drained.size(), "A batch did not drain after linger");
     }
 
@@ -1129,9 +1144,9 @@ public class RecordAccumulatorTest {
         for (Boolean mute : muteStates) {
             accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, null, 0, false, time.milliseconds(), cluster);
             time.sleep(lingerMs);
-            readyNodes = accum.ready(cluster, time.milliseconds()).readyNodes;
+            readyNodes = accum.ready(metadataMock, time.milliseconds()).readyNodes;
             assertEquals(Collections.singleton(node1), readyNodes, "Our partition's leader should be ready");
-            Map<Integer, List<ProducerBatch>> drained = accum.drain(cluster, readyNodes, Integer.MAX_VALUE, time.milliseconds());
+            Map<Integer, List<ProducerBatch>> drained = accum.drain(metadataMock, readyNodes, Integer.MAX_VALUE, time.milliseconds());
             assertEquals(1, drained.get(node1.id()).size(), "There should be only one batch.");
             time.sleep(rtt);
             accum.reenqueue(drained.get(node1.id()).get(0), time.milliseconds());
@@ -1143,7 +1158,7 @@ public class RecordAccumulatorTest {
 
             // test expiration
             time.sleep(deliveryTimeoutMs - rtt);
-            accum.drain(cluster, Collections.singleton(node1), Integer.MAX_VALUE, time.milliseconds());
+            accum.drain(metadataMock, Collections.singleton(node1), Integer.MAX_VALUE, time.milliseconds());
             expiredBatches = accum.expiredBatches(time.milliseconds());
             assertEquals(mute ? 1 : 0, expiredBatches.size(), "RecordAccumulator has expired batches if the partition is not muted");
         }
@@ -1184,12 +1199,12 @@ public class RecordAccumulatorTest {
             // We only appended if we do not retry.
             if (!switchPartition) {
                 appends++;
-                assertEquals(0, accum.ready(cluster, now).readyNodes.size(), "No partitions should be ready.");
+                assertEquals(0, accum.ready(metadataMock, now).readyNodes.size(), "No partitions should be ready.");
             }
         }
 
         // Batch should be full.
-        assertEquals(1, accum.ready(cluster, time.milliseconds()).readyNodes.size());
+        assertEquals(1, accum.ready(metadataMock, time.milliseconds()).readyNodes.size());
         assertEquals(appends, expectedAppends);
         switchPartition = false;
 
@@ -1317,7 +1332,7 @@ public class RecordAccumulatorTest {
             }
 
             // Let the accumulator generate the probability tables.
-            accum.ready(cluster, time.milliseconds());
+            accum.ready(metadataMock, time.milliseconds());
 
             // Set up callbacks so that we know what partition is chosen.
             final AtomicInteger partition = new AtomicInteger(RecordMetadata.UNKNOWN_PARTITION);
@@ -1361,7 +1376,7 @@ public class RecordAccumulatorTest {
             // Test that partitions residing on high-latency nodes don't get switched to.
             accum.updateNodeLatencyStats(0, time.milliseconds() - 200, true);
             accum.updateNodeLatencyStats(0, time.milliseconds(), false);
-            accum.ready(cluster, time.milliseconds());
+            accum.ready(metadataMock, time.milliseconds());
 
             // Do one append, because partition gets switched after append.
             accum.append(topic, RecordMetadata.UNKNOWN_PARTITION, 0L, null, largeValue, Record.EMPTY_HEADERS,
@@ -1398,9 +1413,9 @@ public class RecordAccumulatorTest {
             time.sleep(10);
 
             // We should have one batch ready.
-            Set<Node> nodes = accum.ready(cluster, time.milliseconds()).readyNodes;
+            Set<Node> nodes = accum.ready(metadataMock, time.milliseconds()).readyNodes;
             assertEquals(1, nodes.size(), "Should have 1 leader ready");
-            List<ProducerBatch> batches = accum.drain(cluster, nodes, Integer.MAX_VALUE, 0).entrySet().iterator().next().getValue();
+            List<ProducerBatch> batches = accum.drain(metadataMock, nodes, Integer.MAX_VALUE, 0).entrySet().iterator().next().getValue();
             assertEquals(1, batches.size(), "Should have 1 batch ready");
             int actualBatchSize = batches.get(0).records().sizeInBytes();
             assertTrue(actualBatchSize > batchSize / 2, "Batch must be greater than half batch.size");
@@ -1408,6 +1423,238 @@ public class RecordAccumulatorTest {
         }
     }
 
+    /**
+     * For a batch being retried, this validates ready() and drain() whether a batch should skip-backoff(retries-immediately), or backoff, based on -
+     * 1. how long it has waited between retry attempts.
+     * 2. change in leader hosting the partition.
+     */
+    @Test
+    public void testReadyAndDrainWhenABatchIsBeingRetried() throws InterruptedException {
+        int part1LeaderEpoch = 100;
+        // Create cluster metadata, partition1 being hosted by node1.
+        part1 = new PartitionInfo(topic, partition1, node1, null, null, null);
+        cluster = new Cluster(null, Arrays.asList(node1, node2), Arrays.asList(part1),
+            Collections.emptySet(), Collections.emptySet());
+        final int finalEpoch = part1LeaderEpoch;
+        metadataMock = setupMetadata(cluster, tp -> finalEpoch);
+
+        int batchSize = 10;
+        int lingerMs = 10;
+        int retryBackoffMs = 100;
+        int retryBackoffMaxMs = 1000;
+        int deliveryTimeoutMs = Integer.MAX_VALUE;
+        long totalSize = 10 * 1024;
+        String metricGrpName = "producer-metrics";
+        final RecordAccumulator accum = new RecordAccumulator(logContext, batchSize,
+            CompressionType.NONE, lingerMs, retryBackoffMs, retryBackoffMaxMs,
+            deliveryTimeoutMs, metrics, metricGrpName, time, new ApiVersions(), null,
+            new BufferPool(totalSize, batchSize, metrics, time, metricGrpName));
+
+        // Create 1 batch(batchA) to be produced to partition1.
+        long now = time.milliseconds();
+        accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, now, cluster);
+
+        // 1st attempt(not a retry) to produce batchA, it should be ready & drained to be produced.
+        {
+            now += lingerMs + 1;
+            RecordAccumulator.ReadyCheckResult result = accum.ready(metadataMock, now);
+            assertTrue(result.readyNodes.contains(node1), "Node1 is ready");
+
+            Map<Integer, List<ProducerBatch>> batches = accum.drain(metadataMock,
+                result.readyNodes, 999999 /* maxSize */, now);
+            assertTrue(batches.containsKey(node1.id()) && batches.get(node1.id()).size() == 1, "Node1 has 1 batch ready & drained");
+            ProducerBatch batch = batches.get(node1.id()).get(0);
+            assertEquals(Optional.of(part1LeaderEpoch), batch.currentLeaderEpoch());
+            assertEquals(0, batch.attemptsWhenLeaderLastChanged());
+            // Re-enqueue batch for subsequent retries & test-cases
+            accum.reenqueue(batch, now);
+        }
+
+        // In this retry of batchA, wait-time between retries is less than configured and no leader change, so should backoff.
+        {
+            now += 1;
+            RecordAccumulator.ReadyCheckResult result = accum.ready(metadataMock, now);
+            assertFalse(result.readyNodes.contains(node1), "Node1 is not ready");
+
+            // Try to drain from node1, it should return no batches.
+            Map<Integer, List<ProducerBatch>> batches = accum.drain(metadataMock,
+                new HashSet<>(Arrays.asList(node1)), 999999 /* maxSize */, now);
+            assertTrue(batches.containsKey(node1.id()) && batches.get(node1.id()).isEmpty(),
+                "No batches ready to be drained on Node1");
+        }
+
+        // In this retry of batchA, wait-time between retries is less than configured and leader has changed, so should not backoff.
+        {
+            now += 1;
+            part1LeaderEpoch++;
+            // Create cluster metadata, with new leader epoch.
+            part1 = new PartitionInfo(topic, partition1, node1, null, null, null);
+            cluster = new Cluster(null, Arrays.asList(node1, node2), Arrays.asList(part1),
+                Collections.emptySet(), Collections.emptySet());
+            final int finalPart1LeaderEpoch = part1LeaderEpoch;
+            metadataMock = setupMetadata(cluster, tp -> finalPart1LeaderEpoch);
+            RecordAccumulator.ReadyCheckResult result = accum.ready(metadataMock, now);
+            assertTrue(result.readyNodes.contains(node1), "Node1 is ready");
+
+            Map<Integer, List<ProducerBatch>> batches = accum.drain(metadataMock,
+                result.readyNodes, 999999 /* maxSize */, now);
+            assertTrue(batches.containsKey(node1.id()) && batches.get(node1.id()).size() == 1, "Node1 has 1 batch ready & drained");
+            ProducerBatch batch = batches.get(node1.id()).get(0);
+            assertEquals(Optional.of(part1LeaderEpoch), batch.currentLeaderEpoch());
+            assertEquals(1, batch.attemptsWhenLeaderLastChanged());
+
+            // Re-enqueue batch for subsequent retries/test-cases.
+            accum.reenqueue(batch, now);
+        }
+
+        // In this retry of batchA, wait-time between retries is more than configured and no leader change, so should not backoff.
+        {
+            now += 2 * retryBackoffMaxMs;
+            // Create cluster metadata, with new leader epoch.
+            part1 = new PartitionInfo(topic, partition1, node1, null, null, null);
+            cluster = new Cluster(null, Arrays.asList(node1, node2), Arrays.asList(part1),
+                Collections.emptySet(), Collections.emptySet());
+            final int finalPart1LeaderEpoch = part1LeaderEpoch;
+            metadataMock = setupMetadata(cluster, tp -> finalPart1LeaderEpoch);
+            RecordAccumulator.ReadyCheckResult result = accum.ready(metadataMock, now);
+            assertTrue(result.readyNodes.contains(node1), "Node1 is ready");
+
+            Map<Integer, List<ProducerBatch>> batches = accum.drain(metadataMock,
+                result.readyNodes, 999999 /* maxSize */, now);
+            assertTrue(batches.containsKey(node1.id()) && batches.get(node1.id()).size() == 1, "Node1 has 1 batch ready & drained");
+            ProducerBatch batch = batches.get(node1.id()).get(0);
+            assertEquals(Optional.of(part1LeaderEpoch), batch.currentLeaderEpoch());
+            assertEquals(1, batch.attemptsWhenLeaderLastChanged());
+
+            // Re-enqueue batch for subsequent retries/test-cases.
+            accum.reenqueue(batch, now);
+        }
+
+        // In this retry of batchA, wait-time between retries is more than configured and leader has changed, so should not backoff.
+        {
+            now += 2 * retryBackoffMaxMs;
+            part1LeaderEpoch++;
+            // Create cluster metadata, with new leader epoch.
+            part1 = new PartitionInfo(topic, partition1, node1, null, null, null);
+            cluster = new Cluster(null, Arrays.asList(node1, node2), Arrays.asList(part1),
+                Collections.emptySet(), Collections.emptySet());
+            final int finalPart1LeaderEpoch = part1LeaderEpoch;
+            metadataMock = setupMetadata(cluster, tp -> finalPart1LeaderEpoch);
+            RecordAccumulator.ReadyCheckResult result = accum.ready(metadataMock, now);
+            assertTrue(result.readyNodes.contains(node1), "Node1 is ready");
+
+            Map<Integer, List<ProducerBatch>> batches = accum.drain(metadataMock,
+                result.readyNodes, 999999 /* maxSize */, now);
+            assertTrue(batches.containsKey(node1.id()) && batches.get(node1.id()).size() == 1, "Node1 has 1 batch ready & drained");
+            ProducerBatch batch = batches.get(node1.id()).get(0);
+            assertEquals(Optional.of(part1LeaderEpoch), batch.currentLeaderEpoch());
+            assertEquals(3, batch.attemptsWhenLeaderLastChanged());
+
+            // Re-enqueue batch for subsequent retries/test-cases.
+            accum.reenqueue(batch, now);
+        }
+    }
+
+    @Test
+    public void testDrainWithANodeThatDoesntHostAnyPartitions() {
+        int batchSize = 10;
+        int lingerMs = 10;
+        long totalSize = 10 * 1024;
+        RecordAccumulator accum = createTestRecordAccumulator(batchSize, totalSize,
+            CompressionType.NONE, lingerMs);
+
+        // Create cluster metadata, node2 doesn't host any partitions.
+        part1 = new PartitionInfo(topic, partition1, node1, null, null, null);
+        cluster = new Cluster(null, Arrays.asList(node1, node2), Arrays.asList(part1),
+            Collections.emptySet(), Collections.emptySet());
+        metadataMock = Mockito.mock(Metadata.class);
+        Mockito.when(metadataMock.fetch()).thenReturn(cluster);
+        Mockito.when(metadataMock.currentLeader(tp1)).thenReturn(
+            new Metadata.LeaderAndEpoch(Optional.of(node1),
+                Optional.of(999 /* dummy value */)));
+
+        // Drain for node2, it should return 0 batches,
+        Map<Integer, List<ProducerBatch>> batches = accum.drain(metadataMock,
+            new HashSet<>(Arrays.asList(node2)), 999999 /* maxSize */, time.milliseconds());
+        assertTrue(batches.get(node2.id()).isEmpty());
+    }
+
+    @Test
+    public void testDrainOnANodeWhenItCeasesToBeALeader() throws InterruptedException {
+        int batchSize = 10;
+        int lingerMs = 10;
+        long totalSize = 10 * 1024;
+        RecordAccumulator accum = createTestRecordAccumulator(batchSize, totalSize,
+            CompressionType.NONE, lingerMs);
+
+        // While node1 is being drained, leader changes from node1 -> node2 for a partition.
+        {
+            // Create cluster metadata, partition1&2 being hosted by node1&2 resp.
+            part1 = new PartitionInfo(topic, partition1, node1, null, null, null);
+            part2 = new PartitionInfo(topic, partition2, node2, null, null, null);
+            cluster = new Cluster(null, Arrays.asList(node1, node2), Arrays.asList(part1, part2),
+                Collections.emptySet(), Collections.emptySet());
+            metadataMock = Mockito.mock(Metadata.class);
+            Mockito.when(metadataMock.fetch()).thenReturn(cluster);
+            // But metadata has a newer leader for partition1 i.e node2.
+            Mockito.when(metadataMock.currentLeader(tp1)).thenReturn(
+                new Metadata.LeaderAndEpoch(Optional.of(node2),
+                    Optional.of(999 /* dummy value */)));
+            Mockito.when(metadataMock.currentLeader(tp2)).thenReturn(
+                new Metadata.LeaderAndEpoch(Optional.of(node2),
+                    Optional.of(999 /* dummy value */)));
+
+            // Create 1 batch each for partition1 & partition2.
+            long now = time.milliseconds();
+            accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, null,
+                maxBlockTimeMs, false, now, cluster);
+            accum.append(topic, partition2, 0L, key, value, Record.EMPTY_HEADERS, null,
+                maxBlockTimeMs, false, now, cluster);
+
+            // Drain for node1, it should return 0 batches, as partition1's leader in metadata changed.
+            // Drain for node2, it should return 1 batch, for partition2.
+            Map<Integer, List<ProducerBatch>> batches = accum.drain(metadataMock,
+                new HashSet<>(Arrays.asList(node1, node2)), 999999 /* maxSize */, now);
+            assertTrue(batches.get(node1.id()).isEmpty());
+            assertEquals(1, batches.get(node2.id()).size());
+        }
+
+        // Cleanup un-drained batches to have an empty accum before next test.
+        accum.abortUndrainedBatches(new RuntimeException());
+
+        // While node1 is being drained, leader changes from node1 -> "no-leader" for partition.
+        {
+            // Create cluster metadata, partition1&2 being hosted by node1&2 resp.
+            part1 = new PartitionInfo(topic, partition1, node1, null, null, null);
+            part2 = new PartitionInfo(topic, partition2, node2, null, null, null);
+            cluster = new Cluster(null, Arrays.asList(node1, node2), Arrays.asList(part1, part2),
+                Collections.emptySet(), Collections.emptySet());
+            metadataMock = Mockito.mock(Metadata.class);
+            Mockito.when(metadataMock.fetch()).thenReturn(cluster);
+            // But metadata no longer has a leader for partition1.
+            Mockito.when(metadataMock.currentLeader(tp1)).thenReturn(
+                new Metadata.LeaderAndEpoch(Optional.empty(),
+                    Optional.of(999 /* dummy value */)));
+            Mockito.when(metadataMock.currentLeader(tp2)).thenReturn(
+                new Metadata.LeaderAndEpoch(Optional.of(node2),
+                    Optional.of(999 /* dummy value */)));
+
+            // Create 1 batch each for partition1 & partition2.
+            long now = time.milliseconds();
+            accum.append(topic, partition1, 0L, key, value, Record.EMPTY_HEADERS, null,
+                maxBlockTimeMs, false, now, cluster);
+            accum.append(topic, partition2, 0L, key, value, Record.EMPTY_HEADERS, null,
+                maxBlockTimeMs, false, now, cluster);
+
+            // Drain for node1, it should return 0 batches, as partition1's leader in metadata changed.
+            // Drain for node2, it should return 1 batch, for partition2.
+            Map<Integer, List<ProducerBatch>> batches = accum.drain(metadataMock,
+                new HashSet<>(Arrays.asList(node1, node2)), 999999 /* maxSize */, now);
+            assertTrue(batches.get(node1.id()).isEmpty());
+            assertEquals(1, batches.get(node2.id()).size());
+        }
+    }
+
     private int prepareSplitBatches(RecordAccumulator accum, long seed, int recordSize, int numRecords)
         throws InterruptedException {
         Random random = new Random();
@@ -1420,9 +1667,9 @@ public class RecordAccumulatorTest {
             accum.append(topic, partition1, 0L, null, bytesWithPoorCompression(random, recordSize), Record.EMPTY_HEADERS, null, 0, false, time.milliseconds(), cluster);
         }
 
-        RecordAccumulator.ReadyCheckResult result = accum.ready(cluster, time.milliseconds());
+        RecordAccumulator.ReadyCheckResult result = accum.ready(metadataMock, time.milliseconds());
         assertFalse(result.readyNodes.isEmpty());
-        Map<Integer, List<ProducerBatch>> batches = accum.drain(cluster, result.readyNodes, Integer.MAX_VALUE, time.milliseconds());
+        Map<Integer, List<ProducerBatch>> batches = accum.drain(metadataMock, result.readyNodes, Integer.MAX_VALUE, time.milliseconds());
         assertEquals(1, batches.size());
         assertEquals(1, batches.values().iterator().next().size());
         ProducerBatch batch = batches.values().iterator().next().get(0);
@@ -1438,8 +1685,8 @@ public class RecordAccumulatorTest {
         boolean batchDrained;
         do {
             batchDrained = false;
-            RecordAccumulator.ReadyCheckResult result = accum.ready(cluster, time.milliseconds());
-            Map<Integer, List<ProducerBatch>> batches = accum.drain(cluster, result.readyNodes, Integer.MAX_VALUE, time.milliseconds());
+            RecordAccumulator.ReadyCheckResult result = accum.ready(metadataMock, time.milliseconds());
+            Map<Integer, List<ProducerBatch>> batches = accum.drain(metadataMock, result.readyNodes, Integer.MAX_VALUE, time.milliseconds());
             for (List<ProducerBatch> batchList : batches.values()) {
                 for (ProducerBatch batch : batchList) {
                     batchDrained = true;
@@ -1558,4 +1805,27 @@ public class RecordAccumulatorTest {
             txnManager,
             new BufferPool(totalSize, batchSize, metrics, time, metricGrpName));
     }
+
+    /**
+     * Setup a mocked metadata object.
+     */
+    private Metadata setupMetadata(Cluster cluster) {
+        return setupMetadata(cluster, tp -> 999 /* dummy epoch */);
+    }
+
+    /**
+     * Setup a mocked metadata object.
+     */
+    private Metadata setupMetadata(Cluster cluster, final Function<TopicPartition, Integer> epochSupplier) {
+        Metadata metadataMock = Mockito.mock(Metadata.class);
+        Mockito.when(metadataMock.fetch()).thenReturn(cluster);
+        for (String topic: cluster.topics()) {
+            for (PartitionInfo partInfo: cluster.partitionsForTopic(topic)) {
+                TopicPartition tp = new TopicPartition(partInfo.topic(), partInfo.partition());
+                Integer partLeaderEpoch = epochSupplier.apply(tp);
+                Mockito.when(metadataMock.currentLeader(tp)).thenReturn(new Metadata.LeaderAndEpoch(Optional.of(partInfo.leader()), Optional.of(partLeaderEpoch)));
+            }
+        }
+        return metadataMock;
+    }
 }
diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
index 860b969f469..ba625d74408 100644
--- a/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
@@ -497,7 +497,7 @@ public class SenderTest {
 
         Node clusterNode = metadata.fetch().nodes().get(0);
         Map<Integer, List<ProducerBatch>> drainedBatches =
-            accumulator.drain(metadata.fetch(), Collections.singleton(clusterNode), Integer.MAX_VALUE, time.milliseconds());
+            accumulator.drain(metadata, Collections.singleton(clusterNode), Integer.MAX_VALUE, time.milliseconds());
         sender.addToInflightBatches(drainedBatches);
 
         // Disconnect the target node for the pending produce request. This will ensure that sender will try to
@@ -3146,6 +3146,93 @@ public class SenderTest {
 
         txnManager.beginTransaction();
     }
+    @Test
+    public void testProducerBatchRetriesWhenPartitionLeaderChanges() throws Exception {
+        Metrics m = new Metrics();
+        SenderMetricsRegistry senderMetrics = new SenderMetricsRegistry(m);
+        try {
+            // SETUP
+            String metricGrpName = "producer-metrics-test-stats-1";
+            long totalSize = 1024 * 1024;
+            BufferPool pool = new BufferPool(totalSize, batchSize, metrics, time,
+                metricGrpName);
+            long retryBackoffMaxMs = 100L;
+            // lingerMs is 0 to send batch as soon as any records are available on it.
+            this.accumulator = new RecordAccumulator(logContext, batchSize,
+                CompressionType.NONE, 0, 10L, retryBackoffMaxMs,
+                DELIVERY_TIMEOUT_MS, metrics, metricGrpName, time, apiVersions, null, pool);
+            Sender sender = new Sender(logContext, client, metadata, this.accumulator, false,
+                MAX_REQUEST_SIZE, ACKS_ALL,
+                10, senderMetrics, time, REQUEST_TIMEOUT, RETRY_BACKOFF_MS, null,
+                apiVersions);
+            // Update metadata with leader-epochs.
+            int tp0LeaderEpoch = 100;
+            int epoch = tp0LeaderEpoch;
+            this.client.updateMetadata(
+                RequestTestUtils.metadataUpdateWith(1, Collections.singletonMap("test", 2),
+                    tp -> {
+                        if (tp0.equals(tp)) {
+                            return epoch;
+                        }  else if (tp1.equals(tp)) {
+                            return 0;
+                        } else {
+                            throw new RuntimeException("unexpected tp " + tp);
+                        }
+                    }));
+
+            // Produce batch, it returns with a retry-able error like NOT_LEADER_OR_FOLLOWER, scheduled for retry.
+            Future<RecordMetadata> futureIsProduced = appendToAccumulator(tp0, 0L, "key", "value");
+            sender.runOnce(); // connect
+            sender.runOnce(); // send produce request
+            assertEquals(1, client.inFlightRequestCount(),
+                "We should have a single produce request in flight.");
+            assertEquals(1, sender.inFlightBatches(tp0).size());
+            assertTrue(client.hasInFlightRequests());
+            client.respond(produceResponse(tp0, -1, Errors.NOT_LEADER_OR_FOLLOWER, 0));
+            sender.runOnce(); // receive produce response, batch scheduled for retry
+            assertTrue(!futureIsProduced.isDone(), "Produce request is yet not done.");
+
+            // TEST that as new-leader(with epochA) is discovered, the batch is retried immediately i.e. skips any backoff period.
+            // Update leader epoch for tp0
+            int newEpoch = ++tp0LeaderEpoch;
+            this.client.updateMetadata(
+                RequestTestUtils.metadataUpdateWith(1, Collections.singletonMap("test", 2),
+                    tp -> {
+                        if (tp0.equals(tp)) {
+                            return newEpoch;
+                        } else if (tp1.equals(tp)) {
+                            return 0;
+                        } else {
+                            throw new RuntimeException("unexpected tp " + tp);
+                        }
+                    }));
+            sender.runOnce(); // send produce request, immediately.
+            assertEquals(1, sender.inFlightBatches(tp0).size());
+            assertTrue(client.hasInFlightRequests());
+            client.respond(produceResponse(tp0, -1, Errors.NOT_LEADER_OR_FOLLOWER, 0));
+            sender.runOnce(); // receive produce response, schedule batch for retry.
+            assertTrue(!futureIsProduced.isDone(), "Produce request is yet not done.");
+
+            // TEST that a subsequent retry to the same leader(epochA) waits the backoff period.
+            sender.runOnce(); //send produce request
+            // No batches in-flight
+            assertEquals(0, sender.inFlightBatches(tp0).size());
+            assertTrue(!client.hasInFlightRequests());
+
+            // TEST that after waiting for longer than backoff period, batch is retried again.
+            time.sleep(2 * retryBackoffMaxMs);
+            sender.runOnce(); // send produce request
+            assertEquals(1, sender.inFlightBatches(tp0).size());
+            assertTrue(client.hasInFlightRequests());
+            long offset = 999;
+            client.respond(produceResponse(tp0, offset, Errors.NONE, 0));
+            sender.runOnce(); // receive response.
+            assertTrue(futureIsProduced.isDone(), "Request to tp0 successfully done");
+            assertEquals(offset, futureIsProduced.get().offset());
+        } finally {
+            m.close();
+        }
+    }
 
     private void verifyErrorMessage(ProduceResponse response, String expectedMessage) throws Exception {
         Future<RecordMetadata> future = appendToAccumulator(tp0, 0L, "key", "value");
diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
index e4ca0630c33..feb764228cb 100644
--- a/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
@@ -17,6 +17,7 @@
 package org.apache.kafka.clients.producer.internals;
 
 import org.apache.kafka.clients.ApiVersions;
+import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.MockClient;
 import org.apache.kafka.clients.NodeApiVersions;
 import org.apache.kafka.clients.consumer.CommitFailedException;
@@ -99,6 +100,7 @@ import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.function.Supplier;
+import org.mockito.Mockito;
 
 import static java.util.Collections.singleton;
 import static java.util.Collections.singletonList;
@@ -2474,10 +2476,11 @@ public class TransactionManagerTest {
 
         Cluster cluster = new Cluster(null, Arrays.asList(node1, node2), Arrays.asList(part1, part2),
                 Collections.emptySet(), Collections.emptySet());
+        Metadata metadataMock = setupMetadata(cluster);
         Set<Node> nodes = new HashSet<>();
         nodes.add(node1);
         nodes.add(node2);
-        Map<Integer, List<ProducerBatch>> drainedBatches = accumulator.drain(cluster, nodes, Integer.MAX_VALUE,
+        Map<Integer, List<ProducerBatch>> drainedBatches = accumulator.drain(metadataMock, nodes, Integer.MAX_VALUE,
                 time.milliseconds());
 
         // We shouldn't drain batches which haven't been added to the transaction yet.
@@ -2506,8 +2509,9 @@ public class TransactionManagerTest {
         PartitionInfo part1 = new PartitionInfo(topic, 1, node1, null, null);
         Cluster cluster = new Cluster(null, Collections.singletonList(node1), Collections.singletonList(part1),
                 Collections.emptySet(), Collections.emptySet());
+        Metadata metadataMock = setupMetadata(cluster);
         appendToAccumulator(tp1);
-        Map<Integer, List<ProducerBatch>> drainedBatches = accumulator.drain(cluster, Collections.singleton(node1),
+        Map<Integer, List<ProducerBatch>> drainedBatches = accumulator.drain(metadataMock, Collections.singleton(node1),
                 Integer.MAX_VALUE,
                 time.milliseconds());
 
@@ -2529,9 +2533,11 @@ public class TransactionManagerTest {
 
         Cluster cluster = new Cluster(null, Collections.singletonList(node1), Collections.singletonList(part1),
                 Collections.emptySet(), Collections.emptySet());
+        Metadata metadataMock = setupMetadata(cluster);
+
         Set<Node> nodes = new HashSet<>();
         nodes.add(node1);
-        Map<Integer, List<ProducerBatch>> drainedBatches = accumulator.drain(cluster, nodes, Integer.MAX_VALUE,
+        Map<Integer, List<ProducerBatch>> drainedBatches = accumulator.drain(metadataMock, nodes, Integer.MAX_VALUE,
                 time.milliseconds());
 
         // We shouldn't drain batches which haven't been added to the transaction yet.
@@ -3833,4 +3839,16 @@ public class TransactionManagerTest {
         ProducerTestUtils.runUntil(sender, condition);
     }
 
+    private Metadata setupMetadata(Cluster cluster) {
+        Metadata metadataMock = Mockito.mock(Metadata.class);
+        Mockito.when(metadataMock.fetch()).thenReturn(cluster);
+        for (String topic: cluster.topics()) {
+            for (PartitionInfo partInfo: cluster.partitionsForTopic(topic)) {
+                TopicPartition tp = new TopicPartition(partInfo.topic(), partInfo.partition());
+                Mockito.when(metadataMock.currentLeader(tp)).thenReturn(new Metadata.LeaderAndEpoch(Optional.of(partInfo.leader()), Optional.of(999 /* dummy value */)));
+            }
+        }
+        return metadataMock;
+    }
+
 }