You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by da...@apache.org on 2023/10/30 18:51:38 UTC

(kafka) branch trunk updated: KAFKA-15628: Refactor ConsumerRebalanceListener invocation for reuse (#14638)

This is an automated email from the ASF dual-hosted git repository.

dajac pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 2e2f32c0500 KAFKA-15628: Refactor ConsumerRebalanceListener invocation for reuse (#14638)
2e2f32c0500 is described below

commit 2e2f32c05008cdd7009e5f76fdd92f98996aab84
Author: Kirk True <ki...@kirktrue.pro>
AuthorDate: Mon Oct 30 11:51:30 2023 -0700

    KAFKA-15628: Refactor ConsumerRebalanceListener invocation for reuse (#14638)
    
    Straightforward refactoring to extract an inner class and methods related to `ConsumerRebalanceListener` for reuse in the KIP-848 implementation of the consumer group protocol. Also using `Optional` to explicitly mark when a `ConsumerRebalanceListener` is in use or not, allowing us to make some (forthcoming) optimizations when there is no listener to invoke.
    
    Reviewers: David Jacot <dj...@confluent.io>
---
 .../kafka/clients/consumer/KafkaConsumer.java      | 133 ++++++++++++------
 .../kafka/clients/consumer/MockConsumer.java       |  45 ++++---
 .../consumer/internals/ConsumerCoordinator.java    | 148 +++------------------
 .../internals/ConsumerCoordinatorMetrics.java      |  81 +++++++++++
 .../ConsumerRebalanceListenerInvoker.java          | 128 ++++++++++++++++++
 .../internals/NoOpConsumerRebalanceListener.java   |  32 -----
 .../consumer/internals/PrototypeAsyncConsumer.java | 125 +++++++++--------
 .../consumer/internals/SubscriptionState.java      |  14 +-
 .../internals/ConsumerCoordinatorTest.java         | 106 +++++++--------
 .../consumer/internals/ConsumerMetadataTest.java   |   6 +-
 .../internals/FetchRequestManagerTest.java         |   6 +-
 .../clients/consumer/internals/FetcherTest.java    |   6 +-
 .../internals/HeartbeatRequestManagerTest.java     |   2 +-
 .../consumer/internals/SubscriptionStateTest.java  |  42 +++---
 .../kafka/api/AuthorizerIntegrationTest.scala      |   3 +-
 15 files changed, 509 insertions(+), 368 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java b/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java
index dd273c38c43..141ac66c5b7 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java
@@ -31,7 +31,6 @@ import org.apache.kafka.clients.consumer.internals.FetchConfig;
 import org.apache.kafka.clients.consumer.internals.FetchMetricsManager;
 import org.apache.kafka.clients.consumer.internals.Fetcher;
 import org.apache.kafka.clients.consumer.internals.KafkaConsumerMetrics;
-import org.apache.kafka.clients.consumer.internals.NoOpConsumerRebalanceListener;
 import org.apache.kafka.clients.consumer.internals.OffsetFetcher;
 import org.apache.kafka.clients.consumer.internals.SubscriptionState;
 import org.apache.kafka.clients.consumer.internals.TopicMetadataFetcher;
@@ -903,6 +902,59 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
      */
     @Override
     public void subscribe(Collection<String> topics, ConsumerRebalanceListener listener) {
+        if (listener == null)
+            throw new IllegalArgumentException("RebalanceListener cannot be null");
+
+        subscribe(topics, Optional.of(listener));
+    }
+
+    /**
+     * Subscribe to the given list of topics to get dynamically assigned partitions.
+     * <b>Topic subscriptions are not incremental. This list will replace the current
+     * assignment (if there is one).</b> It is not possible to combine topic subscription with group management
+     * with manual partition assignment through {@link #assign(Collection)}.
+     *
+     * If the given list of topics is empty, it is treated the same as {@link #unsubscribe()}.
+     *
+     * <p>
+     * This is a short-hand for {@link #subscribe(Collection, ConsumerRebalanceListener)}, which
+     * uses a no-op listener. If you need the ability to seek to particular offsets, you should prefer
+     * {@link #subscribe(Collection, ConsumerRebalanceListener)}, since group rebalances will cause partition offsets
+     * to be reset. You should also provide your own listener if you are doing your own offset
+     * management since the listener gives you an opportunity to commit offsets before a rebalance finishes.
+     *
+     * @param topics The list of topics to subscribe to
+     * @throws IllegalArgumentException If topics is null or contains null or empty elements
+     * @throws IllegalStateException If {@code subscribe()} is called previously with pattern, or assign is called
+     *                               previously (without a subsequent call to {@link #unsubscribe()}), or if not
+     *                               configured at-least one partition assignment strategy
+     */
+    @Override
+    public void subscribe(Collection<String> topics) {
+        subscribe(topics, Optional.empty());
+    }
+
+    /**
+     * Internal helper method for {@link #subscribe(Collection)} and
+     * {@link #subscribe(Collection, ConsumerRebalanceListener)}
+     * <p>
+     * Subscribe to the given list of topics to get dynamically assigned partitions.
+     * <b>Topic subscriptions are not incremental. This list will replace the current
+     * assignment (if there is one).</b> It is not possible to combine topic subscription with group management
+     * with manual partition assignment through {@link #assign(Collection)}.
+     *
+     * If the given list of topics is empty, it is treated the same as {@link #unsubscribe()}.
+     *
+     * <p>
+     * @param topics The list of topics to subscribe to
+     * @param listener {@link Optional} listener instance to get notifications on partition assignment/revocation
+     *                 for the subscribed topics
+     * @throws IllegalArgumentException If topics is null or contains null or empty elements
+     * @throws IllegalStateException If {@code subscribe()} is called previously with pattern, or assign is called
+     *                               previously (without a subsequent call to {@link #unsubscribe()}), or if not
+     *                               configured at-least one partition assignment strategy
+     */
+    private void subscribe(Collection<String> topics, Optional<ConsumerRebalanceListener> listener) {
         acquireAndEnsureOpen();
         try {
             maybeThrowInvalidGroupIdException();
@@ -939,32 +991,57 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
     }
 
     /**
-     * Subscribe to the given list of topics to get dynamically assigned partitions.
-     * <b>Topic subscriptions are not incremental. This list will replace the current
-     * assignment (if there is one).</b> It is not possible to combine topic subscription with group management
-     * with manual partition assignment through {@link #assign(Collection)}.
-     *
-     * If the given list of topics is empty, it is treated the same as {@link #unsubscribe()}.
+     * Subscribe to all topics matching specified pattern to get dynamically assigned partitions.
+     * The pattern matching will be done periodically against all topics existing at the time of check.
+     * This can be controlled through the {@code metadata.max.age.ms} configuration: by lowering
+     * the max metadata age, the consumer will refresh metadata more often and check for matching topics.
+     * <p>
+     * See {@link #subscribe(Collection, ConsumerRebalanceListener)} for details on the
+     * use of the {@link ConsumerRebalanceListener}. Generally rebalances are triggered when there
+     * is a change to the topics matching the provided pattern and when consumer group membership changes.
+     * Group rebalances only take place during an active call to {@link #poll(Duration)}.
      *
+     * @param pattern Pattern to subscribe to
+     * @param listener Non-null listener instance to get notifications on partition assignment/revocation for the
+     *                 subscribed topics
+     * @throws IllegalArgumentException If pattern or listener is null
+     * @throws IllegalStateException If {@code subscribe()} is called previously with topics, or assign is called
+     *                               previously (without a subsequent call to {@link #unsubscribe()}), or if not
+     *                               configured at-least one partition assignment strategy
+     */
+    @Override
+    public void subscribe(Pattern pattern, ConsumerRebalanceListener listener) {
+        if (listener == null)
+            throw new IllegalArgumentException("RebalanceListener cannot be null");
+
+        subscribe(pattern, Optional.of(listener));
+    }
+
+    /**
+     * Subscribe to all topics matching specified pattern to get dynamically assigned partitions.
+     * The pattern matching will be done periodically against topics existing at the time of check.
      * <p>
-     * This is a short-hand for {@link #subscribe(Collection, ConsumerRebalanceListener)}, which
+     * This is a short-hand for {@link #subscribe(Pattern, ConsumerRebalanceListener)}, which
      * uses a no-op listener. If you need the ability to seek to particular offsets, you should prefer
-     * {@link #subscribe(Collection, ConsumerRebalanceListener)}, since group rebalances will cause partition offsets
+     * {@link #subscribe(Pattern, ConsumerRebalanceListener)}, since group rebalances will cause partition offsets
      * to be reset. You should also provide your own listener if you are doing your own offset
      * management since the listener gives you an opportunity to commit offsets before a rebalance finishes.
      *
-     * @param topics The list of topics to subscribe to
-     * @throws IllegalArgumentException If topics is null or contains null or empty elements
-     * @throws IllegalStateException If {@code subscribe()} is called previously with pattern, or assign is called
+     * @param pattern Pattern to subscribe to
+     * @throws IllegalArgumentException If pattern is null
+     * @throws IllegalStateException If {@code subscribe()} is called previously with topics, or assign is called
      *                               previously (without a subsequent call to {@link #unsubscribe()}), or if not
      *                               configured at-least one partition assignment strategy
      */
     @Override
-    public void subscribe(Collection<String> topics) {
-        subscribe(topics, new NoOpConsumerRebalanceListener());
+    public void subscribe(Pattern pattern) {
+        subscribe(pattern, Optional.empty());
     }
 
     /**
+     * Internal helper method for {@link #subscribe(Pattern)} and
+     * {@link #subscribe(Pattern, ConsumerRebalanceListener)}
+     * <p>
      * Subscribe to all topics matching specified pattern to get dynamically assigned partitions.
      * The pattern matching will be done periodically against all topics existing at the time of check.
      * This can be controlled through the {@code metadata.max.age.ms} configuration: by lowering
@@ -976,15 +1053,14 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
      * Group rebalances only take place during an active call to {@link #poll(Duration)}.
      *
      * @param pattern Pattern to subscribe to
-     * @param listener Non-null listener instance to get notifications on partition assignment/revocation for the
-     *                 subscribed topics
+     * @param listener {@link Optional} listener instance to get notifications on partition assignment/revocation
+     *                 for the subscribed topics
      * @throws IllegalArgumentException If pattern or listener is null
      * @throws IllegalStateException If {@code subscribe()} is called previously with topics, or assign is called
      *                               previously (without a subsequent call to {@link #unsubscribe()}), or if not
      *                               configured at-least one partition assignment strategy
      */
-    @Override
-    public void subscribe(Pattern pattern, ConsumerRebalanceListener listener) {
+    private void subscribe(Pattern pattern, Optional<ConsumerRebalanceListener> listener) {
         maybeThrowInvalidGroupIdException();
         if (pattern == null || pattern.toString().equals(""))
             throw new IllegalArgumentException("Topic pattern to subscribe to cannot be " + (pattern == null ?
@@ -1002,27 +1078,6 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
         }
     }
 
-    /**
-     * Subscribe to all topics matching specified pattern to get dynamically assigned partitions.
-     * The pattern matching will be done periodically against topics existing at the time of check.
-     * <p>
-     * This is a short-hand for {@link #subscribe(Pattern, ConsumerRebalanceListener)}, which
-     * uses a no-op listener. If you need the ability to seek to particular offsets, you should prefer
-     * {@link #subscribe(Pattern, ConsumerRebalanceListener)}, since group rebalances will cause partition offsets
-     * to be reset. You should also provide your own listener if you are doing your own offset
-     * management since the listener gives you an opportunity to commit offsets before a rebalance finishes.
-     *
-     * @param pattern Pattern to subscribe to
-     * @throws IllegalArgumentException If pattern is null
-     * @throws IllegalStateException If {@code subscribe()} is called previously with topics, or assign is called
-     *                               previously (without a subsequent call to {@link #unsubscribe()}), or if not
-     *                               configured at-least one partition assignment strategy
-     */
-    @Override
-    public void subscribe(Pattern pattern) {
-        subscribe(pattern, new NoOpConsumerRebalanceListener());
-    }
-
     /**
      * Unsubscribe from topics currently subscribed with {@link #subscribe(Collection)} or {@link #subscribe(Pattern)}.
      * This also clears any partitions directly assigned through {@link #assign(Collection)}.
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java b/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java
index 5f143c9a079..53c13e4b98b 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java
@@ -17,7 +17,6 @@
 package org.apache.kafka.clients.consumer;
 
 import org.apache.kafka.clients.Metadata;
-import org.apache.kafka.clients.consumer.internals.NoOpConsumerRebalanceListener;
 import org.apache.kafka.clients.consumer.internals.SubscriptionState;
 import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.Metric;
@@ -109,10 +108,10 @@ public class MockConsumer<K, V> implements Consumer<K, V> {
 
         // rebalance callbacks
         if (!added.isEmpty()) {
-            this.subscriptions.rebalanceListener().onPartitionsAssigned(added);
+            this.subscriptions.rebalanceListener().ifPresent(crl -> crl.onPartitionsAssigned(added));
         }
         if (!removed.isEmpty()) {
-            this.subscriptions.rebalanceListener().onPartitionsRevoked(removed);
+            this.subscriptions.rebalanceListener().ifPresent(crl -> crl.onPartitionsRevoked(removed));
         }
     }
 
@@ -123,11 +122,37 @@ public class MockConsumer<K, V> implements Consumer<K, V> {
 
     @Override
     public synchronized void subscribe(Collection<String> topics) {
-        subscribe(topics, new NoOpConsumerRebalanceListener());
+        subscribe(topics, Optional.empty());
     }
 
     @Override
     public synchronized void subscribe(Pattern pattern, final ConsumerRebalanceListener listener) {
+        if (listener == null)
+            throw new IllegalArgumentException("RebalanceListener cannot be null");
+
+        subscribe(pattern, Optional.of(listener));
+    }
+
+    @Override
+    public synchronized void subscribe(Pattern pattern) {
+        subscribe(pattern, Optional.empty());
+    }
+
+    @Override
+    public void subscribe(Collection<String> topics, final ConsumerRebalanceListener listener) {
+        if (listener == null)
+            throw new IllegalArgumentException("RebalanceListener cannot be null");
+
+        subscribe(topics, Optional.of(listener));
+    }
+
+    private synchronized void subscribe(Collection<String> topics, Optional<ConsumerRebalanceListener> listener) {
+        ensureNotClosed();
+        committed.clear();
+        this.subscriptions.subscribe(new HashSet<>(topics), listener);
+    }
+
+    private synchronized void subscribe(Pattern pattern, Optional<ConsumerRebalanceListener> listener) {
         ensureNotClosed();
         committed.clear();
         this.subscriptions.subscribe(pattern, listener);
@@ -149,18 +174,6 @@ public class MockConsumer<K, V> implements Consumer<K, V> {
         subscriptions.assignFromSubscribed(assignedPartitions);
     }
 
-    @Override
-    public synchronized void subscribe(Pattern pattern) {
-        subscribe(pattern, new NoOpConsumerRebalanceListener());
-    }
-
-    @Override
-    public synchronized void subscribe(Collection<String> topics, final ConsumerRebalanceListener listener) {
-        ensureNotClosed();
-        committed.clear();
-        this.subscriptions.subscribe(new HashSet<>(topics), listener);
-    }
-
     @Override
     public synchronized void assign(Collection<TopicPartition> partitions) {
         ensureNotClosed();
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
index bdcbfc39dfc..38dda759e43 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
@@ -25,7 +25,6 @@ import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.Assignment;
 import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.GroupSubscription;
 import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.RebalanceProtocol;
 import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.Subscription;
-import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
 import org.apache.kafka.clients.consumer.OffsetAndMetadata;
 import org.apache.kafka.clients.consumer.OffsetCommitCallback;
 import org.apache.kafka.clients.consumer.RetriableCommitFailedException;
@@ -48,11 +47,7 @@ import org.apache.kafka.common.message.JoinGroupRequestData;
 import org.apache.kafka.common.message.JoinGroupResponseData;
 import org.apache.kafka.common.message.OffsetCommitRequestData;
 import org.apache.kafka.common.message.OffsetCommitResponseData;
-import org.apache.kafka.common.metrics.Measurable;
 import org.apache.kafka.common.metrics.Metrics;
-import org.apache.kafka.common.metrics.Sensor;
-import org.apache.kafka.common.metrics.stats.Avg;
-import org.apache.kafka.common.metrics.stats.Max;
 import org.apache.kafka.common.protocol.ApiKeys;
 import org.apache.kafka.common.protocol.Errors;
 import org.apache.kafka.common.record.RecordBatch;
@@ -100,7 +95,7 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
     private final Logger log;
     private final List<ConsumerPartitionAssignor> assignors;
     private final ConsumerMetadata metadata;
-    private final ConsumerCoordinatorMetrics sensors;
+    private final ConsumerCoordinatorMetrics coordinatorMetrics;
     private final SubscriptionState subscriptions;
     private final OffsetCommitCallback defaultOffsetCommitCallback;
     private final boolean autoCommitEnabled;
@@ -148,6 +143,8 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
     }
 
     private final RebalanceProtocol protocol;
+    // Wraps the logic for invoking the ConsumerRebalanceListener methods
+    private final ConsumerRebalanceListenerInvoker rebalanceListenerInvoker;
     // pending commit offset request in onJoinPrepare
     private RequestFuture<Void> autoCommitOffsetRequestFuture = null;
     // a timer for join prepare to know when to stop.
@@ -189,7 +186,7 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
         this.autoCommitIntervalMs = autoCommitIntervalMs;
         this.assignors = assignors;
         this.completedOffsetCommits = new ConcurrentLinkedQueue<>();
-        this.sensors = new ConsumerCoordinatorMetrics(metrics, metricGrpPrefix);
+        this.coordinatorMetrics = new ConsumerCoordinatorMetrics(subscriptions, metrics, metricGrpPrefix);
         this.interceptors = interceptors;
         this.inFlightAsyncCommits = new AtomicInteger();
         this.pendingAsyncCommits = new AtomicInteger();
@@ -227,6 +224,12 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
             protocol = null;
         }
 
+        this.rebalanceListenerInvoker = new ConsumerRebalanceListenerInvoker(
+            logContext,
+            subscriptions,
+            time,
+            coordinatorMetrics
+        );
         this.metadata.requestUpdate(true);
     }
 
@@ -321,71 +324,6 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
         return null;
     }
 
-    private Exception invokePartitionsAssigned(final SortedSet<TopicPartition> assignedPartitions) {
-        log.info("Adding newly assigned partitions: {}", Utils.join(assignedPartitions, ", "));
-
-        ConsumerRebalanceListener listener = subscriptions.rebalanceListener();
-        try {
-            final long startMs = time.milliseconds();
-            listener.onPartitionsAssigned(assignedPartitions);
-            sensors.assignCallbackSensor.record(time.milliseconds() - startMs);
-        } catch (WakeupException | InterruptException e) {
-            throw e;
-        } catch (Exception e) {
-            log.error("User provided listener {} failed on invocation of onPartitionsAssigned for partitions {}",
-                listener.getClass().getName(), assignedPartitions, e);
-            return e;
-        }
-
-        return null;
-    }
-
-    private Exception invokePartitionsRevoked(final SortedSet<TopicPartition> revokedPartitions) {
-        log.info("Revoke previously assigned partitions {}", Utils.join(revokedPartitions, ", "));
-        Set<TopicPartition> revokePausedPartitions = subscriptions.pausedPartitions();
-        revokePausedPartitions.retainAll(revokedPartitions);
-        if (!revokePausedPartitions.isEmpty())
-            log.info("The pause flag in partitions [{}] will be removed due to revocation.", Utils.join(revokePausedPartitions, ", "));
-
-        ConsumerRebalanceListener listener = subscriptions.rebalanceListener();
-        try {
-            final long startMs = time.milliseconds();
-            listener.onPartitionsRevoked(revokedPartitions);
-            sensors.revokeCallbackSensor.record(time.milliseconds() - startMs);
-        } catch (WakeupException | InterruptException e) {
-            throw e;
-        } catch (Exception e) {
-            log.error("User provided listener {} failed on invocation of onPartitionsRevoked for partitions {}",
-                listener.getClass().getName(), revokedPartitions, e);
-            return e;
-        }
-
-        return null;
-    }
-
-    private Exception invokePartitionsLost(final SortedSet<TopicPartition> lostPartitions) {
-        log.info("Lost previously assigned partitions {}", Utils.join(lostPartitions, ", "));
-        Set<TopicPartition> lostPausedPartitions = subscriptions.pausedPartitions();
-        lostPausedPartitions.retainAll(lostPartitions);
-        if (!lostPausedPartitions.isEmpty())
-            log.info("The pause flag in partitions [{}] will be removed due to partition lost.", Utils.join(lostPausedPartitions, ", "));
-
-        ConsumerRebalanceListener listener = subscriptions.rebalanceListener();
-        try {
-            final long startMs = time.milliseconds();
-            listener.onPartitionsLost(lostPartitions);
-            sensors.loseCallbackSensor.record(time.milliseconds() - startMs);
-        } catch (WakeupException | InterruptException e) {
-            throw e;
-        } catch (Exception e) {
-            log.error("User provided listener {} failed on invocation of onPartitionsLost for partitions {}",
-                listener.getClass().getName(), lostPartitions, e);
-            return e;
-        }
-
-        return null;
-    }
-
     @Override
     protected void onJoinComplete(int generation,
                                   String memberId,
@@ -453,7 +391,7 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
                 // Revoke partitions that were previously owned but no longer assigned;
                 // note that we should only change the assignment (or update the assignor's state)
                 // AFTER we've triggered  the revoke callback
-                firstException.compareAndSet(null, invokePartitionsRevoked(revokedPartitions));
+                firstException.compareAndSet(null, rebalanceListenerInvoker.invokePartitionsRevoked(revokedPartitions));
 
                 // If revoked any partitions, need to re-join the group afterwards
                 final String fullReason = String.format("need to revoke partitions %s as indicated " +
@@ -476,7 +414,7 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
         subscriptions.assignFromSubscribed(assignedPartitions);
 
         // Add partitions that were not previously owned but are now assigned
-        firstException.compareAndSet(null, invokePartitionsAssigned(addedPartitions));
+        firstException.compareAndSet(null, rebalanceListenerInvoker.invokePartitionsAssigned(addedPartitions));
 
         if (firstException.get() != null) {
             if (firstException.get() instanceof KafkaException) {
@@ -831,7 +769,7 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
             if (!revokedPartitions.isEmpty()) {
                 log.info("Giving away all assigned partitions as lost since generation/memberID has been reset," +
                     "indicating that consumer is in old state or no longer part of the group");
-                exception = invokePartitionsLost(revokedPartitions);
+                exception = rebalanceListenerInvoker.invokePartitionsLost(revokedPartitions);
 
                 subscriptions.assignFromSubscribed(Collections.emptySet());
             }
@@ -840,7 +778,7 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
                 case EAGER:
                     // revoke all partitions
                     revokedPartitions.addAll(subscriptions.assignedPartitions());
-                    exception = invokePartitionsRevoked(revokedPartitions);
+                    exception = rebalanceListenerInvoker.invokePartitionsRevoked(revokedPartitions);
 
                     subscriptions.assignFromSubscribed(Collections.emptySet());
 
@@ -854,7 +792,7 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
                         .collect(Collectors.toSet()));
 
                     if (!revokedPartitions.isEmpty()) {
-                        exception = invokePartitionsRevoked(revokedPartitions);
+                        exception = rebalanceListenerInvoker.invokePartitionsRevoked(revokedPartitions);
 
                         ownedPartitions.removeAll(revokedPartitions);
                         subscriptions.assignFromSubscribed(ownedPartitions);
@@ -908,9 +846,9 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
             if ((currentGeneration.generationId == Generation.NO_GENERATION.generationId ||
                 currentGeneration.memberId.equals(Generation.NO_GENERATION.memberId)) ||
                 rebalanceInProgress()) {
-                e = invokePartitionsLost(droppedPartitions);
+                e = rebalanceListenerInvoker.invokePartitionsLost(droppedPartitions);
             } else {
-                e = invokePartitionsRevoked(droppedPartitions);
+                e = rebalanceListenerInvoker.invokePartitionsRevoked(droppedPartitions);
             }
 
             subscriptions.assignFromSubscribed(Collections.emptySet());
@@ -1366,7 +1304,7 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
 
         @Override
         public void handle(OffsetCommitResponse commitResponse, RequestFuture<Void> future) {
-            sensors.commitSensor.record(response.requestLatencyMs());
+            coordinatorMetrics.commitSensor.record(response.requestLatencyMs());
             Set<String> unauthorizedTopics = new HashSet<>();
 
             for (OffsetCommitResponseData.OffsetCommitResponseTopic topic : commitResponse.data().topics()) {
@@ -1579,56 +1517,6 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
         }
     }
 
-    private class ConsumerCoordinatorMetrics {
-        private final String metricGrpName;
-        private final Sensor commitSensor;
-        private final Sensor revokeCallbackSensor;
-        private final Sensor assignCallbackSensor;
-        private final Sensor loseCallbackSensor;
-
-        private ConsumerCoordinatorMetrics(Metrics metrics, String metricGrpPrefix) {
-            this.metricGrpName = metricGrpPrefix + "-coordinator-metrics";
-
-            this.commitSensor = metrics.sensor("commit-latency");
-            this.commitSensor.add(metrics.metricName("commit-latency-avg",
-                this.metricGrpName,
-                "The average time taken for a commit request"), new Avg());
-            this.commitSensor.add(metrics.metricName("commit-latency-max",
-                this.metricGrpName,
-                "The max time taken for a commit request"), new Max());
-            this.commitSensor.add(createMeter(metrics, metricGrpName, "commit", "commit calls"));
-
-            this.revokeCallbackSensor = metrics.sensor("partition-revoked-latency");
-            this.revokeCallbackSensor.add(metrics.metricName("partition-revoked-latency-avg",
-                this.metricGrpName,
-                "The average time taken for a partition-revoked rebalance listener callback"), new Avg());
-            this.revokeCallbackSensor.add(metrics.metricName("partition-revoked-latency-max",
-                this.metricGrpName,
-                "The max time taken for a partition-revoked rebalance listener callback"), new Max());
-
-            this.assignCallbackSensor = metrics.sensor("partition-assigned-latency");
-            this.assignCallbackSensor.add(metrics.metricName("partition-assigned-latency-avg",
-                this.metricGrpName,
-                "The average time taken for a partition-assigned rebalance listener callback"), new Avg());
-            this.assignCallbackSensor.add(metrics.metricName("partition-assigned-latency-max",
-                this.metricGrpName,
-                "The max time taken for a partition-assigned rebalance listener callback"), new Max());
-
-            this.loseCallbackSensor = metrics.sensor("partition-lost-latency");
-            this.loseCallbackSensor.add(metrics.metricName("partition-lost-latency-avg",
-                this.metricGrpName,
-                "The average time taken for a partition-lost rebalance listener callback"), new Avg());
-            this.loseCallbackSensor.add(metrics.metricName("partition-lost-latency-max",
-                this.metricGrpName,
-                "The max time taken for a partition-lost rebalance listener callback"), new Max());
-
-            Measurable numParts = (config, now) -> subscriptions.numAssignedPartitions();
-            metrics.addMetric(metrics.metricName("assigned-partitions",
-                this.metricGrpName,
-                "The number of partitions currently assigned to this consumer"), numParts);
-        }
-    }
-
     private static class MetadataSnapshot {
         private final int version;
         private final Map<String, List<PartitionRackInfo>> partitionsPerTopic;
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorMetrics.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorMetrics.java
new file mode 100644
index 00000000000..378aded216c
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorMetrics.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.clients.consumer.internals;
+
+import org.apache.kafka.common.metrics.Measurable;
+import org.apache.kafka.common.metrics.Metrics;
+import org.apache.kafka.common.metrics.Sensor;
+import org.apache.kafka.common.metrics.stats.Avg;
+import org.apache.kafka.common.metrics.stats.Max;
+import org.apache.kafka.common.metrics.stats.Meter;
+import org.apache.kafka.common.metrics.stats.WindowedCount;
+
+class ConsumerCoordinatorMetrics {
+
+    final Sensor commitSensor;
+    final Sensor revokeCallbackSensor;
+    final Sensor assignCallbackSensor;
+    final Sensor loseCallbackSensor;
+
+    ConsumerCoordinatorMetrics(SubscriptionState subscriptions,
+                               Metrics metrics,
+                               String metricGrpPrefix) {
+        String metricGrpName = metricGrpPrefix + "-coordinator-metrics";
+
+        this.commitSensor = metrics.sensor("commit-latency");
+        this.commitSensor.add(metrics.metricName("commit-latency-avg",
+                metricGrpName,
+                "The average time taken for a commit request"), new Avg());
+        this.commitSensor.add(metrics.metricName("commit-latency-max",
+                metricGrpName,
+                "The max time taken for a commit request"), new Max());
+        this.commitSensor.add(new Meter(new WindowedCount(),
+                metrics.metricName("commit-rate", metricGrpName,
+                        "The number of commit calls per second"),
+                metrics.metricName("commit-total", metricGrpName,
+                        "The total number of commit calls")));
+
+        this.revokeCallbackSensor = metrics.sensor("partition-revoked-latency");
+        this.revokeCallbackSensor.add(metrics.metricName("partition-revoked-latency-avg",
+                metricGrpName,
+                "The average time taken for a partition-revoked rebalance listener callback"), new Avg());
+        this.revokeCallbackSensor.add(metrics.metricName("partition-revoked-latency-max",
+                metricGrpName,
+                "The max time taken for a partition-revoked rebalance listener callback"), new Max());
+
+        this.assignCallbackSensor = metrics.sensor("partition-assigned-latency");
+        this.assignCallbackSensor.add(metrics.metricName("partition-assigned-latency-avg",
+                metricGrpName,
+                "The average time taken for a partition-assigned rebalance listener callback"), new Avg());
+        this.assignCallbackSensor.add(metrics.metricName("partition-assigned-latency-max",
+                metricGrpName,
+                "The max time taken for a partition-assigned rebalance listener callback"), new Max());
+
+        this.loseCallbackSensor = metrics.sensor("partition-lost-latency");
+        this.loseCallbackSensor.add(metrics.metricName("partition-lost-latency-avg",
+                metricGrpName,
+                "The average time taken for a partition-lost rebalance listener callback"), new Avg());
+        this.loseCallbackSensor.add(metrics.metricName("partition-lost-latency-max",
+                metricGrpName,
+                "The max time taken for a partition-lost rebalance listener callback"), new Max());
+
+        Measurable numParts = (config, now) -> subscriptions.numAssignedPartitions();
+        metrics.addMetric(metrics.metricName("assigned-partitions",
+                metricGrpName,
+                "The number of partitions currently assigned to this consumer"), numParts);
+    }
+}
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerRebalanceListenerInvoker.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerRebalanceListenerInvoker.java
new file mode 100644
index 00000000000..d4527aa0b76
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerRebalanceListenerInvoker.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.clients.consumer.internals;
+
+import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.errors.InterruptException;
+import org.apache.kafka.common.errors.WakeupException;
+import org.apache.kafka.common.metrics.Sensor;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.common.utils.Utils;
+import org.slf4j.Logger;
+
+import java.util.Optional;
+import java.util.Set;
+import java.util.SortedSet;
+
+/**
+ * This class encapsulates the invocation of the callback methods defined in the {@link ConsumerRebalanceListener}
+ * interface. When consumer group partition assignment changes, these methods are invoked. This class wraps those
+ * callback calls with some logging, optional {@link Sensor} updates, etc.
+ */
+class ConsumerRebalanceListenerInvoker {
+
+    private final Logger log;
+    private final SubscriptionState subscriptions;
+    private final Time time;
+    private final ConsumerCoordinatorMetrics coordinatorMetrics;
+
+    ConsumerRebalanceListenerInvoker(LogContext logContext,
+                                     SubscriptionState subscriptions,
+                                     Time time,
+                                     ConsumerCoordinatorMetrics coordinatorMetrics) {
+        this.log = logContext.logger(getClass());
+        this.subscriptions = subscriptions;
+        this.time = time;
+        this.coordinatorMetrics = coordinatorMetrics;
+    }
+
+    Exception invokePartitionsAssigned(final SortedSet<TopicPartition> assignedPartitions) {
+        log.info("Adding newly assigned partitions: {}", Utils.join(assignedPartitions, ", "));
+
+        Optional<ConsumerRebalanceListener> listener = subscriptions.rebalanceListener();
+
+        if (listener.isPresent()) {
+            try {
+                final long startMs = time.milliseconds();
+                listener.get().onPartitionsAssigned(assignedPartitions);
+                coordinatorMetrics.assignCallbackSensor.record(time.milliseconds() - startMs);
+            } catch (WakeupException | InterruptException e) {
+                throw e;
+            } catch (Exception e) {
+                log.error("User provided listener {} failed on invocation of onPartitionsAssigned for partitions {}",
+                        listener.getClass().getName(), assignedPartitions, e);
+                return e;
+            }
+        }
+
+        return null;
+    }
+
+    Exception invokePartitionsRevoked(final SortedSet<TopicPartition> revokedPartitions) {
+        log.info("Revoke previously assigned partitions {}", Utils.join(revokedPartitions, ", "));
+        Set<TopicPartition> revokePausedPartitions = subscriptions.pausedPartitions();
+        revokePausedPartitions.retainAll(revokedPartitions);
+        if (!revokePausedPartitions.isEmpty())
+            log.info("The pause flag in partitions [{}] will be removed due to revocation.", Utils.join(revokePausedPartitions, ", "));
+
+        Optional<ConsumerRebalanceListener> listener = subscriptions.rebalanceListener();
+
+        if (listener.isPresent()) {
+            try {
+                final long startMs = time.milliseconds();
+                listener.get().onPartitionsRevoked(revokedPartitions);
+                coordinatorMetrics.revokeCallbackSensor.record(time.milliseconds() - startMs);
+            } catch (WakeupException | InterruptException e) {
+                throw e;
+            } catch (Exception e) {
+                log.error("User provided listener {} failed on invocation of onPartitionsRevoked for partitions {}",
+                        listener.getClass().getName(), revokedPartitions, e);
+                return e;
+            }
+        }
+
+        return null;
+    }
+
+    Exception invokePartitionsLost(final SortedSet<TopicPartition> lostPartitions) {
+        log.info("Lost previously assigned partitions {}", Utils.join(lostPartitions, ", "));
+        Set<TopicPartition> lostPausedPartitions = subscriptions.pausedPartitions();
+        lostPausedPartitions.retainAll(lostPartitions);
+        if (!lostPausedPartitions.isEmpty())
+            log.info("The pause flag in partitions [{}] will be removed due to partition lost.", Utils.join(lostPausedPartitions, ", "));
+
+        Optional<ConsumerRebalanceListener> listener = subscriptions.rebalanceListener();
+
+        if (listener.isPresent()) {
+            try {
+                final long startMs = time.milliseconds();
+                listener.get().onPartitionsLost(lostPartitions);
+                coordinatorMetrics.loseCallbackSensor.record(time.milliseconds() - startMs);
+            } catch (WakeupException | InterruptException e) {
+                throw e;
+            } catch (Exception e) {
+                log.error("User provided listener {} failed on invocation of onPartitionsLost for partitions {}",
+                        listener.getClass().getName(), lostPartitions, e);
+                return e;
+            }
+        }
+
+        return null;
+    }
+}
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/NoOpConsumerRebalanceListener.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/NoOpConsumerRebalanceListener.java
deleted file mode 100644
index a3acc834713..00000000000
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/NoOpConsumerRebalanceListener.java
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.kafka.clients.consumer.internals;
-
-import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
-import org.apache.kafka.common.TopicPartition;
-
-import java.util.Collection;
-
-public class NoOpConsumerRebalanceListener implements ConsumerRebalanceListener {
-
-    @Override
-    public void onPartitionsAssigned(Collection<TopicPartition> partitions) {}
-
-    @Override
-    public void onPartitionsRevoked(Collection<TopicPartition> partitions) {}
-
-}
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/PrototypeAsyncConsumer.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/PrototypeAsyncConsumer.java
index 949616daa85..6389eb416e8 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/PrototypeAsyncConsumer.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/PrototypeAsyncConsumer.java
@@ -783,42 +783,6 @@ public class PrototypeAsyncConsumer<K, V> implements Consumer<K, V> {
         return Collections.unmodifiableSet(subscriptions.subscription());
     }
 
-    @Override
-    public void subscribe(Collection<String> topics) {
-        subscribe(topics, new NoOpConsumerRebalanceListener());
-    }
-
-    @Override
-    public void subscribe(Collection<String> topics, ConsumerRebalanceListener callback) {
-        maybeThrowInvalidGroupIdException();
-        if (topics == null)
-            throw new IllegalArgumentException("Topic collection to subscribe to cannot be null");
-        if (topics.isEmpty()) {
-            // treat subscribing to empty topic list as the same as unsubscribing
-            unsubscribe();
-        } else {
-            for (String topic : topics) {
-                if (isBlank(topic))
-                    throw new IllegalArgumentException("Topic collection to subscribe to cannot contain null or empty topic");
-            }
-
-            throwIfNoAssignorsConfigured();
-
-            // Clear the buffered data which are not a part of newly assigned topics
-            final Set<TopicPartition> currentTopicPartitions = new HashSet<>();
-
-            for (TopicPartition tp : subscriptions.assignedPartitions()) {
-                if (topics.contains(tp.topic()))
-                    currentTopicPartitions.add(tp);
-            }
-
-            fetchBuffer.retainAll(currentTopicPartitions);
-            log.info("Subscribed to topic(s): {}", join(topics, ", "));
-            if (subscriptions.subscribe(new HashSet<>(topics), callback))
-                metadata.requestUpdateForNewTopics();
-        }
-    }
-
     @Override
     public void assign(Collection<TopicPartition> partitions) {
         if (partitions == null) {
@@ -858,20 +822,6 @@ public class PrototypeAsyncConsumer<K, V> implements Consumer<K, V> {
             applicationEventHandler.add(new NewTopicsMetadataUpdateRequestEvent());
     }
 
-    @Override
-    public void subscribe(Pattern pattern, ConsumerRebalanceListener listener) {
-        maybeThrowInvalidGroupIdException();
-        if (pattern == null || pattern.toString().isEmpty())
-            throw new IllegalArgumentException("Topic pattern to subscribe to cannot be " + (pattern == null ?
-                    "null" : "empty"));
-
-        throwIfNoAssignorsConfigured();
-        log.info("Subscribed to pattern: '{}'", pattern);
-        subscriptions.subscribe(pattern, listener);
-        updatePatternSubscription(metadata.fetch());
-        metadata.requestUpdateForNewTopics();
-    }
-
     /**
      * TODO: remove this when we implement the KIP-848 protocol.
      *
@@ -890,11 +840,6 @@ public class PrototypeAsyncConsumer<K, V> implements Consumer<K, V> {
             metadata.requestUpdateForNewTopics();
     }
 
-    @Override
-    public void subscribe(Pattern pattern) {
-        subscribe(pattern, new NoOpConsumerRebalanceListener());
-    }
-
     @Override
     public void unsubscribe() {
         fetchBuffer.retainAll(Collections.emptySet());
@@ -1081,4 +1026,74 @@ public class PrototypeAsyncConsumer<K, V> implements Consumer<K, V> {
         // logic
         return updateFetchPositions(timer);
     }
+
+    @Override
+    public void subscribe(Collection<String> topics) {
+        subscribeInternal(topics, Optional.empty());
+    }
+
+    @Override
+    public void subscribe(Collection<String> topics, ConsumerRebalanceListener listener) {
+        if (listener == null)
+            throw new IllegalArgumentException("RebalanceListener cannot be null");
+
+        subscribeInternal(topics, Optional.of(listener));
+    }
+
+    @Override
+    public void subscribe(Pattern pattern) {
+        subscribeInternal(pattern, Optional.empty());
+    }
+
+    @Override
+    public void subscribe(Pattern pattern, ConsumerRebalanceListener listener) {
+        if (listener == null)
+            throw new IllegalArgumentException("RebalanceListener cannot be null");
+
+        subscribeInternal(pattern, Optional.of(listener));
+    }
+
+    private void subscribeInternal(Pattern pattern, Optional<ConsumerRebalanceListener> listener) {
+        maybeThrowInvalidGroupIdException();
+        if (pattern == null || pattern.toString().isEmpty())
+            throw new IllegalArgumentException("Topic pattern to subscribe to cannot be " + (pattern == null ?
+                    "null" : "empty"));
+
+        throwIfNoAssignorsConfigured();
+        log.info("Subscribed to pattern: '{}'", pattern);
+        subscriptions.subscribe(pattern, listener);
+        updatePatternSubscription(metadata.fetch());
+        metadata.requestUpdateForNewTopics();
+    }
+
+    private void subscribeInternal(Collection<String> topics, Optional<ConsumerRebalanceListener> listener) {
+        maybeThrowInvalidGroupIdException();
+        if (topics == null)
+            throw new IllegalArgumentException("Topic collection to subscribe to cannot be null");
+        if (topics.isEmpty()) {
+            // treat subscribing to empty topic list as the same as unsubscribing
+            unsubscribe();
+        } else {
+            for (String topic : topics) {
+                if (isBlank(topic))
+                    throw new IllegalArgumentException("Topic collection to subscribe to cannot contain null or empty topic");
+            }
+
+            throwIfNoAssignorsConfigured();
+
+            // Clear the buffered data which are not a part of newly assigned topics
+            final Set<TopicPartition> currentTopicPartitions = new HashSet<>();
+
+            for (TopicPartition tp : subscriptions.assignedPartitions()) {
+                if (topics.contains(tp.topic()))
+                    currentTopicPartitions.add(tp);
+            }
+
+            fetchBuffer.retainAll(currentTopicPartitions);
+            log.info("Subscribed to topic(s): {}", join(topics, ", "));
+            if (subscriptions.subscribe(new HashSet<>(topics), listener))
+                metadata.requestUpdateForNewTopics();
+        }
+    }
+
 }
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java
index fdf89944d68..edac65fcd41 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java
@@ -99,7 +99,7 @@ public class SubscriptionState {
     private final OffsetResetStrategy defaultResetStrategy;
 
     /* User-provided listener to be invoked when assignment changes */
-    private ConsumerRebalanceListener rebalanceListener;
+    private Optional<ConsumerRebalanceListener> rebalanceListener;
 
     private int assignmentId = 0;
 
@@ -162,13 +162,13 @@ public class SubscriptionState {
             throw new IllegalStateException(SUBSCRIPTION_EXCEPTION_MESSAGE);
     }
 
-    public synchronized boolean subscribe(Set<String> topics, ConsumerRebalanceListener listener) {
+    public synchronized boolean subscribe(Set<String> topics, Optional<ConsumerRebalanceListener> listener) {
         registerRebalanceListener(listener);
         setSubscriptionType(SubscriptionType.AUTO_TOPICS);
         return changeSubscription(topics);
     }
 
-    public synchronized void subscribe(Pattern pattern, ConsumerRebalanceListener listener) {
+    public synchronized void subscribe(Pattern pattern, Optional<ConsumerRebalanceListener> listener) {
         registerRebalanceListener(listener);
         setSubscriptionType(SubscriptionType.AUTO_PATTERN);
         this.subscribedPattern = pattern;
@@ -285,10 +285,8 @@ public class SubscriptionState {
         this.assignment.set(assignedPartitionStates);
     }
 
-    private void registerRebalanceListener(ConsumerRebalanceListener listener) {
-        if (listener == null)
-            throw new IllegalArgumentException("RebalanceListener cannot be null");
-        this.rebalanceListener = listener;
+    private void registerRebalanceListener(Optional<ConsumerRebalanceListener> listener) {
+        this.rebalanceListener = Objects.requireNonNull(listener, "RebalanceListener cannot be null");
     }
 
     /**
@@ -764,7 +762,7 @@ public class SubscriptionState {
         assignment.moveToEnd(tp);
     }
 
-    public synchronized ConsumerRebalanceListener rebalanceListener() {
+    public synchronized Optional<ConsumerRebalanceListener> rebalanceListener() {
         return rebalanceListener;
     }
 
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
index cc0b1294fc2..40995b7f4f4 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
@@ -504,7 +504,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testGroupReadUnauthorized() {
-        subscriptions.subscribe(singleton(topic1), rebalanceListener);
+        subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
         coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
@@ -728,7 +728,7 @@ public abstract class ConsumerCoordinatorTest {
         coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // illegal_generation will cause re-partition
-        subscriptions.subscribe(singleton(topic1), rebalanceListener);
+        subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
         subscriptions.assignFromSubscribed(Collections.singletonList(t1p));
 
         time.sleep(sessionTimeoutMs);
@@ -756,7 +756,7 @@ public abstract class ConsumerCoordinatorTest {
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
         coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
-        subscriptions.subscribe(singleton(topic1), rebalanceListener);
+        subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
         ByteBuffer buffer = ConsumerProtocol.serializeAssignment(
             new ConsumerPartitionAssignor.Assignment(Collections.singletonList(t1p), ByteBuffer.wrap(new byte[0])));
         coordinator.onJoinComplete(1, "memberId", partitionAssignor.name(), buffer);
@@ -871,7 +871,7 @@ public abstract class ConsumerCoordinatorTest {
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
         coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
-        subscriptions.subscribe(singleton(topic1), rebalanceListener);
+        subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
         ByteBuffer buffer = ConsumerProtocol.serializeAssignment(
             new ConsumerPartitionAssignor.Assignment(Collections.singletonList(t1p), ByteBuffer.wrap(new byte[0])));
         subscriptions.assignFromSubscribed(singleton(t2p));
@@ -898,7 +898,7 @@ public abstract class ConsumerCoordinatorTest {
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
         coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
-        subscriptions.subscribe(singleton(topic1), rebalanceListener);
+        subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
         subscriptions.assignFromSubscribed(Collections.singletonList(t1p));
 
         coordinator.onLeavePrepare();
@@ -912,7 +912,7 @@ public abstract class ConsumerCoordinatorTest {
         coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
         // illegal_generation will cause re-partition
-        subscriptions.subscribe(singleton(topic1), rebalanceListener);
+        subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
         subscriptions.assignFromSubscribed(Collections.singletonList(t1p));
 
         time.sleep(sessionTimeoutMs);
@@ -960,7 +960,7 @@ public abstract class ConsumerCoordinatorTest {
     public void testJoinGroupInvalidGroupId() {
         final String consumerId = "leader";
 
-        subscriptions.subscribe(singleton(topic1), rebalanceListener);
+        subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
 
         // ensure metadata is up-to-date for leader
         client.updateMetadata(metadataResponse);
@@ -980,7 +980,7 @@ public abstract class ConsumerCoordinatorTest {
         final List<TopicPartition> owned = Collections.emptyList();
         final List<TopicPartition> assigned = Arrays.asList(t1p);
 
-        subscriptions.subscribe(singleton(topic1), rebalanceListener);
+        subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
 
         // ensure metadata is up-to-date for leader
         client.updateMetadata(metadataResponse);
@@ -1019,7 +1019,7 @@ public abstract class ConsumerCoordinatorTest {
         final List<String> newSubscription = singletonList(topic1);
         final List<TopicPartition> newAssignment = Arrays.asList(t1p);
 
-        subscriptions.subscribe(toSet(oldSubscription), rebalanceListener);
+        subscriptions.subscribe(toSet(oldSubscription), Optional.of(rebalanceListener));
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
         coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
@@ -1053,7 +1053,7 @@ public abstract class ConsumerCoordinatorTest {
         coordinator.poll(time.timer(0));
 
         // Before the sync group response gets completed change the subscription
-        subscriptions.subscribe(toSet(newSubscription), rebalanceListener);
+        subscriptions.subscribe(toSet(newSubscription), Optional.of(rebalanceListener));
         coordinator.poll(time.timer(0));
 
         coordinator.poll(time.timer(Long.MAX_VALUE));
@@ -1076,7 +1076,7 @@ public abstract class ConsumerCoordinatorTest {
         final List<String> newSubscription = singletonList(topic2);
         final List<TopicPartition> newAssignment = Collections.singletonList(t2p);
 
-        subscriptions.subscribe(toSet(oldSubscription), rebalanceListener);
+        subscriptions.subscribe(toSet(oldSubscription), Optional.of(rebalanceListener));
         assertEquals(toSet(oldSubscription), subscriptions.metadataTopics());
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
@@ -1087,7 +1087,7 @@ public abstract class ConsumerCoordinatorTest {
         coordinator.poll(time.timer(0));
         assertEquals(toSet(oldSubscription), subscriptions.metadataTopics());
 
-        subscriptions.subscribe(toSet(newSubscription), rebalanceListener);
+        subscriptions.subscribe(toSet(newSubscription), Optional.of(rebalanceListener));
         assertEquals(Utils.mkSet(topic1, topic2), subscriptions.metadataTopics());
 
         prepareJoinAndSyncResponse(consumerId, 2, newSubscription, newAssignment);
@@ -1103,7 +1103,7 @@ public abstract class ConsumerCoordinatorTest {
         final List<TopicPartition> assigned = Arrays.asList(t1p, t2p);
         final List<TopicPartition> owned = Collections.emptyList();
 
-        subscriptions.subscribe(Pattern.compile("test.*"), rebalanceListener);
+        subscriptions.subscribe(Pattern.compile("test.*"), Optional.of(rebalanceListener));
 
         // partially update the metadata with one topic first,
         // let the leader to refresh metadata during assignment
@@ -1144,7 +1144,7 @@ public abstract class ConsumerCoordinatorTest {
         final String consumerId = "leader";
         final List<TopicPartition> owned = Collections.emptyList();
         final List<TopicPartition> oldAssigned = singletonList(t1p);
-        subscriptions.subscribe(Pattern.compile(".*"), rebalanceListener);
+        subscriptions.subscribe(Pattern.compile(".*"), Optional.of(rebalanceListener));
         client.updateMetadata(RequestTestUtils.metadataUpdateWith(1, singletonMap(topic1, 1)));
         coordinator.maybeUpdateSubscriptionMetadata();
 
@@ -1254,7 +1254,7 @@ public abstract class ConsumerCoordinatorTest {
     public void testForceMetadataRefreshForPatternSubscriptionDuringRebalance() {
         // Set up a non-leader consumer with pattern subscription and a cluster containing one topic matching the
         // pattern.
-        subscriptions.subscribe(Pattern.compile(".*"), rebalanceListener);
+        subscriptions.subscribe(Pattern.compile(".*"), Optional.of(rebalanceListener));
         client.updateMetadata(RequestTestUtils.metadataUpdateWith(1, singletonMap(topic1, 1)));
         coordinator.maybeUpdateSubscriptionMetadata();
         assertEquals(singleton(topic1), subscriptions.subscription());
@@ -1292,7 +1292,7 @@ public abstract class ConsumerCoordinatorTest {
     @Test
     public void testForceMetadataDeleteForPatternSubscriptionDuringRebalance() {
         try (ConsumerCoordinator coordinator = buildCoordinator(rebalanceConfig, new Metrics(), assignors, true, subscriptions)) {
-            subscriptions.subscribe(Pattern.compile("test.*"), rebalanceListener);
+            subscriptions.subscribe(Pattern.compile("test.*"), Optional.of(rebalanceListener));
             client.updateMetadata(RequestTestUtils.metadataUpdateWith(1, new HashMap<String, Integer>() {
                 {
                     put(topic1, 1);
@@ -1580,7 +1580,7 @@ public abstract class ConsumerCoordinatorTest {
         final List<TopicPartition> partitions = metadataResponse1.topicMetadata().stream()
                 .flatMap(t -> t.partitionMetadata().stream().map(p -> new TopicPartition(t.topic(), p.partition())))
                 .collect(Collectors.toList());
-        subscriptions.subscribe(toSet(topics), rebalanceListener);
+        subscriptions.subscribe(toSet(topics), Optional.of(rebalanceListener));
         client.updateMetadata(metadataResponse1);
         coordinator.maybeUpdateSubscriptionMetadata();
 
@@ -1652,7 +1652,7 @@ public abstract class ConsumerCoordinatorTest {
         final List<TopicPartition> owned = Collections.emptyList();
         final List<TopicPartition> assigned = singletonList(t1p);
 
-        subscriptions.subscribe(singleton(topic1), rebalanceListener);
+        subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
 
         // ensure metadata is up-to-date for leader
         client.updateMetadata(metadataResponse);
@@ -1691,7 +1691,7 @@ public abstract class ConsumerCoordinatorTest {
         final List<TopicPartition> owned = Collections.emptyList();
         final List<TopicPartition> assigned = singletonList(t1p);
 
-        subscriptions.subscribe(subscription, rebalanceListener);
+        subscriptions.subscribe(subscription, Optional.of(rebalanceListener));
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
         coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
@@ -1720,7 +1720,7 @@ public abstract class ConsumerCoordinatorTest {
     public void testUpdateLastHeartbeatPollWhenCoordinatorUnknown() throws Exception {
         // If we are part of an active group and we cannot find the coordinator, we should nevertheless
         // continue to update the last poll time so that we do not expire the consumer
-        subscriptions.subscribe(singleton(topic1), rebalanceListener);
+        subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
         coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
@@ -1751,7 +1751,7 @@ public abstract class ConsumerCoordinatorTest {
         final List<TopicPartition> owned = Collections.emptyList();
         final List<TopicPartition> assigned = Arrays.asList(t1p, t2p);
 
-        subscriptions.subscribe(Pattern.compile("test.*"), rebalanceListener);
+        subscriptions.subscribe(Pattern.compile("test.*"), Optional.of(rebalanceListener));
 
         // partially update the metadata with one topic first,
         // let the leader to refresh metadata during assignment
@@ -1785,7 +1785,7 @@ public abstract class ConsumerCoordinatorTest {
     @Test
     public void testLeaveGroupOnClose() {
 
-        subscriptions.subscribe(singleton(topic1), rebalanceListener);
+        subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
         joinAsFollowerAndReceiveAssignment(coordinator, singletonList(t1p));
 
         final AtomicBoolean received = new AtomicBoolean(false);
@@ -1801,7 +1801,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testMaybeLeaveGroup() {
-        subscriptions.subscribe(singleton(topic1), rebalanceListener);
+        subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
         joinAsFollowerAndReceiveAssignment(coordinator, singletonList(t1p));
 
         final AtomicBoolean received = new AtomicBoolean(false);
@@ -1836,7 +1836,7 @@ public abstract class ConsumerCoordinatorTest {
     @Test
     public void testPendingMemberShouldLeaveGroup() {
         final String consumerId = "consumer-id";
-        subscriptions.subscribe(singleton(topic1), rebalanceListener);
+        subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
         coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
@@ -1860,7 +1860,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testUnexpectedErrorOnSyncGroup() {
-        subscriptions.subscribe(singleton(topic1), rebalanceListener);
+        subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
         coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
@@ -1873,7 +1873,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testUnknownMemberIdOnSyncGroup() {
-        subscriptions.subscribe(singleton(topic1), rebalanceListener);
+        subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
         coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
@@ -1897,7 +1897,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testRebalanceInProgressOnSyncGroup() {
-        subscriptions.subscribe(singleton(topic1), rebalanceListener);
+        subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
         coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
@@ -1918,7 +1918,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testIllegalGenerationOnSyncGroup() {
-        subscriptions.subscribe(singleton(topic1), rebalanceListener);
+        subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
         coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
@@ -1944,7 +1944,7 @@ public abstract class ConsumerCoordinatorTest {
     @Test
     public void testMetadataChangeTriggersRebalance() {
         // ensure metadata is up-to-date for leader
-        subscriptions.subscribe(singleton(topic1), rebalanceListener);
+        subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
         client.updateMetadata(metadataResponse);
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
@@ -1972,7 +1972,7 @@ public abstract class ConsumerCoordinatorTest {
     @Test
     public void testStaticLeaderRejoinsGroupAndCanTriggersRebalance() {
         // ensure metadata is up-to-date for leader
-        subscriptions.subscribe(singleton(topic1), rebalanceListener);
+        subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
         client.updateMetadata(metadataResponse);
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
@@ -2001,7 +2001,7 @@ public abstract class ConsumerCoordinatorTest {
     @Test
     public void testStaticLeaderRejoinsGroupAndCanDetectMetadataChangesForOtherMembers() {
         // ensure metadata is up-to-date for leader
-        subscriptions.subscribe(singleton(topic1), rebalanceListener);
+        subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
         client.updateMetadata(metadataResponse);
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
@@ -2040,7 +2040,7 @@ public abstract class ConsumerCoordinatorTest {
 
         List<String> topics = Arrays.asList(topic1, topic2);
 
-        subscriptions.subscribe(new HashSet<>(topics), rebalanceListener);
+        subscriptions.subscribe(new HashSet<>(topics), Optional.of(rebalanceListener));
 
         // we only have metadata for one topic initially
         client.updateMetadata(RequestTestUtils.metadataUpdateWith(1, singletonMap(topic1, 1)));
@@ -2086,7 +2086,7 @@ public abstract class ConsumerCoordinatorTest {
     @Test
     public void testSubscriptionChangeWithAuthorizationFailure() {
         // Subscribe to two topics of which only one is authorized and verify that metadata failure is propagated.
-        subscriptions.subscribe(Utils.mkSet(topic1, topic2), rebalanceListener);
+        subscriptions.subscribe(Utils.mkSet(topic1, topic2), Optional.of(rebalanceListener));
         client.prepareMetadataUpdate(RequestTestUtils.metadataUpdateWith("kafka-cluster", 1,
                 Collections.singletonMap(topic2, Errors.TOPIC_AUTHORIZATION_FAILED), singletonMap(topic1, 1)));
         assertThrows(TopicAuthorizationException.class, () -> coordinator.poll(time.timer(Long.MAX_VALUE)));
@@ -2101,7 +2101,7 @@ public abstract class ConsumerCoordinatorTest {
 
         // Change subscription to include only the authorized topic. Complete rebalance and check that
         // references to topic2 have been removed from SubscriptionState.
-        subscriptions.subscribe(Utils.mkSet(topic1), rebalanceListener);
+        subscriptions.subscribe(Utils.mkSet(topic1), Optional.of(rebalanceListener));
         assertEquals(Collections.singleton(topic1), subscriptions.metadataTopics());
         client.prepareMetadataUpdate(RequestTestUtils.metadataUpdateWith("kafka-cluster", 1,
                 Collections.emptyMap(), singletonMap(topic1, 1)));
@@ -2133,7 +2133,7 @@ public abstract class ConsumerCoordinatorTest {
             }
         };
 
-        subscriptions.subscribe(topics, rebalanceListener);
+        subscriptions.subscribe(topics, Optional.of(rebalanceListener));
 
         // we only have metadata for one topic initially
         client.updateMetadata(RequestTestUtils.metadataUpdateWith(1, singletonMap(topic1, 1)));
@@ -2179,9 +2179,9 @@ public abstract class ConsumerCoordinatorTest {
 
     private void unavailableTopicTest(boolean patternSubscribe, Set<String> unavailableTopicsInLastMetadata) {
         if (patternSubscribe)
-            subscriptions.subscribe(Pattern.compile("test.*"), rebalanceListener);
+            subscriptions.subscribe(Pattern.compile("test.*"), Optional.of(rebalanceListener));
         else
-            subscriptions.subscribe(singleton(topic1), rebalanceListener);
+            subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
 
         client.prepareMetadataUpdate(RequestTestUtils.metadataUpdateWith("kafka-cluster", 1,
                 Collections.singletonMap(topic1, Errors.UNKNOWN_TOPIC_OR_PARTITION), Collections.emptyMap()));
@@ -2232,7 +2232,7 @@ public abstract class ConsumerCoordinatorTest {
                 false, subscriptions, new LogContext(), new ClusterResourceListeners());
         client = new MockClient(time, metadata);
         try (ConsumerCoordinator coordinator = buildCoordinator(rebalanceConfig, new Metrics(), assignors, false, subscriptions)) {
-            subscriptions.subscribe(Pattern.compile(".*"), rebalanceListener);
+            subscriptions.subscribe(Pattern.compile(".*"), Optional.of(rebalanceListener));
             Node node = new Node(0, "localhost", 9999);
             MetadataResponse.PartitionMetadata partitionMetadata =
                 new MetadataResponse.PartitionMetadata(Errors.NONE, new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, 0),
@@ -2255,7 +2255,7 @@ public abstract class ConsumerCoordinatorTest {
         final List<TopicPartition> owned = Collections.emptyList();
         final List<TopicPartition> assigned = Arrays.asList(t1p);
 
-        subscriptions.subscribe(singleton(topic1), rebalanceListener);
+        subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
 
         // join the group once
         joinAsFollowerAndReceiveAssignment(coordinator, assigned);
@@ -2268,7 +2268,7 @@ public abstract class ConsumerCoordinatorTest {
         // and join the group again
         rebalanceListener.revoked = null;
         rebalanceListener.assigned = null;
-        subscriptions.subscribe(new HashSet<>(Arrays.asList(topic1, otherTopic)), rebalanceListener);
+        subscriptions.subscribe(new HashSet<>(Arrays.asList(topic1, otherTopic)), Optional.of(rebalanceListener));
         client.prepareResponse(joinGroupFollowerResponse(2, consumerId, "leader", Errors.NONE));
         client.prepareResponse(syncGroupResponse(assigned, Errors.NONE));
         coordinator.joinGroupIfNeeded(time.timer(Long.MAX_VALUE));
@@ -2283,7 +2283,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testDisconnectInJoin() {
-        subscriptions.subscribe(singleton(topic1), rebalanceListener);
+        subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
         final List<TopicPartition> owned = Collections.emptyList();
         final List<TopicPartition> assigned = Arrays.asList(t1p);
 
@@ -2308,7 +2308,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testInvalidSessionTimeout() {
-        subscriptions.subscribe(singleton(topic1), rebalanceListener);
+        subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
         coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
@@ -2370,7 +2370,7 @@ public abstract class ConsumerCoordinatorTest {
     @Test
     public void testAutoCommitDynamicAssignment() {
         try (ConsumerCoordinator coordinator = buildCoordinator(rebalanceConfig, new Metrics(), assignors, true, subscriptions)) {
-            subscriptions.subscribe(singleton(topic1), rebalanceListener);
+            subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
             joinAsFollowerAndReceiveAssignment(coordinator, singletonList(t1p));
             subscriptions.seek(t1p, 100);
             prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NONE);
@@ -2383,7 +2383,7 @@ public abstract class ConsumerCoordinatorTest {
     @Test
     public void testAutoCommitRetryBackoff() {
         try (ConsumerCoordinator coordinator = buildCoordinator(rebalanceConfig, new Metrics(), assignors, true, subscriptions)) {
-            subscriptions.subscribe(singleton(topic1), rebalanceListener);
+            subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
             joinAsFollowerAndReceiveAssignment(coordinator, singletonList(t1p));
 
             subscriptions.seek(t1p, 100);
@@ -2416,7 +2416,7 @@ public abstract class ConsumerCoordinatorTest {
     @Test
     public void testAutoCommitAwaitsInterval() {
         try (ConsumerCoordinator coordinator = buildCoordinator(rebalanceConfig, new Metrics(), assignors, true, subscriptions)) {
-            subscriptions.subscribe(singleton(topic1), rebalanceListener);
+            subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
             joinAsFollowerAndReceiveAssignment(coordinator, singletonList(t1p));
 
             subscriptions.seek(t1p, 100);
@@ -2453,7 +2453,7 @@ public abstract class ConsumerCoordinatorTest {
     @Test
     public void testAutoCommitDynamicAssignmentRebalance() {
         try (ConsumerCoordinator coordinator = buildCoordinator(rebalanceConfig, new Metrics(), assignors, true, subscriptions)) {
-            subscriptions.subscribe(singleton(topic1), rebalanceListener);
+            subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
             client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
             coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
@@ -2562,7 +2562,7 @@ public abstract class ConsumerCoordinatorTest {
     @Test
     public void testCommitAfterLeaveGroup() {
         // enable auto-assignment
-        subscriptions.subscribe(singleton(topic1), rebalanceListener);
+        subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
 
         joinAsFollowerAndReceiveAssignment(coordinator, singletonList(t1p));
 
@@ -2802,7 +2802,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testCommitOffsetIllegalGenerationShouldResetGenerationId() {
-        subscriptions.subscribe(singleton(topic1), rebalanceListener);
+        subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
         coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
@@ -2905,7 +2905,7 @@ public abstract class ConsumerCoordinatorTest {
 
     @Test
     public void testCommitOffsetUnknownMemberShouldResetToNoGeneration() {
-        subscriptions.subscribe(singleton(topic1), rebalanceListener);
+        subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
         coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
 
@@ -3013,7 +3013,7 @@ public abstract class ConsumerCoordinatorTest {
         // we cannot retry if a rebalance occurs before the commit completed
         final String consumerId = "leader";
 
-        subscriptions.subscribe(singleton(topic1), rebalanceListener);
+        subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
 
         // ensure metadata is up-to-date for leader
         client.updateMetadata(metadataResponse);
@@ -3508,7 +3508,7 @@ public abstract class ConsumerCoordinatorTest {
             }
         };
 
-        subscriptions.subscribe(singleton(topic1), rebalanceListener);
+        subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
         {
             ByteBuffer buffer = ConsumerProtocol.serializeAssignment(
                 new ConsumerPartitionAssignor.Assignment(Collections.singletonList(t1p), ByteBuffer.wrap(new byte[0])));
@@ -3660,7 +3660,7 @@ public abstract class ConsumerCoordinatorTest {
         RackAwareAssignor assignor = new RackAwareAssignor(protocol);
         createRackAwareCoordinator(rackId, assignor);
 
-        subscriptions.subscribe(singleton(topic1), rebalanceListener);
+        subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
         client.updateMetadata(metadataResponse);
 
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
@@ -3756,7 +3756,7 @@ public abstract class ConsumerCoordinatorTest {
         client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
         coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
         if (useGroupManagement) {
-            subscriptions.subscribe(singleton(topic1), rebalanceListener);
+            subscriptions.subscribe(singleton(topic1), Optional.of(rebalanceListener));
             client.prepareResponse(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NONE));
             client.prepareResponse(syncGroupResponse(singletonList(t1p), Errors.NONE));
             coordinator.joinGroupIfNeeded(time.timer(Long.MAX_VALUE));
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerMetadataTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerMetadataTest.java
index 4184c459bd2..a6864a5c7b3 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerMetadataTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerMetadataTest.java
@@ -66,7 +66,7 @@ public class ConsumerMetadataTest {
     }
 
     private void testPatternSubscription(boolean includeInternalTopics) {
-        subscription.subscribe(Pattern.compile("__.*"), new NoOpConsumerRebalanceListener());
+        subscription.subscribe(Pattern.compile("__.*"), Optional.empty());
         ConsumerMetadata metadata = newConsumerMetadata(includeInternalTopics);
 
         MetadataRequest.Builder builder = metadata.newMetadataRequestBuilder();
@@ -103,7 +103,7 @@ public class ConsumerMetadataTest {
 
     @Test
     public void testNormalSubscription() {
-        subscription.subscribe(Utils.mkSet("foo", "bar", "__consumer_offsets"), new NoOpConsumerRebalanceListener());
+        subscription.subscribe(Utils.mkSet("foo", "bar", "__consumer_offsets"), Optional.empty());
         subscription.groupSubscribe(Utils.mkSet("baz", "foo", "bar", "__consumer_offsets"));
         testBasicSubscription(Utils.mkSet("foo", "bar", "baz"), Utils.mkSet("__consumer_offsets"));
 
@@ -115,7 +115,7 @@ public class ConsumerMetadataTest {
     public void testTransientTopics() {
         Map<String, Uuid> topicIds = new HashMap<>();
         topicIds.put("foo", Uuid.randomUuid());
-        subscription.subscribe(singleton("foo"), new NoOpConsumerRebalanceListener());
+        subscription.subscribe(singleton("foo"), Optional.empty());
         ConsumerMetadata metadata = newConsumerMetadata(false);
         metadata.updateWithCurrentRequestVersion(RequestTestUtils.metadataUpdateWithIds(1, singletonMap("foo", 1), topicIds), false, time.milliseconds());
         assertEquals(topicIds.get("foo"), metadata.topicIds().get("foo"));
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetchRequestManagerTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetchRequestManagerTest.java
index c4270c3f4ca..88de3355472 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetchRequestManagerTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetchRequestManagerTest.java
@@ -25,7 +25,6 @@ import org.apache.kafka.clients.MockClient;
 import org.apache.kafka.clients.NetworkClient;
 import org.apache.kafka.clients.NodeApiVersions;
 import org.apache.kafka.clients.consumer.ConsumerConfig;
-import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.clients.consumer.OffsetOutOfRangeException;
 import org.apache.kafka.clients.consumer.OffsetResetStrategy;
@@ -150,7 +149,6 @@ public class FetchRequestManagerTest {
 
     private static final double EPSILON = 0.0001;
 
-    private ConsumerRebalanceListener listener = new NoOpConsumerRebalanceListener();
     private String topicName = "test";
     private String groupId = "test-group";
     private Uuid topicId = Uuid.randomUuid();
@@ -1270,7 +1268,7 @@ public class FetchRequestManagerTest {
     public void testFetchDuringEagerRebalance() {
         buildFetcher();
 
-        subscriptions.subscribe(singleton(topicName), listener);
+        subscriptions.subscribe(singleton(topicName), Optional.empty());
         subscriptions.assignFromSubscribed(singleton(tp0));
         subscriptions.seek(tp0, 0);
 
@@ -1294,7 +1292,7 @@ public class FetchRequestManagerTest {
     public void testFetchDuringCooperativeRebalance() {
         buildFetcher();
 
-        subscriptions.subscribe(singleton(topicName), listener);
+        subscriptions.subscribe(singleton(topicName), Optional.empty());
         subscriptions.assignFromSubscribed(singleton(tp0));
         subscriptions.seek(tp0, 0);
 
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
index a66d153f098..81551b95239 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
@@ -24,7 +24,6 @@ import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.MockClient;
 import org.apache.kafka.clients.NetworkClient;
 import org.apache.kafka.clients.NodeApiVersions;
-import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.clients.consumer.OffsetOutOfRangeException;
 import org.apache.kafka.clients.consumer.OffsetResetStrategy;
@@ -147,7 +146,6 @@ import static org.mockito.Mockito.verify;
 public class FetcherTest {
     private static final double EPSILON = 0.0001;
 
-    private ConsumerRebalanceListener listener = new NoOpConsumerRebalanceListener();
     private String topicName = "test";
     private String groupId = "test-group";
     private Uuid topicId = Uuid.randomUuid();
@@ -1268,7 +1266,7 @@ public class FetcherTest {
     public void testFetchDuringEagerRebalance() {
         buildFetcher();
 
-        subscriptions.subscribe(singleton(topicName), listener);
+        subscriptions.subscribe(singleton(topicName), Optional.empty());
         subscriptions.assignFromSubscribed(singleton(tp0));
         subscriptions.seek(tp0, 0);
 
@@ -1292,7 +1290,7 @@ public class FetcherTest {
     public void testFetchDuringCooperativeRebalance() {
         buildFetcher();
 
-        subscriptions.subscribe(singleton(topicName), listener);
+        subscriptions.subscribe(singleton(topicName), Optional.empty());
         subscriptions.assignFromSubscribed(singleton(tp0));
         subscriptions.seek(tp0, 0);
 
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/HeartbeatRequestManagerTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/HeartbeatRequestManagerTest.java
index 480a6242e40..56c0ba1d50d 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/HeartbeatRequestManagerTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/HeartbeatRequestManagerTest.java
@@ -210,7 +210,7 @@ public class HeartbeatRequestManagerTest {
         resetWithZeroHeartbeatInterval(Optional.of(DEFAULT_GROUP_INSTANCE_ID));
 
         List<String> subscribedTopics = Collections.singletonList("topic");
-        subscriptions.subscribe(new HashSet<>(subscribedTopics), new NoOpConsumerRebalanceListener());
+        subscriptions.subscribe(new HashSet<>(subscribedTopics), Optional.empty());
 
         // Update membershipManager's memberId and memberEpoch
         ConsumerGroupHeartbeatResponse result =
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/SubscriptionStateTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/SubscriptionStateTest.java
index 97c61616c75..7df06e3b3e9 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/SubscriptionStateTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/SubscriptionStateTest.java
@@ -91,7 +91,7 @@ public class SubscriptionStateTest {
         assertTrue(state.assignedPartitions().isEmpty());
         assertEquals(0, state.numAssignedPartitions());
 
-        state.subscribe(singleton(topic1), rebalanceListener);
+        state.subscribe(singleton(topic1), Optional.of(rebalanceListener));
         // assigned partitions should remain unchanged
         assertTrue(state.assignedPartitions().isEmpty());
         assertEquals(0, state.numAssignedPartitions());
@@ -102,7 +102,7 @@ public class SubscriptionStateTest {
         assertEquals(singleton(t1p0), state.assignedPartitions());
         assertEquals(1, state.numAssignedPartitions());
 
-        state.subscribe(singleton(topic), rebalanceListener);
+        state.subscribe(singleton(topic), Optional.of(rebalanceListener));
         // assigned partitions should remain unchanged
         assertEquals(singleton(t1p0), state.assignedPartitions());
         assertEquals(1, state.numAssignedPartitions());
@@ -115,7 +115,7 @@ public class SubscriptionStateTest {
 
     @Test
     public void testGroupSubscribe() {
-        state.subscribe(singleton(topic1), rebalanceListener);
+        state.subscribe(singleton(topic1), Optional.of(rebalanceListener));
         assertEquals(singleton(topic1), state.metadataTopics());
 
         assertFalse(state.groupSubscribe(singleton(topic1)));
@@ -128,7 +128,7 @@ public class SubscriptionStateTest {
         assertFalse(state.groupSubscribe(singleton(topic1)));
         assertEquals(singleton(topic1), state.metadataTopics());
 
-        state.subscribe(singleton("anotherTopic"), rebalanceListener);
+        state.subscribe(singleton("anotherTopic"), Optional.of(rebalanceListener));
         assertEquals(Utils.mkSet(topic1, "anotherTopic"), state.metadataTopics());
 
         assertFalse(state.groupSubscribe(singleton("anotherTopic")));
@@ -137,7 +137,7 @@ public class SubscriptionStateTest {
 
     @Test
     public void partitionAssignmentChangeOnPatternSubscription() {
-        state.subscribe(Pattern.compile(".*"), rebalanceListener);
+        state.subscribe(Pattern.compile(".*"), Optional.of(rebalanceListener));
         // assigned partitions should remain unchanged
         assertTrue(state.assignedPartitions().isEmpty());
         assertEquals(0, state.numAssignedPartitions());
@@ -163,7 +163,7 @@ public class SubscriptionStateTest {
         assertEquals(1, state.numAssignedPartitions());
         assertEquals(singleton(topic), state.subscription());
 
-        state.subscribe(Pattern.compile(".*t"), rebalanceListener);
+        state.subscribe(Pattern.compile(".*t"), Optional.of(rebalanceListener));
         // assigned partitions should remain unchanged
         assertEquals(singleton(t1p0), state.assignedPartitions());
         assertEquals(1, state.numAssignedPartitions());
@@ -200,7 +200,7 @@ public class SubscriptionStateTest {
         assertEquals(Collections.emptySet(), state.assignedPartitions());
 
         Set<TopicPartition> autoAssignment = Utils.mkSet(t1p0);
-        state.subscribe(singleton(topic1), rebalanceListener);
+        state.subscribe(singleton(topic1), Optional.of(rebalanceListener));
         assertTrue(state.checkAssignmentMatchedSubscription(autoAssignment));
         state.assignFromSubscribed(autoAssignment);
         assertEquals(3, state.assignmentId());
@@ -225,7 +225,7 @@ public class SubscriptionStateTest {
 
     @Test
     public void topicSubscription() {
-        state.subscribe(singleton(topic), rebalanceListener);
+        state.subscribe(singleton(topic), Optional.of(rebalanceListener));
         assertEquals(1, state.subscription().size());
         assertTrue(state.assignedPartitions().isEmpty());
         assertEquals(0, state.numAssignedPartitions());
@@ -268,7 +268,7 @@ public class SubscriptionStateTest {
 
     @Test
     public void invalidPositionUpdate() {
-        state.subscribe(singleton(topic), rebalanceListener);
+        state.subscribe(singleton(topic), Optional.of(rebalanceListener));
         assertTrue(state.checkAssignmentMatchedSubscription(singleton(tp0)));
         state.assignFromSubscribed(singleton(tp0));
 
@@ -278,13 +278,13 @@ public class SubscriptionStateTest {
 
     @Test
     public void cantAssignPartitionForUnsubscribedTopics() {
-        state.subscribe(singleton(topic), rebalanceListener);
+        state.subscribe(singleton(topic), Optional.of(rebalanceListener));
         assertFalse(state.checkAssignmentMatchedSubscription(Collections.singletonList(t1p0)));
     }
 
     @Test
     public void cantAssignPartitionForUnmatchedPattern() {
-        state.subscribe(Pattern.compile(".*t"), rebalanceListener);
+        state.subscribe(Pattern.compile(".*t"), Optional.of(rebalanceListener));
         state.subscribeFromPattern(Collections.singleton(topic));
         assertFalse(state.checkAssignmentMatchedSubscription(Collections.singletonList(t1p0)));
     }
@@ -297,31 +297,31 @@ public class SubscriptionStateTest {
 
     @Test
     public void cantSubscribeTopicAndPattern() {
-        state.subscribe(singleton(topic), rebalanceListener);
-        assertThrows(IllegalStateException.class, () -> state.subscribe(Pattern.compile(".*"), rebalanceListener));
+        state.subscribe(singleton(topic), Optional.of(rebalanceListener));
+        assertThrows(IllegalStateException.class, () -> state.subscribe(Pattern.compile(".*"), Optional.of(rebalanceListener)));
     }
 
     @Test
     public void cantSubscribePartitionAndPattern() {
         state.assignFromUser(singleton(tp0));
-        assertThrows(IllegalStateException.class, () -> state.subscribe(Pattern.compile(".*"), rebalanceListener));
+        assertThrows(IllegalStateException.class, () -> state.subscribe(Pattern.compile(".*"), Optional.of(rebalanceListener)));
     }
 
     @Test
     public void cantSubscribePatternAndTopic() {
-        state.subscribe(Pattern.compile(".*"), rebalanceListener);
-        assertThrows(IllegalStateException.class, () -> state.subscribe(singleton(topic), rebalanceListener));
+        state.subscribe(Pattern.compile(".*"), Optional.of(rebalanceListener));
+        assertThrows(IllegalStateException.class, () -> state.subscribe(singleton(topic), Optional.of(rebalanceListener)));
     }
 
     @Test
     public void cantSubscribePatternAndPartition() {
-        state.subscribe(Pattern.compile(".*"), rebalanceListener);
+        state.subscribe(Pattern.compile(".*"), Optional.of(rebalanceListener));
         assertThrows(IllegalStateException.class, () -> state.assignFromUser(singleton(tp0)));
     }
 
     @Test
     public void patternSubscription() {
-        state.subscribe(Pattern.compile(".*"), rebalanceListener);
+        state.subscribe(Pattern.compile(".*"), Optional.of(rebalanceListener));
         state.subscribeFromPattern(new HashSet<>(Arrays.asList(topic, topic1)));
         assertEquals(2, state.subscription().size(), "Expected subscribed topics count is incorrect");
     }
@@ -330,13 +330,13 @@ public class SubscriptionStateTest {
     public void unsubscribeUserAssignment() {
         state.assignFromUser(new HashSet<>(Arrays.asList(tp0, tp1)));
         state.unsubscribe();
-        state.subscribe(singleton(topic), rebalanceListener);
+        state.subscribe(singleton(topic), Optional.of(rebalanceListener));
         assertEquals(singleton(topic), state.subscription());
     }
 
     @Test
     public void unsubscribeUserSubscribe() {
-        state.subscribe(singleton(topic), rebalanceListener);
+        state.subscribe(singleton(topic), Optional.of(rebalanceListener));
         state.unsubscribe();
         state.assignFromUser(singleton(tp0));
         assertEquals(singleton(tp0), state.assignedPartitions());
@@ -345,7 +345,7 @@ public class SubscriptionStateTest {
 
     @Test
     public void unsubscription() {
-        state.subscribe(Pattern.compile(".*"), rebalanceListener);
+        state.subscribe(Pattern.compile(".*"), Optional.of(rebalanceListener));
         state.subscribeFromPattern(new HashSet<>(Arrays.asList(topic, topic1)));
         assertTrue(state.checkAssignmentMatchedSubscription(singleton(tp1)));
         state.assignFromSubscribed(singleton(tp1));
diff --git a/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala b/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala
index c7e7fb0ac0c..06f92edc94c 100644
--- a/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala
+++ b/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala
@@ -26,7 +26,6 @@ import kafka.utils.{TestInfoUtils, TestUtils}
 import kafka.utils.TestUtils.waitUntilTrue
 import org.apache.kafka.clients.admin.{Admin, AlterConfigOp, NewTopic}
 import org.apache.kafka.clients.consumer._
-import org.apache.kafka.clients.consumer.internals.NoOpConsumerRebalanceListener
 import org.apache.kafka.clients.producer._
 import org.apache.kafka.common.acl.AclOperation._
 import org.apache.kafka.common.acl.AclPermissionType.{ALLOW, DENY}
@@ -1174,7 +1173,7 @@ class AuthorizerIntegrationTest extends BaseRequestTest {
     addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResource)
 
     val consumer = createConsumer()
-    consumer.subscribe(Pattern.compile(topicPattern), new NoOpConsumerRebalanceListener)
+    consumer.subscribe(Pattern.compile(topicPattern))
     consumer.poll(0)
     assertTrue(consumer.subscription.isEmpty)
   }