You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by sh...@apache.org on 2022/07/22 03:04:53 UTC

[kafka] branch 3.3 updated (327439504b -> 158fc5f026)

This is an automated email from the ASF dual-hosted git repository.

showuon pushed a change to branch 3.3
in repository https://gitbox.apache.org/repos/asf/kafka.git


    from 327439504b KAFKA-14076: Fix issues with KafkaStreams.CloseOptions (#12408)
     new 9aabf49882 KAFKA-14024: Consumer keeps Commit offset in onJoinPrepare in Cooperative rebalance (#12349)
     new 95dc3fdedc Merge branch '3.3' of https://github.com/apache/kafka into 3.3
     new 158fc5f026 KAFKA-13919: expose log recovery metrics (#12347)

The 3 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 .../consumer/internals/AbstractCoordinator.java    |   5 +-
 .../consumer/internals/ConsumerCoordinator.java    |  73 +++++--
 .../internals/AbstractCoordinatorTest.java         |   2 +-
 .../internals/ConsumerCoordinatorTest.java         | 120 ++++++++---
 .../runtime/distributed/WorkerCoordinator.java     |   2 +-
 core/src/main/scala/kafka/log/LogLoader.scala      |  26 ++-
 core/src/main/scala/kafka/log/LogManager.scala     |  85 ++++++--
 core/src/main/scala/kafka/log/UnifiedLog.scala     |  10 +-
 .../kafka/api/AbstractConsumerTest.scala           |  11 +-
 .../kafka/api/PlaintextConsumerTest.scala          |  87 ++++++++
 .../test/scala/unit/kafka/log/LogLoaderTest.scala  |   3 +-
 .../test/scala/unit/kafka/log/LogManagerTest.scala | 220 ++++++++++++++++++++-
 .../test/scala/unit/kafka/log/LogTestUtils.scala   |   7 +-
 .../test/scala/unit/kafka/utils/TestUtils.scala    |   5 +-
 .../apache/kafka/jmh/server/CheckpointBench.java   |   2 +-
 15 files changed, 570 insertions(+), 88 deletions(-)


[kafka] 01/03: KAFKA-14024: Consumer keeps Commit offset in onJoinPrepare in Cooperative rebalance (#12349)

Posted by sh...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

showuon pushed a commit to branch 3.3
in repository https://gitbox.apache.org/repos/asf/kafka.git

commit 9aabf498824907816f50a914d98113dcbd73e4d1
Author: Shawn <wa...@bytedance.com>
AuthorDate: Wed Jul 20 10:03:43 2022 +0800

    KAFKA-14024: Consumer keeps Commit offset in onJoinPrepare in Cooperative rebalance (#12349)
    
    In KAFKA-13310, we tried to fix a issue that consumer#poll(duration) will be returned after the provided duration. It's because if rebalance needed, we'll try to commit current offset first before rebalance synchronously. And if the offset committing takes too long, the consumer#poll will spend more time than provided duration. To fix that, we change commit sync with commit async before rebalance (i.e. onPrepareJoin).
    
    However, in this ticket, we found the async commit will keep sending a new commit request during each Consumer#poll, because the offset commit never completes in time. The impact is that the existing consumer will be kicked out of the group after rebalance timeout without joining the group. That is, suppose we have consumer A in group G, and now consumer B joined the group, after the rebalance, only consumer B in the group.
    
    Besides, there's also another bug found during fixing this bug. Before KAFKA-13310, we commitOffset sync with rebalanceTimeout, which will retry when retriable error until timeout. After KAFKA-13310, we thought we have retry, but we'll retry after partitions revoking. That is, even though the retried offset commit successfully, it still causes some partitions offsets un-committed, and after rebalance, other consumers will consume overlapping records.
    
    Reviewers: RivenSun <ri...@zoom.us>, Luke Chen <sh...@gmail.com>
---
 .../consumer/internals/AbstractCoordinator.java    |   5 +-
 .../consumer/internals/ConsumerCoordinator.java    |  73 ++++++++++---
 .../internals/AbstractCoordinatorTest.java         |   2 +-
 .../internals/ConsumerCoordinatorTest.java         | 120 ++++++++++++++++-----
 .../runtime/distributed/WorkerCoordinator.java     |   2 +-
 .../kafka/api/AbstractConsumerTest.scala           |  11 +-
 .../kafka/api/PlaintextConsumerTest.scala          |  87 +++++++++++++++
 7 files changed, 251 insertions(+), 49 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinator.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinator.java
index c9ad797ebe..d2ece9efc5 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinator.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinator.java
@@ -187,11 +187,12 @@ public abstract class AbstractCoordinator implements Closeable {
     /**
      * Invoked prior to each group join or rejoin. This is typically used to perform any
      * cleanup from the previous generation (such as committing offsets for the consumer)
+     * @param timer Timer bounding how long this method can block
      * @param generation The previous generation or -1 if there was none
      * @param memberId The identifier of this member in the previous group or "" if there was none
      * @return true If onJoinPrepare async commit succeeded, false otherwise
      */
-    protected abstract boolean onJoinPrepare(int generation, String memberId);
+    protected abstract boolean onJoinPrepare(Timer timer, int generation, String memberId);
 
     /**
      * Invoked when the leader is elected. This is used by the leader to perform the assignment
@@ -426,7 +427,7 @@ public abstract class AbstractCoordinator implements Closeable {
                 // exception, in which case upon retry we should not retry onJoinPrepare either.
                 needsJoinPrepare = false;
                 // return false when onJoinPrepare is waiting for committing offset
-                if (!onJoinPrepare(generation.generationId, generation.memberId)) {
+                if (!onJoinPrepare(timer, generation.generationId, generation.memberId)) {
                     needsJoinPrepare = true;
                     //should not initiateJoinGroup if needsJoinPrepare still is true
                     return false;
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
index b853ff99e8..9838e7dc8f 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
@@ -141,6 +141,12 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
     }
 
     private final RebalanceProtocol protocol;
+    // pending commit offset request in onJoinPrepare
+    private RequestFuture<Void> autoCommitOffsetRequestFuture = null;
+    // a timer for join prepare to know when to stop.
+    // it'll set to rebalance timeout so that the member can join the group successfully
+    // even though offset commit failed.
+    private Timer joinPrepareTimer = null;
 
     /**
      * Initialize the coordination manager.
@@ -740,24 +746,58 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
     }
 
     @Override
-    protected boolean onJoinPrepare(int generation, String memberId) {
+    protected boolean onJoinPrepare(Timer timer, int generation, String memberId) {
         log.debug("Executing onJoinPrepare with generation {} and memberId {}", generation, memberId);
-        boolean onJoinPrepareAsyncCommitCompleted = false;
+        if (joinPrepareTimer == null) {
+            // We should complete onJoinPrepare before rebalanceTimeout,
+            // and continue to join group to avoid member got kicked out from group
+            joinPrepareTimer = time.timer(rebalanceConfig.rebalanceTimeoutMs);
+        } else {
+            joinPrepareTimer.update();
+        }
+
         // async commit offsets prior to rebalance if auto-commit enabled
-        RequestFuture<Void> future = maybeAutoCommitOffsetsAsync();
-        // return true when
-        // 1. future is null, which means no commit request sent, so it is still considered completed
-        // 2. offset commit completed
-        // 3. offset commit failed with non-retriable exception
-        if (future == null)
-            onJoinPrepareAsyncCommitCompleted = true;
-        else if (future.succeeded())
-            onJoinPrepareAsyncCommitCompleted = true;
-        else if (future.failed() && !future.isRetriable()) {
-            log.error("Asynchronous auto-commit of offsets failed: {}", future.exception().getMessage());
-            onJoinPrepareAsyncCommitCompleted = true;
+        // and there is no in-flight offset commit request
+        if (autoCommitEnabled && autoCommitOffsetRequestFuture == null) {
+            autoCommitOffsetRequestFuture = maybeAutoCommitOffsetsAsync();
         }
 
+        // wait for commit offset response before timer expired
+        if (autoCommitOffsetRequestFuture != null) {
+            Timer pollTimer = timer.remainingMs() < joinPrepareTimer.remainingMs() ?
+                    timer : joinPrepareTimer;
+            client.poll(autoCommitOffsetRequestFuture, pollTimer);
+            joinPrepareTimer.update();
+
+            // Keep retrying/waiting the offset commit when:
+            // 1. offset commit haven't done (and joinPrepareTimer not expired)
+            // 2. failed with retryable exception (and joinPrepareTimer not expired)
+            // Otherwise, continue to revoke partitions, ex:
+            // 1. if joinPrepareTime has expired
+            // 2. if offset commit failed with no-retryable exception
+            // 3. if offset commit success
+            boolean onJoinPrepareAsyncCommitCompleted = true;
+            if (joinPrepareTimer.isExpired()) {
+                log.error("Asynchronous auto-commit of offsets failed: joinPrepare timeout. Will continue to join group");
+            } else if (!autoCommitOffsetRequestFuture.isDone()) {
+                onJoinPrepareAsyncCommitCompleted = false;
+            } else if (autoCommitOffsetRequestFuture.failed() && autoCommitOffsetRequestFuture.isRetriable()) {
+                log.debug("Asynchronous auto-commit of offsets failed with retryable error: {}. Will retry it.",
+                        autoCommitOffsetRequestFuture.exception().getMessage());
+                onJoinPrepareAsyncCommitCompleted = false;
+            } else if (autoCommitOffsetRequestFuture.failed() && !autoCommitOffsetRequestFuture.isRetriable()) {
+                log.error("Asynchronous auto-commit of offsets failed: {}. Will continue to join group.",
+                        autoCommitOffsetRequestFuture.exception().getMessage());
+            }
+            if (autoCommitOffsetRequestFuture.isDone()) {
+                autoCommitOffsetRequestFuture = null;
+            }
+            if (!onJoinPrepareAsyncCommitCompleted) {
+                pollTimer.sleep(Math.min(pollTimer.remainingMs(), rebalanceConfig.retryBackoffMs));
+                timer.update();
+                return false;
+            }
+        }
 
         // the generation / member-id can possibly be reset by the heartbeat thread
         // upon getting errors or heartbeat timeouts; in this case whatever is previously
@@ -809,11 +849,14 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
 
         isLeader = false;
         subscriptions.resetGroupSubscription();
+        joinPrepareTimer = null;
+        autoCommitOffsetRequestFuture = null;
+        timer.update();
 
         if (exception != null) {
             throw new KafkaException("User rebalance callback throws an error", exception);
         }
-        return onJoinPrepareAsyncCommitCompleted;
+        return true;
     }
 
     @Override
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinatorTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinatorTest.java
index ddbebb6dde..cbc4e7495e 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinatorTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinatorTest.java
@@ -1726,7 +1726,7 @@ public class AbstractCoordinatorTest {
         }
 
         @Override
-        protected boolean onJoinPrepare(int generation, String memberId) {
+        protected boolean onJoinPrepare(Timer timer, int generation, String memberId) {
             onJoinPrepareInvokes++;
             return true;
         }
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
index db483c6c0f..d948990d69 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
@@ -78,6 +78,7 @@ import org.apache.kafka.common.requests.SyncGroupResponse;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.SystemTime;
+import org.apache.kafka.common.utils.Timer;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.test.TestUtils;
 import org.junit.jupiter.api.AfterEach;
@@ -1301,9 +1302,71 @@ public abstract class ConsumerCoordinatorTest {
         }
     }
 
+    @Test
+    public void testOnJoinPrepareWithOffsetCommitShouldSuccessAfterRetry() {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, Optional.empty(), false)) {
+            int generationId = 42;
+            String memberId = "consumer-42";
+
+            Timer pollTimer = time.timer(100L);
+            client.prepareResponse(offsetCommitResponse(singletonMap(t1p, Errors.UNKNOWN_TOPIC_OR_PARTITION)));
+            boolean res = coordinator.onJoinPrepare(pollTimer, generationId, memberId);
+            assertFalse(res);
+
+            pollTimer = time.timer(100L);
+            client.prepareResponse(offsetCommitResponse(singletonMap(t1p, Errors.NONE)));
+            res = coordinator.onJoinPrepare(pollTimer, generationId, memberId);
+            assertTrue(res);
+
+            assertFalse(client.hasPendingResponses());
+            assertFalse(client.hasInFlightRequests());
+            assertFalse(coordinator.coordinatorUnknown());
+        }
+    }
+
+    @Test
+    public void testOnJoinPrepareWithOffsetCommitShouldKeepJoinAfterNonRetryableException() {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, Optional.empty(), false)) {
+            int generationId = 42;
+            String memberId = "consumer-42";
+
+            Timer pollTimer = time.timer(100L);
+            client.prepareResponse(offsetCommitResponse(singletonMap(t1p, Errors.UNKNOWN_MEMBER_ID)));
+            boolean res = coordinator.onJoinPrepare(pollTimer, generationId, memberId);
+            assertTrue(res);
+
+            assertFalse(client.hasPendingResponses());
+            assertFalse(client.hasInFlightRequests());
+            assertFalse(coordinator.coordinatorUnknown());
+        }
+    }
+
+    @Test
+    public void testOnJoinPrepareWithOffsetCommitShouldKeepJoinAfterRebalanceTimeout() {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, Optional.empty(), false)) {
+            int generationId = 42;
+            String memberId = "consumer-42";
+
+            Timer pollTimer = time.timer(100L);
+            time.sleep(150);
+            boolean res = coordinator.onJoinPrepare(pollTimer, generationId, memberId);
+            assertFalse(res);
+
+            pollTimer = time.timer(100L);
+            time.sleep(rebalanceTimeoutMs);
+            client.respond(offsetCommitResponse(singletonMap(t1p, Errors.UNKNOWN_TOPIC_OR_PARTITION)));
+            res = coordinator.onJoinPrepare(pollTimer, generationId, memberId);
+            assertTrue(res);
+
+            assertFalse(client.hasPendingResponses());
+            assertFalse(client.hasInFlightRequests());
+            assertFalse(coordinator.coordinatorUnknown());
+        }
+    }
+
     @Test
     public void testJoinPrepareWithDisableAutoCommit() {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.of("group-id"))) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.of("group-id"), true)) {
             coordinator.ensureActiveGroup();
 
             prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NONE);
@@ -1311,7 +1374,7 @@ public abstract class ConsumerCoordinatorTest {
             int generationId = 42;
             String memberId = "consumer-42";
 
-            boolean res = coordinator.onJoinPrepare(generationId, memberId);
+            boolean res = coordinator.onJoinPrepare(time.timer(0L), generationId, memberId);
 
             assertTrue(res);
             assertTrue(client.hasPendingResponses());
@@ -1322,14 +1385,14 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testJoinPrepareAndCommitCompleted() {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, Optional.of("group-id"))) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, Optional.of("group-id"), true)) {
             coordinator.ensureActiveGroup();
 
             prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NONE);
             int generationId = 42;
             String memberId = "consumer-42";
 
-            boolean res = coordinator.onJoinPrepare(generationId, memberId);
+            boolean res = coordinator.onJoinPrepare(time.timer(0L), generationId, memberId);
             coordinator.invokeCompletedOffsetCommitCallbacks();
 
             assertTrue(res);
@@ -1341,7 +1404,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testJoinPrepareAndCommitWithCoordinatorNotAvailable() {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, Optional.of("group-id"))) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, Optional.of("group-id"), true)) {
             coordinator.ensureActiveGroup();
 
             prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.COORDINATOR_NOT_AVAILABLE);
@@ -1349,7 +1412,7 @@ public abstract class ConsumerCoordinatorTest {
             int generationId = 42;
             String memberId = "consumer-42";
 
-            boolean res = coordinator.onJoinPrepare(generationId, memberId);
+            boolean res = coordinator.onJoinPrepare(time.timer(0L), generationId, memberId);
             coordinator.invokeCompletedOffsetCommitCallbacks();
 
             assertFalse(res);
@@ -1361,7 +1424,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testJoinPrepareAndCommitWithUnknownMemberId() {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, Optional.of("group-id"))) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, Optional.of("group-id"), true)) {
             coordinator.ensureActiveGroup();
 
             prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.UNKNOWN_MEMBER_ID);
@@ -1369,7 +1432,7 @@ public abstract class ConsumerCoordinatorTest {
             int generationId = 42;
             String memberId = "consumer-42";
 
-            boolean res = coordinator.onJoinPrepare(generationId, memberId);
+            boolean res = coordinator.onJoinPrepare(time.timer(0L), generationId, memberId);
             coordinator.invokeCompletedOffsetCommitCallbacks();
 
             assertTrue(res);
@@ -3080,21 +3143,21 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testCloseDynamicAssignment() {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, Optional.empty())) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, Optional.empty(), true)) {
             gracefulCloseTest(coordinator, true);
         }
     }
 
     @Test
     public void testCloseManualAssignment() {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(false, true, Optional.empty())) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(false, true, Optional.empty(), true)) {
             gracefulCloseTest(coordinator, false);
         }
     }
 
     @Test
     public void testCloseCoordinatorNotKnownManualAssignment() throws Exception {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(false, true, Optional.empty())) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(false, true, Optional.empty(), true)) {
             makeCoordinatorUnknown(coordinator, Errors.NOT_COORDINATOR);
             time.sleep(autoCommitIntervalMs);
             closeVerifyTimeout(coordinator, 1000, 1000, 1000);
@@ -3103,7 +3166,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testCloseCoordinatorNotKnownNoCommits() throws Exception {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.empty())) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.empty(), true)) {
             makeCoordinatorUnknown(coordinator, Errors.NOT_COORDINATOR);
             closeVerifyTimeout(coordinator, 1000, 0, 0);
         }
@@ -3111,7 +3174,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testCloseCoordinatorNotKnownWithCommits() throws Exception {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, Optional.empty())) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, Optional.empty(), true)) {
             makeCoordinatorUnknown(coordinator, Errors.NOT_COORDINATOR);
             time.sleep(autoCommitIntervalMs);
             closeVerifyTimeout(coordinator, 1000, 1000, 1000);
@@ -3120,7 +3183,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testCloseCoordinatorUnavailableNoCommits() throws Exception {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.empty())) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.empty(), true)) {
             makeCoordinatorUnknown(coordinator, Errors.COORDINATOR_NOT_AVAILABLE);
             closeVerifyTimeout(coordinator, 1000, 0, 0);
         }
@@ -3128,7 +3191,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testCloseTimeoutCoordinatorUnavailableForCommit() throws Exception {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, groupInstanceId)) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, groupInstanceId, true)) {
             makeCoordinatorUnknown(coordinator, Errors.COORDINATOR_NOT_AVAILABLE);
             time.sleep(autoCommitIntervalMs);
             closeVerifyTimeout(coordinator, 1000, 1000, 1000);
@@ -3137,7 +3200,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testCloseMaxWaitCoordinatorUnavailableForCommit() throws Exception {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, groupInstanceId)) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, groupInstanceId, true)) {
             makeCoordinatorUnknown(coordinator, Errors.COORDINATOR_NOT_AVAILABLE);
             time.sleep(autoCommitIntervalMs);
             closeVerifyTimeout(coordinator, Long.MAX_VALUE, requestTimeoutMs, requestTimeoutMs);
@@ -3146,7 +3209,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testCloseNoResponseForCommit() throws Exception {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, groupInstanceId)) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, groupInstanceId, true)) {
             time.sleep(autoCommitIntervalMs);
             closeVerifyTimeout(coordinator, Long.MAX_VALUE, requestTimeoutMs, requestTimeoutMs);
         }
@@ -3154,14 +3217,14 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testCloseNoResponseForLeaveGroup() throws Exception {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.empty())) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.empty(), true)) {
             closeVerifyTimeout(coordinator, Long.MAX_VALUE, requestTimeoutMs, requestTimeoutMs);
         }
     }
 
     @Test
     public void testCloseNoWait() throws Exception {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, groupInstanceId)) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, groupInstanceId, true)) {
             time.sleep(autoCommitIntervalMs);
             closeVerifyTimeout(coordinator, 0, 0, 0);
         }
@@ -3169,7 +3232,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testHeartbeatThreadClose() throws Exception {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, groupInstanceId)) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, groupInstanceId, true)) {
             coordinator.ensureActiveGroup();
             time.sleep(heartbeatIntervalMs + 100);
             Thread.yield(); // Give heartbeat thread a chance to attempt heartbeat
@@ -3236,7 +3299,7 @@ public abstract class ConsumerCoordinatorTest {
         assertEquals(JoinGroupRequest.UNKNOWN_MEMBER_ID, groupMetadata.memberId());
         assertFalse(groupMetadata.groupInstanceId().isPresent());
 
-        try (final ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, groupInstanceId)) {
+        try (final ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, groupInstanceId, true)) {
             coordinator.ensureActiveGroup();
 
             final ConsumerGroupMetadata joinedGroupMetadata = coordinator.groupMetadata();
@@ -3272,7 +3335,7 @@ public abstract class ConsumerCoordinatorTest {
     @Test
     public void testPrepareJoinAndRejoinAfterFailedRebalance() {
         final List<TopicPartition> partitions = singletonList(t1p);
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.of("group-id"))) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.of("group-id"), true)) {
             coordinator.ensureActiveGroup();
 
             prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.REBALANCE_IN_PROGRESS);
@@ -3292,7 +3355,7 @@ public abstract class ConsumerCoordinatorTest {
             MockTime time = new MockTime(1);
 
             // onJoinPrepare will be executed and onJoinComplete will not.
-            boolean res = coordinator.joinGroupIfNeeded(time.timer(2));
+            boolean res = coordinator.joinGroupIfNeeded(time.timer(100));
 
             assertFalse(res);
             assertFalse(client.hasPendingResponses());
@@ -3337,7 +3400,7 @@ public abstract class ConsumerCoordinatorTest {
     @Test
     public void shouldLoseAllOwnedPartitionsBeforeRejoiningAfterDroppingOutOfTheGroup() {
         final List<TopicPartition> partitions = singletonList(t1p);
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.of("group-id"))) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.of("group-id"), true)) {
             final SystemTime realTime = new SystemTime();
             coordinator.ensureActiveGroup();
 
@@ -3370,7 +3433,7 @@ public abstract class ConsumerCoordinatorTest {
     @Test
     public void shouldLoseAllOwnedPartitionsBeforeRejoiningAfterResettingGenerationId() {
         final List<TopicPartition> partitions = singletonList(t1p);
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.of("group-id"))) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.of("group-id"), true)) {
             final SystemTime realTime = new SystemTime();
             coordinator.ensureActiveGroup();
 
@@ -3468,7 +3531,8 @@ public abstract class ConsumerCoordinatorTest {
 
     private ConsumerCoordinator prepareCoordinatorForCloseTest(final boolean useGroupManagement,
                                                                final boolean autoCommit,
-                                                               final Optional<String> groupInstanceId) {
+                                                               final Optional<String> groupInstanceId,
+                                                               final boolean shouldPoll) {
         rebalanceConfig = buildRebalanceConfig(groupInstanceId);
         ConsumerCoordinator coordinator = buildCoordinator(rebalanceConfig,
                                                            new Metrics(),
@@ -3487,7 +3551,9 @@ public abstract class ConsumerCoordinatorTest {
         }
 
         subscriptions.seek(t1p, 100);
-        coordinator.poll(time.timer(Long.MAX_VALUE));
+        if (shouldPoll) {
+            coordinator.poll(time.timer(Long.MAX_VALUE));
+        }
 
         return coordinator;
     }
diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinator.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinator.java
index 1ab3543361..ced67427a3 100644
--- a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinator.java
+++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinator.java
@@ -225,7 +225,7 @@ public class WorkerCoordinator extends AbstractCoordinator implements Closeable
     }
 
     @Override
-    protected boolean onJoinPrepare(int generation, String memberId) {
+    protected boolean onJoinPrepare(Timer timer, int generation, String memberId) {
         log.info("Rebalance started");
         leaderState(null);
         final ExtendedAssignment localAssignmentSnapshot = assignmentSnapshot;
diff --git a/core/src/test/scala/integration/kafka/api/AbstractConsumerTest.scala b/core/src/test/scala/integration/kafka/api/AbstractConsumerTest.scala
index 56bc47c79e..23b56b8e91 100644
--- a/core/src/test/scala/integration/kafka/api/AbstractConsumerTest.scala
+++ b/core/src/test/scala/integration/kafka/api/AbstractConsumerTest.scala
@@ -342,15 +342,16 @@ abstract class AbstractConsumerTest extends BaseRequestTest {
 
   protected class ConsumerAssignmentPoller(consumer: Consumer[Array[Byte], Array[Byte]],
                                            topicsToSubscribe: List[String],
-                                           partitionsToAssign: Set[TopicPartition])
+                                           partitionsToAssign: Set[TopicPartition],
+                                           userRebalanceListener: ConsumerRebalanceListener)
     extends ShutdownableThread("daemon-consumer-assignment", false) {
 
     def this(consumer: Consumer[Array[Byte], Array[Byte]], topicsToSubscribe: List[String]) = {
-      this(consumer, topicsToSubscribe, Set.empty[TopicPartition])
+      this(consumer, topicsToSubscribe, Set.empty[TopicPartition], null)
     }
 
     def this(consumer: Consumer[Array[Byte], Array[Byte]], partitionsToAssign: Set[TopicPartition]) = {
-      this(consumer, List.empty[String], partitionsToAssign)
+      this(consumer, List.empty[String], partitionsToAssign, null)
     }
 
     @volatile var thrownException: Option[Throwable] = None
@@ -363,10 +364,14 @@ abstract class AbstractConsumerTest extends BaseRequestTest {
     val rebalanceListener: ConsumerRebalanceListener = new ConsumerRebalanceListener {
       override def onPartitionsAssigned(partitions: util.Collection[TopicPartition]) = {
         partitionAssignment ++= partitions.toArray(new Array[TopicPartition](0))
+        if (userRebalanceListener != null)
+          userRebalanceListener.onPartitionsAssigned(partitions)
       }
 
       override def onPartitionsRevoked(partitions: util.Collection[TopicPartition]) = {
         partitionAssignment --= partitions.toArray(new Array[TopicPartition](0))
+        if (userRebalanceListener != null)
+          userRebalanceListener.onPartitionsRevoked(partitions)
       }
     }
 
diff --git a/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala b/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala
index 4ede241b0c..5dc7c2ada1 100644
--- a/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala
+++ b/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala
@@ -37,7 +37,11 @@ import kafka.server.QuotaType
 import kafka.server.KafkaServer
 import org.apache.kafka.clients.admin.NewPartitions
 import org.apache.kafka.clients.admin.NewTopic
+import org.junit.jupiter.params.ParameterizedTest
+import org.junit.jupiter.params.provider.ValueSource
 
+import java.util.concurrent.TimeUnit
+import java.util.concurrent.locks.ReentrantLock
 import scala.collection.mutable
 
 /* We have some tests in this class instead of `BaseConsumerTest` in order to keep the build time under control. */
@@ -969,6 +973,89 @@ class PlaintextConsumerTest extends BaseConsumerTest {
     }
   }
 
+  @ParameterizedTest
+  @ValueSource(strings = Array(
+    "org.apache.kafka.clients.consumer.CooperativeStickyAssignor",
+    "org.apache.kafka.clients.consumer.RangeAssignor"))
+  def testRebalanceAndRejoin(assignmentStrategy: String): Unit = {
+    // create 2 consumers
+    this.consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "rebalance-and-rejoin-group")
+    this.consumerConfig.setProperty(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG, assignmentStrategy)
+    this.consumerConfig.setProperty(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "true")
+    val consumer1 = createConsumer()
+    val consumer2 = createConsumer()
+
+    // create a new topic, have 2 partitions
+    val topic = "topic1"
+    val producer = createProducer()
+    val expectedAssignment = createTopicAndSendRecords(producer, topic, 2, 100)
+
+    assertEquals(0, consumer1.assignment().size)
+    assertEquals(0, consumer2.assignment().size)
+
+    val lock = new ReentrantLock()
+    var generationId1 = -1
+    var memberId1 = ""
+    val customRebalanceListener = new ConsumerRebalanceListener {
+      override def onPartitionsRevoked(partitions: util.Collection[TopicPartition]): Unit = {
+      }
+      override def onPartitionsAssigned(partitions: util.Collection[TopicPartition]): Unit = {
+        if (!lock.tryLock(3000, TimeUnit.MILLISECONDS)) {
+          fail(s"Time out while awaiting for lock.")
+        }
+        try {
+          generationId1 = consumer1.groupMetadata().generationId()
+          memberId1 = consumer1.groupMetadata().memberId()
+        } finally {
+          lock.unlock()
+        }
+      }
+    }
+    val consumerPoller1 = new ConsumerAssignmentPoller(consumer1, List(topic), Set.empty, customRebalanceListener)
+    consumerPoller1.start()
+    TestUtils.waitUntilTrue(() => consumerPoller1.consumerAssignment() == expectedAssignment,
+      s"Timed out while awaiting expected assignment change to $expectedAssignment.")
+
+    // Since the consumer1 already completed the rebalance,
+    // the `onPartitionsAssigned` rebalance listener will be invoked to set the generationId and memberId
+    var stableGeneration = -1
+    var stableMemberId1 = ""
+    if (!lock.tryLock(3000, TimeUnit.MILLISECONDS)) {
+      fail(s"Time out while awaiting for lock.")
+    }
+    try {
+      stableGeneration = generationId1
+      stableMemberId1 = memberId1
+    } finally {
+      lock.unlock()
+    }
+
+    val consumerPoller2 = subscribeConsumerAndStartPolling(consumer2, List(topic))
+    TestUtils.waitUntilTrue(() => consumerPoller1.consumerAssignment().size == 1,
+      s"Timed out while awaiting expected assignment size change to 1.")
+    TestUtils.waitUntilTrue(() => consumerPoller2.consumerAssignment().size == 1,
+      s"Timed out while awaiting expected assignment size change to 1.")
+
+    if (!lock.tryLock(3000, TimeUnit.MILLISECONDS)) {
+      fail(s"Time out while awaiting for lock.")
+    }
+    try {
+      if (assignmentStrategy.equals(classOf[CooperativeStickyAssignor].getName)) {
+        // cooperative rebalance should rebalance twice before finally stable
+        assertEquals(stableGeneration + 2, generationId1)
+      } else {
+        // eager rebalance should rebalance once before finally stable
+        assertEquals(stableGeneration + 1, generationId1)
+      }
+      assertEquals(stableMemberId1, memberId1)
+    } finally {
+      lock.unlock()
+    }
+
+    consumerPoller1.shutdown()
+    consumerPoller2.shutdown()
+  }
+
   /**
    * This test re-uses BaseConsumerTest's consumers.
    * As a result, it is testing the default assignment strategy set by BaseConsumerTest


[kafka] 02/03: Merge branch '3.3' of https://github.com/apache/kafka into 3.3

Posted by sh...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

showuon pushed a commit to branch 3.3
in repository https://gitbox.apache.org/repos/asf/kafka.git

commit 95dc3fdedc322da8c2e51ba840c7cbf48aa74fed
Merge: 9aabf49882 327439504b
Author: Luke Chen <sh...@gmail.com>
AuthorDate: Fri Jul 22 11:04:04 2022 +0800

    Merge branch '3.3' of https://github.com/apache/kafka into 3.3

 .../kafka/clients/producer/KafkaProducer.java      |  30 +++-
 .../producer/internals/RecordAccumulator.java      |  13 +-
 .../org/apache/kafka/streams/KafkaStreams.java     |  69 ++++---
 .../KafkaStreamsCloseOptionsIntegrationTest.java   | 198 +++++++++++++++++++++
 4 files changed, 261 insertions(+), 49 deletions(-)


[kafka] 03/03: KAFKA-13919: expose log recovery metrics (#12347)

Posted by sh...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

showuon pushed a commit to branch 3.3
in repository https://gitbox.apache.org/repos/asf/kafka.git

commit 158fc5f0269670d80d6c8f49186c51930bd600a1
Author: Luke Chen <sh...@gmail.com>
AuthorDate: Fri Jul 22 11:00:15 2022 +0800

    KAFKA-13919: expose log recovery metrics (#12347)
    
    Implementation for KIP-831.
    1. add remainingLogsToRecover metric for the number of remaining logs for each log.dir to be recovered
    2.  add remainingSegmentsToRecover metric for the number of remaining segments for the current log assigned to the recovery thread.
    3. remove these metrics after log loaded completely
    4. add tests
    
    Reviewers: Jun Rao <ju...@confluent.io>, Tom Bentley <tb...@redhat.com>
---
 core/src/main/scala/kafka/log/LogLoader.scala      |  26 ++-
 core/src/main/scala/kafka/log/LogManager.scala     |  85 ++++++--
 core/src/main/scala/kafka/log/UnifiedLog.scala     |  10 +-
 .../test/scala/unit/kafka/log/LogLoaderTest.scala  |   3 +-
 .../test/scala/unit/kafka/log/LogManagerTest.scala | 220 ++++++++++++++++++++-
 .../test/scala/unit/kafka/log/LogTestUtils.scala   |   7 +-
 .../test/scala/unit/kafka/utils/TestUtils.scala    |   5 +-
 .../apache/kafka/jmh/server/CheckpointBench.java   |   2 +-
 8 files changed, 319 insertions(+), 39 deletions(-)

diff --git a/core/src/main/scala/kafka/log/LogLoader.scala b/core/src/main/scala/kafka/log/LogLoader.scala
index 581d016e5e..25ee89c72b 100644
--- a/core/src/main/scala/kafka/log/LogLoader.scala
+++ b/core/src/main/scala/kafka/log/LogLoader.scala
@@ -29,6 +29,7 @@ import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.errors.InvalidOffsetException
 import org.apache.kafka.common.utils.Time
 
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
 import scala.collection.{Set, mutable}
 
 case class LoadedLogOffsets(logStartOffset: Long,
@@ -64,6 +65,7 @@ object LogLoader extends Logging {
  * @param recoveryPointCheckpoint The checkpoint of the offset at which to begin the recovery
  * @param leaderEpochCache An optional LeaderEpochFileCache instance to be updated during recovery
  * @param producerStateManager The ProducerStateManager instance to be updated during recovery
+ * @param numRemainingSegments The remaining segments to be recovered in this log keyed by recovery thread name
  */
 class LogLoader(
   dir: File,
@@ -77,7 +79,8 @@ class LogLoader(
   logStartOffsetCheckpoint: Long,
   recoveryPointCheckpoint: Long,
   leaderEpochCache: Option[LeaderEpochFileCache],
-  producerStateManager: ProducerStateManager
+  producerStateManager: ProducerStateManager,
+  numRemainingSegments: ConcurrentMap[String, Int] = new ConcurrentHashMap[String, Int]
 ) extends Logging {
   logIdent = s"[LogLoader partition=$topicPartition, dir=${dir.getParent}] "
 
@@ -404,12 +407,18 @@ class LogLoader(
 
     // If we have the clean shutdown marker, skip recovery.
     if (!hadCleanShutdown) {
-      val unflushed = segments.values(recoveryPointCheckpoint, Long.MaxValue).iterator
+      val unflushed = segments.values(recoveryPointCheckpoint, Long.MaxValue)
+      val numUnflushed = unflushed.size
+      val unflushedIter = unflushed.iterator
       var truncated = false
+      var numFlushed = 0
+      val threadName = Thread.currentThread().getName
+      numRemainingSegments.put(threadName, numUnflushed)
+
+      while (unflushedIter.hasNext && !truncated) {
+        val segment = unflushedIter.next()
+        info(s"Recovering unflushed segment ${segment.baseOffset}. $numFlushed/$numUnflushed recovered for $topicPartition.")
 
-      while (unflushed.hasNext && !truncated) {
-        val segment = unflushed.next()
-        info(s"Recovering unflushed segment ${segment.baseOffset}")
         val truncatedBytes =
           try {
             recoverSegment(segment)
@@ -424,8 +433,13 @@ class LogLoader(
           // we had an invalid message, delete all remaining log
           warn(s"Corruption found in segment ${segment.baseOffset}," +
             s" truncating to offset ${segment.readNextOffset}")
-          removeAndDeleteSegmentsAsync(unflushed.toList)
+          removeAndDeleteSegmentsAsync(unflushedIter.toList)
           truncated = true
+          // segment is truncated, so set remaining segments to 0
+          numRemainingSegments.put(threadName, 0)
+        } else {
+          numFlushed += 1
+          numRemainingSegments.put(threadName, numUnflushed - numFlushed)
         }
       }
     }
diff --git a/core/src/main/scala/kafka/log/LogManager.scala b/core/src/main/scala/kafka/log/LogManager.scala
index bdc7ffd74d..886f56c63c 100755
--- a/core/src/main/scala/kafka/log/LogManager.scala
+++ b/core/src/main/scala/kafka/log/LogManager.scala
@@ -262,7 +262,8 @@ class LogManager(logDirs: Seq[File],
                            recoveryPoints: Map[TopicPartition, Long],
                            logStartOffsets: Map[TopicPartition, Long],
                            defaultConfig: LogConfig,
-                           topicConfigOverrides: Map[String, LogConfig]): UnifiedLog = {
+                           topicConfigOverrides: Map[String, LogConfig],
+                           numRemainingSegments: ConcurrentMap[String, Int]): UnifiedLog = {
     val topicPartition = UnifiedLog.parseTopicPartitionName(logDir)
     val config = topicConfigOverrides.getOrElse(topicPartition.topic, defaultConfig)
     val logRecoveryPoint = recoveryPoints.getOrElse(topicPartition, 0L)
@@ -282,7 +283,8 @@ class LogManager(logDirs: Seq[File],
       logDirFailureChannel = logDirFailureChannel,
       lastShutdownClean = hadCleanShutdown,
       topicId = None,
-      keepPartitionMetadataFile = keepPartitionMetadataFile)
+      keepPartitionMetadataFile = keepPartitionMetadataFile,
+      numRemainingSegments = numRemainingSegments)
 
     if (logDir.getName.endsWith(UnifiedLog.DeleteDirSuffix)) {
       addLogToBeDeleted(log)
@@ -307,6 +309,27 @@ class LogManager(logDirs: Seq[File],
     log
   }
 
+  // factory class for naming the log recovery threads used in metrics
+  class LogRecoveryThreadFactory(val dirPath: String) extends ThreadFactory {
+    val threadNum = new AtomicInteger(0)
+
+    override def newThread(runnable: Runnable): Thread = {
+      KafkaThread.nonDaemon(logRecoveryThreadName(dirPath, threadNum.getAndIncrement()), runnable)
+    }
+  }
+
+  // create a unique log recovery thread name for each log dir as the format: prefix-dirPath-threadNum, ex: "log-recovery-/tmp/kafkaLogs-0"
+  private def logRecoveryThreadName(dirPath: String, threadNum: Int, prefix: String = "log-recovery"): String = s"$prefix-$dirPath-$threadNum"
+
+  /*
+   * decrement the number of remaining logs
+   * @return the number of remaining logs after decremented 1
+   */
+  private[log] def decNumRemainingLogs(numRemainingLogs: ConcurrentMap[String, Int], path: String): Int = {
+    require(path != null, "path cannot be null to update remaining logs metric.")
+    numRemainingLogs.compute(path, (_, oldVal) => oldVal - 1)
+  }
+
   /**
    * Recover and load all logs in the given data directories
    */
@@ -317,6 +340,10 @@ class LogManager(logDirs: Seq[File],
     val offlineDirs = mutable.Set.empty[(String, IOException)]
     val jobs = ArrayBuffer.empty[Seq[Future[_]]]
     var numTotalLogs = 0
+    // log dir path -> number of Remaining logs map for remainingLogsToRecover metric
+    val numRemainingLogs: ConcurrentMap[String, Int] = new ConcurrentHashMap[String, Int]
+    // log recovery thread name -> number of remaining segments map for remainingSegmentsToRecover metric
+    val numRemainingSegments: ConcurrentMap[String, Int] = new ConcurrentHashMap[String, Int]
 
     def handleIOException(logDirAbsolutePath: String, e: IOException): Unit = {
       offlineDirs.add((logDirAbsolutePath, e))
@@ -328,7 +355,7 @@ class LogManager(logDirs: Seq[File],
       var hadCleanShutdown: Boolean = false
       try {
         val pool = Executors.newFixedThreadPool(numRecoveryThreadsPerDataDir,
-          KafkaThread.nonDaemon(s"log-recovery-$logDirAbsolutePath", _))
+          new LogRecoveryThreadFactory(logDirAbsolutePath))
         threadPools.append(pool)
 
         val cleanShutdownFile = new File(dir, LogLoader.CleanShutdownFile)
@@ -363,28 +390,32 @@ class LogManager(logDirs: Seq[File],
 
         val logsToLoad = Option(dir.listFiles).getOrElse(Array.empty).filter(logDir =>
           logDir.isDirectory && UnifiedLog.parseTopicPartitionName(logDir).topic != KafkaRaftServer.MetadataTopic)
-        val numLogsLoaded = new AtomicInteger(0)
         numTotalLogs += logsToLoad.length
+        numRemainingLogs.put(dir.getAbsolutePath, logsToLoad.length)
 
         val jobsForDir = logsToLoad.map { logDir =>
           val runnable: Runnable = () => {
+            debug(s"Loading log $logDir")
+            var log = None: Option[UnifiedLog]
+            val logLoadStartMs = time.hiResClockMs()
             try {
-              debug(s"Loading log $logDir")
-
-              val logLoadStartMs = time.hiResClockMs()
-              val log = loadLog(logDir, hadCleanShutdown, recoveryPoints, logStartOffsets,
-                defaultConfig, topicConfigOverrides)
-              val logLoadDurationMs = time.hiResClockMs() - logLoadStartMs
-              val currentNumLoaded = numLogsLoaded.incrementAndGet()
-
-              info(s"Completed load of $log with ${log.numberOfSegments} segments in ${logLoadDurationMs}ms " +
-                s"($currentNumLoaded/${logsToLoad.length} loaded in $logDirAbsolutePath)")
+              log = Some(loadLog(logDir, hadCleanShutdown, recoveryPoints, logStartOffsets,
+                defaultConfig, topicConfigOverrides, numRemainingSegments))
             } catch {
               case e: IOException =>
                 handleIOException(logDirAbsolutePath, e)
               case e: KafkaStorageException if e.getCause.isInstanceOf[IOException] =>
                 // KafkaStorageException might be thrown, ex: during writing LeaderEpochFileCache
                 // And while converting IOException to KafkaStorageException, we've already handled the exception. So we can ignore it here.
+            } finally {
+              val logLoadDurationMs = time.hiResClockMs() - logLoadStartMs
+              val remainingLogs = decNumRemainingLogs(numRemainingLogs, dir.getAbsolutePath)
+              val currentNumLoaded = logsToLoad.length - remainingLogs
+              log match {
+                case Some(loadedLog) => info(s"Completed load of $loadedLog with ${loadedLog.numberOfSegments} segments in ${logLoadDurationMs}ms " +
+                  s"($currentNumLoaded/${logsToLoad.length} completed in $logDirAbsolutePath)")
+                case None => info(s"Error while loading logs in $logDir in ${logLoadDurationMs}ms ($currentNumLoaded/${logsToLoad.length} completed in $logDirAbsolutePath)")
+              }
             }
           }
           runnable
@@ -398,6 +429,7 @@ class LogManager(logDirs: Seq[File],
     }
 
     try {
+      addLogRecoveryMetrics(numRemainingLogs, numRemainingSegments)
       for (dirJobs <- jobs) {
         dirJobs.foreach(_.get)
       }
@@ -410,12 +442,37 @@ class LogManager(logDirs: Seq[File],
         error(s"There was an error in one of the threads during logs loading: ${e.getCause}")
         throw e.getCause
     } finally {
+      removeLogRecoveryMetrics()
       threadPools.foreach(_.shutdown())
     }
 
     info(s"Loaded $numTotalLogs logs in ${time.hiResClockMs() - startMs}ms.")
   }
 
+  private[log] def addLogRecoveryMetrics(numRemainingLogs: ConcurrentMap[String, Int],
+                                         numRemainingSegments: ConcurrentMap[String, Int]): Unit = {
+    debug("Adding log recovery metrics")
+    for (dir <- logDirs) {
+      newGauge("remainingLogsToRecover", () => numRemainingLogs.get(dir.getAbsolutePath),
+        Map("dir" -> dir.getAbsolutePath))
+      for (i <- 0 until numRecoveryThreadsPerDataDir) {
+        val threadName = logRecoveryThreadName(dir.getAbsolutePath, i)
+        newGauge("remainingSegmentsToRecover", () => numRemainingSegments.get(threadName),
+          Map("dir" -> dir.getAbsolutePath, "threadNum" -> i.toString))
+      }
+    }
+  }
+
+  private[log] def removeLogRecoveryMetrics(): Unit = {
+    debug("Removing log recovery metrics")
+    for (dir <- logDirs) {
+      removeMetric("remainingLogsToRecover", Map("dir" -> dir.getAbsolutePath))
+      for (i <- 0 until numRecoveryThreadsPerDataDir) {
+        removeMetric("remainingSegmentsToRecover", Map("dir" -> dir.getAbsolutePath, "threadNum" -> i.toString))
+      }
+    }
+  }
+
   /**
    *  Start the background threads to flush logs and do log cleanup
    */
diff --git a/core/src/main/scala/kafka/log/UnifiedLog.scala b/core/src/main/scala/kafka/log/UnifiedLog.scala
index ddd66eb160..c4a2300237 100644
--- a/core/src/main/scala/kafka/log/UnifiedLog.scala
+++ b/core/src/main/scala/kafka/log/UnifiedLog.scala
@@ -18,11 +18,11 @@
 package kafka.log
 
 import com.yammer.metrics.core.MetricName
+
 import java.io.{File, IOException}
 import java.nio.file.Files
 import java.util.Optional
-import java.util.concurrent.TimeUnit
-
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, TimeUnit}
 import kafka.common.{LongRef, OffsetsOutOfOrderException, UnexpectedAppendOffsetException}
 import kafka.log.AppendOrigin.RaftLeader
 import kafka.message.{BrokerCompressionCodec, CompressionCodec, NoCompressionCodec}
@@ -1803,7 +1803,8 @@ object UnifiedLog extends Logging {
             logDirFailureChannel: LogDirFailureChannel,
             lastShutdownClean: Boolean = true,
             topicId: Option[Uuid],
-            keepPartitionMetadataFile: Boolean): UnifiedLog = {
+            keepPartitionMetadataFile: Boolean,
+            numRemainingSegments: ConcurrentMap[String, Int] = new ConcurrentHashMap[String, Int]): UnifiedLog = {
     // create the log directory if it doesn't exist
     Files.createDirectories(dir.toPath)
     val topicPartition = UnifiedLog.parseTopicPartitionName(dir)
@@ -1828,7 +1829,8 @@ object UnifiedLog extends Logging {
       logStartOffset,
       recoveryPoint,
       leaderEpochCache,
-      producerStateManager
+      producerStateManager,
+      numRemainingSegments
     ).load()
     val localLog = new LocalLog(dir, config, segments, offsets.recoveryPoint,
       offsets.nextOffsetMetadata, scheduler, time, topicPartition, logDirFailureChannel)
diff --git a/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala b/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala
index 0d41a5073b..c6379ff3f3 100644
--- a/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala
@@ -38,6 +38,7 @@ import org.mockito.ArgumentMatchers
 import org.mockito.ArgumentMatchers.{any, anyLong}
 import org.mockito.Mockito.{mock, reset, times, verify, when}
 
+import java.util.concurrent.ConcurrentMap
 import scala.annotation.nowarn
 import scala.collection.mutable.ListBuffer
 import scala.collection.{Iterable, Map, mutable}
@@ -117,7 +118,7 @@ class LogLoaderTest {
 
         override def loadLog(logDir: File, hadCleanShutdown: Boolean, recoveryPoints: Map[TopicPartition, Long],
                              logStartOffsets: Map[TopicPartition, Long], defaultConfig: LogConfig,
-                             topicConfigs: Map[String, LogConfig]): UnifiedLog = {
+                             topicConfigs: Map[String, LogConfig], numRemainingSegments: ConcurrentMap[String, Int]): UnifiedLog = {
           if (simulateError.hasError) {
             simulateError.errorType match {
               case ErrorTypes.KafkaStorageExceptionWithIOExceptionCause =>
diff --git a/core/src/test/scala/unit/kafka/log/LogManagerTest.scala b/core/src/test/scala/unit/kafka/log/LogManagerTest.scala
index 5353df6db3..1b2dd7809f 100755
--- a/core/src/test/scala/unit/kafka/log/LogManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogManagerTest.scala
@@ -17,10 +17,10 @@
 
 package kafka.log
 
-import com.yammer.metrics.core.MetricName
+import com.yammer.metrics.core.{Gauge, MetricName}
 import kafka.server.checkpoints.OffsetCheckpointFile
 import kafka.server.metadata.{ConfigRepository, MockConfigRepository}
-import kafka.server.{FetchDataInfo, FetchLogEnd}
+import kafka.server.{BrokerTopicStats, FetchDataInfo, FetchLogEnd, LogDirFailureChannel}
 import kafka.utils._
 import org.apache.directory.api.util.FileUtils
 import org.apache.kafka.common.errors.OffsetOutOfRangeException
@@ -29,16 +29,17 @@ import org.apache.kafka.common.{KafkaException, TopicPartition}
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
 import org.mockito.ArgumentMatchers.any
-import org.mockito.{ArgumentMatchers, Mockito}
-import org.mockito.Mockito.{doAnswer, mock, never, spy, times, verify}
+import org.mockito.{ArgumentCaptor, ArgumentMatchers, Mockito}
+import org.mockito.Mockito.{doAnswer, doNothing, mock, never, spy, times, verify}
+
 import java.io._
 import java.nio.file.Files
-import java.util.concurrent.Future
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Future}
 import java.util.{Collections, Properties}
-
 import org.apache.kafka.server.metrics.KafkaYammerMetrics
 
-import scala.collection.mutable
+import scala.collection.{Map, mutable}
+import scala.collection.mutable.ArrayBuffer
 import scala.jdk.CollectionConverters._
 import scala.util.{Failure, Try}
 
@@ -421,12 +422,14 @@ class LogManagerTest {
   }
 
   private def createLogManager(logDirs: Seq[File] = Seq(this.logDir),
-                               configRepository: ConfigRepository = new MockConfigRepository): LogManager = {
+                               configRepository: ConfigRepository = new MockConfigRepository,
+                               recoveryThreadsPerDataDir: Int = 1): LogManager = {
     TestUtils.createLogManager(
       defaultConfig = logConfig,
       configRepository = configRepository,
       logDirs = logDirs,
-      time = this.time)
+      time = this.time,
+      recoveryThreadsPerDataDir = recoveryThreadsPerDataDir)
   }
 
   @Test
@@ -638,6 +641,205 @@ class LogManagerTest {
     assertTrue(logManager.partitionsInitializing.isEmpty)
   }
 
+  private def appendRecordsToLog(time: MockTime, parentLogDir: File, partitionId: Int, brokerTopicStats: BrokerTopicStats, expectedSegmentsPerLog: Int): Unit = {
+    def createRecord = TestUtils.singletonRecords(value = "test".getBytes, timestamp = time.milliseconds)
+    val tpFile = new File(parentLogDir, s"$name-$partitionId")
+    val segmentBytes = 1024
+
+    val log = LogTestUtils.createLog(tpFile, logConfig, brokerTopicStats, time.scheduler, time, 0, 0,
+      5 * 60 * 1000, 60 * 60 * 1000, LogManager.ProducerIdExpirationCheckIntervalMs)
+
+    assertTrue(expectedSegmentsPerLog > 0)
+    // calculate numMessages to append to logs. It'll create "expectedSegmentsPerLog" log segments with segment.bytes=1024
+    val numMessages = Math.floor(segmentBytes * expectedSegmentsPerLog / createRecord.sizeInBytes).asInstanceOf[Int]
+    try {
+      for (_ <- 0 until numMessages) {
+        log.appendAsLeader(createRecord, leaderEpoch = 0)
+      }
+
+      assertEquals(expectedSegmentsPerLog, log.numberOfSegments)
+    } finally {
+      log.close()
+    }
+  }
+
+  private def verifyRemainingLogsToRecoverMetric(spyLogManager: LogManager, expectedParams: Map[String, Int]): Unit = {
+    val spyLogManagerClassName = spyLogManager.getClass().getSimpleName
+    // get all `remainingLogsToRecover` metrics
+    val logMetrics: ArrayBuffer[Gauge[Int]] = KafkaYammerMetrics.defaultRegistry.allMetrics.asScala
+      .filter { case (metric, _) => metric.getType == s"$spyLogManagerClassName" && metric.getName == "remainingLogsToRecover" }
+      .map { case (_, gauge) => gauge }
+      .asInstanceOf[ArrayBuffer[Gauge[Int]]]
+
+    assertEquals(expectedParams.size, logMetrics.size)
+
+    val capturedPath: ArgumentCaptor[String] = ArgumentCaptor.forClass(classOf[String])
+
+    val expectedCallTimes = expectedParams.values.sum
+    verify(spyLogManager, times(expectedCallTimes)).decNumRemainingLogs(any[ConcurrentMap[String, Int]], capturedPath.capture());
+
+    val paths = capturedPath.getAllValues
+    expectedParams.foreach {
+      case (path, totalLogs) =>
+        // make sure each path is called "totalLogs" times, which means it is decremented to 0 in the end
+        assertEquals(totalLogs, Collections.frequency(paths, path))
+    }
+
+    // expected the end value is 0
+    logMetrics.foreach { gauge => assertEquals(0, gauge.value()) }
+  }
+
+  private def verifyRemainingSegmentsToRecoverMetric(spyLogManager: LogManager,
+                                                     logDirs: Seq[File],
+                                                     recoveryThreadsPerDataDir: Int,
+                                                     mockMap: ConcurrentHashMap[String, Int],
+                                                     expectedParams: Map[String, Int]): Unit = {
+    val spyLogManagerClassName = spyLogManager.getClass().getSimpleName
+    // get all `remainingSegmentsToRecover` metrics
+    val logSegmentMetrics: ArrayBuffer[Gauge[Int]] = KafkaYammerMetrics.defaultRegistry.allMetrics.asScala
+          .filter { case (metric, _) => metric.getType == s"$spyLogManagerClassName" && metric.getName == "remainingSegmentsToRecover" }
+          .map { case (_, gauge) => gauge }
+          .asInstanceOf[ArrayBuffer[Gauge[Int]]]
+
+    // expected each log dir has 1 metrics for each thread
+    assertEquals(recoveryThreadsPerDataDir * logDirs.size, logSegmentMetrics.size)
+
+    val capturedThreadName: ArgumentCaptor[String] = ArgumentCaptor.forClass(classOf[String])
+    val capturedNumRemainingSegments: ArgumentCaptor[Int] = ArgumentCaptor.forClass(classOf[Int])
+
+    // Since we'll update numRemainingSegments from totalSegments to 0 for each thread, so we need to add 1 here
+    val expectedCallTimes = expectedParams.values.map( num => num + 1 ).sum
+    verify(mockMap, times(expectedCallTimes)).put(capturedThreadName.capture(), capturedNumRemainingSegments.capture());
+
+    // expected the end value is 0
+    logSegmentMetrics.foreach { gauge => assertEquals(0, gauge.value()) }
+
+    val threadNames = capturedThreadName.getAllValues
+    val numRemainingSegments = capturedNumRemainingSegments.getAllValues
+
+    expectedParams.foreach {
+      case (threadName, totalSegments) =>
+        // make sure we update the numRemainingSegments from totalSegments to 0 in order for each thread
+        var expectedCurRemainingSegments = totalSegments + 1
+        for (i <- 0 until threadNames.size) {
+          if (threadNames.get(i).contains(threadName)) {
+            expectedCurRemainingSegments -= 1
+            assertEquals(expectedCurRemainingSegments, numRemainingSegments.get(i))
+          }
+        }
+        assertEquals(0, expectedCurRemainingSegments)
+    }
+  }
+
+  private def verifyLogRecoverMetricsRemoved(spyLogManager: LogManager): Unit = {
+    val spyLogManagerClassName = spyLogManager.getClass().getSimpleName
+    // get all `remainingLogsToRecover` metrics
+    def logMetrics: mutable.Set[MetricName] = KafkaYammerMetrics.defaultRegistry.allMetrics.keySet.asScala
+      .filter { metric => metric.getType == s"$spyLogManagerClassName" && metric.getName == "remainingLogsToRecover" }
+
+    assertTrue(logMetrics.isEmpty)
+
+    // get all `remainingSegmentsToRecover` metrics
+    val logSegmentMetrics: mutable.Set[MetricName] = KafkaYammerMetrics.defaultRegistry.allMetrics.keySet.asScala
+      .filter { metric => metric.getType == s"$spyLogManagerClassName" && metric.getName == "remainingSegmentsToRecover" }
+
+    assertTrue(logSegmentMetrics.isEmpty)
+  }
+
+  @Test
+  def testLogRecoveryMetrics(): Unit = {
+    logManager.shutdown()
+    val logDir1 = TestUtils.tempDir()
+    val logDir2 = TestUtils.tempDir()
+    val logDirs = Seq(logDir1, logDir2)
+    val recoveryThreadsPerDataDir = 2
+    // create logManager with expected recovery thread number
+    logManager = createLogManager(logDirs, recoveryThreadsPerDataDir = recoveryThreadsPerDataDir)
+    val spyLogManager = spy(logManager)
+
+    assertEquals(2, spyLogManager.liveLogDirs.size)
+
+    val mockTime = new MockTime()
+    val mockMap = mock(classOf[ConcurrentHashMap[String, Int]])
+    val mockBrokerTopicStats = mock(classOf[BrokerTopicStats])
+    val expectedSegmentsPerLog = 2
+
+    // create log segments for log recovery in each log dir
+    appendRecordsToLog(mockTime, logDir1, 0, mockBrokerTopicStats, expectedSegmentsPerLog)
+    appendRecordsToLog(mockTime, logDir2, 1, mockBrokerTopicStats, expectedSegmentsPerLog)
+
+    // intercept loadLog method to pass expected parameter to do log recovery
+    doAnswer { invocation =>
+      val dir: File = invocation.getArgument(0)
+      val topicConfigOverrides: mutable.Map[String, LogConfig] = invocation.getArgument(5)
+
+      val topicPartition = UnifiedLog.parseTopicPartitionName(dir)
+      val config = topicConfigOverrides.getOrElse(topicPartition.topic, logConfig)
+
+      UnifiedLog(
+        dir = dir,
+        config = config,
+        logStartOffset = 0,
+        recoveryPoint = 0,
+        maxTransactionTimeoutMs = 5 * 60 * 1000,
+        maxProducerIdExpirationMs = 5 * 60 * 1000,
+        producerIdExpirationCheckIntervalMs = LogManager.ProducerIdExpirationCheckIntervalMs,
+        scheduler = mockTime.scheduler,
+        time = mockTime,
+        brokerTopicStats = mockBrokerTopicStats,
+        logDirFailureChannel = mock(classOf[LogDirFailureChannel]),
+        // not clean shutdown
+        lastShutdownClean = false,
+        topicId = None,
+        keepPartitionMetadataFile = false,
+        // pass mock map for verification later
+        numRemainingSegments = mockMap)
+
+    }.when(spyLogManager).loadLog(any[File], any[Boolean], any[Map[TopicPartition, Long]], any[Map[TopicPartition, Long]],
+      any[LogConfig], any[Map[String, LogConfig]], any[ConcurrentMap[String, Int]])
+
+    // do nothing for removeLogRecoveryMetrics for metrics verification
+    doNothing().when(spyLogManager).removeLogRecoveryMetrics()
+
+    // start the logManager to do log recovery
+    spyLogManager.startup(Set.empty)
+
+    // make sure log recovery metrics are added and removed
+    verify(spyLogManager, times(1)).addLogRecoveryMetrics(any[ConcurrentMap[String, Int]], any[ConcurrentMap[String, Int]])
+    verify(spyLogManager, times(1)).removeLogRecoveryMetrics()
+
+    // expected 1 log in each log dir since we created 2 partitions with 2 log dirs
+    val expectedRemainingLogsParams = Map[String, Int](logDir1.getAbsolutePath -> 1, logDir2.getAbsolutePath -> 1)
+    verifyRemainingLogsToRecoverMetric(spyLogManager, expectedRemainingLogsParams)
+
+    val expectedRemainingSegmentsParams = Map[String, Int](
+      logDir1.getAbsolutePath -> expectedSegmentsPerLog, logDir2.getAbsolutePath -> expectedSegmentsPerLog)
+    verifyRemainingSegmentsToRecoverMetric(spyLogManager, logDirs, recoveryThreadsPerDataDir, mockMap, expectedRemainingSegmentsParams)
+  }
+
+  @Test
+  def testLogRecoveryMetricsShouldBeRemovedAfterLogRecovered(): Unit = {
+    logManager.shutdown()
+    val logDir1 = TestUtils.tempDir()
+    val logDir2 = TestUtils.tempDir()
+    val logDirs = Seq(logDir1, logDir2)
+    val recoveryThreadsPerDataDir = 2
+    // create logManager with expected recovery thread number
+    logManager = createLogManager(logDirs, recoveryThreadsPerDataDir = recoveryThreadsPerDataDir)
+    val spyLogManager = spy(logManager)
+
+    assertEquals(2, spyLogManager.liveLogDirs.size)
+
+    // start the logManager to do log recovery
+    spyLogManager.startup(Set.empty)
+
+    // make sure log recovery metrics are added and removed once
+    verify(spyLogManager, times(1)).addLogRecoveryMetrics(any[ConcurrentMap[String, Int]], any[ConcurrentMap[String, Int]])
+    verify(spyLogManager, times(1)).removeLogRecoveryMetrics()
+
+    verifyLogRecoverMetricsRemoved(spyLogManager)
+  }
+
   @Test
   def testMetricsExistWhenLogIsRecreatedBeforeDeletion(): Unit = {
     val topicName = "metric-test"
diff --git a/core/src/test/scala/unit/kafka/log/LogTestUtils.scala b/core/src/test/scala/unit/kafka/log/LogTestUtils.scala
index f6b58d78ce..50af76f556 100644
--- a/core/src/test/scala/unit/kafka/log/LogTestUtils.scala
+++ b/core/src/test/scala/unit/kafka/log/LogTestUtils.scala
@@ -28,6 +28,7 @@ import org.apache.kafka.common.utils.{Time, Utils}
 import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse}
 
 import java.nio.file.Files
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
 import scala.collection.Iterable
 import scala.jdk.CollectionConverters._
 
@@ -83,7 +84,8 @@ object LogTestUtils {
                 producerIdExpirationCheckIntervalMs: Int = LogManager.ProducerIdExpirationCheckIntervalMs,
                 lastShutdownClean: Boolean = true,
                 topicId: Option[Uuid] = None,
-                keepPartitionMetadataFile: Boolean = true): UnifiedLog = {
+                keepPartitionMetadataFile: Boolean = true,
+                numRemainingSegments: ConcurrentMap[String, Int] = new ConcurrentHashMap[String, Int]): UnifiedLog = {
     UnifiedLog(
       dir = dir,
       config = config,
@@ -98,7 +100,8 @@ object LogTestUtils {
       logDirFailureChannel = new LogDirFailureChannel(10),
       lastShutdownClean = lastShutdownClean,
       topicId = topicId,
-      keepPartitionMetadataFile = keepPartitionMetadataFile
+      keepPartitionMetadataFile = keepPartitionMetadataFile,
+      numRemainingSegments = numRemainingSegments
     )
   }
 
diff --git a/core/src/test/scala/unit/kafka/utils/TestUtils.scala b/core/src/test/scala/unit/kafka/utils/TestUtils.scala
index 5a8d43795a..c49a7bdde0 100755
--- a/core/src/test/scala/unit/kafka/utils/TestUtils.scala
+++ b/core/src/test/scala/unit/kafka/utils/TestUtils.scala
@@ -1281,13 +1281,14 @@ object TestUtils extends Logging {
                        configRepository: ConfigRepository = new MockConfigRepository,
                        cleanerConfig: CleanerConfig = CleanerConfig(enableCleaner = false),
                        time: MockTime = new MockTime(),
-                       interBrokerProtocolVersion: MetadataVersion = MetadataVersion.latest): LogManager = {
+                       interBrokerProtocolVersion: MetadataVersion = MetadataVersion.latest,
+                       recoveryThreadsPerDataDir: Int = 4): LogManager = {
     new LogManager(logDirs = logDirs.map(_.getAbsoluteFile),
                    initialOfflineDirs = Array.empty[File],
                    configRepository = configRepository,
                    initialDefaultConfig = defaultConfig,
                    cleanerConfig = cleanerConfig,
-                   recoveryThreadsPerDataDir = 4,
+                   recoveryThreadsPerDataDir = recoveryThreadsPerDataDir,
                    flushCheckMs = 1000L,
                    flushRecoveryOffsetCheckpointMs = 10000L,
                    flushStartOffsetCheckpointMs = 10000L,
diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/server/CheckpointBench.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/server/CheckpointBench.java
index 3bf65afc22..99fb814327 100644
--- a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/server/CheckpointBench.java
+++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/server/CheckpointBench.java
@@ -108,7 +108,7 @@ public class CheckpointBench {
         this.logManager = TestUtils.createLogManager(JavaConverters.asScalaBuffer(files),
                 LogConfig.apply(), new MockConfigRepository(), CleanerConfig.apply(1, 4 * 1024 * 1024L, 0.9d,
                         1024 * 1024, 32 * 1024 * 1024,
-                        Double.MAX_VALUE, 15 * 1000, true, "MD5"), time, MetadataVersion.latest());
+                        Double.MAX_VALUE, 15 * 1000, true, "MD5"), time, MetadataVersion.latest(), 4);
         scheduler.startup();
         final BrokerTopicStats brokerTopicStats = new BrokerTopicStats();
         final MetadataCache metadataCache =