You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jg...@apache.org on 2018/06/19 20:26:37 UTC

[kafka] branch trunk updated: MINOR: Handle segment splitting edge cases and fix recovery bug (#5169)

This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 3afe2ed  MINOR: Handle segment splitting edge cases and fix recovery bug  (#5169)
3afe2ed is described below

commit 3afe2ed8e39da3605274baddc7245be569b4aed7
Author: Jason Gustafson <ja...@confluent.io>
AuthorDate: Tue Jun 19 13:26:29 2018 -0700

    MINOR: Handle segment splitting edge cases and fix recovery bug  (#5169)
    
    This patch fixes the following issues in the log splitting logic added to address KAFKA-6264:
    
    1. We were not handling the case when all messages in the segment overflowed the index. In this case, there is only one resulting segment following the split.
    2. There was an off-by-one error in the recovery logic when completing a swap operation which caused an unintended segment deletion.
    
    Additionally, this patch factors out of `splitOverflowedSegment` a method to write to a segment using from with an instance of `FileRecords`. This allows for future reuse and isolated testing.
    
    Reviewers: Dhruvil Shah <dh...@confluent.io>, Ismael Juma <is...@juma.me.uk>, Jun Rao <ju...@gmail.com>
---
 .../apache/kafka/common/record/BufferSupplier.java |  32 ++
 .../apache/kafka/common/record/FileRecords.java    |   4 +-
 .../kafka/common/record/BufferSupplierTest.java    |  46 +++
 .../common/LogSegmentOffsetOverflowException.scala |   5 +-
 .../coordinator/group/GroupMetadataManager.scala   |   4 +-
 .../transaction/TransactionStateManager.scala      |   4 +-
 core/src/main/scala/kafka/log/Log.scala            | 122 ++------
 core/src/main/scala/kafka/log/LogCleaner.scala     |   5 +-
 core/src/main/scala/kafka/log/LogSegment.scala     | 112 +++++--
 .../group/GroupMetadataManagerTest.scala           |  13 +-
 .../TransactionCoordinatorConcurrencyTest.scala    |  32 +-
 .../transaction/TransactionStateManagerTest.scala  |  11 +-
 .../test/scala/unit/kafka/log/LogCleanerTest.scala |  11 +-
 .../test/scala/unit/kafka/log/LogSegmentTest.scala |  30 ++
 core/src/test/scala/unit/kafka/log/LogTest.scala   | 327 +++++++++++----------
 15 files changed, 448 insertions(+), 310 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/common/record/BufferSupplier.java b/clients/src/main/java/org/apache/kafka/common/record/BufferSupplier.java
index 2e09f7d..1a6c92c 100644
--- a/clients/src/main/java/org/apache/kafka/common/record/BufferSupplier.java
+++ b/clients/src/main/java/org/apache/kafka/common/record/BufferSupplier.java
@@ -93,4 +93,36 @@ public abstract class BufferSupplier implements AutoCloseable {
             bufferMap.clear();
         }
     }
+
+    /**
+     * Simple buffer supplier for single-threaded usage. It caches a single buffer, which grows
+     * monotonically as needed to fulfill the allocation request.
+     */
+    public static class GrowableBufferSupplier extends BufferSupplier {
+        private ByteBuffer cachedBuffer;
+
+        @Override
+        public ByteBuffer get(int minCapacity) {
+            if (cachedBuffer != null && cachedBuffer.capacity() >= minCapacity) {
+                ByteBuffer res = cachedBuffer;
+                cachedBuffer = null;
+                return res;
+            } else {
+                cachedBuffer = null;
+                return ByteBuffer.allocate(minCapacity);
+            }
+        }
+
+        @Override
+        public void release(ByteBuffer buffer) {
+            buffer.clear();
+            cachedBuffer = buffer;
+        }
+
+        @Override
+        public void close() {
+            cachedBuffer = null;
+        }
+    }
+
 }
diff --git a/clients/src/main/java/org/apache/kafka/common/record/FileRecords.java b/clients/src/main/java/org/apache/kafka/common/record/FileRecords.java
index 20b5105..df38ac7 100644
--- a/clients/src/main/java/org/apache/kafka/common/record/FileRecords.java
+++ b/clients/src/main/java/org/apache/kafka/common/record/FileRecords.java
@@ -109,14 +109,12 @@ public class FileRecords extends AbstractRecords implements Closeable {
      *
      * @param buffer The buffer to write the batches to
      * @param position Position in the buffer to read from
-     * @return The same buffer
      * @throws IOException If an I/O error occurs, see {@link FileChannel#read(ByteBuffer, long)} for details on the
      * possible exceptions
      */
-    public ByteBuffer readInto(ByteBuffer buffer, int position) throws IOException {
+    public void readInto(ByteBuffer buffer, int position) throws IOException {
         Utils.readFully(channel, buffer, position + this.start);
         buffer.flip();
-        return buffer;
     }
 
     /**
diff --git a/clients/src/test/java/org/apache/kafka/common/record/BufferSupplierTest.java b/clients/src/test/java/org/apache/kafka/common/record/BufferSupplierTest.java
new file mode 100644
index 0000000..dea0c98
--- /dev/null
+++ b/clients/src/test/java/org/apache/kafka/common/record/BufferSupplierTest.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.common.record;
+
+import org.junit.Test;
+
+import java.nio.ByteBuffer;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertSame;
+
+public class BufferSupplierTest {
+
+    @Test
+    public void testGrowableBuffer() {
+        BufferSupplier.GrowableBufferSupplier supplier = new BufferSupplier.GrowableBufferSupplier();
+        ByteBuffer buffer = supplier.get(1024);
+        assertEquals(0, buffer.position());
+        assertEquals(1024, buffer.capacity());
+        supplier.release(buffer);
+
+        ByteBuffer cached = supplier.get(512);
+        assertEquals(0, cached.position());
+        assertSame(buffer, cached);
+
+        ByteBuffer increased = supplier.get(2048);
+        assertEquals(2048, increased.capacity());
+        assertEquals(0, increased.position());
+    }
+
+}
diff --git a/core/src/main/scala/kafka/common/LogSegmentOffsetOverflowException.scala b/core/src/main/scala/kafka/common/LogSegmentOffsetOverflowException.scala
index 62379de..9a24efe 100644
--- a/core/src/main/scala/kafka/common/LogSegmentOffsetOverflowException.scala
+++ b/core/src/main/scala/kafka/common/LogSegmentOffsetOverflowException.scala
@@ -25,7 +25,6 @@ import kafka.log.LogSegment
  * KAFKA-5413. With KAFKA-6264, we have the ability to split such log segments into multiple log segments such that we
  * do not have any segments with offset overflow.
  */
-class LogSegmentOffsetOverflowException(message: String, cause: Throwable, val logSegment: LogSegment) extends KafkaException(message, cause) {
-  def this(cause: Throwable, logSegment: LogSegment) = this(null, cause, logSegment)
-  def this(message: String, logSegment: LogSegment) = this(message, null, logSegment)
+class LogSegmentOffsetOverflowException(val segment: LogSegment, val offset: Long)
+  extends KafkaException(s"Detected offset overflow at offset $offset in segment $segment") {
 }
diff --git a/core/src/main/scala/kafka/coordinator/group/GroupMetadataManager.scala b/core/src/main/scala/kafka/coordinator/group/GroupMetadataManager.scala
index 35a0574..233a76e 100644
--- a/core/src/main/scala/kafka/coordinator/group/GroupMetadataManager.scala
+++ b/core/src/main/scala/kafka/coordinator/group/GroupMetadataManager.scala
@@ -532,8 +532,8 @@ class GroupMetadataManager(brokerId: Int,
             case records: MemoryRecords => records
             case fileRecords: FileRecords =>
               buffer.clear()
-              val bufferRead = fileRecords.readInto(buffer, 0)
-              MemoryRecords.readableRecords(bufferRead)
+              fileRecords.readInto(buffer, 0)
+              MemoryRecords.readableRecords(buffer)
           }
 
           memRecords.batches.asScala.foreach { batch =>
diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
index 5b82be4..e3b0321 100644
--- a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
+++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala
@@ -313,8 +313,8 @@ class TransactionStateManager(brokerId: Int,
               case records: MemoryRecords => records
               case fileRecords: FileRecords =>
                 buffer.clear()
-                val bufferRead = fileRecords.readInto(buffer, 0)
-                MemoryRecords.readableRecords(bufferRead)
+                fileRecords.readInto(buffer, 0)
+                MemoryRecords.readableRecords(buffer)
             }
 
             memRecords.batches.asScala.foreach { batch =>
diff --git a/core/src/main/scala/kafka/log/Log.scala b/core/src/main/scala/kafka/log/Log.scala
index c92beee..64048fb 100644
--- a/core/src/main/scala/kafka/log/Log.scala
+++ b/core/src/main/scala/kafka/log/Log.scala
@@ -19,7 +19,6 @@ package kafka.log
 
 import java.io.{File, IOException}
 import java.lang.{Long => JLong}
-import java.nio.ByteBuffer
 import java.nio.file.{Files, NoSuchFileException}
 import java.text.NumberFormat
 import java.util.Map.{Entry => JEntry}
@@ -462,21 +461,17 @@ class Log(@volatile var dir: File,
       info(s"Found log file ${swapFile.getPath} from interrupted swap operation, repairing.")
       recoverSegment(swapSegment)
 
-      var oldSegments = logSegments(swapSegment.baseOffset, swapSegment.readNextOffset)
-
-      // We create swap files for two cases: (1) Log cleaning where multiple segments are merged into one, and
+      // We create swap files for two cases:
+      // (1) Log cleaning where multiple segments are merged into one, and
       // (2) Log splitting where one segment is split into multiple.
+      //
       // Both of these mean that the resultant swap segments be composed of the original set, i.e. the swap segment
       // must fall within the range of existing segment(s). If we cannot find such a segment, it means the deletion
       // of that segment was successful. In such an event, we should simply rename the .swap to .log without having to
       // do a replace with an existing segment.
-      if (oldSegments.nonEmpty) {
-        val start = oldSegments.head.baseOffset
-        val end = oldSegments.last.readNextOffset
-        if (!(swapSegment.baseOffset >= start && swapSegment.baseOffset <= end))
-          oldSegments = List()
+      val oldSegments = logSegments(swapSegment.baseOffset, swapSegment.readNextOffset).filter { segment =>
+        segment.readNextOffset > swapSegment.baseOffset
       }
-
       replaceSegments(Seq(swapSegment), oldSegments.toSeq, isRecoveredSwapFile = true)
     }
   }
@@ -494,7 +489,7 @@ class Log(@volatile var dir: File,
     val swapFiles = removeTempFilesAndCollectSwapFiles()
 
     // Now do a second pass and load all the log and index files.
-    // We might encounter legacy log segments with offset overflow (KAFKA-6264). We need to split such segments. Whe
+    // We might encounter legacy log segments with offset overflow (KAFKA-6264). We need to split such segments. When
     // this happens, restart loading segment files from scratch.
     retryOnOffsetOverflow {
       // In case we encounter a segment with offset overflow, the retry logic will split it after which we need to retry
@@ -1838,26 +1833,21 @@ class Log(@volatile var dir: File,
     }
   }
 
-  /**
-   * @throws LogSegmentOffsetOverflowException if we encounter segments with index overflow for more than maxTries
-   */
   private[log] def retryOnOffsetOverflow[T](fn: => T): T = {
-    var triesSoFar = 0
     while (true) {
       try {
         return fn
       } catch {
         case e: LogSegmentOffsetOverflowException =>
-          triesSoFar += 1
-          info(s"Caught LogOffsetOverflowException ${e.getMessage}. Split segment and retry. retry#: $triesSoFar.")
-          splitOverflowedSegment(e.logSegment)
+          info(s"Caught segment overflow error: ${e.getMessage}. Split segment and retry.")
+          splitOverflowedSegment(e.segment)
       }
     }
     throw new IllegalStateException()
   }
 
   /**
-   * Split the given log segment into multiple such that there is no offset overflow in the resulting segments. The
+   * Split a segment into one or more segments such that there is no offset overflow in any of them. The
    * resulting segments will contain the exact same messages that are present in the input segment. On successful
    * completion of this method, the input segment will be deleted and will be replaced by the resulting new segments.
    * See replaceSegments for recovery logic, in case the broker dies in the middle of this operation.
@@ -1871,94 +1861,44 @@ class Log(@volatile var dir: File,
    */
   private[log] def splitOverflowedSegment(segment: LogSegment): List[LogSegment] = {
     require(isLogFile(segment.log.file), s"Cannot split file ${segment.log.file.getAbsoluteFile}")
-    info(s"Attempting to split segment ${segment.log.file.getAbsolutePath}")
-
-    val newSegments = ListBuffer[LogSegment]()
-    var position = 0
-    val sourceRecords = segment.log
-    var readBuffer = ByteBuffer.allocate(1024 * 1024)
-
-    class CopyResult(val bytesRead: Int, val overflowOffset: Option[Long])
-
-    // Helper method to copy `records` into `segment`. Makes sure records being appended do not result in offset overflow.
-    def copyRecordsToSegment(records: FileRecords, segment: LogSegment, readBuffer: ByteBuffer): CopyResult = {
-      var bytesRead = 0
-      var maxTimestamp = Long.MinValue
-      var offsetOfMaxTimestamp = Long.MinValue
-      var maxOffset = Long.MinValue
-
-      // find all batches that are valid to be appended to the current log segment
-      val (validBatches, overflowBatches) = records.batches.asScala.span(batch => segment.offsetIndex.canAppendOffset(batch.lastOffset))
-      val overflowOffset = overflowBatches.headOption.map { firstBatch =>
-        info(s"Found overflow at offset ${firstBatch.baseOffset} in segment $segment")
-        firstBatch.baseOffset
-      }
+    require(segment.hasOverflow, "Split operation is only permitted for segments with overflow")
 
-      // return early if no valid batches were found
-      if (validBatches.isEmpty) {
-        require(overflowOffset.isDefined, "No batches found during split")
-        return new CopyResult(0, overflowOffset)
-      }
-
-      // determine the maximum offset and timestamp in batches
-      for (batch <- validBatches) {
-        if (batch.maxTimestamp > maxTimestamp) {
-          maxTimestamp = batch.maxTimestamp
-          offsetOfMaxTimestamp = batch.lastOffset
-        }
-        maxOffset = batch.lastOffset
-        bytesRead += batch.sizeInBytes
-      }
-
-      // read all valid batches into memory
-      val validRecords = records.slice(0, bytesRead)
-      require(readBuffer.capacity >= validRecords.sizeInBytes)
-      readBuffer.clear()
-      readBuffer.limit(validRecords.sizeInBytes)
-      validRecords.readInto(readBuffer, 0)
-
-      // append valid batches into the segment
-      segment.append(maxOffset, maxTimestamp, offsetOfMaxTimestamp, MemoryRecords.readableRecords(readBuffer))
-      readBuffer.clear()
-      info(s"Appended messages till $maxOffset to segment $segment during split")
-
-      new CopyResult(bytesRead, overflowOffset)
-    }
+    info(s"Splitting overflowed segment $segment")
 
+    val newSegments = ListBuffer[LogSegment]()
     try {
-      info(s"Splitting segment $segment")
-      newSegments += LogCleaner.createNewCleanedSegment(this, segment.baseOffset)
-      while (position < sourceRecords.sizeInBytes) {
-        val currentSegment = newSegments.last
+      var position = 0
+      val sourceRecords = segment.log
 
-        // grow buffers if needed
+      while (position < sourceRecords.sizeInBytes) {
         val firstBatch = sourceRecords.batchesFrom(position).asScala.head
-        if (firstBatch.sizeInBytes > readBuffer.capacity)
-          readBuffer = ByteBuffer.allocate(firstBatch.sizeInBytes)
+        val newSegment = LogCleaner.createNewCleanedSegment(this, firstBatch.baseOffset)
+        newSegments += newSegment
 
-        // get records we want to copy and copy them into the new segment
-        val recordsToCopy = sourceRecords.slice(position, readBuffer.capacity)
-        val copyResult = copyRecordsToSegment(recordsToCopy, currentSegment, readBuffer)
-        position += copyResult.bytesRead
+        val bytesAppended = newSegment.appendFromFile(sourceRecords, position)
+        if (bytesAppended == 0)
+          throw new IllegalStateException(s"Failed to append records from position $position in $segment")
 
-        // create a new segment if there was an overflow
-        copyResult.overflowOffset.foreach(overflowOffset => newSegments += LogCleaner.createNewCleanedSegment(this, overflowOffset))
+        position += bytesAppended
       }
-      require(newSegments.length > 1, s"No offset overflow found for $segment")
 
       // prepare new segments
       var totalSizeOfNewSegments = 0
-      info(s"Split messages from $segment into ${newSegments.length} new segments")
       newSegments.foreach { splitSegment =>
         splitSegment.onBecomeInactiveSegment()
         splitSegment.flush()
         splitSegment.lastModified = segment.lastModified
         totalSizeOfNewSegments += splitSegment.log.sizeInBytes
-        info(s"New segment: $splitSegment")
       }
       // size of all the new segments combined must equal size of the original segment
-      require(totalSizeOfNewSegments == segment.log.sizeInBytes, "Inconsistent segment sizes after split" +
-        s" before: ${segment.log.sizeInBytes} after: $totalSizeOfNewSegments")
+      if (totalSizeOfNewSegments != segment.log.sizeInBytes)
+        throw new IllegalStateException("Inconsistent segment sizes after split" +
+          s" before: ${segment.log.sizeInBytes} after: $totalSizeOfNewSegments")
+
+      // replace old segment with new ones
+      info(s"Replacing overflowed segment $segment with split segments $newSegments")
+      replaceSegments(newSegments.toList, List(segment), isRecoveredSwapFile = false)
+      newSegments.toList
     } catch {
       case e: Exception =>
         newSegments.foreach { splitSegment =>
@@ -1967,10 +1907,6 @@ class Log(@volatile var dir: File,
         }
         throw e
     }
-
-    // replace old segment with new ones
-    replaceSegments(newSegments.toList, List(segment), isRecoveredSwapFile = false)
-    newSegments.toList
   }
 }
 
diff --git a/core/src/main/scala/kafka/log/LogCleaner.scala b/core/src/main/scala/kafka/log/LogCleaner.scala
index d79a840..08bfa4f 100644
--- a/core/src/main/scala/kafka/log/LogCleaner.scala
+++ b/core/src/main/scala/kafka/log/LogCleaner.scala
@@ -513,7 +513,7 @@ private[log] class Cleaner(val id: Int,
           case e: LogSegmentOffsetOverflowException =>
             // Split the current segment. It's also safest to abort the current cleaning process, so that we retry from
             // scratch once the split is complete.
-            info(s"Caught LogSegmentOverflowException during log cleaning $e")
+            info(s"Caught segment overflow error during cleaning: ${e.getMessage}")
             log.splitOverflowedSegment(currentSegment)
             throw new LogCleaningAbortedException()
         }
@@ -529,8 +529,7 @@ private[log] class Cleaner(val id: Int,
       cleaned.lastModified = modified
 
       // swap in new segment
-      info(s"Swapping in cleaned segment ${cleaned.baseOffset} for segment(s) ${segments.map(_.baseOffset).mkString(",")} " +
-        s"in log ${log.name}")
+      info(s"Swapping in cleaned segment $cleaned for segment(s) $segments in log $log")
       log.replaceSegments(List(cleaned), segments)
     } catch {
       case e: LogCleaningAbortedException =>
diff --git a/core/src/main/scala/kafka/log/LogSegment.scala b/core/src/main/scala/kafka/log/LogSegment.scala
index f066106..0b71670 100755
--- a/core/src/main/scala/kafka/log/LogSegment.scala
+++ b/core/src/main/scala/kafka/log/LogSegment.scala
@@ -21,7 +21,7 @@ import java.nio.file.{Files, NoSuchFileException}
 import java.nio.file.attribute.FileTime
 import java.util.concurrent.TimeUnit
 
-import kafka.common.{IndexOffsetOverflowException, LogSegmentOffsetOverflowException}
+import kafka.common.LogSegmentOffsetOverflowException
 import kafka.metrics.{KafkaMetricsGroup, KafkaTimer}
 import kafka.server.epoch.LeaderEpochCache
 import kafka.server.{FetchDataInfo, LogOffsetMetadata}
@@ -132,14 +132,11 @@ class LogSegment private[log] (val log: FileRecords,
       if (physicalPosition == 0)
         rollingBasedTimestamp = Some(largestTimestamp)
 
-      if (!canConvertToRelativeOffset(largestOffset))
-        throw new LogSegmentOffsetOverflowException(
-          s"largest offset $largestOffset cannot be safely converted to relative offset for segment with baseOffset $baseOffset",
-          this)
+      ensureOffsetInRange(largestOffset)
 
       // append the messages
       val appendedBytes = log.append(records)
-      trace(s"Appended $appendedBytes to ${log.file()} at end offset $largestOffset")
+      trace(s"Appended $appendedBytes to ${log.file} at end offset $largestOffset")
       // Update the in memory max timestamp and corresponding offset.
       if (largestTimestamp > maxTimestampSoFar) {
         maxTimestampSoFar = largestTimestamp
@@ -147,14 +144,76 @@ class LogSegment private[log] (val log: FileRecords,
       }
       // append an entry to the index (if needed)
       if (bytesSinceLastIndexEntry > indexIntervalBytes) {
-        appendToOffsetIndex(largestOffset, physicalPosition)
-        maybeAppendToTimeIndex(maxTimestampSoFar, offsetOfMaxTimestamp)
+        offsetIndex.append(largestOffset, physicalPosition)
+        timeIndex.maybeAppend(maxTimestampSoFar, offsetOfMaxTimestamp)
         bytesSinceLastIndexEntry = 0
       }
       bytesSinceLastIndexEntry += records.sizeInBytes
     }
   }
 
+  private def ensureOffsetInRange(offset: Long): Unit = {
+    if (!canConvertToRelativeOffset(offset))
+      throw new LogSegmentOffsetOverflowException(this, offset)
+  }
+
+  private def appendChunkFromFile(records: FileRecords, position: Int, bufferSupplier: BufferSupplier): Int = {
+    var bytesToAppend = 0
+    var maxTimestamp = Long.MinValue
+    var offsetOfMaxTimestamp = Long.MinValue
+    var maxOffset = Long.MinValue
+    var readBuffer = bufferSupplier.get(1024 * 1024)
+
+    def canAppend(batch: RecordBatch) =
+      canConvertToRelativeOffset(batch.lastOffset) &&
+        (bytesToAppend == 0 || bytesToAppend + batch.sizeInBytes < readBuffer.capacity)
+
+    // find all batches that are valid to be appended to the current log segment and
+    // determine the maximum offset and timestamp
+    val nextBatches = records.batchesFrom(position).asScala.iterator
+    for (batch <- nextBatches.takeWhile(canAppend)) {
+      if (batch.maxTimestamp > maxTimestamp) {
+        maxTimestamp = batch.maxTimestamp
+        offsetOfMaxTimestamp = batch.lastOffset
+      }
+      maxOffset = batch.lastOffset
+      bytesToAppend += batch.sizeInBytes
+    }
+
+    if (bytesToAppend > 0) {
+      // Grow buffer if needed to ensure we copy at least one batch
+      if (readBuffer.capacity < bytesToAppend)
+        readBuffer = bufferSupplier.get(bytesToAppend)
+
+      readBuffer.limit(bytesToAppend)
+      records.readInto(readBuffer, position)
+
+      append(maxOffset, maxTimestamp, offsetOfMaxTimestamp, MemoryRecords.readableRecords(readBuffer))
+    }
+
+    bufferSupplier.release(readBuffer)
+    bytesToAppend
+  }
+
+  /**
+   * Append records from a file beginning at the given position until either the end of the file
+   * is reached or an offset is found which is too large to convert to a relative offset for the indexes.
+   *
+   * @return the number of bytes appended to the log (may be less than the size of the input if an
+   *         offset is encountered which would overflow this segment)
+   */
+  def appendFromFile(records: FileRecords, start: Int): Int = {
+    var position = start
+    val bufferSupplier: BufferSupplier = new BufferSupplier.GrowableBufferSupplier
+    while (position < start + records.sizeInBytes) {
+      val bytesAppended = appendChunkFromFile(records, position, bufferSupplier)
+      if (bytesAppended == 0)
+        return position - start
+      position += bytesAppended
+    }
+    position - start
+  }
+
   @nonthreadsafe
   def updateTxnIndex(completedTxn: CompletedTxn, lastStableOffset: Long) {
     if (completedTxn.isAborted) {
@@ -281,6 +340,7 @@ class LogSegment private[log] (val log: FileRecords,
     try {
       for (batch <- log.batches.asScala) {
         batch.ensureValid()
+        ensureOffsetInRange(batch.lastOffset)
 
         // The max timestamp is exposed at the batch level, so no need to iterate the records
         if (batch.maxTimestamp > maxTimestampSoFar) {
@@ -290,8 +350,8 @@ class LogSegment private[log] (val log: FileRecords,
 
         // Build offset index
         if (validBytes - lastIndexEntry > indexIntervalBytes) {
-          appendToOffsetIndex(batch.lastOffset, validBytes)
-          maybeAppendToTimeIndex(maxTimestampSoFar, offsetOfMaxTimestamp)
+          offsetIndex.append(batch.lastOffset, validBytes)
+          timeIndex.maybeAppend(maxTimestampSoFar, offsetOfMaxTimestamp)
           lastIndexEntry = validBytes
         }
         validBytes += batch.sizeInBytes()
@@ -316,7 +376,7 @@ class LogSegment private[log] (val log: FileRecords,
     log.truncateTo(validBytes)
     offsetIndex.trimToValidSize()
     // A normally closed segment always appends the biggest timestamp ever seen into log segment, we do this as well.
-    maybeAppendToTimeIndex(maxTimestampSoFar, offsetOfMaxTimestamp, skipFullCheck = true)
+    timeIndex.maybeAppend(maxTimestampSoFar, offsetOfMaxTimestamp, skipFullCheck = true)
     timeIndex.trimToValidSize()
     truncated
   }
@@ -336,6 +396,14 @@ class LogSegment private[log] (val log: FileRecords,
     }
   }
 
+  /**
+   * Check whether the last offset of the last batch in this segment overflows the indexes.
+   */
+  def hasOverflow: Boolean = {
+    val nextOffset = readNextOffset
+    nextOffset > baseOffset && !canConvertToRelativeOffset(nextOffset - 1)
+  }
+
   def collectAbortedTxns(fetchOffset: Long, upperBoundOffset: Long): TxnIndexSearchResult =
     txnIndex.collectAbortedTxns(fetchOffset, upperBoundOffset)
 
@@ -429,7 +497,7 @@ class LogSegment private[log] (val log: FileRecords,
    * The time index entry appended will be used to decide when to delete the segment.
    */
   def onBecomeInactiveSegment() {
-    maybeAppendToTimeIndex(maxTimestampSoFar, offsetOfMaxTimestamp, skipFullCheck = true)
+    timeIndex.maybeAppend(maxTimestampSoFar, offsetOfMaxTimestamp, skipFullCheck = true)
     offsetIndex.trimToValidSize()
     timeIndex.trimToValidSize()
     log.trim()
@@ -493,7 +561,7 @@ class LogSegment private[log] (val log: FileRecords,
    * Close this log segment
    */
   def close() {
-    CoreUtils.swallow(maybeAppendToTimeIndex(maxTimestampSoFar, offsetOfMaxTimestamp, skipFullCheck = true), this)
+    CoreUtils.swallow(timeIndex.maybeAppend(maxTimestampSoFar, offsetOfMaxTimestamp, skipFullCheck = true), this)
     CoreUtils.swallow(offsetIndex.close(), this)
     CoreUtils.swallow(timeIndex.close(), this)
     CoreUtils.swallow(log.close(), this)
@@ -554,24 +622,6 @@ class LogSegment private[log] (val log: FileRecords,
     Files.setLastModifiedTime(timeIndex.file.toPath, fileTime)
   }
 
-  private def maybeAppendToTimeIndex(timestamp: Long, offset: Long, skipFullCheck: Boolean = false): Unit = {
-    maybeHandleOffsetOverflowException {
-      timeIndex.maybeAppend(timestamp, offset, skipFullCheck)
-    }
-  }
-
-  private def appendToOffsetIndex(offset: Long, position: Int): Unit = {
-    maybeHandleOffsetOverflowException {
-      offsetIndex.append(offset, position)
-    }
-  }
-
-  private def maybeHandleOffsetOverflowException[T](fun: => T): T = {
-    try fun
-    catch {
-      case e: IndexOffsetOverflowException => throw new LogSegmentOffsetOverflowException(e, this)
-    }
-  }
 }
 
 object LogSegment {
diff --git a/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala
index b358c4e..3bfacab 100644
--- a/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala
@@ -1494,8 +1494,17 @@ class GroupMetadataManagerTest {
     EasyMock.expect(logMock.read(EasyMock.eq(startOffset), EasyMock.anyInt(), EasyMock.eq(None),
       EasyMock.eq(true), EasyMock.eq(IsolationLevel.READ_UNCOMMITTED)))
       .andReturn(FetchDataInfo(LogOffsetMetadata(startOffset), fileRecordsMock))
-    EasyMock.expect(fileRecordsMock.readInto(EasyMock.anyObject(classOf[ByteBuffer]), EasyMock.anyInt()))
-      .andReturn(records.buffer)
+
+    val bufferCapture = EasyMock.newCapture[ByteBuffer]
+    fileRecordsMock.readInto(EasyMock.capture(bufferCapture), EasyMock.anyInt())
+    EasyMock.expectLastCall().andAnswer(new IAnswer[Unit] {
+      override def answer: Unit = {
+        val buffer = bufferCapture.getValue
+        buffer.put(records.buffer.duplicate)
+        buffer.flip()
+      }
+    })
+
     EasyMock.replay(fileRecordsMock)
 
     endOffset
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
index 6168077..873b88d 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
@@ -22,21 +22,19 @@ import kafka.coordinator.AbstractCoordinatorConcurrencyTest
 import kafka.coordinator.AbstractCoordinatorConcurrencyTest._
 import kafka.coordinator.transaction.TransactionCoordinatorConcurrencyTest._
 import kafka.log.Log
-import kafka.server.{ DelayedOperationPurgatory, FetchDataInfo, KafkaConfig, LogOffsetMetadata, MetadataCache }
+import kafka.server.{DelayedOperationPurgatory, FetchDataInfo, KafkaConfig, LogOffsetMetadata, MetadataCache}
 import kafka.utils.timer.MockTimer
-import kafka.utils.{ Pool, TestUtils}
-
-import org.apache.kafka.clients.{ ClientResponse, NetworkClient }
-import org.apache.kafka.common.{ Node, TopicPartition }
+import kafka.utils.{Pool, TestUtils}
+import org.apache.kafka.clients.{ClientResponse, NetworkClient}
+import org.apache.kafka.common.{Node, TopicPartition}
 import org.apache.kafka.common.internals.Topic.TRANSACTION_STATE_TOPIC_NAME
-import org.apache.kafka.common.protocol.{ ApiKeys, Errors }
-import org.apache.kafka.common.record.{ CompressionType, FileRecords, MemoryRecords, SimpleRecord }
+import org.apache.kafka.common.protocol.{ApiKeys, Errors}
+import org.apache.kafka.common.record.{CompressionType, FileRecords, MemoryRecords, SimpleRecord}
 import org.apache.kafka.common.requests._
-import org.apache.kafka.common.utils.{ LogContext, MockTime }
-
-import org.easymock.EasyMock
+import org.apache.kafka.common.utils.{LogContext, MockTime}
+import org.easymock.{EasyMock, IAnswer}
 import org.junit.Assert._
-import org.junit.{ After, Before, Test }
+import org.junit.{After, Before, Test}
 
 import scala.collection.Map
 import scala.collection.mutable
@@ -260,8 +258,16 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren
     EasyMock.expect(logMock.read(EasyMock.eq(startOffset), EasyMock.anyInt(), EasyMock.eq(None),
       EasyMock.eq(true), EasyMock.eq(IsolationLevel.READ_UNCOMMITTED)))
       .andReturn(FetchDataInfo(LogOffsetMetadata(startOffset), fileRecordsMock))
-    EasyMock.expect(fileRecordsMock.readInto(EasyMock.anyObject(classOf[ByteBuffer]), EasyMock.anyInt()))
-      .andReturn(records.buffer)
+
+    val bufferCapture = EasyMock.newCapture[ByteBuffer]
+    fileRecordsMock.readInto(EasyMock.capture(bufferCapture), EasyMock.anyInt())
+    EasyMock.expectLastCall().andAnswer(new IAnswer[Unit] {
+      override def answer: Unit = {
+        val buffer = bufferCapture.getValue
+        buffer.put(records.buffer.duplicate)
+        buffer.flip()
+      }
+    })
 
     EasyMock.replay(logMock, fileRecordsMock)
     synchronized {
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
index 20dfaa6..34b82d9 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala
@@ -585,9 +585,16 @@ class TransactionStateManagerTest {
     EasyMock.expect(logMock.read(EasyMock.eq(startOffset), EasyMock.anyInt(), EasyMock.eq(None),
       EasyMock.eq(true), EasyMock.eq(IsolationLevel.READ_UNCOMMITTED)))
       .andReturn(FetchDataInfo(LogOffsetMetadata(startOffset), fileRecordsMock))
-    EasyMock.expect(fileRecordsMock.readInto(EasyMock.anyObject(classOf[ByteBuffer]), EasyMock.anyInt()))
-      .andReturn(records.buffer)
 
+    val bufferCapture = EasyMock.newCapture[ByteBuffer]
+    fileRecordsMock.readInto(EasyMock.capture(bufferCapture), EasyMock.anyInt())
+    EasyMock.expectLastCall().andAnswer(new IAnswer[Unit] {
+      override def answer: Unit = {
+        val buffer = bufferCapture.getValue
+        buffer.put(records.buffer.duplicate)
+        buffer.flip()
+      }
+    })
     EasyMock.replay(logMock, fileRecordsMock, replicaManager)
   }
 
diff --git a/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala b/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
index 3207e15..b351311 100755
--- a/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
@@ -1065,8 +1065,13 @@ class LogCleanerTest extends JUnitSuite {
     logProps.put(LogConfig.FileDeleteDelayMsProp, 1000: java.lang.Integer)
     val config = LogConfig.fromProps(logConfig.originals, logProps)
 
-    val time = new MockTime()
-    val (log, segmentWithOverflow, _) = LogTest.createLogWithOffsetOverflow(dir, new BrokerTopicStats(), Some(config), time.scheduler, time)
+    LogTest.initializeLogDirWithOverflowedSegment(dir)
+
+    val log = makeLog(config = config, recoveryPoint = Long.MaxValue)
+    val segmentWithOverflow = LogTest.firstOverflowSegment(log).getOrElse {
+      fail("Failed to create log with a segment which has overflowed offsets")
+    }
+
     val numSegmentsInitial = log.logSegments.size
     val allKeys = LogTest.keysInLog(log).toList
     val expectedKeysAfterCleaning = mutable.MutableList[Long]()
@@ -1445,7 +1450,7 @@ class LogCleanerTest extends JUnitSuite {
   private def tombstoneRecord(key: Int): MemoryRecords = record(key, null)
 
   private def recoverAndCheck(config: LogConfig, expectedKeys: Iterable[Long]): Log = {
-    LogTest.recoverAndCheck(dir, config, expectedKeys, new BrokerTopicStats())
+    LogTest.recoverAndCheck(dir, config, expectedKeys, new BrokerTopicStats(), time, time.scheduler)
   }
 }
 
diff --git a/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala b/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala
index 79d6ea2..8976c68 100644
--- a/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala
@@ -517,4 +517,34 @@ class LogSegmentTest {
     assertEquals(1, log.records.batches.asScala.size)
   }
 
+  @Test
+  def testAppendFromFile(): Unit = {
+    def records(offset: Long, size: Int): MemoryRecords =
+      MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V2, offset, CompressionType.NONE, TimestampType.CREATE_TIME,
+        new SimpleRecord(new Array[Byte](size)))
+
+    // create a log file in a separate directory to avoid conflicting with created segments
+    val tempDir = TestUtils.tempDir()
+    val fileRecords = FileRecords.open(Log.logFile(tempDir, 0))
+
+    // Simulate a scenario where we have a single log with an offset range exceeding Int.MaxValue
+    fileRecords.append(records(0, 1024))
+    fileRecords.append(records(500, 1024 * 1024 + 1))
+    val sizeBeforeOverflow = fileRecords.sizeInBytes()
+    fileRecords.append(records(Int.MaxValue + 5L, 1024))
+    val sizeAfterOverflow = fileRecords.sizeInBytes()
+
+    val segment = createSegment(0)
+    val bytesAppended = segment.appendFromFile(fileRecords, 0)
+    assertEquals(sizeBeforeOverflow, bytesAppended)
+    assertEquals(sizeBeforeOverflow, segment.size)
+
+    val overflowSegment = createSegment(Int.MaxValue)
+    val overflowBytesAppended = overflowSegment.appendFromFile(fileRecords, sizeBeforeOverflow)
+    assertEquals(sizeAfterOverflow - sizeBeforeOverflow, overflowBytesAppended)
+    assertEquals(overflowBytesAppended, overflowSegment.size)
+
+    Utils.delete(tempDir)
+  }
+
 }
diff --git a/core/src/test/scala/unit/kafka/log/LogTest.scala b/core/src/test/scala/unit/kafka/log/LogTest.scala
index 550b929..f3b4e95 100755
--- a/core/src/test/scala/unit/kafka/log/LogTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogTest.scala
@@ -22,7 +22,6 @@ import java.nio.ByteBuffer
 import java.nio.file.{Files, Paths}
 import java.util.Properties
 
-import org.apache.kafka.common.errors._
 import kafka.common.{OffsetsOutOfOrderException, UnexpectedAppendOffsetException, KafkaException}
 import kafka.log.Log.DeleteDirSuffix
 import kafka.server.epoch.{EpochEntry, LeaderEpochCache, LeaderEpochFileCache}
@@ -39,6 +38,7 @@ import org.apache.kafka.common.utils.{Time, Utils}
 import org.easymock.EasyMock
 import org.junit.Assert._
 import org.junit.{After, Before, Test}
+import org.scalatest.Assertions
 
 import scala.collection.Iterable
 import scala.collection.JavaConverters._
@@ -2118,33 +2118,90 @@ class LogTest {
   def testSplitOnOffsetOverflow(): Unit = {
     // create a log such that one log segment has offsets that overflow, and call the split API on that segment
     val logConfig = LogTest.createLogConfig(indexIntervalBytes = 1, fileDeleteDelayMs = 1000)
-    val (log, segmentWithOverflow, inputRecords) = createLogWithOffsetOverflow(Some(logConfig))
+    val (log, segmentWithOverflow) = createLogWithOffsetOverflow(logConfig)
     assertTrue("At least one segment must have offset overflow", LogTest.hasOffsetOverflow(log))
 
+    val allRecordsBeforeSplit = LogTest.allRecords(log)
+
     // split the segment with overflow
     log.splitOverflowedSegment(segmentWithOverflow)
 
     // assert we were successfully able to split the segment
-    assertEquals(log.numberOfSegments, 4)
-    assertTrue(LogTest.verifyRecordsInLog(log, inputRecords))
+    assertEquals(4, log.numberOfSegments)
+    LogTest.verifyRecordsInLog(log, allRecordsBeforeSplit)
 
     // verify we do not have offset overflow anymore
     assertFalse(LogTest.hasOffsetOverflow(log))
   }
 
   @Test
+  def testDegenerateSegmentSplit(): Unit = {
+    // This tests a scenario where all of the batches appended to a segment have overflowed.
+    // When we split the overflowed segment, only one new segment will be created.
+
+    val overflowOffset = Int.MaxValue + 1L
+    val batch1 = MemoryRecords.withRecords(overflowOffset, CompressionType.NONE, 0,
+      new SimpleRecord("a".getBytes))
+    val batch2 = MemoryRecords.withRecords(overflowOffset + 1, CompressionType.NONE, 0,
+      new SimpleRecord("b".getBytes))
+
+    testDegenerateSplitSegmentWithOverflow(segmentBaseOffset = 0L, List(batch1, batch2))
+  }
+
+  @Test
+  def testDegenerateSegmentSplitWithOutOfRangeBatchLastOffset(): Unit = {
+    // Degenerate case where the only batch in the segment overflows. In this scenario,
+    // the first offset of the batch is valid, but the last overflows.
+
+    val firstBatchBaseOffset = Int.MaxValue - 1
+    val records = MemoryRecords.withRecords(firstBatchBaseOffset, CompressionType.NONE, 0,
+      new SimpleRecord("a".getBytes),
+      new SimpleRecord("b".getBytes),
+      new SimpleRecord("c".getBytes))
+
+    testDegenerateSplitSegmentWithOverflow(segmentBaseOffset = 0L, List(records))
+  }
+
+  private def testDegenerateSplitSegmentWithOverflow(segmentBaseOffset: Long, records: List[MemoryRecords]): Unit = {
+    val segment = LogTest.rawSegment(logDir, segmentBaseOffset)
+    records.foreach(segment.append _)
+    segment.close()
+
+    // Create clean shutdown file so that we do not split during the load
+    createCleanShutdownFile()
+
+    val logConfig = LogTest.createLogConfig(indexIntervalBytes = 1, fileDeleteDelayMs = 1000)
+    val log = createLog(logDir, logConfig, recoveryPoint = Long.MaxValue)
+
+    val segmentWithOverflow = LogTest.firstOverflowSegment(log).getOrElse {
+      Assertions.fail("Failed to create log with a segment which has overflowed offsets")
+    }
+
+    val allRecordsBeforeSplit = LogTest.allRecords(log)
+    log.splitOverflowedSegment(segmentWithOverflow)
+
+    assertEquals(1, log.numberOfSegments)
+
+    val firstBatchBaseOffset = records.head.batches.asScala.head.baseOffset
+    assertEquals(firstBatchBaseOffset, log.activeSegment.baseOffset)
+    LogTest.verifyRecordsInLog(log, allRecordsBeforeSplit)
+
+    assertFalse(LogTest.hasOffsetOverflow(log))
+  }
+
+  @Test
   def testRecoveryOfSegmentWithOffsetOverflow(): Unit = {
     val logConfig = LogTest.createLogConfig(indexIntervalBytes = 1, fileDeleteDelayMs = 1000)
-    var (log, segmentWithOverflow, initialRecords) = createLogWithOffsetOverflow(Some(logConfig))
+    val (log, _) = createLogWithOffsetOverflow(logConfig)
     val expectedKeys = LogTest.keysInLog(log)
 
     // Run recovery on the log. This should split the segment underneath. Ignore .deleted files as we could have still
     // have them lying around after the split.
-    log = LogTest.recoverAndCheck(logDir, logConfig, expectedKeys, brokerTopicStats, expectDeletedFiles = true)
-    assertEquals(expectedKeys, LogTest.keysInLog(log))
+    val recoveredLog = recoverAndCheck(logConfig, expectedKeys)
+    assertEquals(expectedKeys, LogTest.keysInLog(recoveredLog))
 
     // Running split again would throw an error
-    for (segment <- log.logSegments) {
+    for (segment <- recoveredLog.logSegments) {
       try {
         log.splitOverflowedSegment(segment)
         fail()
@@ -2157,7 +2214,7 @@ class LogTest {
   @Test
   def testRecoveryAfterCrashDuringSplitPhase1(): Unit = {
     val logConfig = LogTest.createLogConfig(indexIntervalBytes = 1, fileDeleteDelayMs = 1000)
-    var (log, segmentWithOverflow, initialRecords) = createLogWithOffsetOverflow(Some(logConfig))
+    val (log, segmentWithOverflow) = createLogWithOffsetOverflow(logConfig)
     val expectedKeys = LogTest.keysInLog(log)
     val numSegmentsInitial = log.logSegments.size
 
@@ -2172,16 +2229,17 @@ class LogTest {
     })
     for (file <- logDir.listFiles if file.getName.endsWith(Log.DeletedFileSuffix))
       Utils.atomicMoveWithFallback(file.toPath, Paths.get(CoreUtils.replaceSuffix(file.getPath, Log.DeletedFileSuffix, "")))
-    log = LogTest.recoverAndCheck(logDir, logConfig, expectedKeys, brokerTopicStats, expectDeletedFiles = true)
-    assertEquals(expectedKeys, LogTest.keysInLog(log))
-    assertEquals(numSegmentsInitial + 1, log.logSegments.size)
-    log.close()
+
+    val recoveredLog = recoverAndCheck(logConfig, expectedKeys)
+    assertEquals(expectedKeys, LogTest.keysInLog(recoveredLog))
+    assertEquals(numSegmentsInitial + 1, recoveredLog.logSegments.size)
+    recoveredLog.close()
   }
 
   @Test
   def testRecoveryAfterCrashDuringSplitPhase2(): Unit = {
     val logConfig = LogTest.createLogConfig(indexIntervalBytes = 1, fileDeleteDelayMs = 1000)
-    var (log, segmentWithOverflow, initialRecords) = createLogWithOffsetOverflow(Some(logConfig))
+    val (log, segmentWithOverflow) = createLogWithOffsetOverflow(logConfig)
     val expectedKeys = LogTest.keysInLog(log)
     val numSegmentsInitial = log.logSegments.size
 
@@ -2190,25 +2248,26 @@ class LogTest {
 
     // Simulate recovery just after one of the new segments has been renamed to .swap. On recovery, existing split
     // operation is aborted but the recovery process itself kicks off split which should complete.
-    newSegments.reverse.foreach(segment => {
-      if (segment != newSegments.tail)
+    newSegments.reverse.foreach { segment =>
+      if (segment != newSegments.last)
         segment.changeFileSuffixes("", Log.CleanedFileSuffix)
       else
         segment.changeFileSuffixes("", Log.SwapFileSuffix)
       segment.truncateTo(0)
-    })
+    }
     for (file <- logDir.listFiles if file.getName.endsWith(Log.DeletedFileSuffix))
       Utils.atomicMoveWithFallback(file.toPath, Paths.get(CoreUtils.replaceSuffix(file.getPath, Log.DeletedFileSuffix, "")))
-    log = LogTest.recoverAndCheck(logDir, logConfig, expectedKeys, brokerTopicStats, expectDeletedFiles = true)
-    assertEquals(expectedKeys, LogTest.keysInLog(log))
-    assertEquals(numSegmentsInitial + 1, log.logSegments.size)
-    log.close()
+
+    val recoveredLog = recoverAndCheck(logConfig, expectedKeys)
+    assertEquals(expectedKeys, LogTest.keysInLog(recoveredLog))
+    assertEquals(numSegmentsInitial + 1, recoveredLog.logSegments.size)
+    recoveredLog.close()
   }
 
   @Test
   def testRecoveryAfterCrashDuringSplitPhase3(): Unit = {
     val logConfig = LogTest.createLogConfig(indexIntervalBytes = 1, fileDeleteDelayMs = 1000)
-    var (log, segmentWithOverflow, initialRecords) = createLogWithOffsetOverflow(Some(logConfig))
+    val (log, segmentWithOverflow) = createLogWithOffsetOverflow(logConfig)
     val expectedKeys = LogTest.keysInLog(log)
     val numSegmentsInitial = log.logSegments.size
 
@@ -2226,16 +2285,16 @@ class LogTest {
     // Truncate the old segment
     segmentWithOverflow.truncateTo(0)
 
-    log = LogTest.recoverAndCheck(logDir, logConfig, expectedKeys, brokerTopicStats, expectDeletedFiles = true)
-    assertEquals(expectedKeys, LogTest.keysInLog(log))
-    assertEquals(numSegmentsInitial + 1, log.logSegments.size)
+    val recoveredLog = recoverAndCheck(logConfig, expectedKeys)
+    assertEquals(expectedKeys, LogTest.keysInLog(recoveredLog))
+    assertEquals(numSegmentsInitial + 1, recoveredLog.logSegments.size)
     log.close()
   }
 
   @Test
   def testRecoveryAfterCrashDuringSplitPhase4(): Unit = {
     val logConfig = LogTest.createLogConfig(indexIntervalBytes = 1, fileDeleteDelayMs = 1000)
-    var (log, segmentWithOverflow, initialRecords) = createLogWithOffsetOverflow(Some(logConfig))
+    val (log, segmentWithOverflow) = createLogWithOffsetOverflow(logConfig)
     val expectedKeys = LogTest.keysInLog(log)
     val numSegmentsInitial = log.logSegments.size
 
@@ -2244,25 +2303,24 @@ class LogTest {
 
     // Simulate recovery right after all new segments have been renamed to .swap and old segment has been deleted. On
     // recovery, existing split operation is completed.
-    newSegments.reverse.foreach(segment => {
-      segment.changeFileSuffixes("", Log.SwapFileSuffix)
-    })
+    newSegments.reverse.foreach(_.changeFileSuffixes("", Log.SwapFileSuffix))
+
     for (file <- logDir.listFiles if file.getName.endsWith(Log.DeletedFileSuffix))
       Utils.delete(file)
 
     // Truncate the old segment
     segmentWithOverflow.truncateTo(0)
 
-    log = LogTest.recoverAndCheck(logDir, logConfig, expectedKeys, brokerTopicStats, expectDeletedFiles = true)
-    assertEquals(expectedKeys, LogTest.keysInLog(log))
-    assertEquals(numSegmentsInitial + 1, log.logSegments.size)
-    log.close()
+    val recoveredLog = recoverAndCheck(logConfig, expectedKeys)
+    assertEquals(expectedKeys, LogTest.keysInLog(recoveredLog))
+    assertEquals(numSegmentsInitial + 1, recoveredLog.logSegments.size)
+    recoveredLog.close()
   }
 
   @Test
   def testRecoveryAfterCrashDuringSplitPhase5(): Unit = {
     val logConfig = LogTest.createLogConfig(indexIntervalBytes = 1, fileDeleteDelayMs = 1000)
-    var (log, segmentWithOverflow, initialRecords) = createLogWithOffsetOverflow(Some(logConfig))
+    val (log, segmentWithOverflow) = createLogWithOffsetOverflow(logConfig)
     val expectedKeys = LogTest.keysInLog(log)
     val numSegmentsInitial = log.logSegments.size
 
@@ -2276,10 +2334,10 @@ class LogTest {
     // Truncate the old segment
     segmentWithOverflow.truncateTo(0)
 
-    log = LogTest.recoverAndCheck(logDir, logConfig, expectedKeys, brokerTopicStats, expectDeletedFiles = true)
-    assertEquals(expectedKeys, LogTest.keysInLog(log))
-    assertEquals(numSegmentsInitial + 1, log.logSegments.size)
-    log.close()
+    val recoveredLog = recoverAndCheck(logConfig, expectedKeys)
+    assertEquals(expectedKeys, LogTest.keysInLog(recoveredLog))
+    assertEquals(numSegmentsInitial + 1, recoveredLog.logSegments.size)
+    recoveredLog.close()
   }
 
   @Test
@@ -3390,13 +3448,28 @@ class LogTest {
                         time: Time = mockTime,
                         maxProducerIdExpirationMs: Int = 60 * 60 * 1000,
                         producerIdExpirationCheckIntervalMs: Int = LogManager.ProducerIdExpirationCheckIntervalMs): Log = {
-    return LogTest.createLog(dir, config, brokerTopicStats, scheduler, time, logStartOffset, recoveryPoint,
+    LogTest.createLog(dir, config, brokerTopicStats, scheduler, time, logStartOffset, recoveryPoint,
       maxProducerIdExpirationMs, producerIdExpirationCheckIntervalMs)
   }
 
-  private def createLogWithOffsetOverflow(logConfig: Option[LogConfig]): (Log, LogSegment, List[Record]) = {
-    return LogTest.createLogWithOffsetOverflow(logDir, brokerTopicStats, logConfig, mockTime.scheduler, mockTime)
+  private def createLogWithOffsetOverflow(logConfig: LogConfig): (Log, LogSegment) = {
+    LogTest.initializeLogDirWithOverflowedSegment(logDir)
+
+    val log = createLog(logDir, logConfig, recoveryPoint = Long.MaxValue)
+    val segmentWithOverflow = LogTest.firstOverflowSegment(log).getOrElse {
+      Assertions.fail("Failed to create log with a segment which has overflowed offsets")
+    }
+
+    (log, segmentWithOverflow)
   }
+
+  private def recoverAndCheck(config: LogConfig,
+                              expectedKeys: Iterable[Long],
+                              expectDeletedFiles: Boolean = true): Log = {
+    LogTest.recoverAndCheck(logDir, config, expectedKeys, brokerTopicStats, mockTime, mockTime.scheduler,
+      expectDeletedFiles)
+  }
+
 }
 
 object LogTest {
@@ -3453,133 +3526,77 @@ object LogTest {
    * @param log Log to check
    * @return true if log contains at least one segment with offset overflow; false otherwise
    */
-  def hasOffsetOverflow(log: Log): Boolean = {
-    for (logSegment <- log.logSegments) {
-      val baseOffset = logSegment.baseOffset
-      for (batch <- logSegment.log.batches.asScala) {
-        val it = batch.iterator()
-        while (it.hasNext()) {
-          val record = it.next()
-          if (record.offset > baseOffset + Int.MaxValue || record.offset < baseOffset)
-            return true
-        }
-      }
-    }
-    false
-  }
-
-  /**
-   * Create a log such that one of the log segments has messages with offsets that cause index offset overflow.
-   * @param logDir Directory in which log should be created
-   * @param brokerTopicStats Container for Broker Topic Yammer Metrics
-   * @param logConfigOpt Optional log configuration to use
-   * @param scheduler The thread pool scheduler used for background actions
-   * @param time The time instance to use
-   * @return (1) Created log containing segment with offset overflow, (2) Log segment within log containing messages with
-   *         offset overflow, and (3) List of messages in the log
-   */
-  def createLogWithOffsetOverflow(logDir: File, brokerTopicStats: BrokerTopicStats, logConfigOpt: Option[LogConfig] = None,
-                                  scheduler: Scheduler, time: Time): (Log, LogSegment, List[Record]) = {
-    val logConfig =
-      if (logConfigOpt.isDefined)
-        logConfigOpt.get
-      else
-        createLogConfig(indexIntervalBytes = 1)
-
-    var log = createLog(logDir, logConfig, brokerTopicStats, scheduler, time)
-    val inputRecords = ListBuffer[Record]()
+  def hasOffsetOverflow(log: Log): Boolean = firstOverflowSegment(log).isDefined
 
-    // References to files we want to "merge" to emulate offset overflow
-    val toMerge = ListBuffer[File]()
+  def firstOverflowSegment(log: Log): Option[LogSegment] = {
+    def hasOverflow(baseOffset: Long, batch: RecordBatch): Boolean =
+      batch.lastOffset > baseOffset + Int.MaxValue || batch.baseOffset < baseOffset
 
-    def getRecords(baseOffset: Long): List[MemoryRecords] = {
-      def toBytes(value: Long): Array[Byte] = value.toString.getBytes
-
-      val set1 = MemoryRecords.withRecords(baseOffset, CompressionType.NONE, 0,
-        new SimpleRecord(toBytes(baseOffset), toBytes(baseOffset)))
-      val set2 = MemoryRecords.withRecords(baseOffset + 1, CompressionType.NONE, 0,
-        new SimpleRecord(toBytes(baseOffset + 1), toBytes(baseOffset + 1)),
-        new SimpleRecord(toBytes(baseOffset + 2), toBytes(baseOffset + 2)));
-      val set3 = MemoryRecords.withRecords(baseOffset + Int.MaxValue - 1, CompressionType.NONE, 0,
-        new SimpleRecord(toBytes(baseOffset + Int.MaxValue - 1), toBytes(baseOffset + Int.MaxValue - 1)));
-      List(set1, set2, set3)
+    for (segment <- log.logSegments) {
+      val overflowBatch = segment.log.batches.asScala.find(batch => hasOverflow(segment.baseOffset, batch))
+      if (overflowBatch.isDefined)
+        return Some(segment)
     }
+    None
+  }
 
-    // Append some messages to the log. This will create four log segments.
-    var firstOffset = 0L
-    for (i <- 0 until 4) {
-      val recordsToAppend = getRecords(firstOffset)
-      for (records <- recordsToAppend)
-        log.appendAsFollower(records)
-
-      if (i == 1 || i == 2)
-        toMerge += log.activeSegment.log.file
-
-      firstOffset += Int.MaxValue + 1L
-    }
+  private def rawSegment(logDir: File, baseOffset: Long): FileRecords =
+    FileRecords.open(Log.logFile(logDir, baseOffset))
 
-    // assert that we have the correct number of segments
-    assertEquals(log.numberOfSegments, 4)
+  /**
+   * Initialize the given log directory with a set of segments, one of which will have an
+   * offset which overflows the segment
+   */
+  def initializeLogDirWithOverflowedSegment(logDir: File): Unit = {
+    def writeSampleBatches(baseOffset: Long, segment: FileRecords): Long = {
+      def record(offset: Long) = {
+        val data = offset.toString.getBytes
+        new SimpleRecord(data, data)
+      }
 
-    // assert number of batches
-    for (logSegment <- log.logSegments) {
-      var numBatches = 0
-      for (_ <- logSegment.log.batches.asScala)
-        numBatches += 1
-      assertEquals(numBatches, 3)
+      segment.append(MemoryRecords.withRecords(baseOffset, CompressionType.NONE, 0,
+        record(baseOffset)))
+      segment.append(MemoryRecords.withRecords(baseOffset + 1, CompressionType.NONE, 0,
+        record(baseOffset + 1),
+        record(baseOffset + 2)))
+      segment.append(MemoryRecords.withRecords(baseOffset + Int.MaxValue - 1, CompressionType.NONE, 0,
+        record(baseOffset + Int.MaxValue - 1)))
+      baseOffset + Int.MaxValue
     }
 
-    // create a list of appended records
-    for (logSegment <- log.logSegments) {
-      for (batch <- logSegment.log.batches.asScala) {
-        val it = batch.iterator()
-        while (it.hasNext())
-          inputRecords += it.next()
-      }
+    def writeNormalSegment(baseOffset: Long): Long = {
+      val segment = rawSegment(logDir, baseOffset)
+      try writeSampleBatches(baseOffset, segment)
+      finally segment.close()
     }
 
-    log.flush()
-    log.close()
-
-    // We want to "merge" log segments 1 and 2. This is where the offset overflow will be.
-    // Current: segment #1 | segment #2 | segment #3 | segment# 4
-    // Final: segment #1 | segment #2' | segment #4
-    // where 2' corresponds to segment #2 and segment #3 combined together.
-    // Append segment #3 at the end of segment #2 to create 2'
-    var dest: FileOutputStream = null
-    var source: FileInputStream = null
-    try {
-      dest = new FileOutputStream(toMerge(0), true)
-      source = new FileInputStream(toMerge(1))
-      val sourceBytes = new Array[Byte](toMerge(1).length.toInt)
-      source.read(sourceBytes)
-      dest.write(sourceBytes)
-    } finally {
-      dest.close()
-      source.close()
+    def writeOverflowSegment(baseOffset: Long): Long = {
+      val segment = rawSegment(logDir, baseOffset)
+      try {
+        val nextOffset = writeSampleBatches(baseOffset, segment)
+        writeSampleBatches(nextOffset, segment)
+      } finally segment.close()
     }
 
-    // Delete segment #3 including any index, etc.
-    toMerge(1).delete()
-    log = createLog(logDir, logConfig, brokerTopicStats, scheduler, time, recoveryPoint = Long.MaxValue)
-
-    // assert that there is now one less segment than before, and that the records in the log are same as before
-    assertEquals(log.numberOfSegments, 3)
-    assertTrue(verifyRecordsInLog(log, inputRecords.toList))
-
-    (log, log.logSegments.toList(1), inputRecords.toList)
+    // We create three segments, the second of which contains offsets which overflow
+    var nextOffset = 0L
+    nextOffset = writeNormalSegment(nextOffset)
+    nextOffset = writeOverflowSegment(nextOffset)
+    writeNormalSegment(nextOffset)
   }
 
-  def verifyRecordsInLog(log: Log, expectedRecords: List[Record]): Boolean = {
+  def allRecords(log: Log): List[Record] = {
     val recordsFound = ListBuffer[Record]()
     for (logSegment <- log.logSegments) {
       for (batch <- logSegment.log.batches.asScala) {
-        val it = batch.iterator()
-        while (it.hasNext())
-          recordsFound += it.next()
+        recordsFound ++= batch.iterator().asScala
       }
     }
-    return recordsFound.equals(expectedRecords)
+    recordsFound.toList
+  }
+
+  def verifyRecordsInLog(log: Log, expectedRecords: List[Record]): Unit = {
+    assertEquals(expectedRecords, allRecords(log))
   }
 
   /* extract all the keys from a log */
@@ -3590,12 +3607,16 @@ object LogTest {
       yield TestUtils.readString(record.key).toLong
   }
 
-  def recoverAndCheck(logDir: File, config: LogConfig, expectedKeys: Iterable[Long],
-                      brokerTopicStats: BrokerTopicStats, expectDeletedFiles: Boolean = false): Log = {
-    val time = new MockTime()
+  def recoverAndCheck(logDir: File,
+                      config: LogConfig,
+                      expectedKeys: Iterable[Long],
+                      brokerTopicStats: BrokerTopicStats,
+                      time: Time,
+                      scheduler: Scheduler,
+                      expectDeletedFiles: Boolean = false): Log = {
     // Recover log file and check that after recovery, keys are as expected
     // and all temporary files have been deleted
-    val recoveredLog = createLog(logDir, config, brokerTopicStats, time.scheduler, time)
+    val recoveredLog = createLog(logDir, config, brokerTopicStats, scheduler, time)
     time.sleep(config.fileDeleteDelayMs + 1)
     for (file <- logDir.listFiles) {
       if (!expectDeletedFiles)