You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by ju...@apache.org on 2022/01/27 23:03:26 UTC

[kafka] branch trunk updated: KAFKA-13603: Allow the empty active segment to have missing offset index during recovery (#11345)

This is an automated email from the ASF dual-hosted git repository.

junrao pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new a21aec8  KAFKA-13603: Allow the empty active segment to have missing offset index during recovery (#11345)
a21aec8 is described below

commit a21aec8d6287f170ccb2568eadde515980dd597c
Author: Cong Ding <co...@ccding.com>
AuthorDate: Thu Jan 27 16:59:21 2022 -0600

    KAFKA-13603: Allow the empty active segment to have missing offset index during recovery (#11345)
    
    Within a LogSegment, the TimeIndex and OffsetIndex are lazy indices that don't get created on disk until they are accessed for the first time. However, Log recovery logic expects the presence of an offset index file on disk for each segment, otherwise, the segment is considered corrupted.
    
    This PR introduces a forceFlushActiveSegment boolean for the log.flush function to allow the shutdown process to flush the empty active segment, which makes sure the offset index file exists.
    
    Co-Author: Kowshik Prakasam kowshik@gmail.com
    Reviewers: Jason Gustafson <ja...@confluent.io>, Jun Rao <ju...@gmail.com>
---
 core/src/main/scala/kafka/log/LogLoader.scala      |  5 ++-
 core/src/main/scala/kafka/log/LogManager.scala     |  4 +-
 core/src/main/scala/kafka/log/UnifiedLog.scala     | 37 ++++++++++++----
 .../main/scala/kafka/raft/KafkaMetadataLog.scala   |  4 +-
 .../test/scala/unit/kafka/log/LogLoaderTest.scala  | 49 ++++++++++++++++++++--
 .../test/scala/unit/kafka/log/LogManagerTest.scala |  4 +-
 .../test/scala/unit/kafka/log/UnifiedLogTest.scala | 21 ++++++++--
 .../scala/unit/kafka/server/LogOffsetTest.scala    | 12 +++---
 ...chDrivenReplicationProtocolAcceptanceTest.scala |  2 +-
 .../unit/kafka/tools/DumpLogSegmentsTest.scala     |  4 +-
 .../org/apache/kafka/raft/KafkaRaftClient.java     |  5 ++-
 .../java/org/apache/kafka/raft/ReplicatedLog.java  |  4 +-
 .../test/java/org/apache/kafka/raft/MockLog.java   |  4 +-
 .../java/org/apache/kafka/raft/MockLogTest.java    |  2 +-
 14 files changed, 119 insertions(+), 38 deletions(-)

diff --git a/core/src/main/scala/kafka/log/LogLoader.scala b/core/src/main/scala/kafka/log/LogLoader.scala
index f9c5dab..e22dcd4 100644
--- a/core/src/main/scala/kafka/log/LogLoader.scala
+++ b/core/src/main/scala/kafka/log/LogLoader.scala
@@ -326,8 +326,9 @@ class LogLoader(
         try segment.sanityCheck(timeIndexFileNewlyCreated)
         catch {
           case _: NoSuchFileException =>
-            error(s"Could not find offset index file corresponding to log file" +
-              s" ${segment.log.file.getAbsolutePath}, recovering segment and rebuilding index files...")
+            if (hadCleanShutdown || segment.baseOffset < recoveryPointCheckpoint)
+              error(s"Could not find offset index file corresponding to log file" +
+                s" ${segment.log.file.getAbsolutePath}, recovering segment and rebuilding index files...")
             recoverSegment(segment)
           case e: CorruptIndexException =>
             warn(s"Found a corrupted index file corresponding to log file" +
diff --git a/core/src/main/scala/kafka/log/LogManager.scala b/core/src/main/scala/kafka/log/LogManager.scala
index c4ee18c..66dc581 100755
--- a/core/src/main/scala/kafka/log/LogManager.scala
+++ b/core/src/main/scala/kafka/log/LogManager.scala
@@ -524,7 +524,7 @@ class LogManager(logDirs: Seq[File],
       val jobsForDir = logs.map { log =>
         val runnable: Runnable = () => {
           // flush the log to ensure latest possible recovery point
-          log.flush()
+          log.flush(true)
           log.close()
         }
         runnable
@@ -1242,7 +1242,7 @@ class LogManager(logDirs: Seq[File],
         debug(s"Checking if flush is needed on ${topicPartition.topic} flush interval ${log.config.flushMs}" +
               s" last flushed ${log.lastFlushTime} time since last flush: $timeSinceLastFlush")
         if(timeSinceLastFlush >= log.config.flushMs)
-          log.flush()
+          log.flush(false)
       } catch {
         case e: Throwable =>
           error(s"Error flushing topic ${topicPartition.topic}", e)
diff --git a/core/src/main/scala/kafka/log/UnifiedLog.scala b/core/src/main/scala/kafka/log/UnifiedLog.scala
index 543087a..a69273b 100644
--- a/core/src/main/scala/kafka/log/UnifiedLog.scala
+++ b/core/src/main/scala/kafka/log/UnifiedLog.scala
@@ -930,7 +930,7 @@ class UnifiedLog(@volatile var logStartOffset: Long,
                 s"next offset: ${localLog.logEndOffset}, " +
                 s"and messages: $validRecords")
 
-              if (localLog.unflushedMessages >= config.flushInterval) flush()
+              if (localLog.unflushedMessages >= config.flushInterval) flush(false)
           }
           appendInfo
         }
@@ -1498,28 +1498,47 @@ class UnifiedLog(@volatile var logStartOffset: Long,
     producerStateManager.takeSnapshot()
     updateHighWatermarkWithLogEndOffset()
     // Schedule an asynchronous flush of the old segment
-    scheduler.schedule("flush-log", () => flush(newSegment.baseOffset))
+    scheduler.schedule("flush-log", () => flushUptoOffsetExclusive(newSegment.baseOffset))
     newSegment
   }
 
   /**
    * Flush all local log segments
+   *
+   * @param forceFlushActiveSegment should be true during a clean shutdown, and false otherwise. The reason is that
+   * we have to pass logEndOffset + 1 to the `localLog.flush(offset: Long): Unit` function to flush empty
+   * active segments, which is important to make sure we persist the active segment file during shutdown, particularly
+   * when it's empty.
    */
-  def flush(): Unit = flush(logEndOffset)
+  def flush(forceFlushActiveSegment: Boolean): Unit = flush(logEndOffset, forceFlushActiveSegment)
 
   /**
    * Flush local log segments for all offsets up to offset-1
    *
    * @param offset The offset to flush up to (non-inclusive); the new recovery point
    */
-  def flush(offset: Long): Unit = {
-    maybeHandleIOException(s"Error while flushing log for $topicPartition in dir ${dir.getParent} with offset $offset") {
-      if (offset > localLog.recoveryPoint) {
-        debug(s"Flushing log up to offset $offset, last flushed: $lastFlushTime,  current time: ${time.milliseconds()}, " +
+  def flushUptoOffsetExclusive(offset: Long): Unit = flush(offset, false)
+
+  /**
+   * Flush local log segments for all offsets up to offset-1 if includingOffset=false; up to offset
+   * if includingOffset=true. The recovery point is set to offset.
+   *
+   * @param offset The offset to flush up to; the new recovery point
+   * @param includingOffset Whether the flush includes the provided offset.
+   */
+  private def flush(offset: Long, includingOffset: Boolean): Unit = {
+    val flushOffset = if (includingOffset) offset + 1  else offset
+    val newRecoveryPoint = offset
+    val includingOffsetStr =  if (includingOffset) "inclusive" else "exclusive"
+    maybeHandleIOException(s"Error while flushing log for $topicPartition in dir ${dir.getParent} with offset $offset " +
+      s"($includingOffsetStr) and recovery point $newRecoveryPoint") {
+      if (flushOffset > localLog.recoveryPoint) {
+        debug(s"Flushing log up to offset $offset ($includingOffsetStr)" +
+          s"with recovery point $newRecoveryPoint, last flushed: $lastFlushTime,  current time: ${time.milliseconds()}," +
           s"unflushed: ${localLog.unflushedMessages}")
-        localLog.flush(offset)
+        localLog.flush(flushOffset)
         lock synchronized {
-          localLog.markFlushed(offset)
+          localLog.markFlushed(newRecoveryPoint)
         }
       }
     }
diff --git a/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala b/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala
index c83aec6..8c9132b 100644
--- a/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala
+++ b/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala
@@ -211,8 +211,8 @@ final class KafkaMetadataLog private (
     new LogOffsetMetadata(hwm.messageOffset, segmentPosition)
   }
 
-  override def flush(): Unit = {
-    log.flush()
+  override def flush(forceFlushActiveSegment: Boolean): Unit = {
+    log.flush(forceFlushActiveSegment)
   }
 
   override def lastFlushedOffset(): Long = {
diff --git a/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala b/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala
index ea236b9..496b1d1 100644
--- a/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala
@@ -19,7 +19,7 @@ package kafka.log
 
 import java.io.{BufferedWriter, File, FileWriter}
 import java.nio.ByteBuffer
-import java.nio.file.{Files, Paths}
+import java.nio.file.{Files, NoSuchFileException, Paths}
 import java.util.Properties
 import kafka.api.{ApiVersion, KAFKA_0_11_0_IV0}
 import kafka.server.epoch.{EpochEntry, LeaderEpochFileCache}
@@ -764,7 +764,7 @@ class LogLoaderTest {
     val lastOffset = log.logEndOffset
     // After segment is closed, the last entry in the time index should be (largest timestamp -> last offset).
     val lastTimeIndexOffset = log.logEndOffset - 1
-    val lastTimeIndexTimestamp  = log.activeSegment.largestTimestamp
+    val lastTimeIndexTimestamp = log.activeSegment.largestTimestamp
     // Depending on when the last time index entry is inserted, an entry may or may not be inserted into the time index.
     val numTimeIndexEntries = log.activeSegment.timeIndex.entries + {
       if (log.activeSegment.timeIndex.lastEntry.offset == log.logEndOffset - 1) 0 else 1
@@ -790,7 +790,7 @@ class LogLoaderTest {
     log = createLog(logDir, logConfig, recoveryPoint = recoveryPoint, lastShutdownClean = false)
     // the recovery point should not be updated after unclean shutdown until the log is flushed
     verifyRecoveredLog(log, recoveryPoint)
-    log.flush()
+    log.flush(false)
     verifyRecoveredLog(log, lastOffset)
     log.close()
   }
@@ -1681,4 +1681,47 @@ class LogLoaderTest {
       s"Found offsets with missing producer state snapshot files: $offsetsWithMissingSnapshotFiles")
     assertFalse(logDir.list().exists(_.endsWith(UnifiedLog.DeletedFileSuffix)), "Expected no files to be present with the deleted file suffix")
   }
+
+  @Test
+  def testRecoverWithEmptyActiveSegment(): Unit = {
+    val numMessages = 100
+    val messageSize = 100
+    val segmentSize = 7 * messageSize
+    val indexInterval = 3 * messageSize
+    val logConfig = LogTestUtils.createLogConfig(segmentBytes = segmentSize, indexIntervalBytes = indexInterval, segmentIndexBytes = 4096)
+    var log = createLog(logDir, logConfig)
+    for(i <- 0 until numMessages)
+      log.appendAsLeader(TestUtils.singletonRecords(value = TestUtils.randomBytes(messageSize),
+        timestamp = mockTime.milliseconds + i * 10), leaderEpoch = 0)
+    assertEquals(numMessages, log.logEndOffset,
+      "After appending %d messages to an empty log, the log end offset should be %d".format(numMessages, numMessages))
+    log.roll()
+    log.flush(false)
+    assertThrows(classOf[NoSuchFileException], () => log.activeSegment.sanityCheck(true))
+    var lastOffset = log.logEndOffset
+    log.closeHandlers()
+
+    log = createLog(logDir, logConfig, recoveryPoint = lastOffset, lastShutdownClean = false)
+    assertEquals(lastOffset, log.recoveryPoint, s"Unexpected recovery point")
+    assertEquals(numMessages, log.logEndOffset, s"Should have $numMessages messages when log is reopened w/o recovery")
+    assertEquals(0, log.activeSegment.timeIndex.entries, "Should have same number of time index entries as before.")
+    log.activeSegment.sanityCheck(true) // this should not throw because the LogLoader created the empty active log index file during recovery
+
+    for(i <- 0 until numMessages)
+      log.appendAsLeader(TestUtils.singletonRecords(value = TestUtils.randomBytes(messageSize),
+        timestamp = mockTime.milliseconds + i * 10), leaderEpoch = 0)
+    log.roll()
+    assertThrows(classOf[NoSuchFileException], () => log.activeSegment.sanityCheck(true))
+    log.flush(true)
+    log.activeSegment.sanityCheck(true) // this should not throw because we flushed the active segment which created the empty log index file
+    lastOffset = log.logEndOffset
+
+    log = createLog(logDir, logConfig, recoveryPoint = lastOffset, lastShutdownClean = false)
+    assertEquals(lastOffset, log.recoveryPoint, s"Unexpected recovery point")
+    assertEquals(2 * numMessages, log.logEndOffset, s"Should have $numMessages messages when log is reopened w/o recovery")
+    assertEquals(0, log.activeSegment.timeIndex.entries, "Should have same number of time index entries as before.")
+    log.activeSegment.sanityCheck(true) // this should not throw
+
+    log.close()
+  }
 }
diff --git a/core/src/test/scala/unit/kafka/log/LogManagerTest.scala b/core/src/test/scala/unit/kafka/log/LogManagerTest.scala
index 937d80c..11b511e 100755
--- a/core/src/test/scala/unit/kafka/log/LogManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogManagerTest.scala
@@ -408,7 +408,7 @@ class LogManagerTest {
       for (_ <- 0 until 50)
         log.appendAsLeader(TestUtils.singletonRecords("test".getBytes()), leaderEpoch = 0)
 
-      log.flush()
+      log.flush(false)
     }
 
     logManager.checkpointLogRecoveryOffsets()
@@ -484,7 +484,7 @@ class LogManagerTest {
     allLogs.foreach { log =>
       for (_ <- 0 until 50)
         log.appendAsLeader(TestUtils.singletonRecords("test".getBytes), leaderEpoch = 0)
-      log.flush()
+      log.flush(false)
     }
 
     logManager.checkpointRecoveryOffsetsInDir(logDir)
diff --git a/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala b/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala
index be63413..6ad78da 100755
--- a/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala
+++ b/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala
@@ -1099,7 +1099,7 @@ class UnifiedLogTest {
 
     // even if we flush within the active segment, the snapshot should remain
     log.appendAsLeader(TestUtils.singletonRecords("baz".getBytes), leaderEpoch = 0)
-    log.flush(4L)
+    log.flushUptoOffsetExclusive(4L)
     assertEquals(Some(3L), log.latestProducerSnapshotOffset)
   }
 
@@ -1293,7 +1293,7 @@ class UnifiedLogTest {
     val memoryRecords = MemoryRecords.readableRecords(buffer)
 
     log.appendAsFollower(memoryRecords)
-    log.flush()
+    log.flush(false)
 
     val fetchedData = LogTestUtils.readLog(log, 0, Int.MaxValue)
 
@@ -1630,6 +1630,21 @@ class UnifiedLogTest {
     assertThrows(classOf[OffsetOutOfRangeException], () => LogTestUtils.readLog(log, 1026, 1000))
   }
 
+  @Test
+  def testFlushingEmptyActiveSegments(): Unit = {
+    val logConfig = LogTestUtils.createLogConfig()
+    val log = createLog(logDir, logConfig)
+    val message = TestUtils.singletonRecords(value = "Test".getBytes, timestamp = mockTime.milliseconds)
+    log.appendAsLeader(message, leaderEpoch = 0)
+    log.roll()
+    assertEquals(2, logDir.listFiles(_.getName.endsWith(".log")).length)
+    assertEquals(1, logDir.listFiles(_.getName.endsWith(".index")).length)
+    assertEquals(0, log.activeSegment.size)
+    log.flush(true)
+    assertEquals(2, logDir.listFiles(_.getName.endsWith(".log")).length)
+    assertEquals(2, logDir.listFiles(_.getName.endsWith(".index")).length)
+  }
+
   /**
    * Test that covers reads and writes on a multisegment log. This test appends a bunch of messages
    * and then reads them all back and checks that the message read and offset matches what was appended.
@@ -1643,7 +1658,7 @@ class UnifiedLogTest {
     val messageSets = (0 until numMessages).map(i => TestUtils.singletonRecords(value = i.toString.getBytes,
                                                                                 timestamp = mockTime.milliseconds))
     messageSets.foreach(log.appendAsLeader(_, leaderEpoch = 0))
-    log.flush()
+    log.flush(false)
 
     /* do successive reads to ensure all our messages are there */
     var offset = 0L
diff --git a/core/src/test/scala/unit/kafka/server/LogOffsetTest.scala b/core/src/test/scala/unit/kafka/server/LogOffsetTest.scala
index be62211..6c71e9c 100755
--- a/core/src/test/scala/unit/kafka/server/LogOffsetTest.scala
+++ b/core/src/test/scala/unit/kafka/server/LogOffsetTest.scala
@@ -69,7 +69,7 @@ class LogOffsetTest extends BaseRequestTest {
 
     for (_ <- 0 until 20)
       log.appendAsLeader(TestUtils.singletonRecords(value = Integer.toString(42).getBytes()), leaderEpoch = 0)
-    log.flush()
+    log.flush(false)
 
     log.updateHighWatermark(log.logEndOffset)
     log.maybeIncrementLogStartOffset(3, ClientRecordDeletion)
@@ -94,7 +94,7 @@ class LogOffsetTest extends BaseRequestTest {
 
     for (timestamp <- 0 until 20)
       log.appendAsLeader(TestUtils.singletonRecords(value = Integer.toString(42).getBytes(), timestamp = timestamp.toLong), leaderEpoch = 0)
-    log.flush()
+    log.flush(false)
 
     log.updateHighWatermark(log.logEndOffset)
 
@@ -117,7 +117,7 @@ class LogOffsetTest extends BaseRequestTest {
 
     for (timestamp <- List(0L, 1L, 2L, 3L, 4L, 6L, 5L))
       log.appendAsLeader(TestUtils.singletonRecords(value = Integer.toString(42).getBytes(), timestamp = timestamp), leaderEpoch = 0)
-    log.flush()
+    log.flush(false)
 
     log.updateHighWatermark(log.logEndOffset)
 
@@ -139,7 +139,7 @@ class LogOffsetTest extends BaseRequestTest {
 
     for (_ <- 0 until 20)
       log.appendAsLeader(TestUtils.singletonRecords(value = Integer.toString(42).getBytes()), leaderEpoch = 0)
-    log.flush()
+    log.flush(false)
 
     val offsets = log.legacyFetchOffsetsBefore(ListOffsetsRequest.LATEST_TIMESTAMP, 15)
     assertEquals(Seq(20L, 18L, 16L, 14L, 12L, 10L, 8L, 6L, 4L, 2L, 0L), offsets)
@@ -210,7 +210,7 @@ class LogOffsetTest extends BaseRequestTest {
 
     for (_ <- 0 until 20)
       log.appendAsLeader(TestUtils.singletonRecords(value = Integer.toString(42).getBytes()), leaderEpoch = 0)
-    log.flush()
+    log.flush(false)
 
     val now = time.milliseconds + 30000 // pretend it is the future to avoid race conditions with the fs
 
@@ -238,7 +238,7 @@ class LogOffsetTest extends BaseRequestTest {
     val log = logManager.getOrCreateLog(topicPartition, topicId = None)
     for (_ <- 0 until 20)
       log.appendAsLeader(TestUtils.singletonRecords(value = Integer.toString(42).getBytes()), leaderEpoch = 0)
-    log.flush()
+    log.flush(false)
 
     val offsets = log.legacyFetchOffsetsBefore(ListOffsetsRequest.EARLIEST_TIMESTAMP, 10)
 
diff --git a/core/src/test/scala/unit/kafka/server/epoch/EpochDrivenReplicationProtocolAcceptanceTest.scala b/core/src/test/scala/unit/kafka/server/epoch/EpochDrivenReplicationProtocolAcceptanceTest.scala
index a6d4019..1e617a8 100644
--- a/core/src/test/scala/unit/kafka/server/epoch/EpochDrivenReplicationProtocolAcceptanceTest.scala
+++ b/core/src/test/scala/unit/kafka/server/epoch/EpochDrivenReplicationProtocolAcceptanceTest.scala
@@ -419,7 +419,7 @@ class EpochDrivenReplicationProtocolAcceptanceTest extends QuorumTestHarness wit
 
   private def getLogFile(broker: KafkaServer, partition: Int): File = {
     val log: UnifiedLog = getLog(broker, partition)
-    log.flush()
+    log.flush(false)
     log.dir.listFiles.filter(_.getName.endsWith(".log"))(0)
   }
 
diff --git a/core/src/test/scala/unit/kafka/tools/DumpLogSegmentsTest.scala b/core/src/test/scala/unit/kafka/tools/DumpLogSegmentsTest.scala
index e0fac30..acc5c69 100644
--- a/core/src/test/scala/unit/kafka/tools/DumpLogSegmentsTest.scala
+++ b/core/src/test/scala/unit/kafka/tools/DumpLogSegmentsTest.scala
@@ -81,7 +81,7 @@ class DumpLogSegmentsTest {
         leaderEpoch = 0)
     }
     // Flush, but don't close so that the indexes are not trimmed and contain some zero entries
-    log.flush()
+    log.flush(false)
   }
 
   @AfterEach
@@ -242,7 +242,7 @@ class DumpLogSegmentsTest {
       new SimpleRecord(null, buf.array)
     }).toArray
     log.appendAsLeader(MemoryRecords.withRecords(CompressionType.NONE, records:_*), leaderEpoch = 1)
-    log.flush()
+    log.flush(false)
 
     var output = runDumpLogSegments(Array("--cluster-metadata-decoder", "false", "--files", logFilePath))
     assert(output.contains("TOPIC_RECORD"))
diff --git a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
index 9a4d414..dced48d 100644
--- a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
+++ b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
@@ -443,7 +443,7 @@ public class KafkaRaftClient<T> implements RaftClient<T> {
     private void flushLeaderLog(LeaderState<T> state, long currentTimeMs) {
         // We update the end offset before flushing so that parked fetches can return sooner.
         updateLeaderEndOffsetAndTimestamp(state, currentTimeMs);
-        log.flush();
+        log.flush(false);
     }
 
     private boolean maybeTransitionToLeader(CandidateState state, long currentTimeMs) {
@@ -1136,7 +1136,7 @@ public class KafkaRaftClient<T> implements RaftClient<T> {
         Records records
     ) {
         LogAppendInfo info = log.appendAsFollower(records);
-        log.flush();
+        log.flush(false);
 
         OffsetAndEpoch endOffset = endOffset();
         kafkaRaftMetrics.updateFetchedRecords(info.lastOffset - info.firstOffset + 1);
@@ -2362,6 +2362,7 @@ public class KafkaRaftClient<T> implements RaftClient<T> {
 
     @Override
     public void close() {
+        log.flush(true);
         if (kafkaRaftMetrics != null) {
             kafkaRaftMetrics.close();
         }
diff --git a/raft/src/main/java/org/apache/kafka/raft/ReplicatedLog.java b/raft/src/main/java/org/apache/kafka/raft/ReplicatedLog.java
index f60a796..b71de32 100644
--- a/raft/src/main/java/org/apache/kafka/raft/ReplicatedLog.java
+++ b/raft/src/main/java/org/apache/kafka/raft/ReplicatedLog.java
@@ -184,8 +184,10 @@ public interface ReplicatedLog extends AutoCloseable {
 
     /**
      * Flush the current log to disk.
+     *
+     * @param forceFlushActiveSegment Whether to force flush the active segment. Should be `true` during close; otherwise false.
      */
-    void flush();
+    void flush(boolean forceFlushActiveSegment);
 
     /**
      * Possibly perform cleaning of snapshots and logs
diff --git a/raft/src/test/java/org/apache/kafka/raft/MockLog.java b/raft/src/test/java/org/apache/kafka/raft/MockLog.java
index 50cdeec..c4b4c10 100644
--- a/raft/src/test/java/org/apache/kafka/raft/MockLog.java
+++ b/raft/src/test/java/org/apache/kafka/raft/MockLog.java
@@ -100,7 +100,7 @@ public class MockLog implements ReplicatedLog {
                 epochStartOffsets.clear();
                 snapshots.headMap(snapshotId, false).clear();
                 updateHighWatermark(new LogOffsetMetadata(snapshotId.offset));
-                flush();
+                flush(false);
 
                 truncated.set(true);
             }
@@ -328,7 +328,7 @@ public class MockLog implements ReplicatedLog {
     }
 
     @Override
-    public void flush() {
+    public void flush(boolean forceFlushActiveSegment) {
         lastFlushedOffset = endOffset().offset;
     }
 
diff --git a/raft/src/test/java/org/apache/kafka/raft/MockLogTest.java b/raft/src/test/java/org/apache/kafka/raft/MockLogTest.java
index 9365640..88ee3a6 100644
--- a/raft/src/test/java/org/apache/kafka/raft/MockLogTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/MockLogTest.java
@@ -420,7 +420,7 @@ public class MockLogTest {
     public void testUnflushedRecordsLostAfterReopen() {
         appendBatch(5, 1);
         appendBatch(10, 2);
-        log.flush();
+        log.flush(false);
 
         appendBatch(5, 3);
         appendBatch(10, 4);