You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by vv...@apache.org on 2020/06/10 15:11:56 UTC

[kafka] branch 2.6 updated: KAFKA-10079: improve thread-level stickiness (#8775)

This is an automated email from the ASF dual-hosted git repository.

vvcephei pushed a commit to branch 2.6
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/2.6 by this push:
     new b146248  KAFKA-10079: improve thread-level stickiness (#8775)
b146248 is described below

commit b146248442de7199b3458af479a2fcdc061fb127
Author: A. Sophie Blee-Goldman <so...@confluent.io>
AuthorDate: Wed Jun 10 07:56:06 2020 -0700

    KAFKA-10079: improve thread-level stickiness (#8775)
    
    Uses a similar (but slightly different) algorithm as in KAFKA-9987 to produce a maximally sticky -- and perfectly balanced -- assignment of tasks to threads within a single client. This is important for in-memory stores which get wiped out when transferred between threads.
    
    Reviewers: John Roesler <vv...@apache.org>
---
 .../internals/StreamsPartitionAssignor.java        | 391 ++++++++-------------
 .../internals/assignment/ClientState.java          |  41 ++-
 .../internals/StreamsPartitionAssignorTest.java    | 252 +++++++------
 .../internals/assignment/ClientStateTest.java      |  51 ++-
 .../assignment/TaskAssignorConvergenceTest.java    |   2 +-
 5 files changed, 341 insertions(+), 396 deletions(-)

diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java
index a6fbdfb..3f2cc87 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java
@@ -16,6 +16,11 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
+import java.util.Iterator;
+import java.util.PriorityQueue;
+import java.util.Queue;
+import java.util.SortedSet;
+import java.util.TreeSet;
 import org.apache.kafka.clients.admin.Admin;
 import org.apache.kafka.clients.admin.ListOffsetsResult.ListOffsetsResultInfo;
 import org.apache.kafka.clients.consumer.ConsumerGroupMetadata;
@@ -54,24 +59,24 @@ import java.util.Collections;
 import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
-import java.util.Iterator;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 import java.util.Set;
-import java.util.TreeMap;
 import java.util.UUID;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.function.Supplier;
 import java.util.stream.Collectors;
 
+import static java.util.Comparator.comparingLong;
 import static java.util.UUID.randomUUID;
 import static org.apache.kafka.streams.processor.internals.ClientUtils.fetchEndOffsets;
 import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.EARLIEST_PROBEABLE_VERSION;
 import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.LATEST_SUPPORTED_VERSION;
 import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.UNKNOWN;
+import static org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo.UNKNOWN_OFFSET_SUM;
 
 public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Configurable {
 
@@ -112,7 +117,7 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
     private static class ClientMetadata {
 
         private final HostInfo hostInfo;
-        private final Set<String> consumers;
+        private final SortedSet<String> consumers;
         private final ClientState state;
 
         ClientMetadata(final String endPoint) {
@@ -121,7 +126,7 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
             hostInfo = HostInfo.buildFromEndpoint(endPoint);
 
             // initialize the consumer memberIds
-            consumers = new HashSet<>();
+            consumers = new TreeSet<>();
 
             // initialize the client state
             state = new ClientState();
@@ -133,8 +138,8 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
             state.addOwnedPartitions(ownedPartitions, consumerMemberId);
         }
 
-        void addPreviousTasksAndOffsetSums(final Map<TaskId, Long> taskOffsetSums) {
-            state.addPreviousTasksAndOffsetSums(taskOffsetSums);
+        void addPreviousTasksAndOffsetSums(final String consumerId, final Map<TaskId, Long> taskOffsetSums) {
+            state.addPreviousTasksAndOffsetSums(consumerId, taskOffsetSums);
         }
 
         @Override
@@ -318,7 +323,7 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
             // add the consumer and any info in its subscription to the client
             clientMetadata.addConsumer(consumerId, subscription.ownedPartitions());
             allOwnedPartitions.addAll(subscription.ownedPartitions());
-            clientMetadata.addPreviousTasksAndOffsetSums(info.taskOffsetSums());
+            clientMetadata.addPreviousTasksAndOffsetSums(consumerId, info.taskOffsetSums());
         }
 
         final boolean versionProbing =
@@ -380,19 +385,7 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
 
         // compute the assignment of tasks to threads within each client and build the final group assignment
 
-        final Map<String, Assignment> assignment;
-        if (versionProbing) {
-            assignment = versionProbingAssignment(
-                clientMetadataMap,
-                partitionsForTask,
-                partitionsByHost,
-                standbyPartitionsByHost,
-                allOwnedPartitions,
-                minReceivedMetadataVersion,
-                minSupportedMetadataVersion
-            );
-        } else {
-            assignment = computeNewAssignment(
+        final Map<String, Assignment> assignment = computeNewAssignment(
                 clientMetadataMap,
                 partitionsForTask,
                 partitionsByHost,
@@ -400,9 +393,9 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
                 allOwnedPartitions,
                 minReceivedMetadataVersion,
                 minSupportedMetadataVersion,
+                versionProbing,
                 probingRebalanceNeeded
-            );
-        }
+        );
 
         return new GroupAssignment(assignment);
     }
@@ -775,7 +768,7 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
             allTaskEndOffsetSums = computeEndOffsetSumsByTask(endOffsets, changelogsByStatefulTask, allNewlyCreatedChangelogPartitions);
             fetchEndOffsetsSuccessful = true;
         } catch (final StreamsException e) {
-            allTaskEndOffsetSums = null;
+            allTaskEndOffsetSums = changelogsByStatefulTask.keySet().stream().collect(Collectors.toMap(t -> t, t -> UNKNOWN_OFFSET_SUM));
             fetchEndOffsetsSuccessful = false;
         }
 
@@ -784,9 +777,7 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
             final ClientState state = entry.getValue().state;
             state.initializePrevTasks(taskForPartition);
 
-            if (fetchEndOffsetsSuccessful) {
-                state.computeTaskLags(uuid, allTaskEndOffsetSums);
-            }
+            state.computeTaskLags(uuid, allTaskEndOffsetSums);
             clientStates.put(uuid, state);
         }
         return fetchEndOffsetsSuccessful;
@@ -878,8 +869,9 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
                                                          final Set<TopicPartition> allOwnedPartitions,
                                                          final int minUserMetadataVersion,
                                                          final int minSupportedMetadataVersion,
+                                                         final boolean versionProbing,
                                                          final boolean shouldTriggerProbingRebalance) {
-        boolean rebalanceRequired = shouldTriggerProbingRebalance;
+        boolean rebalanceRequired = shouldTriggerProbingRebalance || versionProbing;
         final Map<String, Assignment> assignment = new HashMap<>();
 
         // within the client, distribute tasks to its owned consumers
@@ -887,41 +879,40 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
             final UUID clientId = clientEntry.getKey();
             final ClientMetadata clientMetadata = clientEntry.getValue();
             final ClientState state = clientMetadata.state;
-            final Set<String> consumers = clientMetadata.consumers;
-            Map<String, List<TaskId>> activeTaskAssignments;
-
-            // Try to avoid triggering another rebalance by giving active tasks back to their previous owners within a
-            // client, without violating load balance. If we already know another rebalance will be required, or the
-            // client had no owned partitions, try to balance the workload as evenly as possible by interleaving tasks
-            if (rebalanceRequired || state.ownedPartitions().isEmpty()) {
-                activeTaskAssignments = interleaveConsumerTasksByGroupId(state.activeTasks(), consumers);
-            } else if ((activeTaskAssignments = tryStickyAndBalancedTaskAssignmentWithinClient(state, consumers, partitionsForTask, allOwnedPartitions))
-                        .equals(Collections.emptyMap())) {
-                rebalanceRequired = true;
-                activeTaskAssignments = interleaveConsumerTasksByGroupId(state.activeTasks(), consumers);
-            }
+            final SortedSet<String> consumers = clientMetadata.consumers;
 
-            final Map<String, List<TaskId>> interleavedStandby =
-                interleaveConsumerTasksByGroupId(state.standbyTasks(), consumers);
+            final Map<String, List<TaskId>> activeTaskAssignment = assignTasksToThreads(
+                state.statefulActiveTasks(),
+                state.statelessActiveTasks(),
+                consumers,
+                state
+            );
+
+            final Map<String, List<TaskId>> standbyTaskAssignment = assignTasksToThreads(
+                state.standbyTasks(),
+                Collections.emptySet(),
+                consumers,
+                state
+            );
 
             // Arbitrarily choose the leader's client to be responsible for triggering the probing rebalance
-            final boolean encodeNextProbingRebalanceTime = clientId.equals(taskManager.processId()) && shouldTriggerProbingRebalance;
+            final boolean encodeNextProbingRebalanceTime = shouldTriggerProbingRebalance && clientId.equals(taskManager.processId());
 
-            final boolean followupRebalanceScheduled = addClientAssignments(
+            final boolean tasksRevoked = addClientAssignments(
                 assignment,
                 clientMetadata,
                 partitionsForTask,
                 partitionsByHostState,
                 standbyPartitionsByHost,
                 allOwnedPartitions,
-                activeTaskAssignments,
-                interleavedStandby,
+                activeTaskAssignment,
+                standbyTaskAssignment,
                 minUserMetadataVersion,
                 minSupportedMetadataVersion,
-                false,
-                encodeNextProbingRebalanceTime);
+                encodeNextProbingRebalanceTime
+            );
 
-            if (followupRebalanceScheduled) {
+            if (tasksRevoked || encodeNextProbingRebalanceTime) {
                 rebalanceRequired = true;
                 log.debug("Requested client {} to schedule a followup rebalance", clientId);
             }
@@ -939,56 +930,8 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
     }
 
     /**
-     * Computes the assignment of tasks to threads within each client and assembles the final assignment to send out,
-     * in the special case of version probing where some members are on different versions and have sent different
-     * subscriptions.
-     *
-     * @return the final assignment for each StreamThread consumer
-     */
-    private Map<String, Assignment> versionProbingAssignment(final Map<UUID, ClientMetadata> clientsMetadata,
-                                                             final Map<TaskId, Set<TopicPartition>> partitionsForTask,
-                                                             final Map<HostInfo, Set<TopicPartition>> partitionsByHost,
-                                                             final Map<HostInfo, Set<TopicPartition>> standbyPartitionsByHost,
-                                                             final Set<TopicPartition> allOwnedPartitions,
-                                                             final int minUserMetadataVersion,
-                                                             final int minSupportedMetadataVersion) {
-        final Map<String, Assignment> assignment = new HashMap<>();
-
-        // Since we know another rebalance will be triggered anyway, just try and generate a balanced assignment
-        // (without violating cooperative protocol) now so that on the second rebalance we can just give tasks
-        // back to their previous owners
-        // within the client, distribute tasks to its owned consumers
-        for (final ClientMetadata clientMetadata : clientsMetadata.values()) {
-            final ClientState state = clientMetadata.state;
-
-            final Map<String, List<TaskId>> interleavedActive =
-                interleaveConsumerTasksByGroupId(state.activeTasks(), clientMetadata.consumers);
-            final Map<String, List<TaskId>> interleavedStandby =
-                interleaveConsumerTasksByGroupId(state.standbyTasks(), clientMetadata.consumers);
-
-            addClientAssignments(
-                assignment,
-                clientMetadata,
-                partitionsForTask,
-                partitionsByHost,
-                standbyPartitionsByHost,
-                allOwnedPartitions,
-                interleavedActive,
-                interleavedStandby,
-                minUserMetadataVersion,
-                minSupportedMetadataVersion,
-                true,
-                false);
-        }
-
-        log.info("Finished unstable assignment of tasks, a followup rebalance will be scheduled due to version probing.");
-
-        return assignment;
-    }
-
-    /**
      * Adds the encoded assignment for each StreamThread consumer in the client to the overall assignment map
-     * @return true if this client has been told to schedule a followup rebalance
+     * @return true if a followup rebalance will be required due to revoked tasks
      */
     private boolean addClientAssignments(final Map<String, Assignment> assignment,
                                          final ClientMetadata clientMetadata,
@@ -1000,9 +943,10 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
                                          final Map<String, List<TaskId>> standbyTaskAssignments,
                                          final int minUserMetadataVersion,
                                          final int minSupportedMetadataVersion,
-                                         final boolean versionProbing,
                                          final boolean probingRebalanceNeeded) {
-        boolean rebalanceRequested = probingRebalanceNeeded || versionProbing;
+        boolean followupRebalanceRequiredForRevokedTasks = false;
+
+        // We only want to encode a scheduled probing rebalance for a single member in this client
         boolean shouldEncodeProbingRebalance = probingRebalanceNeeded;
 
         // Loop through the consumers and build their assignment
@@ -1020,7 +964,8 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
                 clientMetadata.state,
                 activeTasksForConsumer,
                 partitionsForTask,
-                allOwnedPartitions);
+                allOwnedPartitions
+            );
 
             final Map<TaskId, Set<TopicPartition>> standbyTaskMap =
                 buildStandbyTaskMap(standbyTaskAssignments.get(consumer), partitionsForTask);
@@ -1036,13 +981,15 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
             );
 
             if (tasksRevoked) {
-                // TODO: once KAFKA-9821 is resolved we can leave it to the client to trigger this rebalance
-                log.debug("Requesting followup rebalance be scheduled immediately due to tasks changing ownership.");
+                // TODO: once KAFKA-10078 is resolved we can leave it to the client to trigger this rebalance
+                log.info("Requesting followup rebalance be scheduled immediately due to tasks changing ownership.");
                 info.setNextRebalanceTime(0L);
-                rebalanceRequested = true;
+                followupRebalanceRequiredForRevokedTasks = true;
+                // Don't bother to schedule a probing rebalance if an immediate one is already scheduled
+                shouldEncodeProbingRebalance = false;
             } else if (shouldEncodeProbingRebalance) {
                 final long nextRebalanceTimeMs = time.milliseconds() + probingRebalanceIntervalMs();
-                log.debug("Requesting followup rebalance be scheduled for {} ms to probe for caught-up replica tasks.", nextRebalanceTimeMs);
+                log.info("Requesting followup rebalance be scheduled for {} ms to probe for caught-up replica tasks.", nextRebalanceTimeMs);
                 info.setNextRebalanceTime(nextRebalanceTimeMs);
                 shouldEncodeProbingRebalance = false;
             }
