You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by rs...@apache.org on 2020/05/29 21:14:47 UTC

[kafka] branch 2.6 updated (50482bc -> dddd73f)

This is an automated email from the ASF dual-hosted git repository.

rsivaram pushed a change to branch 2.6
in repository https://gitbox.apache.org/repos/asf/kafka.git.


    from 50482bc  KAFKA-10061; Fix flaky `ReassignPartitionsIntegrationTest.testCancellation` (#8749)
     new 6fee20b  KAFKA-10029; Don't update completedReceives when channels are closed to avoid ConcurrentModificationException (#8705)
     new dddd73f  KAFKA-10056; Ensure consumer metadata contains new topics on subscription change (#8739)

The 2 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 .../consumer/internals/SubscriptionState.java      | 12 ++-
 .../apache/kafka/common/network/KafkaChannel.java  |  5 +-
 .../org/apache/kafka/common/network/Selector.java  | 35 +++++++--
 .../internals/ConsumerCoordinatorTest.java         | 42 +++++++++++
 .../consumer/internals/SubscriptionStateTest.java  |  6 ++
 .../apache/kafka/common/network/SelectorTest.java  | 85 +++++++++++++++++++++
 .../main/scala/kafka/network/SocketServer.scala    |  4 +-
 .../unit/kafka/network/SocketServerTest.scala      | 87 +++++++++++++++++-----
 8 files changed, 247 insertions(+), 29 deletions(-)


[kafka] 01/02: KAFKA-10029; Don't update completedReceives when channels are closed to avoid ConcurrentModificationException (#8705)

Posted by rs...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

rsivaram pushed a commit to branch 2.6
in repository https://gitbox.apache.org/repos/asf/kafka.git

commit 6fee20b28e32d4919b583838d49f7c2c095bcca5
Author: Rajini Sivaram <ra...@googlemail.com>
AuthorDate: Fri May 29 09:32:57 2020 +0100

    KAFKA-10029; Don't update completedReceives when channels are closed to avoid ConcurrentModificationException (#8705)
    
    Reviewers: Ismael Juma <is...@juma.me.uk>, Chia-Ping Tsai <ch...@gmail.com>
---
 .../apache/kafka/common/network/KafkaChannel.java  |  5 +-
 .../org/apache/kafka/common/network/Selector.java  | 35 +++++++--
 .../apache/kafka/common/network/SelectorTest.java  | 85 +++++++++++++++++++++
 .../main/scala/kafka/network/SocketServer.scala    |  4 +-
 .../unit/kafka/network/SocketServerTest.scala      | 87 +++++++++++++++++-----
 5 files changed, 188 insertions(+), 28 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java b/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java
index 4e4edd4..0ed9ee0 100644
--- a/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java
+++ b/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java
@@ -28,7 +28,6 @@ import java.net.Socket;
 import java.net.SocketAddress;
 import java.nio.channels.SelectionKey;
 import java.nio.channels.SocketChannel;
-import java.util.Objects;
 import java.util.Optional;
 import java.util.function.Supplier;
 
@@ -471,12 +470,12 @@ public class KafkaChannel implements AutoCloseable {
             return false;
         }
         KafkaChannel that = (KafkaChannel) o;
-        return Objects.equals(id, that.id);
+        return id.equals(that.id);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(id);
+        return id.hashCode();
     }
 
     @Override
diff --git a/clients/src/main/java/org/apache/kafka/common/network/Selector.java b/clients/src/main/java/org/apache/kafka/common/network/Selector.java
index cb91cad..06f7048 100644
--- a/clients/src/main/java/org/apache/kafka/common/network/Selector.java
+++ b/clients/src/main/java/org/apache/kafka/common/network/Selector.java
@@ -107,7 +107,7 @@ public class Selector implements Selectable, AutoCloseable {
     private final Set<KafkaChannel> explicitlyMutedChannels;
     private boolean outOfMemory;
     private final List<Send> completedSends;
-    private final LinkedHashMap<KafkaChannel, NetworkReceive> completedReceives;
+    private final LinkedHashMap<String, NetworkReceive> completedReceives;
     private final Set<SelectionKey> immediatelyConnectedKeys;
     private final Map<String, KafkaChannel> closingChannels;
     private Set<SelectionKey> keysWithBufferedRead;
@@ -804,7 +804,33 @@ public class Selector implements Selectable, AutoCloseable {
     }
 
     /**
-     * Clear the results from the prior poll
+     * Clears completed receives. This is used by SocketServer to remove references to
+     * receive buffers after processing completed receives, without waiting for the next
+     * poll().
+     */
+    public void clearCompletedReceives() {
+        this.completedReceives.clear();
+    }
+
+    /**
+     * Clears completed sends. This is used by SocketServer to remove references to
+     * send buffers after processing completed sends, without waiting for the next
+     * poll().
+     */
+    public void clearCompletedSends() {
+        this.completedSends.clear();
+    }
+
+    /**
+     * Clears all the results from the previous poll. This is invoked by Selector at the start of
+     * a poll() when all the results from the previous poll are expected to have been handled.
+     * <p>
+     * SocketServer uses {@link #clearCompletedSends()} and {@link #clearCompletedSends()} to
+     * clear `completedSends` and `completedReceives` as soon as they are processed to avoid
+     * holding onto large request/response buffers from multiple connections longer than necessary.
+     * Clients rely on Selector invoking {@link #clear()} at the start of each poll() since memory usage
+     * is less critical and clearing once-per-poll provides the flexibility to process these results in
+     * any order before the next poll.
      */
     private void clear() {
         this.completedSends.clear();
@@ -935,7 +961,6 @@ public class Selector implements Selectable, AutoCloseable {
         }
 
         this.sensors.connectionClosed.record();
-        this.completedReceives.remove(channel);
         this.explicitlyMutedChannels.remove(channel);
         if (notifyDisconnect)
             this.disconnected.put(channel.id(), channel.state());
@@ -1015,7 +1040,7 @@ public class Selector implements Selectable, AutoCloseable {
      * Check if given channel has a completed receive
      */
     private boolean hasCompletedReceive(KafkaChannel channel) {
-        return completedReceives.containsKey(channel);
+        return completedReceives.containsKey(channel.id());
     }
 
     /**
@@ -1025,7 +1050,7 @@ public class Selector implements Selectable, AutoCloseable {
         if (hasCompletedReceive(channel))
             throw new IllegalStateException("Attempting to add second completed receive to channel " + channel.id());
 
-        this.completedReceives.put(channel, networkReceive);
+        this.completedReceives.put(channel.id(), networkReceive);
         sensors.recordCompletedReceive(channel.id(), networkReceive.size(), currentTimeMs);
     }
 
diff --git a/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java b/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java
index 57b0153..ac773ee 100644
--- a/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java
@@ -48,6 +48,7 @@ import java.nio.channels.SocketChannel;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
@@ -341,6 +342,36 @@ public class SelectorTest {
         assertEquals("", blockingRequest(node, ""));
     }
 
+    @Test
+    public void testClearCompletedSendsAndReceives() throws Exception {
+        int bufferSize = 1024;
+        String node = "0";
+        InetSocketAddress addr = new InetSocketAddress("localhost", server.port);
+        connect(node, addr);
+        String request = TestUtils.randomString(bufferSize);
+        selector.send(createSend(node, request));
+        boolean sent = false;
+        boolean received = false;
+        while (!sent || !received) {
+            selector.poll(1000L);
+            assertEquals("No disconnects should have occurred.", 0, selector.disconnected().size());
+            if (!selector.completedSends().isEmpty()) {
+                assertEquals(1, selector.completedSends().size());
+                selector.clearCompletedSends();
+                assertEquals(0, selector.completedSends().size());
+                sent = true;
+            }
+
+            if (!selector.completedReceives().isEmpty()) {
+                assertEquals(1, selector.completedReceives().size());
+                assertEquals(request, asString(selector.completedReceives().iterator().next()));
+                selector.clearCompletedReceives();
+                assertEquals(0, selector.completedReceives().size());
+                received = true;
+            }
+        }
+    }
+
     @Test(expected = IllegalStateException.class)
     public void testExistingConnectionId() throws IOException {
         blockingConnect("0");
@@ -904,6 +935,60 @@ public class SelectorTest {
         assertEquals(asList(send), selector.completedSends());
     }
 
+    /**
+     * Ensure that no errors are thrown if channels are closed while processing multiple completed receives
+     */
+    @Test
+    public void testChannelCloseWhileProcessingReceives() throws Exception {
+        int numChannels = 4;
+        Map<String, KafkaChannel> channels = TestUtils.fieldValue(selector, Selector.class, "channels");
+        Set<SelectionKey> selectionKeys = new HashSet<>();
+        for (int i = 0; i < numChannels; i++) {
+            String id = String.valueOf(i);
+            KafkaChannel channel = mock(KafkaChannel.class);
+            channels.put(id, channel);
+            when(channel.id()).thenReturn(id);
+            when(channel.state()).thenReturn(ChannelState.READY);
+            when(channel.isConnected()).thenReturn(true);
+            when(channel.ready()).thenReturn(true);
+            when(channel.read()).thenReturn(1L);
+
+            SelectionKey selectionKey = mock(SelectionKey.class);
+            when(channel.selectionKey()).thenReturn(selectionKey);
+            when(selectionKey.isValid()).thenReturn(true);
+            when(selectionKey.readyOps()).thenReturn(SelectionKey.OP_READ);
+            selectionKey.attach(channel);
+            selectionKeys.add(selectionKey);
+
+            NetworkReceive receive = mock(NetworkReceive.class);
+            when(receive.source()).thenReturn(id);
+            when(receive.size()).thenReturn(10);
+            when(receive.bytesRead()).thenReturn(1);
+            when(receive.payload()).thenReturn(ByteBuffer.allocate(10));
+            when(channel.maybeCompleteReceive()).thenReturn(receive);
+        }
+
+        selector.pollSelectionKeys(selectionKeys, false, System.nanoTime());
+        assertEquals(numChannels, selector.completedReceives().size());
+        Set<KafkaChannel> closed = new HashSet<>();
+        Set<KafkaChannel> notClosed = new HashSet<>();
+        for (NetworkReceive receive : selector.completedReceives()) {
+            KafkaChannel channel = selector.channel(receive.source());
+            assertNotNull(channel);
+            if (closed.size() < 2) {
+                selector.close(channel.id());
+                closed.add(channel);
+            } else
+                notClosed.add(channel);
+        }
+        assertEquals(notClosed, new HashSet<>(selector.channels()));
+        closed.forEach(channel -> assertNull(selector.channel(channel.id())));
+
+        selector.poll(0);
+        assertEquals(0, selector.completedReceives().size());
+    }
+
+
     private String blockingRequest(String node, String s) throws IOException {
         selector.send(createSend(node, s));
         selector.poll(1000L);
diff --git a/core/src/main/scala/kafka/network/SocketServer.scala b/core/src/main/scala/kafka/network/SocketServer.scala
index 16d9322..bff185d 100644
--- a/core/src/main/scala/kafka/network/SocketServer.scala
+++ b/core/src/main/scala/kafka/network/SocketServer.scala
@@ -835,7 +835,7 @@ private[kafka] class Processor(val id: Int,
     }
   }
 
-  private def processException(errorMessage: String, throwable: Throwable): Unit = {
+  private[network] def processException(errorMessage: String, throwable: Throwable): Unit = {
     throwable match {
       case e: ControlThrowable => throw e
       case e => error(errorMessage, e)
@@ -968,6 +968,7 @@ private[kafka] class Processor(val id: Int,
           processChannelException(receive.source, s"Exception while processing request from ${receive.source}", e)
       }
     }
+    selector.clearCompletedReceives()
   }
 
   private def processCompletedSends(): Unit = {
@@ -991,6 +992,7 @@ private[kafka] class Processor(val id: Int,
           s"Exception while processing completed send to ${send.destination}", e)
       }
     }
+    selector.clearCompletedSends()
   }
 
   private def updateRequestMetrics(response: RequestChannel.Response): Unit = {
diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala
index 6743dc8..8f1091e 100644
--- a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala
+++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala
@@ -1597,6 +1597,8 @@ class SocketServerTest {
       testableSelector.waitForOperations(SelectorOperation.Poll, 1)
 
       testableSelector.waitForOperations(SelectorOperation.CloseSelector, 1)
+      assertEquals(1, testableServer.uncaughtExceptions)
+      testableServer.uncaughtExceptions = 0
     })
   }
 
@@ -1675,6 +1677,7 @@ class SocketServerTest {
         testWithServer(testableServer)
     } finally {
       shutdownServerAndMetrics(testableServer)
+      assertEquals(0, testableServer.uncaughtExceptions)
     }
   }
 
@@ -1730,6 +1733,7 @@ class SocketServerTest {
       new Metrics, time, credentialProvider) {
 
     @volatile var selector: Option[TestableSelector] = None
+    @volatile var uncaughtExceptions = 0
 
     override def newProcessor(id: Int, requestChannel: RequestChannel, connectionQuotas: ConnectionQuotas, listenerName: ListenerName,
                                 protocol: SecurityProtocol, memoryPool: MemoryPool): Processor = {
@@ -1742,6 +1746,12 @@ class SocketServerTest {
            selector = Some(testableSelector)
            testableSelector
         }
+
+        override private[network] def processException(errorMessage: String, throwable: Throwable): Unit = {
+          if (errorMessage.contains("uncaught exception"))
+            uncaughtExceptions += 1
+          super.processException(errorMessage, throwable)
+        }
       }
     }
 
@@ -1794,12 +1804,19 @@ class SocketServerTest {
     // Enable data from `Selector.poll()` to be deferred to a subsequent poll() until
     // the number of elements of that type reaches `minPerPoll`. This enables tests to verify
     // that failed processing doesn't impact subsequent processing within the same iteration.
-    class PollData[T] {
+    abstract class PollData[T] {
       var minPerPoll = 1
       val deferredValues = mutable.Buffer[T]()
-      val currentPollValues = mutable.Buffer[T]()
-      def update(newValues: mutable.Buffer[T]): Unit = {
-        if (currentPollValues.nonEmpty || deferredValues.size + newValues.size >= minPerPoll) {
+
+      /**
+       * Process new results and return the results for the current poll if at least
+       * `minPerPoll` results are available including any deferred results. Otherwise
+       * add the provided values to the deferred set and return an empty buffer. This allows
+       * tests to process `minPerPoll` elements as the results of a single poll iteration.
+       */
+      protected def update(newValues: mutable.Buffer[T]): mutable.Buffer[T] = {
+        val currentPollValues = mutable.Buffer[T]()
+        if (deferredValues.size + newValues.size >= minPerPoll) {
           if (deferredValues.nonEmpty) {
             currentPollValues ++= deferredValues
             deferredValues.clear()
@@ -1807,14 +1824,49 @@ class SocketServerTest {
           currentPollValues ++= newValues
         } else
           deferredValues ++= newValues
+
+        currentPollValues
       }
-      def reset(): Unit = {
-        currentPollValues.clear()
+
+      /**
+       * Process results from the appropriate buffer in Selector and update the buffer to either
+       * defer and return nothing or return all results including previously deferred values.
+       */
+      def updateResults(): Unit
+    }
+
+    class CompletedReceivesPollData(selector: TestableSelector) extends PollData[NetworkReceive] {
+      val completedReceivesMap: util.Map[String, NetworkReceive] = JTestUtils.fieldValue(selector, classOf[Selector], "completedReceives")
+
+      override def updateResults(): Unit = {
+        val currentReceives = update(selector.completedReceives.asScala.toBuffer)
+        completedReceivesMap.clear()
+        currentReceives.foreach { receive =>
+          val channelOpt = Option(selector.channel(receive.source)).orElse(Option(selector.closingChannel(receive.source)))
+          channelOpt.foreach { channel => completedReceivesMap.put(channel.id, receive) }
+        }
       }
     }
-    val cachedCompletedReceives = new PollData[NetworkReceive]()
-    val cachedCompletedSends = new PollData[Send]()
-    val cachedDisconnected = new PollData[(String, ChannelState)]()
+
+    class CompletedSendsPollData(selector: TestableSelector) extends PollData[Send] {
+      override def updateResults(): Unit = {
+        val currentSends = update(selector.completedSends.asScala)
+        selector.completedSends.clear()
+        currentSends.foreach { selector.completedSends.add }
+      }
+    }
+
+    class DisconnectedPollData(selector: TestableSelector) extends PollData[(String, ChannelState)] {
+      override def updateResults(): Unit = {
+        val currentDisconnected = update(selector.disconnected.asScala.toBuffer)
+        selector.disconnected.clear()
+        currentDisconnected.foreach { case (channelId, state) => selector.disconnected.put(channelId, state) }
+      }
+    }
+
+    val cachedCompletedReceives = new CompletedReceivesPollData(this)
+    val cachedCompletedSends = new CompletedSendsPollData(this)
+    val cachedDisconnected = new DisconnectedPollData(this)
     val allCachedPollData = Seq(cachedCompletedReceives, cachedCompletedSends, cachedDisconnected)
     val pendingClosingChannels = new ConcurrentLinkedQueue[KafkaChannel]()
     @volatile var minWakeupCount = 0
@@ -1861,20 +1913,23 @@ class SocketServerTest {
 
     override def poll(timeout: Long): Unit = {
       try {
+        assertEquals(0, super.completedReceives().size)
+        assertEquals(0, super.completedSends().size)
+
         pollCallback.apply()
         while (!pendingClosingChannels.isEmpty) {
           makeClosing(pendingClosingChannels.poll())
         }
-        allCachedPollData.foreach(_.reset)
         runOp(SelectorOperation.Poll, None) {
           super.poll(pollTimeoutOverride.getOrElse(timeout))
         }
       } finally {
         super.channels.forEach(allChannels += _.id)
         allDisconnectedChannels ++= super.disconnected.asScala.keys
-        cachedCompletedReceives.update(super.completedReceives.asScala.toBuffer)
-        cachedCompletedSends.update(super.completedSends.asScala)
-        cachedDisconnected.update(super.disconnected.asScala.toBuffer)
+
+        cachedCompletedReceives.updateResults()
+        cachedCompletedSends.updateResults()
+        cachedDisconnected.updateResults()
       }
     }
 
@@ -1899,12 +1954,6 @@ class SocketServerTest {
       }
     }
 
-    override def disconnected: java.util.Map[String, ChannelState] = cachedDisconnected.currentPollValues.toMap.asJava
-
-    override def completedSends: java.util.List[Send] = cachedCompletedSends.currentPollValues.asJava
-
-    override def completedReceives: java.util.List[NetworkReceive] = cachedCompletedReceives.currentPollValues.asJava
-
     override def close(id: String): Unit = {
       runOp(SelectorOperation.Close, Some(id)) {
         super.close(id)


[kafka] 02/02: KAFKA-10056; Ensure consumer metadata contains new topics on subscription change (#8739)

Posted by rs...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

rsivaram pushed a commit to branch 2.6
in repository https://gitbox.apache.org/repos/asf/kafka.git

commit dddd73fb2efdc2a0b1fa71c9946406b12886c1bf
Author: Rajini Sivaram <ra...@googlemail.com>
AuthorDate: Fri May 29 17:03:49 2020 +0100

    KAFKA-10056; Ensure consumer metadata contains new topics on subscription change (#8739)
    
    Reviewers: Jason Gustafson <ja...@confluent.io>
---
 .../consumer/internals/SubscriptionState.java      | 12 ++++++-
 .../internals/ConsumerCoordinatorTest.java         | 42 ++++++++++++++++++++++
 .../consumer/internals/SubscriptionStateTest.java  |  6 ++++
 3 files changed, 59 insertions(+), 1 deletion(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java
index 4d93a1b..86f4e68 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java
@@ -345,7 +345,17 @@ public class SubscriptionState {
      *   of the current generation; otherwise it returns the same set as {@link #subscription()}
      */
     synchronized Set<String> metadataTopics() {
-        return groupSubscription.isEmpty() ? subscription : groupSubscription;
+        if (groupSubscription.isEmpty())
+            return subscription;
+        else if (groupSubscription.containsAll(subscription))
+            return groupSubscription;
+        else {
+            // When subscription changes `groupSubscription` may be outdated, ensure that
+            // new subscription topics are returned.
+            Set<String> topics = new HashSet<>(groupSubscription);
+            topics.addAll(subscription);
+            return topics;
+        }
     }
 
     synchronized boolean needsMetadata(String topic) {
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
index afba13b..d61010e 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
@@ -783,6 +783,35 @@ public class ConsumerCoordinatorTest {
     }
 
     @Test
+    public void testMetadataTopicsDuringSubscriptionChange() {
+        final String consumerId = "subscription_change";
+        final List<String> oldSubscription = singletonList(topic1);
+        final List<TopicPartition> oldAssignment = Collections.singletonList(t1p);
+        final List<String> newSubscription = singletonList(topic2);
+        final List<TopicPartition> newAssignment = Collections.singletonList(t2p);
+
+        subscriptions.subscribe(toSet(oldSubscription), rebalanceListener);
+        assertEquals(toSet(oldSubscription), subscriptions.metadataTopics());
+
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
+
+        prepareJoinAndSyncResponse(consumerId, 1, oldSubscription, oldAssignment);
+
+        coordinator.poll(time.timer(0));
+        assertEquals(toSet(oldSubscription), subscriptions.metadataTopics());
+
+        subscriptions.subscribe(toSet(newSubscription), rebalanceListener);
+        assertEquals(Utils.mkSet(topic1, topic2), subscriptions.metadataTopics());
+
+        prepareJoinAndSyncResponse(consumerId, 2, newSubscription, newAssignment);
+        coordinator.poll(time.timer(Long.MAX_VALUE));
+        assertFalse(coordinator.rejoinNeededOrPending());
+        assertEquals(toSet(newAssignment), subscriptions.assignedPartitions());
+        assertEquals(toSet(newSubscription), subscriptions.metadataTopics());
+    }
+
+    @Test
     public void testPatternJoinGroupLeader() {
         final String consumerId = "leader";
         final List<TopicPartition> assigned = Arrays.asList(t1p, t2p);
@@ -3107,6 +3136,19 @@ public class ConsumerCoordinatorTest {
         client.prepareResponse(offsetCommitRequestMatcher(expectedOffsets), offsetCommitResponse(errors), disconnected);
     }
 
+    private void prepareJoinAndSyncResponse(String consumerId, int generation, List<String> subscription, List<TopicPartition> assignment) {
+        partitionAssignor.prepare(singletonMap(consumerId, assignment));
+        client.prepareResponse(
+                joinGroupLeaderResponse(
+                        generation, consumerId, singletonMap(consumerId, subscription), Errors.NONE));
+        client.prepareResponse(body -> {
+            SyncGroupRequest sync = (SyncGroupRequest) body;
+            return sync.data.memberId().equals(consumerId) &&
+                    sync.data.generationId() == generation &&
+                    sync.groupAssignments().containsKey(consumerId);
+        }, syncGroupResponse(assignment, Errors.NONE));
+    }
+
     private Map<TopicPartition, Errors> partitionErrors(Collection<TopicPartition> partitions, Errors error) {
         final Map<TopicPartition, Errors> errors = new HashMap<>();
         for (TopicPartition partition : partitions) {
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/SubscriptionStateTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/SubscriptionStateTest.java
index cc74652..ae00b8e 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/SubscriptionStateTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/SubscriptionStateTest.java
@@ -122,6 +122,12 @@ public class SubscriptionStateTest {
         // `groupSubscribe` does not accumulate
         assertFalse(state.groupSubscribe(singleton(topic1)));
         assertEquals(singleton(topic1), state.metadataTopics());
+
+        state.subscribe(singleton("anotherTopic"), rebalanceListener);
+        assertEquals(Utils.mkSet(topic1, "anotherTopic"), state.metadataTopics());
+
+        assertFalse(state.groupSubscribe(singleton("anotherTopic")));
+        assertEquals(singleton("anotherTopic"), state.metadataTopics());
     }
 
     @Test