You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jg...@apache.org on 2017/05/12 19:32:53 UTC

[2/2] kafka git commit: KAFKA-5196; Make LogCleaner transaction-aware

KAFKA-5196; Make LogCleaner transaction-aware

Author: Jason Gustafson <ja...@confluent.io>

Reviewers: Jun Rao <ju...@gmail.com>

Closes #3008 from hachikuji/KAFKA-5196


Project: http://git-wip-us.apache.org/repos/asf/kafka/repo
Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/7baa58d7
Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/7baa58d7
Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/7baa58d7

Branch: refs/heads/trunk
Commit: 7baa58d797126b6fb2b1de30e72428895d2bcb40
Parents: 1c2bbaa
Author: Jason Gustafson <ja...@confluent.io>
Authored: Fri May 12 12:07:22 2017 -0700
Committer: Jason Gustafson <ja...@confluent.io>
Committed: Fri May 12 12:07:22 2017 -0700

----------------------------------------------------------------------
 .../clients/consumer/internals/Fetcher.java     |  38 ++-
 .../kafka/common/record/MemoryRecords.java      |  51 +++-
 .../clients/consumer/internals/FetcherTest.java | 189 ++++++++++---
 .../common/record/MemoryRecordsBuilderTest.java |   4 +-
 .../kafka/common/record/MemoryRecordsTest.java  |  57 +++-
 core/src/main/scala/kafka/log/Log.scala         |  30 +-
 core/src/main/scala/kafka/log/LogCleaner.scala  | 257 +++++++++++++----
 .../scala/kafka/log/LogCleanerManager.scala     |  29 +-
 core/src/main/scala/kafka/log/OffsetMap.scala   |   5 +
 .../scala/kafka/log/ProducerStateManager.scala  |   2 +-
 .../main/scala/kafka/log/TransactionIndex.scala |  13 +-
 .../unit/kafka/log/LogCleanerManagerTest.scala  |  48 ++++
 .../scala/unit/kafka/log/LogCleanerTest.scala   | 280 ++++++++++++++++---
 .../scala/unit/kafka/log/LogSegmentTest.scala   |  26 +-
 .../src/test/scala/unit/kafka/log/LogTest.scala |  21 +-
 .../unit/kafka/log/TransactionIndexTest.scala   |  11 +-
 16 files changed, 840 insertions(+), 221 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kafka/blob/7baa58d7/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
