You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by ju...@apache.org on 2017/04/20 20:01:50 UTC

kafka git commit: MINOR: Improvements to PID snapshot management

Repository: kafka
Updated Branches:
  refs/heads/trunk 6af876b94 -> 588ed4644


MINOR: Improvements to PID snapshot management

Author: Jason Gustafson <ja...@confluent.io>

Reviewers: Ismael Juma <is...@juma.me.uk>, Jun Rao <ju...@gmail.com>

Closes #2866 from hachikuji/improve-snapshot-management


Project: http://git-wip-us.apache.org/repos/asf/kafka/repo
Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/588ed464
Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/588ed464
Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/588ed464

Branch: refs/heads/trunk
Commit: 588ed4644fe8842e4217c8e2add39f6e08e07db9
Parents: 6af876b
Author: Jason Gustafson <ja...@confluent.io>
Authored: Thu Apr 20 13:01:46 2017 -0700
Committer: Jun Rao <ju...@gmail.com>
Committed: Thu Apr 20 13:01:46 2017 -0700

----------------------------------------------------------------------
 .../kafka/common/record/DefaultRecordBatch.java |   3 +
 .../kafka/common/record/MemoryRecords.java      |   9 +-
 .../common/record/MemoryRecordsBuilder.java     |   6 +
 .../common/record/DefaultRecordBatchTest.java   |  39 +++-
 .../kafka/common/record/MemoryRecordsTest.java  |  81 ++++++-
 core/src/main/scala/kafka/log/Log.scala         |  96 +++++---
 .../scala/kafka/log/ProducerIdMapping.scala     | 227 +++++++++----------
 .../scala/unit/kafka/log/LogManagerTest.scala   |   6 +-
 .../src/test/scala/unit/kafka/log/LogTest.scala | 221 +++++++++++++++++-
 .../unit/kafka/log/ProducerIdMappingTest.scala  | 112 +++++++--
 .../test/scala/unit/kafka/utils/TestUtils.scala |   5 +-
 11 files changed, 609 insertions(+), 196 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kafka/blob/588ed464/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java b/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java
index 3eeea36..2680f30 100644
--- a/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java
+++ b/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java
@@ -166,6 +166,9 @@ public class DefaultRecordBatch extends AbstractRecordBatch implements MutableRe
 
     @Override
     public int lastSequence() {
+        int baseSequence = baseSequence();
+        if (baseSequence == RecordBatch.NO_SEQUENCE)
+            return RecordBatch.NO_SEQUENCE;
         return baseSequence() + lastOffsetDelta();
     }
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/588ed464/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
index a6e804b..e7e155f 100644
--- a/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
+++ b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
@@ -176,11 +176,18 @@ public class MemoryRecords extends AbstractRecords {
                 TimestampType timestampType = batch.timestampType();
                 long logAppendTime = timestampType == TimestampType.LOG_APPEND_TIME ? batch.maxTimestamp() : RecordBatch.NO_TIMESTAMP;
                 MemoryRecordsBuilder builder = builder(slice, batch.magic(), batch.compressionType(), timestampType,
-                        firstOffset, logAppendTime);
+                        firstOffset, logAppendTime, batch.producerId(), batch.producerEpoch(), batch.baseSequence(),
+                        batch.partitionLeaderEpoch());
 
                 for (Record record : retainedRecords)
                     builder.append(record);
 
+                if (batch.magic() >= RecordBatch.MAGIC_VALUE_V2)
+                    // we must preserve the last offset from the initial batch in order to ensure that the
+                    // last sequence number from the batch remains even after compaction. Otherwise, the producer
+                    // could incorrectly see an out of sequence error.
+                    builder.overrideLastOffset(batch.lastOffset());
+
                 MemoryRecords records = builder.build();
                 destinationBuffer.position(destinationBuffer.position() + slice.position());
                 messagesRetained += retainedRecords.size();

http://git-wip-us.apache.org/repos/asf/kafka/blob/588ed464/clients/src/main/java/org/apache/kafka/common/record/MemoryRecordsBuilder.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/record/MemoryRecordsBuilder.java b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecordsBuilder.java
index 208db5b..549804a 100644
--- a/clients/src/main/java/org/apache/kafka/common/record/MemoryRecordsBuilder.java
+++ b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecordsBuilder.java
@@ -234,6 +234,12 @@ public class MemoryRecordsBuilder {
         this.baseSequence = baseSequence;
     }
 
