You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jg...@apache.org on 2018/06/08 15:40:07 UTC

[kafka] branch trunk updated: KAFKA-6264; Split log segments as needed if offsets overflow the indexes (#4975)

This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new d2b2fbd  KAFKA-6264; Split log segments as needed if offsets overflow the indexes (#4975)
d2b2fbd is described below

commit d2b2fbdf94cf48094081650ebc18ee860e67d8d5
Author: Dhruvil Shah <dh...@confluent.io>
AuthorDate: Fri Jun 8 08:39:59 2018 -0700

    KAFKA-6264; Split log segments as needed if offsets overflow the indexes (#4975)
    
    This patch adds logic to detect and fix segments which have overflowed offsets as a result of bugs in older versions of Kafka.
    
    Reviewers: Jun Rao <ju...@gmail.com>, Jason Gustafson <ja...@confluent.io>
---
 .../apache/kafka/common/record/FileRecords.java    |  11 +-
 .../kafka/common/record/FileRecordsTest.java       |  24 +-
 .../common/IndexOffsetOverflowException.scala      |  25 +
 .../common/LogSegmentOffsetOverflowException.scala |  31 +
 core/src/main/scala/kafka/log/AbstractIndex.scala  |  31 +-
 core/src/main/scala/kafka/log/Log.scala            | 388 ++++++++--
 core/src/main/scala/kafka/log/LogCleaner.scala     |  44 +-
 core/src/main/scala/kafka/log/LogSegment.scala     |  66 +-
 core/src/main/scala/kafka/log/OffsetIndex.scala    |   5 +-
 core/src/main/scala/kafka/log/TimeIndex.scala      |   2 +-
 .../main/scala/kafka/tools/DumpLogSegments.scala   |   4 +-
 .../test/scala/unit/kafka/log/LogCleanerTest.scala | 156 ++--
 core/src/test/scala/unit/kafka/log/LogTest.scala   | 846 +++++++++++++++------
 .../scala/unit/kafka/log/OffsetIndexTest.scala     |  11 +-
 14 files changed, 1184 insertions(+), 460 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/common/record/FileRecords.java b/clients/src/main/java/org/apache/kafka/common/record/FileRecords.java
index e44d5d9..20b5105 100644
--- a/clients/src/main/java/org/apache/kafka/common/record/FileRecords.java
+++ b/clients/src/main/java/org/apache/kafka/common/record/FileRecords.java
@@ -131,7 +131,7 @@ public class FileRecords extends AbstractRecords implements Closeable {
      * @param size The number of bytes after the start position to include
      * @return A sliced wrapper on this message set limited based on the given position and size
      */
-    public FileRecords read(int position, int size) throws IOException {
+    public FileRecords slice(int position, int size) throws IOException {
         if (position < 0)
             throw new IllegalArgumentException("Invalid position: " + position + " in read from " + file);
         if (size < 0)
@@ -356,7 +356,14 @@ public class FileRecords extends AbstractRecords implements Closeable {
                 ")";
     }
 
-    private Iterable<FileChannelRecordBatch> batchesFrom(final int start) {
+    /**
+     * Get an iterator over the record batches in the file, starting at a specific position. This is similar to
+     * {@link #batches()} except that callers specify a particular position to start reading the batches from. This
+     * method must be used with caution: the start position passed in must be a known start of a batch.
+     * @param start The position to start record iteration from; must be a known position for start of a batch
+     * @return An iterator over batches starting from {@code start}
+     */
+    public Iterable<FileChannelRecordBatch> batchesFrom(final int start) {
         return new Iterable<FileChannelRecordBatch>() {
             @Override
             public Iterator<FileChannelRecordBatch> iterator() {
diff --git a/clients/src/test/java/org/apache/kafka/common/record/FileRecordsTest.java b/clients/src/test/java/org/apache/kafka/common/record/FileRecordsTest.java
index f8b6dd4..bbe84b2 100644
--- a/clients/src/test/java/org/apache/kafka/common/record/FileRecordsTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/record/FileRecordsTest.java
@@ -121,7 +121,7 @@ public class FileRecordsTest {
      */
     @Test
     public void testRead() throws IOException {
-        FileRecords read = fileRecords.read(0, fileRecords.sizeInBytes());
+        FileRecords read = fileRecords.slice(0, fileRecords.sizeInBytes());
         assertEquals(fileRecords.sizeInBytes(), read.sizeInBytes());
         TestUtils.checkEquals(fileRecords.batches(), read.batches());
 
@@ -129,35 +129,35 @@ public class FileRecordsTest {
         RecordBatch first = items.get(0);
 
         // read from second message until the end
-        read = fileRecords.read(first.sizeInBytes(), fileRecords.sizeInBytes() - first.sizeInBytes());
+        read = fileRecords.slice(first.sizeInBytes(), fileRecords.sizeInBytes() - first.sizeInBytes());
         assertEquals(fileRecords.sizeInBytes() - first.sizeInBytes(), read.sizeInBytes());
         assertEquals("Read starting from the second message", items.subList(1, items.size()), batches(read));
 
         // read from second message and size is past the end of the file
-        read = fileRecords.read(first.sizeInBytes(), fileRecords.sizeInBytes());
+        read = fileRecords.slice(first.sizeInBytes(), fileRecords.sizeInBytes());
         assertEquals(fileRecords.sizeInBytes() - first.sizeInBytes(), read.sizeInBytes());
         assertEquals("Read starting from the second message", items.subList(1, items.size()), batches(read));
 
         // read from second message and position + size overflows
-        read = fileRecords.read(first.sizeInBytes(), Integer.MAX_VALUE);
+        read = fileRecords.slice(first.sizeInBytes(), Integer.MAX_VALUE);
         assertEquals(fileRecords.sizeInBytes() - first.sizeInBytes(), read.sizeInBytes());
         assertEquals("Read starting from the second message", items.subList(1, items.size()), batches(read));
 
         // read from second message and size is past the end of the file on a view/slice
-        read = fileRecords.read(1, fileRecords.sizeInBytes() - 1)
-                .read(first.sizeInBytes() - 1, fileRecords.sizeInBytes());
+        read = fileRecords.slice(1, fileRecords.sizeInBytes() - 1)
+                .slice(first.sizeInBytes() - 1, fileRecords.sizeInBytes());
         assertEquals(fileRecords.sizeInBytes() - first.sizeInBytes(), read.sizeInBytes());
         assertEquals("Read starting from the second message", items.subList(1, items.size()), batches(read));
 
         // read from second message and position + size overflows on a view/slice
-        read = fileRecords.read(1, fileRecords.sizeInBytes() - 1)
-                .read(first.sizeInBytes() - 1, Integer.MAX_VALUE);
+        read = fileRecords.slice(1, fileRecords.sizeInBytes() - 1)
+                .slice(first.sizeInBytes() - 1, Integer.MAX_VALUE);
         assertEquals(fileRecords.sizeInBytes() - first.sizeInBytes(), read.sizeInBytes());
         assertEquals("Read starting from the second message", items.subList(1, items.size()), batches(read));
 
         // read a single message starting from second message
         RecordBatch second = items.get(1);
-        read = fileRecords.read(first.sizeInBytes(), second.sizeInBytes());
+        read = fileRecords.slice(first.sizeInBytes(), second.sizeInBytes());
         assertEquals(second.sizeInBytes(), read.sizeInBytes());
         assertEquals("Read a single message starting from the second message",
                 Collections.singletonList(second), batches(read));
@@ -207,9 +207,9 @@ public class FileRecordsTest {
         RecordBatch batch = batches(fileRecords).get(1);
         int start = fileRecords.searchForOffsetWithSize(1, 0).position;
         int size = batch.sizeInBytes();
-        FileRecords slice = fileRecords.read(start, size);
+        FileRecords slice = fileRecords.slice(start, size);
         assertEquals(Collections.singletonList(batch), batches(slice));
-        FileRecords slice2 = fileRecords.read(start, size - 1);
+        FileRecords slice2 = fileRecords.slice(start, size - 1);
         assertEquals(Collections.emptyList(), batches(slice2));
     }
 
@@ -344,7 +344,7 @@ public class FileRecordsTest {
         RecordBatch batch = batches(fileRecords).get(1);
         int start = fileRecords.searchForOffsetWithSize(1, 0).position;
         int size = batch.sizeInBytes();
-        FileRecords slice = fileRecords.read(start, size - 1);
+        FileRecords slice = fileRecords.slice(start, size - 1);
         Records messageV0 = slice.downConvert(RecordBatch.MAGIC_VALUE_V0, 0, time).records();
         assertTrue("No message should be there", batches(messageV0).isEmpty());
         assertEquals("There should be " + (size - 1) + " bytes", size - 1, messageV0.sizeInBytes());
diff --git a/core/src/main/scala/kafka/common/IndexOffsetOverflowException.scala b/core/src/main/scala/kafka/common/IndexOffsetOverflowException.scala
new file mode 100644
index 0000000..7f3ea11
--- /dev/null
+++ b/core/src/main/scala/kafka/common/IndexOffsetOverflowException.scala
@@ -0,0 +1,25 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kafka.common
+
+/**
+ * Indicates that an attempt was made to append a message whose offset could cause the index offset to overflow.
+ */
+class IndexOffsetOverflowException(message: String, cause: Throwable) extends KafkaException(message, cause) {
+  def this(message: String) = this(message, null)
+}
diff --git a/core/src/main/scala/kafka/common/LogSegmentOffsetOverflowException.scala b/core/src/main/scala/kafka/common/LogSegmentOffsetOverflowException.scala
new file mode 100644
index 0000000..62379de
--- /dev/null
+++ b/core/src/main/scala/kafka/common/LogSegmentOffsetOverflowException.scala
@@ -0,0 +1,31 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kafka.common
+
+import kafka.log.LogSegment
+
+/**
+ * Indicates that the log segment contains one or more messages that overflow the offset (and / or time) index. This is
+ * not a typical scenario, and could only happen when brokers have log segments that were created before the patch for
+ * KAFKA-5413. With KAFKA-6264, we have the ability to split such log segments into multiple log segments such that we
+ * do not have any segments with offset overflow.
+ */
+class LogSegmentOffsetOverflowException(message: String, cause: Throwable, val logSegment: LogSegment) extends KafkaException(message, cause) {
+  def this(cause: Throwable, logSegment: LogSegment) = this(null, cause, logSegment)
+  def this(message: String, logSegment: LogSegment) = this(message, null, logSegment)
+}
diff --git a/core/src/main/scala/kafka/log/AbstractIndex.scala b/core/src/main/scala/kafka/log/AbstractIndex.scala
index 44083c1..95f0749 100644
--- a/core/src/main/scala/kafka/log/AbstractIndex.scala
+++ b/core/src/main/scala/kafka/log/AbstractIndex.scala
@@ -18,11 +18,12 @@
 package kafka.log
 
 import java.io.{File, RandomAccessFile}
-import java.nio.{ByteBuffer, MappedByteBuffer}
 import java.nio.channels.FileChannel
 import java.nio.file.Files
+import java.nio.{ByteBuffer, MappedByteBuffer}
 import java.util.concurrent.locks.{Lock, ReentrantLock}
 
+import kafka.common.IndexOffsetOverflowException
 import kafka.log.IndexSearchType.IndexSearchEntity
 import kafka.utils.CoreUtils.inLock
 import kafka.utils.{CoreUtils, Logging}
@@ -226,6 +227,26 @@ abstract class AbstractIndex[K, V](@volatile var file: File, val baseOffset: Lon
     resize(maxIndexSize)
   }
 
+  /**
+   * Get offset relative to base offset of this index
+   * @throws IndexOffsetOverflowException
+   */
+  def relativeOffset(offset: Long): Int = {
+    val relativeOffset = toRelative(offset)
+    if (relativeOffset.isEmpty)
+      throw new IndexOffsetOverflowException(s"Integer overflow for offset: $offset (${file.getAbsoluteFile})")
+    relativeOffset.get
+  }
+
+  /**
+   * Check if a particular offset is valid to be appended to this index.
+   * @param offset The offset to check
+   * @return true if this offset is valid to be appended to this index; false otherwise
+   */
+  def canAppendOffset(offset: Long): Boolean = {
+    toRelative(offset).isDefined
+  }
+
   protected def safeForceUnmap(): Unit = {
     try forceUnmap()
     catch {
@@ -325,6 +346,14 @@ abstract class AbstractIndex[K, V](@volatile var file: File, val baseOffset: Lon
    */
   private def roundDownToExactMultiple(number: Int, factor: Int) = factor * (number / factor)
 
+  private def toRelative(offset: Long): Option[Int] = {
+    val relativeOffset = offset - baseOffset
+    if (relativeOffset < 0 || relativeOffset > Int.MaxValue)
+      None
+    else
+      Some(relativeOffset.toInt)
+  }
+
 }
 
 object IndexSearchType extends Enumeration {
diff --git a/core/src/main/scala/kafka/log/Log.scala b/core/src/main/scala/kafka/log/Log.scala
index 777dbb5..c7d2a6e 100644
--- a/core/src/main/scala/kafka/log/Log.scala
+++ b/core/src/main/scala/kafka/log/Log.scala
@@ -18,33 +18,34 @@
 package kafka.log
 
 import java.io.{File, IOException}
+import java.lang.{Long => JLong}
+import java.nio.ByteBuffer
 import java.nio.file.{Files, NoSuchFileException}
 import java.text.NumberFormat
+import java.util.Map.{Entry => JEntry}
 import java.util.concurrent.atomic._
 import java.util.concurrent.{ConcurrentNavigableMap, ConcurrentSkipListMap, TimeUnit}
+import java.util.regex.Pattern
 
+import com.yammer.metrics.core.Gauge
 import kafka.api.KAFKA_0_10_0_IV0
-import kafka.common.{InvalidOffsetException, KafkaException, LongRef}
+import kafka.common.{InvalidOffsetException, KafkaException, LogSegmentOffsetOverflowException, LongRef}
+import kafka.message.{BrokerCompressionCodec, CompressionCodec, NoCompressionCodec}
 import kafka.metrics.KafkaMetricsGroup
+import kafka.server.checkpoints.{LeaderEpochCheckpointFile, LeaderEpochFile}
+import kafka.server.epoch.{LeaderEpochCache, LeaderEpochFileCache}
 import kafka.server.{BrokerTopicStats, FetchDataInfo, LogDirFailureChannel, LogOffsetMetadata}
 import kafka.utils._
+import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.errors.{CorruptRecordException, KafkaStorageException, OffsetOutOfRangeException, RecordBatchTooLargeException, RecordTooLargeException, UnsupportedForMessageFormatException}
 import org.apache.kafka.common.record._
+import org.apache.kafka.common.requests.FetchResponse.AbortedTransaction
 import org.apache.kafka.common.requests.{IsolationLevel, ListOffsetRequest}
+import org.apache.kafka.common.utils.{Time, Utils}
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable.{ArrayBuffer, ListBuffer}
 import scala.collection.{Seq, Set, mutable}
-import com.yammer.metrics.core.Gauge
-import org.apache.kafka.common.utils.{Time, Utils}
-import kafka.message.{BrokerCompressionCodec, CompressionCodec, NoCompressionCodec}
-import kafka.server.checkpoints.{LeaderEpochCheckpointFile, LeaderEpochFile}
-import kafka.server.epoch.{LeaderEpochCache, LeaderEpochFileCache}
-import org.apache.kafka.common.TopicPartition
-import org.apache.kafka.common.requests.FetchResponse.AbortedTransaction
-import java.util.Map.{Entry => JEntry}
-import java.lang.{Long => JLong}
-import java.util.regex.Pattern
 
 object LogAppendInfo {
   val UnknownLogAppendInfo = LogAppendInfo(None, -1, RecordBatch.NO_TIMESTAMP, -1L, RecordBatch.NO_TIMESTAMP, -1L,
@@ -85,15 +86,15 @@ case class LogAppendInfo(var firstOffset: Option[Long],
                          validBytes: Int,
                          offsetsMonotonic: Boolean) {
   /**
-    * Get the first offset if it exists, else get the last offset.
-    * @return The offset of first message if it exists; else offset of the last message.
-    */
+   * Get the first offset if it exists, else get the last offset.
+   * @return The offset of first message if it exists; else offset of the last message.
+   */
   def firstOrLastOffset: Long = firstOffset.getOrElse(lastOffset)
 
   /**
-    * Get the (maximum) number of messages described by LogAppendInfo
-    * @return Maximum possible number of messages described by LogAppendInfo
-    */
+   * Get the (maximum) number of messages described by LogAppendInfo
+   * @return Maximum possible number of messages described by LogAppendInfo
+   */
   def numMessages: Long = {
     firstOffset match {
       case Some(firstOffsetVal) if (firstOffsetVal >= 0 && lastOffset >= 0) => (lastOffset - firstOffsetVal + 1)
@@ -157,7 +158,7 @@ class Log(@volatile var dir: File,
           @volatile var recoveryPoint: Long,
           scheduler: Scheduler,
           brokerTopicStats: BrokerTopicStats,
-          time: Time,
+          val time: Time,
           val maxProducerIdExpirationMs: Int,
           val producerIdExpirationCheckIntervalMs: Int,
           val topicPartition: TopicPartition,
@@ -295,42 +296,79 @@ class Log(@volatile var dir: File,
       new LeaderEpochCheckpointFile(LeaderEpochFile.newFile(dir), logDirFailureChannel))
   }
 
+  /**
+   * Removes any temporary files found in log directory, and creates a list of all .swap files which could be swapped
+   * in place of existing segment(s). For log splitting, we know that any .swap file whose base offset is higher than
+   * the smallest offset .clean file could be part of an incomplete split operation. Such .swap files are also deleted
+   * by this method.
+   * @return Set of .swap files that are valid to be swapped in as segment files
+   */
   private def removeTempFilesAndCollectSwapFiles(): Set[File] = {
 
-    def deleteIndicesIfExist(baseFile: File, swapFile: File, fileType: String): Unit = {
-      info(s"Found $fileType file ${swapFile.getAbsolutePath} from interrupted swap operation. Deleting index files (if they exist).")
+    def deleteIndicesIfExist(baseFile: File, suffix: String = ""): Unit = {
+      info(s"Deleting index files with suffix $suffix for baseFile $baseFile")
       val offset = offsetFromFile(baseFile)
-      Files.deleteIfExists(Log.offsetIndexFile(dir, offset).toPath)
-      Files.deleteIfExists(Log.timeIndexFile(dir, offset).toPath)
-      Files.deleteIfExists(Log.transactionIndexFile(dir, offset).toPath)
+      Files.deleteIfExists(Log.offsetIndexFile(dir, offset, suffix).toPath)
+      Files.deleteIfExists(Log.timeIndexFile(dir, offset, suffix).toPath)
+      Files.deleteIfExists(Log.transactionIndexFile(dir, offset, suffix).toPath)
     }
 
     var swapFiles = Set[File]()
+    var cleanFiles = Set[File]()
+    var minCleanedFileOffset = Long.MaxValue
 
     for (file <- dir.listFiles if file.isFile) {
       if (!file.canRead)
         throw new IOException(s"Could not read file $file")
       val filename = file.getName
-      if (filename.endsWith(DeletedFileSuffix) || filename.endsWith(CleanedFileSuffix)) {
+      if (filename.endsWith(DeletedFileSuffix)) {
         debug(s"Deleting stray temporary file ${file.getAbsolutePath}")
         Files.deleteIfExists(file.toPath)
+      } else if (filename.endsWith(CleanedFileSuffix)) {
+        minCleanedFileOffset = Math.min(offsetFromFileName(filename), minCleanedFileOffset)
+        cleanFiles += file
       } else if (filename.endsWith(SwapFileSuffix)) {
         // we crashed in the middle of a swap operation, to recover:
         // if a log, delete the index files, complete the swap operation later
         // if an index just delete the index files, they will be rebuilt
         val baseFile = new File(CoreUtils.replaceSuffix(file.getPath, SwapFileSuffix, ""))
+        info(s"Found file ${file.getAbsolutePath} from interrupted swap operation.")
         if (isIndexFile(baseFile)) {
-          deleteIndicesIfExist(baseFile, file, "index")
+          deleteIndicesIfExist(baseFile)
         } else if (isLogFile(baseFile)) {
-          deleteIndicesIfExist(baseFile, file, "log")
+          deleteIndicesIfExist(baseFile)
           swapFiles += file
         }
       }
     }
-    swapFiles
+
+    // KAFKA-6264: Delete all .swap files whose base offset is greater than the minimum .cleaned segment offset. Such .swap
+    // files could be part of an incomplete split operation that could not complete. See Log#splitOverflowedSegment
+    // for more details about the split operation.
+    val (invalidSwapFiles, validSwapFiles) = swapFiles.partition(file => offsetFromFile(file) >= minCleanedFileOffset)
+    invalidSwapFiles.foreach { file =>
+      debug(s"Deleting invalid swap file ${file.getAbsoluteFile} minCleanedFileOffset: $minCleanedFileOffset")
+      val baseFile = new File(CoreUtils.replaceSuffix(file.getPath, SwapFileSuffix, ""))
+      deleteIndicesIfExist(baseFile, SwapFileSuffix)
+      Files.deleteIfExists(file.toPath)
+    }
+
+    // Now that we have deleted all .swap files that constitute an incomplete split operation, let's delete all .clean files
+    cleanFiles.foreach { file =>
+      debug(s"Deleting stray .clean file ${file.getAbsolutePath}")
+      Files.deleteIfExists(file.toPath)
+    }
+
+    validSwapFiles
   }
 
-  // This method does not need to convert IOException to KafkaStorageException because it is only called before all logs are loaded
+  /**
+   * This method does not need to convert IOException to KafkaStorageException because it is only called before all logs are loaded
+   * It is possible that we encounter a segment with index offset overflow in which case the LogSegmentOffsetOverflowException
+   * will be thrown. Note that any segments that were opened before we encountered the exception will remain open and the
+   * caller is responsible for closing them appropriately, if needed.
+   * @throws LogSegmentOffsetOverflowException if the log directory contains a segment with messages that overflow the index offset
+   */
   private def loadSegmentFiles(): Unit = {
     // load segments in ascending order because transactional data from one segment may depend on the
     // segments that come before it
@@ -369,6 +407,13 @@ class Log(@volatile var dir: File,
     }
   }
 
+  /**
+   * Recover the given segment.
+   * @param segment Segment to recover
+   * @param leaderEpochCache Optional cache for updating the leader epoch during recovery
+   * @return The number of bytes truncated from the segment
+   * @throws LogSegmentOffsetOverflowException if the segment contains messages that cause index offset overflow
+   */
   private def recoverSegment(segment: LogSegment, leaderEpochCache: Option[LeaderEpochCache] = None): Int = lock synchronized {
     val stateManager = new ProducerStateManager(topicPartition, dir, maxProducerIdExpirationMs)
     stateManager.truncateAndReload(logStartOffset, segment.baseOffset, time.milliseconds)
@@ -383,7 +428,6 @@ class Log(@volatile var dir: File,
     // take a snapshot for the first recovered segment to avoid reloading all the segments if we shutdown before we
     // checkpoint the recovery point
     stateManager.takeSnapshot()
-
     val bytesTruncated = segment.recover(stateManager, leaderEpochCache)
 
     // once we have recovered the segment's data, take a snapshot to ensure that we won't
@@ -392,7 +436,16 @@ class Log(@volatile var dir: File,
     bytesTruncated
   }
 
-  // This method does not need to convert IOException to KafkaStorageException because it is only called before all logs are loaded
+  /**
+   * This method does not need to convert IOException to KafkaStorageException because it is only called before all logs
+   * are loaded.
+   * @throws LogSegmentOffsetOverflowException if the swap file contains messages that cause the log segment offset to
+   *                                           overflow. Note that this is currently a fatal exception as we do not have
+   *                                           a way to deal with it. The exception is propagated all the way up to
+   *                                           KafkaServer#startup which will cause the broker to shut down if we are in
+   *                                           this situation. This is expected to be an extremely rare scenario in practice,
+   *                                           and manual intervention might be required to get out of it.
+   */
   private def completeSwapOperations(swapFiles: Set[File]): Unit = {
     for (swapFile <- swapFiles) {
       val logFile = new File(CoreUtils.replaceSuffix(swapFile.getPath, SwapFileSuffix, ""))
@@ -404,20 +457,49 @@ class Log(@volatile var dir: File,
         fileSuffix = SwapFileSuffix)
       info(s"Found log file ${swapFile.getPath} from interrupted swap operation, repairing.")
       recoverSegment(swapSegment)
-      val oldSegments = logSegments(swapSegment.baseOffset, swapSegment.readNextOffset)
-      replaceSegments(swapSegment, oldSegments.toSeq, isRecoveredSwapFile = true)
+
+      var oldSegments = logSegments(swapSegment.baseOffset, swapSegment.readNextOffset)
+
+      // We create swap files for two cases: (1) Log cleaning where multiple segments are merged into one, and
+      // (2) Log splitting where one segment is split into multiple.
+      // Both of these mean that the resultant swap segments be composed of the original set, i.e. the swap segment
+      // must fall within the range of existing segment(s). If we cannot find such a segment, it means the deletion
+      // of that segment was successful. In such an event, we should simply rename the .swap to .log without having to
+      // do a replace with an existing segment.
+      if (oldSegments.nonEmpty) {
+        val start = oldSegments.head.baseOffset
+        val end = oldSegments.last.readNextOffset
+        if (!(swapSegment.baseOffset >= start && swapSegment.baseOffset <= end))
+          oldSegments = List()
+      }
+
+      replaceSegments(Seq(swapSegment), oldSegments.toSeq, isRecoveredSwapFile = true)
     }
   }
 
-  // Load the log segments from the log files on disk and return the next offset
-  // This method does not need to convert IOException to KafkaStorageException because it is only called before all logs are loaded
+  /**
+   * Load the log segments from the log files on disk and return the next offset.
+   * This method does not need to convert IOException to KafkaStorageException because it is only called before all logs
+   * are loaded.
+   * @throws LogSegmentOffsetOverflowException if we encounter a .swap file with messages that overflow index offset; or when
+   *                                           we find an unexpected number of .log files with overflow
+   */
   private def loadSegments(): Long = {
     // first do a pass through the files in the log directory and remove any temporary files
     // and find any interrupted swap operations
     val swapFiles = removeTempFilesAndCollectSwapFiles()
 
-    // now do a second pass and load all the log and index files
-    loadSegmentFiles()
+    // Now do a second pass and load all the log and index files.
+    // We might encounter legacy log segments with offset overflow (KAFKA-6264). We need to split such segments. Whe
+    // this happens, restart loading segment files from scratch.
+    retryOnOffsetOverflow {
+      // In case we encounter a segment with offset overflow, the retry logic will split it after which we need to retry
+      // loading of segments. In that case, we also need to close all segments that could have been left open in previous
+      // call to loadSegmentFiles().
+      logSegments.foreach(_.close())
+      segments.clear()
+      loadSegmentFiles()
+    }
 
     // Finally, complete any interrupted swap operations. To be crash-safe,
     // log files that are replaced by the swap segment should be renamed to .deleted
@@ -435,7 +517,10 @@ class Log(@volatile var dir: File,
         preallocate = config.preallocate))
       0
     } else if (!dir.getAbsolutePath.endsWith(Log.DeleteDirSuffix)) {
-      val nextOffset = recoverLog()
+      val nextOffset = retryOnOffsetOverflow {
+        recoverLog()
+      }
+
       // reset the index size of the currently active log segment to allow more entries
       activeSegment.resizeIndexes(config.maxIndexSize)
       nextOffset
@@ -448,9 +533,9 @@ class Log(@volatile var dir: File,
 
   /**
    * Recover the log segments and return the next offset after recovery.
-   *
    * This method does not need to convert IOException to KafkaStorageException because it is only called before all
    * logs are loaded.
+   * @throws LogSegmentOffsetOverflowException if we encountered a legacy segment with offset overflow
    */
   private def recoverLog(): Long = {
     // if we have the clean shutdown marker, skip recovery
@@ -585,10 +670,10 @@ class Log(@volatile var dir: File,
   }
 
   /**
-    * Rename the directory of the log
-    *
-    * @throws KafkaStorageException if rename fails
-    */
+   * Rename the directory of the log
+   *
+   * @throws KafkaStorageException if rename fails
+   */
   def renameDir(name: String) {
     lock synchronized {
       maybeHandleIOException(s"Error while renaming dir for $topicPartition in log dir ${dir.getParent}") {
@@ -1315,9 +1400,9 @@ class Log(@volatile var dir: File,
 
     if (segment.shouldRoll(messagesSize, maxTimestampInMessages, maxOffsetInMessages, now)) {
       debug(s"Rolling new log segment (log_size = ${segment.size}/${config.segmentSize}}, " +
-          s"offset_index_size = ${segment.offsetIndex.entries}/${segment.offsetIndex.maxEntries}, " +
-          s"time_index_size = ${segment.timeIndex.entries}/${segment.timeIndex.maxEntries}, " +
-          s"inactive_time_ms = ${segment.timeWaitedForRoll(now, maxTimestampInMessages)}/${config.segmentMs - segment.rollJitterMs}).")
+        s"offset_index_size = ${segment.offsetIndex.entries}/${segment.offsetIndex.maxEntries}, " +
+        s"time_index_size = ${segment.timeIndex.entries}/${segment.timeIndex.maxEntries}, " +
+        s"inactive_time_ms = ${segment.timeWaitedForRoll(now, maxTimestampInMessages)}/${config.segmentMs - segment.rollJitterMs}).")
 
       /*
         maxOffsetInMessages - Integer.MAX_VALUE is a heuristic value for the first offset in the set of messages.
@@ -1644,51 +1729,59 @@ class Log(@volatile var dir: File,
   }
 
   /**
-   * Swap a new segment in place and delete one or more existing segments in a crash-safe manner. The old segments will
-   * be asynchronously deleted.
+   * Swap one or more new segment in place and delete one or more existing segments in a crash-safe manner. The old
+   * segments will be asynchronously deleted.
    *
    * This method does not need to convert IOException to KafkaStorageException because it is either called before all logs are loaded
    * or the caller will catch and handle IOException
    *
    * The sequence of operations is:
    * <ol>
-   *   <li> Cleaner creates new segment with suffix .cleaned and invokes replaceSegments().
+   *   <li> Cleaner creates one or more new segments with suffix .cleaned and invokes replaceSegments().
    *        If broker crashes at this point, the clean-and-swap operation is aborted and
-   *        the .cleaned file is deleted on recovery in loadSegments().
-   *   <li> New segment is renamed .swap. If the broker crashes after this point before the whole
-   *        operation is completed, the swap operation is resumed on recovery as described in the next step.
+   *        the .cleaned files are deleted on recovery in loadSegments().
+   *   <li> New segments are renamed .swap. If the broker crashes before all segments were renamed to .swap, the
+   *        clean-and-swap operation is aborted - .cleaned as well as .swap files are deleted on recovery in
+   *        loadSegments(). We detect this situation by maintaining a specific order in which files are renamed from
+   *        .cleaned to .swap. Basically, files are renamed in descending order of offsets. On recovery, all .swap files
+   *        whose offset is greater than the minimum-offset .clean file are deleted.
+   *   <li> If the broker crashes after all new segments were renamed to .swap, the operation is completed, the swap
+   *        operation is resumed on recovery as described in the next step.
    *   <li> Old segment files are renamed to .deleted and asynchronous delete is scheduled.
    *        If the broker crashes, any .deleted files left behind are deleted on recovery in loadSegments().
    *        replaceSegments() is then invoked to complete the swap with newSegment recreated from
    *        the .swap file and oldSegments containing segments which were not renamed before the crash.
-   *   <li> Swap segment is renamed to replace the existing segment, completing this operation.
+   *   <li> Swap segment(s) are renamed to replace the existing segments, completing this operation.
    *        If the broker crashes, any .deleted files which may be left behind are deleted
    *        on recovery in loadSegments().
    * </ol>
    *
-   * @param newSegment The new log segment to add to the log
+   * @param newSegments The new log segment to add to the log
    * @param oldSegments The old log segments to delete from the log
    * @param isRecoveredSwapFile true if the new segment was created from a swap file during recovery after a crash
    */
-  private[log] def replaceSegments(newSegment: LogSegment, oldSegments: Seq[LogSegment], isRecoveredSwapFile: Boolean = false) {
+  private[log] def replaceSegments(newSegments: Seq[LogSegment], oldSegments: Seq[LogSegment], isRecoveredSwapFile: Boolean = false) {
+    val sortedNewSegments = newSegments.sortBy(_.baseOffset)
+    val sortedOldSegments = oldSegments.sortBy(_.baseOffset)
+
     lock synchronized {
       checkIfMemoryMappedBufferClosed()
       // need to do this in two phases to be crash safe AND do the delete asynchronously
       // if we crash in the middle of this we complete the swap in loadSegments()
       if (!isRecoveredSwapFile)
-        newSegment.changeFileSuffixes(Log.CleanedFileSuffix, Log.SwapFileSuffix)
-      addSegment(newSegment)
+        sortedNewSegments.reverse.foreach(_.changeFileSuffixes(Log.CleanedFileSuffix, Log.SwapFileSuffix))
+      sortedNewSegments.reverse.foreach(addSegment(_))
 
       // delete the old files
-      for (seg <- oldSegments) {
+      for (seg <- sortedOldSegments) {
         // remove the index entry
-        if (seg.baseOffset != newSegment.baseOffset)
+        if (seg.baseOffset != sortedNewSegments.head.baseOffset)
           segments.remove(seg.baseOffset)
         // delete segment
         asyncDeleteSegment(seg)
       }
       // okay we are safe now, remove the swap suffix
-      newSegment.changeFileSuffixes(Log.SwapFileSuffix, "")
+      sortedNewSegments.foreach(_.changeFileSuffixes(Log.SwapFileSuffix, ""))
     }
   }
 
@@ -1701,12 +1794,13 @@ class Log(@volatile var dir: File,
     removeMetric("LogEndOffset", tags)
     removeMetric("Size", tags)
   }
+
   /**
    * Add the given segment to the segments in this log. If this segment replaces an existing segment, delete it.
-   *
    * @param segment The segment to add
    */
-  def addSegment(segment: LogSegment) = this.segments.put(segment.baseOffset, segment)
+  @threadsafe
+  def addSegment(segment: LogSegment): LogSegment = this.segments.put(segment.baseOffset, segment)
 
   private def maybeHandleIOException[T](msg: => String)(fun: => T): T = {
     try {
@@ -1718,6 +1812,140 @@ class Log(@volatile var dir: File,
     }
   }
 
+  /**
+   * @throws LogSegmentOffsetOverflowException if we encounter segments with index overflow for more than maxTries
+   */
+  private[log] def retryOnOffsetOverflow[T](fn: => T): T = {
+    var triesSoFar = 0
+    while (true) {
+      try {
+        return fn
+      } catch {
+        case e: LogSegmentOffsetOverflowException =>
+          triesSoFar += 1
+          info(s"Caught LogOffsetOverflowException ${e.getMessage}. Split segment and retry. retry#: $triesSoFar.")
+          splitOverflowedSegment(e.logSegment)
+      }
+    }
+    throw new IllegalStateException()
+  }
+
+  /**
+   * Split the given log segment into multiple such that there is no offset overflow in the resulting segments. The
+   * resulting segments will contain the exact same messages that are present in the input segment. On successful
+   * completion of this method, the input segment will be deleted and will be replaced by the resulting new segments.
+   * See replaceSegments for recovery logic, in case the broker dies in the middle of this operation.
+   * <p>Note that this method assumes we have already determined that the segment passed in contains records that cause
+   * offset overflow.</p>
+   * <p>The split logic overloads the use of .clean files that LogCleaner typically uses to make the process of replacing
+   * the input segment with multiple new segments atomic and recoverable in the event of a crash. See replaceSegments
+   * and completeSwapOperations for the implementation to make this operation recoverable on crashes.</p>
+   * @param segment Segment to split
+   * @return List of new segments that replace the input segment
+   */
+  private[log] def splitOverflowedSegment(segment: LogSegment): List[LogSegment] = {
+    require(isLogFile(segment.log.file), s"Cannot split file ${segment.log.file.getAbsoluteFile}")
+    info(s"Attempting to split segment ${segment.log.file.getAbsolutePath}")
+
+    val newSegments = ListBuffer[LogSegment]()
+    var position = 0
+    val sourceRecords = segment.log
+    var readBuffer = ByteBuffer.allocate(1024 * 1024)
+
+    class CopyResult(val bytesRead: Int, val overflowOffset: Option[Long])
+
+    // Helper method to copy `records` into `segment`. Makes sure records being appended do not result in offset overflow.
+    def copyRecordsToSegment(records: FileRecords, segment: LogSegment, readBuffer: ByteBuffer): CopyResult = {
+      var bytesRead = 0
+      var maxTimestamp = Long.MinValue
+      var offsetOfMaxTimestamp = Long.MinValue
+      var maxOffset = Long.MinValue
+
+      // find all batches that are valid to be appended to the current log segment
+      val (validBatches, overflowBatches) = records.batches.asScala.span(batch => segment.offsetIndex.canAppendOffset(batch.lastOffset))
+      val overflowOffset = overflowBatches.headOption.map { firstBatch =>
+        info(s"Found overflow at offset ${firstBatch.baseOffset} in segment $segment")
+        firstBatch.baseOffset
+      }
+
+      // return early if no valid batches were found
+      if (validBatches.isEmpty) {
+        require(overflowOffset.isDefined, "No batches found during split")
+        return new CopyResult(0, overflowOffset)
+      }
+
+      // determine the maximum offset and timestamp in batches
+      for (batch <- validBatches) {
+        if (batch.maxTimestamp > maxTimestamp) {
+          maxTimestamp = batch.maxTimestamp
+          offsetOfMaxTimestamp = batch.lastOffset
+        }
+        maxOffset = batch.lastOffset
+        bytesRead += batch.sizeInBytes
+      }
+
+      // read all valid batches into memory
+      val validRecords = records.slice(0, bytesRead)
+      require(readBuffer.capacity >= validRecords.sizeInBytes)
+      readBuffer.clear()
+      readBuffer.limit(validRecords.sizeInBytes)
+      validRecords.readInto(readBuffer, 0)
+
+      // append valid batches into the segment
+      segment.append(maxOffset, maxTimestamp, offsetOfMaxTimestamp, MemoryRecords.readableRecords(readBuffer))
+      readBuffer.clear()
+      info(s"Appended messages till $maxOffset to segment $segment during split")
+
+      new CopyResult(bytesRead, overflowOffset)
+    }
+
+    try {
+      info(s"Splitting segment $segment")
+      newSegments += LogCleaner.createNewCleanedSegment(this, segment.baseOffset)
+      while (position < sourceRecords.sizeInBytes) {
+        val currentSegment = newSegments.last
+
+        // grow buffers if needed
+        val firstBatch = sourceRecords.batchesFrom(position).asScala.head
+        if (firstBatch.sizeInBytes > readBuffer.capacity)
+          readBuffer = ByteBuffer.allocate(firstBatch.sizeInBytes)
+
+        // get records we want to copy and copy them into the new segment
+        val recordsToCopy = sourceRecords.slice(position, readBuffer.capacity)
+        val copyResult = copyRecordsToSegment(recordsToCopy, currentSegment, readBuffer)
+        position += copyResult.bytesRead
+
+        // create a new segment if there was an overflow
+        copyResult.overflowOffset.foreach(overflowOffset => newSegments += LogCleaner.createNewCleanedSegment(this, overflowOffset))
+      }
+      require(newSegments.length > 1, s"No offset overflow found for $segment")
+
+      // prepare new segments
+      var totalSizeOfNewSegments = 0
+      info(s"Split messages from $segment into ${newSegments.length} new segments")
+      newSegments.foreach { splitSegment =>
+        splitSegment.onBecomeInactiveSegment()
+        splitSegment.flush()
+        splitSegment.lastModified = segment.lastModified
+        totalSizeOfNewSegments += splitSegment.log.sizeInBytes
+        info(s"New segment: $splitSegment")
+      }
+      // size of all the new segments combined must equal size of the original segment
+      require(totalSizeOfNewSegments == segment.log.sizeInBytes, "Inconsistent segment sizes after split" +
+        s" before: ${segment.log.sizeInBytes} after: $totalSizeOfNewSegments")
+    } catch {
+      case e: Exception =>
+        newSegments.foreach { splitSegment =>
+          splitSegment.close()
+          splitSegment.deleteIfExists()
+        }
+        throw e
+    }
+
+    // replace old segment with new ones
+    replaceSegments(newSegments.toList, List(segment), isRecoveredSwapFile = false)
+    newSegments.toList
+  }
 }
 
 /**
@@ -1809,17 +2037,17 @@ object Log {
     new File(dir, filenamePrefixFromOffset(offset) + LogFileSuffix + suffix)
 
   /**
-    * Return a directory name to rename the log directory to for async deletion. The name will be in the following
-    * format: topic-partition.uniqueId-delete where topic, partition and uniqueId are variables.
-    */
+   * Return a directory name to rename the log directory to for async deletion. The name will be in the following
+   * format: topic-partition.uniqueId-delete where topic, partition and uniqueId are variables.
+   */
   def logDeleteDirName(topicPartition: TopicPartition): String = {
     logDirNameWithSuffix(topicPartition, DeleteDirSuffix)
   }
 
   /**
-    * Return a future directory name for the given topic partition. The name will be in the following
-    * format: topic-partition.uniqueId-future where topic, partition and uniqueId are variables.
-    */
+   * Return a future directory name for the given topic partition. The name will be in the following
+   * format: topic-partition.uniqueId-future where topic, partition and uniqueId are variables.
+   */
   def logFutureDirName(topicPartition: TopicPartition): String = {
     logDirNameWithSuffix(topicPartition, FutureDirSuffix)
   }
@@ -1830,9 +2058,9 @@ object Log {
   }
 
   /**
-    * Return a directory name for the given topic partition. The name will be in the following
-    * format: topic-partition where topic, partition are variables.
-    */
+   * Return a directory name for the given topic partition. The name will be in the following
+   * format: topic-partition where topic, partition are variables.
+   */
   def logDirName(topicPartition: TopicPartition): String = {
     s"${topicPartition.topic}-${topicPartition.partition}"
   }
@@ -1857,6 +2085,9 @@ object Log {
   def timeIndexFile(dir: File, offset: Long, suffix: String = ""): File =
     new File(dir, filenamePrefixFromOffset(offset) + TimeIndexFileSuffix + suffix)
 
+  def deleteFileIfExists(file: File, suffix: String = ""): Unit =
+    Files.deleteIfExists(new File(file.getPath + suffix).toPath)
+
   /**
    * Construct a producer id snapshot file using the given offset.
    *
@@ -1876,17 +2107,20 @@ object Log {
   def transactionIndexFile(dir: File, offset: Long, suffix: String = ""): File =
     new File(dir, filenamePrefixFromOffset(offset) + TxnIndexFileSuffix + suffix)
 
-  def offsetFromFile(file: File): Long = {
-    val filename = file.getName
+  def offsetFromFileName(filename: String): Long = {
     filename.substring(0, filename.indexOf('.')).toLong
   }
 
+  def offsetFromFile(file: File): Long = {
+    offsetFromFileName(file.getName)
+  }
+
   /**
-    * Calculate a log's size (in bytes) based on its log segments
-    *
-    * @param segments The log segments to calculate the size of
-    * @return Sum of the log segments' sizes (in bytes)
-    */
+   * Calculate a log's size (in bytes) based on its log segments
+   *
+   * @param segments The log segments to calculate the size of
+   * @return Sum of the log segments' sizes (in bytes)
+   */
   def sizeInBytes(segments: Iterable[LogSegment]): Long =
     segments.map(_.size.toLong).sum
 
diff --git a/core/src/main/scala/kafka/log/LogCleaner.scala b/core/src/main/scala/kafka/log/LogCleaner.scala
index aa7cfe2..d79a840 100644
--- a/core/src/main/scala/kafka/log/LogCleaner.scala
+++ b/core/src/main/scala/kafka/log/LogCleaner.scala
@@ -19,7 +19,6 @@ package kafka.log
 
 import java.io.{File, IOException}
 import java.nio._
-import java.nio.file.Files
 import java.util.Date
 import java.util.concurrent.TimeUnit
 
@@ -28,16 +27,16 @@ import kafka.common._
 import kafka.metrics.KafkaMetricsGroup
 import kafka.server.{BrokerReconfigurable, KafkaConfig, LogDirFailureChannel}
 import kafka.utils._
-import org.apache.kafka.common.record._
-import org.apache.kafka.common.utils.Time
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.config.ConfigException
 import org.apache.kafka.common.errors.{CorruptRecordException, KafkaStorageException}
 import org.apache.kafka.common.record.MemoryRecords.RecordFilter
 import org.apache.kafka.common.record.MemoryRecords.RecordFilter.BatchRetention
+import org.apache.kafka.common.record._
+import org.apache.kafka.common.utils.Time
 
-import scala.collection.{Set, mutable}
 import scala.collection.JavaConverters._
+import scala.collection.{Set, mutable}
 
 /**
  * The cleaner is responsible for removing obsolete records from logs which have the "compact" retention strategy.
@@ -382,6 +381,12 @@ object LogCleaner {
       enableCleaner = config.logCleanerEnable)
 
   }
+
+  def createNewCleanedSegment(log: Log, baseOffset: Long): LogSegment = {
+    LogSegment.deleteIfExists(log.dir, baseOffset, fileSuffix = Log.CleanedFileSuffix)
+    LogSegment.open(log.dir, baseOffset, log.config, Time.SYSTEM, fileAlreadyExists = false,
+      fileSuffix = Log.CleanedFileSuffix, initFileSize = log.initFileSize, preallocate = log.config.preallocate)
+  }
 }
 
 /**
@@ -454,7 +459,6 @@ private[log] class Cleaner(val id: Int,
     // this is the lower of the last active segment and the compaction lag
     val cleanableHorizonMs = log.logSegments(0, cleanable.firstUncleanableOffset).lastOption.map(_.lastModified).getOrElse(0L)
 
-
     // group the segments and clean the groups
     info("Cleaning log %s (cleaning prior to %s, discarding tombstones prior to %s)...".format(log.name, new Date(cleanableHorizonMs), new Date(deleteHorizonMs)))
     for (group <- groupSegmentsBySize(log.logSegments(0, endOffset), log.config.segmentSize, log.config.maxIndexSize, cleanable.firstUncleanableOffset))
@@ -482,21 +486,8 @@ private[log] class Cleaner(val id: Int,
                                  map: OffsetMap,
                                  deleteHorizonMs: Long,
                                  stats: CleanerStats) {
-
-    def deleteCleanedFileIfExists(file: File): Unit = {
-      Files.deleteIfExists(new File(file.getPath + Log.CleanedFileSuffix).toPath)
-    }
-
     // create a new segment with a suffix appended to the name of the log and indexes
-    val firstSegment = segments.head
-    deleteCleanedFileIfExists(firstSegment.log.file)
-    deleteCleanedFileIfExists(firstSegment.offsetIndex.file)
-    deleteCleanedFileIfExists(firstSegment.timeIndex.file)
-    deleteCleanedFileIfExists(firstSegment.txnIndex.file)
-
-    val baseOffset = firstSegment.baseOffset
-    val cleaned = LogSegment.open(log.dir, baseOffset, log.config, time, fileSuffix = Log.CleanedFileSuffix,
-      initFileSize = log.initFileSize, preallocate = log.config.preallocate)
+    val cleaned = LogCleaner.createNewCleanedSegment(log, segments.head.baseOffset)
 
     try {
       // clean segments into the new destination segment
@@ -514,9 +505,18 @@ private[log] class Cleaner(val id: Int,
         val retainDeletes = currentSegment.lastModified > deleteHorizonMs
         info(s"Cleaning segment $startOffset in log ${log.name} (largest timestamp ${new Date(currentSegment.largestTimestamp)}) " +
           s"into ${cleaned.baseOffset}, ${if(retainDeletes) "retaining" else "discarding"} deletes.")
-        cleanInto(log.topicPartition, currentSegment.log, cleaned, map, retainDeletes, log.config.maxMessageSize,
-          transactionMetadata, log.activeProducersWithLastSequence, stats)
 
+        try {
+          cleanInto(log.topicPartition, currentSegment.log, cleaned, map, retainDeletes, log.config.maxMessageSize,
+            transactionMetadata, log.activeProducersWithLastSequence, stats)
+        } catch {
+          case e: LogSegmentOffsetOverflowException =>
+            // Split the current segment. It's also safest to abort the current cleaning process, so that we retry from
+            // scratch once the split is complete.
+            info(s"Caught LogSegmentOverflowException during log cleaning $e")
+            log.splitOverflowedSegment(currentSegment)
+            throw new LogCleaningAbortedException()
+        }
         currentSegmentOpt = nextSegmentOpt
       }
 
@@ -531,7 +531,7 @@ private[log] class Cleaner(val id: Int,
       // swap in new segment
       info(s"Swapping in cleaned segment ${cleaned.baseOffset} for segment(s) ${segments.map(_.baseOffset).mkString(",")} " +
         s"in log ${log.name}")
-      log.replaceSegments(cleaned, segments)
+      log.replaceSegments(List(cleaned), segments)
     } catch {
       case e: LogCleaningAbortedException =>
         try cleaned.deleteIfExists()
diff --git a/core/src/main/scala/kafka/log/LogSegment.scala b/core/src/main/scala/kafka/log/LogSegment.scala
index 55ab088..6d61a41 100755
--- a/core/src/main/scala/kafka/log/LogSegment.scala
+++ b/core/src/main/scala/kafka/log/LogSegment.scala
@@ -21,11 +21,12 @@ import java.nio.file.{Files, NoSuchFileException}
 import java.nio.file.attribute.FileTime
 import java.util.concurrent.TimeUnit
 
+import kafka.common.{IndexOffsetOverflowException, LogSegmentOffsetOverflowException}
 import kafka.metrics.{KafkaMetricsGroup, KafkaTimer}
 import kafka.server.epoch.LeaderEpochCache
 import kafka.server.{FetchDataInfo, LogOffsetMetadata}
 import kafka.utils._
-import org.apache.kafka.common.errors.CorruptRecordException
+import org.apache.kafka.common.errors.{CorruptRecordException, InvalidOffsetException}
 import org.apache.kafka.common.record.FileRecords.LogOffsetPosition
 import org.apache.kafka.common.record._
 import org.apache.kafka.common.utils.Time
@@ -103,7 +104,7 @@ class LogSegment private[log] (val log: FileRecords,
    * checks that the argument offset can be represented as an integer offset relative to the baseOffset.
    */
   def canConvertToRelativeOffset(offset: Long): Boolean = {
-    (offset - baseOffset) <= Integer.MAX_VALUE
+    offsetIndex.canAppendOffset(offset)
   }
 
   /**
@@ -117,6 +118,7 @@ class LogSegment private[log] (val log: FileRecords,
    * @param shallowOffsetOfMaxTimestamp The offset of the message that has the largest timestamp in the messages to append.
    * @param records The log entries to append.
    * @return the physical position in the file of the appended records
+   * @throws LogSegmentOffsetOverflowException if the largest offset causes index offset overflow
    */
   @nonthreadsafe
   def append(largestOffset: Long,
@@ -129,8 +131,13 @@ class LogSegment private[log] (val log: FileRecords,
       val physicalPosition = log.sizeInBytes()
       if (physicalPosition == 0)
         rollingBasedTimestamp = Some(largestTimestamp)
+
+      if (!canConvertToRelativeOffset(largestOffset))
+        throw new LogSegmentOffsetOverflowException(
+          s"largest offset $largestOffset cannot be safely converted to relative offset for segment with baseOffset $baseOffset",
+          this)
+
       // append the messages
-      require(canConvertToRelativeOffset(largestOffset), "largest offset in message set can not be safely converted to relative offset.")
       val appendedBytes = log.append(records)
       trace(s"Appended $appendedBytes to ${log.file()} at end offset $largestOffset")
       // Update the in memory max timestamp and corresponding offset.
@@ -139,9 +146,9 @@ class LogSegment private[log] (val log: FileRecords,
         offsetOfMaxTimestamp = shallowOffsetOfMaxTimestamp
       }
       // append an entry to the index (if needed)
-      if(bytesSinceLastIndexEntry > indexIntervalBytes) {
-        offsetIndex.append(largestOffset, physicalPosition)
-        timeIndex.maybeAppend(maxTimestampSoFar, offsetOfMaxTimestamp)
+      if (bytesSinceLastIndexEntry > indexIntervalBytes) {
+        appendToOffsetIndex(largestOffset, physicalPosition)
+        maybeAppendToTimeIndex(maxTimestampSoFar, offsetOfMaxTimestamp)
         bytesSinceLastIndexEntry = 0
       }
       bytesSinceLastIndexEntry += records.sizeInBytes
@@ -193,8 +200,8 @@ class LogSegment private[log] (val log: FileRecords,
    * no more than maxSize bytes and will end before maxOffset if a maxOffset is specified.
    *
    * @param startOffset A lower bound on the first offset to include in the message set we read
-   * @param maxSize The maximum number of bytes to include in the message set we read
    * @param maxOffset An optional maximum offset for the message set we read
+   * @param maxSize The maximum number of bytes to include in the message set we read
    * @param maxPosition The maximum position in the log segment that should be exposed for read
    * @param minOneMessage If this is true, the first message will be returned even if it exceeds `maxSize` (if one exists)
    *
@@ -246,7 +253,7 @@ class LogSegment private[log] (val log: FileRecords,
         min(min(maxPosition, endPosition) - startPosition, adjustedMaxSize).toInt
     }
 
-    FetchDataInfo(offsetMetadata, log.read(startPosition, fetchSize),
+    FetchDataInfo(offsetMetadata, log.slice(startPosition, fetchSize),
       firstEntryIncomplete = adjustedMaxSize < startOffsetAndSize.size)
   }
 
@@ -261,6 +268,7 @@ class LogSegment private[log] (val log: FileRecords,
    *                             the transaction index.
    * @param leaderEpochCache Optionally a cache for updating the leader epoch during recovery.
    * @return The number of bytes truncated from the log
+   * @throws LogSegmentOffsetOverflowException if the log segment contains an offset that causes the index offset to overflow
    */
   @nonthreadsafe
   def recover(producerStateManager: ProducerStateManager, leaderEpochCache: Option[LeaderEpochCache] = None): Int = {
@@ -282,9 +290,8 @@ class LogSegment private[log] (val log: FileRecords,
 
         // Build offset index
         if (validBytes - lastIndexEntry > indexIntervalBytes) {
-          val startOffset = batch.baseOffset
-          offsetIndex.append(startOffset, validBytes)
-          timeIndex.maybeAppend(maxTimestampSoFar, offsetOfMaxTimestamp)
+          appendToOffsetIndex(batch.lastOffset, validBytes)
+          maybeAppendToTimeIndex(maxTimestampSoFar, offsetOfMaxTimestamp)
           lastIndexEntry = validBytes
         }
         validBytes += batch.sizeInBytes()
@@ -309,7 +316,7 @@ class LogSegment private[log] (val log: FileRecords,
     log.truncateTo(validBytes)
     offsetIndex.trimToValidSize()
     // A normally closed segment always appends the biggest timestamp ever seen into log segment, we do this as well.
-    timeIndex.maybeAppend(maxTimestampSoFar, offsetOfMaxTimestamp, skipFullCheck = true)
+    maybeAppendToTimeIndex(maxTimestampSoFar, offsetOfMaxTimestamp, skipFullCheck = true)
     timeIndex.trimToValidSize()
     truncated
   }
@@ -372,11 +379,11 @@ class LogSegment private[log] (val log: FileRecords,
    */
   @threadsafe
   def readNextOffset: Long = {
-    val ms = read(offsetIndex.lastOffset, None, log.sizeInBytes)
-    if (ms == null)
+    val fetchData = read(offsetIndex.lastOffset, None, log.sizeInBytes)
+    if (fetchData == null)
       baseOffset
     else
-      ms.records.batches.asScala.lastOption
+      fetchData.records.batches.asScala.lastOption
         .map(_.nextOffset)
         .getOrElse(baseOffset)
   }
@@ -422,7 +429,7 @@ class LogSegment private[log] (val log: FileRecords,
    * The time index entry appended will be used to decide when to delete the segment.
    */
   def onBecomeInactiveSegment() {
-    timeIndex.maybeAppend(maxTimestampSoFar, offsetOfMaxTimestamp, skipFullCheck = true)
+    maybeAppendToTimeIndex(maxTimestampSoFar, offsetOfMaxTimestamp, skipFullCheck = true)
     offsetIndex.trimToValidSize()
     timeIndex.trimToValidSize()
     log.trim()
@@ -486,7 +493,7 @@ class LogSegment private[log] (val log: FileRecords,
    * Close this log segment
    */
   def close() {
-    CoreUtils.swallow(timeIndex.maybeAppend(maxTimestampSoFar, offsetOfMaxTimestamp, skipFullCheck = true), this)
+    CoreUtils.swallow(maybeAppendToTimeIndex(maxTimestampSoFar, offsetOfMaxTimestamp, skipFullCheck = true), this)
     CoreUtils.swallow(offsetIndex.close(), this)
     CoreUtils.swallow(timeIndex.close(), this)
     CoreUtils.swallow(log.close(), this)
@@ -546,6 +553,25 @@ class LogSegment private[log] (val log: FileRecords,
     Files.setLastModifiedTime(offsetIndex.file.toPath, fileTime)
     Files.setLastModifiedTime(timeIndex.file.toPath, fileTime)
   }
+
+  private def maybeAppendToTimeIndex(timestamp: Long, offset: Long, skipFullCheck: Boolean = false): Unit = {
+    maybeHandleOffsetOverflowException {
+      timeIndex.maybeAppend(timestamp, offset, skipFullCheck)
+    }
+  }
+
+  private def appendToOffsetIndex(offset: Long, position: Int): Unit = {
+    maybeHandleOffsetOverflowException {
+      offsetIndex.append(offset, position)
+    }
+  }
+
+  private def maybeHandleOffsetOverflowException[T](fun: => T): T = {
+    try fun
+    catch {
+      case e: IndexOffsetOverflowException => throw new LogSegmentOffsetOverflowException(e, this)
+    }
+  }
 }
 
 object LogSegment {
@@ -566,6 +592,12 @@ object LogSegment {
       time)
   }
 
+  def deleteIfExists(dir: File, baseOffset: Long, fileSuffix: String = ""): Unit = {
+    Log.deleteFileIfExists(Log.offsetIndexFile(dir, baseOffset, fileSuffix))
+    Log.deleteFileIfExists(Log.timeIndexFile(dir, baseOffset, fileSuffix))
+    Log.deleteFileIfExists(Log.transactionIndexFile(dir, baseOffset, fileSuffix))
+    Log.deleteFileIfExists(Log.logFile(dir, baseOffset, fileSuffix))
+  }
 }
 
 object LogFlushStats extends KafkaMetricsGroup {
diff --git a/core/src/main/scala/kafka/log/OffsetIndex.scala b/core/src/main/scala/kafka/log/OffsetIndex.scala
index 523c88c..d185631 100755
--- a/core/src/main/scala/kafka/log/OffsetIndex.scala
+++ b/core/src/main/scala/kafka/log/OffsetIndex.scala
@@ -21,7 +21,7 @@ import java.io.File
 import java.nio.ByteBuffer
 
 import kafka.utils.CoreUtils.inLock
-import kafka.common.InvalidOffsetException
+import kafka.common.{IndexOffsetOverflowException, InvalidOffsetException}
 
 /**
  * An index that maps offsets to physical file locations for a particular log segment. This index may be sparse:
@@ -134,13 +134,14 @@ class OffsetIndex(_file: File, baseOffset: Long, maxIndexSize: Int = -1, writabl
 
   /**
    * Append an entry for the given offset/location pair to the index. This entry must have a larger offset than all subsequent entries.
+   * @throws IndexOffsetOverflowException if the offset causes index offset to overflow
    */
   def append(offset: Long, position: Int) {
     inLock(lock) {
       require(!isFull, "Attempt to append to a full index (size = " + _entries + ").")
       if (_entries == 0 || offset > _lastOffset) {
         debug("Adding index entry %d => %d to %s.".format(offset, position, file.getName))
-        mmap.putInt((offset - baseOffset).toInt)
+        mmap.putInt(relativeOffset(offset))
         mmap.putInt(position)
         _entries += 1
         _lastOffset = offset
diff --git a/core/src/main/scala/kafka/log/TimeIndex.scala b/core/src/main/scala/kafka/log/TimeIndex.scala
index e505f36..7fae130 100644
--- a/core/src/main/scala/kafka/log/TimeIndex.scala
+++ b/core/src/main/scala/kafka/log/TimeIndex.scala
@@ -128,7 +128,7 @@ class TimeIndex(_file: File, baseOffset: Long, maxIndexSize: Int = -1, writable:
       if (timestamp > lastEntry.timestamp) {
         debug("Adding index entry %d => %d to %s.".format(timestamp, offset, file.getName))
         mmap.putLong(timestamp)
-        mmap.putInt((offset - baseOffset).toInt)
+        mmap.putInt(relativeOffset(offset))
         _entries += 1
         _lastEntry = TimestampOffset(timestamp, offset)
         require(_entries * entrySize == mmap.position(), _entries + " entries but file position in index is " + mmap.position() + ".")
diff --git a/core/src/main/scala/kafka/tools/DumpLogSegments.scala b/core/src/main/scala/kafka/tools/DumpLogSegments.scala
index 2aa7ad3..17fbd8f 100755
--- a/core/src/main/scala/kafka/tools/DumpLogSegments.scala
+++ b/core/src/main/scala/kafka/tools/DumpLogSegments.scala
@@ -189,7 +189,7 @@ object DumpLogSegments {
 
     for(i <- 0 until index.entries) {
       val entry = index.entry(i)
-      val slice = fileRecords.read(entry.position, maxMessageSize)
+      val slice = fileRecords.slice(entry.position, maxMessageSize)
       val firstRecord = slice.records.iterator.next()
       if (firstRecord.offset != entry.offset + index.baseOffset) {
         var misMatchesSeq = misMatchesForIndexFilesMap.getOrElse(file.getAbsolutePath, List[(Long, Long)]())
@@ -227,7 +227,7 @@ object DumpLogSegments {
     for(i <- 0 until timeIndex.entries) {
       val entry = timeIndex.entry(i)
       val position = index.lookup(entry.offset + timeIndex.baseOffset).position
-      val partialFileRecords = fileRecords.read(position, Int.MaxValue)
+      val partialFileRecords = fileRecords.slice(position, Int.MaxValue)
       val batches = partialFileRecords.batches.asScala
       var maxTimestamp = RecordBatch.NO_TIMESTAMP
       // We first find the message by offset then check if the timestamp is correct.
diff --git a/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala b/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
index 537c561..3207e15 100755
--- a/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala
@@ -71,11 +71,11 @@ class LogCleanerTest extends JUnitSuite {
     // append messages to the log until we have four segments
     while(log.numberOfSegments < 4)
       log.appendAsLeader(record(log.logEndOffset.toInt, log.logEndOffset.toInt), leaderEpoch = 0)
-    val keysFound = keysInLog(log)
+    val keysFound = LogTest.keysInLog(log)
     assertEquals(0L until log.logEndOffset, keysFound)
 
     // pretend we have the following keys
-    val keys = immutable.ListSet(1, 3, 5, 7, 9)
+    val keys = immutable.ListSet(1L, 3L, 5L, 7L, 9L)
     val map = new FakeOffsetMap(Int.MaxValue)
     keys.foreach(k => map.put(key(k), Long.MaxValue))
 
@@ -84,8 +84,8 @@ class LogCleanerTest extends JUnitSuite {
     val stats = new CleanerStats()
     val expectedBytesRead = segments.map(_.size).sum
     cleaner.cleanSegments(log, segments, map, 0L, stats)
-    val shouldRemain = keysInLog(log).filter(!keys.contains(_))
-    assertEquals(shouldRemain, keysInLog(log))
+    val shouldRemain = LogTest.keysInLog(log).filter(!keys.contains(_))
+    assertEquals(shouldRemain, LogTest.keysInLog(log))
     assertEquals(expectedBytesRead, stats.bytesRead)
   }
 
@@ -135,7 +135,7 @@ class LogCleanerTest extends JUnitSuite {
     cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 0L, log.activeSegment.baseOffset))
     assertEquals(List(2, 5, 7), lastOffsetsPerBatchInLog(log))
     assertEquals(Map(pid1 -> 2, pid2 -> 2, pid3 -> 1), lastSequencesInLog(log))
-    assertEquals(List(2, 3, 1, 4), keysInLog(log))
+    assertEquals(List(2, 3, 1, 4), LogTest.keysInLog(log))
     assertEquals(List(1, 3, 6, 7), offsetsInLog(log))
 
     // we have to reload the log to validate that the cleaner maintained sequence numbers correctly
@@ -167,7 +167,7 @@ class LogCleanerTest extends JUnitSuite {
     cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 0L, log.activeSegment.baseOffset))
     assertEquals(Map(pid1 -> 2, pid2 -> 2, pid3 -> 1, pid4 -> 0), lastSequencesInLog(log))
     assertEquals(List(2, 5, 7, 8), lastOffsetsPerBatchInLog(log))
-    assertEquals(List(3, 1, 4, 2), keysInLog(log))
+    assertEquals(List(3, 1, 4, 2), LogTest.keysInLog(log))
     assertEquals(List(3, 6, 7, 8), offsetsInLog(log))
 
     reloadLog()
@@ -204,7 +204,7 @@ class LogCleanerTest extends JUnitSuite {
 
     log.roll()
     cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 0L, log.activeSegment.baseOffset))
-    assertEquals(List(3, 2), keysInLog(log))
+    assertEquals(List(3, 2), LogTest.keysInLog(log))
     assertEquals(List(3, 6, 7, 8, 9), offsetsInLog(log))
 
     // ensure the transaction index is still correct
@@ -244,7 +244,7 @@ class LogCleanerTest extends JUnitSuite {
 
     // we have only cleaned the records in the first segment
     val dirtyOffset = cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 0L, log.activeSegment.baseOffset))._1
-    assertEquals(List(2, 3, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10), keysInLog(log))
+    assertEquals(List(2, 3, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10), LogTest.keysInLog(log))
 
     log.roll()
 
@@ -254,7 +254,7 @@ class LogCleanerTest extends JUnitSuite {
 
     // finally only the keys from pid3 should remain
     cleaner.clean(LogToClean(new TopicPartition("test", 0), log, dirtyOffset, log.activeSegment.baseOffset))
-    assertEquals(List(2, 3, 6, 7, 8, 9, 11, 12), keysInLog(log))
+    assertEquals(List(2, 3, 6, 7, 8, 9, 11, 12), LogTest.keysInLog(log))
   }
 
   @Test
@@ -278,7 +278,7 @@ class LogCleanerTest extends JUnitSuite {
 
     // cannot remove the marker in this pass because there are still valid records
     var dirtyOffset = cleaner.doClean(LogToClean(tp, log, 0L, 100L), deleteHorizonMs = Long.MaxValue)._1
-    assertEquals(List(1, 3, 2), keysInLog(log))
+    assertEquals(List(1, 3, 2), LogTest.keysInLog(log))
     assertEquals(List(0, 2, 3, 4, 5), offsetsInLog(log))
 
     appendProducer(Seq(1, 3))
@@ -287,17 +287,17 @@ class LogCleanerTest extends JUnitSuite {
 
     // the first cleaning preserves the commit marker (at offset 3) since there were still records for the transaction
     dirtyOffset = cleaner.doClean(LogToClean(tp, log, dirtyOffset, 100L), deleteHorizonMs = Long.MaxValue)._1
-    assertEquals(List(2, 1, 3), keysInLog(log))
+    assertEquals(List(2, 1, 3), LogTest.keysInLog(log))
     assertEquals(List(3, 4, 5, 6, 7, 8), offsetsInLog(log))
 
     // delete horizon forced to 0 to verify marker is not removed early
     dirtyOffset = cleaner.doClean(LogToClean(tp, log, dirtyOffset, 100L), deleteHorizonMs = 0L)._1
-    assertEquals(List(2, 1, 3), keysInLog(log))
+    assertEquals(List(2, 1, 3), LogTest.keysInLog(log))
     assertEquals(List(3, 4, 5, 6, 7, 8), offsetsInLog(log))
 
     // clean again with large delete horizon and verify the marker is removed
     dirtyOffset = cleaner.doClean(LogToClean(tp, log, dirtyOffset, 100L), deleteHorizonMs = Long.MaxValue)._1
-    assertEquals(List(2, 1, 3), keysInLog(log))
+    assertEquals(List(2, 1, 3), LogTest.keysInLog(log))
     assertEquals(List(4, 5, 6, 7, 8), offsetsInLog(log))
   }
 
@@ -326,11 +326,11 @@ class LogCleanerTest extends JUnitSuite {
     log.roll()
 
     cleaner.doClean(LogToClean(tp, log, 0L, 100L), deleteHorizonMs = Long.MaxValue)
-    assertEquals(List(2), keysInLog(log))
+    assertEquals(List(2), LogTest.keysInLog(log))
     assertEquals(List(1, 3, 4), offsetsInLog(log))
 
     cleaner.doClean(LogToClean(tp, log, 0L, 100L), deleteHorizonMs = Long.MaxValue)
-    assertEquals(List(2), keysInLog(log))
+    assertEquals(List(2), LogTest.keysInLog(log))
     assertEquals(List(3, 4), offsetsInLog(log))
   }
 
@@ -356,13 +356,13 @@ class LogCleanerTest extends JUnitSuite {
 
     // first time through the records are removed
     var dirtyOffset = cleaner.doClean(LogToClean(tp, log, 0L, 100L), deleteHorizonMs = Long.MaxValue)._1
-    assertEquals(List(2, 3), keysInLog(log))
+    assertEquals(List(2, 3), LogTest.keysInLog(log))
     assertEquals(List(2, 3, 4), offsetsInLog(log)) // commit marker is retained
     assertEquals(List(1, 2, 3, 4), lastOffsetsPerBatchInLog(log)) // empty batch is retained
 
     // the empty batch remains if cleaned again because it still holds the last sequence
     dirtyOffset = cleaner.doClean(LogToClean(tp, log, dirtyOffset, 100L), deleteHorizonMs = Long.MaxValue)._1
-    assertEquals(List(2, 3), keysInLog(log))
+    assertEquals(List(2, 3), LogTest.keysInLog(log))
     assertEquals(List(2, 3, 4), offsetsInLog(log)) // commit marker is still retained
     assertEquals(List(1, 2, 3, 4), lastOffsetsPerBatchInLog(log)) // empty batch is retained
 
@@ -371,12 +371,12 @@ class LogCleanerTest extends JUnitSuite {
     log.roll()
 
     dirtyOffset = cleaner.doClean(LogToClean(tp, log, dirtyOffset, 100L), deleteHorizonMs = Long.MaxValue)._1
-    assertEquals(List(2, 3, 1), keysInLog(log))
+    assertEquals(List(2, 3, 1), LogTest.keysInLog(log))
     assertEquals(List(2, 3, 4, 5), offsetsInLog(log)) // commit marker is still retained
     assertEquals(List(2, 3, 4, 5), lastOffsetsPerBatchInLog(log)) // empty batch should be gone
 
     dirtyOffset = cleaner.doClean(LogToClean(tp, log, dirtyOffset, 100L), deleteHorizonMs = Long.MaxValue)._1
-    assertEquals(List(2, 3, 1), keysInLog(log))
+    assertEquals(List(2, 3, 1), LogTest.keysInLog(log))
     assertEquals(List(3, 4, 5), offsetsInLog(log)) // commit marker is gone
     assertEquals(List(3, 4, 5), lastOffsetsPerBatchInLog(log)) // empty batch is gone
   }
@@ -402,12 +402,12 @@ class LogCleanerTest extends JUnitSuite {
 
     // delete horizon set to 0 to verify marker is not removed early
     val dirtyOffset = cleaner.doClean(LogToClean(tp, log, 0L, 100L), deleteHorizonMs = 0L)._1
-    assertEquals(List(3), keysInLog(log))
+    assertEquals(List(3), LogTest.keysInLog(log))
     assertEquals(List(3, 4, 5), offsetsInLog(log))
 
     // clean again with large delete horizon and verify the marker is removed
     cleaner.doClean(LogToClean(tp, log, dirtyOffset, 100L), deleteHorizonMs = Long.MaxValue)
-    assertEquals(List(3), keysInLog(log))
+    assertEquals(List(3), LogTest.keysInLog(log))
     assertEquals(List(4, 5), offsetsInLog(log))
   }
 
@@ -440,14 +440,14 @@ class LogCleanerTest extends JUnitSuite {
     // first time through the records are removed
     var dirtyOffset = cleaner.doClean(LogToClean(tp, log, 0L, 100L), deleteHorizonMs = Long.MaxValue)._1
     assertAbortedTransactionIndexed()
-    assertEquals(List(), keysInLog(log))
+    assertEquals(List(), LogTest.keysInLog(log))
     assertEquals(List(2), offsetsInLog(log)) // abort marker is retained
     assertEquals(List(1, 2), lastOffsetsPerBatchInLog(log)) // empty batch is retained
 
     // the empty batch remains if cleaned again because it still holds the last sequence
     dirtyOffset = cleaner.doClean(LogToClean(tp, log, dirtyOffset, 100L), deleteHorizonMs = Long.MaxValue)._1
     assertAbortedTransactionIndexed()
-    assertEquals(List(), keysInLog(log))
+    assertEquals(List(), LogTest.keysInLog(log))
     assertEquals(List(2), offsetsInLog(log)) // abort marker is still retained
     assertEquals(List(1, 2), lastOffsetsPerBatchInLog(log)) // empty batch is retained
 
@@ -457,12 +457,12 @@ class LogCleanerTest extends JUnitSuite {
 
     dirtyOffset = cleaner.doClean(LogToClean(tp, log, dirtyOffset, 100L), deleteHorizonMs = Long.MaxValue)._1
     assertAbortedTransactionIndexed()
-    assertEquals(List(1), keysInLog(log))
+    assertEquals(List(1), LogTest.keysInLog(log))
     assertEquals(List(2, 3), offsetsInLog(log)) // abort marker is not yet gone because we read the empty batch
     assertEquals(List(2, 3), lastOffsetsPerBatchInLog(log)) // but we do not preserve the empty batch
 
     dirtyOffset = cleaner.doClean(LogToClean(tp, log, dirtyOffset, 100L), deleteHorizonMs = Long.MaxValue)._1
-    assertEquals(List(1), keysInLog(log))
+    assertEquals(List(1), LogTest.keysInLog(log))
     assertEquals(List(3), offsetsInLog(log)) // abort marker is gone
     assertEquals(List(3), lastOffsetsPerBatchInLog(log))
 
@@ -486,19 +486,19 @@ class LogCleanerTest extends JUnitSuite {
 
     while(log.numberOfSegments < 2)
       log.appendAsLeader(record(log.logEndOffset.toInt, Array.fill(largeMessageSize)(0: Byte)), leaderEpoch = 0)
-    val keysFound = keysInLog(log)
+    val keysFound = LogTest.keysInLog(log)
     assertEquals(0L until log.logEndOffset, keysFound)
 
     // pretend we have the following keys
-    val keys = immutable.ListSet(1, 3, 5, 7, 9)
+    val keys = immutable.ListSet(1L, 3L, 5L, 7L, 9L)
     val map = new FakeOffsetMap(Int.MaxValue)
     keys.foreach(k => map.put(key(k), Long.MaxValue))
 
     // clean the log
     val stats = new CleanerStats()
     cleaner.cleanSegments(log, Seq(log.logSegments.head), map, 0L, stats)
-    val shouldRemain = keysInLog(log).filter(!keys.contains(_))
-    assertEquals(shouldRemain, keysInLog(log))
+    val shouldRemain = LogTest.keysInLog(log).filter(!keys.contains(_))
+    assertEquals(shouldRemain, LogTest.keysInLog(log))
   }
 
   /**
@@ -510,8 +510,8 @@ class LogCleanerTest extends JUnitSuite {
 
     val cleaner = makeCleaner(Int.MaxValue, maxMessageSize=1024)
     cleaner.cleanSegments(log, Seq(log.logSegments.head), offsetMap, 0L, new CleanerStats)
-    val shouldRemain = keysInLog(log).filter(k => !offsetMap.map.containsKey(k.toString))
-    assertEquals(shouldRemain, keysInLog(log))
+    val shouldRemain = LogTest.keysInLog(log).filter(k => !offsetMap.map.containsKey(k.toString))
+    assertEquals(shouldRemain, LogTest.keysInLog(log))
   }
 
   /**
@@ -558,7 +558,7 @@ class LogCleanerTest extends JUnitSuite {
 
     while(log.numberOfSegments < 2)
       log.appendAsLeader(record(log.logEndOffset.toInt, Array.fill(largeMessageSize)(0: Byte)), leaderEpoch = 0)
-    val keysFound = keysInLog(log)
+    val keysFound = LogTest.keysInLog(log)
     assertEquals(0L until log.logEndOffset, keysFound)
 
     // Decrease the log's max message size
@@ -595,7 +595,7 @@ class LogCleanerTest extends JUnitSuite {
       log.appendAsLeader(record(log.logEndOffset.toInt, log.logEndOffset.toInt), leaderEpoch = 0)
 
     cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 0, log.activeSegment.baseOffset))
-    val keys = keysInLog(log).toSet
+    val keys = LogTest.keysInLog(log).toSet
     assertTrue("None of the keys we deleted should still exist.",
                (0 until leo.toInt by 2).forall(!keys.contains(_)))
   }
@@ -647,7 +647,7 @@ class LogCleanerTest extends JUnitSuite {
     cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 0L, log.activeSegment.baseOffset))
     assertEquals(List(1, 3, 4), lastOffsetsPerBatchInLog(log))
     assertEquals(Map(1L -> 0, 2L -> 1, 3L -> 0), lastSequencesInLog(log))
-    assertEquals(List(0, 1), keysInLog(log))
+    assertEquals(List(0, 1), LogTest.keysInLog(log))
     assertEquals(List(3, 4), offsetsInLog(log))
   }
 
@@ -670,7 +670,7 @@ class LogCleanerTest extends JUnitSuite {
     cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 0L, log.activeSegment.baseOffset))
     assertEquals(List(2, 3), lastOffsetsPerBatchInLog(log))
     assertEquals(Map(producerId -> 2), lastSequencesInLog(log))
-    assertEquals(List(), keysInLog(log))
+    assertEquals(List(), LogTest.keysInLog(log))
     assertEquals(List(3), offsetsInLog(log))
 
     // Append a new entry from the producer and verify that the empty batch is cleaned up
@@ -680,7 +680,7 @@ class LogCleanerTest extends JUnitSuite {
 
     assertEquals(List(3, 5), lastOffsetsPerBatchInLog(log))
     assertEquals(Map(producerId -> 4), lastSequencesInLog(log))
-    assertEquals(List(1, 5), keysInLog(log))
+    assertEquals(List(1, 5), LogTest.keysInLog(log))
     assertEquals(List(3, 4, 5), offsetsInLog(log))
   }
 
@@ -703,16 +703,16 @@ class LogCleanerTest extends JUnitSuite {
 
     // clean the log with only one message removed
     cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 2, log.activeSegment.baseOffset))
-    assertEquals(List(1,0,1,0), keysInLog(log))
+    assertEquals(List(1,0,1,0), LogTest.keysInLog(log))
     assertEquals(List(1,2,3,4), offsetsInLog(log))
 
     // continue to make progress, even though we can only clean one message at a time
     cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 3, log.activeSegment.baseOffset))
-    assertEquals(List(0,1,0), keysInLog(log))
+    assertEquals(List(0,1,0), LogTest.keysInLog(log))
     assertEquals(List(2,3,4), offsetsInLog(log))
 
     cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 4, log.activeSegment.baseOffset))
-    assertEquals(List(1,0), keysInLog(log))
+    assertEquals(List(1,0), LogTest.keysInLog(log))
     assertEquals(List(3,4), offsetsInLog(log))
   }
 
@@ -835,14 +835,6 @@ class LogCleanerTest extends JUnitSuite {
     assertEquals("Cleaner should have seen %d invalid messages.", numInvalidMessages, stats.invalidMessagesRead)
   }
 
-  /* extract all the keys from a log */
-  def keysInLog(log: Log): Iterable[Int] = {
-    for (segment <- log.logSegments;
-         batch <- segment.log.batches.asScala if !batch.isControlBatch;
-         record <- batch.asScala if record.hasValue && record.hasKey)
-      yield TestUtils.readString(record.key).toInt
-  }
-
   def lastOffsetsPerBatchInLog(log: Log): Iterable[Long] = {
     for (segment <- log.logSegments; batch <- segment.log.batches.asScala)
       yield batch.lastOffset
@@ -880,7 +872,7 @@ class LogCleanerTest extends JUnitSuite {
     while(log.numberOfSegments < 4)
       log.appendAsLeader(record(log.logEndOffset.toInt, log.logEndOffset.toInt), leaderEpoch = 0)
 
-    val keys = keysInLog(log)
+    val keys = LogTest.keysInLog(log)
     val map = new FakeOffsetMap(Int.MaxValue)
     keys.foreach(k => map.put(key(k), Long.MaxValue))
     intercept[LogCleaningAbortedException] {
@@ -1065,6 +1057,43 @@ class LogCleanerTest extends JUnitSuite {
     checkRange(map, segments(3).baseOffset.toInt, log.logEndOffset.toInt)
   }
 
+  @Test
+  def testSegmentWithOffsetOverflow(): Unit = {
+    val cleaner = makeCleaner(Int.MaxValue)
+    val logProps = new Properties()
+    logProps.put(LogConfig.IndexIntervalBytesProp, 1: java.lang.Integer)
+    logProps.put(LogConfig.FileDeleteDelayMsProp, 1000: java.lang.Integer)
+    val config = LogConfig.fromProps(logConfig.originals, logProps)
+
+    val time = new MockTime()
+    val (log, segmentWithOverflow, _) = LogTest.createLogWithOffsetOverflow(dir, new BrokerTopicStats(), Some(config), time.scheduler, time)
+    val numSegmentsInitial = log.logSegments.size
+    val allKeys = LogTest.keysInLog(log).toList
+    val expectedKeysAfterCleaning = mutable.MutableList[Long]()
+
+    // pretend we want to clean every alternate key
+    val offsetMap = new FakeOffsetMap(Int.MaxValue)
+    for (k <- 1 until allKeys.size by 2) {
+      expectedKeysAfterCleaning += allKeys(k - 1)
+      offsetMap.put(key(allKeys(k)), Long.MaxValue)
+    }
+
+    // Try to clean segment with offset overflow. This will trigger log split and the cleaning itself must abort.
+    assertThrows[LogCleaningAbortedException] {
+      cleaner.cleanSegments(log, List(segmentWithOverflow), offsetMap, 0L, new CleanerStats())
+    }
+    assertEquals(numSegmentsInitial + 1, log.logSegments.size)
+    assertEquals(allKeys, LogTest.keysInLog(log))
+    assertFalse(LogTest.hasOffsetOverflow(log))
+
+    // Clean each segment now that split is complete.
+    for (segmentToClean <- log.logSegments)
+      cleaner.cleanSegments(log, List(segmentToClean), offsetMap, 0L, new CleanerStats())
+    assertEquals(expectedKeysAfterCleaning, LogTest.keysInLog(log))
+    assertFalse(LogTest.hasOffsetOverflow(log))
+    log.close()
+  }
+
   /**
    * Tests recovery if broker crashes at the following stages during the cleaning sequence
    * <ol>
@@ -1084,28 +1113,14 @@ class LogCleanerTest extends JUnitSuite {
 
     val config = LogConfig.fromProps(logConfig.originals, logProps)
 
-    def recoverAndCheck(config: LogConfig, expectedKeys: Iterable[Int]): Log = {
-      // Recover log file and check that after recovery, keys are as expected
-      // and all temporary files have been deleted
-      val recoveredLog = makeLog(config = config)
-      time.sleep(config.fileDeleteDelayMs + 1)
-      for (file <- dir.listFiles) {
-        assertFalse("Unexpected .deleted file after recovery", file.getName.endsWith(Log.DeletedFileSuffix))
-        assertFalse("Unexpected .cleaned file after recovery", file.getName.endsWith(Log.CleanedFileSuffix))
-        assertFalse("Unexpected .swap file after recovery", file.getName.endsWith(Log.SwapFileSuffix))
-      }
-      assertEquals(expectedKeys, keysInLog(recoveredLog))
-      recoveredLog
-    }
-
-    // create a log and append some messages
+   // create a log and append some messages
     var log = makeLog(config = config)
     var messageCount = 0
     while (log.numberOfSegments < 10) {
       log.appendAsLeader(record(log.logEndOffset.toInt, log.logEndOffset.toInt), leaderEpoch = 0)
       messageCount += 1
     }
-    val allKeys = keysInLog(log)
+    val allKeys = LogTest.keysInLog(log)
 
     // pretend we have odd-numbered keys
     val offsetMap = new FakeOffsetMap(Int.MaxValue)
@@ -1116,7 +1131,7 @@ class LogCleanerTest extends JUnitSuite {
     cleaner.cleanSegments(log, log.logSegments.take(9).toSeq, offsetMap, 0L, new CleanerStats())
     // clear scheduler so that async deletes don't run
     time.scheduler.clear()
-    var cleanedKeys = keysInLog(log)
+    var cleanedKeys = LogTest.keysInLog(log)
     log.close()
 
     // 1) Simulate recovery just after .cleaned file is created, before rename to .swap
@@ -1131,7 +1146,7 @@ class LogCleanerTest extends JUnitSuite {
     cleaner.cleanSegments(log, log.logSegments.take(9).toSeq, offsetMap, 0L, new CleanerStats())
     // clear scheduler so that async deletes don't run
     time.scheduler.clear()
-    cleanedKeys = keysInLog(log)
+    cleanedKeys = LogTest.keysInLog(log)
     log.close()
 
     // 2) Simulate recovery just after swap file is created, before old segment files are
@@ -1152,7 +1167,7 @@ class LogCleanerTest extends JUnitSuite {
     cleaner.cleanSegments(log, log.logSegments.take(9).toSeq, offsetMap, 0L, new CleanerStats())
     // clear scheduler so that async deletes don't run
     time.scheduler.clear()
-    cleanedKeys = keysInLog(log)
+    cleanedKeys = LogTest.keysInLog(log)
 
     // 3) Simulate recovery after swap file is created and old segments files are renamed
     //    to .deleted. Clean operation is resumed during recovery.
@@ -1169,7 +1184,7 @@ class LogCleanerTest extends JUnitSuite {
     cleaner.cleanSegments(log, log.logSegments.take(9).toSeq, offsetMap, 0L, new CleanerStats())
     // clear scheduler so that async deletes don't run
     time.scheduler.clear()
-    cleanedKeys = keysInLog(log)
+    cleanedKeys = LogTest.keysInLog(log)
     log.close()
 
     // 4) Simulate recovery after swap is complete, but async deletion
@@ -1375,7 +1390,7 @@ class LogCleanerTest extends JUnitSuite {
     for ((key, value) <- seq) yield log.appendAsLeader(record(key, value), leaderEpoch = 0).firstOffset.get
   }
 
-  private def key(id: Int) = ByteBuffer.wrap(id.toString.getBytes)
+  private def key(id: Long) = ByteBuffer.wrap(id.toString.getBytes)
 
   private def record(key: Int, value: Int,
              producerId: Long = RecordBatch.NO_PRODUCER_ID,
@@ -1429,6 +1444,9 @@ class LogCleanerTest extends JUnitSuite {
 
   private def tombstoneRecord(key: Int): MemoryRecords = record(key, null)
 
+  private def recoverAndCheck(config: LogConfig, expectedKeys: Iterable[Long]): Log = {
+    LogTest.recoverAndCheck(dir, config, expectedKeys, new BrokerTopicStats())
+  }
 }
 
 class FakeOffsetMap(val slots: Int) extends OffsetMap {
diff --git a/core/src/test/scala/unit/kafka/log/LogTest.scala b/core/src/test/scala/unit/kafka/log/LogTest.scala
index bf74a3e..1171e5e 100755
--- a/core/src/test/scala/unit/kafka/log/LogTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogTest.scala
@@ -19,18 +19,16 @@ package kafka.log
 
 import java.io._
 import java.nio.ByteBuffer
-import java.nio.file.Files
+import java.nio.file.{Files, Paths}
 import java.util.Properties
 
-import org.apache.kafka.common.errors._
 import kafka.common.KafkaException
 import kafka.log.Log.DeleteDirSuffix
-import org.junit.Assert._
-import org.junit.{After, Before, Test}
-import kafka.utils._
-import kafka.server.{BrokerTopicStats, FetchDataInfo, KafkaConfig, LogDirFailureChannel}
 import kafka.server.epoch.{EpochEntry, LeaderEpochCache, LeaderEpochFileCache}
+import kafka.server.{BrokerTopicStats, FetchDataInfo, KafkaConfig, LogDirFailureChannel}
+import kafka.utils._
 import org.apache.kafka.common.TopicPartition
+import org.apache.kafka.common.errors._
 import org.apache.kafka.common.record.MemoryRecords.RecordFilter
 import org.apache.kafka.common.record.MemoryRecords.RecordFilter.BatchRetention
 import org.apache.kafka.common.record._
@@ -38,17 +36,19 @@ import org.apache.kafka.common.requests.FetchResponse.AbortedTransaction
 import org.apache.kafka.common.requests.IsolationLevel
 import org.apache.kafka.common.utils.{Time, Utils}
 import org.easymock.EasyMock
+import org.junit.Assert._
+import org.junit.{After, Before, Test}
 
+import scala.collection.Iterable
 import scala.collection.JavaConverters._
 import scala.collection.mutable.{ArrayBuffer, ListBuffer}
 
 class LogTest {
-
+  var config: KafkaConfig = null
+  val brokerTopicStats = new BrokerTopicStats
   val tmpDir = TestUtils.tempDir()
   val logDir = TestUtils.randomPartitionLogDir(tmpDir)
   val mockTime = new MockTime()
-  var config: KafkaConfig = null
-  val brokerTopicStats = new BrokerTopicStats
 
   @Before
   def setUp() {
@@ -93,10 +93,10 @@ class LogTest {
   @Test
   def testTimeBasedLogRoll() {
     def createRecords = TestUtils.singletonRecords("test".getBytes)
-    val logConfig = createLogConfig(segmentMs = 1 * 60 * 60L)
+    val logConfig = LogTest.createLogConfig(segmentMs = 1 * 60 * 60L)
 
     // create a log
-    val log = createLog(logDir, logConfig, maxProducerIdExpirationMs = 24 * 60)
+    val log = createLog(logDir, logConfig, maxProducerIdExpirationMs = 24 * 60, brokerTopicStats = brokerTopicStats)
     assertEquals("Log begins with a single empty segment.", 1, log.numberOfSegments)
     // Test the segment rolling behavior when messages do not have a timestamp.
     mockTime.sleep(log.config.segmentMs + 1)
@@ -141,7 +141,7 @@ class LogTest {
   @Test(expected = classOf[OutOfOrderSequenceException])
   def testNonSequentialAppend(): Unit = {
     // create a log
-    val log = createLog(logDir, LogConfig())
+    val log = createLog(logDir, LogConfig(), brokerTopicStats)
     val pid = 1L
     val epoch: Short = 0
 
@@ -154,7 +154,7 @@ class LogTest {
 
   @Test
   def testTruncateToEmptySegment(): Unit = {
-    val log = createLog(logDir, LogConfig())
+    val log = createLog(logDir, LogConfig(), brokerTopicStats)
 
     // Force a segment roll by using a large offset. The first segment will be empty
     val records = TestUtils.records(List(new SimpleRecord(mockTime.milliseconds, "key".getBytes, "value".getBytes)),
@@ -178,8 +178,8 @@ class LogTest {
   def testInitializationOfProducerSnapshotsUpgradePath(): Unit = {
     // simulate the upgrade path by creating a new log with several segments, deleting the
     // snapshot files, and then reloading the log
-    val logConfig = createLogConfig(segmentBytes = 64 * 10)
-    var log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 64 * 10)
+    var log = createLog(logDir, logConfig, brokerTopicStats)
     assertEquals(None, log.oldestProducerSnapshotOffset)
 
     for (i <- 0 to 100) {
@@ -194,7 +194,7 @@ class LogTest {
     deleteProducerSnapshotFiles()
 
     // Reload after clean shutdown
-    log = createLog(logDir, logConfig, recoveryPoint = logEndOffset)
+    log = createLog(logDir, logConfig, brokerTopicStats, recoveryPoint = logEndOffset)
     var expectedSnapshotOffsets = log.logSegments.map(_.baseOffset).takeRight(2).toVector :+ log.logEndOffset
     assertEquals(expectedSnapshotOffsets, listProducerSnapshotOffsets)
     log.close()
@@ -203,7 +203,7 @@ class LogTest {
     deleteProducerSnapshotFiles()
 
     // Reload after unclean shutdown with recoveryPoint set to log end offset
-    log = createLog(logDir, logConfig, recoveryPoint = logEndOffset)
+    log = createLog(logDir, logConfig, brokerTopicStats, recoveryPoint = logEndOffset)
     // Note that we don't maintain the guarantee of having a snapshot for the 2 most recent segments in this case
     expectedSnapshotOffsets = Vector(log.logSegments.last.baseOffset, log.logEndOffset)
     assertEquals(expectedSnapshotOffsets, listProducerSnapshotOffsets)
@@ -212,7 +212,7 @@ class LogTest {
     deleteProducerSnapshotFiles()
 
     // Reload after unclean shutdown with recoveryPoint set to 0
-    log = createLog(logDir, logConfig, recoveryPoint = 0L)
+    log = createLog(logDir, logConfig, brokerTopicStats, recoveryPoint = 0L)
     // Is this working as intended?
     expectedSnapshotOffsets = log.logSegments.map(_.baseOffset).tail.toVector :+ log.logEndOffset
     assertEquals(expectedSnapshotOffsets, listProducerSnapshotOffsets)
@@ -221,8 +221,8 @@ class LogTest {
 
   @Test
   def testProducerSnapshotsRecoveryAfterUncleanShutdown(): Unit = {
-    val logConfig = createLogConfig(segmentBytes = 64 * 10)
-    var log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 64 * 10)
+    var log = createLog(logDir, logConfig, brokerTopicStats)
     assertEquals(None, log.oldestProducerSnapshotOffset)
 
     for (i <- 0 to 100) {
@@ -320,8 +320,8 @@ class LogTest {
 
   @Test
   def testProducerIdMapOffsetUpdatedForNonIdempotentData() {
-    val logConfig = createLogConfig(segmentBytes = 2048 * 5)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 2048 * 5)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     val records = TestUtils.records(List(new SimpleRecord(mockTime.milliseconds, "key".getBytes, "value".getBytes)))
     log.appendAsLeader(records, leaderEpoch = 0)
     log.takeProducerSnapshot()
@@ -511,8 +511,8 @@ class LogTest {
 
   @Test
   def testRebuildProducerIdMapWithCompactedData() {
-    val logConfig = createLogConfig(segmentBytes = 2048 * 5)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 2048 * 5)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     val pid = 1L
     val epoch = 0.toShort
     val seq = 0
@@ -554,8 +554,8 @@ class LogTest {
 
   @Test
   def testRebuildProducerStateWithEmptyCompactedBatch() {
-    val logConfig = createLogConfig(segmentBytes = 2048 * 5)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 2048 * 5)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     val pid = 1L
     val epoch = 0.toShort
     val seq = 0
@@ -595,8 +595,8 @@ class LogTest {
 
   @Test
   def testUpdateProducerIdMapWithCompactedData() {
-    val logConfig = createLogConfig(segmentBytes = 2048 * 5)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 2048 * 5)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     val pid = 1L
     val epoch = 0.toShort
     val seq = 0
@@ -628,8 +628,8 @@ class LogTest {
 
   @Test
   def testProducerIdMapTruncateTo() {
-    val logConfig = createLogConfig(segmentBytes = 2048 * 5)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 2048 * 5)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes))), leaderEpoch = 0)
     log.appendAsLeader(TestUtils.records(List(new SimpleRecord("b".getBytes))), leaderEpoch = 0)
     log.takeProducerSnapshot()
@@ -649,8 +649,8 @@ class LogTest {
   @Test
   def testProducerIdMapTruncateToWithNoSnapshots() {
     // This ensures that the upgrade optimization path cannot be hit after initial loading
-    val logConfig = createLogConfig(segmentBytes = 2048 * 5)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 2048 * 5)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     val pid = 1L
     val epoch = 0.toShort
 
@@ -673,8 +673,8 @@ class LogTest {
 
   @Test
   def testLoadProducersAfterDeleteRecordsMidSegment(): Unit = {
-    val logConfig = createLogConfig(segmentBytes = 2048 * 5)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 2048 * 5)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     val pid1 = 1L
     val pid2 = 2L
     val epoch = 0.toShort
@@ -694,7 +694,7 @@ class LogTest {
 
     log.close()
 
-    val reloadedLog = createLog(logDir, logConfig, logStartOffset = 1L)
+    val reloadedLog = createLog(logDir, logConfig, logStartOffset = 1L, brokerTopicStats = brokerTopicStats)
     assertEquals(1, reloadedLog.activeProducersWithLastSequence.size)
     val reloadedLastSeqOpt = log.activeProducersWithLastSequence.get(pid2)
     assertEquals(retainedLastSeqOpt, reloadedLastSeqOpt)
@@ -702,8 +702,8 @@ class LogTest {
 
   @Test
   def testLoadProducersAfterDeleteRecordsOnSegment(): Unit = {
-    val logConfig = createLogConfig(segmentBytes = 2048 * 5)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 2048 * 5)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     val pid1 = 1L
     val pid2 = 2L
     val epoch = 0.toShort
@@ -729,7 +729,7 @@ class LogTest {
 
     log.close()
 
-    val reloadedLog = createLog(logDir, logConfig, logStartOffset = 1L)
+    val reloadedLog = createLog(logDir, logConfig, logStartOffset = 1L, brokerTopicStats = brokerTopicStats)
     assertEquals(1, reloadedLog.activeProducersWithLastSequence.size)
     val reloadedEntryOpt = log.activeProducersWithLastSequence.get(pid2)
     assertEquals(retainedLastSeqOpt, reloadedEntryOpt)
@@ -738,8 +738,8 @@ class LogTest {
   @Test
   def testProducerIdMapTruncateFullyAndStartAt() {
     val records = TestUtils.singletonRecords("foo".getBytes)
-    val logConfig = createLogConfig(segmentBytes = records.sizeInBytes, retentionBytes = records.sizeInBytes * 2)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = records.sizeInBytes, retentionBytes = records.sizeInBytes * 2)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     log.appendAsLeader(records, leaderEpoch = 0)
     log.takeProducerSnapshot()
 
@@ -761,8 +761,8 @@ class LogTest {
   def testProducerIdExpirationOnSegmentDeletion() {
     val pid1 = 1L
     val records = TestUtils.records(Seq(new SimpleRecord("foo".getBytes)), producerId = pid1, producerEpoch = 0, sequence = 0)
-    val logConfig = createLogConfig(segmentBytes = records.sizeInBytes, retentionBytes = records.sizeInBytes * 2)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = records.sizeInBytes, retentionBytes = records.sizeInBytes * 2)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     log.appendAsLeader(records, leaderEpoch = 0)
     log.takeProducerSnapshot()
 
@@ -785,8 +785,8 @@ class LogTest {
 
   @Test
   def testTakeSnapshotOnRollAndDeleteSnapshotOnRecoveryPointCheckpoint() {
-    val logConfig = createLogConfig(segmentBytes = 2048 * 5)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 2048 * 5)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     log.appendAsLeader(TestUtils.singletonRecords("a".getBytes), leaderEpoch = 0)
     log.roll(1L)
     assertEquals(Some(1L), log.latestProducerSnapshotOffset)
@@ -818,8 +818,8 @@ class LogTest {
   @Test
   def testProducerSnapshotAfterSegmentRollOnAppend(): Unit = {
     val producerId = 1L
-    val logConfig = createLogConfig(segmentBytes = 1024)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 1024)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
 
     log.appendAsLeader(TestUtils.records(Seq(new SimpleRecord(mockTime.milliseconds(), new Array[Byte](512))),
       producerId = producerId, producerEpoch = 0, sequence = 0),
@@ -850,8 +850,8 @@ class LogTest {
 
   @Test
   def testRebuildTransactionalState(): Unit = {
-    val logConfig = createLogConfig(segmentBytes = 1024 * 1024 * 5)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 1024 * 1024 * 5)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
 
     val pid = 137L
     val epoch = 5.toShort
@@ -872,7 +872,7 @@ class LogTest {
 
     log.close()
 
-    val reopenedLog = createLog(logDir, logConfig)
+    val reopenedLog = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     reopenedLog.onHighWatermarkIncremented(commitAppendInfo.lastOffset + 1)
     assertEquals(None, reopenedLog.firstUnstableOffset)
   }
@@ -893,9 +893,9 @@ class LogTest {
     val producerIdExpirationCheckIntervalMs = 100
 
     val pid = 23L
-    val logConfig = createLogConfig(segmentBytes = 2048 * 5)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 2048 * 5)
     val log = createLog(logDir, logConfig, maxProducerIdExpirationMs = maxProducerIdExpirationMs,
-      producerIdExpirationCheckIntervalMs = producerIdExpirationCheckIntervalMs)
+      producerIdExpirationCheckIntervalMs = producerIdExpirationCheckIntervalMs, brokerTopicStats = brokerTopicStats)
     val records = Seq(new SimpleRecord(mockTime.milliseconds(), "foo".getBytes))
     log.appendAsLeader(TestUtils.records(records, producerId = pid, producerEpoch = 0, sequence = 0), leaderEpoch = 0)
 
@@ -911,7 +911,7 @@ class LogTest {
   @Test
   def testDuplicateAppends(): Unit = {
     // create a log
-    val log = createLog(logDir, LogConfig())
+    val log = createLog(logDir, LogConfig(), brokerTopicStats = brokerTopicStats)
     val pid = 1L
     val epoch: Short = 0
 
@@ -985,7 +985,7 @@ class LogTest {
   @Test
   def testMultipleProducerIdsPerMemoryRecord() : Unit = {
     // create a log
-    val log = createLog(logDir, LogConfig())
+    val log = createLog(logDir, LogConfig(), brokerTopicStats = brokerTopicStats)
 
     val epoch: Short = 0
     val buffer = ByteBuffer.allocate(512)
@@ -1030,8 +1030,8 @@ class LogTest {
 
   @Test
   def testDuplicateAppendToFollower() : Unit = {
-    val logConfig = createLogConfig(segmentBytes = 1024 * 1024 * 5)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 1024 * 1024 * 5)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     val epoch: Short = 0
     val pid = 1L
     val baseSequence = 0
@@ -1051,8 +1051,8 @@ class LogTest {
 
   @Test
   def testMultipleProducersWithDuplicatesInSingleAppend() : Unit = {
-    val logConfig = createLogConfig(segmentBytes = 1024 * 1024 * 5)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 1024 * 1024 * 5)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
 
     val pid1 = 1L
     val pid2 = 2L
@@ -1104,7 +1104,7 @@ class LogTest {
   @Test(expected = classOf[ProducerFencedException])
   def testOldProducerEpoch(): Unit = {
     // create a log
-    val log = createLog(logDir, LogConfig())
+    val log = createLog(logDir, LogConfig(), brokerTopicStats = brokerTopicStats)
     val pid = 1L
     val newEpoch: Short = 1
     val oldEpoch: Short = 0
@@ -1125,8 +1125,8 @@ class LogTest {
     var set = TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds)
     val maxJitter = 20 * 60L
     // create a log
-    val logConfig = createLogConfig(segmentMs = 1 * 60 * 60L, segmentJitterMs = maxJitter)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentMs = 1 * 60 * 60L, segmentJitterMs = maxJitter)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     assertEquals("Log begins with a single empty segment.", 1, log.numberOfSegments)
     log.appendAsLeader(set, leaderEpoch = 0)
 
@@ -1150,8 +1150,8 @@ class LogTest {
     val msgPerSeg = 10
     val segmentSize = msgPerSeg * (setSize - 1) // each segment will be 10 messages
     // create a log
-    val logConfig = createLogConfig(segmentBytes = segmentSize)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = segmentSize)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     assertEquals("There should be exactly 1 segment.", 1, log.numberOfSegments)
 
     // segments expire in size
@@ -1166,7 +1166,7 @@ class LogTest {
   @Test
   def testLoadEmptyLog() {
     createEmptyLogs(logDir, 0)
-    val log = createLog(logDir, LogConfig())
+    val log = createLog(logDir, LogConfig(), brokerTopicStats = brokerTopicStats)
     log.appendAsLeader(TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds), leaderEpoch = 0)
   }
 
@@ -1175,8 +1175,8 @@ class LogTest {
    */
   @Test
   def testAppendAndReadWithSequentialOffsets() {
-    val logConfig = createLogConfig(segmentBytes = 71)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 71)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     val values = (0 until 100 by 2).map(id => id.toString.getBytes).toArray
 
     for(value <- values)
@@ -1199,8 +1199,8 @@ class LogTest {
    */
   @Test
   def testAppendAndReadWithNonSequentialOffsets() {
-    val logConfig = createLogConfig(segmentBytes = 72)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 72)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     val messageIds = ((0 until 50) ++ (50 until 200 by 7)).toArray
     val records = messageIds.map(id => new SimpleRecord(id.toString.getBytes))
 
@@ -1223,8 +1223,8 @@ class LogTest {
    */
   @Test
   def testReadAtLogGap() {
-    val logConfig = createLogConfig(segmentBytes = 300)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 300)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
 
     // keep appending until we have two segments with only a single message in the second segment
     while(log.numberOfSegments == 1)
@@ -1239,16 +1239,16 @@ class LogTest {
 
   @Test(expected = classOf[KafkaStorageException])
   def testLogRollAfterLogHandlerClosed() {
-    val logConfig = createLogConfig()
-    val log = createLog(logDir,  logConfig)
+    val logConfig = LogTest.createLogConfig()
+    val log = createLog(logDir,  logConfig, brokerTopicStats = brokerTopicStats)
     log.closeHandlers()
     log.roll(1)
   }
 
   @Test
   def testReadWithMinMessage() {
-    val logConfig = createLogConfig(segmentBytes = 72)
-    val log = createLog(logDir,  logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 72)
+    val log = createLog(logDir,  logConfig, brokerTopicStats = brokerTopicStats)
     val messageIds = ((0 until 50) ++ (50 until 200 by 7)).toArray
     val records = messageIds.map(id => new SimpleRecord(id.toString.getBytes))
 
@@ -1274,8 +1274,8 @@ class LogTest {
 
   @Test
   def testReadWithTooSmallMaxLength() {
-    val logConfig = createLogConfig(segmentBytes = 72)
-    val log = createLog(logDir,  logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 72)
+    val log = createLog(logDir,  logConfig, brokerTopicStats = brokerTopicStats)
     val messageIds = ((0 until 50) ++ (50 until 200 by 7)).toArray
     val records = messageIds.map(id => new SimpleRecord(id.toString.getBytes))
 
@@ -1308,8 +1308,8 @@ class LogTest {
   def testReadOutOfRange() {
     createEmptyLogs(logDir, 1024)
     // set up replica log starting with offset 1024 and with one message (at offset 1024)
-    val logConfig = createLogConfig(segmentBytes = 1024)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 1024)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     log.appendAsLeader(TestUtils.singletonRecords(value = "42".getBytes), leaderEpoch = 0)
 
     assertEquals("Reading at the log end offset should produce 0 byte read.", 0,
@@ -1340,8 +1340,8 @@ class LogTest {
   @Test
   def testLogRolls() {
     /* create a multipart log with 100 messages */
-    val logConfig = createLogConfig(segmentBytes = 100)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 100)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     val numMessages = 100
     val messageSets = (0 until numMessages).map(i => TestUtils.singletonRecords(value = i.toString.getBytes,
                                                                                 timestamp = mockTime.milliseconds))
@@ -1378,8 +1378,8 @@ class LogTest {
   @Test
   def testCompressedMessages() {
     /* this log should roll after every messageset */
-    val logConfig = createLogConfig(segmentBytes = 110)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 110)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
 
     /* append 2 compressed message sets, each with two messages giving offsets 0, 1, 2, 3 */
     log.appendAsLeader(MemoryRecords.withRecords(CompressionType.GZIP, new SimpleRecord("hello".getBytes), new SimpleRecord("there".getBytes)), leaderEpoch = 0)
@@ -1402,8 +1402,8 @@ class LogTest {
     for(messagesToAppend <- List(0, 1, 25)) {
       logDir.mkdirs()
       // first test a log segment starting at 0
-      val logConfig = createLogConfig(segmentBytes = 100, retentionMs = 0)
-      val log = createLog(logDir, logConfig)
+      val logConfig = LogTest.createLogConfig(segmentBytes = 100, retentionMs = 0)
+      val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
       for(i <- 0 until messagesToAppend)
         log.appendAsLeader(TestUtils.singletonRecords(value = i.toString.getBytes, timestamp = mockTime.milliseconds - 10), leaderEpoch = 0)
 
@@ -1436,8 +1436,8 @@ class LogTest {
     val messageSet = MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("You".getBytes), new SimpleRecord("bethe".getBytes))
     // append messages to log
     val configSegmentSize = messageSet.sizeInBytes - 1
-    val logConfig = createLogConfig(segmentBytes = configSegmentSize)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = configSegmentSize)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
 
     try {
       log.appendAsLeader(messageSet, leaderEpoch = 0)
@@ -1461,8 +1461,8 @@ class LogTest {
     val messageSetWithKeyedMessage = MemoryRecords.withRecords(CompressionType.NONE, keyedMessage)
     val messageSetWithKeyedMessages = MemoryRecords.withRecords(CompressionType.NONE, keyedMessage, anotherKeyedMessage)
 
-    val logConfig = createLogConfig(cleanupPolicy = LogConfig.Compact)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(cleanupPolicy = LogConfig.Compact)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
 
     try {
       log.appendAsLeader(messageSetWithUnkeyedMessage, leaderEpoch = 0)
@@ -1502,8 +1502,8 @@ class LogTest {
 
     // append messages to log
     val maxMessageSize = second.sizeInBytes - 1
-    val logConfig = createLogConfig(maxMessageBytes = maxMessageSize)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(maxMessageBytes = maxMessageSize)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
 
     // should be able to append the small message
     log.appendAsLeader(first, leaderEpoch = 0)
@@ -1524,8 +1524,8 @@ class LogTest {
     val messageSize = 100
     val segmentSize = 7 * messageSize
     val indexInterval = 3 * messageSize
-    val logConfig = createLogConfig(segmentBytes = segmentSize, indexIntervalBytes = indexInterval, segmentIndexBytes = 4096)
-    var log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = segmentSize, indexIntervalBytes = indexInterval, segmentIndexBytes = 4096)
+    var log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     for(i <- 0 until numMessages)
       log.appendAsLeader(TestUtils.singletonRecords(value = TestUtils.randomBytes(messageSize),
         timestamp = mockTime.milliseconds + i * 10), leaderEpoch = 0)
@@ -1552,12 +1552,12 @@ class LogTest {
       assertEquals("Should have same number of time index entries as before.", numTimeIndexEntries, log.activeSegment.timeIndex.entries)
     }
 
-    log = createLog(logDir, logConfig, recoveryPoint = lastOffset)
+    log = createLog(logDir, logConfig, recoveryPoint = lastOffset, brokerTopicStats = brokerTopicStats)
     verifyRecoveredLog(log, lastOffset)
     log.close()
 
     // test recovery case
-    log = createLog(logDir, logConfig)
+    log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     verifyRecoveredLog(log, lastOffset)
     log.close()
   }
@@ -1568,8 +1568,8 @@ class LogTest {
   @Test
   def testBuildTimeIndexWhenNotAssigningOffsets() {
     val numMessages = 100
-    val logConfig = createLogConfig(segmentBytes = 10000, indexIntervalBytes = 1)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 10000, indexIntervalBytes = 1)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
 
     val messages = (0 until numMessages).map { i =>
       MemoryRecords.withRecords(100 + i, CompressionType.NONE, 0, new SimpleRecord(mockTime.milliseconds + i, i.toString.getBytes()))
@@ -1588,8 +1588,8 @@ class LogTest {
   def testIndexRebuild() {
     // publish the messages and close the log
     val numMessages = 200
-    val logConfig = createLogConfig(segmentBytes = 200, indexIntervalBytes = 1)
-    var log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 200, indexIntervalBytes = 1)
+    var log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     for(i <- 0 until numMessages)
       log.appendAsLeader(TestUtils.singletonRecords(value = TestUtils.randomBytes(10), timestamp = mockTime.milliseconds + i * 10), leaderEpoch = 0)
     val indexFiles = log.logSegments.map(_.offsetIndex.file)
@@ -1601,7 +1601,7 @@ class LogTest {
     timeIndexFiles.foreach(_.delete())
 
     // reopen the log
-    log = createLog(logDir, logConfig)
+    log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     assertEquals("Should have %d messages when log is reopened".format(numMessages), numMessages, log.logEndOffset)
     assertTrue("The index should have been rebuilt", log.logSegments.head.offsetIndex.entries > 0)
     assertTrue("The time index should have been rebuilt", log.logSegments.head.timeIndex.entries > 0)
@@ -1622,8 +1622,8 @@ class LogTest {
   def testRebuildTimeIndexForOldMessages() {
     val numMessages = 200
     val segmentSize = 200
-    val logConfig = createLogConfig(segmentBytes = segmentSize, indexIntervalBytes = 1, messageFormatVersion = "0.9.0")
-    var log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = segmentSize, indexIntervalBytes = 1, messageFormatVersion = "0.9.0")
+    var log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     for (i <- 0 until numMessages)
       log.appendAsLeader(TestUtils.singletonRecords(value = TestUtils.randomBytes(10),
         timestamp = mockTime.milliseconds + i * 10, magicValue = RecordBatch.MAGIC_VALUE_V1), leaderEpoch = 0)
@@ -1634,7 +1634,7 @@ class LogTest {
     timeIndexFiles.foreach(file => Files.delete(file.toPath))
 
     // The rebuilt time index should be empty
-    log = createLog(logDir, logConfig, recoveryPoint = numMessages + 1)
+    log = createLog(logDir, logConfig, recoveryPoint = numMessages + 1, brokerTopicStats = brokerTopicStats)
     for (segment <- log.logSegments.init) {
       assertEquals("The time index should be empty", 0, segment.timeIndex.entries)
       assertEquals("The time index file size should be 0", 0, segment.timeIndex.file.length)
@@ -1648,8 +1648,8 @@ class LogTest {
   def testCorruptIndexRebuild() {
     // publish the messages and close the log
     val numMessages = 200
-    val logConfig = createLogConfig(segmentBytes = 200, indexIntervalBytes = 1)
-    var log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 200, indexIntervalBytes = 1)
+    var log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     for(i <- 0 until numMessages)
       log.appendAsLeader(TestUtils.singletonRecords(value = TestUtils.randomBytes(10), timestamp = mockTime.milliseconds + i * 10), leaderEpoch = 0)
     val indexFiles = log.logSegments.map(_.offsetIndex.file)
@@ -1671,7 +1671,7 @@ class LogTest {
     }
 
     // reopen the log
-    log = createLog(logDir, logConfig, recoveryPoint = 200L)
+    log = createLog(logDir, logConfig, recoveryPoint = 200L, brokerTopicStats = brokerTopicStats)
     assertEquals("Should have %d messages when log is reopened".format(numMessages), numMessages, log.logEndOffset)
     for(i <- 0 until numMessages) {
       assertEquals(i, log.readUncommitted(i, 100, None).records.batches.iterator.next().lastOffset)
@@ -1694,8 +1694,8 @@ class LogTest {
     val segmentSize = msgPerSeg * setSize  // each segment will be 10 messages
 
     // create a log
-    val logConfig = createLogConfig(segmentBytes = segmentSize)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = segmentSize)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     assertEquals("There should be exactly 1 segment.", 1, log.numberOfSegments)
 
     for (_ <- 1 to msgPerSeg)
@@ -1746,8 +1746,8 @@ class LogTest {
     val setSize = TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds).sizeInBytes
     val msgPerSeg = 10
     val segmentSize = msgPerSeg * setSize  // each segment will be 10 messages
-    val logConfig = createLogConfig(segmentBytes = segmentSize, indexIntervalBytes = setSize - 1)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = segmentSize, indexIntervalBytes = setSize - 1)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     assertEquals("There should be exactly 1 segment.", 1, log.numberOfSegments)
 
     for (i<- 1 to msgPerSeg)
@@ -1785,8 +1785,8 @@ class LogTest {
     val bogusTimeIndex2 = Log.timeIndexFile(logDir, 5)
 
     def createRecords = TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds)
-    val logConfig = createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, segmentIndexBytes = 1000, indexIntervalBytes = 1)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, segmentIndexBytes = 1000, indexIntervalBytes = 1)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
 
     assertTrue("The first index file should have been replaced with a larger file", bogusIndex1.length > 0)
     assertTrue("The first time index file should have been replaced with a larger file", bogusTimeIndex1.length > 0)
@@ -1807,14 +1807,14 @@ class LogTest {
   def testReopenThenTruncate() {
     def createRecords = TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds)
     // create a log
-    val logConfig = createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, segmentIndexBytes = 1000, indexIntervalBytes = 10000)
-    var log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, segmentIndexBytes = 1000, indexIntervalBytes = 10000)
+    var log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
 
     // add enough messages to roll over several segments then close and re-open and attempt to truncate
     for (_ <- 0 until 100)
       log.appendAsLeader(createRecords, leaderEpoch = 0)
     log.close()
-    log = createLog(logDir, logConfig)
+    log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     log.truncateTo(3)
     assertEquals("All but one segment should be deleted.", 1, log.numberOfSegments)
     assertEquals("Log end offset should be 3.", 3, log.logEndOffset)
@@ -1827,9 +1827,9 @@ class LogTest {
   def testAsyncDelete() {
     def createRecords = TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds - 1000L)
     val asyncDeleteMs = 1000
-    val logConfig = createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, segmentIndexBytes = 1000, indexIntervalBytes = 10000,
+    val logConfig = LogTest.createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, segmentIndexBytes = 1000, indexIntervalBytes = 10000,
                                     retentionMs = 999, fileDeleteDelayMs = asyncDeleteMs)
-    val log = createLog(logDir, logConfig)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
 
     // append some messages to create some segments
     for (_ <- 0 until 100)
@@ -1861,8 +1861,8 @@ class LogTest {
   @Test
   def testOpenDeletesObsoleteFiles() {
     def createRecords = TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds - 1000)
-    val logConfig = createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, segmentIndexBytes = 1000, retentionMs = 999)
-    var log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, segmentIndexBytes = 1000, retentionMs = 999)
+    var log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
 
     // append some messages to create some segments
     for (_ <- 0 until 100)
@@ -1872,13 +1872,13 @@ class LogTest {
     log.onHighWatermarkIncremented(log.logEndOffset)
     log.deleteOldSegments()
     log.close()
-    log = createLog(logDir, logConfig)
+    log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     assertEquals("The deleted segments should be gone.", 1, log.numberOfSegments)
   }
 
   @Test
   def testAppendMessageWithNullPayload() {
-    val log = createLog(logDir, LogConfig())
+    val log = createLog(logDir, LogConfig(), brokerTopicStats = brokerTopicStats)
     log.appendAsLeader(TestUtils.singletonRecords(value = null), leaderEpoch = 0)
     val head = log.readUncommitted(0, 4096, None).records.records.iterator.next()
     assertEquals(0, head.offset)
@@ -1887,7 +1887,7 @@ class LogTest {
 
   @Test(expected = classOf[IllegalArgumentException])
   def testAppendWithOutOfOrderOffsetsThrowsException() {
-    val log = createLog(logDir, LogConfig())
+    val log = createLog(logDir, LogConfig(), brokerTopicStats = brokerTopicStats)
     val records = (0 until 2).map(id => new SimpleRecord(id.toString.getBytes)).toArray
     records.foreach(record => log.appendAsLeader(MemoryRecords.withRecords(CompressionType.NONE, record), leaderEpoch = 0))
     val invalidRecord = MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord(1.toString.getBytes))
@@ -1896,7 +1896,7 @@ class LogTest {
 
   @Test
   def testAppendWithNoTimestamp(): Unit = {
-    val log = createLog(logDir, LogConfig())
+    val log = createLog(logDir, LogConfig(), brokerTopicStats = brokerTopicStats)
     log.appendAsLeader(MemoryRecords.withRecords(CompressionType.NONE,
       new SimpleRecord(RecordBatch.NO_TIMESTAMP, "key".getBytes, "value".getBytes)), leaderEpoch = 0)
   }
@@ -1904,13 +1904,13 @@ class LogTest {
   @Test
   def testCorruptLog() {
     // append some messages to create some segments
-    val logConfig = createLogConfig(segmentBytes = 1000, indexIntervalBytes = 1, maxMessageBytes = 64 * 1024)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 1000, indexIntervalBytes = 1, maxMessageBytes = 64 * 1024)
     def createRecords = TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds)
     val recoveryPoint = 50L
     for (_ <- 0 until 10) {
       // create a log and write some messages to it
       logDir.mkdirs()
-      var log = createLog(logDir, logConfig)
+      var log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
       val numMessages = 50 + TestUtils.random.nextInt(50)
       for (_ <- 0 until numMessages)
         log.appendAsLeader(createRecords, leaderEpoch = 0)
@@ -1922,7 +1922,7 @@ class LogTest {
       TestUtils.appendNonsenseToFile(log.activeSegment.log.file, TestUtils.random.nextInt(1024) + 1)
 
       // attempt recovery
-      log = createLog(logDir, logConfig, 0L, recoveryPoint)
+      log = createLog(logDir, logConfig, brokerTopicStats, 0L, recoveryPoint)
       assertEquals(numMessages, log.logEndOffset)
 
       val recovered = log.logSegments.flatMap(_.log.records.asScala.toList).toList
@@ -1943,8 +1943,8 @@ class LogTest {
   @Test
   def testOverCompactedLogRecovery(): Unit = {
     // append some messages to create some segments
-    val logConfig = createLogConfig(segmentBytes = 1000, indexIntervalBytes = 1, maxMessageBytes = 64 * 1024)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 1000, indexIntervalBytes = 1, maxMessageBytes = 64 * 1024)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     val set1 = MemoryRecords.withRecords(0, CompressionType.NONE, 0, new SimpleRecord("v1".getBytes(), "k1".getBytes()))
     val set2 = MemoryRecords.withRecords(Integer.MAX_VALUE.toLong + 2, CompressionType.NONE, 0, new SimpleRecord("v3".getBytes(), "k3".getBytes()))
     val set3 = MemoryRecords.withRecords(Integer.MAX_VALUE.toLong + 3, CompressionType.NONE, 0, new SimpleRecord("v4".getBytes(), "k4".getBytes()))
@@ -1976,8 +1976,8 @@ class LogTest {
   @Test
   def testOverCompactedLogRecoveryMultiRecord(): Unit = {
     // append some messages to create some segments
-    val logConfig = createLogConfig(segmentBytes = 1000, indexIntervalBytes = 1, maxMessageBytes = 64 * 1024)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 1000, indexIntervalBytes = 1, maxMessageBytes = 64 * 1024)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     val set1 = MemoryRecords.withRecords(0, CompressionType.NONE, 0, new SimpleRecord("v1".getBytes(), "k1".getBytes()))
     val set2 = MemoryRecords.withRecords(Integer.MAX_VALUE.toLong + 2, CompressionType.GZIP, 0,
       new SimpleRecord("v3".getBytes(), "k3".getBytes()),
@@ -2015,8 +2015,8 @@ class LogTest {
   @Test
   def testOverCompactedLogRecoveryMultiRecordV1(): Unit = {
     // append some messages to create some segments
-    val logConfig = createLogConfig(segmentBytes = 1000, indexIntervalBytes = 1, maxMessageBytes = 64 * 1024)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 1000, indexIntervalBytes = 1, maxMessageBytes = 64 * 1024)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     val set1 = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V1, 0, CompressionType.NONE,
       new SimpleRecord("v1".getBytes(), "k1".getBytes()))
     val set2 = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V1, Integer.MAX_VALUE.toLong + 2, CompressionType.GZIP,
@@ -2054,16 +2054,184 @@ class LogTest {
   }
 
   @Test
+  def testSplitOnOffsetOverflow(): Unit = {
+    // create a log such that one log segment has offsets that overflow, and call the split API on that segment
+    val logConfig = LogTest.createLogConfig(indexIntervalBytes = 1, fileDeleteDelayMs = 1000)
+    val (log, segmentWithOverflow, inputRecords) = createLogWithOffsetOverflow(Some(logConfig))
+    assertTrue("At least one segment must have offset overflow", LogTest.hasOffsetOverflow(log))
+
+    // split the segment with overflow
+    log.splitOverflowedSegment(segmentWithOverflow)
+
+    // assert we were successfully able to split the segment
+    assertEquals(log.numberOfSegments, 4)
+    assertTrue(LogTest.verifyRecordsInLog(log, inputRecords))
+
+    // verify we do not have offset overflow anymore
+    assertFalse(LogTest.hasOffsetOverflow(log))
+  }
+
+  @Test
+  def testRecoveryOfSegmentWithOffsetOverflow(): Unit = {
+    val logConfig = LogTest.createLogConfig(indexIntervalBytes = 1, fileDeleteDelayMs = 1000)
+    var (log, segmentWithOverflow, initialRecords) = createLogWithOffsetOverflow(Some(logConfig))
+    val expectedKeys = LogTest.keysInLog(log)
+
+    // Run recovery on the log. This should split the segment underneath. Ignore .deleted files as we could have still
+    // have them lying around after the split.
+    log = LogTest.recoverAndCheck(logDir, logConfig, expectedKeys, brokerTopicStats, expectDeletedFiles = true)
+    assertEquals(expectedKeys, LogTest.keysInLog(log))
+
+    // Running split again would throw an error
+    for (segment <- log.logSegments) {
+      try {
+        log.splitOverflowedSegment(segment)
+        fail()
+      } catch {
+        case _: IllegalArgumentException =>
+      }
+    }
+  }
+
+  @Test
+  def testRecoveryAfterCrashDuringSplitPhase1(): Unit = {
+    val logConfig = LogTest.createLogConfig(indexIntervalBytes = 1, fileDeleteDelayMs = 1000)
+    var (log, segmentWithOverflow, initialRecords) = createLogWithOffsetOverflow(Some(logConfig))
+    val expectedKeys = LogTest.keysInLog(log)
+    val numSegmentsInitial = log.logSegments.size
+
+    // Split the segment
+    val newSegments = log.splitOverflowedSegment(segmentWithOverflow)
+
+    // Simulate recovery just after .cleaned file is created, before rename to .swap. On recovery, existing split
+    // operation is aborted but the recovery process itself kicks off split which should complete.
+    newSegments.reverse.foreach(segment => {
+      segment.changeFileSuffixes("", Log.CleanedFileSuffix)
+      segment.truncateTo(0)
+    })
+    for (file <- logDir.listFiles if file.getName.endsWith(Log.DeletedFileSuffix))
+      Utils.atomicMoveWithFallback(file.toPath, Paths.get(CoreUtils.replaceSuffix(file.getPath, Log.DeletedFileSuffix, "")))
+    log = LogTest.recoverAndCheck(logDir, logConfig, expectedKeys, brokerTopicStats, expectDeletedFiles = true)
+    assertEquals(expectedKeys, LogTest.keysInLog(log))
+    assertEquals(numSegmentsInitial + 1, log.logSegments.size)
+    log.close()
+  }
+
+  @Test
+  def testRecoveryAfterCrashDuringSplitPhase2(): Unit = {
+    val logConfig = LogTest.createLogConfig(indexIntervalBytes = 1, fileDeleteDelayMs = 1000)
+    var (log, segmentWithOverflow, initialRecords) = createLogWithOffsetOverflow(Some(logConfig))
+    val expectedKeys = LogTest.keysInLog(log)
+    val numSegmentsInitial = log.logSegments.size
+
+    // Split the segment
+    val newSegments = log.splitOverflowedSegment(segmentWithOverflow)
+
+    // Simulate recovery just after one of the new segments has been renamed to .swap. On recovery, existing split
+    // operation is aborted but the recovery process itself kicks off split which should complete.
+    newSegments.reverse.foreach(segment => {
+      if (segment != newSegments.tail)
+        segment.changeFileSuffixes("", Log.CleanedFileSuffix)
+      else
+        segment.changeFileSuffixes("", Log.SwapFileSuffix)
+      segment.truncateTo(0)
+    })
+    for (file <- logDir.listFiles if file.getName.endsWith(Log.DeletedFileSuffix))
+      Utils.atomicMoveWithFallback(file.toPath, Paths.get(CoreUtils.replaceSuffix(file.getPath, Log.DeletedFileSuffix, "")))
+    log = LogTest.recoverAndCheck(logDir, logConfig, expectedKeys, brokerTopicStats, expectDeletedFiles = true)
+    assertEquals(expectedKeys, LogTest.keysInLog(log))
+    assertEquals(numSegmentsInitial + 1, log.logSegments.size)
+    log.close()
+  }
+
+  @Test
+  def testRecoveryAfterCrashDuringSplitPhase3(): Unit = {
+    val logConfig = LogTest.createLogConfig(indexIntervalBytes = 1, fileDeleteDelayMs = 1000)
+    var (log, segmentWithOverflow, initialRecords) = createLogWithOffsetOverflow(Some(logConfig))
+    val expectedKeys = LogTest.keysInLog(log)
+    val numSegmentsInitial = log.logSegments.size
+
+    // Split the segment
+    val newSegments = log.splitOverflowedSegment(segmentWithOverflow)
+
+    // Simulate recovery right after all new segments have been renamed to .swap. On recovery, existing split operation
+    // is completed and the old segment must be deleted.
+    newSegments.reverse.foreach(segment => {
+        segment.changeFileSuffixes("", Log.SwapFileSuffix)
+    })
+    for (file <- logDir.listFiles if file.getName.endsWith(Log.DeletedFileSuffix))
+      Utils.atomicMoveWithFallback(file.toPath, Paths.get(CoreUtils.replaceSuffix(file.getPath, Log.DeletedFileSuffix, "")))
+
+    // Truncate the old segment
+    segmentWithOverflow.truncateTo(0)
+
+    log = LogTest.recoverAndCheck(logDir, logConfig, expectedKeys, brokerTopicStats, expectDeletedFiles = true)
+    assertEquals(expectedKeys, LogTest.keysInLog(log))
+    assertEquals(numSegmentsInitial + 1, log.logSegments.size)
+    log.close()
+  }
+
+  @Test
+  def testRecoveryAfterCrashDuringSplitPhase4(): Unit = {
+    val logConfig = LogTest.createLogConfig(indexIntervalBytes = 1, fileDeleteDelayMs = 1000)
+    var (log, segmentWithOverflow, initialRecords) = createLogWithOffsetOverflow(Some(logConfig))
+    val expectedKeys = LogTest.keysInLog(log)
+    val numSegmentsInitial = log.logSegments.size
+
+    // Split the segment
+    val newSegments = log.splitOverflowedSegment(segmentWithOverflow)
+
+    // Simulate recovery right after all new segments have been renamed to .swap and old segment has been deleted. On
+    // recovery, existing split operation is completed.
+    newSegments.reverse.foreach(segment => {
+      segment.changeFileSuffixes("", Log.SwapFileSuffix)
+    })
+    for (file <- logDir.listFiles if file.getName.endsWith(Log.DeletedFileSuffix))
+      Utils.delete(file)
+
+    // Truncate the old segment
+    segmentWithOverflow.truncateTo(0)
+
+    log = LogTest.recoverAndCheck(logDir, logConfig, expectedKeys, brokerTopicStats, expectDeletedFiles = true)
+    assertEquals(expectedKeys, LogTest.keysInLog(log))
+    assertEquals(numSegmentsInitial + 1, log.logSegments.size)
+    log.close()
+  }
+
+  @Test
+  def testRecoveryAfterCrashDuringSplitPhase5(): Unit = {
+    val logConfig = LogTest.createLogConfig(indexIntervalBytes = 1, fileDeleteDelayMs = 1000)
+    var (log, segmentWithOverflow, initialRecords) = createLogWithOffsetOverflow(Some(logConfig))
+    val expectedKeys = LogTest.keysInLog(log)
+    val numSegmentsInitial = log.logSegments.size
+
+    // Split the segment
+    val newSegments = log.splitOverflowedSegment(segmentWithOverflow)
+
+    // Simulate recovery right after one of the new segment has been renamed to .swap and the other to .log. On
+    // recovery, existing split operation is completed.
+    newSegments.last.changeFileSuffixes("", Log.SwapFileSuffix)
+
+    // Truncate the old segment
+    segmentWithOverflow.truncateTo(0)
+
+    log = LogTest.recoverAndCheck(logDir, logConfig, expectedKeys, brokerTopicStats, expectDeletedFiles = true)
+    assertEquals(expectedKeys, LogTest.keysInLog(log))
+    assertEquals(numSegmentsInitial + 1, log.logSegments.size)
+    log.close()
+  }
+
+  @Test
   def testCleanShutdownFile() {
     // append some messages to create some segments
-    val logConfig = createLogConfig(segmentBytes = 1000, indexIntervalBytes = 1, maxMessageBytes = 64 * 1024)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 1000, indexIntervalBytes = 1, maxMessageBytes = 64 * 1024)
     def createRecords = TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds)
 
     val cleanShutdownFile = createCleanShutdownFile()
     assertTrue(".kafka_cleanshutdown must exist", cleanShutdownFile.exists())
     var recoveryPoint = 0L
     // create a log and write some messages to it
-    var log = createLog(logDir, logConfig)
+    var log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     for (_ <- 0 until 100)
       log.appendAsLeader(createRecords, leaderEpoch = 0)
     log.close()
@@ -2071,7 +2239,7 @@ class LogTest {
     // check if recovery was attempted. Even if the recovery point is 0L, recovery should not be attempted as the
     // clean shutdown file exists.
     recoveryPoint = log.logEndOffset
-    log = createLog(logDir, logConfig)
+    log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     assertEquals(recoveryPoint, log.logEndOffset)
     Utils.delete(cleanShutdownFile)
   }
@@ -2231,8 +2399,8 @@ class LogTest {
   @Test
   def testDeleteOldSegments() {
     def createRecords = TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds - 1000)
-    val logConfig = createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, segmentIndexBytes = 1000, retentionMs = 999)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, segmentIndexBytes = 1000, retentionMs = 999)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
 
     // append some messages to create some segments
     for (_ <- 0 until 100)
@@ -2281,8 +2449,8 @@ class LogTest {
   @Test
   def testLogDeletionAfterClose() {
     def createRecords = TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds - 1000)
-    val logConfig = createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, segmentIndexBytes = 1000, retentionMs = 999)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, segmentIndexBytes = 1000, retentionMs = 999)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
 
     // append some messages to create some segments
     log.appendAsLeader(createRecords, leaderEpoch = 0)
@@ -2299,8 +2467,8 @@ class LogTest {
   @Test
   def testLogDeletionAfterDeleteRecords() {
     def createRecords = TestUtils.singletonRecords("test".getBytes)
-    val logConfig = createLogConfig(segmentBytes = createRecords.sizeInBytes * 5)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = createRecords.sizeInBytes * 5)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
 
     for (_ <- 0 until 15)
       log.appendAsLeader(createRecords, leaderEpoch = 0)
@@ -2331,8 +2499,8 @@ class LogTest {
   @Test
   def shouldDeleteSizeBasedSegments() {
     def createRecords = TestUtils.singletonRecords("test".getBytes)
-    val logConfig = createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, retentionBytes = createRecords.sizeInBytes * 10)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, retentionBytes = createRecords.sizeInBytes * 10)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
 
     // append some messages to create some segments
     for (_ <- 0 until 15)
@@ -2346,8 +2514,8 @@ class LogTest {
   @Test
   def shouldNotDeleteSizeBasedSegmentsWhenUnderRetentionSize() {
     def createRecords = TestUtils.singletonRecords("test".getBytes)
-    val logConfig = createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, retentionBytes = createRecords.sizeInBytes * 15)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, retentionBytes = createRecords.sizeInBytes * 15)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
 
     // append some messages to create some segments
     for (_ <- 0 until 15)
@@ -2361,8 +2529,8 @@ class LogTest {
   @Test
   def shouldDeleteTimeBasedSegmentsReadyToBeDeleted() {
     def createRecords = TestUtils.singletonRecords("test".getBytes, timestamp = 10)
-    val logConfig = createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, retentionMs = 10000)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, retentionMs = 10000)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
 
     // append some messages to create some segments
     for (_ <- 0 until 15)
@@ -2376,8 +2544,8 @@ class LogTest {
   @Test
   def shouldNotDeleteTimeBasedSegmentsWhenNoneReadyToBeDeleted() {
     def createRecords = TestUtils.singletonRecords("test".getBytes, timestamp = mockTime.milliseconds)
-    val logConfig = createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, retentionMs = 10000000)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, retentionMs = 10000000)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
 
     // append some messages to create some segments
     for (_ <- 0 until 15)
@@ -2391,8 +2559,8 @@ class LogTest {
   @Test
   def shouldNotDeleteSegmentsWhenPolicyDoesNotIncludeDelete() {
     def createRecords = TestUtils.singletonRecords("test".getBytes, key = "test".getBytes(), timestamp = 10L)
-    val logConfig = createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, retentionMs = 10000, cleanupPolicy = "compact")
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, retentionMs = 10000, cleanupPolicy = "compact")
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
 
     // append some messages to create some segments
     for (_ <- 0 until 15)
@@ -2410,8 +2578,8 @@ class LogTest {
   @Test
   def shouldDeleteSegmentsReadyToBeDeletedWhenCleanupPolicyIsCompactAndDelete() {
     def createRecords = TestUtils.singletonRecords("test".getBytes, key = "test".getBytes, timestamp = 10L)
-    val logConfig = createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, retentionMs = 10000, cleanupPolicy = "compact,delete")
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, retentionMs = 10000, cleanupPolicy = "compact,delete")
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
 
     // append some messages to create some segments
     for (_ <- 0 until 15)
@@ -2428,7 +2596,7 @@ class LogTest {
 
     //Given this partition is on leader epoch 72
     val epoch = 72
-    val log = createLog(logDir, LogConfig())
+    val log = createLog(logDir, LogConfig(), brokerTopicStats = brokerTopicStats)
     log.leaderEpochCache.assign(epoch, records.size)
 
     //When appending messages as a leader (i.e. assignOffsets = true)
@@ -2460,7 +2628,7 @@ class LogTest {
       recs
     }
 
-    val log = createLog(logDir, LogConfig())
+    val log = createLog(logDir, LogConfig(), brokerTopicStats = brokerTopicStats)
 
     //When appending as follower (assignOffsets = false)
     for (i <- records.indices)
@@ -2472,8 +2640,8 @@ class LogTest {
   @Test
   def shouldTruncateLeaderEpochsWhenDeletingSegments() {
     def createRecords = TestUtils.singletonRecords("test".getBytes)
-    val logConfig = createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, retentionBytes = createRecords.sizeInBytes * 10)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, retentionBytes = createRecords.sizeInBytes * 10)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     val cache = epochCache(log)
 
     // Given three segments of 5 messages each
@@ -2497,8 +2665,8 @@ class LogTest {
   @Test
   def shouldUpdateOffsetForLeaderEpochsWhenDeletingSegments() {
     def createRecords = TestUtils.singletonRecords("test".getBytes)
-    val logConfig = createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, retentionBytes = createRecords.sizeInBytes * 10)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, retentionBytes = createRecords.sizeInBytes * 10)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     val cache = epochCache(log)
 
     // Given three segments of 5 messages each
@@ -2522,8 +2690,8 @@ class LogTest {
   @Test
   def shouldTruncateLeaderEpochFileWhenTruncatingLog() {
     def createRecords = TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds)
-    val logConfig = createLogConfig(segmentBytes = 10 * createRecords.sizeInBytes)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 10 * createRecords.sizeInBytes)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     val cache = epochCache(log)
 
     //Given 2 segments, 10 messages per segment
@@ -2564,11 +2732,11 @@ class LogTest {
   }
 
   /**
-    * Append a bunch of messages to a log and then re-open it with recovery and check that the leader epochs are recovered properly.
-    */
+   * Append a bunch of messages to a log and then re-open it with recovery and check that the leader epochs are recovered properly.
+   */
   @Test
   def testLogRecoversForLeaderEpoch() {
-    val log = createLog(logDir, LogConfig())
+    val log = createLog(logDir, LogConfig(), brokerTopicStats = brokerTopicStats)
     val leaderEpochCache = epochCache(log)
     val firstBatch = singletonRecordsWithLeaderEpoch(value = "random".getBytes, leaderEpoch = 1, offset = 0)
     log.appendAsFollower(records = firstBatch)
@@ -2590,7 +2758,7 @@ class LogTest {
     log.close()
 
     // reopen the log and recover from the beginning
-    val recoveredLog = createLog(logDir, LogConfig())
+    val recoveredLog = createLog(logDir, LogConfig(), brokerTopicStats = brokerTopicStats)
     val recoveredLeaderEpochCache = epochCache(recoveredLog)
 
     // epoch entries should be recovered
@@ -2599,15 +2767,15 @@ class LogTest {
   }
 
   /**
-    * Wrap a single record log buffer with leader epoch.
-    */
+   * Wrap a single record log buffer with leader epoch.
+   */
   private def singletonRecordsWithLeaderEpoch(value: Array[Byte],
-                                      key: Array[Byte] = null,
-                                      leaderEpoch: Int,
-                                      offset: Long,
-                                      codec: CompressionType = CompressionType.NONE,
-                                      timestamp: Long = RecordBatch.NO_TIMESTAMP,
-                                      magicValue: Byte = RecordBatch.CURRENT_MAGIC_VALUE): MemoryRecords = {
+                                              key: Array[Byte] = null,
+                                              leaderEpoch: Int,
+                                              offset: Long,
+                                              codec: CompressionType = CompressionType.NONE,
+                                              timestamp: Long = RecordBatch.NO_TIMESTAMP,
+                                              magicValue: Byte = RecordBatch.CURRENT_MAGIC_VALUE): MemoryRecords = {
     val records = Seq(new SimpleRecord(timestamp, key, value))
 
     val buf = ByteBuffer.allocate(DefaultRecordBatch.sizeInBytes(records.asJava))
@@ -2618,8 +2786,8 @@ class LogTest {
   }
 
   def testFirstUnstableOffsetNoTransactionalData() {
-    val logConfig = createLogConfig(segmentBytes = 1024 * 1024 * 5)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 1024 * 1024 * 5)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
 
     val records = MemoryRecords.withRecords(CompressionType.NONE,
       new SimpleRecord("foo".getBytes),
@@ -2632,8 +2800,8 @@ class LogTest {
 
   @Test
   def testFirstUnstableOffsetWithTransactionalData() {
-    val logConfig = createLogConfig(segmentBytes = 1024 * 1024 * 5)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 1024 * 1024 * 5)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
 
     val pid = 137L
     val epoch = 5.toShort
@@ -2670,8 +2838,8 @@ class LogTest {
 
   @Test
   def testTransactionIndexUpdated(): Unit = {
-    val logConfig = createLogConfig(segmentBytes = 1024 * 1024 * 5)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 1024 * 1024 * 5)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     val epoch = 0.toShort
 
     val pid1 = 1L
@@ -2711,8 +2879,8 @@ class LogTest {
 
   @Test
   def testFullTransactionIndexRecovery(): Unit = {
-    val logConfig = createLogConfig(segmentBytes = 128 * 5)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 128 * 5)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     val epoch = 0.toShort
 
     val pid1 = 1L
@@ -2754,16 +2922,16 @@ class LogTest {
 
     log.close()
 
-    val reloadedLogConfig = createLogConfig(segmentBytes = 1024 * 5)
-    val reloadedLog = createLog(logDir, reloadedLogConfig)
+    val reloadedLogConfig = LogTest.createLogConfig(segmentBytes = 1024 * 5)
+    val reloadedLog = createLog(logDir, reloadedLogConfig, brokerTopicStats = brokerTopicStats)
     val abortedTransactions = allAbortedTransactions(reloadedLog)
     assertEquals(List(new AbortedTxn(pid1, 0L, 29L, 8L), new AbortedTxn(pid2, 8L, 74L, 36L)), abortedTransactions)
   }
 
   @Test
   def testRecoverOnlyLastSegment(): Unit = {
-    val logConfig = createLogConfig(segmentBytes = 128 * 5)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 128 * 5)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     val epoch = 0.toShort
 
     val pid1 = 1L
@@ -2805,16 +2973,16 @@ class LogTest {
 
     log.close()
 
-    val reloadedLogConfig = createLogConfig(segmentBytes = 1024 * 5)
-    val reloadedLog = createLog(logDir, reloadedLogConfig, recoveryPoint = recoveryPoint)
+    val reloadedLogConfig = LogTest.createLogConfig(segmentBytes = 1024 * 5)
+    val reloadedLog = createLog(logDir, reloadedLogConfig, recoveryPoint = recoveryPoint, brokerTopicStats = brokerTopicStats)
     val abortedTransactions = allAbortedTransactions(reloadedLog)
     assertEquals(List(new AbortedTxn(pid1, 0L, 29L, 8L), new AbortedTxn(pid2, 8L, 74L, 36L)), abortedTransactions)
   }
 
   @Test
   def testRecoverLastSegmentWithNoSnapshots(): Unit = {
-    val logConfig = createLogConfig(segmentBytes = 128 * 5)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 128 * 5)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     val epoch = 0.toShort
 
     val pid1 = 1L
@@ -2859,8 +3027,8 @@ class LogTest {
 
     log.close()
 
-    val reloadedLogConfig = createLogConfig(segmentBytes = 1024 * 5)
-    val reloadedLog = createLog(logDir, reloadedLogConfig, recoveryPoint = recoveryPoint)
+    val reloadedLogConfig = LogTest.createLogConfig(segmentBytes = 1024 * 5)
+    val reloadedLog = createLog(logDir, reloadedLogConfig, recoveryPoint = recoveryPoint, brokerTopicStats = brokerTopicStats)
     val abortedTransactions = allAbortedTransactions(reloadedLog)
     assertEquals(List(new AbortedTxn(pid1, 0L, 29L, 8L), new AbortedTxn(pid2, 8L, 74L, 36L)), abortedTransactions)
   }
@@ -2868,8 +3036,8 @@ class LogTest {
   @Test
   def testTransactionIndexUpdatedThroughReplication(): Unit = {
     val epoch = 0.toShort
-    val logConfig = createLogConfig(segmentBytes = 1024 * 1024 * 5)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 1024 * 1024 * 5)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     val buffer = ByteBuffer.allocate(2048)
 
     val pid1 = 1L
@@ -2914,8 +3082,8 @@ class LogTest {
   def testZombieCoordinatorFenced(): Unit = {
     val pid = 1L
     val epoch = 0.toShort
-    val logConfig = createLogConfig(segmentBytes = 1024 * 1024 * 5)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 1024 * 1024 * 5)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
 
     val append = appendTransactionalAsLeader(log, pid, epoch)
 
@@ -2930,8 +3098,8 @@ class LogTest {
 
   @Test
   def testFirstUnstableOffsetDoesNotExceedLogStartOffsetMidSegment(): Unit = {
-    val logConfig = createLogConfig(segmentBytes = 1024 * 1024 * 5)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 1024 * 1024 * 5)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     val epoch = 0.toShort
     val pid = 1L
     val appendPid = appendTransactionalAsLeader(log, pid, epoch)
@@ -2954,8 +3122,8 @@ class LogTest {
 
   @Test
   def testFirstUnstableOffsetDoesNotExceedLogStartOffsetAfterSegmentDeletion(): Unit = {
-    val logConfig = createLogConfig(segmentBytes = 1024 * 1024 * 5)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 1024 * 1024 * 5)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
     val epoch = 0.toShort
     val pid = 1L
     val appendPid = appendTransactionalAsLeader(log, pid, epoch)
@@ -2981,8 +3149,8 @@ class LogTest {
 
   @Test
   def testLastStableOffsetWithMixedProducerData() {
-    val logConfig = createLogConfig(segmentBytes = 1024 * 1024 * 5)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = 1024 * 1024 * 5)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
 
     // for convenience, both producers share the same epoch
     val epoch = 5.toShort
@@ -3042,8 +3210,8 @@ class LogTest {
       new SimpleRecord("b".getBytes),
       new SimpleRecord("c".getBytes))
 
-    val logConfig = createLogConfig(segmentBytes = records.sizeInBytes)
-    val log = createLog(logDir, logConfig)
+    val logConfig = LogTest.createLogConfig(segmentBytes = records.sizeInBytes)
+    val log = createLog(logDir, logConfig, brokerTopicStats = brokerTopicStats)
 
     val firstAppendInfo = log.appendAsLeader(records, leaderEpoch = 0)
     assertEquals(Some(firstAppendInfo.firstOffset.get), log.firstUnstableOffset.map(_.messageOffset))
@@ -3073,55 +3241,7 @@ class LogTest {
     assertEquals(new AbortedTransaction(pid, 0), fetchDataInfo.abortedTransactions.get.head)
   }
 
-  def createLogConfig(segmentMs: Long = Defaults.SegmentMs,
-                      segmentBytes: Int = Defaults.SegmentSize,
-                      retentionMs: Long = Defaults.RetentionMs,
-                      retentionBytes: Long = Defaults.RetentionSize,
-                      segmentJitterMs: Long = Defaults.SegmentJitterMs,
-                      cleanupPolicy: String = Defaults.CleanupPolicy,
-                      maxMessageBytes: Int = Defaults.MaxMessageSize,
-                      indexIntervalBytes: Int = Defaults.IndexInterval,
-                      segmentIndexBytes: Int = Defaults.MaxIndexSize,
-                      messageFormatVersion: String = Defaults.MessageFormatVersion,
-                      fileDeleteDelayMs: Long = Defaults.FileDeleteDelayMs): LogConfig = {
-    val logProps = new Properties()
-
-    logProps.put(LogConfig.SegmentMsProp, segmentMs: java.lang.Long)
-    logProps.put(LogConfig.SegmentBytesProp, segmentBytes: Integer)
-    logProps.put(LogConfig.RetentionMsProp, retentionMs: java.lang.Long)
-    logProps.put(LogConfig.RetentionBytesProp, retentionBytes: java.lang.Long)
-    logProps.put(LogConfig.SegmentJitterMsProp, segmentJitterMs: java.lang.Long)
-    logProps.put(LogConfig.CleanupPolicyProp, cleanupPolicy)
-    logProps.put(LogConfig.MaxMessageBytesProp, maxMessageBytes: Integer)
-    logProps.put(LogConfig.IndexIntervalBytesProp, indexIntervalBytes: Integer)
-    logProps.put(LogConfig.SegmentIndexBytesProp, segmentIndexBytes: Integer)
-    logProps.put(LogConfig.MessageFormatVersionProp, messageFormatVersion)
-    logProps.put(LogConfig.FileDeleteDelayMsProp, fileDeleteDelayMs: java.lang.Long)
-    LogConfig(logProps)
-  }
-
-  def createLog(dir: File,
-                config: LogConfig,
-                logStartOffset: Long = 0L,
-                recoveryPoint: Long = 0L,
-                scheduler: Scheduler = mockTime.scheduler,
-                brokerTopicStats: BrokerTopicStats = brokerTopicStats,
-                time: Time = mockTime,
-                maxProducerIdExpirationMs: Int = 60 * 60 * 1000,
-                producerIdExpirationCheckIntervalMs: Int = LogManager.ProducerIdExpirationCheckIntervalMs): Log = {
-    Log(dir = dir,
-        config = config,
-        logStartOffset = logStartOffset,
-        recoveryPoint = recoveryPoint,
-        scheduler = scheduler,
-        brokerTopicStats = brokerTopicStats,
-        time = time,
-        maxProducerIdExpirationMs = maxProducerIdExpirationMs,
-        producerIdExpirationCheckIntervalMs = producerIdExpirationCheckIntervalMs,
-        logDirFailureChannel = new LogDirFailureChannel(10))
-  }
-
-  private def allAbortedTransactions(log: Log) = log.logSegments.flatMap(_.txnIndex.allAbortedTxns)
+ private def allAbortedTransactions(log: Log) = log.logSegments.flatMap(_.txnIndex.allAbortedTxns)
 
   private def appendTransactionalAsLeader(log: Log, producerId: Long, producerEpoch: Short): Int => Unit = {
     var sequence = 0
@@ -3200,4 +3320,230 @@ class LogTest {
   private def listProducerSnapshotOffsets: Seq[Long] =
     ProducerStateManager.listSnapshotFiles(logDir).map(Log.offsetFromFile).sorted
 
+  private def createLog(dir: File,
+                        config: LogConfig,
+                        brokerTopicStats: BrokerTopicStats = brokerTopicStats,
+                        logStartOffset: Long = 0L,
+                        recoveryPoint: Long = 0L,
+                        scheduler: Scheduler = mockTime.scheduler,
+                        time: Time = mockTime,
+                        maxProducerIdExpirationMs: Int = 60 * 60 * 1000,
+                        producerIdExpirationCheckIntervalMs: Int = LogManager.ProducerIdExpirationCheckIntervalMs): Log = {
+    return LogTest.createLog(dir, config, brokerTopicStats, scheduler, time, logStartOffset, recoveryPoint,
+      maxProducerIdExpirationMs, producerIdExpirationCheckIntervalMs)
+  }
+
+  private def createLogWithOffsetOverflow(logConfig: Option[LogConfig]): (Log, LogSegment, List[Record]) = {
+    return LogTest.createLogWithOffsetOverflow(logDir, brokerTopicStats, logConfig, mockTime.scheduler, mockTime)
+  }
+}
+
+object LogTest {
+  def createLogConfig(segmentMs: Long = Defaults.SegmentMs,
+                      segmentBytes: Int = Defaults.SegmentSize,
+                      retentionMs: Long = Defaults.RetentionMs,
+                      retentionBytes: Long = Defaults.RetentionSize,
+                      segmentJitterMs: Long = Defaults.SegmentJitterMs,
+                      cleanupPolicy: String = Defaults.CleanupPolicy,
+                      maxMessageBytes: Int = Defaults.MaxMessageSize,
+                      indexIntervalBytes: Int = Defaults.IndexInterval,
+                      segmentIndexBytes: Int = Defaults.MaxIndexSize,
+                      messageFormatVersion: String = Defaults.MessageFormatVersion,
+                      fileDeleteDelayMs: Long = Defaults.FileDeleteDelayMs): LogConfig = {
+    val logProps = new Properties()
+
+    logProps.put(LogConfig.SegmentMsProp, segmentMs: java.lang.Long)
+    logProps.put(LogConfig.SegmentBytesProp, segmentBytes: Integer)
+    logProps.put(LogConfig.RetentionMsProp, retentionMs: java.lang.Long)
+    logProps.put(LogConfig.RetentionBytesProp, retentionBytes: java.lang.Long)
+    logProps.put(LogConfig.SegmentJitterMsProp, segmentJitterMs: java.lang.Long)
+    logProps.put(LogConfig.CleanupPolicyProp, cleanupPolicy)
+    logProps.put(LogConfig.MaxMessageBytesProp, maxMessageBytes: Integer)
+    logProps.put(LogConfig.IndexIntervalBytesProp, indexIntervalBytes: Integer)
+    logProps.put(LogConfig.SegmentIndexBytesProp, segmentIndexBytes: Integer)
+    logProps.put(LogConfig.MessageFormatVersionProp, messageFormatVersion)
+    logProps.put(LogConfig.FileDeleteDelayMsProp, fileDeleteDelayMs: java.lang.Long)
+    LogConfig(logProps)
+  }
+
+  def createLog(dir: File,
+                config: LogConfig,
+                brokerTopicStats: BrokerTopicStats,
+                scheduler: Scheduler,
+                time: Time,
+                logStartOffset: Long = 0L,
+                recoveryPoint: Long = 0L,
+                maxProducerIdExpirationMs: Int = 60 * 60 * 1000,
+                producerIdExpirationCheckIntervalMs: Int = LogManager.ProducerIdExpirationCheckIntervalMs): Log = {
+    Log(dir = dir,
+      config = config,
+      logStartOffset = logStartOffset,
+      recoveryPoint = recoveryPoint,
+      scheduler = scheduler,
+      brokerTopicStats = brokerTopicStats,
+      time = time,
+      maxProducerIdExpirationMs = maxProducerIdExpirationMs,
+      producerIdExpirationCheckIntervalMs = producerIdExpirationCheckIntervalMs,
+      logDirFailureChannel = new LogDirFailureChannel(10))
+  }
+
+  /**
+   * Check if the given log contains any segment with records that cause offset overflow.
+   * @param log Log to check
+   * @return true if log contains at least one segment with offset overflow; false otherwise
+   */
+  def hasOffsetOverflow(log: Log): Boolean = {
+    for (logSegment <- log.logSegments) {
+      val baseOffset = logSegment.baseOffset
+      for (batch <- logSegment.log.batches.asScala) {
+        val it = batch.iterator()
+        while (it.hasNext()) {
+          val record = it.next()
+          if (record.offset > baseOffset + Int.MaxValue || record.offset < baseOffset)
+            return true
+        }
+      }
+    }
+    false
+  }
+
+  /**
+   * Create a log such that one of the log segments has messages with offsets that cause index offset overflow.
+   * @param logDir Directory in which log should be created
+   * @param brokerTopicStats Container for Broker Topic Yammer Metrics
+   * @param logConfigOpt Optional log configuration to use
+   * @param scheduler The thread pool scheduler used for background actions
+   * @param time The time instance to use
+   * @return (1) Created log containing segment with offset overflow, (2) Log segment within log containing messages with
+   *         offset overflow, and (3) List of messages in the log
+   */
+  def createLogWithOffsetOverflow(logDir: File, brokerTopicStats: BrokerTopicStats, logConfigOpt: Option[LogConfig] = None,
+                                  scheduler: Scheduler, time: Time): (Log, LogSegment, List[Record]) = {
+    val logConfig =
+      if (logConfigOpt.isDefined)
+        logConfigOpt.get
+      else
+        createLogConfig(indexIntervalBytes = 1)
+
+    var log = createLog(logDir, logConfig, brokerTopicStats, scheduler, time)
+    var inputRecords = ListBuffer[Record]()
+
+    // References to files we want to "merge" to emulate offset overflow
+    val toMerge = ListBuffer[File]()
+
+    def getRecords(baseOffset: Long): List[MemoryRecords] = {
+      def toBytes(value: Long): Array[Byte] = value.toString.getBytes
+
+      val set1 = MemoryRecords.withRecords(baseOffset, CompressionType.NONE, 0,
+        new SimpleRecord(toBytes(baseOffset), toBytes(baseOffset)))
+      val set2 = MemoryRecords.withRecords(baseOffset + 1, CompressionType.NONE, 0,
+        new SimpleRecord(toBytes(baseOffset + 1), toBytes(baseOffset + 1)),
+        new SimpleRecord(toBytes(baseOffset + 2), toBytes(baseOffset + 2)));
+      val set3 = MemoryRecords.withRecords(baseOffset + Int.MaxValue - 1, CompressionType.NONE, 0,
+        new SimpleRecord(toBytes(baseOffset + Int.MaxValue - 1), toBytes(baseOffset + Int.MaxValue - 1)));
+      List(set1, set2, set3)
+    }
+
+    // Append some messages to the log. This will create four log segments.
+    var firstOffset = 0L
+    for (i <- 0 until 4) {
+      val recordsToAppend = getRecords(firstOffset)
+      for (records <- recordsToAppend)
+        log.appendAsFollower(records)
+
+      if (i == 1 || i == 2)
+        toMerge += log.activeSegment.log.file
+
+      firstOffset += Int.MaxValue + 1L
+    }
+
+    // assert that we have the correct number of segments
+    assertEquals(log.numberOfSegments, 4)
+
+    // assert number of batches
+    for (logSegment <- log.logSegments) {
+      var numBatches = 0
+      for (_ <- logSegment.log.batches.asScala)
+        numBatches += 1
+      assertEquals(numBatches, 3)
+    }
+
+    // create a list of appended records
+    for (logSegment <- log.logSegments) {
+      for (batch <- logSegment.log.batches.asScala) {
+        val it = batch.iterator()
+        while (it.hasNext())
+          inputRecords += it.next()
+      }
+    }
+
+    log.flush()
+    log.close()
+
+    // We want to "merge" log segments 1 and 2. This is where the offset overflow will be.
+    // Current: segment #1 | segment #2 | segment #3 | segment# 4
+    // Final: segment #1 | segment #2' | segment #4
+    // where 2' corresponds to segment #2 and segment #3 combined together.
+    // Append segment #3 at the end of segment #2 to create 2'
+    var dest: FileOutputStream = null
+    var source: FileInputStream = null
+    try {
+      dest = new FileOutputStream(toMerge(0), true)
+      source = new FileInputStream(toMerge(1))
+      val sourceBytes = new Array[Byte](toMerge(1).length.toInt)
+      source.read(sourceBytes)
+      dest.write(sourceBytes)
+    } finally {
+      dest.close()
+      source.close()
+    }
+
+    // Delete segment #3 including any index, etc.
+    toMerge(1).delete()
+    log = createLog(logDir, logConfig, brokerTopicStats, scheduler, time, recoveryPoint = Long.MaxValue)
+
+    // assert that there is now one less segment than before, and that the records in the log are same as before
+    assertEquals(log.numberOfSegments, 3)
+    assertTrue(verifyRecordsInLog(log, inputRecords.toList))
+
+    (log, log.logSegments.toList(1), inputRecords.toList)
+  }
+
+  def verifyRecordsInLog(log: Log, expectedRecords: List[Record]): Boolean = {
+    val recordsFound = ListBuffer[Record]()
+    for (logSegment <- log.logSegments) {
+      for (batch <- logSegment.log.batches.asScala) {
+        val it = batch.iterator()
+        while (it.hasNext())
+          recordsFound += it.next()
+      }
+    }
+    return recordsFound.equals(expectedRecords)
+  }
+
+  /* extract all the keys from a log */
+  def keysInLog(log: Log): Iterable[Long] = {
+    for (logSegment <- log.logSegments;
+         batch <- logSegment.log.batches.asScala if !batch.isControlBatch;
+         record <- batch.asScala if record.hasValue && record.hasKey)
+      yield TestUtils.readString(record.key).toLong
+  }
+
+  def recoverAndCheck(logDir: File, config: LogConfig, expectedKeys: Iterable[Long],
+                      brokerTopicStats: BrokerTopicStats, expectDeletedFiles: Boolean = false): Log = {
+    val time = new MockTime()
+    // Recover log file and check that after recovery, keys are as expected
+    // and all temporary files have been deleted
+    val recoveredLog = createLog(logDir, config, brokerTopicStats, time.scheduler, time)
+    time.sleep(config.fileDeleteDelayMs + 1)
+    for (file <- logDir.listFiles) {
+      if (!expectDeletedFiles)
+        assertFalse("Unexpected .deleted file after recovery", file.getName.endsWith(Log.DeletedFileSuffix))
+      assertFalse("Unexpected .cleaned file after recovery", file.getName.endsWith(Log.CleanedFileSuffix))
+      assertFalse("Unexpected .swap file after recovery", file.getName.endsWith(Log.SwapFileSuffix))
+    }
+    assertEquals(expectedKeys, LogTest.keysInLog(recoveredLog))
+    assertFalse(LogTest.hasOffsetOverflow(recoveredLog))
+    recoveredLog
+  }
 }
diff --git a/core/src/test/scala/unit/kafka/log/OffsetIndexTest.scala b/core/src/test/scala/unit/kafka/log/OffsetIndexTest.scala
index 8fa3cc1..1e4e892 100644
--- a/core/src/test/scala/unit/kafka/log/OffsetIndexTest.scala
+++ b/core/src/test/scala/unit/kafka/log/OffsetIndexTest.scala
@@ -35,10 +35,11 @@ class OffsetIndexTest extends JUnitSuite {
   
   var idx: OffsetIndex = null
   val maxEntries = 30
+  val baseOffset = 45L
   
   @Before
   def setup() {
-    this.idx = new OffsetIndex(nonExistentTempFile(), baseOffset = 45L, maxIndexSize = 30 * 8)
+    this.idx = new OffsetIndex(nonExistentTempFile(), baseOffset, maxIndexSize = 30 * 8)
   }
   
   @After
@@ -102,10 +103,10 @@ class OffsetIndexTest extends JUnitSuite {
 
   @Test
   def testFetchUpperBoundOffset() {
-    val first = OffsetPosition(0, 0)
-    val second = OffsetPosition(1, 10)
-    val third = OffsetPosition(2, 23)
-    val fourth = OffsetPosition(3, 37)
+    val first = OffsetPosition(baseOffset + 0, 0)
+    val second = OffsetPosition(baseOffset + 1, 10)
+    val third = OffsetPosition(baseOffset + 2, 23)
+    val fourth = OffsetPosition(baseOffset + 3, 37)
 
     assertEquals(None, idx.fetchUpperBoundOffset(first, 5))
 

-- 
To stop receiving notification emails like this one, please contact
jgus@apache.org.