You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jg...@apache.org on 2022/01/21 17:36:09 UTC

[kafka] branch 3.1 updated: KAFKA-13412; Ensure initTransactions() safe for retry after timeout (#11452)

This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch 3.1
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/3.1 by this push:
     new cde0046  KAFKA-13412; Ensure initTransactions() safe for retry after timeout (#11452)
cde0046 is described below

commit cde0046d15a6f5f13e6700ab7e83ef38074ff072
Author: Jason Gustafson <ja...@confluent.io>
AuthorDate: Wed Jan 19 13:20:41 2022 -0800

    KAFKA-13412; Ensure initTransactions() safe for retry after timeout (#11452)
    
    If the user's `initTransactions` call times out, the user is expected to retry. However, the producer will continue retrying the `InitProducerId` request in the background. If it happens to return before the user retry of `initTransactions`, then the producer will raise an exception about an invalid state transition.
    
    The patch fixes the issue by tracking the pending state transition until the user has acknowledged the operation's result. In the case of `initTransactions`, even if the `InitProducerId` returns in the background and the state changes, we can still retry the `initTransactions` call to obtain the result.
    
    Reviewers: David Jacot <dj...@confluent.io>
---
 .../kafka/clients/producer/KafkaProducer.java      |   8 +-
 .../producer/internals/TransactionManager.java     | 133 ++++---
 .../internals/TransactionalRequestResult.java      |  34 +-
 .../kafka/common/utils/ProducerIdAndEpoch.java     |   2 +-
 .../kafka/clients/producer/KafkaProducerTest.java  |  44 ++-
 .../clients/producer/internals/SenderTest.java     |  32 +-
 .../producer/internals/TransactionManagerTest.java | 417 +++++++++++----------
 .../kafka/api/AuthorizerIntegrationTest.scala      |  65 ++--
 .../integration/kafka/api/TransactionsTest.scala   |   4 +-
 9 files changed, 421 insertions(+), 318 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java b/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java
index dbb908d..a7f9389 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java
@@ -964,9 +964,10 @@ public class KafkaProducer<K, V> implements Producer<K, V> {
             // producer callback will make sure to call both 'callback' and interceptor callback
             Callback interceptCallback = new InterceptorCallback<>(callback, this.interceptors, tp);
 
-            if (transactionManager != null && transactionManager.isTransactional()) {
-                transactionManager.failIfNotReadyForSend();
+            if (transactionManager != null) {
+                transactionManager.maybeAddPartition(tp);
             }
+
             RecordAccumulator.RecordAppendResult result = accumulator.append(tp, timestamp, serializedKey,
                     serializedValue, headers, interceptCallback, remainingWaitMs, true, nowMs);
 
@@ -985,9 +986,6 @@ public class KafkaProducer<K, V> implements Producer<K, V> {
                     serializedValue, headers, interceptCallback, remainingWaitMs, false, nowMs);
             }
 
-            if (transactionManager != null && transactionManager.isTransactional())
-                transactionManager.maybeAddPartitionToTransaction(tp);
-
             if (result.batchIsFull || result.newBatchCreated) {
                 log.trace("Waking up the sender since topic {} partition {} is either full or getting a new batch", record.topic(), partition);
                 this.sender.wakeup();
diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
index 2de31a0..521d5da 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
@@ -225,7 +225,7 @@ public class TransactionManager {
     private final Set<TopicPartition> newPartitionsInTransaction;
     private final Set<TopicPartition> pendingPartitionsInTransaction;
     private final Set<TopicPartition> partitionsInTransaction;
-    private TransactionalRequestResult pendingResult;
+    private PendingStateTransition pendingTransition;
 
     // This is used by the TxnRequestHandlers to control how long to back off before a given request is retried.
     // For instance, this value is lowered by the AddPartitionsToTxnHandler when it receives a CONCURRENT_TRANSACTIONS
@@ -329,6 +329,8 @@ public class TransactionManager {
     }
 
     synchronized TransactionalRequestResult initializeTransactions(ProducerIdAndEpoch producerIdAndEpoch) {
+        maybeFailWithError();
+
         boolean isEpochBump = producerIdAndEpoch != ProducerIdAndEpoch.NONE;
         return handleCachedTransactionRequestResult(() -> {
             // If this is an epoch bump, we will transition the state as part of handling the EndTxnRequest
@@ -347,11 +349,12 @@ public class TransactionManager {
                     isEpochBump);
             enqueueRequest(handler);
             return handler.result;
-        }, State.INITIALIZING);
+        }, State.INITIALIZING, "initTransactions");
     }
 
     public synchronized void beginTransaction() {
         ensureTransactional();
+        throwIfPendingState("beginTransaction");
         maybeFailWithError();
         transitionTo(State.IN_TRANSACTION);
     }
@@ -361,7 +364,7 @@ public class TransactionManager {
             maybeFailWithError();
             transitionTo(State.COMMITTING_TRANSACTION);
             return beginCompletingTransaction(TransactionResult.COMMIT);
-        }, State.COMMITTING_TRANSACTION);
+        }, State.COMMITTING_TRANSACTION, "commitTransaction");
     }
 
     public synchronized TransactionalRequestResult beginAbort() {
@@ -373,7 +376,7 @@ public class TransactionManager {
             // We're aborting the transaction, so there should be no need to add new partitions
             newPartitionsInTransaction.clear();
             return beginCompletingTransaction(TransactionResult.ABORT);
-        }, State.ABORTING_TRANSACTION);
+        }, State.ABORTING_TRANSACTION, "abortTransaction");
     }
 
     private TransactionalRequestResult beginCompletingTransaction(TransactionResult transactionResult) {
@@ -404,10 +407,13 @@ public class TransactionManager {
     public synchronized TransactionalRequestResult sendOffsetsToTransaction(final Map<TopicPartition, OffsetAndMetadata> offsets,
                                                                             final ConsumerGroupMetadata groupMetadata) {
         ensureTransactional();
+        throwIfPendingState("sendOffsetsToTransaction");
         maybeFailWithError();
-        if (currentState != State.IN_TRANSACTION)
-            throw new KafkaException("Cannot send offsets to transaction either because the producer is not in an " +
-                    "active transaction");
+
+        if (currentState != State.IN_TRANSACTION) {
+            throw new IllegalStateException("Cannot send offsets if a transaction is not in progress " +
+                "(currentState= " + currentState + ")");
+        }
 
         log.debug("Begin adding offsets {} for consumer group {} to transaction", offsets, groupMetadata);
         AddOffsetsToTxnRequest.Builder builder = new AddOffsetsToTxnRequest.Builder(
@@ -423,34 +429,31 @@ public class TransactionManager {
         return handler.result;
     }
 
-    public synchronized void maybeAddPartitionToTransaction(TopicPartition topicPartition) {
-        if (isPartitionAdded(topicPartition) || isPartitionPendingAdd(topicPartition))
-            return;
+    public synchronized void maybeAddPartition(TopicPartition topicPartition) {
+        maybeFailWithError();
+        throwIfPendingState("send");
 
-        log.debug("Begin adding new partition {} to transaction", topicPartition);
-        topicPartitionBookkeeper.addPartition(topicPartition);
-        newPartitionsInTransaction.add(topicPartition);
+        if (isTransactional()) {
+            if (!hasProducerId()) {
+                throw new IllegalStateException("Cannot add partition " + topicPartition +
+                    " to transaction before completing a call to initTransactions");
+            } else if (currentState != State.IN_TRANSACTION) {
+                throw new IllegalStateException("Cannot add partition " + topicPartition +
+                    " to transaction while in state  " + currentState);
+            } else if (isPartitionAdded(topicPartition) || isPartitionPendingAdd(topicPartition)) {
+                return;
+            } else {
+                log.debug("Begin adding new partition {} to transaction", topicPartition);
+                topicPartitionBookkeeper.addPartition(topicPartition);
+                newPartitionsInTransaction.add(topicPartition);
+            }
+        }
     }
 
     RuntimeException lastError() {
         return lastError;
     }
 
-    public synchronized void failIfNotReadyForSend() {
-        if (hasError())
-            throw new KafkaException("Cannot perform send because at least one previous transactional or " +
-                    "idempotent request has failed with errors.", lastError);
-
-        if (isTransactional()) {
-            if (!hasProducerId())
-                throw new IllegalStateException("Cannot perform a 'send' before completing a call to initTransactions " +
-                        "when transactions are enabled.");
-
-            if (currentState != State.IN_TRANSACTION)
-                throw new IllegalStateException("Cannot call send in state " + currentState);
-        }
-    }
-
     synchronized boolean isSendToPartitionAllowed(TopicPartition tp) {
         if (hasFatalError())
             return false;
@@ -500,8 +503,8 @@ public class TransactionManager {
         log.info("Transiting to fatal error state due to {}", exception.toString());
         transitionTo(State.FATAL_ERROR, exception);
 
-        if (pendingResult != null) {
-            pendingResult.fail(exception);
+        if (pendingTransition != null) {
+            pendingTransition.result.fail(exception);
         }
     }
 
@@ -919,8 +922,8 @@ public class TransactionManager {
         KafkaException shutdownException = new KafkaException("The producer closed forcefully");
         pendingRequests.forEach(handler ->
                 handler.fatalError(shutdownException));
-        if (pendingResult != null) {
-            pendingResult.fail(shutdownException);
+        if (pendingTransition != null) {
+            pendingTransition.result.fail(shutdownException);
         }
     }
 
@@ -1073,7 +1076,7 @@ public class TransactionManager {
     private void transitionTo(State target, RuntimeException error) {
         if (!currentState.isTransitionValid(currentState, target)) {
             String idString = transactionalId == null ?  "" : "TransactionalId " + transactionalId + ": ";
-            throw new KafkaException(idString + "Invalid transition attempted from state "
+            throw new IllegalStateException(idString + "Invalid transition attempted from state "
                     + currentState.name() + " to state " + target.name());
         }
 
@@ -1103,10 +1106,12 @@ public class TransactionManager {
             // for ProducerFencedException, do not wrap it as a KafkaException
             // but create a new instance without the call trace since it was not thrown because of the current call
             if (lastError instanceof ProducerFencedException) {
-                throw new ProducerFencedException("The producer has been rejected from the broker because " +
-                    "it tried to use an old epoch with the transactionalId");
+                throw new ProducerFencedException("Producer with transactionalId '" + transactionalId
+                    + "' and " + producerIdAndEpoch + " has been fenced by another producer " +
+                    "with the same transactionalId");
             } else if (lastError instanceof InvalidProducerEpochException) {
-                throw new InvalidProducerEpochException("Producer attempted to produce with an old epoch " + producerIdAndEpoch);
+                throw new InvalidProducerEpochException("Producer with transactionalId '" + transactionalId
+                    + "' and " + producerIdAndEpoch + " attempted to produce with an old epoch");
             } else {
                 throw new KafkaException("Cannot execute transactional method because we are in an error state", lastError);
             }
@@ -1183,20 +1188,40 @@ public class TransactionManager {
         return new TxnOffsetCommitHandler(result, builder);
     }
 
+    private void throwIfPendingState(String operation) {
+        if (pendingTransition != null) {
+            if (pendingTransition.result.isAcked()) {
+                pendingTransition = null;
+            } else {
+                throw new IllegalStateException("Cannot attempt operation `" + operation + "` "
+                    + "because the previous call to `" + pendingTransition.operation + "` "
+                    + "timed out and must be retried");
+            }
+        }
+    }
+
     private TransactionalRequestResult handleCachedTransactionRequestResult(
-            Supplier<TransactionalRequestResult> transactionalRequestResultSupplier,
-            State targetState) {
+        Supplier<TransactionalRequestResult> transactionalRequestResultSupplier,
+        State nextState,
+        String operation
+    ) {
         ensureTransactional();
 
-        if (pendingResult != null && currentState == targetState) {
-            TransactionalRequestResult result = pendingResult;
-            if (result.isCompleted())
-                pendingResult = null;
-            return result;
+        if (pendingTransition != null) {
+            if (pendingTransition.result.isAcked()) {
+                pendingTransition = null;
+            } else if (nextState != pendingTransition.state) {
+                throw new IllegalStateException("Cannot attempt operation `" + operation + "` "
+                    + "because the previous call to `" + pendingTransition.operation + "` "
+                    + "timed out and must be retried");
+            } else {
+                return pendingTransition.result;
+            }
         }
 
-        pendingResult = transactionalRequestResultSupplier.get();
-        return pendingResult;
+        TransactionalRequestResult result = transactionalRequestResultSupplier.get();
+        pendingTransition = new PendingStateTransition(result, nextState, operation);
+        return result;
     }
 
     // package-private for testing
@@ -1762,4 +1787,22 @@ public class TransactionManager {
                    || error == Errors.PRODUCER_FENCED
                    || error == Errors.UNSUPPORTED_FOR_MESSAGE_FORMAT;
     }
+
+    private static final class PendingStateTransition {
+        private final TransactionalRequestResult result;
+        private final State state;
+        private final String operation;
+
+        private PendingStateTransition(
+            TransactionalRequestResult result,
+            State state,
+            String operation
+        ) {
+            this.result = result;
+            this.state = state;
+            this.operation = operation;
+        }
+    }
+
+
 }
diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionalRequestResult.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionalRequestResult.java
index d442b18..6739da8 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionalRequestResult.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionalRequestResult.java
@@ -20,15 +20,14 @@ package org.apache.kafka.clients.producer.internals;
 import org.apache.kafka.common.errors.InterruptException;
 import org.apache.kafka.common.errors.TimeoutException;
 
-import java.util.Locale;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 
 public final class TransactionalRequestResult {
-
     private final CountDownLatch latch;
     private volatile RuntimeException error = null;
     private final String operation;
+    private volatile boolean isAcked = false;
 
     public TransactionalRequestResult(String operation) {
         this(new CountDownLatch(1), operation);
@@ -49,29 +48,20 @@ public final class TransactionalRequestResult {
     }
 
     public void await() {
-        boolean completed = false;
-
-        while (!completed) {
-            try {
-                latch.await();
-                completed = true;
-            } catch (InterruptedException e) {
-                // Keep waiting until done, we have no other option for these transactional requests.
-            }
-        }
-
-        if (!isSuccessful())
-            throw error();
+        this.await(Long.MAX_VALUE, TimeUnit.MILLISECONDS);
     }
 
     public void await(long timeout, TimeUnit unit) {
         try {
             boolean success = latch.await(timeout, unit);
-            if (!isSuccessful()) {
-                throw error();
-            }
             if (!success) {
-                throw new TimeoutException("Timeout expired after " + timeout + " " + unit.name().toLowerCase(Locale.ROOT) + " while awaiting " + operation);
+                throw new TimeoutException("Timeout expired after " + unit.toMillis(timeout) +
+                    "ms while awaiting " + operation);
+            }
+
+            isAcked = true;
+            if (error != null) {
+                throw error;
             }
         } catch (InterruptedException e) {
             throw new InterruptException("Received interrupt while awaiting " + operation, e);
@@ -83,11 +73,15 @@ public final class TransactionalRequestResult {
     }
 
     public boolean isSuccessful() {
-        return error == null;
+        return isCompleted() && error == null;
     }
 
     public boolean isCompleted() {
         return latch.getCount() == 0L;
     }
 
+    public boolean isAcked() {
+        return isAcked;
+    }
+
 }
diff --git a/clients/src/main/java/org/apache/kafka/common/utils/ProducerIdAndEpoch.java b/clients/src/main/java/org/apache/kafka/common/utils/ProducerIdAndEpoch.java
index 674b423..5061da1 100644
--- a/clients/src/main/java/org/apache/kafka/common/utils/ProducerIdAndEpoch.java
+++ b/clients/src/main/java/org/apache/kafka/common/utils/ProducerIdAndEpoch.java
@@ -35,7 +35,7 @@ public class ProducerIdAndEpoch {
 
     @Override
     public String toString() {
-        return "(producerId=" + producerId + ", epoch=" + epoch + ")";
+        return "ProducerIdAndEpoch(producerId=" + producerId + ", epoch=" + epoch + ")";
     }
 
     @Override
diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/KafkaProducerTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/KafkaProducerTest.java
index 1e45a58..96a3034 100644
--- a/clients/src/test/java/org/apache/kafka/clients/producer/KafkaProducerTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/producer/KafkaProducerTest.java
@@ -904,6 +904,48 @@ public class KafkaProducerTest {
     }
 
     @Test
+    public void testInitTransactionsResponseAfterTimeout() throws Exception {
+        int maxBlockMs = 500;
+
+        Map<String, Object> configs = new HashMap<>();
+        configs.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "bad-transaction");
+        configs.put(ProducerConfig.MAX_BLOCK_MS_CONFIG, maxBlockMs);
+        configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000");
+
+        Time time = new MockTime();
+        MetadataResponse initialUpdateResponse = RequestTestUtils.metadataUpdateWith(1, singletonMap("topic", 1));
+        ProducerMetadata metadata = newMetadata(0, Long.MAX_VALUE);
+        metadata.updateWithCurrentRequestVersion(initialUpdateResponse, false, time.milliseconds());
+
+        MockClient client = new MockClient(time, metadata);
+
+        ExecutorService executor = Executors.newFixedThreadPool(1);
+
+        Producer<String, String> producer = kafkaProducer(configs, new StringSerializer(),
+            new StringSerializer(), metadata, client, null, time);
+        try {
+            client.prepareResponse(
+                request -> request instanceof FindCoordinatorRequest &&
+                    ((FindCoordinatorRequest) request).data().keyType() == FindCoordinatorRequest.CoordinatorType.TRANSACTION.id(),
+                FindCoordinatorResponse.prepareResponse(Errors.NONE, "bad-transaction", host1));
+
+            Future<?> future = executor.submit(producer::initTransactions);
+            TestUtils.waitForCondition(client::hasInFlightRequests,
+                "Timed out while waiting for expected `InitProducerId` request to be sent");
+
+            time.sleep(maxBlockMs);
+            TestUtils.assertFutureThrows(future, TimeoutException.class);
+
+            client.respond(initProducerIdResponse(1L, (short) 5, Errors.NONE));
+
+            Thread.sleep(1000);
+            producer.initTransactions();
+        } finally {
+            producer.close(Duration.ZERO);
+        }
+    }
+
+    @Test
     public void testInitTransactionTimeout() {
         Map<String, Object> configs = new HashMap<>();
         configs.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "bad-transaction");
@@ -1249,7 +1291,7 @@ public class KafkaProducerTest {
         assertThrows(TimeoutException.class, producer::initTransactions);
         // other transactional operations should not be allowed if we catch the error after initTransactions failed
         try {
-            assertThrows(KafkaException.class, producer::beginTransaction);
+            assertThrows(IllegalStateException.class, producer::beginTransaction);
         } finally {
             producer.close(Duration.ofMillis(0));
         }
diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
index 34e4f18..60e9f06 100644
--- a/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
@@ -1465,8 +1465,7 @@ public class SenderTest {
         doInitTransactions(txnManager, producerIdAndEpoch);
 
         txnManager.beginTransaction();
-        txnManager.failIfNotReadyForSend();
-        txnManager.maybeAddPartitionToTransaction(tp0);
+        txnManager.maybeAddPartition(tp0);
         client.prepareResponse(new AddPartitionsToTxnResponse(0, Collections.singletonMap(tp0, Errors.NONE)));
         sender.runOnce();
 
@@ -1751,7 +1750,8 @@ public class SenderTest {
         doInitTransactions(transactionManager, new ProducerIdAndEpoch(producerId, (short) 0));
         assertTrue(transactionManager.hasProducerId());
 
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.beginTransaction();
+        transactionManager.maybeAddPartition(tp0);
         client.prepareResponse(new AddPartitionsToTxnResponse(0, Collections.singletonMap(tp0, Errors.NONE)));
         sender.runOnce(); // Receive AddPartitions response
 
@@ -2307,8 +2307,7 @@ public class SenderTest {
         doInitTransactions(txnManager, producerIdAndEpoch);
 
         txnManager.beginTransaction();
-        txnManager.failIfNotReadyForSend();
-        txnManager.maybeAddPartitionToTransaction(tp);
+        txnManager.maybeAddPartition(tp);
         client.prepareResponse(new AddPartitionsToTxnResponse(0, Collections.singletonMap(tp, Errors.NONE)));
         sender.runOnce();
 
@@ -2654,8 +2653,7 @@ public class SenderTest {
             doInitTransactions(txnManager, producerIdAndEpoch);
 
             txnManager.beginTransaction();
-            txnManager.failIfNotReadyForSend();
-            txnManager.maybeAddPartitionToTransaction(tp);
+            txnManager.maybeAddPartition(tp);
             client.prepareResponse(new AddPartitionsToTxnResponse(0, Collections.singletonMap(tp, Errors.NONE)));
             sender.runOnce();
             sender.initiateClose();
@@ -2697,7 +2695,7 @@ public class SenderTest {
 
             // Now begin the commit and assert that the Produce request is sent immediately
             // without waiting for the linger.
-            txnManager.beginCommit();
+            TransactionalRequestResult commitResult = txnManager.beginCommit();
             runUntil(sender, client::hasInFlightRequests);
 
             // Respond to the produce request and wait for the EndTxn request to be sent.
@@ -2708,6 +2706,9 @@ public class SenderTest {
             respondToEndTxn(Errors.NONE);
             runUntil(sender, txnManager::isReady);
 
+            assertTrue(commitResult.isSuccessful());
+            commitResult.await();
+
             // Finally, we want to assert that the linger time is still effective
             // when the new transaction begins.
             txnManager.beginTransaction();
@@ -2772,7 +2773,7 @@ public class SenderTest {
     }
 
     private void addPartitionToTxn(Sender sender, TransactionManager txnManager, TopicPartition tp) {
-        txnManager.maybeAddPartitionToTransaction(tp);
+        txnManager.maybeAddPartition(tp);
         client.prepareResponse(new AddPartitionsToTxnResponse(0, Collections.singletonMap(tp, Errors.NONE)));
         runUntil(sender, () -> txnManager.isPartitionAdded(tp));
         assertFalse(txnManager.hasInFlightRequest());
@@ -2813,8 +2814,7 @@ public class SenderTest {
             doInitTransactions(txnManager, producerIdAndEpoch);
 
             txnManager.beginTransaction();
-            txnManager.failIfNotReadyForSend();
-            txnManager.maybeAddPartitionToTransaction(tp);
+            txnManager.maybeAddPartition(tp);
             client.prepareResponse(new AddPartitionsToTxnResponse(0, Collections.singletonMap(tp, Errors.NONE)));
             sender.runOnce();
             sender.initiateClose();
@@ -2848,8 +2848,7 @@ public class SenderTest {
             doInitTransactions(txnManager, producerIdAndEpoch);
 
             txnManager.beginTransaction();
-            txnManager.failIfNotReadyForSend();
-            txnManager.maybeAddPartitionToTransaction(tp);
+            txnManager.maybeAddPartition(tp);
             client.prepareResponse(new AddPartitionsToTxnResponse(0, Collections.singletonMap(tp, Errors.NONE)));
             sender.runOnce();
 
@@ -2874,7 +2873,7 @@ public class SenderTest {
         doInitTransactions(txnManager, producerIdAndEpoch);
         // Begin the transaction
         txnManager.beginTransaction();
-        txnManager.maybeAddPartitionToTransaction(tp0);
+        txnManager.maybeAddPartition(tp0);
         client.prepareResponse(new AddPartitionsToTxnResponse(0, Collections.singletonMap(tp0, Errors.NONE)));
         // Run it once so that the partition is added to the transaction.
         sender.runOnce();
@@ -2912,7 +2911,7 @@ public class SenderTest {
         doInitTransactions(txnManager, producerIdAndEpoch);
 
         txnManager.beginTransaction();
-        txnManager.maybeAddPartitionToTransaction(tp0);
+        txnManager.maybeAddPartition(tp0);
         client.prepareResponse(new AddPartitionsToTxnResponse(0, Collections.singletonMap(tp0, Errors.NONE)));
         sender.runOnce();
 
@@ -3191,7 +3190,7 @@ public class SenderTest {
     }
 
     private void doInitTransactions(TransactionManager transactionManager, ProducerIdAndEpoch producerIdAndEpoch) {
-        transactionManager.initializeTransactions();
+        TransactionalRequestResult result = transactionManager.initializeTransactions();
         prepareFindCoordinatorResponse(Errors.NONE, transactionManager.transactionalId());
         sender.runOnce();
         sender.runOnce();
@@ -3199,6 +3198,7 @@ public class SenderTest {
         prepareInitProducerResponse(Errors.NONE, producerIdAndEpoch.producerId, producerIdAndEpoch.epoch);
         sender.runOnce();
         assertTrue(transactionManager.hasProducerId());
+        result.await();
     }
 
     private void prepareFindCoordinatorResponse(Errors error, String txnid) {
diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
index 6c1e2fd..4227db5 100644
--- a/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
@@ -103,6 +103,7 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertNotEquals;
 import static org.junit.jupiter.api.Assertions.assertNotNull;
 import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertSame;
 import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 import static org.junit.jupiter.api.Assertions.fail;
@@ -184,8 +185,7 @@ public class TransactionManagerTest {
         doInitTransactions();
         transactionManager.beginTransaction();
 
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
         FutureRecordMetadata sendFuture = appendToAccumulator(tp0);
 
         prepareAddPartitionsToTxn(tp0, Errors.NONE);
@@ -206,8 +206,7 @@ public class TransactionManagerTest {
         doInitTransactions();
         transactionManager.beginTransaction();
 
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
         prepareAddPartitionsToTxn(tp0, Errors.NONE);
         runUntil(() -> transactionManager.isPartitionAdded(tp0));
 
@@ -218,26 +217,26 @@ public class TransactionManagerTest {
 
     @Test
     public void testFailIfNotReadyForSendNoProducerId() {
-        assertThrows(IllegalStateException.class, () -> transactionManager.failIfNotReadyForSend());
+        assertThrows(IllegalStateException.class, () -> transactionManager.maybeAddPartition(tp0));
     }
 
     @Test
     public void testFailIfNotReadyForSendIdempotentProducer() {
         initializeTransactionManager(Optional.empty());
-        transactionManager.failIfNotReadyForSend();
+        transactionManager.maybeAddPartition(tp0);
     }
 
     @Test
     public void testFailIfNotReadyForSendIdempotentProducerFatalError() {
         initializeTransactionManager(Optional.empty());
         transactionManager.transitionToFatalError(new KafkaException());
-        assertThrows(KafkaException.class, () -> transactionManager.failIfNotReadyForSend());
+        assertThrows(KafkaException.class, () -> transactionManager.maybeAddPartition(tp0));
     }
 
     @Test
     public void testFailIfNotReadyForSendNoOngoingTransaction() {
         doInitTransactions();
-        assertThrows(IllegalStateException.class, () -> transactionManager.failIfNotReadyForSend());
+        assertThrows(IllegalStateException.class, () -> transactionManager.maybeAddPartition(tp0));
     }
 
     @Test
@@ -245,14 +244,14 @@ public class TransactionManagerTest {
         doInitTransactions();
         transactionManager.beginTransaction();
         transactionManager.transitionToAbortableError(new KafkaException());
-        assertThrows(KafkaException.class, transactionManager::failIfNotReadyForSend);
+        assertThrows(KafkaException.class, () -> transactionManager.maybeAddPartition(tp0));
     }
 
     @Test
     public void testFailIfNotReadyForSendAfterFatalError() {
         doInitTransactions();
         transactionManager.transitionToFatalError(new KafkaException());
-        assertThrows(KafkaException.class, transactionManager::failIfNotReadyForSend);
+        assertThrows(KafkaException.class, () -> transactionManager.maybeAddPartition(tp0));
     }
 
     @Test
@@ -266,8 +265,7 @@ public class TransactionManagerTest {
         transactionManager.beginTransaction();
         assertTrue(transactionManager.hasOngoingTransaction());
 
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(partition);
+        transactionManager.maybeAddPartition(partition);
         runUntil(transactionManager::hasOngoingTransaction);
 
         prepareAddPartitionsToTxn(partition, Errors.NONE);
@@ -291,8 +289,7 @@ public class TransactionManagerTest {
         transactionManager.beginTransaction();
         assertTrue(transactionManager.hasOngoingTransaction());
 
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(partition);
+        transactionManager.maybeAddPartition(partition);
         assertTrue(transactionManager.hasOngoingTransaction());
 
         prepareAddPartitionsToTxn(partition, Errors.NONE);
@@ -316,8 +313,7 @@ public class TransactionManagerTest {
         transactionManager.beginTransaction();
         assertTrue(transactionManager.hasOngoingTransaction());
 
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(partition);
+        transactionManager.maybeAddPartition(partition);
         assertTrue(transactionManager.hasOngoingTransaction());
 
         prepareAddPartitionsToTxn(partition, Errors.NONE);
@@ -344,8 +340,7 @@ public class TransactionManagerTest {
         transactionManager.beginTransaction();
         assertTrue(transactionManager.hasOngoingTransaction());
 
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(partition);
+        transactionManager.maybeAddPartition(partition);
         assertTrue(transactionManager.hasOngoingTransaction());
 
         prepareAddPartitionsToTxn(partition, Errors.NONE);
@@ -361,8 +356,7 @@ public class TransactionManagerTest {
         doInitTransactions();
         transactionManager.beginTransaction();
 
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(partition);
+        transactionManager.maybeAddPartition(partition);
         assertTrue(transactionManager.hasPartitionsToAdd());
         assertFalse(transactionManager.isPartitionAdded(partition));
         assertTrue(transactionManager.isPartitionPendingAdd(partition));
@@ -375,8 +369,7 @@ public class TransactionManagerTest {
         assertFalse(transactionManager.isPartitionPendingAdd(partition));
 
         // adding the partition again should not have any effect
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(partition);
+        transactionManager.maybeAddPartition(partition);
         assertFalse(transactionManager.hasPartitionsToAdd());
         assertTrue(transactionManager.isPartitionAdded(partition));
         assertFalse(transactionManager.isPartitionPendingAdd(partition));
@@ -388,8 +381,7 @@ public class TransactionManagerTest {
         doInitTransactions();
         transactionManager.beginTransaction();
 
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(partition);
+        transactionManager.maybeAddPartition(partition);
         assertTrue(transactionManager.hasPartitionsToAdd());
         assertFalse(transactionManager.isPartitionAdded(partition));
         assertTrue(transactionManager.isPartitionPendingAdd(partition));
@@ -408,8 +400,7 @@ public class TransactionManagerTest {
         doInitTransactions();
         transactionManager.beginTransaction();
 
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(partition);
+        transactionManager.maybeAddPartition(partition);
         assertTrue(transactionManager.hasPartitionsToAdd());
         assertFalse(transactionManager.isPartitionAdded(partition));
         assertTrue(transactionManager.isPartitionPendingAdd(partition));
@@ -428,8 +419,7 @@ public class TransactionManagerTest {
         doInitTransactions();
         transactionManager.beginTransaction();
 
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(partition);
+        transactionManager.maybeAddPartition(partition);
         assertTrue(transactionManager.hasPartitionsToAdd());
         assertFalse(transactionManager.isPartitionAdded(partition));
         assertTrue(transactionManager.isPartitionPendingAdd(partition));
@@ -438,8 +428,7 @@ public class TransactionManagerTest {
         runUntil(() -> transactionManager.isPartitionAdded(partition));
 
         TopicPartition otherPartition = new TopicPartition("foo", 1);
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(otherPartition);
+        transactionManager.maybeAddPartition(otherPartition);
 
         prepareAddPartitionsToTxn(otherPartition, Errors.CONCURRENT_TRANSACTIONS);
         TransactionManager.TxnRequestHandler handler = transactionManager.nextRequest(false);
@@ -449,13 +438,13 @@ public class TransactionManagerTest {
 
     @Test
     public void testNotReadyForSendBeforeInitTransactions() {
-        assertThrows(IllegalStateException.class, () -> transactionManager.failIfNotReadyForSend());
+        assertThrows(IllegalStateException.class, () -> transactionManager.maybeAddPartition(tp0));
     }
 
     @Test
     public void testNotReadyForSendBeforeBeginTransaction() {
         doInitTransactions();
-        assertThrows(IllegalStateException.class, () -> transactionManager.failIfNotReadyForSend());
+        assertThrows(IllegalStateException.class, () -> transactionManager.maybeAddPartition(tp0));
     }
 
     @Test
@@ -463,14 +452,14 @@ public class TransactionManagerTest {
         doInitTransactions();
         transactionManager.beginTransaction();
         transactionManager.transitionToAbortableError(new KafkaException());
-        assertThrows(KafkaException.class, () -> transactionManager.failIfNotReadyForSend());
+        assertThrows(KafkaException.class, () -> transactionManager.maybeAddPartition(tp0));
     }
 
     @Test
     public void testNotReadyForSendAfterFatalError() {
         doInitTransactions();
         transactionManager.transitionToFatalError(new KafkaException());
-        assertThrows(KafkaException.class, () -> transactionManager.failIfNotReadyForSend());
+        assertThrows(KafkaException.class, () -> transactionManager.maybeAddPartition(tp0));
     }
 
     @Test
@@ -478,8 +467,7 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
         transactionManager.transitionToAbortableError(new KafkaException());
 
         assertFalse(transactionManager.isSendToPartitionAllowed(tp0));
@@ -491,8 +479,7 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         // Send the AddPartitionsToTxn request and leave it in-flight
         runUntil(transactionManager::hasInFlightRequest);
@@ -507,8 +494,7 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
         transactionManager.transitionToFatalError(new KafkaException());
 
         assertFalse(transactionManager.isSendToPartitionAllowed(tp0));
@@ -520,8 +506,7 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         // Send the AddPartitionsToTxn request and leave it in-flight
         runUntil(transactionManager::hasInFlightRequest);
@@ -537,8 +522,7 @@ public class TransactionManagerTest {
 
         transactionManager.beginTransaction();
 
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
         prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, producerId);
 
         runUntil(() -> !transactionManager.hasPartitionsToAdd());
@@ -553,8 +537,7 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
         prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, producerId);
 
         runUntil(() -> !transactionManager.hasPartitionsToAdd());
@@ -747,8 +730,7 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         Future<RecordMetadata> responseFuture = appendToAccumulator(tp0);
 
@@ -813,7 +795,7 @@ public class TransactionManagerTest {
     public void testInitializeTransactionsTwiceRaisesError() {
         doInitTransactions(producerId, epoch);
         assertTrue(transactionManager.hasProducerId());
-        assertThrows(KafkaException.class, () -> transactionManager.initializeTransactions());
+        assertThrows(IllegalStateException.class, () -> transactionManager.initializeTransactions());
     }
 
     @Test
@@ -1075,7 +1057,7 @@ public class TransactionManagerTest {
         assertTrue(transactionManager.hasFatalError());
         assertTrue(transactionManager.lastError() instanceof TransactionalIdAuthorizationException);
         assertFalse(initPidResult.isSuccessful());
-        assertTrue(initPidResult.error() instanceof TransactionalIdAuthorizationException);
+        assertThrows(TransactionalIdAuthorizationException.class, initPidResult::await);
         assertFatalError(TransactionalIdAuthorizationException.class);
     }
 
@@ -1090,8 +1072,7 @@ public class TransactionManagerTest {
         runUntil(transactionManager::hasError);
         assertTrue(initPidResult.isCompleted());
         assertFalse(initPidResult.isSuccessful());
-        assertTrue(initPidResult.error() instanceof TransactionalIdAuthorizationException);
-
+        assertThrows(TransactionalIdAuthorizationException.class, initPidResult::await);
         assertFatalError(TransactionalIdAuthorizationException.class);
     }
 
@@ -1202,10 +1183,8 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp1);
+        transactionManager.maybeAddPartition(tp0);
+        transactionManager.maybeAddPartition(tp1);
 
         FutureRecordMetadata firstPartitionAppend = appendToAccumulator(tp0);
         FutureRecordMetadata secondPartitionAppend = appendToAccumulator(tp1);
@@ -1242,8 +1221,8 @@ public class TransactionManagerTest {
 
         // Begin a transaction, send two records, and begin commit
         transactionManager.beginTransaction();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
-        transactionManager.maybeAddPartitionToTransaction(tp1);
+        transactionManager.maybeAddPartition(tp0);
+        transactionManager.maybeAddPartition(tp1);
         FutureRecordMetadata firstPartitionAppend = appendToAccumulator(tp0);
         FutureRecordMetadata secondPartitionAppend = appendToAccumulator(tp1);
         TransactionalRequestResult commitResult = transactionManager.beginCommit();
@@ -1287,8 +1266,7 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(unauthorizedPartition);
+        transactionManager.maybeAddPartition(unauthorizedPartition);
 
         Future<RecordMetadata> responseFuture = appendToAccumulator(unauthorizedPartition);
 
@@ -1296,7 +1274,7 @@ public class TransactionManagerTest {
         runUntil(() -> !client.hasPendingResponses());
 
         assertTrue(transactionManager.hasAbortableError());
-        transactionManager.beginAbort();
+        TransactionalRequestResult abortResult = transactionManager.beginAbort();
         runUntil(responseFuture::isDone);
         assertProduceFutureFailed(responseFuture);
 
@@ -1304,12 +1282,13 @@ public class TransactionManagerTest {
         runUntil(transactionManager::isReady);
         assertFalse(transactionManager.hasPartitionsToAdd());
         assertFalse(accumulator.hasIncomplete());
+        assertTrue(abortResult.isSuccessful());
+        abortResult.await();
 
         // ensure we can now start a new transaction
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         responseFuture = appendToAccumulator(tp0);
 
@@ -1327,21 +1306,117 @@ public class TransactionManagerTest {
     }
 
     @Test
+    public void testRetryAbortTransactionAfterTimeout() throws Exception {
+        doInitTransactions();
+
+        transactionManager.beginTransaction();
+        transactionManager.maybeAddPartition(tp0);
+
+        prepareAddPartitionsToTxn(tp0, Errors.NONE);
+        appendToAccumulator(tp0);
+        runUntil(() -> transactionManager.isPartitionAdded(tp0));
+
+        TransactionalRequestResult result = transactionManager.beginAbort();
+        assertThrows(TimeoutException.class, () -> result.await(0, TimeUnit.MILLISECONDS));
+
+        prepareEndTxnResponse(Errors.NONE, TransactionResult.ABORT, producerId, epoch);
+        runUntil(transactionManager::isReady);
+        assertTrue(result.isSuccessful());
+        assertFalse(result.isAcked());
+        assertFalse(transactionManager.hasOngoingTransaction());
+
+        assertThrows(IllegalStateException.class, transactionManager::initializeTransactions);
+        assertThrows(IllegalStateException.class, transactionManager::beginTransaction);
+        assertThrows(IllegalStateException.class, transactionManager::beginCommit);
+        assertThrows(IllegalStateException.class, () -> transactionManager.maybeAddPartition(tp0));
+
+        assertSame(result, transactionManager.beginAbort());
+        result.await();
+
+        transactionManager.beginTransaction();
+        assertTrue(transactionManager.hasOngoingTransaction());
+    }
+
+    @Test
+    public void testRetryCommitTransactionAfterTimeout() throws Exception {
+        doInitTransactions();
+
+        transactionManager.beginTransaction();
+        transactionManager.maybeAddPartition(tp0);
+
+        prepareAddPartitionsToTxn(tp0, Errors.NONE);
+        prepareProduceResponse(Errors.NONE, producerId, epoch);
+
+        appendToAccumulator(tp0);
+        runUntil(() -> transactionManager.isPartitionAdded(tp0));
+
+        TransactionalRequestResult result = transactionManager.beginCommit();
+        assertThrows(TimeoutException.class, () -> result.await(0, TimeUnit.MILLISECONDS));
+
+        prepareEndTxnResponse(Errors.NONE, TransactionResult.COMMIT, producerId, epoch);
+        runUntil(transactionManager::isReady);
+        assertTrue(result.isSuccessful());
+        assertFalse(result.isAcked());
+        assertFalse(transactionManager.hasOngoingTransaction());
+
+        assertThrows(IllegalStateException.class, transactionManager::initializeTransactions);
+        assertThrows(IllegalStateException.class, transactionManager::beginTransaction);
+        assertThrows(IllegalStateException.class, transactionManager::beginAbort);
+        assertThrows(IllegalStateException.class, () -> transactionManager.maybeAddPartition(tp0));
+
+        assertSame(result, transactionManager.beginCommit());
+        result.await();
+
+        transactionManager.beginTransaction();
+        assertTrue(transactionManager.hasOngoingTransaction());
+    }
+
+    @Test
+    public void testRetryInitTransactionsAfterTimeout() {
+        TransactionalRequestResult result = transactionManager.initializeTransactions();
+        prepareFindCoordinatorResponse(Errors.NONE, false, CoordinatorType.TRANSACTION, transactionalId);
+        runUntil(() -> transactionManager.coordinator(CoordinatorType.TRANSACTION) != null);
+        assertEquals(brokerNode, transactionManager.coordinator(CoordinatorType.TRANSACTION));
+
+        assertThrows(TimeoutException.class, () -> result.await(0, TimeUnit.MILLISECONDS));
+
+        prepareInitPidResponse(Errors.NONE, false, producerId, epoch);
+        runUntil(transactionManager::hasProducerId);
+        assertTrue(result.isSuccessful());
+        assertFalse(result.isAcked());
+
+        // At this point, the InitProducerId call has returned, but the user has yet
+        // to complete the call to `initTransactions`. Other transitions should be
+        // rejected until they do.
+
+        assertThrows(IllegalStateException.class, transactionManager::beginTransaction);
+        assertThrows(IllegalStateException.class, transactionManager::beginAbort);
+        assertThrows(IllegalStateException.class, transactionManager::beginCommit);
+        assertThrows(IllegalStateException.class, () -> transactionManager.maybeAddPartition(tp0));
+
+        assertSame(result, transactionManager.initializeTransactions());
+        result.await();
+        assertTrue(result.isAcked());
+        assertThrows(IllegalStateException.class, transactionManager::initializeTransactions);
+
+        transactionManager.beginTransaction();
+        assertTrue(transactionManager.hasOngoingTransaction());
+    }
+
+    @Test
     public void testRecoveryFromAbortableErrorTransactionStarted() throws Exception {
         final TopicPartition unauthorizedPartition = new TopicPartition("foo", 0);
 
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
         prepareAddPartitionsToTxn(tp0, Errors.NONE);
 
         Future<RecordMetadata> authorizedTopicProduceFuture = appendToAccumulator(unauthorizedPartition);
         runUntil(() -> transactionManager.isPartitionAdded(tp0));
 
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(unauthorizedPartition);
+        transactionManager.maybeAddPartition(unauthorizedPartition);
         Future<RecordMetadata> unauthorizedTopicProduceFuture = appendToAccumulator(unauthorizedPartition);
         prepareAddPartitionsToTxn(singletonMap(unauthorizedPartition, Errors.TOPIC_AUTHORIZATION_FAILED));
         runUntil(transactionManager::hasAbortableError);
@@ -1351,19 +1426,20 @@ public class TransactionManagerTest {
         assertFalse(unauthorizedTopicProduceFuture.isDone());
 
         prepareEndTxnResponse(Errors.NONE, TransactionResult.ABORT, producerId, epoch);
-        transactionManager.beginAbort();
+        TransactionalRequestResult result = transactionManager.beginAbort();
         runUntil(transactionManager::isReady);
         // neither produce request has been sent, so they should both be failed immediately
         assertProduceFutureFailed(authorizedTopicProduceFuture);
         assertProduceFutureFailed(unauthorizedTopicProduceFuture);
         assertFalse(transactionManager.hasPartitionsToAdd());
         assertFalse(accumulator.hasIncomplete());
+        assertTrue(result.isSuccessful());
+        result.await();
 
         // ensure we can now start a new transaction
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         FutureRecordMetadata nextTransactionFuture = appendToAccumulator(tp0);
 
@@ -1387,8 +1463,7 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
         prepareAddPartitionsToTxn(tp0, Errors.NONE);
 
         Future<RecordMetadata> authorizedTopicProduceFuture = appendToAccumulator(tp0);
@@ -1400,8 +1475,7 @@ public class TransactionManagerTest {
         assertFalse(authorizedTopicProduceFuture.isDone());
         assertTrue(accumulator.hasIncomplete());
 
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(unauthorizedPartition);
+        transactionManager.maybeAddPartition(unauthorizedPartition);
         Future<RecordMetadata> unauthorizedTopicProduceFuture = appendToAccumulator(unauthorizedPartition);
         prepareAddPartitionsToTxn(singletonMap(unauthorizedPartition, Errors.TOPIC_AUTHORIZATION_FAILED));
         runUntil(transactionManager::hasAbortableError);
@@ -1417,18 +1491,19 @@ public class TransactionManagerTest {
         assertTrue(authorizedTopicProduceFuture.isDone());
 
         prepareEndTxnResponse(Errors.NONE, TransactionResult.ABORT, producerId, epoch);
-        transactionManager.beginAbort();
+        TransactionalRequestResult abortResult = transactionManager.beginAbort();
         runUntil(transactionManager::isReady);
         // neither produce request has been sent, so they should both be failed immediately
         assertTrue(transactionManager.isReady());
         assertFalse(transactionManager.hasPartitionsToAdd());
         assertFalse(accumulator.hasIncomplete());
+        assertTrue(abortResult.isSuccessful());
+        abortResult.await();
 
         // ensure we can now start a new transaction
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         FutureRecordMetadata nextTransactionFuture = appendToAccumulator(tp0);
 
@@ -1452,8 +1527,7 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp);
+        transactionManager.maybeAddPartition(tp);
 
         prepareAddPartitionsToTxn(tp, Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED);
         runUntil(transactionManager::hasError);
@@ -1467,8 +1541,7 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         Future<RecordMetadata> responseFuture = appendToAccumulator(tp0);
 
@@ -1506,8 +1579,7 @@ public class TransactionManagerTest {
 
         transactionManager.beginTransaction();
         // User does one producer.send
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         Future<RecordMetadata> responseFuture = appendToAccumulator(tp0);
 
@@ -1520,8 +1592,7 @@ public class TransactionManagerTest {
         runUntil(() -> transactionManager.transactionContainsPartition(tp0));
 
         // In the mean time, the user does a second produce to a different partition
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp1);
+        transactionManager.maybeAddPartition(tp1);
         Future<RecordMetadata> secondResponseFuture = appendToAccumulator(tp0);
 
         prepareAddPartitionsToTxnResponse(Errors.NONE, tp1, epoch, producerId);
@@ -1563,7 +1634,7 @@ public class TransactionManagerTest {
 
         runUntil(transactionManager::hasError);
 
-        assertEquals(ProducerFencedException.class, result.error().getClass());
+        assertThrows(ProducerFencedException.class, result::await);
 
         assertThrows(ProducerFencedException.class, () -> transactionManager.beginTransaction());
         assertThrows(ProducerFencedException.class, () -> transactionManager.beginCommit());
@@ -1586,8 +1657,7 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         Future<RecordMetadata> responseFuture = appendToAccumulator(tp0);
 
@@ -1611,7 +1681,6 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
         transactionManager.sendOffsetsToTransaction(Collections.emptyMap(), new ConsumerGroupMetadata(consumerGroupId));
 
         Future<RecordMetadata> responseFuture = appendToAccumulator(tp0);
@@ -1647,8 +1716,7 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
         TransactionalRequestResult commitResult = transactionManager.beginCommit();
 
         Future<RecordMetadata> responseFuture = appendToAccumulator(tp0);
@@ -1661,6 +1729,10 @@ public class TransactionManagerTest {
         runUntil(commitResult::isCompleted);
         runUntil(responseFuture::isDone);
 
+        assertThrows(KafkaException.class, commitResult::await);
+        assertFalse(commitResult.isSuccessful());
+        assertTrue(commitResult.isAcked());
+
         // make sure the exception was thrown directly from the follow-up calls.
         assertThrows(KafkaException.class, () -> transactionManager.beginTransaction());
         assertThrows(KafkaException.class, () -> transactionManager.beginCommit());
@@ -1674,8 +1746,7 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         Future<RecordMetadata> responseFuture = appendToAccumulator(tp0);
 
@@ -1709,8 +1780,7 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         Future<RecordMetadata> responseFuture = appendToAccumulator(tp0);
 
@@ -1721,19 +1791,8 @@ public class TransactionManagerTest {
 
         runUntil(commitResult::isCompleted);  // commit should be cancelled with exception without being sent.
 
-        try {
-            commitResult.await();
-            fail();  // the get() must throw an exception.
-        } catch (KafkaException e) {
-            // Expected
-        }
-
-        try {
-            responseFuture.get();
-            fail("Expected produce future to raise an exception");
-        } catch (ExecutionException e) {
-            assertTrue(e.getCause() instanceof OutOfOrderSequenceException);
-        }
+        assertThrows(KafkaException.class, commitResult::await);
+        TestUtils.assertFutureThrows(responseFuture, OutOfOrderSequenceException.class);
 
         // Commit is not allowed, so let's abort and try again.
         TransactionalRequestResult abortResult = transactionManager.beginAbort();
@@ -1749,8 +1808,7 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         Future<RecordMetadata> responseFuture = appendToAccumulator(tp0);
 
@@ -1773,8 +1831,7 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         Future<RecordMetadata> responseFuture = appendToAccumulator(tp0);
 
@@ -1804,8 +1861,7 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         Future<RecordMetadata> responseFuture = appendToAccumulator(tp0);
 
@@ -1848,8 +1904,7 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         Future<RecordMetadata> responseFuture = appendToAccumulator(tp0);
 
@@ -1891,8 +1946,7 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         Future<RecordMetadata> responseFuture = appendToAccumulator(tp0);
 
@@ -1914,8 +1968,7 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         Future<RecordMetadata> responseFuture = appendToAccumulator(tp0);
 
@@ -1928,12 +1981,7 @@ public class TransactionManagerTest {
         assertTrue(abortResult.isSuccessful());
         assertTrue(transactionManager.isReady());  // make sure we are ready for a transaction now.
 
-        try {
-            responseFuture.get();
-            fail("Expected produce future to raise an exception");
-        } catch (ExecutionException e) {
-            assertTrue(e.getCause() instanceof KafkaException);
-        }
+        TestUtils.assertFutureThrows(responseFuture, KafkaException.class);
     }
 
     @Test
@@ -1941,8 +1989,7 @@ public class TransactionManagerTest {
         doInitTransactions(producerId, epoch);
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
         prepareAddPartitionsToTxnResponse(Errors.UNKNOWN_TOPIC_OR_PARTITION, tp0, epoch, producerId);
 
         Future<RecordMetadata> responseFuture = appendToAccumulator(tp0);
@@ -1960,12 +2007,7 @@ public class TransactionManagerTest {
         assertTrue(abortResult.isSuccessful());
         assertTrue(transactionManager.isReady());  // make sure we are ready for a transaction now.
 
-        try {
-            responseFuture.get();
-            fail("Expected produce future to raise an exception");
-        } catch (ExecutionException e) {
-            assertTrue(e.getCause() instanceof KafkaException);
-        }
+        TestUtils.assertFutureThrows(responseFuture, KafkaException.class);
     }
 
     @Test
@@ -1973,8 +2015,7 @@ public class TransactionManagerTest {
         doInitTransactions(producerId, epoch);
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
         prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, producerId);
         prepareProduceResponse(Errors.REQUEST_TIMED_OUT, producerId, epoch);
 
@@ -2002,8 +2043,7 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         Future<RecordMetadata> responseFuture = appendToAccumulator(tp0);
 
@@ -2116,8 +2156,7 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         Future<RecordMetadata> responseFuture = appendToAccumulator(tp0);
         assertFalse(responseFuture.isDone());
@@ -2131,8 +2170,7 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         prepareAddPartitionsToTxnResponse(Errors.TOPIC_AUTHORIZATION_FAILED, tp0, epoch, producerId);
         runUntil(() -> !client.hasPendingResponses());
@@ -2241,7 +2279,6 @@ public class TransactionManagerTest {
         assertFalse(addOffsetsResult.isCompleted());  // The request should complete only after the TxnOffsetCommit completes
 
         prepareFindCoordinatorResponse(Errors.NONE, false, CoordinatorType.GROUP, consumerGroupId);
-//        prepareTxnOffsetCommitResponse(consumerGroupId, producerId, epoch, groupInstanceId, memberId, generationId, txnOffsetCommitResponse);
         prepareTxnCommitResponse.run();
 
         assertNull(transactionManager.coordinator(CoordinatorType.GROUP));
@@ -2256,11 +2293,9 @@ public class TransactionManagerTest {
     public void testNoDrainWhenPartitionsPending() throws InterruptedException {
         doInitTransactions();
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
         appendToAccumulator(tp0);
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp1);
+        transactionManager.maybeAddPartition(tp1);
         appendToAccumulator(tp1);
 
         assertFalse(transactionManager.isSendToPartitionAllowed(tp0));
@@ -2291,13 +2326,11 @@ public class TransactionManagerTest {
     public void testAllowDrainInAbortableErrorState() throws InterruptedException {
         doInitTransactions();
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp1);
+        transactionManager.maybeAddPartition(tp1);
         prepareAddPartitionsToTxn(tp1, Errors.NONE);
         runUntil(() -> transactionManager.transactionContainsPartition(tp1));
 
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
         prepareAddPartitionsToTxn(tp0, Errors.TOPIC_AUTHORIZATION_FAILED);
         runUntil(transactionManager::hasAbortableError);
         assertTrue(transactionManager.isSendToPartitionAllowed(tp1));
@@ -2345,8 +2378,7 @@ public class TransactionManagerTest {
         doInitTransactions();
         transactionManager.beginTransaction();
 
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         Future<RecordMetadata> responseFuture = appendToAccumulator(tp0);
 
@@ -2367,8 +2399,7 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         Future<RecordMetadata> responseFuture = appendToAccumulator(tp0);
 
@@ -2408,10 +2439,8 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp1);
+        transactionManager.maybeAddPartition(tp0);
+        transactionManager.maybeAddPartition(tp1);
 
         Future<RecordMetadata> firstBatchResponse = appendToAccumulator(tp0);
         Future<RecordMetadata> secondBatchResponse = appendToAccumulator(tp1);
@@ -2468,8 +2497,7 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         Future<RecordMetadata> responseFuture = appendToAccumulator(tp0);
 
@@ -2504,6 +2532,7 @@ public class TransactionManagerTest {
         }
         runUntil(commitResult::isCompleted);  // the commit shouldn't be completed without being sent since the produce request failed.
         assertFalse(commitResult.isSuccessful());  // the commit shouldn't succeed since the produce request failed.
+        assertThrows(TimeoutException.class, commitResult::await);
 
         assertTrue(transactionManager.hasAbortableError());
         assertTrue(transactionManager.hasOngoingTransaction());
@@ -2535,8 +2564,7 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         Future<RecordMetadata> responseFuture = appendToAccumulator(tp0);
 
@@ -2735,8 +2763,7 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         Future<RecordMetadata> responseFuture0 = appendToAccumulator(tp0);
         prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, producerId);
@@ -2758,11 +2785,11 @@ public class TransactionManagerTest {
         prepareEndTxnResponse(Errors.NONE, TransactionResult.ABORT, producerId, epoch);
         runUntil(abortResult::isCompleted);
         assertTrue(abortResult.isSuccessful());
+        abortResult.await();
         assertTrue(transactionManager.isReady());  // make sure we are ready for a transaction now.
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, producerId);
         runUntil(() -> transactionManager.isPartitionAdded(tp0));  // Send AddPartitionsRequest
@@ -2788,16 +2815,15 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
 
-        transactionManager.maybeAddPartitionToTransaction(tp1);
+        transactionManager.maybeAddPartition(tp1);
         Future<RecordMetadata> successPartitionResponseFuture = appendToAccumulator(tp1);
         prepareAddPartitionsToTxnResponse(Errors.NONE, tp1, epoch, producerId);
         prepareProduceResponse(Errors.NONE, producerId, epoch, tp1);
         runUntil(successPartitionResponseFuture::isDone);
         assertTrue(transactionManager.isPartitionAdded(tp1));
 
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
         Future<RecordMetadata> responseFuture0 = appendToAccumulator(tp0);
         prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, producerId);
         prepareProduceResponse(Errors.NONE, producerId, epoch);
@@ -2819,11 +2845,11 @@ public class TransactionManagerTest {
         prepareEndTxnResponse(Errors.NONE, TransactionResult.ABORT, producerId, epoch);
         runUntil(abortResult::isCompleted);
         assertTrue(abortResult.isSuccessful());
+        abortResult.await();
         assertTrue(transactionManager.isReady());  // make sure we are ready for a transaction now.
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, producerId);
         runUntil(() -> transactionManager.isPartitionAdded(tp0));
@@ -2840,8 +2866,7 @@ public class TransactionManagerTest {
         doInitTransactions(producerId, initialEpoch);
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, initialEpoch, producerId);
         runUntil(() -> transactionManager.isPartitionAdded(tp0));
@@ -2867,11 +2892,11 @@ public class TransactionManagerTest {
 
         assertTrue(abortResult.isCompleted());
         assertTrue(abortResult.isSuccessful());
+        abortResult.await();
         assertTrue(transactionManager.isReady());  // make sure we are ready for a transaction now.
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, bumpedEpoch, producerId);
         runUntil(() -> transactionManager.isPartitionAdded(tp0));
@@ -2887,8 +2912,7 @@ public class TransactionManagerTest {
         doInitTransactions(producerId, initialEpoch);
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, initialEpoch, producerId);
         runUntil(() -> transactionManager.isPartitionAdded(tp0));
@@ -2915,11 +2939,11 @@ public class TransactionManagerTest {
 
         assertTrue(abortResult.isCompleted());
         assertTrue(abortResult.isSuccessful());
+        abortResult.await();
         assertTrue(transactionManager.isReady());  // make sure we are ready for a transaction now.
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, bumpedEpoch, producerId);
         runUntil(() -> transactionManager.isPartitionAdded(tp0));
@@ -2935,8 +2959,7 @@ public class TransactionManagerTest {
         doInitTransactions(producerId, initialEpoch);
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, initialEpoch, producerId);
         runUntil(() -> transactionManager.isPartitionAdded(tp0));
@@ -2975,11 +2998,11 @@ public class TransactionManagerTest {
 
         assertTrue(abortResult.isCompleted());
         assertTrue(abortResult.isSuccessful());
+        abortResult.await();
         assertTrue(transactionManager.isReady());  // make sure we are ready for a transaction now.
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, bumpedEpoch, producerId);
         runUntil(() -> transactionManager.isPartitionAdded(tp0));
@@ -2995,8 +3018,7 @@ public class TransactionManagerTest {
         doInitTransactions(producerId, initialEpoch);
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
         prepareAddPartitionsToTxnResponse(Errors.INVALID_PRODUCER_ID_MAPPING, tp0, initialEpoch, producerId);
         runUntil(transactionManager::hasAbortableError);
         TransactionalRequestResult abortResult = transactionManager.beginAbort();
@@ -3016,8 +3038,7 @@ public class TransactionManagerTest {
         doInitTransactions(producerId, initialEpoch);
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         Future<RecordMetadata> responseFuture = appendToAccumulator(tp0);
 
@@ -3148,12 +3169,12 @@ public class TransactionManagerTest {
 
     @Test
     public void testRetryAbortTransactionAfterCommitTimeout() {
-        assertThrows(KafkaException.class, () -> verifyCommitOrAbortTransactionRetriable(TransactionResult.COMMIT, TransactionResult.ABORT));
+        assertThrows(IllegalStateException.class, () -> verifyCommitOrAbortTransactionRetriable(TransactionResult.COMMIT, TransactionResult.ABORT));
     }
 
     @Test
     public void testRetryCommitTransactionAfterAbortTimeout() {
-        assertThrows(KafkaException.class, () -> verifyCommitOrAbortTransactionRetriable(TransactionResult.ABORT, TransactionResult.COMMIT));
+        assertThrows(IllegalStateException.class, () -> verifyCommitOrAbortTransactionRetriable(TransactionResult.ABORT, TransactionResult.COMMIT));
     }
 
     @Test
@@ -3270,8 +3291,7 @@ public class TransactionManagerTest {
         doInitTransactions();
 
         transactionManager.beginTransaction();
-        transactionManager.failIfNotReadyForSend();
-        transactionManager.maybeAddPartitionToTransaction(tp0);
+        transactionManager.maybeAddPartition(tp0);
 
         appendToAccumulator(tp0);
 
@@ -3284,12 +3304,7 @@ public class TransactionManagerTest {
         prepareEndTxnResponse(Errors.NONE, firstTransactionResult, producerId, epoch, true);
         runUntil(() -> !client.hasPendingResponses());
         assertFalse(result.isCompleted());
-
-        try {
-            result.await(MAX_BLOCK_TIMEOUT, TimeUnit.MILLISECONDS);
-            fail("Should have raised TimeoutException");
-        } catch (TimeoutException ignored) {
-        }
+        assertThrows(TimeoutException.class, () -> result.await(MAX_BLOCK_TIMEOUT, TimeUnit.MILLISECONDS));
 
         prepareFindCoordinatorResponse(Errors.NONE, false, CoordinatorType.TRANSACTION, transactionalId);
         runUntil(() -> !client.hasPendingResponses());
@@ -3522,13 +3537,17 @@ public class TransactionManagerTest {
     }
 
     private void doInitTransactions(long producerId, short epoch) {
-        transactionManager.initializeTransactions();
+        TransactionalRequestResult result = transactionManager.initializeTransactions();
         prepareFindCoordinatorResponse(Errors.NONE, false, CoordinatorType.TRANSACTION, transactionalId);
         runUntil(() -> transactionManager.coordinator(CoordinatorType.TRANSACTION) != null);
         assertEquals(brokerNode, transactionManager.coordinator(CoordinatorType.TRANSACTION));
 
         prepareInitPidResponse(Errors.NONE, false, producerId, epoch);
         runUntil(transactionManager::hasProducerId);
+
+        result.await();
+        assertTrue(result.isSuccessful());
+        assertTrue(result.isAcked());
     }
 
     private void assertAbortableError(Class<? extends RuntimeException> cause) {
diff --git a/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala b/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala
index 6efb860..1da4fcd 100644
--- a/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala
+++ b/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala
@@ -58,7 +58,7 @@ import org.apache.kafka.common.resource.{PatternType, Resource, ResourcePattern,
 import org.apache.kafka.common.security.auth.{AuthenticationContext, KafkaPrincipal, SecurityProtocol}
 import org.apache.kafka.common.security.authenticator.DefaultKafkaPrincipalBuilder
 import org.apache.kafka.common.utils.Utils
-import org.apache.kafka.common.{ElectionType, IsolationLevel, Node, TopicPartition, Uuid, requests}
+import org.apache.kafka.common.{ElectionType, IsolationLevel, KafkaException, Node, TopicPartition, Uuid, requests}
 import org.apache.kafka.test.{TestUtils => JTestUtils}
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo}
@@ -1883,31 +1883,38 @@ class AuthorizerIntegrationTest extends BaseRequestTest {
   def testIdempotentProducerNoIdempotentWriteAclInInitProducerId(): Unit = {
     createTopic(topic)
     addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), topicResource)
-    shouldIdempotentProducerFailInInitProducerId(true)
+    assertIdempotentSendAuthorizationFailure()
   }
 
-  def shouldIdempotentProducerFailInInitProducerId(expectAuthException: Boolean): Unit = {
+  private def assertIdempotentSendSuccess(): Unit = {
     val producer = buildIdempotentProducer()
-    try {
+    producer.send(new ProducerRecord[Array[Byte], Array[Byte]](topic, "hi".getBytes)).get()
+  }
+
+  private def assertIdempotentSendAuthorizationFailure(): Unit = {
+    val producer = buildIdempotentProducer()
+
+    def assertClusterAuthFailure(): Unit = {
       // the InitProducerId is sent asynchronously, so we expect the error either in the callback
       // or raised from send itself
-      producer.send(new ProducerRecord[Array[Byte], Array[Byte]](topic, "hi".getBytes)).get()
-      if (expectAuthException)
-        fail("Should have raised ClusterAuthorizationException")
-    } catch {
-      case e: ExecutionException =>
-        assertTrue(e.getCause.isInstanceOf[ClusterAuthorizationException])
-    }
-    try {
-      // the second time, the call to send itself should fail (the producer becomes unusable
-      // if no producerId can be obtained)
-      producer.send(new ProducerRecord[Array[Byte], Array[Byte]](topic, "hi".getBytes)).get()
-      if (expectAuthException)
-        fail("Should have raised ClusterAuthorizationException")
-    } catch {
-      case e: ExecutionException =>
-        assertTrue(e.getCause.isInstanceOf[ClusterAuthorizationException])
+      val exception = assertThrows(classOf[Exception], () => {
+        val future = producer.send(new ProducerRecord[Array[Byte], Array[Byte]](topic, "hi".getBytes))
+        future.get()
+      })
+
+      exception match {
+        case e@ (_: KafkaException | _: ExecutionException) =>
+          assertTrue(exception.getCause.isInstanceOf[ClusterAuthorizationException])
+        case _ =>
+          fail(s"Unexpected exception type raised from send: ${exception.getClass}")
+      }
     }
+
+    assertClusterAuthFailure()
+
+    // the second time, the call to send itself should fail (the producer becomes unusable
+    // if no producerId can be obtained)
+    assertClusterAuthFailure()
   }
 
   @Test
@@ -2119,14 +2126,14 @@ class AuthorizerIntegrationTest extends BaseRequestTest {
 
     for (_ <- 1 to 3) {
       addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), topicResource)
-      shouldIdempotentProducerFailInInitProducerId(true)
+      assertIdempotentSendAuthorizationFailure()
 
       addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), topicResource)
-      shouldIdempotentProducerFailInInitProducerId(false)
+      assertIdempotentSendSuccess()
 
       removeAllClientAcls()
       addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), topicResource)
-      shouldIdempotentProducerFailInInitProducerId(true)
+      assertIdempotentSendAuthorizationFailure()
     }
   }
 
@@ -2149,7 +2156,7 @@ class AuthorizerIntegrationTest extends BaseRequestTest {
     addAndVerifyAcls(Set(acl1, acl4, acl5), topicResource)
     addAndVerifyAcls(Set(acl2, acl3), unrelatedTopicResource)
     addAndVerifyAcls(Set(acl2, acl3), unrelatedGroupResource)
-    shouldIdempotentProducerFailInInitProducerId(false)
+    assertIdempotentSendSuccess()
   }
 
   @Test
@@ -2157,11 +2164,11 @@ class AuthorizerIntegrationTest extends BaseRequestTest {
     createTopic(topic)
     val allowWriteAce = new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)
     addAndVerifyAcls(Set(allowWriteAce), topicResource)
-    shouldIdempotentProducerFailInInitProducerId(false)
+    assertIdempotentSendSuccess()
 
     val denyWriteAce = new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, DENY)
     addAndVerifyAcls(Set(denyWriteAce), topicResource)
-    shouldIdempotentProducerFailInInitProducerId(true)
+    assertIdempotentSendAuthorizationFailure()
   }
 
   @Test
@@ -2175,10 +2182,10 @@ class AuthorizerIntegrationTest extends BaseRequestTest {
 
     addAndVerifyAcls(Set(allowWriteAce), prefixed)
     addAndVerifyAcls(Set(allowWriteAce), literal)
-    shouldIdempotentProducerFailInInitProducerId(false)
+    assertIdempotentSendSuccess()
 
     addAndVerifyAcls(Set(denyWriteAce), wildcard)
-    shouldIdempotentProducerFailInInitProducerId(true)
+    assertIdempotentSendAuthorizationFailure()
   }
 
   @Test
@@ -2191,7 +2198,7 @@ class AuthorizerIntegrationTest extends BaseRequestTest {
 
     addAndVerifyAcls(Set(denyWriteAce), prefixed)
     addAndVerifyAcls(Set(allowWriteAce), literal)
-    shouldIdempotentProducerFailInInitProducerId(true)
+    assertIdempotentSendAuthorizationFailure()
   }
 
   @Test
diff --git a/core/src/test/scala/integration/kafka/api/TransactionsTest.scala b/core/src/test/scala/integration/kafka/api/TransactionsTest.scala
index 1fbba9e..2d8689f 100644
--- a/core/src/test/scala/integration/kafka/api/TransactionsTest.scala
+++ b/core/src/test/scala/integration/kafka/api/TransactionsTest.scala
@@ -30,7 +30,7 @@ import kafka.utils.TestUtils.consumeRecords
 import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerGroupMetadata, KafkaConsumer, OffsetAndMetadata}
 import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord}
 import org.apache.kafka.common.errors.{InvalidProducerEpochException, ProducerFencedException, TimeoutException}
-import org.apache.kafka.common.{KafkaException, TopicPartition}
+import org.apache.kafka.common.TopicPartition
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo}
 
@@ -604,7 +604,7 @@ class TransactionsTest extends KafkaServerTestHarness {
     val producer = createTransactionalProducer(transactionalId = "normalProducer")
 
     producer.initTransactions()
-    assertThrows(classOf[KafkaException], () => producer.initTransactions())
+    assertThrows(classOf[IllegalStateException], () => producer.initTransactions())
   }
 
   @Test