index 66221c0..e4365da 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
@@ -1015,6 +1015,11 @@ public class Fetcher<K, V> implements SubscriptionState.Listener, Closeable {
                     maybeEnsureValid(currentBatch);
 
                     if (isolationLevel == IsolationLevel.READ_COMMITTED && currentBatch.hasProducerId()) {
+                        // remove from the aborted transaction queue all aborted transactions which have begun
+                        // before the current batch's last offset and add the associated producerIds to the
+                        // aborted producer set
+                        consumeAbortedTransactionsUpTo(currentBatch.lastOffset());
+
                         long producerId = currentBatch.producerId();
                         if (containsAbortMarker(currentBatch)) {
                             abortedProducerIds.remove(producerId);
@@ -1072,29 +1077,18 @@ public class Fetcher<K, V> implements SubscriptionState.Listener, Closeable {
             return records;
         }
 
-        private boolean isBatchAborted(RecordBatch batch) {
-            /* When in READ_COMMITTED mode, we need to do the following for each incoming entry:
-            *   0. Check whether the pid is in the 'abortedProducerIds' set && the entry does not include an abort marker.
-            *      If so, skip the entry.
-            *   1. If the pid is in aborted pids and the entry contains an abort marker, remove the pid from
-            *      aborted pids and skip the entry.
-            *   2. Check lowest offset entry in the abort index. If the PID of the current entry matches the
-            *      pid of the abort index entry, and the incoming offset is no smaller than the abort index offset,
-            *      this means that the entry has been aborted. Add the pid to the aborted pids set, and remove
-            *      the entry from the abort index.
-            */
-            long producerId = batch.producerId();
-            if (abortedProducerIds.contains(producerId)) {
-                return true;
-            } else if (abortedTransactions != null && !abortedTransactions.isEmpty()) {
-                FetchResponse.AbortedTransaction nextAbortedTransaction = abortedTransactions.peek();
-                if (nextAbortedTransaction.producerId == producerId && nextAbortedTransaction.firstOffset <= batch.baseOffset()) {
-                    abortedProducerIds.add(producerId);
-                    abortedTransactions.poll();
-                    return true;
-                }
+        private void consumeAbortedTransactionsUpTo(long offset) {
+            if (abortedTransactions == null)
+                return;
+
+            while (!abortedTransactions.isEmpty() && abortedTransactions.peek().firstOffset <= offset) {
+                FetchResponse.AbortedTransaction abortedTransaction = abortedTransactions.poll();
+                abortedProducerIds.add(abortedTransaction.producerId);
             }
-            return false;
+        }
+
+        private boolean isBatchAborted(RecordBatch batch) {
+            return batch.isTransactional() && abortedProducerIds.contains(batch.producerId());
         }
 
         private PriorityQueue<FetchResponse.AbortedTransaction> abortedTransactions(FetchResponse.PartitionData partition) {

http://git-wip-us.apache.org/repos/asf/kafka/blob/7baa58d7/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
index c8754c7..a222cc3 100644
--- a/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
+++ b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
@@ -131,6 +131,9 @@ public class MemoryRecords extends AbstractRecords {
         for (MutableRecordBatch batch : batches) {
             bytesRead += batch.sizeInBytes();
 
+            if (filter.shouldDiscard(batch))
+                continue;
+
             // We use the absolute offset to decide whether to retain the message or not. Due to KAFKA-4298, we have to
             // allow for the possibility that a previous version corrupted the log by writing a compressed record batch
             // with a magic value not matching the magic of the records (magic < 2). This will be fixed as we
@@ -251,8 +254,21 @@ public class MemoryRecords extends AbstractRecords {
         return buffer.hashCode();
     }
 
-    public interface RecordFilter {
-        boolean shouldRetain(RecordBatch recordBatch, Record record);
+    public static abstract class RecordFilter {
+        /**
+         * Check whether the full batch can be discarded (i.e. whether we even need to
+         * check the records individually).
+         */
+        protected boolean shouldDiscard(RecordBatch batch) {
+            return false;
+        }
+
+        /**
+         * Check whether a record should be retained in the log. Only records from
+         * batches which were not discarded with {@link #shouldDiscard(RecordBatch)}
+         * will be considered.
+         */
+        protected abstract boolean shouldRetain(RecordBatch recordBatch, Record record);
     }
 
     public static class FilterResult {
@@ -432,9 +448,10 @@ public class MemoryRecords extends AbstractRecords {
     }
 
     public static MemoryRecords withTransactionalRecords(long initialOffset, CompressionType compressionType, long producerId,
-                                                         short producerEpoch, int baseSequence, SimpleRecord... records) {
+                                                         short producerEpoch, int baseSequence, int partitionLeaderEpoch,
+                                                         SimpleRecord... records) {
         return withRecords(RecordBatch.CURRENT_MAGIC_VALUE, initialOffset, compressionType, TimestampType.CREATE_TIME,
-                producerId, producerEpoch, baseSequence, RecordBatch.NO_PARTITION_LEADER_EPOCH, true, records);
+                producerId, producerEpoch, baseSequence, partitionLeaderEpoch, true, records);
     }
 
     public static MemoryRecords withRecords(byte magic, long initialOffset, CompressionType compressionType,
@@ -464,28 +481,38 @@ public class MemoryRecords extends AbstractRecords {
     }
 
     public static MemoryRecords withEndTransactionMarker(long producerId, short producerEpoch, EndTransactionMarker marker) {
-        return withEndTransactionMarker(0L, producerId, producerEpoch, marker);
+        return withEndTransactionMarker(0L, System.currentTimeMillis(), RecordBatch.NO_PARTITION_LEADER_EPOCH,
+                producerId, producerEpoch, marker);
+    }
+
+    public static MemoryRecords withEndTransactionMarker(long timestamp, long producerId, short producerEpoch,
+                                                         EndTransactionMarker marker) {
+        return withEndTransactionMarker(0L, timestamp, RecordBatch.NO_PARTITION_LEADER_EPOCH, producerId,
+                producerEpoch, marker);
     }
 
-    public static MemoryRecords withEndTransactionMarker(long initialOffset, long producerId, short producerEpoch,
+    public static MemoryRecords withEndTransactionMarker(long initialOffset, long timestamp, int partitionLeaderEpoch,
+                                                         long producerId, short producerEpoch,
                                                          EndTransactionMarker marker) {
         int endTxnMarkerBatchSize = DefaultRecordBatch.RECORD_BATCH_OVERHEAD +
                 EndTransactionMarker.CURRENT_END_TXN_SCHEMA_RECORD_SIZE;
         ByteBuffer buffer = ByteBuffer.allocate(endTxnMarkerBatchSize);
-        writeEndTransactionalMarker(buffer, initialOffset, producerId, producerEpoch, marker);
+        writeEndTransactionalMarker(buffer, initialOffset, timestamp, partitionLeaderEpoch, producerId,
+                producerEpoch, marker);
         buffer.flip();
         return MemoryRecords.readableRecords(buffer);
     }
 
-    public static void writeEndTransactionalMarker(ByteBuffer buffer, long initialOffset, long producerId,
-                                                   short producerEpoch, EndTransactionMarker marker) {
+    public static void writeEndTransactionalMarker(ByteBuffer buffer, long initialOffset, long timestamp,
+                                                   int partitionLeaderEpoch, long producerId, short producerEpoch,
+                                                   EndTransactionMarker marker) {
         boolean isTransactional = true;
         boolean isControlBatch = true;
         MemoryRecordsBuilder builder = new MemoryRecordsBuilder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE,
-                TimestampType.CREATE_TIME, initialOffset, RecordBatch.NO_TIMESTAMP, producerId, producerEpoch,
-                RecordBatch.NO_SEQUENCE, isTransactional, isControlBatch, RecordBatch.NO_PARTITION_LEADER_EPOCH,
+                TimestampType.CREATE_TIME, initialOffset, timestamp, producerId, producerEpoch,
+                RecordBatch.NO_SEQUENCE, isTransactional, isControlBatch, partitionLeaderEpoch,
                 buffer.capacity());
-        builder.appendEndTxnMarker(System.currentTimeMillis(), marker);
+        builder.appendEndTxnMarker(timestamp, marker);
         builder.close();
     }
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/7baa58d7/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
----------------------------------------------------------------------
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
index 743568d..4e46d57 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
@@ -53,6 +53,7 @@ import org.apache.kafka.common.record.Record;
 import org.apache.kafka.common.record.RecordBatch;
 import org.apache.kafka.common.record.SimpleRecord;
 import org.apache.kafka.common.requests.IsolationLevel;
+import org.apache.kafka.common.serialization.StringDeserializer;
 import org.apache.kafka.common.utils.ByteBufferOutputStream;
 import org.apache.kafka.common.record.TimestampType;
 import org.apache.kafka.common.requests.AbstractRequest;
@@ -184,17 +185,18 @@ public class FetcherTest {
         assertFalse(fetcher.hasCompletedFetches());
 
         long producerId = 1;
-        short epoch = 0;
+        short producerEpoch = 0;
         int baseSequence = 0;
+        int partitionLeaderEpoch = 0;
 
         ByteBuffer buffer = ByteBuffer.allocate(1024);
         MemoryRecordsBuilder builder = MemoryRecords.idempotentBuilder(buffer, CompressionType.NONE, 0L, producerId,
-                epoch, baseSequence);
+                producerEpoch, baseSequence);
         builder.append(0L, "key".getBytes(), null);
         builder.close();
 
-        MemoryRecords.writeEndTransactionalMarker(buffer, 1L, producerId, epoch, new EndTransactionMarker(ControlRecordType.ABORT, 0)
-        );
+        MemoryRecords.writeEndTransactionalMarker(buffer, 1L, time.milliseconds(), partitionLeaderEpoch, producerId, producerEpoch,
+                new EndTransactionMarker(ControlRecordType.ABORT, 0));
 
         buffer.flip();
 
@@ -1326,51 +1328,54 @@ public class FetcherTest {
     }
 
     @Test
-    public void testWithCommittedAndAbortedTransactions() {
+    public void testReadCommittedWithCommittedAndAbortedTransactions() {
         Fetcher<byte[], byte[]> fetcher = createFetcher(subscriptions, new Metrics(), new ByteArrayDeserializer(),
                 new ByteArrayDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_COMMITTED);
         ByteBuffer buffer = ByteBuffer.allocate(1024);
 
         List<FetchResponse.AbortedTransaction> abortedTransactions = new ArrayList<>();
 
-        int currOffset = 0;
+        long pid1 = 1L;
+        long pid2 = 2L;
+
         // Appends for producer 1 (eventually committed)
-        currOffset += appendTransactionalRecords(buffer, 1L, currOffset,
-                new SimpleRecord(time.milliseconds(), "commit1-1".getBytes(), "value".getBytes()),
-                new SimpleRecord(time.milliseconds(), "commit1-2".getBytes(), "value".getBytes()));
+        appendTransactionalRecords(buffer, pid1, 0L,
+                new SimpleRecord("commit1-1".getBytes(), "value".getBytes()),
+                new SimpleRecord("commit1-2".getBytes(), "value".getBytes()));
 
         // Appends for producer 2 (eventually aborted)
-        currOffset += appendTransactionalRecords(buffer, 2L, currOffset,
-                new SimpleRecord(time.milliseconds(), "abort2-1".getBytes(), "value".getBytes()));
+        appendTransactionalRecords(buffer, pid2, 2L,
+                new SimpleRecord("abort2-1".getBytes(), "value".getBytes()));
 
         // commit producer 1
-        currOffset += commitTransaction(buffer, 1L, currOffset);
+        commitTransaction(buffer, pid1, 3L);
+
         // append more for producer 2 (eventually aborted)
-        currOffset += appendTransactionalRecords(buffer, 2L, currOffset,
-                new SimpleRecord(time.milliseconds(), "abort2-2".getBytes(), "value".getBytes()));
+        appendTransactionalRecords(buffer, pid2, 4L,
+                new SimpleRecord("abort2-2".getBytes(), "value".getBytes()));
 
         // abort producer 2
-        currOffset += abortTransaction(buffer, 2L, currOffset);
-        abortedTransactions.add(new FetchResponse.AbortedTransaction(2, 2));
+        abortTransaction(buffer, pid2, 5L);
+        abortedTransactions.add(new FetchResponse.AbortedTransaction(pid2, 2L));
 
         // New transaction for producer 1 (eventually aborted)
-        currOffset += appendTransactionalRecords(buffer, 1L, currOffset,
-                new SimpleRecord(time.milliseconds(), "abort1-1".getBytes(), "value".getBytes()));
+        appendTransactionalRecords(buffer, pid1, 6L,
+                new SimpleRecord("abort1-1".getBytes(), "value".getBytes()));
 
         // New transaction for producer 2 (eventually committed)
-        currOffset += appendTransactionalRecords(buffer, 2L, currOffset,
-                new SimpleRecord(time.milliseconds(), "commit2-1".getBytes(), "value".getBytes()));
+        appendTransactionalRecords(buffer, pid2, 7L,
+                new SimpleRecord("commit2-1".getBytes(), "value".getBytes()));
 
         // Add messages for producer 1 (eventually aborted)
-        currOffset += appendTransactionalRecords(buffer, 1L, currOffset,
-                new SimpleRecord(time.milliseconds(), "abort1-2".getBytes(), "value".getBytes()));
+        appendTransactionalRecords(buffer, pid1, 8L,
+                new SimpleRecord("abort1-2".getBytes(), "value".getBytes()));
 
         // abort producer 1
-        currOffset += abortTransaction(buffer, 1L, currOffset);
+        abortTransaction(buffer, pid1, 9L);
         abortedTransactions.add(new FetchResponse.AbortedTransaction(1, 6));
 
         // commit producer 2
-        currOffset += commitTransaction(buffer, 2L, currOffset);
+        commitTransaction(buffer, pid2, 10L);
 
         buffer.flip();
 
@@ -1416,7 +1421,7 @@ public class FetcherTest {
         currentOffset += appendTransactionalRecords(buffer, 1L, currentOffset,
                 new SimpleRecord(time.milliseconds(), "commit1-1".getBytes(), "value".getBytes()),
                 new SimpleRecord(time.milliseconds(), "commit1-2".getBytes(), "value".getBytes()));
-        currentOffset += commitTransaction(buffer, 1L, currentOffset);
+        commitTransaction(buffer, 1L, currentOffset);
         buffer.flip();
 
         List<FetchResponse.AbortedTransaction> abortedTransactions = new ArrayList<>();
@@ -1447,6 +1452,108 @@ public class FetcherTest {
     }
 
     @Test
+    public void testReadCommittedAbortMarkerWithNoData() {
+        Fetcher<String, String> fetcher = createFetcher(subscriptions, new Metrics(), new StringDeserializer(),
+                new StringDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_COMMITTED);
+        ByteBuffer buffer = ByteBuffer.allocate(1024);
+
+        long producerId = 1L;
+
+        abortTransaction(buffer, producerId, 5L);
+
+        appendTransactionalRecords(buffer, producerId, 6L,
+                new SimpleRecord("6".getBytes(), null),
+                new SimpleRecord("7".getBytes(), null),
+                new SimpleRecord("8".getBytes(), null));
+
+        commitTransaction(buffer, producerId, 9L);
+
+        buffer.flip();
+
+        // send the fetch
+        subscriptions.assignFromUser(singleton(tp1));
+        subscriptions.seek(tp1, 0);
+        assertEquals(1, fetcher.sendFetches());
+
+        // prepare the response. the aborted transactions begin at offsets which are no longer in the log
+        List<FetchResponse.AbortedTransaction> abortedTransactions = new ArrayList<>();
+        abortedTransactions.add(new FetchResponse.AbortedTransaction(producerId, 0L));
+
+        client.prepareResponse(fetchResponseWithAbortedTransactions(MemoryRecords.readableRecords(buffer),
+                abortedTransactions, Errors.NONE, 100L, 100L, 0));
+        consumerClient.poll(0);
+        assertTrue(fetcher.hasCompletedFetches());
+
+        Map<TopicPartition, List<ConsumerRecord<String, String>>> allFetchedRecords = fetcher.fetchedRecords();
+        assertTrue(allFetchedRecords.containsKey(tp1));
+        List<ConsumerRecord<String, String>> fetchedRecords = allFetchedRecords.get(tp1);
+        assertEquals(3, fetchedRecords.size());
+        assertEquals(Arrays.asList(6L, 7L, 8L), collectRecordOffsets(fetchedRecords));
+    }
+
+    @Test
+    public void testReadCommittedWithCompactedTopic() {
+        Fetcher<String, String> fetcher = createFetcher(subscriptions, new Metrics(), new StringDeserializer(),
+                new StringDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_COMMITTED);
+        ByteBuffer buffer = ByteBuffer.allocate(1024);
+
+        long pid1 = 1L;
+        long pid2 = 2L;
+        long pid3 = 3L;
+
+        appendTransactionalRecords(buffer, pid3, 3L,
+                new SimpleRecord("3".getBytes(), "value".getBytes()),
+                new SimpleRecord("4".getBytes(), "value".getBytes()));
+
+        appendTransactionalRecords(buffer, pid2, 15L,
+                new SimpleRecord("15".getBytes(), "value".getBytes()),
+                new SimpleRecord("16".getBytes(), "value".getBytes()),
+                new SimpleRecord("17".getBytes(), "value".getBytes()));
+
+        appendTransactionalRecords(buffer, pid1, 22L,
+                new SimpleRecord("22".getBytes(), "value".getBytes()),
+                new SimpleRecord("23".getBytes(), "value".getBytes()));
+
+        abortTransaction(buffer, pid2, 28L);
+
+        appendTransactionalRecords(buffer, pid3, 30L,
+                new SimpleRecord("30".getBytes(), "value".getBytes()),
+                new SimpleRecord("31".getBytes(), "value".getBytes()),
+                new SimpleRecord("32".getBytes(), "value".getBytes()));
+
+        commitTransaction(buffer, pid3, 35L);
+
+        appendTransactionalRecords(buffer, pid1, 39L,
+                new SimpleRecord("39".getBytes(), "value".getBytes()),
+                new SimpleRecord("40".getBytes(), "value".getBytes()));
+
+        // transaction from pid1 is aborted, but the marker is not included in the fetch
+
+        buffer.flip();
+
+        // send the fetch
+        subscriptions.assignFromUser(singleton(tp1));
+        subscriptions.seek(tp1, 0);
+        assertEquals(1, fetcher.sendFetches());
+
+        // prepare the response. the aborted transactions begin at offsets which are no longer in the log
+        List<FetchResponse.AbortedTransaction> abortedTransactions = new ArrayList<>();
+        abortedTransactions.add(new FetchResponse.AbortedTransaction(pid2, 6L));
+        abortedTransactions.add(new FetchResponse.AbortedTransaction(pid1, 0L));
+
+        client.prepareResponse(fetchResponseWithAbortedTransactions(MemoryRecords.readableRecords(buffer),
+                abortedTransactions, Errors.NONE, 100L, 100L, 0));
+        consumerClient.poll(0);
+        assertTrue(fetcher.hasCompletedFetches());
+
+        Map<TopicPartition, List<ConsumerRecord<String, String>>> allFetchedRecords = fetcher.fetchedRecords();
+        assertTrue(allFetchedRecords.containsKey(tp1));
+        List<ConsumerRecord<String, String>> fetchedRecords = allFetchedRecords.get(tp1);
+        assertEquals(5, fetchedRecords.size());
+        assertEquals(Arrays.asList(3L, 4L, 30L, 31L, 32L), collectRecordOffsets(fetchedRecords));
+    }
+
+    @Test
     public void testReturnAbortedTransactionsinUncommittedMode() {
         Fetcher<byte[], byte[]> fetcher = createFetcher(subscriptions, new Metrics(), new ByteArrayDeserializer(),
                 new ByteArrayDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_UNCOMMITTED);
@@ -1457,7 +1564,7 @@ public class FetcherTest {
                 new SimpleRecord(time.milliseconds(), "key".getBytes(), "value".getBytes()),
                 new SimpleRecord(time.milliseconds(), "key".getBytes(), "value".getBytes()));
 
-        currentOffset += abortTransaction(buffer, 1L, currentOffset);
+        abortTransaction(buffer, 1L, currentOffset);
 
         buffer.flip();
 
@@ -1516,9 +1623,9 @@ public class FetcherTest {
         assertEquals(currentOffset, (long) subscriptions.position(tp1));
     }
 
-    private int appendTransactionalRecords(ByteBuffer buffer, long pid, long baseOffset, SimpleRecord... records) {
-        MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, CompressionType.NONE,
-                TimestampType.LOG_APPEND_TIME, baseOffset, time.milliseconds(), pid, (short) 0, (int) baseOffset, true,
+    private int appendTransactionalRecords(ByteBuffer buffer, long pid, long baseOffset, int baseSequence, SimpleRecord... records) {
+        MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE,
+                TimestampType.CREATE_TIME, baseOffset, time.milliseconds(), pid, (short) 0, baseSequence, true,
                 RecordBatch.NO_PARTITION_LEADER_EPOCH);
 
         for (SimpleRecord record : records) {
@@ -1528,14 +1635,22 @@ public class FetcherTest {
         return records.length;
     }
 
-    private int commitTransaction(ByteBuffer buffer, long producerId, int baseOffset) {
-        MemoryRecords.writeEndTransactionalMarker(buffer, baseOffset, producerId, (short) 0,
+    private int appendTransactionalRecords(ByteBuffer buffer, long pid, long baseOffset, SimpleRecord... records) {
+        return appendTransactionalRecords(buffer, pid, baseOffset, (int) baseOffset, records);
+    }
+
+    private int commitTransaction(ByteBuffer buffer, long producerId, long baseOffset) {
+        short producerEpoch = 0;
+        int partitionLeaderEpoch = 0;
+        MemoryRecords.writeEndTransactionalMarker(buffer, baseOffset, time.milliseconds(), partitionLeaderEpoch, producerId, producerEpoch,
                 new EndTransactionMarker(ControlRecordType.COMMIT, 0));
         return 1;
     }
 
     private int abortTransaction(ByteBuffer buffer, long producerId, long baseOffset) {
-        MemoryRecords.writeEndTransactionalMarker(buffer, baseOffset, producerId, (short) 0,
+        short producerEpoch = 0;
+        int partitionLeaderEpoch = 0;
+        MemoryRecords.writeEndTransactionalMarker(buffer, baseOffset, time.milliseconds(), partitionLeaderEpoch, producerId, producerEpoch,
                 new EndTransactionMarker(ControlRecordType.ABORT, 0));
         return 1;
     }
@@ -1605,7 +1720,9 @@ public class FetcherTest {
     private FetchResponse fetchResponseWithAbortedTransactions(MemoryRecords records,
                                                                List<FetchResponse.AbortedTransaction> abortedTransactions,
                                                                Errors error,
-                                                               long lastStableOffset, long hw, int throttleTime) {
+                                                               long lastStableOffset,
+                                                               long hw,
+                                                               int throttleTime) {
         Map<TopicPartition, FetchResponse.PartitionData> partitions = Collections.singletonMap(tp1,
                 new FetchResponse.PartitionData(error, hw, lastStableOffset, 0L, abortedTransactions, records));
         return new FetchResponse(new LinkedHashMap<>(partitions), throttleTime);
@@ -1681,4 +1798,10 @@ public class FetcherTest {
                 isolationLevel);
     }
 
+    private <T> List<Long> collectRecordOffsets(List<ConsumerRecord<T, T>> records) {
+        List<Long> res = new ArrayList<>(records.size());
+        for (ConsumerRecord<?, ?> record : records)
+            res.add(record.offset());
+        return res;
+    }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/7baa58d7/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsBuilderTest.java
----------------------------------------------------------------------
diff --git a/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsBuilderTest.java b/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsBuilderTest.java
index 0467522..a300a65 100644
--- a/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsBuilderTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsBuilderTest.java
@@ -400,7 +400,7 @@ public class MemoryRecordsBuilderTest {
         builder.append(10L, "1".getBytes(), "a".getBytes());
         builder.close();
 
-        MemoryRecords.writeEndTransactionalMarker(buffer, 1L, 15L, (short) 0,
+        MemoryRecords.writeEndTransactionalMarker(buffer, 1L, System.currentTimeMillis(), 0, 15L, (short) 0,
                 new EndTransactionMarker(ControlRecordType.ABORT, 0));
 
         builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, compressionType,
@@ -409,7 +409,7 @@ public class MemoryRecordsBuilderTest {
         builder.append(13L, "3".getBytes(), "c".getBytes());
         builder.close();
 
-        MemoryRecords.writeEndTransactionalMarker(buffer, 14L, 1L, (short) 0,
+        MemoryRecords.writeEndTransactionalMarker(buffer, 14L, System.currentTimeMillis(), 0, 1L, (short) 0,
                 new EndTransactionMarker(ControlRecordType.COMMIT, 0));
 
         buffer.flip();

http://git-wip-us.apache.org/repos/asf/kafka/blob/7baa58d7/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsTest.java
----------------------------------------------------------------------
diff --git a/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsTest.java b/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsTest.java
index 014a5bd..5a34f0f 100644
--- a/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsTest.java
@@ -222,8 +222,11 @@ public class MemoryRecordsTest {
             short producerEpoch = 13;
             long initialOffset = 983L;
             int coordinatorEpoch = 347;
+            int partitionLeaderEpoch = 29;
+
             EndTransactionMarker marker = new EndTransactionMarker(ControlRecordType.COMMIT, coordinatorEpoch);
-            MemoryRecords records = MemoryRecords.withEndTransactionMarker(initialOffset, producerId, producerEpoch, marker);
+            MemoryRecords records = MemoryRecords.withEndTransactionMarker(initialOffset, System.currentTimeMillis(),
+                    partitionLeaderEpoch, producerId, producerEpoch, marker);
             // verify that buffer allocation was precise
             assertEquals(records.buffer().remaining(), records.buffer().capacity());
 
@@ -235,6 +238,7 @@ public class MemoryRecordsTest {
             assertEquals(producerId, batch.producerId());
             assertEquals(producerEpoch, batch.producerEpoch());
             assertEquals(initialOffset, batch.baseOffset());
+            assertEquals(partitionLeaderEpoch, batch.partitionLeaderEpoch());
             assertTrue(batch.isValid());
 
             List<Record> createdRecords = TestUtils.toList(batch);
@@ -249,6 +253,55 @@ public class MemoryRecordsTest {
     }
 
     @Test
+    public void testFilterToBatchDiscard() {
+        if (compression != CompressionType.NONE || magic >= RecordBatch.MAGIC_VALUE_V2) {
+            ByteBuffer buffer = ByteBuffer.allocate(2048);
+            MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, magic, compression, TimestampType.CREATE_TIME, 0L);
+            builder.append(10L, "1".getBytes(), "a".getBytes());
+            builder.close();
+
+            builder = MemoryRecords.builder(buffer, magic, compression, TimestampType.CREATE_TIME, 1L);
+            builder.append(11L, "2".getBytes(), "b".getBytes());
+            builder.append(12L, "3".getBytes(), "c".getBytes());
+            builder.close();
+
+            builder = MemoryRecords.builder(buffer, magic, compression, TimestampType.CREATE_TIME, 3L);
+            builder.append(13L, "4".getBytes(), "d".getBytes());
+            builder.append(20L, "5".getBytes(), "e".getBytes());
+            builder.append(15L, "6".getBytes(), "f".getBytes());
+            builder.close();
+
+            builder = MemoryRecords.builder(buffer, magic, compression, TimestampType.CREATE_TIME, 6L);
+            builder.append(16L, "7".getBytes(), "g".getBytes());
+            builder.close();
+
+            buffer.flip();
+
+            ByteBuffer filtered = ByteBuffer.allocate(2048);
+            MemoryRecords.readableRecords(buffer).filterTo(new MemoryRecords.RecordFilter() {
+                @Override
+                protected boolean shouldDiscard(RecordBatch batch) {
+                    // discard the second and fourth batches
+                    return batch.lastOffset() == 2L || batch.lastOffset() == 6L;
+                }
+
+                @Override
+                protected boolean shouldRetain(RecordBatch recordBatch, Record record) {
+                    return true;
+                }
+            }, filtered);
+
+            filtered.flip();
+            MemoryRecords filteredRecords = MemoryRecords.readableRecords(filtered);
+
+            List<MutableRecordBatch> batches = TestUtils.toList(filteredRecords.batches());
+            assertEquals(2, batches.size());
+            assertEquals(0L, batches.get(0).lastOffset());
+            assertEquals(5L, batches.get(1).lastOffset());
+        }
+    }
+
+    @Test
     public void testFilterToPreservesProducerInfo() {
         if (magic >= RecordBatch.MAGIC_VALUE_V2) {
             ByteBuffer buffer = ByteBuffer.allocate(2048);
@@ -490,7 +543,7 @@ public class MemoryRecordsTest {
         return values;
     }
 
-    private static class RetainNonNullKeysFilter implements MemoryRecords.RecordFilter {
+    private static class RetainNonNullKeysFilter extends MemoryRecords.RecordFilter {
         @Override
         public boolean shouldRetain(RecordBatch batch, Record record) {
             return record.hasKey();

http://git-wip-us.apache.org/repos/asf/kafka/blob/7baa58d7/core/src/main/scala/kafka/log/Log.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/log/Log.scala b/core/src/main/scala/kafka/log/Log.scala
index d3ea251..77ab25d 100644
--- a/core/src/main/scala/kafka/log/Log.scala
+++ b/core/src/main/scala/kafka/log/Log.scala
@@ -159,8 +159,7 @@ class Log(@volatile var dir: File,
     loadSegments()
 
     /* Calculate the offset of the next message */
-    nextOffsetMetadata = new LogOffsetMetadata(activeSegment.nextOffset, activeSegment.baseOffset,
-      activeSegment.size.toInt)
+    nextOffsetMetadata = new LogOffsetMetadata(activeSegment.nextOffset, activeSegment.baseOffset, activeSegment.size)
 
     leaderEpochCache.clearAndFlushLatest(nextOffsetMetadata.messageOffset)
 
@@ -879,6 +878,14 @@ class Log(@volatile var dir: File,
     FetchDataInfo(nextOffsetMetadata, MemoryRecords.EMPTY)
   }
 
+  private[log] def collectAbortedTransactions(startOffset: Long, upperBoundOffset: Long): List[AbortedTxn] = {
+    val segmentEntry = segments.floorEntry(startOffset)
+    val allAbortedTxns = ListBuffer.empty[AbortedTxn]
+    def accumulator(abortedTxns: List[AbortedTxn]): Unit = allAbortedTxns ++= abortedTxns
+    collectAbortedTransactions(logStartOffset, upperBoundOffset, segmentEntry, accumulator)
+    allAbortedTxns.toList
+  }
+
   private def addAbortedTransactions(startOffset: Long, segmentEntry: JEntry[JLong, LogSegment],
                                      fetchInfo: FetchDataInfo): FetchDataInfo = {
     val fetchSize = fetchInfo.records.sizeInBytes
@@ -891,27 +898,28 @@ class Log(@volatile var dir: File,
       else
         logEndOffset
     }
-    val abortedTransactions = collectAbortedTransactions(startOffset, upperBoundOffset, segmentEntry)
+
+    val abortedTransactions = ListBuffer.empty[AbortedTransaction]
+    def accumulator(abortedTxns: List[AbortedTxn]): Unit = abortedTransactions ++= abortedTxns.map(_.asAbortedTransaction)
+    collectAbortedTransactions(startOffset, upperBoundOffset, segmentEntry, accumulator)
+
     FetchDataInfo(fetchOffsetMetadata = fetchInfo.fetchOffsetMetadata,
       records = fetchInfo.records,
       firstEntryIncomplete = fetchInfo.firstEntryIncomplete,
-      abortedTransactions = Some(abortedTransactions))
+      abortedTransactions = Some(abortedTransactions.toList))
   }
 
   private def collectAbortedTransactions(startOffset: Long, upperBoundOffset: Long,
-                                         startingSegmentEntry: JEntry[JLong, LogSegment]): List[AbortedTransaction] = {
+                                         startingSegmentEntry: JEntry[JLong, LogSegment],
+                                         accumulator: List[AbortedTxn] => Unit): Unit = {
     var segmentEntry = startingSegmentEntry
-    val abortedTransactions = ListBuffer.empty[AbortedTransaction]
-
     while (segmentEntry != null) {
       val searchResult = segmentEntry.getValue.collectAbortedTxns(startOffset, upperBoundOffset)
-      abortedTransactions ++= searchResult.abortedTransactions
+      accumulator(searchResult.abortedTransactions)
       if (searchResult.isComplete)
-        return abortedTransactions.toList
-
+        return
       segmentEntry = segments.higherEntry(segmentEntry.getKey)
     }
-    abortedTransactions.toList
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/kafka/blob/7baa58d7/core/src/main/scala/kafka/log/LogCleaner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/log/LogCleaner.scala b/core/src/main/scala/kafka/log/LogCleaner.scala
index 282e049..8eda2e1 100644
--- a/core/src/main/scala/kafka/log/LogCleaner.scala
+++ b/core/src/main/scala/kafka/log/LogCleaner.scala
@@ -26,16 +26,16 @@ import com.yammer.metrics.core.Gauge
 import kafka.common._
 import kafka.metrics.KafkaMetricsGroup
 import kafka.utils._
-import org.apache.kafka.common.record.{FileRecords, MemoryRecords, Record, RecordBatch}
+import org.apache.kafka.common.record._
 import org.apache.kafka.common.utils.Time
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.record.MemoryRecords.RecordFilter
 
-import scala.collection._
-import JavaConverters._
+import scala.collection.mutable
+import scala.collection.JavaConverters._
 
 /**
- * The cleaner is responsible for removing obsolete records from logs which have the dedupe retention strategy.
+ * The cleaner is responsible for removing obsolete records from logs which have the "compact" retention strategy.
  * A message with key K and offset O is obsolete if there exists a message with key K and offset O' such that O < O'.
  * 
  * Each log can be thought of being split into two sections of segments: a "clean" section which has previously been cleaned followed by a
@@ -43,7 +43,7 @@ import JavaConverters._
  * The uncleanable section is excluded from cleaning. The active log segment is always uncleanable. If there is a
  * compaction lag time set, segments whose largest message timestamp is within the compaction lag time of the cleaning operation are also uncleanable.
  *
- * The cleaning is carried out by a pool of background threads. Each thread chooses the dirtiest log that has the "dedupe" retention policy 
+ * The cleaning is carried out by a pool of background threads. Each thread chooses the dirtiest log that has the "compact" retention policy
  * and cleans that. The dirtiness of the log is guessed by taking the ratio of bytes in the dirty section of the log to the total bytes in the log. 
  * 
  * To clean a log the cleaner first builds a mapping of key=>last_offset for the dirty section of the log. See kafka.log.OffsetMap for details of
@@ -332,10 +332,22 @@ private[log] class Cleaner(val id: Int,
    * @return The first offset not cleaned and the statistics for this round of cleaning
    */
   private[log] def clean(cleanable: LogToClean): (Long, CleanerStats) = {
-    val stats = new CleanerStats()
+    // figure out the timestamp below which it is safe to remove delete tombstones
+    // this position is defined to be a configurable time beneath the last modified time of the last clean segment
+    val deleteHorizonMs = 
+      cleanable.log.logSegments(0, cleanable.firstDirtyOffset).lastOption match {
+        case None => 0L
+        case Some(seg) => seg.lastModified - cleanable.log.config.deleteRetentionMs
+    }
+
+    doClean(cleanable, deleteHorizonMs)
+  }
 
+  private[log] def doClean(cleanable: LogToClean, deleteHorizonMs: Long): (Long, CleanerStats) = {
     info("Beginning cleaning of log %s.".format(cleanable.log.name))
+
     val log = cleanable.log
+    val stats = new CleanerStats()
 
     // build the offset map
     info("Building offset map for %s...".format(cleanable.log.name))
@@ -343,14 +355,6 @@ private[log] class Cleaner(val id: Int,
     buildOffsetMap(log, cleanable.firstDirtyOffset, upperBoundOffset, offsetMap, stats)
     val endOffset = offsetMap.latestOffset + 1
     stats.indexDone()
-    
-    // figure out the timestamp below which it is safe to remove delete tombstones
-    // this position is defined to be a configurable time beneath the last modified time of the last clean segment
-    val deleteHorizonMs = 
-      log.logSegments(0, cleanable.firstDirtyOffset).lastOption match {
-        case None => 0L
-        case Some(seg) => seg.lastModified - log.config.deleteRetentionMs
-    }
 
     // determine the timestamp up to which the log will be cleaned
     // this is the lower of the last active segment and the compaction lag
@@ -363,7 +367,7 @@ private[log] class Cleaner(val id: Int,
 
     // record buffer utilization
     stats.bufferUtilization = offsetMap.utilization
-    
+
     stats.allDone()
 
     (endOffset, stats)
@@ -379,8 +383,8 @@ private[log] class Cleaner(val id: Int,
    * @param stats Collector for cleaning statistics
    */
   private[log] def cleanSegments(log: Log,
-                                 segments: Seq[LogSegment], 
-                                 map: OffsetMap, 
+                                 segments: Seq[LogSegment],
+                                 map: OffsetMap,
                                  deleteHorizonMs: Long,
                                  stats: CleanerStats) {
     // create a new segment with the suffix .cleaned appended to both the log and index name
@@ -403,11 +407,24 @@ private[log] class Cleaner(val id: Int,
 
     try {
       // clean segments into the new destination segment
-      for (old <- segments) {
-        val retainDeletes = old.lastModified > deleteHorizonMs
+      val iter = segments.iterator
+      var currentSegmentOpt: Option[LogSegment] = Some(iter.next())
+      while (currentSegmentOpt.isDefined) {
+        val oldSegmentOpt = currentSegmentOpt.get
+        val nextSegmentOpt = if (iter.hasNext) Some(iter.next()) else None
+
+        val startOffset = oldSegmentOpt.baseOffset
+        val upperBoundOffset = nextSegmentOpt.map(_.baseOffset).getOrElse(map.latestOffset + 1)
+        val abortedTransactions = log.collectAbortedTransactions(startOffset, upperBoundOffset)
+        val transactionMetadata = CleanedTransactionMetadata(abortedTransactions, Some(txnIndex))
+
+        val retainDeletes = oldSegmentOpt.lastModified > deleteHorizonMs
         info("Cleaning segment %s in log %s (largest timestamp %s) into %s, %s deletes."
-            .format(old.baseOffset, log.name, new Date(old.largestTimestamp), cleaned.baseOffset, if(retainDeletes) "retaining" else "discarding"))
-        cleanInto(log.topicPartition, old, cleaned, map, retainDeletes, log.config.maxMessageSize, log.activePids, stats)
+          .format(startOffset, log.name, new Date(oldSegmentOpt.largestTimestamp), cleaned.baseOffset, if(retainDeletes) "retaining" else "discarding"))
+        cleanInto(log.topicPartition, oldSegmentOpt, cleaned, map, retainDeletes, log.config.maxMessageSize, transactionMetadata,
+          log.activePids, stats)
+
+        currentSegmentOpt = nextSegmentOpt
       }
 
       // trim excess index
@@ -454,11 +471,39 @@ private[log] class Cleaner(val id: Int,
                              map: OffsetMap,
                              retainDeletes: Boolean,
                              maxLogMessageSize: Int,
-                             activePids: Map[Long, ProducerIdEntry],
+                             transactionMetadata: CleanedTransactionMetadata,
+                             activeProducers: Map[Long, ProducerIdEntry],
                              stats: CleanerStats) {
     val logCleanerFilter = new RecordFilter {
-      def shouldRetain(recordBatch: RecordBatch, record: Record): Boolean =
-        shouldRetainMessage(source, map, retainDeletes, recordBatch, record, stats, activePids)
+      var retainLastBatchSequence: Boolean = false
+      var discardBatchRecords: Boolean = false
+
+      override def shouldDiscard(batch: RecordBatch): Boolean = {
+        // we piggy-back on the tombstone retention logic to delay deletion of transaction markers.
+        // note that we will never delete a marker until all the records from that transaction are removed.
+        discardBatchRecords = shouldDiscardBatch(batch, transactionMetadata, retainTxnMarkers = retainDeletes)
+
+        // check if the batch contains the last sequence number for the producer. if so, we cannot
+        // remove the batch just yet or the producer may see an out of sequence error.
+        if (batch.hasProducerId && activeProducers.get(batch.producerId).exists(_.lastSeq == batch.lastSequence)) {
+          retainLastBatchSequence = true
+          false
+        } else {
+          retainLastBatchSequence = false
+          discardBatchRecords
+        }
+      }
+
+      override def shouldRetain(batch: RecordBatch, record: Record): Boolean = {
+        if (retainLastBatchSequence && batch.lastSequence == record.sequence)
+          // always retain the record with the last sequence number
+          true
+        else if (discardBatchRecords)
+          // remove the record if the batch would have otherwise been discarded
+          false
+        else
+          shouldRetainRecord(source, map, retainDeletes, batch, record, stats)
+      }
     }
 
     var position = 0
@@ -488,7 +533,7 @@ private[log] class Cleaner(val id: Int,
           records = retained)
         throttler.maybeThrottle(writeBuffer.limit)
       }
-      
+
       // if we read bytes but didn't get even one complete message, our I/O buffer is too small, grow it and try again
       if (readBuffer.limit > 0 && result.messagesRead == 0)
         growBuffers(maxLogMessageSize)
@@ -496,24 +541,24 @@ private[log] class Cleaner(val id: Int,
     restoreBuffers()
   }
 
-  private def shouldRetainMessage(source: kafka.log.LogSegment,
-                                  map: kafka.log.OffsetMap,
-                                  retainDeletes: Boolean,
-                                  batch: RecordBatch,
-                                  record: Record,
-                                  stats: CleanerStats,
-                                  activeProducers: Map[Long, ProducerIdEntry]): Boolean = {
-    if (batch.isControlBatch)
-      return true
-
-    // retain the record if it is the last one produced by an active idempotent producer to ensure that
-    // the producerId is not removed from the log before it has been expired
-    if (batch.hasProducerId) {
-      val producerId = batch.producerId
-      if (RecordBatch.NO_PRODUCER_ID < producerId && activeProducers.get(producerId).exists(_.lastOffset == record.offset))
-        return true
+  private def shouldDiscardBatch(batch: RecordBatch,
+                                 transactionMetadata: CleanedTransactionMetadata,
+                                 retainTxnMarkers: Boolean): Boolean = {
+    if (batch.isControlBatch) {
+      val canDiscardControlBatch = transactionMetadata.onControlBatchRead(batch)
+      canDiscardControlBatch && !retainTxnMarkers
+    } else {
+      val canDiscardBatch = transactionMetadata.onBatchRead(batch)
+      canDiscardBatch
     }
+  }
 
+  private def shouldRetainRecord(source: kafka.log.LogSegment,
+                                 map: kafka.log.OffsetMap,
+                                 retainDeletes: Boolean,
+                                 batch: RecordBatch,
+                                 record: Record,
+                                 stats: CleanerStats): Boolean = {
     val pastLatestOffset = record.offset > map.latestOffset
     if (pastLatestOffset)
       return true
@@ -546,7 +591,7 @@ private[log] class Cleaner(val id: Int,
     this.readBuffer = ByteBuffer.allocate(newSize)
     this.writeBuffer = ByteBuffer.allocate(newSize)
   }
-  
+
   /**
    * Restore the I/O buffer capacity to its original size
    */
@@ -609,14 +654,18 @@ private[log] class Cleaner(val id: Int,
     map.clear()
     val dirty = log.logSegments(start, end).toBuffer
     info("Building offset map for log %s for %d segments in offset range [%d, %d).".format(log.name, dirty.size, start, end))
-    
+
+    val abortedTransactions = log.collectAbortedTransactions(start, end)
+    val transactionMetadata = CleanedTransactionMetadata(abortedTransactions)
+
     // Add all the cleanable dirty segments. We must take at least map.slots * load_factor,
     // but we may be able to fit more (if there is lots of duplication in the dirty section of the log)
     var full = false
     for (segment <- dirty if !full) {
       checkDone(log.topicPartition)
 
-      full = buildOffsetMapForSegment(log.topicPartition, segment, map, start, log.config.maxMessageSize, stats)
+      full = buildOffsetMapForSegment(log.topicPartition, segment, map, start, log.config.maxMessageSize,
+        transactionMetadata, stats)
       if (full)
         debug("Offset map is full, %d segments fully mapped, segment with base offset %d is partially mapped".format(dirty.indexOf(segment), segment.baseOffset))
     }
@@ -635,10 +684,11 @@ private[log] class Cleaner(val id: Int,
   private def buildOffsetMapForSegment(topicPartition: TopicPartition,
                                        segment: LogSegment,
                                        map: OffsetMap,
-                                       start: Long,
+                                       startOffset: Long,
                                        maxLogMessageSize: Int,
+                                       transactionMetadata: CleanedTransactionMetadata,
                                        stats: CleanerStats): Boolean = {
-    var position = segment.index.lookup(start).position
+    var position = segment.index.lookup(startOffset).position
     val maxDesiredMapSize = (map.slots * this.dupBufferLoadFactor).toInt
     while (position < segment.log.sizeInBytes) {
       checkDone(topicPartition)
@@ -648,14 +698,30 @@ private[log] class Cleaner(val id: Int,
       throttler.maybeThrottle(records.sizeInBytes)
 
       val startPosition = position
-      for (batch <- records.batches.asScala; record <- batch.asScala) {
-        if (!batch.isControlBatch && record.hasKey && record.offset >= start) {
-          if (map.size < maxDesiredMapSize)
-            map.put(record.key, record.offset)
-          else
-            return true
+      for (batch <- records.batches.asScala) {
+        if (batch.isControlBatch) {
+          transactionMetadata.onControlBatchRead(batch)
+          stats.indexMessagesRead(1)
+        } else {
+          val isAborted = transactionMetadata.onBatchRead(batch)
+          if (isAborted) {
+            // abort markers are supported in v2 and above, which means count is defined
+            stats.indexMessagesRead(batch.countOrNull)
+          } else {
+            for (record <- batch.asScala) {
+              if (record.hasKey && record.offset >= startOffset) {
+                if (map.size < maxDesiredMapSize)
+                  map.put(record.key, record.offset)
+                else
+                  return true
+              }
+              stats.indexMessagesRead(1)
+            }
+          }
         }
-        stats.indexMessagesRead(1)
+
+        if (batch.lastOffset >= startOffset)
+          map.updateLatestOffset(batch.lastOffset)
       }
       val bytesRead = records.validBytes
       position += bytesRead
@@ -694,7 +760,7 @@ private class CleanerStats(time: Time = Time.SYSTEM) {
   def invalidMessage() {
     invalidMessagesRead += 1
   }
-  
+
   def recopyMessages(messagesWritten: Int, bytesWritten: Int) {
     this.messagesWritten += messagesWritten
     this.bytesWritten += bytesWritten
@@ -715,11 +781,11 @@ private class CleanerStats(time: Time = Time.SYSTEM) {
   def allDone() {
     endTime = time.milliseconds
   }
-  
+
   def elapsedSecs = (endTime - startTime)/1000.0
-  
+
   def elapsedIndexSecs = (mapCompleteTime - startTime)/1000.0
-  
+
 }
 
 /**
@@ -734,3 +800,80 @@ private case class LogToClean(topicPartition: TopicPartition, log: Log, firstDir
   val cleanableRatio = cleanableBytes / totalBytes.toDouble
   override def compare(that: LogToClean): Int = math.signum(this.cleanableRatio - that.cleanableRatio).toInt
 }
+
+private[log] object CleanedTransactionMetadata {
+  def apply(abortedTransactions: List[AbortedTxn],
+            transactionIndex: Option[TransactionIndex] = None): CleanedTransactionMetadata = {
+    val queue = mutable.PriorityQueue.empty[AbortedTxn](new Ordering[AbortedTxn] {
+      override def compare(x: AbortedTxn, y: AbortedTxn): Int = x.firstOffset compare y.firstOffset
+    }.reverse)
+    queue ++= abortedTransactions
+    new CleanedTransactionMetadata(queue, transactionIndex)
+  }
+
+  val Empty = CleanedTransactionMetadata(List.empty[AbortedTxn])
+}
+
+/**
+ * This is a helper class to facilitate tracking transaction state while cleaning the log. It is initialized
+ * with the aborted transactions from the transaction index and its state is updated as the cleaner iterates through
+ * the log during a round of cleaning. This class is responsible for deciding when transaction markers can
+ * be removed and is therefore also responsible for updating the cleaned transaction index accordingly.
+ */
+private[log] class CleanedTransactionMetadata(val abortedTransactions: mutable.PriorityQueue[AbortedTxn],
+                                              val transactionIndex: Option[TransactionIndex] = None) {
+  val ongoingCommittedTxns = mutable.Set.empty[Long]
+  val ongoingAbortedTxns = mutable.Map.empty[Long, AbortedTxn]
+
+  /**
+   * Update the cleaned transaction state with a control batch that has just been traversed by the cleaner.
+   * Return true if the control batch can be discarded.
+   */
+  def onControlBatchRead(controlBatch: RecordBatch): Boolean = {
+    consumeAbortedTxnsUpTo(controlBatch.lastOffset)
+
+    val controlRecord = controlBatch.iterator.next()
+    val controlType = ControlRecordType.parse(controlRecord.key)
+    val producerId = controlBatch.producerId
+    controlType match {
+      case ControlRecordType.ABORT =>
+        val maybeAbortedTxn = ongoingAbortedTxns.remove(producerId)
+        maybeAbortedTxn.foreach { abortedTxn =>
+          transactionIndex.foreach(_.append(abortedTxn))
+        }
+        true
+
+      case ControlRecordType.COMMIT =>
+        // this marker is eligible for deletion if we didn't traverse any records from the transaction
+        !ongoingCommittedTxns.remove(producerId)
+
+      case _ => false
+    }
+  }
+
+  private def consumeAbortedTxnsUpTo(offset: Long): Unit = {
+    while (abortedTransactions.headOption.exists(_.firstOffset <= offset)) {
+      val abortedTxn = abortedTransactions.dequeue()
+      ongoingAbortedTxns += abortedTxn.producerId -> abortedTxn
+    }
+  }
+
+  /**
+   * Update the transactional state for the incoming non-control batch. If the batch is part of
+   * an aborted transaction, return true to indicate that it is safe to discard.
+   */
+  def onBatchRead(batch: RecordBatch): Boolean = {
+    consumeAbortedTxnsUpTo(batch.lastOffset)
+    if (batch.isTransactional) {
+      if (ongoingAbortedTxns.contains(batch.producerId))
+        true
+      else {
+        ongoingCommittedTxns += batch.producerId
+        false
+      }
+    } else {
+      false
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/kafka/blob/7baa58d7/core/src/main/scala/kafka/log/LogCleanerManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/log/LogCleanerManager.scala b/core/src/main/scala/kafka/log/LogCleanerManager.scala
index 2b4d956..6e0ebfb 100755
--- a/core/src/main/scala/kafka/log/LogCleanerManager.scala
+++ b/core/src/main/scala/kafka/log/LogCleanerManager.scala
@@ -304,19 +304,22 @@ private[log] object LogCleanerManager extends Logging {
     // may be cleaned
     val firstUncleanableDirtyOffset: Long = Seq (
 
-        // the active segment is always uncleanable
-        Option(log.activeSegment.baseOffset),
-
-        // the first segment whose largest message timestamp is within a minimum time lag from now
-        if (compactionLagMs > 0) {
-          dirtyNonActiveSegments.find {
-            s =>
-              val isUncleanable = s.largestTimestamp > now - compactionLagMs
-              debug(s"Checking if log segment may be cleaned: log='${log.name}' segment.baseOffset=${s.baseOffset} segment.largestTimestamp=${s.largestTimestamp}; now - compactionLag=${now - compactionLagMs}; is uncleanable=$isUncleanable")
-              isUncleanable
-          } map(_.baseOffset)
-        } else None
-      ).flatten.min
+      // we do not clean beyond the first unstable offset
+      log.firstUnstableOffset.map(_.messageOffset),
+
+      // the active segment is always uncleanable
+      Option(log.activeSegment.baseOffset),
+
+      // the first segment whose largest message timestamp is within a minimum time lag from now
+      if (compactionLagMs > 0) {
+        dirtyNonActiveSegments.find {
+          s =>
+            val isUncleanable = s.largestTimestamp > now - compactionLagMs
+            debug(s"Checking if log segment may be cleaned: log='${log.name}' segment.baseOffset=${s.baseOffset} segment.largestTimestamp=${s.largestTimestamp}; now - compactionLag=${now - compactionLagMs}; is uncleanable=$isUncleanable")
+            isUncleanable
+        } map(_.baseOffset)
+      } else None
+    ).flatten.min
 
     debug(s"Finding range of cleanable offsets for log=${log.name} topicPartition=$topicPartition. Last clean offset=$lastCleanOffset now=$now => firstDirtyOffset=$firstDirtyOffset firstUncleanableOffset=$firstUncleanableDirtyOffset activeSegment.baseOffset=${log.activeSegment.baseOffset}")
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/7baa58d7/core/src/main/scala/kafka/log/OffsetMap.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/log/OffsetMap.scala b/core/src/main/scala/kafka/log/OffsetMap.scala
index 1df0615..8b493c2 100755
--- a/core/src/main/scala/kafka/log/OffsetMap.scala
+++ b/core/src/main/scala/kafka/log/OffsetMap.scala
@@ -27,6 +27,7 @@ trait OffsetMap {
   def slots: Int
   def put(key: ByteBuffer, offset: Long)
   def get(key: ByteBuffer): Long
+  def updateLatestOffset(offset: Long)
   def clear()
   def size: Int
   def utilization: Double = size.toDouble / slots
@@ -167,6 +168,10 @@ class SkimpyOffsetMap(val memory: Int, val hashAlgorithm: String = "MD5") extend
    */
   override def latestOffset: Long = lastOffset
 
+  override def updateLatestOffset(offset: Long): Unit = {
+    lastOffset = offset
+  }
+
   /**
    * Calculate the ith probe position. We first try reading successive integers from the hash itself
    * then if all of those fail we degrade to linear probing.

http://git-wip-us.apache.org/repos/asf/kafka/blob/7baa58d7/core/src/main/scala/kafka/log/ProducerStateManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/log/ProducerStateManager.scala b/core/src/main/scala/kafka/log/ProducerStateManager.scala
index b1a43d2..03c60e4 100644
--- a/core/src/main/scala/kafka/log/ProducerStateManager.scala
+++ b/core/src/main/scala/kafka/log/ProducerStateManager.scala
@@ -165,7 +165,7 @@ private[log] class ProducerAppendInfo(val producerId: Long, initialEntry: Produc
     }
 
     val firstOffset = currentTxnFirstOffset match {
-      case Some(firstOffset) => firstOffset
+      case Some(txnFirstOffset) => txnFirstOffset
       case None =>
         transactions += new TxnMetadata(producerId, offset)
         offset

http://git-wip-us.apache.org/repos/asf/kafka/blob/7baa58d7/core/src/main/scala/kafka/log/TransactionIndex.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/log/TransactionIndex.scala b/core/src/main/scala/kafka/log/TransactionIndex.scala
index bf6a6d4..f8a3879 100644
--- a/core/src/main/scala/kafka/log/TransactionIndex.scala
+++ b/core/src/main/scala/kafka/log/TransactionIndex.scala
@@ -28,7 +28,7 @@ import org.apache.kafka.common.utils.Utils
 
 import scala.collection.mutable.ListBuffer
 
-private[log] case class TxnIndexSearchResult(abortedTransactions: List[AbortedTransaction], isComplete: Boolean)
+private[log] case class TxnIndexSearchResult(abortedTransactions: List[AbortedTxn], isComplete: Boolean)
 
 /**
  * The transaction index maintains metadata about the aborted transactions for each segment. This includes
@@ -114,7 +114,7 @@ class TransactionIndex(val startOffset: Long, @volatile var file: File) extends
     }
   }
 
-  private def iterator(allocate: () => ByteBuffer): Iterator[(AbortedTxn, Int)] = {
+  private def iterator(allocate: () => ByteBuffer = () => ByteBuffer.allocate(AbortedTxn.TotalSize)): Iterator[(AbortedTxn, Int)] = {
     maybeChannel match {
       case None => Iterator.empty
       case Some(channel) =>
@@ -148,7 +148,7 @@ class TransactionIndex(val startOffset: Long, @volatile var file: File) extends
   }
 
   def allAbortedTxns: List[AbortedTxn] = {
-    iterator(() => ByteBuffer.allocate(AbortedTxn.TotalSize)).map(_._1).toList
+    iterator().map(_._1).toList
   }
 
   /**
@@ -160,11 +160,10 @@ class TransactionIndex(val startOffset: Long, @volatile var file: File) extends
    *         into the next log segment.
    */
   def collectAbortedTxns(fetchOffset: Long, upperBoundOffset: Long): TxnIndexSearchResult = {
-    val abortedTransactions = ListBuffer.empty[AbortedTransaction]
-    val buffer = ByteBuffer.allocate(AbortedTxn.TotalSize)
-    for ((abortedTxn, _) <- iterator(() => buffer)) {
+    val abortedTransactions = ListBuffer.empty[AbortedTxn]
+    for ((abortedTxn, _) <- iterator()) {
       if (abortedTxn.lastOffset >= fetchOffset && abortedTxn.firstOffset < upperBoundOffset)
-        abortedTransactions += abortedTxn.asAbortedTransaction
+        abortedTransactions += abortedTxn
 
       if (abortedTxn.lastStableOffset >= upperBoundOffset)
         return TxnIndexSearchResult(abortedTransactions.toList, isComplete = true)

http://git-wip-us.apache.org/repos/asf/kafka/blob/7baa58d7/core/src/test/scala/unit/kafka/log/LogCleanerManagerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/log/LogCleanerManagerTest.scala b/core/src/test/scala/unit/kafka/log/LogCleanerManagerTest.scala
index c2ae155..9f9d982 100644
--- a/core/src/test/scala/unit/kafka/log/LogCleanerManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogCleanerManagerTest.scala
@@ -166,6 +166,54 @@ class LogCleanerManagerTest extends JUnitSuite with Logging {
     assertEquals("The first uncleanable offset begins with active segment.", log.activeSegment.baseOffset, cleanableOffsets._2)
   }
 
+  @Test
+  def testUndecidedTransactionalDataNotCleanable(): Unit = {
+    val topicPartition = new TopicPartition("log", 0)
+    val compactionLag = 60 * 60 * 1000
+    val logProps = new Properties()
+    logProps.put(LogConfig.SegmentBytesProp, 1024: java.lang.Integer)
+    logProps.put(LogConfig.MinCompactionLagMsProp, compactionLag: java.lang.Integer)
+
+    val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps))
+
+    val producerId = 15L
+    val producerEpoch = 0.toShort
+    val sequence = 0
+    log.appendAsLeader(MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, producerEpoch, sequence,
+      new SimpleRecord(time.milliseconds(), "1".getBytes, "a".getBytes),
+      new SimpleRecord(time.milliseconds(), "2".getBytes, "b".getBytes)), leaderEpoch = 0)
+    log.appendAsLeader(MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, producerEpoch, sequence + 2,
+      new SimpleRecord(time.milliseconds(), "3".getBytes, "c".getBytes)), leaderEpoch = 0)
+    log.roll()
+    log.onHighWatermarkIncremented(3L)
+
+    time.sleep(compactionLag + 1)
+    // although the compaction lag has been exceeded, the undecided data should not be cleaned
+    var cleanableOffsets = LogCleanerManager.cleanableOffsets(log, topicPartition,
+      Map(topicPartition -> 0L), time.milliseconds())
+    assertEquals(0L, cleanableOffsets._1)
+    assertEquals(0L, cleanableOffsets._2)
+
+    log.appendAsLeader(MemoryRecords.withEndTransactionMarker(time.milliseconds(), producerId, producerEpoch,
+      new EndTransactionMarker(ControlRecordType.ABORT, 15)), leaderEpoch = 0, isFromClient = false)
+    log.roll()
+    log.onHighWatermarkIncremented(4L)
+
+    // the first segment should now become cleanable immediately
+    cleanableOffsets = LogCleanerManager.cleanableOffsets(log, topicPartition,
+      Map(topicPartition -> 0L), time.milliseconds())
+    assertEquals(0L, cleanableOffsets._1)
+    assertEquals(3L, cleanableOffsets._2)
+
+    time.sleep(compactionLag + 1)
+
+    // the second segment becomes cleanable after the compaction lag
+    cleanableOffsets = LogCleanerManager.cleanableOffsets(log, topicPartition,
+      Map(topicPartition -> 0L), time.milliseconds())
+    assertEquals(0L, cleanableOffsets._1)
+    assertEquals(4L, cleanableOffsets._2)
+  }
+
   private def createCleanerManager(log: Log): LogCleanerManager = {
     val logs = new Pool[TopicPartition, Log]()
     logs.put(new TopicPartition("log", 0), log)

http://git-wip-us.apache.org/repos/asf/kafka/blob/7baa58d7/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala b/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
index fe07fdd..c842fc3 100755
--- a/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
@@ -87,6 +87,159 @@ class LogCleanerTest extends JUnitSuite {
     assertEquals(expectedBytesRead, stats.bytesRead)
   }
 
+  @Test
+  def testBasicTransactionAwareCleaning(): Unit = {
+    val cleaner = makeCleaner(Int.MaxValue)
+    val logProps = new Properties()
+    logProps.put(LogConfig.SegmentBytesProp, 2048: java.lang.Integer)
+    val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps))
+
+    val producerEpoch = 0.toShort
+    val pid1 = 1
+    val pid2 = 2
+
+    val appendProducer1 = appendTransactionalAsLeader(log, pid1, producerEpoch)
+    val appendProducer2 = appendTransactionalAsLeader(log, pid2, producerEpoch)
+
+    appendProducer1(Seq(1, 2))
+    appendProducer2(Seq(2, 3))
+    appendProducer1(Seq(3, 4))
+    log.appendAsLeader(abortMarker(pid1, producerEpoch), leaderEpoch = 0, isFromClient = false)
+    log.appendAsLeader(commitMarker(pid2, producerEpoch), leaderEpoch = 0, isFromClient = false)
+    appendProducer1(Seq(2))
+    log.appendAsLeader(commitMarker(pid1, producerEpoch), leaderEpoch = 0, isFromClient = false)
+
+    val abortedTransactions = log.collectAbortedTransactions(log.logStartOffset, log.logEndOffset)
+
+    log.roll()
+    cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 0L, log.activeSegment.baseOffset))
+    assertEquals(List(3, 2), keysInLog(log))
+    assertEquals(List(3, 6, 7, 8, 9), offsetsInLog(log))
+
+    // ensure the transaction index is still correct
+    assertEquals(abortedTransactions, log.collectAbortedTransactions(log.logStartOffset, log.logEndOffset))
+  }
+
+  @Test
+  def testCleanWithTransactionsSpanningSegments(): Unit = {
+    val cleaner = makeCleaner(Int.MaxValue)
+    val logProps = new Properties()
+    logProps.put(LogConfig.SegmentBytesProp, 1024: java.lang.Integer)
+    val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps))
+
+    val producerEpoch = 0.toShort
+    val pid1 = 1
+    val pid2 = 2
+    val pid3 = 3
+
+    val appendProducer1 = appendTransactionalAsLeader(log, pid1, producerEpoch)
+    val appendProducer2 = appendTransactionalAsLeader(log, pid2, producerEpoch)
+    val appendProducer3 = appendTransactionalAsLeader(log, pid3, producerEpoch)
+
+    appendProducer1(Seq(1, 2))
+    appendProducer3(Seq(2, 3))
+    appendProducer2(Seq(3, 4))
+
+    log.roll()
+
+    appendProducer2(Seq(5, 6))
+    appendProducer3(Seq(6, 7))
+    appendProducer1(Seq(7, 8))
+    log.appendAsLeader(abortMarker(pid2, producerEpoch), leaderEpoch = 0, isFromClient = false)
+    appendProducer3(Seq(8, 9))
+    log.appendAsLeader(commitMarker(pid3, producerEpoch), leaderEpoch = 0, isFromClient = false)
+    appendProducer1(Seq(9, 10))
+    log.appendAsLeader(abortMarker(pid1, producerEpoch), leaderEpoch = 0, isFromClient = false)
+
+    // we have only cleaned the records in the first segment
+    val dirtyOffset = cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 0L, log.activeSegment.baseOffset))._1
+    assertEquals(List(2, 3, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10), keysInLog(log))
+
+    log.roll()
+
+    // append a couple extra segments in the new segment to ensure we have sequence numbers
+    appendProducer2(Seq(11))
+    appendProducer1(Seq(12))
+
+    // finally only the keys from pid3 should remain
+    cleaner.clean(LogToClean(new TopicPartition("test", 0), log, dirtyOffset, log.activeSegment.baseOffset))
+    assertEquals(List(2, 3, 6, 7, 8, 9, 11, 12), keysInLog(log))
+  }
+
+  @Test
+  def testCommitMarkerRemoval(): Unit = {
+    val tp = new TopicPartition("test", 0)
+    val cleaner = makeCleaner(Int.MaxValue)
+    val logProps = new Properties()
+    logProps.put(LogConfig.SegmentBytesProp, 256: java.lang.Integer)
+    val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps))
+
+    val producerEpoch = 0.toShort
+    val producerId = 1L
+    val appendProducer = appendTransactionalAsLeader(log, producerId, producerEpoch)
+
+    appendProducer(Seq(1))
+    appendProducer(Seq(2, 3))
+    log.appendAsLeader(commitMarker(producerId, producerEpoch), leaderEpoch = 0, isFromClient = false)
+    appendProducer(Seq(2))
+    log.appendAsLeader(commitMarker(producerId, producerEpoch), leaderEpoch = 0, isFromClient = false)
+    log.roll()
+
+    // cannot remove the marker in this pass because there are still valid records
+    var dirtyOffset = cleaner.doClean(LogToClean(tp, log, 0L, 100L), deleteHorizonMs = time.milliseconds())._1
+    assertEquals(List(1, 3, 2), keysInLog(log))
+    assertEquals(List(0, 2, 3, 4, 5), offsetsInLog(log))
+
+    appendProducer(Seq(1, 3))
+    log.appendAsLeader(commitMarker(producerId, producerEpoch), leaderEpoch = 0, isFromClient = false)
+    log.roll()
+
+    // the first cleaning preserves the commit marker (at offset 3) since there were still records for the transaction
+    dirtyOffset = cleaner.doClean(LogToClean(tp, log, dirtyOffset, 100L), deleteHorizonMs = time.milliseconds())._1
+    assertEquals(List(2, 1, 3), keysInLog(log))
+    assertEquals(List(3, 4, 5, 6, 7, 8), offsetsInLog(log))
+
+    // delete horizon forced to 0 to verify marker is not removed early
+    dirtyOffset = cleaner.doClean(LogToClean(tp, log, dirtyOffset, 100L), deleteHorizonMs = 0L)._1
+    assertEquals(List(2, 1, 3), keysInLog(log))
+    assertEquals(List(3, 4, 5, 6, 7, 8), offsetsInLog(log))
+
+    // clean again with the delete horizon set back to the current time and verify the marker is removed
+    cleaner.doClean(LogToClean(tp, log, dirtyOffset, 100L), deleteHorizonMs = time.milliseconds())
+    assertEquals(List(2, 1, 3), keysInLog(log))
+    assertEquals(List(4, 5, 6, 7, 8), offsetsInLog(log))
+  }
+
+  @Test
+  def testAbortMarkerRemoval(): Unit = {
+    val tp = new TopicPartition("test", 0)
+    val cleaner = makeCleaner(Int.MaxValue)
+    val logProps = new Properties()
+    logProps.put(LogConfig.SegmentBytesProp, 256: java.lang.Integer)
+    val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps))
+
+    val producerEpoch = 0.toShort
+    val producerId = 1L
+    val appendProducer = appendTransactionalAsLeader(log, producerId, producerEpoch)
+
+    appendProducer(Seq(1))
+    appendProducer(Seq(2, 3))
+    log.appendAsLeader(abortMarker(producerId, producerEpoch), leaderEpoch = 0, isFromClient = false)
+    appendProducer(Seq(3))
+    log.appendAsLeader(commitMarker(producerId, producerEpoch), leaderEpoch = 0, isFromClient = false)
+    log.roll()
+
+    // delete horizon set to 0 to verify marker is not removed early
+    val dirtyOffset = cleaner.doClean(LogToClean(tp, log, 0L, 100L), deleteHorizonMs = 0L)._1
+    assertEquals(List(3), keysInLog(log))
+    assertEquals(List(3, 4, 5), offsetsInLog(log))
+
+    // clean again with the delete horizon set back to the current time and verify the marker is removed
+    cleaner.doClean(LogToClean(tp, log, dirtyOffset, 100L), deleteHorizonMs = time.milliseconds())
+    assertEquals(List(3), keysInLog(log))
+    assertEquals(List(4, 5), offsetsInLog(log))
+  }
+
   /**
    * Test log cleaning with logs containing messages larger than default message size
    */
@@ -174,24 +327,45 @@ class LogCleanerTest extends JUnitSuite {
   }
 
   @Test
-  def testLogCleanerRetainsLastWrittenRecordForEachPid(): Unit = {
+  def testLogCleanerRetainsProducerLastSequence(): Unit = {
     val cleaner = makeCleaner(10)
     val logProps = new Properties()
     logProps.put(LogConfig.SegmentBytesProp, 1024: java.lang.Integer)
 
     val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps))
     log.appendAsLeader(record(0, 0), leaderEpoch = 0) // offset 0
-    log.appendAsLeader(record(0, 1, pid = 1, epoch = 0, sequence = 0), leaderEpoch = 0) // offset 1
-    log.appendAsLeader(record(0, 2, pid = 2, epoch = 0, sequence = 0), leaderEpoch = 0) // offset 2
-    log.appendAsLeader(record(0, 3, pid = 3, epoch = 0, sequence = 0), leaderEpoch = 0) // offset 3
-    log.appendAsLeader(record(1, 1, pid = 2, epoch = 0, sequence = 1), leaderEpoch = 0) // offset 4
+    log.appendAsLeader(record(0, 1, producerId = 1, producerEpoch = 0, sequence = 0), leaderEpoch = 0) // offset 1
+    log.appendAsLeader(record(0, 2, producerId = 2, producerEpoch = 0, sequence = 0), leaderEpoch = 0) // offset 2
+    log.appendAsLeader(record(0, 3, producerId = 3, producerEpoch = 0, sequence = 0), leaderEpoch = 0) // offset 3
+    log.appendAsLeader(record(1, 1, producerId = 2, producerEpoch = 0, sequence = 1), leaderEpoch = 0) // offset 4
 
     // roll the segment, so we can clean the messages already appended
     log.roll()
 
-    cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 2, log.activeSegment.baseOffset))
-    assertEquals(immutable.List(0, 0, 1), keysInLog(log))
-    assertEquals(immutable.List(1, 3, 4), offsetsInLog(log))
+    cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 0L, log.activeSegment.baseOffset))
+    assertEquals(List(0, 0, 1), keysInLog(log))
+    assertEquals(List(1, 3, 4), offsetsInLog(log))
+  }
+
+  @Test
+  def testLogCleanerRetainsLastSequenceEvenIfTransactionAborted(): Unit = {
+    val cleaner = makeCleaner(10)
+    val logProps = new Properties()
+    logProps.put(LogConfig.SegmentBytesProp, 1024: java.lang.Integer)
+    val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps))
+
+    val producerEpoch = 0.toShort
+    val producerId = 1L
+    val appendProducer = appendTransactionalAsLeader(log, producerId, producerEpoch)
+
+    appendProducer(Seq(1))
+    appendProducer(Seq(2, 3))
+    log.appendAsLeader(abortMarker(producerId, producerEpoch), leaderEpoch = 0, isFromClient = false)
+    log.roll()
+
+    cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 0L, log.activeSegment.baseOffset))
+    assertEquals(List(3), keysInLog(log))
+    assertEquals(List(2, 3), offsetsInLog(log))
   }
 
   @Test
@@ -213,17 +387,17 @@ class LogCleanerTest extends JUnitSuite {
 
     // clean the log with only one message removed
     cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 2, log.activeSegment.baseOffset))
-    assertEquals(immutable.List(1,0,1,0), keysInLog(log))
-    assertEquals(immutable.List(1,2,3,4), offsetsInLog(log))
+    assertEquals(List(1,0,1,0), keysInLog(log))
+    assertEquals(List(1,2,3,4), offsetsInLog(log))
 
     // continue to make progress, even though we can only clean one message at a time
     cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 3, log.activeSegment.baseOffset))
-    assertEquals(immutable.List(0,1,0), keysInLog(log))
-    assertEquals(immutable.List(2,3,4), offsetsInLog(log))
+    assertEquals(List(0,1,0), keysInLog(log))
+    assertEquals(List(2,3,4), offsetsInLog(log))
 
     cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 4, log.activeSegment.baseOffset))
-    assertEquals(immutable.List(1,0), keysInLog(log))
-    assertEquals(immutable.List(3,4), offsetsInLog(log))
+    assertEquals(List(1,0), keysInLog(log))
+    assertEquals(List(3,4), offsetsInLog(log))
   }
 
   @Test
@@ -346,8 +520,12 @@ class LogCleanerTest extends JUnitSuite {
   }
 
   /* extract all the keys from a log */
-  def keysInLog(log: Log): Iterable[Int] =
-    log.logSegments.flatMap(s => s.log.records.asScala.filter(_.hasValue).filter(_.hasKey).map(record => TestUtils.readString(record.key).toInt))
+  def keysInLog(log: Log): Iterable[Int] = {
+    for (segment <- log.logSegments;
+         batch <- segment.log.batches.asScala if !batch.isControlBatch;
+         record <- batch.asScala if record.hasValue && record.hasKey)
+      yield TestUtils.readString(record.key).toInt
+  }
 
   /* extract all the offsets from a log */
   def offsetsInLog(log: Log): Iterable[Long] =
@@ -795,12 +973,12 @@ class LogCleanerTest extends JUnitSuite {
   private def messageWithOffset(key: Int, value: Int, offset: Long): MemoryRecords =
     messageWithOffset(key.toString.getBytes, value.toString.getBytes, offset)
 
-  def makeLog(dir: File = dir, config: LogConfig = logConfig) =
+  private def makeLog(dir: File = dir, config: LogConfig = logConfig) =
     new Log(dir = dir, config = config, logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler, time = time)
 
-  def noOpCheckDone(topicPartition: TopicPartition) { /* do nothing */  }
+  private def noOpCheckDone(topicPartition: TopicPartition) { /* do nothing */  }
 
-  def makeCleaner(capacity: Int, checkDone: TopicPartition => Unit = noOpCheckDone, maxMessageSize: Int = 64*1024) =
+  private def makeCleaner(capacity: Int, checkDone: TopicPartition => Unit = noOpCheckDone, maxMessageSize: Int = 64*1024) =
     new Cleaner(id = 0,
                 offsetMap = new FakeOffsetMap(capacity),
                 ioBufferSize = maxMessageSize,
@@ -810,28 +988,62 @@ class LogCleanerTest extends JUnitSuite {
                 time = time,
                 checkDone = checkDone)
 
-  def writeToLog(log: Log, seq: Iterable[(Int, Int)]): Iterable[Long] = {
+  private def writeToLog(log: Log, seq: Iterable[(Int, Int)]): Iterable[Long] = {
     for((key, value) <- seq)
       yield log.appendAsLeader(record(key, value), leaderEpoch = 0).firstOffset
   }
 
-  def key(id: Int) = ByteBuffer.wrap(id.toString.getBytes)
-
+  private def key(id: Int) = ByteBuffer.wrap(id.toString.getBytes)
 
-  def record(key: Int, value: Int, pid: Long = RecordBatch.NO_PRODUCER_ID, epoch: Short = RecordBatch.NO_PRODUCER_EPOCH,
+  private def record(key: Int, value: Int,
+             producerId: Long = RecordBatch.NO_PRODUCER_ID,
+             producerEpoch: Short = RecordBatch.NO_PRODUCER_EPOCH,
              sequence: Int = RecordBatch.NO_SEQUENCE,
              partitionLeaderEpoch: Int = RecordBatch.NO_PARTITION_LEADER_EPOCH): MemoryRecords = {
-    MemoryRecords.withIdempotentRecords(RecordBatch.CURRENT_MAGIC_VALUE, 0L, CompressionType.NONE, pid, epoch, sequence,
+    MemoryRecords.withIdempotentRecords(RecordBatch.CURRENT_MAGIC_VALUE, 0L, CompressionType.NONE, producerId, producerEpoch, sequence,
       partitionLeaderEpoch, new SimpleRecord(key.toString.getBytes, value.toString.getBytes))
   }
 
-  def record(key: Int, value: Array[Byte]): MemoryRecords =
+  private def transactionalRecords(records: Seq[SimpleRecord],
+                           producerId: Long,
+                           producerEpoch: Short,
+                           sequence: Int): MemoryRecords = {
+    MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, producerEpoch, sequence, records: _*)
+  }
+
+  private def appendTransactionalAsLeader(log: Log, producerId: Long, producerEpoch: Short = 0): Seq[Int] => Unit = {
+    var sequence = 0
+    keys: Seq[Int] => {
+      val simpleRecords = keys.map { key =>
+        val keyBytes = key.toString.getBytes
+        new SimpleRecord(keyBytes, keyBytes) // the value doesn't matter too much since we validate offsets
+      }
+      val records = transactionalRecords(simpleRecords, producerId, producerEpoch, sequence)
+      log.appendAsLeader(records, leaderEpoch = 0)
+      sequence += simpleRecords.size
+    }
+  }
+
+  private def commitMarker(producerId: Long, producerEpoch: Short, timestamp: Long = time.milliseconds()): MemoryRecords =
+    endTxnMarker(producerId, producerEpoch, ControlRecordType.COMMIT, 0L, timestamp)
+
+  private def abortMarker(producerId: Long, producerEpoch: Short, timestamp: Long = time.milliseconds()): MemoryRecords =
+    endTxnMarker(producerId, producerEpoch, ControlRecordType.ABORT, 0L, timestamp)
+
+  private def endTxnMarker(producerId: Long, producerEpoch: Short, controlRecordType: ControlRecordType,
+                   offset: Long, timestamp: Long): MemoryRecords = {
+    val endTxnMarker = new EndTransactionMarker(controlRecordType, 0)
+    MemoryRecords.withEndTransactionMarker(offset, timestamp, RecordBatch.NO_PARTITION_LEADER_EPOCH,
+      producerId, producerEpoch, endTxnMarker)
+  }
+
+  private def record(key: Int, value: Array[Byte]): MemoryRecords =
     TestUtils.singletonRecords(key = key.toString.getBytes, value = value)
 
-  def unkeyedRecord(value: Int): MemoryRecords =
+  private def unkeyedRecord(value: Int): MemoryRecords =
     TestUtils.singletonRecords(value = value.toString.getBytes)
 
-  def tombstoneRecord(key: Int): MemoryRecords = record(key, null)
+  private def tombstoneRecord(key: Int): MemoryRecords = record(key, null)
 
 }
 
@@ -842,12 +1054,12 @@ class FakeOffsetMap(val slots: Int) extends OffsetMap {
   private def keyFor(key: ByteBuffer) =
     new String(Utils.readBytes(key.duplicate), "UTF-8")
 
-  def put(key: ByteBuffer, offset: Long): Unit = {
+  override def put(key: ByteBuffer, offset: Long): Unit = {
     lastOffset = offset
     map.put(keyFor(key), offset)
   }
   
-  def get(key: ByteBuffer): Long = {
+  override def get(key: ByteBuffer): Long = {
     val k = keyFor(key)
     if(map.containsKey(k))
       map.get(k)
@@ -855,11 +1067,15 @@ class FakeOffsetMap(val slots: Int) extends OffsetMap {
       -1L
   }
   
-  def clear(): Unit = map.clear()
+  override def clear(): Unit = map.clear()
   
-  def size: Int = map.size
+  override def size: Int = map.size
+
+  override def latestOffset: Long = lastOffset
 
-  def latestOffset: Long = lastOffset
+  override def updateLatestOffset(offset: Long): Unit = {
+    lastOffset = offset
+  }
 
   override def toString: String = map.toString
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/7baa58d7/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala b/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala
index 4709b77..c3da9b3 100644
--- a/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala
@@ -21,8 +21,7 @@ import java.io.File
 import kafka.utils.TestUtils
 import kafka.utils.TestUtils.checkEquals
 import org.apache.kafka.common.TopicPartition
-import org.apache.kafka.common.record.MemoryRecords.withEndTransactionMarker
-import org.apache.kafka.common.record.{RecordBatch, _}
+import org.apache.kafka.common.record._
 import org.apache.kafka.common.utils.{Time, Utils}
 import org.junit.Assert._
 import org.junit.{After, Before, Test}
@@ -273,7 +272,8 @@ class LogSegmentTest {
   @Test
   def testRecoverTransactionIndex(): Unit = {
     val segment = createSegment(100)
-    val epoch = 0.toShort
+    val producerEpoch = 0.toShort
+    val partitionLeaderEpoch = 15
     val sequence = 0
 
     val pid1 = 5L
@@ -282,25 +282,25 @@ class LogSegmentTest {
     // append transactional records from pid1
     segment.append(firstOffset = 100L, largestOffset = 101L, largestTimestamp = RecordBatch.NO_TIMESTAMP,
       shallowOffsetOfMaxTimestamp = 100L, MemoryRecords.withTransactionalRecords(100L, CompressionType.NONE,
-        pid1, epoch, sequence, new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes)))
+        pid1, producerEpoch, sequence, partitionLeaderEpoch, new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes)))
 
     // append transactional records from pid2
     segment.append(firstOffset = 102L, largestOffset = 103L, largestTimestamp = RecordBatch.NO_TIMESTAMP,
       shallowOffsetOfMaxTimestamp = 102L, MemoryRecords.withTransactionalRecords(102L, CompressionType.NONE,
-        pid2, epoch, sequence, new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes)))
+        pid2, producerEpoch, sequence, partitionLeaderEpoch, new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes)))
 
     // append non-transactional records
     segment.append(firstOffset = 104L, largestOffset = 105L, largestTimestamp = RecordBatch.NO_TIMESTAMP,
       shallowOffsetOfMaxTimestamp = 104L, MemoryRecords.withRecords(104L, CompressionType.NONE,
-        new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes)))
+        partitionLeaderEpoch, new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes)))
 
     // abort the transaction from pid2 (note LSO should be 100L since the txn from pid1 has not completed)
     segment.append(firstOffset = 106L, largestOffset = 106L, largestTimestamp = RecordBatch.NO_TIMESTAMP,
-      shallowOffsetOfMaxTimestamp = 106L,  endTxnRecords(ControlRecordType.ABORT, pid2, epoch, offset = 106L))
+      shallowOffsetOfMaxTimestamp = 106L, endTxnRecords(ControlRecordType.ABORT, pid2, producerEpoch, offset = 106L))
 
     // commit the transaction from pid1
     segment.append(firstOffset = 107L, largestOffset = 107L, largestTimestamp = RecordBatch.NO_TIMESTAMP,
-      shallowOffsetOfMaxTimestamp = 107L, endTxnRecords(ControlRecordType.COMMIT, pid1, epoch, offset = 107L))
+      shallowOffsetOfMaxTimestamp = 107L, endTxnRecords(ControlRecordType.COMMIT, pid1, producerEpoch, offset = 107L))
 
     segment.recover(64 * 1024, new ProducerStateManager(topicPartition, logDir))
 
@@ -314,7 +314,7 @@ class LogSegmentTest {
 
     // recover again, but this time assuming the transaction from pid2 began on a previous segment
     val stateManager = new ProducerStateManager(topicPartition, logDir)
-    stateManager.loadProducerEntry(ProducerIdEntry(pid2, epoch, 10, 90L, 5, RecordBatch.NO_TIMESTAMP, 0, Some(75L)))
+    stateManager.loadProducerEntry(ProducerIdEntry(pid2, producerEpoch, 10, 90L, 5, RecordBatch.NO_TIMESTAMP, 0, Some(75L)))
     segment.recover(64 * 1024, stateManager)
 
     abortedTxns = segment.txnIndex.allAbortedTxns
@@ -328,11 +328,13 @@ class LogSegmentTest {
 
   private def endTxnRecords(controlRecordType: ControlRecordType,
                             producerId: Long,
-                            epoch: Short,
+                            producerEpoch: Short,
                             offset: Long = 0L,
-                            coordinatorEpoch: Int = 0): MemoryRecords = {
+                            partitionLeaderEpoch: Int = 0,
+                            coordinatorEpoch: Int = 0,
+                            timestamp: Long = RecordBatch.NO_TIMESTAMP): MemoryRecords = {
     val marker = new EndTransactionMarker(controlRecordType, coordinatorEpoch)
-    withEndTransactionMarker(offset, producerId, epoch, marker)
+    MemoryRecords.withEndTransactionMarker(offset, timestamp, partitionLeaderEpoch, producerId, producerEpoch, marker)
   }
 
   /**