You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by gu...@apache.org on 2020/06/01 23:18:02 UTC

[kafka] branch 2.4 updated: KAFKA-9987: optimize sticky assignment algorithm for same-subscription case (#8668)

This is an automated email from the ASF dual-hosted git repository.

guozhang pushed a commit to branch 2.4
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/2.4 by this push:
     new 1163686  KAFKA-9987: optimize sticky assignment algorithm for same-subscription case (#8668)
1163686 is described below

commit 116368604bd28427429844582ddb5fe4623bbb5e
Author: A. Sophie Blee-Goldman <so...@confluent.io>
AuthorDate: Mon Jun 1 15:57:15 2020 -0700

    KAFKA-9987: optimize sticky assignment algorithm for same-subscription case (#8668)
    
    Motivation and pseudo code algorithm in the ticket.
    
    Added a scale test with large number of topic partitions and consumers and 30s timeout.
    With these changes, assignment with 2,000 consumers and 200 topics with 2,000 each completes within a few seconds.
    
    Porting the same test to trunk, it took 2 minutes even with a 100x reduction in the number of topics (ie, 2 minutes for 2,000 consumers and 2 topics with 2,000 partitions)
    
    Should be cherry-picked to 2.6, 2.5, and 2.4
    
    Reviewers: Guozhang Wang <wa...@gmail.com>
---
 checkstyle/suppressions.xml                        |   4 +-
 .../consumer/CooperativeStickyAssignor.java        |  46 +--
 .../consumer/internals/AbstractStickyAssignor.java | 319 ++++++++++++++-------
 .../kafka/clients/consumer/StickyAssignorTest.java |   5 +-
 .../internals/AbstractStickyAssignorTest.java      |  82 +++---
 5 files changed, 283 insertions(+), 173 deletions(-)

diff --git a/checkstyle/suppressions.xml b/checkstyle/suppressions.xml
index c8678ac..bee700b 100644
--- a/checkstyle/suppressions.xml
+++ b/checkstyle/suppressions.xml
@@ -57,13 +57,13 @@
               files="(Utils|Topic|KafkaLZ4BlockOutputStream|AclData|JoinGroupRequest).java"/>
 
     <suppress checks="CyclomaticComplexity"
-              files="(ConsumerCoordinator|Fetcher|Sender|KafkaProducer|BufferPool|ConfigDef|RecordAccumulator|KerberosLogin|AbstractRequest|AbstractResponse|Selector|SslFactory|SslTransportLayer|SaslClientAuthenticator|SaslClientCallbackHandler|SaslServerAuthenticator|AbstractCoordinator).java"/>
+              files="(ConsumerCoordinator|Fetcher|Sender|KafkaProducer|BufferPool|ConfigDef|RecordAccumulator|KerberosLogin|AbstractRequest|AbstractResponse|Selector|SslFactory|SslTransportLayer|SaslClientAuthenticator|SaslClientCallbackHandler|SaslServerAuthenticator|AbstractCoordinator|TransactionManager|AbstractStickyAssignor).java"/>
 
     <suppress checks="JavaNCSS"
               files="(AbstractRequest|KerberosLogin|WorkerSinkTaskTest|TransactionManagerTest|SenderTest|KafkaAdminClient|ConsumerCoordinatorTest).java"/>
 
     <suppress checks="NPathComplexity"
-              files="(BufferPool|Fetcher|MetricName|Node|ConfigDef|RecordBatch|SslFactory|SslTransportLayer|MetadataResponse|KerberosLogin|Selector|Sender|Serdes|TokenInformation|Agent|Values|PluginUtils|MiniTrogdorCluster|TasksRequest|KafkaProducer).java"/>
+              files="(ConsumerCoordinator|BufferPool|Fetcher|MetricName|Node|ConfigDef|RecordBatch|SslFactory|SslTransportLayer|MetadataResponse|KerberosLogin|Selector|Sender|Serdes|TokenInformation|Agent|Values|PluginUtils|MiniTrogdorCluster|TasksRequest|KafkaProducer|AbstractStickyAssignor).java"/>
 
     <suppress checks="(JavaNCSS|CyclomaticComplexity|MethodLength)"
               files="CoordinatorClient.java"/>
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/CooperativeStickyAssignor.java b/clients/src/main/java/org/apache/kafka/clients/consumer/CooperativeStickyAssignor.java
index bef32bf..c7c0679 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/CooperativeStickyAssignor.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/CooperativeStickyAssignor.java
@@ -16,7 +16,6 @@
  */
 package org.apache.kafka.clients.consumer;
 
-import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -62,16 +61,26 @@ public class CooperativeStickyAssignor extends AbstractStickyAssignor {
     @Override
     public Map<String, List<TopicPartition>> assign(Map<String, Integer> partitionsPerTopic,
                                                     Map<String, Subscription> subscriptions) {
+        Map<String, List<TopicPartition>> assignments = super.assign(partitionsPerTopic, subscriptions);
 
-        final Map<String, List<TopicPartition>> assignments = super.assign(partitionsPerTopic, subscriptions);
-        adjustAssignment(subscriptions, assignments);
+        Map<TopicPartition, String> partitionsTransferringOwnership = super.partitionsTransferringOwnership == null ?
+            computePartitionsTransferringOwnership(subscriptions, assignments) :
+            super.partitionsTransferringOwnership;
+
+        adjustAssignment(assignments, partitionsTransferringOwnership);
         return assignments;
     }
 
     // Following the cooperative rebalancing protocol requires removing partitions that must first be revoked from the assignment
-    private void adjustAssignment(final Map<String, Subscription> subscriptions,
-                                  final Map<String, List<TopicPartition>> assignments) {
+    private void adjustAssignment(Map<String, List<TopicPartition>> assignments,
+                                  Map<TopicPartition, String> partitionsTransferringOwnership) {
+        for (Map.Entry<TopicPartition, String> partitionEntry : partitionsTransferringOwnership.entrySet()) {
+            assignments.get(partitionEntry.getValue()).remove(partitionEntry.getKey());
+        }
+    }
 
+    private Map<TopicPartition, String> computePartitionsTransferringOwnership(Map<String, Subscription> subscriptions,
+                                                                               Map<String, List<TopicPartition>> assignments) {
         Map<TopicPartition, String> allAddedPartitions = new HashMap<>();
         Set<TopicPartition> allRevokedPartitions = new HashSet<>();
 
@@ -81,25 +90,20 @@ public class CooperativeStickyAssignor extends AbstractStickyAssignor {
             List<TopicPartition> ownedPartitions = subscriptions.get(consumer).ownedPartitions();
             List<TopicPartition> assignedPartitions = entry.getValue();
 
-            List<TopicPartition> addedPartitions = new ArrayList<>(assignedPartitions);
-            addedPartitions.removeAll(ownedPartitions);
-            for (TopicPartition tp : addedPartitions) {
-                allAddedPartitions.put(tp, consumer);
+            Set<TopicPartition> ownedPartitionsSet = new HashSet<>(ownedPartitions);
+            for (TopicPartition tp : assignedPartitions) {
+                if (!ownedPartitionsSet.contains(tp))
+                    allAddedPartitions.put(tp, consumer);
             }
 
-            final Set<TopicPartition> revokedPartitions = new HashSet<>(ownedPartitions);
-            revokedPartitions.removeAll(assignedPartitions);
-            allRevokedPartitions.addAll(revokedPartitions);
-        }
-
-        // remove any partitions to be revoked from the current assignment
-        for (TopicPartition tp : allRevokedPartitions) {
-            // if partition is being migrated to another consumer, don't assign it there yet
-            if (allAddedPartitions.containsKey(tp)) {
-                String assignedConsumer = allAddedPartitions.get(tp);
-                assignments.get(assignedConsumer).remove(tp);
+            Set<TopicPartition> assignedPartitionsSet = new HashSet<>(assignedPartitions);
+            for (TopicPartition tp : ownedPartitions) {
+                if (!assignedPartitionsSet.contains(tp))
+                    allRevokedPartitions.add(tp);
             }
         }
-    }
 
+        allAddedPartitions.keySet().retainAll(allRevokedPartitions);
+        return allAddedPartitions;
+    }
 }
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractStickyAssignor.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractStickyAssignor.java
index 12864de..d4e023c 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractStickyAssignor.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractStickyAssignor.java
@@ -18,19 +18,22 @@ package org.apache.kafka.clients.consumer.internals;
 
 import java.io.Serializable;
 import java.util.ArrayList;
-import java.util.Collection;
 import java.util.Collections;
 import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Iterator;
+import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
 import java.util.Optional;
+import java.util.Queue;
 import java.util.Set;
+import java.util.SortedSet;
 import java.util.TreeMap;
 import java.util.TreeSet;
+import java.util.stream.Collectors;
 import org.apache.kafka.common.TopicPartition;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -40,7 +43,11 @@ public abstract class AbstractStickyAssignor extends AbstractPartitionAssignor {
 
     public static final int DEFAULT_GENERATION = -1;
 
-    private PartitionMovements partitionMovements;
+    private PartitionMovements partitionMovements = new PartitionMovements();
+
+    // Keep track of the partitions being migrated from one consumer to another during assignment
+    // so the cooperative assignor can adjust the assignment
+    protected Map<TopicPartition, String> partitionsTransferringOwnership = new HashMap<>();
 
     static final class ConsumerGenerationPair {
         final String consumer;
@@ -65,9 +72,206 @@ public abstract class AbstractStickyAssignor extends AbstractPartitionAssignor {
     @Override
     public Map<String, List<TopicPartition>> assign(Map<String, Integer> partitionsPerTopic,
                                                     Map<String, Subscription> subscriptions) {
+        Map<String, List<TopicPartition>> consumerToOwnedPartitions = new HashMap<>();
+        if (allSubscriptionsEqual(partitionsPerTopic.keySet(), subscriptions, consumerToOwnedPartitions)) {
+            log.debug("Detected that all consumers were subscribed to same set of topics, invoking the "
+                          + "optimized assignment algorithm");
+            partitionsTransferringOwnership = new HashMap<>();
+            return constrainedAssign(partitionsPerTopic, consumerToOwnedPartitions);
+        } else {
+            log.debug("Detected that all not consumers were subscribed to same set of topics, falling back to the "
+                          + "general case assignment algorithm");
+            partitionsTransferringOwnership = null;
+            return generalAssign(partitionsPerTopic, subscriptions);
+        }
+    }
+
+    /**
+     * Returns true iff all consumers have an identical subscription. Also fills out the passed in
+     * {@code consumerToOwnedPartitions} with each consumer's previously owned and still-subscribed partitions
+     */
+    private boolean allSubscriptionsEqual(Set<String> allTopics,
+                                          Map<String, Subscription> subscriptions,
+                                          Map<String, List<TopicPartition>> consumerToOwnedPartitions) {
+        Set<String> membersWithOldGeneration = new HashSet<>();
+        Set<String> membersOfCurrentHighestGeneration = new HashSet<>();
+        int maxGeneration = DEFAULT_GENERATION;
+
+        Set<String> subscribedTopics = new HashSet<>();
+
+        for (Map.Entry<String, Subscription> subscriptionEntry : subscriptions.entrySet()) {
+            String consumer = subscriptionEntry.getKey();
+            Subscription subscription = subscriptionEntry.getValue();
+
+            // initialize the subscribed topics set if this is the first subscription
+            if (subscribedTopics.isEmpty()) {
+                subscribedTopics.addAll(subscription.topics());
+            } else if (!(subscription.topics().size() == subscribedTopics.size()
+                && subscribedTopics.containsAll(subscription.topics()))) {
+                return false;
+            }
+
+            MemberData memberData = memberData(subscription);
+
+            List<TopicPartition> ownedPartitions = new ArrayList<>();
+            consumerToOwnedPartitions.put(consumer, ownedPartitions);
+
+            // Only consider this consumer's owned partitions as valid if it is a member of the current highest
+            // generation, or it's generation is not present but we have not seen any known generation so far
+            if (memberData.generation.isPresent() && memberData.generation.get() >= maxGeneration
+                || !memberData.generation.isPresent() && maxGeneration == DEFAULT_GENERATION) {
+
+                membersOfCurrentHighestGeneration.add(consumer);
+                for (final TopicPartition tp : memberData.partitions) {
+                    // filter out any topics that no longer exist or aren't part of the current subscription
+                    if (allTopics.contains(tp.topic())) {
+                        ownedPartitions.add(tp);
+                    }
+                }
+
+                // If the current member's generation is higher, all the previous owned partitions are invalid
+                if (memberData.generation.isPresent() && memberData.generation.get() > maxGeneration) {
+                    membersWithOldGeneration.addAll(membersOfCurrentHighestGeneration);
+                    membersOfCurrentHighestGeneration.clear();
+                    maxGeneration = memberData.generation.get();
+                }
+            }
+        }
+
+        for (String consumer : membersWithOldGeneration) {
+            consumerToOwnedPartitions.get(consumer).clear();
+        }
+        return true;
+    }
+
+    private Map<String, List<TopicPartition>> constrainedAssign(Map<String, Integer> partitionsPerTopic,
+                                                                Map<String, List<TopicPartition>> consumerToOwnedPartitions) {
+        SortedSet<TopicPartition> unassignedPartitions = getTopicPartitions(partitionsPerTopic);
+
+        Set<TopicPartition> allRevokedPartitions = new HashSet<>();
+
+        // Each consumer should end up in exactly one of the below
+        List<String> unfilledMembers = new LinkedList<>();
+        Queue<String> maxCapacityMembers = new LinkedList<>();
+        Queue<String> minCapacityMembers = new LinkedList<>();
+
+        int numberOfConsumers = consumerToOwnedPartitions.size();
+        int minQuota = (int) Math.floor(((double) unassignedPartitions.size()) / numberOfConsumers);
+        int maxQuota = (int) Math.ceil(((double) unassignedPartitions.size()) / numberOfConsumers);
+
+        // initialize the assignment map with an empty array of size minQuota for all members
+        Map<String, List<TopicPartition>> assignment = new HashMap<>(
+            consumerToOwnedPartitions.keySet().stream().collect(Collectors.toMap(c -> c, c -> new ArrayList<>(minQuota))));
+
+        for (Map.Entry<String, List<TopicPartition>> consumerEntry : consumerToOwnedPartitions.entrySet()) {
+            String consumer = consumerEntry.getKey();
+            List<TopicPartition> ownedPartitions = consumerEntry.getValue();
+
+            List<TopicPartition> consumerAssignment = assignment.get(consumer);
+            int i = 0;
+            // assign the first N partitions up to the max quota, and mark the remaining as being revoked
+            for (TopicPartition tp : ownedPartitions) {
+                if (i < maxQuota) {
+                    consumerAssignment.add(tp);
+                    unassignedPartitions.remove(tp);
+                } else {
+                    allRevokedPartitions.add(tp);
+                }
+                ++i;
+            }
+
+            if (ownedPartitions.size() < minQuota) {
+                unfilledMembers.add(consumer);
+            } else {
+                // It's possible for a consumer to be at both min and max capacity if minQuota == maxQuota
+                if (consumerAssignment.size() == minQuota)
+                    minCapacityMembers.add(consumer);
+                if (consumerAssignment.size() == maxQuota)
+                    maxCapacityMembers.add(consumer);
+            }
+        }
+
+        Collections.sort(unfilledMembers);
+        Iterator<TopicPartition> unassignedPartitionsIter = unassignedPartitions.iterator();
+
+        while (!unfilledMembers.isEmpty() && !unassignedPartitions.isEmpty()) {
+            Iterator<String> unfilledConsumerIter = unfilledMembers.iterator();
+
+            while (unfilledConsumerIter.hasNext()) {
+                String consumer = unfilledConsumerIter.next();
+                List<TopicPartition> consumerAssignment = assignment.get(consumer);
+
+                if (unassignedPartitionsIter.hasNext()) {
+                    TopicPartition tp = unassignedPartitionsIter.next();
+                    consumerAssignment.add(tp);
+                    unassignedPartitionsIter.remove();
+                    // We already assigned all possible ownedPartitions, so we know this must be newly to this consumer
+                    if (allRevokedPartitions.contains(tp))
+                        partitionsTransferringOwnership.put(tp, consumer);
+                } else {
+                    break;
+                }
+
+                if (consumerAssignment.size() == minQuota) {
+                    minCapacityMembers.add(consumer);
+                    unfilledConsumerIter.remove();
+                }
+            }
+        }
+
+        // If we ran out of unassigned partitions before filling all consumers, we need to start stealing partitions
+        // from the over-full consumers at max capacity
+        for (String consumer : unfilledMembers) {
+            List<TopicPartition> consumerAssignment = assignment.get(consumer);
+            int remainingCapacity = minQuota - consumerAssignment.size();
+            while (remainingCapacity > 0) {
+                String overloadedConsumer = maxCapacityMembers.poll();
+                if (overloadedConsumer == null) {
+                    throw new IllegalStateException("Some consumers are under capacity but all partitions have been assigned");
+                }
+                TopicPartition swappedPartition = assignment.get(overloadedConsumer).remove(0);
+                consumerAssignment.add(swappedPartition);
+                --remainingCapacity;
+                // This partition is by definition transferring ownership, the swapped partition must have come from
+                // the max capacity member's owned partitions since it can only reach max capacity with owned partitions
+                partitionsTransferringOwnership.put(swappedPartition, consumer);
+            }
+            minCapacityMembers.add(consumer);
+        }
+
+        // Otherwise we may have run out of unfilled consumers before assigning all partitions, in which case we
+        // should just distribute one partition each to all consumers at min capacity
+        for (TopicPartition unassignedPartition : unassignedPartitions) {
+            String underCapacityConsumer = minCapacityMembers.poll();
+            if (underCapacityConsumer == null) {
+                throw new IllegalStateException("Some partitions are unassigned but all consumers are at maximum capacity");
+            }
+            // We can skip the bookkeeping of unassignedPartitions and maxCapacityMembers here since we are at the end
+            assignment.get(underCapacityConsumer).add(unassignedPartition);
+
+            if (allRevokedPartitions.contains(unassignedPartition))
+                partitionsTransferringOwnership.put(unassignedPartition, underCapacityConsumer);
+        }
+
+        return assignment;
+    }
+
+    private SortedSet<TopicPartition> getTopicPartitions(Map<String, Integer> partitionsPerTopic) {
+        SortedSet<TopicPartition> allPartitions =
+            new TreeSet<>(Comparator.comparing(TopicPartition::topic).thenComparing(TopicPartition::partition));
+        for (Entry<String, Integer> entry: partitionsPerTopic.entrySet()) {
+            String topic = entry.getKey();
+            for (int i = 0; i < entry.getValue(); ++i) {
+                allPartitions.add(new TopicPartition(topic, i));
+            }
+        }
+        return allPartitions;
+    }
+
+    private Map<String, List<TopicPartition>> generalAssign(Map<String, Integer> partitionsPerTopic,
+                                                            Map<String, Subscription> subscriptions) {
         Map<String, List<TopicPartition>> currentAssignment = new HashMap<>();
         Map<TopicPartition, ConsumerGenerationPair> prevAssignment = new HashMap<>();
-        partitionMovements = new PartitionMovements();
 
         prepopulateCurrentAssignments(subscriptions, currentAssignment, prevAssignment);
         boolean isFreshAssignment = currentAssignment.isEmpty();
@@ -105,8 +309,7 @@ public abstract class AbstractStickyAssignor extends AbstractPartitionAssignor {
             for (TopicPartition topicPartition: entry.getValue())
                 currentPartitionConsumer.put(topicPartition, entry.getKey());
 
-        List<TopicPartition> sortedPartitions = sortPartitions(
-            currentAssignment, prevAssignment.keySet(), isFreshAssignment, partition2AllPotentialConsumers, consumer2AllPotentialPartitions);
+        List<TopicPartition> sortedPartitions = sortPartitions(partition2AllPotentialConsumers);
 
         // all partitions that need to be assigned (initially set to all partitions but adjusted in the following loop)
         List<TopicPartition> unassignedPartitions = new ArrayList<>(sortedPartitions);
@@ -287,96 +490,16 @@ public abstract class AbstractStickyAssignor extends AbstractPartitionAssignor {
      * Sort valid partitions so they are processed in the potential reassignment phase in the proper order
      * that causes minimal partition movement among consumers (hence honoring maximal stickiness)
      *
-     * @param currentAssignment the calculated assignment so far
-     * @param partitionsWithADifferentPreviousAssignment partitions that had a different consumer before (for every
-     *                                                   such partition there should also be a mapping in
-     *                                                   @currentAssignment to a different consumer)
-     * @param isFreshAssignment whether this is a new assignment, or a reassignment of an existing one
      * @param partition2AllPotentialConsumers a mapping of partitions to their potential consumers
-     * @param consumer2AllPotentialPartitions a mapping of consumers to potential partitions they can consumer from
-     * @return sorted list of valid partitions
+     * @return  an ascending sorted list of topic partitions based on how many consumers can potentially use them
      */
-    private List<TopicPartition> sortPartitions(Map<String, List<TopicPartition>> currentAssignment,
-                                                Set<TopicPartition> partitionsWithADifferentPreviousAssignment,
-                                                boolean isFreshAssignment,
-                                                Map<TopicPartition, List<String>> partition2AllPotentialConsumers,
-                                                Map<String, List<TopicPartition>> consumer2AllPotentialPartitions) {
-        List<TopicPartition> sortedPartitions = new ArrayList<>();
-
-        if (!isFreshAssignment && areSubscriptionsIdentical(partition2AllPotentialConsumers, consumer2AllPotentialPartitions)) {
-            // if this is a reassignment and the subscriptions are identical (all consumers can consumer from all topics)
-            // then we just need to simply list partitions in a round robin fashion (from consumers with
-            // most assigned partitions to those with least)
-            Map<String, List<TopicPartition>> assignments = deepCopy(currentAssignment);
-            for (Entry<String, List<TopicPartition>> entry: assignments.entrySet()) {
-                List<TopicPartition> toRemove = new ArrayList<>();
-                for (TopicPartition partition: entry.getValue())
-                    if (!partition2AllPotentialConsumers.keySet().contains(partition))
-                        toRemove.add(partition);
-                for (TopicPartition partition: toRemove)
-                    entry.getValue().remove(partition);
-            }
-            TreeSet<String> sortedConsumers = new TreeSet<>(new SubscriptionComparator(assignments));
-            sortedConsumers.addAll(assignments.keySet());
-            // at this point, sortedConsumers contains an ascending-sorted list of consumers based on
-            // how many valid partitions are currently assigned to them
-
-            while (!sortedConsumers.isEmpty()) {
-                // take the consumer with the most partitions
-                String consumer = sortedConsumers.pollLast();
-                // currently assigned partitions to this consumer
-                List<TopicPartition> remainingPartitions = assignments.get(consumer);
-                // partitions that were assigned to a different consumer last time
-                List<TopicPartition> prevPartitions = new ArrayList<>(partitionsWithADifferentPreviousAssignment);
-                // from partitions that had a different consumer before, keep only those that are
-                // assigned to this consumer now
-                prevPartitions.retainAll(remainingPartitions);
-                if (!prevPartitions.isEmpty()) {
-                    // if there is a partition of this consumer that was assigned to another consumer before
-                    // mark it as good options for reassignment
-                    TopicPartition partition = prevPartitions.remove(0);
-                    remainingPartitions.remove(partition);
-                    sortedPartitions.add(partition);
-                    sortedConsumers.add(consumer);
-                } else if (!remainingPartitions.isEmpty()) {
-                    // otherwise, mark any other one of the current partitions as a reassignment candidate
-                    sortedPartitions.add(remainingPartitions.remove(0));
-                    sortedConsumers.add(consumer);
-                }
-            }
-
-            for (TopicPartition partition: partition2AllPotentialConsumers.keySet()) {
-                if (!sortedPartitions.contains(partition))
-                    sortedPartitions.add(partition);
-            }
-
-        } else {
-            // an ascending sorted set of topic partitions based on how many consumers can potentially use them
-            TreeSet<TopicPartition> sortedAllPartitions = new TreeSet<>(new PartitionComparator(partition2AllPotentialConsumers));
-            sortedAllPartitions.addAll(partition2AllPotentialConsumers.keySet());
-
-            while (!sortedAllPartitions.isEmpty())
-                sortedPartitions.add(sortedAllPartitions.pollFirst());
-        }
-
+    private List<TopicPartition> sortPartitions(Map<TopicPartition, List<String>> partition2AllPotentialConsumers) {
+        List<TopicPartition> sortedPartitions = new ArrayList<>(partition2AllPotentialConsumers.keySet());
+        Collections.sort(sortedPartitions, new PartitionComparator(partition2AllPotentialConsumers));
         return sortedPartitions;
     }
 
     /**
-     * @param partition2AllPotentialConsumers a mapping of partitions to their potential consumers
-     * @param consumer2AllPotentialPartitions a mapping of consumers to potential partitions they can consumer from
-     * @return true if potential consumers of partitions are the same, and potential partitions consumers can
-     * consumer from are the same too
-     */
-    private boolean areSubscriptionsIdentical(Map<TopicPartition, List<String>> partition2AllPotentialConsumers,
-        Map<String, List<TopicPartition>> consumer2AllPotentialPartitions) {
-        if (!hasIdenticalListElements(partition2AllPotentialConsumers.values()))
-            return false;
-
-        return hasIdenticalListElements(consumer2AllPotentialPartitions.values());
-    }
-
-    /**
      * The assignment should improve the overall balance of the partition assignments to consumers.
      */
     private void assignPartition(TopicPartition partition,
@@ -601,24 +724,6 @@ public abstract class AbstractStickyAssignor extends AbstractPartitionAssignor {
         return partitionMovements.isSticky();
     }
 
-    /**
-     * @param col a collection of elements of type list
-     * @return true if all lists in the collection have the same members; false otherwise
-     */
-    private <T> boolean hasIdenticalListElements(Collection<List<T>> col) {
-        Iterator<List<T>> it = col.iterator();
-        if (!it.hasNext())
-            return true;
-        List<T> cur = it.next();
-        while (it.hasNext()) {
-            List<T> next = it.next();
-            if (!(cur.containsAll(next) && next.containsAll(cur)))
-                return false;
-            cur = next;
-        }
-        return true;
-    }
-
     private void deepCopy(Map<String, List<TopicPartition>> source, Map<String, List<TopicPartition>> dest) {
         dest.clear();
         for (Entry<String, List<TopicPartition>> entry: source.entrySet())
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/StickyAssignorTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/StickyAssignorTest.java
index 01a8f3e..fb89944 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/StickyAssignorTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/StickyAssignorTest.java
@@ -169,10 +169,10 @@ public class StickyAssignorTest extends AbstractStickyAssignorTest {
         TopicPartition tp5 = new TopicPartition(topic, 5);
 
         List<TopicPartition> c1partitions0 = partitions(tp0, tp1, tp4);
-        List<TopicPartition> c2partitions0 = partitions(tp0, tp2, tp3);
+        List<TopicPartition> c2partitions0 = partitions(tp0, tp1, tp2);
         List<TopicPartition> c3partitions0 = partitions(tp3, tp4, tp5);
         subscriptions.put(consumer1, buildSubscriptionWithGeneration(topics(topic), c1partitions0, 1));
-        subscriptions.put(consumer2, buildSubscriptionWithGeneration(topics(topic), c2partitions0, 1));
+        subscriptions.put(consumer2, buildSubscriptionWithGeneration(topics(topic), c2partitions0, 2));
         subscriptions.put(consumer3, buildSubscriptionWithGeneration(topics(topic), c3partitions0, 2));
 
         Map<String, List<TopicPartition>> assignment = assignor.assign(partitionsPerTopic, subscriptions);
@@ -181,7 +181,6 @@ public class StickyAssignorTest extends AbstractStickyAssignorTest {
         List<TopicPartition> c3partitions = assignment.get(consumer3);
 
         assertTrue(c1partitions.size() == 2 && c2partitions.size() == 2 && c3partitions.size() == 2);
-        assertTrue(c1partitions0.containsAll(c1partitions));
         assertTrue(c2partitions0.containsAll(c2partitions));
         assertTrue(c3partitions0.containsAll(c3partitions));
         verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic);
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractStickyAssignorTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractStickyAssignorTest.java
index 52f6747..c7b4523 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractStickyAssignorTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractStickyAssignorTest.java
@@ -33,12 +33,13 @@ import org.apache.kafka.common.utils.Utils;
 import org.junit.Before;
 import org.junit.Test;
 
+import static org.apache.kafka.common.utils.Utils.mkEntry;
+import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 
 public abstract class AbstractStickyAssignorTest {
-
     protected AbstractStickyAssignor assignor;
     protected String consumerId = "consumer";
     protected Map<String, Subscription> subscriptions;
@@ -105,12 +106,16 @@ public abstract class AbstractStickyAssignorTest {
         String otherTopic = "other";
 
         Map<String, Integer> partitionsPerTopic = new HashMap<>();
-        partitionsPerTopic.put(topic, 3);
-        partitionsPerTopic.put(otherTopic, 3);
-        subscriptions = Collections.singletonMap(consumerId, new Subscription(topics(topic)));
+        partitionsPerTopic.put(topic, 2);
+        subscriptions = mkMap(
+                mkEntry(consumerId, buildSubscription(
+                        topics(topic),
+                        Arrays.asList(tp(topic, 0), tp(topic, 1), tp(otherTopic, 0), tp(otherTopic, 1)))
+                )
+        );
 
         Map<String, List<TopicPartition>> assignment = assignor.assign(partitionsPerTopic, subscriptions);
-        assertEquals(partitions(tp(topic, 0), tp(topic, 1), tp(topic, 2)), assignment.get(consumerId));
+        assertEquals(partitions(tp(topic, 0), tp(topic, 1)), assignment.get(consumerId));
 
         verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic);
         assertTrue(isFullyBalanced(assignment));
@@ -145,8 +150,6 @@ public abstract class AbstractStickyAssignorTest {
         subscriptions.put(consumer2, new Subscription(topics(topic)));
 
         Map<String, List<TopicPartition>> assignment = assignor.assign(partitionsPerTopic, subscriptions);
-        assertEquals(partitions(tp(topic, 0)), assignment.get(consumer1));
-        assertEquals(Collections.<TopicPartition>emptyList(), assignment.get(consumer2));
 
         verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic);
         assertTrue(isFullyBalanced(assignment));
@@ -238,8 +241,8 @@ public abstract class AbstractStickyAssignorTest {
         assignment = assignor.assign(partitionsPerTopic, subscriptions);
 
         verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic);
-        assertEquals(partitions(tp(topic, 2), tp(topic, 1)), assignment.get(consumer1));
-        assertEquals(partitions(tp(topic, 0)), assignment.get(consumer2));
+        assertEquals(partitions(tp(topic, 0), tp(topic, 1)), assignment.get(consumer1));
+        assertEquals(partitions(tp(topic, 2)), assignment.get(consumer2));
         assertTrue(isFullyBalanced(assignment));
         assertTrue(assignor.isSticky());
 
@@ -425,8 +428,37 @@ public abstract class AbstractStickyAssignorTest {
         assertTrue(assignor.isSticky());
     }
 
+    @Test(timeout = 30 * 1000)
+    public void testLargeAssignmentAndGroupWithUniformSubscription() {
+        // 1 million partitions!
+        int topicCount = 500;
+        int partitionCount = 2_000;
+        int consumerCount = 2_000;
+
+        List<String> topics = new ArrayList<>();
+        Map<String, Integer> partitionsPerTopic = new HashMap<>();
+        for (int i = 0; i < topicCount; i++) {
+            String topicName = getTopicName(i, topicCount);
+            topics.add(topicName);
+            partitionsPerTopic.put(topicName, partitionCount);
+        }
+
+        for (int i = 0; i < consumerCount; i++) {
+            subscriptions.put(getConsumerName(i, consumerCount), new Subscription(topics));
+        }
+
+        Map<String, List<TopicPartition>> assignment = assignor.assign(partitionsPerTopic, subscriptions);
+
+        for (int i = 1; i < consumerCount; i++) {
+            String consumer = getConsumerName(i, consumerCount);
+            subscriptions.put(consumer, buildSubscription(topics, assignment.get(consumer)));
+        }
+
+        assignor.assign(partitionsPerTopic, subscriptions);
+    }
+
     @Test
-    public void testLargeAssignmentWithMultipleConsumersLeaving() {
+    public void testLargeAssignmentWithMultipleConsumersLeavingAndRandomSubscription() {
         Random rand = new Random();
         int topicCount = 40;
         int consumerCount = 200;
@@ -555,7 +587,6 @@ public abstract class AbstractStickyAssignorTest {
         }
     }
 
-
     @Test
     public void testAssignmentUpdatedForDeletedTopic() {
         Map<String, Integer> partitionsPerTopic = new HashMap<>();
@@ -583,35 +614,6 @@ public abstract class AbstractStickyAssignorTest {
     }
 
     @Test
-    public void testConflictingPreviousAssignments() {
-        String consumer1 = "consumer1";
-        String consumer2 = "consumer2";
-
-        Map<String, Integer> partitionsPerTopic = new HashMap<>();
-        partitionsPerTopic.put(topic, 2);
-        subscriptions.put(consumer1, new Subscription(topics(topic)));
-        subscriptions.put(consumer2, new Subscription(topics(topic)));
-
-        TopicPartition tp0 = new TopicPartition(topic, 0);
-        TopicPartition tp1 = new TopicPartition(topic, 1);
-
-        // both c1 and c2 have partition 1 assigned to them in generation 1
-        List<TopicPartition> c1partitions0 = partitions(tp0, tp1);
-        List<TopicPartition> c2partitions0 = partitions(tp0, tp1);
-        subscriptions.put(consumer1, buildSubscription(topics(topic), c1partitions0));
-        subscriptions.put(consumer2, buildSubscription(topics(topic), c2partitions0));
-
-        Map<String, List<TopicPartition>> assignment = assignor.assign(partitionsPerTopic, subscriptions);
-        List<TopicPartition> c1partitions = assignment.get(consumer1);
-        List<TopicPartition> c2partitions = assignment.get(consumer2);
-
-        assertTrue(c1partitions.size() == 1 && c2partitions.size() == 1);
-        verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic);
-        assertTrue(isFullyBalanced(assignment));
-        assertTrue(assignor.isSticky());
-    }
-
-    @Test
     public void testReassignmentWithRandomSubscriptionsAndChanges() {
         final int minNumConsumers = 20;
         final int maxNumConsumers = 40;