You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jo...@apache.org on 2023/07/14 22:18:18 UTC

[kafka] branch trunk updated: KAFKA-14884: Include check transaction is still ongoing right before append (take 2) (#13787)

This is an automated email from the ASF dual-hosted git repository.

jolshan pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new ea0bb001262 KAFKA-14884: Include check transaction is still ongoing right before append (take 2) (#13787)
ea0bb001262 is described below

commit ea0bb001262320bc9233221955a2be31c85993b9
Author: Justine Olshan <jo...@confluent.io>
AuthorDate: Fri Jul 14 15:18:11 2023 -0700

    KAFKA-14884: Include check transaction is still ongoing right before append (take 2) (#13787)
    
    Introduced extra mapping to track verification state.
    
    When verifying, there is a race condition that the add partitions verification response returns that the partition is in the ongoing transaction, but an abort marker is written before we get to append. Therefore, we track any given transaction we are verifying with an object unique to that transaction.
    
    We check this unique state upon the first append to the log. After that, we can rely on currentTransactionFirstOffset. We remove the verification state on appending to the log with a transactional data record or marker.
    
    We will also clean up lingering verification state entries via the producer state entry expiration mechanism. We do not update the the timestamp on retrying a verification for a transaction, so each entry must be verified before producer.id.expiration.ms.
    
    There were a few other fixes:
    - Moved the transaction manager handling for failed batch into the future completed exceptionally block to avoid processing it twice (this caused issues in unit tests)
    - handle interrupted exceptions encountered when callback thread encountered them
    - change handling to throw error if we try to set verification state and leaderLogIfLocal is None.
    
    Reviewers: David Jacot <dj...@confluent.io>, Artem Livshits <al...@confluent.io>, Jason Gustafson <ja...@confluent.io>
---
 .../kafka/clients/producer/internals/Sender.java   |  19 +-
 .../producer/internals/TransactionManager.java     |   5 +
 .../java/org/apache/kafka/clients/MockClient.java  |   7 +-
 .../clients/producer/internals/SenderTest.java     |  50 ++++
 core/src/main/scala/kafka/cluster/Partition.scala  |  12 +-
 core/src/main/scala/kafka/log/UnifiedLog.scala     |  63 ++++-
 .../main/scala/kafka/server/ReplicaManager.scala   |  78 ++++--
 .../unit/kafka/cluster/AbstractPartitionTest.scala |   2 +-
 .../unit/kafka/cluster/PartitionLockTest.scala     |   4 +-
 .../scala/unit/kafka/cluster/PartitionTest.scala   |  37 ++-
 .../group/GroupMetadataManagerTest.scala           |  20 +-
 .../unit/kafka/log/ProducerStateManagerTest.scala  |  37 ++-
 .../test/scala/unit/kafka/log/UnifiedLogTest.scala | 112 ++++++++
 .../unit/kafka/server/ReplicaManagerTest.scala     | 306 +++++++++++++++------
 .../test/scala/unit/kafka/utils/TestUtils.scala    |   5 +-
 .../apache/kafka/jmh/server/CheckpointBench.java   |   2 +-
 .../kafka/server/util/InterBrokerSendThread.java   |   4 +
 .../server/util/InterBrokerSendThreadTest.java     |  37 +++
 .../internals/log/ProducerStateManager.java        |  29 ++
 .../internals/log/VerificationStateEntry.java      |  44 +++
 20 files changed, 719 insertions(+), 154 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/Sender.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/Sender.java
index 15c4be8c7b9..75968bfba8b 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/internals/Sender.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/Sender.java
@@ -793,19 +793,18 @@ public class Sender implements Runnable {
         Function<Integer, RuntimeException> recordExceptions,
         boolean adjustSequenceNumbers
     ) {
-        if (transactionManager != null) {
-            try {
-                // This call can throw an exception in the rare case that there's an invalid state transition
-                // attempted. Catch these so as not to interfere with the rest of the logic.
-                transactionManager.handleFailedBatch(batch, topLevelException, adjustSequenceNumbers);
-            } catch (Exception e) {
-                log.debug("Encountered error when handling a failed batch", e);
-            }
-        }
-
         this.sensors.recordErrors(batch.topicPartition.topic(), batch.recordCount);
 
         if (batch.completeExceptionally(topLevelException, recordExceptions)) {
+            if (transactionManager != null) {
+                try {
+                    // This call can throw an exception in the rare case that there's an invalid state transition
+                    // attempted. Catch these so as not to interfere with the rest of the logic.
+                    transactionManager.handleFailedBatch(batch, topLevelException, adjustSequenceNumbers);
+                } catch (Exception e) {
+                    log.debug("Encountered error when transaction manager was handling a failed batch", e);
+                }
+            }
             maybeRemoveAndDeallocateBatch(batch);
         }
     }
diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
index e8ffb97fa71..3eaaca04aa6 100644
--- a/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
+++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
@@ -1030,6 +1030,11 @@ public class TransactionManager {
         return isTransactional() && currentState == State.READY;
     }
 
+    // visible for testing
+    synchronized boolean isInitializing() {
+        return isTransactional() && currentState == State.INITIALIZING;
+    }
+
     void handleCoordinatorReady() {
         NodeApiVersions nodeApiVersions = transactionCoordinator != null ?
                 apiVersions.get(transactionCoordinator.idString()) :
diff --git a/clients/src/test/java/org/apache/kafka/clients/MockClient.java b/clients/src/test/java/org/apache/kafka/clients/MockClient.java
index 6d404a906aa..d6cdd14f365 100644
--- a/clients/src/test/java/org/apache/kafka/clients/MockClient.java
+++ b/clients/src/test/java/org/apache/kafka/clients/MockClient.java
@@ -189,6 +189,10 @@ public class MockClient implements KafkaClient {
 
     @Override
     public void disconnect(String node) {
+        disconnect(node, false);
+    }
+
+    public void disconnect(String node, boolean allowLateResponses) {
         long now = time.milliseconds();
         Iterator<ClientRequest> iter = requests.iterator();
         while (iter.hasNext()) {
@@ -197,7 +201,8 @@ public class MockClient implements KafkaClient {
                 short version = request.requestBuilder().latestAllowedVersion();
                 responses.add(new ClientResponse(request.makeHeader(version), request.callback(), request.destination(),
                         request.createdTimeMs(), now, true, null, null, null));
-                iter.remove();
+                if (!allowLateResponses)
+                    iter.remove();
             }
         }
         CompletableFuture<String> curDisconnectFuture = disconnectFuture;
diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
index b80817465a5..364ca181719 100644
--- a/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
@@ -3057,6 +3057,56 @@ public class SenderTest {
         assertEquals(RETRY_BACKOFF_MS, time.milliseconds() - request2);
     }
 
+    @Test
+    public void testReceiveFailedBatchTwiceWithTransactions() throws Exception {
+        ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(123456L, (short) 0);
+        apiVersions.update("0", NodeApiVersions.create(ApiKeys.INIT_PRODUCER_ID.id, (short) 0, (short) 3));
+        TransactionManager txnManager = new TransactionManager(logContext, "testFailTwice", 60000, 100, apiVersions);
+
+        setupWithTransactionState(txnManager);
+        doInitTransactions(txnManager, producerIdAndEpoch);
+
+        txnManager.beginTransaction();
+        txnManager.maybeAddPartition(tp0);
+        client.prepareResponse(buildAddPartitionsToTxnResponseData(0, Collections.singletonMap(tp0, Errors.NONE)));
+        sender.runOnce();
+
+        // Send first ProduceRequest
+        Future<RecordMetadata> request1 = appendToAccumulator(tp0);
+        sender.runOnce();  // send request
+
+        Node node = metadata.fetch().nodes().get(0);
+        time.sleep(2000L);
+        client.disconnect(node.idString(), true);
+        client.backoff(node, 10);
+
+        sender.runOnce(); // now expire the batch.
+        assertFutureFailure(request1, TimeoutException.class);
+
+        time.sleep(20);
+
+        sendIdempotentProducerResponse(0, tp0, Errors.INVALID_RECORD, -1);
+        sender.runOnce(); // receive late response
+
+        // Loop once and confirm that the transaction manager does not enter a fatal error state
+        sender.runOnce();
+        assertTrue(txnManager.hasAbortableError());
+        TransactionalRequestResult result = txnManager.beginAbort();
+        sender.runOnce();
+
+        respondToEndTxn(Errors.NONE);
+        sender.runOnce();
+        assertTrue(txnManager::isInitializing);
+        prepareInitProducerResponse(Errors.NONE, producerIdAndEpoch.producerId, producerIdAndEpoch.epoch);
+        sender.runOnce();
+        assertTrue(txnManager::isReady);
+
+        assertTrue(result.isSuccessful());
+        result.await();
+
+        txnManager.beginTransaction();
+    }
+
     private void verifyErrorMessage(ProduceResponse response, String expectedMessage) throws Exception {
         Future<RecordMetadata> future = appendToAccumulator(tp0, 0L, "key", "value");
         sender.runOnce(); // connect
diff --git a/core/src/main/scala/kafka/cluster/Partition.scala b/core/src/main/scala/kafka/cluster/Partition.scala
index 4913a787ef3..e1478c62d51 100755
--- a/core/src/main/scala/kafka/cluster/Partition.scala
+++ b/core/src/main/scala/kafka/cluster/Partition.scala
@@ -576,8 +576,12 @@ class Partition(val topicPartition: TopicPartition,
     }
   }
 
-  def hasOngoingTransaction(producerId: Long): Boolean = {
-    leaderLogIfLocal.exists(leaderLog => leaderLog.hasOngoingTransaction(producerId))
+  // Returns a verification guard object if we need to verify. This starts or continues the verification process. Otherwise return null.
+  def maybeStartTransactionVerification(producerId: Long): Object = {
+    leaderLogIfLocal match {
+      case Some(log) => log.maybeStartTransactionVerification(producerId)
+      case None => throw new NotLeaderOrFollowerException();
+    }
   }
 
   // Return true if the future replica exists and it has caught up with the current replica for this partition
@@ -1279,7 +1283,7 @@ class Partition(val topicPartition: TopicPartition,
   }
 
   def appendRecordsToLeader(records: MemoryRecords, origin: AppendOrigin, requiredAcks: Int,
-                            requestLocal: RequestLocal): LogAppendInfo = {
+                            requestLocal: RequestLocal, verificationGuard: Object = null): LogAppendInfo = {
     val (info, leaderHWIncremented) = inReadLock(leaderIsrUpdateLock) {
       leaderLogIfLocal match {
         case Some(leaderLog) =>
@@ -1293,7 +1297,7 @@ class Partition(val topicPartition: TopicPartition,
           }
 
           val info = leaderLog.appendAsLeader(records, leaderEpoch = this.leaderEpoch, origin,
-            interBrokerProtocolVersion, requestLocal)
+            interBrokerProtocolVersion, requestLocal, verificationGuard)
 
           // we may need to increment high watermark since ISR could be down to 1
           (info, maybeIncrementLeaderHW(leaderLog))
diff --git a/core/src/main/scala/kafka/log/UnifiedLog.scala b/core/src/main/scala/kafka/log/UnifiedLog.scala
index a2aedf0a290..9eb1af6d56f 100644
--- a/core/src/main/scala/kafka/log/UnifiedLog.scala
+++ b/core/src/main/scala/kafka/log/UnifiedLog.scala
@@ -577,6 +577,28 @@ class UnifiedLog(@volatile var logStartOffset: Long,
     result
   }
 
+  /**
+   * Maybe create and return the verification guard object for the given producer ID if the transaction is not yet ongoing.
+   * Creation starts the verification process. Otherwise return null.
+   */
+  def maybeStartTransactionVerification(producerId: Long): Object = lock synchronized {
+    if (hasOngoingTransaction(producerId))
+      null
+    else
+      getOrMaybeCreateVerificationGuard(producerId, true)
+  }
+
+  /**
+   * Maybe create the VerificationStateEntry for the given producer ID -- if an entry is present, return its verification guard, otherwise, return null.
+   */
+  def getOrMaybeCreateVerificationGuard(producerId: Long, createIfAbsent: Boolean = false): Object = lock synchronized {
+    val entry = producerStateManager.verificationStateEntry(producerId, createIfAbsent)
+    if (entry != null) entry.verificationGuard else null
+  }
+
+  /**
+   * Return true if the given producer ID has a transaction ongoing.
+   */
   def hasOngoingTransaction(producerId: Long): Boolean = lock synchronized {
     val entry = producerStateManager.activeProducers.get(producerId)
     entry != null && entry.currentTxnFirstOffset.isPresent
@@ -662,9 +684,10 @@ class UnifiedLog(@volatile var logStartOffset: Long,
                      leaderEpoch: Int,
                      origin: AppendOrigin = AppendOrigin.CLIENT,
                      interBrokerProtocolVersion: MetadataVersion = MetadataVersion.latest,
-                     requestLocal: RequestLocal = RequestLocal.NoCaching): LogAppendInfo = {
+                     requestLocal: RequestLocal = RequestLocal.NoCaching,
+                     verificationGuard: Object = null): LogAppendInfo = {
     val validateAndAssignOffsets = origin != AppendOrigin.RAFT_LEADER
-    append(records, origin, interBrokerProtocolVersion, validateAndAssignOffsets, leaderEpoch, Some(requestLocal), ignoreRecordSize = false)
+    append(records, origin, interBrokerProtocolVersion, validateAndAssignOffsets, leaderEpoch, Some(requestLocal), verificationGuard, ignoreRecordSize = false)
   }
 
   /**
@@ -681,6 +704,7 @@ class UnifiedLog(@volatile var logStartOffset: Long,
       validateAndAssignOffsets = false,
       leaderEpoch = -1,
       requestLocal = None,
+      verificationGuard = null,
       // disable to check the validation of record size since the record is already accepted by leader.
       ignoreRecordSize = true)
   }
@@ -709,6 +733,7 @@ class UnifiedLog(@volatile var logStartOffset: Long,
                      validateAndAssignOffsets: Boolean,
                      leaderEpoch: Int,
                      requestLocal: Option[RequestLocal],
+                     verificationGuard: Object,
                      ignoreRecordSize: Boolean): LogAppendInfo = {
     // We want to ensure the partition metadata file is written to the log dir before any log data is written to disk.
     // This will ensure that any log data can be recovered with the correct topic ID in the case of failure.
@@ -833,7 +858,7 @@ class UnifiedLog(@volatile var logStartOffset: Long,
           // now that we have valid records, offsets assigned, and timestamps updated, we need to
           // validate the idempotent/transactional state of the producers and collect some metadata
           val (updatedProducers, completedTxns, maybeDuplicate) = analyzeAndValidateProducerState(
-            logOffsetMetadata, validRecords, origin)
+            logOffsetMetadata, validRecords, origin, verificationGuard)
 
           maybeDuplicate match {
             case Some(duplicate) =>
@@ -961,7 +986,8 @@ class UnifiedLog(@volatile var logStartOffset: Long,
 
   private def analyzeAndValidateProducerState(appendOffsetMetadata: LogOffsetMetadata,
                                               records: MemoryRecords,
-                                              origin: AppendOrigin):
+                                              origin: AppendOrigin,
+                                              requestVerificationGuard: Object):
   (mutable.Map[Long, ProducerAppendInfo], List[CompletedTxn], Option[BatchMetadata]) = {
     val updatedProducers = mutable.Map.empty[Long, ProducerAppendInfo]
     val completedTxns = ListBuffer.empty[CompletedTxn]
@@ -978,6 +1004,25 @@ class UnifiedLog(@volatile var logStartOffset: Long,
           if (duplicateBatch.isPresent) {
             return (updatedProducers, completedTxns.toList, Some(duplicateBatch.get()))
           }
+
+          // Verify that if the record is transactional & the append origin is client, that we either have an ongoing transaction or verified transaction state.
+          // This guarantees that transactional records are never written to the log outside of the transaction coordinator's knowledge of an open transaction on
+          // the partition. If we do not have an ongoing transaction or correct guard, return an error and do not append.
+          // There are two phases -- the first append to the log and subsequent appends.
+          //
+          // 1. First append: Verification starts with creating a verification guard object, sending a verification request to the transaction coordinator, and
+          // given a "verified" response, continuing the append path. (A non-verified response throws an error.) We create the unique verification guard for the transaction
+          // to ensure there is no race between the transaction coordinator response and an abort marker getting written to the log. We need a unique guard because we could
+          // have a sequence of events where we start a transaction verification, have the transaction coordinator send a verified response, write an abort marker,
+          // start a new transaction not aware of the partition, and receive the stale verification (ABA problem). With a unique verification guard object, this sequence would not
+          // result in appending to the log and would return an error. The guard is removed after the first append to the transaction and from then, we can rely on phase 2.
+          //
+          // 2. Subsequent appends: Once we write to the transaction, the in-memory state currentTxnFirstOffset is populated. This field remains until the
+          // transaction is completed or aborted. We can guarantee the transaction coordinator knows about the transaction given step 1 and that the transaction is still
+          // ongoing. If the transaction is expected to be ongoing, we will not set a verification guard. If the transaction is aborted, hasOngoingTransaction is false and
+          // requestVerificationGuard is null, so we will throw an error. A subsequent produce request (retry) should create verification state and return to phase 1.
+          if (batch.isTransactional && !hasOngoingTransaction(batch.producerId) && batchMissingRequiredVerification(batch, requestVerificationGuard))
+            throw new InvalidRecordException("Record was not part of an ongoing transaction")
         }
 
         // We cache offset metadata for the start of each transaction. This allows us to
@@ -996,6 +1041,10 @@ class UnifiedLog(@volatile var logStartOffset: Long,
     (updatedProducers, completedTxns.toList, None)
   }
 
+  private def batchMissingRequiredVerification(batch: MutableRecordBatch, requestVerificationGuard: Object): Boolean = {
+    producerStateManager.producerStateManagerConfig().transactionVerificationEnabled() && (requestVerificationGuard != getOrMaybeCreateVerificationGuard(batch.producerId) || requestVerificationGuard == null)
+  }
+
   /**
    * Validate the following:
    * <ol>
@@ -1872,7 +1921,11 @@ object UnifiedLog extends Logging {
                               origin: AppendOrigin): Option[CompletedTxn] = {
     val producerId = batch.producerId
     val appendInfo = producers.getOrElseUpdate(producerId, producerStateManager.prepareUpdate(producerId, origin))
-    appendInfo.append(batch, firstOffsetMetadata.asJava).asScala
+    val completedTxn = appendInfo.append(batch, firstOffsetMetadata.asJava).asScala
+    // Whether we wrote a control marker or a data batch, we can remove verification guard since either the transaction is complete or we have a first offset.
+    if (batch.isTransactional)
+      producerStateManager.clearVerificationStateEntry(producerId)
+    completedTxn
   }
 
   /**
diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala b/core/src/main/scala/kafka/server/ReplicaManager.scala
index a0c050ce377..aea9adb7212 100644
--- a/core/src/main/scala/kafka/server/ReplicaManager.scala
+++ b/core/src/main/scala/kafka/server/ReplicaManager.scala
@@ -705,28 +705,18 @@ class ReplicaManager(val config: KafkaConfig,
                     actionQueue: ActionQueue = this.actionQueue): Unit = {
     if (isValidRequiredAcks(requiredAcks)) {
       val sTime = time.milliseconds
-      
-      val transactionalProducerIds = mutable.HashSet[Long]()
-      val (verifiedEntriesPerPartition, notYetVerifiedEntriesPerPartition) =
+
+      val verificationGuards: mutable.Map[TopicPartition, Object] = mutable.Map[TopicPartition, Object]()
+      val (verifiedEntriesPerPartition, notYetVerifiedEntriesPerPartition, errorsPerPartition) =
         if (transactionStatePartition.isEmpty || !config.transactionPartitionVerificationEnable)
-          (entriesPerPartition, Map.empty)
+          (entriesPerPartition, Map.empty[TopicPartition, MemoryRecords], Map.empty[TopicPartition, Errors])
         else {
-          entriesPerPartition.partition { case (topicPartition, records) =>
-            // Produce requests (only requests that require verification) should only have one batch per partition in "batches" but check all just to be safe.
-            val transactionalBatches = records.batches.asScala.filter(batch => batch.hasProducerId && batch.isTransactional)
-            transactionalBatches.foreach(batch => transactionalProducerIds.add(batch.producerId))
-            if (transactionalBatches.nonEmpty) {
-              getPartitionOrException(topicPartition).hasOngoingTransaction(transactionalBatches.head.producerId)
-            } else {
-              // If there is no producer ID in the batches, no need to verify.
-              true
-            }
-          }
+          val verifiedEntries = mutable.Map[TopicPartition, MemoryRecords]()
+          val unverifiedEntries = mutable.Map[TopicPartition, MemoryRecords]()
+          val errorEntries = mutable.Map[TopicPartition, Errors]()
+          partitionEntriesForVerification(verificationGuards, entriesPerPartition, verifiedEntries, unverifiedEntries, errorEntries)
+          (verifiedEntries.toMap, unverifiedEntries.toMap, errorEntries.toMap)
         }
-      // We should have exactly one producer ID for transactional records
-      if (transactionalProducerIds.size > 1) {
-        throw new InvalidPidMappingException("Transactional records contained more than one producer ID")
-      }
 
       def appendEntries(allEntries: Map[TopicPartition, MemoryRecords])(unverifiedEntries: Map[TopicPartition, Errors]): Unit = {
         val verifiedEntries = 
@@ -738,7 +728,7 @@ class ReplicaManager(val config: KafkaConfig,
             }
         
         val localProduceResults = appendToLocalLog(internalTopicsAllowed = internalTopicsAllowed,
-          origin, verifiedEntries, requiredAcks, requestLocal)
+          origin, verifiedEntries, requiredAcks, requestLocal, verificationGuards.toMap)
         debug("Produce to local log in %d ms".format(time.milliseconds - sTime))
         
         val unverifiedResults = unverifiedEntries.map { case (topicPartition, error) =>
@@ -749,8 +739,15 @@ class ReplicaManager(val config: KafkaConfig,
             Some(error.exception(message))
           )
         }
+
+        val errorResults = errorsPerPartition.map { case (topicPartition, error) =>
+          topicPartition -> LogAppendResult(
+            LogAppendInfo.UNKNOWN_LOG_APPEND_INFO,
+            Some(error.exception())
+          )
+        }
         
-        val allResults = localProduceResults ++ unverifiedResults
+        val allResults = localProduceResults ++ unverifiedResults ++ errorResults
 
         val produceStatus = allResults.map { case (topicPartition, result) =>
           topicPartition -> ProducePartitionStatus(
@@ -851,6 +848,40 @@ class ReplicaManager(val config: KafkaConfig,
     }
   }
 
+  private def partitionEntriesForVerification(verificationGuards: mutable.Map[TopicPartition, Object],
+                                              entriesPerPartition: Map[TopicPartition, MemoryRecords],
+                                              verifiedEntries: mutable.Map[TopicPartition, MemoryRecords],
+                                              unverifiedEntries: mutable.Map[TopicPartition, MemoryRecords],
+                                              errorEntries: mutable.Map[TopicPartition, Errors]): Unit= {
+    val transactionalProducerIds = mutable.HashSet[Long]()
+    entriesPerPartition.foreach { case (topicPartition, records) =>
+      try {
+        // Produce requests (only requests that require verification) should only have one batch per partition in "batches" but check all just to be safe.
+        val transactionalBatches = records.batches.asScala.filter(batch => batch.hasProducerId && batch.isTransactional)
+        transactionalBatches.foreach(batch => transactionalProducerIds.add(batch.producerId))
+
+        if (transactionalBatches.nonEmpty) {
+          // We return verification guard if the partition needs to be verified. If no state is present, no need to verify.
+          val verificationGuard = getPartitionOrException(topicPartition).maybeStartTransactionVerification(records.firstBatch.producerId)
+          if (verificationGuard != null) {
+            verificationGuards.put(topicPartition, verificationGuard)
+            unverifiedEntries.put(topicPartition, records)
+          } else
+            verifiedEntries.put(topicPartition, records)
+        } else {
+          // If there is no producer ID or transactional records in the batches, no need to verify.
+          verifiedEntries.put(topicPartition, records)
+        }
+      } catch {
+        case e: Exception => errorEntries.put(topicPartition, Errors.forException(e))
+      }
+    }
+    // We should have exactly one producer ID for transactional records
+    if (transactionalProducerIds.size > 1) {
+      throw new InvalidPidMappingException("Transactional records contained more than one producer ID")
+    }
+  }
+
   /**
    * Delete records on leader replicas of the partition, and wait for delete records operation be propagated to other replicas;
    * the callback function will be triggered either when timeout or logStartOffset of all live replicas have reached the specified offset
@@ -1107,7 +1138,8 @@ class ReplicaManager(val config: KafkaConfig,
                                origin: AppendOrigin,
                                entriesPerPartition: Map[TopicPartition, MemoryRecords],
                                requiredAcks: Short,
-                               requestLocal: RequestLocal): Map[TopicPartition, LogAppendResult] = {
+                               requestLocal: RequestLocal,
+                               verificationGuards: Map[TopicPartition, Object]): Map[TopicPartition, LogAppendResult] = {
     val traceEnabled = isTraceEnabled
     def processFailedRecord(topicPartition: TopicPartition, t: Throwable) = {
       val logStartOffset = onlinePartition(topicPartition).map(_.logStartOffset).getOrElse(-1L)
@@ -1133,7 +1165,7 @@ class ReplicaManager(val config: KafkaConfig,
       } else {
         try {
           val partition = getPartitionOrException(topicPartition)
-          val info = partition.appendRecordsToLeader(records, origin, requiredAcks, requestLocal)
+          val info = partition.appendRecordsToLeader(records, origin, requiredAcks, requestLocal, verificationGuards.getOrElse(topicPartition, null))
           val numAppendedMessages = info.numMessages
 
           // update stats for successfully appended bytes and messages as bytesInRate and messageInRate
diff --git a/core/src/test/scala/unit/kafka/cluster/AbstractPartitionTest.scala b/core/src/test/scala/unit/kafka/cluster/AbstractPartitionTest.scala
index bf360597511..1d0577cc977 100644
--- a/core/src/test/scala/unit/kafka/cluster/AbstractPartitionTest.scala
+++ b/core/src/test/scala/unit/kafka/cluster/AbstractPartitionTest.scala
@@ -74,7 +74,7 @@ class AbstractPartitionTest {
     logDir1 = TestUtils.randomPartitionLogDir(tmpDir)
     logDir2 = TestUtils.randomPartitionLogDir(tmpDir)
     logManager = TestUtils.createLogManager(Seq(logDir1, logDir2), logConfig, configRepository,
-      new CleanerConfig(false), time, interBrokerProtocolVersion)
+      new CleanerConfig(false), time, interBrokerProtocolVersion, transactionVerificationEnabled = true)
     logManager.startup(Set.empty)
 
     alterPartitionManager = TestUtils.createAlterIsrManager()
diff --git a/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala b/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala
index 614a9d6ffce..98c2a60bcfc 100644
--- a/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala
+++ b/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala
@@ -449,8 +449,8 @@ class PartitionLockTest extends Logging {
     keepPartitionMetadataFile = true) {
 
     override def appendAsLeader(records: MemoryRecords, leaderEpoch: Int, origin: AppendOrigin,
-                                interBrokerProtocolVersion: MetadataVersion, requestLocal: RequestLocal): LogAppendInfo = {
-      val appendInfo = super.appendAsLeader(records, leaderEpoch, origin, interBrokerProtocolVersion, requestLocal)
+                                interBrokerProtocolVersion: MetadataVersion, requestLocal: RequestLocal, verificationGuard: Object): LogAppendInfo = {
+      val appendInfo = super.appendAsLeader(records, leaderEpoch, origin, interBrokerProtocolVersion, requestLocal, verificationGuard)
       appendSemaphore.acquire()
       appendInfo
     }
diff --git a/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala b/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
index a6ad25e5ed9..ab3d06ef653 100644
--- a/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
+++ b/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
@@ -32,7 +32,7 @@ import org.apache.kafka.common.record.FileRecords.TimestampAndOffset
 import org.apache.kafka.common.record._
 import org.apache.kafka.common.requests.{AlterPartitionResponse, FetchRequest, ListOffsetsRequest, RequestHeader}
 import org.apache.kafka.common.utils.SystemTime
-import org.apache.kafka.common.{IsolationLevel, TopicPartition, Uuid}
+import org.apache.kafka.common.{InvalidRecordException, IsolationLevel, TopicPartition, Uuid}
 import org.apache.kafka.metadata.LeaderRecoveryState
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.Test
@@ -996,8 +996,10 @@ class PartitionTest extends AbstractPartitionTest {
       new SimpleRecord("k1".getBytes, "v1".getBytes),
       new SimpleRecord("k2".getBytes, "v2".getBytes),
       new SimpleRecord("k3".getBytes, "v3".getBytes)),
-      baseOffset = 0L)
-    partition.appendRecordsToLeader(records, origin = AppendOrigin.CLIENT, requiredAcks = 0, RequestLocal.withThreadConfinedCaching)
+      baseOffset = 0L,
+      producerId = 2L)
+    val verificationGuard = partition.maybeStartTransactionVerification(2L)
+    partition.appendRecordsToLeader(records, origin = AppendOrigin.CLIENT, requiredAcks = 0, RequestLocal.withThreadConfinedCaching, verificationGuard)
 
     def fetchOffset(isolationLevel: Option[IsolationLevel], timestamp: Long): TimestampAndOffset = {
       val res = partition.fetchOffsetForTimestamp(timestamp,
@@ -3349,7 +3351,7 @@ class PartitionTest extends AbstractPartitionTest {
   }
 
   @Test
-  def testHasOngoingTransaction(): Unit = {
+  def testMaybeStartTransactionVerification(): Unit = {
     val controllerEpoch = 0
     val leaderEpoch = 5
     val replicas = List[Integer](brokerId, brokerId + 1).asJava
@@ -3367,7 +3369,6 @@ class PartitionTest extends AbstractPartitionTest {
       .setReplicas(replicas)
       .setIsNew(true), offsetCheckpoints, None), "Expected become leader transition to succeed")
     assertEquals(leaderEpoch, partition.getLeaderEpoch)
-    assertFalse(partition.hasOngoingTransaction(producerId))
 
     val idempotentRecords = createIdempotentRecords(List(
       new SimpleRecord("k1".getBytes, "v1".getBytes),
@@ -3376,17 +3377,35 @@ class PartitionTest extends AbstractPartitionTest {
       baseOffset = 0L,
       producerId = producerId)
     partition.appendRecordsToLeader(idempotentRecords, origin = AppendOrigin.CLIENT, requiredAcks = 1, RequestLocal.withThreadConfinedCaching)
-    assertFalse(partition.hasOngoingTransaction(producerId))
 
-    val transactionRecords = createTransactionalRecords(List(
+    def transactionRecords() = createTransactionalRecords(List(
       new SimpleRecord("k1".getBytes, "v1".getBytes),
       new SimpleRecord("k2".getBytes, "v2".getBytes),
       new SimpleRecord("k3".getBytes, "v3".getBytes)),
       baseOffset = 0L,
       baseSequence = 3,
       producerId = producerId)
-    partition.appendRecordsToLeader(transactionRecords, origin = AppendOrigin.CLIENT, requiredAcks = 1, RequestLocal.withThreadConfinedCaching)
-    assertTrue(partition.hasOngoingTransaction(producerId))
+
+    // When verification guard is not there, we should not be able to append.
+    assertThrows(classOf[InvalidRecordException], () => partition.appendRecordsToLeader(transactionRecords(), origin = AppendOrigin.CLIENT, requiredAcks = 1, RequestLocal.withThreadConfinedCaching))
+
+    // Before appendRecordsToLeader is called, ReplicaManager will call maybeStartTransactionVerification. We should get a non-null verification object.
+    val verificationGuard = partition.maybeStartTransactionVerification(producerId)
+    assertNotNull(verificationGuard)
+
+    // With the wrong verification guard, append should fail.
+    assertThrows(classOf[InvalidRecordException], () => partition.appendRecordsToLeader(transactionRecords(),
+      origin = AppendOrigin.CLIENT, requiredAcks = 1, RequestLocal.withThreadConfinedCaching, Optional.of(new Object)))
+
+    // We should return the same verification object when we still need to verify. Append should proceed.
+    val verificationGuard2 = partition.maybeStartTransactionVerification(producerId)
+    assertEquals(verificationGuard, verificationGuard2)
+    partition.appendRecordsToLeader(transactionRecords(), origin = AppendOrigin.CLIENT, requiredAcks = 1, RequestLocal.withThreadConfinedCaching, verificationGuard)
+
+    // We should no longer need a verification object. Future appends without verification guard will also succeed.
+    val verificationGuard3 = partition.maybeStartTransactionVerification(producerId)
+    assertNull(verificationGuard3)
+    partition.appendRecordsToLeader(transactionRecords(), origin = AppendOrigin.CLIENT, requiredAcks = 1, RequestLocal.withThreadConfinedCaching)
   }
 
   private def makeLeader(
diff --git a/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala
index 479fc61b8f0..4304841453b 100644
--- a/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala
@@ -1692,7 +1692,7 @@ class GroupMetadataManagerTest {
 
     when(partition.appendRecordsToLeader(any[MemoryRecords],
       origin = ArgumentMatchers.eq(AppendOrigin.COORDINATOR), requiredAcks = anyInt(),
-      any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
+      any(), any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
     groupMetadataManager.cleanupGroupMetadata()
 
     assertEquals(Some(group), groupMetadataManager.getGroup(groupId))
@@ -1740,7 +1740,7 @@ class GroupMetadataManagerTest {
     mockGetPartition()
     when(partition.appendRecordsToLeader(recordsCapture.capture(),
       origin = ArgumentMatchers.eq(AppendOrigin.COORDINATOR), requiredAcks = anyInt(),
-      any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
+      any(), any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
     groupMetadataManager.cleanupGroupMetadata()
 
     val records = recordsCapture.getValue.records.asScala.toList
@@ -1783,7 +1783,7 @@ class GroupMetadataManagerTest {
     mockGetPartition()
     when(partition.appendRecordsToLeader(recordsCapture.capture(),
       origin = ArgumentMatchers.eq(AppendOrigin.COORDINATOR), requiredAcks = anyInt(),
-      any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
+      any(), any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
     groupMetadataManager.cleanupGroupMetadata()
 
     val records = recordsCapture.getValue.records.asScala.toList
@@ -1851,7 +1851,7 @@ class GroupMetadataManagerTest {
 
     when(partition.appendRecordsToLeader(recordsCapture.capture(),
       origin = ArgumentMatchers.eq(AppendOrigin.COORDINATOR), requiredAcks = anyInt(),
-      any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
+      any(), any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
     groupMetadataManager.cleanupGroupMetadata()
 
     // verify the tombstones are correct and only for the expired offsets
@@ -1959,7 +1959,7 @@ class GroupMetadataManagerTest {
     // expect the offset tombstone
     when(partition.appendRecordsToLeader(any[MemoryRecords],
       origin = ArgumentMatchers.eq(AppendOrigin.COORDINATOR), requiredAcks = anyInt(),
-      any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
+      any(), any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
     groupMetadataManager.cleanupGroupMetadata()
 
     // group is empty now, only one offset should expire
@@ -1984,7 +1984,7 @@ class GroupMetadataManagerTest {
     // expect the offset tombstone
     when(partition.appendRecordsToLeader(any[MemoryRecords],
       origin = ArgumentMatchers.eq(AppendOrigin.COORDINATOR), requiredAcks = anyInt(),
-      any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
+      any(), any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
     groupMetadataManager.cleanupGroupMetadata()
 
     // one more offset should expire
@@ -2041,7 +2041,7 @@ class GroupMetadataManagerTest {
     // expect the offset tombstone
     when(partition.appendRecordsToLeader(any[MemoryRecords],
       origin = ArgumentMatchers.eq(AppendOrigin.COORDINATOR), requiredAcks = anyInt(),
-      any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
+      any(), any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
     groupMetadataManager.cleanupGroupMetadata()
 
     // group and all its offsets should be gone now
@@ -2131,7 +2131,7 @@ class GroupMetadataManagerTest {
     // expect the offset tombstone
     when(partition.appendRecordsToLeader(any[MemoryRecords],
       origin = ArgumentMatchers.eq(AppendOrigin.COORDINATOR), requiredAcks = anyInt(),
-      any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
+      any(), any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
     groupMetadataManager.cleanupGroupMetadata()
 
     // group and all its offsets should be gone now
@@ -2283,13 +2283,13 @@ class GroupMetadataManagerTest {
     // expect the offset tombstone
     when(partition.appendRecordsToLeader(any[MemoryRecords],
       origin = ArgumentMatchers.eq(AppendOrigin.COORDINATOR), requiredAcks = anyInt(),
-      any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
+      any(), any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
 
     groupMetadataManager.cleanupGroupMetadata()
 
     verify(partition).appendRecordsToLeader(any[MemoryRecords],
       origin = ArgumentMatchers.eq(AppendOrigin.COORDINATOR), requiredAcks = anyInt(),
-      any())
+      any(), any())
     verify(replicaManager, times(2)).onlinePartition(groupTopicPartition)
 
     assertEquals(Some(group), groupMetadataManager.getGroup(groupId))
diff --git a/core/src/test/scala/unit/kafka/log/ProducerStateManagerTest.scala b/core/src/test/scala/unit/kafka/log/ProducerStateManagerTest.scala
index b98d0b3acfc..7f8d48071a6 100644
--- a/core/src/test/scala/unit/kafka/log/ProducerStateManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/log/ProducerStateManagerTest.scala
@@ -29,7 +29,7 @@ import org.apache.kafka.common.errors._
 import org.apache.kafka.common.internals.Topic
 import org.apache.kafka.common.record._
 import org.apache.kafka.common.utils.{MockTime, Utils}
-import org.apache.kafka.storage.internals.log.{AppendOrigin, CompletedTxn, LogFileUtils, LogOffsetMetadata, ProducerAppendInfo, ProducerStateEntry, ProducerStateManager, ProducerStateManagerConfig, TxnMetadata}
+import org.apache.kafka.storage.internals.log.{AppendOrigin, CompletedTxn, LogFileUtils, LogOffsetMetadata, ProducerAppendInfo, ProducerStateEntry, ProducerStateManager, ProducerStateManagerConfig, TxnMetadata, VerificationStateEntry}
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
 import org.mockito.Mockito.{mock, when}
@@ -1087,6 +1087,41 @@ class ProducerStateManagerTest {
     assertTrue(!manager.latestSnapshotOffset.isPresent)
   }
 
+  @Test
+  def testEntryForVerification(): Unit = {
+    val originalEntry = stateManager.verificationStateEntry(producerId, true)
+    val originalEntryVerificationGuard = originalEntry.verificationGuard()
+
+    def verifyEntry(producerId: Long, newEntry: VerificationStateEntry): Unit = {
+      val entry = stateManager.verificationStateEntry(producerId, false)
+      assertEquals(originalEntryVerificationGuard, entry.verificationGuard)
+      assertEquals(entry.verificationGuard, newEntry.verificationGuard)
+    }
+
+    // If we already have an entry, reuse it.
+    val updatedEntry = stateManager.verificationStateEntry(producerId, true)
+    verifyEntry(producerId, updatedEntry)
+
+    // Add the transactional data and clear the entry.
+    append(stateManager, producerId, 0, 0, offset = 0, isTransactional = true)
+    stateManager.clearVerificationStateEntry(producerId)
+    assertNull(stateManager.verificationStateEntry(producerId, false))
+  }
+
+  @Test
+  def testVerificationStateEntryExpiration(): Unit = {
+    val originalEntry = stateManager.verificationStateEntry(producerId, true)
+
+    // Before timeout we do not remove. Note: Accessing the verification entry does not update the time.
+    time.sleep(producerStateManagerConfig.producerIdExpirationMs / 2)
+    stateManager.removeExpiredProducers(time.milliseconds())
+    assertEquals(originalEntry, stateManager.verificationStateEntry(producerId, false))
+
+    time.sleep((producerStateManagerConfig.producerIdExpirationMs / 2) + 1)
+    stateManager.removeExpiredProducers(time.milliseconds())
+    assertNull(stateManager.verificationStateEntry(producerId, false))
+  }
+
   private def testLoadFromCorruptSnapshot(makeFileCorrupt: FileChannel => Unit): Unit = {
     val epoch = 0.toShort
     val producerId = 1L
diff --git a/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala b/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala
index 0444f940fc5..a1e22d8ba0f 100755
--- a/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala
+++ b/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala
@@ -3667,6 +3667,118 @@ class UnifiedLogTest {
     listener.verify(expectedHighWatermark = 4)
   }
 
+  @Test
+  def testTransactionIsOngoingAndVerificationGuard(): Unit = {
+    val producerStateManagerConfig = new ProducerStateManagerConfig(86400000, true)
+
+    val producerId = 23L
+    val producerEpoch = 1.toShort
+    val sequence = 3
+    val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5)
+    val log = createLog(logDir, logConfig, producerStateManagerConfig = producerStateManagerConfig)
+    assertFalse(log.hasOngoingTransaction(producerId))
+    assertNull(log.getOrMaybeCreateVerificationGuard(producerId))
+
+    val idempotentRecords = MemoryRecords.withIdempotentRecords(
+      CompressionType.NONE,
+      producerId,
+      producerEpoch,
+      sequence,
+      new SimpleRecord("1".getBytes),
+      new SimpleRecord("2".getBytes)
+    )
+
+    val verificationGuard = log.maybeStartTransactionVerification(producerId)
+    assertNotNull(verificationGuard)
+
+    log.appendAsLeader(idempotentRecords, leaderEpoch = 0)
+    assertFalse(log.hasOngoingTransaction(producerId))
+
+    // Since we wrote idempotent records, we keep verification guard.
+    assertEquals(verificationGuard, log.getOrMaybeCreateVerificationGuard(producerId))
+
+    val transactionalRecords = MemoryRecords.withTransactionalRecords(
+      CompressionType.NONE,
+      producerId,
+      producerEpoch,
+      sequence + 2,
+      new SimpleRecord("1".getBytes),
+      new SimpleRecord("2".getBytes)
+    )
+
+    log.appendAsLeader(transactionalRecords, leaderEpoch = 0, verificationGuard = verificationGuard)
+    assertTrue(log.hasOngoingTransaction(producerId))
+    // Verification guard should be cleared now.
+    assertNull(log.getOrMaybeCreateVerificationGuard(producerId))
+
+    // A subsequent maybeStartTransactionVerification will be empty since we are already verified.
+    assertNull(log.maybeStartTransactionVerification(producerId))
+
+    val endTransactionMarkerRecord = MemoryRecords.withEndTransactionMarker(
+      producerId,
+      producerEpoch,
+      new EndTransactionMarker(ControlRecordType.COMMIT, 0)
+    )
+
+    log.appendAsLeader(endTransactionMarkerRecord, origin = AppendOrigin.COORDINATOR, leaderEpoch = 0)
+    assertFalse(log.hasOngoingTransaction(producerId))
+    assertNull(log.getOrMaybeCreateVerificationGuard(producerId))
+
+    // A new maybeStartTransactionVerification will not be empty, as we need to verify the next transaction.
+    val newVerificationGuard = log.maybeStartTransactionVerification(producerId)
+    assertNotNull(newVerificationGuard)
+    assertNotEquals(verificationGuard, newVerificationGuard)
+  }
+
+  @Test
+  def testEmptyTransactionStillClearsVerificationGuard(): Unit = {
+    val producerStateManagerConfig = new ProducerStateManagerConfig(86400000, true)
+
+    val producerId = 23L
+    val producerEpoch = 1.toShort
+    val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5)
+    val log = createLog(logDir, logConfig, producerStateManagerConfig = producerStateManagerConfig)
+
+    val verificationGuard = log.maybeStartTransactionVerification(producerId)
+    assertNotNull(verificationGuard)
+
+    val endTransactionMarkerRecord = MemoryRecords.withEndTransactionMarker(
+      producerId,
+      producerEpoch,
+      new EndTransactionMarker(ControlRecordType.COMMIT, 0)
+    )
+
+    log.appendAsLeader(endTransactionMarkerRecord, origin = AppendOrigin.COORDINATOR, leaderEpoch = 0)
+    assertFalse(log.hasOngoingTransaction(producerId))
+    assertNull(log.getOrMaybeCreateVerificationGuard(producerId))
+  }
+
+  @Test
+  def testAllowNonZeroSequenceOnFirstAppendNonZeroEpoch(): Unit = {
+    val producerStateManagerConfig = new ProducerStateManagerConfig(86400000, true)
+
+    val producerId = 23L
+    val producerEpoch = 1.toShort
+    val sequence = 3
+    val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5)
+    val log = createLog(logDir, logConfig, producerStateManagerConfig = producerStateManagerConfig)
+    assertFalse(log.hasOngoingTransaction(producerId))
+    assertNull(log.getOrMaybeCreateVerificationGuard(producerId))
+
+    val transactionalRecords = MemoryRecords.withTransactionalRecords(
+      CompressionType.NONE,
+      producerId,
+      producerEpoch,
+      sequence,
+      new SimpleRecord("1".getBytes),
+      new SimpleRecord("2".getBytes)
+    )
+
+    val verificationGuard = log.maybeStartTransactionVerification(producerId)
+    // Append should not throw error.
+    log.appendAsLeader(transactionalRecords, leaderEpoch = 0, verificationGuard = verificationGuard)
+  }
+
   private def appendTransactionalToBuffer(buffer: ByteBuffer,
                                           producerId: Long,
                                           producerEpoch: Short,
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
index 8371fc7d317..2bbda923529 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
@@ -87,6 +87,7 @@ class ReplicaManagerTest {
   private val topicId = Uuid.randomUuid()
   private val topicIds = scala.Predef.Map("test-topic" -> topicId)
   private val topicNames = scala.Predef.Map(topicId -> "test-topic")
+  private val transactionalId = "txn"
   private val time = new MockTime
   private val scheduler = new MockScheduler(time)
   private val metrics = new Metrics
@@ -94,6 +95,7 @@ class ReplicaManagerTest {
   private var config: KafkaConfig = _
   private var quotaManager: QuotaManagers = _
   private var mockRemoteLogManager: RemoteLogManager = _
+  private var addPartitionsToTxnManager: AddPartitionsToTxnManager = _
 
   // Constants defined for readability
   private val zkVersion = 0
@@ -108,6 +110,14 @@ class ReplicaManagerTest {
     alterPartitionManager = mock(classOf[AlterPartitionManager])
     quotaManager = QuotaFactory.instantiate(config, metrics, time, "")
     mockRemoteLogManager = mock(classOf[RemoteLogManager])
+    addPartitionsToTxnManager = mock(classOf[AddPartitionsToTxnManager])
+
+    // Anytime we try to verify, just automatically run the callback as though the transaction was verified.
+    when(addPartitionsToTxnManager.addTxnData(any(), any(), any())).thenAnswer {
+      invocationOnMock =>
+        val callback = invocationOnMock.getArgument(2, classOf[AddPartitionsToTxnManager.AppendCallback])
+        callback(Map.empty[TopicPartition, Errors].toMap)
+    }
   }
 
   @AfterEach
@@ -596,7 +606,7 @@ class ReplicaManagerTest {
       // Simulate producer id expiration.
       // We use -1 because the timestamp in this test is set to -1, so when
       // the expiration check subtracts timestamp, we get max value.
-      partition0.removeExpiredProducers(Long.MaxValue - 1);
+      partition0.removeExpiredProducers(Long.MaxValue - 1)
       assertEquals(1, replicaManagerMetricValue())
     } finally {
       replicaManager.shutdown(checkpointHW = false)
@@ -643,7 +653,7 @@ class ReplicaManagerTest {
       val sequence = 9
       val records = MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, epoch, sequence,
         new SimpleRecord(time.milliseconds(), s"message $sequence".getBytes))
-      appendRecords(replicaManager, new TopicPartition(topic, 0), records).onFire { response =>
+      appendRecords(replicaManager, new TopicPartition(topic, 0), records, transactionalId = transactionalId, transactionStatePartition = Some(0)).onFire { response =>
         assertEquals(Errors.NONE, response.error)
       }
       assertLateTransactionCount(Some(0))
@@ -707,7 +717,7 @@ class ReplicaManagerTest {
       for (sequence <- 0 until numRecords) {
         val records = MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, epoch, sequence,
           new SimpleRecord(s"message $sequence".getBytes))
-        appendRecords(replicaManager, new TopicPartition(topic, 0), records).onFire { response =>
+        appendRecords(replicaManager, new TopicPartition(topic, 0), records, transactionalId = transactionalId, transactionStatePartition = Some(0)).onFire { response =>
           assertEquals(Errors.NONE, response.error)
         }
       }
@@ -828,7 +838,7 @@ class ReplicaManagerTest {
       for (sequence <- 0 until numRecords) {
         val records = MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, epoch, sequence,
           new SimpleRecord(s"message $sequence".getBytes))
-        appendRecords(replicaManager, new TopicPartition(topic, 0), records).onFire { response =>
+        appendRecords(replicaManager, new TopicPartition(topic, 0), records, transactionalId = transactionalId, transactionStatePartition = Some(0)).onFire { response =>
           assertEquals(Errors.NONE, response.error)
         }
       }
@@ -2130,54 +2140,74 @@ class ReplicaManagerTest {
   }
 
   @Test
-  def testVerificationForTransactionalPartitions(): Unit = {
-    val tp = new TopicPartition(topic, 0)
-    val transactionalId = "txn1"
+  def testVerificationForTransactionalPartitionsOnly(): Unit = {
+    val tp0 = new TopicPartition(topic, 0)
+    val tp1 = new TopicPartition(topic, 1)
     val producerId = 24L
     val producerEpoch = 0.toShort
     val sequence = 0
-
-    val mockLogMgr = TestUtils.createLogManager(config.logDirs.map(new File(_)))
-    val metadataCache = mock(classOf[MetadataCache])
+    val node = new Node(0, "host1", 0)
     val addPartitionsToTxnManager = mock(classOf[AddPartitionsToTxnManager])
 
-    val replicaManager = new ReplicaManager(
-      metrics = metrics,
-      config = config,
-      time = time,
-      scheduler = new MockScheduler(time),
-      logManager = mockLogMgr,
-      quotaManagers = quotaManager,
-      metadataCache = metadataCache,
-      logDirFailureChannel = new LogDirFailureChannel(config.logDirs.size),
-      alterPartitionManager = alterPartitionManager,
-      addPartitionsToTxnManager = Some(addPartitionsToTxnManager))
-
+    val replicaManager = setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager, List(tp0, tp1), node)
     try {
-      val becomeLeaderRequest = makeLeaderAndIsrRequest(topicIds(tp.topic), tp, Seq(0, 1), LeaderAndIsr(1,  List(0, 1)))
-      replicaManager.becomeLeaderOrFollower(1, becomeLeaderRequest, (_, _) => ())
+      replicaManager.becomeLeaderOrFollower(1,
+        makeLeaderAndIsrRequest(topicIds(tp0.topic), tp0, Seq(0, 1), LeaderAndIsr(1, List(0, 1))),
+        (_, _) => ())
 
-      // We must set up the metadata cache to handle the append and verification.
-      val metadataResponseTopic = Seq(new MetadataResponseTopic()
-        .setName(Topic.TRANSACTION_STATE_TOPIC_NAME)
-        .setPartitions(Seq(
-          new MetadataResponsePartition()
-            .setPartitionIndex(0)
-            .setLeaderId(0)).asJava))
-      val node = new Node(0, "host1", 0)
+      replicaManager.becomeLeaderOrFollower(1,
+        makeLeaderAndIsrRequest(topicIds(tp1.topic), tp1, Seq(0, 1), LeaderAndIsr(1, List(0, 1))),
+        (_, _) => ())
 
-      when(metadataCache.contains(tp)).thenReturn(true)
-      when(metadataCache.getTopicMetadata(Set(Topic.TRANSACTION_STATE_TOPIC_NAME), config.interBrokerListenerName)).thenReturn(metadataResponseTopic)
-      when(metadataCache.getAliveBrokerNode(0, config.interBrokerListenerName)).thenReturn(Some(node))
-      when(metadataCache.getAliveBrokerNode(1, config.interBrokerListenerName)).thenReturn(None)
+      // If we supply no transactional ID and idempotent records, we do not verify.
+      val idempotentRecords = MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId, producerEpoch, sequence,
+        new SimpleRecord("message".getBytes))
+      appendRecords(replicaManager, tp0, idempotentRecords)
+      verify(addPartitionsToTxnManager, times(0)).addTxnData(any(), any(), any[AddPartitionsToTxnManager.AppendCallback]())
+      assertNull(getVerificationGuard(replicaManager, tp0, producerId))
 
-      // We will attempt to schedule to the request handler thread using a non request handler thread. Set this to avoid error.
-      KafkaRequestHandler.setBypassThreadCheck(true)
+      // If we supply a transactional ID and some transactional and some idempotent records, we should only verify the topic partition with transactional records.
+      val transactionalRecords = MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, producerEpoch, sequence + 1,
+        new SimpleRecord("message".getBytes))
+
+      val transactionToAdd = new AddPartitionsToTxnTransaction()
+        .setTransactionalId(transactionalId)
+        .setProducerId(producerId)
+        .setProducerEpoch(producerEpoch)
+        .setVerifyOnly(true)
+        .setTopics(new AddPartitionsToTxnTopicCollection(
+          Seq(new AddPartitionsToTxnTopic().setName(tp0.topic).setPartitions(Collections.singletonList(tp0.partition))).iterator.asJava
+        ))
+
+      val idempotentRecords2 = MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId, producerEpoch, sequence,
+        new SimpleRecord("message".getBytes))
+      appendRecordsToMultipleTopics(replicaManager, Map(tp0 -> transactionalRecords, tp1 -> idempotentRecords2), transactionalId, Some(0))
+      verify(addPartitionsToTxnManager, times(1)).addTxnData(ArgumentMatchers.eq(node), ArgumentMatchers.eq(transactionToAdd), any[AddPartitionsToTxnManager.AppendCallback]())
+      assertNotNull(getVerificationGuard(replicaManager, tp0, producerId))
+      assertNull(getVerificationGuard(replicaManager, tp1, producerId))
+    } finally {
+      replicaManager.shutdown(checkpointHW = false)
+    }
+  }
+
+  @Test
+  def testTransactionVerificationFlow(): Unit = {
+    val tp0 = new TopicPartition(topic, 0)
+    val producerId = 24L
+    val producerEpoch = 0.toShort
+    val sequence = 6
+    val node = new Node(0, "host1", 0)
+    val addPartitionsToTxnManager = mock(classOf[AddPartitionsToTxnManager])
+
+    val replicaManager = setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager, List(tp0), node)
+    try {
+      replicaManager.becomeLeaderOrFollower(1,
+        makeLeaderAndIsrRequest(topicIds(tp0.topic), tp0, Seq(0, 1), LeaderAndIsr(1, List(0, 1))),
+        (_, _) => ())
 
       // Append some transactional records.
       val transactionalRecords = MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, producerEpoch, sequence,
-        new SimpleRecord(s"message $sequence".getBytes))
-      val result = appendRecords(replicaManager, tp, transactionalRecords, transactionalId = transactionalId, transactionStatePartition = Some(0))
+        new SimpleRecord("message".getBytes))
 
       val transactionToAdd = new AddPartitionsToTxnTransaction()
         .setTransactionalId(transactionalId)
@@ -2185,27 +2215,64 @@ class ReplicaManagerTest {
         .setProducerEpoch(producerEpoch)
         .setVerifyOnly(true)
         .setTopics(new AddPartitionsToTxnTopicCollection(
-          Seq(new AddPartitionsToTxnTopic().setName(tp.topic).setPartitions(Collections.singletonList(tp.partition))).iterator.asJava
+          Seq(new AddPartitionsToTxnTopic().setName(tp0.topic).setPartitions(Collections.singletonList(tp0.partition))).iterator.asJava
         ))
 
-      val appendCallback = ArgumentCaptor.forClass(classOf[AddPartitionsToTxnManager.AppendCallback])
       // We should add these partitions to the manager to verify.
+      val result = appendRecords(replicaManager, tp0, transactionalRecords, transactionalId = transactionalId, transactionStatePartition = Some(0))
+      val appendCallback = ArgumentCaptor.forClass(classOf[AddPartitionsToTxnManager.AppendCallback])
       verify(addPartitionsToTxnManager, times(1)).addTxnData(ArgumentMatchers.eq(node), ArgumentMatchers.eq(transactionToAdd), appendCallback.capture())
+      val verificationGuard = getVerificationGuard(replicaManager, tp0, producerId)
+      assertEquals(verificationGuard, getVerificationGuard(replicaManager, tp0, producerId))
 
       // Confirm we did not write to the log and instead returned error.
       val callback: AddPartitionsToTxnManager.AppendCallback = appendCallback.getValue()
-      callback(Map(tp -> Errors.INVALID_RECORD).toMap)
+      callback(Map(tp0 -> Errors.INVALID_RECORD).toMap)
       assertEquals(Errors.INVALID_RECORD, result.assertFired.error)
+      assertEquals(verificationGuard, getVerificationGuard(replicaManager, tp0, producerId))
+
+      // This time verification is successful.
+      appendRecords(replicaManager, tp0, transactionalRecords, transactionalId = transactionalId, transactionStatePartition = Some(0))
+      val appendCallback2 = ArgumentCaptor.forClass(classOf[AddPartitionsToTxnManager.AppendCallback])
+      verify(addPartitionsToTxnManager, times(2)).addTxnData(ArgumentMatchers.eq(node), ArgumentMatchers.eq(transactionToAdd), appendCallback2.capture())
+      assertEquals(verificationGuard, getVerificationGuard(replicaManager, tp0, producerId))
+
+      val callback2: AddPartitionsToTxnManager.AppendCallback = appendCallback2.getValue()
+      callback2(Map.empty[TopicPartition, Errors].toMap)
+      assertEquals(null, getVerificationGuard(replicaManager, tp0, producerId))
+      assertTrue(replicaManager.localLog(tp0).get.hasOngoingTransaction(producerId))
+    } finally {
+      replicaManager.shutdown(checkpointHW = false)
+    }
+  }
+
+  @Test
+  def testTransactionVerificationGuardOnMultiplePartitions(): Unit = {
+    val mockTimer = new MockTimer(time)
+    val tp0 = new TopicPartition(topic, 0)
+    val tp1 = new TopicPartition(topic, 1)
+    val producerId = 24L
+    val producerEpoch = 0.toShort
+    val sequence = 0
+
+    val replicaManager = setupReplicaManagerWithMockedPurgatories(mockTimer)
+    try {
+      replicaManager.becomeLeaderOrFollower(1,
+        makeLeaderAndIsrRequest(topicIds(tp0.topic), tp0, Seq(0, 1), LeaderAndIsr(0, List(0, 1))),
+        (_, _) => ())
+
+      replicaManager.becomeLeaderOrFollower(1,
+        makeLeaderAndIsrRequest(topicIds(tp1.topic), tp1, Seq(0, 1), LeaderAndIsr(0, List(0, 1))),
+        (_, _) => ())
 
-      // If we supply no transactional ID and idempotent records, we do not verify, so counter stays the same.
-      val idempotentRecords2 = MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId, producerEpoch, sequence + 1,
+      val transactionalRecords = MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, producerEpoch, sequence,
         new SimpleRecord(s"message $sequence".getBytes))
-      appendRecords(replicaManager, tp, idempotentRecords2)
-      verify(addPartitionsToTxnManager, times(1)).addTxnData(ArgumentMatchers.eq(node), ArgumentMatchers.eq(transactionToAdd), any[AddPartitionsToTxnManager.AppendCallback]())
 
-      // If we supply a transactional ID and some transactional and some idempotent records, we should only verify the topic partition with transactional records.
-      appendRecordsToMultipleTopics(replicaManager, Map(tp -> transactionalRecords, new TopicPartition(topic, 1) -> idempotentRecords2), transactionalId, Some(0))
-      verify(addPartitionsToTxnManager, times(2)).addTxnData(ArgumentMatchers.eq(node), ArgumentMatchers.eq(transactionToAdd), any[AddPartitionsToTxnManager.AppendCallback]())
+      appendRecordsToMultipleTopics(replicaManager, Map(tp0 -> transactionalRecords, tp1 -> transactionalRecords), transactionalId, Some(0)).onFire { responses =>
+        responses.foreach {
+          entry => assertEquals(Errors.NONE, entry._2)
+        }
+      }
     } finally {
       replicaManager.shutdown(checkpointHW = false)
     }
@@ -2220,21 +2287,10 @@ class ReplicaManagerTest {
     val producerEpoch = 0.toShort
     val sequence = 0
 
-    val mockLogMgr = TestUtils.createLogManager(config.logDirs.map(new File(_)))
-    val metadataCache = mock(classOf[MetadataCache])
+    val node = new Node(0, "host1", 0)
     val addPartitionsToTxnManager = mock(classOf[AddPartitionsToTxnManager])
 
-    val replicaManager = new ReplicaManager(
-      metrics = metrics,
-      config = config,
-      time = time,
-      scheduler = new MockScheduler(time),
-      logManager = mockLogMgr,
-      quotaManagers = quotaManager,
-      metadataCache = metadataCache,
-      logDirFailureChannel = new LogDirFailureChannel(config.logDirs.size),
-      alterPartitionManager = alterPartitionManager,
-      addPartitionsToTxnManager = Some(addPartitionsToTxnManager))
+    val replicaManager = setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager, List(tp0, tp1), node)
 
     try {
       replicaManager.becomeLeaderOrFollower(1,
@@ -2254,13 +2310,49 @@ class ReplicaManagerTest {
 
       assertThrows(classOf[InvalidPidMappingException],
         () => appendRecordsToMultipleTopics(replicaManager, transactionalRecords, transactionalId = transactionalId, transactionStatePartition = Some(0)))
+      // We should not add these partitions to the manager to verify.
+      verify(addPartitionsToTxnManager, times(0)).addTxnData(any(), any(), any())
     } finally {
       replicaManager.shutdown(checkpointHW = false)
     }
   }
 
   @Test
-  def testDisabledVerification(): Unit = {
+  def testTransactionVerificationWhenNotLeader(): Unit = {
+    val tp0 = new TopicPartition(topic, 0)
+    val producerId = 24L
+    val producerEpoch = 0.toShort
+    val sequence = 6
+    val node = new Node(0, "host1", 0)
+    val addPartitionsToTxnManager = mock(classOf[AddPartitionsToTxnManager])
+
+    val replicaManager = setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager, List(tp0), node)
+    try {
+      // Append some transactional records.
+      val transactionalRecords = MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, producerEpoch, sequence,
+        new SimpleRecord("message".getBytes))
+
+      val transactionToAdd = new AddPartitionsToTxnTransaction()
+        .setTransactionalId(transactionalId)
+        .setProducerId(producerId)
+        .setProducerEpoch(producerEpoch)
+        .setVerifyOnly(true)
+        .setTopics(new AddPartitionsToTxnTopicCollection(
+          Seq(new AddPartitionsToTxnTopic().setName(tp0.topic).setPartitions(Collections.singletonList(tp0.partition))).iterator.asJava
+        ))
+
+      // We should not add these partitions to the manager to verify, but instead throw an error.
+      appendRecords(replicaManager, tp0, transactionalRecords, transactionalId = transactionalId, transactionStatePartition = Some(0)).onFire { response =>
+        assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, response.error)
+      }
+      verify(addPartitionsToTxnManager, times(0)).addTxnData(ArgumentMatchers.eq(node), ArgumentMatchers.eq(transactionToAdd), any[AddPartitionsToTxnManager.AppendCallback]())
+    } finally {
+      replicaManager.shutdown(checkpointHW = false)
+    }
+  }
+
+  @Test
+  def testDisabledTransactionVerification(): Unit = {
     val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect)
     props.put("transaction.partition.verification.enable", "false")
     val config = KafkaConfig.fromProps(props)
@@ -2271,36 +2363,21 @@ class ReplicaManagerTest {
     val producerEpoch = 0.toShort
     val sequence = 0
 
-    val mockLogMgr = TestUtils.createLogManager(config.logDirs.map(new File(_)))
-    val metadataCache = mock(classOf[MetadataCache])
+    val node = new Node(0, "host1", 0)
     val addPartitionsToTxnManager = mock(classOf[AddPartitionsToTxnManager])
 
-    val replicaManager = new ReplicaManager(
-      metrics = metrics,
-      config = config,
-      time = time,
-      scheduler = new MockScheduler(time),
-      logManager = mockLogMgr,
-      quotaManagers = quotaManager,
-      metadataCache = metadataCache,
-      logDirFailureChannel = new LogDirFailureChannel(config.logDirs.size),
-      alterPartitionManager = alterPartitionManager,
-      addPartitionsToTxnManager = Some(addPartitionsToTxnManager))
+    val replicaManager = setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager, List(tp), node, config = config)
 
     try {
       val becomeLeaderRequest = makeLeaderAndIsrRequest(topicIds(tp.topic), tp, Seq(0, 1), LeaderAndIsr(0, List(0, 1)))
       replicaManager.becomeLeaderOrFollower(1, becomeLeaderRequest, (_, _) => ())
 
-      when(metadataCache.contains(tp)).thenReturn(true)
-
       val transactionalRecords = MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, producerEpoch, sequence,
         new SimpleRecord(s"message $sequence".getBytes))
       appendRecords(replicaManager, tp, transactionalRecords, transactionalId = transactionalId, transactionStatePartition = Some(0))
+      assertNull(getVerificationGuard(replicaManager, tp, producerId))
 
       // We should not add these partitions to the manager to verify.
-      verify(metadataCache,  times(0)).getTopicMetadata(any(), any(), any(), any())
-      verify(metadataCache, times(0)).getAliveBrokerNode(any(), any())
-      verify(metadataCache, times(0)).getAliveBrokerNode(any(), any())
       verify(addPartitionsToTxnManager, times(0)).addTxnData(any(), any(), any())
     } finally {
       replicaManager.shutdown(checkpointHW = false)
@@ -2631,9 +2708,11 @@ class ReplicaManagerTest {
                                             transactionalId: String,
                                             transactionStatePartition: Option[Int],
                                             origin: AppendOrigin = AppendOrigin.CLIENT,
-                                            requiredAcks: Short = -1): Unit = {
+                                            requiredAcks: Short = -1): CallbackResult[Map[TopicPartition, PartitionResponse]] = {
+    val result = new CallbackResult[Map[TopicPartition, PartitionResponse]]()
     def appendCallback(responses: Map[TopicPartition, PartitionResponse]): Unit = {
-      responses.foreach( response => responses.get(response._1).isDefined)
+      responses.foreach( response => assertTrue(responses.get(response._1).isDefined))
+      result.fire(responses)
     }
 
     replicaManager.appendRecords(
@@ -2645,6 +2724,8 @@ class ReplicaManagerTest {
       responseCallback = appendCallback,
       transactionalId = transactionalId,
       transactionStatePartition = transactionStatePartition)
+
+    result
   }
 
   private def fetchPartitionAsConsumer(
@@ -2763,6 +2844,48 @@ class ReplicaManagerTest {
     )
   }
 
+  private def getVerificationGuard(replicaManager: ReplicaManager,
+                                   tp: TopicPartition,
+                                   producerId: Long): Object = {
+    replicaManager.getPartitionOrException(tp).log.get.getOrMaybeCreateVerificationGuard(producerId)
+  }
+
+  private def setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager: AddPartitionsToTxnManager,
+                                                                     transactionalTopicPartitions: List[TopicPartition],
+                                                                     node: Node,
+                                                                     config: KafkaConfig = config): ReplicaManager = {
+    val mockLogMgr = TestUtils.createLogManager(config.logDirs.map(new File(_)))
+    val metadataCache = mock(classOf[MetadataCache])
+
+    val replicaManager = new ReplicaManager(
+      metrics = metrics,
+      config = config,
+      time = time,
+      scheduler = new MockScheduler(time),
+      logManager = mockLogMgr,
+      quotaManagers = quotaManager,
+      metadataCache = metadataCache,
+      logDirFailureChannel = new LogDirFailureChannel(config.logDirs.size),
+      alterPartitionManager = alterPartitionManager,
+      addPartitionsToTxnManager = Some(addPartitionsToTxnManager))
+
+    val metadataResponseTopic = Seq(new MetadataResponseTopic()
+      .setName(Topic.TRANSACTION_STATE_TOPIC_NAME)
+      .setPartitions(Seq(
+        new MetadataResponsePartition()
+          .setPartitionIndex(0)
+          .setLeaderId(0)).asJava))
+
+    transactionalTopicPartitions.foreach(tp => when(metadataCache.contains(tp)).thenReturn(true))
+    when(metadataCache.getTopicMetadata(Set(Topic.TRANSACTION_STATE_TOPIC_NAME), config.interBrokerListenerName)).thenReturn(metadataResponseTopic)
+    when(metadataCache.getAliveBrokerNode(0, config.interBrokerListenerName)).thenReturn(Some(node))
+    when(metadataCache.getAliveBrokerNode(1, config.interBrokerListenerName)).thenReturn(None)
+
+    // We will attempt to schedule to the request handler thread using a non request handler thread. Set this to avoid error.
+    KafkaRequestHandler.setBypassThreadCheck(true)
+    replicaManager
+  }
+
   private def setupReplicaManagerWithMockedPurgatories(
     timer: MockTimer,
     brokerId: Int = 0,
@@ -2796,6 +2919,18 @@ class ReplicaManagerTest {
     val mockDelayedElectLeaderPurgatory = new DelayedOperationPurgatory[DelayedElectLeader](
       purgatoryName = "DelayedElectLeader", timer, reaperEnabled = false)
 
+    // Set up transactions
+    val metadataResponseTopic = Seq(new MetadataResponseTopic()
+      .setName(Topic.TRANSACTION_STATE_TOPIC_NAME)
+      .setPartitions(Seq(
+        new MetadataResponsePartition()
+          .setPartitionIndex(0)
+          .setLeaderId(0)).asJava))
+    when(metadataCache.contains(new TopicPartition(topic, 0))).thenReturn(true)
+    when(metadataCache.getTopicMetadata(Set(Topic.TRANSACTION_STATE_TOPIC_NAME), config.interBrokerListenerName)).thenReturn(metadataResponseTopic)
+    // Transactional appends attempt to schedule to the request handler thread using a non request handler thread. Set this to avoid error.
+    KafkaRequestHandler.setBypassThreadCheck(true)
+
     new ReplicaManager(
       metrics = metrics,
       config = config,
@@ -2812,7 +2947,8 @@ class ReplicaManagerTest {
       delayedDeleteRecordsPurgatoryParam = Some(mockDeleteRecordsPurgatory),
       delayedElectLeaderPurgatoryParam = Some(mockDelayedElectLeaderPurgatory),
       threadNamePrefix = Option(this.getClass.getName),
-      remoteLogManager = if (enableRemoteStorage) Some(mockRemoteLogManager) else None) {
+      remoteLogManager = if (enableRemoteStorage) Some(mockRemoteLogManager) else None,
+      addPartitionsToTxnManager = Some(addPartitionsToTxnManager)) {
 
       override protected def createReplicaFetcherManager(
         metrics: Metrics,
diff --git a/core/src/test/scala/unit/kafka/utils/TestUtils.scala b/core/src/test/scala/unit/kafka/utils/TestUtils.scala
index 396ca66ae64..9826cfee794 100755
--- a/core/src/test/scala/unit/kafka/utils/TestUtils.scala
+++ b/core/src/test/scala/unit/kafka/utils/TestUtils.scala
@@ -1407,7 +1407,8 @@ object TestUtils extends Logging {
                        cleanerConfig: CleanerConfig = new CleanerConfig(false),
                        time: MockTime = new MockTime(),
                        interBrokerProtocolVersion: MetadataVersion = MetadataVersion.latest,
-                       recoveryThreadsPerDataDir: Int = 4): LogManager = {
+                       recoveryThreadsPerDataDir: Int = 4,
+                       transactionVerificationEnabled: Boolean = false): LogManager = {
     new LogManager(logDirs = logDirs.map(_.getAbsoluteFile),
                    initialOfflineDirs = Array.empty[File],
                    configRepository = configRepository,
@@ -1419,7 +1420,7 @@ object TestUtils extends Logging {
                    flushStartOffsetCheckpointMs = 10000L,
                    retentionCheckMs = 1000L,
                    maxTransactionTimeoutMs = 5 * 60 * 1000,
-                   producerStateManagerConfig = new ProducerStateManagerConfig(kafka.server.Defaults.ProducerIdExpirationMs, false),
+                   producerStateManagerConfig = new ProducerStateManagerConfig(kafka.server.Defaults.ProducerIdExpirationMs, transactionVerificationEnabled),
                    producerIdExpirationCheckIntervalMs = kafka.server.Defaults.ProducerIdExpirationCheckIntervalMs,
                    scheduler = time.scheduler,
                    time = time,
diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/server/CheckpointBench.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/server/CheckpointBench.java
index 7f3bc14af85..71b28d9e2c4 100644
--- a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/server/CheckpointBench.java
+++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/server/CheckpointBench.java
@@ -108,7 +108,7 @@ public class CheckpointBench {
             JavaConverters.seqAsJavaList(brokerProperties.logDirs()).stream().map(File::new).collect(Collectors.toList());
         this.logManager = TestUtils.createLogManager(JavaConverters.asScalaBuffer(files),
                 new LogConfig(new Properties()), new MockConfigRepository(), new CleanerConfig(1, 4 * 1024 * 1024L, 0.9d,
-                        1024 * 1024, 32 * 1024 * 1024, Double.MAX_VALUE, 15 * 1000, true), time, MetadataVersion.latest(), 4);
+                        1024 * 1024, 32 * 1024 * 1024, Double.MAX_VALUE, 15 * 1000, true), time, MetadataVersion.latest(), 4, false);
         scheduler.startup();
         final BrokerTopicStats brokerTopicStats = new BrokerTopicStats();
         final MetadataCache metadataCache =
diff --git a/server-common/src/main/java/org/apache/kafka/server/util/InterBrokerSendThread.java b/server-common/src/main/java/org/apache/kafka/server/util/InterBrokerSendThread.java
index 227c7e0064a..7b27d1c255c 100644
--- a/server-common/src/main/java/org/apache/kafka/server/util/InterBrokerSendThread.java
+++ b/server-common/src/main/java/org/apache/kafka/server/util/InterBrokerSendThread.java
@@ -118,6 +118,10 @@ public abstract class InterBrokerSendThread extends ShutdownableThread {
                 // DisconnectException is expected when NetworkClient#initiateClose is called
                 return;
             }
+            if (t instanceof InterruptedException && !isRunning()) {
+                // InterruptedException is expected when shutting down. Throw the error to ShutdownableThread to handle
+                throw t;
+            }
             log.error("unhandled exception caught in InterBrokerSendThread", t);
             // rethrow any unhandled exceptions as FatalExitError so the JVM will be terminated
             // as we will be in an unknown state with potentially some requests dropped and not
diff --git a/server-common/src/test/java/org/apache/kafka/server/util/InterBrokerSendThreadTest.java b/server-common/src/test/java/org/apache/kafka/server/util/InterBrokerSendThreadTest.java
index c9b638d88f4..f810cc6e639 100644
--- a/server-common/src/test/java/org/apache/kafka/server/util/InterBrokerSendThreadTest.java
+++ b/server-common/src/test/java/org/apache/kafka/server/util/InterBrokerSendThreadTest.java
@@ -35,6 +35,7 @@ import java.util.Collections;
 import java.util.Queue;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Consumer;
+
 import org.apache.kafka.clients.ClientRequest;
 import org.apache.kafka.clients.ClientResponse;
 import org.apache.kafka.clients.KafkaClient;
@@ -47,6 +48,8 @@ import org.apache.kafka.common.protocol.ApiKeys;
 import org.apache.kafka.common.requests.AbstractRequest;
 import org.apache.kafka.common.utils.Time;
 import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
 import org.mockito.ArgumentMatchers;
 
 public class InterBrokerSendThreadTest {
@@ -299,6 +302,40 @@ public class InterBrokerSendThreadTest {
         assertTrue(completionHandler.executedWithDisconnectedResponse);
     }
 
+    @ParameterizedTest
+    @ValueSource(booleans = {true, false})
+    public void testInterruption(boolean isShuttingDown) throws InterruptedException, IOException {
+        Exception interrupted = new InterruptedException();
+
+        // InterBrokerSendThread#shutdown calls NetworkClient#initiateClose first so NetworkClient#poll
+        // can throw InterruptedException if a callback request that throws it is handled
+        when(networkClient.poll(anyLong(), anyLong())).thenAnswer(t -> {
+            throw interrupted;
+        });
+
+        AtomicReference<Throwable> exception = new AtomicReference<>();
+        final InterBrokerSendThread thread =
+            new TestInterBrokerSendThread(networkClient, t -> {
+                if (isShuttingDown)
+                    assertTrue(t instanceof InterruptedException);
+                else
+                    assertTrue(t instanceof FatalExitError);
+                exception.getAndSet(t);
+            });
+
+        if (isShuttingDown)
+            thread.shutdown();
+        thread.pollOnce(100);
+
+        verify(networkClient).poll(anyLong(), anyLong());
+        if (isShuttingDown) {
+            verify(networkClient).initiateClose();
+            verify(networkClient).close();
+        }
+        verifyNoMoreInteractions(networkClient);
+        assertNotNull(exception.get());
+    }
+
     private static class StubRequestBuilder<T extends AbstractRequest>
         extends AbstractRequest.Builder<T> {
 
diff --git a/storage/src/main/java/org/apache/kafka/storage/internals/log/ProducerStateManager.java b/storage/src/main/java/org/apache/kafka/storage/internals/log/ProducerStateManager.java
index 6212711cb27..6b88f734189 100644
--- a/storage/src/main/java/org/apache/kafka/storage/internals/log/ProducerStateManager.java
+++ b/storage/src/main/java/org/apache/kafka/storage/internals/log/ProducerStateManager.java
@@ -113,6 +113,8 @@ public class ProducerStateManager {
 
     private final Map<Long, ProducerStateEntry> producers = new HashMap<>();
 
+    private final Map<Long, VerificationStateEntry> verificationStates = new HashMap<>();
+
     // ongoing transactions sorted by the first offset of the transaction
     private final TreeMap<Long, TxnMetadata> ongoingTxns = new TreeMap<>();
 
@@ -184,6 +186,26 @@ public class ProducerStateManager {
         producerIdCount = 0;
     }
 
+    /**
+     * Maybe create the VerificationStateEntry for a given producer ID. Return it if it exists, otherwise return null.
+     */
+    public VerificationStateEntry verificationStateEntry(long producerId, boolean createIfAbsent) {
+        return verificationStates.computeIfAbsent(producerId, pid -> {
+            if (createIfAbsent)
+                return new VerificationStateEntry(time.milliseconds());
+            else {
+                return null;
+            }
+        });
+    }
+
+    /**
+     * Clear the verificationStateEntry for the given producer ID.
+     */
+    public void clearVerificationStateEntry(long producerId) {
+        verificationStates.remove(producerId);
+    }
+
     /**
      * Load producer state snapshots by scanning the logDir.
      */
@@ -338,6 +360,7 @@ public class ProducerStateManager {
 
     /**
      * Expire any producer ids which have been idle longer than the configured maximum expiration timeout.
+     * Also expire any verification state entries that are lingering as unverified.
      */
     public void removeExpiredProducers(long currentTimeMs) {
         List<Long> keys = producers.entrySet().stream()
@@ -345,6 +368,12 @@ public class ProducerStateManager {
                 .map(Map.Entry::getKey)
                 .collect(Collectors.toList());
         removeProducerIds(keys);
+
+        List<Long> verificationKeys = verificationStates.entrySet().stream()
+                .filter(entry -> currentTimeMs - entry.getValue().timestamp() >= producerStateManagerConfig.producerIdExpirationMs())
+                .map(Map.Entry::getKey)
+                .collect(Collectors.toList());
+        verificationKeys.forEach(verificationStates::remove);
     }
 
     /**
diff --git a/storage/src/main/java/org/apache/kafka/storage/internals/log/VerificationStateEntry.java b/storage/src/main/java/org/apache/kafka/storage/internals/log/VerificationStateEntry.java
new file mode 100644
index 00000000000..2464581fbe0
--- /dev/null
+++ b/storage/src/main/java/org/apache/kafka/storage/internals/log/VerificationStateEntry.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.storage.internals.log;
+
+/**
+ * This class represents the verification state of a specific producer id.
+ * It contains a verification guard object that is used to uniquely identify the transaction we want to verify.
+ * After verifying, we retain this object until we append to the log. This prevents any race conditions where the transaction
+ * may end via a control marker before we write to the log. This mechanism is used to prevent hanging transactions.
+ * We remove the verification guard object whenever we write data to the transaction or write an end marker for the transaction.
+ * Any lingering entries that are never verified are removed via the producer state entry cleanup mechanism.
+ */
+public class VerificationStateEntry {
+
+    final private long timestamp;
+    final private Object verificationGuard;
+
+    public VerificationStateEntry(long timestamp) {
+        this.timestamp = timestamp;
+        this.verificationGuard = new Object();
+    }
+
+    public long timestamp() {
+        return timestamp;
+    }
+
+    public Object verificationGuard() {
+        return verificationGuard;
+    }
+}