+    public void overrideLastOffset(long lastOffset) {
+        if (builtRecords != null)
+            throw new IllegalStateException("Cannot override the last offset after the records have been built");
+        this.lastOffset = lastOffset;
+    }
+
     /**
      * Release resources required for record appends (e.g. compression buffers). Once this method is called, it's only
      * possible to update the RecordBatch header.

http://git-wip-us.apache.org/repos/asf/kafka/blob/588ed464/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordBatchTest.java
----------------------------------------------------------------------
diff --git a/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordBatchTest.java b/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordBatchTest.java
index 8466c83..a50c5b2 100644
--- a/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordBatchTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordBatchTest.java
@@ -37,14 +37,49 @@ public class DefaultRecordBatchTest {
 
         MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, CompressionType.NONE,
                 TimestampType.CREATE_TIME, 1234567L);
-        builder.appendWithOffset(1234567, System.currentTimeMillis(), "a".getBytes(), "v".getBytes());
-        builder.appendWithOffset(1234568, System.currentTimeMillis(), "b".getBytes(), "v".getBytes());
+        builder.appendWithOffset(1234567, 1L, "a".getBytes(), "v".getBytes());
+        builder.appendWithOffset(1234568, 2L, "b".getBytes(), "v".getBytes());
 
         MemoryRecords records = builder.build();
         for (MutableRecordBatch batch : records.batches()) {
+            assertTrue(batch.isValid());
             assertEquals(1234567, batch.baseOffset());
             assertEquals(1234568, batch.lastOffset());
+            assertEquals(2L, batch.maxTimestamp());
+            assertEquals(RecordBatch.NO_PRODUCER_ID, batch.producerId());
+            assertEquals(RecordBatch.NO_PRODUCER_EPOCH, batch.producerEpoch());
+            assertEquals(RecordBatch.NO_SEQUENCE, batch.baseSequence());
+            assertEquals(RecordBatch.NO_SEQUENCE, batch.lastSequence());
+
+            for (Record record : batch) {
+                assertTrue(record.isValid());
+            }
+        }
+    }
+
+    @Test
+    public void buildDefaultRecordBatchWithProducerId() {
+        long pid = 23423L;
+        short epoch = 145;
+        int baseSequence = 983;
+
+        ByteBuffer buffer = ByteBuffer.allocate(2048);
+
+        MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, CompressionType.NONE,
+                TimestampType.CREATE_TIME, 1234567L, RecordBatch.NO_TIMESTAMP, pid, epoch, baseSequence);
+        builder.appendWithOffset(1234567, 1L, "a".getBytes(), "v".getBytes());
+        builder.appendWithOffset(1234568, 2L, "b".getBytes(), "v".getBytes());
+
+        MemoryRecords records = builder.build();
+        for (MutableRecordBatch batch : records.batches()) {
             assertTrue(batch.isValid());
+            assertEquals(1234567, batch.baseOffset());
+            assertEquals(1234568, batch.lastOffset());
+            assertEquals(2L, batch.maxTimestamp());
+            assertEquals(pid, batch.producerId());
+            assertEquals(epoch, batch.producerEpoch());
+            assertEquals(baseSequence, batch.baseSequence());
+            assertEquals(baseSequence + 1, batch.lastSequence());
 
             for (Record record : batch) {
                 assertTrue(record.isValid());

http://git-wip-us.apache.org/repos/asf/kafka/blob/588ed464/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsTest.java
----------------------------------------------------------------------
diff --git a/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsTest.java b/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsTest.java
index ea430b1..49e1429 100644
--- a/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsTest.java
@@ -190,6 +190,85 @@ public class MemoryRecordsTest {
     }
 
     @Test
+    public void testFilterToPreservesPartitionLeaderEpoch() {
+        if (magic >= RecordBatch.MAGIC_VALUE_V2) {
+            int partitionLeaderEpoch = 67;
+
+            ByteBuffer buffer = ByteBuffer.allocate(2048);
+            MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, magic, compression, TimestampType.CREATE_TIME,
+                    0L, RecordBatch.NO_TIMESTAMP, partitionLeaderEpoch);
+            builder.append(10L, null, "a".getBytes());
+            builder.append(11L, "1".getBytes(), "b".getBytes());
+            builder.append(12L, null, "c".getBytes());
+
+            ByteBuffer filtered = ByteBuffer.allocate(2048);
+            builder.build().filterTo(new RetainNonNullKeysFilter(), filtered);
+
+            filtered.flip();
+            MemoryRecords filteredRecords = MemoryRecords.readableRecords(filtered);
+
+            List<MutableRecordBatch> batches = TestUtils.toList(filteredRecords.batches());
+            assertEquals(1, batches.size());
+
+            MutableRecordBatch firstBatch = batches.get(0);
+            assertEquals(partitionLeaderEpoch, firstBatch.partitionLeaderEpoch());
+        }
+    }
+
+    @Test
+    public void testFilterToPreservesProducerInfo() {
+        if (magic >= RecordBatch.MAGIC_VALUE_V2) {
+            ByteBuffer buffer = ByteBuffer.allocate(2048);
+            MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, magic, compression, TimestampType.CREATE_TIME, 0L);
+            builder.append(10L, null, "a".getBytes());
+            builder.append(11L, "1".getBytes(), "b".getBytes());
+            builder.append(12L, null, "c".getBytes());
+
+            builder.close();
+
+            long pid = 23L;
+            short epoch = 5;
+            int baseSequence = 10;
+
+            builder = MemoryRecords.builder(buffer, magic, compression, TimestampType.CREATE_TIME, 3L,
+                    RecordBatch.NO_TIMESTAMP, pid, epoch, baseSequence);
+            builder.append(13L, null, "d".getBytes());
+            builder.append(14L, "4".getBytes(), "e".getBytes());
+            builder.append(15L, "5".getBytes(), "f".getBytes());
+            builder.close();
+
+            buffer.flip();
+
+            ByteBuffer filtered = ByteBuffer.allocate(2048);
+            MemoryRecords.readableRecords(buffer).filterTo(new RetainNonNullKeysFilter(), filtered);
+
+            filtered.flip();
+            MemoryRecords filteredRecords = MemoryRecords.readableRecords(filtered);
+
+            List<MutableRecordBatch> batches = TestUtils.toList(filteredRecords.batches());
+            assertEquals(2, batches.size());
+
+            MutableRecordBatch firstBatch = batches.get(0);
+            assertEquals(1, firstBatch.countOrNull().intValue());
+            assertEquals(0L, firstBatch.baseOffset());
+            assertEquals(2L, firstBatch.lastOffset());
+            assertEquals(RecordBatch.NO_PRODUCER_ID, firstBatch.producerId());
+            assertEquals(RecordBatch.NO_PRODUCER_EPOCH, firstBatch.producerEpoch());
+            assertEquals(RecordBatch.NO_SEQUENCE, firstBatch.baseSequence());
+            assertEquals(RecordBatch.NO_SEQUENCE, firstBatch.lastSequence());
+
+            MutableRecordBatch secondBatch = batches.get(1);
+            assertEquals(2, secondBatch.countOrNull().intValue());
+            assertEquals(3L, secondBatch.baseOffset());
+            assertEquals(5L, secondBatch.lastOffset());
+            assertEquals(pid, secondBatch.producerId());
+            assertEquals(epoch, secondBatch.producerEpoch());
+            assertEquals(baseSequence, secondBatch.baseSequence());
+            assertEquals(baseSequence + 2, secondBatch.lastSequence());
+        }
+    }
+
+    @Test
     public void testFilterTo() {
         ByteBuffer buffer = ByteBuffer.allocate(2048);
         MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, magic, compression, TimestampType.CREATE_TIME, 0L);
@@ -246,7 +325,7 @@ public class MemoryRecordsTest {
             expectedStartOffsets = asList(1L, 4L, 6L);
             expectedMaxTimestamps = asList(11L, 20L, 16L);
         } else {
-            expectedEndOffsets = asList(1L, 5L, 6L);
+            expectedEndOffsets = asList(2L, 5L, 6L);
             expectedStartOffsets = asList(1L, 3L, 6L);
             expectedMaxTimestamps = asList(11L, 20L, 16L);
         }

http://git-wip-us.apache.org/repos/asf/kafka/blob/588ed464/core/src/main/scala/kafka/log/Log.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/log/Log.scala b/core/src/main/scala/kafka/log/Log.scala
index 450e9f6..485525f 100644
--- a/core/src/main/scala/kafka/log/Log.scala
+++ b/core/src/main/scala/kafka/log/Log.scala
@@ -112,7 +112,7 @@ class Log(@volatile var dir: File,
           time: Time = Time.SYSTEM,
           val maxPidExpirationMs: Int = 60 * 60 * 1000,
           val pidExpirationCheckIntervalMs: Int = 10 * 60 * 1000,
-          val pidSnapshotCreationIntervalMs: Int = 60 * 1000) extends Logging with KafkaMetricsGroup {
+          val pidSnapshotIntervalMs: Int = 60 * 1000) extends Logging with KafkaMetricsGroup {
 
   import kafka.log.Log._
 
@@ -191,7 +191,7 @@ class Log(@volatile var dir: File,
 
   scheduler.schedule(name = "PeriodicPidExpirationCheck", fun = () => {
     lock synchronized {
-      pidMap.checkForExpiredPids(time.milliseconds)
+      pidMap.removeExpiredPids(time.milliseconds)
     }
   }, period = pidExpirationCheckIntervalMs, unit = TimeUnit.MILLISECONDS)
 
@@ -199,7 +199,7 @@ class Log(@volatile var dir: File,
     lock synchronized {
       pidMap.maybeTakeSnapshot()
     }
-  }, period = pidSnapshotCreationIntervalMs, unit = TimeUnit.MILLISECONDS)
+  }, period = pidSnapshotIntervalMs, unit = TimeUnit.MILLISECONDS)
 
 
   /** The name of this log */
@@ -258,14 +258,14 @@ class Log(@volatile var dir: File,
         }
       } else if(filename.endsWith(LogFileSuffix)) {
         // if its a log file, load the corresponding log segment
-        val start = filename.substring(0, filename.length - LogFileSuffix.length).toLong
-        val indexFile = Log.indexFilename(dir, start)
-        val timeIndexFile = Log.timeIndexFilename(dir, start)
+        val startOffset = offsetFromFilename(filename)
+        val indexFile = Log.indexFilename(dir, startOffset)
+        val timeIndexFile = Log.timeIndexFilename(dir, startOffset)
 
         val indexFileExists = indexFile.exists()
         val timeIndexFileExists = timeIndexFile.exists()
         val segment = new LogSegment(dir = dir,
-                                     startOffset = start,
+                                     startOffset = startOffset,
                                      indexIntervalBytes = config.indexInterval,
                                      maxIndexSize = config.maxIndexSize,
                                      rollJitterMs = config.randomSegmentJitter,
@@ -291,7 +291,7 @@ class Log(@volatile var dir: File,
           error("Could not find index file corresponding to log file %s, rebuilding index...".format(segment.log.file.getAbsolutePath))
           segment.recover(config.maxMessageSize)
         }
-        segments.put(start, segment)
+        segments.put(startOffset, segment)
       }
     }
 
