You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by sh...@apache.org on 2022/07/20 02:05:46 UTC

[kafka] branch 3.2 updated: KAFKA-14024: Consumer keeps Commit offset in onJoinPrepare in Cooperative rebalance (#12349)

This is an automated email from the ASF dual-hosted git repository.

showuon pushed a commit to branch 3.2
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/3.2 by this push:
     new d8541b20a1 KAFKA-14024: Consumer keeps Commit offset in onJoinPrepare in Cooperative rebalance (#12349)
d8541b20a1 is described below

commit d8541b20a106f22736947e0c2f293833f3c3873b
Author: Shawn <wa...@bytedance.com>
AuthorDate: Wed Jul 20 10:03:43 2022 +0800

    KAFKA-14024: Consumer keeps Commit offset in onJoinPrepare in Cooperative rebalance (#12349)
    
    In KAFKA-13310, we tried to fix a issue that consumer#poll(duration) will be returned after the provided duration. It's because if rebalance needed, we'll try to commit current offset first before rebalance synchronously. And if the offset committing takes too long, the consumer#poll will spend more time than provided duration. To fix that, we change commit sync with commit async before rebalance (i.e. onPrepareJoin).
    
    However, in this ticket, we found the async commit will keep sending a new commit request during each Consumer#poll, because the offset commit never completes in time. The impact is that the existing consumer will be kicked out of the group after rebalance timeout without joining the group. That is, suppose we have consumer A in group G, and now consumer B joined the group, after the rebalance, only consumer B in the group.
    
    Besides, there's also another bug found during fixing this bug. Before KAFKA-13310, we commitOffset sync with rebalanceTimeout, which will retry when retriable error until timeout. After KAFKA-13310, we thought we have retry, but we'll retry after partitions revoking. That is, even though the retried offset commit successfully, it still causes some partitions offsets un-committed, and after rebalance, other consumers will consume overlapping records.
    
    Reviewers: RivenSun <ri...@zoom.us>, Luke Chen <sh...@gmail.com>
---
 .../consumer/internals/AbstractCoordinator.java    |   5 +-
 .../consumer/internals/ConsumerCoordinator.java    |  73 ++++++++++---
 .../internals/AbstractCoordinatorTest.java         |   2 +-
 .../internals/ConsumerCoordinatorTest.java         | 120 ++++++++++++++++-----
 .../runtime/distributed/WorkerCoordinator.java     |   2 +-
 .../kafka/api/AbstractConsumerTest.scala           |  11 +-
 .../kafka/api/PlaintextConsumerTest.scala          |  87 +++++++++++++++
 7 files changed, 251 insertions(+), 49 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinator.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinator.java
index 5fe8a6a0e1..4d71482562 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinator.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinator.java
@@ -187,11 +187,12 @@ public abstract class AbstractCoordinator implements Closeable {
     /**
      * Invoked prior to each group join or rejoin. This is typically used to perform any
      * cleanup from the previous generation (such as committing offsets for the consumer)
+     * @param timer Timer bounding how long this method can block
      * @param generation The previous generation or -1 if there was none
      * @param memberId The identifier of this member in the previous group or "" if there was none
      * @return true If onJoinPrepare async commit succeeded, false otherwise
      */
-    protected abstract boolean onJoinPrepare(int generation, String memberId);
+    protected abstract boolean onJoinPrepare(Timer timer, int generation, String memberId);
 
     /**
      * Invoked when the leader is elected. This is used by the leader to perform the assignment
@@ -426,7 +427,7 @@ public abstract class AbstractCoordinator implements Closeable {
                 // exception, in which case upon retry we should not retry onJoinPrepare either.
                 needsJoinPrepare = false;
                 // return false when onJoinPrepare is waiting for committing offset
-                if (!onJoinPrepare(generation.generationId, generation.memberId)) {
+                if (!onJoinPrepare(timer, generation.generationId, generation.memberId)) {
                     needsJoinPrepare = true;
                     //should not initiateJoinGroup if needsJoinPrepare still is true
                     return false;
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
index b853ff99e8..9838e7dc8f 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
@@ -141,6 +141,12 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
     }
 
     private final RebalanceProtocol protocol;
+    // pending commit offset request in onJoinPrepare
+    private RequestFuture<Void> autoCommitOffsetRequestFuture = null;
+    // a timer for join prepare to know when to stop.
+    // it'll set to rebalance timeout so that the member can join the group successfully
+    // even though offset commit failed.
+    private Timer joinPrepareTimer = null;
 
     /**
      * Initialize the coordination manager.
@@ -740,24 +746,58 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
     }
 
     @Override
-    protected boolean onJoinPrepare(int generation, String memberId) {
+    protected boolean onJoinPrepare(Timer timer, int generation, String memberId) {
         log.debug("Executing onJoinPrepare with generation {} and memberId {}", generation, memberId);
-        boolean onJoinPrepareAsyncCommitCompleted = false;
+        if (joinPrepareTimer == null) {
+            // We should complete onJoinPrepare before rebalanceTimeout,
+            // and continue to join group to avoid member got kicked out from group
+            joinPrepareTimer = time.timer(rebalanceConfig.rebalanceTimeoutMs);
+        } else {
+            joinPrepareTimer.update();
+        }
+
         // async commit offsets prior to rebalance if auto-commit enabled
-        RequestFuture<Void> future = maybeAutoCommitOffsetsAsync();
-        // return true when
-        // 1. future is null, which means no commit request sent, so it is still considered completed
-        // 2. offset commit completed
-        // 3. offset commit failed with non-retriable exception
-        if (future == null)
-            onJoinPrepareAsyncCommitCompleted = true;
-        else if (future.succeeded())
-            onJoinPrepareAsyncCommitCompleted = true;
-        else if (future.failed() && !future.isRetriable()) {
-            log.error("Asynchronous auto-commit of offsets failed: {}", future.exception().getMessage());
-            onJoinPrepareAsyncCommitCompleted = true;
+        // and there is no in-flight offset commit request
+        if (autoCommitEnabled && autoCommitOffsetRequestFuture == null) {
+            autoCommitOffsetRequestFuture = maybeAutoCommitOffsetsAsync();
         }
 
+        // wait for commit offset response before timer expired
+        if (autoCommitOffsetRequestFuture != null) {
+            Timer pollTimer = timer.remainingMs() < joinPrepareTimer.remainingMs() ?
+                    timer : joinPrepareTimer;
+            client.poll(autoCommitOffsetRequestFuture, pollTimer);
+            joinPrepareTimer.update();
+
+            // Keep retrying/waiting the offset commit when:
+            // 1. offset commit haven't done (and joinPrepareTimer not expired)
+            // 2. failed with retryable exception (and joinPrepareTimer not expired)
+            // Otherwise, continue to revoke partitions, ex:
+            // 1. if joinPrepareTime has expired
+            // 2. if offset commit failed with no-retryable exception
+            // 3. if offset commit success
+            boolean onJoinPrepareAsyncCommitCompleted = true;
+            if (joinPrepareTimer.isExpired()) {
+                log.error("Asynchronous auto-commit of offsets failed: joinPrepare timeout. Will continue to join group");
+            } else if (!autoCommitOffsetRequestFuture.isDone()) {
+                onJoinPrepareAsyncCommitCompleted = false;
+            } else if (autoCommitOffsetRequestFuture.failed() && autoCommitOffsetRequestFuture.isRetriable()) {
+                log.debug("Asynchronous auto-commit of offsets failed with retryable error: {}. Will retry it.",
+                        autoCommitOffsetRequestFuture.exception().getMessage());
+                onJoinPrepareAsyncCommitCompleted = false;
+            } else if (autoCommitOffsetRequestFuture.failed() && !autoCommitOffsetRequestFuture.isRetriable()) {
+                log.error("Asynchronous auto-commit of offsets failed: {}. Will continue to join group.",
+                        autoCommitOffsetRequestFuture.exception().getMessage());
+            }
+            if (autoCommitOffsetRequestFuture.isDone()) {
+                autoCommitOffsetRequestFuture = null;
+            }
+            if (!onJoinPrepareAsyncCommitCompleted) {
+                pollTimer.sleep(Math.min(pollTimer.remainingMs(), rebalanceConfig.retryBackoffMs));
+                timer.update();
+                return false;
+            }
+        }
 
         // the generation / member-id can possibly be reset by the heartbeat thread
         // upon getting errors or heartbeat timeouts; in this case whatever is previously
@@ -809,11 +849,14 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
 
         isLeader = false;
         subscriptions.resetGroupSubscription();
+        joinPrepareTimer = null;
+        autoCommitOffsetRequestFuture = null;
+        timer.update();
 
         if (exception != null) {
             throw new KafkaException("User rebalance callback throws an error", exception);
         }
-        return onJoinPrepareAsyncCommitCompleted;
+        return true;
     }
 
     @Override
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinatorTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinatorTest.java
index 7cf6ee0e66..69dd893e12 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinatorTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinatorTest.java
@@ -1661,7 +1661,7 @@ public class AbstractCoordinatorTest {
         }
 
         @Override
-        protected boolean onJoinPrepare(int generation, String memberId) {
+        protected boolean onJoinPrepare(Timer timer, int generation, String memberId) {
             onJoinPrepareInvokes++;
             return true;
         }
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
index c65d33176f..df88d84f08 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
@@ -77,6 +77,7 @@ import org.apache.kafka.common.requests.SyncGroupResponse;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.SystemTime;
+import org.apache.kafka.common.utils.Timer;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.test.TestUtils;
 import org.junit.jupiter.api.AfterEach;
@@ -1299,9 +1300,71 @@ public abstract class ConsumerCoordinatorTest {
         }
     }
 
+    @Test
+    public void testOnJoinPrepareWithOffsetCommitShouldSuccessAfterRetry() {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, Optional.empty(), false)) {
+            int generationId = 42;
+            String memberId = "consumer-42";
+
+            Timer pollTimer = time.timer(100L);
+            client.prepareResponse(offsetCommitResponse(singletonMap(t1p, Errors.UNKNOWN_TOPIC_OR_PARTITION)));
+            boolean res = coordinator.onJoinPrepare(pollTimer, generationId, memberId);
+            assertFalse(res);
+
+            pollTimer = time.timer(100L);
+            client.prepareResponse(offsetCommitResponse(singletonMap(t1p, Errors.NONE)));
+            res = coordinator.onJoinPrepare(pollTimer, generationId, memberId);
+            assertTrue(res);
+
+            assertFalse(client.hasPendingResponses());
+            assertFalse(client.hasInFlightRequests());
+            assertFalse(coordinator.coordinatorUnknown());
+        }
+    }
+
+    @Test
+    public void testOnJoinPrepareWithOffsetCommitShouldKeepJoinAfterNonRetryableException() {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, Optional.empty(), false)) {
+            int generationId = 42;
+            String memberId = "consumer-42";
+
+            Timer pollTimer = time.timer(100L);
+            client.prepareResponse(offsetCommitResponse(singletonMap(t1p, Errors.UNKNOWN_MEMBER_ID)));
+            boolean res = coordinator.onJoinPrepare(pollTimer, generationId, memberId);
+            assertTrue(res);
+
+            assertFalse(client.hasPendingResponses());
+            assertFalse(client.hasInFlightRequests());
+            assertFalse(coordinator.coordinatorUnknown());
+        }
+    }
+
+    @Test
+    public void testOnJoinPrepareWithOffsetCommitShouldKeepJoinAfterRebalanceTimeout() {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, Optional.empty(), false)) {
+            int generationId = 42;
+            String memberId = "consumer-42";
+
+            Timer pollTimer = time.timer(100L);
+            time.sleep(150);
+            boolean res = coordinator.onJoinPrepare(pollTimer, generationId, memberId);
+            assertFalse(res);
+
+            pollTimer = time.timer(100L);
+            time.sleep(rebalanceTimeoutMs);
+            client.respond(offsetCommitResponse(singletonMap(t1p, Errors.UNKNOWN_TOPIC_OR_PARTITION)));
+            res = coordinator.onJoinPrepare(pollTimer, generationId, memberId);
+            assertTrue(res);
+
+            assertFalse(client.hasPendingResponses());
+            assertFalse(client.hasInFlightRequests());
+            assertFalse(coordinator.coordinatorUnknown());
+        }
+    }
+
     @Test
     public void testJoinPrepareWithDisableAutoCommit() {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.of("group-id"))) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.of("group-id"), true)) {
             coordinator.ensureActiveGroup();
 
             prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NONE);
@@ -1309,7 +1372,7 @@ public abstract class ConsumerCoordinatorTest {
             int generationId = 42;
             String memberId = "consumer-42";
 
-            boolean res = coordinator.onJoinPrepare(generationId, memberId);
+            boolean res = coordinator.onJoinPrepare(time.timer(0L), generationId, memberId);
 
             assertTrue(res);
             assertTrue(client.hasPendingResponses());
@@ -1320,14 +1383,14 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testJoinPrepareAndCommitCompleted() {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, Optional.of("group-id"))) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, Optional.of("group-id"), true)) {
             coordinator.ensureActiveGroup();
 
             prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NONE);
             int generationId = 42;
             String memberId = "consumer-42";
 
-            boolean res = coordinator.onJoinPrepare(generationId, memberId);
+            boolean res = coordinator.onJoinPrepare(time.timer(0L), generationId, memberId);
             coordinator.invokeCompletedOffsetCommitCallbacks();
 
             assertTrue(res);
@@ -1339,7 +1402,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testJoinPrepareAndCommitWithCoordinatorNotAvailable() {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, Optional.of("group-id"))) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, Optional.of("group-id"), true)) {
             coordinator.ensureActiveGroup();
 
             prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.COORDINATOR_NOT_AVAILABLE);
@@ -1347,7 +1410,7 @@ public abstract class ConsumerCoordinatorTest {
             int generationId = 42;
             String memberId = "consumer-42";
 
-            boolean res = coordinator.onJoinPrepare(generationId, memberId);
+            boolean res = coordinator.onJoinPrepare(time.timer(0L), generationId, memberId);
             coordinator.invokeCompletedOffsetCommitCallbacks();
 
             assertFalse(res);
@@ -1359,7 +1422,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testJoinPrepareAndCommitWithUnknownMemberId() {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, Optional.of("group-id"))) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, Optional.of("group-id"), true)) {
             coordinator.ensureActiveGroup();
 
             prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.UNKNOWN_MEMBER_ID);
@@ -1367,7 +1430,7 @@ public abstract class ConsumerCoordinatorTest {
             int generationId = 42;
             String memberId = "consumer-42";
 
-            boolean res = coordinator.onJoinPrepare(generationId, memberId);
+            boolean res = coordinator.onJoinPrepare(time.timer(0L), generationId, memberId);
             coordinator.invokeCompletedOffsetCommitCallbacks();
 
             assertTrue(res);
@@ -3078,21 +3141,21 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testCloseDynamicAssignment() {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, Optional.empty())) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, Optional.empty(), true)) {
             gracefulCloseTest(coordinator, true);
         }
     }
 
     @Test
     public void testCloseManualAssignment() {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(false, true, Optional.empty())) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(false, true, Optional.empty(), true)) {
             gracefulCloseTest(coordinator, false);
         }
     }
 
     @Test
     public void testCloseCoordinatorNotKnownManualAssignment() throws Exception {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(false, true, Optional.empty())) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(false, true, Optional.empty(), true)) {
             makeCoordinatorUnknown(coordinator, Errors.NOT_COORDINATOR);
             time.sleep(autoCommitIntervalMs);
             closeVerifyTimeout(coordinator, 1000, 1000, 1000);
@@ -3101,7 +3164,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testCloseCoordinatorNotKnownNoCommits() throws Exception {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.empty())) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.empty(), true)) {
             makeCoordinatorUnknown(coordinator, Errors.NOT_COORDINATOR);
             closeVerifyTimeout(coordinator, 1000, 0, 0);
         }
@@ -3109,7 +3172,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testCloseCoordinatorNotKnownWithCommits() throws Exception {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, Optional.empty())) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, Optional.empty(), true)) {
             makeCoordinatorUnknown(coordinator, Errors.NOT_COORDINATOR);
             time.sleep(autoCommitIntervalMs);
             closeVerifyTimeout(coordinator, 1000, 1000, 1000);
@@ -3118,7 +3181,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testCloseCoordinatorUnavailableNoCommits() throws Exception {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.empty())) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.empty(), true)) {
             makeCoordinatorUnknown(coordinator, Errors.COORDINATOR_NOT_AVAILABLE);
             closeVerifyTimeout(coordinator, 1000, 0, 0);
         }
@@ -3126,7 +3189,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testCloseTimeoutCoordinatorUnavailableForCommit() throws Exception {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, groupInstanceId)) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, groupInstanceId, true)) {
             makeCoordinatorUnknown(coordinator, Errors.COORDINATOR_NOT_AVAILABLE);
             time.sleep(autoCommitIntervalMs);
             closeVerifyTimeout(coordinator, 1000, 1000, 1000);
@@ -3135,7 +3198,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testCloseMaxWaitCoordinatorUnavailableForCommit() throws Exception {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, groupInstanceId)) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, groupInstanceId, true)) {
             makeCoordinatorUnknown(coordinator, Errors.COORDINATOR_NOT_AVAILABLE);
             time.sleep(autoCommitIntervalMs);
             closeVerifyTimeout(coordinator, Long.MAX_VALUE, requestTimeoutMs, requestTimeoutMs);
@@ -3144,7 +3207,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testCloseNoResponseForCommit() throws Exception {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, groupInstanceId)) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, groupInstanceId, true)) {
             time.sleep(autoCommitIntervalMs);
             closeVerifyTimeout(coordinator, Long.MAX_VALUE, requestTimeoutMs, requestTimeoutMs);
         }
@@ -3152,14 +3215,14 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testCloseNoResponseForLeaveGroup() throws Exception {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.empty())) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.empty(), true)) {
             closeVerifyTimeout(coordinator, Long.MAX_VALUE, requestTimeoutMs, requestTimeoutMs);
         }
     }
 
     @Test
     public void testCloseNoWait() throws Exception {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, groupInstanceId)) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, groupInstanceId, true)) {
             time.sleep(autoCommitIntervalMs);
             closeVerifyTimeout(coordinator, 0, 0, 0);
         }
@@ -3167,7 +3230,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testHeartbeatThreadClose() throws Exception {
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, groupInstanceId)) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, groupInstanceId, true)) {
             coordinator.ensureActiveGroup();
             time.sleep(heartbeatIntervalMs + 100);
             Thread.yield(); // Give heartbeat thread a chance to attempt heartbeat
@@ -3234,7 +3297,7 @@ public abstract class ConsumerCoordinatorTest {
         assertEquals(JoinGroupRequest.UNKNOWN_MEMBER_ID, groupMetadata.memberId());
         assertFalse(groupMetadata.groupInstanceId().isPresent());
 
-        try (final ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, groupInstanceId)) {
+        try (final ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, groupInstanceId, true)) {
             coordinator.ensureActiveGroup();
 
             final ConsumerGroupMetadata joinedGroupMetadata = coordinator.groupMetadata();
@@ -3270,7 +3333,7 @@ public abstract class ConsumerCoordinatorTest {
     @Test
     public void testPrepareJoinAndRejoinAfterFailedRebalance() {
         final List<TopicPartition> partitions = singletonList(t1p);
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.of("group-id"))) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.of("group-id"), true)) {
             coordinator.ensureActiveGroup();
 
             prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.REBALANCE_IN_PROGRESS);
@@ -3290,7 +3353,7 @@ public abstract class ConsumerCoordinatorTest {
             MockTime time = new MockTime(1);
 
             // onJoinPrepare will be executed and onJoinComplete will not.
-            boolean res = coordinator.joinGroupIfNeeded(time.timer(2));
+            boolean res = coordinator.joinGroupIfNeeded(time.timer(100));
 
             assertFalse(res);
             assertFalse(client.hasPendingResponses());
@@ -3335,7 +3398,7 @@ public abstract class ConsumerCoordinatorTest {
     @Test
     public void shouldLoseAllOwnedPartitionsBeforeRejoiningAfterDroppingOutOfTheGroup() {
         final List<TopicPartition> partitions = singletonList(t1p);
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.of("group-id"))) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.of("group-id"), true)) {
             final SystemTime realTime = new SystemTime();
             coordinator.ensureActiveGroup();
 
@@ -3368,7 +3431,7 @@ public abstract class ConsumerCoordinatorTest {
     @Test
     public void shouldLoseAllOwnedPartitionsBeforeRejoiningAfterResettingGenerationId() {
         final List<TopicPartition> partitions = singletonList(t1p);
-        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.of("group-id"))) {
+        try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.of("group-id"), true)) {
             final SystemTime realTime = new SystemTime();
             coordinator.ensureActiveGroup();
 
@@ -3462,7 +3525,8 @@ public abstract class ConsumerCoordinatorTest {
 
     private ConsumerCoordinator prepareCoordinatorForCloseTest(final boolean useGroupManagement,
                                                                final boolean autoCommit,
-                                                               final Optional<String> groupInstanceId) {
+                                                               final Optional<String> groupInstanceId,
+                                                               final boolean shouldPoll) {
         rebalanceConfig = buildRebalanceConfig(groupInstanceId);
         ConsumerCoordinator coordinator = buildCoordinator(rebalanceConfig,
                                                            new Metrics(),
@@ -3481,7 +3545,9 @@ public abstract class ConsumerCoordinatorTest {
         }
 
         subscriptions.seek(t1p, 100);
-        coordinator.poll(time.timer(Long.MAX_VALUE));
+        if (shouldPoll) {
+            coordinator.poll(time.timer(Long.MAX_VALUE));
+        }
 
         return coordinator;
     }
diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinator.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinator.java
index 65720e2a78..ce1b82a9d5 100644
--- a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinator.java
+++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinator.java
@@ -224,7 +224,7 @@ public class WorkerCoordinator extends AbstractCoordinator implements Closeable
     }
 
     @Override
-    protected boolean onJoinPrepare(int generation, String memberId) {
+    protected boolean onJoinPrepare(Timer timer, int generation, String memberId) {
         log.info("Rebalance started");
         leaderState(null);
         final ExtendedAssignment localAssignmentSnapshot = assignmentSnapshot;
diff --git a/core/src/test/scala/integration/kafka/api/AbstractConsumerTest.scala b/core/src/test/scala/integration/kafka/api/AbstractConsumerTest.scala
index 56bc47c79e..23b56b8e91 100644
--- a/core/src/test/scala/integration/kafka/api/AbstractConsumerTest.scala
+++ b/core/src/test/scala/integration/kafka/api/AbstractConsumerTest.scala
@@ -342,15 +342,16 @@ abstract class AbstractConsumerTest extends BaseRequestTest {
 
   protected class ConsumerAssignmentPoller(consumer: Consumer[Array[Byte], Array[Byte]],
                                            topicsToSubscribe: List[String],
-                                           partitionsToAssign: Set[TopicPartition])
+                                           partitionsToAssign: Set[TopicPartition],
+                                           userRebalanceListener: ConsumerRebalanceListener)
     extends ShutdownableThread("daemon-consumer-assignment", false) {
 
     def this(consumer: Consumer[Array[Byte], Array[Byte]], topicsToSubscribe: List[String]) = {
-      this(consumer, topicsToSubscribe, Set.empty[TopicPartition])
+      this(consumer, topicsToSubscribe, Set.empty[TopicPartition], null)
     }
 
     def this(consumer: Consumer[Array[Byte], Array[Byte]], partitionsToAssign: Set[TopicPartition]) = {
-      this(consumer, List.empty[String], partitionsToAssign)
+      this(consumer, List.empty[String], partitionsToAssign, null)
     }
 
     @volatile var thrownException: Option[Throwable] = None
@@ -363,10 +364,14 @@ abstract class AbstractConsumerTest extends BaseRequestTest {
     val rebalanceListener: ConsumerRebalanceListener = new ConsumerRebalanceListener {
       override def onPartitionsAssigned(partitions: util.Collection[TopicPartition]) = {
         partitionAssignment ++= partitions.toArray(new Array[TopicPartition](0))
+        if (userRebalanceListener != null)
+          userRebalanceListener.onPartitionsAssigned(partitions)
       }
 
       override def onPartitionsRevoked(partitions: util.Collection[TopicPartition]) = {
         partitionAssignment --= partitions.toArray(new Array[TopicPartition](0))
+        if (userRebalanceListener != null)
+          userRebalanceListener.onPartitionsRevoked(partitions)
       }
     }
 
diff --git a/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala b/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala
index 4ede241b0c..5dc7c2ada1 100644
--- a/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala
+++ b/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala
@@ -37,7 +37,11 @@ import kafka.server.QuotaType
 import kafka.server.KafkaServer
 import org.apache.kafka.clients.admin.NewPartitions
 import org.apache.kafka.clients.admin.NewTopic
+import org.junit.jupiter.params.ParameterizedTest
+import org.junit.jupiter.params.provider.ValueSource
 
+import java.util.concurrent.TimeUnit
+import java.util.concurrent.locks.ReentrantLock
 import scala.collection.mutable
 
 /* We have some tests in this class instead of `BaseConsumerTest` in order to keep the build time under control. */
@@ -969,6 +973,89 @@ class PlaintextConsumerTest extends BaseConsumerTest {
     }
   }
 
+  @ParameterizedTest
+  @ValueSource(strings = Array(
+    "org.apache.kafka.clients.consumer.CooperativeStickyAssignor",
+    "org.apache.kafka.clients.consumer.RangeAssignor"))
+  def testRebalanceAndRejoin(assignmentStrategy: String): Unit = {
+    // create 2 consumers
+    this.consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "rebalance-and-rejoin-group")
+    this.consumerConfig.setProperty(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG, assignmentStrategy)
+    this.consumerConfig.setProperty(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "true")
+    val consumer1 = createConsumer()
+    val consumer2 = createConsumer()
+
+    // create a new topic, have 2 partitions
+    val topic = "topic1"
+    val producer = createProducer()
+    val expectedAssignment = createTopicAndSendRecords(producer, topic, 2, 100)
+
+    assertEquals(0, consumer1.assignment().size)
+    assertEquals(0, consumer2.assignment().size)
+
+    val lock = new ReentrantLock()
+    var generationId1 = -1
+    var memberId1 = ""
+    val customRebalanceListener = new ConsumerRebalanceListener {
+      override def onPartitionsRevoked(partitions: util.Collection[TopicPartition]): Unit = {
+      }
+      override def onPartitionsAssigned(partitions: util.Collection[TopicPartition]): Unit = {
+        if (!lock.tryLock(3000, TimeUnit.MILLISECONDS)) {
+          fail(s"Time out while awaiting for lock.")
+        }
+        try {
+          generationId1 = consumer1.groupMetadata().generationId()
+          memberId1 = consumer1.groupMetadata().memberId()
+        } finally {
+          lock.unlock()
+        }
+      }
+    }
+    val consumerPoller1 = new ConsumerAssignmentPoller(consumer1, List(topic), Set.empty, customRebalanceListener)
+    consumerPoller1.start()
+    TestUtils.waitUntilTrue(() => consumerPoller1.consumerAssignment() == expectedAssignment,
+      s"Timed out while awaiting expected assignment change to $expectedAssignment.")
+
+    // Since the consumer1 already completed the rebalance,
+    // the `onPartitionsAssigned` rebalance listener will be invoked to set the generationId and memberId
+    var stableGeneration = -1
+    var stableMemberId1 = ""
+    if (!lock.tryLock(3000, TimeUnit.MILLISECONDS)) {
+      fail(s"Time out while awaiting for lock.")
+    }
+    try {
+      stableGeneration = generationId1
+      stableMemberId1 = memberId1
+    } finally {
+      lock.unlock()
+    }
+
+    val consumerPoller2 = subscribeConsumerAndStartPolling(consumer2, List(topic))
+    TestUtils.waitUntilTrue(() => consumerPoller1.consumerAssignment().size == 1,
+      s"Timed out while awaiting expected assignment size change to 1.")
+    TestUtils.waitUntilTrue(() => consumerPoller2.consumerAssignment().size == 1,
+      s"Timed out while awaiting expected assignment size change to 1.")
+
+    if (!lock.tryLock(3000, TimeUnit.MILLISECONDS)) {
+      fail(s"Time out while awaiting for lock.")
+    }
+    try {
+      if (assignmentStrategy.equals(classOf[CooperativeStickyAssignor].getName)) {
+        // cooperative rebalance should rebalance twice before finally stable
+        assertEquals(stableGeneration + 2, generationId1)
+      } else {
+        // eager rebalance should rebalance once before finally stable
+        assertEquals(stableGeneration + 1, generationId1)
+      }
+      assertEquals(stableMemberId1, memberId1)
+    } finally {
+      lock.unlock()
+    }
+
+    consumerPoller1.shutdown()
+    consumerPoller2.shutdown()
+  }
+
   /**
    * This test re-uses BaseConsumerTest's consumers.
    * As a result, it is testing the default assignment strategy set by BaseConsumerTest