You are viewing a plain text version of this content. The canonical link for it is here.
Posted to by on 2017/05/06 18:51:11 UTC

[2/6] kafka git commit: KAFKA-5121; Implement transaction index for KIP-98
diff --git a/core/src/test/scala/unit/kafka/log/LogCleanerLagIntegrationTest.scala b/core/src/test/scala/unit/kafka/log/LogCleanerLagIntegrationTest.scala
index eb3f50c..bf634d7 100644
--- a/core/src/test/scala/unit/kafka/log/LogCleanerLagIntegrationTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogCleanerLagIntegrationTest.scala
@@ -109,7 +109,8 @@ class LogCleanerLagIntegrationTest(compressionCodecName: String) extends Abstrac
   private def writeDups(numKeys: Int, numDups: Int, log: Log, codec: CompressionType, timestamp: Long): Seq[(Int, Int)] = {
     for (_ <- 0 until numDups; key <- 0 until numKeys) yield {
       val count = counter
-      log.appendAsLeader(TestUtils.singletonRecords(value = counter.toString.getBytes, codec = codec, key = key.toString.getBytes, timestamp = timestamp), leaderEpoch = 0)
+      log.appendAsLeader(TestUtils.singletonRecords(value = counter.toString.getBytes, codec = codec,
+              key = key.toString.getBytes, timestamp = timestamp), leaderEpoch = 0)
       counter += 1
       (key, count)
diff --git a/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala b/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
index 44d47c9..fe07fdd 100755
--- a/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
@@ -821,7 +821,7 @@ class LogCleanerTest extends JUnitSuite {
   def record(key: Int, value: Int, pid: Long = RecordBatch.NO_PRODUCER_ID, epoch: Short = RecordBatch.NO_PRODUCER_EPOCH,
              sequence: Int = RecordBatch.NO_SEQUENCE,
              partitionLeaderEpoch: Int = RecordBatch.NO_PARTITION_LEADER_EPOCH): MemoryRecords = {
-    MemoryRecords.withRecords(RecordBatch.CURRENT_MAGIC_VALUE, 0L, CompressionType.NONE, pid, epoch, sequence,
+    MemoryRecords.withIdempotentRecords(RecordBatch.CURRENT_MAGIC_VALUE, 0L, CompressionType.NONE, pid, epoch, sequence,
       partitionLeaderEpoch, new SimpleRecord(key.toString.getBytes, value.toString.getBytes))
diff --git a/core/src/test/scala/unit/kafka/log/LogManagerTest.scala b/core/src/test/scala/unit/kafka/log/LogManagerTest.scala
index 2f9396f..a6fe2e4 100755
--- a/core/src/test/scala/unit/kafka/log/LogManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogManagerTest.scala
@@ -151,8 +151,9 @@ class LogManagerTest {
     assertEquals("Now there should be exactly 6 segments", 6, log.numberOfSegments)
     time.sleep(log.config.fileDeleteDelayMs + 1)
-    //There should be a log file, two indexes, the leader epoch checkpoint and the pid snapshot dir
-    assertEquals("Files should have been deleted", log.numberOfSegments * 3 + 2, log.dir.list.length)
+    // there should be a log file, two indexes (the txn index is created lazily),
+    // the leader epoch checkpoint and two pid mapping files (one for the active and previous segments)
+    assertEquals("Files should have been deleted", log.numberOfSegments * 3 + 3, log.dir.list.length)
     assertEquals("Should get empty fetch off new log.", 0, + 1, 1024).records.sizeInBytes)
     try {, 1024)
diff --git a/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala b/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala
index 3f531d9..4709b77 100644
--- a/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala
@@ -16,31 +16,40 @@
  package kafka.log
 import kafka.utils.TestUtils
 import kafka.utils.TestUtils.checkEquals
+import org.apache.kafka.common.TopicPartition
+import org.apache.kafka.common.record.MemoryRecords.withEndTransactionMarker
 import org.apache.kafka.common.record.{RecordBatch, _}
-import org.apache.kafka.common.utils.Time
+import org.apache.kafka.common.utils.{Time, Utils}
 import org.junit.Assert._
-import org.junit.{After, Test}
+import org.junit.{After, Before, Test}
 import scala.collection.JavaConverters._
 import scala.collection._
 class LogSegmentTest {
+  val topicPartition = new TopicPartition("topic", 0)
   val segments = mutable.ArrayBuffer[LogSegment]()
+  var logDir: File = _
   /* create a segment with the given base offset */
   def createSegment(offset: Long, indexIntervalBytes: Int = 10): LogSegment = {
     val msFile = TestUtils.tempFile()
     val ms =
     val idxFile = TestUtils.tempFile()
     val timeIdxFile = TestUtils.tempFile()
+    val txnIdxFile = TestUtils.tempFile()
+    txnIdxFile.delete()
     val idx = new OffsetIndex(idxFile, offset, 1000)
     val timeIdx = new TimeIndex(timeIdxFile, offset, 1500)
-    val seg = new LogSegment(ms, idx, timeIdx, offset, indexIntervalBytes, 0, Time.SYSTEM)
+    val txnIndex = new TransactionIndex(offset, txnIdxFile)
+    val seg = new LogSegment(ms, idx, timeIdx, txnIndex, offset, indexIntervalBytes, 0, Time.SYSTEM)
     segments += seg
@@ -51,12 +60,20 @@ class LogSegmentTest { { s => new SimpleRecord(offset * 10, s.getBytes) }: _*)
+  @Before
+  def setup(): Unit = {
+    logDir = TestUtils.tempDir()
+  }
   def teardown() {
     for(seg <- segments) {
+      seg.timeIndex.delete()
+      seg.txnIndex.delete()
+    Utils.delete(logDir)
@@ -153,7 +170,7 @@ class LogSegmentTest {
-  def testReloadLargestTimestampAfterTruncation() {
+  def testReloadLargestTimestampAndNextOffsetAfterTruncation() {
     val numMessages = 30
     val seg = createSegment(40, 2 * records(0, "hello").sizeInBytes - 1)
     var offset = 40
@@ -161,13 +178,15 @@ class LogSegmentTest {
       seg.append(offset, offset, offset, offset, records(offset, "hello"))
       offset += 1
+    assertEquals(offset, seg.nextOffset)
     val expectedNumEntries = numMessages / 2 - 1
     assertEquals(s"Should have $expectedNumEntries time indexes", expectedNumEntries, seg.timeIndex.entries)
     assertEquals(s"Should have 0 time indexes", 0, seg.timeIndex.entries)
     assertEquals(s"Largest timestamp should be 400", 400L, seg.largestTimestamp)
+    assertEquals(41, seg.nextOffset)
@@ -217,7 +236,7 @@ class LogSegmentTest {
     val seg = createSegment(40)
     assertEquals(40, seg.nextOffset)
     seg.append(50, 52, RecordBatch.NO_TIMESTAMP, -1L, records(50, "hello", "there", "you"))
-    assertEquals(53, seg.nextOffset())
+    assertEquals(53, seg.nextOffset)
@@ -246,11 +265,76 @@ class LogSegmentTest {
       seg.append(i, i, RecordBatch.NO_TIMESTAMP, -1L, records(i, i.toString))
     val indexFile = seg.index.file
     TestUtils.writeNonsenseToFile(indexFile, 5, indexFile.length.toInt)
-    seg.recover(64*1024)
+    seg.recover(64*1024,   new ProducerStateManager(topicPartition, logDir))
     for(i <- 0 until 100)
       assertEquals(i,, Some(i + 1), 1024)
+  @Test
+  def testRecoverTransactionIndex(): Unit = {
+    val segment = createSegment(100)
+    val epoch = 0.toShort
+    val sequence = 0
+    val pid1 = 5L
+    val pid2 = 10L
+    // append transactional records from pid1
+    segment.append(firstOffset = 100L, largestOffset = 101L, largestTimestamp = RecordBatch.NO_TIMESTAMP,
+      shallowOffsetOfMaxTimestamp = 100L, MemoryRecords.withTransactionalRecords(100L, CompressionType.NONE,
+        pid1, epoch, sequence, new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes)))
+    // append transactional records from pid2
+    segment.append(firstOffset = 102L, largestOffset = 103L, largestTimestamp = RecordBatch.NO_TIMESTAMP,
+      shallowOffsetOfMaxTimestamp = 102L, MemoryRecords.withTransactionalRecords(102L, CompressionType.NONE,
+        pid2, epoch, sequence, new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes)))
+    // append non-transactional records
+    segment.append(firstOffset = 104L, largestOffset = 105L, largestTimestamp = RecordBatch.NO_TIMESTAMP,
+      shallowOffsetOfMaxTimestamp = 104L, MemoryRecords.withRecords(104L, CompressionType.NONE,
+        new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes)))
+    // abort the transaction from pid2 (note LSO should be 100L since the txn from pid1 has not completed)
+    segment.append(firstOffset = 106L, largestOffset = 106L, largestTimestamp = RecordBatch.NO_TIMESTAMP,
+      shallowOffsetOfMaxTimestamp = 106L,  endTxnRecords(ControlRecordType.ABORT, pid2, epoch, offset = 106L))
+    // commit the transaction from pid1
+    segment.append(firstOffset = 107L, largestOffset = 107L, largestTimestamp = RecordBatch.NO_TIMESTAMP,
+      shallowOffsetOfMaxTimestamp = 107L, endTxnRecords(ControlRecordType.COMMIT, pid1, epoch, offset = 107L))
+    segment.recover(64 * 1024, new ProducerStateManager(topicPartition, logDir))
+    var abortedTxns = segment.txnIndex.allAbortedTxns
+    assertEquals(1, abortedTxns.size)
+    var abortedTxn = abortedTxns.head
+    assertEquals(pid2, abortedTxn.producerId)
+    assertEquals(102L, abortedTxn.firstOffset)
+    assertEquals(106L, abortedTxn.lastOffset)
+    assertEquals(100L, abortedTxn.lastStableOffset)
+    // recover again, but this time assuming the transaction from pid2 began on a previous segment
+    val stateManager = new ProducerStateManager(topicPartition, logDir)
+    stateManager.loadProducerEntry(ProducerIdEntry(pid2, epoch, 10, 90L, 5, RecordBatch.NO_TIMESTAMP, 0, Some(75L)))
+    segment.recover(64 * 1024, stateManager)
+    abortedTxns = segment.txnIndex.allAbortedTxns
+    assertEquals(1, abortedTxns.size)
+    abortedTxn = abortedTxns.head
+    assertEquals(pid2, abortedTxn.producerId)
+    assertEquals(75L, abortedTxn.firstOffset)
+    assertEquals(106L, abortedTxn.lastOffset)
+    assertEquals(100L, abortedTxn.lastStableOffset)
+  }
+  private def endTxnRecords(controlRecordType: ControlRecordType,
+                            producerId: Long,
+                            epoch: Short,
+                            offset: Long = 0L,
+                            coordinatorEpoch: Int = 0): MemoryRecords = {
+    val marker = new EndTransactionMarker(controlRecordType, coordinatorEpoch)
+    withEndTransactionMarker(offset, producerId, epoch, marker)
+  }
    * Create a segment with some data and an index. Then corrupt the index,
    * and recover the segment, the entries should all be readable.
@@ -262,7 +346,7 @@ class LogSegmentTest {
       seg.append(i, i, i * 10, i, records(i, i.toString))
     val timeIndexFile = seg.timeIndex.file
     TestUtils.writeNonsenseToFile(timeIndexFile, 5, timeIndexFile.length.toInt)
-    seg.recover(64*1024)
+    seg.recover(64*1024, new ProducerStateManager(topicPartition, logDir))
     for(i <- 0 until 100) {
       assertEquals(i, seg.findOffsetByTimestamp(i * 10).get.offset)
       if (i < 99)
@@ -286,7 +370,7 @@ class LogSegmentTest {
       val recordPosition = seg.log.searchForOffsetWithSize(offsetToBeginCorruption, 0)
       val position = recordPosition.position + TestUtils.random.nextInt(15)
       TestUtils.writeNonsenseToFile(seg.log.file, position, (seg.log.file.length - position).toInt)
-      seg.recover(64*1024)
+      seg.recover(64*1024, new ProducerStateManager(topicPartition, logDir))
       assertEquals("Should have truncated off bad messages.", (0 until offsetToBeginCorruption).toList,
diff --git a/core/src/test/scala/unit/kafka/log/LogTest.scala b/core/src/test/scala/unit/kafka/log/LogTest.scala
index 0f82cd3..b11c94b 100755
--- a/core/src/test/scala/unit/kafka/log/LogTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogTest.scala
@@ -31,6 +31,8 @@ import kafka.server.KafkaConfig
 import kafka.server.epoch.{EpochEntry, LeaderEpochFileCache}
 import org.apache.kafka.common.record.MemoryRecords.RecordFilter
 import org.apache.kafka.common.record._
+import org.apache.kafka.common.requests.FetchResponse.AbortedTransaction
+import org.apache.kafka.common.requests.IsolationLevel
 import org.apache.kafka.common.utils.Utils
 import scala.collection.JavaConverters._
@@ -58,7 +60,7 @@ class LogTest {
   def createEmptyLogs(dir: File, offsets: Int*) {
     for(offset <- offsets) {
       Log.logFilename(dir, offset).createNewFile()
-      Log.indexFilename(dir, offset).createNewFile()
+      Log.offsetIndexFile(dir, offset).createNewFile()
@@ -69,13 +71,13 @@ class LogTest {
     val logFile = Log.logFilename(tmpDir, offset)
     assertEquals(offset, Log.offsetFromFilename(logFile.getName))
-    val offsetIndexFile = Log.indexFilename(tmpDir, offset)
+    val offsetIndexFile = Log.offsetIndexFile(tmpDir, offset)
     assertEquals(offset, Log.offsetFromFilename(offsetIndexFile.getName))
-    val timeIndexFile = Log.timeIndexFilename(tmpDir, offset)
+    val timeIndexFile = Log.timeIndexFile(tmpDir, offset)
     assertEquals(offset, Log.offsetFromFilename(timeIndexFile.getName))
-    val snapshotFile = Log.pidSnapshotFilename(tmpDir, offset)
+    val snapshotFile = Log.producerSnapshotFile(tmpDir, offset)
     assertEquals(offset, Log.offsetFromFilename(snapshotFile.getName))
@@ -166,8 +168,8 @@ class LogTest {
     val log = createLog(2048)
     val records = TestUtils.records(List(new SimpleRecord(time.milliseconds, "key".getBytes, "value".getBytes)))
     log.appendAsLeader(records, leaderEpoch = 0)
-    log.maybeTakePidSnapshot()
-    assertEquals(Some(1), log.latestPidSnapshotOffset)
+    log.takeProducerSnapshot()
+    assertEquals(Some(1), log.latestProducerSnapshotOffset)
@@ -253,18 +255,18 @@ class LogTest {
     val log = createLog(2048)
     log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes))), leaderEpoch = 0)
     log.appendAsLeader(TestUtils.records(List(new SimpleRecord("b".getBytes))), leaderEpoch = 0)
-    log.maybeTakePidSnapshot()
+    log.takeProducerSnapshot()
     log.appendAsLeader(TestUtils.records(List(new SimpleRecord("c".getBytes))), leaderEpoch = 0)
-    log.maybeTakePidSnapshot()
+    log.takeProducerSnapshot()
-    assertEquals(Some(2), log.latestPidSnapshotOffset)
-    assertEquals(2, log.latestPidMapOffset)
+    assertEquals(Some(2), log.latestProducerSnapshotOffset)
+    assertEquals(2, log.latestProducerStateEndOffset)
-    assertEquals(None, log.latestPidSnapshotOffset)
-    assertEquals(1, log.latestPidMapOffset)
+    assertEquals(None, log.latestProducerSnapshotOffset)
+    assertEquals(1, log.latestProducerStateEndOffset)
@@ -272,20 +274,20 @@ class LogTest {
     val records = TestUtils.singletonRecords("foo".getBytes)
     val log = createLog(records.sizeInBytes, messagesPerSegment = 1, retentionBytes = records.sizeInBytes * 2)
     log.appendAsLeader(records, leaderEpoch = 0)
-    log.maybeTakePidSnapshot()
+    log.takeProducerSnapshot()
     log.appendAsLeader(TestUtils.singletonRecords("bar".getBytes), leaderEpoch = 0)
     log.appendAsLeader(TestUtils.singletonRecords("baz".getBytes), leaderEpoch = 0)
-    log.maybeTakePidSnapshot()
+    log.takeProducerSnapshot()
     assertEquals(3, log.logSegments.size)
-    assertEquals(3, log.latestPidMapOffset)
-    assertEquals(Some(3), log.latestPidSnapshotOffset)
+    assertEquals(3, log.latestProducerStateEndOffset)
+    assertEquals(Some(3), log.latestProducerSnapshotOffset)
     assertEquals(1, log.logSegments.size)
-    assertEquals(None, log.latestPidSnapshotOffset)
-    assertEquals(29, log.latestPidMapOffset)
+    assertEquals(None, log.latestProducerSnapshotOffset)
+    assertEquals(29, log.latestProducerStateEndOffset)
@@ -294,14 +296,14 @@ class LogTest {
     val records = TestUtils.records(Seq(new SimpleRecord("foo".getBytes)), pid = pid1, epoch = 0, sequence = 0)
     val log = createLog(records.sizeInBytes, messagesPerSegment = 1, retentionBytes = records.sizeInBytes * 2)
     log.appendAsLeader(records, leaderEpoch = 0)
-    log.maybeTakePidSnapshot()
+    log.takeProducerSnapshot()
     val pid2 = 2L
     log.appendAsLeader(TestUtils.records(Seq(new SimpleRecord("bar".getBytes)), pid = pid2, epoch = 0, sequence = 0),
       leaderEpoch = 0)
     log.appendAsLeader(TestUtils.records(Seq(new SimpleRecord("baz".getBytes)), pid = pid2, epoch = 0, sequence = 1),
       leaderEpoch = 0)
-    log.maybeTakePidSnapshot()
+    log.takeProducerSnapshot()
     assertEquals(3, log.logSegments.size)
     assertEquals(Set(pid1, pid2), log.activePids.keySet)
@@ -313,16 +315,69 @@ class LogTest {
-  def testPeriodicPidSnapshot() {
-    val snapshotInterval = 100
-    val log = createLog(2048, pidSnapshotIntervalMs = snapshotInterval)
+  def testTakeSnapshotOnRollAndDeleteSnapshotOnFlush() {
+    val log = createLog(2048)
+    log.appendAsLeader(TestUtils.singletonRecords("a".getBytes), leaderEpoch = 0)
+    log.roll(1L)
+    assertEquals(Some(1L), log.latestProducerSnapshotOffset)
+    assertEquals(Some(1L), log.oldestProducerSnapshotOffset)
+    log.appendAsLeader(TestUtils.singletonRecords("b".getBytes), leaderEpoch = 0)
+    log.roll(2L)
+    assertEquals(Some(2L), log.latestProducerSnapshotOffset)
+    assertEquals(Some(1L), log.oldestProducerSnapshotOffset)
+    log.appendAsLeader(TestUtils.singletonRecords("c".getBytes), leaderEpoch = 0)
+    log.roll(3L)
+    assertEquals(Some(3L), log.latestProducerSnapshotOffset)
+    // roll triggers a flush at the starting offset of the new segment. we should
+    // retain the snapshots from the active segment and the previous segment, but
+    // the oldest one should be gone
+    assertEquals(Some(2L), log.oldestProducerSnapshotOffset)
+    // even if we flush within the active segment, the snapshot should remain
+    log.appendAsLeader(TestUtils.singletonRecords("baz".getBytes), leaderEpoch = 0)
+    log.flush(4L)
+    assertEquals(Some(3L), log.latestProducerSnapshotOffset)
+    assertEquals(Some(2L), log.oldestProducerSnapshotOffset)
+  }
-    log.appendAsLeader(TestUtils.singletonRecords("foo".getBytes), leaderEpoch = 0)
-    log.appendAsLeader(TestUtils.singletonRecords("bar".getBytes), leaderEpoch = 0)
-    assertEquals(None, log.latestPidSnapshotOffset)
+  @Test
+  def testRebuildTransactionalState(): Unit = {
+    val log = createLog(1024 * 1024)
+    val pid = 137L
+    val epoch = 5.toShort
+    val seq = 0
+    // add some transactional records
+    val records = MemoryRecords.withTransactionalRecords(CompressionType.NONE, pid, epoch, seq,
+      new SimpleRecord("foo".getBytes),
+      new SimpleRecord("bar".getBytes),
+      new SimpleRecord("baz".getBytes))
+    log.appendAsLeader(records, leaderEpoch = 0)
+    val commitAppendInfo = log.appendAsLeader(endTxnRecords(ControlRecordType.ABORT, pid, epoch),
+      isFromClient = false, leaderEpoch = 0)
+    log.onHighWatermarkIncremented(commitAppendInfo.lastOffset + 1)
+    // now there should be no first unstable offset
+    assertEquals(None, log.firstUnstableOffset)
+    log.close()
-    time.sleep(snapshotInterval)
-    assertEquals(Some(2), log.latestPidSnapshotOffset)
+    val reopenedLog = createLog(1024 * 1024)
+    reopenedLog.onHighWatermarkIncremented(commitAppendInfo.lastOffset + 1)
+    assertEquals(None, reopenedLog.firstUnstableOffset)
+  }
+  private def endTxnRecords(controlRecordType: ControlRecordType,
+                            producerId: Long,
+                            epoch: Short,
+                            offset: Long = 0L,
+                            coordinatorEpoch: Int = 0): MemoryRecords = {
+    val marker = new EndTransactionMarker(controlRecordType, coordinatorEpoch)
+    MemoryRecords.withEndTransactionMarker(offset, producerId, epoch, marker)
@@ -432,25 +487,25 @@ class LogTest {
       time = time)
     val epoch: Short = 0
     val buffer = ByteBuffer.allocate(512)
-    var builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, CompressionType.NONE, TimestampType.LOG_APPEND_TIME, 0L, time.milliseconds(), 1L, epoch, 0, false, 0)
+    var builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, CompressionType.NONE,
+      TimestampType.LOG_APPEND_TIME, 0L, time.milliseconds(), 1L, epoch, 0, false, 0)
     builder.append(new SimpleRecord("key".getBytes, "value".getBytes))
-    // Append a record with other pids.
-    builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, CompressionType.NONE, TimestampType.LOG_APPEND_TIME, 1L, time.milliseconds(), 2L, epoch, 0, false, 0)
+    builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, CompressionType.NONE,
+      TimestampType.LOG_APPEND_TIME, 1L, time.milliseconds(), 2L, epoch, 0, false, 0)
     builder.append(new SimpleRecord("key".getBytes, "value".getBytes))
-    // Append a record with other pids.
-    builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, CompressionType.NONE, TimestampType.LOG_APPEND_TIME, 2L, time.milliseconds(), 3L, epoch, 0, false, 0)
+    builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, CompressionType.NONE,
+      TimestampType.LOG_APPEND_TIME, 2L, time.milliseconds(), 3L, epoch, 0, false, 0)
     builder.append(new SimpleRecord("key".getBytes, "value".getBytes))
-    // Append a record with other pids.
-    builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, CompressionType.NONE, TimestampType.LOG_APPEND_TIME, 3L, time.milliseconds(), 4L, epoch, 0, false, 0)
+    builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, CompressionType.NONE,
+      TimestampType.LOG_APPEND_TIME, 3L, time.milliseconds(), 4L, epoch, 0, false, 0)
     builder.append(new SimpleRecord("key".getBytes, "value".getBytes))
@@ -473,46 +528,66 @@ class LogTest {
   @Test(expected = classOf[DuplicateSequenceNumberException])
-  def testMultiplePidsWithDuplicates() : Unit = {
-    val logProps = new Properties()
+  def testDuplicateAppendToFollower() : Unit = {
+    val log = createLog(1024*1024)
+    val epoch: Short = 0
+    val pid = 1L
+    val baseSequence = 0
+    val partitionLeaderEpoch = 0
+    // this is a bit contrived. to trigger the duplicate case for a follower append, we have to append
+    // a batch with matching sequence numbers, but valid increasing offsets
+    log.appendAsFollower(MemoryRecords.withIdempotentRecords(0L, CompressionType.NONE, pid, epoch, baseSequence,
+      partitionLeaderEpoch, new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes)))
+    log.appendAsFollower(MemoryRecords.withIdempotentRecords(2L, CompressionType.NONE, pid, epoch, baseSequence,
+      partitionLeaderEpoch, new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes)))
+  }
-    // create a log
-    val log = new Log(logDir,
-      LogConfig(logProps),
-      recoveryPoint = 0L,
-      scheduler = time.scheduler,
-      time = time)
+  @Test(expected = classOf[DuplicateSequenceNumberException])
+  def testMultipleProducersWithDuplicatesInSingleAppend() : Unit = {
+    val log = createLog(1024*1024)
+    val pid1 = 1L
+    val pid2 = 2L
     val epoch: Short = 0
     val buffer = ByteBuffer.allocate(512)
-    var builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, CompressionType.NONE, TimestampType.LOG_APPEND_TIME, 0L, time.milliseconds(), 1L, epoch, 0)
+    // pid1 seq = 0
+    var builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE,
+      TimestampType.LOG_APPEND_TIME, 0L, time.milliseconds(), pid1, epoch, 0)
     builder.append(new SimpleRecord("key".getBytes, "value".getBytes))
-    // Append a record with other pids.
-    builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, CompressionType.NONE, TimestampType.LOG_APPEND_TIME, 1L, time.milliseconds(), 2L, epoch, 0)
+    // pid2 seq = 0
+    builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE,
+      TimestampType.LOG_APPEND_TIME, 1L, time.milliseconds(), pid2, epoch, 0)
     builder.append(new SimpleRecord("key".getBytes, "value".getBytes))
-    // Append a record with other pids.
-    builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, CompressionType.NONE, TimestampType.LOG_APPEND_TIME, 2L, time.milliseconds(), 1L, epoch, 1)
+    // pid1 seq = 1
+    builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE,
+      TimestampType.LOG_APPEND_TIME, 2L, time.milliseconds(), pid1, epoch, 1)
     builder.append(new SimpleRecord("key".getBytes, "value".getBytes))
-    builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, CompressionType.NONE, TimestampType.LOG_APPEND_TIME, 3L, time.milliseconds(), 2L, epoch, 1)
+    // pid2 seq = 1
+    builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE,
+      TimestampType.LOG_APPEND_TIME, 3L, time.milliseconds(), pid2, epoch, 1)
     builder.append(new SimpleRecord("key".getBytes, "value".getBytes))
-    builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, CompressionType.NONE, TimestampType.LOG_APPEND_TIME, 4L, time.milliseconds(), 1L, epoch, 1)
+    // // pid1 seq = 1 (duplicate)
+    builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE,
+      TimestampType.LOG_APPEND_TIME, 4L, time.milliseconds(), pid1, epoch, 1)
     builder.append(new SimpleRecord("key".getBytes, "value".getBytes))
-    log.appendAsFollower(MemoryRecords.readableRecords(buffer))
-    // Should throw a duplicate seqeuence exception here.
+    val records = MemoryRecords.readableRecords(buffer)
+    records.batches.asScala.foreach(_.setPartitionLeaderEpoch(0))
+    log.appendAsFollower(records)
+    // Should throw a duplicate sequence exception here.
     fail("should have thrown a DuplicateSequenceNumberException.")
@@ -1245,10 +1320,10 @@ class LogTest {
   def testBogusIndexSegmentsAreRemoved() {
-    val bogusIndex1 = Log.indexFilename(logDir, 0)
-    val bogusTimeIndex1 = Log.timeIndexFilename(logDir, 0)
-    val bogusIndex2 = Log.indexFilename(logDir, 5)
-    val bogusTimeIndex2 = Log.timeIndexFilename(logDir, 5)
+    val bogusIndex1 = Log.offsetIndexFile(logDir, 0)
+    val bogusTimeIndex1 = Log.timeIndexFile(logDir, 0)
+    val bogusIndex2 = Log.offsetIndexFile(logDir, 5)
+    val bogusTimeIndex2 = Log.timeIndexFile(logDir, 5)
     def createRecords = TestUtils.singletonRecords(value = "test".getBytes, timestamp = time.milliseconds)
     val logProps = new Properties()
@@ -1501,9 +1576,11 @@ class LogTest {
     //This write will roll the segment, yielding a new segment with base offset = max(2, 1) = 2
     assertEquals(2L, log.activeSegment.baseOffset)
+    assertTrue(Log.producerSnapshotFile(logDir, 2L).exists)
     //This will also roll the segment, yielding a new segment with base offset = max(3, Integer.MAX_VALUE+3) = Integer.MAX_VALUE+3
     assertEquals(Integer.MAX_VALUE.toLong + 3, log.activeSegment.baseOffset)
+    assertTrue(Log.producerSnapshotFile(logDir, Integer.MAX_VALUE.toLong + 3).exists)
     //This will go into the existing log
     assertEquals(Integer.MAX_VALUE.toLong + 3, log.activeSegment.baseOffset)
@@ -1990,11 +2067,301 @@ class LogTest {
+  def testFirstUnstableOffsetNoTransactionalData() {
+    val log = createLog(1024 * 1024)
+    val records = MemoryRecords.withRecords(CompressionType.NONE,
+      new SimpleRecord("foo".getBytes),
+      new SimpleRecord("bar".getBytes),
+      new SimpleRecord("baz".getBytes))
+    log.appendAsLeader(records, leaderEpoch = 0)
+    assertEquals(None, log.firstUnstableOffset)
+  }
+  @Test
+  def testFirstUnstableOffsetWithTransactionalData() {
+    val log = createLog(1024 * 1024)
+    val pid = 137L
+    val epoch = 5.toShort
+    var seq = 0
+    // add some transactional records
+    val records = MemoryRecords.withTransactionalRecords(CompressionType.NONE, pid, epoch, seq,
+      new SimpleRecord("foo".getBytes),
+      new SimpleRecord("bar".getBytes),
+      new SimpleRecord("baz".getBytes))
+    val firstAppendInfo = log.appendAsLeader(records, leaderEpoch = 0)
+    assertEquals(Some(firstAppendInfo.firstOffset),
+    // add more transactional records
+    seq += 3
+    log.appendAsLeader(MemoryRecords.withTransactionalRecords(CompressionType.NONE, pid, epoch, seq,
+      new SimpleRecord("blah".getBytes)), leaderEpoch = 0)
+    // LSO should not have changed
+    assertEquals(Some(firstAppendInfo.firstOffset),
+    // now transaction is committed
+    val commitAppendInfo = log.appendAsLeader(endTxnRecords(ControlRecordType.COMMIT, pid, epoch),
+      isFromClient = false, leaderEpoch = 0)
+    // first unstable offset is not updated until the high watermark is advanced
+    assertEquals(Some(firstAppendInfo.firstOffset),
+    log.onHighWatermarkIncremented(commitAppendInfo.lastOffset + 1)
+    // now there should be no first unstable offset
+    assertEquals(None, log.firstUnstableOffset)
+  }
+  @Test
+  def testTransactionIndexUpdated(): Unit = {
+    val log = createLog(1024 * 1024)
+    val epoch = 0.toShort
+    val pid1 = 1L
+    val pid2 = 2L
+    val pid3 = 3L
+    val pid4 = 4L
+    val appendPid1 = appendTransactionalAsLeader(log, pid1, epoch)
+    val appendPid2 = appendTransactionalAsLeader(log, pid2, epoch)
+    val appendPid3 = appendTransactionalAsLeader(log, pid3, epoch)
+    val appendPid4 = appendTransactionalAsLeader(log, pid4, epoch)
+    // mix transactional and non-transactional data
+    appendPid1(5) // nextOffset: 5
+    appendNonTransactionalAsLeader(log, 3) // 8
+    appendPid2(2) // 10
+    appendPid1(4) // 14
+    appendPid3(3) // 17
+    appendNonTransactionalAsLeader(log, 2) // 19
+    appendPid1(10) // 29
+    appendEndTxnMarkerAsLeader(log, pid1, epoch, ControlRecordType.ABORT) // 30
+    appendPid2(6) // 36
+    appendPid4(3) // 39
+    appendNonTransactionalAsLeader(log, 10) // 49
+    appendPid3(9) // 58
+    appendEndTxnMarkerAsLeader(log, pid3, epoch, ControlRecordType.COMMIT) // 59
+    appendPid4(8) // 67
+    appendPid2(7) // 74
+    appendEndTxnMarkerAsLeader(log, pid2, epoch, ControlRecordType.ABORT) // 75
+    appendNonTransactionalAsLeader(log, 10) // 85
+    appendPid4(4) // 89
+    appendEndTxnMarkerAsLeader(log, pid4, epoch, ControlRecordType.COMMIT) // 90
+    val abortedTransactions = allAbortedTransactions(log)
+    assertEquals(List(new AbortedTxn(pid1, 0L, 29L, 8L), new AbortedTxn(pid2, 8L, 74L, 36L)), abortedTransactions)
+  }
+  @Test
+  def testRecoverTransactionIndex(): Unit = {
+    val log = createLog(128)
+    val epoch = 0.toShort
+    val pid1 = 1L
+    val pid2 = 2L
+    val pid3 = 3L
+    val pid4 = 4L
+    val appendPid1 = appendTransactionalAsLeader(log, pid1, epoch)
+    val appendPid2 = appendTransactionalAsLeader(log, pid2, epoch)
+    val appendPid3 = appendTransactionalAsLeader(log, pid3, epoch)
+    val appendPid4 = appendTransactionalAsLeader(log, pid4, epoch)
+    // mix transactional and non-transactional data
+    appendPid1(5) // nextOffset: 5
+    appendNonTransactionalAsLeader(log, 3) // 8
+    appendPid2(2) // 10
+    appendPid1(4) // 14
+    appendPid3(3) // 17
+    appendNonTransactionalAsLeader(log, 2) // 19
+    appendPid1(10) // 29
+    appendEndTxnMarkerAsLeader(log, pid1, epoch, ControlRecordType.ABORT) // 30
+    appendPid2(6) // 36
+    appendPid4(3) // 39
+    appendNonTransactionalAsLeader(log, 10) // 49
+    appendPid3(9) // 58
+    appendEndTxnMarkerAsLeader(log, pid3, epoch, ControlRecordType.COMMIT) // 59
+    appendPid4(8) // 67
+    appendPid2(7) // 74
+    appendEndTxnMarkerAsLeader(log, pid2, epoch, ControlRecordType.ABORT) // 75
+    appendNonTransactionalAsLeader(log, 10) // 85
+    appendPid4(4) // 89
+    appendEndTxnMarkerAsLeader(log, pid4, epoch, ControlRecordType.COMMIT) // 90
+    // delete all the offset and transaction index files to force recovery
+    log.logSegments.foreach { segment =>
+      segment.index.delete()
+      segment.txnIndex.delete()
+    }
+    log.close()
+    val reloadedLog = createLog(1024)
+    val abortedTransactions = allAbortedTransactions(reloadedLog)
+    assertEquals(List(new AbortedTxn(pid1, 0L, 29L, 8L), new AbortedTxn(pid2, 8L, 74L, 36L)), abortedTransactions)
+  }
+  @Test
+  def testTransactionIndexUpdatedThroughReplication(): Unit = {
+    val epoch = 0.toShort
+    val log = createLog(1024 * 1024)
+    val buffer = ByteBuffer.allocate(2048)
+    val pid1 = 1L
+    val pid2 = 2L
+    val pid3 = 3L
+    val pid4 = 4L
+    val appendPid1 = appendTransactionalToBuffer(buffer, pid1, epoch)
+    val appendPid2 = appendTransactionalToBuffer(buffer, pid2, epoch)
+    val appendPid3 = appendTransactionalToBuffer(buffer, pid3, epoch)
+    val appendPid4 = appendTransactionalToBuffer(buffer, pid4, epoch)
+    appendPid1(0L, 5)
+    appendNonTransactionalToBuffer(buffer, 5L, 3)
+    appendPid2(8L, 2)
+    appendPid1(10L, 4)
+    appendPid3(14L, 3)
+    appendNonTransactionalToBuffer(buffer, 17L, 2)
+    appendPid1(19L, 10)
+    appendEndTxnMarkerToBuffer(buffer, pid1, epoch, 29L, ControlRecordType.ABORT)
+    appendPid2(30L, 6)
+    appendPid4(36L, 3)
+    appendNonTransactionalToBuffer(buffer, 39L, 10)
+    appendPid3(49L, 9)
+    appendEndTxnMarkerToBuffer(buffer, pid3, epoch, 58L, ControlRecordType.COMMIT)
+    appendPid4(59L, 8)
+    appendPid2(67L, 7)
+    appendEndTxnMarkerToBuffer(buffer, pid2, epoch, 74L, ControlRecordType.ABORT)
+    appendNonTransactionalToBuffer(buffer, 75L, 10)
+    appendPid4(85L, 4)
+    appendEndTxnMarkerToBuffer(buffer, pid4, epoch, 89L, ControlRecordType.COMMIT)
+    buffer.flip()
+    appendAsFollower(log, MemoryRecords.readableRecords(buffer))
+    val abortedTransactions = allAbortedTransactions(log)
+    assertEquals(List(new AbortedTxn(pid1, 0L, 29L, 8L), new AbortedTxn(pid2, 8L, 74L, 36L)), abortedTransactions)
+  }
+  @Test(expected = classOf[TransactionCoordinatorFencedException])
+  def testZombieCoordinatorFenced(): Unit = {
+    val pid = 1L
+    val epoch = 0.toShort
+    val log = createLog(1024 * 1024)
+    val append = appendTransactionalAsLeader(log, pid, epoch)
+    append(10)
+    appendEndTxnMarkerAsLeader(log, pid, epoch, ControlRecordType.ABORT, coordinatorEpoch = 1)
+    append(5)
+    appendEndTxnMarkerAsLeader(log, pid, epoch, ControlRecordType.COMMIT, coordinatorEpoch = 2)
-  def createLog(messageSizeInBytes: Int, retentionMs: Int = -1, retentionBytes: Int = -1,
-                cleanupPolicy: String = "delete", messagesPerSegment: Int = 5,
-                maxPidExpirationMs: Int = 300000, pidExpirationCheckIntervalMs: Int = 30000,
-                pidSnapshotIntervalMs: Int = 60000): Log = {
+    appendEndTxnMarkerAsLeader(log, pid, epoch, ControlRecordType.ABORT, coordinatorEpoch = 1)
+  }
+  @Test
+  def testLastStableOffsetWithMixedProducerData() {
+    val log = createLog(1024 * 1024)
+    // for convenience, both producers share the same epoch
+    val epoch = 5.toShort
+    val pid1 = 137L
+    val seq1 = 0
+    val pid2 = 983L
+    val seq2 = 0
+    // add some transactional records
+    val firstAppendInfo = log.appendAsLeader(MemoryRecords.withTransactionalRecords(CompressionType.NONE, pid1, epoch, seq1,
+      new SimpleRecord("a".getBytes),
+      new SimpleRecord("b".getBytes),
+      new SimpleRecord("c".getBytes)), leaderEpoch = 0)
+    assertEquals(Some(firstAppendInfo.firstOffset),
+    // mix in some non-transactional data
+    log.appendAsLeader(MemoryRecords.withRecords(CompressionType.NONE,
+      new SimpleRecord("g".getBytes),
+      new SimpleRecord("h".getBytes),
+      new SimpleRecord("i".getBytes)), leaderEpoch = 0)
+    // append data from a second transactional producer
+    val secondAppendInfo = log.appendAsLeader(MemoryRecords.withTransactionalRecords(CompressionType.NONE, pid2, epoch, seq2,
+      new SimpleRecord("d".getBytes),
+      new SimpleRecord("e".getBytes),
+      new SimpleRecord("f".getBytes)), leaderEpoch = 0)
+    // LSO should not have changed
+    assertEquals(Some(firstAppendInfo.firstOffset),
+    // now first producer's transaction is aborted
+    val abortAppendInfo = log.appendAsLeader(endTxnRecords(ControlRecordType.ABORT, pid1, epoch),
+      isFromClient = false, leaderEpoch = 0)
+    log.onHighWatermarkIncremented(abortAppendInfo.lastOffset + 1)
+    // LSO should now point to one less than the first offset of the second transaction
+    assertEquals(Some(secondAppendInfo.firstOffset),
+    // commit the second transaction
+    val commitAppendInfo = log.appendAsLeader(endTxnRecords(ControlRecordType.COMMIT, pid2, epoch),
+      isFromClient = false, leaderEpoch = 0)
+    log.onHighWatermarkIncremented(commitAppendInfo.lastOffset + 1)
+    // now there should be no first unstable offset
+    assertEquals(None, log.firstUnstableOffset)
+  }
+  @Test
+  def testAbortedTransactionSpanningMultipleSegments() {
+    val pid = 137L
+    val epoch = 5.toShort
+    var seq = 0
+    val records = MemoryRecords.withTransactionalRecords(CompressionType.NONE, pid, epoch, seq,
+      new SimpleRecord("a".getBytes),
+      new SimpleRecord("b".getBytes),
+      new SimpleRecord("c".getBytes))
+    val log = createLog(messageSizeInBytes = records.sizeInBytes, messagesPerSegment = 1)
+    val firstAppendInfo = log.appendAsLeader(records, leaderEpoch = 0)
+    assertEquals(Some(firstAppendInfo.firstOffset),
+    assertEquals(Some(0L),
+    // this write should spill to the second segment
+    seq = 3
+    log.appendAsLeader(MemoryRecords.withTransactionalRecords(CompressionType.NONE, pid, epoch, seq,
+      new SimpleRecord("d".getBytes),
+      new SimpleRecord("e".getBytes),
+      new SimpleRecord("f".getBytes)), leaderEpoch = 0)
+    assertEquals(Some(firstAppendInfo.firstOffset),
+    assertEquals(Some(0L),
+    assertEquals(3L, log.logEndOffsetMetadata.segmentBaseOffset)
+    // now abort the transaction
+    val appendInfo = log.appendAsLeader(endTxnRecords(ControlRecordType.ABORT, pid, epoch),
+      isFromClient = false, leaderEpoch = 0)
+    log.onHighWatermarkIncremented(appendInfo.lastOffset + 1)
+    assertEquals(None,
+    // now check that a fetch includes the aborted transaction
+    val fetchDataInfo =, 2048, isolationLevel = IsolationLevel.READ_COMMITTED)
+    assertEquals(1, fetchDataInfo.abortedTransactions.size)
+    assertTrue(fetchDataInfo.abortedTransactions.isDefined)
+    assertEquals(new AbortedTransaction(pid, 0), fetchDataInfo.abortedTransactions.get.head)
+  }
+  private def createLog(messageSizeInBytes: Int, retentionMs: Int = -1, retentionBytes: Int = -1,
+                        cleanupPolicy: String = "delete", messagesPerSegment: Int = 5,
+                        maxPidExpirationMs: Int = 300000, pidExpirationCheckIntervalMs: Int = 30000,
+                        pidSnapshotIntervalMs: Int = 60000): Log = {
     val logProps = new Properties()
     logProps.put(LogConfig.SegmentBytesProp, messageSizeInBytes * messagesPerSegment: Integer)
     logProps.put(LogConfig.RetentionMsProp, retentionMs: Integer)
@@ -2009,8 +2376,70 @@ class LogTest {
       scheduler = time.scheduler,
       time = time,
       maxPidExpirationMs = maxPidExpirationMs,
-      pidExpirationCheckIntervalMs = pidExpirationCheckIntervalMs,
-      pidSnapshotIntervalMs = pidSnapshotIntervalMs)
+      pidExpirationCheckIntervalMs = pidExpirationCheckIntervalMs)
+  private def allAbortedTransactions(log: Log) = log.logSegments.flatMap(_.txnIndex.allAbortedTxns)
+  private def appendTransactionalAsLeader(log: Log, pid: Long, producerEpoch: Short): Int => Unit = {
+    var sequence = 0
+    numRecords: Int => {
+      val simpleRecords = (sequence until sequence + numRecords).map { seq =>
+        new SimpleRecord(s"$seq".getBytes)
+      }
+      val records = MemoryRecords.withTransactionalRecords(CompressionType.NONE, pid,
+        producerEpoch, sequence, simpleRecords: _*)
+      log.appendAsLeader(records, leaderEpoch = 0)
+      sequence += numRecords
+    }
+  }
+  private def appendEndTxnMarkerAsLeader(log: Log, pid: Long, producerEpoch: Short,
+                                         controlType: ControlRecordType, coordinatorEpoch: Int = 0): Unit = {
+    val records = endTxnRecords(controlType, pid, producerEpoch, coordinatorEpoch = coordinatorEpoch)
+    log.appendAsLeader(records, isFromClient = false, leaderEpoch = 0)
+  }
+  private def appendNonTransactionalAsLeader(log: Log, numRecords: Int): Unit = {
+    val simpleRecords = (0 until numRecords).map { seq =>
+      new SimpleRecord(s"$seq".getBytes)
+    }
+    val records = MemoryRecords.withRecords(CompressionType.NONE, simpleRecords: _*)
+    log.appendAsLeader(records, leaderEpoch = 0)
+  }
+  private def appendTransactionalToBuffer(buffer: ByteBuffer, pid: Long, epoch: Short): (Long, Int) => Unit = {
+    var sequence = 0
+    (offset: Long, numRecords: Int) => {
+      val builder = MemoryRecords.builder(buffer, CompressionType.NONE, offset, pid, epoch, sequence, true)
+      for (seq <- sequence until sequence + numRecords) {
+        val record = new SimpleRecord(s"$seq".getBytes)
+        builder.append(record)
+      }
+      sequence += numRecords
+      builder.close()
+    }
+  }
+  private def appendEndTxnMarkerToBuffer(buffer: ByteBuffer, producerId: Long, producerEpoch: Short, offset: Long,
+                                    controlType: ControlRecordType, coordinatorEpoch: Int = 0): Unit = {
+    val marker = new EndTransactionMarker(controlType, coordinatorEpoch)
+    MemoryRecords.writeEndTransactionalMarker(buffer, offset, producerId, producerEpoch, marker)
+  }
+  private def appendNonTransactionalToBuffer(buffer: ByteBuffer, offset: Long, numRecords: Int): Unit = {
+    val builder = MemoryRecords.builder(buffer, CompressionType.NONE, TimestampType.CREATE_TIME, offset)
+    (0 until numRecords).foreach { seq =>
+      builder.append(new SimpleRecord(s"$seq".getBytes))
+    }
+    builder.close()
+  }
+  private def appendAsFollower(log: Log, records: MemoryRecords, leaderEpoch: Int = 0): Unit = {
+    records.batches.asScala.foreach(_.setPartitionLeaderEpoch(leaderEpoch))
+    log.appendAsFollower(records)
+  }
diff --git a/core/src/test/scala/unit/kafka/log/LogValidatorTest.scala b/core/src/test/scala/unit/kafka/log/LogValidatorTest.scala
index 5b2c660..61fae80 100644
--- a/core/src/test/scala/unit/kafka/log/LogValidatorTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogValidatorTest.scala
@@ -22,6 +22,7 @@ import kafka.common.LongRef
 import kafka.message.{DefaultCompressionCodec, GZIPCompressionCodec, NoCompressionCodec, SnappyCompressionCodec}
 import org.apache.kafka.common.errors.InvalidTimestampException
 import org.apache.kafka.common.record._
+import org.apache.kafka.test.TestUtils
 import org.junit.Assert._
 import org.junit.Test
@@ -47,7 +48,8 @@ class LogValidatorTest {
       magic = magic,
       timestampType = TimestampType.LOG_APPEND_TIME,
       timestampDiffMaxMs = 1000L,
-      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH)
+      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH,
+      isFromClient = true)
     val validatedRecords = validatedResults.validatedRecords
     assertEquals("message set size should not change", records.records.asScala.size, validatedRecords.records.asScala.size)
     validatedRecords.batches.asScala.foreach(batch => validateLogAppendTime(now, batch))
@@ -79,7 +81,8 @@ class LogValidatorTest {
       magic = targetMagic,
       timestampType = TimestampType.LOG_APPEND_TIME,
       timestampDiffMaxMs = 1000L,
-      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH)
+      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH,
+      isFromClient = true)
     val validatedRecords = validatedResults.validatedRecords
     assertEquals("message set size should not change", records.records.asScala.size, validatedRecords.records.asScala.size)
@@ -115,7 +118,8 @@ class LogValidatorTest {
       magic = magic,
       timestampType = TimestampType.LOG_APPEND_TIME,
       timestampDiffMaxMs = 1000L,
-      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH)
+      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH,
+      isFromClient = true)
     val validatedRecords = validatedResults.validatedRecords
     assertEquals("message set size should not change", records.records.asScala.size,
@@ -141,14 +145,19 @@ class LogValidatorTest {
   private def checkNonCompressed(magic: Byte) {
     val now = System.currentTimeMillis()
     val timestampSeq = Seq(now - 1, now + 1, now)
-    val producerId = if (magic >= RecordBatch.MAGIC_VALUE_V2) 1324L else RecordBatch.NO_PRODUCER_ID
-    val producerEpoch = if (magic >= RecordBatch.MAGIC_VALUE_V2) 10: Short else RecordBatch.NO_PRODUCER_EPOCH
-    val baseSequence = if (magic >= RecordBatch.MAGIC_VALUE_V2) 20 else RecordBatch.NO_SEQUENCE
-    val partitionLeaderEpoch = if (magic >= RecordBatch.MAGIC_VALUE_V2) 40 else RecordBatch.NO_PARTITION_LEADER_EPOCH
-    val records = MemoryRecords.withRecords(magic, 0L, CompressionType.NONE, producerId, producerEpoch, baseSequence,
-      partitionLeaderEpoch, new SimpleRecord(timestampSeq(0), "hello".getBytes),
-      new SimpleRecord(timestampSeq(1), "there".getBytes), new SimpleRecord(timestampSeq(2), "beautiful".getBytes))
+    val (producerId, producerEpoch, baseSequence, isTransactional, partitionLeaderEpoch) =
+      if (magic >= RecordBatch.MAGIC_VALUE_V2)
+        (1324L, 10.toShort, 984, true, 40)
+      else
+        (RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, false,
+          RecordBatch.NO_PARTITION_LEADER_EPOCH)
+    val records = MemoryRecords.withRecords(magic, 0L, CompressionType.GZIP, TimestampType.CREATE_TIME, producerId,
+      producerEpoch, baseSequence, partitionLeaderEpoch, isTransactional,
+      new SimpleRecord(timestampSeq(0), "hello".getBytes),
+      new SimpleRecord(timestampSeq(1), "there".getBytes),
+      new SimpleRecord(timestampSeq(2), "beautiful".getBytes))
     val validatingResults = LogValidator.validateMessagesAndAssignOffsets(records,
       offsetCounter = new LongRef(0),
@@ -159,7 +168,8 @@ class LogValidatorTest {
       magic = magic,
       timestampType = TimestampType.CREATE_TIME,
       timestampDiffMaxMs = 1000L,
-      partitionLeaderEpoch = partitionLeaderEpoch)
+      partitionLeaderEpoch = partitionLeaderEpoch,
+      isFromClient = true)
     val validatedRecords = validatingResults.validatedRecords
     var i = 0
@@ -170,6 +180,7 @@ class LogValidatorTest {
       assertEquals(producerEpoch, batch.producerEpoch)
       assertEquals(producerId, batch.producerId)
       assertEquals(baseSequence, batch.baseSequence)
+      assertEquals(isTransactional, batch.isTransactional)
       assertEquals(partitionLeaderEpoch, batch.partitionLeaderEpoch)
       for (record <- batch.asScala) {
@@ -195,14 +206,19 @@ class LogValidatorTest {
   private def checkRecompression(magic: Byte): Unit = {
     val now = System.currentTimeMillis()
     val timestampSeq = Seq(now - 1, now + 1, now)
-    val producerId = if (magic >= RecordBatch.MAGIC_VALUE_V2) 1324L else RecordBatch.NO_PRODUCER_ID
-    val producerEpoch = if (magic >= RecordBatch.MAGIC_VALUE_V2) 10: Short else RecordBatch.NO_PRODUCER_EPOCH
-    val baseSequence = if (magic >= RecordBatch.MAGIC_VALUE_V2) 20 else RecordBatch.NO_SEQUENCE
-    val partitionLeaderEpoch = if (magic >= RecordBatch.MAGIC_VALUE_V2) 40 else RecordBatch.NO_PARTITION_LEADER_EPOCH
-    val records = MemoryRecords.withRecords(magic, 0L, CompressionType.NONE, producerId, producerEpoch, baseSequence,
-      partitionLeaderEpoch, new SimpleRecord(timestampSeq(0), "hello".getBytes),
-      new SimpleRecord(timestampSeq(1), "there".getBytes), new SimpleRecord(timestampSeq(2), "beautiful".getBytes))
+    val (producerId, producerEpoch, baseSequence, isTransactional, partitionLeaderEpoch) =
+      if (magic >= RecordBatch.MAGIC_VALUE_V2)
+        (1324L, 10.toShort, 984, true, 40)
+      else
+        (RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, false,
+          RecordBatch.NO_PARTITION_LEADER_EPOCH)
+    val records = MemoryRecords.withRecords(magic, 0L, CompressionType.GZIP, TimestampType.CREATE_TIME, producerId,
+      producerEpoch, baseSequence, partitionLeaderEpoch, isTransactional,
+      new SimpleRecord(timestampSeq(0), "hello".getBytes),
+      new SimpleRecord(timestampSeq(1), "there".getBytes),
+      new SimpleRecord(timestampSeq(2), "beautiful".getBytes))
     val validatingResults = LogValidator.validateMessagesAndAssignOffsets(records,
       offsetCounter = new LongRef(0),
@@ -213,7 +229,8 @@ class LogValidatorTest {
       magic = magic,
       timestampType = TimestampType.CREATE_TIME,
       timestampDiffMaxMs = 1000L,
-      partitionLeaderEpoch = partitionLeaderEpoch)
+      partitionLeaderEpoch = partitionLeaderEpoch,
+      isFromClient = true)
     val validatedRecords = validatingResults.validatedRecords
     var i = 0
@@ -257,7 +274,8 @@ class LogValidatorTest {
       compactedTopic = false,
       timestampType = TimestampType.CREATE_TIME,
       timestampDiffMaxMs = 1000L,
-      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH)
+      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH,
+      isFromClient = true)
     val validatedRecords = validatedResults.validatedRecords
     for (batch <- validatedRecords.batches.asScala) {
@@ -292,7 +310,8 @@ class LogValidatorTest {
       compactedTopic = false,
       timestampType = TimestampType.CREATE_TIME,
       timestampDiffMaxMs = 1000L,
-      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH)
+      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH,
+      isFromClient = true)
     val validatedRecords = validatedResults.validatedRecords
     for (batch <- validatedRecords.batches.asScala) {
@@ -317,24 +336,31 @@ class LogValidatorTest {
   private def checkCompressed(magic: Byte) {
     val now = System.currentTimeMillis()
     val timestampSeq = Seq(now - 1, now + 1, now)
-    val producerId = if (magic >= RecordBatch.MAGIC_VALUE_V2) 1324L else RecordBatch.NO_PRODUCER_ID
-    val producerEpoch = if (magic >= RecordBatch.MAGIC_VALUE_V2) 10: Short else RecordBatch.NO_PRODUCER_EPOCH
-    val baseSequence = if (magic >= RecordBatch.MAGIC_VALUE_V2) 20 else RecordBatch.NO_SEQUENCE
-    val partitionLeaderEpoch = if (magic >= RecordBatch.MAGIC_VALUE_V2) 40 else RecordBatch.NO_PARTITION_LEADER_EPOCH
-    val records = MemoryRecords.withRecords(magic, 0L, CompressionType.GZIP, producerId, producerEpoch, baseSequence,
-      partitionLeaderEpoch, new SimpleRecord(timestampSeq(0), "hello".getBytes),
-      new SimpleRecord(timestampSeq(1), "there".getBytes), new SimpleRecord(timestampSeq(2), "beautiful".getBytes))
+    val (producerId, producerEpoch, baseSequence, isTransactional, partitionLeaderEpoch) =
+      if (magic >= RecordBatch.MAGIC_VALUE_V2)
+        (1324L, 10.toShort, 984, true, 40)
+      else
+        (RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, false,
+          RecordBatch.NO_PARTITION_LEADER_EPOCH)
+    val records = MemoryRecords.withRecords(magic, 0L, CompressionType.GZIP, TimestampType.CREATE_TIME, producerId,
+      producerEpoch, baseSequence, partitionLeaderEpoch, isTransactional,
+      new SimpleRecord(timestampSeq(0), "hello".getBytes),
+      new SimpleRecord(timestampSeq(1), "there".getBytes),
+      new SimpleRecord(timestampSeq(2), "beautiful".getBytes))
     val validatedResults = LogValidator.validateMessagesAndAssignOffsets(records,
-        offsetCounter = new LongRef(0),
-        now = System.currentTimeMillis(),
-        sourceCodec = DefaultCompressionCodec,
-        targetCodec = DefaultCompressionCodec,
-        magic = magic,
-        compactedTopic = false,
-        timestampType = TimestampType.CREATE_TIME,
-        timestampDiffMaxMs = 1000L,
-        partitionLeaderEpoch = partitionLeaderEpoch)
+      offsetCounter = new LongRef(0),
+      now = System.currentTimeMillis(),
+      sourceCodec = DefaultCompressionCodec,
+      targetCodec = DefaultCompressionCodec,
+      magic = magic,
+      compactedTopic = false,
+      timestampType = TimestampType.CREATE_TIME,
+      timestampDiffMaxMs = 1000L,
+      partitionLeaderEpoch = partitionLeaderEpoch,
+      isFromClient = true)
     val validatedRecords = validatedResults.validatedRecords
     var i = 0
@@ -378,7 +404,8 @@ class LogValidatorTest {
       magic = RecordBatch.MAGIC_VALUE_V1,
       timestampType = TimestampType.CREATE_TIME,
       timestampDiffMaxMs = 1000L,
-      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH)
+      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH,
+      isFromClient = true)
   @Test(expected = classOf[InvalidTimestampException])
@@ -396,7 +423,8 @@ class LogValidatorTest {
       magic = RecordBatch.MAGIC_VALUE_V2,
       timestampType = TimestampType.CREATE_TIME,
       timestampDiffMaxMs = 1000L,
-      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH)
+      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH,
+      isFromClient = true)
   @Test(expected = classOf[InvalidTimestampException])
@@ -414,7 +442,8 @@ class LogValidatorTest {
       compactedTopic = false,
       timestampType = TimestampType.CREATE_TIME,
       timestampDiffMaxMs = 1000L,
-      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH)
+      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH,
+      isFromClient = true)
   @Test(expected = classOf[InvalidTimestampException])
@@ -432,7 +461,8 @@ class LogValidatorTest {
       compactedTopic = false,
       timestampType = TimestampType.CREATE_TIME,
       timestampDiffMaxMs = 1000L,
-      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH)
+      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH,
+      isFromClient = true)
@@ -449,7 +479,8 @@ class LogValidatorTest {
       compactedTopic = false,
       timestampType = TimestampType.CREATE_TIME,
       timestampDiffMaxMs = 1000L,
-      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH).validatedRecords, offset)
+      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH,
+      isFromClient = true).validatedRecords, offset)
@@ -466,7 +497,8 @@ class LogValidatorTest {
       magic = RecordBatch.MAGIC_VALUE_V0,
       timestampType = TimestampType.CREATE_TIME,
       timestampDiffMaxMs = 1000L,
-      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH).validatedRecords, offset)
+      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH,
+      isFromClient = true).validatedRecords, offset)
@@ -484,7 +516,8 @@ class LogValidatorTest {
       magic = RecordBatch.MAGIC_VALUE_V1,
       timestampType = TimestampType.CREATE_TIME,
       timestampDiffMaxMs = 5000L,
-      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH).validatedRecords
+      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH,
+      isFromClient = true).validatedRecords
     checkOffsets(messageWithOffset, offset)
@@ -503,7 +536,8 @@ class LogValidatorTest {
       magic = RecordBatch.MAGIC_VALUE_V2,
       timestampType = TimestampType.CREATE_TIME,
       timestampDiffMaxMs = 5000L,
-      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH).validatedRecords
+      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH,
+      isFromClient = true).validatedRecords
     checkOffsets(messageWithOffset, offset)
@@ -523,7 +557,8 @@ class LogValidatorTest {
       magic = RecordBatch.MAGIC_VALUE_V1,
       timestampType = TimestampType.CREATE_TIME,
       timestampDiffMaxMs = 5000L,
-      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH).validatedRecords
+      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH,
+      isFromClient = true).validatedRecords
     checkOffsets(compressedMessagesWithOffset, offset)
@@ -543,7 +578,8 @@ class LogValidatorTest {
       magic = RecordBatch.MAGIC_VALUE_V2,
       timestampType = TimestampType.CREATE_TIME,
       timestampDiffMaxMs = 5000L,
-      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH).validatedRecords
+      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH,
+      isFromClient = true).validatedRecords
     checkOffsets(compressedMessagesWithOffset, offset)
@@ -561,7 +597,8 @@ class LogValidatorTest {
       magic = RecordBatch.MAGIC_VALUE_V1,
       timestampType = TimestampType.LOG_APPEND_TIME,
       timestampDiffMaxMs = 1000L,
-      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH).validatedRecords, offset)
+      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH,
+      isFromClient = true).validatedRecords, offset)
@@ -578,7 +615,8 @@ class LogValidatorTest {
       magic = RecordBatch.MAGIC_VALUE_V2,
       timestampType = TimestampType.LOG_APPEND_TIME,
       timestampDiffMaxMs = 1000L,
-      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH).validatedRecords, offset)
+      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH,
+      isFromClient = true).validatedRecords, offset)
@@ -595,7 +633,8 @@ class LogValidatorTest {
       magic = RecordBatch.MAGIC_VALUE_V1,
       timestampType = TimestampType.LOG_APPEND_TIME,
       timestampDiffMaxMs = 1000L,
-      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH).validatedRecords, offset)
+      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH,
+      isFromClient = true).validatedRecords, offset)
@@ -612,7 +651,48 @@ class LogValidatorTest {
       magic = RecordBatch.MAGIC_VALUE_V2,
       timestampType = TimestampType.LOG_APPEND_TIME,
       timestampDiffMaxMs = 1000L,
-      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH).validatedRecords, offset)
+      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH,
+      isFromClient = true).validatedRecords, offset)
+  }
+  @Test(expected = classOf[InvalidRecordException])
+  def testControlRecordsNotAllowedFromClients() {
+    val offset = 1234567
+    val endTxnMarker = new EndTransactionMarker(ControlRecordType.COMMIT, 0)
+    val records = MemoryRecords.withEndTransactionMarker(23423L, 5, endTxnMarker)
+    LogValidator.validateMessagesAndAssignOffsets(records,
+      offsetCounter = new LongRef(offset),
+      now = System.currentTimeMillis(),
+      sourceCodec = NoCompressionCodec,
+      targetCodec = NoCompressionCodec,
+      compactedTopic = false,
+      magic = RecordBatch.CURRENT_MAGIC_VALUE,
+      timestampType = TimestampType.CREATE_TIME,
+      timestampDiffMaxMs = 5000L,
+      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH,
+      isFromClient = true)
+  }
+  @Test
+  def testControlRecordsNotCompressed() {
+    val offset = 1234567
+    val endTxnMarker = new EndTransactionMarker(ControlRecordType.COMMIT, 0)
+    val records = MemoryRecords.withEndTransactionMarker(23423L, 5, endTxnMarker)
+    val result = LogValidator.validateMessagesAndAssignOffsets(records,
+      offsetCounter = new LongRef(offset),
+      now = System.currentTimeMillis(),
+      sourceCodec = NoCompressionCodec,
+      targetCodec = SnappyCompressionCodec,
+      compactedTopic = false,
+      magic = RecordBatch.CURRENT_MAGIC_VALUE,
+      timestampType = TimestampType.CREATE_TIME,
+      timestampDiffMaxMs = 5000L,
+      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH,
+      isFromClient = false)
+    val batches = TestUtils.toList(result.validatedRecords.batches)
+    assertEquals(1, batches.size)
+    val batch = batches.get(0)
+    assertFalse(batch.isCompressed)
@@ -630,7 +710,8 @@ class LogValidatorTest {
       magic = RecordBatch.MAGIC_VALUE_V0,
       timestampType = TimestampType.CREATE_TIME,
       timestampDiffMaxMs = 5000L,
-      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH).validatedRecords, offset)
+      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH,
+      isFromClient = true).validatedRecords, offset)
@@ -648,7 +729,8 @@ class LogValidatorTest {
       magic = RecordBatch.MAGIC_VALUE_V0,
       timestampType = TimestampType.CREATE_TIME,
       timestampDiffMaxMs = 5000L,
-      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH).validatedRecords, offset)
+      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH,
+      isFromClient = true).validatedRecords, offset)
@@ -665,7 +747,8 @@ class LogValidatorTest {
       magic = RecordBatch.MAGIC_VALUE_V2,
       timestampType = TimestampType.LOG_APPEND_TIME,
       timestampDiffMaxMs = 1000L,
-      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH).validatedRecords, offset)
+      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH,
+      isFromClient = true).validatedRecords, offset)
@@ -682,7 +765,8 @@ class LogValidatorTest {
       magic = RecordBatch.MAGIC_VALUE_V2,
       timestampType = TimestampType.LOG_APPEND_TIME,
       timestampDiffMaxMs = 1000L,
-      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH).validatedRecords, offset)
+      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH,
+      isFromClient = true).validatedRecords, offset)
@@ -700,7 +784,8 @@ class LogValidatorTest {
       magic = RecordBatch.MAGIC_VALUE_V1,
       timestampType = TimestampType.CREATE_TIME,
       timestampDiffMaxMs = 5000L,
-      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH).validatedRecords, offset)
+      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH,
+      isFromClient = true).validatedRecords, offset)
@@ -718,7 +803,8 @@ class LogValidatorTest {
       magic = RecordBatch.MAGIC_VALUE_V1,
       timestampType = TimestampType.CREATE_TIME,
       timestampDiffMaxMs = 5000L,
-      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH).validatedRecords, offset)
+      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH,
+      isFromClient = true).validatedRecords, offset)
@@ -736,7 +822,8 @@ class LogValidatorTest {
       magic = RecordBatch.MAGIC_VALUE_V0,
       timestampType = TimestampType.CREATE_TIME,
       timestampDiffMaxMs = 5000L,
-      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH).validatedRecords, offset)
+      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH,
+      isFromClient = true).validatedRecords, offset)
@@ -754,7 +841,8 @@ class LogValidatorTest {
       magic = RecordBatch.MAGIC_VALUE_V0,
       timestampType = TimestampType.CREATE_TIME,
       timestampDiffMaxMs = 5000L,
-      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH).validatedRecords, offset)
+      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH,
+      isFromClient = true).validatedRecords, offset)
   @Test(expected = classOf[InvalidRecordException])
