You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jg...@apache.org on 2017/06/17 17:20:05 UTC

[1/2] kafka git commit: KAFKA-5435; Improve producer state loading after failure

Repository: kafka
Updated Branches:
  refs/heads/trunk d68f9e2fe -> bcaee7fe1


http://git-wip-us.apache.org/repos/asf/kafka/blob/bcaee7fe/core/src/test/scala/unit/kafka/log/ProducerStateManagerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/log/ProducerStateManagerTest.scala b/core/src/test/scala/unit/kafka/log/ProducerStateManagerTest.scala
index ac1d623..3cc68ad 100644
--- a/core/src/test/scala/unit/kafka/log/ProducerStateManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/log/ProducerStateManagerTest.scala
@@ -393,12 +393,12 @@ class ProducerStateManagerTest extends JUnitSuite {
     append(stateManager, producerId, epoch, sequence, offset = 99, isTransactional = true)
     assertEquals(Some(99), stateManager.firstUnstableOffset.map(_.messageOffset))
     append(stateManager, 2L, epoch, 0, offset = 106, isTransactional = true)
-    stateManager.evictUnretainedProducers(100)
+    stateManager.truncateHead(100)
     assertEquals(Some(106), stateManager.firstUnstableOffset.map(_.messageOffset))
   }
 
   @Test
