You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jg...@apache.org on 2021/02/19 00:45:56 UTC

[kafka] branch trunk updated: KAFKA-12331: Use LEO for the base offset of LeaderChangeMessage batch (#10138)

This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new e29f7a3  KAFKA-12331: Use LEO for the base offset of LeaderChangeMessage batch (#10138)
e29f7a3 is described below

commit e29f7a36dbbd316ae03008140a1a0d282a26b82d
Author: José Armando García Sancio <js...@users.noreply.github.com>
AuthorDate: Thu Feb 18 16:44:40 2021 -0800

    KAFKA-12331: Use LEO for the base offset of LeaderChangeMessage batch (#10138)
    
    The `KafkaMetadataLog` implementation of `ReplicatedLog` validates that batches appended using `appendAsLeader` and `appendAsFollower` have an offset that matches the LEO. This is enforced by `KafkaRaftClient` and `BatchAccumulator`. When creating control batches for the `LeaderChangeMessage` the default base offset of `0` was being used instead of using the LEO. This is fixed by:
    
    1. Changing the implementation for `MockLog` to validate against this and throw an `RuntimeException` if this invariant is violated.
    2. Always create a batch for `LeaderChangeMessage` with an offset equal to the LEO.
    
    Reviewers: Jason Gustafson <ja...@confluent.io>
---
 .../apache/kafka/common/record/MemoryRecords.java  | 15 +++--
 .../kafka/common/record/MemoryRecordsTest.java     | 11 +++-
 .../scala/kafka/raft/KafkaMetadataLogTest.scala    | 35 ++++++++++
 .../org/apache/kafka/raft/KafkaRaftClient.java     | 10 ++-
 .../java/org/apache/kafka/raft/ReplicatedLog.java  |  2 +
 .../kafka/raft/KafkaRaftClientSnapshotTest.java    | 27 ++++----
 .../org/apache/kafka/raft/KafkaRaftClientTest.java | 40 ++++++------
 .../test/java/org/apache/kafka/raft/MockLog.java   | 49 ++++++--------
 .../java/org/apache/kafka/raft/MockLogTest.java    | 74 ++++++++++++++++++----
 .../apache/kafka/raft/RaftClientTestContext.java   |  9 ++-
 10 files changed, 190 insertions(+), 82 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
index 7d14f67..ae844bf 100644
--- a/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
+++ b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
@@ -637,12 +637,17 @@ public class MemoryRecords extends AbstractRecords {
         builder.close();
     }
 
-    public static MemoryRecords withLeaderChangeMessage(long timestamp, int leaderEpoch, LeaderChangeMessage leaderChangeMessage) {
+    public static MemoryRecords withLeaderChangeMessage(
+        long initialOffset,
+        long timestamp,
+        int leaderEpoch,
+        LeaderChangeMessage leaderChangeMessage
+    ) {
         // To avoid calling message toStruct multiple times, we supply a fixed message size
         // for leader change, as it happens rare and the buffer could still grow if not sufficient in
         // certain edge cases.
         ByteBuffer buffer = ByteBuffer.allocate(256);
-        writeLeaderChangeMessage(buffer, 0L, timestamp, leaderEpoch, leaderChangeMessage);
+        writeLeaderChangeMessage(buffer, initialOffset, timestamp, leaderEpoch, leaderChangeMessage);
         buffer.flip();
         return MemoryRecords.readableRecords(buffer);
     }
@@ -652,10 +657,12 @@ public class MemoryRecords extends AbstractRecords {
                                                  long timestamp,
                                                  int leaderEpoch,
                                                  LeaderChangeMessage leaderChangeMessage) {
-        MemoryRecordsBuilder builder = new MemoryRecordsBuilder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE,
+        MemoryRecordsBuilder builder = new MemoryRecordsBuilder(
+            buffer, RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE,
             TimestampType.CREATE_TIME, initialOffset, timestamp,
             RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE,
-            false, true, leaderEpoch, buffer.capacity());
+            false, true, leaderEpoch, buffer.capacity()
+        );
         builder.appendLeaderChangeMessage(timestamp, leaderChangeMessage);
         builder.close();
     }
diff --git a/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsTest.java b/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsTest.java
index 26b1c48..be5b337 100644
--- a/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsTest.java
@@ -488,20 +488,25 @@ public class MemoryRecordsTest {
         final int leaderId = 5;
         final int leaderEpoch = 20;
         final int voterId = 6;
+        long initialOffset = 983L;
 
         LeaderChangeMessage leaderChangeMessage = new LeaderChangeMessage()
             .setLeaderId(leaderId)
             .setVoters(Collections.singletonList(
                 new Voter().setVoterId(voterId)));
-        MemoryRecords records = MemoryRecords.withLeaderChangeMessage(System.currentTimeMillis(),
-            leaderEpoch, leaderChangeMessage);
+        MemoryRecords records = MemoryRecords.withLeaderChangeMessage(
+            initialOffset,
+            System.currentTimeMillis(),
+            leaderEpoch,
+            leaderChangeMessage
+        );
 
         List<MutableRecordBatch> batches = TestUtils.toList(records.batches());
         assertEquals(1, batches.size());
 
         RecordBatch batch = batches.get(0);
         assertTrue(batch.isControlBatch());
-        assertEquals(0, batch.baseOffset());
+        assertEquals(initialOffset, batch.baseOffset());
         assertEquals(leaderEpoch, batch.partitionLeaderEpoch());
         assertTrue(batch.isValid());
 
diff --git a/core/src/test/scala/kafka/raft/KafkaMetadataLogTest.scala b/core/src/test/scala/kafka/raft/KafkaMetadataLogTest.scala
index 5229ae7..ce55a2b 100644
--- a/core/src/test/scala/kafka/raft/KafkaMetadataLogTest.scala
+++ b/core/src/test/scala/kafka/raft/KafkaMetadataLogTest.scala
@@ -64,6 +64,41 @@ final class KafkaMetadataLogTest {
   }
 
   @Test
+  def testUnexpectedAppendOffset(): Unit = {
+    val topicPartition = new TopicPartition("cluster-metadata", 0)
+    val log = buildMetadataLog(tempDir, mockTime, topicPartition)
+
+    val recordFoo = new SimpleRecord("foo".getBytes())
+    val currentEpoch = 3
+    val initialOffset = log.endOffset().offset
+
+    log.appendAsLeader(
+      MemoryRecords.withRecords(initialOffset, CompressionType.NONE, currentEpoch, recordFoo),
+      currentEpoch
+    )
+
+    // Throw exception for out of order records
+    assertThrows(
+      classOf[RuntimeException],
+      () => {
+        log.appendAsLeader(
+          MemoryRecords.withRecords(initialOffset, CompressionType.NONE, currentEpoch, recordFoo),
+          currentEpoch
+        )
+      }
+    )
+
+    assertThrows(
+      classOf[RuntimeException],
+      () => {
+        log.appendAsFollower(
+          MemoryRecords.withRecords(initialOffset, CompressionType.NONE, currentEpoch, recordFoo)
+        )
+      }
+    )
+  }
+
+  @Test
   def testCreateSnapshot(): Unit = {
     val topicPartition = new TopicPartition("cluster-metadata", 0)
     val numberOfRecords = 10
diff --git a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
index 164a921..081651a 100644
--- a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
+++ b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
@@ -404,7 +404,7 @@ public class KafkaRaftClient<T> implements RaftClient<T> {
         // The high watermark can only be advanced once we have written a record
         // from the new leader's epoch. Hence we write a control message immediately
         // to ensure there is no delay committing pending data.
-        appendLeaderChangeMessage(state, currentTimeMs);
+        appendLeaderChangeMessage(state, log.endOffset().offset, currentTimeMs);
         updateLeaderEndOffsetAndTimestamp(state, currentTimeMs);
 
         resetConnections();
@@ -429,7 +429,7 @@ public class KafkaRaftClient<T> implements RaftClient<T> {
             .collect(Collectors.toList());
     }
 
-    private void appendLeaderChangeMessage(LeaderState state, long currentTimeMs) {
+    private void appendLeaderChangeMessage(LeaderState state, long baseOffset, long currentTimeMs) {
         List<Voter> voters = convertToVoters(state.followers());
         List<Voter> grantingVoters = convertToVoters(state.grantingVoters());
 
@@ -442,7 +442,11 @@ public class KafkaRaftClient<T> implements RaftClient<T> {
             .setGrantingVoters(grantingVoters);
 
         MemoryRecords records = MemoryRecords.withLeaderChangeMessage(
-            currentTimeMs, quorum.epoch(), leaderChangeMessage);
+            baseOffset,
+            currentTimeMs,
+            quorum.epoch(),
+            leaderChangeMessage
+        );
 
         appendAsLeader(records);
         flushLeaderLog(state, currentTimeMs);
diff --git a/raft/src/main/java/org/apache/kafka/raft/ReplicatedLog.java b/raft/src/main/java/org/apache/kafka/raft/ReplicatedLog.java
index 417b769..6b4adce 100644
--- a/raft/src/main/java/org/apache/kafka/raft/ReplicatedLog.java
+++ b/raft/src/main/java/org/apache/kafka/raft/ReplicatedLog.java
@@ -34,6 +34,7 @@ public interface ReplicatedLog extends Closeable {
      *
      * @return the metadata information of the appended batch
      * @throws IllegalArgumentException if the record set is empty
+     * @throws RuntimeException if the batch base offset doesn't match the log end offset
      */
     LogAppendInfo appendAsLeader(Records records, int epoch);
 
@@ -44,6 +45,7 @@ public interface ReplicatedLog extends Closeable {
      *
      * @return the metadata information of the appended batch
      * @throws IllegalArgumentException if the record set is empty
+     * @throws RuntimeException if the batch base offset doesn't match the log end offset
      */
     LogAppendInfo appendAsFollower(Records records);
 
diff --git a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientSnapshotTest.java b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientSnapshotTest.java
index 614ff32..9ebb776 100644
--- a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientSnapshotTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientSnapshotTest.java
@@ -106,7 +106,8 @@ final public class KafkaRaftClientSnapshotTest {
         OffsetAndEpoch oldestSnapshotId = new OffsetAndEpoch(3, 2);
 
         RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
-            .appendToLog(oldestSnapshotId.offset, oldestSnapshotId.epoch, Arrays.asList("a", "b", "c"))
+            .appendToLog(oldestSnapshotId.epoch, Arrays.asList("a", "b", "c"))
+            .appendToLog(oldestSnapshotId.epoch, Arrays.asList("d", "e", "f"))
             .withAppendLingerMs(1)
             .build();
 
@@ -115,10 +116,11 @@ final public class KafkaRaftClientSnapshotTest {
         assertEquals(oldestSnapshotId.epoch + 1, epoch);
 
         // Advance the highWatermark
-        context.deliverRequest(context.fetchRequest(epoch, otherNodeId, oldestSnapshotId.offset, oldestSnapshotId.epoch, 0));
+        long localLogEndOffset = context.log.endOffset().offset;
+        context.deliverRequest(context.fetchRequest(epoch, otherNodeId, localLogEndOffset, epoch, 0));
         context.pollUntilResponse();
         context.assertSentFetchResponse(Errors.NONE, epoch, OptionalInt.of(localId));
-        assertEquals(oldestSnapshotId.offset, context.client.highWatermark().getAsLong());
+        assertEquals(localLogEndOffset, context.client.highWatermark().getAsLong());
 
         // Create a snapshot at the high watermark
         try (SnapshotWriter<String> snapshot = context.client.createSnapshot(oldestSnapshotId)) {
@@ -146,8 +148,8 @@ final public class KafkaRaftClientSnapshotTest {
         OffsetAndEpoch oldestSnapshotId = new OffsetAndEpoch(3, 2);
 
         RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
-            .appendToLog(oldestSnapshotId.offset, oldestSnapshotId.epoch, Arrays.asList("a", "b", "c"))
-            .appendToLog(oldestSnapshotId.offset + 3, oldestSnapshotId.epoch + 2, Arrays.asList("a", "b", "c"))
+            .appendToLog(oldestSnapshotId.epoch, Arrays.asList("a", "b", "c"))
+            .appendToLog(oldestSnapshotId.epoch + 2, Arrays.asList("d", "e", "f"))
             .withAppendLingerMs(1)
             .build();
 
@@ -192,8 +194,9 @@ final public class KafkaRaftClientSnapshotTest {
         OffsetAndEpoch oldestSnapshotId = new OffsetAndEpoch(3, 2);
 
         RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
-            .appendToLog(oldestSnapshotId.offset, oldestSnapshotId.epoch, Arrays.asList("a", "b", "c"))
-            .appendToLog(oldestSnapshotId.offset + 3, oldestSnapshotId.epoch + 2, Arrays.asList("a", "b", "c"))
+            .appendToLog(oldestSnapshotId.epoch, Arrays.asList("a", "b", "c"))
+            .appendToLog(oldestSnapshotId.epoch, Arrays.asList("d", "e", "f"))
+            .appendToLog(oldestSnapshotId.epoch + 2, Arrays.asList("g", "h", "i"))
             .withAppendLingerMs(1)
             .build();
 
@@ -233,8 +236,9 @@ final public class KafkaRaftClientSnapshotTest {
         OffsetAndEpoch oldestSnapshotId = new OffsetAndEpoch(3, 2);
 
         RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
-            .appendToLog(oldestSnapshotId.offset, oldestSnapshotId.epoch, Arrays.asList("a", "b", "c"))
-            .appendToLog(oldestSnapshotId.offset + 3, oldestSnapshotId.epoch + 2, Arrays.asList("a", "b", "c"))
+            .appendToLog(oldestSnapshotId.epoch, Arrays.asList("a", "b", "c"))
+            .appendToLog(oldestSnapshotId.epoch, Arrays.asList("d", "e", "f"))
+            .appendToLog(oldestSnapshotId.epoch + 2, Arrays.asList("g", "h", "i"))
             .withAppendLingerMs(1)
             .build();
 
@@ -279,8 +283,9 @@ final public class KafkaRaftClientSnapshotTest {
         OffsetAndEpoch oldestSnapshotId = new OffsetAndEpoch(3, 2);
 
         RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
-            .appendToLog(oldestSnapshotId.offset, oldestSnapshotId.epoch, Arrays.asList("a", "b", "c"))
-            .appendToLog(oldestSnapshotId.offset + 3, oldestSnapshotId.epoch + 2, Arrays.asList("a", "b", "c"))
+            .appendToLog(oldestSnapshotId.epoch, Arrays.asList("a", "b", "c"))
+            .appendToLog(oldestSnapshotId.epoch, Arrays.asList("d", "e", "f"))
+            .appendToLog(oldestSnapshotId.epoch + 2, Arrays.asList("g", "h", "i"))
             .withAppendLingerMs(1)
             .build();
 
diff --git a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java
index b29f1ef..fb188f1 100644
--- a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java
@@ -31,7 +31,6 @@ import org.apache.kafka.common.record.MutableRecordBatch;
 import org.apache.kafka.common.record.Record;
 import org.apache.kafka.common.record.RecordBatch;
 import org.apache.kafka.common.record.Records;
-import org.apache.kafka.common.record.SimpleRecord;
 import org.apache.kafka.common.requests.DescribeQuorumRequest;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.test.TestUtils;
@@ -226,7 +225,10 @@ public class KafkaRaftClientTest {
         context.client.poll();
 
         // append some record, but the fetch in purgatory will still fail
-        context.log.appendAsLeader(Collections.singleton(new SimpleRecord("raft".getBytes())), epoch);
+        context.log.appendAsLeader(
+            context.buildBatch(context.log.endOffset().offset, epoch, Arrays.asList("raft")),
+            epoch
+        );
 
         // when transition to resign, all request in fetchPurgatory will fail
         context.client.shutdown(1000);
@@ -449,7 +451,7 @@ public class KafkaRaftClientTest {
         RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
             .withUnknownLeader(epoch)
             .build();
-        
+
         context.deliverRequest(context.endEpochRequest(epoch, voter2,
             Arrays.asList(localId, voter3)));
 
@@ -945,7 +947,7 @@ public class KafkaRaftClientTest {
 
         RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
             .withElectedLeader(epoch, otherNodeId)
-            .appendToLog(0L, lastEpoch, singletonList("foo"))
+            .appendToLog(lastEpoch, singletonList("foo"))
             .build();
 
         context.assertElectedLeader(epoch, otherNodeId);
@@ -964,7 +966,7 @@ public class KafkaRaftClientTest {
 
         RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
             .withElectedLeader(epoch, otherNodeId)
-            .appendToLog(0L, lastEpoch, singletonList("foo"))
+            .appendToLog(lastEpoch, singletonList("foo"))
             .build();
         context.assertElectedLeader(epoch, otherNodeId);
 
@@ -1781,8 +1783,8 @@ public class KafkaRaftClientTest {
 
         RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
             .withElectedLeader(epoch, otherNodeId)
-            .appendToLog(0L, lastEpoch, Arrays.asList("foo", "bar"))
-            .appendToLog(2L, lastEpoch, Arrays.asList("baz"))
+            .appendToLog(lastEpoch, Arrays.asList("foo", "bar"))
+            .appendToLog(lastEpoch, Arrays.asList("baz"))
             .build();
 
         context.assertElectedLeader(epoch, otherNodeId);
@@ -1964,9 +1966,9 @@ public class KafkaRaftClientTest {
         List<String> batch3 = Arrays.asList("7", "8", "9");
 
         RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
-            .appendToLog(0L, 1, batch1)
-            .appendToLog(3L, 1, batch2)
-            .appendToLog(6L, 2, batch3)
+            .appendToLog(1, batch1)
+            .appendToLog(1, batch2)
+            .appendToLog(2, batch3)
             .withUnknownLeader(epoch - 1)
             .build();
 
@@ -2016,9 +2018,9 @@ public class KafkaRaftClientTest {
         List<String> batch3 = Arrays.asList("7", "8", "9");
 
         RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
-            .appendToLog(0L, 1, batch1)
-            .appendToLog(3L, 1, batch2)
-            .appendToLog(6L, 2, batch3)
+            .appendToLog(1, batch1)
+            .appendToLog(1, batch2)
+            .appendToLog(2, batch3)
             .withUnknownLeader(epoch - 1)
             .build();
 
@@ -2105,9 +2107,9 @@ public class KafkaRaftClientTest {
         Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
 
         RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
-            .appendToLog(0L, 2, Arrays.asList("a", "b", "c"))
-            .appendToLog(3L, 4, Arrays.asList("d", "e", "f"))
-            .appendToLog(6L, 4, Arrays.asList("g", "h", "i"))
+            .appendToLog(2, Arrays.asList("a", "b", "c"))
+            .appendToLog(4, Arrays.asList("d", "e", "f"))
+            .appendToLog(4, Arrays.asList("g", "h", "i"))
             .withUnknownLeader(epoch - 1)
             .build();
 
@@ -2146,9 +2148,9 @@ public class KafkaRaftClientTest {
         Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
 
         RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
-            .appendToLog(0L, 2, Arrays.asList("a", "b", "c"))
-            .appendToLog(3L, 4, Arrays.asList("d", "e", "f"))
-            .appendToLog(6L, 4, Arrays.asList("g", "h", "i"))
+            .appendToLog(2, Arrays.asList("a", "b", "c"))
+            .appendToLog(4, Arrays.asList("d", "e", "f"))
+            .appendToLog(4, Arrays.asList("g", "h", "i"))
             .withUnknownLeader(epoch - 1)
             .build();
 
diff --git a/raft/src/test/java/org/apache/kafka/raft/MockLog.java b/raft/src/test/java/org/apache/kafka/raft/MockLog.java
index 3252d54..9d80465 100644
--- a/raft/src/test/java/org/apache/kafka/raft/MockLog.java
+++ b/raft/src/test/java/org/apache/kafka/raft/MockLog.java
@@ -35,7 +35,6 @@ import org.apache.kafka.snapshot.RawSnapshotWriter;
 
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
-import java.util.Collection;
 import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
@@ -250,30 +249,7 @@ public class MockLog implements ReplicatedLog {
 
     @Override
     public LogAppendInfo appendAsLeader(Records records, int epoch) {
-        if (records.sizeInBytes() == 0)
-            throw new IllegalArgumentException("Attempt to append an empty record set");
-
-        long baseOffset = endOffset().offset;
-        AtomicLong offsetSupplier = new AtomicLong(baseOffset);
-        for (RecordBatch batch : records.batches()) {
-            List<LogEntry> entries = buildEntries(batch, record -> offsetSupplier.getAndIncrement());
-            appendBatch(new LogBatch(epoch, batch.isControlBatch(), entries));
-        }
-
-        return new LogAppendInfo(baseOffset, offsetSupplier.get() - 1);
-    }
-
-    LogAppendInfo appendAsLeader(Collection<SimpleRecord> records, int epoch) {
-        long baseOffset = endOffset().offset;
-        long offset = baseOffset;
-
-        List<LogEntry> entries = new ArrayList<>();
-        for (SimpleRecord record : records) {
-            entries.add(buildEntry(offset, record));
-            offset += 1;
-        }
-        appendBatch(new LogBatch(epoch, false, entries));
-        return new LogAppendInfo(baseOffset, offset - 1);
+        return append(records, OptionalInt.of(epoch));
     }
 
     private Long appendBatch(LogBatch batch) {
@@ -286,6 +262,10 @@ public class MockLog implements ReplicatedLog {
 
     @Override
     public LogAppendInfo appendAsFollower(Records records) {
+        return append(records, OptionalInt.empty());
+    }
+
+    private LogAppendInfo append(Records records, OptionalInt epoch) {
         if (records.sizeInBytes() == 0)
             throw new IllegalArgumentException("Attempt to append an empty record set");
 
@@ -293,13 +273,26 @@ public class MockLog implements ReplicatedLog {
         long lastOffset = baseOffset;
         for (RecordBatch batch : records.batches()) {
             if (batch.baseOffset() != endOffset().offset) {
-                throw new IllegalArgumentException(
-                    String.format("Illegal append at offset %s with current end offset of %", batch.baseOffset(), endOffset().offset)
+                /* KafkaMetadataLog throws an kafka.common.UnexpectedAppendOffsetException this is the
+                 * best we can do from this module.
+                 */
+                throw new RuntimeException(
+                    String.format(
+                        "Illegal append at offset %s with current end offset of %s",
+                        batch.baseOffset(),
+                        endOffset().offset
+                    )
                 );
             }
 
             List<LogEntry> entries = buildEntries(batch, Record::offset);
-            appendBatch(new LogBatch(batch.partitionLeaderEpoch(), batch.isControlBatch(), entries));
+            appendBatch(
+                new LogBatch(
+                    epoch.orElseGet(batch::partitionLeaderEpoch),
+                    batch.isControlBatch(),
+                    entries
+                )
+            );
             lastOffset = entries.get(entries.size() - 1).offset;
         }
 
diff --git a/raft/src/test/java/org/apache/kafka/raft/MockLogTest.java b/raft/src/test/java/org/apache/kafka/raft/MockLogTest.java
index 95f098f..67231eb 100644
--- a/raft/src/test/java/org/apache/kafka/raft/MockLogTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/MockLogTest.java
@@ -36,6 +36,7 @@ import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.List;
 import java.util.Objects;
@@ -65,7 +66,7 @@ public class MockLogTest {
     public void testAppendAsLeaderHelper() {
         int epoch = 2;
         SimpleRecord recordOne = new SimpleRecord("one".getBytes());
-        log.appendAsLeader(Collections.singleton(recordOne), epoch);
+        appendAsLeader(Collections.singleton(recordOne), epoch);
         assertEquals(epoch, log.lastFetchedEpoch());
         assertEquals(0L, log.startOffset());
         assertEquals(1L, log.endOffset().offset);
@@ -84,7 +85,7 @@ public class MockLogTest {
 
         SimpleRecord recordTwo = new SimpleRecord("two".getBytes());
         SimpleRecord recordThree = new SimpleRecord("three".getBytes());
-        log.appendAsLeader(Arrays.asList(recordTwo, recordThree), epoch);
+        appendAsLeader(Arrays.asList(recordTwo, recordThree), epoch);
         assertEquals(0L, log.startOffset());
         assertEquals(3L, log.endOffset().offset);
 
@@ -109,10 +110,10 @@ public class MockLogTest {
         int epoch = 2;
         SimpleRecord recordOne = new SimpleRecord("one".getBytes());
         SimpleRecord recordTwo = new SimpleRecord("two".getBytes());
-        log.appendAsLeader(Arrays.asList(recordOne, recordTwo), epoch);
+        appendAsLeader(Arrays.asList(recordOne, recordTwo), epoch);
 
         SimpleRecord recordThree = new SimpleRecord("three".getBytes());
-        log.appendAsLeader(Collections.singleton(recordThree), epoch);
+        appendAsLeader(Collections.singleton(recordThree), epoch);
 
         assertEquals(0L, log.startOffset());
         assertEquals(3L, log.endOffset().offset);
@@ -150,11 +151,14 @@ public class MockLogTest {
 
     @Test
     public void testAppendAsLeader() {
-        // The record passed-in offsets are not going to affect the eventual offsets.
-        final long initialOffset = 5L;
         SimpleRecord recordFoo = new SimpleRecord("foo".getBytes());
         final int currentEpoch = 3;
-        log.appendAsLeader(MemoryRecords.withRecords(initialOffset, CompressionType.NONE, recordFoo), currentEpoch);
+        final long initialOffset = log.endOffset().offset;
+
+        log.appendAsLeader(
+            MemoryRecords.withRecords(initialOffset, CompressionType.NONE, recordFoo),
+            currentEpoch
+        );
 
         assertEquals(0, log.startOffset());
         assertEquals(1, log.endOffset().offset);
@@ -172,12 +176,46 @@ public class MockLogTest {
     }
 
     @Test
+    public void testUnexpectedAppendOffset() {
+        SimpleRecord recordFoo = new SimpleRecord("foo".getBytes());
+        final int currentEpoch = 3;
+        final long initialOffset = log.endOffset().offset;
+
+        log.appendAsLeader(
+            MemoryRecords.withRecords(initialOffset, CompressionType.NONE, currentEpoch, recordFoo),
+            currentEpoch
+        );
+
+        // Throw exception for out of order records
+        assertThrows(
+            RuntimeException.class,
+            () -> {
+                log.appendAsLeader(
+                    MemoryRecords.withRecords(initialOffset, CompressionType.NONE, currentEpoch, recordFoo),
+                    currentEpoch
+                );
+            }
+        );
+
+        assertThrows(
+            RuntimeException.class,
+            () -> {
+                log.appendAsFollower(
+                    MemoryRecords.withRecords(initialOffset, CompressionType.NONE, currentEpoch, recordFoo)
+                );
+            }
+        );
+    }
+
+    @Test
     public void testAppendControlRecord() {
-        final long initialOffset = 5L;
+        final long initialOffset = 0;
         final int currentEpoch = 3;
         LeaderChangeMessage messageData =  new LeaderChangeMessage().setLeaderId(0);
-        log.appendAsLeader(MemoryRecords.withLeaderChangeMessage(
-            initialOffset, 2, messageData), currentEpoch);
+        log.appendAsLeader(
+            MemoryRecords.withLeaderChangeMessage(initialOffset, 0L, 2, messageData),
+            currentEpoch
+        );
 
         assertEquals(0, log.startOffset());
         assertEquals(1, log.endOffset().offset);
@@ -239,7 +277,7 @@ public class MockLogTest {
         recordTwoBuffer.putInt(2);
         SimpleRecord recordTwo = new SimpleRecord(recordTwoBuffer);
 
-        log.appendAsLeader(Arrays.asList(recordOne, recordTwo), epoch);
+        appendAsLeader(Arrays.asList(recordOne, recordTwo), epoch);
 
         Records records = log.read(0, Isolation.UNCOMMITTED).records;
 
@@ -597,11 +635,23 @@ public class MockLogTest {
         }
     }
 
+    private void appendAsLeader(Collection<SimpleRecord> records, int epoch) {
+        log.appendAsLeader(
+            MemoryRecords.withRecords(
+                log.endOffset().offset,
+                CompressionType.NONE,
+                records.toArray(new SimpleRecord[records.size()])
+            ),
+            epoch
+        );
+    }
+
     private void appendBatch(int numRecords, int epoch) {
         List<SimpleRecord> records = new ArrayList<>(numRecords);
         for (int i = 0; i < numRecords; i++) {
             records.add(new SimpleRecord(String.valueOf(i).getBytes()));
         }
-        log.appendAsLeader(records, epoch);
+
+        appendAsLeader(records, epoch);
     }
 }
diff --git a/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java b/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java
index 9d19b86..12fda05 100644
--- a/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java
+++ b/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java
@@ -171,8 +171,13 @@ public final class RaftClientTestContext {
             return this;
         }
 
-        Builder appendToLog(long baseOffset, int epoch, List<String> records) {
-            MemoryRecords batch = buildBatch(time.milliseconds(), baseOffset, epoch, records);
+        Builder appendToLog(int epoch, List<String> records) {
+            MemoryRecords batch = buildBatch(
+                time.milliseconds(),
+                log.endOffset().offset,
+                epoch,
+                records
+            );
             log.appendAsLeader(batch, epoch);
             return this;
         }