You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jg...@apache.org on 2019/06/12 19:54:33 UTC

[kafka] branch trunk updated: KAFKA-8483/KAFKA-8484; Ensure safe handling of producerId resets (#6883)

This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new af28010  KAFKA-8483/KAFKA-8484; Ensure safe handling of producerId resets (#6883)
af28010 is described below

commit af2801031c455f86fb771d9f36c313f99459cd1a
Author: Jason Gustafson <ja...@confluent.io>
AuthorDate: Wed Jun 12 12:54:17 2019 -0700

    KAFKA-8483/KAFKA-8484; Ensure safe handling of producerId resets (#6883)
    
    The idempotent producer attempts to detect spurious UNKNOWN_PRODUCER_ID errors and handle them by reassigning sequence numbers to the inflight batches. The inflight batches are tracked in a PriorityQueue. The problem is that the reassignment of sequence numbers depends on the iteration order of PriorityQueue, which does not guarantee any ordering. So this can result in sequence numbers being assigned in the wrong order.  This patch fixes the problem by using a sorted set instead of a  [...]
    
    This patch also fixes KAFKA-8484, which can cause an IllegalStateException when the producerId is reset while there are pending produce requests inflight. The solution is to ensure that sequence numbers are only reset if the producerId of a failed batch corresponds to the current producerId.
    
    Reviewers: Guozhang Wang <wa...@gmail.com>
---
 .../producer/internals/RecordAccumulator.java      |   2 +-
 .../kafka/clients/producer/internals/Sender.java   |  53 +---
 .../producer/internals/TransactionManager.java     | 152 ++++++++---
 .../producer/internals/TransactionManagerTest.java | 288 +++++++++++++++++++--
 4 files changed, 399 insertions(+), 96 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/RecordAccumulator.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/RecordAccumulator.java
index e6b29f3..91fb8c9 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/internals/RecordAccumulator.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/RecordAccumulator.java
@@ -558,7 +558,7 @@ public final class RecordAccumulator {
                     if (shouldStopDrainBatchesForPartition(first, tp))
                         break;
 
-                    boolean isTransactional = transactionManager != null ? transactionManager.isTransactional() : false;
+                    boolean isTransactional = transactionManager != null && transactionManager.isTransactional();
                     ProducerIdAndEpoch producerIdAndEpoch =
                         transactionManager != null ? transactionManager.producerIdAndEpoch() : null;
                     ProducerBatch batch = deque.pollFirst();
diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/Sender.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/Sender.java
index cd1712a..121ddb2 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/internals/Sender.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/Sender.java
@@ -16,7 +16,6 @@
  */
 package org.apache.kafka.clients.producer.internals;
 
-import java.util.ArrayList;
 import org.apache.kafka.clients.ApiVersions;
 import org.apache.kafka.clients.ClientRequest;
 import org.apache.kafka.clients.ClientResponse;
@@ -33,11 +32,9 @@ import org.apache.kafka.common.errors.AuthenticationException;
 import org.apache.kafka.common.errors.ClusterAuthorizationException;
 import org.apache.kafka.common.errors.InvalidMetadataException;
 import org.apache.kafka.common.errors.OutOfOrderSequenceException;
-import org.apache.kafka.common.errors.ProducerFencedException;
 import org.apache.kafka.common.errors.RetriableException;
 import org.apache.kafka.common.errors.TimeoutException;
 import org.apache.kafka.common.errors.TopicAuthorizationException;
-import org.apache.kafka.common.errors.TransactionalIdAuthorizationException;
 import org.apache.kafka.common.errors.UnknownTopicOrPartitionException;
 import org.apache.kafka.common.errors.UnsupportedVersionException;
 import org.apache.kafka.common.message.InitProducerIdRequestData;
@@ -60,6 +57,7 @@ import org.apache.kafka.common.utils.Utils;
 import org.slf4j.Logger;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.Iterator;
@@ -295,9 +293,7 @@ public class Sender implements Runnable {
     void runOnce() {
         if (transactionManager != null) {
             try {
-                if (transactionManager.shouldResetProducerStateAfterResolvingSequences())
-                    // Check if the previous run expired batches which requires a reset of the producer state.
-                    transactionManager.resetProducerId();
+                transactionManager.resetProducerIdIfNeeded();
 
                 if (!transactionManager.isTransactional()) {
                     // this is an idempotent producer, so make sure we have a producer id
@@ -694,16 +690,7 @@ public class Sender implements Runnable {
 
     private void completeBatch(ProducerBatch batch, ProduceResponse.PartitionResponse response) {
         if (transactionManager != null) {
-            if (transactionManager.hasProducerIdAndEpoch(batch.producerId(), batch.producerEpoch())) {
-                transactionManager
-                    .maybeUpdateLastAckedSequence(batch.topicPartition, batch.baseSequence() + batch.recordCount - 1);
-                log.debug("ProducerId: {}; Set last ack'd sequence number for topic-partition {} to {}",
-                    batch.producerId(),
-                    batch.topicPartition,
-                    transactionManager.lastAckedSequence(batch.topicPartition).orElse(-1));
-            }
-            transactionManager.updateLastAckedOffset(response, batch);
-            transactionManager.removeInFlightBatch(batch);
+            transactionManager.handleCompletedBatch(batch, response);
         }
 
         if (batch.done(response.baseOffset, response.logAppendTime, null)) {
@@ -712,36 +699,20 @@ public class Sender implements Runnable {
         }
     }
 
-    private void failBatch(ProducerBatch batch, ProduceResponse.PartitionResponse response, RuntimeException exception,
+    private void failBatch(ProducerBatch batch,
+                           ProduceResponse.PartitionResponse response,
+                           RuntimeException exception,
                            boolean adjustSequenceNumbers) {
         failBatch(batch, response.baseOffset, response.logAppendTime, exception, adjustSequenceNumbers);
     }
 
-    private void failBatch(ProducerBatch batch, long baseOffset, long logAppendTime, RuntimeException exception,
-        boolean adjustSequenceNumbers) {
+    private void failBatch(ProducerBatch batch,
+                           long baseOffset,
+                           long logAppendTime,
+                           RuntimeException exception,
+                           boolean adjustSequenceNumbers) {
         if (transactionManager != null) {
-            if (exception instanceof OutOfOrderSequenceException
-                    && !transactionManager.isTransactional()
-                    && transactionManager.hasProducerId(batch.producerId())) {
-                log.error("The broker returned {} for topic-partition " +
-                            "{} at offset {}. This indicates data loss on the broker, and should be investigated.",
-                        exception, batch.topicPartition, baseOffset);
-
-                // Reset the transaction state since we have hit an irrecoverable exception and cannot make any guarantees
-                // about the previously committed message. Note that this will discard the producer id and sequence
-                // numbers for all existing partitions.
-                transactionManager.resetProducerId();
-            } else if (exception instanceof ClusterAuthorizationException
-                    || exception instanceof TransactionalIdAuthorizationException
-                    || exception instanceof ProducerFencedException
-                    || exception instanceof UnsupportedVersionException) {
-                transactionManager.transitionToFatalError(exception);
-            } else if (transactionManager.isTransactional()) {
-                transactionManager.transitionToAbortableError(exception);
-            }
-            transactionManager.removeInFlightBatch(batch);
-            if (adjustSequenceNumbers)
-                transactionManager.adjustSequencesDueToFailedBatch(batch);
+            transactionManager.handleFailedBatch(batch, exception, adjustSequenceNumbers);
         }
 
         this.sensors.recordErrors(batch.topicPartition.topic(), batch.recordCount);
diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
index 9ed0dde..182c92c 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
@@ -23,9 +23,13 @@ import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.Node;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.errors.AuthenticationException;
+import org.apache.kafka.common.errors.ClusterAuthorizationException;
 import org.apache.kafka.common.errors.GroupAuthorizationException;
+import org.apache.kafka.common.errors.OutOfOrderSequenceException;
 import org.apache.kafka.common.errors.ProducerFencedException;
 import org.apache.kafka.common.errors.TopicAuthorizationException;
+import org.apache.kafka.common.errors.TransactionalIdAuthorizationException;
+import org.apache.kafka.common.errors.UnsupportedVersionException;
 import org.apache.kafka.common.message.FindCoordinatorRequestData;
 import org.apache.kafka.common.message.InitProducerIdRequestData;
 import org.apache.kafka.common.protocol.Errors;
@@ -62,6 +66,10 @@ import java.util.OptionalInt;
 import java.util.OptionalLong;
 import java.util.PriorityQueue;
 import java.util.Set;
+import java.util.SortedSet;
+import java.util.TreeSet;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Consumer;
 import java.util.function.Supplier;
 
 import static org.apache.kafka.common.record.RecordBatch.NO_PRODUCER_EPOCH;
@@ -133,7 +141,7 @@ public class TransactionManager {
         // we continue to order batches by the sequence numbers even when the responses come back out of order during
         // leader failover. We add a batch to the queue when it is drained, and remove it when the batch completes
         // (either successfully or through a fatal failure).
-        private PriorityQueue<ProducerBatch> inflightBatchesBySequence;
+        private SortedSet<ProducerBatch> inflightBatchesBySequence;
 
         // We keep track of the last acknowledged offset on a per partition basis in order to disambiguate UnknownProducer
         // responses which are due to the retention period elapsing, and those which are due to actual lost data.
@@ -143,8 +151,18 @@ public class TransactionManager {
             this.nextSequence = 0;
             this.lastAckedSequence = NO_LAST_ACKED_SEQUENCE_NUMBER;
             this.lastAckedOffset = ProduceResponse.INVALID_OFFSET;
-            this.inflightBatchesBySequence = new PriorityQueue<>(5, Comparator.comparingInt(ProducerBatch::baseSequence));
+            this.inflightBatchesBySequence = new TreeSet<>(Comparator.comparingInt(ProducerBatch::baseSequence));
         }
+
+        public void resetSequenceNumbers(Consumer<ProducerBatch> resetSequence) {
+            TreeSet<ProducerBatch> newInflights = new TreeSet<>(Comparator.comparingInt(ProducerBatch::baseSequence));
+            for (ProducerBatch inflightBatch : inflightBatchesBySequence) {
+                resetSequence.accept(inflightBatch);
+                newInflights.add(inflightBatch);
+            }
+            inflightBatchesBySequence = newInflights;
+        }
+
     }
 
     private final TopicPartitionBookkeeper topicPartitionBookkeeper;
@@ -459,6 +477,12 @@ public class TransactionManager {
         this.partitionsWithUnresolvedSequences.clear();
     }
 
+    synchronized void resetProducerIdIfNeeded() {
+        if (shouldResetProducerStateAfterResolvingSequences())
+            // Check if the previous run expired batches which requires a reset of the producer state.
+            resetProducerId();
+    }
+
     /**
      * Returns the next sequence number to be written to the given TopicPartition.
      */
@@ -479,7 +503,7 @@ public class TransactionManager {
     synchronized void addInFlightBatch(ProducerBatch batch) {
         if (!batch.hasSequence())
             throw new IllegalStateException("Can't track batch for partition " + batch.topicPartition + " when sequence is not set.");
-        topicPartitionBookkeeper.getPartition(batch.topicPartition).inflightBatchesBySequence.offer(batch);
+        topicPartitionBookkeeper.getPartition(batch.topicPartition).inflightBatchesBySequence.add(batch);
     }
 
     /**
@@ -493,26 +517,25 @@ public class TransactionManager {
         if (!hasInflightBatches(topicPartition))
             return RecordBatch.NO_SEQUENCE;
 
-        ProducerBatch first = topicPartitionBookkeeper.getPartition(topicPartition).inflightBatchesBySequence.peek();
-        if (first == null)
+        SortedSet<ProducerBatch> inflightBatches = topicPartitionBookkeeper.getPartition(topicPartition).inflightBatchesBySequence;
+        if (inflightBatches.isEmpty())
             return RecordBatch.NO_SEQUENCE;
-
-        return first.baseSequence();
+        else
+            return inflightBatches.first().baseSequence();
     }
 
     synchronized ProducerBatch nextBatchBySequence(TopicPartition topicPartition) {
-        PriorityQueue<ProducerBatch> queue = topicPartitionBookkeeper.getPartition(topicPartition).inflightBatchesBySequence;
-        return queue.peek();
+        SortedSet<ProducerBatch> queue = topicPartitionBookkeeper.getPartition(topicPartition).inflightBatchesBySequence;
+        return queue.isEmpty() ? null : queue.first();
     }
 
     synchronized void removeInFlightBatch(ProducerBatch batch) {
         if (hasInflightBatches(batch.topicPartition)) {
-            PriorityQueue<ProducerBatch> queue = topicPartitionBookkeeper.getPartition(batch.topicPartition).inflightBatchesBySequence;
-            queue.remove(batch);
+            topicPartitionBookkeeper.getPartition(batch.topicPartition).inflightBatchesBySequence.remove(batch);
         }
     }
 
-    synchronized void maybeUpdateLastAckedSequence(TopicPartition topicPartition, int sequence) {
+    private void maybeUpdateLastAckedSequence(TopicPartition topicPartition, int sequence) {
         if (sequence > lastAckedSequence(topicPartition).orElse(NO_LAST_ACKED_SEQUENCE_NUMBER))
             topicPartitionBookkeeper.getPartition(topicPartition).lastAckedSequence = sequence;
     }
@@ -525,7 +548,7 @@ public class TransactionManager {
         return topicPartitionBookkeeper.lastAckedOffset(topicPartition);
     }
 
-    synchronized void updateLastAckedOffset(ProduceResponse.PartitionResponse response, ProducerBatch batch) {
+    private void updateLastAckedOffset(ProduceResponse.PartitionResponse response, ProducerBatch batch) {
         if (response.baseOffset == ProduceResponse.INVALID_OFFSET)
             return;
         long lastOffset = response.baseOffset + batch.recordCount - 1;
@@ -543,12 +566,66 @@ public class TransactionManager {
         }
     }
 
+    public synchronized void handleCompletedBatch(ProducerBatch batch, ProduceResponse.PartitionResponse response) {
+        if (!hasProducerIdAndEpoch(batch.producerId(), batch.producerEpoch())) {
+            log.debug("Ignoring completed batch {} with producer id {}, epoch {}, and sequence number {} " +
+                            "since the producerId has been reset internally", batch, batch.producerId(),
+                    batch.producerEpoch(), batch.baseSequence());
+            return;
+        }
+
+        maybeUpdateLastAckedSequence(batch.topicPartition, batch.baseSequence() + batch.recordCount - 1);
+        log.debug("ProducerId: {}; Set last ack'd sequence number for topic-partition {} to {}",
+                batch.producerId(),
+                batch.topicPartition,
+                lastAckedSequence(batch.topicPartition).orElse(-1));
+
+        updateLastAckedOffset(response, batch);
+        removeInFlightBatch(batch);
+    }
+
+    private void maybeTransitionToErrorState(RuntimeException exception) {
+        if (exception instanceof ClusterAuthorizationException
+                || exception instanceof TransactionalIdAuthorizationException
+                || exception instanceof ProducerFencedException
+                || exception instanceof UnsupportedVersionException) {
+            transitionToFatalError(exception);
+        } else if (isTransactional()) {
+            transitionToAbortableError(exception);
+        }
+    }
+
+    public synchronized void handleFailedBatch(ProducerBatch batch, RuntimeException exception, boolean adjustSequenceNumbers) {
+        maybeTransitionToErrorState(exception);
+
+        if (!hasProducerIdAndEpoch(batch.producerId(), batch.producerEpoch())) {
+            log.debug("Ignoring failed batch {} with producer id {}, epoch {}, and sequence number {} " +
+                    "since the producerId has been reset internally", batch, batch.producerId(),
+                    batch.producerEpoch(), batch.baseSequence(), exception);
+            return;
+        }
+
+        if (exception instanceof OutOfOrderSequenceException && !isTransactional()) {
+            log.error("The broker returned {} for topic-partition {} with producerId {}, epoch {}, and sequence number {}",
+                    exception, batch.topicPartition, batch.producerId(), batch.producerEpoch(), batch.baseSequence());
+
+            // Reset the producerId since we have hit an irrecoverable exception and cannot make any guarantees
+            // about the previously committed message. Note that this will discard the producer id and sequence
+            // numbers for all existing partitions.
+            resetProducerId();
+        } else {
+            removeInFlightBatch(batch);
+            if (adjustSequenceNumbers)
+                adjustSequencesDueToFailedBatch(batch);
+        }
+    }
+
     // If a batch is failed fatally, the sequence numbers for future batches bound for the partition must be adjusted
     // so that they don't fail with the OutOfOrderSequenceException.
     //
     // This method must only be called when we know that the batch is question has been unequivocally failed by the broker,
     // ie. it has received a confirmed fatal status code like 'Message Too Large' or something similar.
-    synchronized void adjustSequencesDueToFailedBatch(ProducerBatch batch) {
+    private void adjustSequencesDueToFailedBatch(ProducerBatch batch) {
         if (!topicPartitionBookkeeper.contains(batch.topicPartition))
             // Sequence numbers are not being tracked for this partition. This could happen if the producer id was just
             // reset due to a previous OutOfOrderSequenceException.
@@ -558,38 +635,39 @@ public class TransactionManager {
         int currentSequence = sequenceNumber(batch.topicPartition);
         currentSequence -= batch.recordCount;
         if (currentSequence < 0)
-            throw new IllegalStateException("Sequence number for partition " + batch.topicPartition + " is going to become negative : " + currentSequence);
+            throw new IllegalStateException("Sequence number for partition " + batch.topicPartition + " is going to become negative: " + currentSequence);
 
         setNextSequence(batch.topicPartition, currentSequence);
 
-        for (ProducerBatch inFlightBatch : topicPartitionBookkeeper.getPartition(batch.topicPartition).inflightBatchesBySequence) {
+        topicPartitionBookkeeper.getPartition(batch.topicPartition).resetSequenceNumbers(inFlightBatch -> {
             if (inFlightBatch.baseSequence() < batch.baseSequence())
-                continue;
+                return;
+
             int newSequence = inFlightBatch.baseSequence() - batch.recordCount;
             if (newSequence < 0)
                 throw new IllegalStateException("Sequence number for batch with sequence " + inFlightBatch.baseSequence()
-                        + " for partition " + batch.topicPartition + " is going to become negative :" + newSequence);
+                        + " for partition " + batch.topicPartition + " is going to become negative: " + newSequence);
 
             log.info("Resetting sequence number of batch with current sequence {} for partition {} to {}", inFlightBatch.baseSequence(), batch.topicPartition, newSequence);
             inFlightBatch.resetProducerState(new ProducerIdAndEpoch(inFlightBatch.producerId(), inFlightBatch.producerEpoch()), newSequence, inFlightBatch.isTransactional());
-        }
+
+        });
     }
 
-    private synchronized void startSequencesAtBeginning(TopicPartition topicPartition) {
-        int sequence = 0;
-        for (ProducerBatch inFlightBatch : topicPartitionBookkeeper.getPartition(topicPartition).inflightBatchesBySequence) {
+    private void startSequencesAtBeginning(TopicPartition topicPartition) {
+        final AtomicInteger sequence = new AtomicInteger(0);
+        topicPartitionBookkeeper.getPartition(topicPartition).resetSequenceNumbers(inFlightBatch -> {
             log.info("Resetting sequence number of batch with current sequence {} for partition {} to {}",
-                    inFlightBatch.baseSequence(), inFlightBatch.topicPartition, sequence);
+                    inFlightBatch.baseSequence(), inFlightBatch.topicPartition, sequence.get());
             inFlightBatch.resetProducerState(new ProducerIdAndEpoch(inFlightBatch.producerId(),
-                    inFlightBatch.producerEpoch()), sequence, inFlightBatch.isTransactional());
-
-            sequence += inFlightBatch.recordCount;
-        }
-        setNextSequence(topicPartition, sequence);
+                    inFlightBatch.producerEpoch()), sequence.get(), inFlightBatch.isTransactional());
+            sequence.getAndAdd(inFlightBatch.recordCount);
+        });
+        setNextSequence(topicPartition, sequence.get());
         topicPartitionBookkeeper.getPartition(topicPartition).lastAckedSequence = NO_LAST_ACKED_SEQUENCE_NUMBER;
     }
 
-    private synchronized boolean hasInflightBatches(TopicPartition topicPartition) {
+    private boolean hasInflightBatches(TopicPartition topicPartition) {
         return topicPartitionBookkeeper.contains(topicPartition)
                 && !topicPartitionBookkeeper.getPartition(topicPartition).inflightBatchesBySequence.isEmpty();
     }
@@ -609,7 +687,7 @@ public class TransactionManager {
 
     // Checks if there are any partitions with unresolved partitions which may now be resolved. Returns true if
     // the producer id needs a reset, false otherwise.
-    synchronized boolean shouldResetProducerStateAfterResolvingSequences() {
+    private boolean shouldResetProducerStateAfterResolvingSequences() {
         if (isTransactional())
             // We should not reset producer state if we are transactional. We will transition to a fatal error instead.
             return false;
@@ -634,11 +712,11 @@ public class TransactionManager {
         return false;
     }
 
-    private synchronized boolean isNextSequence(TopicPartition topicPartition, int sequence) {
+    private boolean isNextSequence(TopicPartition topicPartition, int sequence) {
         return sequence - lastAckedSequence(topicPartition).orElse(NO_LAST_ACKED_SEQUENCE_NUMBER) == 1;
     }
 
-    private synchronized void setNextSequence(TopicPartition topicPartition, int sequence) {
+    private void setNextSequence(TopicPartition topicPartition, int sequence) {
         topicPartitionBookkeeper.getPartition(topicPartition).nextSequence = sequence;
     }
 
@@ -755,7 +833,7 @@ public class TransactionManager {
     }
 
     synchronized boolean canRetry(ProduceResponse.PartitionResponse response, ProducerBatch batch) {
-        if (!hasProducerId(batch.producerId()))
+        if (!hasProducerIdAndEpoch(batch.producerId(), batch.producerEpoch()))
             return false;
 
         Errors error = response.error;
@@ -807,7 +885,7 @@ public class TransactionManager {
         transitionTo(target, null);
     }
 
-    private synchronized void transitionTo(State target, RuntimeException error) {
+    private void transitionTo(State target, RuntimeException error) {
         if (!currentState.isTransitionValid(currentState, target)) {
             String idString = transactionalId == null ?  "" : "TransactionalId " + transactionalId + ": ";
             throw new KafkaException(idString + "Invalid transition attempted from state "
@@ -865,7 +943,7 @@ public class TransactionManager {
         pendingRequests.add(requestHandler);
     }
 
-    private synchronized void lookupCoordinator(FindCoordinatorRequest.CoordinatorType type, String coordinatorKey) {
+    private void lookupCoordinator(FindCoordinatorRequest.CoordinatorType type, String coordinatorKey) {
         switch (type) {
             case GROUP:
                 consumerGroupCoordinator = null;
@@ -884,7 +962,7 @@ public class TransactionManager {
         enqueueRequest(new FindCoordinatorHandler(builder));
     }
 
-    private synchronized void completeTransaction() {
+    private void completeTransaction() {
         transitionTo(State.READY);
         lastError = null;
         transactionStarted = false;
@@ -893,7 +971,7 @@ public class TransactionManager {
         partitionsInTransaction.clear();
     }
 
-    private synchronized TxnRequestHandler addPartitionsToTransactionHandler() {
+    private TxnRequestHandler addPartitionsToTransactionHandler() {
         pendingPartitionsInTransaction.addAll(newPartitionsInTransaction);
         newPartitionsInTransaction.clear();
         AddPartitionsToTxnRequest.Builder builder = new AddPartitionsToTxnRequest.Builder(transactionalId,
diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
index eceb8df..03b13d3 100644
--- a/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
@@ -33,6 +33,7 @@ import org.apache.kafka.common.errors.TopicAuthorizationException;
 import org.apache.kafka.common.errors.TransactionalIdAuthorizationException;
 import org.apache.kafka.common.errors.UnsupportedForMessageFormatException;
 import org.apache.kafka.common.errors.UnsupportedVersionException;
+import org.apache.kafka.common.header.Header;
 import org.apache.kafka.common.internals.ClusterResourceListeners;
 import org.apache.kafka.common.message.InitProducerIdResponseData;
 import org.apache.kafka.common.metrics.MetricConfig;
@@ -40,9 +41,11 @@ import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.protocol.Errors;
 import org.apache.kafka.common.record.CompressionType;
 import org.apache.kafka.common.record.MemoryRecords;
+import org.apache.kafka.common.record.MemoryRecordsBuilder;
 import org.apache.kafka.common.record.MutableRecordBatch;
 import org.apache.kafka.common.record.Record;
 import org.apache.kafka.common.record.RecordBatch;
+import org.apache.kafka.common.record.TimestampType;
 import org.apache.kafka.common.requests.AddOffsetsToTxnRequest;
 import org.apache.kafka.common.requests.AddOffsetsToTxnResponse;
 import org.apache.kafka.common.requests.AddPartitionsToTxnRequest;
@@ -65,6 +68,7 @@ import org.apache.kafka.test.TestUtils;
 import org.junit.Before;
 import org.junit.Test;
 
+import java.nio.ByteBuffer;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
@@ -73,6 +77,7 @@ import java.util.Iterator;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.OptionalInt;
 import java.util.Set;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Future;
@@ -85,8 +90,8 @@ import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertNull;
-import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
 public class TransactionManagerTest {
@@ -582,6 +587,164 @@ public class TransactionManagerTest {
     }
 
     @Test
+    public void testResetSequenceNumbersAfterUnknownProducerId() {
+        final long producerId = 13131L;
+        final short epoch = 1;
+        ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(producerId, epoch);
+
+        TransactionManager transactionManager = new TransactionManager();
+        transactionManager.setProducerIdAndEpoch(producerIdAndEpoch);
+
+        ProducerBatch b1 = writeIdempotentBatchWithValue(transactionManager, tp0, "1");
+        ProducerBatch b2 = writeIdempotentBatchWithValue(transactionManager, tp0, "2");
+        ProducerBatch b3 = writeIdempotentBatchWithValue(transactionManager, tp0, "3");
+        ProducerBatch b4 = writeIdempotentBatchWithValue(transactionManager, tp0, "4");
+        ProducerBatch b5 = writeIdempotentBatchWithValue(transactionManager, tp0, "5");
+        assertEquals(5, transactionManager.sequenceNumber(tp0).intValue());
+
+        // First batch succeeds
+        long b1AppendTime = time.milliseconds();
+        ProduceResponse.PartitionResponse b1Response = new ProduceResponse.PartitionResponse(
+                Errors.NONE, 500L, b1AppendTime, 0L);
+        b1.done(500L, b1AppendTime, null);
+        transactionManager.handleCompletedBatch(b1, b1Response);
+
+        // Retention caused log start offset to jump forward. We set sequence numbers back to 0
+        ProduceResponse.PartitionResponse b2Response = new ProduceResponse.PartitionResponse(
+                Errors.UNKNOWN_PRODUCER_ID, -1, -1, 600L);
+        assertTrue(transactionManager.canRetry(b2Response, b2));
+        assertEquals(4, transactionManager.sequenceNumber(tp0).intValue());
+        assertEquals(0, b2.baseSequence());
+        assertEquals(1, b3.baseSequence());
+        assertEquals(2, b4.baseSequence());
+        assertEquals(3, b5.baseSequence());
+    }
+
+    @Test
+    public void testAdjustSequenceNumbersAfterFatalError() {
+        final long producerId = 13131L;
+        final short epoch = 1;
+        ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(producerId, epoch);
+
+        TransactionManager transactionManager = new TransactionManager();
+        transactionManager.setProducerIdAndEpoch(producerIdAndEpoch);
+
+        ProducerBatch b1 = writeIdempotentBatchWithValue(transactionManager, tp0, "1");
+        ProducerBatch b2 = writeIdempotentBatchWithValue(transactionManager, tp0, "2");
+        ProducerBatch b3 = writeIdempotentBatchWithValue(transactionManager, tp0, "3");
+        ProducerBatch b4 = writeIdempotentBatchWithValue(transactionManager, tp0, "4");
+        ProducerBatch b5 = writeIdempotentBatchWithValue(transactionManager, tp0, "5");
+        assertEquals(5, transactionManager.sequenceNumber(tp0).intValue());
+
+        // First batch succeeds
+        long b1AppendTime = time.milliseconds();
+        ProduceResponse.PartitionResponse b1Response = new ProduceResponse.PartitionResponse(
+                Errors.NONE, 500L, b1AppendTime, 0L);
+        b1.done(500L, b1AppendTime, null);
+        transactionManager.handleCompletedBatch(b1, b1Response);
+
+        // Second batch fails with a fatal error. Sequence numbers are adjusted by one for remaining
+        // inflight batches.
+        ProduceResponse.PartitionResponse b2Response = new ProduceResponse.PartitionResponse(
+                Errors.MESSAGE_TOO_LARGE, -1, -1, 0L);
+        assertFalse(transactionManager.canRetry(b2Response, b2));
+
+        b2.done(-1L, -1L, Errors.MESSAGE_TOO_LARGE.exception());
+        transactionManager.handleFailedBatch(b2, Errors.MESSAGE_TOO_LARGE.exception(), true);
+        assertEquals(4, transactionManager.sequenceNumber(tp0).intValue());
+        assertEquals(1, b3.baseSequence());
+        assertEquals(2, b4.baseSequence());
+        assertEquals(3, b5.baseSequence());
+
+        // The remaining batches are doomed to fail, but they can be retried. Expected
+        // sequence numbers should remain the same.
+        ProduceResponse.PartitionResponse b3Response = new ProduceResponse.PartitionResponse(
+                Errors.OUT_OF_ORDER_SEQUENCE_NUMBER, -1, -1, 0L);
+        assertTrue(transactionManager.canRetry(b3Response, b3));
+        assertEquals(4, transactionManager.sequenceNumber(tp0).intValue());
+        assertEquals(1, b3.baseSequence());
+        assertEquals(2, b4.baseSequence());
+        assertEquals(3, b5.baseSequence());
+    }
+
+    @Test
+    public void testBatchFailureAfterProducerReset() {
+        // This tests a scenario where the producerId is reset while pending requests are still inflight.
+        // The returned responses should not update internal state.
+
+        final long producerId = 13131L;
+        final short epoch = 1;
+        ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(producerId, epoch);
+        TransactionManager transactionManager = new TransactionManager();
+        transactionManager.setProducerIdAndEpoch(producerIdAndEpoch);
+
+        ProducerBatch b1 = writeIdempotentBatchWithValue(transactionManager, tp0, "1");
+
+        ProducerIdAndEpoch updatedProducerIdAndEpoch = new ProducerIdAndEpoch(producerId + 1, epoch);
+        transactionManager.resetProducerId();
+        transactionManager.setProducerIdAndEpoch(updatedProducerIdAndEpoch);
+
+        ProducerBatch b2 = writeIdempotentBatchWithValue(transactionManager, tp0, "2");
+        assertEquals(1, transactionManager.sequenceNumber(tp0).intValue());
+
+        ProduceResponse.PartitionResponse b1Response = new ProduceResponse.PartitionResponse(
+                Errors.UNKNOWN_PRODUCER_ID, -1, -1, 400L);
+        assertFalse(transactionManager.canRetry(b1Response, b1));
+        transactionManager.handleFailedBatch(b1, Errors.UNKNOWN_PRODUCER_ID.exception(), true);
+
+        assertEquals(1, transactionManager.sequenceNumber(tp0).intValue());
+        assertEquals(b2, transactionManager.nextBatchBySequence(tp0));
+    }
+
+    @Test
+    public void testBatchCompletedAfterProducerReset() {
+        final long producerId = 13131L;
+        final short epoch = 1;
+        ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(producerId, epoch);
+        TransactionManager transactionManager = new TransactionManager();
+        transactionManager.setProducerIdAndEpoch(producerIdAndEpoch);
+
+        ProducerBatch b1 = writeIdempotentBatchWithValue(transactionManager, tp0, "1");
+
+        // The producerId might be reset due to a failure on another partition
+        ProducerIdAndEpoch updatedProducerIdAndEpoch = new ProducerIdAndEpoch(producerId + 1, epoch);
+        transactionManager.resetProducerId();
+        transactionManager.setProducerIdAndEpoch(updatedProducerIdAndEpoch);
+
+        ProducerBatch b2 = writeIdempotentBatchWithValue(transactionManager, tp0, "2");
+        assertEquals(1, transactionManager.sequenceNumber(tp0).intValue());
+
+        // If the request returns successfully, we should ignore the response and not update any state
+        ProduceResponse.PartitionResponse b1Response = new ProduceResponse.PartitionResponse(
+                Errors.NONE, 500L, time.milliseconds(), 0L);
+        transactionManager.handleCompletedBatch(b1, b1Response);
+
+        assertEquals(1, transactionManager.sequenceNumber(tp0).intValue());
+        assertEquals(b2, transactionManager.nextBatchBySequence(tp0));
+    }
+
+    private ProducerBatch writeIdempotentBatchWithValue(TransactionManager manager,
+                                                        TopicPartition tp,
+                                                        String value) {
+        int seq = manager.sequenceNumber(tp);
+        manager.incrementSequenceNumber(tp, 1);
+        ProducerBatch batch = batchWithValue(tp, value);
+        batch.setProducerState(manager.producerIdAndEpoch(), seq, false);
+        manager.addInFlightBatch(batch);
+        batch.close();
+        return batch;
+    }
+
+    private ProducerBatch batchWithValue(TopicPartition tp, String value) {
+        MemoryRecordsBuilder builder = MemoryRecords.builder(ByteBuffer.allocate(64),
+                CompressionType.NONE, TimestampType.CREATE_TIME, 0L);
+        long currentTimeMs = time.milliseconds();
+        ProducerBatch batch = new ProducerBatch(tp, builder, currentTimeMs);
+        batch.tryAppend(currentTimeMs, new byte[0], value.getBytes(), new Header[0], null, currentTimeMs);
+        return batch;
+    }
+
+    @Test
     public void testSequenceNumberOverflow() {
         TransactionManager transactionManager = new TransactionManager();
         assertEquals((int) transactionManager.sequenceNumber(tp0), 0);
@@ -2272,30 +2435,121 @@ public class TransactionManagerTest {
     }
 
     @Test
-    public void testShouldResetProducerStateAfterResolvingSequences() {
-        // Create a TransactionManager without a transactionalId to test
-        // shouldResetProducerStateAfterResolvingSequences.
+    public void testResetProducerIdAfterWithoutPendingInflightRequests() {
         TransactionManager manager = new TransactionManager(logContext, null, transactionTimeoutMs,
             DEFAULT_RETRY_BACKOFF_MS);
-        assertFalse(manager.shouldResetProducerStateAfterResolvingSequences());
+        long producerId = 15L;
+        short epoch = 5;
+        ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(producerId, epoch);
+        manager.setProducerIdAndEpoch(producerIdAndEpoch);
+
+        // Nothing to resolve, so no reset is needed
+        manager.resetProducerIdIfNeeded();
+        assertEquals(producerIdAndEpoch, manager.producerIdAndEpoch());
+
         TopicPartition tp0 = new TopicPartition("foo", 0);
-        TopicPartition tp1 = new TopicPartition("foo", 1);
         assertEquals(Integer.valueOf(0), manager.sequenceNumber(tp0));
-        assertEquals(Integer.valueOf(0), manager.sequenceNumber(tp1));
 
-        manager.incrementSequenceNumber(tp0, 1);
-        manager.incrementSequenceNumber(tp1, 1);
-        manager.maybeUpdateLastAckedSequence(tp0, 0);
-        manager.maybeUpdateLastAckedSequence(tp1, 0);
+        ProducerBatch b1 = writeIdempotentBatchWithValue(manager, tp0, "1");
+        assertEquals(Integer.valueOf(1), manager.sequenceNumber(tp0));
+        manager.handleCompletedBatch(b1, new ProduceResponse.PartitionResponse(
+                Errors.NONE, 500L, time.milliseconds(), 0L));
+        assertEquals(OptionalInt.of(0), manager.lastAckedSequence(tp0));
+
+        // Marking sequence numbers unresolved without inflight requests is basically a no-op.
+        manager.markSequenceUnresolved(tp0);
+        manager.resetProducerIdIfNeeded();
+        assertEquals(producerIdAndEpoch, manager.producerIdAndEpoch());
+        assertFalse(manager.hasUnresolvedSequences());
+
+        // We have a new batch which fails with a timeout
+        ProducerBatch b2 = writeIdempotentBatchWithValue(manager, tp0, "2");
+        assertEquals(Integer.valueOf(2), manager.sequenceNumber(tp0));
+        manager.markSequenceUnresolved(tp0);
+        manager.handleFailedBatch(b2, new TimeoutException(), false);
+        assertTrue(manager.hasUnresolvedSequences());
+
+        // We only had one inflight batch, so we should be able to clear the unresolved status
+        // and reset the producerId
+        manager.resetProducerIdIfNeeded();
+        assertFalse(manager.hasUnresolvedSequences());
+        assertFalse(manager.hasProducerId());
+    }
+
+    @Test
+    public void testNoProducerIdResetAfterLastInFlightBatchSucceeds() {
+        TransactionManager manager = new TransactionManager(logContext, null, transactionTimeoutMs,
+                DEFAULT_RETRY_BACKOFF_MS);
+        long producerId = 15L;
+        short epoch = 5;
+        ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(producerId, epoch);
+        manager.setProducerIdAndEpoch(producerIdAndEpoch);
+
+        TopicPartition tp0 = new TopicPartition("foo", 0);
+        ProducerBatch b1 = writeIdempotentBatchWithValue(manager, tp0, "1");
+        ProducerBatch b2 = writeIdempotentBatchWithValue(manager, tp0, "2");
+        ProducerBatch b3 = writeIdempotentBatchWithValue(manager, tp0, "3");
+        assertEquals(3, manager.sequenceNumber(tp0).intValue());
+
+        // The first batch fails with a timeout
         manager.markSequenceUnresolved(tp0);
-        manager.markSequenceUnresolved(tp1);
-        assertFalse(manager.shouldResetProducerStateAfterResolvingSequences());
+        manager.handleFailedBatch(b1, new TimeoutException(), false);
+        assertTrue(manager.hasUnresolvedSequences());
+
+        // The reset should not occur until sequence numbers have been resolved
+        manager.resetProducerIdIfNeeded();
+        assertEquals(producerIdAndEpoch, manager.producerIdAndEpoch());
+        assertTrue(manager.hasUnresolvedSequences());
+
+        // The second batch fails as well with a timeout
+        manager.handleFailedBatch(b2, new TimeoutException(), false);
+        manager.resetProducerIdIfNeeded();
+        assertEquals(producerIdAndEpoch, manager.producerIdAndEpoch());
+        assertTrue(manager.hasUnresolvedSequences());
+
+        // The third batch succeeds, which should resolve the sequence number without
+        // requiring a producerId reset.
+        manager.handleCompletedBatch(b3, new ProduceResponse.PartitionResponse(
+                Errors.NONE, 500L, time.milliseconds(), 0L));
+        manager.resetProducerIdIfNeeded();
+        assertEquals(producerIdAndEpoch, manager.producerIdAndEpoch());
+        assertFalse(manager.hasUnresolvedSequences());
+        assertEquals(3, manager.sequenceNumber(tp0).intValue());
+    }
+
+    @Test
+    public void testProducerIdResetAfterLastInFlightBatchFails() {
+        TransactionManager manager = new TransactionManager(logContext, null, transactionTimeoutMs,
+                DEFAULT_RETRY_BACKOFF_MS);
+        long producerId = 15L;
+        short epoch = 5;
+        ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(producerId, epoch);
+        manager.setProducerIdAndEpoch(producerIdAndEpoch);
+
+        TopicPartition tp0 = new TopicPartition("foo", 0);
+        ProducerBatch b1 = writeIdempotentBatchWithValue(manager, tp0, "1");
+        ProducerBatch b2 = writeIdempotentBatchWithValue(manager, tp0, "2");
+        ProducerBatch b3 = writeIdempotentBatchWithValue(manager, tp0, "3");
+        assertEquals(Integer.valueOf(3), manager.sequenceNumber(tp0));
 
-        manager.maybeUpdateLastAckedSequence(tp0, 5);
-        manager.incrementSequenceNumber(tp0, 1);
+        // The first batch fails with a timeout
         manager.markSequenceUnresolved(tp0);
-        manager.markSequenceUnresolved(tp1);
-        assertTrue(manager.shouldResetProducerStateAfterResolvingSequences());
+        manager.handleFailedBatch(b1, new TimeoutException(), false);
+        assertTrue(manager.hasUnresolvedSequences());
+
+        // The second batch succeeds, but sequence numbers are still not resolved
+        manager.handleCompletedBatch(b2, new ProduceResponse.PartitionResponse(
+                Errors.NONE, 500L, time.milliseconds(), 0L));
+        manager.resetProducerIdIfNeeded();
+        assertEquals(producerIdAndEpoch, manager.producerIdAndEpoch());
+        assertTrue(manager.hasUnresolvedSequences());
+
+        // When the last inflight batch fails, we have to reset the producerId
+        manager.handleFailedBatch(b3, new TimeoutException(), false);
+        manager.resetProducerIdIfNeeded();
+        assertFalse(manager.hasProducerId());
+        assertFalse(manager.hasUnresolvedSequences());
+        assertEquals(0, manager.sequenceNumber(tp0).intValue());
     }
 
     @Test