@@ -300,8 +300,8 @@ class Log(@volatile var dir: File,
     // before the swap file is restored as the new segment file.
     for (swapFile <- swapFiles) {
       val logFile = new File(CoreUtils.replaceSuffix(swapFile.getPath, SwapFileSuffix, ""))
-      val fileName = logFile.getName
-      val startOffset = fileName.substring(0, fileName.length - LogFileSuffix.length).toLong
+      val filename = logFile.getName
+      val startOffset = offsetFromFilename(filename)
       val indexFile = new File(CoreUtils.replaceSuffix(logFile.getPath, LogFileSuffix, IndexFileSuffix) + SwapFileSuffix)
       val index =  new OffsetIndex(indexFile, baseOffset = startOffset, maxIndexSize = config.maxIndexSize)
       val timeIndexFile = new File(CoreUtils.replaceSuffix(logFile.getPath, LogFileSuffix, TimeIndexFileSuffix) + SwapFileSuffix)
@@ -369,7 +369,7 @@ class Log(@volatile var dir: File,
       if(truncatedBytes > 0) {
         // we had an invalid message, delete all remaining log
         warn("Corruption found in segment %d of log %s, truncating to offset %d.".format(curr.baseOffset, name, curr.nextOffset))
-        unflushed.foreach(deleteSegment(_))
+        unflushed.foreach(deleteSegment)
       }
     }
   }
@@ -380,32 +380,26 @@ class Log(@volatile var dir: File,
     * starts from a snapshot that is taken strictly before the log end
     * offset. Consequently, we need to process the tail of the log to update
     * the mapping.
-    *
-    * @param lastOffset
-    *
-    * @return An instance of ProducerIdMapping
     */
   private def buildAndRecoverPidMap(lastOffset: Long) {
     lock synchronized {
+      info(s"Recovering PID mapping from offset $lastOffset for partition $topicPartition")
       val currentTimeMs = time.milliseconds
-      pidMap.truncateAndReload(lastOffset, currentTimeMs)
+      pidMap.truncateAndReload(logStartOffset, lastOffset, currentTimeMs)
       logSegments(pidMap.mapEndOffset, lastOffset).foreach { segment =>
         val startOffset = math.max(segment.baseOffset, pidMap.mapEndOffset)
         val fetchDataInfo = segment.read(startOffset, Some(lastOffset), Int.MaxValue)
-        val records = fetchDataInfo.records
-        records.batches.asScala.foreach { batch =>
-          if (batch.hasProducerId) {
-            // TODO: Currently accessing any of the batch-level headers other than the offset
-            // or magic causes us to load the full entry into memory. It would be better if we
-            // only loaded the header
-            val numRecords = (batch.lastOffset - batch.baseOffset + 1).toInt
-            val pidEntry = ProducerIdEntry(batch.producerEpoch, batch.lastSequence, batch.lastOffset,
-              numRecords, batch.maxTimestamp)
-            pidMap.load(batch.producerId, pidEntry, currentTimeMs)
+        if (fetchDataInfo != null) {
+          fetchDataInfo.records.batches.asScala.foreach { batch =>
+            if (batch.hasProducerId) {
+              val pidEntry = ProducerIdEntry(batch.producerEpoch, batch.lastSequence, batch.lastOffset,
+                batch.lastSequence - batch.baseSequence, batch.maxTimestamp)
+              pidMap.load(batch.producerId, pidEntry, currentTimeMs)
+            }
           }
         }
       }
-      pidMap.cleanFrom(logStartOffset)
+      pidMap.updateMapEndOffset(lastOffset)
     }
   }
 
@@ -436,7 +430,6 @@ class Log(@volatile var dir: File,
     }
   }
 
-
   /**
     * Append this message set to the active segment of the log, assigning offsets and Partition Leader Epochs
     * @param records The records to append
@@ -529,10 +522,10 @@ class Log(@volatile var dir: File,
             throw new IllegalArgumentException("Out of order offsets found in " + records.records.asScala.map(_.offset))
         }
 
-        //Update the epoch cache with the epoch stamped onto the message by the leader
-        validRecords.batches().asScala.map { batch =>
+        // update the epoch cache with the epoch stamped onto the message by the leader
+        validRecords.batches.asScala.foreach { batch =>
           if (batch.magic >= RecordBatch.MAGIC_VALUE_V2)
-            leaderEpochCache.assign(batch.partitionLeaderEpoch, batch.baseOffset())
+            leaderEpochCache.assign(batch.partitionLeaderEpoch, batch.baseOffset)
         }
 
         // check messages set size may be exceed config.segmentSize
@@ -562,6 +555,10 @@ class Log(@volatile var dir: File,
           pidMap.update(producerAppendInfo)
         }
 
+        // always update the last pid map offset so that the snapshot reflects the current offset
+        // even if there isn't any idempotent data being written
+        pidMap.updateMapEndOffset(appendInfo.lastOffset + 1)
+
         // increment the log end offset
         updateLogEndOffset(appendInfo.lastOffset + 1)
 
@@ -678,7 +675,7 @@ class Log(@volatile var dir: File,
               firstOffset = lastEntry.firstOffset
               lastOffset = lastEntry.lastOffset
               maxTimestamp = lastEntry.timestamp
-              info(s"Detected a duplicate for partition $topicPartition at (firstOffset, lastOffset): (${firstOffset}, ${lastOffset}). " +
+              debug(s"Detected a duplicate for partition $topicPartition at (firstOffset, lastOffset): ($firstOffset, $lastOffset). " +
                 "Ignoring the incoming record.")
             } else {
               val producerAppendInfo = new ProducerAppendInfo(pid, lastEntry)
@@ -860,9 +857,10 @@ class Log(@volatile var dir: File,
         roll()
       lock synchronized {
         // remove the segments for lookups
-        deletable.foreach(deleteSegment(_))
+        deletable.foreach(deleteSegment)
         logStartOffset = math.max(logStartOffset, segments.firstEntry().getValue.baseOffset)
         leaderEpochCache.clearEarliest(logStartOffset)
+        pidMap.expirePids(logStartOffset)
       }
     }
     numToDelete
@@ -1075,6 +1073,12 @@ class Log(@volatile var dir: File,
     }
   }
 
+  private[log] def maybeTakePidSnapshot(): Unit = pidMap.maybeTakeSnapshot()
+
+  private[log] def latestPidSnapshotOffset: Option[Long] = pidMap.latestSnapshotOffset
+
+  private[log] def latestPidMapOffset: Long = pidMap.mapEndOffset
+
   /**
    * Truncate this log so that it ends with the greatest offset < targetOffset.
    *
@@ -1093,14 +1097,14 @@ class Log(@volatile var dir: File,
         truncateFullyAndStartAt(targetOffset)
       } else {
         val deletable = logSegments.filter(segment => segment.baseOffset > targetOffset)
-        deletable.foreach(deleteSegment(_))
+        deletable.foreach(deleteSegment)
         activeSegment.truncateTo(targetOffset)
         updateLogEndOffset(targetOffset)
         this.recoveryPoint = math.min(targetOffset, this.recoveryPoint)
         this.logStartOffset = math.min(targetOffset, this.logStartOffset)
         leaderEpochCache.clearLatest(targetOffset)
+        buildAndRecoverPidMap(targetOffset)
       }
-      buildAndRecoverPidMap(targetOffset)
     }
   }
 
@@ -1110,10 +1114,10 @@ class Log(@volatile var dir: File,
    *  @param newOffset The new offset to start the log with
    */
   private[log] def truncateFullyAndStartAt(newOffset: Long) {
-    debug("Truncate and start log '" + name + "' to " + newOffset)
+    debug(s"Truncate and start log '$name' at offset $newOffset")
     lock synchronized {
       val segmentsToDelete = logSegments.toList
-      segmentsToDelete.foreach(deleteSegment(_))
+      segmentsToDelete.foreach(deleteSegment)
       addSegment(new LogSegment(dir,
                                 newOffset,
                                 indexIntervalBytes = config.indexInterval,
@@ -1125,6 +1129,10 @@ class Log(@volatile var dir: File,
                                 preallocate = config.preallocate))
       updateLogEndOffset(newOffset)
       leaderEpochCache.clear()
+
+      pidMap.truncate()
+      pidMap.updateMapEndOffset(newOffset)
+
       this.recoveryPoint = math.min(newOffset, this.recoveryPoint)
       this.logStartOffset = newOffset
     }
@@ -1272,6 +1280,8 @@ object Log {
   /** a time index file */
   val TimeIndexFileSuffix = ".timeindex"
 
+  val PidSnapshotFileSuffix = ".snapshot"
+
   /** a file that is scheduled to be deleted */
   val DeletedFileSuffix = ".deleted"
 
@@ -1334,6 +1344,18 @@ object Log {
     new File(dir, filenamePrefixFromOffset(offset) + TimeIndexFileSuffix)
 
   /**
+   * Construct a PID snapshot file using the given offset.
+   *
+   * @param dir The directory in which the log will reside
+   * @param offset The last offset (exclusive) included in the snapshot
+   */
+  def pidSnapshotFilename(dir: File, offset: Long) =
+    new File(dir, filenamePrefixFromOffset(offset) + PidSnapshotFileSuffix)
+
+  def offsetFromFilename(filename: String): Long =
+    filename.substring(0, filename.indexOf('.')).toLong
+
+  /**
    * Parse the topic and partition out of the directory name of a log
    */
   def parseTopicPartitionName(dir: File): TopicPartition = {

http://git-wip-us.apache.org/repos/asf/kafka/blob/588ed464/core/src/main/scala/kafka/log/ProducerIdMapping.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/log/ProducerIdMapping.scala b/core/src/main/scala/kafka/log/ProducerIdMapping.scala
index 4c2eb7f..fc6ff0b 100644
--- a/core/src/main/scala/kafka/log/ProducerIdMapping.scala
+++ b/core/src/main/scala/kafka/log/ProducerIdMapping.scala
@@ -35,9 +35,9 @@ private[log] object ProducerIdEntry {
     -1, 0, RecordBatch.NO_TIMESTAMP)
 }
 
-private[log] case class ProducerIdEntry(epoch: Short, lastSeq: Int, lastOffset: Long, numRecords: Int, timestamp: Long) {
-  def firstSeq: Int = lastSeq - numRecords + 1
-  def firstOffset: Long = lastOffset - numRecords + 1
+private[log] case class ProducerIdEntry(epoch: Short, lastSeq: Int, lastOffset: Long, offsetDelta: Int, timestamp: Long) {
+  def firstSeq: Int = lastSeq - offsetDelta
+  def firstOffset: Long = lastOffset - offsetDelta
 
   def isDuplicate(batch: RecordBatch): Boolean = {
     batch.producerEpoch == epoch &&
@@ -92,24 +92,20 @@ private[log] class ProducerAppendInfo(val pid: Long, initialEntry: ProducerIdEnt
     append(entry.epoch, entry.firstSeq, entry.lastSeq, entry.timestamp, entry.lastOffset)
 
   def lastEntry: ProducerIdEntry =
-    ProducerIdEntry(epoch, lastSeq, lastOffset, lastSeq - firstSeq + 1, maxTimestamp)
+    ProducerIdEntry(epoch, lastSeq, lastOffset, lastSeq - firstSeq, maxTimestamp)
 }
 
 private[log] class CorruptSnapshotException(msg: String) extends KafkaException(msg)
 
 object ProducerIdMapping {
-  private val DirnamePrefix = "pid-mapping"
-  private val FilenameSuffix = "snapshot"
-  private val FilenamePattern = s"^\\d{1,}.$FilenameSuffix".r
   private val PidSnapshotVersion: Short = 1
-
   private val VersionField = "version"
   private val CrcField = "crc"
   private val PidField = "pid"
   private val LastSequenceField = "last_sequence"
   private val EpochField = "epoch"
   private val LastOffsetField = "last_offset"
-  private val NumRecordsField = "num_records"
+  private val OffsetDeltaField = "offset_delta"
   private val TimestampField = "timestamp"
   private val PidEntriesField = "pid_entries"
 
@@ -124,15 +120,14 @@ object ProducerIdMapping {
     new Field(EpochField, Type.INT16, "Current epoch of the producer"),
     new Field(LastSequenceField, Type.INT32, "Last written sequence of the producer"),
     new Field(LastOffsetField, Type.INT64, "Last written offset of the producer"),
-    new Field(NumRecordsField, Type.INT32, "The number of records written in the last log entry"),
+    new Field(OffsetDeltaField, Type.INT32, "The difference of the last sequence and first sequence in the last written batch"),
     new Field(TimestampField, Type.INT64, "Max timestamp from the last written entry"))
   val PidSnapshotMapSchema = new Schema(
     new Field(VersionField, Type.INT16, "Version of the snapshot file"),
     new Field(CrcField, Type.UNSIGNED_INT32, "CRC of the snapshot data"),
     new Field(PidEntriesField, new ArrayOf(PidSnapshotEntrySchema), "The entries in the PID table"))
 
-  private def loadSnapshot(file: File, pidMap: mutable.Map[Long, ProducerIdEntry],
-                           checkNotExpired: (ProducerIdEntry) => Boolean) {
+  private def readSnapshot(file: File): Iterable[(Long, ProducerIdEntry)] = {
     val buffer = Files.readAllBytes(file.toPath)
     val struct = PidSnapshotMapSchema.read(ByteBuffer.wrap(buffer))
 
@@ -145,17 +140,16 @@ object ProducerIdMapping {
     if (crc != computedCrc)
       throw new CorruptSnapshotException(s"Snapshot file is corrupted (CRC is no longer valid). Stored crc: ${crc}. Computed crc: ${computedCrc}")
 
-    struct.getArray(PidEntriesField).foreach { pidEntryObj =>
+    struct.getArray(PidEntriesField).map { pidEntryObj =>
       val pidEntryStruct = pidEntryObj.asInstanceOf[Struct]
-      val pid = pidEntryStruct.getLong(PidField)
+      val pid: Long = pidEntryStruct.getLong(PidField)
       val epoch = pidEntryStruct.getShort(EpochField)
       val seq = pidEntryStruct.getInt(LastSequenceField)
       val offset = pidEntryStruct.getLong(LastOffsetField)
       val timestamp = pidEntryStruct.getLong(TimestampField)
-      val numRecords = pidEntryStruct.getInt(NumRecordsField)
-      val newEntry = ProducerIdEntry(epoch, seq, offset, numRecords, timestamp)
-      if (checkNotExpired(newEntry))
-        pidMap.put(pid, newEntry)
+      val offsetDelta = pidEntryStruct.getInt(OffsetDeltaField)
+      val newEntry = ProducerIdEntry(epoch, seq, offset, offsetDelta, timestamp)
+      pid -> newEntry
     }
   }
 
@@ -170,7 +164,7 @@ object ProducerIdMapping {
           .set(EpochField, entry.epoch)
           .set(LastSequenceField, entry.lastSeq)
           .set(LastOffsetField, entry.lastOffset)
-          .set(NumRecordsField, entry.numRecords)
+          .set(OffsetDeltaField, entry.offsetDelta)
           .set(TimestampField, entry.timestamp)
         pidEntryStruct
     }.toArray
@@ -192,16 +186,7 @@ object ProducerIdMapping {
     }
   }
 
-  private def verifyFileName(name: String): Boolean = FilenamePattern.findFirstIn(name).isDefined
-
-  private def offsetFromFile(file: File): Long = {
-    s"${file.getName.replace(s".$FilenameSuffix", "")}".toLong
-  }
-
-  private def formatFileName(lastOffset: Long): String = {
-    // The files will be named '$lastOffset.snapshot' and located in 'logDir/pid-mapping'
-    s"$lastOffset.$FilenameSuffix"
-  }
+  private def isSnapshotFile(name: String): Boolean = name.endsWith(Log.PidSnapshotFileSuffix)
 
 }
 
@@ -212,17 +197,22 @@ object ProducerIdMapping {
  * The sequence number is the last number successfully appended to the partition for the given identifier.
  * The epoch is used for fencing against zombie writers. The offset is the one of the last successful message
  * appended to the partition.
+ *
+ * As long as a PID is contained in the map, the corresponding producer can continue to write data.
+ * However, PIDs can be expired due to lack of recent use or if the last written entry has been deleted from
+ * the log (e.g. if the retention policy is "delete"). For compacted topics, the log cleaner will ensure
+ * that the most recent entry from a given PID is retained in the log provided it hasn't expired due to
+ * age. This ensures that PIDs will not be expired until either the max expiration time has been reached,
+ * or if the topic also is configured for deletion, the segment containing the last written offset has
+ * been deleted.
  */
 @nonthreadsafe
 class ProducerIdMapping(val config: LogConfig,
                         val topicPartition: TopicPartition,
-                        val snapParentDir: File,
+                        val logDir: File,
                         val maxPidExpirationMs: Int) extends Logging {
   import ProducerIdMapping._
 
-  val snapDir: File = new File(snapParentDir, DirnamePrefix)
-  Files.createDirectories(snapDir.toPath)
-
   private val pidMap = mutable.Map[Long, ProducerIdEntry]()
   private var lastMapOffset = 0L
   private var lastSnapOffset = 0L
@@ -237,53 +227,59 @@ class ProducerIdMapping(val config: LogConfig,
    */
   def activePids: immutable.Map[Long, ProducerIdEntry] = pidMap.toMap
 
-  /**
-   * Load a snapshot of the id mapping or return empty maps
-   * in the case the snapshot doesn't exist (first time).
-   */
-  private def loadFromSnapshot(logEndOffset: Long, checkNotExpired:(ProducerIdEntry) => Boolean) {
+  private def loadFromSnapshot(logStartOffset: Long, currentTime: Long) {
     pidMap.clear()
 
-    var loaded = false
-    while (!loaded) {
-      lastSnapshotFile(logEndOffset) match {
+    while (true) {
+      latestSnapshotFile match {
         case Some(file) =>
           try {
-            loadSnapshot(file, pidMap, checkNotExpired)
-            lastSnapOffset = offsetFromFile(file)
+            info(s"Loading PID mapping from snapshot file ${file.getName} for partition $topicPartition")
+            readSnapshot(file).foreach { case (pid, entry) =>
+              if (!isExpired(currentTime, entry))
+                pidMap.put(pid, entry)
+            }
+
+            lastSnapOffset = Log.offsetFromFilename(file.getName)
             lastMapOffset = lastSnapOffset
-            loaded = true
+            return
           } catch {
             case e: CorruptSnapshotException =>
-              error(s"Snapshot file at $file is corrupt: ${e.getMessage}")
-              try Files.delete(file.toPath)
-              catch {
-                case e: IOException => error(s"Failed to delete corrupt snapshot file $file", e)
-              }
+              error(s"Snapshot file at ${file.getPath} is corrupt: ${e.getMessage}")
+              Files.deleteIfExists(file.toPath)
           }
         case None =>
-          lastSnapOffset = 0L
-          lastMapOffset = 0L
-          Files.createDirectories(snapDir.toPath)
-          loaded = true
+          lastSnapOffset = logStartOffset
+          lastMapOffset = logStartOffset
+          return
       }
     }
   }
 
-  def isEntryValid(currentTimeMs: Long, producerIdEntry: ProducerIdEntry) : Boolean = {
-    currentTimeMs - producerIdEntry.timestamp < maxPidExpirationMs
-  }
+  private def isExpired(currentTimeMs: Long, producerIdEntry: ProducerIdEntry) : Boolean =
+    currentTimeMs - producerIdEntry.timestamp >= maxPidExpirationMs
+
 
-  def checkForExpiredPids(currentTimeMs: Long) {
+  def removeExpiredPids(currentTimeMs: Long) {
     pidMap.retain { case (pid, lastEntry) =>
-      isEntryValid(currentTimeMs, lastEntry)
+      !isExpired(currentTimeMs, lastEntry)
     }
   }
 
-  def truncateAndReload(logEndOffset: Long, currentTime: Long) {
-    truncateSnapshotFiles(logEndOffset)
-    def checkNotExpired = (producerIdEntry: ProducerIdEntry) => { isEntryValid(currentTime, producerIdEntry) }
-    loadFromSnapshot(logEndOffset, checkNotExpired)
+  /**
+   * Truncate the PID mapping to the given offset range and reload the entries from the most recent
+   * snapshot in range (if there is one).
+   */
+  def truncateAndReload(logStartOffset: Long, logEndOffset: Long, currentTimeMs: Long) {
+    if (logEndOffset != mapEndOffset) {
+      deleteSnapshotFiles { file =>
+        val offset = Log.offsetFromFilename(file.getName)
+        offset > logEndOffset || offset <= logStartOffset
+      }
+      loadFromSnapshot(logStartOffset, currentTimeMs)
+    } else {
+      expirePids(logStartOffset)
+    }
   }
 
   /**
@@ -294,7 +290,10 @@ class ProducerIdMapping(val config: LogConfig,
       throw new IllegalArgumentException("Invalid PID passed to update")
     val entry = appendInfo.lastEntry
     pidMap.put(appendInfo.pid, entry)
-    lastMapOffset = entry.lastOffset + 1
+  }
+
+  def updateMapEndOffset(lastOffset: Long): Unit = {
+    lastMapOffset = lastOffset
   }
 
   /**
@@ -302,10 +301,8 @@ class ProducerIdMapping(val config: LogConfig,
    * than the current time minus the PID expiration time (i.e. if the PID has expired).
    */
   def load(pid: Long, entry: ProducerIdEntry, currentTimeMs: Long) {
-    if (pid != RecordBatch.NO_PRODUCER_ID && currentTimeMs - entry.timestamp < maxPidExpirationMs) {
+    if (pid != RecordBatch.NO_PRODUCER_ID && !isExpired(currentTimeMs, entry))
       pidMap.put(pid, entry)
-      lastMapOffset = entry.lastOffset + 1
-    }
   }
 
   /**
@@ -314,84 +311,74 @@ class ProducerIdMapping(val config: LogConfig,
   def lastEntry(pid: Long): Option[ProducerIdEntry] = pidMap.get(pid)
 
   /**
-    * Serialize and write the bytes to a file. The file name is a concatenation of:
-    *   - offset
-    *   - a ".snapshot" suffix
-    *
-    *  The snapshot files are located in the logDirectory, inside a 'pid-mapping' sub directory.
-    */
+   * Write a new snapshot if there have been updates since the last one.
+   */
   def maybeTakeSnapshot() {
     // If not a new offset, then it is not worth taking another snapshot
     if (lastMapOffset > lastSnapOffset) {
-      val file = new File(snapDir, formatFileName(lastMapOffset))
-      writeSnapshot(file, pidMap)
+      val snapshotFile = Log.pidSnapshotFilename(logDir, lastMapOffset)
+      debug(s"Writing producer snapshot for partition $topicPartition at offset $lastMapOffset")
+      writeSnapshot(snapshotFile, pidMap)
 
       // Update the last snap offset according to the serialized map
       lastSnapOffset = lastMapOffset
 
-      maybeRemove()
+      maybeRemoveOldestSnapshot()
     }
   }
 
   /**
-    * When we remove the head of the log due to retention, we need to
-    * clean up the id map. This method takes the new start offset and
-    * expires all ids that have a smaller offset.
-    *
-    * @param startOffset New start offset for the log associated to
-    *                    this id map instance
-    */
-  def cleanFrom(startOffset: Long) {
-    pidMap.retain((pid, entry) => entry.firstOffset >= startOffset)
-    if (pidMap.isEmpty)
-      lastMapOffset = -1L
+   * Get the last offset (exclusive) of the latest snapshot file.
+   */
+  def latestSnapshotOffset: Option[Long] = latestSnapshotFile.map(file => Log.offsetFromFilename(file.getName))
+
+  /**
+   * When we remove the head of the log due to retention, we need to clean up the id map. This method takes
+   * the new start offset and expires all pids which have a smaller last written offset.
+   */
+  def expirePids(logStartOffset: Long) {
+    pidMap.retain((pid, entry) => entry.lastOffset >= logStartOffset)
+    deleteSnapshotFiles(file => Log.offsetFromFilename(file.getName) <= logStartOffset)
+    if (lastMapOffset < logStartOffset)
+      lastMapOffset = logStartOffset
+    lastSnapOffset = latestSnapshotOffset.getOrElse(logStartOffset)
   }
 
-  private def maybeRemove() {
-    val list = listSnapshotFiles()
+  /**
+   * Truncate the PID mapping and remove all snapshots. This resets the state of the mapping.
+   */
+  def truncate() {
+    pidMap.clear()
+    deleteSnapshotFiles()
+    lastSnapOffset = 0L
+    lastMapOffset = 0L
+  }
+
+  private def maybeRemoveOldestSnapshot() {
+    val list = listSnapshotFiles
     if (list.size > maxPidSnapshotsToRetain) {
-      // Get file with the smallest offset
-      val toDelete = list.minBy(offsetFromFile)
-      // Delete the last
+      val toDelete = list.minBy(file => Log.offsetFromFilename(file.getName))
       Files.deleteIfExists(toDelete.toPath)
     }
   }
 
-  private def listSnapshotFiles(): List[File] = {
-    if (snapDir.exists && snapDir.isDirectory)
-      snapDir.listFiles.filter(f => f.isFile && verifyFileName(f.getName)).toList
+  private def listSnapshotFiles: List[File] = {
+    if (logDir.exists && logDir.isDirectory)
+      logDir.listFiles.filter(f => f.isFile && isSnapshotFile(f.getName)).toList
     else
       List.empty[File]
   }
 
-  /**
-   * Returns the last valid snapshot with offset smaller than the base offset provided as
-   * a constructor parameter for loading.
-   */
-  private def lastSnapshotFile(maxOffset: Long): Option[File] = {
-    val files = listSnapshotFiles()
-    if (files != null && files.nonEmpty) {
-      val targetOffset = files.foldLeft(0L) { (accOffset, file) =>
-        val snapshotLastOffset = offsetFromFile(file)
-        if ((maxOffset >= snapshotLastOffset) && (snapshotLastOffset > accOffset))
-          snapshotLastOffset
-        else
-          accOffset
-      }
-      val snap = new File(snapDir, formatFileName(targetOffset))
-      if (snap.exists)
-        Some(snap)
-      else
-        None
-    } else
+  private def latestSnapshotFile: Option[File] = {
+    val files = listSnapshotFiles
+    if (files.nonEmpty)
+      Some(files.maxBy(file => Log.offsetFromFilename(file.getName)))
+    else
       None
   }
 
-  private def truncateSnapshotFiles(maxOffset: Long) {
-    listSnapshotFiles().foreach { file =>
-      val snapshotLastOffset = offsetFromFile(file)
-      if (snapshotLastOffset >= maxOffset)
-        file.delete()
-    }
+  private def deleteSnapshotFiles(predicate: File => Boolean = _ => true) {
+    listSnapshotFiles.filter(predicate).foreach(file => Files.deleteIfExists(file.toPath))
   }
+
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/588ed464/core/src/test/scala/unit/kafka/log/LogManagerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/log/LogManagerTest.scala b/core/src/test/scala/unit/kafka/log/LogManagerTest.scala
index b0eb29a..2f9396f 100755
--- a/core/src/test/scala/unit/kafka/log/LogManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogManagerTest.scala
@@ -21,7 +21,7 @@ import java.io._
 import java.util.Properties
 
 import kafka.common._
-import kafka.server.checkpoints.{OffsetCheckpoint, OffsetCheckpointFile}
+import kafka.server.checkpoints.OffsetCheckpointFile
 import kafka.utils._
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.errors.OffsetOutOfRangeException
@@ -103,8 +103,8 @@ class LogManagerTest {
     assertEquals("Now there should only be only one segment in the index.", 1, log.numberOfSegments)
     time.sleep(log.config.fileDeleteDelayMs + 1)
 
-    //There should be a log file, two indexes, the leader epoch checkpoint and the pid snapshot dir
-    assertEquals("Files should have been deleted", log.numberOfSegments * 3 + 2, log.dir.list.length)
+    // there should be a log file, two indexes, and the leader epoch checkpoint
+    assertEquals("Files should have been deleted", log.numberOfSegments * 3 + 1, log.dir.list.length)
     assertEquals("Should get empty fetch off new log.", 0, log.read(offset+1, 1024).records.sizeInBytes)
 
     try {

http://git-wip-us.apache.org/repos/asf/kafka/blob/588ed464/core/src/test/scala/unit/kafka/log/LogTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/log/LogTest.scala b/core/src/test/scala/unit/kafka/log/LogTest.scala
index d42abd4..0b1d299 100755
--- a/core/src/test/scala/unit/kafka/log/LogTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogTest.scala
@@ -27,11 +27,10 @@ import org.junit.Assert._
 import org.junit.{After, Before, Test}
 import kafka.utils._
 import kafka.server.KafkaConfig
-import kafka.server.epoch.{EpochEntry, LeaderEpochCache, LeaderEpochFileCache}
-import org.apache.kafka.common.record.{RecordBatch, _}
+import kafka.server.epoch.{EpochEntry, LeaderEpochFileCache}
+import org.apache.kafka.common.record.MemoryRecords.RecordFilter
+import org.apache.kafka.common.record._
 import org.apache.kafka.common.utils.Utils
-import org.easymock.EasyMock
-import org.easymock.EasyMock._
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable.ListBuffer
@@ -62,6 +61,23 @@ class LogTest {
     }
   }
 
+  @Test
+  def testOffsetFromFilename() {
+    val offset = 23423423L
+
+    val logFile = Log.logFilename(tmpDir, offset)
+    assertEquals(offset, Log.offsetFromFilename(logFile.getName))
+
+    val offsetIndexFile = Log.indexFilename(tmpDir, offset)
+    assertEquals(offset, Log.offsetFromFilename(offsetIndexFile.getName))
+
+    val timeIndexFile = Log.timeIndexFilename(tmpDir, offset)
+    assertEquals(offset, Log.offsetFromFilename(timeIndexFile.getName))
+
+    val snapshotFile = Log.pidSnapshotFilename(tmpDir, offset)
+    assertEquals(offset, Log.offsetFromFilename(snapshotFile.getName))
+  }
+
   /**
    * Tests for time based log roll. This test appends messages then changes the time
    * using the mock clock to force the log to roll and checks the number of segments.
@@ -145,6 +161,190 @@ class LogTest {
   }
 
   @Test
+  def testPidMapOffsetUpdatedForNonIdempotentData() {
+    val log = createLog(2048)
+    val records = TestUtils.records(List(new SimpleRecord(time.milliseconds, "key".getBytes, "value".getBytes)))
+    log.appendAsLeader(records, leaderEpoch = 0)
+    log.maybeTakePidSnapshot()
+    assertEquals(Some(1), log.latestPidSnapshotOffset)
+  }
+
+  @Test
+  def testRebuildPidMapWithCompactedData() {
+    val log = createLog(2048, pidSnapshotIntervalMs = Int.MaxValue)
+    val pid = 1L
+    val epoch = 0.toShort
+    val seq = 0
+    val baseOffset = 23L
+
+    // create a batch with a couple gaps to simulate compaction
+    val records = TestUtils.records(pid = pid, epoch = epoch, sequence = seq, baseOffset = baseOffset, records = List(
+      new SimpleRecord(System.currentTimeMillis(), "a".getBytes),
+      new SimpleRecord(System.currentTimeMillis(), "key".getBytes, "b".getBytes),
+      new SimpleRecord(System.currentTimeMillis(), "c".getBytes),
+      new SimpleRecord(System.currentTimeMillis(), "key".getBytes, "d".getBytes)))
+    records.batches.asScala.foreach(_.setPartitionLeaderEpoch(0))
+
+    val filtered = ByteBuffer.allocate(2048)
+    records.filterTo(new RecordFilter {
+      override def shouldRetain(recordBatch: RecordBatch, record: Record): Boolean = !record.hasKey
+    }, filtered)
+    filtered.flip()
+    val filteredRecords = MemoryRecords.readableRecords(filtered)
+
+    log.appendAsFollower(filteredRecords)
+
+    // append some more data and then truncate to force rebuilding of the PID map
+    val moreRecords = TestUtils.records(baseOffset = baseOffset + 4, records = List(
+      new SimpleRecord(System.currentTimeMillis(), "e".getBytes),
+      new SimpleRecord(System.currentTimeMillis(), "f".getBytes)))
+    moreRecords.batches.asScala.foreach(_.setPartitionLeaderEpoch(0))
+    log.appendAsFollower(moreRecords)
+
+    log.truncateTo(baseOffset + 4)
+
+    val activePids = log.activePids
+    assertTrue(activePids.contains(pid))
+
+    val entry = activePids(pid)
+    assertEquals(0, entry.firstSeq)
+    assertEquals(baseOffset, entry.firstOffset)
+    assertEquals(3, entry.lastSeq)
+    assertEquals(baseOffset + 3, entry.lastOffset)
+  }
+
+  @Test
+  def testUpdatePidMapWithCompactedData() {
+    val log = createLog(2048)
+    val pid = 1L
+    val epoch = 0.toShort
+    val seq = 0
+    val baseOffset = 23L
+
+    // create a batch with a couple gaps to simulate compaction
+    val records = TestUtils.records(pid = pid, epoch = epoch, sequence = seq, baseOffset = baseOffset, records = List(
+      new SimpleRecord(System.currentTimeMillis(), "a".getBytes),
+      new SimpleRecord(System.currentTimeMillis(), "key".getBytes, "b".getBytes),
+      new SimpleRecord(System.currentTimeMillis(), "c".getBytes),
+      new SimpleRecord(System.currentTimeMillis(), "key".getBytes, "d".getBytes)))
+    records.batches.asScala.foreach(_.setPartitionLeaderEpoch(0))
+
+    val filtered = ByteBuffer.allocate(2048)
+    records.filterTo(new RecordFilter {
+      override def shouldRetain(recordBatch: RecordBatch, record: Record): Boolean = !record.hasKey
+    }, filtered)
+    filtered.flip()
+    val filteredRecords = MemoryRecords.readableRecords(filtered)
+
+    log.appendAsFollower(filteredRecords)
+    val activePids = log.activePids
+    assertTrue(activePids.contains(pid))
+
+    val entry = activePids(pid)
+    assertEquals(0, entry.firstSeq)
+    assertEquals(baseOffset, entry.firstOffset)
+    assertEquals(3, entry.lastSeq)
+    assertEquals(baseOffset + 3, entry.lastOffset)
+  }
+
+  @Test
+  def testPidMapTruncateTo() {
+    val log = createLog(2048)
+    log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes))), leaderEpoch = 0)
+    log.appendAsLeader(TestUtils.records(List(new SimpleRecord("b".getBytes))), leaderEpoch = 0)
+    log.maybeTakePidSnapshot()
+
+    log.appendAsLeader(TestUtils.records(List(new SimpleRecord("c".getBytes))), leaderEpoch = 0)
+    log.maybeTakePidSnapshot()
+
+    log.truncateTo(2)
+    assertEquals(Some(2), log.latestPidSnapshotOffset)
+    assertEquals(2, log.latestPidMapOffset)
+
+    log.truncateTo(1)
+    assertEquals(None, log.latestPidSnapshotOffset)
+    assertEquals(1, log.latestPidMapOffset)
+  }
+
+  @Test
+  def testPidMapTruncateFullyAndStartAt() {
+    val records = TestUtils.singletonRecords("foo".getBytes)
+    val log = createLog(records.sizeInBytes, messagesPerSegment = 1, retentionBytes = records.sizeInBytes * 2)
+    log.appendAsLeader(records, leaderEpoch = 0)
+    log.maybeTakePidSnapshot()
+
+    log.appendAsLeader(TestUtils.singletonRecords("bar".getBytes), leaderEpoch = 0)
+    log.appendAsLeader(TestUtils.singletonRecords("baz".getBytes), leaderEpoch = 0)
+    log.maybeTakePidSnapshot()
+
+    assertEquals(3, log.logSegments.size)
+    assertEquals(3, log.latestPidMapOffset)
+    assertEquals(Some(3), log.latestPidSnapshotOffset)
+
+    log.truncateFullyAndStartAt(29)
+    assertEquals(1, log.logSegments.size)
+    assertEquals(None, log.latestPidSnapshotOffset)
+    assertEquals(29, log.latestPidMapOffset)
+  }
+
+  @Test
+  def testPidExpirationOnSegmentDeletion() {
+    val pid1 = 1L
+    val records = TestUtils.records(Seq(new SimpleRecord("foo".getBytes)), pid = pid1, epoch = 0, sequence = 0)
+    val log = createLog(records.sizeInBytes, messagesPerSegment = 1, retentionBytes = records.sizeInBytes * 2)
+    log.appendAsLeader(records, leaderEpoch = 0)
+    log.maybeTakePidSnapshot()
+
+    val pid2 = 2L
+    log.appendAsLeader(TestUtils.records(Seq(new SimpleRecord("bar".getBytes)), pid = pid2, epoch = 0, sequence = 0),
+      leaderEpoch = 0)
+    log.appendAsLeader(TestUtils.records(Seq(new SimpleRecord("baz".getBytes)), pid = pid2, epoch = 0, sequence = 1),
+      leaderEpoch = 0)
+    log.maybeTakePidSnapshot()
+
+    assertEquals(3, log.logSegments.size)
+    assertEquals(Set(pid1, pid2), log.activePids.keySet)
+
+    log.deleteOldSegments()
+
+    assertEquals(2, log.logSegments.size)
+    assertEquals(Set(pid2), log.activePids.keySet)
+  }
+
+  @Test
+  def testPeriodicPidSnapshot() {
+    val snapshotInterval = 100
+    val log = createLog(2048, pidSnapshotIntervalMs = snapshotInterval)
+
+    log.appendAsLeader(TestUtils.singletonRecords("foo".getBytes), leaderEpoch = 0)
+    log.appendAsLeader(TestUtils.singletonRecords("bar".getBytes), leaderEpoch = 0)
+    assertEquals(None, log.latestPidSnapshotOffset)
+
+    time.sleep(snapshotInterval)
+    assertEquals(Some(2), log.latestPidSnapshotOffset)
+  }
+
+  @Test
+  def testPeriodicPidExpiration() {
+    val maxPidExpirationMs = 200
+    val expirationCheckInterval = 100
+
+    val pid = 23L
+    val log = createLog(2048, maxPidExpirationMs = maxPidExpirationMs,
+      pidExpirationCheckIntervalMs = expirationCheckInterval)
+    val records = Seq(new SimpleRecord(time.milliseconds(), "foo".getBytes))
+    log.appendAsLeader(TestUtils.records(records, pid = pid, epoch = 0, sequence = 0), leaderEpoch = 0)
+
+    assertEquals(Set(pid), log.activePids.keySet)
+
+    time.sleep(expirationCheckInterval)
+    assertEquals(Set(pid), log.activePids.keySet)
+
+    time.sleep(expirationCheckInterval)
+    assertEquals(Set(), log.activePids.keySet)
+  }
+
+  @Test
   def testDuplicateAppends(): Unit = {
     val logProps = new Properties()
 
@@ -1772,10 +1972,12 @@ class LogTest {
   }
 
 
-  def createLog(messageSizeInBytes: Int, retentionMs: Int = -1,
-                retentionBytes: Int = -1, cleanupPolicy: String = "delete"): Log = {
+  def createLog(messageSizeInBytes: Int, retentionMs: Int = -1, retentionBytes: Int = -1,
+                cleanupPolicy: String = "delete", messagesPerSegment: Int = 5,
+                maxPidExpirationMs: Int = 300000, pidExpirationCheckIntervalMs: Int = 30000,
+                pidSnapshotIntervalMs: Int = 60000): Log = {
     val logProps = new Properties()
-    logProps.put(LogConfig.SegmentBytesProp, messageSizeInBytes * 5: Integer)
+    logProps.put(LogConfig.SegmentBytesProp, messageSizeInBytes * messagesPerSegment: Integer)
     logProps.put(LogConfig.RetentionMsProp, retentionMs: Integer)
     logProps.put(LogConfig.RetentionBytesProp, retentionBytes: Integer)
     logProps.put(LogConfig.CleanupPolicyProp, cleanupPolicy)
@@ -1786,7 +1988,10 @@ class LogTest {
       logStartOffset = 0L,
       recoveryPoint = 0L,
       scheduler = time.scheduler,
-      time = time)
+      time = time,
+      maxPidExpirationMs = maxPidExpirationMs,
+      pidExpirationCheckIntervalMs = pidExpirationCheckIntervalMs,
+      pidSnapshotIntervalMs = pidSnapshotIntervalMs)
     log
   }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/588ed464/core/src/test/scala/unit/kafka/log/ProducerIdMappingTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/log/ProducerIdMappingTest.scala b/core/src/test/scala/unit/kafka/log/ProducerIdMappingTest.scala
index 3b78921..1bf983c 100644
--- a/core/src/test/scala/unit/kafka/log/ProducerIdMappingTest.scala
+++ b/core/src/test/scala/unit/kafka/log/ProducerIdMappingTest.scala
@@ -39,14 +39,8 @@ class ProducerIdMappingTest extends JUnitSuite {
 
   @Before
   def setUp(): Unit = {
-    // Create configuration including number of snapshots to hold
-    val props = new Properties()
-    config = LogConfig(props)
-
-    // Create temporary directory
+    config = LogConfig(new Properties)
     idMappingDir = TestUtils.tempDir()
-
-    // Instantiate IdMapping
     idMapping = new ProducerIdMapping(config, partition, idMappingDir, maxPidExpirationMs)
   }
 
@@ -105,7 +99,7 @@ class ProducerIdMappingTest extends JUnitSuite {
     checkAndUpdate(idMapping, pid, 1, epoch, 1L, time.milliseconds)
     idMapping.maybeTakeSnapshot()
     val recoveredMapping = new ProducerIdMapping(config, partition, idMappingDir, maxPidExpirationMs)
-    recoveredMapping.truncateAndReload(3L, time.milliseconds)
+    recoveredMapping.truncateAndReload(0L, 3L, time.milliseconds)
 
     // entry added after recovery
     checkAndUpdate(recoveredMapping, pid, 2, epoch, 2L, time.milliseconds)
@@ -119,28 +113,97 @@ class ProducerIdMappingTest extends JUnitSuite {
 
     idMapping.maybeTakeSnapshot()
     val recoveredMapping = new ProducerIdMapping(config, partition, idMappingDir, maxPidExpirationMs)
-    recoveredMapping.truncateAndReload(1L, 70000)
+    recoveredMapping.truncateAndReload(0L, 1L, 70000)
 
     // entry added after recovery. The pid should be expired now, and would not exist in the pid mapping. Hence
     // we should get an out of order sequence exception.
     checkAndUpdate(recoveredMapping, pid, 2, epoch, 2L, 70001)
   }
 
-
   @Test
   def testRemoveOldSnapshot(): Unit = {
     val epoch = 0.toShort
-    checkAndUpdate(idMapping, pid, 0, epoch, 0L, 1L)
-    checkAndUpdate(idMapping, pid, 1, epoch, 1L, 2L)
 
+    checkAndUpdate(idMapping, pid, 0, epoch, 0L)
+    checkAndUpdate(idMapping, pid, 1, epoch, 1L)
+    idMapping.maybeTakeSnapshot()
+    assertEquals(1, idMappingDir.listFiles().length)
+    assertEquals(Set(2), currentSnapshotOffsets)
+
+    checkAndUpdate(idMapping, pid, 2, epoch, 2L)
+    idMapping.maybeTakeSnapshot()
+    assertEquals(2, idMappingDir.listFiles().length)
+    assertEquals(Set(2, 3), currentSnapshotOffsets)
+
+    // we only retain two snapshot files, so the next snapshot should cause the oldest to be deleted
+    checkAndUpdate(idMapping, pid, 3, epoch, 3L)
     idMapping.maybeTakeSnapshot()
+    assertEquals(2, idMappingDir.listFiles().length)
+    assertEquals(Set(3, 4), currentSnapshotOffsets)
+  }
 
-    checkAndUpdate(idMapping, pid, 2, epoch, 2L, 3L)
+  @Test
+  def testTruncate(): Unit = {
+    val epoch = 0.toShort
+
+    checkAndUpdate(idMapping, pid, 0, epoch, 0L)
+    checkAndUpdate(idMapping, pid, 1, epoch, 1L)
+    idMapping.maybeTakeSnapshot()
+    assertEquals(1, idMappingDir.listFiles().length)
+    assertEquals(Set(2), currentSnapshotOffsets)
 
+    checkAndUpdate(idMapping, pid, 2, epoch, 2L)
     idMapping.maybeTakeSnapshot()
+    assertEquals(2, idMappingDir.listFiles().length)
+    assertEquals(Set(2, 3), currentSnapshotOffsets)
 
-    assertEquals(s"number of snapshot files is incorrect: ${idMappingDir.listFiles().length}",
-               1, idMappingDir.listFiles().length)
+    idMapping.truncate()
+
+    assertEquals(0, idMappingDir.listFiles().length)
+    assertEquals(Set(), currentSnapshotOffsets)
+
+    checkAndUpdate(idMapping, pid, 0, epoch, 0L)
+    idMapping.maybeTakeSnapshot()
+    assertEquals(1, idMappingDir.listFiles().length)
+    assertEquals(Set(1), currentSnapshotOffsets)
+  }
+
+  @Test
+  def testExpirePids(): Unit = {
+    val epoch = 0.toShort
+
+    checkAndUpdate(idMapping, pid, 0, epoch, 0L)
+    checkAndUpdate(idMapping, pid, 1, epoch, 1L)
+    idMapping.maybeTakeSnapshot()
+
+    val anotherPid = 2L
+    checkAndUpdate(idMapping, anotherPid, 0, epoch, 2L)
+    checkAndUpdate(idMapping, anotherPid, 1, epoch, 3L)
+    idMapping.maybeTakeSnapshot()
+    assertEquals(Set(2, 4), currentSnapshotOffsets)
+
+    idMapping.expirePids(2)
+    assertEquals(Set(4), currentSnapshotOffsets)
+    assertEquals(Set(anotherPid), idMapping.activePids.keySet)
+    assertEquals(None, idMapping.lastEntry(pid))
+
+    val maybeEntry = idMapping.lastEntry(anotherPid)
+    assertTrue(maybeEntry.isDefined)
+    assertEquals(3L, maybeEntry.get.lastOffset)
+
+    idMapping.expirePids(3)
+    assertEquals(Set(anotherPid), idMapping.activePids.keySet)
+    assertEquals(Set(4), currentSnapshotOffsets)
+    assertEquals(4, idMapping.mapEndOffset)
+
+    idMapping.expirePids(5)
+    assertEquals(Set(), idMapping.activePids.keySet)
+    assertEquals(Set(), currentSnapshotOffsets)
+    assertEquals(5, idMapping.mapEndOffset)
+
+    idMapping.maybeTakeSnapshot()
+    // shouldn't be any new snapshot because the log is empty
+    assertEquals(Set(), currentSnapshotOffsets)
   }
 
   @Test
@@ -149,12 +212,13 @@ class ProducerIdMappingTest extends JUnitSuite {
     checkAndUpdate(idMapping, pid, 0, epoch, 0L, 0L)
 
     idMapping.maybeTakeSnapshot()
+    assertEquals(1, idMappingDir.listFiles().length)
+    assertEquals(Set(1), currentSnapshotOffsets)
 
     // nothing changed so there should be no new snapshot
     idMapping.maybeTakeSnapshot()
-
-    assertEquals(s"number of snapshot files is incorrect: ${idMappingDir.listFiles().length}",
-      1, idMappingDir.listFiles().length)
+    assertEquals(1, idMappingDir.listFiles().length)
+    assertEquals(Set(1), currentSnapshotOffsets)
   }
 
   @Test
@@ -169,7 +233,7 @@ class ProducerIdMappingTest extends JUnitSuite {
 
     intercept[OutOfOrderSequenceException] {
       val recoveredMapping = new ProducerIdMapping(config, partition, idMappingDir, maxPidExpirationMs)
-      recoveredMapping.truncateAndReload(1L, time.milliseconds)
+      recoveredMapping.truncateAndReload(0L, 1L, time.milliseconds)
       checkAndUpdate(recoveredMapping, pid2, 1, epoch, 4L, 5L)
     }
   }
@@ -180,7 +244,7 @@ class ProducerIdMappingTest extends JUnitSuite {
     val sequence = 37
     checkAndUpdate(idMapping, pid, sequence, epoch, 1L)
     time.sleep(maxPidExpirationMs + 1)
-    idMapping.checkForExpiredPids(time.milliseconds)
+    idMapping.removeExpiredPids(time.milliseconds)
     checkAndUpdate(idMapping, pid, sequence + 1, epoch, 1L)
   }
 
@@ -213,11 +277,15 @@ class ProducerIdMappingTest extends JUnitSuite {
                              epoch: Short,
                              lastOffset: Long,
                              timestamp: Long = time.milliseconds()): Unit = {
-    val numRecords = 1
-    val incomingPidEntry = ProducerIdEntry(epoch, seq, lastOffset, numRecords, timestamp)
+    val offsetDelta = 0
+    val incomingPidEntry = ProducerIdEntry(epoch, seq, lastOffset, offsetDelta, timestamp)
     val producerAppendInfo = new ProducerAppendInfo(pid, mapping.lastEntry(pid).getOrElse(ProducerIdEntry.Empty))
     producerAppendInfo.append(incomingPidEntry)
     mapping.update(producerAppendInfo)
+    mapping.updateMapEndOffset(lastOffset + 1)
   }
 
+  private def currentSnapshotOffsets =
+    idMappingDir.listFiles().map(file => Log.offsetFromFilename(file.getName)).toSet
+
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/588ed464/core/src/test/scala/unit/kafka/utils/TestUtils.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/utils/TestUtils.scala b/core/src/test/scala/unit/kafka/utils/TestUtils.scala
index eb681ec..02b5fe3 100755
--- a/core/src/test/scala/unit/kafka/utils/TestUtils.scala
+++ b/core/src/test/scala/unit/kafka/utils/TestUtils.scala
@@ -332,9 +332,10 @@ object TestUtils extends Logging {
               codec: CompressionType = CompressionType.NONE,
               pid: Long = RecordBatch.NO_PRODUCER_ID,
               epoch: Short = RecordBatch.NO_PRODUCER_EPOCH,
-              sequence: Int = RecordBatch.NO_SEQUENCE): MemoryRecords = {
+              sequence: Int = RecordBatch.NO_SEQUENCE,
+              baseOffset: Long = 0L): MemoryRecords = {
     val buf = ByteBuffer.allocate(DefaultRecordBatch.sizeInBytes(records.asJava))
-    val builder = MemoryRecords.builder(buf, magicValue, codec, TimestampType.CREATE_TIME, 0L,
+    val builder = MemoryRecords.builder(buf, magicValue, codec, TimestampType.CREATE_TIME, baseOffset,
       System.currentTimeMillis, pid, epoch, sequence)
     records.foreach(builder.append)
     builder.build()