@@ -1055,7 +1002,7 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
                 )
             );
         }
-        return rebalanceRequested;
+        return followupRebalanceRequiredForRevokedTasks;
     }
 
     /**
@@ -1078,7 +1025,7 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
         for (final TaskId taskId : activeTasksForConsumer) {
             final List<AssignedPartition> assignedPartitionsForTask = new ArrayList<>();
             for (final TopicPartition partition : partitionsForTask.get(taskId)) {
-                final String oldOwner = clientState.ownedPartitions().get(partition);
+                final String oldOwner = clientState.previousOwnerForPartition(partition);
                 final boolean newPartitionForConsumer = oldOwner == null || !oldOwner.equals(consumer);
 
                 // If the partition is new to this consumer but is still owned by another, remove from the assignment
@@ -1121,172 +1068,124 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
     }
 
     /**
-     * Generates an assignment that tries to satisfy two conditions: no active task previously owned by a consumer
-     * be assigned to another (ie nothing gets revoked), and the number of tasks is evenly distributed throughout
-     * the client.
-     * <p>
-     * If it is impossible to satisfy both constraints we abort early and return an empty map so we can use a
-     * different assignment strategy that tries to distribute tasks of a single subtopology across different threads.
-     *
-     * @param state state for this client
-     * @param consumers the consumers in this client
-     * @param partitionsForTask mapping from task to its associated partitions
-     * @param allOwnedPartitions set of all partitions claimed as owned by the group
-     * @return task assignment for the consumers of this client
-     *         empty map if it is not possible to generate a balanced assignment without moving a task to a new consumer
+     * Generate an assignment that tries to preserve thread-level stickiness of stateful tasks without violating
+     * balance. The stateful and total task load are both balanced across threads. Tasks without previous owners
+     * will be interleaved by group id to spread subtopologies across threads and further balance the workload.
      */
