You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jg...@apache.org on 2018/03/23 06:51:11 UTC

[kafka] branch 1.0 updated: KAFKA-6683; Ensure producer state not mutated prior to append (#4755)

This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch 1.0
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/1.0 by this push:
     new 62ee1fd  KAFKA-6683; Ensure producer state not mutated prior to append (#4755)
62ee1fd is described below

commit 62ee1fda2f1c38fd8ef82e8f69d285429a0a054b
Author: Jason Gustafson <ja...@confluent.io>
AuthorDate: Thu Mar 22 21:42:49 2018 -0700

    KAFKA-6683; Ensure producer state not mutated prior to append (#4755)
    
    We were unintentionally mutating the cached queue of batches prior to appending to the log. This could have several bad consequences if the append ultimately failed or was truncated. In the reporter's case, it caused the snapshot to be invalid after a segment roll. The snapshot contained producer state at offsets higher than the snapshot offset. If we ever had to load from that snapshot, the state was left inconsistent, which led to an error that ultimately crashed the replica fetcher.
    
    The fix required some refactoring to avoid sharing the same underlying queue inside ProducerAppendInfo. I have added test cases which reproduce the invalid snapshot state. I have also made an effort to clean up logging since it was not easy to track this problem down.
    
    One final note: I have removed the duplicate check inside ProducerStateManager since it was both redundant and incorrect. The redundancy was in the checking of the cached batches: we already check these in Log.analyzeAndValidateProducerState. The incorrectness was the handling of sequence number overflow: we were only handling one very specific case of overflow, but others would have resulted in an invalid assertion. Instead, we now throw OutOfOrderSequenceException.
    
    Reviewers: Apurva Mehta <ap...@confluent.io>, Jun Rao <ju...@gmail.com>
---
 core/src/main/scala/kafka/log/Log.scala            |  72 ++++---
 .../scala/kafka/log/ProducerStateManager.scala     | 239 ++++++++++++---------
 .../main/scala/kafka/tools/DumpLogSegments.scala   |  10 +-
 .../test/scala/unit/kafka/log/LogSegmentTest.scala |   2 +-
 core/src/test/scala/unit/kafka/log/LogTest.scala   |  43 +++-
 .../unit/kafka/log/ProducerStateManagerTest.scala  |  81 +++++--
 6 files changed, 279 insertions(+), 168 deletions(-)

diff --git a/core/src/main/scala/kafka/log/Log.scala b/core/src/main/scala/kafka/log/Log.scala
index eb455ef..caa7bf5 100644
--- a/core/src/main/scala/kafka/log/Log.scala
+++ b/core/src/main/scala/kafka/log/Log.scala
@@ -149,6 +149,8 @@ class Log(@volatile var dir: File,
 
   import kafka.log.Log._
 
+  this.logIdent = s"[Log partition=$topicPartition, dir=${dir.getParent}] "
+
   /* A lock that guards all modifications to the log */
   private val lock = new Object
   // The memory mapped buffer for index files of this log will be closed with either delete() or closeHandlers()
@@ -214,8 +216,8 @@ class Log(@volatile var dir: File,
 
     loadProducerState(logEndOffset, reloadFromCleanShutdown = hasCleanShutdownFile)
 
-    info("Completed load of log %s with %d log segments, log start offset %d and log end offset %d in %d ms"
-      .format(name, segments.size(), logStartOffset, logEndOffset, time.milliseconds - startMs))
+    info(s"Completed load of log with ${segments.size} segments, log start offset $logStartOffset and " +
+      s"log end offset $logEndOffset in ${time.milliseconds() - startMs} ms")
   }
 
   private val tags = Map("topic" -> topicPartition.topic, "partition" -> topicPartition.partition.toString)
@@ -448,21 +450,20 @@ class Log(@volatile var dir: File,
       val unflushed = logSegments(this.recoveryPoint, Long.MaxValue).iterator
       while (unflushed.hasNext) {
         val segment = unflushed.next
-        info("Recovering unflushed segment %d in log %s.".format(segment.baseOffset, name))
+        info(s"Recovering unflushed segment ${segment.baseOffset}")
         val truncatedBytes =
           try {
             recoverSegment(segment, Some(leaderEpochCache))
           } catch {
             case _: InvalidOffsetException =>
               val startOffset = segment.baseOffset
-              warn("Found invalid offset during recovery for log " + dir.getName + ". Deleting the corrupt segment and " +
-                "creating an empty one with starting offset " + startOffset)
+              warn("Found invalid offset during recovery. Deleting the corrupt segment and " +
+                s"creating an empty one with starting offset $startOffset")
               segment.truncateTo(startOffset)
           }
         if (truncatedBytes > 0) {
           // we had an invalid message, delete all remaining log
-          warn("Corruption found in segment %d of log %s, truncating to offset %d.".format(segment.baseOffset, name,
-            segment.nextOffset()))
+          warn(s"Corruption found in segment ${segment.baseOffset}, truncating to offset ${segment.nextOffset}")
           unflushed.foreach(deleteSegment)
         }
       }
@@ -474,8 +475,7 @@ class Log(@volatile var dir: File,
   private def loadProducerState(lastOffset: Long, reloadFromCleanShutdown: Boolean): Unit = lock synchronized {
     checkIfMemoryMappedBufferClosed()
     val messageFormatVersion = config.messageFormatVersion.messageFormatVersion.value
-    info(s"Loading producer state from offset $lastOffset for partition $topicPartition with message " +
-      s"format version $messageFormatVersion")
+    info(s"Loading producer state from offset $lastOffset with message format version $messageFormatVersion")
 
     // We want to avoid unnecessary scanning of the log to build the producer state when the broker is being
     // upgraded. The basic idea is to use the absence of producer snapshot files to detect the upgrade case,
@@ -562,7 +562,7 @@ class Log(@volatile var dir: File,
    * The memory mapped buffer for index files of this log will be left open until the log is deleted.
    */
   def close() {
-    debug(s"Closing log $name")
+    debug("Closing log")
     lock synchronized {
       checkIfMemoryMappedBufferClosed()
       maybeHandleIOException(s"Error while renaming dir for $topicPartition in dir ${dir.getParent}") {
@@ -579,7 +579,7 @@ class Log(@volatile var dir: File,
    * Close file handlers used by log but don't write to disk. This is used when the disk may have failed
    */
   def closeHandlers() {
-    debug(s"Closing handlers of log $name")
+    debug("Closing handlers")
     lock synchronized {
       logSegments.foreach(_.closeHandlers())
       isMemoryMappedBufferClosed = true
@@ -745,8 +745,10 @@ class Log(@volatile var dir: File,
         // update the first unstable offset (which is used to compute LSO)
         updateFirstUnstableOffset()
 
-        trace("Appended message set to log %s with first offset: %d, next offset: %d, and messages: %s"
-          .format(this.name, appendInfo.firstOffset, nextOffsetMetadata.messageOffset, validRecords))
+        trace(s"Appended message set to log with last offset ${appendInfo.lastOffset} " +
+          s"first offset: ${appendInfo.firstOffset}, " +
+          s"next offset: ${nextOffsetMetadata.messageOffset}, " +
+          s"and messages: $validRecords")
 
         if (unflushedMessages >= config.flushInterval)
           flush()
@@ -776,7 +778,7 @@ class Log(@volatile var dir: File,
     }
 
     if (updatedFirstStableOffset != this.firstUnstableOffset) {
-      debug(s"First unstable offset for ${this.name} updated to $updatedFirstStableOffset")
+      debug(s"First unstable offset updated to $updatedFirstStableOffset")
       this.firstUnstableOffset = updatedFirstStableOffset
     }
   }
@@ -792,7 +794,7 @@ class Log(@volatile var dir: File,
       lock synchronized {
         checkIfMemoryMappedBufferClosed()
         if (newLogStartOffset > logStartOffset) {
-          info(s"Incrementing log start offset of partition $topicPartition to $newLogStartOffset in dir ${dir.getParent}")
+          info(s"Incrementing log start offset to $newLogStartOffset")
           logStartOffset = newLogStartOffset
           leaderEpochCache.clearAndFlushEarliest(logStartOffset)
           producerStateManager.truncateHead(logStartOffset)
@@ -809,12 +811,13 @@ class Log(@volatile var dir: File,
     for (batch <- records.batches.asScala if batch.hasProducerId) {
       val maybeLastEntry = producerStateManager.lastEntry(batch.producerId)
 
-      // if this is a client produce request, there will be upto 5 batches which could have been duplicated.
+      // if this is a client produce request, there will be up to 5 batches which could have been duplicated.
       // If we find a duplicate, we return the metadata of the appended batch to the client.
-      if (isFromClient)
-        maybeLastEntry.flatMap(_.duplicateOf(batch)).foreach { duplicate =>
+      if (isFromClient) {
+        maybeLastEntry.flatMap(_.findDuplicateBatch(batch)).foreach { duplicate =>
           return (updatedProducers, completedTxns.toList, Some(duplicate))
         }
+      }
 
       val maybeCompletedTxn = updateProducers(batch, updatedProducers, isFromClient = isFromClient)
       maybeCompletedTxn.foreach(completedTxns += _)
@@ -956,7 +959,7 @@ class Log(@volatile var dir: File,
   def read(startOffset: Long, maxLength: Int, maxOffset: Option[Long] = None, minOneMessage: Boolean = false,
            isolationLevel: IsolationLevel): FetchDataInfo = {
     maybeHandleIOException(s"Exception while reading from $topicPartition in dir ${dir.getParent}") {
-      trace("Reading %d bytes from offset %d in log %s of length %d bytes".format(maxLength, startOffset, name, size))
+      trace(s"Reading $maxLength bytes from offset $startOffset of length $size bytes")
 
       // Because we don't use lock for reading, the synchronization is a little bit tricky.
       // We create the local variables to avoid race conditions with updates to the log.
@@ -1267,10 +1270,11 @@ class Log(@volatile var dir: File,
     if (segment.size > config.segmentSize - messagesSize ||
         (segment.size > 0 && reachedRollMs) ||
         segment.index.isFull || segment.timeIndex.isFull || !segment.canConvertToRelativeOffset(maxOffsetInMessages)) {
-      debug(s"Rolling new log segment in $name (log_size = ${segment.size}/${config.segmentSize}}, " +
-          s"index_size = ${segment.index.entries}/${segment.index.maxEntries}, " +
-          s"time_index_size = ${segment.timeIndex.entries}/${segment.timeIndex.maxEntries}, " +
-          s"inactive_time_ms = ${segment.timeWaitedForRoll(now, maxTimestampInMessages)}/${config.segmentMs - segment.rollJitterMs}).")
+      debug(s"Rolling new log segment (log_size = ${segment.size}/${config.segmentSize}}, " +
+        s"offset_index_size = ${segment.index.entries}/${segment.index.maxEntries}, " +
+        s"time_index_size = ${segment.timeIndex.entries}/${segment.timeIndex.maxEntries}, " +
+        s"inactive_time_ms = ${segment.timeWaitedForRoll(now, maxTimestampInMessages)}/${config.segmentMs - segment.rollJitterMs}).")
+
       /*
         maxOffsetInMessages - Integer.MAX_VALUE is a heuristic value for the first offset in the set of messages.
         Since the offset in messages will not differ by more than Integer.MAX_VALUE, this is guaranteed <= the real
@@ -1295,7 +1299,7 @@ class Log(@volatile var dir: File,
    */
   def roll(expectedNextOffset: Long = 0): LogSegment = {
     maybeHandleIOException(s"Error while rolling log segment for $topicPartition in dir ${dir.getParent}") {
-      val start = time.nanoseconds
+      val start = time.hiResClockMs()
       lock synchronized {
         checkIfMemoryMappedBufferClosed()
         val newOffset = math.max(expectedNextOffset, logEndOffset)
@@ -1342,7 +1346,7 @@ class Log(@volatile var dir: File,
         // schedule an asynchronous flush of the old segment
         scheduler.schedule("flush-log", () => flush(newOffset), delay = 0L)
 
-        info("Rolled new log segment for '" + name + "' in %.0f ms.".format((System.nanoTime - start) / (1000.0 * 1000.0)))
+        info(s"Rolled new log segment at offset $newOffset in ${time.hiResClockMs() - start} ms.")
 
         segment
       }
@@ -1352,7 +1356,7 @@ class Log(@volatile var dir: File,
   /**
    * The number of messages appended to the log since the last flush
    */
-  def unflushedMessages() = this.logEndOffset - this.recoveryPoint
+  def unflushedMessages: Long = this.logEndOffset - this.recoveryPoint
 
   /**
    * Flush all log segments
@@ -1368,8 +1372,8 @@ class Log(@volatile var dir: File,
     maybeHandleIOException(s"Error while flushing log for $topicPartition in dir ${dir.getParent} with offset $offset") {
       if (offset <= this.recoveryPoint)
         return
-      debug("Flushing log '" + name + " up to offset " + offset + ", last flushed: " + lastFlushTime + " current time: " +
-        time.milliseconds + " unflushed = " + unflushedMessages)
+      debug(s"Flushing log up to offset $offset, last flushed: $lastFlushTime,  current time: ${time.milliseconds()}, " +
+        s"unflushed: $unflushedMessages")
       for (segment <- logSegments(this.recoveryPoint, offset))
         segment.flush()
 
@@ -1466,12 +1470,12 @@ class Log(@volatile var dir: File,
   private[log] def truncateTo(targetOffset: Long): Boolean = {
     maybeHandleIOException(s"Error while truncating log to offset $targetOffset for $topicPartition in dir ${dir.getParent}") {
       if (targetOffset < 0)
-        throw new IllegalArgumentException("Cannot truncate to a negative offset (%d).".format(targetOffset))
+        throw new IllegalArgumentException(s"Cannot truncate partition $topicPartition to a negative offset ($targetOffset).")
       if (targetOffset >= logEndOffset) {
-        info("Truncating %s to %d has no effect as the largest offset in the log is %d.".format(name, targetOffset, logEndOffset - 1))
+        info(s"Truncating to $targetOffset has no effect as the largest offset in the log is ${logEndOffset - 1}")
         false
       } else {
-        info("Truncating log %s to offset %d.".format(name, targetOffset))
+        info(s"Truncating to offset $targetOffset")
         lock synchronized {
           checkIfMemoryMappedBufferClosed()
           if (segments.firstEntry.getValue.baseOffset > targetOffset) {
@@ -1499,7 +1503,7 @@ class Log(@volatile var dir: File,
    */
   private[log] def truncateFullyAndStartAt(newOffset: Long) {
     maybeHandleIOException(s"Error while truncating the entire log for $topicPartition in dir ${dir.getParent}") {
-      debug(s"Truncate and start log '$name' at offset $newOffset")
+      debug(s"Truncate and start at offset $newOffset")
       lock synchronized {
         checkIfMemoryMappedBufferClosed()
         val segmentsToDelete = logSegments.toList
@@ -1573,7 +1577,7 @@ class Log(@volatile var dir: File,
    * @param segment The log segment to schedule for deletion
    */
   private def deleteSegment(segment: LogSegment) {
-    info("Scheduling log segment %d for log %s for deletion.".format(segment.baseOffset, name))
+    info(s"Scheduling log segment [baseOffset ${segment.baseOffset}, size ${segment.size}] for deletion.")
     lock synchronized {
       segments.remove(segment.baseOffset)
       asyncDeleteSegment(segment)
@@ -1591,7 +1595,7 @@ class Log(@volatile var dir: File,
   private def asyncDeleteSegment(segment: LogSegment) {
     segment.changeFileSuffixes("", Log.DeletedFileSuffix)
     def deleteSeg() {
-      info("Deleting segment %d from log %s.".format(segment.baseOffset, name))
+      info(s"Deleting segment ${segment.baseOffset}")
       maybeHandleIOException(s"Error while deleting segments for $topicPartition in dir ${dir.getParent}") {
         segment.delete()
       }
diff --git a/core/src/main/scala/kafka/log/ProducerStateManager.scala b/core/src/main/scala/kafka/log/ProducerStateManager.scala
index 63c1f56..d4ce104 100644
--- a/core/src/main/scala/kafka/log/ProducerStateManager.scala
+++ b/core/src/main/scala/kafka/log/ProducerStateManager.scala
@@ -70,9 +70,9 @@ private[log] case class TxnMetadata(producerId: Long, var firstOffset: LogOffset
   }
 }
 
-private[log] object ProducerIdEntry {
+private[log] object ProducerStateEntry {
   private[log] val NumBatchesToRetain = 5
-  def empty(producerId: Long) = new ProducerIdEntry(producerId, mutable.Queue[BatchMetadata](), RecordBatch.NO_PRODUCER_EPOCH, -1, None)
+  def empty(producerId: Long) = new ProducerStateEntry(producerId, mutable.Queue[BatchMetadata](), RecordBatch.NO_PRODUCER_EPOCH, -1, None)
 }
 
 private[log] case class BatchMetadata(lastSeq: Int, lastOffset: Long, offsetDelta: Int, timestamp: Long) {
@@ -90,27 +90,31 @@ private[log] case class BatchMetadata(lastSeq: Int, lastOffset: Long, offsetDelt
 }
 
 // the batchMetadata is ordered such that the batch with the lowest sequence is at the head of the queue while the
-// batch with the highest sequence is at the tail of the queue. We will retain at most ProducerIdEntry.NumBatchesToRetain
+// batch with the highest sequence is at the tail of the queue. We will retain at most ProducerStateEntry.NumBatchesToRetain
 // elements in the queue. When the queue is at capacity, we remove the first element to make space for the incoming batch.
-private[log] class ProducerIdEntry(val producerId: Long, val batchMetadata: mutable.Queue[BatchMetadata],
-                                   var producerEpoch: Short, var coordinatorEpoch: Int,
-                                   var currentTxnFirstOffset: Option[Long]) {
+private[log] class ProducerStateEntry(val producerId: Long,
+                                      val batchMetadata: mutable.Queue[BatchMetadata],
+                                      var producerEpoch: Short,
+                                      var coordinatorEpoch: Int,
+                                      var currentTxnFirstOffset: Option[Long]) {
 
-  def firstSeq: Int = if (batchMetadata.isEmpty) RecordBatch.NO_SEQUENCE else batchMetadata.front.firstSeq
-  def firstOffset: Long = if (batchMetadata.isEmpty) -1L else batchMetadata.front.firstOffset
+  def firstSeq: Int = if (isEmpty) RecordBatch.NO_SEQUENCE else batchMetadata.front.firstSeq
 
-  def lastSeq: Int = if (batchMetadata.isEmpty) RecordBatch.NO_SEQUENCE else batchMetadata.last.lastSeq
-  def lastDataOffset: Long = if (batchMetadata.isEmpty) -1L else batchMetadata.last.lastOffset
-  def lastTimestamp = if (batchMetadata.isEmpty) RecordBatch.NO_TIMESTAMP else batchMetadata.last.timestamp
-  def lastOffsetDelta : Int = if (batchMetadata.isEmpty) 0 else batchMetadata.last.offsetDelta
+  def firstOffset: Long = if (isEmpty) -1L else batchMetadata.front.firstOffset
 
-  def addBatchMetadata(producerEpoch: Short, lastSeq: Int, lastOffset: Long, offsetDelta: Int, timestamp: Long) = {
-    maybeUpdateEpoch(producerEpoch)
+  def lastSeq: Int = if (isEmpty) RecordBatch.NO_SEQUENCE else batchMetadata.last.lastSeq
 
-    if (batchMetadata.size == ProducerIdEntry.NumBatchesToRetain)
-      batchMetadata.dequeue()
+  def lastDataOffset: Long = if (isEmpty) -1L else batchMetadata.last.lastOffset
+
+  def lastTimestamp = if (isEmpty) RecordBatch.NO_TIMESTAMP else batchMetadata.last.timestamp
 
-    batchMetadata.enqueue(BatchMetadata(lastSeq, lastOffset, offsetDelta, timestamp))
+  def lastOffsetDelta : Int = if (isEmpty) 0 else batchMetadata.last.offsetDelta
+
+  def isEmpty: Boolean = batchMetadata.isEmpty
+
+  def addBatch(producerEpoch: Short, lastSeq: Int, lastOffset: Long, offsetDelta: Int, timestamp: Long): Unit = {
+    maybeUpdateEpoch(producerEpoch)
+    addBatchMetadata(BatchMetadata(lastSeq, lastOffset, offsetDelta, timestamp))
   }
 
   def maybeUpdateEpoch(producerEpoch: Short): Boolean = {
@@ -123,25 +127,39 @@ private[log] class ProducerIdEntry(val producerId: Long, val batchMetadata: muta
     }
   }
 
-  def removeBatchesOlderThan(offset: Long) = batchMetadata.dropWhile(_.lastOffset < offset)
+  private def addBatchMetadata(batch: BatchMetadata): Unit = {
+    if (batchMetadata.size == ProducerStateEntry.NumBatchesToRetain)
+      batchMetadata.dequeue()
+    batchMetadata.enqueue(batch)
+  }
+
+  def update(nextEntry: ProducerStateEntry): Unit = {
+    maybeUpdateEpoch(nextEntry.producerEpoch)
+    while (nextEntry.batchMetadata.nonEmpty)
+      addBatchMetadata(nextEntry.batchMetadata.dequeue())
+    this.coordinatorEpoch = nextEntry.coordinatorEpoch
+    this.currentTxnFirstOffset = nextEntry.currentTxnFirstOffset
+  }
+
+  def removeBatchesOlderThan(offset: Long): Unit = batchMetadata.dropWhile(_.lastOffset < offset)
 
-  def duplicateOf(batch: RecordBatch): Option[BatchMetadata] = {
-    if (batch.producerEpoch() != producerEpoch)
+  def findDuplicateBatch(batch: RecordBatch): Option[BatchMetadata] = {
+    if (batch.producerEpoch != producerEpoch)
        None
     else
-      batchWithSequenceRange(batch.baseSequence(), batch.lastSequence())
+      batchWithSequenceRange(batch.baseSequence, batch.lastSequence)
   }
 
   // Return the batch metadata of the cached batch having the exact sequence range, if any.
   def batchWithSequenceRange(firstSeq: Int, lastSeq: Int): Option[BatchMetadata] = {
-    val duplicate = batchMetadata.filter { case(metadata) =>
+    val duplicate = batchMetadata.filter { metadata =>
       firstSeq == metadata.firstSeq && lastSeq == metadata.lastSeq
     }
     duplicate.headOption
   }
 
   override def toString: String = {
-    "ProducerIdEntry(" +
+    "ProducerStateEntry(" +
       s"producerId=$producerId, " +
       s"producerEpoch=$producerEpoch, " +
       s"currentTxnFirstOffset=$currentTxnFirstOffset, " +
@@ -159,7 +177,7 @@ private[log] class ProducerIdEntry(val producerId: Long, val batchMetadata: muta
  * @param producerId The id of the producer appending to the log
  * @param currentEntry  The current entry associated with the producer id which contains metadata for a fixed number of
  *                      the most recent appends made by the producer. Validation of the first incoming append will
- *                      be made against the lastest append in the current entry. New appends will replace older appends
+ *                      be made against the latest append in the current entry. New appends will replace older appends
  *                      in the current entry so that the space overhead is constant.
  * @param validationType Indicates the extent of validation to perform on the appends on this instance. Offset commits
  *                       coming from the producer should have ValidationType.EpochOnly. Appends which aren't from a client
@@ -167,66 +185,67 @@ private[log] class ProducerIdEntry(val producerId: Long, val batchMetadata: muta
  *                       ValidationType.Full.
  */
 private[log] class ProducerAppendInfo(val producerId: Long,
-                                      currentEntry: ProducerIdEntry,
-                                      validationType: ValidationType) {
-
+                                      val currentEntry: ProducerStateEntry,
+                                      val validationType: ValidationType) {
   private val transactions = ListBuffer.empty[TxnMetadata]
+  private val updatedEntry = ProducerStateEntry.empty(producerId)
 
-  private def maybeValidateAppend(producerEpoch: Short, firstSeq: Int, lastSeq: Int) = {
+  updatedEntry.producerEpoch = currentEntry.producerEpoch
+  updatedEntry.coordinatorEpoch = currentEntry.coordinatorEpoch
+  updatedEntry.currentTxnFirstOffset = currentEntry.currentTxnFirstOffset
+
+  private def maybeValidateAppend(producerEpoch: Short, firstSeq: Int) = {
     validationType match {
       case ValidationType.None =>
 
       case ValidationType.EpochOnly =>
-        checkEpoch(producerEpoch)
+        checkProducerEpoch(producerEpoch)
 
       case ValidationType.Full =>
-        checkEpoch(producerEpoch)
-        checkSequence(producerEpoch, firstSeq, lastSeq)
+        checkProducerEpoch(producerEpoch)
+        checkSequence(producerEpoch, firstSeq)
     }
   }
 
-  private def checkEpoch(producerEpoch: Short): Unit = {
-    if (isFenced(producerEpoch)) {
+  private def checkProducerEpoch(producerEpoch: Short): Unit = {
+    if (producerEpoch < updatedEntry.producerEpoch) {
       throw new ProducerFencedException(s"Producer's epoch is no longer valid. There is probably another producer " +
-        s"with a newer epoch. $producerEpoch (request epoch), ${currentEntry.producerEpoch} (server epoch)")
+        s"with a newer epoch. $producerEpoch (request epoch), ${updatedEntry.producerEpoch} (server epoch)")
     }
   }
 
-  private def checkSequence(producerEpoch: Short, firstSeq: Int, lastSeq: Int): Unit = {
-    if (producerEpoch != currentEntry.producerEpoch) {
-      if (firstSeq != 0) {
-        if (currentEntry.producerEpoch != RecordBatch.NO_PRODUCER_EPOCH) {
+  private def checkSequence(producerEpoch: Short, appendFirstSeq: Int): Unit = {
+    if (producerEpoch != updatedEntry.producerEpoch) {
+      if (appendFirstSeq != 0) {
+        if (updatedEntry.producerEpoch != RecordBatch.NO_PRODUCER_EPOCH) {
           throw new OutOfOrderSequenceException(s"Invalid sequence number for new epoch: $producerEpoch " +
-            s"(request epoch), $firstSeq (seq. number)")
+            s"(request epoch), $appendFirstSeq (seq. number)")
         } else {
           throw new UnknownProducerIdException(s"Found no record of producerId=$producerId on the broker. It is possible " +
             s"that the last message with the producerId=$producerId has been removed due to hitting the retention limit.")
         }
       }
-    } else if (currentEntry.lastSeq == RecordBatch.NO_SEQUENCE && firstSeq != 0) {
-      // the epoch was bumped by a control record, so we expect the sequence number to be reset
-      throw new OutOfOrderSequenceException(s"Out of order sequence number for producerId $producerId: found $firstSeq " +
-        s"(incoming seq. number), but expected 0")
-    } else if (isDuplicate(firstSeq, lastSeq)) {
-      throw new DuplicateSequenceException(s"Duplicate sequence number for producerId $producerId: (incomingBatch.firstSeq, " +
-        s"incomingBatch.lastSeq): ($firstSeq, $lastSeq).")
-    } else if (!inSequence(firstSeq, lastSeq)) {
-      throw new OutOfOrderSequenceException(s"Out of order sequence number for producerId $producerId: $firstSeq " +
-        s"(incoming seq. number), ${currentEntry.lastSeq} (current end sequence number)")
+    } else {
+      val currentLastSeq = if (!updatedEntry.isEmpty)
+        updatedEntry.lastSeq
+      else if (producerEpoch == currentEntry.producerEpoch)
+        currentEntry.lastSeq
+      else
+        RecordBatch.NO_SEQUENCE
+
+      if (currentLastSeq == RecordBatch.NO_SEQUENCE && appendFirstSeq != 0) {
+        // the epoch was bumped by a control record, so we expect the sequence number to be reset
+        throw new OutOfOrderSequenceException(s"Out of order sequence number for producerId $producerId: found $appendFirstSeq " +
+          s"(incoming seq. number), but expected 0")
+      } else if (!inSequence(currentLastSeq, appendFirstSeq)) {
+        throw new OutOfOrderSequenceException(s"Out of order sequence number for producerId $producerId: $appendFirstSeq " +
+          s"(incoming seq. number), $currentLastSeq (current end sequence number)")
+      }
     }
   }
 
-  private def isDuplicate(firstSeq: Int, lastSeq: Int): Boolean = {
-    ((lastSeq != 0 && currentEntry.firstSeq != Int.MaxValue && lastSeq < currentEntry.firstSeq)
-      || currentEntry.batchWithSequenceRange(firstSeq, lastSeq).isDefined)
-  }
-
-  private def inSequence(firstSeq: Int, lastSeq: Int): Boolean = {
-    firstSeq == currentEntry.lastSeq + 1L || (firstSeq == 0 && currentEntry.lastSeq == Int.MaxValue)
-  }
-
-  private def isFenced(producerEpoch: Short): Boolean = {
-    producerEpoch < currentEntry.producerEpoch
+  private def inSequence(lastSeq: Int, nextSeq: Int): Boolean = {
+    nextSeq == lastSeq + 1L || (nextSeq == 0 && lastSeq == Int.MaxValue)
   }
 
   def append(batch: RecordBatch): Option[CompletedTxn] = {
@@ -248,17 +267,21 @@ private[log] class ProducerAppendInfo(val producerId: Long,
              lastTimestamp: Long,
              lastOffset: Long,
              isTransactional: Boolean): Unit = {
-    maybeValidateAppend(epoch, firstSeq, lastSeq)
+    maybeValidateAppend(epoch, firstSeq)
+    updatedEntry.addBatch(epoch, lastSeq, lastOffset, lastSeq - firstSeq, lastTimestamp)
 
-    currentEntry.addBatchMetadata(epoch, lastSeq, lastOffset, lastSeq - firstSeq, lastTimestamp)
+    updatedEntry.currentTxnFirstOffset match {
+      case Some(_) if !isTransactional =>
+        // Received a non-transactional message while a transaction is active
+        throw new InvalidTxnStateException(s"Expected transactional write from producer $producerId")
 
-    if (currentEntry.currentTxnFirstOffset.isDefined && !isTransactional)
-      throw new InvalidTxnStateException(s"Expected transactional write from producer $producerId")
+      case None if isTransactional =>
+        // Began a new transaction
+        val firstOffset = lastOffset - (lastSeq - firstSeq)
+        updatedEntry.currentTxnFirstOffset = Some(firstOffset)
+        transactions += new TxnMetadata(producerId, firstOffset)
 
-    if (isTransactional && currentEntry.currentTxnFirstOffset.isEmpty) {
-      val firstOffset = lastOffset - (lastSeq - firstSeq)
-      currentEntry.currentTxnFirstOffset = Some(firstOffset)
-      transactions += new TxnMetadata(producerId, firstOffset)
+      case _ => // nothing to do
     }
   }
 
@@ -266,28 +289,27 @@ private[log] class ProducerAppendInfo(val producerId: Long,
                          producerEpoch: Short,
                          offset: Long,
                          timestamp: Long): CompletedTxn = {
-    if (isFenced(producerEpoch))
-      throw new ProducerFencedException(s"Invalid producer epoch: $producerEpoch (zombie): ${currentEntry.producerEpoch} (current)")
+    checkProducerEpoch(producerEpoch)
 
-    if (currentEntry.coordinatorEpoch > endTxnMarker.coordinatorEpoch)
+    if (updatedEntry.coordinatorEpoch > endTxnMarker.coordinatorEpoch)
       throw new TransactionCoordinatorFencedException(s"Invalid coordinator epoch: ${endTxnMarker.coordinatorEpoch} " +
-        s"(zombie), ${currentEntry.coordinatorEpoch} (current)")
+        s"(zombie), ${updatedEntry.coordinatorEpoch} (current)")
 
-    currentEntry.maybeUpdateEpoch(producerEpoch)
+    updatedEntry.maybeUpdateEpoch(producerEpoch)
 
-    val firstOffset = currentEntry.currentTxnFirstOffset match {
+    val firstOffset = updatedEntry.currentTxnFirstOffset match {
       case Some(txnFirstOffset) => txnFirstOffset
       case None =>
         transactions += new TxnMetadata(producerId, offset)
         offset
     }
 
-    currentEntry.currentTxnFirstOffset = None
-    currentEntry.coordinatorEpoch = endTxnMarker.coordinatorEpoch
+    updatedEntry.currentTxnFirstOffset = None
+    updatedEntry.coordinatorEpoch = endTxnMarker.coordinatorEpoch
     CompletedTxn(producerId, firstOffset, offset, endTxnMarker.controlType == ControlRecordType.ABORT)
   }
 
-  def latestEntry: ProducerIdEntry = currentEntry
+  def toEntry: ProducerStateEntry = updatedEntry
 
   def startedTransactions: List[TxnMetadata] = transactions.toList
 
@@ -306,11 +328,11 @@ private[log] class ProducerAppendInfo(val producerId: Long,
   override def toString: String = {
     "ProducerAppendInfo(" +
       s"producerId=$producerId, " +
-      s"producerEpoch=${currentEntry.producerEpoch}, " +
-      s"firstSequence=${currentEntry.firstSeq}, " +
-      s"lastSequence=${currentEntry.lastSeq}, " +
-      s"currentTxnFirstOffset=${currentEntry.currentTxnFirstOffset}, " +
-      s"coordinatorEpoch=${currentEntry.coordinatorEpoch}, " +
+      s"producerEpoch=${updatedEntry.producerEpoch}, " +
+      s"firstSequence=${updatedEntry.firstSeq}, " +
+      s"lastSequence=${updatedEntry.lastSeq}, " +
+      s"currentTxnFirstOffset=${updatedEntry.currentTxnFirstOffset}, " +
+      s"coordinatorEpoch=${updatedEntry.coordinatorEpoch}, " +
       s"startedTransactions=$transactions)"
   }
 }
@@ -347,7 +369,7 @@ object ProducerStateManager {
     new Field(CrcField, Type.UNSIGNED_INT32, "CRC of the snapshot data"),
     new Field(ProducerEntriesField, new ArrayOf(ProducerSnapshotEntrySchema), "The entries in the producer table"))
 
-  def readSnapshot(file: File): Iterable[ProducerIdEntry] = {
+  def readSnapshot(file: File): Iterable[ProducerStateEntry] = {
     try {
       val buffer = Files.readAllBytes(file.toPath)
       val struct = PidSnapshotMapSchema.read(ByteBuffer.wrap(buffer))
@@ -372,7 +394,7 @@ object ProducerStateManager {
         val offsetDelta = producerEntryStruct.getInt(OffsetDeltaField)
         val coordinatorEpoch = producerEntryStruct.getInt(CoordinatorEpochField)
         val currentTxnFirstOffset = producerEntryStruct.getLong(CurrentTxnFirstOffsetField)
-        val newEntry = new ProducerIdEntry(producerId, mutable.Queue[BatchMetadata](BatchMetadata(seq, offset, offsetDelta, timestamp)), producerEpoch,
+        val newEntry = new ProducerStateEntry(producerId, mutable.Queue[BatchMetadata](BatchMetadata(seq, offset, offsetDelta, timestamp)), producerEpoch,
           coordinatorEpoch, if (currentTxnFirstOffset >= 0) Some(currentTxnFirstOffset) else None)
         newEntry
       }
@@ -382,7 +404,7 @@ object ProducerStateManager {
     }
   }
 
-  private def writeSnapshot(file: File, entries: mutable.Map[Long, ProducerIdEntry]) {
+  private def writeSnapshot(file: File, entries: mutable.Map[Long, ProducerStateEntry]) {
     val struct = new Struct(PidSnapshotMapSchema)
     struct.set(VersionField, ProducerSnapshotVersion)
     struct.set(CrcField, 0L) // we'll fill this after writing the entries
@@ -462,7 +484,9 @@ class ProducerStateManager(val topicPartition: TopicPartition,
   import ProducerStateManager._
   import java.util
 
-  private val producers = mutable.Map.empty[Long, ProducerIdEntry]
+  this.logIdent = s"[ProducerStateManager partition=$topicPartition] "
+
+  private val producers = mutable.Map.empty[Long, ProducerStateEntry]
   private var lastMapOffset = 0L
   private var lastSnapOffset = 0L
 
@@ -512,7 +536,7 @@ class ProducerStateManager(val topicPartition: TopicPartition,
   /**
    * Get a copy of the active producers
    */
-  def activeProducers: immutable.Map[Long, ProducerIdEntry] = producers.toMap
+  def activeProducers: immutable.Map[Long, ProducerStateEntry] = producers.toMap
 
   def isEmpty: Boolean = producers.isEmpty && unreplicatedTxns.isEmpty
 
@@ -521,7 +545,7 @@ class ProducerStateManager(val topicPartition: TopicPartition,
       latestSnapshotFile match {
         case Some(file) =>
           try {
-            info(s"Loading producer state from snapshot file '$file' for partition $topicPartition")
+            info(s"Loading producer state from snapshot file '$file'")
             val loadedProducers = readSnapshot(file).filter { producerEntry =>
               isProducerRetained(producerEntry, logStartOffset) && !isProducerExpired(currentTime, producerEntry)
             }
@@ -543,7 +567,7 @@ class ProducerStateManager(val topicPartition: TopicPartition,
   }
 
   // visible for testing
-  private[log] def loadProducerEntry(entry: ProducerIdEntry): Unit = {
+  private[log] def loadProducerEntry(entry: ProducerStateEntry): Unit = {
     val producerId = entry.producerId
     producers.put(producerId, entry)
     entry.currentTxnFirstOffset.foreach { offset =>
@@ -551,8 +575,8 @@ class ProducerStateManager(val topicPartition: TopicPartition,
     }
   }
 
-  private def isProducerExpired(currentTimeMs: Long, producerIdEntry: ProducerIdEntry): Boolean =
-    producerIdEntry.currentTxnFirstOffset.isEmpty && currentTimeMs - producerIdEntry.lastTimestamp >= maxProducerIdExpirationMs
+  private def isProducerExpired(currentTimeMs: Long, producerState: ProducerStateEntry): Boolean =
+    producerState.currentTxnFirstOffset.isEmpty && currentTimeMs - producerState.lastTimestamp >= maxProducerIdExpirationMs
 
   /**
    * Expire any producer ids which have been idle longer than the configured maximum expiration timeout.
@@ -596,7 +620,8 @@ class ProducerStateManager(val topicPartition: TopicPartition,
       else
         ValidationType.Full
 
-    new ProducerAppendInfo(producerId, lastEntry(producerId).getOrElse(ProducerIdEntry.empty(producerId)), validationToPerform)
+    val currentEntry = lastEntry(producerId).getOrElse(ProducerStateEntry.empty(producerId))
+    new ProducerAppendInfo(producerId, currentEntry, validationToPerform)
   }
 
   /**
@@ -604,12 +629,19 @@ class ProducerStateManager(val topicPartition: TopicPartition,
    */
   def update(appendInfo: ProducerAppendInfo): Unit = {
     if (appendInfo.producerId == RecordBatch.NO_PRODUCER_ID)
-      throw new IllegalArgumentException(s"Invalid producer id ${appendInfo.producerId} passed to update")
+      throw new IllegalArgumentException(s"Invalid producer id ${appendInfo.producerId} passed to update " +
+        s"for partition $topicPartition")
 
     trace(s"Updated producer ${appendInfo.producerId} state to $appendInfo")
+    val updatedEntry = appendInfo.toEntry
+    producers.get(appendInfo.producerId) match {
+      case Some(currentEntry) =>
+        currentEntry.update(updatedEntry)
+
+      case None =>
+        producers.put(appendInfo.producerId, updatedEntry)
+    }
 
-    val entry = appendInfo.latestEntry
-    producers.put(appendInfo.producerId, entry)
     appendInfo.startedTransactions.foreach { txn =>
       ongoingTxns.put(txn.firstOffset.messageOffset, txn)
     }
@@ -622,7 +654,7 @@ class ProducerStateManager(val topicPartition: TopicPartition,
   /**
    * Get the last written entry for the given producer id.
    */
-  def lastEntry(producerId: Long): Option[ProducerIdEntry] = producers.get(producerId)
+  def lastEntry(producerId: Long): Option[ProducerStateEntry] = producers.get(producerId)
 
   /**
    * Take a snapshot at the current end offset if one does not already exist.
@@ -631,7 +663,7 @@ class ProducerStateManager(val topicPartition: TopicPartition,
     // If not a new offset, then it is not worth taking another snapshot
     if (lastMapOffset > lastSnapOffset) {
       val snapshotFile = Log.producerSnapshotFile(logDir, lastMapOffset)
-      debug(s"Writing producer snapshot for partition $topicPartition at offset $lastMapOffset")
+      info(s"Writing producer snapshot at offset $lastMapOffset")
       writeSnapshot(snapshotFile, producers)
 
       // Update the last snap offset according to the serialized map
@@ -649,9 +681,9 @@ class ProducerStateManager(val topicPartition: TopicPartition,
    */
   def oldestSnapshotOffset: Option[Long] = oldestSnapshotFile.map(file => offsetFromFile(file))
 
-  private def isProducerRetained(producerIdEntry: ProducerIdEntry, logStartOffset: Long): Boolean = {
-    producerIdEntry.removeBatchesOlderThan(logStartOffset)
-    producerIdEntry.lastDataOffset >= logStartOffset
+  private def isProducerRetained(producerStateEntry: ProducerStateEntry, logStartOffset: Long): Boolean = {
+    producerStateEntry.removeBatchesOlderThan(logStartOffset)
+    producerStateEntry.lastDataOffset >= logStartOffset
   }
 
   /**
@@ -664,8 +696,8 @@ class ProducerStateManager(val topicPartition: TopicPartition,
    * the snapshot.
    */
   def truncateHead(logStartOffset: Long) {
-    val evictedProducerEntries = producers.filter { case (_, producerIdEntry) =>
-      !isProducerRetained(producerIdEntry, logStartOffset)
+    val evictedProducerEntries = producers.filter { case (_, producerState) =>
+      !isProducerRetained(producerState, logStartOffset)
     }
     val evictedProducerIds = evictedProducerEntries.keySet
 
@@ -717,7 +749,8 @@ class ProducerStateManager(val topicPartition: TopicPartition,
   def completeTxn(completedTxn: CompletedTxn): Long = {
     val txnMetadata = ongoingTxns.remove(completedTxn.firstOffset)
     if (txnMetadata == null)
-      throw new IllegalArgumentException("Attempted to complete a transaction which was not started")
+      throw new IllegalArgumentException(s"Attempted to complete transaction $completedTxn on partition $topicPartition " +
+        s"which was not started")
 
     txnMetadata.lastOffset = Some(completedTxn.lastOffset)
     unreplicatedTxns.put(completedTxn.firstOffset, txnMetadata)
diff --git a/core/src/main/scala/kafka/tools/DumpLogSegments.scala b/core/src/main/scala/kafka/tools/DumpLogSegments.scala
index 855ca75..3261906 100755
--- a/core/src/main/scala/kafka/tools/DumpLogSegments.scala
+++ b/core/src/main/scala/kafka/tools/DumpLogSegments.scala
@@ -155,9 +155,13 @@ object DumpLogSegments {
   private def dumpProducerIdSnapshot(file: File): Unit = {
     try {
       ProducerStateManager.readSnapshot(file).foreach { entry =>
-        println(s"producerId: ${entry.producerId} producerEpoch: ${entry.producerEpoch} " +
-          s"coordinatorEpoch: ${entry.coordinatorEpoch} currentTxnFirstOffset: ${entry.currentTxnFirstOffset} " +
-          s"cachedMetadata: ${entry.batchMetadata}")
+        print(s"producerId: ${entry.producerId} producerEpoch: ${entry.producerEpoch} " +
+          s"coordinatorEpoch: ${entry.coordinatorEpoch} currentTxnFirstOffset: ${entry.currentTxnFirstOffset} ")
+        entry.batchMetadata.headOption.foreach { metadata =>
+          print(s"firstSequence: ${metadata.firstSeq} lastSequence: ${metadata.lastSeq} " +
+            s"lastOffset: ${metadata.lastOffset} offsetDelta: ${metadata.offsetDelta} timestamp: ${metadata.timestamp}")
+        }
+        println()
       }
     } catch {
       case e: CorruptSnapshotException =>
diff --git a/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala b/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala
index cef2bca..7836d3a 100644
--- a/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala
@@ -317,7 +317,7 @@ class LogSegmentTest {
 
     // recover again, but this time assuming the transaction from pid2 began on a previous segment
     stateManager = new ProducerStateManager(topicPartition, logDir)
-    stateManager.loadProducerEntry(new ProducerIdEntry(pid2,
+    stateManager.loadProducerEntry(new ProducerStateEntry(pid2,
       mutable.Queue[BatchMetadata](BatchMetadata(10, 10L, 5, RecordBatch.NO_TIMESTAMP)), producerEpoch, 0, Some(75L)))
     segment.recover(stateManager)
     assertEquals(108L, stateManager.mapEndOffset)
diff --git a/core/src/test/scala/unit/kafka/log/LogTest.scala b/core/src/test/scala/unit/kafka/log/LogTest.scala
index 0a0ce19..3c8b01e 100755
--- a/core/src/test/scala/unit/kafka/log/LogTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogTest.scala
@@ -792,6 +792,39 @@ class LogTest {
   }
 
   @Test
+  def testProducerSnapshotAfterSegmentRollOnAppend(): Unit = {
+    val producerId = 1L
+    val logConfig = createLogConfig(segmentBytes = 1024)
+    val log = createLog(logDir, logConfig)
+
+    log.appendAsLeader(TestUtils.records(Seq(new SimpleRecord(mockTime.milliseconds(), new Array[Byte](512))),
+      producerId = producerId, producerEpoch = 0, sequence = 0),
+      leaderEpoch = 0)
+
+    // The next append should overflow the segment and cause it to roll
+    log.appendAsLeader(TestUtils.records(Seq(new SimpleRecord(mockTime.milliseconds(), new Array[Byte](512))),
+      producerId = producerId, producerEpoch = 0, sequence = 1),
+      leaderEpoch = 0)
+
+    assertEquals(2, log.logSegments.size)
+    assertEquals(1L, log.activeSegment.baseOffset)
+    assertEquals(Some(1L), log.latestProducerSnapshotOffset)
+
+    // Force a reload from the snapshot to check its consistency
+    log.truncateTo(1L)
+
+    assertEquals(2, log.logSegments.size)
+    assertEquals(1L, log.activeSegment.baseOffset)
+    assertTrue(log.activeSegment.log.batches.asScala.isEmpty)
+    assertEquals(Some(1L), log.latestProducerSnapshotOffset)
+
+    val lastEntry = log.producerStateManager.lastEntry(producerId)
+    assertTrue(lastEntry.isDefined)
+    assertEquals(0L, lastEntry.get.firstOffset)
+    assertEquals(0L, lastEntry.get.lastDataOffset)
+  }
+
+  @Test
   def testRebuildTransactionalState(): Unit = {
     val logConfig = createLogConfig(segmentBytes = 1024 * 1024 * 5)
     val log = createLog(logDir, logConfig)
@@ -892,7 +925,7 @@ class LogTest {
           new SimpleRecord(mockTime.milliseconds, s"key-$seq".getBytes, s"value-$seq".getBytes)),
         producerId = pid, producerEpoch = epoch, sequence = seq - 2)
       log.appendAsLeader(records, leaderEpoch = 0)
-      fail ("Should have received an OutOfOrderSequenceException since we attempted to append a duplicate of a records " +
+      fail("Should have received an OutOfOrderSequenceException since we attempted to append a duplicate of a records " +
         "in the middle of the log.")
     } catch {
       case _: OutOfOrderSequenceException => // Good!
@@ -904,16 +937,16 @@ class LogTest {
       producerId = pid, producerEpoch = epoch, sequence = 2)
     log.appendAsLeader(duplicateOfFourth, leaderEpoch = 0)
 
-    // Append a Duplicate of an entry older than the last 5 appended batches. This should result in a DuplicateSequenceNumberException.
-     try {
+    // Duplicates at older entries are reported as OutOfOrderSequence errors
+    try {
       val records = TestUtils.records(
         List(new SimpleRecord(mockTime.milliseconds, s"key-1".getBytes, s"value-1".getBytes)),
         producerId = pid, producerEpoch = epoch, sequence = 1)
       log.appendAsLeader(records, leaderEpoch = 0)
-      fail ("Should have received an DuplicateSequenceNumberException since we attempted to append a duplicate of a batch" +
+      fail("Should have received an OutOfOrderSequenceException since we attempted to append a duplicate of a batch " +
         "which is older than the last 5 appended batches.")
     } catch {
-      case _: DuplicateSequenceException => // Good!
+      case _: OutOfOrderSequenceException => // Good!
     }
 
     // Append a duplicate entry with a single records at the tail of the log. This should return the appendInfo of the original entry.
diff --git a/core/src/test/scala/unit/kafka/log/ProducerStateManagerTest.scala b/core/src/test/scala/unit/kafka/log/ProducerStateManagerTest.scala
index 67b1b15..053aed7 100644
--- a/core/src/test/scala/unit/kafka/log/ProducerStateManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/log/ProducerStateManagerTest.scala
@@ -62,8 +62,8 @@ class ProducerStateManagerTest extends JUnitSuite {
     // Second entry for id 0 added
     append(stateManager, producerId, epoch, 1, 0L, 1L)
 
-    // Duplicate sequence number (matches previous sequence number)
-    assertThrows[DuplicateSequenceException] {
+    // Duplicates are checked separately and should result in OutOfOrderSequence if appended
+    assertThrows[OutOfOrderSequenceException] {
       append(stateManager, producerId, epoch, 1, 0L, 1L)
     }
 
@@ -159,7 +159,7 @@ class ProducerStateManagerTest extends JUnitSuite {
     val producerEpoch = 0.toShort
     val offset = 992342L
     val seq = 0
-    val producerAppendInfo = new ProducerAppendInfo(producerId, ProducerIdEntry.empty(producerId), ValidationType.Full)
+    val producerAppendInfo = new ProducerAppendInfo(producerId, ProducerStateEntry.empty(producerId), ValidationType.Full)
     producerAppendInfo.append(producerEpoch, seq, seq, time.milliseconds(), offset, isTransactional = true)
 
     val logOffsetMetadata = new LogOffsetMetadata(messageOffset = offset, segmentBaseOffset = 990000L,
@@ -175,7 +175,7 @@ class ProducerStateManagerTest extends JUnitSuite {
     val producerEpoch = 0.toShort
     val offset = 992342L
     val seq = 0
-    val producerAppendInfo = new ProducerAppendInfo(producerId, ProducerIdEntry.empty(producerId), ValidationType.Full)
+    val producerAppendInfo = new ProducerAppendInfo(producerId, ProducerStateEntry.empty(producerId), ValidationType.Full)
     producerAppendInfo.append(producerEpoch, seq, seq, time.milliseconds(), offset, isTransactional = true)
 
     // use some other offset to simulate a follower append where the log offset metadata won't typically
@@ -189,6 +189,32 @@ class ProducerStateManagerTest extends JUnitSuite {
   }
 
   @Test
+  def testPrepareUpdateDoesNotMutate(): Unit = {
+    val producerEpoch = 0.toShort
+
+    val appendInfo = stateManager.prepareUpdate(producerId, isFromClient = true)
+    appendInfo.append(producerEpoch, 0, 5, time.milliseconds(), 20L, isTransactional = false)
+    assertEquals(None, stateManager.lastEntry(producerId))
+    stateManager.update(appendInfo)
+    assertTrue(stateManager.lastEntry(producerId).isDefined)
+
+    val nextAppendInfo = stateManager.prepareUpdate(producerId, isFromClient = true)
+    nextAppendInfo.append(producerEpoch, 6, 10, time.milliseconds(), 30L, isTransactional = false)
+    assertTrue(stateManager.lastEntry(producerId).isDefined)
+
+    var lastEntry = stateManager.lastEntry(producerId).get
+    assertEquals(0, lastEntry.firstSeq)
+    assertEquals(5, lastEntry.lastSeq)
+    assertEquals(20L, lastEntry.lastDataOffset)
+
+    stateManager.update(nextAppendInfo)
+    lastEntry = stateManager.lastEntry(producerId).get
+    assertEquals(0, lastEntry.firstSeq)
+    assertEquals(10, lastEntry.lastSeq)
+    assertEquals(30L, lastEntry.lastDataOffset)
+  }
+
+  @Test
   def updateProducerTransactionState(): Unit = {
     val producerEpoch = 0.toShort
     val coordinatorEpoch = 15
@@ -197,21 +223,21 @@ class ProducerStateManagerTest extends JUnitSuite {
 
     val appendInfo = stateManager.prepareUpdate(producerId, isFromClient = true)
     appendInfo.append(producerEpoch, 1, 5, time.milliseconds(), 20L, isTransactional = true)
-    var lastEntry = appendInfo.latestEntry
+    var lastEntry = appendInfo.toEntry
     assertEquals(producerEpoch, lastEntry.producerEpoch)
-    assertEquals(0, lastEntry.firstSeq)
+    assertEquals(1, lastEntry.firstSeq)
     assertEquals(5, lastEntry.lastSeq)
-    assertEquals(9L, lastEntry.firstOffset)
+    assertEquals(16L, lastEntry.firstOffset)
     assertEquals(20L, lastEntry.lastDataOffset)
     assertEquals(Some(16L), lastEntry.currentTxnFirstOffset)
     assertEquals(List(new TxnMetadata(producerId, 16L)), appendInfo.startedTransactions)
 
     appendInfo.append(producerEpoch, 6, 10, time.milliseconds(), 30L, isTransactional = true)
-    lastEntry = appendInfo.latestEntry
+    lastEntry = appendInfo.toEntry
     assertEquals(producerEpoch, lastEntry.producerEpoch)
-    assertEquals(0, lastEntry.firstSeq)
+    assertEquals(1, lastEntry.firstSeq)
     assertEquals(10, lastEntry.lastSeq)
-    assertEquals(9L, lastEntry.firstOffset)
+    assertEquals(16L, lastEntry.firstOffset)
     assertEquals(30L, lastEntry.lastDataOffset)
     assertEquals(Some(16L), lastEntry.currentTxnFirstOffset)
     assertEquals(List(new TxnMetadata(producerId, 16L)), appendInfo.startedTransactions)
@@ -223,29 +249,40 @@ class ProducerStateManagerTest extends JUnitSuite {
     assertEquals(40L, completedTxn.lastOffset)
     assertFalse(completedTxn.isAborted)
 
-    lastEntry = appendInfo.latestEntry
+    lastEntry = appendInfo.toEntry
     assertEquals(producerEpoch, lastEntry.producerEpoch)
     // verify that appending the transaction marker doesn't affect the metadata of the cached record batches.
-    assertEquals(0, lastEntry.firstSeq)
+    assertEquals(1, lastEntry.firstSeq)
     assertEquals(10, lastEntry.lastSeq)
-    assertEquals(9L, lastEntry.firstOffset)
+    assertEquals(16L, lastEntry.firstOffset)
     assertEquals(30L, lastEntry.lastDataOffset)
     assertEquals(coordinatorEpoch, lastEntry.coordinatorEpoch)
     assertEquals(None, lastEntry.currentTxnFirstOffset)
     assertEquals(List(new TxnMetadata(producerId, 16L)), appendInfo.startedTransactions)
   }
 
-  @Test(expected = classOf[OutOfOrderSequenceException])
+  @Test
   def testOutOfSequenceAfterControlRecordEpochBump(): Unit = {
     val epoch = 0.toShort
-    append(stateManager, producerId, epoch, 0, 0L)
-    append(stateManager, producerId, epoch, 1, 1L)
+    append(stateManager, producerId, epoch, 0, 0L, isTransactional = true)
+    append(stateManager, producerId, epoch, 1, 1L, isTransactional = true)
 
     val bumpedEpoch = 1.toShort
     appendEndTxnMarker(stateManager, producerId, bumpedEpoch, ControlRecordType.ABORT, 1L)
 
     // next append is invalid since we expect the sequence to be reset
-    append(stateManager, producerId, bumpedEpoch, 2, 2L)
+    assertThrows[OutOfOrderSequenceException] {
+      append(stateManager, producerId, bumpedEpoch, 2, 2L, isTransactional = true)
+    }
+
+    assertThrows[OutOfOrderSequenceException] {
+      append(stateManager, producerId, (bumpedEpoch + 1).toShort, 2, 2L, isTransactional = true)
+    }
+
+    // Append with the bumped epoch should be fine if starting from sequence 0
+    append(stateManager, producerId, bumpedEpoch, 0, 0L, isTransactional = true)
+    assertEquals(bumpedEpoch, stateManager.lastEntry(producerId).get.producerEpoch)
+    assertEquals(0, stateManager.lastEntry(producerId).get.lastSeq)
   }
 
   @Test(expected = classOf[InvalidTxnStateException])
@@ -334,10 +371,10 @@ class ProducerStateManagerTest extends JUnitSuite {
     assertFalse(recoveredMapping.activeProducers.contains(producerId))
     append(recoveredMapping, producerId, epoch, sequence, 2L, 70001, isFromClient = false)
     assertTrue(recoveredMapping.activeProducers.contains(producerId))
-    val producerIdEntry = recoveredMapping.activeProducers.get(producerId).head
-    assertEquals(epoch, producerIdEntry.producerEpoch)
-    assertEquals(sequence, producerIdEntry.firstSeq)
-    assertEquals(sequence, producerIdEntry.lastSeq)
+    val producerStateEntry = recoveredMapping.activeProducers.get(producerId).head
+    assertEquals(epoch, producerStateEntry.producerEpoch)
+    assertEquals(sequence, producerStateEntry.firstSeq)
+    assertEquals(sequence, producerStateEntry.lastSeq)
   }
 
   @Test
@@ -524,7 +561,7 @@ class ProducerStateManagerTest extends JUnitSuite {
     append(stateManager, producerId, epoch, 2, 3L, 4L)
     stateManager.takeSnapshot()
 
-    intercept[UnknownProducerIdException] {
+    assertThrows[UnknownProducerIdException] {
       val recoveredMapping = new ProducerStateManager(partition, logDir, maxPidExpirationMs)
       recoveredMapping.truncateAndReload(0L, 1L, time.milliseconds)
       append(recoveredMapping, pid2, epoch, 1, 4L, 5L)

-- 
To stop receiving notification emails like this one, please contact
jgus@apache.org.