You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jg...@apache.org on 2022/02/02 19:59:38 UTC

[kafka] branch trunk updated: KAFKA-13221; Implement `PartitionsWithLateTransactionsCount` metric (#11725)

This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 9159914  KAFKA-13221; Implement `PartitionsWithLateTransactionsCount` metric (#11725)
9159914 is described below

commit 915991445fde106d02e61a70425ae2601c813db0
Author: Jason Gustafson <ja...@confluent.io>
AuthorDate: Wed Feb 2 11:57:13 2022 -0800

    KAFKA-13221; Implement `PartitionsWithLateTransactionsCount` metric (#11725)
    
    This patch implements a new metric `PartitionsWithLateTransactionsCount` which tracks the number of partitions with late transactions in the cluster. This metric was documented in KIP-664: https://cwiki.apache.org/confluence/display/KAFKA/KIP-664%3A+Provide+tooling+to+detect+and+abort+hanging+transactions.
    
    Reviewers: David Jacot <dj...@confluent.io>
---
 .../kafka/server/builders/LogManagerBuilder.java   |   9 +-
 core/src/main/scala/kafka/cluster/Partition.scala  |   2 +
 .../transaction/TransactionCoordinator.scala       |   5 +-
 core/src/main/scala/kafka/log/LogLoader.scala      |   6 +-
 core/src/main/scala/kafka/log/LogManager.scala     |   4 +
 .../scala/kafka/log/ProducerStateManager.scala     |  32 +++++
 core/src/main/scala/kafka/log/UnifiedLog.scala     |  10 +-
 .../main/scala/kafka/raft/KafkaMetadataLog.scala   |   1 +
 .../main/scala/kafka/server/ReplicaManager.scala   |   7 +
 .../src/test/scala/other/kafka/StressTestLog.scala |   1 +
 .../scala/other/kafka/TestLinearWriteSpeed.scala   |  17 ++-
 .../unit/kafka/cluster/PartitionLockTest.scala     |   4 +-
 .../scala/unit/kafka/cluster/PartitionTest.scala   |  11 +-
 .../scala/unit/kafka/cluster/ReplicaTest.scala     |   9 +-
 .../TransactionCoordinatorConcurrencyTest.scala    |   2 +-
 .../transaction/TransactionCoordinatorTest.scala   |   1 -
 .../log/AbstractLogCleanerIntegrationTest.scala    |   6 +-
 .../unit/kafka/log/BrokerCompressionTest.scala     |  17 ++-
 .../unit/kafka/log/LogCleanerManagerTest.scala     |  30 +++--
 .../test/scala/unit/kafka/log/LogCleanerTest.scala |  25 +++-
 .../scala/unit/kafka/log/LogConcurrencyTest.scala  |   4 +-
 .../test/scala/unit/kafka/log/LogLoaderTest.scala  |  57 +++++---
 .../test/scala/unit/kafka/log/LogSegmentTest.scala |  19 +--
 .../test/scala/unit/kafka/log/LogTestUtils.scala   |   8 +-
 .../unit/kafka/log/ProducerStateManagerTest.scala  | 143 ++++++++++++++++++---
 .../test/scala/unit/kafka/log/UnifiedLogTest.scala |   4 +-
 .../unit/kafka/server/ReplicaManagerTest.scala     |  70 +++++++++-
 .../unit/kafka/tools/DumpLogSegmentsTest.scala     |  17 ++-
 .../scala/unit/kafka/utils/SchedulerTest.scala     |   5 +-
 .../test/scala/unit/kafka/utils/TestUtils.scala    |  12 +-
 30 files changed, 437 insertions(+), 101 deletions(-)

diff --git a/core/src/main/java/kafka/server/builders/LogManagerBuilder.java b/core/src/main/java/kafka/server/builders/LogManagerBuilder.java
index 0082040..3ebe7fa 100644
--- a/core/src/main/java/kafka/server/builders/LogManagerBuilder.java
+++ b/core/src/main/java/kafka/server/builders/LogManagerBuilder.java
@@ -26,11 +26,11 @@ import kafka.server.LogDirFailureChannel;
 import kafka.server.metadata.ConfigRepository;
 import kafka.utils.Scheduler;
 import org.apache.kafka.common.utils.Time;
+import scala.collection.JavaConverters;
 
 import java.io.File;
 import java.util.Collections;
 import java.util.List;
-import scala.collection.JavaConverters;
 
 
 public class LogManagerBuilder {
@@ -44,6 +44,7 @@ public class LogManagerBuilder {
     private long flushRecoveryOffsetCheckpointMs = 10000L;
     private long flushStartOffsetCheckpointMs = 10000L;
     private long retentionCheckMs = 1000L;
+    private int maxTransactionTimeoutMs = 15 * 60 * 1000;
     private int maxPidExpirationMs = 60000;
     private ApiVersion interBrokerProtocolVersion = ApiVersion.latestVersion();
     private Scheduler scheduler = null;
@@ -102,6 +103,11 @@ public class LogManagerBuilder {
         return this;
     }
 
+    public LogManagerBuilder setMaxTransactionTimeoutMs(int maxTransactionTimeoutMs) {
+        this.maxTransactionTimeoutMs = maxTransactionTimeoutMs;
+        return this;
+    }
+
     public LogManagerBuilder setMaxPidExpirationMs(int maxPidExpirationMs) {
         this.maxPidExpirationMs = maxPidExpirationMs;
         return this;
@@ -156,6 +162,7 @@ public class LogManagerBuilder {
                               flushRecoveryOffsetCheckpointMs,
                               flushStartOffsetCheckpointMs,
                               retentionCheckMs,
+                              maxTransactionTimeoutMs,
                               maxPidExpirationMs,
                               interBrokerProtocolVersion,
                               scheduler,
diff --git a/core/src/main/scala/kafka/cluster/Partition.scala b/core/src/main/scala/kafka/cluster/Partition.scala
index 371e895..150432d 100755
--- a/core/src/main/scala/kafka/cluster/Partition.scala
+++ b/core/src/main/scala/kafka/cluster/Partition.scala
@@ -274,6 +274,8 @@ class Partition(val topicPartition: TopicPartition,
   newGauge("ReplicasCount", () => if (isLeader) assignmentState.replicationFactor else 0, tags)
   newGauge("LastStableOffsetLag", () => log.map(_.lastStableOffsetLag).getOrElse(0), tags)
 
+  def hasLateTransaction(currentTimeMs: Long): Boolean = leaderLogIfLocal.exists(_.hasLateTransaction(currentTimeMs))
+
   def isUnderReplicated: Boolean = isLeader && (assignmentState.replicationFactor - isrState.isr.size) > 0
 
   def isUnderMinIsr: Boolean = leaderLogIfLocal.exists { isrState.isr.size < _.config.minInSyncReplicas }
diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
index 78983c1..ac90388 100644
--- a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
+++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
@@ -57,7 +57,7 @@ object TransactionCoordinator {
     val txnMarkerChannelManager = TransactionMarkerChannelManager(config, metrics, metadataCache, txnStateManager,
       time, logContext)
 
-    new TransactionCoordinator(config.brokerId, txnConfig, scheduler, createProducerIdGenerator, txnStateManager, txnMarkerChannelManager,
+    new TransactionCoordinator(txnConfig, scheduler, createProducerIdGenerator, txnStateManager, txnMarkerChannelManager,
       time, logContext)
   }
 
@@ -78,8 +78,7 @@ object TransactionCoordinator {
  * producers. Producers with specific transactional ids are assigned to their corresponding coordinators;
  * Producers with no specific transactional id may talk to a random broker as their coordinators.
  */
-class TransactionCoordinator(brokerId: Int,
-                             txnConfig: TransactionConfig,
+class TransactionCoordinator(txnConfig: TransactionConfig,
                              scheduler: Scheduler,
                              createProducerIdManager: () => ProducerIdManager,
                              txnManager: TransactionStateManager,
diff --git a/core/src/main/scala/kafka/log/LogLoader.scala b/core/src/main/scala/kafka/log/LogLoader.scala
index e22dcd4..eb9dec7 100644
--- a/core/src/main/scala/kafka/log/LogLoader.scala
+++ b/core/src/main/scala/kafka/log/LogLoader.scala
@@ -62,8 +62,6 @@ object LogLoader extends Logging {
  *                 populated
  * @param logStartOffsetCheckpoint The checkpoint of the log start offset
  * @param recoveryPointCheckpoint The checkpoint of the offset at which to begin the recovery
- * @param maxProducerIdExpirationMs The maximum amount of time to wait before a producer id is
- *                                  considered expired
  * @param leaderEpochCache An optional LeaderEpochFileCache instance to be updated during recovery
  * @param producerStateManager The ProducerStateManager instance to be updated during recovery
  */
@@ -78,7 +76,6 @@ class LogLoader(
   segments: LogSegments,
   logStartOffsetCheckpoint: Long,
   recoveryPointCheckpoint: Long,
-  maxProducerIdExpirationMs: Int,
   leaderEpochCache: Option[LeaderEpochFileCache],
   producerStateManager: ProducerStateManager
 ) extends Logging {
@@ -354,7 +351,8 @@ class LogLoader(
     val producerStateManager = new ProducerStateManager(
       topicPartition,
       dir,
-      maxProducerIdExpirationMs,
+      this.producerStateManager.maxTransactionTimeoutMs,
+      this.producerStateManager.maxProducerIdExpirationMs,
       time)
     UnifiedLog.rebuildProducerState(
       producerStateManager,
diff --git a/core/src/main/scala/kafka/log/LogManager.scala b/core/src/main/scala/kafka/log/LogManager.scala
index 66dc581..b81f6a9 100755
--- a/core/src/main/scala/kafka/log/LogManager.scala
+++ b/core/src/main/scala/kafka/log/LogManager.scala
@@ -63,6 +63,7 @@ class LogManager(logDirs: Seq[File],
                  val flushRecoveryOffsetCheckpointMs: Long,
                  val flushStartOffsetCheckpointMs: Long,
                  val retentionCheckMs: Long,
+                 val maxTransactionTimeoutMs: Int,
                  val maxPidExpirationMs: Int,
                  interBrokerProtocolVersion: ApiVersion,
                  scheduler: Scheduler,
@@ -271,6 +272,7 @@ class LogManager(logDirs: Seq[File],
       config = config,
       logStartOffset = logStartOffset,
       recoveryPoint = logRecoveryPoint,
+      maxTransactionTimeoutMs = maxTransactionTimeoutMs,
       maxProducerIdExpirationMs = maxPidExpirationMs,
       producerIdExpirationCheckIntervalMs = LogManager.ProducerIdExpirationCheckIntervalMs,
       scheduler = scheduler,
@@ -882,6 +884,7 @@ class LogManager(logDirs: Seq[File],
           config = config,
           logStartOffset = 0L,
           recoveryPoint = 0L,
+          maxTransactionTimeoutMs = maxTransactionTimeoutMs,
           maxProducerIdExpirationMs = maxPidExpirationMs,
           producerIdExpirationCheckIntervalMs = LogManager.ProducerIdExpirationCheckIntervalMs,
           scheduler = scheduler,
@@ -1307,6 +1310,7 @@ object LogManager {
       flushRecoveryOffsetCheckpointMs = config.logFlushOffsetCheckpointIntervalMs,
       flushStartOffsetCheckpointMs = config.logFlushStartOffsetCheckpointIntervalMs,
       retentionCheckMs = config.logCleanupIntervalMs,
+      maxTransactionTimeoutMs = config.transactionMaxTimeoutMs,
       maxPidExpirationMs = config.transactionalIdExpirationMs,
       scheduler = kafkaScheduler,
       brokerTopicStats = brokerTopicStats,
diff --git a/core/src/main/scala/kafka/log/ProducerStateManager.scala b/core/src/main/scala/kafka/log/ProducerStateManager.scala
index 5f5c225..8313f24 100644
--- a/core/src/main/scala/kafka/log/ProducerStateManager.scala
+++ b/core/src/main/scala/kafka/log/ProducerStateManager.scala
@@ -346,6 +346,8 @@ private[log] class ProducerAppendInfo(val topicPartition: TopicPartition,
 }
 
 object ProducerStateManager {
+  val LateTransactionBufferMs = 5 * 60 * 1000
+
   private val ProducerSnapshotVersion: Short = 1
   private val VersionField = "version"
   private val CrcField = "crc"
@@ -483,6 +485,7 @@ object ProducerStateManager {
 @nonthreadsafe
 class ProducerStateManager(val topicPartition: TopicPartition,
                            @volatile var _logDir: File,
+                           val maxTransactionTimeoutMs: Int,
                            val maxProducerIdExpirationMs: Int = 60 * 60 * 1000,
                            val time: Time = Time.SYSTEM) extends Logging {
   import ProducerStateManager._
@@ -498,12 +501,23 @@ class ProducerStateManager(val topicPartition: TopicPartition,
   private var lastMapOffset = 0L
   private var lastSnapOffset = 0L
 
+  // Keep track of the last timestamp from the oldest transaction. This is used
+  // to detect (approximately) when a transaction has been left hanging on a partition.
+  // We make the field volatile so that it can be safely accessed without a lock.
+  @volatile private var oldestTxnLastTimestamp: Long = -1L
+
   // ongoing transactions sorted by the first offset of the transaction
   private val ongoingTxns = new util.TreeMap[Long, TxnMetadata]
 
   // completed transactions whose markers are at offsets above the high watermark
   private val unreplicatedTxns = new util.TreeMap[Long, TxnMetadata]
 
+  @threadsafe
+  def hasLateTransaction(currentTimeMs: Long): Boolean = {
+    val lastTimestamp = oldestTxnLastTimestamp
+    lastTimestamp > 0 && (currentTimeMs - lastTimestamp) > maxTransactionTimeoutMs + ProducerStateManager.LateTransactionBufferMs
+  }
+
   /**
    * Load producer state snapshots by scanning the _logDir.
    */
@@ -612,6 +626,7 @@ class ProducerStateManager(val topicPartition: TopicPartition,
             loadedProducers.foreach(loadProducerEntry)
             lastSnapOffset = snapshot.offset
             lastMapOffset = lastSnapOffset
+            updateOldestTxnTimestamp()
             return
           } catch {
             case e: CorruptSnapshotException =>
@@ -664,6 +679,7 @@ class ProducerStateManager(val topicPartition: TopicPartition,
     if (logEndOffset != mapEndOffset) {
       producers.clear()
       ongoingTxns.clear()
+      updateOldestTxnTimestamp()
 
       // since we assume that the offset is less than or equal to the high watermark, it is
       // safe to clear the unreplicated transactions
@@ -700,6 +716,20 @@ class ProducerStateManager(val topicPartition: TopicPartition,
     appendInfo.startedTransactions.foreach { txn =>
       ongoingTxns.put(txn.firstOffset.messageOffset, txn)
     }
+
+    updateOldestTxnTimestamp()
+  }
+
+  private def updateOldestTxnTimestamp(): Unit = {
+    val firstEntry = ongoingTxns.firstEntry()
+    if (firstEntry == null) {
+      oldestTxnLastTimestamp = -1
+    } else {
+      val oldestTxnMetadata = firstEntry.getValue
+      oldestTxnLastTimestamp = producers.get(oldestTxnMetadata.producerId)
+        .map(_.lastTimestamp)
+        .getOrElse(-1L)
+    }
   }
 
   def updateMapEndOffset(lastOffset: Long): Unit = {
@@ -789,6 +819,7 @@ class ProducerStateManager(val topicPartition: TopicPartition,
     }
     lastSnapOffset = 0L
     lastMapOffset = offset
+    updateOldestTxnTimestamp()
   }
 
   /**
@@ -813,6 +844,7 @@ class ProducerStateManager(val topicPartition: TopicPartition,
 
     txnMetadata.lastOffset = Some(completedTxn.lastOffset)
     unreplicatedTxns.put(completedTxn.firstOffset, txnMetadata)
+    updateOldestTxnTimestamp()
   }
 
   @threadsafe
diff --git a/core/src/main/scala/kafka/log/UnifiedLog.scala b/core/src/main/scala/kafka/log/UnifiedLog.scala
index a69273b..11f66aa 100644
--- a/core/src/main/scala/kafka/log/UnifiedLog.scala
+++ b/core/src/main/scala/kafka/log/UnifiedLog.scala
@@ -614,6 +614,11 @@ class UnifiedLog(@volatile var logStartOffset: Long,
       reloadFromCleanShutdown = false, logIdent)
   }
 
+  @threadsafe
+  def hasLateTransaction(currentTimeMs: Long): Boolean = {
+    producerStateManager.hasLateTransaction(currentTimeMs)
+  }
+
   def activeProducers: Seq[DescribeProducersResponseData.ProducerState] = {
     lock synchronized {
       producerStateManager.activeProducers.map { case (producerId, state) =>
@@ -1771,6 +1776,7 @@ object UnifiedLog extends Logging {
             scheduler: Scheduler,
             brokerTopicStats: BrokerTopicStats,
             time: Time = Time.SYSTEM,
+            maxTransactionTimeoutMs: Int,
             maxProducerIdExpirationMs: Int,
             producerIdExpirationCheckIntervalMs: Int,
             logDirFailureChannel: LogDirFailureChannel,
@@ -1787,7 +1793,8 @@ object UnifiedLog extends Logging {
       logDirFailureChannel,
       config.recordVersion,
       s"[UnifiedLog partition=$topicPartition, dir=${dir.getParent}] ")
-    val producerStateManager = new ProducerStateManager(topicPartition, dir, maxProducerIdExpirationMs)
+    val producerStateManager = new ProducerStateManager(topicPartition, dir,
+      maxTransactionTimeoutMs, maxProducerIdExpirationMs)
     val offsets = new LogLoader(
       dir,
       topicPartition,
@@ -1799,7 +1806,6 @@ object UnifiedLog extends Logging {
       segments,
       logStartOffset,
       recoveryPoint,
-      maxProducerIdExpirationMs,
       leaderEpochCache,
       producerStateManager
     ).load()
diff --git a/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala b/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala
index 8c9132b..1b0aef3 100644
--- a/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala
+++ b/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala
@@ -571,6 +571,7 @@ object KafkaMetadataLog {
       scheduler = scheduler,
       brokerTopicStats = new BrokerTopicStats,
       time = time,
+      maxTransactionTimeoutMs = Int.MaxValue,
       maxProducerIdExpirationMs = Int.MaxValue,
       producerIdExpirationCheckIntervalMs = Int.MaxValue,
       logDirFailureChannel = new LogDirFailureChannel(5),
diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala b/core/src/main/scala/kafka/server/ReplicaManager.scala
index 0dadeed..22f2755 100644
--- a/core/src/main/scala/kafka/server/ReplicaManager.scala
+++ b/core/src/main/scala/kafka/server/ReplicaManager.scala
@@ -257,9 +257,15 @@ class ReplicaManager(val config: KafkaConfig,
   newGauge("UnderMinIsrPartitionCount", () => leaderPartitionsIterator.count(_.isUnderMinIsr))
   newGauge("AtMinIsrPartitionCount", () => leaderPartitionsIterator.count(_.isAtMinIsr))
   newGauge("ReassigningPartitions", () => reassigningPartitionsCount)
+  newGauge("PartitionsWithLateTransactionsCount", () => lateTransactionsCount)
 
   def reassigningPartitionsCount: Int = leaderPartitionsIterator.count(_.isReassigning)
 
+  private def lateTransactionsCount: Int = {
+    val currentTimeMs = time.milliseconds()
+    leaderPartitionsIterator.count(_.hasLateTransaction(currentTimeMs))
+  }
+
   val isrExpandRate: Meter = newMeter("IsrExpandsPerSec", "expands", TimeUnit.SECONDS)
   val isrShrinkRate: Meter = newMeter("IsrShrinksPerSec", "shrinks", TimeUnit.SECONDS)
   val failedIsrUpdatesRate: Meter = newMeter("FailedIsrUpdatesPerSec", "failedUpdates", TimeUnit.SECONDS)
@@ -1939,6 +1945,7 @@ class ReplicaManager(val config: KafkaConfig,
     removeMetric("UnderMinIsrPartitionCount")
     removeMetric("AtMinIsrPartitionCount")
     removeMetric("ReassigningPartitions")
+    removeMetric("PartitionsWithLateTransactionsCount")
   }
 
   // High watermark do not need to be checkpointed only when under unit tests
diff --git a/core/src/test/scala/other/kafka/StressTestLog.scala b/core/src/test/scala/other/kafka/StressTestLog.scala
index 2422dcc..a90690a 100755
--- a/core/src/test/scala/other/kafka/StressTestLog.scala
+++ b/core/src/test/scala/other/kafka/StressTestLog.scala
@@ -48,6 +48,7 @@ object StressTestLog {
       recoveryPoint = 0L,
       scheduler = time.scheduler,
       time = time,
+      maxTransactionTimeoutMs = 5 * 60 * 1000,
       maxProducerIdExpirationMs = 60 * 60 * 1000,
       producerIdExpirationCheckIntervalMs = LogManager.ProducerIdExpirationCheckIntervalMs,
       brokerTopicStats = new BrokerTopicStats,
diff --git a/core/src/test/scala/other/kafka/TestLinearWriteSpeed.scala b/core/src/test/scala/other/kafka/TestLinearWriteSpeed.scala
index f274954..c342e71 100755
--- a/core/src/test/scala/other/kafka/TestLinearWriteSpeed.scala
+++ b/core/src/test/scala/other/kafka/TestLinearWriteSpeed.scala
@@ -210,8 +210,21 @@ object TestLinearWriteSpeed {
 
   class LogWritable(val dir: File, config: LogConfig, scheduler: Scheduler, val messages: MemoryRecords) extends Writable {
     Utils.delete(dir)
-    val log = UnifiedLog(dir, config, 0L, 0L, scheduler, new BrokerTopicStats, Time.SYSTEM, 60 * 60 * 1000,
-      LogManager.ProducerIdExpirationCheckIntervalMs, new LogDirFailureChannel(10), topicId = None, keepPartitionMetadataFile = true)
+    val log = UnifiedLog(
+      dir = dir,
+      config = config,
+      logStartOffset = 0L,
+      recoveryPoint = 0L,
+      scheduler = scheduler,
+      brokerTopicStats = new BrokerTopicStats,
+      time = Time.SYSTEM,
+      maxTransactionTimeoutMs = 5 * 60 * 1000,
+      maxProducerIdExpirationMs = 60 * 60 * 1000,
+      producerIdExpirationCheckIntervalMs = LogManager.ProducerIdExpirationCheckIntervalMs,
+      logDirFailureChannel = new LogDirFailureChannel(10),
+      topicId = None,
+      keepPartitionMetadataFile = true
+    )
     def write(): Int = {
       log.appendAsLeader(messages, leaderEpoch = 0)
       messages.sizeInBytes
diff --git a/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala b/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala
index b22c459..8c13035 100644
--- a/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala
+++ b/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala
@@ -285,8 +285,9 @@ class PartitionLockTest extends Logging {
         val logDirFailureChannel = new LogDirFailureChannel(1)
         val segments = new LogSegments(log.topicPartition)
         val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(log.dir, log.topicPartition, logDirFailureChannel, log.config.recordVersion, "")
+        val maxTransactionTimeout = 5 * 60 * 1000
         val maxProducerIdExpirationMs = 60 * 60 * 1000
-        val producerStateManager = new ProducerStateManager(log.topicPartition, log.dir, maxProducerIdExpirationMs)
+        val producerStateManager = new ProducerStateManager(log.topicPartition, log.dir, maxTransactionTimeout, maxProducerIdExpirationMs)
         val offsets = new LogLoader(
           log.dir,
           log.topicPartition,
@@ -298,7 +299,6 @@ class PartitionLockTest extends Logging {
           segments,
           0L,
           0L,
-          maxProducerIdExpirationMs,
           leaderEpochCache,
           producerStateManager
         ).load()
diff --git a/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala b/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
index 82147c0..bcab6ec 100644
--- a/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
+++ b/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
@@ -212,8 +212,10 @@ class PartitionTest extends AbstractPartitionTest {
         val logDirFailureChannel = new LogDirFailureChannel(1)
         val segments = new LogSegments(log.topicPartition)
         val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(log.dir, log.topicPartition, logDirFailureChannel, log.config.recordVersion, "")
+        val maxTransactionTimeoutMs = 5 * 60 * 1000
         val maxProducerIdExpirationMs = 60 * 60 * 1000
-        val producerStateManager = new ProducerStateManager(log.topicPartition, log.dir, maxProducerIdExpirationMs)
+        val producerStateManager = new ProducerStateManager(log.topicPartition, log.dir,
+          maxTransactionTimeoutMs, maxProducerIdExpirationMs)
         val offsets = new LogLoader(
           log.dir,
           log.topicPartition,
@@ -222,10 +224,9 @@ class PartitionTest extends AbstractPartitionTest {
           mockTime,
           logDirFailureChannel,
           hadCleanShutdown = true,
-          segments,
-          0L,
-          0L,
-          maxProducerIdExpirationMs,
+          segments = segments,
+          logStartOffsetCheckpoint = 0L,
+          recoveryPointCheckpoint = 0L,
           leaderEpochCache,
           producerStateManager
         ).load()
diff --git a/core/src/test/scala/unit/kafka/cluster/ReplicaTest.scala b/core/src/test/scala/unit/kafka/cluster/ReplicaTest.scala
index 08d0950..201ec1d 100644
--- a/core/src/test/scala/unit/kafka/cluster/ReplicaTest.scala
+++ b/core/src/test/scala/unit/kafka/cluster/ReplicaTest.scala
@@ -41,18 +41,21 @@ class ReplicaTest {
     logProps.put(LogConfig.SegmentIndexBytesProp, 1000: java.lang.Integer)
     logProps.put(LogConfig.RetentionMsProp, 999: java.lang.Integer)
     val config = LogConfig(logProps)
-    log = UnifiedLog(logDir,
-      config,
+    log = UnifiedLog(
+      dir = logDir,
+      config = config,
       logStartOffset = 0L,
       recoveryPoint = 0L,
       scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats,
       time = time,
+      maxTransactionTimeoutMs = 5 * 60 * 1000,
       maxProducerIdExpirationMs = 60 * 60 * 1000,
       producerIdExpirationCheckIntervalMs = LogManager.ProducerIdExpirationCheckIntervalMs,
       logDirFailureChannel = new LogDirFailureChannel(10),
       topicId = None,
-      keepPartitionMetadataFile = true)
+      keepPartitionMetadataFile = true
+    )
   }
 
   @AfterEach
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
index 85778d5..11451bf 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
@@ -98,7 +98,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren
       txnStateManager,
       time)
 
-    transactionCoordinator = new TransactionCoordinator(brokerId = 0,
+    transactionCoordinator = new TransactionCoordinator(
       txnConfig,
       scheduler,
       () => pidGenerator,
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
index 38e8e71..f6583ef 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
@@ -51,7 +51,6 @@ class TransactionCoordinatorTest {
   private val scheduler = new MockScheduler(time)
 
   val coordinator = new TransactionCoordinator(
-    brokerId,
     TransactionConfig(),
     scheduler,
     () => pidGenerator,
diff --git a/core/src/test/scala/unit/kafka/log/AbstractLogCleanerIntegrationTest.scala b/core/src/test/scala/unit/kafka/log/AbstractLogCleanerIntegrationTest.scala
index 8cfc8da..381ec93 100644
--- a/core/src/test/scala/unit/kafka/log/AbstractLogCleanerIntegrationTest.scala
+++ b/core/src/test/scala/unit/kafka/log/AbstractLogCleanerIntegrationTest.scala
@@ -101,13 +101,15 @@ abstract class AbstractLogCleanerIntegrationTest {
         deleteDelay = deleteDelay,
         segmentSize = segmentSize,
         maxCompactionLagMs = maxCompactionLagMs))
-      val log = UnifiedLog(dir,
-        logConfig,
+      val log = UnifiedLog(
+        dir = dir,
+        config = logConfig,
         logStartOffset = 0L,
         recoveryPoint = 0L,
         scheduler = time.scheduler,
         time = time,
         brokerTopicStats = new BrokerTopicStats,
+        maxTransactionTimeoutMs = 5 * 60 * 1000,
         maxProducerIdExpirationMs = 60 * 60 * 1000,
         producerIdExpirationCheckIntervalMs = LogManager.ProducerIdExpirationCheckIntervalMs,
         logDirFailureChannel = new LogDirFailureChannel(10),
diff --git a/core/src/test/scala/unit/kafka/log/BrokerCompressionTest.scala b/core/src/test/scala/unit/kafka/log/BrokerCompressionTest.scala
index f308b54..85745bf 100755
--- a/core/src/test/scala/unit/kafka/log/BrokerCompressionTest.scala
+++ b/core/src/test/scala/unit/kafka/log/BrokerCompressionTest.scala
@@ -52,10 +52,21 @@ class BrokerCompressionTest {
     val logProps = new Properties()
     logProps.put(LogConfig.CompressionTypeProp, brokerCompression)
     /*configure broker-side compression  */
-    val log = UnifiedLog(logDir, LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
-      time = time, brokerTopicStats = new BrokerTopicStats, maxProducerIdExpirationMs = 60 * 60 * 1000,
+    val log = UnifiedLog(
+      dir = logDir,
+      config = LogConfig(logProps),
+      logStartOffset = 0L,
+      recoveryPoint = 0L,
+      scheduler = time.scheduler,
+      time = time,
+      brokerTopicStats = new BrokerTopicStats,
+      maxTransactionTimeoutMs = 5 * 60 * 1000,
+      maxProducerIdExpirationMs = 60 * 60 * 1000,
       producerIdExpirationCheckIntervalMs = LogManager.ProducerIdExpirationCheckIntervalMs,
-      logDirFailureChannel = new LogDirFailureChannel(10), topicId = None, keepPartitionMetadataFile = true)
+      logDirFailureChannel = new LogDirFailureChannel(10),
+      topicId = None,
+      keepPartitionMetadataFile = true
+    )
 
     /* append two messages */
     log.appendAsLeader(MemoryRecords.withRecords(CompressionType.forId(messageCompressionCode.codec), 0,
diff --git a/core/src/test/scala/unit/kafka/log/LogCleanerManagerTest.scala b/core/src/test/scala/unit/kafka/log/LogCleanerManagerTest.scala
index f596987..0cdafed 100644
--- a/core/src/test/scala/unit/kafka/log/LogCleanerManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogCleanerManagerTest.scala
@@ -98,10 +98,11 @@ class LogCleanerManagerTest extends Logging {
     Files.createDirectories(tpDir.toPath)
     val logDirFailureChannel = new LogDirFailureChannel(10)
     val config = createLowRetentionLogConfig(logSegmentSize, LogConfig.Compact)
+    val maxTransactionTimeoutMs = 5 * 60 * 1000
     val maxProducerIdExpirationMs = 60 * 60 * 1000
     val segments = new LogSegments(tp)
     val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(tpDir, topicPartition, logDirFailureChannel, config.recordVersion, "")
-    val producerStateManager = new ProducerStateManager(topicPartition, tpDir, maxProducerIdExpirationMs, time)
+    val producerStateManager = new ProducerStateManager(topicPartition, tpDir, maxTransactionTimeoutMs, maxProducerIdExpirationMs, time)
     val offsets = new LogLoader(
       tpDir,
       tp,
@@ -113,7 +114,6 @@ class LogCleanerManagerTest extends Logging {
       segments,
       0L,
       0L,
-      maxProducerIdExpirationMs,
       leaderEpochCache,
       producerStateManager
     ).load()
@@ -796,13 +796,15 @@ class LogCleanerManagerTest extends Logging {
     val config = createLowRetentionLogConfig(segmentSize, cleanupPolicy)
     val partitionDir = new File(logDir, UnifiedLog.logDirName(topicPartition))
 
-    UnifiedLog(partitionDir,
-      config,
+    UnifiedLog(
+      dir = partitionDir,
+      config = config,
       logStartOffset = 0L,
       recoveryPoint = 0L,
       scheduler = time.scheduler,
       time = time,
       brokerTopicStats = new BrokerTopicStats,
+      maxTransactionTimeoutMs = 5 * 60 * 1000,
       maxProducerIdExpirationMs = 60 * 60 * 1000,
       producerIdExpirationCheckIntervalMs = LogManager.ProducerIdExpirationCheckIntervalMs,
       logDirFailureChannel = new LogDirFailureChannel(10),
@@ -847,11 +849,23 @@ class LogCleanerManagerTest extends Logging {
     log.maybeIncrementHighWatermark(log.logEndOffsetMetadata)
   }
 
-  private def makeLog(dir: File = logDir, config: LogConfig) =
-    UnifiedLog(dir = dir, config = config, logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
-      time = time, brokerTopicStats = new BrokerTopicStats, maxProducerIdExpirationMs = 60 * 60 * 1000,
+  private def makeLog(dir: File = logDir, config: LogConfig) = {
+    UnifiedLog(
+      dir = dir,
+      config = config,
+      logStartOffset = 0L,
+      recoveryPoint = 0L,
+      scheduler = time.scheduler,
+      time = time,
+      brokerTopicStats = new BrokerTopicStats,
+      maxTransactionTimeoutMs = 5 * 60 * 1000,
+      maxProducerIdExpirationMs = 60 * 60 * 1000,
       producerIdExpirationCheckIntervalMs = LogManager.ProducerIdExpirationCheckIntervalMs,
-      logDirFailureChannel = new LogDirFailureChannel(10), topicId = None, keepPartitionMetadataFile = true)
+      logDirFailureChannel = new LogDirFailureChannel(10),
+      topicId = None,
+      keepPartitionMetadataFile = true
+    )
+  }
 
   private def records(key: Int, value: Int, timestamp: Long) =
     MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord(timestamp, key.toString.getBytes, value.toString.getBytes))
diff --git a/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala b/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
index 3a32ac2..30f3bd4 100755
--- a/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
@@ -105,10 +105,12 @@ class LogCleanerTest {
     val config = LogConfig.fromProps(logConfig.originals, logProps)
     val topicPartition = UnifiedLog.parseTopicPartitionName(dir)
     val logDirFailureChannel = new LogDirFailureChannel(10)
+    val maxTransactionTimeoutMs = 5 * 60 * 1000
     val maxProducerIdExpirationMs = 60 * 60 * 1000
     val logSegments = new LogSegments(topicPartition)
     val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(dir, topicPartition, logDirFailureChannel, config.recordVersion, "")
-    val producerStateManager = new ProducerStateManager(topicPartition, dir, maxProducerIdExpirationMs, time)
+    val producerStateManager = new ProducerStateManager(topicPartition, dir,
+      maxTransactionTimeoutMs, maxProducerIdExpirationMs, time)
     val offsets = new LogLoader(
       dir,
       topicPartition,
@@ -120,7 +122,6 @@ class LogCleanerTest {
       logSegments,
       0L,
       0L,
-      maxProducerIdExpirationMs,
       leaderEpochCache,
       producerStateManager
     ).load()
@@ -1779,11 +1780,23 @@ class LogCleanerTest {
   private def messageWithOffset(key: Int, value: Int, offset: Long): MemoryRecords =
     messageWithOffset(key.toString.getBytes, value.toString.getBytes, offset)
 
-  private def makeLog(dir: File = dir, config: LogConfig = logConfig, recoveryPoint: Long = 0L) =
-    UnifiedLog(dir = dir, config = config, logStartOffset = 0L, recoveryPoint = recoveryPoint, scheduler = time.scheduler,
-      time = time, brokerTopicStats = new BrokerTopicStats, maxProducerIdExpirationMs = 60 * 60 * 1000,
+  private def makeLog(dir: File = dir, config: LogConfig = logConfig, recoveryPoint: Long = 0L) = {
+    UnifiedLog(
+      dir = dir,
+      config = config,
+      logStartOffset = 0L,
+      recoveryPoint = recoveryPoint,
+      scheduler = time.scheduler,
+      time = time,
+      brokerTopicStats = new BrokerTopicStats,
+      maxTransactionTimeoutMs = 5 * 60 * 1000,
+      maxProducerIdExpirationMs = 60 * 60 * 1000,
       producerIdExpirationCheckIntervalMs = LogManager.ProducerIdExpirationCheckIntervalMs,
-      logDirFailureChannel = new LogDirFailureChannel(10), topicId = None, keepPartitionMetadataFile = true)
+      logDirFailureChannel = new LogDirFailureChannel(10),
+      topicId = None,
+      keepPartitionMetadataFile = true
+    )
+  }
 
   private def makeCleaner(capacity: Int, checkDone: TopicPartition => Unit = _ => (), maxMessageSize: Int = 64*1024) =
     new Cleaner(id = 0,
diff --git a/core/src/test/scala/unit/kafka/log/LogConcurrencyTest.scala b/core/src/test/scala/unit/kafka/log/LogConcurrencyTest.scala
index e10b5ab..db3222b 100644
--- a/core/src/test/scala/unit/kafka/log/LogConcurrencyTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogConcurrencyTest.scala
@@ -148,11 +148,13 @@ class LogConcurrencyTest {
       scheduler = scheduler,
       brokerTopicStats = brokerTopicStats,
       time = Time.SYSTEM,
+      maxTransactionTimeoutMs = 5 * 60 * 1000,
       maxProducerIdExpirationMs = 60 * 60 * 1000,
       producerIdExpirationCheckIntervalMs = LogManager.ProducerIdExpirationCheckIntervalMs,
       logDirFailureChannel = new LogDirFailureChannel(10),
       topicId = None,
-      keepPartitionMetadataFile = true)
+      keepPartitionMetadataFile = true
+    )
   }
 
   private def validateConsumedData(log: UnifiedLog, consumedBatches: Iterable[FetchedBatch]): Unit = {
diff --git a/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala b/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala
index 496b1d1..abaa037 100644
--- a/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala
@@ -40,7 +40,8 @@ import scala.jdk.CollectionConverters._
 
 class LogLoaderTest {
   var config: KafkaConfig = null
-  val brokerTopicStats = new BrokerTopicStats()
+  val brokerTopicStats = new BrokerTopicStats
+  val maxTransactionTimeoutMs: Int = 5 * 60 * 1000
   val maxProducerIdExpirationMs: Int = 60 * 60 * 1000
   val tmpDir = TestUtils.tempDir()
   val logDir = TestUtils.randomPartitionLogDir(tmpDir)
@@ -73,9 +74,12 @@ class LogLoaderTest {
     case class SimulateError(var hasError: Boolean = false)
     val simulateError = SimulateError()
 
+    val maxTransactionTimeoutMs = 5 * 60 * 1000
+    val maxProducerIdExpirationMs = 60 * 60 * 1000
+
     // Create a LogManager with some overridden methods to facilitate interception of clean shutdown
     // flag and to inject a runtime error
-    def interceptedLogManager(logConfig: LogConfig, logDirs: Seq[File], simulateError: SimulateError): LogManager =
+    def interceptedLogManager(logConfig: LogConfig, logDirs: Seq[File], simulateError: SimulateError): LogManager = {
       new LogManager(
         logDirs = logDirs.map(_.getAbsoluteFile),
         initialOfflineDirs = Array.empty[File],
@@ -87,7 +91,8 @@ class LogLoaderTest {
         flushRecoveryOffsetCheckpointMs = 10000L,
         flushStartOffsetCheckpointMs = 10000L,
         retentionCheckMs = 1000L,
-        maxPidExpirationMs = 60 * 60 * 1000,
+        maxTransactionTimeoutMs = maxTransactionTimeoutMs,
+        maxPidExpirationMs = maxProducerIdExpirationMs,
         interBrokerProtocolVersion = config.interBrokerProtocolVersion,
         scheduler = time.scheduler,
         brokerTopicStats = new BrokerTopicStats(),
@@ -107,13 +112,13 @@ class LogLoaderTest {
           val logRecoveryPoint = recoveryPoints.getOrElse(topicPartition, 0L)
           val logStartOffset = logStartOffsets.getOrElse(topicPartition, 0L)
           val logDirFailureChannel: LogDirFailureChannel = new LogDirFailureChannel(1)
-          val maxProducerIdExpirationMs = 60 * 60 * 1000
           val segments = new LogSegments(topicPartition)
           val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(logDir, topicPartition, logDirFailureChannel, config.recordVersion, "")
-          val producerStateManager = new ProducerStateManager(topicPartition, logDir, maxProducerIdExpirationMs, time)
+          val producerStateManager = new ProducerStateManager(topicPartition, logDir,
+            maxTransactionTimeoutMs, maxProducerIdExpirationMs, time)
           val logLoader = new LogLoader(logDir, topicPartition, config, time.scheduler, time,
             logDirFailureChannel, hadCleanShutdown, segments, logStartOffset, logRecoveryPoint,
-            maxProducerIdExpirationMs, leaderEpochCache, producerStateManager)
+            leaderEpochCache, producerStateManager)
           val offsets = logLoader.load()
           val localLog = new LocalLog(logDir, logConfig, segments, offsets.recoveryPoint,
             offsets.nextOffsetMetadata, mockTime.scheduler, mockTime, topicPartition,
@@ -123,6 +128,7 @@ class LogLoaderTest {
             producerStateManager, None, true)
         }
       }
+    }
 
     val cleanShutdownFile = new File(logDir, LogLoader.CleanShutdownFile)
     locally {
@@ -184,11 +190,12 @@ class LogLoaderTest {
                         recoveryPoint: Long = 0L,
                         scheduler: Scheduler = mockTime.scheduler,
                         time: Time = mockTime,
+                        maxTransactionTimeoutMs: Int = maxTransactionTimeoutMs,
                         maxProducerIdExpirationMs: Int = maxProducerIdExpirationMs,
                         producerIdExpirationCheckIntervalMs: Int = LogManager.ProducerIdExpirationCheckIntervalMs,
                         lastShutdownClean: Boolean = true): UnifiedLog = {
     LogTestUtils.createLog(dir, config, brokerTopicStats, scheduler, time, logStartOffset, recoveryPoint,
-      maxProducerIdExpirationMs, producerIdExpirationCheckIntervalMs, lastShutdownClean)
+      maxTransactionTimeoutMs, maxProducerIdExpirationMs, producerIdExpirationCheckIntervalMs, lastShutdownClean)
   }
 
   private def createLogWithOffsetOverflow(logConfig: LogConfig): (UnifiedLog, LogSegment) = {
@@ -266,7 +273,8 @@ class LogLoaderTest {
       expectedSnapshotOffsets ++= log.logSegments.map(_.baseOffset).toVector.takeRight(4) :+ log.logEndOffset
     }
 
-    def createLogWithInterceptedReads(recoveryPoint: Long) = {
+    def createLogWithInterceptedReads(recoveryPoint: Long): UnifiedLog = {
+      val maxTransactionTimeoutMs = 5 * 60 * 1000
       val maxProducerIdExpirationMs = 60 * 60 * 1000
       val topicPartition = UnifiedLog.parseTopicPartitionName(logDir)
       val logDirFailureChannel = new LogDirFailureChannel(10)
@@ -291,7 +299,8 @@ class LogLoaderTest {
         }
       }
       val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(logDir, topicPartition, logDirFailureChannel, logConfig.recordVersion, "")
-      val producerStateManager = new ProducerStateManager(topicPartition, logDir, maxProducerIdExpirationMs, mockTime)
+      val producerStateManager = new ProducerStateManager(topicPartition, logDir,
+        maxTransactionTimeoutMs, maxProducerIdExpirationMs, mockTime)
       val logLoader = new LogLoader(
         logDir,
         topicPartition,
@@ -303,7 +312,6 @@ class LogLoaderTest {
         interceptedLogSegments,
         0L,
         recoveryPoint,
-        maxProducerIdExpirationMs,
         leaderEpochCache,
         producerStateManager)
       val offsets = logLoader.load()
@@ -339,8 +347,14 @@ class LogLoaderTest {
 
   @Test
   def testSkipLoadingIfEmptyProducerStateBeforeTruncation(): Unit = {
+    val maxTransactionTimeoutMs = 60000
+    val maxProducerIdExpirationMs = 300000
+
     val stateManager: ProducerStateManager = EasyMock.mock(classOf[ProducerStateManager])
+    EasyMock.expect(stateManager.maxProducerIdExpirationMs).andStubReturn(maxProducerIdExpirationMs)
+    EasyMock.expect(stateManager.maxTransactionTimeoutMs).andStubReturn(maxTransactionTimeoutMs)
     EasyMock.expect(stateManager.removeStraySnapshots(EasyMock.anyObject())).anyTimes()
+
     // Load the log
     EasyMock.expect(stateManager.latestSnapshotOffset).andReturn(None)
 
@@ -363,7 +377,6 @@ class LogLoaderTest {
     val topicPartition = UnifiedLog.parseTopicPartitionName(logDir)
     val logDirFailureChannel: LogDirFailureChannel = new LogDirFailureChannel(1)
     val config = LogConfig(new Properties())
-    val maxProducerIdExpirationMs = 300000
     val segments = new LogSegments(topicPartition)
     val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(logDir, topicPartition, logDirFailureChannel, config.recordVersion, "")
     val offsets = new LogLoader(
@@ -377,7 +390,6 @@ class LogLoaderTest {
       segments,
       0L,
       0L,
-      maxProducerIdExpirationMs,
       leaderEpochCache,
       stateManager
     ).load()
@@ -471,7 +483,12 @@ class LogLoaderTest {
   @nowarn("cat=deprecation")
   @Test
   def testSkipTruncateAndReloadIfOldMessageFormatAndNoCleanShutdown(): Unit = {
+    val maxTransactionTimeoutMs = 60000
+    val maxProducerIdExpirationMs = 300000
+
     val stateManager: ProducerStateManager = EasyMock.mock(classOf[ProducerStateManager])
+    EasyMock.expect(stateManager.maxProducerIdExpirationMs).andStubReturn(maxProducerIdExpirationMs)
+    EasyMock.expect(stateManager.maxTransactionTimeoutMs).andStubReturn(maxTransactionTimeoutMs)
     EasyMock.expect(stateManager.removeStraySnapshots(EasyMock.anyObject())).anyTimes()
 
     stateManager.updateMapEndOffset(0L)
@@ -492,7 +509,6 @@ class LogLoaderTest {
     val logProps = new Properties()
     logProps.put(LogConfig.MessageFormatVersionProp, "0.10.2")
     val config = LogConfig(logProps)
-    val maxProducerIdExpirationMs = 300000
     val logDirFailureChannel = null
     val segments = new LogSegments(topicPartition)
     val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(logDir, topicPartition, logDirFailureChannel, config.recordVersion, "")
@@ -507,7 +523,6 @@ class LogLoaderTest {
       segments,
       0L,
       0L,
-      maxProducerIdExpirationMs,
       leaderEpochCache,
       stateManager
     ).load()
@@ -529,7 +544,12 @@ class LogLoaderTest {
   @nowarn("cat=deprecation")
   @Test
   def testSkipTruncateAndReloadIfOldMessageFormatAndCleanShutdown(): Unit = {
+    val maxTransactionTimeoutMs = 60000
+    val maxProducerIdExpirationMs = 300000
+
     val stateManager: ProducerStateManager = EasyMock.mock(classOf[ProducerStateManager])
+    EasyMock.expect(stateManager.maxProducerIdExpirationMs).andStubReturn(maxProducerIdExpirationMs)
+    EasyMock.expect(stateManager.maxTransactionTimeoutMs).andStubReturn(maxTransactionTimeoutMs)
     EasyMock.expect(stateManager.removeStraySnapshots(EasyMock.anyObject())).anyTimes()
 
     stateManager.updateMapEndOffset(0L)
@@ -550,7 +570,6 @@ class LogLoaderTest {
     val logProps = new Properties()
     logProps.put(LogConfig.MessageFormatVersionProp, "0.10.2")
     val config = LogConfig(logProps)
-    val maxProducerIdExpirationMs = 300000
     val logDirFailureChannel = null
     val segments = new LogSegments(topicPartition)
     val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(logDir, topicPartition, logDirFailureChannel, config.recordVersion, "")
@@ -565,7 +584,6 @@ class LogLoaderTest {
       segments,
       0L,
       0L,
-      maxProducerIdExpirationMs,
       leaderEpochCache,
       stateManager
     ).load()
@@ -587,7 +605,12 @@ class LogLoaderTest {
   @nowarn("cat=deprecation")
   @Test
   def testSkipTruncateAndReloadIfNewMessageFormatAndCleanShutdown(): Unit = {
+    val maxTransactionTimeoutMs = 60000
+    val maxProducerIdExpirationMs = 300000
+
     val stateManager: ProducerStateManager = EasyMock.mock(classOf[ProducerStateManager])
+    EasyMock.expect(stateManager.maxProducerIdExpirationMs).andStubReturn(maxProducerIdExpirationMs)
+    EasyMock.expect(stateManager.maxTransactionTimeoutMs).andStubReturn(maxTransactionTimeoutMs)
     EasyMock.expect(stateManager.removeStraySnapshots(EasyMock.anyObject())).anyTimes()
 
     EasyMock.expect(stateManager.latestSnapshotOffset).andReturn(None)
@@ -610,7 +633,6 @@ class LogLoaderTest {
     val logProps = new Properties()
     logProps.put(LogConfig.MessageFormatVersionProp, "0.11.0")
     val config = LogConfig(logProps)
-    val maxProducerIdExpirationMs = 300000
     val logDirFailureChannel = null
     val segments = new LogSegments(topicPartition)
     val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(logDir, topicPartition, logDirFailureChannel, config.recordVersion, "")
@@ -625,7 +647,6 @@ class LogLoaderTest {
       segments,
       0L,
       0L,
-      maxProducerIdExpirationMs,
       leaderEpochCache,
       stateManager
     ).load()
diff --git a/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala b/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala
index 9884576..226330c 100644
--- a/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala
@@ -35,9 +35,10 @@ import scala.collection.mutable.ArrayBuffer
 
 class LogSegmentTest {
 
-  val topicPartition = new TopicPartition("topic", 0)
-  val segments = mutable.ArrayBuffer[LogSegment]()
-  var logDir: File = _
+  private val maxTransactionTimeoutMs = 5 * 60 * 1000
+  private val topicPartition = new TopicPartition("topic", 0)
+  private val segments = mutable.ArrayBuffer[LogSegment]()
+  private var logDir: File = _
 
   /* create a segment with the given base offset */
   def createSegment(offset: Long,
@@ -301,7 +302,7 @@ class LogSegmentTest {
       seg.append(i, RecordBatch.NO_TIMESTAMP, -1L, records(i, i.toString))
     val indexFile = seg.lazyOffsetIndex.file
     TestUtils.writeNonsenseToFile(indexFile, 5, indexFile.length.toInt)
-    seg.recover(new ProducerStateManager(topicPartition, logDir))
+    seg.recover(new ProducerStateManager(topicPartition, logDir, maxTransactionTimeoutMs))
     for(i <- 0 until 100) {
       val records = seg.read(i, 1, minOneMessage = true).records.records
       assertEquals(i, records.iterator.next().offset)
@@ -341,7 +342,7 @@ class LogSegmentTest {
     segment.append(largestOffset = 107L, largestTimestamp = RecordBatch.NO_TIMESTAMP,
       shallowOffsetOfMaxTimestamp = 107L, records = endTxnRecords(ControlRecordType.COMMIT, pid1, producerEpoch, offset = 107L))
 
-    var stateManager = new ProducerStateManager(topicPartition, logDir)
+    var stateManager = new ProducerStateManager(topicPartition, logDir, maxTransactionTimeoutMs)
     segment.recover(stateManager)
     assertEquals(108L, stateManager.mapEndOffset)
 
@@ -355,7 +356,7 @@ class LogSegmentTest {
     assertEquals(100L, abortedTxn.lastStableOffset)
 
     // recover again, but this time assuming the transaction from pid2 began on a previous segment
-    stateManager = new ProducerStateManager(topicPartition, logDir)
+    stateManager = new ProducerStateManager(topicPartition, logDir, maxTransactionTimeoutMs)
     stateManager.loadProducerEntry(new ProducerStateEntry(pid2,
       mutable.Queue[BatchMetadata](BatchMetadata(10, 10L, 5, RecordBatch.NO_TIMESTAMP)), producerEpoch,
       0, RecordBatch.NO_TIMESTAMP, Some(75L)))
@@ -406,7 +407,7 @@ class LogSegmentTest {
       shallowOffsetOfMaxTimestamp = 110, records = MemoryRecords.withRecords(110L, CompressionType.NONE, 2,
         new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes)))
 
-    seg.recover(new ProducerStateManager(topicPartition, logDir), Some(cache))
+    seg.recover(new ProducerStateManager(topicPartition, logDir, maxTransactionTimeoutMs), Some(cache))
     assertEquals(ArrayBuffer(EpochEntry(epoch = 0, startOffset = 104L),
                              EpochEntry(epoch = 1, startOffset = 106),
                              EpochEntry(epoch = 2, startOffset = 110)),
@@ -435,7 +436,7 @@ class LogSegmentTest {
       seg.append(i, i * 10, i, records(i, i.toString))
     val timeIndexFile = seg.lazyTimeIndex.file
     TestUtils.writeNonsenseToFile(timeIndexFile, 5, timeIndexFile.length.toInt)
-    seg.recover(new ProducerStateManager(topicPartition, logDir))
+    seg.recover(new ProducerStateManager(topicPartition, logDir, maxTransactionTimeoutMs))
     for(i <- 0 until 100) {
       assertEquals(i, seg.findOffsetByTimestamp(i * 10).get.offset)
       if (i < 99)
@@ -459,7 +460,7 @@ class LogSegmentTest {
       val recordPosition = seg.log.searchForOffsetWithSize(offsetToBeginCorruption, 0)
       val position = recordPosition.position + TestUtils.random.nextInt(15)
       TestUtils.writeNonsenseToFile(seg.log.file, position, (seg.log.file.length - position).toInt)
-      seg.recover(new ProducerStateManager(topicPartition, logDir))
+      seg.recover(new ProducerStateManager(topicPartition, logDir, maxTransactionTimeoutMs))
       assertEquals((0 until offsetToBeginCorruption).toList, seg.log.batches.asScala.map(_.lastOffset).toList,
         "Should have truncated off bad messages.")
       seg.deleteIfExists()
diff --git a/core/src/test/scala/unit/kafka/log/LogTestUtils.scala b/core/src/test/scala/unit/kafka/log/LogTestUtils.scala
index 1f32ed8..e524bcb 100644
--- a/core/src/test/scala/unit/kafka/log/LogTestUtils.scala
+++ b/core/src/test/scala/unit/kafka/log/LogTestUtils.scala
@@ -77,24 +77,28 @@ object LogTestUtils {
                 time: Time,
                 logStartOffset: Long = 0L,
                 recoveryPoint: Long = 0L,
+                maxTransactionTimeoutMs: Int = 5 * 60 * 1000,
                 maxProducerIdExpirationMs: Int = 60 * 60 * 1000,
                 producerIdExpirationCheckIntervalMs: Int = LogManager.ProducerIdExpirationCheckIntervalMs,
                 lastShutdownClean: Boolean = true,
                 topicId: Option[Uuid] = None,
                 keepPartitionMetadataFile: Boolean = true): UnifiedLog = {
-    UnifiedLog(dir = dir,
+    UnifiedLog(
+      dir = dir,
       config = config,
       logStartOffset = logStartOffset,
       recoveryPoint = recoveryPoint,
       scheduler = scheduler,
       brokerTopicStats = brokerTopicStats,
       time = time,
+      maxTransactionTimeoutMs = maxTransactionTimeoutMs,
       maxProducerIdExpirationMs = maxProducerIdExpirationMs,
       producerIdExpirationCheckIntervalMs = producerIdExpirationCheckIntervalMs,
       logDirFailureChannel = new LogDirFailureChannel(10),
       lastShutdownClean = lastShutdownClean,
       topicId = topicId,
-      keepPartitionMetadataFile = keepPartitionMetadataFile)
+      keepPartitionMetadataFile = keepPartitionMetadataFile
+    )
   }
 
   /**
diff --git a/core/src/test/scala/unit/kafka/log/ProducerStateManagerTest.scala b/core/src/test/scala/unit/kafka/log/ProducerStateManagerTest.scala
index 0c2fb6b..7140477 100644
--- a/core/src/test/scala/unit/kafka/log/ProducerStateManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/log/ProducerStateManagerTest.scala
@@ -36,17 +36,20 @@ import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
 
 class ProducerStateManagerTest {
-  var logDir: File = null
-  var stateManager: ProducerStateManager = null
-  val partition = new TopicPartition("test", 0)
-  val producerId = 1L
-  val maxPidExpirationMs = 60 * 1000
-  val time = new MockTime
+  private var logDir: File = null
+  private var stateManager: ProducerStateManager = null
+  private val partition = new TopicPartition("test", 0)
+  private val producerId = 1L
+  private val maxTransactionTimeoutMs = 5 * 60 * 1000
+  private val maxProducerIdExpirationMs = 60 * 60 * 1000
+  private val lateTransactionTimeoutMs = maxTransactionTimeoutMs + ProducerStateManager.LateTransactionBufferMs
+  private val time = new MockTime
 
   @BeforeEach
   def setUp(): Unit = {
     logDir = TestUtils.tempDir()
-    stateManager = new ProducerStateManager(partition, logDir, maxPidExpirationMs, time)
+    stateManager = new ProducerStateManager(partition, logDir, maxTransactionTimeoutMs,
+      maxProducerIdExpirationMs, time)
   }
 
   @AfterEach
@@ -264,6 +267,100 @@ class ProducerStateManagerTest {
   }
 
   @Test
+  def testHasLateTransaction(): Unit = {
+    val producerId1 = 39L
+    val epoch1 = 2.toShort
+
+    val producerId2 = 57L
+    val epoch2 = 9.toShort
+
+    // Start two transactions with a delay between them
+    append(stateManager, producerId1, epoch1, seq = 0, offset = 100, isTransactional = true)
+    assertFalse(stateManager.hasLateTransaction(time.milliseconds()))
+
+    time.sleep(500)
+    append(stateManager, producerId2, epoch2, seq = 0, offset = 150, isTransactional = true)
+    assertFalse(stateManager.hasLateTransaction(time.milliseconds()))
+
+    // Only the first transaction is late
+    time.sleep(lateTransactionTimeoutMs - 500 + 1)
+    assertTrue(stateManager.hasLateTransaction(time.milliseconds()))
+
+    // Both transactions are now late
+    time.sleep(500)
+    assertTrue(stateManager.hasLateTransaction(time.milliseconds()))
+
+    // Finish the first transaction
+    appendEndTxnMarker(stateManager, producerId1, epoch1, ControlRecordType.COMMIT, offset = 200)
+    assertTrue(stateManager.hasLateTransaction(time.milliseconds()))
+
+    // Now finish the second transaction
+    appendEndTxnMarker(stateManager, producerId2, epoch2, ControlRecordType.COMMIT, offset = 250)
+    assertFalse(stateManager.hasLateTransaction(time.milliseconds()))
+  }
+
+  @Test
+  def testHasLateTransactionInitializedAfterReload(): Unit = {
+    val producerId1 = 39L
+    val epoch1 = 2.toShort
+
+    val producerId2 = 57L
+    val epoch2 = 9.toShort
+
+    // Start two transactions with a delay between them
+    append(stateManager, producerId1, epoch1, seq = 0, offset = 100, isTransactional = true)
+    assertFalse(stateManager.hasLateTransaction(time.milliseconds()))
+
+    time.sleep(500)
+    append(stateManager, producerId2, epoch2, seq = 0, offset = 150, isTransactional = true)
+    assertFalse(stateManager.hasLateTransaction(time.milliseconds()))
+
+    // Take a snapshot and reload the state
+    stateManager.takeSnapshot()
+    time.sleep(lateTransactionTimeoutMs - 500 + 1)
+    assertTrue(stateManager.hasLateTransaction(time.milliseconds()))
+
+    // After reloading from the snapshot, the transaction should still be considered late
+    val reloadedStateManager = new ProducerStateManager(partition, logDir, maxTransactionTimeoutMs,
+      maxProducerIdExpirationMs, time)
+    reloadedStateManager.truncateAndReload(logStartOffset = 0L,
+      logEndOffset = stateManager.mapEndOffset, currentTimeMs = time.milliseconds())
+    assertTrue(reloadedStateManager.hasLateTransaction(time.milliseconds()))
+  }
+
+  @Test
+  def testHasLateTransactionUpdatedAfterPartialTruncation(): Unit = {
+    val producerId = 39L
+    val epoch = 2.toShort
+
+    // Start one transaction and sleep until it is late
+    append(stateManager, producerId, epoch, seq = 0, offset = 100, isTransactional = true)
+    assertFalse(stateManager.hasLateTransaction(time.milliseconds()))
+    time.sleep(lateTransactionTimeoutMs + 1)
+    assertTrue(stateManager.hasLateTransaction(time.milliseconds()))
+
+    // After truncation, the ongoing transaction will be cleared
+    stateManager.truncateAndReload(logStartOffset = 0, logEndOffset = 80, currentTimeMs = time.milliseconds())
+    assertFalse(stateManager.hasLateTransaction(time.milliseconds()))
+  }
+
+  @Test
+  def testHasLateTransactionUpdatedAfterFullTruncation(): Unit = {
+    val producerId = 39L
+    val epoch = 2.toShort
+
+    // Start one transaction and sleep until it is late
+    append(stateManager, producerId, epoch, seq = 0, offset = 100, isTransactional = true)
+    assertFalse(stateManager.hasLateTransaction(time.milliseconds()))
+    time.sleep(lateTransactionTimeoutMs + 1)
+    assertTrue(stateManager.hasLateTransaction(time.milliseconds()))
+
+    // After truncation, the ongoing transaction will be cleared
+    stateManager.truncateFullyAndStartAt(offset = 150L)
+    assertFalse(stateManager.hasLateTransaction(time.milliseconds()))
+  }
+
+  @Test
   def testLastStableOffsetCompletedTxn(): Unit = {
     val producerEpoch = 0.toShort
     val segmentBaseOffset = 990000L
@@ -467,7 +564,8 @@ class ProducerStateManagerTest {
     append(stateManager, producerId, epoch, 1, 1L, isTransactional = true)
 
     stateManager.takeSnapshot()
-    val recoveredMapping = new ProducerStateManager(partition, logDir, maxPidExpirationMs, time)
+    val recoveredMapping = new ProducerStateManager(partition, logDir,
+      maxTransactionTimeoutMs, maxProducerIdExpirationMs, time)
     recoveredMapping.truncateAndReload(0L, 3L, time.milliseconds)
 
     // The snapshot only persists the last appended batch metadata
@@ -490,7 +588,8 @@ class ProducerStateManagerTest {
     appendEndTxnMarker(stateManager, producerId, epoch, ControlRecordType.ABORT, offset = 2L)
 
     stateManager.takeSnapshot()
-    val recoveredMapping = new ProducerStateManager(partition, logDir, maxPidExpirationMs, time)
+    val recoveredMapping = new ProducerStateManager(partition, logDir,
+      maxTransactionTimeoutMs, maxProducerIdExpirationMs, time)
     recoveredMapping.truncateAndReload(0L, 3L, time.milliseconds)
 
     // The snapshot only persists the last appended batch metadata
@@ -510,7 +609,8 @@ class ProducerStateManagerTest {
       offset = 0L, timestamp = appendTimestamp)
     stateManager.takeSnapshot()
 
-    val recoveredMapping = new ProducerStateManager(partition, logDir, maxPidExpirationMs, time)
+    val recoveredMapping = new ProducerStateManager(partition, logDir,
+      maxTransactionTimeoutMs, maxProducerIdExpirationMs, time)
     recoveredMapping.truncateAndReload(logStartOffset = 0L, logEndOffset = 1L, time.milliseconds)
 
     val lastEntry = recoveredMapping.lastEntry(producerId)
@@ -542,7 +642,8 @@ class ProducerStateManagerTest {
     append(stateManager, producerId, epoch, 1, 1L, 1)
 
     stateManager.takeSnapshot()
-    val recoveredMapping = new ProducerStateManager(partition, logDir, maxPidExpirationMs, time)
+    val recoveredMapping = new ProducerStateManager(partition, logDir,
+      maxTransactionTimeoutMs, maxProducerIdExpirationMs, time)
     recoveredMapping.truncateAndReload(0L, 1L, 70000)
 
     // entry added after recovery. The pid should be expired now, and would not exist in the pid mapping. Hence
@@ -561,7 +662,8 @@ class ProducerStateManagerTest {
     append(stateManager, producerId, epoch, 1, 1L, 1)
 
     stateManager.takeSnapshot()
-    val recoveredMapping = new ProducerStateManager(partition, logDir, maxPidExpirationMs, time)
+    val recoveredMapping = new ProducerStateManager(partition, logDir,
+      maxTransactionTimeoutMs, maxProducerIdExpirationMs, time)
     recoveredMapping.truncateAndReload(0L, 1L, 70000)
 
     val sequence = 2
@@ -706,7 +808,7 @@ class ProducerStateManagerTest {
     val epoch = 5.toShort
     val sequence = 37
     append(stateManager, producerId, epoch, sequence, 1L)
-    time.sleep(maxPidExpirationMs + 1)
+    time.sleep(maxProducerIdExpirationMs + 1)
     stateManager.removeExpiredProducers(time.milliseconds)
     append(stateManager, producerId, epoch, sequence + 1, 2L)
     assertEquals(1, stateManager.activeProducers.size)
@@ -756,7 +858,7 @@ class ProducerStateManagerTest {
     append(stateManager, producerId, epoch, sequence, offset = 99, isTransactional = true)
     assertEquals(Some(99L), stateManager.firstUndecidedOffset)
 
-    time.sleep(maxPidExpirationMs + 1)
+    time.sleep(maxProducerIdExpirationMs + 1)
     stateManager.removeExpiredProducers(time.milliseconds)
 
     assertTrue(stateManager.lastEntry(producerId).isDefined)
@@ -769,7 +871,8 @@ class ProducerStateManagerTest {
   @Test
   def testSequenceNotValidatedForGroupMetadataTopic(): Unit = {
     val partition = new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, 0)
-    val stateManager = new ProducerStateManager(partition, logDir, maxPidExpirationMs, time)
+    val stateManager = new ProducerStateManager(partition, logDir,
+      maxTransactionTimeoutMs, maxProducerIdExpirationMs, time)
 
     val epoch = 0.toShort
     append(stateManager, producerId, epoch, RecordBatch.NO_SEQUENCE, offset = 99,
@@ -818,7 +921,8 @@ class ProducerStateManagerTest {
     appendEndTxnMarker(stateManager, producerId, producerEpoch, ControlRecordType.COMMIT, offset = 100, coordinatorEpoch = 1)
     stateManager.takeSnapshot()
 
-    val recoveredMapping = new ProducerStateManager(partition, logDir, maxPidExpirationMs, time)
+    val recoveredMapping = new ProducerStateManager(partition, logDir,
+      maxTransactionTimeoutMs, maxProducerIdExpirationMs, time)
     recoveredMapping.truncateAndReload(0L, 2L, 70000)
 
     // append from old coordinator should be rejected
@@ -907,7 +1011,7 @@ class ProducerStateManagerTest {
   @Test
   def testRemoveAndMarkSnapshotForDeletion(): Unit = {
     UnifiedLog.producerSnapshotFile(logDir, 5).createNewFile()
-    val manager = new ProducerStateManager(partition, logDir, time = time)
+    val manager = new ProducerStateManager(partition, logDir, maxTransactionTimeoutMs, time = time)
     assertTrue(manager.latestSnapshotOffset.isDefined)
     val snapshot = manager.removeAndMarkSnapshotForDeletion(5).get
     assertTrue(snapshot.file.toPath.toString.endsWith(UnifiedLog.DeletedFileSuffix))
@@ -925,7 +1029,7 @@ class ProducerStateManagerTest {
   def testRemoveAndMarkSnapshotForDeletionAlreadyDeleted(): Unit = {
     val file = UnifiedLog.producerSnapshotFile(logDir, 5)
     file.createNewFile()
-    val manager = new ProducerStateManager(partition, logDir, time = time)
+    val manager = new ProducerStateManager(partition, logDir, maxTransactionTimeoutMs, time = time)
     assertTrue(manager.latestSnapshotOffset.isDefined)
     Files.delete(file.toPath)
     assertTrue(manager.removeAndMarkSnapshotForDeletion(5).isEmpty)
@@ -954,7 +1058,8 @@ class ProducerStateManagerTest {
     }
 
     // Ensure that the truncated snapshot is deleted and producer state is loaded from the previous snapshot
-    val reloadedStateManager = new ProducerStateManager(partition, logDir, maxPidExpirationMs, time)
+    val reloadedStateManager = new ProducerStateManager(partition, logDir,
+      maxTransactionTimeoutMs, maxProducerIdExpirationMs, time)
     reloadedStateManager.truncateAndReload(0L, 20L, time.milliseconds())
     assertFalse(snapshotToTruncate.exists())
 
diff --git a/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala b/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala
index 6ad78da..79aa743 100755
--- a/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala
+++ b/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala
@@ -3373,13 +3373,15 @@ class UnifiedLogTest {
                         recoveryPoint: Long = 0L,
                         scheduler: Scheduler = mockTime.scheduler,
                         time: Time = mockTime,
+                        maxTransactionTimeoutMs: Int = 60 * 60 * 1000,
                         maxProducerIdExpirationMs: Int = 60 * 60 * 1000,
                         producerIdExpirationCheckIntervalMs: Int = LogManager.ProducerIdExpirationCheckIntervalMs,
                         lastShutdownClean: Boolean = true,
                         topicId: Option[Uuid] = None,
                         keepPartitionMetadataFile: Boolean = true): UnifiedLog = {
     LogTestUtils.createLog(dir, config, brokerTopicStats, scheduler, time, logStartOffset, recoveryPoint,
-      maxProducerIdExpirationMs, producerIdExpirationCheckIntervalMs, lastShutdownClean, topicId = topicId, keepPartitionMetadataFile = keepPartitionMetadataFile)
+      maxTransactionTimeoutMs, maxProducerIdExpirationMs, producerIdExpirationCheckIntervalMs,
+      lastShutdownClean, topicId, keepPartitionMetadataFile)
   }
 
   private def createLogWithOffsetOverflow(logConfig: LogConfig): (UnifiedLog, LogSegment) = {
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
index ebf4c2d..eb0f306 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
@@ -392,7 +392,72 @@ class ReplicaManagerTest {
     } finally {
       replicaManager.shutdown(checkpointHW = false)
     }
+  }
+
+  @Test
+  def testPartitionsWithLateTransactionsCount(): Unit = {
+    val timer = new MockTimer(time)
+    val replicaManager = setupReplicaManagerWithMockedPurgatories(timer)
+    val topicPartition = new TopicPartition(topic, 0)
+
+    def assertLateTransactionCount(expectedCount: Option[Int]): Unit = {
+      assertEquals(expectedCount, TestUtils.yammerGaugeValue[Int]("PartitionsWithLateTransactionsCount"))
+    }
+
+    try {
+      assertLateTransactionCount(Some(0))
 
+      val partition = replicaManager.createPartition(topicPartition)
+      partition.createLogIfNotExists(isNew = false, isFutureReplica = false,
+        new LazyOffsetCheckpoints(replicaManager.highWatermarkCheckpoints), None)
+
+      // Make this replica the leader.
+      val brokerList = Seq[Integer](0, 1, 2).asJava
+      val leaderAndIsrRequest1 = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch,
+        Seq(new LeaderAndIsrPartitionState()
+          .setTopicName(topic)
+          .setPartitionIndex(0)
+          .setControllerEpoch(0)
+          .setLeader(0)
+          .setLeaderEpoch(0)
+          .setIsr(brokerList)
+          .setZkVersion(0)
+          .setReplicas(brokerList)
+          .setIsNew(true)).asJava,
+        topicIds.asJava,
+        Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build()
+      replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest1, (_, _) => ())
+
+      // Start a transaction
+      val producerId = 234L
+      val epoch = 5.toShort
+      val sequence = 9
+      val records = MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, epoch, sequence,
+        new SimpleRecord(time.milliseconds(), s"message $sequence".getBytes))
+      appendRecords(replicaManager, new TopicPartition(topic, 0), records).onFire { response =>
+        assertEquals(Errors.NONE, response.error)
+      }
+      assertLateTransactionCount(Some(0))
+
+      // The transaction becomes late if not finished before the max transaction timeout passes
+      time.sleep(replicaManager.logManager.maxTransactionTimeoutMs + ProducerStateManager.LateTransactionBufferMs)
+      assertLateTransactionCount(Some(0))
+      time.sleep(1)
+      assertLateTransactionCount(Some(1))
+
+      // After the late transaction is aborted, we expect the count to return to 0
+      val abortTxnMarker = new EndTransactionMarker(ControlRecordType.ABORT, 0)
+      val abortRecordBatch = MemoryRecords.withEndTransactionMarker(producerId, epoch, abortTxnMarker)
+      appendRecords(replicaManager, new TopicPartition(topic, 0),
+        abortRecordBatch, origin = AppendOrigin.Coordinator).onFire { response =>
+        assertEquals(Errors.NONE, response.error)
+      }
+      assertLateTransactionCount(Some(0))
+    } finally {
+      // After shutdown, the metric should no longer be registered
+      replicaManager.shutdown(checkpointHW = false)
+      assertLateTransactionCount(None)
+    }
   }
 
   @Test
@@ -1750,10 +1815,12 @@ class ReplicaManagerTest {
     val mockBrokerTopicStats = new BrokerTopicStats
     val mockLogDirFailureChannel = new LogDirFailureChannel(config.logDirs.size)
     val tp = new TopicPartition(topic, topicPartition)
+    val maxTransactionTimeoutMs = 30000
     val maxProducerIdExpirationMs = 30000
     val segments = new LogSegments(tp)
     val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(logDir, tp, mockLogDirFailureChannel, logConfig.recordVersion, "")
-    val producerStateManager = new ProducerStateManager(tp, logDir, maxProducerIdExpirationMs, time)
+    val producerStateManager = new ProducerStateManager(tp, logDir,
+      maxTransactionTimeoutMs, maxProducerIdExpirationMs, time)
     val offsets = new LogLoader(
       logDir,
       tp,
@@ -1765,7 +1832,6 @@ class ReplicaManagerTest {
       segments,
       0L,
       0L,
-      maxProducerIdExpirationMs,
       leaderEpochCache,
       producerStateManager
     ).load()
diff --git a/core/src/test/scala/unit/kafka/tools/DumpLogSegmentsTest.scala b/core/src/test/scala/unit/kafka/tools/DumpLogSegmentsTest.scala
index acc5c69..04556aa 100644
--- a/core/src/test/scala/unit/kafka/tools/DumpLogSegmentsTest.scala
+++ b/core/src/test/scala/unit/kafka/tools/DumpLogSegmentsTest.scala
@@ -59,10 +59,21 @@ class DumpLogSegmentsTest {
   def setUp(): Unit = {
     val props = new Properties
     props.setProperty(LogConfig.IndexIntervalBytesProp, "128")
-    log = UnifiedLog(logDir, LogConfig(props), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
-      time = time, brokerTopicStats = new BrokerTopicStats, maxProducerIdExpirationMs = 60 * 60 * 1000,
+    log = UnifiedLog(
+      dir = logDir,
+      config = LogConfig(props),
+      logStartOffset = 0L,
+      recoveryPoint = 0L,
+      scheduler = time.scheduler,
+      time = time,
+      brokerTopicStats = new BrokerTopicStats,
+      maxTransactionTimeoutMs = 5 * 60 * 1000,
+      maxProducerIdExpirationMs = 60 * 60 * 1000,
       producerIdExpirationCheckIntervalMs = LogManager.ProducerIdExpirationCheckIntervalMs,
-      logDirFailureChannel = new LogDirFailureChannel(10), topicId = None, keepPartitionMetadataFile = true)
+      logDirFailureChannel = new LogDirFailureChannel(10),
+      topicId = None,
+      keepPartitionMetadataFile = true
+    )
   }
 
   def addSimpleRecords(): Unit = {
diff --git a/core/src/test/scala/unit/kafka/utils/SchedulerTest.scala b/core/src/test/scala/unit/kafka/utils/SchedulerTest.scala
index fee9a9d..91c8f27 100644
--- a/core/src/test/scala/unit/kafka/utils/SchedulerTest.scala
+++ b/core/src/test/scala/unit/kafka/utils/SchedulerTest.scala
@@ -118,12 +118,14 @@ class SchedulerTest {
     val logDir = TestUtils.randomPartitionLogDir(tmpDir)
     val logConfig = LogConfig(new Properties())
     val brokerTopicStats = new BrokerTopicStats
+    val maxTransactionTimeoutMs = 5 * 60 * 1000
     val maxProducerIdExpirationMs = 60 * 60 * 1000
     val topicPartition = UnifiedLog.parseTopicPartitionName(logDir)
     val logDirFailureChannel = new LogDirFailureChannel(10)
     val segments = new LogSegments(topicPartition)
     val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(logDir, topicPartition, logDirFailureChannel, logConfig.recordVersion, "")
-    val producerStateManager = new ProducerStateManager(topicPartition, logDir, maxProducerIdExpirationMs, mockTime)
+    val producerStateManager = new ProducerStateManager(topicPartition, logDir,
+      maxTransactionTimeoutMs, maxProducerIdExpirationMs, mockTime)
     val offsets = new LogLoader(
       logDir,
       topicPartition,
@@ -135,7 +137,6 @@ class SchedulerTest {
       segments,
       0L,
       0L,
-      maxProducerIdExpirationMs,
       leaderEpochCache,
       producerStateManager
     ).load()
diff --git a/core/src/test/scala/unit/kafka/utils/TestUtils.scala b/core/src/test/scala/unit/kafka/utils/TestUtils.scala
index 7bbb6e8..53bc8b0 100755
--- a/core/src/test/scala/unit/kafka/utils/TestUtils.scala
+++ b/core/src/test/scala/unit/kafka/utils/TestUtils.scala
@@ -29,7 +29,7 @@ import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger}
 import java.util.concurrent.{Callable, CompletableFuture, ExecutionException, Executors, TimeUnit}
 import java.util.{Arrays, Collections, Optional, Properties}
 
-import com.yammer.metrics.core.Meter
+import com.yammer.metrics.core.{Gauge, Meter}
 import javax.net.ssl.X509TrustManager
 import kafka.api._
 import kafka.cluster.{Broker, EndPoint, IsrChangeListener}
@@ -1247,6 +1247,7 @@ object TestUtils extends Logging {
                    flushRecoveryOffsetCheckpointMs = 10000L,
                    flushStartOffsetCheckpointMs = 10000L,
                    retentionCheckMs = 1000L,
+                   maxTransactionTimeoutMs = 5 * 60 * 1000,
                    maxPidExpirationMs = 60 * 60 * 1000,
                    scheduler = time.scheduler,
                    time = time,
@@ -1991,6 +1992,15 @@ object TestUtils extends Logging {
     total.toLong
   }
 
+  def yammerGaugeValue[T](metricName: String): Option[T] = {
+    KafkaYammerMetrics.defaultRegistry.allMetrics.asScala
+      .filter { case (k, _) => k.getMBeanName.endsWith(metricName) }
+      .values
+      .headOption
+      .map(_.asInstanceOf[Gauge[T]])
+      .map(_.value)
+  }
+
   def meterCount(metricName: String): Long = {
     KafkaYammerMetrics.defaultRegistry.allMetrics.asScala
       .filter { case (k, _) => k.getMBeanName.endsWith(metricName) }