-    Map<String, List<TaskId>> tryStickyAndBalancedTaskAssignmentWithinClient(final ClientState state,
-                                                                             final Set<String> consumers,
-                                                                             final Map<TaskId, Set<TopicPartition>> partitionsForTask,
-                                                                             final Set<TopicPartition> allOwnedPartitions) {
-        final Map<String, List<TaskId>> assignments = new HashMap<>();
-        final LinkedList<TaskId> newTasks = new LinkedList<>();
-        final Set<String> unfilledConsumers = new HashSet<>(consumers);
-
-        final int maxTasksPerClient = (int) Math.ceil(((double) state.activeTaskCount()) / consumers.size());
-
-        // initialize task list for consumers
+    static Map<String, List<TaskId>> assignTasksToThreads(final Collection<TaskId> statefulTasksToAssign,
+                                                          final Collection<TaskId> statelessTasksToAssign,
+                                                          final SortedSet<String> consumers,
+                                                          final ClientState state) {
+        final Map<String, List<TaskId>> assignment = new HashMap<>();
         for (final String consumer : consumers) {
-            assignments.put(consumer, new ArrayList<>());
+            assignment.put(consumer, new ArrayList<>());
         }
 
-        for (final TaskId task : state.activeTasks()) {
-            final Set<String> previousConsumers = previousConsumersOfTaskPartitions(partitionsForTask.get(task), state.ownedPartitions(), allOwnedPartitions);
+        final List<TaskId> unassignedStatelessTasks = new ArrayList<>(statelessTasksToAssign);
+        Collections.sort(unassignedStatelessTasks);
 
-            // If this task's partitions were owned by different consumers, we can't avoid revoking partitions
-            if (previousConsumers.size() > 1) {
-                log.warn("The partitions of task {} were claimed as owned by different StreamThreads. " +
-                    "This indicates the mapping from partitions to tasks has changed!", task);
-                return Collections.emptyMap();
-            }
+        final Iterator<TaskId> unassignedStatelessTasksIter = unassignedStatelessTasks.iterator();
 
-            // If this is a new task, or its old consumer no longer exists, it can be freely (re)assigned
-            if (previousConsumers.isEmpty()) {
-                log.debug("Task {} was not previously owned by any consumers still in the group. It's owner may " +
-                    "have died or it may be a new task", task);
-                newTasks.add(task);
-            } else {
-                final String consumer = previousConsumers.iterator().next();
+        final int minStatefulTasksPerThread = (int) Math.floor(((double) statefulTasksToAssign.size()) / consumers.size());
+        final PriorityQueue<TaskId> unassignedStatefulTasks = new PriorityQueue<>(statefulTasksToAssign);
 
-                // If the previous consumer was from another client, these partitions will have to be revoked
-                if (!consumers.contains(consumer)) {
-                    log.debug("This client was assigned a task {} whose partition(s) were previously owned by another " +
-                        "client, falling back to an interleaved assignment since a rebalance is inevitable.", task);
-                    return Collections.emptyMap();
-                }
+        final Queue<String> consumersToFill = new LinkedList<>();
+        // keep track of tasks that we have to skip during the first pass in case we can reassign them later
+        final Map<TaskId, String> unassignedTaskToPreviousOwner = new HashMap<>();
 
-                // If this consumer previously owned more tasks than it has capacity for, some must be revoked
-                if (assignments.get(consumer).size() >= maxTasksPerClient) {
-                    log.debug("Cannot create a sticky and balanced assignment as this client's consumers owned more " +
-                        "previous tasks than it has capacity for during this assignment, falling back to interleaved " +
-                        "assignment since a realance is inevitable.");
-                    return Collections.emptyMap();
-                }
+        if (!unassignedStatefulTasks.isEmpty()) {
+            // First assign stateful tasks to previous owner, up to the min expected tasks/thread
+            for (final String consumer : consumers) {
+                final List<TaskId> threadAssignment = assignment.get(consumer);
 
-                assignments.get(consumer).add(task);
+                for (final TaskId task : getPreviousTasksByLag(state, consumer)) {
+                    if (unassignedStatefulTasks.contains(task)) {
+                        if (threadAssignment.size() < minStatefulTasksPerThread) {
+                            threadAssignment.add(task);
+                            unassignedStatefulTasks.remove(task);
+                        } else {
+                            unassignedTaskToPreviousOwner.put(task, consumer);
+                        }
+                    }
+                }
 
-                // If we have now reached capacity, remove it from set of consumers who still need more tasks
-                if (assignments.get(consumer).size() == maxTasksPerClient) {
-                    unfilledConsumers.remove(consumer);
+                if (threadAssignment.size() < minStatefulTasksPerThread) {
+                    consumersToFill.offer(consumer);
                 }
             }
-        }
 
-        // Interleave any remaining tasks by groupId among the consumers with remaining capacity. For further
-        // explanation, see the javadocs for #interleaveConsumerTasksByGroupId
-        Collections.sort(newTasks);
-        while (!newTasks.isEmpty()) {
-            if (unfilledConsumers.isEmpty()) {
-                throw new IllegalStateException("Some tasks could not be distributed");
+            // Next interleave remaining unassigned tasks amongst unfilled consumers
+            while (!consumersToFill.isEmpty()) {
+                final TaskId task = unassignedStatefulTasks.poll();
+                if (task != null) {
+                    final String consumer = consumersToFill.poll();
+                    final List<TaskId> threadAssignment = assignment.get(consumer);
+                    threadAssignment.add(task);
+                    if (threadAssignment.size() < minStatefulTasksPerThread) {
+                        consumersToFill.offer(consumer);
+                    }
+                } else {
+                    throw new IllegalStateException("Ran out of unassigned stateful tasks but some members were not at capacity");
+                }
             }
 
-            final Iterator<String> consumerIt = unfilledConsumers.iterator();
+            // At this point all consumers are at the min capacity, so there may be up to N - 1 unassigned
+            // stateful tasks still remaining that should now be distributed over the consumers
+            if (!unassignedStatefulTasks.isEmpty()) {
+                consumersToFill.addAll(consumers);
+
+                // Go over the tasks we skipped earlier and assign them to their previous owner when possible
+                for (final Map.Entry<TaskId, String> taskEntry : unassignedTaskToPreviousOwner.entrySet()) {
+                    final TaskId task = taskEntry.getKey();
+                    final String consumer = taskEntry.getValue();
+                    if (consumersToFill.contains(consumer) && unassignedStatefulTasks.contains(task)) {
+                        assignment.get(consumer).add(task);
+                        unassignedStatefulTasks.remove(task);
+                        // Remove this consumer since we know it is now at minCapacity + 1
+                        consumersToFill.remove(consumer);
+                    }
+                }
 
-            // Loop through the unfilled consumers and distribute tasks until newTasks is empty
-            while (consumerIt.hasNext()) {
-                final String consumer = consumerIt.next();
-                final List<TaskId> consumerAssignment = assignments.get(consumer);
-                final TaskId task = newTasks.poll();
-                if (task == null) {
-                    break;
+                // Now just distribute the remaining unassigned stateful tasks over the consumers still at min capacity
+                for (final TaskId task : unassignedStatefulTasks) {
+                    final String consumer = consumersToFill.poll();
+                    final List<TaskId> threadAssignment = assignment.get(consumer);
+                    threadAssignment.add(task);
                 }
 
-                consumerAssignment.add(task);
-                if (consumerAssignment.size() == maxTasksPerClient) {
-                    consumerIt.remove();
+
+                // There must be at least one consumer still at min capacity while all the others are at min
+                // capacity + 1, so start distributing stateless tasks to get all consumers back to the same count
+                while (unassignedStatelessTasksIter.hasNext()) {
+                    final String consumer = consumersToFill.poll();
+                    if (consumer != null) {
+                        final TaskId task = unassignedStatelessTasksIter.next();
+                        unassignedStatelessTasksIter.remove();
+                        assignment.get(consumer).add(task);
+                    } else {
+                        break;
+                    }
                 }
             }
         }
 
-        return assignments;
-    }
+        // Now just distribute tasks while circling through all the consumers
+        consumersToFill.addAll(consumers);
 
-    /**
-     * Get the previous consumer for the partitions of a task
-     *
-     * @param taskPartitions the TopicPartitions for a single given task
-     * @param clientOwnedPartitions the partitions owned by all consumers in a client
-     * @param allOwnedPartitions all partitions claimed as owned by any consumer in any client
-     * @return set of consumer(s) that previously owned the partitions in this task
-     *         empty set signals that it is a new task, or its previous owner is no longer in the group
-     */
-    private Set<String> previousConsumersOfTaskPartitions(final Set<TopicPartition> taskPartitions,
-                                                          final Map<TopicPartition, String> clientOwnedPartitions,
-                                                          final Set<TopicPartition> allOwnedPartitions) {
-        // this "foreignConsumer" indicates a partition was owned by someone from another client -- we don't really care who
-        final String foreignConsumer = "";
-        final Set<String> previousConsumers = new HashSet<>();
-
-        for (final TopicPartition tp : taskPartitions) {
-            final String currentPartitionConsumer = clientOwnedPartitions.get(tp);
-            if (currentPartitionConsumer != null) {
-                previousConsumers.add(currentPartitionConsumer);
-            } else if (allOwnedPartitions.contains(tp)) {
-                previousConsumers.add(foreignConsumer);
-            }
+        while (unassignedStatelessTasksIter.hasNext()) {
+            final TaskId task = unassignedStatelessTasksIter.next();
+            final String consumer = consumersToFill.poll();
+            assignment.get(consumer).add(task);
+            consumersToFill.offer(consumer);
         }
 
-        return previousConsumers;
+        return assignment;
     }
 
-    /**
-     * Generate an assignment that attempts to maximize load balance without regard for stickiness, by spreading
-     * tasks of the same groupId (subtopology) over different consumers.
-     *
-     * @param taskIds the set of tasks to be distributed
-     * @param consumers the set of consumers to receive tasks
-     * @return a map of task assignments keyed by the consumer id
-     */
-    static Map<String, List<TaskId>> interleaveConsumerTasksByGroupId(final Collection<TaskId> taskIds,
-                                                                      final Set<String> consumers) {
-        // First we make a sorted list of the tasks, grouping them by groupId
-        final LinkedList<TaskId> sortedTasks = new LinkedList<>(taskIds);
-        Collections.sort(sortedTasks);
-
-        // Initialize the assignment map and task list for each consumer. We use a TreeMap here for a consistent
-        // ordering of the consumers in the hope they will end up with the same set of tasks in subsequent assignments
-        final Map<String, List<TaskId>> taskIdsForConsumerAssignment = new TreeMap<>();
-        for (final String consumer : consumers) {
-            taskIdsForConsumerAssignment.put(consumer, new ArrayList<>());
-        }
-
-        // We loop until the tasks have all been assigned, removing them from the list when they are given to a
-        // consumer. To interleave the tasks, we loop through the consumers and give each one task from the head
-        // of the list. When we finish going through the list of consumers we start over at the beginning of the
-        // consumers list, continuing until we run out of tasks.
-        while (!sortedTasks.isEmpty()) {
-            for (final Map.Entry<String, List<TaskId>> consumerTaskIds : taskIdsForConsumerAssignment.entrySet()) {
-                final List<TaskId> taskIdList = consumerTaskIds.getValue();
-                final TaskId taskId = sortedTasks.poll();
-
-                // Check for null here as we may run out of tasks before giving every consumer exactly the same number
-                if (taskId == null) {
-                    break;
-                }
-                taskIdList.add(taskId);
-            }
-        }
-        return taskIdsForConsumerAssignment;
+    private static SortedSet<TaskId> getPreviousTasksByLag(final ClientState state, final String consumer) {
+        final SortedSet<TaskId> prevTasksByLag = new TreeSet<>(comparingLong(state::lagFor).thenComparing(TaskId::compareTo));
+        prevTasksByLag.addAll(state.previousTasksForConsumer(consumer));
+        return prevTasksByLag;
     }
 
     private void validateMetadataVersions(final int receivedAssignmentMetadataVersion,
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
index 4c96ade..616cd42 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
@@ -16,6 +16,7 @@
  */
 package org.apache.kafka.streams.processor.internals.assignment;
 
+import java.util.stream.Collectors;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.internals.Task;
@@ -48,6 +49,7 @@ public class ClientState {
     private final Set<TaskId> prevActiveTasks;
     private final Set<TaskId> prevStandbyTasks;
 
+    private final Map<String, Set<TaskId>> consumerToPreviousStatefulTaskIds;
     private final Map<TopicPartition, String> ownedPartitions;
     private final Map<TaskId, Long> taskOffsetSums; // contains only stateful tasks we previously owned
     private final Map<TaskId, Long> taskLagTotals;  // contains lag for all stateful tasks in the app topology
@@ -63,6 +65,7 @@ public class ClientState {
         standbyTasks = new TreeSet<>();
         prevActiveTasks = new TreeSet<>();
         prevStandbyTasks = new TreeSet<>();
+        consumerToPreviousStatefulTaskIds = new TreeMap<>();
         ownedPartitions = new TreeMap<>(TOPIC_PARTITION_COMPARATOR);
         taskOffsetSums = new TreeMap<>();
         taskLagTotals = new TreeMap<>();
@@ -73,6 +76,7 @@ public class ClientState {
                         final Set<TaskId> standbyTasks,
                         final Set<TaskId> prevActiveTasks,
                         final Set<TaskId> prevStandbyTasks,
+                        final Map<String, Set<TaskId>> consumerToPreviousStatefulTaskIds,
                         final SortedMap<TopicPartition, String> ownedPartitions,
                         final Map<TaskId, Long> taskOffsetSums,
                         final Map<TaskId, Long> taskLagTotals,
@@ -81,6 +85,7 @@ public class ClientState {
         this.standbyTasks = standbyTasks;
         this.prevActiveTasks = prevActiveTasks;
         this.prevStandbyTasks = prevStandbyTasks;
+        this.consumerToPreviousStatefulTaskIds = consumerToPreviousStatefulTaskIds;
         this.ownedPartitions = ownedPartitions;
         this.taskOffsetSums = taskOffsetSums;
         this.taskLagTotals = taskLagTotals;
@@ -95,6 +100,7 @@ public class ClientState {
         standbyTasks = new TreeSet<>();
         prevActiveTasks = unmodifiableSet(new TreeSet<>(previousActiveTasks));
         prevStandbyTasks = unmodifiableSet(new TreeSet<>(previousStandbyTasks));
+        consumerToPreviousStatefulTaskIds = new TreeMap<>();
         ownedPartitions = new TreeMap<>(TOPIC_PARTITION_COMPARATOR);
         taskOffsetSums = emptyMap();
         this.taskLagTotals = unmodifiableMap(taskLagTotals);
@@ -109,6 +115,7 @@ public class ClientState {
             new TreeSet<>(standbyTasks),
             new TreeSet<>(prevActiveTasks),
             new TreeSet<>(prevStandbyTasks),
+            new TreeMap<>(consumerToPreviousStatefulTaskIds),
             newOwnedPartitions,
             new TreeMap<>(taskOffsetSums),
             new TreeMap<>(taskLagTotals),
@@ -143,7 +150,7 @@ public class ClientState {
         activeTasks.addAll(tasks);
     }
 
-    void assignActive(final TaskId task) {
+    public void assignActive(final TaskId task) {
         assertNotAssigned(task);
         activeTasks.add(task);
     }
@@ -232,8 +239,9 @@ public class ClientState {
         return union(() -> new HashSet<>(prevActiveTasks.size() + prevStandbyTasks.size()), prevActiveTasks, prevStandbyTasks);
     }
 
-    public Map<TopicPartition, String> ownedPartitions() {
-        return unmodifiableMap(ownedPartitions);
+    // May return null
+    public String previousOwnerForPartition(final TopicPartition partition) {
+        return ownedPartitions.get(partition);
     }
 
     public void addOwnedPartitions(final Collection<TopicPartition> ownedPartitions, final String consumer) {
@@ -242,8 +250,9 @@ public class ClientState {
         }
     }
 
-    public void addPreviousTasksAndOffsetSums(final Map<TaskId, Long> taskOffsetSums) {
+    public void addPreviousTasksAndOffsetSums(final String consumerId, final Map<TaskId, Long> taskOffsetSums) {
         this.taskOffsetSums.putAll(taskOffsetSums);
+        consumerToPreviousStatefulTaskIds.put(consumerId, taskOffsetSums.keySet());
     }
 
     public void initializePrevTasks(final Map<TopicPartition, TaskId> taskForPartitionMap) {
@@ -291,14 +300,24 @@ public class ClientState {
      * @return end offset sum - offset sum
      *          Task.LATEST_OFFSET if this was previously an active running task on this client
      */
-    long lagFor(final TaskId task) {
+    public long lagFor(final TaskId task) {
         final Long totalLag = taskLagTotals.get(task);
-
         if (totalLag == null) {
             throw new IllegalStateException("Tried to lookup lag for unknown task " + task);
-        } else {
-            return totalLag;
         }
+        return totalLag;
+    }
+
+    public Set<TaskId> statefulActiveTasks() {
+        return activeTasks.stream().filter(this::isStateful).collect(Collectors.toSet());
+    }
+
+    public Set<TaskId> statelessActiveTasks() {
+        return activeTasks.stream().filter(task -> !isStateful(task)).collect(Collectors.toSet());
+    }
+
+    public Set<TaskId> previousTasksForConsumer(final String memberId) {
+        return consumerToPreviousStatefulTaskIds.get(memberId);
     }
 
     boolean hasUnfulfilledQuota(final int tasksPerThread) {
@@ -340,12 +359,16 @@ public class ClientState {
             "]";
     }
 
+    private boolean isStateful(final TaskId task) {
+        return taskLagTotals.containsKey(task);
+    }
+
     private void initializePrevActiveTasksFromOwnedPartitions(final Map<TopicPartition, TaskId> taskForPartitionMap) {
         // there are three cases where we need to construct some or all of the prevTasks from the ownedPartitions:
         // 1) COOPERATIVE clients on version 2.4-2.5 do not encode active tasks at all and rely on ownedPartitions
         // 2) future client during version probing, when we can't decode the future subscription info's prev tasks
         // 3) stateless tasks are not encoded in the task lags, and must be figured out from the ownedPartitions
-        for (final Map.Entry<TopicPartition, String> partitionEntry : ownedPartitions().entrySet()) {
+        for (final Map.Entry<TopicPartition, String> partitionEntry : ownedPartitions.entrySet()) {
             final TopicPartition tp = partitionEntry.getKey();
             final TaskId task = taskForPartitionMap.get(tp);
             if (task != null) {
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignorTest.java
index 350a598..03ab1a7 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignorTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignorTest.java
@@ -16,6 +16,7 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
+import java.util.SortedSet;
 import org.apache.kafka.clients.admin.Admin;
 import org.apache.kafka.clients.admin.AdminClient;
 import org.apache.kafka.clients.admin.ListOffsetsResult;
@@ -88,6 +89,8 @@ import static java.util.Collections.singletonMap;
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.common.utils.Utils.mkSet;
+import static org.apache.kafka.common.utils.Utils.mkSortedSet;
+import static org.apache.kafka.streams.processor.internals.StreamsPartitionAssignor.assignTasksToThreads;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.EMPTY_CHANGELOG_END_OFFSETS;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.EMPTY_TASKS;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.EMPTY_TASK_OFFSET_SUMS;
@@ -101,8 +104,6 @@ import static org.apache.kafka.streams.processor.internals.assignment.Assignment
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_3;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_0;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_1;
-import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_2;
-import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_3;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_1;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_2;
 import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.LATEST_SUPPORTED_VERSION;
@@ -126,7 +127,7 @@ public class StreamsPartitionAssignorTest {
     private static final String CONSUMER_2 = "consumer2";
     private static final String CONSUMER_3 = "consumer3";
     private static final String CONSUMER_4 = "consumer4";
-    
+
     private final Set<String> allTopics = mkSet("topic1", "topic2");
 
     private final TopicPartition t1p0 = new TopicPartition("topic1", 0);
@@ -141,27 +142,6 @@ public class StreamsPartitionAssignorTest {
     private final TopicPartition t3p1 = new TopicPartition("topic3", 1);
     private final TopicPartition t3p2 = new TopicPartition("topic3", 2);
     private final TopicPartition t3p3 = new TopicPartition("topic3", 3);
-    private final TopicPartition t4p0 = new TopicPartition("topic4", 0);
-    private final TopicPartition t4p1 = new TopicPartition("topic4", 1);
-    private final TopicPartition t4p2 = new TopicPartition("topic4", 2);
-    private final TopicPartition t4p3 = new TopicPartition("topic4", 3);
-
-    private final Map<TaskId, Set<TopicPartition>> partitionsForTask = mkMap(
-        mkEntry(TASK_0_0, mkSet(t1p0, t2p0)),
-        mkEntry(TASK_0_1, mkSet(t1p1, t2p1)),
-        mkEntry(TASK_0_2, mkSet(t1p2, t2p2)),
-        mkEntry(TASK_0_3, mkSet(t1p3, t2p3)),
-
-        mkEntry(TASK_1_0, mkSet(t3p0)),
-        mkEntry(TASK_1_1, mkSet(t3p1)),
-        mkEntry(TASK_1_2, mkSet(t3p2)),
-        mkEntry(TASK_1_3, mkSet(t3p3)),
-
-        mkEntry(TASK_2_0, mkSet(t4p0)),
-        mkEntry(TASK_2_1, mkSet(t4p1)),
-        mkEntry(TASK_2_2, mkSet(t4p2)),
-        mkEntry(TASK_2_3, mkSet(t4p3))
-    );
 
     private final List<PartitionInfo> infos = asList(
         new PartitionInfo("topic1", 0, Node.noNode(), new Node[0], new Node[0]),
@@ -325,10 +305,8 @@ public class StreamsPartitionAssignorTest {
 
     @Test
     public void shouldProduceStickyAndBalancedAssignmentWhenNothingChanges() {
-        configureDefault();
-        final ClientState state = new ClientState();
-        final List<TaskId> allTasks = asList(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_1, TASK_1_2,
-            TASK_1_3);
+        final List<TaskId> allTasks =
+                asList(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_1, TASK_1_2, TASK_1_3);
 
         final Map<String, List<TaskId>> previousAssignment = mkMap(
             mkEntry(CONSUMER_1, asList(TASK_0_0, TASK_1_1, TASK_1_3)),
@@ -336,33 +314,29 @@ public class StreamsPartitionAssignorTest {
             mkEntry(CONSUMER_3, asList(TASK_0_1, TASK_0_2, TASK_1_2))
         );
 
-        for (final Map.Entry<String, List<TaskId>> entry : previousAssignment.entrySet()) {
-            for (final TaskId task : entry.getValue()) {
-                state.addOwnedPartitions(partitionsForTask.get(task), entry.getKey());
-            }
-        }
-
-        final Set<String> consumers = mkSet(CONSUMER_1, CONSUMER_2, CONSUMER_3);
-        state.assignActiveTasks(allTasks);
+        final ClientState state = new ClientState();
+        final SortedSet<String> consumers = mkSortedSet(CONSUMER_1, CONSUMER_2, CONSUMER_3);
+        state.addPreviousTasksAndOffsetSums(CONSUMER_1, getTaskOffsetSums(asList(TASK_0_0, TASK_1_1, TASK_1_3), EMPTY_TASKS));
+        state.addPreviousTasksAndOffsetSums(CONSUMER_2, getTaskOffsetSums(asList(TASK_0_3, TASK_1_0), EMPTY_TASKS));
+        state.addPreviousTasksAndOffsetSums(CONSUMER_3, getTaskOffsetSums(asList(TASK_0_1, TASK_0_2, TASK_1_2), EMPTY_TASKS));
+        state.initializePrevTasks(emptyMap());
+        state.computeTaskLags(UUID_1, getTaskEndOffsetSums(allTasks));
 
         assertEquivalentAssignment(
             previousAssignment,
-            partitionAssignor.tryStickyAndBalancedTaskAssignmentWithinClient(
-                state,
+            assignTasksToThreads(
+                allTasks,
+                emptySet(),
                 consumers,
-                partitionsForTask,
-                emptySet()
+                state
             )
         );
     }
 
     @Test
     public void shouldProduceStickyAndBalancedAssignmentWhenNewTasksAreAdded() {
-        configureDefault();
-        final ClientState state = new ClientState();
-
-        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_1, TASK_1_2,
-            TASK_1_3);
+        final List<TaskId> allTasks =
+            new ArrayList<>(asList(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_1, TASK_1_2, TASK_1_3));
 
         final Map<String, List<TaskId>> previousAssignment = mkMap(
             mkEntry(CONSUMER_1, new ArrayList<>(asList(TASK_0_0, TASK_1_1, TASK_1_3))),
@@ -370,33 +344,35 @@ public class StreamsPartitionAssignorTest {
             mkEntry(CONSUMER_3, new ArrayList<>(asList(TASK_0_1, TASK_0_2, TASK_1_2)))
         );
 
-        for (final Map.Entry<String, List<TaskId>> entry : previousAssignment.entrySet()) {
-            for (final TaskId task : entry.getValue()) {
-                state.addOwnedPartitions(partitionsForTask.get(task), entry.getKey());
-            }
-        }
-
-        final Set<String> consumers = mkSet(CONSUMER_1, CONSUMER_2, CONSUMER_3);
-
-        // We should be able to add a new task without sacrificing stickyness
+        final ClientState state = new ClientState();
+        final SortedSet<String> consumers = mkSortedSet(CONSUMER_1, CONSUMER_2, CONSUMER_3);
+        state.addPreviousTasksAndOffsetSums(CONSUMER_1, getTaskOffsetSums(asList(TASK_0_0, TASK_1_1, TASK_1_3), EMPTY_TASKS));
+        state.addPreviousTasksAndOffsetSums(CONSUMER_2, getTaskOffsetSums(asList(TASK_0_3, TASK_1_0), EMPTY_TASKS));
+        state.addPreviousTasksAndOffsetSums(CONSUMER_3, getTaskOffsetSums(asList(TASK_0_1, TASK_0_2, TASK_1_2), EMPTY_TASKS));
+        state.initializePrevTasks(emptyMap());
+        state.computeTaskLags(UUID_1, getTaskEndOffsetSums(allTasks));
+
+        // We should be able to add a new task without sacrificing stickiness
         final TaskId newTask = TASK_2_0;
         allTasks.add(newTask);
         state.assignActiveTasks(allTasks);
 
         final Map<String, List<TaskId>> newAssignment =
-            partitionAssignor.tryStickyAndBalancedTaskAssignmentWithinClient(state, consumers, partitionsForTask, emptySet());
+            assignTasksToThreads(
+                allTasks,
+                emptySet(),
+                consumers,
+                state
+            );
 
         previousAssignment.get(CONSUMER_2).add(newTask);
         assertEquivalentAssignment(previousAssignment, newAssignment);
     }
 
     @Test
-    public void shouldReturnEmptyMapWhenStickyAndBalancedAssignmentIsNotPossibleBecauseNewConsumerJoined() {
-        configureDefault();
-        final ClientState state = new ClientState();
-
-        final List<TaskId> allTasks = asList(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_1, TASK_1_2,
-            TASK_1_3);
+    public void shouldProduceMaximallyStickyAssignmentWhenMemberLeaves() {
+        final List<TaskId> allTasks =
+            asList(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_1, TASK_1_2, TASK_1_3);
 
         final Map<String, List<TaskId>> previousAssignment = mkMap(
             mkEntry(CONSUMER_1, asList(TASK_0_0, TASK_1_1, TASK_1_3)),
@@ -404,84 +380,103 @@ public class StreamsPartitionAssignorTest {
             mkEntry(CONSUMER_3, asList(TASK_0_1, TASK_0_2, TASK_1_2))
         );
 
-        for (final Map.Entry<String, List<TaskId>> entry : previousAssignment.entrySet()) {
-            for (final TaskId task : entry.getValue()) {
-                state.addOwnedPartitions(partitionsForTask.get(task), entry.getKey());
-            }
-        }
+        final ClientState state = new ClientState();
+        final SortedSet<String> consumers = mkSortedSet(CONSUMER_1, CONSUMER_2, CONSUMER_3);
+        state.addPreviousTasksAndOffsetSums(CONSUMER_1, getTaskOffsetSums(asList(TASK_0_0, TASK_1_1, TASK_1_3), EMPTY_TASKS));
+        state.addPreviousTasksAndOffsetSums(CONSUMER_2, getTaskOffsetSums(asList(TASK_0_3, TASK_1_0), EMPTY_TASKS));
+        state.addPreviousTasksAndOffsetSums(CONSUMER_3, getTaskOffsetSums(asList(TASK_0_1, TASK_0_2, TASK_1_2), EMPTY_TASKS));
+        state.initializePrevTasks(emptyMap());
+        state.computeTaskLags(UUID_1, getTaskEndOffsetSums(allTasks));
+
+        // Consumer 3 leaves the group
+        consumers.remove(CONSUMER_3);
+
+        final Map<String, List<TaskId>> assignment = assignTasksToThreads(
+            allTasks,
+            emptySet(),
+            consumers,
+            state
+        );
 
-        // If we add a new consumer here, we cannot produce an assignment that is both sticky and balanced
-        final Set<String> consumers = mkSet(CONSUMER_1, CONSUMER_2, CONSUMER_3, CONSUMER_4);
-        state.assignActiveTasks(allTasks);
+        // Each member should have all of its previous tasks reassigned plus some of consumer 3's tasks
+        // We should give one of its tasks to consumer 1, and two of its tasks to consumer 2
+        assertTrue(assignment.get(CONSUMER_1).containsAll(previousAssignment.get(CONSUMER_1)));
+        assertTrue(assignment.get(CONSUMER_2).containsAll(previousAssignment.get(CONSUMER_2)));
 
-        assertThat(partitionAssignor.tryStickyAndBalancedTaskAssignmentWithinClient(state, consumers, partitionsForTask, emptySet()),
-                   equalTo(emptyMap()));
+        assertThat(assignment.get(CONSUMER_1).size(), equalTo(4));
+        assertThat(assignment.get(CONSUMER_2).size(), equalTo(4));
     }
 
     @Test
-    public void shouldReturnEmptyMapWhenStickyAndBalancedAssignmentIsNotPossibleBecauseOtherClientOwnedPartition() {
-        configureDefault();
-        final ClientState state = new ClientState();
-
-        final List<TaskId> allTasks = asList(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_1, TASK_1_2,
-            TASK_1_3);
+    public void shouldProduceStickyEnoughAssignmentWhenNewMemberJoins() {
+        final List<TaskId> allTasks =
+            asList(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_1, TASK_1_2, TASK_1_3);
 
         final Map<String, List<TaskId>> previousAssignment = mkMap(
-            mkEntry(CONSUMER_1, new ArrayList<>(asList(TASK_1_1, TASK_1_3))),
-            mkEntry(CONSUMER_2, new ArrayList<>(asList(TASK_0_3, TASK_1_0))),
-            mkEntry(CONSUMER_3, new ArrayList<>(asList(TASK_0_1, TASK_0_2, TASK_1_2)))
+            mkEntry(CONSUMER_1, asList(TASK_0_0, TASK_1_1, TASK_1_3)),
+            mkEntry(CONSUMER_2, asList(TASK_0_3, TASK_1_0)),
+            mkEntry(CONSUMER_3, asList(TASK_0_1, TASK_0_2, TASK_1_2))
         );
 
-        for (final Map.Entry<String, List<TaskId>> entry : previousAssignment.entrySet()) {
-            for (final TaskId task : entry.getValue()) {
-                state.addOwnedPartitions(partitionsForTask.get(task), entry.getKey());
-            }
-        }
-
-        // Add the partitions of TASK_0_0 to allOwnedPartitions but not c1's ownedPartitions/previousAssignment
-        final Set<TopicPartition> allOwnedPartitions = new HashSet<>(partitionsForTask.get(TASK_0_0));
+        final ClientState state = new ClientState();
+        final SortedSet<String> consumers = mkSortedSet(CONSUMER_1, CONSUMER_2, CONSUMER_3);
+        state.addPreviousTasksAndOffsetSums(CONSUMER_1, getTaskOffsetSums(asList(TASK_0_0, TASK_1_1, TASK_1_3), EMPTY_TASKS));
+        state.addPreviousTasksAndOffsetSums(CONSUMER_2, getTaskOffsetSums(asList(TASK_0_3, TASK_1_0), EMPTY_TASKS));
+        state.addPreviousTasksAndOffsetSums(CONSUMER_3, getTaskOffsetSums(asList(TASK_0_1, TASK_0_2, TASK_1_2), EMPTY_TASKS));
 
-        final Set<String> consumers = mkSet(CONSUMER_1, CONSUMER_2, CONSUMER_3);
-        state.assignActiveTasks(allTasks);
+        // Consumer 4 joins the group
+        consumers.add(CONSUMER_4);
+        state.addPreviousTasksAndOffsetSums(CONSUMER_4, getTaskOffsetSums(EMPTY_TASKS, EMPTY_TASKS));
 
-        assertThat(partitionAssignor.tryStickyAndBalancedTaskAssignmentWithinClient(state, consumers, partitionsForTask, allOwnedPartitions),
-                   equalTo(emptyMap()));
-    }
+        state.initializePrevTasks(emptyMap());
+        state.computeTaskLags(UUID_1, getTaskEndOffsetSums(allTasks));
 
-    @Test
-    public void shouldInterleaveTasksByGroupId() {
-        final TaskId taskIdA0 = new TaskId(0, 0);
-        final TaskId taskIdA1 = new TaskId(0, 1);
-        final TaskId taskIdA2 = new TaskId(0, 2);
-        final TaskId taskIdA3 = new TaskId(0, 3);
+        final Map<String, List<TaskId>> assignment = assignTasksToThreads(
+            allTasks,
+            emptySet(),
+            consumers,
+            state
+        );
 
-        final TaskId taskIdB0 = new TaskId(1, 0);
-        final TaskId taskIdB1 = new TaskId(1, 1);
-        final TaskId taskIdB2 = new TaskId(1, 2);
+        // we should move one task each from consumer 1 and consumer 3 to the new member, and none from consumer 2
+        assertTrue(previousAssignment.get(CONSUMER_1).containsAll(assignment.get(CONSUMER_1)));
+        assertTrue(previousAssignment.get(CONSUMER_3).containsAll(assignment.get(CONSUMER_3)));
 
-        final TaskId taskIdC0 = new TaskId(2, 0);
-        final TaskId taskIdC1 = new TaskId(2, 1);
+        assertTrue(assignment.get(CONSUMER_2).containsAll(previousAssignment.get(CONSUMER_2)));
 
-        final String c1 = "c1";
-        final String c2 = "c2";
-        final String c3 = "c3";
 
-        final Set<String> consumers = mkSet(c1, c2, c3);
+        assertThat(assignment.get(CONSUMER_1).size(), equalTo(2));
+        assertThat(assignment.get(CONSUMER_2).size(), equalTo(2));
+        assertThat(assignment.get(CONSUMER_3).size(), equalTo(2));
+        assertThat(assignment.get(CONSUMER_4).size(), equalTo(2));
+    }
 
-        final List<TaskId> expectedSubList1 = asList(taskIdA0, taskIdA3, taskIdB2);
-        final List<TaskId> expectedSubList2 = asList(taskIdA1, taskIdB0, taskIdC0);
-        final List<TaskId> expectedSubList3 = asList(taskIdA2, taskIdB1, taskIdC1);
+    @Test
+    public void shouldInterleaveTasksByGroupIdDuringNewAssignment() {
+        final List<TaskId> allTasks =
+            asList(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_1, TASK_1_2, TASK_2_0, TASK_2_1);
+
+        final Map<String, List<TaskId>> assignment = mkMap(
+            mkEntry(CONSUMER_1, new ArrayList<>(asList(TASK_0_0, TASK_0_3, TASK_1_2))),
+            mkEntry(CONSUMER_2, new ArrayList<>(asList(TASK_0_1, TASK_1_0, TASK_2_0))),
+            mkEntry(CONSUMER_3, new ArrayList<>(asList(TASK_0_2, TASK_1_1, TASK_2_1)))
+        );
 
-        final Map<String, List<TaskId>> assignment = new HashMap<>();
-        assignment.put(c1, expectedSubList1);
-        assignment.put(c2, expectedSubList2);
-        assignment.put(c3, expectedSubList3);
+        final ClientState state = new ClientState();
+        final SortedSet<String> consumers = mkSortedSet(CONSUMER_1, CONSUMER_2, CONSUMER_3);
+        state.addPreviousTasksAndOffsetSums(CONSUMER_1, emptyMap());
+        state.addPreviousTasksAndOffsetSums(CONSUMER_2, emptyMap());
+        state.addPreviousTasksAndOffsetSums(CONSUMER_3, emptyMap());
 
-        final List<TaskId> tasks = asList(taskIdC0, taskIdC1, taskIdB0, taskIdB1, taskIdB2, taskIdA0, taskIdA1, taskIdA2, taskIdA3);
-        Collections.shuffle(tasks);
+        Collections.shuffle(allTasks);
 
         final Map<String, List<TaskId>> interleavedTaskIds =
-            StreamsPartitionAssignor.interleaveConsumerTasksByGroupId(tasks, consumers);
+            assignTasksToThreads(
+                allTasks,
+                emptySet(),
+                consumers,
+                state
+            );
 
         assertThat(interleavedTaskIds, equalTo(assignment));
     }
@@ -637,16 +632,6 @@ public class StreamsPartitionAssignorTest {
 
         final List<String> topics = asList("topic1", "topic2");
 
-        final TaskId taskIdA0 = new TaskId(0, 0);
-        final TaskId taskIdA1 = new TaskId(0, 1);
-        final TaskId taskIdA2 = new TaskId(0, 2);
-        final TaskId taskIdA3 = new TaskId(0, 3);
-
-        final TaskId taskIdB0 = new TaskId(1, 0);
-        final TaskId taskIdB1 = new TaskId(1, 1);
-        final TaskId taskIdB2 = new TaskId(1, 2);
-        final TaskId taskIdB3 = new TaskId(1, 3);
-
         configureDefault();
 
         subscriptions.put("consumer10",
@@ -669,12 +654,12 @@ public class StreamsPartitionAssignorTest {
         // the first consumer
         final AssignmentInfo info10 = AssignmentInfo.decode(assignments.get("consumer10").userData());
 
-        final List<TaskId> expectedInfo10TaskIds = asList(taskIdA0, taskIdA2, taskIdB0, taskIdB2);
+        final List<TaskId> expectedInfo10TaskIds = asList(TASK_0_0, TASK_0_2, TASK_1_0, TASK_1_2);
         assertEquals(expectedInfo10TaskIds, info10.activeTasks());
 
         // the second consumer
         final AssignmentInfo info11 = AssignmentInfo.decode(assignments.get("consumer11").userData());
-        final List<TaskId> expectedInfo11TaskIds = asList(taskIdA1, taskIdA3, taskIdB1, taskIdB3);
+        final List<TaskId> expectedInfo11TaskIds = asList(TASK_0_1, TASK_0_3, TASK_1_1, TASK_1_3);
 
         assertEquals(expectedInfo11TaskIds, info11.activeTasks());
     }
@@ -2019,10 +2004,15 @@ public class StreamsPartitionAssignorTest {
     }
 
     // Stub offset sums for when we only care about the prev/standby task sets, not the actual offsets
-    private static Map<TaskId, Long> getTaskOffsetSums(final Set<TaskId> activeTasks, final Set<TaskId> standbyTasks) {
+    private static Map<TaskId, Long> getTaskOffsetSums(final Collection<TaskId> activeTasks, final Collection<TaskId> standbyTasks) {
         final Map<TaskId, Long> taskOffsetSums = activeTasks.stream().collect(Collectors.toMap(t -> t, t -> Task.LATEST_OFFSET));
         taskOffsetSums.putAll(standbyTasks.stream().collect(Collectors.toMap(t -> t, t -> 0L)));
         return taskOffsetSums;
     }
 
+    // Stub end offsets sums for situations where we don't really care about computing exact lags
+    private static Map<TaskId, Long> getTaskEndOffsetSums(final Collection<TaskId> allStatefulTasks) {
+        return allStatefulTasks.stream().collect(Collectors.toMap(t -> t, t -> Long.MAX_VALUE));
+    }
+
 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java
index 8e5fa36..d50c00e 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java
@@ -32,6 +32,7 @@ import static org.apache.kafka.streams.processor.internals.assignment.Assignment
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_1;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_2;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_3;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_1;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.hasActiveTasks;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.hasStandbyTasks;
 import static org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo.UNKNOWN_OFFSET_SUM;
@@ -44,7 +45,6 @@ import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
 
 public class ClientStateTest {
-
     private final ClientState client = new ClientState(1);
     private final ClientState zeroCapacityClient = new ClientState(0);
 
@@ -300,7 +300,7 @@ public class ClientStateTest {
     @Test
     public void shouldAddTasksWithLatestOffsetToPrevActiveTasks() {
         final Map<TaskId, Long> taskOffsetSums = Collections.singletonMap(TASK_0_1, Task.LATEST_OFFSET);
-        client.addPreviousTasksAndOffsetSums(taskOffsetSums);
+        client.addPreviousTasksAndOffsetSums("c1", taskOffsetSums);
         client.initializePrevTasks(Collections.emptyMap());
         assertThat(client.prevActiveTasks(), equalTo(Collections.singleton(TASK_0_1)));
         assertThat(client.previousAssignedTasks(), equalTo(Collections.singleton(TASK_0_1)));
@@ -308,12 +308,45 @@ public class ClientStateTest {
     }
 
     @Test
+    public void shouldReturnPreviousStatefulTasksForConsumer() {
+        client.addPreviousTasksAndOffsetSums("c1", Collections.singletonMap(TASK_0_1, Task.LATEST_OFFSET));
+        client.addPreviousTasksAndOffsetSums("c2", Collections.singletonMap(TASK_0_2, 0L));
+        client.addPreviousTasksAndOffsetSums("c3", Collections.emptyMap());
+
+        client.initializePrevTasks(Collections.emptyMap());
+        client.computeTaskLags(
+            UUID_1,
+            mkMap(
+                mkEntry(TASK_0_1, 1_000L),
+                mkEntry(TASK_0_2, 1_000L)
+            )
+        );
+
+        assertThat(client.previousTasksForConsumer("c1"), equalTo(mkSet(TASK_0_1)));
+        assertThat(client.previousTasksForConsumer("c2"), equalTo(mkSet(TASK_0_2)));
+        assertTrue(client.previousTasksForConsumer("c3").isEmpty());
+    }
+
+    @Test
+    public void shouldReturnPreviousTasksForConsumer() {
+        client.addPreviousTasksAndOffsetSums("c1", mkMap(
+            mkEntry(TASK_0_1, 100L),
+            mkEntry(TASK_0_2, 0L),
+            mkEntry(TASK_0_3, Task.LATEST_OFFSET)
+        ));
+
+        client.initializePrevTasks(Collections.emptyMap());
+
+        assertThat(client.previousTasksForConsumer("c1"), equalTo(mkSet(TASK_0_3, TASK_0_2, TASK_0_1)));
+    }
+
+    @Test
     public void shouldAddTasksInOffsetSumsMapToPrevStandbyTasks() {
         final Map<TaskId, Long> taskOffsetSums = mkMap(
             mkEntry(TASK_0_1, 0L),
             mkEntry(TASK_0_2, 100L)
         );
-        client.addPreviousTasksAndOffsetSums(taskOffsetSums);
+        client.addPreviousTasksAndOffsetSums("c1", taskOffsetSums);
         client.initializePrevTasks(Collections.emptyMap());
         assertThat(client.prevStandbyTasks(), equalTo(mkSet(TASK_0_1, TASK_0_2)));
         assertThat(client.previousAssignedTasks(), equalTo(mkSet(TASK_0_1, TASK_0_2)));
@@ -330,7 +363,7 @@ public class ClientStateTest {
             mkEntry(TASK_0_1, 500L),
             mkEntry(TASK_0_2, 100L)
         );
-        client.addPreviousTasksAndOffsetSums(taskOffsetSums);
+        client.addPreviousTasksAndOffsetSums("c1", taskOffsetSums);
         client.computeTaskLags(null, allTaskEndOffsetSums);
 
         assertThat(client.lagFor(TASK_0_1), equalTo(500L));
@@ -341,7 +374,7 @@ public class ClientStateTest {
     public void shouldReturnEndOffsetSumForLagOfTaskWeDidNotPreviouslyOwn() {
         final Map<TaskId, Long> taskOffsetSums = Collections.emptyMap();
         final Map<TaskId, Long> allTaskEndOffsetSums = Collections.singletonMap(TASK_0_1, 500L);
-        client.addPreviousTasksAndOffsetSums(taskOffsetSums);
+        client.addPreviousTasksAndOffsetSums("c1", taskOffsetSums);
         client.computeTaskLags(null, allTaskEndOffsetSums);
         assertThat(client.lagFor(TASK_0_1), equalTo(500L));
     }
@@ -350,7 +383,7 @@ public class ClientStateTest {
     public void shouldReturnLatestOffsetForLagOfPreviousActiveRunningTask() {
         final Map<TaskId, Long> taskOffsetSums = Collections.singletonMap(TASK_0_1, Task.LATEST_OFFSET);
         final Map<TaskId, Long> allTaskEndOffsetSums = Collections.singletonMap(TASK_0_1, 500L);
-        client.addPreviousTasksAndOffsetSums(taskOffsetSums);
+        client.addPreviousTasksAndOffsetSums("c1", taskOffsetSums);
         client.computeTaskLags(null, allTaskEndOffsetSums);
         assertThat(client.lagFor(TASK_0_1), equalTo(Task.LATEST_OFFSET));
     }
@@ -359,7 +392,7 @@ public class ClientStateTest {
     public void shouldReturnUnknownOffsetSumForLagOfTaskWithUnknownOffset() {
         final Map<TaskId, Long> taskOffsetSums = Collections.singletonMap(TASK_0_1, UNKNOWN_OFFSET_SUM);
         final Map<TaskId, Long> allTaskEndOffsetSums = Collections.singletonMap(TASK_0_1, 500L);
-        client.addPreviousTasksAndOffsetSums(taskOffsetSums);
+        client.addPreviousTasksAndOffsetSums("c1", taskOffsetSums);
         client.computeTaskLags(null, allTaskEndOffsetSums);
         assertThat(client.lagFor(TASK_0_1), equalTo(UNKNOWN_OFFSET_SUM));
     }
@@ -368,7 +401,7 @@ public class ClientStateTest {
     public void shouldReturnEndOffsetSumIfOffsetSumIsGreaterThanEndOffsetSum() {
         final Map<TaskId, Long> taskOffsetSums = Collections.singletonMap(TASK_0_1, 5L);
         final Map<TaskId, Long> allTaskEndOffsetSums = Collections.singletonMap(TASK_0_1, 1L);
-        client.addPreviousTasksAndOffsetSums(taskOffsetSums);
+        client.addPreviousTasksAndOffsetSums("c1", taskOffsetSums);
         client.computeTaskLags(null, allTaskEndOffsetSums);
         assertThat(client.lagFor(TASK_0_1), equalTo(1L));
     }
@@ -385,7 +418,7 @@ public class ClientStateTest {
     public void shouldThrowIllegalStateExceptionOnLagForUnknownTask() {
         final Map<TaskId, Long> taskOffsetSums = Collections.singletonMap(TASK_0_1, 0L);
         final Map<TaskId, Long> allTaskEndOffsetSums = Collections.singletonMap(TASK_0_1, 500L);
-        client.addPreviousTasksAndOffsetSums(taskOffsetSums);
+        client.addPreviousTasksAndOffsetSums("c1", taskOffsetSums);
         client.computeTaskLags(null, allTaskEndOffsetSums);
         assertThrows(IllegalStateException.class, () -> client.lagFor(TASK_0_2));
     }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorConvergenceTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorConvergenceTest.java
index 8253bcd..68c9dfe 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorConvergenceTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorConvergenceTest.java
@@ -185,7 +185,7 @@ public class TaskAssignorConvergenceTest {
                 }
                 newClientState.addPreviousActiveTasks(clientState.activeTasks());
                 newClientState.addPreviousStandbyTasks(clientState.standbyTasks());
-                newClientState.addPreviousTasksAndOffsetSums(taskOffsetSums);
+                newClientState.addPreviousTasksAndOffsetSums("consumer", taskOffsetSums);
                 newClientState.computeTaskLags(uuid, statefulTaskEndOffsetSums);
                 newClientStates.put(uuid, newClientState);
             }