You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by sa...@apache.org on 2023/02/12 07:31:22 UTC

[kafka] branch trunk updated: KAFKA-14480 Move/Rewrite ProducerStateManager to storage module. (#13040)

This is an automated email from the ASF dual-hosted git repository.

satishd pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new c576e02849b KAFKA-14480 Move/Rewrite ProducerStateManager to storage module. (#13040)
c576e02849b is described below

commit c576e02849b9e7ada80198de41499bcc2480bc93
Author: Satish Duggana <sa...@apache.org>
AuthorDate: Sun Feb 12 13:00:51 2023 +0530

    KAFKA-14480 Move/Rewrite ProducerStateManager to storage module. (#13040)
    
    KAFKA-14480 Move/Rewrite of ProducerStateManager to the storage module.
    
    Replaced `File.listFiles` with `Files.list` in ProducerStateManager.listSnapshotFiles
    Used `Path` instead of `File` in ProducerStateManager.isSnapshotFile to check whether the given path is a file or not and has a suffix of '.snapshot'.
---
 .../kafka/server/builders/LogManagerBuilder.java   |   2 +-
 core/src/main/scala/kafka/log/LocalLog.scala       |   8 +-
 core/src/main/scala/kafka/log/LogCleaner.scala     |   2 +-
 core/src/main/scala/kafka/log/LogLoader.scala      |  13 +-
 core/src/main/scala/kafka/log/LogManager.scala     |   2 +-
 core/src/main/scala/kafka/log/LogSegment.scala     |   4 +-
 .../scala/kafka/log/ProducerStateManager.scala     | 620 -------------------
 core/src/main/scala/kafka/log/UnifiedLog.scala     |  56 +-
 .../scala/kafka/log/remote/RemoteIndexCache.scala  |   8 +-
 .../main/scala/kafka/raft/KafkaMetadataLog.scala   |   4 +-
 .../scala/kafka/server/DynamicBrokerConfig.scala   |  27 +-
 core/src/main/scala/kafka/server/KafkaConfig.scala |   4 +-
 .../scala/kafka/server/ReplicaFetcherThread.scala  |   6 +-
 .../main/scala/kafka/tools/DumpLogSegments.scala   |   6 +-
 .../integration/kafka/api/TransactionsTest.scala   |  11 +-
 .../src/test/scala/other/kafka/StressTestLog.scala |   2 +-
 .../scala/other/kafka/TestLinearWriteSpeed.scala   |   2 +-
 .../unit/kafka/cluster/PartitionLockTest.scala     |   2 +-
 .../scala/unit/kafka/cluster/PartitionTest.scala   |   2 +-
 .../log/AbstractLogCleanerIntegrationTest.scala    |   2 +-
 .../unit/kafka/log/BrokerCompressionTest.scala     |   2 +-
 .../unit/kafka/log/LogCleanerManagerTest.scala     |   2 +-
 .../test/scala/unit/kafka/log/LogCleanerTest.scala |  16 +-
 .../scala/unit/kafka/log/LogConcurrencyTest.scala  |   2 +-
 .../test/scala/unit/kafka/log/LogLoaderTest.scala  |  75 +--
 .../test/scala/unit/kafka/log/LogManagerTest.scala |   2 +-
 .../test/scala/unit/kafka/log/LogSegmentTest.scala |   8 +-
 .../test/scala/unit/kafka/log/LogTestUtils.scala   |  11 +-
 .../unit/kafka/log/ProducerStateManagerTest.scala  | 190 +++---
 .../test/scala/unit/kafka/log/UnifiedLogTest.scala |  61 +-
 .../unit/kafka/server/ReplicaManagerTest.scala     |   4 +-
 .../unit/kafka/tools/DumpLogSegmentsTest.scala     |   4 +-
 .../scala/unit/kafka/utils/SchedulerTest.scala     |   4 +-
 .../test/scala/unit/kafka/utils/TestUtils.scala    |   2 +-
 .../kafka/storage/internals/log/LogFileUtils.java  |  40 ++
 .../internals/log/ProducerStateManager.java        | 678 +++++++++++++++++++++
 ...eUtils.java => ProducerStateManagerConfig.java} |  25 +-
 .../kafka/storage/internals/log/SnapshotFile.java  |   8 +
 38 files changed, 1018 insertions(+), 899 deletions(-)

diff --git a/core/src/main/java/kafka/server/builders/LogManagerBuilder.java b/core/src/main/java/kafka/server/builders/LogManagerBuilder.java
index cc7c254ce08..2bbd0e9bb44 100644
--- a/core/src/main/java/kafka/server/builders/LogManagerBuilder.java
+++ b/core/src/main/java/kafka/server/builders/LogManagerBuilder.java
@@ -18,7 +18,6 @@
 package kafka.server.builders;
 
 import kafka.log.LogManager;
-import kafka.log.ProducerStateManagerConfig;
 import kafka.server.BrokerTopicStats;
 import kafka.server.metadata.ConfigRepository;
 import org.apache.kafka.common.utils.Time;
@@ -27,6 +26,7 @@ import org.apache.kafka.storage.internals.log.CleanerConfig;
 import org.apache.kafka.storage.internals.log.LogConfig;
 import org.apache.kafka.storage.internals.log.LogDirFailureChannel;
 import org.apache.kafka.server.util.Scheduler;
+import org.apache.kafka.storage.internals.log.ProducerStateManagerConfig;
 import scala.collection.JavaConverters;
 
 import java.io.File;
diff --git a/core/src/main/scala/kafka/log/LocalLog.scala b/core/src/main/scala/kafka/log/LocalLog.scala
index 5b5f309295a..bfb14d9ad18 100644
--- a/core/src/main/scala/kafka/log/LocalLog.scala
+++ b/core/src/main/scala/kafka/log/LocalLog.scala
@@ -31,7 +31,7 @@ import org.apache.kafka.common.record.MemoryRecords
 import org.apache.kafka.common.utils.{Time, Utils}
 import org.apache.kafka.storage.internals.log.LogFileUtils.offsetFromFileName
 import org.apache.kafka.server.util.Scheduler
-import org.apache.kafka.storage.internals.log.{AbortedTxn, FetchDataInfo, LogConfig, LogDirFailureChannel, LogOffsetMetadata, OffsetPosition}
+import org.apache.kafka.storage.internals.log.{AbortedTxn, FetchDataInfo, LogConfig, LogDirFailureChannel, LogFileUtils, LogOffsetMetadata, OffsetPosition}
 
 import java.util.{Collections, Optional}
 import scala.jdk.CollectionConverters._
@@ -338,7 +338,7 @@ class LocalLog(@volatile private var _dir: File,
                                           asyncDelete: Boolean,
                                           reason: SegmentDeletionReason): LogSegment = {
     if (newOffset == segmentToDelete.baseOffset)
-      segmentToDelete.changeFileSuffixes("", DeletedFileSuffix)
+      segmentToDelete.changeFileSuffixes("", LogFileUtils.DELETED_FILE_SUFFIX)
 
     val newSegment = LogSegment.open(dir,
       baseOffset = newOffset,
@@ -984,8 +984,8 @@ object LocalLog extends Logging {
                                       logDirFailureChannel: LogDirFailureChannel,
                                       logPrefix: String): Unit = {
     segmentsToDelete.foreach { segment =>
-      if (!segment.hasSuffix(DeletedFileSuffix))
-        segment.changeFileSuffixes("", DeletedFileSuffix)
+      if (!segment.hasSuffix(LogFileUtils.DELETED_FILE_SUFFIX))
+        segment.changeFileSuffixes("", LogFileUtils.DELETED_FILE_SUFFIX)
     }
 
     def deleteSegments(): Unit = {
diff --git a/core/src/main/scala/kafka/log/LogCleaner.scala b/core/src/main/scala/kafka/log/LogCleaner.scala
index 187b03fec64..392a3a0da1b 100644
--- a/core/src/main/scala/kafka/log/LogCleaner.scala
+++ b/core/src/main/scala/kafka/log/LogCleaner.scala
@@ -656,7 +656,7 @@ private[log] class Cleaner(val id: Int,
                              deleteRetentionMs: Long,
                              maxLogMessageSize: Int,
                              transactionMetadata: CleanedTransactionMetadata,
-                             lastRecordsOfActiveProducers: Map[Long, LastRecord],
+                             lastRecordsOfActiveProducers: mutable.Map[Long, LastRecord],
                              stats: CleanerStats,
                              currentTime: Long): Unit = {
     val logCleanerFilter: RecordFilter = new RecordFilter(currentTime, deleteRetentionMs) {
diff --git a/core/src/main/scala/kafka/log/LogLoader.scala b/core/src/main/scala/kafka/log/LogLoader.scala
index a16ad04d133..cc0232aef28 100644
--- a/core/src/main/scala/kafka/log/LogLoader.scala
+++ b/core/src/main/scala/kafka/log/LogLoader.scala
@@ -20,7 +20,7 @@ package kafka.log
 import java.io.{File, IOException}
 import java.nio.file.{Files, NoSuchFileException}
 import kafka.common.LogSegmentOffsetOverflowException
-import kafka.log.UnifiedLog.{CleanedFileSuffix, DeletedFileSuffix, SwapFileSuffix, isIndexFile, isLogFile, offsetFromFile}
+import kafka.log.UnifiedLog.{CleanedFileSuffix, SwapFileSuffix, isIndexFile, isLogFile, offsetFromFile}
 import kafka.utils.Logging
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.errors.InvalidOffsetException
@@ -28,10 +28,11 @@ import org.apache.kafka.common.utils.{Time, Utils}
 import org.apache.kafka.snapshot.Snapshots
 import org.apache.kafka.server.util.Scheduler
 import org.apache.kafka.storage.internals.epoch.LeaderEpochFileCache
-import org.apache.kafka.storage.internals.log.{CorruptIndexException, LoadedLogOffsets, LogConfig, LogDirFailureChannel, LogOffsetMetadata}
+import org.apache.kafka.storage.internals.log.{CorruptIndexException, LoadedLogOffsets, LogConfig, LogDirFailureChannel, LogFileUtils, LogOffsetMetadata, ProducerStateManager}
 
 import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
 import scala.collection.{Set, mutable}
+import scala.jdk.CollectionConverters._
 
 object LogLoader extends Logging {
 
@@ -191,7 +192,7 @@ class LogLoader(
     // Reload all snapshots into the ProducerStateManager cache, the intermediate ProducerStateManager used
     // during log recovery may have deleted some files without the LogLoader.producerStateManager instance witnessing the
     // deletion.
-    producerStateManager.removeStraySnapshots(segments.baseOffsets.toSeq)
+    producerStateManager.removeStraySnapshots(segments.baseOffsets.map(x => Long.box(x)).asJavaCollection)
     UnifiedLog.rebuildProducerState(
       producerStateManager,
       segments,
@@ -229,7 +230,7 @@ class LogLoader(
 
       // Delete stray files marked for deletion, but skip KRaft snapshots.
       // These are handled in the recovery logic in `KafkaMetadataLog`.
-      if (filename.endsWith(DeletedFileSuffix) && !filename.endsWith(Snapshots.DELETE_SUFFIX)) {
+      if (filename.endsWith(LogFileUtils.DELETED_FILE_SUFFIX) && !filename.endsWith(Snapshots.DELETE_SUFFIX)) {
         debug(s"Deleting stray temporary file ${file.getAbsolutePath}")
         Files.deleteIfExists(file.toPath)
       } else if (filename.endsWith(CleanedFileSuffix)) {
@@ -354,8 +355,8 @@ class LogLoader(
     val producerStateManager = new ProducerStateManager(
       topicPartition,
       dir,
-      this.producerStateManager.maxTransactionTimeoutMs,
-      this.producerStateManager.producerStateManagerConfig,
+      this.producerStateManager.maxTransactionTimeoutMs(),
+      this.producerStateManager.producerStateManagerConfig(),
       time)
     UnifiedLog.rebuildProducerState(
       producerStateManager,
diff --git a/core/src/main/scala/kafka/log/LogManager.scala b/core/src/main/scala/kafka/log/LogManager.scala
index f3aaeabee89..6b2704d0fc1 100755
--- a/core/src/main/scala/kafka/log/LogManager.scala
+++ b/core/src/main/scala/kafka/log/LogManager.scala
@@ -43,7 +43,7 @@ import java.util.Properties
 import org.apache.kafka.server.common.MetadataVersion
 import org.apache.kafka.storage.internals.log.LogConfig.MessageFormatVersion
 import org.apache.kafka.server.util.Scheduler
-import org.apache.kafka.storage.internals.log.{CleanerConfig, LogConfig, LogDirFailureChannel}
+import org.apache.kafka.storage.internals.log.{CleanerConfig, LogConfig, LogDirFailureChannel, ProducerStateManagerConfig}
 
 import scala.annotation.nowarn
 
diff --git a/core/src/main/scala/kafka/log/LogSegment.scala b/core/src/main/scala/kafka/log/LogSegment.scala
index a9abc5466c1..e5fe0dfd585 100644
--- a/core/src/main/scala/kafka/log/LogSegment.scala
+++ b/core/src/main/scala/kafka/log/LogSegment.scala
@@ -27,7 +27,7 @@ import org.apache.kafka.common.record.FileRecords.{LogOffsetPosition, TimestampA
 import org.apache.kafka.common.record._
 import org.apache.kafka.common.utils.{BufferSupplier, Time, Utils}
 import org.apache.kafka.storage.internals.epoch.LeaderEpochFileCache
-import org.apache.kafka.storage.internals.log.{AbortedTxn, AppendOrigin, CompletedTxn, FetchDataInfo, LazyIndex, LogConfig, LogOffsetMetadata, OffsetIndex, OffsetPosition, TimeIndex, TimestampOffset, TransactionIndex, TxnIndexSearchResult}
+import org.apache.kafka.storage.internals.log.{AbortedTxn, AppendOrigin, CompletedTxn, FetchDataInfo, LazyIndex, LogConfig, LogOffsetMetadata, OffsetIndex, OffsetPosition, ProducerStateManager, TimeIndex, TimestampOffset, TransactionIndex, TxnIndexSearchResult}
 
 import java.io.{File, IOException}
 import java.nio.file.attribute.FileTime
@@ -248,7 +248,7 @@ class LogSegment private[log] (val log: FileRecords,
   private def updateProducerState(producerStateManager: ProducerStateManager, batch: RecordBatch): Unit = {
     if (batch.hasProducerId) {
       val producerId = batch.producerId
-      val appendInfo = producerStateManager.prepareUpdate(producerId, origin = AppendOrigin.REPLICATION)
+      val appendInfo = producerStateManager.prepareUpdate(producerId, AppendOrigin.REPLICATION)
       val maybeCompletedTxn = appendInfo.append(batch, Optional.empty())
       producerStateManager.update(appendInfo)
       maybeCompletedTxn.ifPresent(completedTxn => {
diff --git a/core/src/main/scala/kafka/log/ProducerStateManager.scala b/core/src/main/scala/kafka/log/ProducerStateManager.scala
deleted file mode 100644
index 6a79f0e7166..00000000000
--- a/core/src/main/scala/kafka/log/ProducerStateManager.scala
+++ /dev/null
@@ -1,620 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package kafka.log
-
-import kafka.server.{BrokerReconfigurable, KafkaConfig}
-import kafka.utils.{Logging, nonthreadsafe, threadsafe}
-import org.apache.kafka.common.TopicPartition
-import org.apache.kafka.common.config.ConfigException
-import org.apache.kafka.common.protocol.types._
-import org.apache.kafka.common.record.RecordBatch
-import org.apache.kafka.common.utils.{ByteUtils, Crc32C, Time}
-import org.apache.kafka.storage.internals.log.{AppendOrigin, BatchMetadata, CompletedTxn, CorruptSnapshotException, LogOffsetMetadata, ProducerAppendInfo, ProducerStateEntry, SnapshotFile, TxnMetadata}
-
-import java.io.File
-import java.nio.ByteBuffer
-import java.nio.channels.FileChannel
-import java.nio.file.{Files, NoSuchFileException, StandardOpenOption}
-import java.util.{Optional, OptionalLong}
-import java.util.concurrent.ConcurrentSkipListMap
-import scala.collection.{immutable, mutable}
-import scala.jdk.CollectionConverters._
-
-object ProducerStateManager {
-  val LateTransactionBufferMs = 5 * 60 * 1000
-
-  private val ProducerSnapshotVersion: Short = 1
-  private val VersionField = "version"
-  private val CrcField = "crc"
-  private val ProducerIdField = "producer_id"
-  private val LastSequenceField = "last_sequence"
-  private val ProducerEpochField = "epoch"
-  private val LastOffsetField = "last_offset"
-  private val OffsetDeltaField = "offset_delta"
-  private val TimestampField = "timestamp"
-  private val ProducerEntriesField = "producer_entries"
-  private val CoordinatorEpochField = "coordinator_epoch"
-  private val CurrentTxnFirstOffsetField = "current_txn_first_offset"
-
-  private val VersionOffset = 0
-  private val CrcOffset = VersionOffset + 2
-  private val ProducerEntriesOffset = CrcOffset + 4
-
-  val ProducerSnapshotEntrySchema = new Schema(
-    new Field(ProducerIdField, Type.INT64, "The producer ID"),
-    new Field(ProducerEpochField, Type.INT16, "Current epoch of the producer"),
-    new Field(LastSequenceField, Type.INT32, "Last written sequence of the producer"),
-    new Field(LastOffsetField, Type.INT64, "Last written offset of the producer"),
-    new Field(OffsetDeltaField, Type.INT32, "The difference of the last sequence and first sequence in the last written batch"),
-    new Field(TimestampField, Type.INT64, "Max timestamp from the last written entry"),
-    new Field(CoordinatorEpochField, Type.INT32, "The epoch of the last transaction coordinator to send an end transaction marker"),
-    new Field(CurrentTxnFirstOffsetField, Type.INT64, "The first offset of the on-going transaction (-1 if there is none)"))
-  val PidSnapshotMapSchema = new Schema(
-    new Field(VersionField, Type.INT16, "Version of the snapshot file"),
-    new Field(CrcField, Type.UNSIGNED_INT32, "CRC of the snapshot data"),
-    new Field(ProducerEntriesField, new ArrayOf(ProducerSnapshotEntrySchema), "The entries in the producer table"))
-
-  def readSnapshot(file: File): Iterable[ProducerStateEntry] = {
-    try {
-      val buffer = Files.readAllBytes(file.toPath)
-      val struct = PidSnapshotMapSchema.read(ByteBuffer.wrap(buffer))
-
-      val version = struct.getShort(VersionField)
-      if (version != ProducerSnapshotVersion)
-        throw new CorruptSnapshotException(s"Snapshot contained an unknown file version $version")
-
-      val crc = struct.getUnsignedInt(CrcField)
-      val computedCrc =  Crc32C.compute(buffer, ProducerEntriesOffset, buffer.length - ProducerEntriesOffset)
-      if (crc != computedCrc)
-        throw new CorruptSnapshotException(s"Snapshot is corrupt (CRC is no longer valid). " +
-          s"Stored crc: $crc. Computed crc: $computedCrc")
-
-      struct.getArray(ProducerEntriesField).map { producerEntryObj =>
-        val producerEntryStruct = producerEntryObj.asInstanceOf[Struct]
-        val producerId = producerEntryStruct.getLong(ProducerIdField)
-        val producerEpoch = producerEntryStruct.getShort(ProducerEpochField)
-        val seq = producerEntryStruct.getInt(LastSequenceField)
-        val offset = producerEntryStruct.getLong(LastOffsetField)
-        val timestamp = producerEntryStruct.getLong(TimestampField)
-        val offsetDelta = producerEntryStruct.getInt(OffsetDeltaField)
-        val coordinatorEpoch = producerEntryStruct.getInt(CoordinatorEpochField)
-        val currentTxnFirstOffset = producerEntryStruct.getLong(CurrentTxnFirstOffsetField)
-        val batchMetadata =
-          if (offset >= 0) Optional.of(new BatchMetadata(seq, offset, offsetDelta, timestamp))
-          else Optional.empty[BatchMetadata]()
-        val currentTxnFirstOffsetValue = if (currentTxnFirstOffset >= 0) OptionalLong.of(currentTxnFirstOffset) else OptionalLong.empty()
-        new ProducerStateEntry(producerId, producerEpoch, coordinatorEpoch, timestamp, currentTxnFirstOffsetValue, batchMetadata)
-      }
-    } catch {
-      case e: SchemaException =>
-        throw new CorruptSnapshotException(s"Snapshot failed schema validation: ${e.getMessage}")
-    }
-  }
-
-  private def writeSnapshot(file: File, entries: mutable.Map[Long, ProducerStateEntry]): Unit = {
-    val struct = new Struct(PidSnapshotMapSchema)
-    struct.set(VersionField, ProducerSnapshotVersion)
-    struct.set(CrcField, 0L) // we'll fill this after writing the entries
-    val entriesArray = entries.map {
-      case (producerId, entry) =>
-        val producerEntryStruct = struct.instance(ProducerEntriesField)
-        producerEntryStruct.set(ProducerIdField, producerId)
-          .set(ProducerEpochField, entry.producerEpoch)
-          .set(LastSequenceField, entry.lastSeq)
-          .set(LastOffsetField, entry.lastDataOffset)
-          .set(OffsetDeltaField, entry.lastOffsetDelta)
-          .set(TimestampField, entry.lastTimestamp)
-          .set(CoordinatorEpochField, entry.coordinatorEpoch)
-          .set(CurrentTxnFirstOffsetField, entry.currentTxnFirstOffset.orElse(-1L))
-        producerEntryStruct
-    }.toArray
-    struct.set(ProducerEntriesField, entriesArray)
-
-    val buffer = ByteBuffer.allocate(struct.sizeOf)
-    struct.writeTo(buffer)
-    buffer.flip()
-
-    // now fill in the CRC
-    val crc = Crc32C.compute(buffer, ProducerEntriesOffset, buffer.limit() - ProducerEntriesOffset)
-    ByteUtils.writeUnsignedInt(buffer, CrcOffset, crc)
-
-    val fileChannel = FileChannel.open(file.toPath, StandardOpenOption.CREATE, StandardOpenOption.WRITE)
-    try {
-      fileChannel.write(buffer)
-      fileChannel.force(true)
-    } finally {
-      fileChannel.close()
-    }
-  }
-
-  private def isSnapshotFile(file: File): Boolean = file.getName.endsWith(UnifiedLog.ProducerSnapshotFileSuffix)
-
-  // visible for testing
-  private[log] def listSnapshotFiles(dir: File): Seq[SnapshotFile] = {
-    if (dir.exists && dir.isDirectory) {
-      Option(dir.listFiles).map { files =>
-        files.filter(f => f.isFile && isSnapshotFile(f)).map(new SnapshotFile(_)).toSeq
-      }.getOrElse(Seq.empty)
-    } else Seq.empty
-  }
-}
-
-/**
- * Maintains a mapping from ProducerIds to metadata about the last appended entries (e.g.
- * epoch, sequence number, last offset, etc.)
- *
- * The sequence number is the last number successfully appended to the partition for the given identifier.
- * The epoch is used for fencing against zombie writers. The offset is the one of the last successful message
- * appended to the partition.
- *
- * As long as a producer id is contained in the map, the corresponding producer can continue to write data.
- * However, producer ids can be expired due to lack of recent use or if the last written entry has been deleted from
- * the log (e.g. if the retention policy is "delete"). For compacted topics, the log cleaner will ensure
- * that the most recent entry from a given producer id is retained in the log provided it hasn't expired due to
- * age. This ensures that producer ids will not be expired until either the max expiration time has been reached,
- * or if the topic also is configured for deletion, the segment containing the last written offset has
- * been deleted.
- */
-@nonthreadsafe
-class ProducerStateManager(
-  val topicPartition: TopicPartition,
-  @volatile var _logDir: File,
-  val maxTransactionTimeoutMs: Int,
-  val producerStateManagerConfig: ProducerStateManagerConfig,
-  val time: Time
-) extends Logging {
-  import ProducerStateManager._
-  import java.util
-
-  this.logIdent = s"[ProducerStateManager partition=$topicPartition] "
-
-  private var snapshots: ConcurrentSkipListMap[java.lang.Long, SnapshotFile] = locally {
-    loadSnapshots()
-  }
-
-  private val producers = mutable.Map.empty[Long, ProducerStateEntry]
-  private var lastMapOffset = 0L
-  private var lastSnapOffset = 0L
-
-  // Keep track of the last timestamp from the oldest transaction. This is used
-  // to detect (approximately) when a transaction has been left hanging on a partition.
-  // We make the field volatile so that it can be safely accessed without a lock.
-  @volatile private var oldestTxnLastTimestamp: Long = -1L
-
-  // ongoing transactions sorted by the first offset of the transaction
-  private val ongoingTxns = new util.TreeMap[Long, TxnMetadata]
-
-  // completed transactions whose markers are at offsets above the high watermark
-  private val unreplicatedTxns = new util.TreeMap[Long, TxnMetadata]
-
-  @threadsafe
-  def hasLateTransaction(currentTimeMs: Long): Boolean = {
-    val lastTimestamp = oldestTxnLastTimestamp
-    lastTimestamp > 0 && (currentTimeMs - lastTimestamp) > maxTransactionTimeoutMs + ProducerStateManager.LateTransactionBufferMs
-  }
-
-  def truncateFullyAndReloadSnapshots(): Unit = {
-    info("Reloading the producer state snapshots")
-    truncateFullyAndStartAt(0L)
-    snapshots = loadSnapshots()
-  }
-
-  /**
-   * Load producer state snapshots by scanning the _logDir.
-   */
-  private def loadSnapshots(): ConcurrentSkipListMap[java.lang.Long, SnapshotFile] = {
-    val offsetToSnapshots = new ConcurrentSkipListMap[java.lang.Long, SnapshotFile]()
-    for (snapshotFile <- listSnapshotFiles(_logDir)) {
-      offsetToSnapshots.put(snapshotFile.offset, snapshotFile)
-    }
-    offsetToSnapshots
-  }
-
-  /**
-   * Scans the log directory, gathering all producer state snapshot files. Snapshot files which do not have an offset
-   * corresponding to one of the provided offsets in segmentBaseOffsets will be removed, except in the case that there
-   * is a snapshot file at a higher offset than any offset in segmentBaseOffsets.
-   *
-   * The goal here is to remove any snapshot files which do not have an associated segment file, but not to remove the
-   * largest stray snapshot file which was emitted during clean shutdown.
-   */
-  private[log] def removeStraySnapshots(segmentBaseOffsets: Seq[Long]): Unit = {
-    val maxSegmentBaseOffset = if (segmentBaseOffsets.isEmpty) None else Some(segmentBaseOffsets.max)
-    val baseOffsets = segmentBaseOffsets.toSet
-    var latestStraySnapshot: Option[SnapshotFile] = None
-
-    val ss = loadSnapshots()
-    for (snapshot <- ss.values().asScala) {
-      val key = snapshot.offset
-      latestStraySnapshot match {
-        case Some(prev) =>
-          if (!baseOffsets.contains(key)) {
-            // this snapshot is now the largest stray snapshot.
-            prev.deleteIfExists()
-            ss.remove(prev.offset)
-            latestStraySnapshot = Some(snapshot)
-          }
-        case None =>
-          if (!baseOffsets.contains(key)) {
-            latestStraySnapshot = Some(snapshot)
-          }
-      }
-    }
-
-    // Check to see if the latestStraySnapshot is larger than the largest segment base offset, if it is not,
-    // delete the largestStraySnapshot.
-    for (strayOffset <- latestStraySnapshot.map(_.offset); maxOffset <- maxSegmentBaseOffset) {
-      if (strayOffset < maxOffset) {
-        Option(ss.remove(strayOffset)).foreach(_.deleteIfExists())
-      }
-    }
-
-    this.snapshots = ss
-  }
-
-  /**
-   * An unstable offset is one which is either undecided (i.e. its ultimate outcome is not yet known),
-   * or one that is decided, but may not have been replicated (i.e. any transaction which has a COMMIT/ABORT
-   * marker written at a higher offset than the current high watermark).
-   */
-  def firstUnstableOffset: Option[LogOffsetMetadata] = {
-    val unreplicatedFirstOffset = Option(unreplicatedTxns.firstEntry).map(_.getValue.firstOffset)
-    val undecidedFirstOffset = Option(ongoingTxns.firstEntry).map(_.getValue.firstOffset)
-    if (unreplicatedFirstOffset.isEmpty)
-      undecidedFirstOffset
-    else if (undecidedFirstOffset.isEmpty)
-      unreplicatedFirstOffset
-    else if (undecidedFirstOffset.get.messageOffset < unreplicatedFirstOffset.get.messageOffset)
-      undecidedFirstOffset
-    else
-      unreplicatedFirstOffset
-  }
-
-  /**
-   * Acknowledge all transactions which have been completed before a given offset. This allows the LSO
-   * to advance to the next unstable offset.
-   */
-  def onHighWatermarkUpdated(highWatermark: Long): Unit = {
-    removeUnreplicatedTransactions(highWatermark)
-  }
-
-  /**
-   * The first undecided offset is the earliest transactional message which has not yet been committed
-   * or aborted. Unlike [[firstUnstableOffset]], this does not reflect the state of replication (i.e.
-   * whether a completed transaction marker is beyond the high watermark).
-   */
-  private[log] def firstUndecidedOffset: Option[Long] = Option(ongoingTxns.firstEntry).map(_.getValue.firstOffset.messageOffset)
-
-  /**
-   * Returns the last offset of this map
-   */
-  def mapEndOffset: Long = lastMapOffset
-
-  /**
-   * Get a copy of the active producers
-   */
-  def activeProducers: immutable.Map[Long, ProducerStateEntry] = producers.toMap
-
-  def isEmpty: Boolean = producers.isEmpty && unreplicatedTxns.isEmpty
-
-  private def loadFromSnapshot(logStartOffset: Long, currentTime: Long): Unit = {
-    while (true) {
-      latestSnapshotFile match {
-        case Some(snapshot) =>
-          try {
-            info(s"Loading producer state from snapshot file '$snapshot'")
-            val loadedProducers = readSnapshot(snapshot.file).filter { producerEntry => !isProducerExpired(currentTime, producerEntry) }
-            loadedProducers.foreach(loadProducerEntry)
-            lastSnapOffset = snapshot.offset
-            lastMapOffset = lastSnapOffset
-            updateOldestTxnTimestamp()
-            return
-          } catch {
-            case e: CorruptSnapshotException =>
-              warn(s"Failed to load producer snapshot from '${snapshot.file}': ${e.getMessage}")
-              removeAndDeleteSnapshot(snapshot.offset)
-          }
-        case None =>
-          lastSnapOffset = logStartOffset
-          lastMapOffset = logStartOffset
-          return
-      }
-    }
-  }
-
-  // visible for testing
-  private[log] def loadProducerEntry(entry: ProducerStateEntry): Unit = {
-    val producerId = entry.producerId
-    producers.put(producerId, entry)
-    entry.currentTxnFirstOffset.ifPresent((offset: Long) => ongoingTxns.put(offset, new TxnMetadata(producerId, offset)))
-  }
-
-  private def isProducerExpired(currentTimeMs: Long, producerState: ProducerStateEntry): Boolean =
-    !producerState.currentTxnFirstOffset.isPresent && currentTimeMs - producerState.lastTimestamp >= producerStateManagerConfig.producerIdExpirationMs
-
-  /**
-   * Expire any producer ids which have been idle longer than the configured maximum expiration timeout.
-   */
-  def removeExpiredProducers(currentTimeMs: Long): Unit = {
-    producers --= producers.filter { case (_, lastEntry) => isProducerExpired(currentTimeMs, lastEntry) }.keySet
-  }
-
-  /**
-   * Truncate the producer id mapping to the given offset range and reload the entries from the most recent
-   * snapshot in range (if there is one). We delete snapshot files prior to the logStartOffset but do not remove
-   * producer state from the map. This means that in-memory and on-disk state can diverge, and in the case of
-   * broker failover or unclean shutdown, any in-memory state not persisted in the snapshots will be lost, which
-   * would lead to UNKNOWN_PRODUCER_ID errors. Note that the log end offset is assumed to be less than or equal
-   * to the high watermark.
-   */
-  def truncateAndReload(logStartOffset: Long, logEndOffset: Long, currentTimeMs: Long): Unit = {
-    // remove all out of range snapshots
-    snapshots.values().asScala.foreach { snapshot =>
-      if (snapshot.offset > logEndOffset || snapshot.offset <= logStartOffset) {
-        removeAndDeleteSnapshot(snapshot.offset)
-      }
-    }
-
-    if (logEndOffset != mapEndOffset) {
-      producers.clear()
-      ongoingTxns.clear()
-      updateOldestTxnTimestamp()
-
-      // since we assume that the offset is less than or equal to the high watermark, it is
-      // safe to clear the unreplicated transactions
-      unreplicatedTxns.clear()
-      loadFromSnapshot(logStartOffset, currentTimeMs)
-    } else {
-      onLogStartOffsetIncremented(logStartOffset)
-    }
-  }
-
-  def prepareUpdate(producerId: Long, origin: AppendOrigin): ProducerAppendInfo = {
-    val currentEntry = lastEntry(producerId).getOrElse(ProducerStateEntry.empty(producerId))
-    new ProducerAppendInfo(topicPartition, producerId, currentEntry, origin)
-  }
-
-  /**
-   * Update the mapping with the given append information
-   */
-  def update(appendInfo: ProducerAppendInfo): Unit = {
-    if (appendInfo.producerId() == RecordBatch.NO_PRODUCER_ID)
-      throw new IllegalArgumentException(s"Invalid producer id ${appendInfo.producerId()} passed to update " +
-        s"for partition $topicPartition")
-
-    trace(s"Updated producer ${appendInfo.producerId} state to $appendInfo")
-    val updatedEntry = appendInfo.toEntry
-    producers.get(appendInfo.producerId) match {
-      case Some(currentEntry) =>
-        currentEntry.update(updatedEntry)
-
-      case None =>
-        producers.put(appendInfo.producerId, updatedEntry)
-    }
-
-    appendInfo.startedTransactions.asScala.foreach { txn =>
-      ongoingTxns.put(txn.firstOffset.messageOffset, txn)
-    }
-
-    updateOldestTxnTimestamp()
-  }
-
-  private def updateOldestTxnTimestamp(): Unit = {
-    val firstEntry = ongoingTxns.firstEntry()
-    if (firstEntry == null) {
-      oldestTxnLastTimestamp = -1
-    } else {
-      val oldestTxnMetadata = firstEntry.getValue
-      oldestTxnLastTimestamp = producers.get(oldestTxnMetadata.producerId)
-        .map(_.lastTimestamp)
-        .getOrElse(-1L)
-    }
-  }
-
-  def updateMapEndOffset(lastOffset: Long): Unit = {
-    lastMapOffset = lastOffset
-  }
-
-  /**
-   * Get the last written entry for the given producer id.
-   */
-  def lastEntry(producerId: Long): Option[ProducerStateEntry] = producers.get(producerId)
-
-  /**
-   * Take a snapshot at the current end offset if one does not already exist.
-   */
-  def takeSnapshot(): Unit = {
-    // If not a new offset, then it is not worth taking another snapshot
-    if (lastMapOffset > lastSnapOffset) {
-      val snapshotFile = new SnapshotFile(UnifiedLog.producerSnapshotFile(_logDir, lastMapOffset))
-      val start = time.hiResClockMs()
-      writeSnapshot(snapshotFile.file, producers)
-      info(s"Wrote producer snapshot at offset $lastMapOffset with ${producers.size} producer ids in ${time.hiResClockMs() - start} ms.")
-
-      snapshots.put(snapshotFile.offset, snapshotFile)
-
-      // Update the last snap offset according to the serialized map
-      lastSnapOffset = lastMapOffset
-    }
-  }
-
-  /**
-   * Update the parentDir for this ProducerStateManager and all of the snapshot files which it manages.
-   */
-  def updateParentDir(parentDir: File): Unit = {
-    _logDir = parentDir
-    snapshots.forEach((_, s) => s.updateParentDir(parentDir))
-  }
-
-  /**
-   * Get the last offset (exclusive) of the latest snapshot file.
-   */
-  def latestSnapshotOffset: Option[Long] = latestSnapshotFile.map(_.offset)
-
-  /**
-   * Get the last offset (exclusive) of the oldest snapshot file.
-   */
-  def oldestSnapshotOffset: Option[Long] = oldestSnapshotFile.map(_.offset)
-
-  /**
-   * Visible for testing
-   */
-  private[log] def snapshotFileForOffset(offset: Long): Option[SnapshotFile] = {
-    Option(snapshots.get(offset))
-  }
-
-  /**
-   * Remove any unreplicated transactions lower than the provided logStartOffset and bring the lastMapOffset forward
-   * if necessary.
-   */
-  def onLogStartOffsetIncremented(logStartOffset: Long): Unit = {
-    removeUnreplicatedTransactions(logStartOffset)
-
-    if (lastMapOffset < logStartOffset)
-      lastMapOffset = logStartOffset
-
-    lastSnapOffset = latestSnapshotOffset.getOrElse(logStartOffset)
-  }
-
-  private def removeUnreplicatedTransactions(offset: Long): Unit = {
-    val iterator = unreplicatedTxns.entrySet.iterator
-    while (iterator.hasNext) {
-      val txnEntry = iterator.next()
-      val lastOffset = txnEntry.getValue.lastOffset
-      if (lastOffset.isPresent && lastOffset.getAsLong < offset)
-        iterator.remove()
-    }
-  }
-
-  /**
-   * Truncate the producer id mapping and remove all snapshots. This resets the state of the mapping.
-   */
-  def truncateFullyAndStartAt(offset: Long): Unit = {
-    producers.clear()
-    ongoingTxns.clear()
-    unreplicatedTxns.clear()
-    snapshots.values().asScala.foreach { snapshot =>
-      removeAndDeleteSnapshot(snapshot.offset)
-    }
-    lastSnapOffset = 0L
-    lastMapOffset = offset
-    updateOldestTxnTimestamp()
-  }
-
-  /**
-   * Compute the last stable offset of a completed transaction, but do not yet mark the transaction complete.
-   * That will be done in `completeTxn` below. This is used to compute the LSO that will be appended to the
-   * transaction index, but the completion must be done only after successfully appending to the index.
-   */
-  def lastStableOffset(completedTxn: CompletedTxn): Long = {
-    val nextIncompleteTxn = ongoingTxns.values.asScala.find(_.producerId != completedTxn.producerId)
-    nextIncompleteTxn.map(_.firstOffset.messageOffset).getOrElse(completedTxn.lastOffset  + 1)
-  }
-
-  /**
-   * Mark a transaction as completed. We will still await advancement of the high watermark before
-   * advancing the first unstable offset.
-   */
-  def completeTxn(completedTxn: CompletedTxn): Unit = {
-    val txnMetadata = ongoingTxns.remove(completedTxn.firstOffset)
-    if (txnMetadata == null)
-      throw new IllegalArgumentException(s"Attempted to complete transaction $completedTxn on partition $topicPartition " +
-        s"which was not started")
-
-    txnMetadata.lastOffset = OptionalLong.of(completedTxn.lastOffset)
-    unreplicatedTxns.put(completedTxn.firstOffset, txnMetadata)
-    updateOldestTxnTimestamp()
-  }
-
-  @threadsafe
-  def deleteSnapshotsBefore(offset: Long): Unit = {
-    snapshots.subMap(0, offset).values().asScala.foreach { snapshot =>
-      removeAndDeleteSnapshot(snapshot.offset)
-    }
-  }
-
-  private def oldestSnapshotFile: Option[SnapshotFile] = {
-    Option(snapshots.firstEntry()).map(_.getValue)
-  }
-
-  private def latestSnapshotFile: Option[SnapshotFile] = {
-    Option(snapshots.lastEntry()).map(_.getValue)
-  }
-
-  /**
-   * Removes the producer state snapshot file metadata corresponding to the provided offset if it exists from this
-   * ProducerStateManager, and deletes the backing snapshot file.
-   */
-  private def removeAndDeleteSnapshot(snapshotOffset: Long): Unit = {
-    Option(snapshots.remove(snapshotOffset)).foreach(_.deleteIfExists())
-  }
-
-  /**
-   * Removes the producer state snapshot file metadata corresponding to the provided offset if it exists from this
-   * ProducerStateManager, and renames the backing snapshot file to have the Log.DeletionSuffix.
-   *
-   * Note: This method is safe to use with async deletes. If a race occurs and the snapshot file
-   *       is deleted without this ProducerStateManager instance knowing, the resulting exception on
-   *       SnapshotFile rename will be ignored and None will be returned.
-   */
-  private[log] def removeAndMarkSnapshotForDeletion(snapshotOffset: Long): Option[SnapshotFile] = {
-    Option(snapshots.remove(snapshotOffset)).flatMap { snapshot => {
-      // If the file cannot be renamed, it likely means that the file was deleted already.
-      // This can happen due to the way we construct an intermediate producer state manager
-      // during log recovery, and use it to issue deletions prior to creating the "real"
-      // producer state manager.
-      //
-      // In any case, removeAndMarkSnapshotForDeletion is intended to be used for snapshot file
-      // deletion, so ignoring the exception here just means that the intended operation was
-      // already completed.
-      try {
-        snapshot.renameTo(UnifiedLog.DeletedFileSuffix)
-        Some(snapshot)
-      } catch {
-        case _: NoSuchFileException =>
-          info(s"Failed to rename producer state snapshot ${snapshot.file.getAbsoluteFile} with deletion suffix because it was already deleted")
-          None
-      }
-    }
-    }
-  }
-}
-
-
-
-
-
-class ProducerStateManagerConfig(@volatile var producerIdExpirationMs: Int) extends Logging with BrokerReconfigurable {
-
-  override def reconfigurableConfigs: Set[String] = ProducerStateManagerConfig.ReconfigurableConfigs
-
-  override def reconfigure(oldConfig: KafkaConfig, newConfig: KafkaConfig): Unit = {
-    if (producerIdExpirationMs != newConfig.producerIdExpirationMs) {
-      info(s"Reconfigure ${KafkaConfig.ProducerIdExpirationMsProp} from $producerIdExpirationMs to ${newConfig.producerIdExpirationMs}")
-      producerIdExpirationMs = newConfig.producerIdExpirationMs
-    }
-  }
-
-  override def validateReconfiguration(newConfig: KafkaConfig): Unit = {
-    if (newConfig.producerIdExpirationMs < 0)
-      throw new ConfigException(s"${KafkaConfig.ProducerIdExpirationMsProp} cannot be less than 0, current value is $producerIdExpirationMs")
-  }
-}
-
-object ProducerStateManagerConfig {
-  val ReconfigurableConfigs = Set(KafkaConfig.ProducerIdExpirationMsProp)
-}
diff --git a/core/src/main/scala/kafka/log/UnifiedLog.scala b/core/src/main/scala/kafka/log/UnifiedLog.scala
index 95329cd0795..712774cc05a 100644
--- a/core/src/main/scala/kafka/log/UnifiedLog.scala
+++ b/core/src/main/scala/kafka/log/UnifiedLog.scala
@@ -45,7 +45,7 @@ import org.apache.kafka.server.record.BrokerCompressionType
 import org.apache.kafka.server.util.Scheduler
 import org.apache.kafka.storage.internals.checkpoint.LeaderEpochCheckpointFile
 import org.apache.kafka.storage.internals.epoch.LeaderEpochFileCache
-import org.apache.kafka.storage.internals.log.{AbortedTxn, AppendOrigin, BatchMetadata, CompletedTxn, EpochEntry, FetchDataInfo, FetchIsolation, LastRecord, LogConfig, LogDirFailureChannel, LogOffsetMetadata, LogValidator, ProducerAppendInfo}
+import org.apache.kafka.storage.internals.log.{AbortedTxn, AppendOrigin, BatchMetadata, CompletedTxn, EpochEntry, FetchDataInfo, FetchIsolation, LastRecord, LogConfig, LogDirFailureChannel, LogOffsetMetadata, LogValidator, ProducerAppendInfo, ProducerStateManager, ProducerStateManagerConfig}
 
 import scala.annotation.nowarn
 import scala.collection.mutable.ListBuffer
@@ -667,7 +667,7 @@ class UnifiedLog(@volatile var logStartOffset: Long,
 
   def activeProducers: Seq[DescribeProducersResponseData.ProducerState] = {
     lock synchronized {
-      producerStateManager.activeProducers.map { case (producerId, state) =>
+      producerStateManager.activeProducers.asScala.map { case (producerId, state) =>
         new DescribeProducersResponseData.ProducerState()
           .setProducerId(producerId)
           .setProducerEpoch(state.producerEpoch)
@@ -679,20 +679,24 @@ class UnifiedLog(@volatile var logStartOffset: Long,
     }.toSeq
   }
 
-  private[log] def activeProducersWithLastSequence: Map[Long, Int] = lock synchronized {
-    producerStateManager.activeProducers.map { case (producerId, producerIdEntry) =>
-      (producerId, producerIdEntry.lastSeq)
+  private[log] def activeProducersWithLastSequence: mutable.Map[Long, Int] = lock synchronized {
+    val result = mutable.Map[Long, Int]()
+    producerStateManager.activeProducers.forEach { case (producerId, producerIdEntry) =>
+      result.put(producerId.toLong, producerIdEntry.lastSeq)
     }
+    result
   }
 
-  private[log] def lastRecordsOfActiveProducers: Map[Long, LastRecord] = lock synchronized {
-    producerStateManager.activeProducers.map { case (producerId, producerIdEntry) =>
-      val lastDataOffset =
-        if (producerIdEntry.lastDataOffset >= 0) OptionalLong.of(producerIdEntry.lastDataOffset)
-        else OptionalLong.empty()
-      val lastRecord = new LastRecord(lastDataOffset, producerIdEntry.producerEpoch)
-      producerId -> lastRecord
+  private[log] def lastRecordsOfActiveProducers: mutable.Map[Long, LastRecord] = lock synchronized {
+    val result = mutable.Map[Long, LastRecord]()
+    producerStateManager.activeProducers.forEach { case (producerId, producerIdEntry) =>
+      val lastDataOffset = if (producerIdEntry.lastDataOffset >= 0) Some(producerIdEntry.lastDataOffset) else None
+      val lastRecord = new LastRecord(
+        if (lastDataOffset.isEmpty) OptionalLong.empty() else OptionalLong.of(lastDataOffset.get),
+        producerIdEntry.producerEpoch)
+      result.put(producerId.toLong, lastRecord)
     }
+    result
   }
 
   /**
@@ -1023,7 +1027,7 @@ class UnifiedLog(@volatile var logStartOffset: Long,
   private def maybeIncrementFirstUnstableOffset(): Unit = lock synchronized {
     localLog.checkIfMemoryMappedBufferClosed()
 
-    val updatedFirstUnstableOffset = producerStateManager.firstUnstableOffset match {
+    val updatedFirstUnstableOffset = producerStateManager.firstUnstableOffset.asScala match {
       case Some(logOffsetMetadata) if logOffsetMetadata.messageOffsetOnly || logOffsetMetadata.messageOffset < logStartOffset =>
         val offset = math.max(logOffsetMetadata.messageOffset, logStartOffset)
         Some(convertToOffsetMetadataOrThrow(offset))
@@ -1088,8 +1092,9 @@ class UnifiedLog(@volatile var logStartOffset: Long,
         if (origin == AppendOrigin.CLIENT) {
           val maybeLastEntry = producerStateManager.lastEntry(batch.producerId)
 
-          maybeLastEntry.flatMap(_.findDuplicateBatch(batch).asScala).foreach { duplicate =>
-            return (updatedProducers, completedTxns.toList, Some(duplicate))
+          val duplicateBatch = maybeLastEntry.flatMap(_.findDuplicateBatch(batch))
+          if (duplicateBatch.isPresent) {
+            return (updatedProducers, completedTxns.toList, Some(duplicateBatch.get()))
           }
         }
 
@@ -1678,12 +1683,12 @@ class UnifiedLog(@volatile var logStartOffset: Long,
   }
 
   // visible for testing
-  private[log] def latestProducerSnapshotOffset: Option[Long] = lock synchronized {
+  private[log] def latestProducerSnapshotOffset: OptionalLong = lock synchronized {
     producerStateManager.latestSnapshotOffset
   }
 
   // visible for testing
-  private[log] def oldestProducerSnapshotOffset: Option[Long] = lock synchronized {
+  private[log] def oldestProducerSnapshotOffset: OptionalLong = lock synchronized {
     producerStateManager.oldestSnapshotOffset
   }
 
@@ -1856,12 +1861,8 @@ object UnifiedLog extends Logging {
 
   val TimeIndexFileSuffix = LocalLog.TimeIndexFileSuffix
 
-  val ProducerSnapshotFileSuffix = ".snapshot"
-
   val TxnIndexFileSuffix = LocalLog.TxnIndexFileSuffix
 
-  val DeletedFileSuffix = LocalLog.DeletedFileSuffix
-
   val CleanedFileSuffix = LocalLog.CleanedFileSuffix
 
   val SwapFileSuffix = LocalLog.SwapFileSuffix
@@ -1948,15 +1949,6 @@ object UnifiedLog extends Logging {
   def deleteFileIfExists(file: File, suffix: String = ""): Unit =
     Files.deleteIfExists(new File(file.getPath + suffix).toPath)
 
-  /**
-   * Construct a producer id snapshot file using the given offset.
-   *
-   * @param dir    The directory in which the log will reside
-   * @param offset The last offset (exclusive) included in the snapshot
-   */
-  def producerSnapshotFile(dir: File, offset: Long): File =
-    new File(dir, LocalLog.filenamePrefixFromOffset(offset) + ProducerSnapshotFileSuffix)
-
   def transactionIndexFile(dir: File, offset: Long, suffix: String = ""): File = LocalLog.transactionIndexFile(dir, offset, suffix)
 
   def offsetFromFile(file: File): Long = LocalLog.offsetFromFile(file)
@@ -2116,7 +2108,7 @@ object UnifiedLog extends Logging {
     // (or later snapshots). Otherwise, if there is no snapshot file, then we have to rebuild producer state
     // from the first segment.
     if (recordVersion.value < RecordBatch.MAGIC_VALUE_V2 ||
-      (producerStateManager.latestSnapshotOffset.isEmpty && reloadFromCleanShutdown)) {
+      (!producerStateManager.latestSnapshotOffset.isPresent && reloadFromCleanShutdown)) {
       // To avoid an expensive scan through all of the segments, we take empty snapshots from the start of the
       // last two segments and the last offset. This should avoid the full scan in the case that the log needs
       // truncation.
@@ -2189,7 +2181,7 @@ object UnifiedLog extends Logging {
                                            parentDir: String,
                                            topicPartition: TopicPartition): Unit = {
     val snapshotsToDelete = segments.flatMap { segment =>
-      producerStateManager.removeAndMarkSnapshotForDeletion(segment.baseOffset)
+      producerStateManager.removeAndMarkSnapshotForDeletion(segment.baseOffset).asScala
     }
 
     def deleteProducerSnapshots(): Unit = {
diff --git a/core/src/main/scala/kafka/log/remote/RemoteIndexCache.scala b/core/src/main/scala/kafka/log/remote/RemoteIndexCache.scala
index 9e8bddec9a1..8fa07cf2ae8 100644
--- a/core/src/main/scala/kafka/log/remote/RemoteIndexCache.scala
+++ b/core/src/main/scala/kafka/log/remote/RemoteIndexCache.scala
@@ -24,7 +24,7 @@ import org.apache.kafka.common.errors.CorruptRecordException
 import org.apache.kafka.common.utils.Utils
 import org.apache.kafka.server.log.remote.storage.RemoteStorageManager.IndexType
 import org.apache.kafka.server.log.remote.storage.{RemoteLogSegmentMetadata, RemoteStorageManager}
-import org.apache.kafka.storage.internals.log.{LazyIndex, OffsetIndex, OffsetPosition, TimeIndex, TransactionIndex}
+import org.apache.kafka.storage.internals.log.{LazyIndex, LogFileUtils, OffsetIndex, OffsetPosition, TimeIndex, TransactionIndex}
 
 import java.io.{Closeable, File, InputStream}
 import java.nio.file.{Files, Path}
@@ -62,9 +62,9 @@ class Entry(val offsetIndex: LazyIndex[OffsetIndex], val timeIndex: LazyIndex[Ti
       if (!markedForCleanup) {
         markedForCleanup = true
         Array(offsetIndex, timeIndex).foreach(index =>
-          index.renameTo(new File(Utils.replaceSuffix(index.file.getPath, "", UnifiedLog.DeletedFileSuffix))))
+          index.renameTo(new File(Utils.replaceSuffix(index.file.getPath, "", LogFileUtils.DELETED_FILE_SUFFIX))))
         txnIndex.renameTo(new File(Utils.replaceSuffix(txnIndex.file.getPath, "",
-          UnifiedLog.DeletedFileSuffix)))
+          LogFileUtils.DELETED_FILE_SUFFIX)))
       }
     }
   }
@@ -122,7 +122,7 @@ class RemoteIndexCache(maxSize: Int = 1024, remoteStorageManager: RemoteStorageM
 
     // Delete any .deleted files remained from the earlier run of the broker.
     Files.list(cacheDir.toPath).forEach((path: Path) => {
-      if (path.endsWith(UnifiedLog.DeletedFileSuffix)) {
+      if (path.endsWith(LogFileUtils.DELETED_FILE_SUFFIX)) {
         Files.deleteIfExists(path)
       }
     })
diff --git a/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala b/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala
index 81ace2489a9..2ed286c3d21 100644
--- a/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala
+++ b/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala
@@ -16,7 +16,7 @@
  */
 package kafka.raft
 
-import kafka.log.{LogOffsetSnapshot, ProducerStateManagerConfig, SnapshotGenerated, UnifiedLog}
+import kafka.log.{LogOffsetSnapshot, SnapshotGenerated, UnifiedLog}
 import kafka.server.KafkaConfig.{MetadataLogSegmentBytesProp, MetadataLogSegmentMinBytesProp}
 import kafka.server.{BrokerTopicStats, KafkaConfig, RequestLocal}
 import kafka.utils.{CoreUtils, Logging}
@@ -29,7 +29,7 @@ import org.apache.kafka.raft.{Isolation, KafkaRaftClient, LogAppendInfo, LogFetc
 import org.apache.kafka.server.util.Scheduler
 import org.apache.kafka.snapshot.{FileRawSnapshotReader, FileRawSnapshotWriter, RawSnapshotReader, RawSnapshotWriter, SnapshotPath, Snapshots}
 import org.apache.kafka.storage.internals
-import org.apache.kafka.storage.internals.log.{AppendOrigin, FetchIsolation, LogConfig, LogDirFailureChannel}
+import org.apache.kafka.storage.internals.log.{AppendOrigin, FetchIsolation, LogConfig, LogDirFailureChannel, ProducerStateManagerConfig}
 
 import java.io.File
 import java.nio.file.{Files, NoSuchFileException, Path}
diff --git a/core/src/main/scala/kafka/server/DynamicBrokerConfig.scala b/core/src/main/scala/kafka/server/DynamicBrokerConfig.scala
index e43754e72fc..b924648c691 100755
--- a/core/src/main/scala/kafka/server/DynamicBrokerConfig.scala
+++ b/core/src/main/scala/kafka/server/DynamicBrokerConfig.scala
@@ -22,7 +22,7 @@ import java.util.{Collections, Properties}
 import java.util.concurrent.CopyOnWriteArrayList
 import java.util.concurrent.locks.ReentrantReadWriteLock
 import kafka.cluster.EndPoint
-import kafka.log.{LogCleaner, LogManager, ProducerStateManagerConfig}
+import kafka.log.{LogCleaner, LogManager}
 import kafka.network.{DataPlaneAcceptor, SocketServer}
 import kafka.server.DynamicBrokerConfig._
 import kafka.utils.{CoreUtils, Logging, PasswordEncoder}
@@ -36,7 +36,7 @@ import org.apache.kafka.common.network.{ListenerName, ListenerReconfigurable}
 import org.apache.kafka.common.security.authenticator.LoginManager
 import org.apache.kafka.common.utils.{ConfigUtils, Utils}
 import org.apache.kafka.server.config.ServerTopicConfigSynonyms
-import org.apache.kafka.storage.internals.log.LogConfig
+import org.apache.kafka.storage.internals.log.{LogConfig, ProducerStateManagerConfig}
 
 import scala.annotation.nowarn
 import scala.collection._
@@ -59,7 +59,7 @@ import scala.jdk.CollectionConverters._
   * </ol>
   * Log configs use topic config overrides if defined and fallback to broker defaults using the order of precedence above.
   * Topic config overrides may use a different config name from the default broker config.
-  * See [[kafka.log.LogConfig#TopicConfigSynonyms]] for the mapping.
+  * See [[org.apache.kafka.storage.internals.log.LogConfig#TopicConfigSynonyms]] for the mapping.
   * <p>
   * AdminClient returns all config synonyms in the order of precedence when configs are described with
   * <code>includeSynonyms</code>. In addition to configs that may be defined with the same name at different levels,
@@ -87,7 +87,7 @@ object DynamicBrokerConfig {
     Set(KafkaConfig.MetricReporterClassesProp) ++
     DynamicListenerConfig.ReconfigurableConfigs ++
     SocketServer.ReconfigurableConfigs ++
-    ProducerStateManagerConfig.ReconfigurableConfigs
+    ProducerStateManagerConfig.RECONFIGURABLE_CONFIGS.asScala
 
   private val ClusterLevelListenerConfigs = Set(KafkaConfig.MaxConnectionsProp, KafkaConfig.MaxConnectionCreationRateProp, KafkaConfig.NumNetworkThreadsProp)
   private val PerBrokerConfigs = (DynamicSecurityConfigs ++ DynamicListenerConfig.ReconfigurableConfigs).diff(
@@ -268,7 +268,7 @@ class DynamicBrokerConfig(private val kafkaConfig: KafkaConfig) extends Logging
     addBrokerReconfigurable(new DynamicLogConfig(kafkaServer.logManager, kafkaServer))
     addBrokerReconfigurable(new DynamicListenerConfig(kafkaServer))
     addBrokerReconfigurable(kafkaServer.socketServer)
-    addBrokerReconfigurable(kafkaServer.logManager.producerStateManagerConfig)
+    addBrokerReconfigurable(new DynamicProducerStateManagerConfig(kafkaServer.logManager.producerStateManagerConfig))
   }
 
   def addReconfigurable(reconfigurable: Reconfigurable): Unit = CoreUtils.inWriteLock(lock) {
@@ -1031,3 +1031,20 @@ class DynamicListenerConfig(server: KafkaBroker) extends BrokerReconfigurable wi
     listeners.map(e => (e.listenerName, e)).toMap
 
 }
+
+class DynamicProducerStateManagerConfig(val producerStateManagerConfig: ProducerStateManagerConfig) extends BrokerReconfigurable with Logging {
+  def reconfigure(oldConfig: KafkaConfig, newConfig: KafkaConfig): Unit = {
+    if (producerStateManagerConfig.producerIdExpirationMs() != newConfig.producerIdExpirationMs) {
+      info(s"Reconfigure ${KafkaConfig.ProducerIdExpirationMsProp} from ${producerStateManagerConfig.producerIdExpirationMs()} to ${newConfig.producerIdExpirationMs}")
+      producerStateManagerConfig.setProducerIdExpirationMs(newConfig.producerIdExpirationMs)
+    }
+  }
+
+  def validateReconfiguration(newConfig: KafkaConfig): Unit = {
+    if (newConfig.producerIdExpirationMs < 0)
+      throw new ConfigException(s"${KafkaConfig.ProducerIdExpirationMsProp} cannot be less than 0, current value is ${producerStateManagerConfig.producerIdExpirationMs}, and new value is ${newConfig.producerIdExpirationMs}")
+  }
+
+  override def reconfigurableConfigs: Set[String] = ProducerStateManagerConfig.RECONFIGURABLE_CONFIGS.asScala
+
+}
diff --git a/core/src/main/scala/kafka/server/KafkaConfig.scala b/core/src/main/scala/kafka/server/KafkaConfig.scala
index a37868675fe..1130cb5a12a 100755
--- a/core/src/main/scala/kafka/server/KafkaConfig.scala
+++ b/core/src/main/scala/kafka/server/KafkaConfig.scala
@@ -47,10 +47,10 @@ import org.apache.kafka.server.authorizer.Authorizer
 import org.apache.kafka.server.common.{MetadataVersion, MetadataVersionValidator}
 import org.apache.kafka.server.common.MetadataVersion._
 import org.apache.kafka.server.config.ServerTopicConfigSynonyms
+import org.apache.kafka.storage.internals.log.{LogConfig, ProducerStateManagerConfig}
 import org.apache.kafka.storage.internals.log.LogConfig.MessageFormatVersion
 import org.apache.kafka.server.log.remote.storage.RemoteLogManagerConfig
 import org.apache.kafka.server.record.BrokerCompressionType
-import org.apache.kafka.storage.internals.log.LogConfig
 import org.apache.zookeeper.client.ZKClientConfig
 
 import scala.annotation.nowarn
@@ -510,7 +510,7 @@ object KafkaConfig {
   val TransactionsAbortTimedOutTransactionCleanupIntervalMsProp = "transaction.abort.timed.out.transaction.cleanup.interval.ms"
   val TransactionsRemoveExpiredTransactionalIdCleanupIntervalMsProp = "transaction.remove.expired.transaction.cleanup.interval.ms"
 
-  val ProducerIdExpirationMsProp = "producer.id.expiration.ms"
+  val ProducerIdExpirationMsProp = ProducerStateManagerConfig.PRODUCER_ID_EXPIRATION_MS
   val ProducerIdExpirationCheckIntervalMsProp = "producer.id.expiration.check.interval.ms"
 
   /** ********* Fetch Configuration **************/
diff --git a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
index f68be09f8f8..832ecd42c99 100644
--- a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
+++ b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
@@ -18,7 +18,7 @@
 package kafka.server
 
 import kafka.log.remote.RemoteLogManager
-import kafka.log.{LeaderOffsetIncremented, LogAppendInfo, UnifiedLog}
+import kafka.log.{LeaderOffsetIncremented, LogAppendInfo}
 import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset
 import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.record.MemoryRecords
@@ -29,7 +29,7 @@ import org.apache.kafka.server.common.CheckpointFile.CheckpointReadBuffer
 import org.apache.kafka.server.common.MetadataVersion
 import org.apache.kafka.server.log.remote.storage.{RemoteLogSegmentMetadata, RemoteStorageException, RemoteStorageManager}
 import org.apache.kafka.storage.internals.checkpoint.LeaderEpochCheckpointFile
-import org.apache.kafka.storage.internals.log.EpochEntry
+import org.apache.kafka.storage.internals.log.{EpochEntry, LogFileUtils}
 
 import java.io.{BufferedReader, File, InputStreamReader}
 import java.nio.charset.StandardCharsets
@@ -299,7 +299,7 @@ class ReplicaFetcherThread(name: String,
             s"with size: ${epochs.size} for $partition")
 
           // Restore producer snapshot
-          val snapshotFile = UnifiedLog.producerSnapshotFile(log.dir, nextOffset)
+          val snapshotFile = LogFileUtils.producerSnapshotFile(log.dir, nextOffset)
           buildProducerSnapshotFile(snapshotFile, remoteLogSegmentMetadata, rlm)
 
           // Reload producer snapshots.
diff --git a/core/src/main/scala/kafka/tools/DumpLogSegments.scala b/core/src/main/scala/kafka/tools/DumpLogSegments.scala
index 9319b72dc9e..bb773d86ade 100755
--- a/core/src/main/scala/kafka/tools/DumpLogSegments.scala
+++ b/core/src/main/scala/kafka/tools/DumpLogSegments.scala
@@ -34,7 +34,7 @@ import org.apache.kafka.metadata.MetadataRecordSerde
 import org.apache.kafka.metadata.bootstrap.BootstrapDirectory
 import org.apache.kafka.snapshot.Snapshots
 import org.apache.kafka.server.util.{CommandDefaultOptions, CommandLineUtils}
-import org.apache.kafka.storage.internals.log.{CorruptSnapshotException, OffsetIndex, TimeIndex, TransactionIndex}
+import org.apache.kafka.storage.internals.log.{CorruptSnapshotException, LogFileUtils, OffsetIndex, ProducerStateManager, TimeIndex, TransactionIndex}
 
 import scala.jdk.CollectionConverters._
 import scala.collection.mutable
@@ -68,7 +68,7 @@ object DumpLogSegments {
           dumpIndex(file, opts.indexSanityOnly, opts.verifyOnly, misMatchesForIndexFilesMap, opts.maxMessageSize)
         case UnifiedLog.TimeIndexFileSuffix =>
           dumpTimeIndex(file, opts.indexSanityOnly, opts.verifyOnly, timeIndexDumpErrors)
-        case UnifiedLog.ProducerSnapshotFileSuffix =>
+        case LogFileUtils.PRODUCER_SNAPSHOT_FILE_SUFFIX =>
           dumpProducerIdSnapshot(file)
         case UnifiedLog.TxnIndexFileSuffix =>
           dumpTxnIndex(file)
@@ -104,7 +104,7 @@ object DumpLogSegments {
 
   private def dumpProducerIdSnapshot(file: File): Unit = {
     try {
-      ProducerStateManager.readSnapshot(file).foreach { entry =>
+      ProducerStateManager.readSnapshot(file).forEach { entry =>
         print(s"producerId: ${entry.producerId} producerEpoch: ${entry.producerEpoch} " +
           s"coordinatorEpoch: ${entry.coordinatorEpoch} currentTxnFirstOffset: ${entry.currentTxnFirstOffset} " +
           s"lastTimestamp: ${entry.lastTimestamp} ")
diff --git a/core/src/test/scala/integration/kafka/api/TransactionsTest.scala b/core/src/test/scala/integration/kafka/api/TransactionsTest.scala
index c4d676a3ce7..0a1f07dbc13 100644
--- a/core/src/test/scala/integration/kafka/api/TransactionsTest.scala
+++ b/core/src/test/scala/integration/kafka/api/TransactionsTest.scala
@@ -651,8 +651,8 @@ class TransactionsTest extends IntegrationTestHarness {
       producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(testTopic, 0, "4", "4", willBeCommitted = true))
       producer.commitTransaction()
 
-      var producerStateEntry =
-        brokers(partitionLeader).logManager.getLog(new TopicPartition(testTopic, 0)).get.producerStateManager.activeProducers.head._2
+      var producerStateEntry = brokers(partitionLeader).logManager.getLog(new TopicPartition(testTopic, 0)).get
+        .producerStateManager.activeProducers.entrySet().iterator().next().getValue
       val producerId = producerStateEntry.producerId
       val initialProducerEpoch = producerStateEntry.producerEpoch
 
@@ -685,7 +685,7 @@ class TransactionsTest extends IntegrationTestHarness {
       // get here without having bumped the epoch. If bumping the epoch is possible, the producer will attempt to, so
       // check there that the epoch has actually increased
       producerStateEntry =
-        brokers(partitionLeader).logManager.getLog(new TopicPartition(testTopic, 0)).get.producerStateManager.activeProducers(producerId)
+        brokers(partitionLeader).logManager.getLog(new TopicPartition(testTopic, 0)).get.producerStateManager.activeProducers.get(producerId)
       assertTrue(producerStateEntry.producerEpoch > initialProducerEpoch)
     } finally {
       producer.close(Duration.ZERO)
@@ -706,7 +706,8 @@ class TransactionsTest extends IntegrationTestHarness {
 
     val partitionLeader = TestUtils.waitUntilLeaderIsKnown(brokers, new TopicPartition(topic1, 0))
     var producerStateEntry =
-      brokers(partitionLeader).logManager.getLog(new TopicPartition(topic1, 0)).get.producerStateManager.activeProducers.head._2
+      brokers(partitionLeader).logManager.getLog(new TopicPartition(topic1, 0)).get.producerStateManager
+        .activeProducers.entrySet().iterator().next().getValue
     val producerId = producerStateEntry.producerId
     val initialProducerEpoch = producerStateEntry.producerEpoch
 
@@ -748,7 +749,7 @@ class TransactionsTest extends IntegrationTestHarness {
 
     // Check that the epoch only increased by 1
     producerStateEntry =
-      brokers(partitionLeader).logManager.getLog(new TopicPartition(topic1, 0)).get.producerStateManager.activeProducers(producerId)
+      brokers(partitionLeader).logManager.getLog(new TopicPartition(topic1, 0)).get.producerStateManager.activeProducers.get(producerId)
     assertEquals((initialProducerEpoch + 1).toShort, producerStateEntry.producerEpoch)
   }
 
diff --git a/core/src/test/scala/other/kafka/StressTestLog.scala b/core/src/test/scala/other/kafka/StressTestLog.scala
index 139276f8c86..0eaabc4a56a 100755
--- a/core/src/test/scala/other/kafka/StressTestLog.scala
+++ b/core/src/test/scala/other/kafka/StressTestLog.scala
@@ -26,7 +26,7 @@ import org.apache.kafka.clients.consumer.OffsetOutOfRangeException
 import org.apache.kafka.common.config.TopicConfig
 import org.apache.kafka.common.record.FileRecords
 import org.apache.kafka.common.utils.Utils
-import org.apache.kafka.storage.internals.log.{FetchIsolation, LogConfig, LogDirFailureChannel}
+import org.apache.kafka.storage.internals.log.{FetchIsolation, LogConfig, LogDirFailureChannel, ProducerStateManagerConfig}
 
 /**
  * A stress test that instantiates a log and then runs continual appends against it from one thread and continual reads against it
diff --git a/core/src/test/scala/other/kafka/TestLinearWriteSpeed.scala b/core/src/test/scala/other/kafka/TestLinearWriteSpeed.scala
index b7b2375af23..315dc75809c 100755
--- a/core/src/test/scala/other/kafka/TestLinearWriteSpeed.scala
+++ b/core/src/test/scala/other/kafka/TestLinearWriteSpeed.scala
@@ -31,7 +31,7 @@ import org.apache.kafka.common.record._
 import org.apache.kafka.common.utils.{Time, Utils}
 import org.apache.kafka.server.util.{KafkaScheduler, Scheduler}
 import org.apache.kafka.server.util.CommandLineUtils
-import org.apache.kafka.storage.internals.log.{LogConfig, LogDirFailureChannel}
+import org.apache.kafka.storage.internals.log.{LogConfig, LogDirFailureChannel, ProducerStateManagerConfig}
 
 import scala.math._
 
diff --git a/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala b/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala
index 895a3d3a0c3..7b47f6ddcb8 100644
--- a/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala
+++ b/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala
@@ -36,7 +36,7 @@ import org.apache.kafka.common.utils.Utils
 import org.apache.kafka.common.{TopicPartition, Uuid}
 import org.apache.kafka.server.common.MetadataVersion
 import org.apache.kafka.storage.internals.epoch.LeaderEpochFileCache
-import org.apache.kafka.storage.internals.log.{AppendOrigin, CleanerConfig, FetchIsolation, FetchParams, LogConfig, LogDirFailureChannel}
+import org.apache.kafka.storage.internals.log.{AppendOrigin, CleanerConfig, FetchIsolation, FetchParams, LogConfig, LogDirFailureChannel, ProducerStateManager, ProducerStateManagerConfig}
 import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertTrue}
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
 import org.mockito.ArgumentMatchers
diff --git a/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala b/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
index 18d9022db48..aa6397c8819 100644
--- a/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
+++ b/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala
@@ -55,7 +55,7 @@ import org.apache.kafka.server.common.MetadataVersion.IBP_2_6_IV0
 import org.apache.kafka.server.metrics.KafkaYammerMetrics
 import org.apache.kafka.server.util.KafkaScheduler
 import org.apache.kafka.storage.internals.epoch.LeaderEpochFileCache
-import org.apache.kafka.storage.internals.log.{AppendOrigin, CleanerConfig, EpochEntry, FetchIsolation, FetchParams, LogDirFailureChannel}
+import org.apache.kafka.storage.internals.log.{AppendOrigin, CleanerConfig, EpochEntry, FetchIsolation, FetchParams, LogDirFailureChannel, ProducerStateManager, ProducerStateManagerConfig}
 import org.junit.jupiter.params.ParameterizedTest
 import org.junit.jupiter.params.provider.ValueSource
 
diff --git a/core/src/test/scala/unit/kafka/log/AbstractLogCleanerIntegrationTest.scala b/core/src/test/scala/unit/kafka/log/AbstractLogCleanerIntegrationTest.scala
index 187388d308c..c411f6d37b6 100644
--- a/core/src/test/scala/unit/kafka/log/AbstractLogCleanerIntegrationTest.scala
+++ b/core/src/test/scala/unit/kafka/log/AbstractLogCleanerIntegrationTest.scala
@@ -26,7 +26,7 @@ import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.config.TopicConfig
 import org.apache.kafka.common.record.{CompressionType, MemoryRecords, RecordBatch}
 import org.apache.kafka.common.utils.Utils
-import org.apache.kafka.storage.internals.log.{CleanerConfig, LogConfig, LogDirFailureChannel}
+import org.apache.kafka.storage.internals.log.{CleanerConfig, LogConfig, LogDirFailureChannel, ProducerStateManagerConfig}
 import org.junit.jupiter.api.{AfterEach, Tag}
 
 import scala.collection.Seq
diff --git a/core/src/test/scala/unit/kafka/log/BrokerCompressionTest.scala b/core/src/test/scala/unit/kafka/log/BrokerCompressionTest.scala
index 84a5d1260dc..15d24224c71 100755
--- a/core/src/test/scala/unit/kafka/log/BrokerCompressionTest.scala
+++ b/core/src/test/scala/unit/kafka/log/BrokerCompressionTest.scala
@@ -23,7 +23,7 @@ import org.apache.kafka.common.config.TopicConfig
 import org.apache.kafka.common.record.{CompressionType, MemoryRecords, RecordBatch, SimpleRecord}
 import org.apache.kafka.common.utils.Utils
 import org.apache.kafka.server.record.BrokerCompressionType
-import org.apache.kafka.storage.internals.log.{FetchIsolation, LogConfig, LogDirFailureChannel}
+import org.apache.kafka.storage.internals.log.{FetchIsolation, LogConfig, LogDirFailureChannel, ProducerStateManagerConfig}
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api._
 import org.junit.jupiter.params.ParameterizedTest
diff --git a/core/src/test/scala/unit/kafka/log/LogCleanerManagerTest.scala b/core/src/test/scala/unit/kafka/log/LogCleanerManagerTest.scala
index df2c0af7c9c..7bcc7da25a8 100644
--- a/core/src/test/scala/unit/kafka/log/LogCleanerManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogCleanerManagerTest.scala
@@ -26,7 +26,7 @@ import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.config.TopicConfig
 import org.apache.kafka.common.record._
 import org.apache.kafka.common.utils.Utils
-import org.apache.kafka.storage.internals.log.{AppendOrigin, LogConfig, LogDirFailureChannel}
+import org.apache.kafka.storage.internals.log.{AppendOrigin, LogConfig, LogDirFailureChannel, ProducerStateManager, ProducerStateManagerConfig}
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{AfterEach, Test}
 
diff --git a/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala b/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
index d31d82d1775..ceff738559a 100644
--- a/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
@@ -31,7 +31,7 @@ import org.apache.kafka.common.config.TopicConfig
 import org.apache.kafka.common.errors.CorruptRecordException
 import org.apache.kafka.common.record._
 import org.apache.kafka.common.utils.Utils
-import org.apache.kafka.storage.internals.log.{AbortedTxn, AppendOrigin, CleanerConfig, LogConfig, LogDirFailureChannel, OffsetMap}
+import org.apache.kafka.storage.internals.log.{AbortedTxn, AppendOrigin, CleanerConfig, LogConfig, LogDirFailureChannel, LogFileUtils, OffsetMap, ProducerStateManager, ProducerStateManagerConfig}
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{AfterEach, Test}
 
@@ -168,7 +168,7 @@ class LogCleanerTest {
 
     // Remember reference to the first log and determine its file name expected for async deletion
     val firstLogFile = log.logSegments.head.log
-    val expectedFileName = Utils.replaceSuffix(firstLogFile.file.getPath, "", UnifiedLog.DeletedFileSuffix)
+    val expectedFileName = Utils.replaceSuffix(firstLogFile.file.getPath, "", LogFileUtils.DELETED_FILE_SUFFIX)
 
     // Clean the log. This should trigger replaceSegments() and deleteOldSegments();
     val offsetMap = new FakeOffsetMap(Int.MaxValue)
@@ -1597,8 +1597,8 @@ class LogCleanerTest {
     // 1) Simulate recovery just after .cleaned file is created, before rename to .swap
     //    On recovery, clean operation is aborted. All messages should be present in the log
     log.logSegments.head.changeFileSuffixes("", UnifiedLog.CleanedFileSuffix)
-    for (file <- dir.listFiles if file.getName.endsWith(UnifiedLog.DeletedFileSuffix)) {
-      Utils.atomicMoveWithFallback(file.toPath, Paths.get(Utils.replaceSuffix(file.getPath, UnifiedLog.DeletedFileSuffix, "")), false)
+    for (file <- dir.listFiles if file.getName.endsWith(LogFileUtils.DELETED_FILE_SUFFIX)) {
+      Utils.atomicMoveWithFallback(file.toPath, Paths.get(Utils.replaceSuffix(file.getPath, LogFileUtils.DELETED_FILE_SUFFIX, "")), false)
     }
     log = recoverAndCheck(config, allKeys)
 
@@ -1614,8 +1614,8 @@ class LogCleanerTest {
     //    On recovery, clean operation is aborted. All messages should be present in the log
     log.logSegments.head.changeFileSuffixes("", UnifiedLog.CleanedFileSuffix)
     log.logSegments.head.log.renameTo(new File(Utils.replaceSuffix(log.logSegments.head.log.file.getPath, UnifiedLog.CleanedFileSuffix, UnifiedLog.SwapFileSuffix)))
-    for (file <- dir.listFiles if file.getName.endsWith(UnifiedLog.DeletedFileSuffix)) {
-      Utils.atomicMoveWithFallback(file.toPath, Paths.get(Utils.replaceSuffix(file.getPath, UnifiedLog.DeletedFileSuffix, "")), false)
+    for (file <- dir.listFiles if file.getName.endsWith(LogFileUtils.DELETED_FILE_SUFFIX)) {
+      Utils.atomicMoveWithFallback(file.toPath, Paths.get(Utils.replaceSuffix(file.getPath, LogFileUtils.DELETED_FILE_SUFFIX, "")), false)
     }
     log = recoverAndCheck(config, allKeys)
 
@@ -1630,8 +1630,8 @@ class LogCleanerTest {
     // 3) Simulate recovery just after swap file is created, before old segment files are
     //    renamed to .deleted. Clean operation is resumed during recovery.
     log.logSegments.head.changeFileSuffixes("", UnifiedLog.SwapFileSuffix)
-    for (file <- dir.listFiles if file.getName.endsWith(UnifiedLog.DeletedFileSuffix)) {
-      Utils.atomicMoveWithFallback(file.toPath, Paths.get(Utils.replaceSuffix(file.getPath, UnifiedLog.DeletedFileSuffix, "")), false)
+    for (file <- dir.listFiles if file.getName.endsWith(LogFileUtils.DELETED_FILE_SUFFIX)) {
+      Utils.atomicMoveWithFallback(file.toPath, Paths.get(Utils.replaceSuffix(file.getPath, LogFileUtils.DELETED_FILE_SUFFIX, "")), false)
     }
     log = recoverAndCheck(config, cleanedKeys)
 
diff --git a/core/src/test/scala/unit/kafka/log/LogConcurrencyTest.scala b/core/src/test/scala/unit/kafka/log/LogConcurrencyTest.scala
index 326bd4fe6ea..487b31d438a 100644
--- a/core/src/test/scala/unit/kafka/log/LogConcurrencyTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogConcurrencyTest.scala
@@ -25,7 +25,7 @@ import org.apache.kafka.common.config.TopicConfig
 import org.apache.kafka.common.record.SimpleRecord
 import org.apache.kafka.common.utils.{Time, Utils}
 import org.apache.kafka.server.util.KafkaScheduler
-import org.apache.kafka.storage.internals.log.{FetchIsolation, LogConfig, LogDirFailureChannel}
+import org.apache.kafka.storage.internals.log.{FetchIsolation, LogConfig, LogDirFailureChannel, ProducerStateManagerConfig}
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
 
diff --git a/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala b/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala
index 76b3eebcae3..7cb5275640f 100644
--- a/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala
@@ -20,7 +20,7 @@ package kafka.log
 import java.io.{BufferedWriter, File, FileWriter, IOException}
 import java.nio.ByteBuffer
 import java.nio.file.{Files, NoSuchFileException, Paths}
-import java.util.Properties
+import java.util.{Optional, OptionalLong, Properties}
 import kafka.server.{BrokerTopicStats, KafkaConfig}
 import kafka.server.metadata.MockConfigRepository
 import kafka.utils.{MockTime, TestUtils}
@@ -33,7 +33,7 @@ import org.apache.kafka.server.common.MetadataVersion
 import org.apache.kafka.server.common.MetadataVersion.IBP_0_11_0_IV0
 import org.apache.kafka.server.util.Scheduler
 import org.apache.kafka.storage.internals.epoch.LeaderEpochFileCache
-import org.apache.kafka.storage.internals.log.{AbortedTxn, CleanerConfig, EpochEntry, FetchDataInfo, LogConfig, LogDirFailureChannel, OffsetIndex, SnapshotFile}
+import org.apache.kafka.storage.internals.log.{AbortedTxn, CleanerConfig, EpochEntry, FetchDataInfo, LogConfig, LogDirFailureChannel, LogFileUtils, LogOffsetMetadata, OffsetIndex, ProducerStateManager, ProducerStateManagerConfig, SnapshotFile}
 import org.junit.jupiter.api.Assertions.{assertDoesNotThrow, assertEquals, assertFalse, assertNotEquals, assertThrows, assertTrue}
 import org.junit.jupiter.api.function.Executable
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
@@ -45,6 +45,7 @@ import java.util.concurrent.ConcurrentMap
 import scala.annotation.nowarn
 import scala.collection.mutable.ListBuffer
 import scala.collection.{Iterable, Map, mutable}
+import scala.compat.java8.OptionConverters._
 import scala.jdk.CollectionConverters._
 
 class LogLoaderTest {
@@ -299,7 +300,7 @@ class LogLoaderTest {
     logProps.put(TopicConfig.MESSAGE_FORMAT_VERSION_CONFIG, messageFormatVersion)
     val logConfig = new LogConfig(logProps)
     var log = createLog(logDir, logConfig)
-    assertEquals(None, log.oldestProducerSnapshotOffset)
+    assertEquals(OptionalLong.empty(), log.oldestProducerSnapshotOffset)
 
     for (i <- 0 to 100) {
       val record = new SimpleRecord(mockTime.milliseconds, i.toString.getBytes)
@@ -412,10 +413,10 @@ class LogLoaderTest {
     val stateManager: ProducerStateManager = mock(classOf[ProducerStateManager])
     when(stateManager.producerStateManagerConfig).thenReturn(producerStateManagerConfig)
     when(stateManager.maxTransactionTimeoutMs).thenReturn(maxTransactionTimeoutMs)
-    when(stateManager.latestSnapshotOffset).thenReturn(None)
+    when(stateManager.latestSnapshotOffset).thenReturn(OptionalLong.empty())
     when(stateManager.mapEndOffset).thenReturn(0L)
     when(stateManager.isEmpty).thenReturn(true)
-    when(stateManager.firstUnstableOffset).thenReturn(None)
+    when(stateManager.firstUnstableOffset).thenReturn(Optional.empty[LogOffsetMetadata]())
 
     val topicPartition = UnifiedLog.parseTopicPartitionName(logDir)
     val logDirFailureChannel: LogDirFailureChannel = new LogDirFailureChannel(1)
@@ -455,7 +456,7 @@ class LogLoaderTest {
 
     // Append some messages
     reset(stateManager)
-    when(stateManager.firstUnstableOffset).thenReturn(None)
+    when(stateManager.firstUnstableOffset).thenReturn(Optional.empty[LogOffsetMetadata]())
 
     log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes))), leaderEpoch = 0)
     log.appendAsLeader(TestUtils.records(List(new SimpleRecord("b".getBytes))), leaderEpoch = 0)
@@ -465,8 +466,8 @@ class LogLoaderTest {
 
     // Now truncate
     reset(stateManager)
-    when(stateManager.firstUnstableOffset).thenReturn(None)
-    when(stateManager.latestSnapshotOffset).thenReturn(None)
+    when(stateManager.firstUnstableOffset).thenReturn(Optional.empty[LogOffsetMetadata]())
+    when(stateManager.latestSnapshotOffset).thenReturn(OptionalLong.empty())
     when(stateManager.isEmpty).thenReturn(true)
     when(stateManager.mapEndOffset).thenReturn(2L)
     // Truncation causes the map end offset to reset to 0
@@ -498,7 +499,7 @@ class LogLoaderTest {
 
     val maxProducerIdExpirationMs = kafka.server.Defaults.ProducerIdExpirationMs
     mockTime.sleep(maxProducerIdExpirationMs)
-    assertEquals(None, log.producerStateManager.lastEntry(producerId))
+    assertEquals(Optional.empty(), log.producerStateManager.lastEntry(producerId))
 
     val secondAppendTimestamp = mockTime.milliseconds()
     LogTestUtils.appendEndTxnMarkerAsLeader(log, producerId, epoch, ControlRecordType.ABORT,
@@ -520,7 +521,7 @@ class LogLoaderTest {
 
     val stateManager: ProducerStateManager = mock(classOf[ProducerStateManager])
     when(stateManager.isEmpty).thenReturn(true)
-    when(stateManager.firstUnstableOffset).thenReturn(None)
+    when(stateManager.firstUnstableOffset).thenReturn(Optional.empty[LogOffsetMetadata]())
     when(stateManager.producerStateManagerConfig).thenReturn(producerStateManagerConfig)
     when(stateManager.maxTransactionTimeoutMs).thenReturn(maxTransactionTimeoutMs)
 
@@ -557,7 +558,7 @@ class LogLoaderTest {
       _topicId = None,
       keepPartitionMetadataFile = true)
 
-    verify(stateManager).removeStraySnapshots(any[Seq[Long]])
+    verify(stateManager).removeStraySnapshots(any[java.util.List[java.lang.Long]])
     verify(stateManager, times(2)).updateMapEndOffset(0L)
     verify(stateManager, times(2)).takeSnapshot()
     verify(stateManager).isEmpty
@@ -574,7 +575,7 @@ class LogLoaderTest {
 
     val stateManager: ProducerStateManager = mock(classOf[ProducerStateManager])
     when(stateManager.isEmpty).thenReturn(true)
-    when(stateManager.firstUnstableOffset).thenReturn(None)
+    when(stateManager.firstUnstableOffset).thenReturn(Optional.empty[LogOffsetMetadata]())
     when(stateManager.producerStateManagerConfig).thenReturn(producerStateManagerConfig)
     when(stateManager.maxTransactionTimeoutMs).thenReturn(maxTransactionTimeoutMs)
 
@@ -611,7 +612,7 @@ class LogLoaderTest {
       _topicId = None,
       keepPartitionMetadataFile = true)
 
-    verify(stateManager).removeStraySnapshots(any[Seq[Long]])
+    verify(stateManager).removeStraySnapshots(any[java.util.List[java.lang.Long]])
     verify(stateManager, times(2)).updateMapEndOffset(0L)
     verify(stateManager, times(2)).takeSnapshot()
     verify(stateManager).isEmpty
@@ -625,9 +626,9 @@ class LogLoaderTest {
     val producerStateManagerConfig = new ProducerStateManagerConfig(300000)
 
     val stateManager: ProducerStateManager = mock(classOf[ProducerStateManager])
-    when(stateManager.latestSnapshotOffset).thenReturn(None)
+    when(stateManager.latestSnapshotOffset).thenReturn(OptionalLong.empty())
     when(stateManager.isEmpty).thenReturn(true)
-    when(stateManager.firstUnstableOffset).thenReturn(None)
+    when(stateManager.firstUnstableOffset).thenReturn(Optional.empty[LogOffsetMetadata]())
     when(stateManager.producerStateManagerConfig).thenReturn(producerStateManagerConfig)
     when(stateManager.maxTransactionTimeoutMs).thenReturn(maxTransactionTimeoutMs)
 
@@ -664,7 +665,7 @@ class LogLoaderTest {
       _topicId = None,
       keepPartitionMetadataFile = true)
 
-    verify(stateManager).removeStraySnapshots(any[Seq[Long]])
+    verify(stateManager).removeStraySnapshots(any[java.util.List[java.lang.Long]])
     verify(stateManager, times(2)).updateMapEndOffset(0L)
     verify(stateManager, times(2)).takeSnapshot()
     verify(stateManager).isEmpty
@@ -722,15 +723,15 @@ class LogLoaderTest {
     log.close()
     assertEquals(log.logSegments.size, 3)
     // We expect 3 snapshot files, two of which are for the first two segments, the last was written out during log closing.
-    assertEquals(Seq(1, 2, 4), ProducerStateManager.listSnapshotFiles(logDir).map(_.offset).sorted)
+    assertEquals(Seq(1L, 2L, 4L), ProducerStateManager.listSnapshotFiles(logDir).asScala.map(_.offset).sorted)
     // Inject a stray snapshot file within the bounds of the log at offset 3, it should be cleaned up after loading the log
-    val straySnapshotFile = UnifiedLog.producerSnapshotFile(logDir, 3).toPath
+    val straySnapshotFile = LogFileUtils.producerSnapshotFile(logDir, 3).toPath
     Files.createFile(straySnapshotFile)
-    assertEquals(Seq(1, 2, 3, 4), ProducerStateManager.listSnapshotFiles(logDir).map(_.offset).sorted)
+    assertEquals(Seq(1L, 2L, 3L, 4L), ProducerStateManager.listSnapshotFiles(logDir).asScala.map(_.offset).sorted)
 
     createLog(logDir, logConfig, lastShutdownClean = false)
     // We should clean up the stray producer state snapshot file, but keep the largest snapshot file (4)
-    assertEquals(Seq(1, 2, 4), ProducerStateManager.listSnapshotFiles(logDir).map(_.offset).sorted)
+    assertEquals(Seq(1L, 2L, 4L), ProducerStateManager.listSnapshotFiles(logDir).asScala.map(_.offset).sorted)
   }
 
   @Test
@@ -1064,7 +1065,7 @@ class LogLoaderTest {
     //This write will roll the segment, yielding a new segment with base offset = max(1, Integer.MAX_VALUE+2) = Integer.MAX_VALUE+2
     log.appendAsFollower(set2)
     assertEquals(Integer.MAX_VALUE.toLong + 2, log.activeSegment.baseOffset)
-    assertTrue(UnifiedLog.producerSnapshotFile(logDir, Integer.MAX_VALUE.toLong + 2).exists)
+    assertTrue(LogFileUtils.producerSnapshotFile(logDir, Integer.MAX_VALUE.toLong + 2).exists)
     //This will go into the existing log
     log.appendAsFollower(set3)
     assertEquals(Integer.MAX_VALUE.toLong + 2, log.activeSegment.baseOffset)
@@ -1127,7 +1128,7 @@ class LogLoaderTest {
     //This write will roll the segment, yielding a new segment with base offset = max(1, Integer.MAX_VALUE+2) = Integer.MAX_VALUE+2
     log.appendAsFollower(set2)
     assertEquals(Integer.MAX_VALUE.toLong + 2, log.activeSegment.baseOffset)
-    assertTrue(UnifiedLog.producerSnapshotFile(logDir, Integer.MAX_VALUE.toLong + 2).exists)
+    assertTrue(LogFileUtils.producerSnapshotFile(logDir, Integer.MAX_VALUE.toLong + 2).exists)
     //This will go into the existing log
     log.appendAsFollower(set3)
     assertEquals(Integer.MAX_VALUE.toLong + 2, log.activeSegment.baseOffset)
@@ -1167,11 +1168,11 @@ class LogLoaderTest {
     //This write will roll the segment, yielding a new segment with base offset = max(1, 3) = 3
     log.appendAsFollower(set2)
     assertEquals(3, log.activeSegment.baseOffset)
-    assertTrue(UnifiedLog.producerSnapshotFile(logDir, 3).exists)
+    assertTrue(LogFileUtils.producerSnapshotFile(logDir, 3).exists)
     //This will also roll the segment, yielding a new segment with base offset = max(5, Integer.MAX_VALUE+4) = Integer.MAX_VALUE+4
     log.appendAsFollower(set3)
     assertEquals(Integer.MAX_VALUE.toLong + 4, log.activeSegment.baseOffset)
-    assertTrue(UnifiedLog.producerSnapshotFile(logDir, Integer.MAX_VALUE.toLong + 4).exists)
+    assertTrue(LogFileUtils.producerSnapshotFile(logDir, Integer.MAX_VALUE.toLong + 4).exists)
     //This will go into the existing log
     log.appendAsFollower(set4)
     assertEquals(Integer.MAX_VALUE.toLong + 4, log.activeSegment.baseOffset)
@@ -1220,8 +1221,8 @@ class LogLoaderTest {
       segment.changeFileSuffixes("", UnifiedLog.CleanedFileSuffix)
       segment.truncateTo(0)
     })
-    for (file <- logDir.listFiles if file.getName.endsWith(UnifiedLog.DeletedFileSuffix))
-      Utils.atomicMoveWithFallback(file.toPath, Paths.get(Utils.replaceSuffix(file.getPath, UnifiedLog.DeletedFileSuffix, "")))
+    for (file <- logDir.listFiles if file.getName.endsWith(LogFileUtils.DELETED_FILE_SUFFIX))
+      Utils.atomicMoveWithFallback(file.toPath, Paths.get(Utils.replaceSuffix(file.getPath, LogFileUtils.DELETED_FILE_SUFFIX, "")))
 
     val recoveredLog = recoverAndCheck(logConfig, expectedKeys)
     assertEquals(expectedKeys, LogTestUtils.keysInLog(recoveredLog))
@@ -1248,8 +1249,8 @@ class LogLoaderTest {
         segment.changeFileSuffixes("", UnifiedLog.SwapFileSuffix)
       segment.truncateTo(0)
     }
-    for (file <- logDir.listFiles if file.getName.endsWith(UnifiedLog.DeletedFileSuffix))
-      Utils.atomicMoveWithFallback(file.toPath, Paths.get(Utils.replaceSuffix(file.getPath, UnifiedLog.DeletedFileSuffix, "")))
+    for (file <- logDir.listFiles if file.getName.endsWith(LogFileUtils.DELETED_FILE_SUFFIX))
+      Utils.atomicMoveWithFallback(file.toPath, Paths.get(Utils.replaceSuffix(file.getPath, LogFileUtils.DELETED_FILE_SUFFIX, "")))
 
     val recoveredLog = recoverAndCheck(logConfig, expectedKeys)
     assertEquals(expectedKeys, LogTestUtils.keysInLog(recoveredLog))
@@ -1272,8 +1273,8 @@ class LogLoaderTest {
     newSegments.reverse.foreach(segment => {
       segment.changeFileSuffixes("", UnifiedLog.SwapFileSuffix)
     })
-    for (file <- logDir.listFiles if file.getName.endsWith(UnifiedLog.DeletedFileSuffix))
-      Utils.atomicMoveWithFallback(file.toPath, Paths.get(Utils.replaceSuffix(file.getPath, UnifiedLog.DeletedFileSuffix, "")))
+    for (file <- logDir.listFiles if file.getName.endsWith(LogFileUtils.DELETED_FILE_SUFFIX))
+      Utils.atomicMoveWithFallback(file.toPath, Paths.get(Utils.replaceSuffix(file.getPath, LogFileUtils.DELETED_FILE_SUFFIX, "")))
 
     // Truncate the old segment
     segmentWithOverflow.truncateTo(0)
@@ -1298,7 +1299,7 @@ class LogLoaderTest {
     // recovery, existing split operation is completed.
     newSegments.reverse.foreach(_.changeFileSuffixes("", UnifiedLog.SwapFileSuffix))
 
-    for (file <- logDir.listFiles if file.getName.endsWith(UnifiedLog.DeletedFileSuffix))
+    for (file <- logDir.listFiles if file.getName.endsWith(LogFileUtils.DELETED_FILE_SUFFIX))
       Utils.delete(file)
 
     // Truncate the old segment
@@ -1604,7 +1605,7 @@ class LogLoaderTest {
     assertEquals(9, log.activeSegment.baseOffset)
     assertEquals(9, log.logEndOffset)
     for (offset <- 1 until 10) {
-      val snapshotFileBeforeDeletion = log.producerStateManager.snapshotFileForOffset(offset)
+      val snapshotFileBeforeDeletion = log.producerStateManager.snapshotFileForOffset(offset).asScala
       assertTrue(snapshotFileBeforeDeletion.isDefined)
       assertTrue(snapshotFileBeforeDeletion.get.file.exists)
     }
@@ -1655,11 +1656,11 @@ class LogLoaderTest {
     assertEquals(4, log.logEndOffset)
 
     val offsetsWithSnapshotFiles = (1 until 5)
-        .map(offset => new SnapshotFile(UnifiedLog.producerSnapshotFile(logDir, offset)))
+        .map(offset => new SnapshotFile(LogFileUtils.producerSnapshotFile(logDir, offset)))
         .filter(snapshotFile => snapshotFile.file.exists())
         .map(_.offset)
     val inMemorySnapshotFiles = (1 until 5)
-        .flatMap(offset => log.producerStateManager.snapshotFileForOffset(offset))
+        .flatMap(offset => log.producerStateManager.snapshotFileForOffset(offset).asScala)
 
     assertTrue(offsetsWithSnapshotFiles.isEmpty, s"Found offsets with producer state snapshot files: $offsetsWithSnapshotFiles while none were expected.")
     assertTrue(inMemorySnapshotFiles.isEmpty, s"Found in-memory producer state snapshot files: $inMemorySnapshotFiles while none were expected.")
@@ -1682,7 +1683,7 @@ class LogLoaderTest {
     assertEquals(9, log.logEndOffset)
     for (offset <- 5 until 10) {
       val snapshotFileBeforeDeletion = log.producerStateManager.snapshotFileForOffset(offset)
-      assertTrue(snapshotFileBeforeDeletion.isDefined)
+      assertTrue(snapshotFileBeforeDeletion.isPresent)
       assertTrue(snapshotFileBeforeDeletion.get.file.exists)
     }
 
@@ -1700,13 +1701,13 @@ class LogLoaderTest {
     val offsetsWithMissingSnapshotFiles = ListBuffer[Long]()
     for (offset <- 5 until 10) {
       val snapshotFile = log.producerStateManager.snapshotFileForOffset(offset)
-      if (snapshotFile.isEmpty || !snapshotFile.get.file.exists) {
+      if (!snapshotFile.isPresent || !snapshotFile.get.file.exists) {
         offsetsWithMissingSnapshotFiles.append(offset)
       }
     }
     assertTrue(offsetsWithMissingSnapshotFiles.isEmpty,
       s"Found offsets with missing producer state snapshot files: $offsetsWithMissingSnapshotFiles")
-    assertFalse(logDir.list().exists(_.endsWith(UnifiedLog.DeletedFileSuffix)), "Expected no files to be present with the deleted file suffix")
+    assertFalse(logDir.list().exists(_.endsWith(LogFileUtils.DELETED_FILE_SUFFIX)), "Expected no files to be present with the deleted file suffix")
   }
 
   @Test
diff --git a/core/src/test/scala/unit/kafka/log/LogManagerTest.scala b/core/src/test/scala/unit/kafka/log/LogManagerTest.scala
index 592228fd25f..a1f842c6c67 100755
--- a/core/src/test/scala/unit/kafka/log/LogManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogManagerTest.scala
@@ -39,7 +39,7 @@ import java.nio.file.Files
 import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Future}
 import java.util.{Collections, Properties}
 import org.apache.kafka.server.metrics.KafkaYammerMetrics
-import org.apache.kafka.storage.internals.log.{FetchDataInfo, FetchIsolation, LogConfig, LogDirFailureChannel}
+import org.apache.kafka.storage.internals.log.{FetchDataInfo, FetchIsolation, LogConfig, LogDirFailureChannel, ProducerStateManagerConfig}
 
 import scala.collection.{Map, mutable}
 import scala.collection.mutable.ArrayBuffer
diff --git a/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala b/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala
index 585df56c744..337b9231829 100644
--- a/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala
@@ -26,7 +26,7 @@ import org.apache.kafka.common.record._
 import org.apache.kafka.common.utils.{MockTime, Time, Utils}
 import org.apache.kafka.storage.internals.checkpoint.LeaderEpochCheckpoint
 import org.apache.kafka.storage.internals.epoch.LeaderEpochFileCache
-import org.apache.kafka.storage.internals.log.{BatchMetadata, EpochEntry, LogConfig, ProducerStateEntry}
+import org.apache.kafka.storage.internals.log.{BatchMetadata, EpochEntry, LogConfig, ProducerStateEntry, ProducerStateManager, ProducerStateManagerConfig}
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
 
@@ -588,9 +588,9 @@ class LogSegmentTest {
     new ProducerStateManager(
       topicPartition,
       logDir,
-      maxTransactionTimeoutMs = 5 * 60 * 1000,
-      producerStateManagerConfig = new ProducerStateManagerConfig(kafka.server.Defaults.ProducerIdExpirationMs),
-      time = new MockTime()
+      5 * 60 * 1000,
+      new ProducerStateManagerConfig(kafka.server.Defaults.ProducerIdExpirationMs),
+      new MockTime()
     )
   }
 
diff --git a/core/src/test/scala/unit/kafka/log/LogTestUtils.scala b/core/src/test/scala/unit/kafka/log/LogTestUtils.scala
index 7f2867d28df..79aab9cd486 100644
--- a/core/src/test/scala/unit/kafka/log/LogTestUtils.scala
+++ b/core/src/test/scala/unit/kafka/log/LogTestUtils.scala
@@ -30,11 +30,10 @@ import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse}
 
 import java.nio.file.Files
 import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
-import kafka.log
 import org.apache.kafka.common.config.TopicConfig
 import org.apache.kafka.server.util.Scheduler
 import org.apache.kafka.storage.internals.checkpoint.LeaderEpochCheckpointFile
-import org.apache.kafka.storage.internals.log.{AbortedTxn, AppendOrigin, FetchDataInfo, FetchIsolation, LazyIndex, LogConfig, LogDirFailureChannel, TransactionIndex}
+import org.apache.kafka.storage.internals.log.{AbortedTxn, AppendOrigin, FetchDataInfo, FetchIsolation, LazyIndex, LogConfig, LogDirFailureChannel, LogFileUtils, ProducerStateManager, ProducerStateManagerConfig, TransactionIndex}
 
 import scala.jdk.CollectionConverters._
 
@@ -88,7 +87,7 @@ object LogTestUtils {
                 logStartOffset: Long = 0L,
                 recoveryPoint: Long = 0L,
                 maxTransactionTimeoutMs: Int = 5 * 60 * 1000,
-                producerStateManagerConfig: ProducerStateManagerConfig = new log.ProducerStateManagerConfig(kafka.server.Defaults.ProducerIdExpirationMs),
+                producerStateManagerConfig: ProducerStateManagerConfig = new ProducerStateManagerConfig(kafka.server.Defaults.ProducerIdExpirationMs),
                 producerIdExpirationCheckIntervalMs: Int = kafka.server.Defaults.ProducerIdExpirationCheckIntervalMs,
                 lastShutdownClean: Boolean = true,
                 topicId: Option[Uuid] = None,
@@ -198,7 +197,7 @@ object LogTestUtils {
     val recoveredLog = createLog(logDir, config, brokerTopicStats, scheduler, time, lastShutdownClean = false)
     time.sleep(config.fileDeleteDelayMs + 1)
     for (file <- logDir.listFiles) {
-      assertFalse(file.getName.endsWith(UnifiedLog.DeletedFileSuffix), "Unexpected .deleted file after recovery")
+      assertFalse(file.getName.endsWith(LogFileUtils.DELETED_FILE_SUFFIX), "Unexpected .deleted file after recovery")
       assertFalse(file.getName.endsWith(UnifiedLog.CleanedFileSuffix), "Unexpected .cleaned file after recovery")
       assertFalse(file.getName.endsWith(UnifiedLog.SwapFileSuffix), "Unexpected .swap file after recovery")
     }
@@ -241,12 +240,12 @@ object LogTestUtils {
   def allAbortedTransactions(log: UnifiedLog): Iterable[AbortedTxn] = log.logSegments.flatMap(_.txnIndex.allAbortedTxns.asScala)
 
   def deleteProducerSnapshotFiles(logDir: File): Unit = {
-    val files = logDir.listFiles.filter(f => f.isFile && f.getName.endsWith(UnifiedLog.ProducerSnapshotFileSuffix))
+    val files = logDir.listFiles.filter(f => f.isFile && f.getName.endsWith(LogFileUtils.PRODUCER_SNAPSHOT_FILE_SUFFIX))
     files.foreach(Utils.delete)
   }
 
   def listProducerSnapshotOffsets(logDir: File): Seq[Long] =
-    ProducerStateManager.listSnapshotFiles(logDir).map(_.offset).sorted
+    ProducerStateManager.listSnapshotFiles(logDir).asScala.map(_.offset).sorted.toSeq
 
   def assertLeaderEpochCacheEmpty(log: UnifiedLog): Unit = {
     assertEquals(None, log.leaderEpochCache)
diff --git a/core/src/test/scala/unit/kafka/log/ProducerStateManagerTest.scala b/core/src/test/scala/unit/kafka/log/ProducerStateManagerTest.scala
index 458303ae925..3c482b074c5 100644
--- a/core/src/test/scala/unit/kafka/log/ProducerStateManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/log/ProducerStateManagerTest.scala
@@ -29,7 +29,7 @@ import org.apache.kafka.common.errors._
 import org.apache.kafka.common.internals.Topic
 import org.apache.kafka.common.record._
 import org.apache.kafka.common.utils.{MockTime, Utils}
-import org.apache.kafka.storage.internals.log.{AppendOrigin, CompletedTxn, LogOffsetMetadata, ProducerAppendInfo, ProducerStateEntry, TxnMetadata}
+import org.apache.kafka.storage.internals.log.{AppendOrigin, CompletedTxn, LogFileUtils, LogOffsetMetadata, ProducerAppendInfo, ProducerStateEntry, ProducerStateManager, ProducerStateManagerConfig, TxnMetadata}
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
 import org.mockito.Mockito.{mock, when}
@@ -45,7 +45,7 @@ class ProducerStateManagerTest {
   private val producerId = 1L
   private val maxTransactionTimeoutMs = 5 * 60 * 1000
   private val producerStateManagerConfig = new ProducerStateManagerConfig(kafka.server.Defaults.ProducerIdExpirationMs)
-  private val lateTransactionTimeoutMs = maxTransactionTimeoutMs + ProducerStateManager.LateTransactionBufferMs
+  private val lateTransactionTimeoutMs = maxTransactionTimeoutMs + ProducerStateManager.LATE_TRANSACTION_BUFFER_MS
   private val time = new MockTime
 
   @BeforeEach
@@ -88,7 +88,7 @@ class ProducerStateManagerTest {
     val producerEpoch = 2.toShort
     appendEndTxnMarker(stateManager, producerId, producerEpoch, ControlRecordType.COMMIT, offset = 27L)
 
-    val firstEntry = stateManager.lastEntry(producerId).getOrElse(throw new RuntimeException("Expected last entry to be defined"))
+    val firstEntry = stateManager.lastEntry(producerId).orElseThrow(() => new RuntimeException("Expected last entry to be defined"))
     assertEquals(producerEpoch, firstEntry.producerEpoch)
     assertEquals(producerId, firstEntry.producerId)
     assertEquals(RecordBatch.NO_SEQUENCE, firstEntry.lastSeq)
@@ -102,7 +102,7 @@ class ProducerStateManagerTest {
 
     // The broker should accept the request if the sequence number is reset to 0
     append(stateManager, producerId, producerEpoch, 0, 39L, 4L)
-    val secondEntry = stateManager.lastEntry(producerId).getOrElse(throw new RuntimeException("Expected last entry to be defined"))
+    val secondEntry = stateManager.lastEntry(producerId).orElseThrow(() => new RuntimeException("Expected last entry to be defined"))
     assertEquals(producerEpoch, secondEntry.producerEpoch)
     assertEquals(producerId, secondEntry.producerId)
     assertEquals(0, secondEntry.lastSeq)
@@ -118,7 +118,7 @@ class ProducerStateManagerTest {
     append(stateManager, producerId, epoch, 0, offset + 500)
 
     val maybeLastEntry = stateManager.lastEntry(producerId)
-    assertTrue(maybeLastEntry.isDefined)
+    assertTrue(maybeLastEntry.isPresent)
 
     val lastEntry = maybeLastEntry.get
     assertEquals(epoch, lastEntry.producerEpoch)
@@ -131,13 +131,13 @@ class ProducerStateManagerTest {
   def testProducerSequenceWithWrapAroundBatchRecord(): Unit = {
     val epoch = 15.toShort
 
-    val appendInfo = stateManager.prepareUpdate(producerId, origin = AppendOrigin.REPLICATION)
+    val appendInfo = stateManager.prepareUpdate(producerId, AppendOrigin.REPLICATION)
     // Sequence number wrap around
     appendInfo.appendDataBatch(epoch, Int.MaxValue - 10, 9, time.milliseconds(),
       new LogOffsetMetadata(2000L), 2020L, false)
-    assertEquals(None, stateManager.lastEntry(producerId))
+    assertEquals(Optional.empty(), stateManager.lastEntry(producerId))
     stateManager.update(appendInfo)
-    assertTrue(stateManager.lastEntry(producerId).isDefined)
+    assertTrue(stateManager.lastEntry(producerId).isPresent)
 
     val lastEntry = stateManager.lastEntry(producerId).get
     assertEquals(Int.MaxValue-10, lastEntry.firstSeq)
@@ -163,7 +163,7 @@ class ProducerStateManagerTest {
     append(stateManager, producerId, epoch, sequence, offset, origin = AppendOrigin.REPLICATION)
 
     val maybeLastEntry = stateManager.lastEntry(producerId)
-    assertTrue(maybeLastEntry.isDefined)
+    assertTrue(maybeLastEntry.isPresent)
 
     val lastEntry = maybeLastEntry.get
     assertEquals(epoch, lastEntry.producerEpoch)
@@ -182,7 +182,7 @@ class ProducerStateManagerTest {
     appendEndTxnMarker(stateManager, producerId, bumpedProducerEpoch, ControlRecordType.ABORT, 1L)
 
     val maybeLastEntry = stateManager.lastEntry(producerId)
-    assertTrue(maybeLastEntry.isDefined)
+    assertTrue(maybeLastEntry.isPresent())
 
     val lastEntry = maybeLastEntry.get
     assertEquals(bumpedProducerEpoch, lastEntry.producerEpoch)
@@ -192,7 +192,7 @@ class ProducerStateManagerTest {
 
     // should be able to append with the new epoch if we start at sequence 0
     append(stateManager, producerId, bumpedProducerEpoch, 0, 2L)
-    assertEquals(Some(0), stateManager.lastEntry(producerId).map(_.firstSeq))
+    assertEquals(Optional.of(0L), stateManager.lastEntry(producerId).map[Long](_.firstSeq))
   }
 
   @Test
@@ -207,7 +207,7 @@ class ProducerStateManagerTest {
       firstOffsetMetadata, offset, true)
     stateManager.update(producerAppendInfo)
 
-    assertEquals(Some(firstOffsetMetadata), stateManager.firstUnstableOffset)
+    assertEquals(Optional.of(firstOffsetMetadata), stateManager.firstUnstableOffset())
   }
 
   @Test
@@ -237,17 +237,17 @@ class ProducerStateManagerTest {
     }
 
     // Start one transaction in a separate append
-    val firstAppend = stateManager.prepareUpdate(producerId, origin = AppendOrigin.CLIENT)
+    val firstAppend = stateManager.prepareUpdate(producerId, AppendOrigin.CLIENT)
     appendData(16L, 20L, firstAppend)
-    assertTxnMetadataEquals(new TxnMetadata(producerId, 16L), firstAppend.startedTransactions.asScala.head)
+    assertTxnMetadataEquals(new TxnMetadata(producerId, 16L), firstAppend.startedTransactions.get(0))
     stateManager.update(firstAppend)
     stateManager.onHighWatermarkUpdated(21L)
-    assertEquals(Some(new LogOffsetMetadata(16L)), stateManager.firstUnstableOffset)
+    assertEquals(Optional.of(new LogOffsetMetadata(16L)), stateManager.firstUnstableOffset)
 
     // Now do a single append which completes the old transaction, mixes in
     // some empty transactions, one non-empty complete transaction, and one
     // incomplete transaction
-    val secondAppend = stateManager.prepareUpdate(producerId, origin = AppendOrigin.CLIENT)
+    val secondAppend = stateManager.prepareUpdate(producerId, AppendOrigin.CLIENT)
     val firstCompletedTxn = appendEndTxn(ControlRecordType.COMMIT, 21, secondAppend)
     assertEquals(Some(new CompletedTxn(producerId, 16L, 21, false)), firstCompletedTxn)
     assertEquals(None, appendEndTxn(ControlRecordType.COMMIT, 22, secondAppend))
@@ -258,14 +258,15 @@ class ProducerStateManagerTest {
     assertEquals(None, appendEndTxn(ControlRecordType.ABORT, 29L, secondAppend))
     appendData(30L, 31L, secondAppend)
 
-    assertEquals(2, secondAppend.startedTransactions.size)
-    assertTxnMetadataEquals(new TxnMetadata(producerId, new LogOffsetMetadata(24L)), secondAppend.startedTransactions.asScala.head)
-    assertTxnMetadataEquals(new TxnMetadata(producerId, new LogOffsetMetadata(30L)), secondAppend.startedTransactions.asScala.last)
+    val size = secondAppend.startedTransactions.size
+    assertEquals(2, size)
+    assertTxnMetadataEquals(new TxnMetadata(producerId, new LogOffsetMetadata(24L)), secondAppend.startedTransactions.get(0))
+    assertTxnMetadataEquals(new TxnMetadata(producerId, new LogOffsetMetadata(30L)), secondAppend.startedTransactions.get(size - 1))
     stateManager.update(secondAppend)
     stateManager.completeTxn(firstCompletedTxn.get)
     stateManager.completeTxn(secondCompletedTxn.get)
     stateManager.onHighWatermarkUpdated(32L)
-    assertEquals(Some(new LogOffsetMetadata(30L)), stateManager.firstUnstableOffset)
+    assertEquals(Optional.of(new LogOffsetMetadata(30L)), stateManager.firstUnstableOffset)
   }
 
   def assertTxnMetadataEquals(expected: java.util.List[TxnMetadata], actual: java.util.List[TxnMetadata]): Unit = {
@@ -340,8 +341,7 @@ class ProducerStateManagerTest {
     // After reloading from the snapshot, the transaction should still be considered late
     val reloadedStateManager = new ProducerStateManager(partition, logDir, maxTransactionTimeoutMs,
       producerStateManagerConfig, time)
-    reloadedStateManager.truncateAndReload(logStartOffset = 0L,
-      logEndOffset = stateManager.mapEndOffset, currentTimeMs = time.milliseconds())
+    reloadedStateManager.truncateAndReload(0L, stateManager.mapEndOffset, time.milliseconds())
     assertTrue(reloadedStateManager.hasLateTransaction(time.milliseconds()))
   }
 
@@ -357,7 +357,7 @@ class ProducerStateManagerTest {
     assertTrue(stateManager.hasLateTransaction(time.milliseconds()))
 
     // After truncation, the ongoing transaction will be cleared
-    stateManager.truncateAndReload(logStartOffset = 0, logEndOffset = 80, currentTimeMs = time.milliseconds())
+    stateManager.truncateAndReload(0, 80, time.milliseconds())
     assertFalse(stateManager.hasLateTransaction(time.milliseconds()))
   }
 
@@ -373,7 +373,7 @@ class ProducerStateManagerTest {
     assertTrue(stateManager.hasLateTransaction(time.milliseconds()))
 
     // After truncation, the ongoing transaction will be cleared
-    stateManager.truncateFullyAndStartAt(offset = 150L)
+    stateManager.truncateFullyAndStartAt(150L)
     assertFalse(stateManager.hasLateTransaction(time.milliseconds()))
   }
 
@@ -413,38 +413,39 @@ class ProducerStateManagerTest {
     assertEquals(startOffset2, stateManager.lastStableOffset(completedTxn1))
     stateManager.completeTxn(completedTxn1)
     stateManager.onHighWatermarkUpdated(lastOffset1 + 1)
-    assertEquals(Some(startOffset2), stateManager.firstUnstableOffset.map(_.messageOffset))
+
+    assertEquals(Optional.of(startOffset2), stateManager.firstUnstableOffset.map[Long](x => x.messageOffset))
 
     val lastOffset3 = lastOffset1 + 20
     val completedTxn3 = new CompletedTxn(producerId3, startOffset3, lastOffset3, false)
     assertEquals(startOffset2, stateManager.lastStableOffset(completedTxn3))
     stateManager.completeTxn(completedTxn3)
     stateManager.onHighWatermarkUpdated(lastOffset3 + 1)
-    assertEquals(Some(startOffset2), stateManager.firstUnstableOffset.map(_.messageOffset))
+    assertEquals(Optional.of(startOffset2), stateManager.firstUnstableOffset.map[Long](x => x.messageOffset))
 
     val lastOffset2 = lastOffset3 + 78
     val completedTxn2 = new CompletedTxn(producerId2, startOffset2, lastOffset2, false)
     assertEquals(lastOffset2 + 1, stateManager.lastStableOffset(completedTxn2))
     stateManager.completeTxn(completedTxn2)
     stateManager.onHighWatermarkUpdated(lastOffset2 + 1)
-    assertEquals(None, stateManager.firstUnstableOffset)
+    assertEquals(Optional.empty(), stateManager.firstUnstableOffset)
   }
 
   @Test
   def testPrepareUpdateDoesNotMutate(): Unit = {
     val producerEpoch = 0.toShort
 
-    val appendInfo = stateManager.prepareUpdate(producerId, origin = AppendOrigin.CLIENT)
+    val appendInfo = stateManager.prepareUpdate(producerId, AppendOrigin.CLIENT)
     appendInfo.appendDataBatch(producerEpoch, 0, 5, time.milliseconds(),
       new LogOffsetMetadata(15L), 20L, false)
-    assertEquals(None, stateManager.lastEntry(producerId))
+    assertEquals(Optional.empty(), stateManager.lastEntry(producerId))
     stateManager.update(appendInfo)
-    assertTrue(stateManager.lastEntry(producerId).isDefined)
+    assertTrue(stateManager.lastEntry(producerId).isPresent())
 
-    val nextAppendInfo = stateManager.prepareUpdate(producerId, origin = AppendOrigin.CLIENT)
+    val nextAppendInfo = stateManager.prepareUpdate(producerId, AppendOrigin.CLIENT)
     nextAppendInfo.appendDataBatch(producerEpoch, 6, 10, time.milliseconds(),
       new LogOffsetMetadata(26L), 30L, false)
-    assertTrue(stateManager.lastEntry(producerId).isDefined)
+    assertTrue(stateManager.lastEntry(producerId).isPresent())
 
     var lastEntry = stateManager.lastEntry(producerId).get
     assertEquals(0, lastEntry.firstSeq)
@@ -465,7 +466,7 @@ class ProducerStateManagerTest {
     val offset = 9L
     append(stateManager, producerId, producerEpoch, 0, offset)
 
-    val appendInfo = stateManager.prepareUpdate(producerId, origin = AppendOrigin.CLIENT)
+    val appendInfo = stateManager.prepareUpdate(producerId, AppendOrigin.CLIENT)
     appendInfo.appendDataBatch(producerEpoch, 1, 5, time.milliseconds(),
       new LogOffsetMetadata(16L), 20L, true)
     var lastEntry = appendInfo.toEntry
@@ -555,8 +556,8 @@ class ProducerStateManagerTest {
 
     stateManager.truncateAndReload(1L, 3L, time.milliseconds())
 
-    assertEquals(Some(2L), stateManager.oldestSnapshotOffset)
-    assertEquals(Some(3L), stateManager.latestSnapshotOffset)
+    assertEquals(OptionalLong.of(2L), stateManager.oldestSnapshotOffset)
+    assertEquals(OptionalLong.of(3L), stateManager.latestSnapshotOffset)
   }
 
   @Test
@@ -627,10 +628,10 @@ class ProducerStateManagerTest {
 
     val recoveredMapping = new ProducerStateManager(partition, logDir,
       maxTransactionTimeoutMs, producerStateManagerConfig, time)
-    recoveredMapping.truncateAndReload(logStartOffset = 0L, logEndOffset = 1L, time.milliseconds)
+    recoveredMapping.truncateAndReload(0L, 1L, time.milliseconds)
 
     val lastEntry = recoveredMapping.lastEntry(producerId)
-    assertTrue(lastEntry.isDefined)
+    assertTrue(lastEntry.isPresent())
     assertEquals(appendTimestamp, lastEntry.get.lastTimestamp)
     assertEquals(OptionalLong.empty(), lastEntry.get.currentTxnFirstOffset)
   }
@@ -648,7 +649,7 @@ class ProducerStateManagerTest {
 
     // The producer should not be expired because we want to preserve fencing epochs
     stateManager.removeExpiredProducers(time.milliseconds())
-    assertTrue(stateManager.lastEntry(producerId).isDefined)
+    assertTrue(stateManager.lastEntry(producerId).isPresent())
   }
 
   @Test
@@ -667,7 +668,7 @@ class ProducerStateManagerTest {
     append(recoveredMapping, producerId, epoch, 2, 2L, 70001)
 
     assertEquals(1, recoveredMapping.activeProducers.size)
-    assertEquals(2, recoveredMapping.activeProducers.head._2.lastSeq)
+    assertEquals(2, recoveredMapping.activeProducers.values().iterator().next().lastSeq)
     assertEquals(3L, recoveredMapping.mapEndOffset)
   }
 
@@ -685,10 +686,10 @@ class ProducerStateManagerTest {
     val sequence = 2
     // entry added after recovery. The pid should be expired now, and would not exist in the pid mapping. Nonetheless
     // the append on a replica should be accepted with the local producer state updated to the appended value.
-    assertFalse(recoveredMapping.activeProducers.contains(producerId))
+    assertFalse(recoveredMapping.activeProducers.containsKey(producerId))
     append(recoveredMapping, producerId, epoch, sequence, 2L, 70001, origin = AppendOrigin.REPLICATION)
-    assertTrue(recoveredMapping.activeProducers.contains(producerId))
-    val producerStateEntry = recoveredMapping.activeProducers.get(producerId).head
+    assertTrue(recoveredMapping.activeProducers.containsKey(producerId))
+    val producerStateEntry = recoveredMapping.activeProducers.get(producerId)
     assertEquals(epoch, producerStateEntry.producerEpoch)
     assertEquals(sequence, producerStateEntry.firstSeq)
     assertEquals(sequence, producerStateEntry.lastSeq)
@@ -703,9 +704,9 @@ class ProducerStateManagerTest {
     // First we ensure that we raise an OutOfOrderSequenceException is raised when the append comes from a client.
     assertThrows(classOf[OutOfOrderSequenceException], () => append(stateManager, producerId, epoch, outOfOrderSequence, 1L, 1, origin = AppendOrigin.CLIENT))
 
-    assertEquals(0L, stateManager.activeProducers(producerId).lastSeq)
+    assertEquals(0L, stateManager.activeProducers.get(producerId).lastSeq)
     append(stateManager, producerId, epoch, outOfOrderSequence, 1L, 1, origin = AppendOrigin.REPLICATION)
-    assertEquals(outOfOrderSequence, stateManager.activeProducers(producerId).lastSeq)
+    assertEquals(outOfOrderSequence, stateManager.activeProducers.get(producerId).lastSeq)
   }
 
   @Test
@@ -782,7 +783,7 @@ class ProducerStateManagerTest {
     // It loads the earlier written snapshot files from log dir.
     stateManager.truncateFullyAndReloadSnapshots()
 
-    assertEquals(Some(3), stateManager.latestSnapshotOffset)
+    assertEquals(OptionalLong.of(3), stateManager.latestSnapshotOffset)
     assertEquals(Set(3), currentSnapshotOffsets)
   }
 
@@ -792,20 +793,20 @@ class ProducerStateManagerTest {
     val sequence = 0
 
     append(stateManager, producerId, epoch, sequence, offset = 99, isTransactional = true)
-    assertEquals(Some(99), stateManager.firstUnstableOffset.map(_.messageOffset))
+    assertEquals(Optional.of(99L), stateManager.firstUnstableOffset.map[Long](x => x.messageOffset))
     stateManager.takeSnapshot()
 
     appendEndTxnMarker(stateManager, producerId, epoch, ControlRecordType.COMMIT, offset = 105)
     stateManager.onHighWatermarkUpdated(106)
-    assertEquals(None, stateManager.firstUnstableOffset.map(_.messageOffset))
+    assertEquals(Optional.empty(), stateManager.firstUnstableOffset.map[Long](x => x.messageOffset))
     stateManager.takeSnapshot()
 
     append(stateManager, producerId, epoch, sequence + 1, offset = 106)
     stateManager.truncateAndReload(0L, 106, time.milliseconds())
-    assertEquals(None, stateManager.firstUnstableOffset.map(_.messageOffset))
+    assertEquals(Optional.empty(), stateManager.firstUnstableOffset.map[Long](x => x.messageOffset))
 
     stateManager.truncateAndReload(0L, 100L, time.milliseconds())
-    assertEquals(Some(99), stateManager.firstUnstableOffset.map(_.messageOffset))
+    assertEquals(Optional.of(99L), stateManager.firstUnstableOffset.map[Long](x => x.messageOffset))
   }
 
   @Test
@@ -823,12 +824,12 @@ class ProducerStateManagerTest {
     assertEquals(2, stateManager.activeProducers.size)
 
     val entry1 = stateManager.lastEntry(pid1)
-    assertTrue(entry1.isDefined)
+    assertTrue(entry1.isPresent)
     assertEquals(0, entry1.get.lastSeq)
     assertEquals(0L, entry1.get.lastDataOffset)
 
     val entry2 = stateManager.lastEntry(pid2)
-    assertTrue(entry2.isDefined)
+    assertTrue(entry2.isPresent)
     assertEquals(0, entry2.get.lastSeq)
     assertEquals(1L, entry2.get.lastDataOffset)
   }
@@ -857,7 +858,7 @@ class ProducerStateManagerTest {
     stateManager.removeExpiredProducers(time.milliseconds)
     append(stateManager, producerId, epoch, sequence + 1, 2L)
     assertEquals(1, stateManager.activeProducers.size)
-    assertEquals(sequence + 1, stateManager.activeProducers.head._2.lastSeq)
+    assertEquals(sequence + 1, stateManager.activeProducers.values().iterator().next().lastSeq)
     assertEquals(3L, stateManager.mapEndOffset)
   }
 
@@ -866,33 +867,33 @@ class ProducerStateManagerTest {
     val epoch = 5.toShort
     val sequence = 0
 
-    assertEquals(None, stateManager.firstUndecidedOffset)
+    assertEquals(OptionalLong.empty(), stateManager.firstUndecidedOffset)
 
     append(stateManager, producerId, epoch, sequence, offset = 99, isTransactional = true)
-    assertEquals(Some(99L), stateManager.firstUndecidedOffset)
-    assertEquals(Some(99L), stateManager.firstUnstableOffset.map(_.messageOffset))
+    assertEquals(OptionalLong.of(99L), stateManager.firstUndecidedOffset)
+    assertEquals(Optional.of(99L), stateManager.firstUnstableOffset.map[Long](x => x.messageOffset))
 
     val anotherPid = 2L
     append(stateManager, anotherPid, epoch, sequence, offset = 105, isTransactional = true)
-    assertEquals(Some(99L), stateManager.firstUndecidedOffset)
-    assertEquals(Some(99L), stateManager.firstUnstableOffset.map(_.messageOffset))
+    assertEquals(OptionalLong.of(99L), stateManager.firstUndecidedOffset)
+    assertEquals(Optional.of(99L), stateManager.firstUnstableOffset.map[Long](x => x.messageOffset))
 
     appendEndTxnMarker(stateManager, producerId, epoch, ControlRecordType.COMMIT, offset = 109)
-    assertEquals(Some(105L), stateManager.firstUndecidedOffset)
-    assertEquals(Some(99L), stateManager.firstUnstableOffset.map(_.messageOffset))
+    assertEquals(OptionalLong.of(105L), stateManager.firstUndecidedOffset)
+    assertEquals(Optional.of(99L), stateManager.firstUnstableOffset.map[Long](x => x.messageOffset))
 
     stateManager.onHighWatermarkUpdated(100L)
-    assertEquals(Some(99L), stateManager.firstUnstableOffset.map(_.messageOffset))
+    assertEquals(Optional.of(99L), stateManager.firstUnstableOffset.map[Long](x => x.messageOffset))
 
     stateManager.onHighWatermarkUpdated(110L)
-    assertEquals(Some(105L), stateManager.firstUnstableOffset.map(_.messageOffset))
+    assertEquals(Optional.of(105L), stateManager.firstUnstableOffset.map[Long](x => x.messageOffset))
 
     appendEndTxnMarker(stateManager, anotherPid, epoch, ControlRecordType.ABORT, offset = 112)
-    assertEquals(None, stateManager.firstUndecidedOffset)
-    assertEquals(Some(105L), stateManager.firstUnstableOffset.map(_.messageOffset))
+    assertEquals(OptionalLong.empty(), stateManager.firstUndecidedOffset)
+    assertEquals(Optional.of(105L), stateManager.firstUnstableOffset.map[Long](x => x.messageOffset))
 
     stateManager.onHighWatermarkUpdated(113L)
-    assertEquals(None, stateManager.firstUnstableOffset.map(_.messageOffset))
+    assertEquals(Optional.empty(), stateManager.firstUnstableOffset.map[Long](x => x.messageOffset))
   }
 
   @Test
@@ -901,16 +902,16 @@ class ProducerStateManagerTest {
     val sequence = 0
 
     append(stateManager, producerId, epoch, sequence, offset = 99, isTransactional = true)
-    assertEquals(Some(99L), stateManager.firstUndecidedOffset)
+    assertEquals(OptionalLong.of(99L), stateManager.firstUndecidedOffset)
 
     time.sleep(producerStateManagerConfig.producerIdExpirationMs + 1)
     stateManager.removeExpiredProducers(time.milliseconds)
 
-    assertTrue(stateManager.lastEntry(producerId).isDefined)
-    assertEquals(Some(99L), stateManager.firstUndecidedOffset)
+    assertTrue(stateManager.lastEntry(producerId).isPresent())
+    assertEquals(OptionalLong.of(99L), stateManager.firstUndecidedOffset)
 
     stateManager.removeExpiredProducers(time.milliseconds)
-    assertTrue(stateManager.lastEntry(producerId).isDefined)
+    assertTrue(stateManager.lastEntry(producerId).isPresent)
   }
 
   @Test
@@ -931,7 +932,7 @@ class ProducerStateManagerTest {
     val epoch = 5.toShort
     val sequence = 0
 
-    assertEquals(None, stateManager.firstUndecidedOffset)
+    assertEquals(OptionalLong.empty(), stateManager.firstUndecidedOffset)
 
     append(stateManager, producerId, epoch, sequence, offset = 99, isTransactional = true)
     assertThrows(classOf[InvalidProducerEpochException], () => appendEndTxnMarker(stateManager, producerId, 3.toShort,
@@ -947,7 +948,7 @@ class ProducerStateManagerTest {
     appendEndTxnMarker(stateManager, producerId, epoch, ControlRecordType.COMMIT, offset = 100, coordinatorEpoch = 1)
 
     val lastEntry = stateManager.lastEntry(producerId)
-    assertEquals(Some(1), lastEntry.map(_.coordinatorEpoch))
+    assertEquals(Optional.of(1), lastEntry.map[Int](x => x.coordinatorEpoch))
 
     // writing with the current epoch is allowed
     appendEndTxnMarker(stateManager, producerId, epoch, ControlRecordType.COMMIT, offset = 101, coordinatorEpoch = 1)
@@ -1020,18 +1021,18 @@ class ProducerStateManagerTest {
     // the broker shutdown cleanly and emitted a snapshot file larger than the base offset of the active segment.
 
     // Create 3 snapshot files at different offsets.
-    Files.createFile(UnifiedLog.producerSnapshotFile(logDir, 5).toPath) // not stray
-    Files.createFile(UnifiedLog.producerSnapshotFile(logDir, 2).toPath) // stray
-    Files.createFile(UnifiedLog.producerSnapshotFile(logDir, 42).toPath) // not stray
+    Files.createFile(LogFileUtils.producerSnapshotFile(logDir, 5).toPath) // not stray
+    Files.createFile(LogFileUtils.producerSnapshotFile(logDir, 2).toPath) // stray
+    Files.createFile(LogFileUtils.producerSnapshotFile(logDir, 42).toPath) // not stray
 
     // claim that we only have one segment with a base offset of 5
-    stateManager.removeStraySnapshots(Seq(5))
+    stateManager.removeStraySnapshots(Collections.singletonList(5))
 
     // The snapshot file at offset 2 should be considered a stray, but the snapshot at 42 should be kept
     // around because it is the largest snapshot.
-    assertEquals(Some(42), stateManager.latestSnapshotOffset)
-    assertEquals(Some(5), stateManager.oldestSnapshotOffset)
-    assertEquals(Seq(5, 42), ProducerStateManager.listSnapshotFiles(logDir).map(_.offset).sorted)
+    assertEquals(OptionalLong.of(42), stateManager.latestSnapshotOffset)
+    assertEquals(OptionalLong.of(5), stateManager.oldestSnapshotOffset)
+    assertEquals(Seq(5L, 42L), ProducerStateManager.listSnapshotFiles(logDir).asScala.map(_.offset).sorted)
   }
 
   @Test
@@ -1040,12 +1041,13 @@ class ProducerStateManagerTest {
     // Snapshots associated with an offset in the list of segment base offsets should remain.
 
     // Create 3 snapshot files at different offsets.
-    Files.createFile(UnifiedLog.producerSnapshotFile(logDir, 5).toPath) // stray
-    Files.createFile(UnifiedLog.producerSnapshotFile(logDir, 2).toPath) // stray
-    Files.createFile(UnifiedLog.producerSnapshotFile(logDir, 42).toPath) // not stray
+    Files.createFile(LogFileUtils.producerSnapshotFile(logDir, 5).toPath) // stray
+    Files.createFile(LogFileUtils.producerSnapshotFile(logDir, 2).toPath) // stray
+    Files.createFile(LogFileUtils.producerSnapshotFile(logDir, 42).toPath) // not stray
+
+    stateManager.removeStraySnapshots(Collections.singletonList(42))
+    assertEquals(Seq(42L), ProducerStateManager.listSnapshotFiles(logDir).asScala.map(_.offset).sorted)
 
-    stateManager.removeStraySnapshots(Seq(42))
-    assertEquals(Seq(42), ProducerStateManager.listSnapshotFiles(logDir).map(_.offset).sorted)
   }
 
   /**
@@ -1054,12 +1056,12 @@ class ProducerStateManagerTest {
    */
   @Test
   def testRemoveAndMarkSnapshotForDeletion(): Unit = {
-    Files.createFile(UnifiedLog.producerSnapshotFile(logDir, 5).toPath)
+    Files.createFile(LogFileUtils.producerSnapshotFile(logDir, 5).toPath)
     val manager = new ProducerStateManager(partition, logDir, maxTransactionTimeoutMs, producerStateManagerConfig, time)
-    assertTrue(manager.latestSnapshotOffset.isDefined)
+    assertTrue(manager.latestSnapshotOffset.isPresent)
     val snapshot = manager.removeAndMarkSnapshotForDeletion(5).get
-    assertTrue(snapshot.file.toPath.toString.endsWith(UnifiedLog.DeletedFileSuffix))
-    assertTrue(manager.latestSnapshotOffset.isEmpty)
+    assertTrue(snapshot.file.toPath.toString.endsWith(LogFileUtils.DELETED_FILE_SUFFIX))
+    assertTrue(!manager.latestSnapshotOffset.isPresent)
   }
 
   /**
@@ -1071,13 +1073,13 @@ class ProducerStateManagerTest {
    */
   @Test
   def testRemoveAndMarkSnapshotForDeletionAlreadyDeleted(): Unit = {
-    val file = UnifiedLog.producerSnapshotFile(logDir, 5)
+    val file = LogFileUtils.producerSnapshotFile(logDir, 5)
     Files.createFile(file.toPath)
     val manager = new ProducerStateManager(partition, logDir, maxTransactionTimeoutMs, producerStateManagerConfig, time)
-    assertTrue(manager.latestSnapshotOffset.isDefined)
+    assertTrue(manager.latestSnapshotOffset.isPresent)
     Files.delete(file.toPath)
-    assertTrue(manager.removeAndMarkSnapshotForDeletion(5).isEmpty)
-    assertTrue(manager.latestSnapshotOffset.isEmpty)
+    assertTrue(!manager.removeAndMarkSnapshotForDeletion(5).isPresent)
+    assertTrue(!manager.latestSnapshotOffset.isPresent)
   }
 
   private def testLoadFromCorruptSnapshot(makeFileCorrupt: FileChannel => Unit): Unit = {
@@ -1092,8 +1094,8 @@ class ProducerStateManagerTest {
 
     // Truncate the last snapshot
     val latestSnapshotOffset = stateManager.latestSnapshotOffset
-    assertEquals(Some(2L), latestSnapshotOffset)
-    val snapshotToTruncate = UnifiedLog.producerSnapshotFile(logDir, latestSnapshotOffset.get)
+    assertEquals(OptionalLong.of(2L), latestSnapshotOffset)
+    val snapshotToTruncate = LogFileUtils.producerSnapshotFile(logDir, latestSnapshotOffset.getAsLong)
     val channel = FileChannel.open(snapshotToTruncate.toPath, StandardOpenOption.WRITE)
     try {
       makeFileCorrupt(channel)
@@ -1107,7 +1109,7 @@ class ProducerStateManagerTest {
     reloadedStateManager.truncateAndReload(0L, 20L, time.milliseconds())
     assertFalse(snapshotToTruncate.exists())
 
-    val loadedProducerState = reloadedStateManager.activeProducers(producerId)
+    val loadedProducerState = reloadedStateManager.activeProducers.get(producerId)
     assertEquals(0L, loadedProducerState.lastDataOffset)
   }
 
@@ -1118,7 +1120,7 @@ class ProducerStateManagerTest {
                                  offset: Long,
                                  coordinatorEpoch: Int = 0,
                                  timestamp: Long = time.milliseconds()): Option[CompletedTxn] = {
-    val producerAppendInfo = stateManager.prepareUpdate(producerId, origin = AppendOrigin.COORDINATOR)
+    val producerAppendInfo = stateManager.prepareUpdate(producerId, AppendOrigin.COORDINATOR)
     val endTxnMarker = new EndTransactionMarker(controlType, coordinatorEpoch)
     val completedTxnOpt = producerAppendInfo.appendEndTxnMarker(endTxnMarker, producerEpoch, offset, timestamp).asScala
     mapping.update(producerAppendInfo)
diff --git a/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala b/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala
index 857a1218328..5be226eaf0c 100755
--- a/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala
+++ b/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala
@@ -36,7 +36,7 @@ import org.apache.kafka.server.metrics.KafkaYammerMetrics
 import org.apache.kafka.server.util.{KafkaScheduler, Scheduler}
 import org.apache.kafka.storage.internals.checkpoint.LeaderEpochCheckpointFile
 import org.apache.kafka.storage.internals.epoch.LeaderEpochFileCache
-import org.apache.kafka.storage.internals.log.{AbortedTxn, AppendOrigin, EpochEntry, FetchIsolation, LogConfig, LogOffsetMetadata, RecordValidationException}
+import org.apache.kafka.storage.internals.log.{AbortedTxn, AppendOrigin, EpochEntry, FetchIsolation, LogConfig, LogFileUtils, LogOffsetMetadata, ProducerStateManager, ProducerStateManagerConfig, RecordValidationException}
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
 import org.mockito.ArgumentMatchers
@@ -47,9 +47,8 @@ import java.io._
 import java.nio.ByteBuffer
 import java.nio.file.Files
 import java.util.concurrent.{Callable, ConcurrentHashMap, Executors}
-import java.util.{Optional, Properties}
+import java.util.{Optional, OptionalLong, Properties}
 import scala.annotation.nowarn
-import scala.collection.Map
 import scala.collection.mutable.ListBuffer
 import scala.compat.java8.OptionConverters._
 import scala.jdk.CollectionConverters._
@@ -180,9 +179,7 @@ class UnifiedLogTest {
     testTruncateBelowFirstUnstableOffset(_.truncateFullyAndStartAt)
   }
 
-  private def testTruncateBelowFirstUnstableOffset(
-    truncateFunc: UnifiedLog => (Long => Unit)
-  ): Unit = {
+  private def testTruncateBelowFirstUnstableOffset(truncateFunc: UnifiedLog => (Long => Unit)): Unit = {
     // Verify that truncation below the first unstable offset correctly
     // resets the producer state. Specifically we are testing the case when
     // the segment position of the first unstable offset is unknown.
@@ -216,11 +213,11 @@ class UnifiedLogTest {
     log.close()
 
     val reopened = createLog(logDir, logConfig)
-    assertEquals(Some(new LogOffsetMetadata(3L)), reopened.producerStateManager.firstUnstableOffset)
+    assertEquals(Optional.of(new LogOffsetMetadata(3L)), reopened.producerStateManager.firstUnstableOffset)
 
     truncateFunc(reopened)(0L)
     assertEquals(None, reopened.firstUnstableOffset)
-    assertEquals(Map.empty, reopened.producerStateManager.activeProducers)
+    assertEquals(java.util.Collections.emptyMap(), reopened.producerStateManager.activeProducers)
   }
 
   @Test
@@ -473,7 +470,7 @@ class UnifiedLogTest {
   @Test
   def testOffsetFromProducerSnapshotFile(): Unit = {
     val offset = 23423423L
-    val snapshotFile = UnifiedLog.producerSnapshotFile(tmpDir, offset)
+    val snapshotFile = LogFileUtils.producerSnapshotFile(tmpDir, offset)
     assertEquals(offset, UnifiedLog.offsetFromFile(snapshotFile))
   }
 
@@ -664,7 +661,7 @@ class UnifiedLogTest {
     // snapshot files, and then reloading the log
     val logConfig = LogTestUtils.createLogConfig(segmentBytes = 64 * 10)
     var log = createLog(logDir, logConfig)
-    assertEquals(None, log.oldestProducerSnapshotOffset)
+    assertEquals(OptionalLong.empty(), log.oldestProducerSnapshotOffset)
 
     for (i <- 0 to 100) {
       val record = new SimpleRecord(mockTime.milliseconds, i.toString.getBytes)
@@ -738,7 +735,7 @@ class UnifiedLogTest {
     val records = TestUtils.records(List(new SimpleRecord(mockTime.milliseconds, "key".getBytes, "value".getBytes)))
     log.appendAsLeader(records, leaderEpoch = 0)
     log.takeProducerSnapshot()
-    assertEquals(Some(1), log.latestProducerSnapshotOffset)
+    assertEquals(OptionalLong.of(1), log.latestProducerSnapshotOffset)
   }
 
   @Test
@@ -873,15 +870,15 @@ class UnifiedLogTest {
     log.takeProducerSnapshot()
 
     log.truncateTo(2)
-    assertEquals(Some(2), log.latestProducerSnapshotOffset)
+    assertEquals(OptionalLong.of(2), log.latestProducerSnapshotOffset)
     assertEquals(2, log.latestProducerStateEndOffset)
 
     log.truncateTo(1)
-    assertEquals(Some(1), log.latestProducerSnapshotOffset)
+    assertEquals(OptionalLong.of(1), log.latestProducerSnapshotOffset)
     assertEquals(1, log.latestProducerStateEndOffset)
 
     log.truncateTo(0)
-    assertEquals(None, log.latestProducerSnapshotOffset)
+    assertEquals(OptionalLong.empty(), log.latestProducerSnapshotOffset)
     assertEquals(0, log.latestProducerStateEndOffset)
   }
 
@@ -1010,7 +1007,7 @@ class UnifiedLogTest {
     log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes, "c".getBytes())), producerId = pid1,
       producerEpoch = epoch, sequence = 2), leaderEpoch = 0)
     log.updateHighWatermark(log.logEndOffset)
-    assertEquals(log.logSegments.map(_.baseOffset).toSeq.sorted.drop(1), ProducerStateManager.listSnapshotFiles(logDir).map(_.offset).sorted,
+    assertEquals(log.logSegments.map(_.baseOffset).toSeq.sorted.drop(1), ProducerStateManager.listSnapshotFiles(logDir).asScala.map(_.offset).sorted,
       "expected a snapshot file per segment base offset, except the first segment")
     assertEquals(2, ProducerStateManager.listSnapshotFiles(logDir).size)
 
@@ -1020,7 +1017,7 @@ class UnifiedLogTest {
     log.deleteOldSegments()
     // Sleep to breach the file delete delay and run scheduled file deletion tasks
     mockTime.sleep(1)
-    assertEquals(log.logSegments.map(_.baseOffset).toSeq.sorted.drop(1), ProducerStateManager.listSnapshotFiles(logDir).map(_.offset).sorted,
+    assertEquals(log.logSegments.map(_.baseOffset).toSeq.sorted.drop(1), ProducerStateManager.listSnapshotFiles(logDir).asScala.map(_.offset).sorted,
       "expected a snapshot file per segment base offset, excluding the first")
   }
 
@@ -1030,7 +1027,7 @@ class UnifiedLogTest {
    */
   @Test
   def testLoadingLogDeletesProducerStateSnapshotsPastLogEndOffset(): Unit = {
-    val straySnapshotFile = UnifiedLog.producerSnapshotFile(logDir, 42).toPath
+    val straySnapshotFile = LogFileUtils.producerSnapshotFile(logDir, 42).toPath
     Files.createFile(straySnapshotFile)
     val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5, retentionBytes = -1, fileDeleteDelayMs = 0)
     createLog(logDir, logConfig)
@@ -1052,11 +1049,11 @@ class UnifiedLogTest {
 
     assertEquals(3, log.logSegments.size)
     assertEquals(3, log.latestProducerStateEndOffset)
-    assertEquals(Some(3), log.latestProducerSnapshotOffset)
+    assertEquals(OptionalLong.of(3), log.latestProducerSnapshotOffset)
 
     log.truncateFullyAndStartAt(29)
     assertEquals(1, log.logSegments.size)
-    assertEquals(None, log.latestProducerSnapshotOffset)
+    assertEquals(OptionalLong.empty(), log.latestProducerSnapshotOffset)
     assertEquals(29, log.latestProducerStateEndOffset)
   }
 
@@ -1093,25 +1090,25 @@ class UnifiedLogTest {
     val log = createLog(logDir, logConfig)
     log.appendAsLeader(TestUtils.singletonRecords("a".getBytes), leaderEpoch = 0)
     log.roll(Some(1L))
-    assertEquals(Some(1L), log.latestProducerSnapshotOffset)
-    assertEquals(Some(1L), log.oldestProducerSnapshotOffset)
+    assertEquals(OptionalLong.of(1L), log.latestProducerSnapshotOffset)
+    assertEquals(OptionalLong.of(1L), log.oldestProducerSnapshotOffset)
 
     log.appendAsLeader(TestUtils.singletonRecords("b".getBytes), leaderEpoch = 0)
     log.roll(Some(2L))
-    assertEquals(Some(2L), log.latestProducerSnapshotOffset)
-    assertEquals(Some(1L), log.oldestProducerSnapshotOffset)
+    assertEquals(OptionalLong.of(2L), log.latestProducerSnapshotOffset)
+    assertEquals(OptionalLong.of(1L), log.oldestProducerSnapshotOffset)
 
     log.appendAsLeader(TestUtils.singletonRecords("c".getBytes), leaderEpoch = 0)
     log.roll(Some(3L))
-    assertEquals(Some(3L), log.latestProducerSnapshotOffset)
+    assertEquals(OptionalLong.of(3L), log.latestProducerSnapshotOffset)
 
     // roll triggers a flush at the starting offset of the new segment, we should retain all snapshots
-    assertEquals(Some(1L), log.oldestProducerSnapshotOffset)
+    assertEquals(OptionalLong.of(1L), log.oldestProducerSnapshotOffset)
 
     // even if we flush within the active segment, the snapshot should remain
     log.appendAsLeader(TestUtils.singletonRecords("baz".getBytes), leaderEpoch = 0)
     log.flushUptoOffsetExclusive(4L)
-    assertEquals(Some(3L), log.latestProducerSnapshotOffset)
+    assertEquals(OptionalLong.of(3L), log.latestProducerSnapshotOffset)
   }
 
   @Test
@@ -1131,7 +1128,7 @@ class UnifiedLogTest {
 
     assertEquals(2, log.logSegments.size)
     assertEquals(1L, log.activeSegment.baseOffset)
-    assertEquals(Some(1L), log.latestProducerSnapshotOffset)
+    assertEquals(OptionalLong.of(1L), log.latestProducerSnapshotOffset)
 
     // Force a reload from the snapshot to check its consistency
     log.truncateTo(1L)
@@ -1139,10 +1136,10 @@ class UnifiedLogTest {
     assertEquals(2, log.logSegments.size)
     assertEquals(1L, log.activeSegment.baseOffset)
     assertTrue(log.activeSegment.log.batches.asScala.isEmpty)
-    assertEquals(Some(1L), log.latestProducerSnapshotOffset)
+    assertEquals(OptionalLong.of(1L), log.latestProducerSnapshotOffset)
 
     val lastEntry = log.producerStateManager.lastEntry(producerId)
-    assertTrue(lastEntry.isDefined)
+    assertTrue(lastEntry.isPresent)
     assertEquals(0L, lastEntry.get.firstDataOffset)
     assertEquals(0L, lastEntry.get.lastDataOffset)
   }
@@ -2204,8 +2201,8 @@ class UnifiedLogTest {
     log.deleteOldSegments()
 
     assertEquals(1, log.numberOfSegments, "Only one segment should remain.")
-    assertTrue(segments.forall(_.log.file.getName.endsWith(UnifiedLog.DeletedFileSuffix)) &&
-      segments.forall(_.lazyOffsetIndex.file.getName.endsWith(UnifiedLog.DeletedFileSuffix)),
+    assertTrue(segments.forall(_.log.file.getName.endsWith(LogFileUtils.DELETED_FILE_SUFFIX)) &&
+      segments.forall(_.lazyOffsetIndex.file.getName.endsWith(LogFileUtils.DELETED_FILE_SUFFIX)),
       "All log and index files should end in .deleted")
     assertTrue(segments.forall(_.log.file.exists) && segments.forall(_.lazyOffsetIndex.file.exists),
       "The .deleted files should still be there.")
@@ -3089,7 +3086,7 @@ class UnifiedLogTest {
   }
 
   private def assertCachedFirstUnstableOffset(log: UnifiedLog, expectedOffset: Long): Unit = {
-    assertTrue(log.producerStateManager.firstUnstableOffset.isDefined)
+    assertTrue(log.producerStateManager.firstUnstableOffset.isPresent)
     val firstUnstableOffset = log.producerStateManager.firstUnstableOffset.get
     assertEquals(expectedOffset, firstUnstableOffset.messageOffset)
     assertFalse(firstUnstableOffset.messageOffsetOnly)
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
index b145ad592a5..42812c9f6a3 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
@@ -58,7 +58,7 @@ import org.apache.kafka.metadata.LeaderConstants.NO_LEADER
 import org.apache.kafka.metadata.LeaderRecoveryState
 import org.apache.kafka.server.common.MetadataVersion.IBP_2_6_IV0
 import org.apache.kafka.server.util.MockScheduler
-import org.apache.kafka.storage.internals.log.{AppendOrigin, FetchIsolation, FetchParams, FetchPartitionData, LogConfig, LogDirFailureChannel, LogOffsetMetadata}
+import org.apache.kafka.storage.internals.log.{AppendOrigin, FetchIsolation, FetchParams, FetchPartitionData, LogConfig, LogDirFailureChannel, LogOffsetMetadata, ProducerStateManager, ProducerStateManagerConfig}
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
 import org.junit.jupiter.params.ParameterizedTest
@@ -506,7 +506,7 @@ class ReplicaManagerTest {
       assertLateTransactionCount(Some(0))
 
       // The transaction becomes late if not finished before the max transaction timeout passes
-      time.sleep(replicaManager.logManager.maxTransactionTimeoutMs + ProducerStateManager.LateTransactionBufferMs)
+      time.sleep(replicaManager.logManager.maxTransactionTimeoutMs + ProducerStateManager.LATE_TRANSACTION_BUFFER_MS)
       assertLateTransactionCount(Some(0))
       time.sleep(1)
       assertLateTransactionCount(Some(1))
diff --git a/core/src/test/scala/unit/kafka/tools/DumpLogSegmentsTest.scala b/core/src/test/scala/unit/kafka/tools/DumpLogSegmentsTest.scala
index 94b4aad8899..ea6d591cf33 100644
--- a/core/src/test/scala/unit/kafka/tools/DumpLogSegmentsTest.scala
+++ b/core/src/test/scala/unit/kafka/tools/DumpLogSegmentsTest.scala
@@ -21,7 +21,7 @@ import java.io.{ByteArrayOutputStream, File, PrintWriter}
 import java.nio.ByteBuffer
 import java.util
 import java.util.Properties
-import kafka.log.{LogTestUtils, ProducerStateManagerConfig, UnifiedLog}
+import kafka.log.{LogTestUtils, UnifiedLog}
 import kafka.raft.{KafkaMetadataLog, MetadataLogConfig}
 import kafka.server.{BrokerTopicStats, KafkaRaftServer}
 import kafka.tools.DumpLogSegments.TimeIndexDumpErrors
@@ -37,7 +37,7 @@ import org.apache.kafka.metadata.MetadataRecordSerde
 import org.apache.kafka.raft.{KafkaRaftClient, OffsetAndEpoch}
 import org.apache.kafka.server.common.ApiMessageAndVersion
 import org.apache.kafka.snapshot.RecordsSnapshotWriter
-import org.apache.kafka.storage.internals.log.{AppendOrigin, FetchIsolation, LogConfig, LogDirFailureChannel}
+import org.apache.kafka.storage.internals.log.{AppendOrigin, FetchIsolation, LogConfig, LogDirFailureChannel, ProducerStateManagerConfig}
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
 
diff --git a/core/src/test/scala/unit/kafka/utils/SchedulerTest.scala b/core/src/test/scala/unit/kafka/utils/SchedulerTest.scala
index 2d152a2df85..73bd36b4ecf 100644
--- a/core/src/test/scala/unit/kafka/utils/SchedulerTest.scala
+++ b/core/src/test/scala/unit/kafka/utils/SchedulerTest.scala
@@ -19,11 +19,11 @@ package kafka.utils
 import java.util.Properties
 import java.util.concurrent.atomic._
 import java.util.concurrent.{CountDownLatch, Executors, TimeUnit}
-import kafka.log.{LocalLog, LogLoader, LogSegments, ProducerStateManager, ProducerStateManagerConfig, UnifiedLog}
+import kafka.log.{LocalLog, LogLoader, LogSegments, UnifiedLog}
 import kafka.server.BrokerTopicStats
 import kafka.utils.TestUtils.retry
 import org.apache.kafka.server.util.KafkaScheduler
-import org.apache.kafka.storage.internals.log.{LogConfig, LogDirFailureChannel}
+import org.apache.kafka.storage.internals.log.{LogConfig, LogDirFailureChannel, ProducerStateManager, ProducerStateManagerConfig}
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, Timeout}
 
diff --git a/core/src/test/scala/unit/kafka/utils/TestUtils.scala b/core/src/test/scala/unit/kafka/utils/TestUtils.scala
index 78371b78226..2c9c4ae6690 100755
--- a/core/src/test/scala/unit/kafka/utils/TestUtils.scala
+++ b/core/src/test/scala/unit/kafka/utils/TestUtils.scala
@@ -71,7 +71,7 @@ import org.apache.kafka.controller.QuorumController
 import org.apache.kafka.server.authorizer.{AuthorizableRequestContext, Authorizer => JAuthorizer}
 import org.apache.kafka.server.common.MetadataVersion
 import org.apache.kafka.server.metrics.KafkaYammerMetrics
-import org.apache.kafka.storage.internals.log.{CleanerConfig, LogConfig, LogDirFailureChannel}
+import org.apache.kafka.storage.internals.log.{CleanerConfig, LogConfig, LogDirFailureChannel, ProducerStateManagerConfig}
 import org.apache.kafka.test.{TestSslUtils, TestUtils => JTestUtils}
 import org.apache.zookeeper.KeeperException.SessionExpiredException
 import org.apache.zookeeper.ZooDefs._
diff --git a/storage/src/main/java/org/apache/kafka/storage/internals/log/LogFileUtils.java b/storage/src/main/java/org/apache/kafka/storage/internals/log/LogFileUtils.java
index 3b8b41f91fd..c6d191624cd 100644
--- a/storage/src/main/java/org/apache/kafka/storage/internals/log/LogFileUtils.java
+++ b/storage/src/main/java/org/apache/kafka/storage/internals/log/LogFileUtils.java
@@ -16,8 +16,21 @@
  */
 package org.apache.kafka.storage.internals.log;
 
+import java.io.File;
+import java.text.NumberFormat;
+
 public final class LogFileUtils {
 
+    /**
+     * Suffix of a producer snapshot file
+     */
+    public static final String PRODUCER_SNAPSHOT_FILE_SUFFIX = ".snapshot";
+
+    /**
+     * Suffix for a file that is scheduled to be deleted
+     */
+    public static final String DELETED_FILE_SUFFIX = ".deleted";
+
     private LogFileUtils() {
     }
 
@@ -32,4 +45,31 @@ public final class LogFileUtils {
         return Long.parseLong(fileName.substring(0, fileName.indexOf('.')));
     }
 
+    /**
+     * Returns a File instance with parent directory as logDir and the file name as producer snapshot file for the
+     * given offset.
+     *
+     * @param logDir The directory in which the log will reside
+     * @param offset The last offset (exclusive) included in the snapshot
+     * @return a File instance for producer snapshot.
+     */
+    public static File producerSnapshotFile(File logDir, long offset) {
+        return new File(logDir, filenamePrefixFromOffset(offset) + PRODUCER_SNAPSHOT_FILE_SUFFIX);
+    }
+
+    /**
+     * Make log segment file name from offset bytes. All this does is pad out the offset number with zeros
+     * so that ls sorts the files numerically.
+     *
+     * @param offset The offset to use in the file name
+     * @return The filename
+     */
+    private static String filenamePrefixFromOffset(long offset) {
+        NumberFormat nf = NumberFormat.getInstance();
+        nf.setMinimumIntegerDigits(20);
+        nf.setMaximumFractionDigits(0);
+        nf.setGroupingUsed(false);
+        return nf.format(offset);
+    }
+
 }
diff --git a/storage/src/main/java/org/apache/kafka/storage/internals/log/ProducerStateManager.java b/storage/src/main/java/org/apache/kafka/storage/internals/log/ProducerStateManager.java
new file mode 100644
index 00000000000..efa1a6e63dc
--- /dev/null
+++ b/storage/src/main/java/org/apache/kafka/storage/internals/log/ProducerStateManager.java
@@ -0,0 +1,678 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.storage.internals.log;
+
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.protocol.types.ArrayOf;
+import org.apache.kafka.common.protocol.types.Field;
+import org.apache.kafka.common.protocol.types.Schema;
+import org.apache.kafka.common.protocol.types.SchemaException;
+import org.apache.kafka.common.protocol.types.Struct;
+import org.apache.kafka.common.protocol.types.Type;
+import org.apache.kafka.common.record.RecordBatch;
+import org.apache.kafka.common.utils.ByteUtils;
+import org.apache.kafka.common.utils.Crc32C;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.Time;
+import org.slf4j.Logger;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.FileChannel;
+import java.nio.file.Files;
+import java.nio.file.NoSuchFileException;
+import java.nio.file.Path;
+import java.nio.file.StandardOpenOption;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.OptionalLong;
+import java.util.TreeMap;
+import java.util.concurrent.ConcurrentSkipListMap;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+/**
+ * Maintains a mapping from ProducerIds to metadata about the last appended entries (e.g.
+ * epoch, sequence number, last offset, etc.)
+ * <p>
+ * The sequence number is the last number successfully appended to the partition for the given identifier.
+ * The epoch is used for fencing against zombie writers. The offset is the one of the last successful message
+ * appended to the partition.
+ * <p>
+ * As long as a producer id is contained in the map, the corresponding producer can continue to write data.
+ * However, producer ids can be expired due to lack of recent use or if the last written entry has been deleted from
+ * the log (e.g. if the retention policy is "delete"). For compacted topics, the log cleaner will ensure
+ * that the most recent entry from a given producer id is retained in the log provided it hasn't expired due to
+ * age. This ensures that producer ids will not be expired until either the max expiration time has been reached,
+ * or if the topic also is configured for deletion, the segment containing the last written offset has
+ * been deleted.
+ */
+public class ProducerStateManager {
+
+    public static final long LATE_TRANSACTION_BUFFER_MS = 5 * 60 * 1000;
+
+    private static final short PRODUCER_SNAPSHOT_VERSION = 1;
+    private static final String VERSION_FIELD = "version";
+    private static final String CRC_FIELD = "crc";
+    private static final String PRODUCER_ID_FIELD = "producer_id";
+    private static final String LAST_SEQUENCE_FIELD = "last_sequence";
+    private static final String PRODUCER_EPOCH_FIELD = "epoch";
+    private static final String LAST_OFFSET_FIELD = "last_offset";
+    private static final String OFFSET_DELTA_FIELD = "offset_delta";
+    private static final String TIMESTAMP_FIELD = "timestamp";
+    private static final String PRODUCER_ENTRIES_FIELD = "producer_entries";
+    private static final String COORDINATOR_EPOCH_FIELD = "coordinator_epoch";
+    private static final String CURRENT_TXN_FIRST_OFFSET_FIELD = "current_txn_first_offset";
+
+    private static final int VERSION_OFFSET = 0;
+    private static final int CRC_OFFSET = VERSION_OFFSET + 2;
+    private static final int PRODUCER_ENTRIES_OFFSET = CRC_OFFSET + 4;
+
+    private static final Schema PRODUCER_SNAPSHOT_ENTRY_SCHEMA =
+            new Schema(new Field(PRODUCER_ID_FIELD, Type.INT64, "The producer ID"),
+                    new Field(PRODUCER_EPOCH_FIELD, Type.INT16, "Current epoch of the producer"),
+                    new Field(LAST_SEQUENCE_FIELD, Type.INT32, "Last written sequence of the producer"),
+                    new Field(LAST_OFFSET_FIELD, Type.INT64, "Last written offset of the producer"),
+                    new Field(OFFSET_DELTA_FIELD, Type.INT32, "The difference of the last sequence and first sequence in the last written batch"),
+                    new Field(TIMESTAMP_FIELD, Type.INT64, "Max timestamp from the last written entry"),
+                    new Field(COORDINATOR_EPOCH_FIELD, Type.INT32, "The epoch of the last transaction coordinator to send an end transaction marker"),
+                    new Field(CURRENT_TXN_FIRST_OFFSET_FIELD, Type.INT64, "The first offset of the on-going transaction (-1 if there is none)"));
+    private static final Schema PID_SNAPSHOT_MAP_SCHEMA =
+            new Schema(new Field(VERSION_FIELD, Type.INT16, "Version of the snapshot file"),
+                    new Field(CRC_FIELD, Type.UNSIGNED_INT32, "CRC of the snapshot data"),
+                    new Field(PRODUCER_ENTRIES_FIELD, new ArrayOf(PRODUCER_SNAPSHOT_ENTRY_SCHEMA), "The entries in the producer table"));
+
+    private final Logger log;
+
+    private final TopicPartition topicPartition;
+    private final int maxTransactionTimeoutMs;
+    private final ProducerStateManagerConfig producerStateManagerConfig;
+    private final Time time;
+
+    private final Map<Long, ProducerStateEntry> producers = new HashMap<>();
+
+    // ongoing transactions sorted by the first offset of the transaction
+    private final TreeMap<Long, TxnMetadata> ongoingTxns = new TreeMap<>();
+
+    // completed transactions whose markers are at offsets above the high watermark
+    private final TreeMap<Long, TxnMetadata> unreplicatedTxns = new TreeMap<>();
+
+    private volatile File logDir;
+
+    // Keep track of the last timestamp from the oldest transaction. This is used
+    // to detect (approximately) when a transaction has been left hanging on a partition.
+    // We make the field volatile so that it can be safely accessed without a lock.
+    private volatile long oldestTxnLastTimestamp = -1L;
+
+    private ConcurrentSkipListMap<Long, SnapshotFile> snapshots;
+    private long lastMapOffset = 0L;
+    private long lastSnapOffset = 0L;
+
+    public ProducerStateManager(TopicPartition topicPartition, File logDir, int maxTransactionTimeoutMs, ProducerStateManagerConfig producerStateManagerConfig, Time time) throws IOException {
+        this.topicPartition = topicPartition;
+        this.logDir = logDir;
+        this.maxTransactionTimeoutMs = maxTransactionTimeoutMs;
+        this.producerStateManagerConfig = producerStateManagerConfig;
+        this.time = time;
+        log = new LogContext("[ProducerStateManager partition=" + topicPartition + "]").logger(ProducerStateManager.class);
+        snapshots = loadSnapshots();
+    }
+
+    public int maxTransactionTimeoutMs() {
+        return maxTransactionTimeoutMs;
+    }
+
+    public ProducerStateManagerConfig producerStateManagerConfig() {
+        return producerStateManagerConfig;
+    }
+
+    /**
+     * This method checks whether there is a late transaction in a thread safe manner.
+     */
+    public boolean hasLateTransaction(long currentTimeMs) {
+        long lastTimestamp = oldestTxnLastTimestamp;
+        return lastTimestamp > 0 && (currentTimeMs - lastTimestamp) > maxTransactionTimeoutMs + ProducerStateManager.LATE_TRANSACTION_BUFFER_MS;
+    }
+
+    public void truncateFullyAndReloadSnapshots() throws IOException {
+        log.info("Reloading the producer state snapshots");
+        truncateFullyAndStartAt(0L);
+        snapshots = loadSnapshots();
+    }
+
+    /**
+     * Load producer state snapshots by scanning the logDir.
+     */
+    private ConcurrentSkipListMap<Long, SnapshotFile> loadSnapshots() throws IOException {
+        ConcurrentSkipListMap<Long, SnapshotFile> offsetToSnapshots = new ConcurrentSkipListMap<>();
+        List<SnapshotFile> snapshotFiles = listSnapshotFiles(logDir);
+        for (SnapshotFile snapshotFile : snapshotFiles) {
+            offsetToSnapshots.put(snapshotFile.offset, snapshotFile);
+        }
+        return offsetToSnapshots;
+    }
+
+    /**
+     * Scans the log directory, gathering all producer state snapshot files. Snapshot files which do not have an offset
+     * corresponding to one of the provided offsets in segmentBaseOffsets will be removed, except in the case that there
+     * is a snapshot file at a higher offset than any offset in segmentBaseOffsets.
+     * <p>
+     * The goal here is to remove any snapshot files which do not have an associated segment file, but not to remove the
+     * largest stray snapshot file which was emitted during clean shutdown.
+     */
+    public void removeStraySnapshots(Collection<Long> segmentBaseOffsets) throws IOException {
+        OptionalLong maxSegmentBaseOffset = segmentBaseOffsets.isEmpty() ? OptionalLong.empty() : OptionalLong.of(segmentBaseOffsets.stream().max(Long::compare).get());
+
+        HashSet<Long> baseOffsets = new HashSet<>(segmentBaseOffsets);
+        Optional<SnapshotFile> latestStraySnapshot = Optional.empty();
+
+        ConcurrentSkipListMap<Long, SnapshotFile> snapshots = loadSnapshots();
+        for (SnapshotFile snapshot : snapshots.values()) {
+            long key = snapshot.offset;
+            if (latestStraySnapshot.isPresent()) {
+                SnapshotFile prev = latestStraySnapshot.get();
+                if (!baseOffsets.contains(key)) {
+                    // this snapshot is now the largest stray snapshot.
+                    prev.deleteIfExists();
+                    snapshots.remove(prev.offset);
+                    latestStraySnapshot = Optional.of(snapshot);
+                }
+            } else {
+                if (!baseOffsets.contains(key)) {
+                    latestStraySnapshot = Optional.of(snapshot);
+                }
+            }
+        }
+
+        // Check to see if the latestStraySnapshot is larger than the largest segment base offset, if it is not,
+        // delete the largestStraySnapshot.
+        if (latestStraySnapshot.isPresent() && maxSegmentBaseOffset.isPresent()) {
+            long strayOffset = latestStraySnapshot.get().offset;
+            long maxOffset = maxSegmentBaseOffset.getAsLong();
+            if (strayOffset < maxOffset) {
+                SnapshotFile removedSnapshot = snapshots.remove(strayOffset);
+                if (removedSnapshot != null) {
+                    removedSnapshot.deleteIfExists();
+                }
+            }
+        }
+
+        this.snapshots = snapshots;
+    }
+
+    /**
+     * An unstable offset is one which is either undecided (i.e. its ultimate outcome is not yet known),
+     * or one that is decided, but may not have been replicated (i.e. any transaction which has a COMMIT/ABORT
+     * marker written at a higher offset than the current high watermark).
+     */
+    public Optional<LogOffsetMetadata> firstUnstableOffset() {
+        Optional<LogOffsetMetadata> unreplicatedFirstOffset = Optional.ofNullable(unreplicatedTxns.firstEntry()).map(e -> e.getValue().firstOffset);
+        Optional<LogOffsetMetadata> undecidedFirstOffset = Optional.ofNullable(ongoingTxns.firstEntry()).map(e -> e.getValue().firstOffset);
+
+        if (!unreplicatedFirstOffset.isPresent())
+            return undecidedFirstOffset;
+        else if (!undecidedFirstOffset.isPresent())
+            return unreplicatedFirstOffset;
+        else if (undecidedFirstOffset.get().messageOffset < unreplicatedFirstOffset.get().messageOffset)
+            return undecidedFirstOffset;
+        else
+            return unreplicatedFirstOffset;
+    }
+
+    /**
+     * Acknowledge all transactions which have been completed before a given offset. This allows the LSO
+     * to advance to the next unstable offset.
+     */
+    public void onHighWatermarkUpdated(long highWatermark) {
+        removeUnreplicatedTransactions(highWatermark);
+    }
+
+    /**
+     * The first undecided offset is the earliest transactional message which has not yet been committed
+     * or aborted. Unlike [[firstUnstableOffset]], this does not reflect the state of replication (i.e.
+     * whether a completed transaction marker is beyond the high watermark).
+     */
+    public OptionalLong firstUndecidedOffset() {
+        Map.Entry<Long, TxnMetadata> firstEntry = ongoingTxns.firstEntry();
+        return firstEntry != null ? OptionalLong.of(firstEntry.getValue().firstOffset.messageOffset) : OptionalLong.empty();
+    }
+
+    /**
+     * Returns the last offset of this map
+     */
+    public long mapEndOffset() {
+        return lastMapOffset;
+    }
+
+    /**
+     * Get an unmodifiable map of active producers.
+     */
+    public Map<Long, ProducerStateEntry> activeProducers() {
+        return Collections.unmodifiableMap(producers);
+    }
+
+    public boolean isEmpty() {
+        return producers.isEmpty() && unreplicatedTxns.isEmpty();
+    }
+
+    private void loadFromSnapshot(long logStartOffset, long currentTime) throws IOException {
+        while (true) {
+            Optional<SnapshotFile> latestSnapshotFileOptional = latestSnapshotFile();
+            if (latestSnapshotFileOptional.isPresent()) {
+                SnapshotFile snapshot = latestSnapshotFileOptional.get();
+                try {
+                    log.info("Loading producer state from snapshot file '{}'", snapshot);
+                    Stream<ProducerStateEntry> loadedProducers = readSnapshot(snapshot.file()).stream().filter(producerEntry -> !isProducerExpired(currentTime, producerEntry));
+                    loadedProducers.forEach(this::loadProducerEntry);
+                    lastSnapOffset = snapshot.offset;
+                    lastMapOffset = lastSnapOffset;
+                    updateOldestTxnTimestamp();
+                    return;
+                } catch (CorruptSnapshotException e) {
+                    log.warn("Failed to load producer snapshot from '{}': {}", snapshot.file(), e.getMessage());
+                    removeAndDeleteSnapshot(snapshot.offset);
+                }
+            } else {
+                lastSnapOffset = logStartOffset;
+                lastMapOffset = logStartOffset;
+                return;
+
+            }
+        }
+    }
+
+    // Visible for testing
+    public void loadProducerEntry(ProducerStateEntry entry) {
+        long producerId = entry.producerId();
+        producers.put(producerId, entry);
+        entry.currentTxnFirstOffset().ifPresent(offset -> ongoingTxns.put(offset, new TxnMetadata(producerId, offset)));
+    }
+
+    private boolean isProducerExpired(long currentTimeMs, ProducerStateEntry producerState) {
+        return !producerState.currentTxnFirstOffset().isPresent() && currentTimeMs - producerState.lastTimestamp() >= producerStateManagerConfig.producerIdExpirationMs();
+    }
+
+    /**
+     * Expire any producer ids which have been idle longer than the configured maximum expiration timeout.
+     */
+    public void removeExpiredProducers(long currentTimeMs) {
+        List<Long> keys = producers.entrySet().stream()
+                .filter(entry -> isProducerExpired(currentTimeMs, entry.getValue()))
+                .map(Map.Entry::getKey)
+                .collect(Collectors.toList());
+        producers.keySet().removeAll(keys);
+    }
+
+    /**
+     * Truncate the producer id mapping to the given offset range and reload the entries from the most recent
+     * snapshot in range (if there is one). We delete snapshot files prior to the logStartOffset but do not remove
+     * producer state from the map. This means that in-memory and on-disk state can diverge, and in the case of
+     * broker failover or unclean shutdown, any in-memory state not persisted in the snapshots will be lost, which
+     * would lead to UNKNOWN_PRODUCER_ID errors. Note that the log end offset is assumed to be less than or equal
+     * to the high watermark.
+     */
+    public void truncateAndReload(long logStartOffset, long logEndOffset, long currentTimeMs) throws IOException {
+        // remove all out of range snapshots
+        for (SnapshotFile snapshot : snapshots.values()) {
+            if (snapshot.offset > logEndOffset || snapshot.offset <= logStartOffset) {
+                removeAndDeleteSnapshot(snapshot.offset);
+            }
+        }
+
+        if (logEndOffset != mapEndOffset()) {
+            producers.clear();
+            ongoingTxns.clear();
+            updateOldestTxnTimestamp();
+
+            // since we assume that the offset is less than or equal to the high watermark, it is
+            // safe to clear the unreplicated transactions
+            unreplicatedTxns.clear();
+            loadFromSnapshot(logStartOffset, currentTimeMs);
+        } else {
+            onLogStartOffsetIncremented(logStartOffset);
+        }
+    }
+
+    public ProducerAppendInfo prepareUpdate(long producerId, AppendOrigin origin) {
+        ProducerStateEntry currentEntry = lastEntry(producerId).orElse(ProducerStateEntry.empty(producerId));
+        return new ProducerAppendInfo(topicPartition, producerId, currentEntry, origin);
+    }
+
+    /**
+     * Update the mapping with the given append information
+     */
+    public void update(ProducerAppendInfo appendInfo) {
+        if (appendInfo.producerId() == RecordBatch.NO_PRODUCER_ID)
+            throw new IllegalArgumentException("Invalid producer id " + appendInfo.producerId() + " passed to update "
+                    + "for partition" + topicPartition);
+
+        log.trace("Updated producer {} state to {}", appendInfo.producerId(), appendInfo);
+        ProducerStateEntry updatedEntry = appendInfo.toEntry();
+        ProducerStateEntry currentEntry = producers.get(appendInfo.producerId());
+        if (currentEntry != null) {
+            currentEntry.update(updatedEntry);
+        } else {
+            producers.put(appendInfo.producerId(), updatedEntry);
+        }
+
+        appendInfo.startedTransactions().forEach(txn -> ongoingTxns.put(txn.firstOffset.messageOffset, txn));
+
+        updateOldestTxnTimestamp();
+    }
+
+    private void updateOldestTxnTimestamp() {
+        Map.Entry<Long, TxnMetadata> firstEntry = ongoingTxns.firstEntry();
+        if (firstEntry == null) {
+            oldestTxnLastTimestamp = -1;
+        } else {
+            TxnMetadata oldestTxnMetadata = firstEntry.getValue();
+            ProducerStateEntry entry = producers.get(oldestTxnMetadata.producerId);
+            oldestTxnLastTimestamp = entry != null ? entry.lastTimestamp() : -1L;
+        }
+    }
+
+    public void updateMapEndOffset(long lastOffset) {
+        lastMapOffset = lastOffset;
+    }
+
+    /**
+     * Get the last written entry for the given producer id.
+     */
+    public Optional<ProducerStateEntry> lastEntry(long producerId) {
+        return Optional.ofNullable(producers.get(producerId));
+    }
+
+    /**
+     * Take a snapshot at the current end offset if one does not already exist.
+     */
+    public void takeSnapshot() throws IOException {
+        // If not a new offset, then it is not worth taking another snapshot
+        if (lastMapOffset > lastSnapOffset) {
+            SnapshotFile snapshotFile = new SnapshotFile(LogFileUtils.producerSnapshotFile(logDir, lastMapOffset));
+            long start = time.hiResClockMs();
+            writeSnapshot(snapshotFile.file(), producers);
+            log.info("Wrote producer snapshot at offset {} with {} producer ids in {} ms.", lastMapOffset,
+                    producers.size(), time.hiResClockMs() - start);
+
+            snapshots.put(snapshotFile.offset, snapshotFile);
+
+            // Update the last snap offset according to the serialized map
+            lastSnapOffset = lastMapOffset;
+        }
+    }
+
+    /**
+     * Update the parentDir for this ProducerStateManager and all of the snapshot files which it manages.
+     */
+    public void updateParentDir(File parentDir) {
+        logDir = parentDir;
+        snapshots.forEach((k, v) -> v.updateParentDir(parentDir));
+    }
+
+    /**
+     * Get the last offset (exclusive) of the latest snapshot file.
+     */
+    public OptionalLong latestSnapshotOffset() {
+        Optional<SnapshotFile> snapshotFileOptional = latestSnapshotFile();
+        return snapshotFileOptional.map(snapshotFile -> OptionalLong.of(snapshotFile.offset)).orElseGet(OptionalLong::empty);
+    }
+
+    /**
+     * Get the last offset (exclusive) of the oldest snapshot file.
+     */
+    public OptionalLong oldestSnapshotOffset() {
+        Optional<SnapshotFile> snapshotFileOptional = oldestSnapshotFile();
+        return snapshotFileOptional.map(snapshotFile -> OptionalLong.of(snapshotFile.offset)).orElseGet(OptionalLong::empty);
+    }
+
+    /**
+     * Visible for testing
+     */
+    public Optional<SnapshotFile> snapshotFileForOffset(long offset) {
+        return Optional.ofNullable(snapshots.get(offset));
+    }
+
+    /**
+     * Remove any unreplicated transactions lower than the provided logStartOffset and bring the lastMapOffset forward
+     * if necessary.
+     */
+    public void onLogStartOffsetIncremented(long logStartOffset) {
+        removeUnreplicatedTransactions(logStartOffset);
+
+        if (lastMapOffset < logStartOffset) lastMapOffset = logStartOffset;
+
+        lastSnapOffset = latestSnapshotOffset().orElse(logStartOffset);
+    }
+
+    private void removeUnreplicatedTransactions(long offset) {
+        Iterator<Map.Entry<Long, TxnMetadata>> iterator = unreplicatedTxns.entrySet().iterator();
+        while (iterator.hasNext()) {
+            Map.Entry<Long, TxnMetadata> txnEntry = iterator.next();
+            OptionalLong lastOffset = txnEntry.getValue().lastOffset;
+            if (lastOffset.isPresent() && lastOffset.getAsLong() < offset) iterator.remove();
+        }
+    }
+
+    /**
+     * Truncate the producer id mapping and remove all snapshots. This resets the state of the mapping.
+     */
+    public void truncateFullyAndStartAt(long offset) throws IOException {
+        producers.clear();
+        ongoingTxns.clear();
+        unreplicatedTxns.clear();
+        for (SnapshotFile snapshotFile : snapshots.values()) {
+            removeAndDeleteSnapshot(snapshotFile.offset);
+        }
+        lastSnapOffset = 0L;
+        lastMapOffset = offset;
+        updateOldestTxnTimestamp();
+    }
+
+    /**
+     * Compute the last stable offset of a completed transaction, but do not yet mark the transaction complete.
+     * That will be done in `completeTxn` below. This is used to compute the LSO that will be appended to the
+     * transaction index, but the completion must be done only after successfully appending to the index.
+     */
+    public long lastStableOffset(CompletedTxn completedTxn) {
+        return findNextIncompleteTxn(completedTxn.producerId)
+                .map(x -> x.firstOffset.messageOffset)
+                .orElse(completedTxn.lastOffset + 1);
+    }
+
+    private Optional<TxnMetadata> findNextIncompleteTxn(long producerId) {
+        for (TxnMetadata txnMetadata : ongoingTxns.values()) {
+            if (txnMetadata.producerId != producerId) {
+                return Optional.of(txnMetadata);
+            }
+        }
+        return Optional.empty();
+    }
+
+    /**
+     * Mark a transaction as completed. We will still await advancement of the high watermark before
+     * advancing the first unstable offset.
+     */
+    public void completeTxn(CompletedTxn completedTxn) {
+        TxnMetadata txnMetadata = ongoingTxns.remove(completedTxn.firstOffset);
+        if (txnMetadata == null)
+            throw new IllegalArgumentException("Attempted to complete transaction " + completedTxn + " on partition "
+                    + topicPartition + " which was not started");
+
+        txnMetadata.lastOffset = OptionalLong.of(completedTxn.lastOffset);
+        unreplicatedTxns.put(completedTxn.firstOffset, txnMetadata);
+        updateOldestTxnTimestamp();
+    }
+
+    /**
+     * Deletes the producer snapshot files until the given offset (exclusive) in a thread safe manner.
+     *
+     * @param offset offset number
+     * @throws IOException if any IOException occurs while deleting the files.
+     */
+    public void deleteSnapshotsBefore(long offset) throws IOException {
+        for (SnapshotFile snapshot : snapshots.subMap(0L, offset).values()) {
+            removeAndDeleteSnapshot(snapshot.offset);
+        }
+    }
+
+    private Optional<SnapshotFile> oldestSnapshotFile() {
+        return Optional.ofNullable(snapshots.firstEntry()).map(x -> x.getValue());
+    }
+
+    private Optional<SnapshotFile> latestSnapshotFile() {
+        return Optional.ofNullable(snapshots.lastEntry()).map(e -> e.getValue());
+    }
+
+    /**
+     * Removes the producer state snapshot file metadata corresponding to the provided offset if it exists from this
+     * ProducerStateManager, and deletes the backing snapshot file.
+     */
+    private void removeAndDeleteSnapshot(long snapshotOffset) throws IOException {
+        SnapshotFile snapshotFile = snapshots.remove(snapshotOffset);
+        if (snapshotFile != null) snapshotFile.deleteIfExists();
+    }
+
+    /**
+     * Removes the producer state snapshot file metadata corresponding to the provided offset if it exists from this
+     * ProducerStateManager, and renames the backing snapshot file to have the Log.DeletionSuffix.
+     * <p>
+     * Note: This method is safe to use with async deletes. If a race occurs and the snapshot file
+     * is deleted without this ProducerStateManager instance knowing, the resulting exception on
+     * SnapshotFile rename will be ignored and {@link Optional#empty()} will be returned.
+     */
+    public Optional<SnapshotFile> removeAndMarkSnapshotForDeletion(long snapshotOffset) throws IOException {
+        SnapshotFile snapshotFile = snapshots.remove(snapshotOffset);
+        if (snapshotFile != null) {
+            // If the file cannot be renamed, it likely means that the file was deleted already.
+            // This can happen due to the way we construct an intermediate producer state manager
+            // during log recovery, and use it to issue deletions prior to creating the "real"
+            // producer state manager.
+            //
+            // In any case, removeAndMarkSnapshotForDeletion is intended to be used for snapshot file
+            // deletion, so ignoring the exception here just means that the intended operation was
+            // already completed.
+            try {
+                snapshotFile.renameTo(LogFileUtils.DELETED_FILE_SUFFIX);
+                return Optional.of(snapshotFile);
+            } catch (NoSuchFileException ex) {
+                log.info("Failed to rename producer state snapshot {} with deletion suffix because it was already deleted", snapshotFile.file().getAbsoluteFile());
+            }
+        }
+        return Optional.empty();
+    }
+
+    public static List<ProducerStateEntry> readSnapshot(File file) throws IOException {
+        try {
+            byte[] buffer = Files.readAllBytes(file.toPath());
+            Struct struct = PID_SNAPSHOT_MAP_SCHEMA.read(ByteBuffer.wrap(buffer));
+
+            Short version = struct.getShort(VERSION_FIELD);
+            if (version != PRODUCER_SNAPSHOT_VERSION)
+                throw new CorruptSnapshotException("Snapshot contained an unknown file version " + version);
+
+            long crc = struct.getUnsignedInt(CRC_FIELD);
+            long computedCrc = Crc32C.compute(buffer, PRODUCER_ENTRIES_OFFSET, buffer.length - PRODUCER_ENTRIES_OFFSET);
+            if (crc != computedCrc)
+                throw new CorruptSnapshotException("Snapshot is corrupt (CRC is no longer valid). Stored crc: " + crc
+                        + ". Computed crc: " + computedCrc);
+
+            Object[] producerEntryFields = struct.getArray(PRODUCER_ENTRIES_FIELD);
+            List<ProducerStateEntry> entries = new ArrayList<>(producerEntryFields.length);
+            for (Object producerEntryObj : producerEntryFields) {
+                Struct producerEntryStruct = (Struct) producerEntryObj;
+                long producerId = producerEntryStruct.getLong(PRODUCER_ID_FIELD);
+                short producerEpoch = producerEntryStruct.getShort(PRODUCER_EPOCH_FIELD);
+                int seq = producerEntryStruct.getInt(LAST_SEQUENCE_FIELD);
+                long offset = producerEntryStruct.getLong(LAST_OFFSET_FIELD);
+                long timestamp = producerEntryStruct.getLong(TIMESTAMP_FIELD);
+                int offsetDelta = producerEntryStruct.getInt(OFFSET_DELTA_FIELD);
+                int coordinatorEpoch = producerEntryStruct.getInt(COORDINATOR_EPOCH_FIELD);
+                long currentTxnFirstOffset = producerEntryStruct.getLong(CURRENT_TXN_FIRST_OFFSET_FIELD);
+
+                OptionalLong currentTxnFirstOffsetVal = currentTxnFirstOffset >= 0 ? OptionalLong.of(currentTxnFirstOffset) : OptionalLong.empty();
+                Optional<BatchMetadata> batchMetadata =
+                        (offset >= 0) ? Optional.of(new BatchMetadata(seq, offset, offsetDelta, timestamp)) : Optional.empty();
+                entries.add(new ProducerStateEntry(producerId, producerEpoch, coordinatorEpoch, timestamp, currentTxnFirstOffsetVal, batchMetadata));
+            }
+
+            return entries;
+        } catch (SchemaException e) {
+            throw new CorruptSnapshotException("Snapshot failed schema validation: " + e.getMessage());
+        }
+    }
+
+    private static void writeSnapshot(File file, Map<Long, ProducerStateEntry> entries) throws IOException {
+        Struct struct = new Struct(PID_SNAPSHOT_MAP_SCHEMA);
+        struct.set(VERSION_FIELD, PRODUCER_SNAPSHOT_VERSION);
+        struct.set(CRC_FIELD, 0L); // we'll fill this after writing the entries
+        Struct[] structEntries = new Struct[entries.size()];
+        int i = 0;
+        for (Map.Entry<Long, ProducerStateEntry> producerIdEntry : entries.entrySet()) {
+            Long producerId = producerIdEntry.getKey();
+            ProducerStateEntry entry = producerIdEntry.getValue();
+            Struct producerEntryStruct = struct.instance(PRODUCER_ENTRIES_FIELD);
+            producerEntryStruct.set(PRODUCER_ID_FIELD, producerId)
+                    .set(PRODUCER_EPOCH_FIELD, entry.producerEpoch())
+                    .set(LAST_SEQUENCE_FIELD, entry.lastSeq())
+                    .set(LAST_OFFSET_FIELD, entry.lastDataOffset())
+                    .set(OFFSET_DELTA_FIELD, entry.lastOffsetDelta())
+                    .set(TIMESTAMP_FIELD, entry.lastTimestamp())
+                    .set(COORDINATOR_EPOCH_FIELD, entry.coordinatorEpoch())
+                    .set(CURRENT_TXN_FIRST_OFFSET_FIELD, entry.currentTxnFirstOffset().orElse(-1L));
+            structEntries[i++] = producerEntryStruct;
+        }
+        struct.set(PRODUCER_ENTRIES_FIELD, structEntries);
+
+        ByteBuffer buffer = ByteBuffer.allocate(struct.sizeOf());
+        struct.writeTo(buffer);
+        buffer.flip();
+
+        // now fill in the CRC
+        long crc = Crc32C.compute(buffer, PRODUCER_ENTRIES_OFFSET, buffer.limit() - PRODUCER_ENTRIES_OFFSET);
+        ByteUtils.writeUnsignedInt(buffer, CRC_OFFSET, crc);
+
+        try (FileChannel fileChannel = FileChannel.open(file.toPath(), StandardOpenOption.CREATE, StandardOpenOption.WRITE)) {
+            fileChannel.write(buffer);
+            fileChannel.force(true);
+        }
+    }
+
+    private static boolean isSnapshotFile(Path path) {
+        return Files.isRegularFile(path) && path.getFileName().toString().endsWith(LogFileUtils.PRODUCER_SNAPSHOT_FILE_SUFFIX);
+    }
+
+    // visible for testing
+    public static List<SnapshotFile> listSnapshotFiles(File dir) throws IOException {
+        if (dir.exists() && dir.isDirectory()) {
+            try (Stream<Path> paths = Files.list(dir.toPath())) {
+                return paths.filter(ProducerStateManager::isSnapshotFile)
+                        .map(path -> new SnapshotFile(path.toFile())).collect(Collectors.toList());
+            }
+        } else {
+            return Collections.emptyList();
+        }
+    }
+
+}
diff --git a/storage/src/main/java/org/apache/kafka/storage/internals/log/LogFileUtils.java b/storage/src/main/java/org/apache/kafka/storage/internals/log/ProducerStateManagerConfig.java
similarity index 55%
copy from storage/src/main/java/org/apache/kafka/storage/internals/log/LogFileUtils.java
copy to storage/src/main/java/org/apache/kafka/storage/internals/log/ProducerStateManagerConfig.java
index 3b8b41f91fd..2f7946f95f5 100644
--- a/storage/src/main/java/org/apache/kafka/storage/internals/log/LogFileUtils.java
+++ b/storage/src/main/java/org/apache/kafka/storage/internals/log/ProducerStateManagerConfig.java
@@ -16,20 +16,23 @@
  */
 package org.apache.kafka.storage.internals.log;
 
-public final class LogFileUtils {
+import java.util.Collections;
+import java.util.Set;
 
-    private LogFileUtils() {
+public class ProducerStateManagerConfig {
+    public static final String PRODUCER_ID_EXPIRATION_MS = "producer.id.expiration.ms";
+    public static final Set<String> RECONFIGURABLE_CONFIGS = Collections.singleton(PRODUCER_ID_EXPIRATION_MS);
+    private volatile int producerIdExpirationMs;
+
+    public ProducerStateManagerConfig(int producerIdExpirationMs) {
+        this.producerIdExpirationMs = producerIdExpirationMs;
     }
 
-    /**
-     * Returns the offset for the given file name. The file name is of the form: {number}.{suffix}. This method extracts
-     * the number from the given file name.
-     *
-     * @param fileName name of the file
-     * @return offset of the given file name
-     */
-    public static long offsetFromFileName(String fileName) {
-        return Long.parseLong(fileName.substring(0, fileName.indexOf('.')));
+    public void setProducerIdExpirationMs(int producerIdExpirationMs) {
+        this.producerIdExpirationMs = producerIdExpirationMs;
     }
 
+    public int producerIdExpirationMs() {
+        return producerIdExpirationMs;
+    }
 }
diff --git a/storage/src/main/java/org/apache/kafka/storage/internals/log/SnapshotFile.java b/storage/src/main/java/org/apache/kafka/storage/internals/log/SnapshotFile.java
index 82c8803a643..be496ab2998 100644
--- a/storage/src/main/java/org/apache/kafka/storage/internals/log/SnapshotFile.java
+++ b/storage/src/main/java/org/apache/kafka/storage/internals/log/SnapshotFile.java
@@ -68,4 +68,12 @@ public class SnapshotFile {
             file = renamed;
         }
     }
+
+    @Override
+    public String toString() {
+        return "SnapshotFile(" +
+                "offset=" + offset +
+                ", file=" + file +
+                ')';
+    }
 }
\ No newline at end of file