@@ -770,7 +858,8 @@ class LogValidatorTest {
       magic = RecordBatch.MAGIC_VALUE_V1,
       timestampType = TimestampType.CREATE_TIME,
       timestampDiffMaxMs = 5000L,
-      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH)
+      partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH,
+      isFromClient = true)
   private def createRecords(magicValue: Byte = RecordBatch.CURRENT_MAGIC_VALUE,
diff --git a/core/src/test/scala/unit/kafka/log/OffsetIndexTest.scala b/core/src/test/scala/unit/kafka/log/OffsetIndexTest.scala
index 7618cf7..506d99c 100644
--- a/core/src/test/scala/unit/kafka/log/OffsetIndexTest.scala
+++ b/core/src/test/scala/unit/kafka/log/OffsetIndexTest.scala
@@ -95,7 +95,29 @@ class OffsetIndexTest extends JUnitSuite {
     idx.append(51, 0)
     idx.append(50, 1)
+  @Test
+  def testFetchUpperBoundOffset() {
+    val first = OffsetPosition(0, 0)
+    val second = OffsetPosition(1, 10)
+    val third = OffsetPosition(2, 23)
+    val fourth = OffsetPosition(3, 37)
+    assertEquals(None, idx.fetchUpperBoundOffset(first, 5))
+    for (offsetPosition <- Seq(first, second, third, fourth))
+      idx.append(offsetPosition.offset, offsetPosition.position)
+    assertEquals(Some(second), idx.fetchUpperBoundOffset(first, 5))
+    assertEquals(Some(second), idx.fetchUpperBoundOffset(first, 10))
+    assertEquals(Some(third), idx.fetchUpperBoundOffset(first, 23))
+    assertEquals(Some(third), idx.fetchUpperBoundOffset(first, 22))
+    assertEquals(Some(fourth), idx.fetchUpperBoundOffset(second, 24))
+    assertEquals(None, idx.fetchUpperBoundOffset(fourth, 1))
+    assertEquals(None, idx.fetchUpperBoundOffset(first, 200))
+    assertEquals(None, idx.fetchUpperBoundOffset(second, 200))
+  }
   def testReopen() {
     val first = OffsetPosition(51, 0)
diff --git a/core/src/test/scala/unit/kafka/log/ProducerIdMappingTest.scala b/core/src/test/scala/unit/kafka/log/ProducerIdMappingTest.scala
deleted file mode 100644
index 1bf983c..0000000
--- a/core/src/test/scala/unit/kafka/log/ProducerIdMappingTest.scala
+++ /dev/null
@@ -1,291 +0,0 @@
-  * Licensed to the Apache Software Foundation (ASF) under one or more
-  * contributor license agreements.  See the NOTICE file distributed with
-  * this work for additional information regarding copyright ownership.
-  * The ASF licenses this file to You under the Apache License, Version 2.0
-  * (the "License"); you may not use this file except in compliance with
-  * the License.  You may obtain a copy of the License at
-  *
-  *
-  *
-  * Unless required by applicable law or agreed to in writing, software
-  * distributed under the License is distributed on an "AS IS" BASIS,
-  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-  * See the License for the specific language governing permissions and
-  * limitations under the License.
-  */
-package kafka.log
-import java.util.Properties
-import kafka.utils.TestUtils
-import org.apache.kafka.common.TopicPartition
-import org.apache.kafka.common.errors.{DuplicateSequenceNumberException, OutOfOrderSequenceException, ProducerFencedException}
-import org.apache.kafka.common.utils.{MockTime, Utils}
-import org.junit.Assert._
-import org.junit.{After, Before, Test}
-import org.scalatest.junit.JUnitSuite
-class ProducerIdMappingTest extends JUnitSuite {
-  var idMappingDir: File = null
-  var config: LogConfig = null
-  var idMapping: ProducerIdMapping = null
-  val partition = new TopicPartition("test", 0)
-  val pid = 1L
-  val maxPidExpirationMs = 60 * 1000
-  val time = new MockTime
-  @Before
-  def setUp(): Unit = {
-    config = LogConfig(new Properties)
-    idMappingDir = TestUtils.tempDir()
-    idMapping = new ProducerIdMapping(config, partition, idMappingDir, maxPidExpirationMs)
-  }
-  @After
-  def tearDown(): Unit = {
-    Utils.delete(idMappingDir)
-  }
-  @Test
-  def testBasicIdMapping(): Unit = {
-    val epoch = 0.toShort
-    // First entry for id 0 added
-    checkAndUpdate(idMapping, pid, 0, epoch, 0L, 0L)
-    // Second entry for id 0 added
-    checkAndUpdate(idMapping, pid, 1, epoch, 0L, 1L)
-    // Duplicate sequence number (matches previous sequence number)
-    assertThrows[DuplicateSequenceNumberException] {
-      checkAndUpdate(idMapping, pid, 1, epoch, 0L, 1L)
-    }
-    // Invalid sequence number (greater than next expected sequence number)
-    assertThrows[OutOfOrderSequenceException] {
-      checkAndUpdate(idMapping, pid, 5, epoch, 0L, 2L)
-    }
-    // Change epoch
-    checkAndUpdate(idMapping, pid, 0, (epoch + 1).toShort, 0L, 3L)
-    // Incorrect epoch
-    assertThrows[ProducerFencedException] {
-      checkAndUpdate(idMapping, pid, 0, epoch, 0L, 4L)
-    }
-  }
-  @Test
-  def testTakeSnapshot(): Unit = {
-    val epoch = 0.toShort
-    checkAndUpdate(idMapping, pid, 0, epoch, 0L, 0L)
-    checkAndUpdate(idMapping, pid, 1, epoch, 1L, 1L)
-    // Take snapshot
-    idMapping.maybeTakeSnapshot()
-    // Check that file exists and it is not empty
-    assertEquals("Directory doesn't contain a single file as expected", 1, idMappingDir.list().length)
-    assertTrue("Snapshot file is empty", idMappingDir.list().head.length > 0)
-  }
-  @Test
-  def testRecoverFromSnapshot(): Unit = {
-    val epoch = 0.toShort
-    checkAndUpdate(idMapping, pid, 0, epoch, 0L, time.milliseconds)
-    checkAndUpdate(idMapping, pid, 1, epoch, 1L, time.milliseconds)
-    idMapping.maybeTakeSnapshot()
-    val recoveredMapping = new ProducerIdMapping(config, partition, idMappingDir, maxPidExpirationMs)
-    recoveredMapping.truncateAndReload(0L, 3L, time.milliseconds)
-    // entry added after recovery
-    checkAndUpdate(recoveredMapping, pid, 2, epoch, 2L, time.milliseconds)
-  }
-  @Test(expected = classOf[OutOfOrderSequenceException])
-  def testRemoveExpiredPidsOnReload(): Unit = {
-    val epoch = 0.toShort
-    checkAndUpdate(idMapping, pid, 0, epoch, 0L, 0)
-    checkAndUpdate(idMapping, pid, 1, epoch, 1L, 1)
-    idMapping.maybeTakeSnapshot()
-    val recoveredMapping = new ProducerIdMapping(config, partition, idMappingDir, maxPidExpirationMs)
-    recoveredMapping.truncateAndReload(0L, 1L, 70000)
-    // entry added after recovery. The pid should be expired now, and would not exist in the pid mapping. Hence
-    // we should get an out of order sequence exception.
-    checkAndUpdate(recoveredMapping, pid, 2, epoch, 2L, 70001)
-  }
-  @Test
-  def testRemoveOldSnapshot(): Unit = {
-    val epoch = 0.toShort
-    checkAndUpdate(idMapping, pid, 0, epoch, 0L)
-    checkAndUpdate(idMapping, pid, 1, epoch, 1L)
-    idMapping.maybeTakeSnapshot()
-    assertEquals(1, idMappingDir.listFiles().length)
-    assertEquals(Set(2), currentSnapshotOffsets)
-    checkAndUpdate(idMapping, pid, 2, epoch, 2L)
-    idMapping.maybeTakeSnapshot()
-    assertEquals(2, idMappingDir.listFiles().length)
-    assertEquals(Set(2, 3), currentSnapshotOffsets)
-    // we only retain two snapshot files, so the next snapshot should cause the oldest to be deleted
-    checkAndUpdate(idMapping, pid, 3, epoch, 3L)
-    idMapping.maybeTakeSnapshot()
-    assertEquals(2, idMappingDir.listFiles().length)
-    assertEquals(Set(3, 4), currentSnapshotOffsets)
-  }
-  @Test
-  def testTruncate(): Unit = {
-    val epoch = 0.toShort
-    checkAndUpdate(idMapping, pid, 0, epoch, 0L)
-    checkAndUpdate(idMapping, pid, 1, epoch, 1L)
-    idMapping.maybeTakeSnapshot()
-    assertEquals(1, idMappingDir.listFiles().length)
-    assertEquals(Set(2), currentSnapshotOffsets)
-    checkAndUpdate(idMapping, pid, 2, epoch, 2L)
-    idMapping.maybeTakeSnapshot()
-    assertEquals(2, idMappingDir.listFiles().length)
-    assertEquals(Set(2, 3), currentSnapshotOffsets)
-    idMapping.truncate()
-    assertEquals(0, idMappingDir.listFiles().length)
-    assertEquals(Set(), currentSnapshotOffsets)
-    checkAndUpdate(idMapping, pid, 0, epoch, 0L)
-    idMapping.maybeTakeSnapshot()
-    assertEquals(1, idMappingDir.listFiles().length)
-    assertEquals(Set(1), currentSnapshotOffsets)
-  }
-  @Test
-  def testExpirePids(): Unit = {
-    val epoch = 0.toShort
-    checkAndUpdate(idMapping, pid, 0, epoch, 0L)
-    checkAndUpdate(idMapping, pid, 1, epoch, 1L)
-    idMapping.maybeTakeSnapshot()
-    val anotherPid = 2L
-    checkAndUpdate(idMapping, anotherPid, 0, epoch, 2L)
-    checkAndUpdate(idMapping, anotherPid, 1, epoch, 3L)
-    idMapping.maybeTakeSnapshot()
-    assertEquals(Set(2, 4), currentSnapshotOffsets)
-    idMapping.expirePids(2)
-    assertEquals(Set(4), currentSnapshotOffsets)
-    assertEquals(Set(anotherPid), idMapping.activePids.keySet)
-    assertEquals(None, idMapping.lastEntry(pid))
-    val maybeEntry = idMapping.lastEntry(anotherPid)
-    assertTrue(maybeEntry.isDefined)
-    assertEquals(3L, maybeEntry.get.lastOffset)
-    idMapping.expirePids(3)
-    assertEquals(Set(anotherPid), idMapping.activePids.keySet)
-    assertEquals(Set(4), currentSnapshotOffsets)
-    assertEquals(4, idMapping.mapEndOffset)
-    idMapping.expirePids(5)
-    assertEquals(Set(), idMapping.activePids.keySet)
-    assertEquals(Set(), currentSnapshotOffsets)
-    assertEquals(5, idMapping.mapEndOffset)
-    idMapping.maybeTakeSnapshot()
-    // shouldn't be any new snapshot because the log is empty
-    assertEquals(Set(), currentSnapshotOffsets)
-  }
-  @Test
-  def testSkipSnapshotIfOffsetUnchanged(): Unit = {
-    val epoch = 0.toShort
-    checkAndUpdate(idMapping, pid, 0, epoch, 0L, 0L)
-    idMapping.maybeTakeSnapshot()
-    assertEquals(1, idMappingDir.listFiles().length)
-    assertEquals(Set(1), currentSnapshotOffsets)
-    // nothing changed so there should be no new snapshot
-    idMapping.maybeTakeSnapshot()
-    assertEquals(1, idMappingDir.listFiles().length)
-    assertEquals(Set(1), currentSnapshotOffsets)
-  }
-  @Test
-  def testStartOffset(): Unit = {
-    val epoch = 0.toShort
-    val pid2 = 2L
-    checkAndUpdate(idMapping, pid2, 0, epoch, 0L, 1L)
-    checkAndUpdate(idMapping, pid, 0, epoch, 1L, 2L)
-    checkAndUpdate(idMapping, pid, 1, epoch, 2L, 3L)
-    checkAndUpdate(idMapping, pid, 2, epoch, 3L, 4L)
-    idMapping.maybeTakeSnapshot()
-    intercept[OutOfOrderSequenceException] {
-      val recoveredMapping = new ProducerIdMapping(config, partition, idMappingDir, maxPidExpirationMs)
-      recoveredMapping.truncateAndReload(0L, 1L, time.milliseconds)
-      checkAndUpdate(recoveredMapping, pid2, 1, epoch, 4L, 5L)
-    }
-  }
-  @Test(expected = classOf[OutOfOrderSequenceException])
-  def testPidExpirationTimeout() {
-    val epoch = 5.toShort
-    val sequence = 37
-    checkAndUpdate(idMapping, pid, sequence, epoch, 1L)
-    time.sleep(maxPidExpirationMs + 1)
-    idMapping.removeExpiredPids(time.milliseconds)
-    checkAndUpdate(idMapping, pid, sequence + 1, epoch, 1L)
-  }
-  @Test
-  def testLoadPid() {
-    val epoch = 5.toShort
-    val sequence = 37
-    val createTimeMs = time.milliseconds
-    idMapping.load(pid, ProducerIdEntry(epoch, sequence, 0L, 1, createTimeMs), time.milliseconds)
-    checkAndUpdate(idMapping, pid, sequence + 1, epoch, 2L)
-  }
-  @Test(expected = classOf[OutOfOrderSequenceException])
-  def testLoadIgnoresExpiredPids() {
-    val epoch = 5.toShort
-    val sequence = 37
-    val createTimeMs = time.milliseconds
-    time.sleep(maxPidExpirationMs + 1)
-    val loadTimeMs = time.milliseconds
-    idMapping.load(pid, ProducerIdEntry(epoch, sequence, 0L, 1, createTimeMs), loadTimeMs)
-    // entry wasn't loaded, so this should fail
-    checkAndUpdate(idMapping, pid, sequence + 1, epoch, 2L)
-  }
-  private def checkAndUpdate(mapping: ProducerIdMapping,
-                             pid: Long,
-                             seq: Int,
-                             epoch: Short,
-                             lastOffset: Long,
-                             timestamp: Long = time.milliseconds()): Unit = {
-    val offsetDelta = 0
-    val incomingPidEntry = ProducerIdEntry(epoch, seq, lastOffset, offsetDelta, timestamp)
-    val producerAppendInfo = new ProducerAppendInfo(pid, mapping.lastEntry(pid).getOrElse(ProducerIdEntry.Empty))
-    producerAppendInfo.append(incomingPidEntry)
-    mapping.update(producerAppendInfo)
-    mapping.updateMapEndOffset(lastOffset + 1)
-  }
-  private def currentSnapshotOffsets =
-    idMappingDir.listFiles().map(file => Log.offsetFromFilename(file.getName)).toSet