You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by sh...@apache.org on 2022/07/22 03:04:56 UTC

[kafka] 03/03: KAFKA-13919: expose log recovery metrics (#12347)

This is an automated email from the ASF dual-hosted git repository.

showuon pushed a commit to branch 3.3
in repository https://gitbox.apache.org/repos/asf/kafka.git

commit 158fc5f0269670d80d6c8f49186c51930bd600a1
Author: Luke Chen <sh...@gmail.com>
AuthorDate: Fri Jul 22 11:00:15 2022 +0800

    KAFKA-13919: expose log recovery metrics (#12347)
    
    Implementation for KIP-831.
    1. add remainingLogsToRecover metric for the number of remaining logs for each log.dir to be recovered
    2.  add remainingSegmentsToRecover metric for the number of remaining segments for the current log assigned to the recovery thread.
    3. remove these metrics after log loaded completely
    4. add tests
    
    Reviewers: Jun Rao <ju...@confluent.io>, Tom Bentley <tb...@redhat.com>
---
 core/src/main/scala/kafka/log/LogLoader.scala      |  26 ++-
 core/src/main/scala/kafka/log/LogManager.scala     |  85 ++++++--
 core/src/main/scala/kafka/log/UnifiedLog.scala     |  10 +-
 .../test/scala/unit/kafka/log/LogLoaderTest.scala  |   3 +-
 .../test/scala/unit/kafka/log/LogManagerTest.scala | 220 ++++++++++++++++++++-
 .../test/scala/unit/kafka/log/LogTestUtils.scala   |   7 +-
 .../test/scala/unit/kafka/utils/TestUtils.scala    |   5 +-
 .../apache/kafka/jmh/server/CheckpointBench.java   |   2 +-
 8 files changed, 319 insertions(+), 39 deletions(-)

diff --git a/core/src/main/scala/kafka/log/LogLoader.scala b/core/src/main/scala/kafka/log/LogLoader.scala
index 581d016e5e..25ee89c72b 100644
--- a/core/src/main/scala/kafka/log/LogLoader.scala
+++ b/core/src/main/scala/kafka/log/LogLoader.scala
@@ -29,6 +29,7 @@ import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.errors.InvalidOffsetException
 import org.apache.kafka.common.utils.Time
 
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
 import scala.collection.{Set, mutable}
 
 case class LoadedLogOffsets(logStartOffset: Long,
@@ -64,6 +65,7 @@ object LogLoader extends Logging {
  * @param recoveryPointCheckpoint The checkpoint of the offset at which to begin the recovery
  * @param leaderEpochCache An optional LeaderEpochFileCache instance to be updated during recovery
  * @param producerStateManager The ProducerStateManager instance to be updated during recovery
+ * @param numRemainingSegments The remaining segments to be recovered in this log keyed by recovery thread name
  */
 class LogLoader(
   dir: File,
@@ -77,7 +79,8 @@ class LogLoader(
   logStartOffsetCheckpoint: Long,
   recoveryPointCheckpoint: Long,
   leaderEpochCache: Option[LeaderEpochFileCache],
-  producerStateManager: ProducerStateManager
+  producerStateManager: ProducerStateManager,
+  numRemainingSegments: ConcurrentMap[String, Int] = new ConcurrentHashMap[String, Int]
 ) extends Logging {
   logIdent = s"[LogLoader partition=$topicPartition, dir=${dir.getParent}] "
 
@@ -404,12 +407,18 @@ class LogLoader(
 
     // If we have the clean shutdown marker, skip recovery.
     if (!hadCleanShutdown) {
-      val unflushed = segments.values(recoveryPointCheckpoint, Long.MaxValue).iterator
+      val unflushed = segments.values(recoveryPointCheckpoint, Long.MaxValue)
+      val numUnflushed = unflushed.size
+      val unflushedIter = unflushed.iterator
       var truncated = false
+      var numFlushed = 0
+      val threadName = Thread.currentThread().getName
+      numRemainingSegments.put(threadName, numUnflushed)
+
+      while (unflushedIter.hasNext && !truncated) {
+        val segment = unflushedIter.next()
+        info(s"Recovering unflushed segment ${segment.baseOffset}. $numFlushed/$numUnflushed recovered for $topicPartition.")
 
-      while (unflushed.hasNext && !truncated) {
-        val segment = unflushed.next()
-        info(s"Recovering unflushed segment ${segment.baseOffset}")
         val truncatedBytes =
           try {
             recoverSegment(segment)
@@ -424,8 +433,13 @@ class LogLoader(
           // we had an invalid message, delete all remaining log
           warn(s"Corruption found in segment ${segment.baseOffset}," +
             s" truncating to offset ${segment.readNextOffset}")
-          removeAndDeleteSegmentsAsync(unflushed.toList)
+          removeAndDeleteSegmentsAsync(unflushedIter.toList)
           truncated = true
+          // segment is truncated, so set remaining segments to 0
+          numRemainingSegments.put(threadName, 0)
+        } else {
+          numFlushed += 1
+          numRemainingSegments.put(threadName, numUnflushed - numFlushed)
         }
       }
     }
diff --git a/core/src/main/scala/kafka/log/LogManager.scala b/core/src/main/scala/kafka/log/LogManager.scala
index bdc7ffd74d..886f56c63c 100755
--- a/core/src/main/scala/kafka/log/LogManager.scala
+++ b/core/src/main/scala/kafka/log/LogManager.scala
@@ -262,7 +262,8 @@ class LogManager(logDirs: Seq[File],
                            recoveryPoints: Map[TopicPartition, Long],
                            logStartOffsets: Map[TopicPartition, Long],
                            defaultConfig: LogConfig,
-                           topicConfigOverrides: Map[String, LogConfig]): UnifiedLog = {
+                           topicConfigOverrides: Map[String, LogConfig],
+                           numRemainingSegments: ConcurrentMap[String, Int]): UnifiedLog = {
     val topicPartition = UnifiedLog.parseTopicPartitionName(logDir)
     val config = topicConfigOverrides.getOrElse(topicPartition.topic, defaultConfig)
     val logRecoveryPoint = recoveryPoints.getOrElse(topicPartition, 0L)
@@ -282,7 +283,8 @@ class LogManager(logDirs: Seq[File],
       logDirFailureChannel = logDirFailureChannel,
       lastShutdownClean = hadCleanShutdown,
       topicId = None,
-      keepPartitionMetadataFile = keepPartitionMetadataFile)
+      keepPartitionMetadataFile = keepPartitionMetadataFile,
+      numRemainingSegments = numRemainingSegments)
 
     if (logDir.getName.endsWith(UnifiedLog.DeleteDirSuffix)) {
       addLogToBeDeleted(log)
@@ -307,6 +309,27 @@ class LogManager(logDirs: Seq[File],
     log
   }
 
+  // factory class for naming the log recovery threads used in metrics
+  class LogRecoveryThreadFactory(val dirPath: String) extends ThreadFactory {
+    val threadNum = new AtomicInteger(0)
+
+    override def newThread(runnable: Runnable): Thread = {
+      KafkaThread.nonDaemon(logRecoveryThreadName(dirPath, threadNum.getAndIncrement()), runnable)
+    }
+  }
+
+  // create a unique log recovery thread name for each log dir as the format: prefix-dirPath-threadNum, ex: "log-recovery-/tmp/kafkaLogs-0"
+  private def logRecoveryThreadName(dirPath: String, threadNum: Int, prefix: String = "log-recovery"): String = s"$prefix-$dirPath-$threadNum"
+
+  /*
+   * decrement the number of remaining logs
+   * @return the number of remaining logs after decremented 1
+   */
+  private[log] def decNumRemainingLogs(numRemainingLogs: ConcurrentMap[String, Int], path: String): Int = {
+    require(path != null, "path cannot be null to update remaining logs metric.")
+    numRemainingLogs.compute(path, (_, oldVal) => oldVal - 1)
+  }
+
   /**
    * Recover and load all logs in the given data directories
    */
@@ -317,6 +340,10 @@ class LogManager(logDirs: Seq[File],
     val offlineDirs = mutable.Set.empty[(String, IOException)]
     val jobs = ArrayBuffer.empty[Seq[Future[_]]]
     var numTotalLogs = 0
+    // log dir path -> number of Remaining logs map for remainingLogsToRecover metric
+    val numRemainingLogs: ConcurrentMap[String, Int] = new ConcurrentHashMap[String, Int]
+    // log recovery thread name -> number of remaining segments map for remainingSegmentsToRecover metric
+    val numRemainingSegments: ConcurrentMap[String, Int] = new ConcurrentHashMap[String, Int]
 
     def handleIOException(logDirAbsolutePath: String, e: IOException): Unit = {
       offlineDirs.add((logDirAbsolutePath, e))
@@ -328,7 +355,7 @@ class LogManager(logDirs: Seq[File],
       var hadCleanShutdown: Boolean = false
       try {
         val pool = Executors.newFixedThreadPool(numRecoveryThreadsPerDataDir,
-          KafkaThread.nonDaemon(s"log-recovery-$logDirAbsolutePath", _))
+          new LogRecoveryThreadFactory(logDirAbsolutePath))
         threadPools.append(pool)
 
         val cleanShutdownFile = new File(dir, LogLoader.CleanShutdownFile)
@@ -363,28 +390,32 @@ class LogManager(logDirs: Seq[File],
 
         val logsToLoad = Option(dir.listFiles).getOrElse(Array.empty).filter(logDir =>
           logDir.isDirectory && UnifiedLog.parseTopicPartitionName(logDir).topic != KafkaRaftServer.MetadataTopic)
-        val numLogsLoaded = new AtomicInteger(0)
         numTotalLogs += logsToLoad.length
+        numRemainingLogs.put(dir.getAbsolutePath, logsToLoad.length)
 
         val jobsForDir = logsToLoad.map { logDir =>
           val runnable: Runnable = () => {
+            debug(s"Loading log $logDir")
+            var log = None: Option[UnifiedLog]
+            val logLoadStartMs = time.hiResClockMs()
             try {
-              debug(s"Loading log $logDir")
-
-              val logLoadStartMs = time.hiResClockMs()
-              val log = loadLog(logDir, hadCleanShutdown, recoveryPoints, logStartOffsets,
-                defaultConfig, topicConfigOverrides)
-              val logLoadDurationMs = time.hiResClockMs() - logLoadStartMs
-              val currentNumLoaded = numLogsLoaded.incrementAndGet()
-
-              info(s"Completed load of $log with ${log.numberOfSegments} segments in ${logLoadDurationMs}ms " +
-                s"($currentNumLoaded/${logsToLoad.length} loaded in $logDirAbsolutePath)")
+              log = Some(loadLog(logDir, hadCleanShutdown, recoveryPoints, logStartOffsets,
+                defaultConfig, topicConfigOverrides, numRemainingSegments))
             } catch {
               case e: IOException =>
                 handleIOException(logDirAbsolutePath, e)
               case e: KafkaStorageException if e.getCause.isInstanceOf[IOException] =>
                 // KafkaStorageException might be thrown, ex: during writing LeaderEpochFileCache
                 // And while converting IOException to KafkaStorageException, we've already handled the exception. So we can ignore it here.
+            } finally {
+              val logLoadDurationMs = time.hiResClockMs() - logLoadStartMs
+              val remainingLogs = decNumRemainingLogs(numRemainingLogs, dir.getAbsolutePath)
+              val currentNumLoaded = logsToLoad.length - remainingLogs
+              log match {
+                case Some(loadedLog) => info(s"Completed load of $loadedLog with ${loadedLog.numberOfSegments} segments in ${logLoadDurationMs}ms " +
+                  s"($currentNumLoaded/${logsToLoad.length} completed in $logDirAbsolutePath)")
+                case None => info(s"Error while loading logs in $logDir in ${logLoadDurationMs}ms ($currentNumLoaded/${logsToLoad.length} completed in $logDirAbsolutePath)")
+              }
             }
           }
           runnable
@@ -398,6 +429,7 @@ class LogManager(logDirs: Seq[File],
     }
 
     try {
+      addLogRecoveryMetrics(numRemainingLogs, numRemainingSegments)
       for (dirJobs <- jobs) {
         dirJobs.foreach(_.get)
       }
@@ -410,12 +442,37 @@ class LogManager(logDirs: Seq[File],
         error(s"There was an error in one of the threads during logs loading: ${e.getCause}")
         throw e.getCause
     } finally {
+      removeLogRecoveryMetrics()
       threadPools.foreach(_.shutdown())
     }
 
     info(s"Loaded $numTotalLogs logs in ${time.hiResClockMs() - startMs}ms.")
   }
 
+  private[log] def addLogRecoveryMetrics(numRemainingLogs: ConcurrentMap[String, Int],
+                                         numRemainingSegments: ConcurrentMap[String, Int]): Unit = {
+    debug("Adding log recovery metrics")
+    for (dir <- logDirs) {
+      newGauge("remainingLogsToRecover", () => numRemainingLogs.get(dir.getAbsolutePath),
+        Map("dir" -> dir.getAbsolutePath))
+      for (i <- 0 until numRecoveryThreadsPerDataDir) {
+        val threadName = logRecoveryThreadName(dir.getAbsolutePath, i)
+        newGauge("remainingSegmentsToRecover", () => numRemainingSegments.get(threadName),
+          Map("dir" -> dir.getAbsolutePath, "threadNum" -> i.toString))
+      }
+    }
+  }
+
+  private[log] def removeLogRecoveryMetrics(): Unit = {
+    debug("Removing log recovery metrics")
+    for (dir <- logDirs) {
+      removeMetric("remainingLogsToRecover", Map("dir" -> dir.getAbsolutePath))
+      for (i <- 0 until numRecoveryThreadsPerDataDir) {
+        removeMetric("remainingSegmentsToRecover", Map("dir" -> dir.getAbsolutePath, "threadNum" -> i.toString))
+      }
+    }
+  }
+
   /**
    *  Start the background threads to flush logs and do log cleanup
    */
diff --git a/core/src/main/scala/kafka/log/UnifiedLog.scala b/core/src/main/scala/kafka/log/UnifiedLog.scala
index ddd66eb160..c4a2300237 100644
--- a/core/src/main/scala/kafka/log/UnifiedLog.scala
+++ b/core/src/main/scala/kafka/log/UnifiedLog.scala
@@ -18,11 +18,11 @@
 package kafka.log
 
 import com.yammer.metrics.core.MetricName
+
 import java.io.{File, IOException}
 import java.nio.file.Files
 import java.util.Optional
-import java.util.concurrent.TimeUnit
-
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, TimeUnit}
 import kafka.common.{LongRef, OffsetsOutOfOrderException, UnexpectedAppendOffsetException}
 import kafka.log.AppendOrigin.RaftLeader
 import kafka.message.{BrokerCompressionCodec, CompressionCodec, NoCompressionCodec}
@@ -1803,7 +1803,8 @@ object UnifiedLog extends Logging {
             logDirFailureChannel: LogDirFailureChannel,
             lastShutdownClean: Boolean = true,
             topicId: Option[Uuid],
-            keepPartitionMetadataFile: Boolean): UnifiedLog = {
+            keepPartitionMetadataFile: Boolean,
+            numRemainingSegments: ConcurrentMap[String, Int] = new ConcurrentHashMap[String, Int]): UnifiedLog = {
     // create the log directory if it doesn't exist
     Files.createDirectories(dir.toPath)
     val topicPartition = UnifiedLog.parseTopicPartitionName(dir)
@@ -1828,7 +1829,8 @@ object UnifiedLog extends Logging {
       logStartOffset,
       recoveryPoint,
       leaderEpochCache,
-      producerStateManager
+      producerStateManager,
+      numRemainingSegments
     ).load()
     val localLog = new LocalLog(dir, config, segments, offsets.recoveryPoint,
       offsets.nextOffsetMetadata, scheduler, time, topicPartition, logDirFailureChannel)
diff --git a/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala b/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala
index 0d41a5073b..c6379ff3f3 100644
--- a/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala
@@ -38,6 +38,7 @@ import org.mockito.ArgumentMatchers
 import org.mockito.ArgumentMatchers.{any, anyLong}
 import org.mockito.Mockito.{mock, reset, times, verify, when}
 
+import java.util.concurrent.ConcurrentMap
 import scala.annotation.nowarn
 import scala.collection.mutable.ListBuffer
 import scala.collection.{Iterable, Map, mutable}
@@ -117,7 +118,7 @@ class LogLoaderTest {
 
         override def loadLog(logDir: File, hadCleanShutdown: Boolean, recoveryPoints: Map[TopicPartition, Long],
                              logStartOffsets: Map[TopicPartition, Long], defaultConfig: LogConfig,
-                             topicConfigs: Map[String, LogConfig]): UnifiedLog = {
+                             topicConfigs: Map[String, LogConfig], numRemainingSegments: ConcurrentMap[String, Int]): UnifiedLog = {
           if (simulateError.hasError) {
             simulateError.errorType match {
               case ErrorTypes.KafkaStorageExceptionWithIOExceptionCause =>
diff --git a/core/src/test/scala/unit/kafka/log/LogManagerTest.scala b/core/src/test/scala/unit/kafka/log/LogManagerTest.scala
index 5353df6db3..1b2dd7809f 100755
--- a/core/src/test/scala/unit/kafka/log/LogManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogManagerTest.scala
@@ -17,10 +17,10 @@
 
 package kafka.log
 
-import com.yammer.metrics.core.MetricName
+import com.yammer.metrics.core.{Gauge, MetricName}
 import kafka.server.checkpoints.OffsetCheckpointFile
 import kafka.server.metadata.{ConfigRepository, MockConfigRepository}
-import kafka.server.{FetchDataInfo, FetchLogEnd}
+import kafka.server.{BrokerTopicStats, FetchDataInfo, FetchLogEnd, LogDirFailureChannel}
 import kafka.utils._
 import org.apache.directory.api.util.FileUtils
 import org.apache.kafka.common.errors.OffsetOutOfRangeException
@@ -29,16 +29,17 @@ import org.apache.kafka.common.{KafkaException, TopicPartition}
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
 import org.mockito.ArgumentMatchers.any
-import org.mockito.{ArgumentMatchers, Mockito}
-import org.mockito.Mockito.{doAnswer, mock, never, spy, times, verify}
+import org.mockito.{ArgumentCaptor, ArgumentMatchers, Mockito}
+import org.mockito.Mockito.{doAnswer, doNothing, mock, never, spy, times, verify}
+
 import java.io._
 import java.nio.file.Files
-import java.util.concurrent.Future
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Future}
 import java.util.{Collections, Properties}
-
 import org.apache.kafka.server.metrics.KafkaYammerMetrics
 
-import scala.collection.mutable
+import scala.collection.{Map, mutable}
+import scala.collection.mutable.ArrayBuffer
 import scala.jdk.CollectionConverters._
 import scala.util.{Failure, Try}
 
@@ -421,12 +422,14 @@ class LogManagerTest {
   }
 
   private def createLogManager(logDirs: Seq[File] = Seq(this.logDir),
-                               configRepository: ConfigRepository = new MockConfigRepository): LogManager = {
+                               configRepository: ConfigRepository = new MockConfigRepository,
+                               recoveryThreadsPerDataDir: Int = 1): LogManager = {
     TestUtils.createLogManager(
       defaultConfig = logConfig,
       configRepository = configRepository,
       logDirs = logDirs,
-      time = this.time)
+      time = this.time,
+      recoveryThreadsPerDataDir = recoveryThreadsPerDataDir)
   }
 
   @Test
@@ -638,6 +641,205 @@ class LogManagerTest {
     assertTrue(logManager.partitionsInitializing.isEmpty)
   }
 
+  private def appendRecordsToLog(time: MockTime, parentLogDir: File, partitionId: Int, brokerTopicStats: BrokerTopicStats, expectedSegmentsPerLog: Int): Unit = {
+    def createRecord = TestUtils.singletonRecords(value = "test".getBytes, timestamp = time.milliseconds)
+    val tpFile = new File(parentLogDir, s"$name-$partitionId")
+    val segmentBytes = 1024
+
+    val log = LogTestUtils.createLog(tpFile, logConfig, brokerTopicStats, time.scheduler, time, 0, 0,
+      5 * 60 * 1000, 60 * 60 * 1000, LogManager.ProducerIdExpirationCheckIntervalMs)
+
+    assertTrue(expectedSegmentsPerLog > 0)
+    // calculate numMessages to append to logs. It'll create "expectedSegmentsPerLog" log segments with segment.bytes=1024
+    val numMessages = Math.floor(segmentBytes * expectedSegmentsPerLog / createRecord.sizeInBytes).asInstanceOf[Int]
+    try {
+      for (_ <- 0 until numMessages) {
+        log.appendAsLeader(createRecord, leaderEpoch = 0)
+      }
+
+      assertEquals(expectedSegmentsPerLog, log.numberOfSegments)
+    } finally {
+      log.close()
+    }
+  }
+
+  private def verifyRemainingLogsToRecoverMetric(spyLogManager: LogManager, expectedParams: Map[String, Int]): Unit = {
+    val spyLogManagerClassName = spyLogManager.getClass().getSimpleName
+    // get all `remainingLogsToRecover` metrics
+    val logMetrics: ArrayBuffer[Gauge[Int]] = KafkaYammerMetrics.defaultRegistry.allMetrics.asScala
+      .filter { case (metric, _) => metric.getType == s"$spyLogManagerClassName" && metric.getName == "remainingLogsToRecover" }
+      .map { case (_, gauge) => gauge }
+      .asInstanceOf[ArrayBuffer[Gauge[Int]]]
+
+    assertEquals(expectedParams.size, logMetrics.size)
+
+    val capturedPath: ArgumentCaptor[String] = ArgumentCaptor.forClass(classOf[String])
+
+    val expectedCallTimes = expectedParams.values.sum
+    verify(spyLogManager, times(expectedCallTimes)).decNumRemainingLogs(any[ConcurrentMap[String, Int]], capturedPath.capture());
+
+    val paths = capturedPath.getAllValues
+    expectedParams.foreach {
+      case (path, totalLogs) =>
+        // make sure each path is called "totalLogs" times, which means it is decremented to 0 in the end
+        assertEquals(totalLogs, Collections.frequency(paths, path))
+    }
+
+    // expected the end value is 0
+    logMetrics.foreach { gauge => assertEquals(0, gauge.value()) }
+  }
+
+  private def verifyRemainingSegmentsToRecoverMetric(spyLogManager: LogManager,
+                                                     logDirs: Seq[File],
+                                                     recoveryThreadsPerDataDir: Int,
+                                                     mockMap: ConcurrentHashMap[String, Int],
+                                                     expectedParams: Map[String, Int]): Unit = {
+    val spyLogManagerClassName = spyLogManager.getClass().getSimpleName
+    // get all `remainingSegmentsToRecover` metrics
+    val logSegmentMetrics: ArrayBuffer[Gauge[Int]] = KafkaYammerMetrics.defaultRegistry.allMetrics.asScala
+          .filter { case (metric, _) => metric.getType == s"$spyLogManagerClassName" && metric.getName == "remainingSegmentsToRecover" }
+          .map { case (_, gauge) => gauge }
+          .asInstanceOf[ArrayBuffer[Gauge[Int]]]
+
+    // expected each log dir has 1 metrics for each thread
+    assertEquals(recoveryThreadsPerDataDir * logDirs.size, logSegmentMetrics.size)
+
+    val capturedThreadName: ArgumentCaptor[String] = ArgumentCaptor.forClass(classOf[String])
+    val capturedNumRemainingSegments: ArgumentCaptor[Int] = ArgumentCaptor.forClass(classOf[Int])
+
+    // Since we'll update numRemainingSegments from totalSegments to 0 for each thread, so we need to add 1 here
+    val expectedCallTimes = expectedParams.values.map( num => num + 1 ).sum
+    verify(mockMap, times(expectedCallTimes)).put(capturedThreadName.capture(), capturedNumRemainingSegments.capture());
+
+    // expected the end value is 0
+    logSegmentMetrics.foreach { gauge => assertEquals(0, gauge.value()) }
+
+    val threadNames = capturedThreadName.getAllValues
+    val numRemainingSegments = capturedNumRemainingSegments.getAllValues
+
+    expectedParams.foreach {
+      case (threadName, totalSegments) =>
+        // make sure we update the numRemainingSegments from totalSegments to 0 in order for each thread
+        var expectedCurRemainingSegments = totalSegments + 1
+        for (i <- 0 until threadNames.size) {
+          if (threadNames.get(i).contains(threadName)) {
+            expectedCurRemainingSegments -= 1
+            assertEquals(expectedCurRemainingSegments, numRemainingSegments.get(i))
+          }
+        }
+        assertEquals(0, expectedCurRemainingSegments)
+    }
+  }
+
+  private def verifyLogRecoverMetricsRemoved(spyLogManager: LogManager): Unit = {
+    val spyLogManagerClassName = spyLogManager.getClass().getSimpleName
+    // get all `remainingLogsToRecover` metrics
+    def logMetrics: mutable.Set[MetricName] = KafkaYammerMetrics.defaultRegistry.allMetrics.keySet.asScala
+      .filter { metric => metric.getType == s"$spyLogManagerClassName" && metric.getName == "remainingLogsToRecover" }
+
+    assertTrue(logMetrics.isEmpty)
+
+    // get all `remainingSegmentsToRecover` metrics
+    val logSegmentMetrics: mutable.Set[MetricName] = KafkaYammerMetrics.defaultRegistry.allMetrics.keySet.asScala
+      .filter { metric => metric.getType == s"$spyLogManagerClassName" && metric.getName == "remainingSegmentsToRecover" }
+
+    assertTrue(logSegmentMetrics.isEmpty)
+  }
+
+  @Test
+  def testLogRecoveryMetrics(): Unit = {
+    logManager.shutdown()
+    val logDir1 = TestUtils.tempDir()
+    val logDir2 = TestUtils.tempDir()
+    val logDirs = Seq(logDir1, logDir2)
+    val recoveryThreadsPerDataDir = 2
+    // create logManager with expected recovery thread number
+    logManager = createLogManager(logDirs, recoveryThreadsPerDataDir = recoveryThreadsPerDataDir)
+    val spyLogManager = spy(logManager)
+
+    assertEquals(2, spyLogManager.liveLogDirs.size)
+
+    val mockTime = new MockTime()
+    val mockMap = mock(classOf[ConcurrentHashMap[String, Int]])
+    val mockBrokerTopicStats = mock(classOf[BrokerTopicStats])
+    val expectedSegmentsPerLog = 2
+
+    // create log segments for log recovery in each log dir
+    appendRecordsToLog(mockTime, logDir1, 0, mockBrokerTopicStats, expectedSegmentsPerLog)
+    appendRecordsToLog(mockTime, logDir2, 1, mockBrokerTopicStats, expectedSegmentsPerLog)
+
+    // intercept loadLog method to pass expected parameter to do log recovery
+    doAnswer { invocation =>
+      val dir: File = invocation.getArgument(0)
+      val topicConfigOverrides: mutable.Map[String, LogConfig] = invocation.getArgument(5)
+
+      val topicPartition = UnifiedLog.parseTopicPartitionName(dir)
+      val config = topicConfigOverrides.getOrElse(topicPartition.topic, logConfig)
+
+      UnifiedLog(
+        dir = dir,
+        config = config,
+        logStartOffset = 0,
+        recoveryPoint = 0,
+        maxTransactionTimeoutMs = 5 * 60 * 1000,
+        maxProducerIdExpirationMs = 5 * 60 * 1000,
+        producerIdExpirationCheckIntervalMs = LogManager.ProducerIdExpirationCheckIntervalMs,
+        scheduler = mockTime.scheduler,
+        time = mockTime,
+        brokerTopicStats = mockBrokerTopicStats,
+        logDirFailureChannel = mock(classOf[LogDirFailureChannel]),
+        // not clean shutdown
+        lastShutdownClean = false,
+        topicId = None,
+        keepPartitionMetadataFile = false,
+        // pass mock map for verification later
+        numRemainingSegments = mockMap)
+
+    }.when(spyLogManager).loadLog(any[File], any[Boolean], any[Map[TopicPartition, Long]], any[Map[TopicPartition, Long]],
+      any[LogConfig], any[Map[String, LogConfig]], any[ConcurrentMap[String, Int]])
+
+    // do nothing for removeLogRecoveryMetrics for metrics verification
+    doNothing().when(spyLogManager).removeLogRecoveryMetrics()
+
+    // start the logManager to do log recovery
+    spyLogManager.startup(Set.empty)
+
+    // make sure log recovery metrics are added and removed
+    verify(spyLogManager, times(1)).addLogRecoveryMetrics(any[ConcurrentMap[String, Int]], any[ConcurrentMap[String, Int]])
+    verify(spyLogManager, times(1)).removeLogRecoveryMetrics()
+
+    // expected 1 log in each log dir since we created 2 partitions with 2 log dirs
+    val expectedRemainingLogsParams = Map[String, Int](logDir1.getAbsolutePath -> 1, logDir2.getAbsolutePath -> 1)
+    verifyRemainingLogsToRecoverMetric(spyLogManager, expectedRemainingLogsParams)
+
+    val expectedRemainingSegmentsParams = Map[String, Int](
+      logDir1.getAbsolutePath -> expectedSegmentsPerLog, logDir2.getAbsolutePath -> expectedSegmentsPerLog)
+    verifyRemainingSegmentsToRecoverMetric(spyLogManager, logDirs, recoveryThreadsPerDataDir, mockMap, expectedRemainingSegmentsParams)
+  }
+
+  @Test
+  def testLogRecoveryMetricsShouldBeRemovedAfterLogRecovered(): Unit = {
+    logManager.shutdown()
+    val logDir1 = TestUtils.tempDir()
+    val logDir2 = TestUtils.tempDir()
+    val logDirs = Seq(logDir1, logDir2)
+    val recoveryThreadsPerDataDir = 2
+    // create logManager with expected recovery thread number
+    logManager = createLogManager(logDirs, recoveryThreadsPerDataDir = recoveryThreadsPerDataDir)
+    val spyLogManager = spy(logManager)
+
+    assertEquals(2, spyLogManager.liveLogDirs.size)
+
+    // start the logManager to do log recovery
+    spyLogManager.startup(Set.empty)
+
+    // make sure log recovery metrics are added and removed once
+    verify(spyLogManager, times(1)).addLogRecoveryMetrics(any[ConcurrentMap[String, Int]], any[ConcurrentMap[String, Int]])
+    verify(spyLogManager, times(1)).removeLogRecoveryMetrics()
+
+    verifyLogRecoverMetricsRemoved(spyLogManager)
+  }
+
   @Test
   def testMetricsExistWhenLogIsRecreatedBeforeDeletion(): Unit = {
     val topicName = "metric-test"
diff --git a/core/src/test/scala/unit/kafka/log/LogTestUtils.scala b/core/src/test/scala/unit/kafka/log/LogTestUtils.scala
index f6b58d78ce..50af76f556 100644
--- a/core/src/test/scala/unit/kafka/log/LogTestUtils.scala
+++ b/core/src/test/scala/unit/kafka/log/LogTestUtils.scala
@@ -28,6 +28,7 @@ import org.apache.kafka.common.utils.{Time, Utils}
 import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse}
 
 import java.nio.file.Files
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
 import scala.collection.Iterable
 import scala.jdk.CollectionConverters._
 
@@ -83,7 +84,8 @@ object LogTestUtils {
                 producerIdExpirationCheckIntervalMs: Int = LogManager.ProducerIdExpirationCheckIntervalMs,
                 lastShutdownClean: Boolean = true,
                 topicId: Option[Uuid] = None,
-                keepPartitionMetadataFile: Boolean = true): UnifiedLog = {
+                keepPartitionMetadataFile: Boolean = true,
+                numRemainingSegments: ConcurrentMap[String, Int] = new ConcurrentHashMap[String, Int]): UnifiedLog = {
     UnifiedLog(
       dir = dir,
       config = config,
@@ -98,7 +100,8 @@ object LogTestUtils {
       logDirFailureChannel = new LogDirFailureChannel(10),
       lastShutdownClean = lastShutdownClean,
       topicId = topicId,
-      keepPartitionMetadataFile = keepPartitionMetadataFile
+      keepPartitionMetadataFile = keepPartitionMetadataFile,
+      numRemainingSegments = numRemainingSegments
     )
   }
 
diff --git a/core/src/test/scala/unit/kafka/utils/TestUtils.scala b/core/src/test/scala/unit/kafka/utils/TestUtils.scala
index 5a8d43795a..c49a7bdde0 100755
--- a/core/src/test/scala/unit/kafka/utils/TestUtils.scala
+++ b/core/src/test/scala/unit/kafka/utils/TestUtils.scala
@@ -1281,13 +1281,14 @@ object TestUtils extends Logging {
                        configRepository: ConfigRepository = new MockConfigRepository,
                        cleanerConfig: CleanerConfig = CleanerConfig(enableCleaner = false),
                        time: MockTime = new MockTime(),
-                       interBrokerProtocolVersion: MetadataVersion = MetadataVersion.latest): LogManager = {
+                       interBrokerProtocolVersion: MetadataVersion = MetadataVersion.latest,
+                       recoveryThreadsPerDataDir: Int = 4): LogManager = {
     new LogManager(logDirs = logDirs.map(_.getAbsoluteFile),
                    initialOfflineDirs = Array.empty[File],
                    configRepository = configRepository,
                    initialDefaultConfig = defaultConfig,
                    cleanerConfig = cleanerConfig,
-                   recoveryThreadsPerDataDir = 4,
+                   recoveryThreadsPerDataDir = recoveryThreadsPerDataDir,
                    flushCheckMs = 1000L,
                    flushRecoveryOffsetCheckpointMs = 10000L,
                    flushStartOffsetCheckpointMs = 10000L,
diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/server/CheckpointBench.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/server/CheckpointBench.java
index 3bf65afc22..99fb814327 100644
--- a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/server/CheckpointBench.java
+++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/server/CheckpointBench.java
@@ -108,7 +108,7 @@ public class CheckpointBench {
         this.logManager = TestUtils.createLogManager(JavaConverters.asScalaBuffer(files),
                 LogConfig.apply(), new MockConfigRepository(), CleanerConfig.apply(1, 4 * 1024 * 1024L, 0.9d,
                         1024 * 1024, 32 * 1024 * 1024,
-                        Double.MAX_VALUE, 15 * 1000, true, "MD5"), time, MetadataVersion.latest());
+                        Double.MAX_VALUE, 15 * 1000, true, "MD5"), time, MetadataVersion.latest(), 4);
         scheduler.startup();
         final BrokerTopicStats brokerTopicStats = new BrokerTopicStats();
         final MetadataCache metadataCache =