-  def testEvictUnretainedPids(): Unit = {
+  def testTruncateHead(): Unit = {
     val epoch = 0.toShort
 
     append(stateManager, producerId, epoch, 0, 0L)
@@ -411,8 +411,8 @@ class ProducerStateManagerTest extends JUnitSuite {
     stateManager.takeSnapshot()
     assertEquals(Set(2, 4), currentSnapshotOffsets)
 
-    stateManager.evictUnretainedProducers(2)
-    assertEquals(Set(4), currentSnapshotOffsets)
+    stateManager.truncateHead(2)
+    assertEquals(Set(2, 4), currentSnapshotOffsets)
     assertEquals(Set(anotherPid), stateManager.activeProducers.keySet)
     assertEquals(None, stateManager.lastEntry(producerId))
 
@@ -420,18 +420,39 @@ class ProducerStateManagerTest extends JUnitSuite {
     assertTrue(maybeEntry.isDefined)
     assertEquals(3L, maybeEntry.get.lastOffset)
 
-    stateManager.evictUnretainedProducers(3)
+    stateManager.truncateHead(3)
     assertEquals(Set(anotherPid), stateManager.activeProducers.keySet)
     assertEquals(Set(4), currentSnapshotOffsets)
     assertEquals(4, stateManager.mapEndOffset)
 
-    stateManager.evictUnretainedProducers(5)
+    stateManager.truncateHead(5)
     assertEquals(Set(), stateManager.activeProducers.keySet)
     assertEquals(Set(), currentSnapshotOffsets)
     assertEquals(5, stateManager.mapEndOffset)
   }
 
   @Test
+  def testLoadFromSnapshotRemovesNonRetainedProducers(): Unit = {
+    val epoch = 0.toShort
+    val pid1 = 1L
+    val pid2 = 2L
+
+    append(stateManager, pid1, epoch, 0, 0L)
+    append(stateManager, pid2, epoch, 0, 1L)
+    stateManager.takeSnapshot()
+    assertEquals(2, stateManager.activeProducers.size)
+
+    stateManager.truncateAndReload(1L, 2L, time.milliseconds())
+    assertEquals(1, stateManager.activeProducers.size)
+    assertEquals(None, stateManager.lastEntry(pid1))
+
+    val entry = stateManager.lastEntry(pid2)
+    assertTrue(entry.isDefined)
+    assertEquals(0, entry.get.lastSeq)
+    assertEquals(1L, entry.get.lastOffset)
+  }
+
+  @Test
   def testSkipSnapshotIfOffsetUnchanged(): Unit = {
     val epoch = 0.toShort
     append(stateManager, producerId, epoch, 0, 0L, 0L)

http://git-wip-us.apache.org/repos/asf/kafka/blob/bcaee7fe/core/src/test/scala/unit/kafka/utils/TestUtils.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/utils/TestUtils.scala b/core/src/test/scala/unit/kafka/utils/TestUtils.scala
index a2c9b05..14f0114 100755
--- a/core/src/test/scala/unit/kafka/utils/TestUtils.scala
+++ b/core/src/test/scala/unit/kafka/utils/TestUtils.scala
@@ -352,13 +352,13 @@ object TestUtils extends Logging {
   def records(records: Iterable[SimpleRecord],
               magicValue: Byte = RecordBatch.CURRENT_MAGIC_VALUE,
               codec: CompressionType = CompressionType.NONE,
-              pid: Long = RecordBatch.NO_PRODUCER_ID,
-              epoch: Short = RecordBatch.NO_PRODUCER_EPOCH,
+              producerId: Long = RecordBatch.NO_PRODUCER_ID,
+              producerEpoch: Short = RecordBatch.NO_PRODUCER_EPOCH,
               sequence: Int = RecordBatch.NO_SEQUENCE,
               baseOffset: Long = 0L): MemoryRecords = {
     val buf = ByteBuffer.allocate(DefaultRecordBatch.sizeInBytes(records.asJava))
     val builder = MemoryRecords.builder(buf, magicValue, codec, TimestampType.CREATE_TIME, baseOffset,
-      System.currentTimeMillis, pid, epoch, sequence)
+      System.currentTimeMillis, producerId, producerEpoch, sequence)
     records.foreach(builder.append)
     builder.build()
   }


[2/2] kafka git commit: KAFKA-5435; Improve producer state loading after failure

Posted by jg...@apache.org.
KAFKA-5435; Improve producer state loading after failure

Author: Jason Gustafson <ja...@confluent.io>

Reviewers: Apurva Mehta <ap...@confluent.io>, Jun Rao <ju...@gmail.com>

Closes #3361 from hachikuji/KAFKA-5435-ALT


Project: http://git-wip-us.apache.org/repos/asf/kafka/repo
Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/bcaee7fe
Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/bcaee7fe
Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/bcaee7fe

Branch: refs/heads/trunk
Commit: bcaee7fe19b3f2e75f6de0dfc5487deebfef446e
Parents: d68f9e2
Author: Jason Gustafson <ja...@confluent.io>
Authored: Sat Jun 17 10:20:00 2017 -0700
Committer: Jason Gustafson <ja...@confluent.io>
Committed: Sat Jun 17 10:20:00 2017 -0700

----------------------------------------------------------------------
 .../org/apache/kafka/common/utils/Utils.java    |  20 +-
 core/src/main/scala/kafka/log/Log.scala         | 140 +++--
 core/src/main/scala/kafka/log/LogCleaner.scala  |   2 +-
 core/src/main/scala/kafka/log/LogManager.scala  |   4 +-
 .../scala/kafka/log/ProducerStateManager.scala  |  58 +-
 .../test/scala/other/kafka/StressTestLog.scala  |  14 +-
 .../other/kafka/TestLinearWriteSpeed.scala      |   2 +-
 .../log/AbstractLogCleanerIntegrationTest.scala |   2 +-
 .../unit/kafka/log/BrokerCompressionTest.scala  |   2 +-
 .../unit/kafka/log/LogCleanerManagerTest.scala  |   4 +-
 .../scala/unit/kafka/log/LogCleanerTest.scala   |   2 +-
 .../scala/unit/kafka/log/LogManagerTest.scala   |  16 +-
 .../src/test/scala/unit/kafka/log/LogTest.scala | 602 ++++++++++++++-----
 .../kafka/log/ProducerStateManagerTest.scala    |  33 +-
 .../test/scala/unit/kafka/utils/TestUtils.scala |   6 +-
 15 files changed, 663 insertions(+), 244 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kafka/blob/bcaee7fe/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/utils/Utils.java b/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
index 21fbaf4..e997fef 100755
--- a/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
+++ b/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
@@ -137,8 +137,8 @@ public class Utils {
     /**
      * Get the minimum of some long values.
      * @param first Used to ensure at least one value
-     * @param rest The rest of longs to compare
-     * @return The minimum of all passed argument.
+     * @param rest The remaining values to compare
+     * @return The minimum of all passed values
      */
     public static long min(long first, long ... rest) {
         long min = first;
@@ -149,6 +149,22 @@ public class Utils {
         return min;
     }
 
+    /**
+     * Get the maximum of some long values.
+     * @param first Used to ensure at least one value
+     * @param rest The remaining values to compare
+     * @return The maximum of all passed values
+     */
+    public static long max(long first, long ... rest) {
+        long max = first;
+        for (long r : rest) {
+            if (r > max)
+                max = r;
+        }
+        return max;
+    }
+
+
     public static short min(short first, short second) {
         return (short) Math.min(first, second);
     }

http://git-wip-us.apache.org/repos/asf/kafka/blob/bcaee7fe/core/src/main/scala/kafka/log/Log.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/log/Log.scala b/core/src/main/scala/kafka/log/Log.scala
index 7a3bc94..176a268 100644
--- a/core/src/main/scala/kafka/log/Log.scala
+++ b/core/src/main/scala/kafka/log/Log.scala
@@ -46,8 +46,6 @@ import java.util.Map.{Entry => JEntry}
 import java.lang.{Long => JLong}
 import java.util.regex.Pattern
 
-import org.apache.kafka.common.internals.Topic
-
 object LogAppendInfo {
   val UnknownLogAppendInfo = LogAppendInfo(-1, -1, RecordBatch.NO_TIMESTAMP, -1L, RecordBatch.NO_TIMESTAMP,
     NoCompressionCodec, NoCompressionCodec, -1, -1, offsetsMonotonic = false)
@@ -130,13 +128,15 @@ case class CompletedTxn(producerId: Long, firstOffset: Long, lastOffset: Long, i
 @threadsafe
 class Log(@volatile var dir: File,
           @volatile var config: LogConfig,
-          @volatile var logStartOffset: Long = 0L,
-          @volatile var recoveryPoint: Long = 0L,
+          @volatile var logStartOffset: Long,
+          @volatile var recoveryPoint: Long,
           scheduler: Scheduler,
           brokerTopicStats: BrokerTopicStats,
-          time: Time = Time.SYSTEM,
-          val maxProducerIdExpirationMs: Int = 60 * 60 * 1000,
-          val producerIdExpirationCheckIntervalMs: Int = 10 * 60 * 1000) extends Logging with KafkaMetricsGroup {
+          time: Time,
+          val maxProducerIdExpirationMs: Int,
+          val producerIdExpirationCheckIntervalMs: Int,
+          val topicPartition: TopicPartition,
+          val producerStateManager: ProducerStateManager) extends Logging with KafkaMetricsGroup {
 
   import kafka.log.Log._
 
@@ -153,15 +153,21 @@ class Log(@volatile var dir: File,
       0
   }
 
-  val topicPartition: TopicPartition = Log.parseTopicPartitionName(dir)
-
   @volatile private var nextOffsetMetadata: LogOffsetMetadata = _
 
-  /* The earliest offset which is part of an incomplete transaction. This is used to compute the LSO. */
+  /* The earliest offset which is part of an incomplete transaction. This is used to compute the
+   * last stable offset (LSO) in ReplicaManager. Note that it is possible that the "true" first unstable offset
+   * gets removed from the log (through record or segment deletion). In this case, the first unstable offset
+   * will point to the log start offset, which may actually be either part of a completed transaction or not
+   * part of a transaction at all. However, since we only use the LSO for the purpose of restricting the
+   * read_committed consumer to fetching decided data (i.e. committed, aborted, or non-transactional), this
+   * temporary abuse seems justifiable and saves us from scanning the log after deletion to find the first offsets
+   * of each ongoing transaction in order to compute a new first unstable offset. It is possible, however,
+   * that this could result in disagreement between replicas depending on when they began replicating the log.
+   * In the worst case, the LSO could be seen by a consumer to go backwards. 
+   */
   @volatile var firstUnstableOffset: Option[LogOffsetMetadata] = None
 
-  private val producerStateManager = new ProducerStateManager(topicPartition, dir, maxProducerIdExpirationMs)
-
   /* the actual segments of the log */
   private val segments: ConcurrentNavigableMap[java.lang.Long, LogSegment] = new ConcurrentSkipListMap[java.lang.Long, LogSegment]
 
@@ -182,7 +188,7 @@ class Log(@volatile var dir: File,
     // The earliest leader epoch may not be flushed during a hard failure. Recover it here.
     leaderEpochCache.clearAndFlushEarliest(logStartOffset)
 
-    loadProducerState(logEndOffset)
+    loadProducerState(logEndOffset, reloadFromCleanShutdown = hasCleanShutdownFile)
 
     info("Completed load of log %s with %d log segments, log start offset %d and log end offset %d in %d ms"
       .format(name, segments.size(), logStartOffset, logEndOffset, time.milliseconds - startMs))
@@ -218,7 +224,7 @@ class Log(@volatile var dir: File,
     lock synchronized {
       producerStateManager.removeExpiredProducers(time.milliseconds)
     }
-  }, period = producerIdExpirationCheckIntervalMs, unit = TimeUnit.MILLISECONDS)
+  }, period = producerIdExpirationCheckIntervalMs, delay = producerIdExpirationCheckIntervalMs, unit = TimeUnit.MILLISECONDS)
 
   /** The name of this log */
   def name  = dir.getName()
@@ -394,7 +400,7 @@ class Log(@volatile var dir: File,
   }
 
   private def updateLogEndOffset(messageOffset: Long) {
-    nextOffsetMetadata = new LogOffsetMetadata(messageOffset, activeSegment.baseOffset, activeSegment.size.toInt)
+    nextOffsetMetadata = new LogOffsetMetadata(messageOffset, activeSegment.baseOffset, activeSegment.size)
   }
 
   private def recoverLog() {
@@ -428,15 +434,27 @@ class Log(@volatile var dir: File,
     }
   }
 
-  private def loadProducerState(lastOffset: Long): Unit = lock synchronized {
-    info(s"Loading producer state from offset $lastOffset for partition $topicPartition")
-
-    if (producerStateManager.latestSnapshotOffset.isEmpty) {
-      // if there are no snapshots to load producer state from, we assume that the brokers are
-      // being upgraded, which means there would be no previous idempotent/transactional producers
-      // to load state for. To avoid an expensive scan through all of the segments, we take
-      // empty snapshots from the start of the last two segments and the last offset. The purpose
-      // of taking the segment snapshots is to avoid the full scan in the case that the log needs
+  private def loadProducerState(lastOffset: Long, reloadFromCleanShutdown: Boolean): Unit = lock synchronized {
+    val messageFormatVersion = config.messageFormatVersion.messageFormatVersion
+    info(s"Loading producer state from offset $lastOffset for partition $topicPartition with message " +
+      s"format version $messageFormatVersion")
+
+    // We want to avoid unnecessary scanning of the log to build the producer state when the broker is being
+    // upgraded. The basic idea is to use the absence of producer snapshot files to detect the upgrade case,
+    // but we have to be careful not to assume too much in the presence of broker failures. The two most common
+    // upgrade cases in which we expect to find no snapshots are the following:
+    //
+    // 1. The broker has been upgraded, but the topic is still on the old message format.
+    // 2. The broker has been upgraded, the topic is on the new message format, and we had a clean shutdown.
+    //
+    // If we hit either of these cases, we skip producer state loading and write a new snapshot at the log end
+    // offset (see below). The next time the log is reloaded, we will load producer state using this snapshot
+    // (or later snapshots). Otherwise, if there is no snapshot file, then we have to rebuild producer state
+    // from the first segment.
+
+    if (producerStateManager.latestSnapshotOffset.isEmpty && (messageFormatVersion < RecordBatch.MAGIC_VALUE_V2 || reloadFromCleanShutdown)) {
+      // To avoid an expensive scan through all of the segments, we take empty snapshots from the start of the
+      // last two segments and the last offset. This should avoid the full scan in the case that the log needs
       // truncation.
       val nextLatestSegmentBaseOffset = Option(segments.lowerEntry(activeSegment.baseOffset)).map(_.getValue.baseOffset)
       val offsetsToSnapshot = Seq(nextLatestSegmentBaseOffset, Some(activeSegment.baseOffset), Some(lastOffset))
@@ -445,17 +463,21 @@ class Log(@volatile var dir: File,
         producerStateManager.takeSnapshot()
       }
     } else {
-      val currentTimeMs = time.milliseconds
-      producerStateManager.truncateAndReload(logStartOffset, lastOffset, currentTimeMs)
-
-      // only do the potentially expensive reloading of the last snapshot offset is lower than the
-      // log end offset (which would be the case on first startup) and there are active producers.
-      // if there are no active producers, then truncating shouldn't change that fact (although it
-      // could cause a producerId to expire earlier than expected), so we can skip the loading.
-      // This is an optimization for users which are not yet using idempotent/transactional features yet.
-      if (lastOffset > producerStateManager.mapEndOffset || !producerStateManager.isEmpty) {
+      val isEmptyBeforeTruncation = producerStateManager.isEmpty && producerStateManager.mapEndOffset >= lastOffset
+      producerStateManager.truncateAndReload(logStartOffset, lastOffset, time.milliseconds())
+
+      // Only do the potentially expensive reloading if the last snapshot offset is lower than the log end
+      // offset (which would be the case on first startup) and there were active producers prior to truncation
+      // (which could be the case if truncating after initial loading). If there weren't, then truncating
+      // shouldn't change that fact (although it could cause a producerId to expire earlier than expected),
+      // and we can skip the loading. This is an optimization for users which are not yet using
+      // idempotent/transactional features yet.
+      if (lastOffset > producerStateManager.mapEndOffset && !isEmptyBeforeTruncation) {
         logSegments(producerStateManager.mapEndOffset, lastOffset).foreach { segment =>
-          val startOffset = math.max(segment.baseOffset, producerStateManager.mapEndOffset)
+          val startOffset = Utils.max(segment.baseOffset, producerStateManager.mapEndOffset, logStartOffset)
+          producerStateManager.updateMapEndOffset(startOffset)
+          producerStateManager.takeSnapshot()
+
           val fetchDataInfo = segment.read(startOffset, Some(lastOffset), Int.MaxValue)
           if (fetchDataInfo != null)
             loadProducersFromLog(producerStateManager, fetchDataInfo.records)
@@ -471,14 +493,16 @@ class Log(@volatile var dir: File,
     val loadedProducers = mutable.Map.empty[Long, ProducerAppendInfo]
     val completedTxns = ListBuffer.empty[CompletedTxn]
     records.batches.asScala.foreach { batch =>
-      if (batch.hasProducerId)
-        updateProducers(batch, loadedProducers, completedTxns, loadingFromLog = true)
+      if (batch.hasProducerId) {
+        val maybeCompletedTxn = updateProducers(batch, loadedProducers, loadingFromLog = true)
+        maybeCompletedTxn.foreach(completedTxns += _)
+      }
     }
     loadedProducers.values.foreach(producerStateManager.update)
     completedTxns.foreach(producerStateManager.completeTxn)
   }
 
-  private[log] def activePids: Map[Long, ProducerIdEntry] = lock synchronized {
+  private[log] def activeProducers: Map[Long, ProducerIdEntry] = lock synchronized {
     producerStateManager.activeProducers
   }
 
@@ -499,6 +523,9 @@ class Log(@volatile var dir: File,
   def close() {
     debug(s"Closing log $name")
     lock synchronized {
+      // We take a snapshot at the last written offset to hopefully avoid the need to scan the log
+      // after restarting and to ensure that we cannot inadvertently hit the upgrade optimization
+      // (the clean shutdown file is written after the logs are all closed).
       producerStateManager.takeSnapshot()
       logSegments.foreach(_.close())
     }
@@ -682,8 +709,8 @@ class Log(@volatile var dir: File,
 
   private def updateFirstUnstableOffset(): Unit = lock synchronized {
     val updatedFirstStableOffset = producerStateManager.firstUnstableOffset match {
-      case Some(logOffsetMetadata) if logOffsetMetadata.messageOffsetOnly =>
-        val offset = logOffsetMetadata.messageOffset
+      case Some(logOffsetMetadata) if logOffsetMetadata.messageOffsetOnly || logOffsetMetadata.messageOffset < logStartOffset =>
+        val offset = math.max(logOffsetMetadata.messageOffset, logStartOffset)
         val segment = segments.floorEntry(offset).getValue
         val position  = segment.translateOffset(offset)
         Some(LogOffsetMetadata(offset, segment.baseOffset, position.position))
@@ -706,6 +733,9 @@ class Log(@volatile var dir: File,
     lock synchronized {
       if (offset > logStartOffset) {
         logStartOffset = offset
+        leaderEpochCache.clearAndFlushEarliest(logStartOffset)
+        producerStateManager.truncateHead(logStartOffset)
+        updateFirstUnstableOffset()
       }
     }
   }
@@ -722,7 +752,9 @@ class Log(@volatile var dir: File,
       // the last appended entry to the client.
       if (isFromClient && maybeLastEntry.exists(_.isDuplicate(batch)))
         return (updatedProducers, completedTxns.toList, maybeLastEntry)
-      updateProducers(batch, updatedProducers, completedTxns, loadingFromLog = false)
+
+      val maybeCompletedTxn = updateProducers(batch, updatedProducers, loadingFromLog = false)
+      maybeCompletedTxn.foreach(completedTxns += _)
     }
     (updatedProducers, completedTxns.toList, None)
   }
@@ -808,12 +840,10 @@ class Log(@volatile var dir: File,
 
   private def updateProducers(batch: RecordBatch,
                               producers: mutable.Map[Long, ProducerAppendInfo],
-                              completedTxns: ListBuffer[CompletedTxn],
-                              loadingFromLog: Boolean): Unit = {
+                              loadingFromLog: Boolean): Option[CompletedTxn] = {
     val producerId = batch.producerId
     val appendInfo = producers.getOrElseUpdate(producerId, producerStateManager.prepareUpdate(producerId, loadingFromLog))
-    val maybeCompletedTxn = appendInfo.append(batch)
-    maybeCompletedTxn.foreach(completedTxns += _)
+    appendInfo.append(batch)
   }
 
   /**
@@ -1048,10 +1078,7 @@ class Log(@volatile var dir: File,
       lock synchronized {
         // remove the segments for lookups
         deletable.foreach(deleteSegment)
-        logStartOffset = math.max(logStartOffset, segments.firstEntry().getValue.baseOffset)
-        leaderEpochCache.clearAndFlushEarliest(logStartOffset)
-        producerStateManager.evictUnretainedProducers(logStartOffset)
-        updateFirstUnstableOffset()
+        maybeIncrementLogStartOffset(segments.firstEntry.getValue.baseOffset)
       }
     }
     numToDelete
@@ -1335,7 +1362,7 @@ class Log(@volatile var dir: File,
         this.recoveryPoint = math.min(targetOffset, this.recoveryPoint)
         this.logStartOffset = math.min(targetOffset, this.logStartOffset)
         leaderEpochCache.clearAndFlushLatest(targetOffset)
-        loadProducerState(targetOffset)
+        loadProducerState(targetOffset, reloadFromCleanShutdown = false)
       }
     }
   }
@@ -1539,6 +1566,21 @@ object Log {
 
   val UnknownLogStartOffset = -1L
 
+  def apply(dir: File,
+            config: LogConfig,
+            logStartOffset: Long = 0L,
+            recoveryPoint: Long = 0L,
+            scheduler: Scheduler,
+            brokerTopicStats: BrokerTopicStats,
+            time: Time = Time.SYSTEM,
+            maxProducerIdExpirationMs: Int = 60 * 60 * 1000,
+            producerIdExpirationCheckIntervalMs: Int = 10 * 60 * 1000): Log = {
+    val topicPartition = Log.parseTopicPartitionName(dir)
+    val producerStateManager = new ProducerStateManager(topicPartition, dir, maxProducerIdExpirationMs)
+    new Log(dir, config, logStartOffset, recoveryPoint, scheduler, brokerTopicStats, time, maxProducerIdExpirationMs,
+      producerIdExpirationCheckIntervalMs, topicPartition, producerStateManager)
+  }
+
   /**
    * Make log segment file name from offset bytes. All this does is pad out the offset number with zeros
    * so that ls sorts the files numerically.

http://git-wip-us.apache.org/repos/asf/kafka/blob/bcaee7fe/core/src/main/scala/kafka/log/LogCleaner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/log/LogCleaner.scala b/core/src/main/scala/kafka/log/LogCleaner.scala
index 5aa8672..623586f 100644
--- a/core/src/main/scala/kafka/log/LogCleaner.scala
+++ b/core/src/main/scala/kafka/log/LogCleaner.scala
@@ -424,7 +424,7 @@ private[log] class Cleaner(val id: Int,
         info("Cleaning segment %s in log %s (largest timestamp %s) into %s, %s deletes."
           .format(startOffset, log.name, new Date(oldSegmentOpt.largestTimestamp), cleaned.baseOffset, if(retainDeletes) "retaining" else "discarding"))
         cleanInto(log.topicPartition, oldSegmentOpt, cleaned, map, retainDeletes, log.config.maxMessageSize, transactionMetadata,
-          log.activePids, stats)
+          log.activeProducers, stats)
 
         currentSegmentOpt = nextSegmentOpt
       }

http://git-wip-us.apache.org/repos/asf/kafka/blob/bcaee7fe/core/src/main/scala/kafka/log/LogManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/log/LogManager.scala b/core/src/main/scala/kafka/log/LogManager.scala
index 61879be..2df5241 100755
--- a/core/src/main/scala/kafka/log/LogManager.scala
+++ b/core/src/main/scala/kafka/log/LogManager.scala
@@ -169,7 +169,7 @@ class LogManager(val logDirs: Array[File],
           val logRecoveryPoint = recoveryPoints.getOrElse(topicPartition, 0L)
           val logStartOffset = logStartOffsets.getOrElse(topicPartition, 0L)
 
-          val current = new Log(
+          val current = Log(
             dir = logDir,
             config = config,
             logStartOffset = logStartOffset,
@@ -414,7 +414,7 @@ class LogManager(val logDirs: Array[File],
         val dir = new File(dataDir, topicPartition.topic + "-" + topicPartition.partition)
         Files.createDirectories(dir.toPath)
 
-        val log = new Log(
+        val log = Log(
           dir = dir,
           config = config,
           logStartOffset = 0L,

http://git-wip-us.apache.org/repos/asf/kafka/blob/bcaee7fe/core/src/main/scala/kafka/log/ProducerStateManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/log/ProducerStateManager.scala b/core/src/main/scala/kafka/log/ProducerStateManager.scala
index e791052..7cc8e8e 100644
--- a/core/src/main/scala/kafka/log/ProducerStateManager.scala
+++ b/core/src/main/scala/kafka/log/ProducerStateManager.scala
@@ -70,6 +70,9 @@ private[log] case class ProducerIdEntry(producerId: Long, producerEpoch: Short,
       s"producerEpoch=$producerEpoch, " +
       s"firstSequence=$firstSeq, " +
       s"lastSequence=$lastSeq, " +
+      s"firstOffset=$firstOffset, " +
+      s"lastOffset=$lastOffset, " +
+      s"timestamp=$timestamp, " +
       s"currentTxnFirstOffset=$currentTxnFirstOffset, " +
       s"coordinatorEpoch=$coordinatorEpoch)"
   }
@@ -369,7 +372,7 @@ object ProducerStateManager {
 @nonthreadsafe
 class ProducerStateManager(val topicPartition: TopicPartition,
                            val logDir: File,
-                           val maxPidExpirationMs: Int = 60 * 60 * 1000) extends Logging {
+                           val maxProducerIdExpirationMs: Int = 60 * 60 * 1000) extends Logging {
   import ProducerStateManager._
   import java.util
 
@@ -434,7 +437,10 @@ class ProducerStateManager(val topicPartition: TopicPartition,
         case Some(file) =>
           try {
             info(s"Loading producer state from snapshot file ${file.getName} for partition $topicPartition")
-            readSnapshot(file).filter(!isExpired(currentTime, _)).foreach(loadProducerEntry)
+            val loadedProducers = readSnapshot(file).filter { producerEntry =>
+              isProducerRetained(producerEntry, logStartOffset) && !isProducerExpired(currentTime, producerEntry)
+            }
+            loadedProducers.foreach(loadProducerEntry)
             lastSnapOffset = offsetFromFilename(file.getName)
             lastMapOffset = lastSnapOffset
             return
@@ -460,15 +466,15 @@ class ProducerStateManager(val topicPartition: TopicPartition,
     }
   }
 
-  private def isExpired(currentTimeMs: Long, producerIdEntry: ProducerIdEntry): Boolean =
-    producerIdEntry.currentTxnFirstOffset.isEmpty && currentTimeMs - producerIdEntry.timestamp >= maxPidExpirationMs
+  private def isProducerExpired(currentTimeMs: Long, producerIdEntry: ProducerIdEntry): Boolean =
+    producerIdEntry.currentTxnFirstOffset.isEmpty && currentTimeMs - producerIdEntry.timestamp >= maxProducerIdExpirationMs
 
   /**
    * Expire any producer ids which have been idle longer than the configured maximum expiration timeout.
    */
   def removeExpiredProducers(currentTimeMs: Long) {
     producers.retain { case (producerId, lastEntry) =>
-      !isExpired(currentTimeMs, lastEntry)
+      !isProducerExpired(currentTimeMs, lastEntry)
     }
   }
 
@@ -479,9 +485,8 @@ class ProducerStateManager(val topicPartition: TopicPartition,
    */
   def truncateAndReload(logStartOffset: Long, logEndOffset: Long, currentTimeMs: Long) {
     // remove all out of range snapshots
-    deleteSnapshotFiles { file =>
-      val offset = offsetFromFilename(file.getName)
-      offset > logEndOffset || offset <= logStartOffset
+    deleteSnapshotFiles { snapOffset =>
+      snapOffset > logEndOffset || snapOffset <= logStartOffset
     }
 
     if (logEndOffset != mapEndOffset) {
@@ -493,7 +498,7 @@ class ProducerStateManager(val topicPartition: TopicPartition,
       unreplicatedTxns.clear()
       loadFromSnapshot(logStartOffset, currentTimeMs)
     } else {
-      evictUnretainedProducers(logStartOffset)
+      truncateHead(logStartOffset)
     }
   }
 
@@ -551,21 +556,33 @@ class ProducerStateManager(val topicPartition: TopicPartition,
    */
   def oldestSnapshotOffset: Option[Long] = oldestSnapshotFile.map(file => offsetFromFilename(file.getName))
 
+  private def isProducerRetained(producerIdEntry: ProducerIdEntry, logStartOffset: Long): Boolean = {
+    producerIdEntry.lastOffset >= logStartOffset
+  }
+
   /**
    * When we remove the head of the log due to retention, we need to clean up the id map. This method takes
-   * the new start offset and removes all producerIds which have a smaller last written offset.
+   * the new start offset and removes all producerIds which have a smaller last written offset. Additionally,
+   * we remove snapshots older than the new log start offset.
+   *
+   * Note that snapshots from offsets greater than the log start offset may have producers included which
+   * should no longer be retained: these producers will be removed if and when we need to load state from
+   * the snapshot.
    */
-  def evictUnretainedProducers(logStartOffset: Long) {
-    val evictedProducerEntries = producers.filter(_._2.lastOffset < logStartOffset)
+  def truncateHead(logStartOffset: Long) {
+    val evictedProducerEntries = producers.filter { case (_, producerIdEntry) =>
+      !isProducerRetained(producerIdEntry, logStartOffset)
+    }
     val evictedProducerIds = evictedProducerEntries.keySet
 
     producers --= evictedProducerIds
     removeEvictedOngoingTransactions(evictedProducerIds)
     removeUnreplicatedTransactions(logStartOffset)
 
-    deleteSnapshotFiles(file => offsetFromFilename(file.getName) <= logStartOffset)
     if (lastMapOffset < logStartOffset)
       lastMapOffset = logStartOffset
+
+    deleteSnapshotsBefore(logStartOffset)
     lastSnapOffset = latestSnapshotOffset.getOrElse(logStartOffset)
   }
 
@@ -604,12 +621,12 @@ class ProducerStateManager(val topicPartition: TopicPartition,
    * Complete the transaction and return the last stable offset.
    */
   def completeTxn(completedTxn: CompletedTxn): Long = {
-    val txnMetdata = ongoingTxns.remove(completedTxn.firstOffset)
-    if (txnMetdata == null)
+    val txnMetadata = ongoingTxns.remove(completedTxn.firstOffset)
+    if (txnMetadata == null)
       throw new IllegalArgumentException("Attempted to complete a transaction which was not started")
 
-    txnMetdata.lastOffset = Some(completedTxn.lastOffset)
-    unreplicatedTxns.put(completedTxn.firstOffset, txnMetdata)
+    txnMetadata.lastOffset = Some(completedTxn.lastOffset)
+    unreplicatedTxns.put(completedTxn.firstOffset, txnMetadata)
 
     val lastStableOffset = firstUndecidedOffset.getOrElse(completedTxn.lastOffset + 1)
     lastStableOffset
@@ -617,7 +634,7 @@ class ProducerStateManager(val topicPartition: TopicPartition,
 
   @threadsafe
   def deleteSnapshotsBefore(offset: Long): Unit = {
-    deleteSnapshotFiles(file => offsetFromFilename(file.getName) < offset)
+    deleteSnapshotFiles(_ < offset)
   }
 
   private def listSnapshotFiles: List[File] = {
@@ -643,8 +660,9 @@ class ProducerStateManager(val topicPartition: TopicPartition,
       None
   }
 
-  private def deleteSnapshotFiles(predicate: File => Boolean = _ => true) {
-    listSnapshotFiles.filter(predicate).foreach(file => Files.deleteIfExists(file.toPath))
+  private def deleteSnapshotFiles(predicate: Long => Boolean = _ => true) {
+    listSnapshotFiles.filter(file => predicate(offsetFromFilename(file.getName)))
+      .foreach(file => Files.deleteIfExists(file.toPath))
   }
 
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/bcaee7fe/core/src/test/scala/other/kafka/StressTestLog.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/other/kafka/StressTestLog.scala b/core/src/test/scala/other/kafka/StressTestLog.scala
index 6c134ac..5355ca2 100755
--- a/core/src/test/scala/other/kafka/StressTestLog.scala
+++ b/core/src/test/scala/other/kafka/StressTestLog.scala
@@ -43,13 +43,13 @@ object StressTestLog {
     logProperties.put(LogConfig.MaxMessageBytesProp, Int.MaxValue: java.lang.Integer)
     logProperties.put(LogConfig.SegmentIndexBytesProp, 1024*1024: java.lang.Integer)
 
-    val log = new Log(dir = dir,
-                      config = LogConfig(logProperties),
-                      logStartOffset = 0L,
-                      recoveryPoint = 0L,
-                      scheduler = time.scheduler,
-                      time = time,
-                      brokerTopicStats = new BrokerTopicStats)
+    val log = Log(dir = dir,
+      config = LogConfig(logProperties),
+      logStartOffset = 0L,
+      recoveryPoint = 0L,
+      scheduler = time.scheduler,
+      time = time,
+      brokerTopicStats = new BrokerTopicStats)
     val writer = new WriterThread(log)
     writer.start()
     val reader = new ReaderThread(log)

http://git-wip-us.apache.org/repos/asf/kafka/blob/bcaee7fe/core/src/test/scala/other/kafka/TestLinearWriteSpeed.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/other/kafka/TestLinearWriteSpeed.scala b/core/src/test/scala/other/kafka/TestLinearWriteSpeed.scala
index bd66d25..f211c4c 100755
--- a/core/src/test/scala/other/kafka/TestLinearWriteSpeed.scala
+++ b/core/src/test/scala/other/kafka/TestLinearWriteSpeed.scala
@@ -207,7 +207,7 @@ object TestLinearWriteSpeed {
   
   class LogWritable(val dir: File, config: LogConfig, scheduler: Scheduler, val messages: MemoryRecords) extends Writable {
     Utils.delete(dir)
-    val log = new Log(dir, config, 0L, 0L, scheduler, new BrokerTopicStats, Time.SYSTEM)
+    val log = Log(dir, config, 0L, 0L, scheduler, new BrokerTopicStats, Time.SYSTEM)
     def write(): Int = {
       log.appendAsLeader(messages, leaderEpoch = 0)
       messages.sizeInBytes

http://git-wip-us.apache.org/repos/asf/kafka/blob/bcaee7fe/core/src/test/scala/unit/kafka/log/AbstractLogCleanerIntegrationTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/log/AbstractLogCleanerIntegrationTest.scala b/core/src/test/scala/unit/kafka/log/AbstractLogCleanerIntegrationTest.scala
index f796042..bf36199 100644
--- a/core/src/test/scala/unit/kafka/log/AbstractLogCleanerIntegrationTest.scala
+++ b/core/src/test/scala/unit/kafka/log/AbstractLogCleanerIntegrationTest.scala
@@ -91,7 +91,7 @@ abstract class AbstractLogCleanerIntegrationTest {
         compactionLag = compactionLag,
         deleteDelay = deleteDelay,
         segmentSize = segmentSize))
-      val log = new Log(dir,
+      val log = Log(dir,
         logConfig,
         logStartOffset = 0L,
         recoveryPoint = 0L,

http://git-wip-us.apache.org/repos/asf/kafka/blob/bcaee7fe/core/src/test/scala/unit/kafka/log/BrokerCompressionTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/log/BrokerCompressionTest.scala b/core/src/test/scala/unit/kafka/log/BrokerCompressionTest.scala
index 3d19442..9c727c6 100755
--- a/core/src/test/scala/unit/kafka/log/BrokerCompressionTest.scala
+++ b/core/src/test/scala/unit/kafka/log/BrokerCompressionTest.scala
@@ -55,7 +55,7 @@ class BrokerCompressionTest(messageCompression: String, brokerCompression: Strin
     val logProps = new Properties()
     logProps.put(LogConfig.CompressionTypeProp, brokerCompression)
     /*configure broker-side compression  */
-    val log = new Log(logDir, LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
+    val log = Log(logDir, LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
       time = time, brokerTopicStats = new BrokerTopicStats)
 
     /* append two messages */

http://git-wip-us.apache.org/repos/asf/kafka/blob/bcaee7fe/core/src/test/scala/unit/kafka/log/LogCleanerManagerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/log/LogCleanerManagerTest.scala b/core/src/test/scala/unit/kafka/log/LogCleanerManagerTest.scala
index d577f02..b4c1790 100644
--- a/core/src/test/scala/unit/kafka/log/LogCleanerManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogCleanerManagerTest.scala
@@ -230,7 +230,7 @@ class LogCleanerManagerTest extends JUnitSuite with Logging {
 
     val config = LogConfig(logProps)
     val partitionDir = new File(logDir, "log-0")
-    val log = new Log(partitionDir,
+    val log = Log(partitionDir,
       config,
       logStartOffset = 0L,
       recoveryPoint = 0L,
@@ -241,7 +241,7 @@ class LogCleanerManagerTest extends JUnitSuite with Logging {
   }
 
   private def makeLog(dir: File = logDir, config: LogConfig = logConfig) =
-    new Log(dir = dir, config = config, logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
+    Log(dir = dir, config = config, logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
       time = time, brokerTopicStats = new BrokerTopicStats)
 
   private def records(key: Int, value: Int, timestamp: Long) =

http://git-wip-us.apache.org/repos/asf/kafka/blob/bcaee7fe/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala b/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
index a280679..5e95dc2 100755
--- a/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
@@ -1035,7 +1035,7 @@ class LogCleanerTest extends JUnitSuite {
     messageWithOffset(key.toString.getBytes, value.toString.getBytes, offset)
 
   private def makeLog(dir: File = dir, config: LogConfig = logConfig, recoveryPoint: Long = 0L) =
-    new Log(dir = dir, config = config, logStartOffset = 0L, recoveryPoint = recoveryPoint, scheduler = time.scheduler,
+    Log(dir = dir, config = config, logStartOffset = 0L, recoveryPoint = recoveryPoint, scheduler = time.scheduler,
       time = time, brokerTopicStats = new BrokerTopicStats)
 
   private def noOpCheckDone(topicPartition: TopicPartition) { /* do nothing */  }

http://git-wip-us.apache.org/repos/asf/kafka/blob/bcaee7fe/core/src/test/scala/unit/kafka/log/LogManagerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/log/LogManagerTest.scala b/core/src/test/scala/unit/kafka/log/LogManagerTest.scala
index 5b29471..8b7819f 100755
--- a/core/src/test/scala/unit/kafka/log/LogManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogManagerTest.scala
@@ -35,7 +35,7 @@ class LogManagerTest {
 
   val time: MockTime = new MockTime()
   val maxRollInterval = 100
-  val maxLogAgeMs = 10*60*60*1000
+  val maxLogAgeMs = 10*60*1000
   val logProps = new Properties()
   logProps.put(LogConfig.SegmentBytesProp, 1024: java.lang.Integer)
   logProps.put(LogConfig.SegmentIndexBytesProp, 4096: java.lang.Integer)
@@ -51,7 +51,7 @@ class LogManagerTest {
   def setUp() {
     logDir = TestUtils.tempDir()
     logManager = createLogManager()
-    logManager.startup
+    logManager.startup()
     logDir = logManager.logDirs(0)
   }
 
@@ -105,8 +105,8 @@ class LogManagerTest {
     assertEquals("Now there should only be only one segment in the index.", 1, log.numberOfSegments)
     time.sleep(log.config.fileDeleteDelayMs + 1)
 
-    // there should be a log file, two indexes, and the leader epoch checkpoint
-    assertEquals("Files should have been deleted", log.numberOfSegments * 3 + 1, log.dir.list.length)
+    // there should be a log file, two indexes, one producer snapshot, and the leader epoch checkpoint
+    assertEquals("Files should have been deleted", log.numberOfSegments * 4 + 1, log.dir.list.length)
     assertEquals("Should get empty fetch off new log.", 0, log.readUncommitted(offset+1, 1024).records.sizeInBytes)
 
     try {
@@ -132,7 +132,7 @@ class LogManagerTest {
     val config = LogConfig.fromProps(logConfig.originals, logProps)
 
     logManager = createLogManager()
-    logManager.startup
+    logManager.startup()
 
     // create a log
     val log = logManager.createLog(new TopicPartition(name, 0), config)
@@ -203,7 +203,7 @@ class LogManagerTest {
     val config = LogConfig.fromProps(logConfig.originals, logProps)
 
     logManager = createLogManager()
-    logManager.startup
+    logManager.startup()
     val log = logManager.createLog(new TopicPartition(name, 0), config)
     val lastFlush = log.lastFlushTime
     for (_ <- 0 until 200) {
@@ -265,7 +265,7 @@ class LogManagerTest {
     logDir = TestUtils.tempDir()
     logManager = TestUtils.createLogManager(
       logDirs = Array(new File(logDir.getAbsolutePath + File.separator)))
-    logManager.startup
+    logManager.startup()
     verifyCheckpointRecovery(Seq(new TopicPartition("test-a", 1)), logManager)
   }
 
@@ -279,7 +279,7 @@ class LogManagerTest {
     logDir.mkdirs()
     logDir.deleteOnExit()
     logManager = createLogManager()
-    logManager.startup
+    logManager.startup()
     verifyCheckpointRecovery(Seq(new TopicPartition("test-a", 1)), logManager)
   }
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/bcaee7fe/core/src/test/scala/unit/kafka/log/LogTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/log/LogTest.scala b/core/src/test/scala/unit/kafka/log/LogTest.scala
index 7b67857..630cfcf 100755
--- a/core/src/test/scala/unit/kafka/log/LogTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogTest.scala
@@ -36,6 +36,7 @@ import org.apache.kafka.common.record._
 import org.apache.kafka.common.requests.FetchResponse.AbortedTransaction
 import org.apache.kafka.common.requests.IsolationLevel
 import org.apache.kafka.common.utils.Utils
+import org.easymock.EasyMock
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable.ListBuffer
@@ -98,14 +99,14 @@ class LogTest {
     logProps.put(LogConfig.MessageTimestampDifferenceMaxMsProp, Long.MaxValue.toString)
 
     // create a log
-    val log = new Log(logDir,
-                      LogConfig(logProps),
-                      logStartOffset = 0L,
-                      recoveryPoint = 0L,
-                      maxProducerIdExpirationMs = 24 * 60,
-                      scheduler = time.scheduler,
-                      brokerTopicStats = brokerTopicStats,
-                      time = time)
+    val log = Log(logDir,
+      LogConfig(logProps),
+      logStartOffset = 0L,
+      recoveryPoint = 0L,
+      maxProducerIdExpirationMs = 24 * 60,
+      scheduler = time.scheduler,
+      brokerTopicStats = brokerTopicStats,
+      time = time)
     assertEquals("Log begins with a single empty segment.", 1, log.numberOfSegments)
     // Test the segment rolling behavior when messages do not have a timestamp.
     time.sleep(log.config.segmentMs + 1)
@@ -152,7 +153,7 @@ class LogTest {
     val logProps = new Properties()
 
     // create a log
-    val log = new Log(logDir,
+    val log = Log(logDir,
       LogConfig(logProps),
       recoveryPoint = 0L,
       scheduler = time.scheduler,
@@ -162,10 +163,10 @@ class LogTest {
     val pid = 1L
     val epoch: Short = 0
 
-    val records = TestUtils.records(List(new SimpleRecord(time.milliseconds, "key".getBytes, "value".getBytes)), pid = pid, epoch = epoch, sequence = 0)
+    val records = TestUtils.records(List(new SimpleRecord(time.milliseconds, "key".getBytes, "value".getBytes)), producerId = pid, producerEpoch = epoch, sequence = 0)
     log.appendAsLeader(records, leaderEpoch = 0)
 
-    val nextRecords = TestUtils.records(List(new SimpleRecord(time.milliseconds, "key".getBytes, "value".getBytes)), pid = pid, epoch = epoch, sequence = 2)
+    val nextRecords = TestUtils.records(List(new SimpleRecord(time.milliseconds, "key".getBytes, "value".getBytes)), producerId = pid, producerEpoch = epoch, sequence = 2)
     log.appendAsLeader(nextRecords, leaderEpoch = 0)
   }
 
@@ -206,6 +207,183 @@ class LogTest {
   }
 
   @Test
+  def testSkipLoadingIfEmptyProducerStateBeforeTruncation(): Unit = {
+    val stateManager = EasyMock.mock(classOf[ProducerStateManager])
+
+    // Load the log
+    EasyMock.expect(stateManager.latestSnapshotOffset).andReturn(None)
+
+    stateManager.updateMapEndOffset(0L)
+    EasyMock.expectLastCall().anyTimes()
+
+    EasyMock.expect(stateManager.mapEndOffset).andStubReturn(0L)
+    EasyMock.expect(stateManager.isEmpty).andStubReturn(true)
+
+    stateManager.takeSnapshot()
+    EasyMock.expectLastCall().anyTimes()
+
+    stateManager.truncateAndReload(EasyMock.eq(0L), EasyMock.eq(0L), EasyMock.anyLong)
+    EasyMock.expectLastCall()
+
+    EasyMock.expect(stateManager.firstUnstableOffset).andStubReturn(None)
+
+    EasyMock.replay(stateManager)
+
+    val config = LogConfig(new Properties())
+    val log = new Log(logDir,
+      config,
+      logStartOffset = 0L,
+      recoveryPoint = 0L,
+      scheduler = time.scheduler,
+      brokerTopicStats = brokerTopicStats,
+      time = time,
+      maxProducerIdExpirationMs = 300000,
+      producerIdExpirationCheckIntervalMs = 30000,
+      topicPartition = Log.parseTopicPartitionName(logDir),
+      stateManager)
+
+    EasyMock.verify(stateManager)
+
+    // Append some messages
+    EasyMock.reset(stateManager)
+    EasyMock.expect(stateManager.firstUnstableOffset).andStubReturn(None)
+
+    stateManager.updateMapEndOffset(1L)
+    EasyMock.expectLastCall()
+    stateManager.updateMapEndOffset(2L)
+    EasyMock.expectLastCall()
+
+    EasyMock.replay(stateManager)
+
+    log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes))), leaderEpoch = 0)
+    log.appendAsLeader(TestUtils.records(List(new SimpleRecord("b".getBytes))), leaderEpoch = 0)
+
+    EasyMock.verify(stateManager)
+
+    // Now truncate
+    EasyMock.reset(stateManager)
+    EasyMock.expect(stateManager.firstUnstableOffset).andStubReturn(None)
+    EasyMock.expect(stateManager.latestSnapshotOffset).andReturn(None)
+    EasyMock.expect(stateManager.isEmpty).andStubReturn(true)
+    EasyMock.expect(stateManager.mapEndOffset).andReturn(2L)
+    stateManager.truncateAndReload(EasyMock.eq(0L), EasyMock.eq(1L), EasyMock.anyLong)
+    EasyMock.expectLastCall()
+    // Truncation causes the map end offset to reset to 0
+    EasyMock.expect(stateManager.mapEndOffset).andReturn(0L)
+    // We skip directly to updating the map end offset
+    stateManager.updateMapEndOffset(1L)
+    EasyMock.expectLastCall()
+
+    EasyMock.replay(stateManager)
+
+    log.truncateTo(1L)
+
+    EasyMock.verify(stateManager)
+  }
+
+  @Test
+  def testSkipTruncateAndReloadIfOldMessageFormatAndNoCleanShutdown(): Unit = {
+    val stateManager = EasyMock.mock(classOf[ProducerStateManager])
+
+    EasyMock.expect(stateManager.latestSnapshotOffset).andReturn(None)
+
+    stateManager.updateMapEndOffset(0L)
+    EasyMock.expectLastCall().anyTimes()
+
+    stateManager.takeSnapshot()
+    EasyMock.expectLastCall().anyTimes()
+
+    EasyMock.replay(stateManager)
+
+    val logProps = new Properties()
+    logProps.put(LogConfig.MessageFormatVersionProp, "0.10.2")
+    val config = LogConfig(logProps)
+    new Log(logDir,
+      config,
+      logStartOffset = 0L,
+      recoveryPoint = 0L,
+      scheduler = time.scheduler,
+      brokerTopicStats = brokerTopicStats,
+      time = time,
+      maxProducerIdExpirationMs = 300000,
+      producerIdExpirationCheckIntervalMs = 30000,
+      topicPartition = Log.parseTopicPartitionName(logDir),
+      stateManager)
+
+    EasyMock.verify(stateManager)
+  }
+
+  @Test
+  def testSkipTruncateAndReloadIfOldMessageFormatAndCleanShutdown(): Unit = {
+    val stateManager = EasyMock.mock(classOf[ProducerStateManager])
+
+    EasyMock.expect(stateManager.latestSnapshotOffset).andReturn(None)
+
+    stateManager.updateMapEndOffset(0L)
+    EasyMock.expectLastCall().anyTimes()
+
+    stateManager.takeSnapshot()
+    EasyMock.expectLastCall().anyTimes()
+
+    EasyMock.replay(stateManager)
+
+    val cleanShutdownFile = createCleanShutdownFile()
+
+    val logProps = new Properties()
+    logProps.put(LogConfig.MessageFormatVersionProp, "0.10.2")
+    val config = LogConfig(logProps)
+    new Log(logDir,
+      config,
+      logStartOffset = 0L,
+      recoveryPoint = 0L,
+      scheduler = time.scheduler,
+      brokerTopicStats = brokerTopicStats,
+      time = time,
+      maxProducerIdExpirationMs = 300000,
+      producerIdExpirationCheckIntervalMs = 30000,
+      topicPartition = Log.parseTopicPartitionName(logDir),
+      stateManager)
+
+    EasyMock.verify(stateManager)
+    cleanShutdownFile.delete()
+  }
+
+  @Test
+  def testSkipTruncateAndReloadIfNewMessageFormatAndCleanShutdown(): Unit = {
+    val stateManager = EasyMock.mock(classOf[ProducerStateManager])
+
+    EasyMock.expect(stateManager.latestSnapshotOffset).andReturn(None)
+
+    stateManager.updateMapEndOffset(0L)
+    EasyMock.expectLastCall().anyTimes()
+
+    stateManager.takeSnapshot()
+    EasyMock.expectLastCall().anyTimes()
+
+    EasyMock.replay(stateManager)
+
+    val cleanShutdownFile = createCleanShutdownFile()
+
+    val logProps = new Properties()
+    logProps.put(LogConfig.MessageFormatVersionProp, "0.11.0")
+    val config = LogConfig(logProps)
+    new Log(logDir,
+      config,
+      logStartOffset = 0L,
+      recoveryPoint = 0L,
+      scheduler = time.scheduler,
+      brokerTopicStats = brokerTopicStats,
+      time = time,
+      maxProducerIdExpirationMs = 300000,
+      producerIdExpirationCheckIntervalMs = 30000,
+      topicPartition = Log.parseTopicPartitionName(logDir),
+      stateManager)
+
+    EasyMock.verify(stateManager)
+    cleanShutdownFile.delete()
+  }
+
+  @Test
   def testRebuildPidMapWithCompactedData() {
     val log = createLog(2048)
     val pid = 1L
@@ -214,7 +392,7 @@ class LogTest {
     val baseOffset = 23L
 
     // create a batch with a couple gaps to simulate compaction
-    val records = TestUtils.records(pid = pid, epoch = epoch, sequence = seq, baseOffset = baseOffset, records = List(
+    val records = TestUtils.records(producerId = pid, producerEpoch = epoch, sequence = seq, baseOffset = baseOffset, records = List(
       new SimpleRecord(System.currentTimeMillis(), "a".getBytes),
       new SimpleRecord(System.currentTimeMillis(), "key".getBytes, "b".getBytes),
       new SimpleRecord(System.currentTimeMillis(), "c".getBytes),
@@ -239,10 +417,10 @@ class LogTest {
 
     log.truncateTo(baseOffset + 4)
 
-    val activePids = log.activePids
-    assertTrue(activePids.contains(pid))
+    val activeProducers = log.activeProducers
+    assertTrue(activeProducers.contains(pid))
 
-    val entry = activePids(pid)
+    val entry = activeProducers(pid)
     assertEquals(0, entry.firstSeq)
     assertEquals(baseOffset, entry.firstOffset)
     assertEquals(3, entry.lastSeq)
@@ -258,7 +436,7 @@ class LogTest {
     val baseOffset = 23L
 
     // create a batch with a couple gaps to simulate compaction
-    val records = TestUtils.records(pid = pid, epoch = epoch, sequence = seq, baseOffset = baseOffset, records = List(
+    val records = TestUtils.records(producerId = pid, producerEpoch = epoch, sequence = seq, baseOffset = baseOffset, records = List(
       new SimpleRecord(System.currentTimeMillis(), "a".getBytes),
       new SimpleRecord(System.currentTimeMillis(), "key".getBytes, "b".getBytes),
       new SimpleRecord(System.currentTimeMillis(), "c".getBytes),
@@ -273,10 +451,10 @@ class LogTest {
     val filteredRecords = MemoryRecords.readableRecords(filtered)
 
     log.appendAsFollower(filteredRecords)
-    val activePids = log.activePids
-    assertTrue(activePids.contains(pid))
+    val activeProducers = log.activeProducers
+    assertTrue(activeProducers.contains(pid))
 
-    val entry = activePids(pid)
+    val entry = activeProducers(pid)
     assertEquals(0, entry.firstSeq)
     assertEquals(baseOffset, entry.firstOffset)
     assertEquals(3, entry.lastSeq)
@@ -303,6 +481,95 @@ class LogTest {
   }
 
   @Test
+  def testPidMapTruncateToWithNoSnapshots() {
+    // This ensures that the upgrade optimization path cannot be hit after initial loading
+
+    val log = createLog(2048)
+    val pid = 1L
+    val epoch = 0.toShort
+
+    log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes)), producerId = pid,
+      producerEpoch = epoch, sequence = 0), leaderEpoch = 0)
+    log.appendAsLeader(TestUtils.records(List(new SimpleRecord("b".getBytes)), producerId = pid,
+      producerEpoch = epoch, sequence = 1), leaderEpoch = 0)
+
+    // Delete all snapshots prior to truncating
+    logDir.listFiles.filter(f => f.isFile && f.getName.endsWith(Log.PidSnapshotFileSuffix)).foreach { file =>
+      Files.delete(file.toPath)
+    }
+
+    log.truncateTo(1L)
+    assertEquals(1, log.activeProducers.size)
+
+    val pidEntryOpt = log.activeProducers.get(pid)
+    assertTrue(pidEntryOpt.isDefined)
+
+    val pidEntry = pidEntryOpt.get
+    assertEquals(0, pidEntry.lastSeq)
+  }
+
+  @Test
+  def testLoadProducersAfterDeleteRecordsMidSegment(): Unit = {
+    val log = createLog(2048)
+    val pid1 = 1L
+    val pid2 = 2L
+    val epoch = 0.toShort
+
+    log.appendAsLeader(TestUtils.records(List(new SimpleRecord(time.milliseconds(), "a".getBytes)), producerId = pid1,
+      producerEpoch = epoch, sequence = 0), leaderEpoch = 0)
+    log.appendAsLeader(TestUtils.records(List(new SimpleRecord(time.milliseconds(), "b".getBytes)), producerId = pid2,
+      producerEpoch = epoch, sequence = 0), leaderEpoch = 0)
+    assertEquals(2, log.activeProducers.size)
+
+    log.maybeIncrementLogStartOffset(1L)
+
+    assertEquals(1, log.activeProducers.size)
+    val retainedEntryOpt = log.activeProducers.get(pid2)
+    assertTrue(retainedEntryOpt.isDefined)
+    assertEquals(0, retainedEntryOpt.get.lastSeq)
+
+    log.close()
+
+    val reloadedLog = createLog(2048, logStartOffset = 1L)
+    assertEquals(1, reloadedLog.activeProducers.size)
+    val reloadedEntryOpt = log.activeProducers.get(pid2)
+    assertEquals(retainedEntryOpt, reloadedEntryOpt)
+  }
+
+  @Test
+  def testLoadProducersAfterDeleteRecordsOnSegment(): Unit = {
+    val log = createLog(2048)
+    val pid1 = 1L
+    val pid2 = 2L
+    val epoch = 0.toShort
+
+    log.appendAsLeader(TestUtils.records(List(new SimpleRecord(time.milliseconds(), "a".getBytes)), producerId = pid1,
+      producerEpoch = epoch, sequence = 0), leaderEpoch = 0)
+    log.roll()
+    log.appendAsLeader(TestUtils.records(List(new SimpleRecord(time.milliseconds(), "b".getBytes)), producerId = pid2,
+      producerEpoch = epoch, sequence = 0), leaderEpoch = 0)
+
+    assertEquals(2, log.logSegments.size)
+    assertEquals(2, log.activeProducers.size)
+
+    log.maybeIncrementLogStartOffset(1L)
+    log.deleteOldSegments()
+
+    assertEquals(1, log.logSegments.size)
+    assertEquals(1, log.activeProducers.size)
+    val retainedEntryOpt = log.activeProducers.get(pid2)
+    assertTrue(retainedEntryOpt.isDefined)
+    assertEquals(0, retainedEntryOpt.get.lastSeq)
+
+    log.close()
+
+    val reloadedLog = createLog(2048, logStartOffset = 1L)
+    assertEquals(1, reloadedLog.activeProducers.size)
+    val reloadedEntryOpt = log.activeProducers.get(pid2)
+    assertEquals(retainedEntryOpt, reloadedEntryOpt)
+  }
+
+  @Test
   def testPidMapTruncateFullyAndStartAt() {
     val records = TestUtils.singletonRecords("foo".getBytes)
     val log = createLog(records.sizeInBytes, messagesPerSegment = 1, retentionBytes = records.sizeInBytes * 2)
@@ -326,25 +593,25 @@ class LogTest {
   @Test
   def testPidExpirationOnSegmentDeletion() {
     val pid1 = 1L
-    val records = TestUtils.records(Seq(new SimpleRecord("foo".getBytes)), pid = pid1, epoch = 0, sequence = 0)
+    val records = TestUtils.records(Seq(new SimpleRecord("foo".getBytes)), producerId = pid1, producerEpoch = 0, sequence = 0)
     val log = createLog(records.sizeInBytes, messagesPerSegment = 1, retentionBytes = records.sizeInBytes * 2)
     log.appendAsLeader(records, leaderEpoch = 0)
     log.takeProducerSnapshot()
 
     val pid2 = 2L
-    log.appendAsLeader(TestUtils.records(Seq(new SimpleRecord("bar".getBytes)), pid = pid2, epoch = 0, sequence = 0),
+    log.appendAsLeader(TestUtils.records(Seq(new SimpleRecord("bar".getBytes)), producerId = pid2, producerEpoch = 0, sequence = 0),
       leaderEpoch = 0)
-    log.appendAsLeader(TestUtils.records(Seq(new SimpleRecord("baz".getBytes)), pid = pid2, epoch = 0, sequence = 1),
+    log.appendAsLeader(TestUtils.records(Seq(new SimpleRecord("baz".getBytes)), producerId = pid2, producerEpoch = 0, sequence = 1),
       leaderEpoch = 0)
     log.takeProducerSnapshot()
 
     assertEquals(3, log.logSegments.size)
-    assertEquals(Set(pid1, pid2), log.activePids.keySet)
+    assertEquals(Set(pid1, pid2), log.activeProducers.keySet)
 
     log.deleteOldSegments()
 
     assertEquals(2, log.logSegments.size)
-    assertEquals(Set(pid2), log.activePids.keySet)
+    assertEquals(Set(pid2), log.activeProducers.keySet)
   }
 
   @Test
@@ -423,15 +690,15 @@ class LogTest {
     val log = createLog(2048, maxPidExpirationMs = maxPidExpirationMs,
       pidExpirationCheckIntervalMs = expirationCheckInterval)
     val records = Seq(new SimpleRecord(time.milliseconds(), "foo".getBytes))
-    log.appendAsLeader(TestUtils.records(records, pid = pid, epoch = 0, sequence = 0), leaderEpoch = 0)
+    log.appendAsLeader(TestUtils.records(records, producerId = pid, producerEpoch = 0, sequence = 0), leaderEpoch = 0)
 
-    assertEquals(Set(pid), log.activePids.keySet)
+    assertEquals(Set(pid), log.activeProducers.keySet)
 
     time.sleep(expirationCheckInterval)
-    assertEquals(Set(pid), log.activePids.keySet)
+    assertEquals(Set(pid), log.activeProducers.keySet)
 
     time.sleep(expirationCheckInterval)
-    assertEquals(Set(), log.activePids.keySet)
+    assertEquals(Set(), log.activeProducers.keySet)
   }
 
   @Test
@@ -439,7 +706,7 @@ class LogTest {
     val logProps = new Properties()
 
     // create a log
-    val log = new Log(logDir,
+    val log = Log(logDir,
       LogConfig(logProps),
       recoveryPoint = 0L,
       scheduler = time.scheduler,
@@ -453,7 +720,7 @@ class LogTest {
     // Pad the beginning of the log.
     for (_ <- 0 to 5) {
       val record = TestUtils.records(List(new SimpleRecord(time.milliseconds, "key".getBytes, "value".getBytes)),
-        pid = pid, epoch = epoch, sequence = seq)
+        producerId = pid, producerEpoch = epoch, sequence = seq)
       log.appendAsLeader(record, leaderEpoch = 0)
       seq = seq + 1
     }
@@ -462,7 +729,7 @@ class LogTest {
       new SimpleRecord(time.milliseconds, s"key-$seq".getBytes, s"value-$seq".getBytes),
       new SimpleRecord(time.milliseconds, s"key-$seq".getBytes, s"value-$seq".getBytes),
       new SimpleRecord(time.milliseconds, s"key-$seq".getBytes, s"value-$seq".getBytes)
-    ), pid = pid, epoch = epoch, sequence = seq)
+    ), producerId = pid, producerEpoch = epoch, sequence = seq)
     val multiEntryAppendInfo = log.appendAsLeader(createRecords, leaderEpoch = 0)
     assertEquals("should have appended 3 entries", multiEntryAppendInfo.lastOffset - multiEntryAppendInfo.firstOffset + 1, 3)
 
@@ -481,7 +748,7 @@ class LogTest {
         List(
           new SimpleRecord(time.milliseconds, s"key-$seq".getBytes, s"value-$seq".getBytes),
           new SimpleRecord(time.milliseconds, s"key-$seq".getBytes, s"value-$seq".getBytes)),
-        pid = pid, epoch = epoch, sequence = seq - 2)
+        producerId = pid, producerEpoch = epoch, sequence = seq - 2)
       log.appendAsLeader(records, leaderEpoch = 0)
       fail ("Should have received an OutOfOrderSequenceException since we attempted to append a duplicate of a records " +
         "in the middle of the log.")
@@ -493,7 +760,7 @@ class LogTest {
      try {
       val records = TestUtils.records(
         List(new SimpleRecord(time.milliseconds, s"key-1".getBytes, s"value-1".getBytes)),
-        pid = pid, epoch = epoch, sequence = 1)
+        producerId = pid, producerEpoch = epoch, sequence = 1)
       log.appendAsLeader(records, leaderEpoch = 0)
       fail ("Should have received an OutOfOrderSequenceException since we attempted to append a duplicate of a records " +
         "in the middle of the log.")
@@ -503,7 +770,7 @@ class LogTest {
 
     // Append a duplicate entry with a single records at the tail of the log. This should return the appendInfo of the original entry.
     def createRecordsWithDuplicate = TestUtils.records(List(new SimpleRecord(time.milliseconds, "key".getBytes, "value".getBytes)),
-      pid = pid, epoch = epoch, sequence = seq)
+      producerId = pid, producerEpoch = epoch, sequence = seq)
     val origAppendInfo = log.appendAsLeader(createRecordsWithDuplicate, leaderEpoch = 0)
     val newAppendInfo = log.appendAsLeader(createRecordsWithDuplicate, leaderEpoch = 0)
     assertEquals("Inserted a duplicate records into the log", origAppendInfo.firstOffset, newAppendInfo.firstOffset)
@@ -515,7 +782,7 @@ class LogTest {
     val logProps = new Properties()
 
     // create a log
-    val log = new Log(logDir,
+    val log = Log(logDir,
       LogConfig(logProps),
       recoveryPoint = 0L,
       scheduler = time.scheduler,
@@ -632,7 +899,7 @@ class LogTest {
     val logProps = new Properties()
 
     // create a log
-    val log = new Log(logDir,
+    val log = Log(logDir,
       LogConfig(logProps),
       recoveryPoint = 0L,
       scheduler = time.scheduler,
@@ -643,10 +910,10 @@ class LogTest {
     val newEpoch: Short = 1
     val oldEpoch: Short = 0
 
-    val records = TestUtils.records(List(new SimpleRecord(time.milliseconds, "key".getBytes, "value".getBytes)), pid = pid, epoch = newEpoch, sequence = 0)
+    val records = TestUtils.records(List(new SimpleRecord(time.milliseconds, "key".getBytes, "value".getBytes)), producerId = pid, producerEpoch = newEpoch, sequence = 0)
     log.appendAsLeader(records, leaderEpoch = 0)
 
-    val nextRecords = TestUtils.records(List(new SimpleRecord(time.milliseconds, "key".getBytes, "value".getBytes)), pid = pid, epoch = oldEpoch, sequence = 0)
+    val nextRecords = TestUtils.records(List(new SimpleRecord(time.milliseconds, "key".getBytes, "value".getBytes)), producerId = pid, producerEpoch = oldEpoch, sequence = 0)
     log.appendAsLeader(nextRecords, leaderEpoch = 0)
   }
 
@@ -663,7 +930,7 @@ class LogTest {
     logProps.put(LogConfig.SegmentMsProp, 1 * 60 * 60L: java.lang.Long)
     logProps.put(LogConfig.SegmentJitterMsProp, maxJitter: java.lang.Long)
     // create a log
-    val log = new Log(logDir,
+    val log = Log(logDir,
       LogConfig(logProps),
       logStartOffset = 0L,
       recoveryPoint = 0L,
@@ -698,7 +965,7 @@ class LogTest {
     // We use need to use magic value 1 here because the test is message size sensitive.
     logProps.put(LogConfig.MessageFormatVersionProp, ApiVersion.latestVersion.toString)
     // create a log
-    val log = new Log(logDir, LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
+    val log = Log(logDir, LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats, time = time)
     assertEquals("There should be exactly 1 segment.", 1, log.numberOfSegments)
 
@@ -714,7 +981,7 @@ class LogTest {
   @Test
   def testLoadEmptyLog() {
     createEmptyLogs(logDir, 0)
-    val log = new Log(logDir, logConfig, logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
+    val log = Log(logDir, logConfig, logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats, time = time)
     log.appendAsLeader(TestUtils.singletonRecords(value = "test".getBytes, timestamp = time.milliseconds), leaderEpoch = 0)
   }
@@ -728,7 +995,7 @@ class LogTest {
     logProps.put(LogConfig.SegmentBytesProp, 71: java.lang.Integer)
     // We use need to use magic value 1 here because the test is message size sensitive.
     logProps.put(LogConfig.MessageFormatVersionProp, ApiVersion.latestVersion.toString)
-    val log = new Log(logDir, LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
+    val log = Log(logDir, LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats, time = time)
     val values = (0 until 100 by 2).map(id => id.toString.getBytes).toArray
 
@@ -754,7 +1021,7 @@ class LogTest {
   def testAppendAndReadWithNonSequentialOffsets() {
     val logProps = new Properties()
     logProps.put(LogConfig.SegmentBytesProp, 72: java.lang.Integer)
-    val log = new Log(logDir,  LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
+    val log = Log(logDir,  LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats, time = time)
     val messageIds = ((0 until 50) ++ (50 until 200 by 7)).toArray
     val records = messageIds.map(id => new SimpleRecord(id.toString.getBytes))
@@ -780,7 +1047,7 @@ class LogTest {
   def testReadAtLogGap() {
     val logProps = new Properties()
     logProps.put(LogConfig.SegmentBytesProp, 300: java.lang.Integer)
-    val log = new Log(logDir, LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
+    val log = Log(logDir, LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats, time = time)
 
     // keep appending until we have two segments with only a single message in the second segment
@@ -798,7 +1065,7 @@ class LogTest {
   def testReadWithMinMessage() {
     val logProps = new Properties()
     logProps.put(LogConfig.SegmentBytesProp, 72: java.lang.Integer)
-    val log = new Log(logDir,  LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
+    val log = Log(logDir,  LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats, time = time)
     val messageIds = ((0 until 50) ++ (50 until 200 by 7)).toArray
     val records = messageIds.map(id => new SimpleRecord(id.toString.getBytes))
@@ -827,7 +1094,7 @@ class LogTest {
   def testReadWithTooSmallMaxLength() {
     val logProps = new Properties()
     logProps.put(LogConfig.SegmentBytesProp, 72: java.lang.Integer)
-    val log = new Log(logDir,  LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
+    val log = Log(logDir,  LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats, time = time)
     val messageIds = ((0 until 50) ++ (50 until 200 by 7)).toArray
     val records = messageIds.map(id => new SimpleRecord(id.toString.getBytes))
@@ -864,7 +1131,7 @@ class LogTest {
 
     // set up replica log starting with offset 1024 and with one message (at offset 1024)
     logProps.put(LogConfig.SegmentBytesProp, 1024: java.lang.Integer)
-    val log = new Log(logDir, LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
+    val log = Log(logDir, LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats, time = time)
     log.appendAsLeader(TestUtils.singletonRecords(value = "42".getBytes), leaderEpoch = 0)
 
@@ -898,7 +1165,7 @@ class LogTest {
     /* create a multipart log with 100 messages */
     val logProps = new Properties()
     logProps.put(LogConfig.SegmentBytesProp, 100: java.lang.Integer)
-    val log = new Log(logDir, LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
+    val log = Log(logDir, LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats, time = time)
     val numMessages = 100
     val messageSets = (0 until numMessages).map(i => TestUtils.singletonRecords(value = i.toString.getBytes,
@@ -938,7 +1205,7 @@ class LogTest {
     /* this log should roll after every messageset */
     val logProps = new Properties()
     logProps.put(LogConfig.SegmentBytesProp, 110: java.lang.Integer)
-    val log = new Log(logDir, LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
+    val log = Log(logDir, LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats, time = time)
 
     /* append 2 compressed message sets, each with two messages giving offsets 0, 1, 2, 3 */
@@ -965,7 +1232,7 @@ class LogTest {
       val logProps = new Properties()
       logProps.put(LogConfig.SegmentBytesProp, 100: java.lang.Integer)
       logProps.put(LogConfig.RetentionMsProp, 0: java.lang.Integer)
-      val log = new Log(logDir, LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
+      val log = Log(logDir, LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
         brokerTopicStats = brokerTopicStats, time = time)
       for(i <- 0 until messagesToAppend)
         log.appendAsLeader(TestUtils.singletonRecords(value = i.toString.getBytes, timestamp = time.milliseconds - 10), leaderEpoch = 0)
@@ -1002,7 +1269,7 @@ class LogTest {
     logProps.put(LogConfig.SegmentBytesProp, configSegmentSize: java.lang.Integer)
     // We use need to use magic value 1 here because the test is message size sensitive.
     logProps.put(LogConfig.MessageFormatVersionProp, ApiVersion.latestVersion.toString)
-    val log = new Log(logDir, LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
+    val log = Log(logDir, LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats, time = time)
 
     try {
@@ -1030,7 +1297,7 @@ class LogTest {
     val logProps = new Properties()
     logProps.put(LogConfig.CleanupPolicyProp, LogConfig.Compact)
 
-    val log = new Log(logDir, LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
+    val log = Log(logDir, LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats, time = time)
 
     try {
@@ -1073,7 +1340,7 @@ class LogTest {
     val maxMessageSize = second.sizeInBytes - 1
     val logProps = new Properties()
     logProps.put(LogConfig.MaxMessageBytesProp, maxMessageSize: java.lang.Integer)
-    val log = new Log(logDir, LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
+    val log = Log(logDir, LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats, time = time)
 
     // should be able to append the small message
@@ -1100,7 +1367,7 @@ class LogTest {
     logProps.put(LogConfig.IndexIntervalBytesProp, indexInterval: java.lang.Integer)
     logProps.put(LogConfig.SegmentIndexBytesProp, 4096: java.lang.Integer)
     val config = LogConfig(logProps)
-    var log = new Log(logDir, config, logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
+    var log = Log(logDir, config, logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats, time = time)
     for(i <- 0 until numMessages)
       log.appendAsLeader(TestUtils.singletonRecords(value = TestUtils.randomBytes(messageSize),
@@ -1127,13 +1394,13 @@ class LogTest {
       assertEquals("Should have same number of time index entries as before.", numTimeIndexEntries, log.activeSegment.timeIndex.entries)
     }
 
-    log = new Log(logDir, config, logStartOffset = 0L, recoveryPoint = lastOffset, scheduler = time.scheduler,
+    log = Log(logDir, config, logStartOffset = 0L, recoveryPoint = lastOffset, scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats, time = time)
     verifyRecoveredLog(log)
     log.close()
 
     // test recovery case
-    log = new Log(logDir, config, logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
+    log = Log(logDir, config, logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats, time = time)
     verifyRecoveredLog(log)
     log.close()
@@ -1150,7 +1417,7 @@ class LogTest {
     logProps.put(LogConfig.IndexIntervalBytesProp, 1: java.lang.Integer)
 
     val config = LogConfig(logProps)
-    val log = new Log(logDir, config, logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
+    val log = Log(logDir, config, logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats, time = time)
 
     val messages = (0 until numMessages).map { i =>
@@ -1175,7 +1442,7 @@ class LogTest {
     logProps.put(LogConfig.IndexIntervalBytesProp, 1: java.lang.Integer)
 
     val config = LogConfig(logProps)
-    var log = new Log(logDir, config, logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
+    var log = Log(logDir, config, logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats, time = time)
     for(i <- 0 until numMessages)
       log.appendAsLeader(TestUtils.singletonRecords(value = TestUtils.randomBytes(10), timestamp = time.milliseconds + i * 10), leaderEpoch = 0)
@@ -1188,7 +1455,7 @@ class LogTest {
     timeIndexFiles.foreach(_.delete())
 
     // reopen the log
-    log = new Log(logDir, config, logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
+    log = Log(logDir, config, logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats, time = time)
     assertEquals("Should have %d messages when log is reopened".format(numMessages), numMessages, log.logEndOffset)
     assertTrue("The index should have been rebuilt", log.logSegments.head.index.entries > 0)
@@ -1216,7 +1483,7 @@ class LogTest {
     logProps.put(LogConfig.MessageFormatVersionProp, "0.9.0")
 
     val config = LogConfig(logProps)
-    var log = new Log(logDir, config, logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
+    var log = Log(logDir, config, logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats, time = time)
     for(i <- 0 until numMessages)
       log.appendAsLeader(TestUtils.singletonRecords(value = TestUtils.randomBytes(10),
@@ -1228,7 +1495,7 @@ class LogTest {
     timeIndexFiles.foreach(_.delete())
 
     // The rebuilt time index should be empty
-    log = new Log(logDir, config, logStartOffset = 0L, recoveryPoint = numMessages + 1, scheduler = time.scheduler,
+    log = Log(logDir, config, logStartOffset = 0L, recoveryPoint = numMessages + 1, scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats, time = time)
     val segArray = log.logSegments.toArray
     for (i <- segArray.indices.init) {
@@ -1249,7 +1516,7 @@ class LogTest {
     logProps.put(LogConfig.IndexIntervalBytesProp, 1: java.lang.Integer)
 
     val config = LogConfig(logProps)
-    var log = new Log(logDir, config, logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
+    var log = Log(logDir, config, logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats, time = time)
     for(i <- 0 until numMessages)
       log.appendAsLeader(TestUtils.singletonRecords(value = TestUtils.randomBytes(10), timestamp = time.milliseconds + i * 10), leaderEpoch = 0)
@@ -1272,7 +1539,7 @@ class LogTest {
     }
 
     // reopen the log
-    log = new Log(logDir, config, logStartOffset = 0L, recoveryPoint = 200L, scheduler = time.scheduler,
+    log = Log(logDir, config, logStartOffset = 0L, recoveryPoint = 200L, scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats, time = time)
     assertEquals("Should have %d messages when log is reopened".format(numMessages), numMessages, log.logEndOffset)
     for(i <- 0 until numMessages) {
@@ -1299,7 +1566,7 @@ class LogTest {
     logProps.put(LogConfig.SegmentBytesProp, segmentSize: java.lang.Integer)
 
     // create a log
-    val log = new Log(logDir, LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
+    val log = Log(logDir, LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats, time = time)
     assertEquals("There should be exactly 1 segment.", 1, log.numberOfSegments)
 
@@ -1355,7 +1622,7 @@ class LogTest {
     logProps.put(LogConfig.SegmentBytesProp, segmentSize: java.lang.Integer)
     logProps.put(LogConfig.IndexIntervalBytesProp, setSize - 1: java.lang.Integer)
     val config = LogConfig(logProps)
-    val log = new Log(logDir, config, logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
+    val log = Log(logDir, config, logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats, time = time)
     assertEquals("There should be exactly 1 segment.", 1, log.numberOfSegments)
 
@@ -1398,13 +1665,13 @@ class LogTest {
     logProps.put(LogConfig.SegmentBytesProp, createRecords.sizeInBytes * 5: java.lang.Integer)
     logProps.put(LogConfig.SegmentIndexBytesProp, 1000: java.lang.Integer)
     logProps.put(LogConfig.IndexIntervalBytesProp, 1: java.lang.Integer)
-    val log = new Log(logDir,
-                      LogConfig(logProps),
-                      logStartOffset = 0L,
-                      recoveryPoint = 0L,
-                      scheduler = time.scheduler,
-                      brokerTopicStats = brokerTopicStats,
-                      time = time)
+    val log = Log(logDir,
+      LogConfig(logProps),
+      logStartOffset = 0L,
+      recoveryPoint = 0L,
+      scheduler = time.scheduler,
+      brokerTopicStats = brokerTopicStats,
+      time = time)
 
     assertTrue("The first index file should have been replaced with a larger file", bogusIndex1.length > 0)
     assertTrue("The first time index file should have been replaced with a larger file", bogusTimeIndex1.length > 0)
@@ -1431,25 +1698,25 @@ class LogTest {
     val config = LogConfig(logProps)
 
     // create a log
-    var log = new Log(logDir,
-                      config,
-                      logStartOffset = 0L,
-                      recoveryPoint = 0L,
-                      scheduler = time.scheduler,
-                      brokerTopicStats = brokerTopicStats,
-                      time = time)
+    var log = Log(logDir,
+      config,
+      logStartOffset = 0L,
+      recoveryPoint = 0L,
+      scheduler = time.scheduler,
+      brokerTopicStats = brokerTopicStats,
+      time = time)
 
     // add enough messages to roll over several segments then close and re-open and attempt to truncate
     for (_ <- 0 until 100)
       log.appendAsLeader(createRecords, leaderEpoch = 0)
     log.close()
-    log = new Log(logDir,
-                  config,
-                  logStartOffset = 0L,
-                  recoveryPoint = 0L,
-                  scheduler = time.scheduler,
-                  brokerTopicStats = brokerTopicStats,
-                  time = time)
+    log = Log(logDir,
+      config,
+      logStartOffset = 0L,
+      recoveryPoint = 0L,
+      scheduler = time.scheduler,
+      brokerTopicStats = brokerTopicStats,
+      time = time)
     log.truncateTo(3)
     assertEquals("All but one segment should be deleted.", 1, log.numberOfSegments)
     assertEquals("Log end offset should be 3.", 3, log.logEndOffset)
@@ -1470,13 +1737,13 @@ class LogTest {
     logProps.put(LogConfig.RetentionMsProp, 999: java.lang.Integer)
     val config = LogConfig(logProps)
 
-    val log = new Log(logDir,
-                      config,
-                      logStartOffset = 0L,
-                      recoveryPoint = 0L,
-                      scheduler = time.scheduler,
-                      brokerTopicStats = brokerTopicStats,
-                      time = time)
+    val log = Log(logDir,
+      config,
+      logStartOffset = 0L,
+      recoveryPoint = 0L,
+      scheduler = time.scheduler,
+      brokerTopicStats = brokerTopicStats,
+      time = time)
 
     // append some messages to create some segments
     for (_ <- 0 until 100)
@@ -1512,13 +1779,13 @@ class LogTest {
     logProps.put(LogConfig.SegmentIndexBytesProp, 1000: java.lang.Integer)
     logProps.put(LogConfig.RetentionMsProp, 999: java.lang.Integer)
     val config = LogConfig(logProps)
-    var log = new Log(logDir,
-                      config,
-                      logStartOffset = 0L,
-                      recoveryPoint = 0L,
-                      scheduler = time.scheduler,
-                      brokerTopicStats = brokerTopicStats,
-                      time = time)
+    var log = Log(logDir,
+      config,
+      logStartOffset = 0L,
+      recoveryPoint = 0L,
+      scheduler = time.scheduler,
+      brokerTopicStats = brokerTopicStats,
+      time = time)
 
     // append some messages to create some segments
     for (_ <- 0 until 100)
@@ -1528,25 +1795,25 @@ class LogTest {
     log.deleteOldSegments()
     log.close()
 
-    log = new Log(logDir,
-                  config,
-                  logStartOffset = 0L,
-                  recoveryPoint = 0L,
-                  scheduler = time.scheduler,
-                  brokerTopicStats = brokerTopicStats,
-                  time = time)
+    log = Log(logDir,
+      config,
+      logStartOffset = 0L,
+      recoveryPoint = 0L,
+      scheduler = time.scheduler,
+      brokerTopicStats = brokerTopicStats,
+      time = time)
     assertEquals("The deleted segments should be gone.", 1, log.numberOfSegments)
   }
 
   @Test
   def testAppendMessageWithNullPayload() {
-    val log = new Log(logDir,
-                      LogConfig(),
-                      logStartOffset = 0L,
-                      recoveryPoint = 0L,
-                      scheduler = time.scheduler,
-                      brokerTopicStats = brokerTopicStats,
-                      time = time)
+    val log = Log(logDir,
+      LogConfig(),
+      logStartOffset = 0L,
+      recoveryPoint = 0L,
+      scheduler = time.scheduler,
+      brokerTopicStats = brokerTopicStats,
+      time = time)
     log.appendAsLeader(TestUtils.singletonRecords(value = null), leaderEpoch = 0)
     val head = log.readUncommitted(0, 4096, None).records.records.iterator.next()
     assertEquals(0, head.offset)
@@ -1555,7 +1822,7 @@ class LogTest {
 
   @Test(expected = classOf[IllegalArgumentException])
   def testAppendWithOutOfOrderOffsetsThrowsException() {
-    val log = new Log(logDir,
+    val log = Log(logDir,
       LogConfig(),
       logStartOffset = 0L,
       recoveryPoint = 0L,
@@ -1570,7 +1837,7 @@ class LogTest {
 
   @Test
   def testAppendWithNoTimestamp(): Unit = {
-    val log = new Log(logDir,
+    val log = Log(logDir,
       LogConfig(),
       logStartOffset = 0L,
       recoveryPoint = 0L,
@@ -1594,13 +1861,13 @@ class LogTest {
     for (_ <- 0 until 10) {
       // create a log and write some messages to it
       logDir.mkdirs()
-      var log = new Log(logDir,
-                        config,
-                        logStartOffset = 0L,
-                        recoveryPoint = 0L,
-                        scheduler = time.scheduler,
-                        brokerTopicStats = brokerTopicStats,
-                        time = time)
+      var log = Log(logDir,
+        config,
+        logStartOffset = 0L,
+        recoveryPoint = 0L,
+        scheduler = time.scheduler,
+        brokerTopicStats = brokerTopicStats,
+        time = time)
       val numMessages = 50 + TestUtils.random.nextInt(50)
       for (_ <- 0 until numMessages)
         log.appendAsLeader(createRecords, leaderEpoch = 0)
@@ -1612,7 +1879,7 @@ class LogTest {
       TestUtils.appendNonsenseToFile(log.activeSegment.log.file, TestUtils.random.nextInt(1024) + 1)
 
       // attempt recovery
-      log = new Log(logDir, config, 0L, recoveryPoint, time.scheduler, brokerTopicStats, time)
+      log = Log(logDir, config, 0L, recoveryPoint, time.scheduler, brokerTopicStats, time)
       assertEquals(numMessages, log.logEndOffset)
 
       val recovered = log.logSegments.flatMap(_.log.records.asScala.toList).toList
@@ -1638,7 +1905,7 @@ class LogTest {
     logProps.put(LogConfig.MaxMessageBytesProp, 64*1024: java.lang.Integer)
     logProps.put(LogConfig.IndexIntervalBytesProp, 1: java.lang.Integer)
     val config = LogConfig(logProps)
-    val log = new Log(logDir,
+    val log = Log(logDir,
       config,
       logStartOffset = 0L,
       recoveryPoint = 0L,
@@ -1683,14 +1950,12 @@ class LogTest {
     logProps.put(LogConfig.IndexIntervalBytesProp, 1: java.lang.Integer)
     val config = LogConfig(logProps)
     def createRecords = TestUtils.singletonRecords(value = "test".getBytes, timestamp = time.milliseconds)
-    val parentLogDir = logDir.getParentFile
-    assertTrue("Data directory %s must exist", parentLogDir.isDirectory)
-    val cleanShutdownFile = new File(parentLogDir, Log.CleanShutdownFile)
-    cleanShutdownFile.createNewFile()
+
+    val cleanShutdownFile = createCleanShutdownFile()
     assertTrue(".kafka_cleanshutdown must exist", cleanShutdownFile.exists())
     var recoveryPoint = 0L
     // create a log and write some messages to it
-    var log = new Log(logDir,
+    var log = Log(logDir,
       config,
       logStartOffset = 0L,
       recoveryPoint = 0L,
@@ -1704,7 +1969,7 @@ class LogTest {
     // check if recovery was attempted. Even if the recovery point is 0L, recovery should not be attempted as the
     // clean shutdown file exists.
     recoveryPoint = log.logEndOffset
-    log = new Log(logDir, config, 0L, 0L, time.scheduler, brokerTopicStats, time)
+    log = Log(logDir, config, 0L, 0L, time.scheduler, brokerTopicStats, time)
     assertEquals(recoveryPoint, log.logEndOffset)
     cleanShutdownFile.delete()
   }
@@ -1869,7 +2134,7 @@ class LogTest {
     logProps.put(LogConfig.SegmentIndexBytesProp, 1000: java.lang.Integer)
     logProps.put(LogConfig.RetentionMsProp, 999: java.lang.Integer)
     val config = LogConfig(logProps)
-    val log = new Log(logDir,
+    val log = Log(logDir,
       config,
       logStartOffset = 0L,
       recoveryPoint = 0L,
@@ -2023,7 +2288,7 @@ class LogTest {
 
     //Given this partition is on leader epoch 72
     val epoch = 72
-    val log = new Log(logDir, LogConfig(), recoveryPoint = 0L, scheduler = time.scheduler,
+    val log = Log(logDir, LogConfig(), recoveryPoint = 0L, scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats, time = time)
     log.leaderEpochCache.assign(epoch, records.size)
 
@@ -2056,7 +2321,7 @@ class LogTest {
       recs
     }
 
-    val log = new Log(logDir, LogConfig(), recoveryPoint = 0L, scheduler = time.scheduler,
+    val log = Log(logDir, LogConfig(), recoveryPoint = 0L, scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats, time = time)
 
     //When appending as follower (assignOffsets = false)
@@ -2116,7 +2381,7 @@ class LogTest {
   def shouldTruncateLeaderEpochFileWhenTruncatingLog() {
     def createRecords = TestUtils.singletonRecords(value = "test".getBytes, timestamp = time.milliseconds)
     val logProps = CoreUtils.propsWith(LogConfig.SegmentBytesProp, (10 * createRecords.sizeInBytes).toString)
-    val log = new Log(logDir, LogConfig( logProps), recoveryPoint = 0L, scheduler = time.scheduler,
+    val log = Log(logDir, LogConfig( logProps), recoveryPoint = 0L, scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats, time = time)
     val cache = epochCache(log)
 
@@ -2162,7 +2427,7 @@ class LogTest {
     */
   @Test
   def testLogRecoversForLeaderEpoch() {
-    val log = new Log(logDir, LogConfig(new Properties()), recoveryPoint = 0L, scheduler = time.scheduler,
+    val log = Log(logDir, LogConfig(new Properties()), recoveryPoint = 0L, scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats, time = time)
     val leaderEpochCache = epochCache(log)
     val firstBatch = singletonRecordsWithLeaderEpoch(value = "random".getBytes, leaderEpoch = 1, offset = 0)
@@ -2185,7 +2450,7 @@ class LogTest {
     log.close()
 
     // reopen the log and recover from the beginning
-    val recoveredLog = new Log(logDir, LogConfig(new Properties()), recoveryPoint = 0L, scheduler = time.scheduler,
+    val recoveredLog = Log(logDir, LogConfig(new Properties()), recoveryPoint = 0L, scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats, time = time)
     val recoveredLeaderEpochCache = epochCache(recoveredLog)
 
@@ -2517,6 +2782,54 @@ class LogTest {
   }
 
   @Test
+  def testFirstUnstableOffsetDoesNotExceedLogStartOffsetMidSegment(): Unit = {
+    val log = createLog(1024 * 1024)
+    val epoch = 0.toShort
+    val pid = 1L
+    val appendPid = appendTransactionalAsLeader(log, pid, epoch)
+
+    appendPid(5)
+    appendNonTransactionalAsLeader(log, 3)
+    assertEquals(8L, log.logEndOffset)
+
+    log.roll()
+    assertEquals(2, log.logSegments.size)
+    appendPid(5)
+
+    assertEquals(Some(0L), log.firstUnstableOffset.map(_.messageOffset))
+
+    log.maybeIncrementLogStartOffset(5L)
+
+    // the first unstable offset should be lower bounded by the log start offset
+    assertEquals(Some(5L), log.firstUnstableOffset.map(_.messageOffset))
+  }
+
+  @Test
+  def testFirstUnstableOffsetDoesNotExceedLogStartOffsetAfterSegmentDeletion(): Unit = {
+    val log = createLog(1024 * 1024)
+    val epoch = 0.toShort
+    val pid = 1L
+    val appendPid = appendTransactionalAsLeader(log, pid, epoch)
+
+    appendPid(5)
+    appendNonTransactionalAsLeader(log, 3)
+    assertEquals(8L, log.logEndOffset)
+
+    log.roll()
+    assertEquals(2, log.logSegments.size)
+    appendPid(5)
+
+    assertEquals(Some(0L), log.firstUnstableOffset.map(_.messageOffset))
+
+    log.maybeIncrementLogStartOffset(8L)
+    log.deleteOldSegments()
+    assertEquals(1, log.logSegments.size)
+
+    // the first unstable offset should be lower bounded by the log start offset
+    assertEquals(Some(8L), log.firstUnstableOffset.map(_.messageOffset))
+  }
+
+  @Test
   def testLastStableOffsetWithMixedProducerData() {
     val log = createLog(1024 * 1024)
 
@@ -2611,7 +2924,7 @@ class LogTest {
   private def createLog(messageSizeInBytes: Int, retentionMs: Int = -1, retentionBytes: Int = -1,
                         cleanupPolicy: String = "delete", messagesPerSegment: Int = 5,
                         maxPidExpirationMs: Int = 300000, pidExpirationCheckIntervalMs: Int = 30000,
-                        recoveryPoint: Long = 0L): Log = {
+                        recoveryPoint: Long = 0L, logStartOffset: Long = 0L): Log = {
     val logProps = new Properties()
     logProps.put(LogConfig.SegmentBytesProp, messageSizeInBytes * messagesPerSegment: Integer)
     logProps.put(LogConfig.RetentionMsProp, retentionMs: Integer)
@@ -2619,9 +2932,9 @@ class LogTest {
     logProps.put(LogConfig.CleanupPolicyProp, cleanupPolicy)
     logProps.put(LogConfig.MessageTimestampDifferenceMaxMsProp, Long.MaxValue.toString)
     val config = LogConfig(logProps)
-    new Log(logDir,
+    Log(logDir,
       config,
-      logStartOffset = 0L,
+      logStartOffset = logStartOffset,
       recoveryPoint = recoveryPoint,
       scheduler = time.scheduler,
       brokerTopicStats = brokerTopicStats,
@@ -2692,4 +3005,13 @@ class LogTest {
     log.appendAsFollower(records)
   }
 
+  private def createCleanShutdownFile(): File = {
+    val parentLogDir = logDir.getParentFile
+    assertTrue("Data directory %s must exist", parentLogDir.isDirectory)
+    val cleanShutdownFile = new File(parentLogDir, Log.CleanShutdownFile)
+    cleanShutdownFile.createNewFile()
+    assertTrue(".kafka_cleanshutdown must exist", cleanShutdownFile.exists())
+    cleanShutdownFile
+  }
+
 }