You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jg...@apache.org on 2021/02/19 03:47:52 UTC

[kafka] branch trunk updated: KAFKA-12258; Add support for splitting appending records (#10063)

This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 9243c10  KAFKA-12258; Add support for splitting appending records (#10063)
9243c10 is described below

commit 9243c10161eb10631353f34a821a8dfb8cab51ed
Author: José Armando García Sancio <js...@users.noreply.github.com>
AuthorDate: Thu Feb 18 19:46:23 2021 -0800

    KAFKA-12258; Add support for splitting appending records (#10063)
    
    1. Type `BatchAccumulator`. Add support for appending records into one or more batches.
    2. Type `RaftClient`. Rename `scheduleAppend` to `scheduleAtomicAppend`.
    3. Type `RaftClient`. Add a new method `scheduleAppend` which appends records to the log using as many batches as necessary.
    4. Increase the batch size from 1MB to 8MB.
    
    Reviewers: David Arthur <mu...@gmail.com>, Jason Gustafson <ja...@confluent.io>
---
 .../org/apache/kafka/raft/KafkaRaftClient.java     |  25 +++-
 .../java/org/apache/kafka/raft/RaftClient.java     |  48 +++++--
 .../kafka/raft/internals/BatchAccumulator.java     | 124 +++++++++++-----
 .../apache/kafka/raft/internals/BatchBuilder.java  | 102 +++++++++----
 .../kafka/raft/internals/BatchAccumulatorTest.java | 159 ++++++++++++++++-----
 .../kafka/raft/internals/BatchBuilderTest.java     |   9 +-
 6 files changed, 339 insertions(+), 128 deletions(-)

diff --git a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
index 081651a..9fbbe31 100644
--- a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
+++ b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
@@ -141,7 +141,7 @@ import static org.apache.kafka.raft.RaftUtil.hasValidTopicPartition;
 public class KafkaRaftClient<T> implements RaftClient<T> {
     private static final int RETRY_BACKOFF_BASE_MS = 100;
     private static final int FETCH_MAX_WAIT_MS = 1000;
-    static final int MAX_BATCH_SIZE = 1024 * 1024;
+    static final int MAX_BATCH_SIZE = 8 * 1024 * 1024;
 
     private final AtomicReference<GracefulShutdown> shutdown = new AtomicReference<>();
     private final Logger logger;
@@ -2188,13 +2188,27 @@ public class KafkaRaftClient<T> implements RaftClient<T> {
 
     @Override
     public Long scheduleAppend(int epoch, List<T> records) {
+        return append(epoch, records, false);
+    }
+
+    @Override
+    public Long scheduleAtomicAppend(int epoch, List<T> records) {
+        return append(epoch, records, true);
+    }
+
+    private Long append(int epoch, List<T> records, boolean isAtomic) {
         BatchAccumulator<T> accumulator = this.accumulator;
         if (accumulator == null) {
             return Long.MAX_VALUE;
         }
 
         boolean isFirstAppend = accumulator.isEmpty();
-        Long offset = accumulator.append(epoch, records);
+        final Long offset;
+        if (isAtomic) {
+            offset = accumulator.appendAtomic(epoch, records);
+        } else {
+            offset = accumulator.append(epoch, records);
+        }
 
         // Wakeup the network channel if either this is the first append
         // or the accumulator is ready to drain now. Checking for the first
@@ -2351,9 +2365,10 @@ public class KafkaRaftClient<T> implements RaftClient<T> {
 
         /**
          * This API is used for committed records originating from {@link #scheduleAppend(int, List)}
-         * on this instance. In this case, we are able to save the original record objects,
-         * which saves the need to read them back from disk. This is a nice optimization
-         * for the leader which is typically doing more work than all of the followers.
+         * or {@link #scheduleAtomicAppend(int, List)} on this instance. In this case, we are able to
+         * save the original record objects, which saves the need to read them back from disk. This is
+         * a nice optimization for the leader which is typically doing more work than all of the
+         * followers.
          */
         public void fireHandleCommit(long baseOffset, int epoch, List<T> records) {
             BatchReader.Batch<T> batch = new BatchReader.Batch<>(baseOffset, epoch, records);
diff --git a/raft/src/main/java/org/apache/kafka/raft/RaftClient.java b/raft/src/main/java/org/apache/kafka/raft/RaftClient.java
index e2bec0e..74488b4 100644
--- a/raft/src/main/java/org/apache/kafka/raft/RaftClient.java
+++ b/raft/src/main/java/org/apache/kafka/raft/RaftClient.java
@@ -32,11 +32,13 @@ public interface RaftClient<T> extends Closeable {
          * after consuming the reader.
          *
          * Note that there is not a one-to-one correspondence between writes through
-         * {@link #scheduleAppend(int, List)} and this callback. The Raft implementation
-         * is free to batch together the records from multiple append calls provided
-         * that batch boundaries are respected. This means that each batch specified
-         * through {@link #scheduleAppend(int, List)} is guaranteed to be a subset of
-         * a batch provided by the {@link BatchReader}.
+         * {@link #scheduleAppend(int, List)} or {@link #scheduleAtomicAppend(int, List)}
+         * and this callback. The Raft implementation is free to batch together the records
+         * from multiple append calls provided that batch boundaries are respected. Records
+         * specified through {@link #scheduleAtomicAppend(int, List)} are guaranteed to be a
+         * subset of a batch provided by the {@link BatchReader}. Records specified through
+         * {@link #scheduleAppend(int, List)} are guaranteed to be in the same order but
+         * they can map to any number of batches provided by the {@link BatchReader}.
          *
          * @param reader reader instance which must be iterated and closed
          */
@@ -48,7 +50,7 @@ public interface RaftClient<T> extends Closeable {
          * {@link #handleCommit(BatchReader)}.
          *
          * After becoming a leader, the client is eligible to write to the log
-         * using {@link #scheduleAppend(int, List)}.
+         * using {@link #scheduleAppend(int, List)} or {@link #scheduleAtomicAppend(int, List)}.
          *
          * @param epoch the claimed leader epoch
          */
@@ -87,6 +89,30 @@ public interface RaftClient<T> extends Closeable {
     /**
      * Append a list of records to the log. The write will be scheduled for some time
      * in the future. There is no guarantee that appended records will be written to
+     * the log and eventually committed. While the order of the records is preserve, they can
+     * be appended to the log using one or more batches. Each record may be committed independently.
+     * If a record is committed, then all records scheduled for append during this epoch
+     * and prior to this record are also committed.
+     *
+     * If the provided current leader epoch does not match the current epoch, which
+     * is possible when the state machine has yet to observe the epoch change, then
+     * this method will return {@link Long#MAX_VALUE} to indicate an offset which is
+     * not possible to become committed. The state machine is expected to discard all
+     * uncommitted entries after observing an epoch change.
+     *
+     * @param epoch the current leader epoch
+     * @param records the list of records to append
+     * @return the expected offset of the last record; {@link Long#MAX_VALUE} if the records could
+     *         be committed; null if no memory could be allocated for the batch at this time
+     * @throws RecordBatchTooLargeException if the size of the records is greater than the maximum
+     *         batch size; if this exception is throw none of the elements in records were
+     *         committed
+     */
+    Long scheduleAppend(int epoch, List<T> records);
+
+    /**
+     * Append a list of records to the log. The write will be scheduled for some time
+     * in the future. There is no guarantee that appended records will be written to
      * the log and eventually committed. However, it is guaranteed that if any of the
      * records become committed, then all of them will be.
      *
@@ -98,11 +124,13 @@ public interface RaftClient<T> extends Closeable {
      *
      * @param epoch the current leader epoch
      * @param records the list of records to append
-     * @return the offset within the current epoch that the log entries will be appended,
-     *         or null if the leader was unable to accept the write (e.g. due to memory
-     *         being reached).
+     * @return the expected offset of the last record; {@link Long#MAX_VALUE} if the records could
+     *         be committed; null if no memory could be allocated for the batch at this time
+     * @throws RecordBatchTooLargeException if the size of the records is greater than the maximum
+     *         batch size; if this exception is throw none of the elements in records were
+     *         committed
      */
-    Long scheduleAppend(int epoch, List<T> records);
+    Long scheduleAtomicAppend(int epoch, List<T> records);
 
     /**
      * Attempt a graceful shutdown of the client. This allows the leader to proactively
diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/BatchAccumulator.java b/raft/src/main/java/org/apache/kafka/raft/internals/BatchAccumulator.java
index 5331e4d..07d1015 100644
--- a/raft/src/main/java/org/apache/kafka/raft/internals/BatchAccumulator.java
+++ b/raft/src/main/java/org/apache/kafka/raft/internals/BatchAccumulator.java
@@ -16,6 +16,7 @@
  */
 package org.apache.kafka.raft.internals;
 
+import org.apache.kafka.common.errors.RecordBatchTooLargeException;
 import org.apache.kafka.common.memory.MemoryPool;
 import org.apache.kafka.common.protocol.ObjectSerializationCache;
 import org.apache.kafka.common.record.CompressionType;
@@ -26,8 +27,10 @@ import org.apache.kafka.raft.RecordSerde;
 import java.io.Closeable;
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.List;
+import java.util.OptionalInt;
 import java.util.concurrent.ConcurrentLinkedQueue;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.locks.ReentrantLock;
@@ -79,70 +82,111 @@ public class BatchAccumulator<T> implements Closeable {
     }
 
     /**
-     * Append a list of records into an atomic batch. We guarantee all records
-     * are included in the same underlying record batch so that either all of
-     * the records become committed or none of them do.
+     * Append a list of records into as many batches as necessary.
      *
-     * @param epoch the expected leader epoch. If this does not match, then
-     *              {@link Long#MAX_VALUE} will be returned as an offset which
-     *              cannot become committed.
-     * @param records the list of records to include in a batch
-     * @return the expected offset of the last record (which will be
-     *         {@link Long#MAX_VALUE} if the epoch does not match), or null if
-     *         no memory could be allocated for the batch at this time
+     * The order of the elements in the records argument will match the order in the batches.
+     * This method will use as many batches as necessary to serialize all of the records. Since
+     * this method can split the records into multiple batches it is possible that some of the
+     * records will get committed while other will not when the leader fails.
+     *
+     * @param epoch the expected leader epoch. If this does not match, then {@link Long#MAX_VALUE}
+     *              will be returned as an offset which cannot become committed
+     * @param records the list of records to include in the batches
+     * @return the expected offset of the last record; {@link Long#MAX_VALUE} if the epoch does not
+     *         match; null if no memory could be allocated for the batch at this time
+     * @throws RecordBatchTooLargeException if the size of one record T is greater than the maximum
+     *         batch size; if this exception is throw some of the elements in records may have
+     *         been committed
      */
     public Long append(int epoch, List<T> records) {
+        return append(epoch, records, false);
+    }
+
+    /**
+     * Append a list of records into an atomic batch. We guarantee all records are included in the
+     * same underlying record batch so that either all of the records become committed or none of
+     * them do.
+     *
+     * @param epoch the expected leader epoch. If this does not match, then {@link Long#MAX_VALUE}
+     *              will be returned as an offset which cannot become committed
+     * @param records the list of records to include in a batch
+     * @return the expected offset of the last record; {@link Long#MAX_VALUE} if the epoch does not
+     *         match; null if no memory could be allocated for the batch at this time
+     * @throws RecordBatchTooLargeException if the size of the records is greater than the maximum
+     *         batch size; if this exception is throw none of the elements in records were
+     *         committed
+     */
+    public Long appendAtomic(int epoch, List<T> records) {
+        return append(epoch, records, true);
+    }
+
+    private Long append(int epoch, List<T> records, boolean isAtomic) {
         if (epoch != this.epoch) {
-            // If the epoch does not match, then the state machine probably
-            // has not gotten the notification about the latest epoch change.
-            // In this case, ignore the append and return a large offset value
-            // which will never be committed.
             return Long.MAX_VALUE;
         }
 
         ObjectSerializationCache serializationCache = new ObjectSerializationCache();
-        int batchSize = 0;
-        for (T record : records) {
-            batchSize += serde.recordSize(record, serializationCache);
-        }
-
-        if (batchSize > maxBatchSize) {
-            throw new IllegalArgumentException("The total size of " + records + " is " + batchSize +
-                ", which exceeds the maximum allowed batch size of " + maxBatchSize);
-        }
 
         appendLock.lock();
         try {
             maybeCompleteDrain();
 
-            BatchBuilder<T> batch = maybeAllocateBatch(batchSize);
-            if (batch == null) {
-                return null;
-            }
-
-            // Restart the linger timer if necessary
-            if (!lingerTimer.isRunning()) {
-                lingerTimer.reset(time.milliseconds() + lingerMs);
+            BatchBuilder<T> batch = null;
+            if (isAtomic) {
+                batch = maybeAllocateBatch(records, serializationCache);
             }
 
             for (T record : records) {
+                if (!isAtomic) {
+                    batch = maybeAllocateBatch(Collections.singleton(record), serializationCache);
+                }
+
+                if (batch == null) {
+                    return null;
+                }
+
                 batch.appendRecord(record, serializationCache);
                 nextOffset += 1;
             }
 
+            maybeResetLinger();
+
             return nextOffset - 1;
         } finally {
             appendLock.unlock();
         }
     }
 
-    private BatchBuilder<T> maybeAllocateBatch(int batchSize) {
+    private void maybeResetLinger() {
+        if (!lingerTimer.isRunning()) {
+            lingerTimer.reset(time.milliseconds() + lingerMs);
+        }
+    }
+
+    private BatchBuilder<T> maybeAllocateBatch(
+        Collection<T> records,
+        ObjectSerializationCache serializationCache
+    ) {
         if (currentBatch == null) {
             startNewBatch();
-        } else if (!currentBatch.hasRoomFor(batchSize)) {
-            completeCurrentBatch();
-            startNewBatch();
         }
+
+        if (currentBatch != null) {
+            OptionalInt bytesNeeded = currentBatch.bytesNeeded(records, serializationCache);
+            if (bytesNeeded.isPresent() && bytesNeeded.getAsInt() > maxBatchSize) {
+                throw new RecordBatchTooLargeException(
+                    String.format(
+                        "The total record(s) size of %s exceeds the maximum allowed batch size of %s",
+                        bytesNeeded.getAsInt(),
+                        maxBatchSize
+                    )
+                );
+            } else if (bytesNeeded.isPresent()) {
+                completeCurrentBatch();
+                startNewBatch();
+            }
+        }
+
         return currentBatch;
     }
 
@@ -298,20 +342,22 @@ public class BatchAccumulator<T> implements Closeable {
         public final List<T> records;
         public final MemoryRecords data;
         private final MemoryPool pool;
-        private final ByteBuffer buffer;
+        // Buffer that was allocated by the MemoryPool (pool). This may not be the buffer used in
+        // the MemoryRecords (data) object.
+        private final ByteBuffer initialBuffer;
 
         private CompletedBatch(
             long baseOffset,
             List<T> records,
             MemoryRecords data,
             MemoryPool pool,
-            ByteBuffer buffer
+            ByteBuffer initialBuffer
         ) {
             this.baseOffset = baseOffset;
             this.records = records;
             this.data = data;
             this.pool = pool;
-            this.buffer = buffer;
+            this.initialBuffer = initialBuffer;
         }
 
         public int sizeInBytes() {
@@ -319,7 +365,7 @@ public class BatchAccumulator<T> implements Closeable {
         }
 
         public void release() {
-            pool.release(buffer);
+            pool.release(initialBuffer);
         }
     }
 
diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/BatchBuilder.java b/raft/src/main/java/org/apache/kafka/raft/internals/BatchBuilder.java
index 542bb51..c953b6a 100644
--- a/raft/src/main/java/org/apache/kafka/raft/internals/BatchBuilder.java
+++ b/raft/src/main/java/org/apache/kafka/raft/internals/BatchBuilder.java
@@ -33,12 +33,14 @@ import org.apache.kafka.raft.RecordSerde;
 import java.io.DataOutputStream;
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.List;
+import java.util.OptionalInt;
 
 /**
  * Collect a set of records into a single batch. New records are added
  * through {@link #appendRecord(Object, ObjectSerializationCache)}, but the caller must first
- * check whether there is room using {@link #hasRoomFor(int)}. Once the
+ * check whether there is room using {@link #bytesNeeded(Collection, ObjectSerializationCache)}. Once the
  * batch is ready, then {@link #build()} should be used to get the resulting
  * {@link MemoryRecords} instance.
  *
@@ -85,8 +87,8 @@ public class BatchBuilder<T> {
         this.maxBytes = maxBytes;
         this.records = new ArrayList<>();
 
-        int batchHeaderSizeInBytes = AbstractRecords.recordBatchHeaderSizeInBytes(
-            RecordBatch.MAGIC_VALUE_V2, compressionType);
+        // field compressionType must be set before calculating the batch header size
+        int batchHeaderSizeInBytes = batchHeaderSizeInBytes();
         batchOutput.position(initialPosition + batchHeaderSizeInBytes);
 
         this.recordOutput = new DataOutputStreamWritable(new DataOutputStream(
@@ -95,7 +97,7 @@ public class BatchBuilder<T> {
 
     /**
      * Append a record to this patch. The caller must first verify there is room for the batch
-     * using {@link #hasRoomFor(int)}.
+     * using {@link #bytesNeeded(Collection, ObjectSerializationCache)}.
      *
      * @param record the record to append
      * @param serializationCache serialization cache for use in {@link RecordSerde#write(Object, ObjectSerializationCache, Writable)}
@@ -103,7 +105,7 @@ public class BatchBuilder<T> {
      */
     public long appendRecord(T record, ObjectSerializationCache serializationCache) {
         if (!isOpenForAppends) {
-            throw new IllegalArgumentException("Cannot append new records after the batch has been built");
+            throw new IllegalStateException("Cannot append new records after the batch has been built");
         }
 
         if (nextOffset - baseOffset > Integer.MAX_VALUE) {
@@ -123,39 +125,39 @@ public class BatchBuilder<T> {
     }
 
     /**
-     * Check whether the batch has enough room for a record of the given size in bytes.
+     * Check whether the batch has enough room for all the record values.
      *
-     * @param sizeInBytes the size of the record to be appended
-     * @return true if there is room for the record to be appended, false otherwise
+     * Returns an empty {@link OptionalInt} if the batch builder has room for this list of records.
+     * Otherwise it returns the expected number of bytes needed for a batch to contain these records.
+     *
+     * @param records the records to use when checking for room
+     * @param serializationCache serialization cache for computing sizes
+     * @return empty {@link OptionalInt} if there is room for the records to be appended, otherwise
+     *         returns the number of bytes needed
      */
-    public boolean hasRoomFor(int sizeInBytes) {
-        if (!isOpenForAppends) {
-            return false;
-        }
+    public OptionalInt bytesNeeded(Collection<T> records, ObjectSerializationCache serializationCache) {
+        int bytesNeeded = bytesNeededForRecords(
+            records,
+            serializationCache
+        );
 
-        if (nextOffset - baseOffset >= Integer.MAX_VALUE) {
-            return false;
+        if (!isOpenForAppends) {
+            return OptionalInt.of(batchHeaderSizeInBytes() + bytesNeeded);
         }
 
-        int recordSizeInBytes = DefaultRecord.sizeOfBodyInBytes(
-            (int) (nextOffset - baseOffset),
-            0,
-            -1,
-            sizeInBytes,
-            DefaultRecord.EMPTY_HEADERS
-        );
-
-        int unusedSizeInBytes = maxBytes - approximateSizeInBytes();
-        if (unusedSizeInBytes >= recordSizeInBytes) {
-            return true;
+        int approxUnusedSizeInBytes = maxBytes - approximateSizeInBytes();
+        if (approxUnusedSizeInBytes >= bytesNeeded) {
+            return OptionalInt.empty();
         } else if (unflushedBytes > 0) {
             recordOutput.flush();
             unflushedBytes = 0;
-            unusedSizeInBytes = maxBytes - flushedSizeInBytes();
-            return unusedSizeInBytes >= recordSizeInBytes;
-        } else {
-            return false;
+            int unusedSizeInBytes = maxBytes - flushedSizeInBytes();
+            if (unusedSizeInBytes >= bytesNeeded) {
+                return OptionalInt.empty();
+            }
         }
+
+        return OptionalInt.of(batchHeaderSizeInBytes() + bytesNeeded);
     }
 
     private int flushedSizeInBytes() {
@@ -307,4 +309,46 @@ public class BatchBuilder<T> {
         recordOutput.writeVarint(0);
         return ByteUtils.sizeOfVarint(sizeInBytes) + sizeInBytes;
     }
+
+    private int batchHeaderSizeInBytes() {
+        return AbstractRecords.recordBatchHeaderSizeInBytes(
+            RecordBatch.MAGIC_VALUE_V2,
+            compressionType
+        );
+    }
+
+    private int bytesNeededForRecords(
+        Collection<T> records,
+        ObjectSerializationCache serializationCache
+    ) {
+        long expectedNextOffset = nextOffset;
+        int bytesNeeded = 0;
+        for (T record : records) {
+            if (expectedNextOffset - baseOffset >= Integer.MAX_VALUE) {
+                throw new IllegalArgumentException(
+                    String.format(
+                        "Adding %s records to a batch with base offset of %s and next offset of %s",
+                        records.size(),
+                        baseOffset,
+                        expectedNextOffset
+                    )
+                );
+            }
+
+            int recordSizeInBytes = DefaultRecord.sizeOfBodyInBytes(
+                (int) (expectedNextOffset  - baseOffset),
+                0,
+                -1,
+                serde.recordSize(record, serializationCache),
+                DefaultRecord.EMPTY_HEADERS
+            );
+
+            bytesNeeded = Math.addExact(bytesNeeded, ByteUtils.sizeOfVarint(recordSizeInBytes));
+            bytesNeeded = Math.addExact(bytesNeeded, recordSizeInBytes);
+
+            expectedNextOffset += 1;
+        }
+
+        return bytesNeeded;
+    }
 }
diff --git a/raft/src/test/java/org/apache/kafka/raft/internals/BatchAccumulatorTest.java b/raft/src/test/java/org/apache/kafka/raft/internals/BatchAccumulatorTest.java
index 24e289d..b32168e 100644
--- a/raft/src/test/java/org/apache/kafka/raft/internals/BatchAccumulatorTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/internals/BatchAccumulatorTest.java
@@ -19,7 +19,11 @@ package org.apache.kafka.raft.internals;
 import org.apache.kafka.common.memory.MemoryPool;
 import org.apache.kafka.common.protocol.ObjectSerializationCache;
 import org.apache.kafka.common.protocol.Writable;
+import org.apache.kafka.common.record.AbstractRecords;
 import org.apache.kafka.common.record.CompressionType;
+import org.apache.kafka.common.record.DefaultRecord;
+import org.apache.kafka.common.record.RecordBatch;
+import org.apache.kafka.common.utils.ByteUtils;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.Utils;
 import org.junit.jupiter.api.Test;
@@ -29,6 +33,8 @@ import java.nio.ByteBuffer;
 import java.util.Collections;
 import java.util.List;
 import java.util.concurrent.CountDownLatch;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
 import static java.util.Arrays.asList;
 import static java.util.Collections.singletonList;
@@ -164,47 +170,85 @@ class BatchAccumulatorTest {
 
     @Test
     public void testSingleBatchAccumulation() {
-        int leaderEpoch = 17;
-        long baseOffset = 157;
-        int lingerMs = 50;
-        int maxBatchSize = 512;
-
-        Mockito.when(memoryPool.tryAllocate(maxBatchSize))
-            .thenReturn(ByteBuffer.allocate(maxBatchSize));
-
-        BatchAccumulator<String> acc = buildAccumulator(
-            leaderEpoch,
-            baseOffset,
-            lingerMs,
-            maxBatchSize
-        );
-
-        List<String> records = asList("a", "b", "c", "d", "e", "f", "g", "h", "i");
-        assertEquals(baseOffset, acc.append(leaderEpoch, records.subList(0, 1)));
-        assertEquals(baseOffset + 2, acc.append(leaderEpoch, records.subList(1, 3)));
-        assertEquals(baseOffset + 5, acc.append(leaderEpoch, records.subList(3, 6)));
-        assertEquals(baseOffset + 7, acc.append(leaderEpoch, records.subList(6, 8)));
-        assertEquals(baseOffset + 8, acc.append(leaderEpoch, records.subList(8, 9)));
-
-        time.sleep(lingerMs);
-        assertTrue(acc.needsDrain(time.milliseconds()));
-
-        List<BatchAccumulator.CompletedBatch<String>> batches = acc.drain();
-        assertEquals(1, batches.size());
-        assertFalse(acc.needsDrain(time.milliseconds()));
-        assertEquals(Long.MAX_VALUE - time.milliseconds(), acc.timeUntilDrain(time.milliseconds()));
-
-        BatchAccumulator.CompletedBatch<String> batch = batches.get(0);
-        assertEquals(records, batch.records);
-        assertEquals(baseOffset, batch.baseOffset);
+        asList(APPEND, APPEND_ATOMIC).forEach(appender -> {
+            int leaderEpoch = 17;
+            long baseOffset = 157;
+            int lingerMs = 50;
+            int maxBatchSize = 512;
+
+            Mockito.when(memoryPool.tryAllocate(maxBatchSize))
+                .thenReturn(ByteBuffer.allocate(maxBatchSize));
+
+            BatchAccumulator<String> acc = buildAccumulator(
+                leaderEpoch,
+                baseOffset,
+                lingerMs,
+                maxBatchSize
+            );
+
+            List<String> records = asList("a", "b", "c", "d", "e", "f", "g", "h", "i");
+            assertEquals(baseOffset, appender.call(acc, leaderEpoch, records.subList(0, 1)));
+            assertEquals(baseOffset + 2, appender.call(acc, leaderEpoch, records.subList(1, 3)));
+            assertEquals(baseOffset + 5, appender.call(acc, leaderEpoch, records.subList(3, 6)));
+            assertEquals(baseOffset + 7, appender.call(acc, leaderEpoch, records.subList(6, 8)));
+            assertEquals(baseOffset + 8, appender.call(acc, leaderEpoch, records.subList(8, 9)));
+
+            time.sleep(lingerMs);
+            assertTrue(acc.needsDrain(time.milliseconds()));
+
+            List<BatchAccumulator.CompletedBatch<String>> batches = acc.drain();
+            assertEquals(1, batches.size());
+            assertFalse(acc.needsDrain(time.milliseconds()));
+            assertEquals(Long.MAX_VALUE - time.milliseconds(), acc.timeUntilDrain(time.milliseconds()));
+
+            BatchAccumulator.CompletedBatch<String> batch = batches.get(0);
+            assertEquals(records, batch.records);
+            assertEquals(baseOffset, batch.baseOffset);
+        });
     }
 
     @Test
     public void testMultipleBatchAccumulation() {
+        asList(APPEND, APPEND_ATOMIC).forEach(appender -> {
+            int leaderEpoch = 17;
+            long baseOffset = 157;
+            int lingerMs = 50;
+            int maxBatchSize = 256;
+
+            Mockito.when(memoryPool.tryAllocate(maxBatchSize))
+                .thenReturn(ByteBuffer.allocate(maxBatchSize));
+
+            BatchAccumulator<String> acc = buildAccumulator(
+                leaderEpoch,
+                baseOffset,
+                lingerMs,
+                maxBatchSize
+            );
+
+            // Append entries until we have 4 batches to drain (3 completed, 1 building)
+            while (acc.numCompletedBatches() < 3) {
+                appender.call(acc, leaderEpoch, singletonList("foo"));
+            }
+
+            List<BatchAccumulator.CompletedBatch<String>> batches = acc.drain();
+            assertEquals(4, batches.size());
+            assertTrue(batches.stream().allMatch(batch -> batch.data.sizeInBytes() <= maxBatchSize));
+        });
+    }
+
+    @Test
+    public void testRecordsAreSplit() {
         int leaderEpoch = 17;
         long baseOffset = 157;
         int lingerMs = 50;
-        int maxBatchSize = 256;
+        String record = "a";
+        int numberOfRecords = 9;
+        int recordsPerBatch = 2;
+        int batchHeaderSize = AbstractRecords.recordBatchHeaderSizeInBytes(
+            RecordBatch.MAGIC_VALUE_V2,
+            CompressionType.NONE
+        );
+        int maxBatchSize = batchHeaderSize + recordsPerBatch * recordSizeInBytes(record, recordsPerBatch);
 
         Mockito.when(memoryPool.tryAllocate(maxBatchSize))
             .thenReturn(ByteBuffer.allocate(maxBatchSize));
@@ -216,13 +260,19 @@ class BatchAccumulatorTest {
             maxBatchSize
         );
 
-        // Append entries until we have 4 batches to drain (3 completed, 1 building)
-        while (acc.numCompletedBatches() < 3) {
-            acc.append(leaderEpoch, singletonList("foo"));
-        }
+        List<String> records = Stream
+            .generate(() -> record)
+            .limit(numberOfRecords)
+            .collect(Collectors.toList());
+        assertEquals(baseOffset + numberOfRecords - 1, acc.append(leaderEpoch, records));
+
+        time.sleep(lingerMs);
+        assertTrue(acc.needsDrain(time.milliseconds()));
 
         List<BatchAccumulator.CompletedBatch<String>> batches = acc.drain();
-        assertEquals(4, batches.size());
+        // ceilingDiv(records.size(), recordsPerBatch)
+        int expectedBatches = (records.size() + recordsPerBatch - 1) / recordsPerBatch;
+        assertEquals(expectedBatches, batches.size());
         assertTrue(batches.stream().allMatch(batch -> batch.data.sizeInBytes() <= maxBatchSize));
     }
 
@@ -306,4 +356,35 @@ class BatchAccumulatorTest {
         });
     }
 
+    int recordSizeInBytes(String record, int numberOfRecords) {
+        int serdeSize = serde.recordSize("a", new ObjectSerializationCache());
+
+        int recordSizeInBytes = DefaultRecord.sizeOfBodyInBytes(
+            numberOfRecords,
+            0,
+            -1,
+            serdeSize,
+            DefaultRecord.EMPTY_HEADERS
+        );
+
+        return ByteUtils.sizeOfVarint(recordSizeInBytes) + recordSizeInBytes;
+    }
+
+    static interface Appender {
+        Long call(BatchAccumulator<String> acc, int epoch, List<String> records);
+    }
+
+    static final Appender APPEND_ATOMIC = new Appender() {
+        @Override
+        public Long call(BatchAccumulator<String> acc, int epoch, List<String> records) {
+            return acc.appendAtomic(epoch, records);
+        }
+    };
+
+    static final Appender APPEND = new Appender() {
+        @Override
+        public Long call(BatchAccumulator<String> acc, int epoch, List<String> records) {
+            return acc.append(epoch, records);
+        }
+    };
 }
diff --git a/raft/src/test/java/org/apache/kafka/raft/internals/BatchBuilderTest.java b/raft/src/test/java/org/apache/kafka/raft/internals/BatchBuilderTest.java
index f860df7..e4611f1 100644
--- a/raft/src/test/java/org/apache/kafka/raft/internals/BatchBuilderTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/internals/BatchBuilderTest.java
@@ -31,7 +31,6 @@ import java.util.List;
 import java.util.stream.Collectors;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
@@ -69,8 +68,8 @@ class BatchBuilderTest {
 
         records.forEach(record -> builder.appendRecord(record, null));
         MemoryRecords builtRecordSet = builder.build();
-        assertFalse(builder.hasRoomFor(1));
-        assertThrows(IllegalArgumentException.class, () -> builder.appendRecord("a", null));
+        assertTrue(builder.bytesNeeded(Arrays.asList("a"), null).isPresent());
+        assertThrows(IllegalStateException.class, () -> builder.appendRecord("a", null));
 
         List<MutableRecordBatch> builtBatches = Utils.toList(builtRecordSet.batchIterator());
         assertEquals(1, builtBatches.size());
@@ -112,9 +111,8 @@ class BatchBuilderTest {
         );
 
         String record = "i am a record";
-        int recordSize = serde.recordSize(record);
 
-        while (builder.hasRoomFor(recordSize)) {
+        while (!builder.bytesNeeded(Arrays.asList(record), null).isPresent()) {
             builder.appendRecord(record, null);
         }
 
@@ -125,5 +123,4 @@ class BatchBuilderTest {
         assertTrue(sizeInBytes <= batchSize, "Built batch size "
             + sizeInBytes + " is larger than max batch size " + batchSize);
     }
-
 }