You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by vv...@apache.org on 2020/04/21 22:10:39 UTC

[kafka] branch trunk updated: KAFKA-6145: KIP-441: Build state constrained assignment from balanced one (#8497)

This is an automated email from the ASF dual-hosted git repository.

vvcephei pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 5c548e5  KAFKA-6145: KIP-441: Build state constrained assignment from balanced one (#8497)
5c548e5 is described below

commit 5c548e5dfc223371f3109de14eddf0918b8dcad2
Author: A. Sophie Blee-Goldman <so...@confluent.io>
AuthorDate: Tue Apr 21 15:09:59 2020 -0700

    KAFKA-6145: KIP-441: Build state constrained assignment from balanced one (#8497)
    
    Implements: KIP-441
    Reviewers: Bruno Cadonna <br...@confluent.io>, John Roesler <vv...@apache.org>
---
 .../internals/StreamsPartitionAssignor.java        |   2 +-
 ...dBalancedAssignor.java => AssignmentUtils.java} |  26 +-
 .../internals/assignment/ClientState.java          |   8 +
 .../DefaultStateConstrainedBalancedAssignor.java   | 304 -------
 .../assignment/HighAvailabilityTaskAssignor.java   | 261 ++----
 .../internals/assignment/TaskMovement.java         | 169 ++--
 .../assignment/ValidClientsByTaskLoadQueue.java    | 112 +++
 .../internals/StreamsPartitionAssignorTest.java    |  77 +-
 .../internals/assignment/AssignmentTestUtils.java  |  20 +-
 .../internals/assignment/AssignmentUtilsTest.java  |  57 ++
 ...efaultStateConstrainedBalancedAssignorTest.java | 978 ---------------------
 .../HighAvailabilityTaskAssignorTest.java          |  94 +-
 .../assignment/TaskAssignorConvergenceTest.java    |   3 -
 .../internals/assignment/TaskMovementTest.java     | 382 ++++----
 .../ValidClientsByTaskLoadQueueTest.java           | 126 +++
 15 files changed, 649 insertions(+), 1970 deletions(-)

diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java
index 9ccd68a..d285e31 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java
@@ -861,7 +861,7 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
                                                          final int minSupportedMetadataVersion,
                                                          final boolean shouldTriggerProbingRebalance) {
         // keep track of whether a 2nd rebalance is unavoidable so we can skip trying to get a completely sticky assignment
-        boolean rebalanceRequired = false;
+        boolean rebalanceRequired = shouldTriggerProbingRebalance;
         final Map<String, Assignment> assignment = new HashMap<>();
 
         // within the client, distribute tasks to its owned consumers
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StateConstrainedBalancedAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentUtils.java
similarity index 61%
rename from streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StateConstrainedBalancedAssignor.java
rename to streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentUtils.java
index a52ea2a..b88c24b 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StateConstrainedBalancedAssignor.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentUtils.java
@@ -16,20 +16,26 @@
  */
 package org.apache.kafka.streams.processor.internals.assignment;
 
-import java.util.UUID;
 import org.apache.kafka.streams.processor.TaskId;
 
-import java.util.List;
 import java.util.Map;
 import java.util.Set;
-import java.util.SortedMap;
 import java.util.SortedSet;
+import java.util.UUID;
+
+final class AssignmentUtils {
+
+    private AssignmentUtils() {}
+
+    /**
+     * @return true if this client is caught-up for this task, or the task has no caught-up clients
+     */
+    static boolean taskIsCaughtUpOnClientOrNoCaughtUpClientsExist(final TaskId task,
+                                                                  final UUID client,
+                                                                  final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients) {
+        final Set<UUID> caughtUpClients = tasksToCaughtUpClients.get(task);
+        return caughtUpClients == null || caughtUpClients.contains(client);
+    }
 
-public interface StateConstrainedBalancedAssignor {
 
-    Map<UUID, List<TaskId>> assign(final SortedMap<TaskId, SortedSet<RankedClient>> statefulTasksToRankedClients,
-                                   final int balanceFactor,
-                                   final Set<UUID> clients,
-                                   final Map<UUID, Integer> clientsToNumberOfStreamThread,
-                                   final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients);
-}
+}
\ No newline at end of file
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
index b2b284b..5b8857c 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
@@ -245,6 +245,14 @@ public class ClientState {
         return capacity;
     }
 
+    double activeTaskLoad() {
+        return ((double) activeTaskCount()) / capacity;
+    }
+
+    double taskLoad() {
+        return ((double) assignedTaskCount()) / capacity;
+    }
+
     boolean hasUnfulfilledQuota(final int tasksPerThread) {
         return activeTasks.size() < capacity * tasksPerThread;
     }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/DefaultStateConstrainedBalancedAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/DefaultStateConstrainedBalancedAssignor.java
deleted file mode 100644
index 57e744e..0000000
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/DefaultStateConstrainedBalancedAssignor.java
+++ /dev/null
@@ -1,304 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.kafka.streams.processor.internals.assignment;
-
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-import java.util.SortedMap;
-import java.util.SortedSet;
-import java.util.TreeSet;
-import java.util.UUID;
-import org.apache.kafka.streams.processor.TaskId;
-import org.apache.kafka.streams.processor.internals.Task;
-
-public class DefaultStateConstrainedBalancedAssignor implements StateConstrainedBalancedAssignor {
-
-    /**
-     * This assignment algorithm guarantees that all task for which caught-up clients exist are assigned to one of the
-     * caught-up clients. Tasks for which no caught-up client exist are assigned best-effort to satisfy the balance
-     * factor. There is no guarantee that the balance factor is satisfied.
-     *
-     * @param statefulTasksToRankedClients ranked clients map
-     * @param balanceFactor balance factor (at least 1)
-     * @param clients set of clients to assign tasks to
-     * @param clientsToNumberOfStreamThreads map of clients to their number of stream threads
-     * @return assignment
-     */
-    @Override
-    public Map<UUID, List<TaskId>> assign(final SortedMap<TaskId, SortedSet<RankedClient>> statefulTasksToRankedClients,
-                                          final int balanceFactor,
-                                          final Set<UUID> clients,
-                                          final Map<UUID, Integer> clientsToNumberOfStreamThreads,
-                                          final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients) {
-        checkClientsAndNumberOfStreamThreads(clientsToNumberOfStreamThreads, clients);
-        final Map<UUID, List<TaskId>> assignment = initAssignment(clients);
-        assignTasksWithCaughtUpClients(
-            assignment,
-            tasksToCaughtUpClients,
-            statefulTasksToRankedClients
-        );
-        assignTasksWithoutCaughtUpClients(
-            assignment,
-            tasksToCaughtUpClients,
-            statefulTasksToRankedClients
-        );
-        balance(
-            assignment,
-            balanceFactor,
-            statefulTasksToRankedClients,
-            tasksToCaughtUpClients,
-            clientsToNumberOfStreamThreads
-        );
-        return assignment;
-    }
-
-    private void checkClientsAndNumberOfStreamThreads(final Map<UUID, Integer> clientsToNumberOfStreamThreads,
-                                                      final Set<UUID> clients) {
-        if (clients.isEmpty()) {
-            throw new IllegalStateException("Set of clients must not be empty");
-        }
-        if (clientsToNumberOfStreamThreads.isEmpty()) {
-            throw new IllegalStateException("Map from clients to their number of stream threads must not be empty");
-        }
-        final Set<UUID> copyOfClients = new HashSet<>(clients);
-        copyOfClients.removeAll(clientsToNumberOfStreamThreads.keySet());
-        if (!copyOfClients.isEmpty()) {
-            throw new IllegalStateException(
-                "Map from clients to their number of stream threads must contain an entry for each client involved in "
-                    + "the assignment."
-            );
-        }
-    }
-
-    /**
-     * Initialises the assignment with an empty list for each client.
-     *
-     * @param clients list of clients
-     * @return initialised assignment with empty lists
-     */
-    private Map<UUID, List<TaskId>> initAssignment(final Set<UUID> clients) {
-        final Map<UUID, List<TaskId>> assignment = new HashMap<>();
-        clients.forEach(client -> assignment.put(client, new ArrayList<>()));
-        return assignment;
-    }
-
-    /**
-     * Maps a task to the client that host the task according to the previous assignment.
-     *
-     * @return map from task UUIDs to clients hosting the corresponding task
-     */
-    private Map<TaskId, UUID> previouslyRunningTasksToPreviousClients(final Map<TaskId, SortedSet<RankedClient>> statefulTasksToRankedClients) {
-        final Map<TaskId, UUID> tasksToPreviousClients = new HashMap<>();
-        for (final Map.Entry<TaskId, SortedSet<RankedClient>> taskToRankedClients : statefulTasksToRankedClients.entrySet()) {
-            final RankedClient topRankedClient = taskToRankedClients.getValue().first();
-            if (topRankedClient.rank() == Task.LATEST_OFFSET) {
-                tasksToPreviousClients.put(taskToRankedClients.getKey(), topRankedClient.clientId());
-            }
-        }
-        return tasksToPreviousClients;
-    }
-
-    /**
-     * Assigns tasks for which one or more caught-up clients exist to one of the caught-up clients.
-     * @param assignment assignment
-     * @param tasksToCaughtUpClients map from task UUIDs to lists of caught-up clients
-     */
-    private void assignTasksWithCaughtUpClients(final Map<UUID, List<TaskId>> assignment,
-                                                final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients,
-                                                final Map<TaskId, SortedSet<RankedClient>> statefulTasksToRankedClients) {
-        // If a task was previously assigned to a client that is caught-up and still exists, give it back to the client
-        final Map<TaskId, UUID> previouslyRunningTasksToPreviousClients =
-            previouslyRunningTasksToPreviousClients(statefulTasksToRankedClients);
-        previouslyRunningTasksToPreviousClients.forEach((task, client) -> assignment.get(client).add(task));
-        final List<TaskId> unassignedTasksWithCaughtUpClients = new ArrayList<>(tasksToCaughtUpClients.keySet());
-        unassignedTasksWithCaughtUpClients.removeAll(previouslyRunningTasksToPreviousClients.keySet());
-
-        // If a task's previous host client was not caught-up or no longer exists, assign it to the caught-up client
-        // with the least tasks
-        for (final TaskId taskId : unassignedTasksWithCaughtUpClients) {
-            final SortedSet<UUID> caughtUpClients = tasksToCaughtUpClients.get(taskId);
-            UUID clientWithLeastTasks = null;
-            int minTaskPerStreamThread = Integer.MAX_VALUE;
-            for (final UUID client : caughtUpClients) {
-                final int assignedTasks = assignment.get(client).size();
-                if (minTaskPerStreamThread > assignedTasks) {
-                    clientWithLeastTasks = client;
-                    minTaskPerStreamThread = assignedTasks;
-                }
-            }
-            assignment.get(clientWithLeastTasks).add(taskId);
-        }
-    }
-
-    /**
-     * Assigns tasks for which no caught-up clients exist.
-     * A task is assigned to one of the clients with the highest rank and the least tasks assigned.
-     * @param assignment assignment
-     * @param tasksToCaughtUpClients map from task UUIDs to lists of caught-up clients
-     * @param statefulTasksToRankedClients ranked clients map
-     */
-    private void assignTasksWithoutCaughtUpClients(final Map<UUID, List<TaskId>> assignment,
-                                                   final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients,
-                                                   final Map<TaskId, SortedSet<RankedClient>> statefulTasksToRankedClients) {
-        final SortedSet<TaskId> unassignedTasksWithoutCaughtUpClients = new TreeSet<>(statefulTasksToRankedClients.keySet());
-        unassignedTasksWithoutCaughtUpClients.removeAll(tasksToCaughtUpClients.keySet());
-        for (final TaskId taskId : unassignedTasksWithoutCaughtUpClients) {
-            final SortedSet<RankedClient> rankedClients = statefulTasksToRankedClients.get(taskId);
-            final long topRank = rankedClients.first().rank();
-            int minTasksPerStreamThread = Integer.MAX_VALUE;
-            UUID clientWithLeastTasks = rankedClients.first().clientId();
-            for (final RankedClient rankedClient : rankedClients) {
-                if (rankedClient.rank() == topRank) {
-                    final UUID clientId = rankedClient.clientId();
-                    final int assignedTasks = assignment.get(clientId).size();
-                    if (minTasksPerStreamThread > assignedTasks) {
-                        clientWithLeastTasks = clientId;
-                        minTasksPerStreamThread = assignedTasks;
-                    }
-                } else {
-                    break;
-                }
-            }
-            assignment.get(clientWithLeastTasks).add(taskId);
-        }
-    }
-
-    /**
-     * Balance the assignment.
-     * @param assignment assignment
-     * @param balanceFactor balance factor
-     * @param statefulTasksToRankedClients ranked clients map
-     * @param tasksToCaughtUpClients map from task UUIDs to lists of caught-up clients
-     * @param clientsToNumberOfStreamThreads map from clients to their number of stream threads
-     */
-    private void balance(final Map<UUID, List<TaskId>> assignment,
-                         final int balanceFactor,
-                         final Map<TaskId, SortedSet<RankedClient>> statefulTasksToRankedClients,
-                         final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients,
-                         final Map<UUID, Integer> clientsToNumberOfStreamThreads) {
-        final List<UUID> clients = new ArrayList<>(assignment.keySet());
-        Collections.sort(clients);
-        for (final UUID sourceClientId : clients) {
-            final List<TaskId> sourceTasks = assignment.get(sourceClientId);
-            maybeMoveSourceTasksWithoutCaughtUpClients(
-                assignment,
-                balanceFactor,
-                statefulTasksToRankedClients,
-                tasksToCaughtUpClients,
-                clientsToNumberOfStreamThreads,
-                sourceClientId,
-                sourceTasks
-            );
-            maybeMoveSourceTasksWithCaughtUpClients(
-                assignment,
-                balanceFactor,
-                tasksToCaughtUpClients,
-                clientsToNumberOfStreamThreads,
-                sourceClientId,
-                sourceTasks
-            );
-        }
-    }
-
-    private void maybeMoveSourceTasksWithoutCaughtUpClients(final Map<UUID, List<TaskId>> assignment,
-                                                            final int balanceFactor,
-                                                            final Map<TaskId, SortedSet<RankedClient>> statefulTasksToRankedClients,
-                                                            final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients,
-                                                            final Map<UUID, Integer> clientsToNumberOfStreamThreads,
-                                                            final UUID sourceClientId,
-                                                            final List<TaskId> sourceTasks) {
-        for (final TaskId task : assignedTasksWithoutCaughtUpClientsThatMightBeMoved(sourceTasks, tasksToCaughtUpClients)) {
-            final int assignedTasksPerStreamThreadAtSource =
-                sourceTasks.size() / clientsToNumberOfStreamThreads.get(sourceClientId);
-            for (final RankedClient clientAndRank : statefulTasksToRankedClients.get(task)) {
-                final UUID destinationClientId = clientAndRank.clientId();
-                final List<TaskId> destination = assignment.get(destinationClientId);
-                final int assignedTasksPerStreamThreadAtDestination =
-                    destination.size() / clientsToNumberOfStreamThreads.get(destinationClientId);
-                if (assignedTasksPerStreamThreadAtSource - assignedTasksPerStreamThreadAtDestination > balanceFactor) {
-                    sourceTasks.remove(task);
-                    destination.add(task);
-                    break;
-                }
-            }
-        }
-    }
-
-    private void maybeMoveSourceTasksWithCaughtUpClients(final Map<UUID, List<TaskId>> assignment,
-                                                         final int balanceFactor,
-                                                         final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients,
-                                                         final Map<UUID, Integer> clientsToNumberOfStreamThreads,
-                                                         final UUID sourceClientId,
-                                                         final List<TaskId> sourceTasks) {
-        for (final TaskId task : assignedTasksWithCaughtUpClientsThatMightBeMoved(sourceTasks, tasksToCaughtUpClients)) {
-            final int assignedTasksPerStreamThreadAtSource =
-                sourceTasks.size() / clientsToNumberOfStreamThreads.get(sourceClientId);
-            for (final UUID destinationClientId : tasksToCaughtUpClients.get(task)) {
-                final List<TaskId> destination = assignment.get(destinationClientId);
-                final int assignedTasksPerStreamThreadAtDestination =
-                    destination.size() / clientsToNumberOfStreamThreads.get(destinationClientId);
-                if (assignedTasksPerStreamThreadAtSource - assignedTasksPerStreamThreadAtDestination > balanceFactor) {
-                    sourceTasks.remove(task);
-                    destination.add(task);
-                    break;
-                }
-            }
-        }
-    }
-
-    /**
-     * Returns a sublist of tasks in the given list that does not have a caught-up client.
-     *
-     * @param tasks list of task UUIDs
-     * @param tasksToCaughtUpClients map from task UUIDs to lists of caught-up clients
-     * @return a list of task UUIDs that does not have a caught-up client
-     */
-    private List<TaskId> assignedTasksWithoutCaughtUpClientsThatMightBeMoved(final List<TaskId> tasks,
-                                                                             final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients) {
-        return assignedTasksThatMightBeMoved(tasks, tasksToCaughtUpClients, false);
-    }
-
-    /**
-     * Returns a sublist of tasks in the given list that have a caught-up client.
-     *
-     * @param tasks list of task UUIDs
-     * @param tasksToCaughtUpClients map from task UUIDs to lists of caught-up clients
-     * @return a list of task UUIDs that have a caught-up client
-     */
-    private List<TaskId> assignedTasksWithCaughtUpClientsThatMightBeMoved(final List<TaskId> tasks,
-                                                                          final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients) {
-        return assignedTasksThatMightBeMoved(tasks, tasksToCaughtUpClients, true);
-    }
-
-    private List<TaskId> assignedTasksThatMightBeMoved(final List<TaskId> tasks,
-                                                       final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients,
-                                                       final boolean isCaughtUp) {
-        final List<TaskId> tasksWithCaughtUpClients = new ArrayList<>();
-        for (int i = tasks.size() - 1; i >= 0; --i) {
-            final TaskId task = tasks.get(i);
-            if (isCaughtUp == tasksToCaughtUpClients.containsKey(task)) {
-                tasksWithCaughtUpClients.add(task);
-            }
-        }
-        return Collections.unmodifiableList(tasksWithCaughtUpClients);
-    }
-}
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java
index 4cc7df1..b1570fb 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java
@@ -16,19 +16,16 @@
  */
 package org.apache.kafka.streams.processor.internals.assignment;
 
-import static java.util.Arrays.asList;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentUtils.taskIsCaughtUpOnClientOrNoCaughtUpClientsExist;
 import static org.apache.kafka.streams.processor.internals.assignment.RankedClient.buildClientRankingsByTask;
 import static org.apache.kafka.streams.processor.internals.assignment.RankedClient.tasksToCaughtUpClients;
-import static org.apache.kafka.streams.processor.internals.assignment.TaskMovement.getMovements;
+import static org.apache.kafka.streams.processor.internals.assignment.TaskMovement.assignTaskMovements;
 
-import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
-import java.util.LinkedList;
 import java.util.List;
-import java.util.PriorityQueue;
 import java.util.SortedMap;
 import java.util.SortedSet;
 import java.util.TreeSet;
@@ -89,95 +86,74 @@ public class HighAvailabilityTaskAssignor implements TaskAssignor {
             return false;
         }
 
-        final Map<UUID, List<TaskId>> warmupTaskAssignment = initializeEmptyTaskAssignmentMap(sortedClients);
-        final Map<UUID, List<TaskId>> standbyTaskAssignment = initializeEmptyTaskAssignmentMap(sortedClients);
-        final Map<UUID, List<TaskId>> statelessActiveTaskAssignment = initializeEmptyTaskAssignmentMap(sortedClients);
+        final Map<TaskId, Integer> tasksToRemainingStandbys =
+            statefulTasks.stream().collect(Collectors.toMap(task -> task, t -> configs.numStandbyReplicas));
 
-        // ---------------- Stateful Active Tasks ---------------- //
+        final boolean followupRebalanceNeeded = assignStatefulActiveTasks(tasksToRemainingStandbys);
 
-        final Map<UUID, List<TaskId>> statefulActiveTaskAssignment =
-            new DefaultStateConstrainedBalancedAssignor().assign(
-                statefulTasksToRankedCandidates,
-                configs.balanceFactor,
-                sortedClients,
-                clientsToNumberOfThreads,
-                tasksToCaughtUpClients
-            );
+        assignStandbyReplicaTasks(tasksToRemainingStandbys);
 
-        // ---------------- Warmup Replica Tasks ---------------- //
+        assignStatelessActiveTasks();
 
-        final Map<UUID, List<TaskId>> balancedStatefulActiveTaskAssignment =
-            new DefaultBalancedAssignor().assign(
-                sortedClients,
-                statefulTasks,
-                clientsToNumberOfThreads,
-                configs.balanceFactor);
+        return followupRebalanceNeeded;
+    }
 
-        final Map<TaskId, Integer> tasksToRemainingStandbys =
-            statefulTasks.stream().collect(Collectors.toMap(task -> task, t -> configs.numStandbyReplicas));
+    private boolean assignStatefulActiveTasks(final Map<TaskId, Integer> tasksToRemainingStandbys) {
+        final Map<UUID, List<TaskId>> statefulActiveTaskAssignment = new DefaultBalancedAssignor().assign(
+            sortedClients,
+            statefulTasks,
+            clientsToNumberOfThreads,
+            configs.balanceFactor
+        );
 
-        final List<TaskMovement> movements = getMovements(
+        return assignTaskMovements(
             statefulActiveTaskAssignment,
-            balancedStatefulActiveTaskAssignment,
             tasksToCaughtUpClients,
             clientStates,
             tasksToRemainingStandbys,
-            configs.maxWarmupReplicas);
-
-        for (final TaskMovement movement : movements) {
-            warmupTaskAssignment.get(movement.destination).add(movement.task);
-        }
-
-        // ---------------- Standby Replica Tasks ---------------- //
-
-        final List<Map<UUID, List<TaskId>>> allTaskAssignmentMaps = asList(
-            statefulActiveTaskAssignment,
-            warmupTaskAssignment,
-            standbyTaskAssignment,
-            statelessActiveTaskAssignment
+            configs.maxWarmupReplicas
         );
+    }
 
-        final ValidClientsByTaskLoadQueue<UUID> clientsByStandbyTaskLoad =
-            new ValidClientsByTaskLoadQueue<>(
-                getClientPriorityQueueByTaskLoad(allTaskAssignmentMaps),
-                allTaskAssignmentMaps
-            );
+    private void assignStandbyReplicaTasks(final Map<TaskId, Integer> tasksToRemainingStandbys) {
+        final ValidClientsByTaskLoadQueue standbyTaskClientsByTaskLoad = new ValidClientsByTaskLoadQueue(
+            clientStates,
+            (client, task) -> !clientStates.get(client).assignedTasks().contains(task)
+        );
+        standbyTaskClientsByTaskLoad.offerAll(clientStates.keySet());
 
         for (final TaskId task : statefulTasksToRankedCandidates.keySet()) {
             final int numRemainingStandbys = tasksToRemainingStandbys.get(task);
-            final List<UUID> clients = clientsByStandbyTaskLoad.poll(task, numRemainingStandbys);
+            final List<UUID> clients = standbyTaskClientsByTaskLoad.poll(task, numRemainingStandbys);
             for (final UUID client : clients) {
-                standbyTaskAssignment.get(client).add(task);
+                clientStates.get(client).assignStandby(task);
             }
-            clientsByStandbyTaskLoad.offer(clients);
+            standbyTaskClientsByTaskLoad.offerAll(clients);
+
             final int numStandbysAssigned = clients.size();
-            if (numStandbysAssigned < configs.numStandbyReplicas) {
+            if (numStandbysAssigned < numRemainingStandbys) {
                 log.warn("Unable to assign {} of {} standby tasks for task [{}]. " +
                              "There is not enough available capacity. You should " +
                              "increase the number of threads and/or application instances " +
                              "to maintain the requested number of standby replicas.",
-                    configs.numStandbyReplicas - numStandbysAssigned, configs.numStandbyReplicas, task);
+                         numRemainingStandbys - numStandbysAssigned, configs.numStandbyReplicas, task);
             }
         }
+    }
 
-        // ---------------- Stateless Active Tasks ---------------- //
-
-        final PriorityQueue<UUID> statelessActiveTaskClientsQueue = getClientPriorityQueueByTaskLoad(allTaskAssignmentMaps);
+    private void assignStatelessActiveTasks() {
+        final ValidClientsByTaskLoadQueue statelessActiveTaskClientsByTaskLoad = new ValidClientsByTaskLoadQueue(
+            clientStates,
+            (client, task) -> true
+        );
+        statelessActiveTaskClientsByTaskLoad.offerAll(clientStates.keySet());
 
         for (final TaskId task : statelessTasks) {
-            final UUID client = statelessActiveTaskClientsQueue.poll();
-            statelessActiveTaskAssignment.get(client).add(task);
-            statelessActiveTaskClientsQueue.offer(client);
+            final UUID client = statelessActiveTaskClientsByTaskLoad.poll(task);
+            final ClientState state = clientStates.get(client);
+            state.assignActive(task);
+            statelessActiveTaskClientsByTaskLoad.offer(client);
         }
-
-        // ---------------- Assign Tasks To Clients ---------------- //
-
-        assignActiveTasksToClients(statefulActiveTaskAssignment);
-        assignStandbyTasksToClients(warmupTaskAssignment);
-        assignStandbyTasksToClients(standbyTaskAssignment);
-        assignActiveTasksToClients(statelessActiveTaskAssignment);
-
-        return !movements.isEmpty();
     }
 
     /**
@@ -198,52 +174,30 @@ public class HighAvailabilityTaskAssignor implements TaskAssignor {
 
             // Verify that this client was caught-up on all stateful active tasks
             for (final TaskId activeTask : prevActiveTasks) {
-                if (!taskIsCaughtUpOnClient(activeTask, client)) {
+                if (!taskIsCaughtUpOnClientOrNoCaughtUpClientsExist(activeTask, client, tasksToCaughtUpClients)) {
                     return false;
                 }
             }
+            if (!unassignedActiveTasks.containsAll(prevActiveTasks)) {
+                return false;
+            }
             unassignedActiveTasks.removeAll(prevActiveTasks);
 
-            if (!unassignedStandbyTasks.isEmpty()) {
-                for (final TaskId task : state.prevStandbyTasks()) {
-                    final Integer remainingStandbys = unassignedStandbyTasks.get(task);
-                    if (remainingStandbys != null) {
-                        if (remainingStandbys == 1) {
-                            unassignedStandbyTasks.remove(task);
-                        } else {
-                            unassignedStandbyTasks.put(task, remainingStandbys - 1);
-                        }
+            for (final TaskId task : state.prevStandbyTasks()) {
+                final Integer remainingStandbys = unassignedStandbyTasks.get(task);
+                if (remainingStandbys != null) {
+                    if (remainingStandbys == 1) {
+                        unassignedStandbyTasks.remove(task);
+                    } else {
+                        unassignedStandbyTasks.put(task, remainingStandbys - 1);
                     }
-                }
-            }
-        }
-        return unassignedActiveTasks.isEmpty() && unassignedStandbyTasks.isEmpty();
-    }
-
-    /**
-     * @return true if this client is caught-up for this task, or the task has no caught-up clients
-     */
-    boolean taskIsCaughtUpOnClient(final TaskId task, final UUID client) {
-        boolean hasNoCaughtUpClients = true;
-        final SortedSet<RankedClient> rankedClients = statefulTasksToRankedCandidates.get(task);
-        if (rankedClients == null) {
-            return true;
-        }
-        for (final RankedClient rankedClient : rankedClients) {
-            if (rankedClient.rank() <= 0L) {
-                if (rankedClient.clientId().equals(client)) {
-                    return true;
                 } else {
-                    hasNoCaughtUpClients = false;
+                    return false;
                 }
             }
 
-            // If we haven't found our client yet, it must not be caught-up
-            if (rankedClient.rank() > 0L) {
-                break;
-            }
         }
-        return hasNoCaughtUpClients;
+        return unassignedActiveTasks.isEmpty() && unassignedStandbyTasks.isEmpty();
     }
 
     /**
@@ -276,6 +230,7 @@ public class HighAvailabilityTaskAssignor implements TaskAssignor {
      *   1) it satisfies the state constraint, ie all tasks with caught up clients are assigned to one of those clients
      *   2) it satisfies the balance factor
      *   3) there are no unassigned tasks (eg due to a client that dropped out of the group)
+     *   4) there are no warmup tasks
      */
     private boolean shouldUsePreviousAssignment() {
         if (previousAssignmentIsValid()) {
@@ -287,26 +242,6 @@ public class HighAvailabilityTaskAssignor implements TaskAssignor {
         }
     }
 
-    private static Map<UUID, List<TaskId>> initializeEmptyTaskAssignmentMap(final Set<UUID> clients) {
-        return clients.stream().collect(Collectors.toMap(id -> id, id -> new ArrayList<>()));
-    }
-
-    private void assignActiveTasksToClients(final Map<UUID, List<TaskId>> activeTasks) {
-        for (final Map.Entry<UUID, ClientState> clientEntry : clientStates.entrySet()) {
-            final UUID clientId = clientEntry.getKey();
-            final ClientState state = clientEntry.getValue();
-            state.assignActiveTasks(activeTasks.get(clientId));
-        }
-    }
-
-    private void assignStandbyTasksToClients(final Map<UUID, List<TaskId>> standbyTasks) {
-        for (final Map.Entry<UUID, ClientState> clientEntry : clientStates.entrySet()) {
-            final UUID clientId = clientEntry.getKey();
-            final ClientState state = clientEntry.getValue();
-            state.assignStandbyTasks(standbyTasks.get(clientId));
-        }
-    }
-
     private void assignPreviousTasksToClientStates() {
         for (final ClientState clientState : clientStates.values()) {
             clientState.assignActiveTasks(clientState.prevActiveTasks());
@@ -314,88 +249,4 @@ public class HighAvailabilityTaskAssignor implements TaskAssignor {
         }
     }
 
-    private PriorityQueue<UUID> getClientPriorityQueueByTaskLoad(final List<Map<UUID, List<TaskId>>> taskLoadsByClient) {
-        final PriorityQueue<UUID> queue = new PriorityQueue<>(
-            (client, other) -> {
-                final int clientTasksPerThread = tasksPerThread(client, taskLoadsByClient);
-                final int otherTasksPerThread = tasksPerThread(other, taskLoadsByClient);
-                if (clientTasksPerThread != otherTasksPerThread) {
-                    return clientTasksPerThread - otherTasksPerThread;
-                } else {
-                    return client.compareTo(other);
-                }
-            });
-
-        queue.addAll(sortedClients);
-        return queue;
-    }
-
-    private int tasksPerThread(final UUID client, final List<Map<UUID, List<TaskId>>> taskLoadsByClient) {
-        double numTasks = 0;
-        for (final Map<UUID, List<TaskId>> assignment : taskLoadsByClient) {
-            numTasks += assignment.get(client).size();
-        }
-        return (int) Math.ceil(numTasks / clientsToNumberOfThreads.get(client));
-    }
-    
-    /**
-     * Wraps a priority queue of clients and returns the next valid candidate(s) based on the current task assignment
-     */
-    static class ValidClientsByTaskLoadQueue<UUID> {
-        private final PriorityQueue<UUID> clientsByTaskLoad;
-        private final List<Map<UUID, List<TaskId>>> allStatefulTaskAssignments;
-
-        ValidClientsByTaskLoadQueue(final PriorityQueue<UUID> clientsByTaskLoad,
-                                    final List<Map<UUID, List<TaskId>>> allStatefulTaskAssignments) {
-            this.clientsByTaskLoad = clientsByTaskLoad;
-            this.allStatefulTaskAssignments = allStatefulTaskAssignments;
-        }
-
-        /**
-         * @return the next N <= {@code numClientsPerTask} clients in the underlying priority queue that are valid
-         * candidates for the given task (ie do not already have any version of this task assigned)
-         */
-        List<UUID> poll(final TaskId task, final int numClients) {
-            final List<UUID> nextLeastLoadedValidClients = new LinkedList<>();
-            final Set<UUID> invalidPolledClients = new HashSet<>();
-            while (nextLeastLoadedValidClients.size() < numClients) {
-                UUID candidateClient;
-                while (true) {
-                    candidateClient = clientsByTaskLoad.poll();
-                    if (candidateClient == null) {
-                        returnPolledClientsToQueue(invalidPolledClients);
-                        return nextLeastLoadedValidClients;
-                    }
-
-                    if (canBeAssignedToClient(task, candidateClient)) {
-                        nextLeastLoadedValidClients.add(candidateClient);
-                        break;
-                    } else {
-                        invalidPolledClients.add(candidateClient);
-                    }
-                }
-            }
-            returnPolledClientsToQueue(invalidPolledClients);
-            return nextLeastLoadedValidClients;
-        }
-
-        void offer(final Collection<UUID> clients) {
-            returnPolledClientsToQueue(clients);
-        }
-
-        private boolean canBeAssignedToClient(final TaskId task, final UUID client) {
-            for (final Map<UUID, List<TaskId>> taskAssignment : allStatefulTaskAssignments) {
-                if (taskAssignment.get(client).contains(task)) {
-                    return false;
-                }
-            }
-            return true;
-        }
-
-        private void returnPolledClientsToQueue(final Collection<UUID> polledClients) {
-            for (final UUID client : polledClients) {
-                clientsByTaskLoad.offer(client);
-            }
-        }
-    }
 }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovement.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovement.java
index e1e9bb7..1bc2751 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovement.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovement.java
@@ -16,128 +16,103 @@
  */
 package org.apache.kafka.streams.processor.internals.assignment;
 
-import java.util.HashMap;
-import java.util.Iterator;
-import java.util.LinkedList;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentUtils.taskIsCaughtUpOnClientOrNoCaughtUpClientsExist;
+
 import java.util.List;
 import java.util.Map;
-import java.util.Objects;
-import java.util.Set;
 import java.util.SortedSet;
+import java.util.TreeSet;
 import java.util.UUID;
+import java.util.concurrent.atomic.AtomicInteger;
 import org.apache.kafka.streams.processor.TaskId;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-public class TaskMovement {
-    private static final Logger log = LoggerFactory.getLogger(TaskMovement.class);
 
-    final TaskId task;
-    final UUID source;
-    final UUID destination;
+class TaskMovement {
+    private final TaskId task;
+    private final UUID destination;
+    private final SortedSet<UUID> caughtUpClients;
 
-    TaskMovement(final TaskId task, final UUID source, final UUID destination) {
+    private TaskMovement(final TaskId task, final UUID destination, final SortedSet<UUID> caughtUpClients) {
         this.task = task;
-        this.source = source;
         this.destination = destination;
-    }
+        this.caughtUpClients = caughtUpClients;
 
-    @Override
-    public boolean equals(final Object o) {
-        if (this == o) {
-            return true;
-        }
-        if (o == null || getClass() != o.getClass()) {
-            return false;
+        if (caughtUpClients == null || caughtUpClients.isEmpty()) {
+            throw new IllegalStateException("Should not attempt to move a task if no caught up clients exist");
         }
-        final TaskMovement movement = (TaskMovement) o;
-        return Objects.equals(task, movement.task) &&
-                   Objects.equals(source, movement.source) &&
-                   Objects.equals(destination, movement.destination);
-    }
-
-    @Override
-    public int hashCode() {
-        return Objects.hash(task, source, destination);
     }
 
     /**
-     * Computes the movement of tasks from the state constrained to the balanced assignment, up to the configured
-     * {@code max.warmup.replicas}. A movement corresponds to a warmup replica on the destination client, with
-     * a few exceptional cases:
-     * <p>
-     * 1. Tasks whose destination clients are caught-up, or whose source clients are not caught-up, will be moved
-     * immediately from the source to the destination in the state constrained assignment
-     * 2. Tasks whose destination client previously had this task as a standby will not be counted towards the total
-     * {@code max.warmup.replicas}. Instead they will be counted against that task's total {@code num.standby.replicas}.
-     *
-     * @param statefulActiveTaskAssignment the initial, state constrained assignment, with the source clients
-     * @param balancedStatefulActiveTaskAssignment the final, balanced assignment, with the destination clients
-     * @return list of the task movements from statefulActiveTaskAssignment to balancedStatefulActiveTaskAssignment
+     * @return whether any warmup replicas were assigned
      */
-    static List<TaskMovement> getMovements(final Map<UUID, List<TaskId>> statefulActiveTaskAssignment,
-                                           final Map<UUID, List<TaskId>> balancedStatefulActiveTaskAssignment,
-                                           final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients,
-                                           final Map<UUID, ClientState> clientStates,
-                                           final Map<TaskId, Integer> tasksToRemainingStandbys,
-                                           final int maxWarmupReplicas) {
-        if (statefulActiveTaskAssignment.size() != balancedStatefulActiveTaskAssignment.size()) {
-            throw new IllegalStateException("Tried to compute movements but assignments differ in size.");
-        }
+    static boolean assignTaskMovements(final Map<UUID, List<TaskId>> statefulActiveTaskAssignment,
+                                       final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients,
+                                       final Map<UUID, ClientState> clientStates,
+                                       final Map<TaskId, Integer> tasksToRemainingStandbys,
+                                       final int maxWarmupReplicas) {
+        boolean warmupReplicasAssigned = false;
+
+        final ValidClientsByTaskLoadQueue clientsByTaskLoad = new ValidClientsByTaskLoadQueue(
+            clientStates,
+            (client, task) -> taskIsCaughtUpOnClientOrNoCaughtUpClientsExist(task, client, tasksToCaughtUpClients)
+        );
 
-        final Map<TaskId, UUID> taskToDestinationClient = new HashMap<>();
-        for (final Map.Entry<UUID, List<TaskId>> clientEntry : balancedStatefulActiveTaskAssignment.entrySet()) {
-            final UUID destination = clientEntry.getKey();
-            for (final TaskId task : clientEntry.getValue()) {
-                taskToDestinationClient.put(task, destination);
+        final SortedSet<TaskMovement> taskMovements = new TreeSet<>(
+            (movement, other) -> {
+                final int numCaughtUpClients = movement.caughtUpClients.size();
+                final int otherNumCaughtUpClients = other.caughtUpClients.size();
+                if (numCaughtUpClients != otherNumCaughtUpClients) {
+                    return Integer.compare(numCaughtUpClients, otherNumCaughtUpClients);
+                } else {
+                    return movement.task.compareTo(other.task);
+                }
             }
+        );
+
+        for (final Map.Entry<UUID, List<TaskId>> assignmentEntry : statefulActiveTaskAssignment.entrySet()) {
+            final UUID client = assignmentEntry.getKey();
+            final ClientState state = clientStates.get(client);
+            for (final TaskId task : assignmentEntry.getValue()) {
+                if (taskIsCaughtUpOnClientOrNoCaughtUpClientsExist(task, client, tasksToCaughtUpClients)) {
+                    state.assignActive(task);
+                } else {
+                    final TaskMovement taskMovement = new TaskMovement(task, client, tasksToCaughtUpClients.get(task));
+                    taskMovements.add(taskMovement);
+                }
+            }
+            clientsByTaskLoad.offer(client);
         }
 
-        int remainingAllowedWarmupReplicas = maxWarmupReplicas;
-        final List<TaskMovement> movements = new LinkedList<>();
-        for (final Map.Entry<UUID, List<TaskId>> sourceClientEntry : statefulActiveTaskAssignment.entrySet()) {
-            final UUID source = sourceClientEntry.getKey();
+        final AtomicInteger remainingWarmupReplicas = new AtomicInteger(maxWarmupReplicas);
+        for (final TaskMovement movement : taskMovements) {
+            final UUID sourceClient = clientsByTaskLoad.poll(movement.task);
+            if (sourceClient == null) {
+                throw new IllegalStateException("Tried to move task to caught-up client but none exist");
+            }
 
-            final Iterator<TaskId> sourceClientTasksIterator = sourceClientEntry.getValue().iterator();
-            while (sourceClientTasksIterator.hasNext()) {
-                final TaskId task = sourceClientTasksIterator.next();
-                final UUID destination = taskToDestinationClient.get(task);
-                if (destination == null) {
-                    log.error("Task {} is assigned to client {} in initial assignment but has no owner in the final " +
-                                  "balanced assignment.", task, source);
-                    throw new IllegalStateException("Found task in initial assignment that was not assigned in the final.");
-                } else if (!source.equals(destination)) {
-                    if (destinationClientIsCaughtUp(task, destination, tasksToCaughtUpClients)) {
-                        sourceClientTasksIterator.remove();
-                        statefulActiveTaskAssignment.get(destination).add(task);
-                    } else {
-                        if (clientStates.get(destination).prevStandbyTasks().contains(task)
-                                && tasksToRemainingStandbys.get(task) > 0
-                        ) {
-                            decrementRemainingStandbys(task, tasksToRemainingStandbys);
-                        } else {
-                            --remainingAllowedWarmupReplicas;
-                        }
+            final ClientState sourceClientState = clientStates.get(sourceClient);
+            sourceClientState.assignActive(movement.task);
+            clientsByTaskLoad.offer(sourceClient);
 
-                        movements.add(new TaskMovement(task, source, destination));
-                        if (remainingAllowedWarmupReplicas == 0) {
-                            return movements;
-                        }
-                    }
-                }
+            final ClientState destinationClientState = clientStates.get(movement.destination);
+            if (shouldAssignWarmupReplica(movement.task, destinationClientState, remainingWarmupReplicas, tasksToRemainingStandbys)) {
+                destinationClientState.assignStandby(movement.task);
+                clientsByTaskLoad.offer(movement.destination);
+                warmupReplicasAssigned = true;
             }
         }
-        return movements;
+        return warmupReplicasAssigned;
     }
 
-    private static boolean destinationClientIsCaughtUp(final TaskId task,
-                                                       final UUID destination,
-                                                       final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients) {
-        final Set<UUID> caughtUpClients = tasksToCaughtUpClients.get(task);
-        return caughtUpClients != null && caughtUpClients.contains(destination);
+    private static boolean shouldAssignWarmupReplica(final TaskId task,
+                                                     final ClientState destinationClientState,
+                                                     final AtomicInteger remainingWarmupReplicas,
+                                                     final Map<TaskId, Integer> tasksToRemainingStandbys) {
+        if (destinationClientState.previousAssignedTasks().contains(task) && tasksToRemainingStandbys.get(task) > 0) {
+            tasksToRemainingStandbys.compute(task, (t, numStandbys) -> numStandbys - 1);
+            return true;
+        } else {
+            return remainingWarmupReplicas.getAndDecrement() > 0;
+        }
     }
 
-    private static void decrementRemainingStandbys(final TaskId task, final Map<TaskId, Integer> tasksToRemainingStandbys) {
-        tasksToRemainingStandbys.compute(task, (t, numstandbys) -> numstandbys - 1);
-    }
 }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ValidClientsByTaskLoadQueue.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ValidClientsByTaskLoadQueue.java
new file mode 100644
index 0000000..8222dc7
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ValidClientsByTaskLoadQueue.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import java.util.Set;
+import java.util.UUID;
+import java.util.function.BiFunction;
+import org.apache.kafka.streams.processor.TaskId;
+
+/**
+ * Wraps a priority queue of clients and returns the next valid candidate(s) based on the current task assignment
+ */
+class ValidClientsByTaskLoadQueue {
+
+    private final PriorityQueue<UUID> clientsByTaskLoad;
+    private final BiFunction<UUID, TaskId, Boolean> validClientCriteria;
+    private final Set<UUID> uniqueClients = new HashSet<>();
+
+    ValidClientsByTaskLoadQueue(final Map<UUID, ClientState> clientStates,
+                                final BiFunction<UUID, TaskId, Boolean> validClientCriteria) {
+        this.validClientCriteria = validClientCriteria;
+
+        clientsByTaskLoad = new PriorityQueue<>(
+            (client, other) -> {
+                final double clientTaskLoad = clientStates.get(client).taskLoad();
+                final double otherTaskLoad = clientStates.get(other).taskLoad();
+                if (clientTaskLoad < otherTaskLoad) {
+                    return -1;
+                } else if (clientTaskLoad > otherTaskLoad) {
+                    return 1;
+                } else {
+                    return client.compareTo(other);
+                }
+            });
+    }
+
+    /**
+     * @return the next least loaded client that satisfies the given criteria, or null if none do
+     */
+    UUID poll(final TaskId task) {
+        final List<UUID> validClient = poll(task, 1);
+        return validClient.isEmpty() ? null : validClient.get(0);
+    }
+
+    /**
+     * @return the next N <= {@code numClientsPerTask} clients in the underlying priority queue that are valid candidates for the given task
+     */
+    List<UUID> poll(final TaskId task, final int numClients) {
+        final List<UUID> nextLeastLoadedValidClients = new LinkedList<>();
+        final Set<UUID> invalidPolledClients = new HashSet<>();
+        while (nextLeastLoadedValidClients.size() < numClients) {
+            UUID candidateClient;
+            while (true) {
+                candidateClient = pollNextClient();
+                if (candidateClient == null) {
+                    offerAll(invalidPolledClients);
+                    return nextLeastLoadedValidClients;
+                }
+
+                if (validClientCriteria.apply(candidateClient, task)) {
+                    nextLeastLoadedValidClients.add(candidateClient);
+                    break;
+                } else {
+                    invalidPolledClients.add(candidateClient);
+                }
+            }
+        }
+        offerAll(invalidPolledClients);
+        return nextLeastLoadedValidClients;
+    }
+
+    void offerAll(final Collection<UUID> clients) {
+        for (final UUID client : clients) {
+            offer(client);
+        }
+    }
+
+    void offer(final UUID client) {
+        if (uniqueClients.contains(client)) {
+            clientsByTaskLoad.remove(client);
+        } else {
+            uniqueClients.add(client);
+        }
+        clientsByTaskLoad.offer(client);
+    }
+
+    private UUID pollNextClient() {
+        final UUID client = clientsByTaskLoad.poll();
+        uniqueClients.remove(client);
+        return client;
+    }
+}
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignorTest.java
index 397ba51..d576b69d 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignorTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignorTest.java
@@ -546,7 +546,7 @@ public class StreamsPartitionAssignorTest {
         final Set<TaskId> prevTasks10 = mkSet(TASK_0_0);
         final Set<TaskId> prevTasks11 = mkSet(TASK_0_1);
         final Set<TaskId> prevTasks20 = mkSet(TASK_0_2);
-        final Set<TaskId> standbyTasks10 = mkSet(TASK_0_1);
+        final Set<TaskId> standbyTasks10 = EMPTY_TASKS;
         final Set<TaskId> standbyTasks11 = mkSet(TASK_0_2);
         final Set<TaskId> standbyTasks20 = mkSet(TASK_0_0);
 
@@ -986,7 +986,7 @@ public class StreamsPartitionAssignorTest {
         subscriptions.put("consumer10",
                           new Subscription(
                               topics,
-                              getInfo(UUID_1, prevTasks00, standbyTasks01, USER_END_POINT).encode()));
+                              getInfo(UUID_1, prevTasks00, EMPTY_TASKS, USER_END_POINT).encode()));
         subscriptions.put("consumer11",
                           new Subscription(
                               topics,
@@ -1611,79 +1611,6 @@ public class StreamsPartitionAssignorTest {
     }
 
     @Test
-    public void shouldReturnNormalAssignmentForOldAndFutureInstancesDuringVersionProbing() {
-        builder.addSource(null, "source1", null, null, null, "topic1");
-        builder.addProcessor("processor", new MockProcessorSupplier(), "source1");
-        builder.addStateStore(new MockKeyValueStoreBuilder("store1", false), "processor");
-
-        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2);
-
-        final Set<TaskId> activeTasks = mkSet(TASK_0_0, TASK_0_1);
-        final Set<TaskId> standbyTasks = mkSet(TASK_0_2);
-        final Map<TaskId, Set<TopicPartition>> standbyTaskMap = mkMap(
-            mkEntry(TASK_0_2, Collections.singleton(t1p2))
-        );
-        final Map<TaskId, Set<TopicPartition>> futureStandbyTaskMap = mkMap(
-            mkEntry(TASK_0_0, Collections.singleton(t1p0)),
-            mkEntry(TASK_0_1, Collections.singleton(t1p1))
-        );
-
-        createMockTaskManager(allTasks, allTasks);
-        createMockAdminClient(getTopicPartitionOffsetsMap(
-            singletonList(APPLICATION_ID + "-store1-changelog"),
-            singletonList(3))
-        );
-
-        configurePartitionAssignorWith(Collections.singletonMap(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1));
-
-        subscriptions.put("consumer1",
-                new Subscription(
-                        Collections.singletonList("topic1"),
-                        getInfo(UUID_1, activeTasks, standbyTasks).encode(),
-                        asList(t1p0, t1p1))
-        );
-        subscriptions.put("future-consumer",
-                          new Subscription(
-                              Collections.singletonList("topic1"),
-                              encodeFutureSubscription(),
-                              Collections.singletonList(t1p2))
-        );
-
-        final Map<String, Assignment> assignment = partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)).groupAssignment();
-
-        assertThat(assignment.size(), equalTo(2));
-
-        assertThat(assignment.get("consumer1").partitions(), equalTo(asList(t1p0, t1p1)));
-        assertThat(
-            AssignmentInfo.decode(assignment.get("consumer1").userData()),
-            equalTo(
-                new AssignmentInfo(
-                    LATEST_SUPPORTED_VERSION,
-                    new ArrayList<>(activeTasks),
-                    standbyTaskMap,
-                    emptyMap(),
-                    emptyMap(),
-                    0
-                )
-            )
-        );
-
-        assertThat(assignment.get("future-consumer").partitions(), equalTo(Collections.singletonList(t1p2)));
-        assertThat(
-            AssignmentInfo.decode(assignment.get("future-consumer").userData()),
-            equalTo(
-                new AssignmentInfo(
-                    LATEST_SUPPORTED_VERSION,
-                    Collections.singletonList(TASK_0_2),
-                    futureStandbyTaskMap,
-                    emptyMap(),
-                    emptyMap(),
-                    0)
-            )
-        );
-    }
-
-    @Test
     public void shouldReturnInterleavedAssignmentForOnlyFutureInstancesDuringVersionProbing() {
         builder.addSource(null, "source1", null, null, null, "topic1");
 
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentTestUtils.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentTestUtils.java
index 085af0e..8021aab 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentTestUtils.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentTestUtils.java
@@ -16,9 +16,11 @@
  */
 package org.apache.kafka.streams.processor.internals.assignment;
 
+import static java.util.Collections.emptyList;
 import static java.util.Collections.emptyMap;
 import static java.util.Collections.emptySet;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.UUID;
@@ -49,20 +51,28 @@ public class AssignmentTestUtils {
     public static final TaskId TASK_2_1 = new TaskId(2, 1);
     public static final TaskId TASK_2_2 = new TaskId(2, 2);
     public static final TaskId TASK_2_3 = new TaskId(2, 3);
-    public static final TaskId TASK_3_4 = new TaskId(3, 4);
 
     public static final Set<TaskId> EMPTY_TASKS = emptySet();
+    public static final List<TaskId> EMPTY_TASK_LIST = emptyList();
     public static final Map<TaskId, Long> EMPTY_TASK_OFFSET_SUMS = emptyMap();
     public static final Map<TopicPartition, Long> EMPTY_CHANGELOG_END_OFFSETS = new HashMap<>();
 
+
+    static Map<UUID, ClientState> getClientStatesMap(final ClientState... states) {
+        final Map<UUID, ClientState> clientStates = new HashMap<>();
+        int nthState = 1;
+        for (final ClientState state : states) {
+            clientStates.put(uuidForInt(nthState), state);
+            ++nthState;
+        }
+        return clientStates;
+    }
+
     /**
      * Builds a UUID by repeating the given number n. For valid n, it is guaranteed that the returned UUIDs satisfy
      * the same relation relative to others as their parameter n does: iff n < m, then uuidForInt(n) < uuidForInt(m)
-     *
-     * @param n an integer between 1 and 7
-     * @return the UUID created by repeating the digit n in the UUID format
      */
-    static UUID uuidForInt(final Integer n) {
+    static UUID uuidForInt(final int n) {
         return new UUID(0, n);
     }
 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentUtilsTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentUtilsTest.java
new file mode 100644
index 0000000..0644b50
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentUtilsTest.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import static java.util.Collections.emptyMap;
+import static org.apache.kafka.common.utils.Utils.mkSortedSet;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_0;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_1;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_2;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentUtils.taskIsCaughtUpOnClientOrNoCaughtUpClientsExist;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.SortedSet;
+import java.util.UUID;
+import org.apache.kafka.streams.processor.TaskId;
+import org.junit.Test;
+
+public class AssignmentUtilsTest {
+
+    @Test
+    public void shouldReturnTrueIfTaskHasNoCaughtUpClients() {
+        assertTrue(taskIsCaughtUpOnClientOrNoCaughtUpClientsExist(TASK_0_0, UUID_1, emptyMap()));
+    }
+
+    @Test
+    public void shouldReturnTrueIfTaskIsCaughtUpOnClient() {
+        final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = new HashMap<>();
+        tasksToCaughtUpClients.put(TASK_0_0, mkSortedSet(UUID_1));
+
+        assertTrue(taskIsCaughtUpOnClientOrNoCaughtUpClientsExist(TASK_0_0, UUID_1, tasksToCaughtUpClients));
+    }
+
+    @Test
+    public void shouldReturnFalseIfTaskWasNotCaughtUpOnClientButCaughtUpClientsExist() {
+        final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = new HashMap<>();
+        tasksToCaughtUpClients.put(TASK_0_0, mkSortedSet(UUID_2));
+
+        assertFalse(taskIsCaughtUpOnClientOrNoCaughtUpClientsExist(TASK_0_0, UUID_1, tasksToCaughtUpClients));
+    }
+}
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/DefaultStateConstrainedBalancedAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/DefaultStateConstrainedBalancedAssignorTest.java
deleted file mode 100644
index f292773..0000000
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/DefaultStateConstrainedBalancedAssignorTest.java
+++ /dev/null
@@ -1,978 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.kafka.streams.processor.internals.assignment;
-
-import java.util.UUID;
-import org.apache.kafka.streams.processor.TaskId;
-import org.apache.kafka.streams.processor.internals.Task;
-import org.junit.Test;
-
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-import java.util.SortedMap;
-import java.util.SortedSet;
-import java.util.TreeMap;
-import java.util.TreeSet;
-
-import static org.apache.kafka.common.utils.Utils.mkEntry;
-import static org.apache.kafka.common.utils.Utils.mkMap;
-import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_1;
-import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_2;
-import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_3;
-import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_3_4;
-import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_1;
-import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_2;
-import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_3;
-import static org.apache.kafka.streams.processor.internals.assignment.RankedClient.tasksToCaughtUpClients;
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.Matchers.is;
-
-public class DefaultStateConstrainedBalancedAssignorTest {
-
-    private static final Set<UUID> TWO_CLIENTS = new HashSet<>(Arrays.asList(UUID_1, UUID_2));
-    private static final Set<UUID> THREE_CLIENTS = new HashSet<>(Arrays.asList(UUID_1, UUID_2, UUID_3));
-
-    @Test
-    public void shouldAssignTaskToCaughtUpClient() {
-        final long rankOfClient1 = 0;
-        final long rankOfClient2 = Long.MAX_VALUE;
-        final int balanceFactor = 1;
-
-        final Map<UUID, List<TaskId>> assignment = new DefaultStateConstrainedBalancedAssignor().assign(
-            oneStatefulTasksToTwoRankedClients(rankOfClient1, rankOfClient2),
-            balanceFactor,
-            TWO_CLIENTS,
-            twoClientsToNumberOfStreamThreads(1, 1),
-            tasksToCaughtUpClients(oneStatefulTasksToTwoRankedClients(rankOfClient1, rankOfClient2))
-        );
-
-        final List<TaskId> assignedTasksForClient1 = Collections.singletonList(TASK_0_1);
-        final List<TaskId> assignedTasksForClient2 = Collections.emptyList();
-        assertThat(assignment, is(expectedAssignmentForTwoClients(assignedTasksForClient1, assignedTasksForClient2)));
-    }
-
-    @Test
-    public void shouldAssignTaskToPreviouslyHostingClient() {
-        final long rankOfClient1 = Long.MAX_VALUE;
-        final long rankOfClient2 = Task.LATEST_OFFSET;
-        final int balanceFactor = 1;
-
-        final Map<UUID, List<TaskId>> assignment = new DefaultStateConstrainedBalancedAssignor().assign(
-            oneStatefulTasksToTwoRankedClients(rankOfClient1, rankOfClient2),
-            balanceFactor,
-            TWO_CLIENTS,
-            twoClientsToNumberOfStreamThreads(1, 1),
-            tasksToCaughtUpClients(oneStatefulTasksToTwoRankedClients(rankOfClient1, rankOfClient2))
-        );
-
-        final List<TaskId> assignedTasksForClient1 = Collections.emptyList();
-        final List<TaskId> assignedTasksForClient2 = Collections.singletonList(TASK_0_1);
-        assertThat(assignment, is(expectedAssignmentForTwoClients(assignedTasksForClient1, assignedTasksForClient2)));
-    }
-
-    @Test
-    public void shouldAssignTaskToPreviouslyHostingClientWhenOtherCaughtUpClientExists() {
-        final long rankOfClient1 = 0;
-        final long rankOfClient2 = Task.LATEST_OFFSET;
-        final int balanceFactor = 1;
-
-        final Map<UUID, List<TaskId>> assignment = new DefaultStateConstrainedBalancedAssignor().assign(
-            oneStatefulTasksToTwoRankedClients(rankOfClient1, rankOfClient2),
-            balanceFactor,
-            TWO_CLIENTS,
-            twoClientsToNumberOfStreamThreads(1, 1),
-            tasksToCaughtUpClients(oneStatefulTasksToTwoRankedClients(rankOfClient1, rankOfClient2))
-        );
-
-        final List<TaskId> assignedTasksForClient1 = Collections.emptyList();
-        final List<TaskId> assignedTasksForClient2 = Collections.singletonList(TASK_0_1);
-        assertThat(assignment, is(expectedAssignmentForTwoClients(assignedTasksForClient1, assignedTasksForClient2)));
-    }
-
-    @Test
-    public void shouldAssignTaskToCaughtUpClientThatIsFirstInSortOrder() {
-        final long rankOfClient1 = 0;
-        final long rankOfClient2 = 0;
-        final int balanceFactor = 1;
-
-        final Map<UUID, List<TaskId>> assignment = new DefaultStateConstrainedBalancedAssignor().assign(
-            oneStatefulTasksToTwoRankedClients(rankOfClient1, rankOfClient2),
-            balanceFactor,
-            TWO_CLIENTS,
-            twoClientsToNumberOfStreamThreads(1, 1),
-            tasksToCaughtUpClients(oneStatefulTasksToTwoRankedClients(rankOfClient1, rankOfClient2))
-        );
-
-        final List<TaskId> assignedTasksForClient1 = Collections.singletonList(TASK_0_1);
-        final List<TaskId> assignedTasksForClient2 = Collections.emptyList();
-        assertThat(assignment, is(expectedAssignmentForTwoClients(assignedTasksForClient1, assignedTasksForClient2)));
-    }
-
-    @Test
-    public void shouldAssignTaskToMostCaughtUpClient() {
-        final long rankOfClient1 = 3;
-        final long rankOfClient2 = 5;
-        final int balanceFactor = 1;
-
-        final Map<UUID, List<TaskId>> assignment = new DefaultStateConstrainedBalancedAssignor().assign(
-            oneStatefulTasksToTwoRankedClients(rankOfClient1, rankOfClient2),
-            balanceFactor,
-            TWO_CLIENTS,
-            twoClientsToNumberOfStreamThreads(1, 1),
-            tasksToCaughtUpClients(oneStatefulTasksToTwoRankedClients(rankOfClient1, rankOfClient2))
-        );
-
-        final List<TaskId> assignedTasksForClient1 = Collections.singletonList(TASK_0_1);
-        final List<TaskId> assignedTasksForClient2 = Collections.emptyList();
-        assertThat(assignment, is(expectedAssignmentForTwoClients(assignedTasksForClient1, assignedTasksForClient2)));
-    }
-
-    @Test
-    public void shouldEvenlyDistributeTasksToCaughtUpClientsThatAreNotPreviousHosts() {
-        final long rankForTask01OnClient1 = 0;
-        final long rankForTask01OnClient2 = 0;
-        final long rankForTask12OnClient1 = 0;
-        final long rankForTask12OnClient2 = 0;
-        final int balanceFactor = 1;
-
-        final SortedMap<TaskId, SortedSet<RankedClient>> statefulTasksToRankedCandidates =
-            twoStatefulTasksToTwoRankedClients(
-                rankForTask01OnClient1,
-                rankForTask01OnClient2,
-                rankForTask12OnClient1,
-                rankForTask12OnClient2
-            );
-
-        final Map<UUID, List<TaskId>> assignment = new DefaultStateConstrainedBalancedAssignor().assign(
-            statefulTasksToRankedCandidates,
-            balanceFactor,
-            TWO_CLIENTS,
-            twoClientsToNumberOfStreamThreads(1, 1),
-            tasksToCaughtUpClients(statefulTasksToRankedCandidates)
-        );
-
-        final List<TaskId> assignedTasksForClient1 = Collections.singletonList(TASK_0_1);
-        final List<TaskId> assignedTasksForClient2 = Collections.singletonList(TASK_1_2);
-        assertThat(assignment, is(expectedAssignmentForTwoClients(assignedTasksForClient1, assignedTasksForClient2)));
-    }
-
-    @Test
-    public void shouldEvenlyDistributeTasksToCaughtUpClientsThatAreNotPreviousHostsEvenIfNotRequiredByBalanceFactor() {
-        final long rankForTask01OnClient1 = 0;
-        final long rankForTask01OnClient2 = 0;
-        final long rankForTask12OnClient1 = 0;
-        final long rankForTask12OnClient2 = 0;
-        final int balanceFactor = 2;
-
-        final SortedMap<TaskId, SortedSet<RankedClient>> statefulTasksToRankedCandidates =
-            twoStatefulTasksToTwoRankedClients(
-                rankForTask01OnClient1,
-                rankForTask01OnClient2,
-                rankForTask12OnClient1,
-                rankForTask12OnClient2
-            );
-
-        final Map<UUID, List<TaskId>> assignment = new DefaultStateConstrainedBalancedAssignor().assign(
-            statefulTasksToRankedCandidates,
-            balanceFactor,
-            TWO_CLIENTS,
-            twoClientsToNumberOfStreamThreads(1, 1),
-            tasksToCaughtUpClients(statefulTasksToRankedCandidates)
-        );
-
-        final List<TaskId> assignedTasksForClient1 = Collections.singletonList(TASK_0_1);
-        final List<TaskId> assignedTasksForClient2 = Collections.singletonList(TASK_1_2);
-        assertThat(assignment, is(expectedAssignmentForTwoClients(assignedTasksForClient1, assignedTasksForClient2)));
-    }
-
-    @Test
-    public void shouldEvenlyDistributeTasksToCaughtUpClientsEvenIfOneClientIsPreviousHostOfAll() {
-        final long rankForTask01OnClient1 = Task.LATEST_OFFSET;
-        final long rankForTask01OnClient2 = 0;
-        final long rankForTask01OnClient3 = 0;
-        final long rankForTask12OnClient1 = Task.LATEST_OFFSET;
-        final long rankForTask12OnClient2 = 0;
-        final long rankForTask12OnClient3 = 0;
-        final long rankForTask23OnClient1 = Task.LATEST_OFFSET;
-        final long rankForTask23OnClient2 = 0;
-        final long rankForTask23OnClient3 = 0;
-        final int balanceFactor = 1;
-
-        final SortedMap<TaskId, SortedSet<RankedClient>> statefulTasksToRankedCandidates =
-            threeStatefulTasksToThreeRankedClients(
-                rankForTask01OnClient1,
-                rankForTask01OnClient2,
-                rankForTask01OnClient3,
-                rankForTask12OnClient1,
-                rankForTask12OnClient2,
-                rankForTask12OnClient3,
-                rankForTask23OnClient1,
-                rankForTask23OnClient2,
-                rankForTask23OnClient3
-            );
-
-        final Map<UUID, List<TaskId>> assignment = new DefaultStateConstrainedBalancedAssignor().assign(
-            statefulTasksToRankedCandidates,
-            balanceFactor,
-            THREE_CLIENTS,
-            threeClientsToNumberOfStreamThreads(1, 1, 1),
-            tasksToCaughtUpClients(statefulTasksToRankedCandidates)
-        );
-
-        final List<TaskId> assignedTasksForClient1 = Collections.singletonList(TASK_0_1);
-        final List<TaskId> assignedTasksForClient2 = Collections.singletonList(TASK_2_3);
-        final List<TaskId> assignedTasksForClient3 = Collections.singletonList(TASK_1_2);
-        assertThat(
-            assignment,
-            is(expectedAssignmentForThreeClients(assignedTasksForClient1, assignedTasksForClient2, assignedTasksForClient3))
-        );
-    }
-
-    @Test
-    public void shouldMoveTask01FromClient1ToEvenlyDistributeTasksToCaughtUpClientsEvenIfOneClientIsPreviousHostOfBoth() {
-        final long rankForTask01OnClient1 = Task.LATEST_OFFSET;
-        final long rankForTask01OnClient2 = 0;
-        final long rankForTask12OnClient1 = Task.LATEST_OFFSET;
-        final long rankForTask12OnClient2 = 100;
-        final int balanceFactor = 1;
-
-        final SortedMap<TaskId, SortedSet<RankedClient>> statefulTasksToRankedCandidates =
-            twoStatefulTasksToTwoRankedClients(
-                rankForTask01OnClient1,
-                rankForTask01OnClient2,
-                rankForTask12OnClient1,
-                rankForTask12OnClient2
-            );
-
-        final Map<UUID, List<TaskId>> assignment = new DefaultStateConstrainedBalancedAssignor().assign(
-            statefulTasksToRankedCandidates,
-            balanceFactor,
-            TWO_CLIENTS,
-            twoClientsToNumberOfStreamThreads(1, 1),
-            tasksToCaughtUpClients(statefulTasksToRankedCandidates)
-        );
-
-        final List<TaskId> assignedTasksForClient1 = Collections.singletonList(TASK_1_2);
-        final List<TaskId> assignedTasksForClient2 = Collections.singletonList(TASK_0_1);
-        assertThat(assignment, is(expectedAssignmentForTwoClients(assignedTasksForClient1, assignedTasksForClient2)));
-    }
-
-    @Test
-    public void shouldMoveTask12FromClient2ToEvenlyDistributeTasksToCaughtUpClientsEvenIfOneClientIsPreviousHostOfBoth() {
-        final long rankForTask01OnClient1 = 100;
-        final long rankForTask01OnClient2 = Task.LATEST_OFFSET;
-        final long rankForTask12OnClient1 = 0;
-        final long rankForTask12OnClient2 = Task.LATEST_OFFSET;
-        final int balanceFactor = 1;
-
-        final SortedMap<TaskId, SortedSet<RankedClient>> statefulTasksToRankedCandidates =
-            twoStatefulTasksToTwoRankedClients(
-                rankForTask01OnClient1,
-                rankForTask01OnClient2,
-                rankForTask12OnClient1,
-                rankForTask12OnClient2
-            );
-
-        final Map<UUID, List<TaskId>> assignment = new DefaultStateConstrainedBalancedAssignor().assign(
-            statefulTasksToRankedCandidates,
-            balanceFactor,
-            TWO_CLIENTS,
-            twoClientsToNumberOfStreamThreads(1, 1),
-            tasksToCaughtUpClients(statefulTasksToRankedCandidates)
-        );
-
-        final List<TaskId> assignedTasksForClient1 = Collections.singletonList(TASK_1_2);
-        final List<TaskId> assignedTasksForClient2 = Collections.singletonList(TASK_0_1);
-        assertThat(assignment, is(expectedAssignmentForTwoClients(assignedTasksForClient1, assignedTasksForClient2)));
-    }
-
-    @Test
-    public void shouldAssignBothTasksToPreviousHostSinceBalanceFactorSatisfied() {
-        final long rankForTask01OnClient1 = Task.LATEST_OFFSET;
-        final long rankForTask01OnClient2 = 0;
-        final long rankForTask12OnClient1 = Task.LATEST_OFFSET;
-        final long rankForTask12OnClient2 = 0;
-        final int balanceFactor = 2;
-
-        final SortedMap<TaskId, SortedSet<RankedClient>> statefulTasksToRankedCandidates =
-            twoStatefulTasksToTwoRankedClients(
-                rankForTask01OnClient1,
-                rankForTask01OnClient2,
-                rankForTask12OnClient1,
-                rankForTask12OnClient2
-            );
-
-        final Map<UUID, List<TaskId>> assignment = new DefaultStateConstrainedBalancedAssignor().assign(
-            statefulTasksToRankedCandidates,
-            balanceFactor,
-            TWO_CLIENTS,
-            twoClientsToNumberOfStreamThreads(1, 1),
-            tasksToCaughtUpClients(statefulTasksToRankedCandidates)
-        );
-
-        final List<TaskId> assignedTasksForClient1 = Arrays.asList(TASK_0_1, TASK_1_2);
-        final List<TaskId> assignedTasksForClient2 = Collections.emptyList();
-        assertThat(assignment, is(expectedAssignmentForTwoClients(assignedTasksForClient1, assignedTasksForClient2)));
-    }
-
-    @Test
-    public void shouldAssignOneTaskToPreviousHostAndOtherTaskToMostCaughtUpClient() {
-        final long rankForTask01OnClient1 = Task.LATEST_OFFSET;
-        final long rankForTask01OnClient2 = 0;
-        final long rankForTask12OnClient1 = 20;
-        final long rankForTask12OnClient2 = 10;
-        final int balanceFactor = 1;
-
-        final SortedMap<TaskId, SortedSet<RankedClient>> statefulTasksToRankedCandidates =
-            twoStatefulTasksToTwoRankedClients(
-                rankForTask01OnClient1,
-                rankForTask01OnClient2,
-                rankForTask12OnClient1,
-                rankForTask12OnClient2
-            );
-
-        final Map<UUID, List<TaskId>> assignment = new DefaultStateConstrainedBalancedAssignor().assign(
-            statefulTasksToRankedCandidates,
-            balanceFactor,
-            TWO_CLIENTS,
-            twoClientsToNumberOfStreamThreads(1, 1),
-            tasksToCaughtUpClients(statefulTasksToRankedCandidates)
-        );
-
-        final List<TaskId> assignedTasksForClient1 = Collections.singletonList(TASK_0_1);
-        final List<TaskId> assignedTasksForClient2 = Collections.singletonList(TASK_1_2);
-        assertThat(assignment, is(expectedAssignmentForTwoClients(assignedTasksForClient1, assignedTasksForClient2)));
-    }
-
-    @Test
-    public void shouldAssignOneTaskToPreviousHostAndOtherTaskToLessCaughtUpClientDueToBalanceFactor() {
-        final long rankForTask01OnClient1 = 0;
-        final long rankForTask01OnClient2 = Task.LATEST_OFFSET;
-        final long rankForTask12OnClient1 = 20;
-        final long rankForTask12OnClient2 = 10;
-        final int balanceFactor = 1;
-
-        final SortedMap<TaskId, SortedSet<RankedClient>> statefulTasksToRankedCandidates =
-            twoStatefulTasksToTwoRankedClients(
-                rankForTask01OnClient1,
-                rankForTask01OnClient2,
-                rankForTask12OnClient1,
-                rankForTask12OnClient2
-            );
-
-        final Map<UUID, List<TaskId>> assignment = new DefaultStateConstrainedBalancedAssignor().assign(
-            statefulTasksToRankedCandidates,
-            balanceFactor,
-            TWO_CLIENTS,
-            twoClientsToNumberOfStreamThreads(1, 1),
-            tasksToCaughtUpClients(statefulTasksToRankedCandidates)
-        );
-
-        final List<TaskId> assignedTasksForClient1 = Collections.singletonList(TASK_1_2);
-        final List<TaskId> assignedTasksForClient2 = Collections.singletonList(TASK_0_1);
-        assertThat(assignment, is(expectedAssignmentForTwoClients(assignedTasksForClient1, assignedTasksForClient2)));
-    }
-
-    @Test
-    public void shouldAssignBothTasksToSameClientSincePreviousHostAndMostCaughtUpAndBalanceFactorSatisfied() {
-        final long rankForTask01OnClient1 = Task.LATEST_OFFSET;
-        final long rankForTask01OnClient2 = 0;
-        final long rankForTask12OnClient1 = 10;
-        final long rankForTask12OnClient2 = 20;
-        final int balanceFactor = 2;
-
-        final SortedMap<TaskId, SortedSet<RankedClient>> statefulTasksToRankedCandidates =
-            twoStatefulTasksToTwoRankedClients(
-                rankForTask01OnClient1,
-                rankForTask01OnClient2,
-                rankForTask12OnClient1,
-                rankForTask12OnClient2
-            );
-
-        final Map<UUID, List<TaskId>> assignment = new DefaultStateConstrainedBalancedAssignor().assign(
-            statefulTasksToRankedCandidates,
-            balanceFactor,
-            TWO_CLIENTS,
-            twoClientsToNumberOfStreamThreads(1, 1),
-            tasksToCaughtUpClients(statefulTasksToRankedCandidates)
-        );
-
-        final List<TaskId> assignedTasksForClient1 = Arrays.asList(TASK_0_1, TASK_1_2);
-        final List<TaskId> assignedTasksForClient2 = Collections.emptyList();
-        assertThat(assignment, is(expectedAssignmentForTwoClients(assignedTasksForClient1, assignedTasksForClient2)));
-    }
-
-    @Test
-    public void shouldAssignTasksToMostCaughtUpClient() {
-        final long rankForTask01OnClient1 = 50;
-        final long rankForTask01OnClient2 = 20;
-        final long rankForTask12OnClient1 = 20;
-        final long rankForTask12OnClient2 = 50;
-        final int balanceFactor = 1;
-
-        final SortedMap<TaskId, SortedSet<RankedClient>> statefulTasksToRankedCandidates =
-            twoStatefulTasksToTwoRankedClients(
-                rankForTask01OnClient1,
-                rankForTask01OnClient2,
-                rankForTask12OnClient1,
-                rankForTask12OnClient2
-            );
-
-        final Map<UUID, List<TaskId>> assignment = new DefaultStateConstrainedBalancedAssignor().assign(
-            statefulTasksToRankedCandidates,
-            balanceFactor,
-            TWO_CLIENTS,
-            twoClientsToNumberOfStreamThreads(1, 1),
-            tasksToCaughtUpClients(statefulTasksToRankedCandidates)
-        );
-
-        final List<TaskId> assignedTasksForClient1 = Collections.singletonList(TASK_1_2);
-        final List<TaskId> assignedTasksForClient2 = Collections.singletonList(TASK_0_1);
-        assertThat(assignment, is(expectedAssignmentForTwoClients(assignedTasksForClient1, assignedTasksForClient2)));
-    }
-
-    @Test
-    public void shouldEvenlyDistributeTasksEvenIfClientsAreNotMostCaughtUpDueToBalanceFactor() {
-        final long rankForTask01OnClient1 = 20;
-        final long rankForTask01OnClient2 = 50;
-        final long rankForTask01OnClient3 = 100;
-        final long rankForTask12OnClient1 = 20;
-        final long rankForTask12OnClient2 = 50;
-        final long rankForTask12OnClient3 = 100;
-        final long rankForTask23OnClient1 = 20;
-        final long rankForTask23OnClient2 = 50;
-        final long rankForTask23OnClient3 = 100;
-        final int balanceFactor = 1;
-
-        final SortedMap<TaskId, SortedSet<RankedClient>> statefulTasksToRankedCandidates =
-            threeStatefulTasksToThreeRankedClients(
-                rankForTask01OnClient1,
-                rankForTask01OnClient2,
-                rankForTask01OnClient3,
-                rankForTask12OnClient1,
-                rankForTask12OnClient2,
-                rankForTask12OnClient3,
-                rankForTask23OnClient1,
-                rankForTask23OnClient2,
-                rankForTask23OnClient3
-            );
-
-        final Map<UUID, List<TaskId>> assignment = new DefaultStateConstrainedBalancedAssignor().assign(
-            statefulTasksToRankedCandidates,
-            balanceFactor,
-            THREE_CLIENTS,
-            threeClientsToNumberOfStreamThreads(1, 1, 1),
-            tasksToCaughtUpClients(statefulTasksToRankedCandidates)
-        );
-
-        final List<TaskId> assignedTasksForClient1 = Collections.singletonList(TASK_0_1);
-        final List<TaskId> assignedTasksForClient2 = Collections.singletonList(TASK_2_3);
-        final List<TaskId> assignedTasksForClient3 = Collections.singletonList(TASK_1_2);
-        assertThat(
-            assignment,
-            is(expectedAssignmentForThreeClients(assignedTasksForClient1, assignedTasksForClient2, assignedTasksForClient3))
-        );
-    }
-
-    @Test
-    public void shouldAssignBothTasksToSameMostCaughtUpClientSinceBalanceFactorSatisfied() {
-        final long rankForTask01OnClient1 = 40;
-        final long rankForTask01OnClient2 = 30;
-        final long rankForTask12OnClient1 = 20;
-        final long rankForTask12OnClient2 = 10;
-        final int balanceFactor = 2;
-
-        final SortedMap<TaskId, SortedSet<RankedClient>> statefulTasksToRankedCandidates =
-            twoStatefulTasksToTwoRankedClients(
-                rankForTask01OnClient1,
-                rankForTask01OnClient2,
-                rankForTask12OnClient1,
-                rankForTask12OnClient2
-            );
-
-        final Map<UUID, List<TaskId>> assignment = new DefaultStateConstrainedBalancedAssignor().assign(
-            statefulTasksToRankedCandidates,
-            balanceFactor,
-            TWO_CLIENTS,
-            twoClientsToNumberOfStreamThreads(1, 1),
-            tasksToCaughtUpClients(statefulTasksToRankedCandidates)
-        );
-
-        final List<TaskId> assignedTasksForClient1 = Collections.emptyList();
-        final List<TaskId> assignedTasksForClient2 = Arrays.asList(TASK_0_1, TASK_1_2);
-        assertThat(assignment, is(expectedAssignmentForTwoClients(assignedTasksForClient1, assignedTasksForClient2)));
-    }
-
-    @Test
-    public void shouldEvenlyDistributeTasksOverClientsWithEqualRank() {
-        final long rankForTask01OnClient1 = 40;
-        final long rankForTask01OnClient2 = 40;
-        final long rankForTask12OnClient1 = 40;
-        final long rankForTask12OnClient2 = 40;
-        final int balanceFactor = 2;
-
-        final SortedMap<TaskId, SortedSet<RankedClient>> statefulTasksToRankedCandidates =
-            twoStatefulTasksToTwoRankedClients(
-                rankForTask01OnClient1,
-                rankForTask01OnClient2,
-                rankForTask12OnClient1,
-                rankForTask12OnClient2
-            );
-
-        final Map<UUID, List<TaskId>> assignment = new DefaultStateConstrainedBalancedAssignor().assign(
-            statefulTasksToRankedCandidates,
-            balanceFactor,
-            TWO_CLIENTS,
-            twoClientsToNumberOfStreamThreads(1, 1),
-            tasksToCaughtUpClients(statefulTasksToRankedCandidates)
-        );
-
-        final List<TaskId> assignedTasksForClient1 = Collections.singletonList(TASK_0_1);
-        final List<TaskId> assignedTasksForClient2 = Collections.singletonList(TASK_1_2);
-        assertThat(assignment, is(expectedAssignmentForTwoClients(assignedTasksForClient1, assignedTasksForClient2)));
-    }
-
-    /**
-     * This test shows that in an assigment of one client the assumption that the set of tasks which are caught-up on
-     * the given client is followed by the set of tasks that are not caught-up on the given client does NOT hold.
-     * In fact, in this test, at some point during the execution of the algorithm the assignment for UUID_2
-     * contains TASK_3_4 followed by TASK_2_3. TASK_2_3 is caught-up on UUID_2 whereas TASK_3_4 is not.
-     */
-    @Test
-    public void shouldEvenlyDistributeTasksOrderOfCaughtUpAndNotCaughtUpTaskIsMixedUpInIntermediateResults() {
-        final long rankForTask01OnClient1 = Task.LATEST_OFFSET;
-        final long rankForTask01OnClient2 = 0;
-        final long rankForTask01OnClient3 = 100;
-        final long rankForTask12OnClient1 = Task.LATEST_OFFSET;
-        final long rankForTask12OnClient2 = 0;
-        final long rankForTask12OnClient3 = 100;
-        final long rankForTask23OnClient1 = Task.LATEST_OFFSET;
-        final long rankForTask23OnClient2 = 0;
-        final long rankForTask23OnClient3 = 100;
-        final long rankForTask34OnClient1 = 50;
-        final long rankForTask34OnClient2 = 20;
-        final long rankForTask34OnClient3 = 100;
-        final int balanceFactor = 1;
-
-        final SortedMap<TaskId, SortedSet<RankedClient>> statefulTasksToRankedCandidates =
-            fourStatefulTasksToThreeRankedClients(
-                rankForTask01OnClient1,
-                rankForTask01OnClient2,
-                rankForTask01OnClient3,
-                rankForTask12OnClient1,
-                rankForTask12OnClient2,
-                rankForTask12OnClient3,
-                rankForTask23OnClient1,
-                rankForTask23OnClient2,
-                rankForTask23OnClient3,
-                rankForTask34OnClient1,
-                rankForTask34OnClient2,
-                rankForTask34OnClient3
-            );
-
-        final Map<UUID, List<TaskId>> assignment = new DefaultStateConstrainedBalancedAssignor().assign(
-            statefulTasksToRankedCandidates,
-            balanceFactor,
-            THREE_CLIENTS,
-            threeClientsToNumberOfStreamThreads(1, 1, 1),
-            tasksToCaughtUpClients(statefulTasksToRankedCandidates)
-        );
-
-        final List<TaskId> assignedTasksForClient1 = Arrays.asList(TASK_0_1, TASK_1_2);
-        final List<TaskId> assignedTasksForClient2 = Collections.singletonList(TASK_2_3);
-        final List<TaskId> assignedTasksForClient3 = Collections.singletonList(TASK_3_4);
-        assertThat(
-            assignment,
-            is(expectedAssignmentForThreeClients(assignedTasksForClient1, assignedTasksForClient2, assignedTasksForClient3))
-        );
-    }
-
-    @Test
-    public void shouldAssignTasksToTheCaughtUpClientEvenIfTheAssignmentIsUnbalanced() {
-        final long rankForTask01OnClient1 = 60;
-        final long rankForTask01OnClient2 = 50;
-        final long rankForTask01OnClient3 = Task.LATEST_OFFSET;
-        final long rankForTask12OnClient1 = 40;
-        final long rankForTask12OnClient2 = 30;
-        final long rankForTask12OnClient3 = 0;
-        final long rankForTask23OnClient1 = 10;
-        final long rankForTask23OnClient2 = 20;
-        final long rankForTask23OnClient3 = Task.LATEST_OFFSET;
-        final long rankForTask34OnClient1 = 70;
-        final long rankForTask34OnClient2 = 80;
-        final long rankForTask34OnClient3 = 90;
-        final int balanceFactor = 1;
-
-        final SortedMap<TaskId, SortedSet<RankedClient>> statefulTasksToRankedCandidates =
-            fourStatefulTasksToThreeRankedClients(
-                rankForTask01OnClient1,
-                rankForTask01OnClient2,
-                rankForTask01OnClient3,
-                rankForTask12OnClient1,
-                rankForTask12OnClient2,
-                rankForTask12OnClient3,
-                rankForTask23OnClient1,
-                rankForTask23OnClient2,
-                rankForTask23OnClient3,
-                rankForTask34OnClient1,
-                rankForTask34OnClient2,
-                rankForTask34OnClient3
-            );
-
-        final Map<UUID, List<TaskId>> assignment = new DefaultStateConstrainedBalancedAssignor().assign(
-            statefulTasksToRankedCandidates,
-            balanceFactor,
-            THREE_CLIENTS,
-            threeClientsToNumberOfStreamThreads(1, 1, 1),
-            tasksToCaughtUpClients(statefulTasksToRankedCandidates)
-        );
-
-        final List<TaskId> assignedTasksForClient1 = Collections.singletonList(TASK_3_4);
-        final List<TaskId> assignedTasksForClient2 = Collections.emptyList();
-        final List<TaskId> assignedTasksForClient3 = Arrays.asList(TASK_0_1, TASK_2_3, TASK_1_2);
-        assertThat(
-            assignment,
-            is(expectedAssignmentForThreeClients(assignedTasksForClient1, assignedTasksForClient2, assignedTasksForClient3))
-        );
-    }
-
-    @Test
-    public void shouldEvenlyDistributeTasksOverSameNumberOfStreamThreads() {
-        final long rankForTask01OnClient1 = 0;
-        final long rankForTask01OnClient2 = 0;
-        final long rankForTask12OnClient1 = 0;
-        final long rankForTask12OnClient2 = 0;
-        final long rankForTask23OnClient1 = 0;
-        final long rankForTask23OnClient2 = 0;
-        final int balanceFactor = 1;
-
-        final SortedMap<TaskId, SortedSet<RankedClient>> statefulTasksToRankedCandidates =
-            threeStatefulTasksToTwoRankedClients(
-                rankForTask01OnClient1,
-                rankForTask01OnClient2,
-                rankForTask12OnClient1,
-                rankForTask12OnClient2,
-                rankForTask23OnClient1,
-                rankForTask23OnClient2
-            );
-
-        final Map<UUID, List<TaskId>> assignment = new DefaultStateConstrainedBalancedAssignor().assign(
-            statefulTasksToRankedCandidates,
-            balanceFactor,
-            TWO_CLIENTS,
-            twoClientsToNumberOfStreamThreads(1, 2),
-            tasksToCaughtUpClients(statefulTasksToRankedCandidates)
-        );
-
-        final List<TaskId> assignedTasksForClient1 = Collections.singletonList(TASK_0_1);
-        final List<TaskId> assignedTasksForClient2 = Arrays.asList(TASK_1_2, TASK_2_3);
-        assertThat(assignment, is(expectedAssignmentForTwoClients(assignedTasksForClient1, assignedTasksForClient2)));
-    }
-
-    @Test
-    public void shouldEvenlyDistributeTasksOnUnderProvisionedStreamThreads() {
-        final long rankForTask01OnClient1 = 0;
-        final long rankForTask01OnClient2 = 0;
-        final long rankForTask12OnClient1 = 0;
-        final long rankForTask12OnClient2 = 0;
-        final long rankForTask23OnClient1 = 0;
-        final long rankForTask23OnClient2 = 0;
-        final long rankForTask34OnClient1 = 0;
-        final long rankForTask34OnClient2 = 0;
-        final int balanceFactor = 1;
-
-        final SortedMap<TaskId, SortedSet<RankedClient>> statefulTasksToRankedCandidates =
-            fourStatefulTasksToTwoRankedClients(
-                rankForTask01OnClient1,
-                rankForTask01OnClient2,
-                rankForTask12OnClient1,
-                rankForTask12OnClient2,
-                rankForTask23OnClient1,
-                rankForTask23OnClient2,
-                rankForTask34OnClient1,
-                rankForTask34OnClient2
-            );
-
-        final Map<UUID, List<TaskId>> assignment = new DefaultStateConstrainedBalancedAssignor().assign(
-            statefulTasksToRankedCandidates,
-            balanceFactor,
-            TWO_CLIENTS,
-            twoClientsToNumberOfStreamThreads(1, 2),
-            tasksToCaughtUpClients(statefulTasksToRankedCandidates)
-        );
-
-        final List<TaskId> assignedTasksForClient1 = Arrays.asList(TASK_0_1, TASK_2_3);
-        final List<TaskId> assignedTasksForClient2 = Arrays.asList(TASK_1_2, TASK_3_4);
-        assertThat(assignment, is(expectedAssignmentForTwoClients(assignedTasksForClient1, assignedTasksForClient2)));
-    }
-
-    @Test
-    public void shouldDistributeTasksOverOverProvisionedStreamThreadsYieldingBalancedStreamThreadsAndClients() {
-        final long rankForTask01OnClient1 = 0;
-        final long rankForTask01OnClient2 = 0;
-        final long rankForTask12OnClient1 = 0;
-        final long rankForTask12OnClient2 = 0;
-        final int balanceFactor = 1;
-
-        final SortedMap<TaskId, SortedSet<RankedClient>> statefulTasksToRankedCandidates =
-            twoStatefulTasksToTwoRankedClients(
-                rankForTask01OnClient1,
-                rankForTask01OnClient2,
-                rankForTask12OnClient1,
-                rankForTask12OnClient2
-            );
-
-        final Map<UUID, List<TaskId>> assignment = new DefaultStateConstrainedBalancedAssignor().assign(
-            statefulTasksToRankedCandidates,
-            balanceFactor,
-            TWO_CLIENTS,
-            twoClientsToNumberOfStreamThreads(2, 1),
-            tasksToCaughtUpClients(statefulTasksToRankedCandidates)
-        );
-
-        final List<TaskId> assignedTasksForClient1 = Collections.singletonList(TASK_0_1);
-        final List<TaskId> assignedTasksForClient2 = Collections.singletonList(TASK_1_2);
-        assertThat(assignment, is(expectedAssignmentForTwoClients(assignedTasksForClient1, assignedTasksForClient2)));
-    }
-
-    @Test
-    public void shouldDistributeTasksOverOverProvisionedStreamThreadsYieldingBalancedStreamThreadsButUnbalancedClients() {
-        final long rankForTask01OnClient1 = 0;
-        final long rankForTask01OnClient2 = 0;
-        final long rankForTask12OnClient1 = 0;
-        final long rankForTask12OnClient2 = 0;
-        final long rankForTask23OnClient1 = 0;
-        final long rankForTask23OnClient2 = 0;
-        final long rankForTask34OnClient1 = 0;
-        final long rankForTask34OnClient2 = 0;
-        final int balanceFactor = 1;
-
-        final SortedMap<TaskId, SortedSet<RankedClient>> statefulTasksToRankedCandidates =
-            fourStatefulTasksToTwoRankedClients(
-                rankForTask01OnClient1,
-                rankForTask01OnClient2,
-                rankForTask12OnClient1,
-                rankForTask12OnClient2,
-                rankForTask23OnClient1,
-                rankForTask23OnClient2,
-                rankForTask34OnClient1,
-                rankForTask34OnClient2
-            );
-
-        final Map<UUID, List<TaskId>> assignment = new DefaultStateConstrainedBalancedAssignor().assign(
-            statefulTasksToRankedCandidates,
-            balanceFactor,
-            TWO_CLIENTS,
-            twoClientsToNumberOfStreamThreads(1, 4),
-            tasksToCaughtUpClients(statefulTasksToRankedCandidates)
-        );
-
-        final List<TaskId> assignedTasksForClient1 = Collections.singletonList(TASK_0_1);
-        final List<TaskId> assignedTasksForClient2 = Arrays.asList(TASK_1_2, TASK_3_4, TASK_2_3);
-        assertThat(assignment, is(expectedAssignmentForTwoClients(assignedTasksForClient1, assignedTasksForClient2)));
-    }
-
-    private static Map<UUID, Integer> twoClientsToNumberOfStreamThreads(final int numberOfStreamThread1,
-                                                                        final int numberOfStreamThread2) {
-        return mkMap(
-            mkEntry(UUID_1, numberOfStreamThread1),
-            mkEntry(UUID_2, numberOfStreamThread2)
-        );
-    }
-
-    private static Map<UUID, Integer> threeClientsToNumberOfStreamThreads(final int numberOfStreamThread1,
-                                                                          final int numberOfStreamThread2,
-                                                                          final int numberOfStreamThread3) {
-        return mkMap(
-            mkEntry(UUID_1, numberOfStreamThread1),
-            mkEntry(UUID_2, numberOfStreamThread2),
-            mkEntry(UUID_3, numberOfStreamThread3)
-        );
-    }
-
-    private static SortedMap<TaskId, SortedSet<RankedClient>> oneStatefulTasksToTwoRankedClients(final long rankOfClient1,
-                                                                                                 final long rankOfClient2) {
-        final SortedSet<RankedClient> rankedClients01 = new TreeSet<>();
-        rankedClients01.add(new RankedClient(UUID_1, rankOfClient1));
-        rankedClients01.add(new RankedClient(UUID_2, rankOfClient2));
-        return new TreeMap<>(
-            mkMap(mkEntry(TASK_0_1, rankedClients01))
-        );
-    }
-
-    private static SortedMap<TaskId, SortedSet<RankedClient>> twoStatefulTasksToTwoRankedClients(final long rankForTask01OnClient1,
-                                                                                                 final long rankForTask01OnClient2,
-                                                                                                 final long rankForTask12OnClient1,
-                                                                                                 final long rankForTask12OnClient2) {
-        final SortedSet<RankedClient> rankedClients01 = new TreeSet<>();
-        rankedClients01.add(new RankedClient(UUID_1, rankForTask01OnClient1));
-        rankedClients01.add(new RankedClient(UUID_2, rankForTask01OnClient2));
-        final SortedSet<RankedClient> rankedClients12 = new TreeSet<>();
-        rankedClients12.add(new RankedClient(UUID_1, rankForTask12OnClient1));
-        rankedClients12.add(new RankedClient(UUID_2, rankForTask12OnClient2));
-        return new TreeMap<>(
-            mkMap(
-                mkEntry(TASK_0_1, rankedClients01),
-                mkEntry(TASK_1_2, rankedClients12)
-            )
-        );
-    }
-
-    private static SortedMap<TaskId, SortedSet<RankedClient>> threeStatefulTasksToTwoRankedClients(final long rankForTask01OnClient1,
-                                                                                                   final long rankForTask01OnClient2,
-                                                                                                   final long rankForTask12OnClient1,
-                                                                                                   final long rankForTask12OnClient2,
-                                                                                                   final long rankForTask23OnClient1,
-                                                                                                   final long rankForTask23OnClient2) {
-        final SortedSet<RankedClient> rankedClients01 = new TreeSet<>();
-        rankedClients01.add(new RankedClient(UUID_1, rankForTask01OnClient1));
-        rankedClients01.add(new RankedClient(UUID_2, rankForTask01OnClient2));
-        final SortedSet<RankedClient> rankedClients12 = new TreeSet<>();
-        rankedClients12.add(new RankedClient(UUID_1, rankForTask12OnClient1));
-        rankedClients12.add(new RankedClient(UUID_2, rankForTask12OnClient2));
-        final SortedSet<RankedClient> rankedClients23 = new TreeSet<>();
-        rankedClients23.add(new RankedClient(UUID_1, rankForTask23OnClient1));
-        rankedClients23.add(new RankedClient(UUID_2, rankForTask23OnClient2));
-        return new TreeMap<>(
-            mkMap(
-                mkEntry(TASK_0_1, rankedClients01),
-                mkEntry(TASK_1_2, rankedClients12),
-                mkEntry(TASK_2_3, rankedClients23)
-            )
-        );
-    }
-
-    private static SortedMap<TaskId, SortedSet<RankedClient>> threeStatefulTasksToThreeRankedClients(final long rankForTask01OnClient1,
-                                                                                                     final long rankForTask01OnClient2,
-                                                                                                     final long rankForTask01OnClient3,
-                                                                                                     final long rankForTask12OnClient1,
-                                                                                                     final long rankForTask12OnClient2,
-                                                                                                     final long rankForTask12OnClient3,
-                                                                                                     final long rankForTask23OnClient1,
-                                                                                                     final long rankForTask23OnClient2,
-                                                                                                     final long rankForTask23OnClient3) {
-        final SortedSet<RankedClient> rankedClients01 = new TreeSet<>();
-        rankedClients01.add(new RankedClient(UUID_1, rankForTask01OnClient1));
-        rankedClients01.add(new RankedClient(UUID_2, rankForTask01OnClient2));
-        rankedClients01.add(new RankedClient(UUID_3, rankForTask01OnClient3));
-        final SortedSet<RankedClient> rankedClients12 = new TreeSet<>();
-        rankedClients12.add(new RankedClient(UUID_1, rankForTask12OnClient1));
-        rankedClients12.add(new RankedClient(UUID_2, rankForTask12OnClient2));
-        rankedClients12.add(new RankedClient(UUID_3, rankForTask12OnClient3));
-        final SortedSet<RankedClient> rankedClients23 = new TreeSet<>();
-        rankedClients23.add(new RankedClient(UUID_1, rankForTask23OnClient1));
-        rankedClients23.add(new RankedClient(UUID_2, rankForTask23OnClient2));
-        rankedClients23.add(new RankedClient(UUID_3, rankForTask23OnClient3));
-        return new TreeMap<>(
-            mkMap(
-                mkEntry(TASK_0_1, rankedClients01),
-                mkEntry(TASK_1_2, rankedClients12),
-                mkEntry(TASK_2_3, rankedClients23)
-            )
-        );
-    }
-
-    private static SortedMap<TaskId, SortedSet<RankedClient>> fourStatefulTasksToTwoRankedClients(final long rankForTask01OnClient1,
-                                                                                                  final long rankForTask01OnClient2,
-                                                                                                  final long rankForTask12OnClient1,
-                                                                                                  final long rankForTask12OnClient2,
-                                                                                                  final long rankForTask23OnClient1,
-                                                                                                  final long rankForTask23OnClient2,
-                                                                                                  final long rankForTask34OnClient1,
-                                                                                                  final long rankForTask34OnClient2) {
-        final SortedSet<RankedClient> rankedClients01 = new TreeSet<>();
-        rankedClients01.add(new RankedClient(UUID_1, rankForTask01OnClient1));
-        rankedClients01.add(new RankedClient(UUID_2, rankForTask01OnClient2));
-        final SortedSet<RankedClient> rankedClients12 = new TreeSet<>();
-        rankedClients12.add(new RankedClient(UUID_1, rankForTask12OnClient1));
-        rankedClients12.add(new RankedClient(UUID_2, rankForTask12OnClient2));
-        final SortedSet<RankedClient> rankedClients23 = new TreeSet<>();
-        rankedClients23.add(new RankedClient(UUID_1, rankForTask23OnClient1));
-        rankedClients23.add(new RankedClient(UUID_2, rankForTask23OnClient2));
-        final SortedSet<RankedClient> rankedClients34 = new TreeSet<>();
-        rankedClients34.add(new RankedClient(UUID_1, rankForTask34OnClient1));
-        rankedClients34.add(new RankedClient(UUID_2, rankForTask34OnClient2));
-        return new TreeMap<>(
-            mkMap(
-                mkEntry(TASK_0_1, rankedClients01),
-                mkEntry(TASK_1_2, rankedClients12),
-                mkEntry(TASK_2_3, rankedClients23),
-                mkEntry(TASK_3_4, rankedClients34)
-            )
-        );
-    }
-
-    private static SortedMap<TaskId, SortedSet<RankedClient>> fourStatefulTasksToThreeRankedClients(final long rankForTask01OnClient1,
-                                                                                                    final long rankForTask01OnClient2,
-                                                                                                    final long rankForTask01OnClient3,
-                                                                                                    final long rankForTask12OnClient1,
-                                                                                                    final long rankForTask12OnClient2,
-                                                                                                    final long rankForTask12OnClient3,
-                                                                                                    final long rankForTask23OnClient1,
-                                                                                                    final long rankForTask23OnClient2,
-                                                                                                    final long rankForTask23OnClient3,
-                                                                                                    final long rankForTask34OnClient1,
-                                                                                                    final long rankForTask34OnClient2,
-                                                                                                    final long rankForTask34OnClient3) {
-        final SortedSet<RankedClient> rankedClients01 = new TreeSet<>();
-        rankedClients01.add(new RankedClient(UUID_1, rankForTask01OnClient1));
-        rankedClients01.add(new RankedClient(UUID_2, rankForTask01OnClient2));
-        rankedClients01.add(new RankedClient(UUID_3, rankForTask01OnClient3));
-        final SortedSet<RankedClient> rankedClients12 = new TreeSet<>();
-        rankedClients12.add(new RankedClient(UUID_1, rankForTask12OnClient1));
-        rankedClients12.add(new RankedClient(UUID_2, rankForTask12OnClient2));
-        rankedClients12.add(new RankedClient(UUID_3, rankForTask12OnClient3));
-        final SortedSet<RankedClient> rankedClients23 = new TreeSet<>();
-        rankedClients23.add(new RankedClient(UUID_1, rankForTask23OnClient1));
-        rankedClients23.add(new RankedClient(UUID_2, rankForTask23OnClient2));
-        rankedClients23.add(new RankedClient(UUID_3, rankForTask23OnClient3));
-        final SortedSet<RankedClient> rankedClients34 = new TreeSet<>();
-        rankedClients34.add(new RankedClient(UUID_1, rankForTask34OnClient1));
-        rankedClients34.add(new RankedClient(UUID_2, rankForTask34OnClient2));
-        rankedClients34.add(new RankedClient(UUID_3, rankForTask34OnClient3));
-        return new TreeMap<>(
-            mkMap(
-                mkEntry(TASK_0_1, rankedClients01),
-                mkEntry(TASK_1_2, rankedClients12),
-                mkEntry(TASK_2_3, rankedClients23),
-                mkEntry(TASK_3_4, rankedClients34)
-            )
-        );
-    }
-
-    private static Map<UUID, List<TaskId>> expectedAssignmentForTwoClients(final List<TaskId> assignedTasksForClient1,
-                                                                           final List<TaskId> assignedTasksForClient2) {
-        return mkMap(
-            mkEntry(UUID_1, assignedTasksForClient1),
-            mkEntry(UUID_2, assignedTasksForClient2)
-        );
-    }
-
-    private static Map<UUID, List<TaskId>> expectedAssignmentForThreeClients(final List<TaskId> assignedTasksForClient1,
-                                                                             final List<TaskId> assignedTasksForClient2,
-                                                                             final List<TaskId> assignedTasksForClient3) {
-        return mkMap(
-            mkEntry(UUID_1, assignedTasksForClient1),
-            mkEntry(UUID_2, assignedTasksForClient2),
-            mkEntry(UUID_3, assignedTasksForClient3)
-        );
-    }
-}
\ No newline at end of file
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignorTest.java
index 1b00dc2..098c650 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignorTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignorTest.java
@@ -35,7 +35,7 @@ import static org.apache.kafka.streams.processor.internals.assignment.Assignment
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_3;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_1;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_2;
-import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_3;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.getClientStatesMap;
 import static org.apache.kafka.streams.processor.internals.assignment.HighAvailabilityTaskAssignor.computeBalanceFactor;
 import static org.easymock.EasyMock.expect;
 import static org.easymock.EasyMock.replay;
@@ -163,50 +163,6 @@ public class HighAvailabilityTaskAssignorTest {
     }
 
     @Test
-    public void shouldReturnTrueIfTaskHasNoCaughtUpClients() {
-        client1 = EasyMock.createNiceMock(ClientState.class);
-        expect(client1.lagFor(TASK_0_0)).andReturn(500L);
-        replay(client1);
-        allTasks =  mkSet(TASK_0_0);
-        statefulTasks =  mkSet(TASK_0_0);
-        clientStates = singletonMap(UUID_1, client1);
-        createTaskAssignor();
-
-        assertTrue(taskAssignor.taskIsCaughtUpOnClient(TASK_0_0, UUID_1));
-    }
-
-    @Test
-    public void shouldReturnTrueIfTaskIsCaughtUpOnClient() {
-        client1 = EasyMock.createNiceMock(ClientState.class);
-        expect(client1.lagFor(TASK_0_0)).andReturn(0L);
-        allTasks =  mkSet(TASK_0_0);
-        statefulTasks =  mkSet(TASK_0_0);
-        clientStates = singletonMap(UUID_1, client1);
-        replay(client1);
-        createTaskAssignor();
-
-        assertTrue(taskAssignor.taskIsCaughtUpOnClient(TASK_0_0, UUID_1));
-    }
-
-    @Test
-    public void shouldReturnFalseIfTaskWasNotCaughtUpOnClientButCaughtUpClientsExist() {
-        client1 = EasyMock.createNiceMock(ClientState.class);
-        client2 = EasyMock.createNiceMock(ClientState.class);
-        expect(client1.lagFor(TASK_0_0)).andReturn(500L);
-        expect(client2.lagFor(TASK_0_0)).andReturn(0L);
-        replay(client1, client2);
-        allTasks =  mkSet(TASK_0_0);
-        statefulTasks =  mkSet(TASK_0_0);
-        clientStates = mkMap(
-            mkEntry(UUID_1, client1),
-            mkEntry(UUID_2, client2)
-        );
-        createTaskAssignor();
-
-        assertFalse(taskAssignor.taskIsCaughtUpOnClient(TASK_0_0, UUID_1));
-    }
-
-    @Test
     public void shouldComputeBalanceFactorAsDifferenceBetweenMostAndLeastLoadedClients() {
         client1 = EasyMock.createNiceMock(ClientState.class);
         client2 = EasyMock.createNiceMock(ClientState.class);
@@ -298,7 +254,7 @@ public class HighAvailabilityTaskAssignorTest {
         client1 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0));
         client2 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_1));
 
-        clientStates = getClientStatesWithTwoClients();
+        clientStates = getClientStatesMap(client1, client2);
         createTaskAssignor();
         taskAssignor.assign();
 
@@ -317,7 +273,7 @@ public class HighAvailabilityTaskAssignorTest {
         client1 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS);
         client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS);
 
-        clientStates = getClientStatesWithTwoClients();
+        clientStates = getClientStatesMap(client1, client2);
         createTaskAssignor();
         taskAssignor.assign();
 
@@ -333,7 +289,7 @@ public class HighAvailabilityTaskAssignorTest {
         client1 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_1));
         client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS);
 
-        clientStates = getClientStatesWithTwoClients();
+        clientStates = getClientStatesMap(client1, client2);
         createTaskAssignor();
         taskAssignor.assign();
         
@@ -343,6 +299,8 @@ public class HighAvailabilityTaskAssignorTest {
         assertHasNoActiveTasks(client2);
     }
 
+
+
     @Test
     public void shouldNotAssignMoreThanMaxWarmupReplicas() {
         maxWarmupReplicas = 1;
@@ -351,7 +309,7 @@ public class HighAvailabilityTaskAssignorTest {
         client1 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3));
         client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS);
 
-        clientStates = getClientStatesWithTwoClients();
+        clientStates = getClientStatesMap(client1, client2);
         createTaskAssignor();
         taskAssignor.assign();
 
@@ -371,7 +329,7 @@ public class HighAvailabilityTaskAssignorTest {
         client1 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3));
         client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS);
 
-        clientStates = getClientStatesWithTwoClients();
+        clientStates = getClientStatesMap(client1, client2);
         createTaskAssignor();
         taskAssignor.assign();
 
@@ -388,7 +346,7 @@ public class HighAvailabilityTaskAssignorTest {
         statefulTasks = mkSet(TASK_0_0, TASK_0_1);
         client1 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_1));
 
-        clientStates = getClientStatesWithOneClient();
+        clientStates = getClientStatesMap(client1);
         createTaskAssignor();
         taskAssignor.assign();
 
@@ -403,7 +361,7 @@ public class HighAvailabilityTaskAssignorTest {
         statefulTasks = mkSet(TASK_0_0, TASK_0_1);
         client1 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS);
 
-        clientStates = getClientStatesWithOneClient();
+        clientStates = getClientStatesMap(client1);
         createTaskAssignor();
         taskAssignor.assign();
 
@@ -421,7 +379,7 @@ public class HighAvailabilityTaskAssignorTest {
         client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS);
         client3 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS);
 
-        clientStates = getClientStatesWithThreeClients();
+        clientStates = getClientStatesMap(client1, client2, client3);
         createTaskAssignor();
         taskAssignor.assign();
 
@@ -441,7 +399,7 @@ public class HighAvailabilityTaskAssignorTest {
         client1 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3));
         client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS);
 
-        clientStates = getClientStatesWithTwoClients();
+        clientStates = getClientStatesMap(client1, client2);
         createTaskAssignor();
         taskAssignor.assign();
 
@@ -459,7 +417,7 @@ public class HighAvailabilityTaskAssignorTest {
         client2 = getMockClientWithPreviousCaughtUpTasks(allTasks).withCapacity(50);
         client3 = getMockClientWithPreviousCaughtUpTasks(allTasks).withCapacity(1);
 
-        clientStates = getClientStatesWithThreeClients();
+        clientStates = getClientStatesMap(client1, client2, client3);
         createTaskAssignor();
         taskAssignor.assign();
 
@@ -472,10 +430,10 @@ public class HighAvailabilityTaskAssignorTest {
     public void shouldReturnFalseIfPreviousAssignmentIsReused() {
         allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3);
         statefulTasks = new HashSet<>(allTasks);
-        client1 = getMockClientWithPreviousCaughtUpTasks(allTasks);
-        client2 = getMockClientWithPreviousCaughtUpTasks(allTasks);
+        client1 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_2));
+        client2 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_1, TASK_0_3));
 
-        clientStates = getClientStatesWithTwoClients();
+        clientStates = getClientStatesMap(client1, client2);
         createTaskAssignor();
         assertFalse(taskAssignor.assign());
 
@@ -490,7 +448,7 @@ public class HighAvailabilityTaskAssignorTest {
         client1 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS);
         client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS);
 
-        clientStates = getClientStatesWithTwoClients();
+        clientStates = getClientStatesMap(client1, client2);
         createTaskAssignor();
         assertFalse(taskAssignor.assign());
         assertHasNoStandbyTasks(client1, client2);
@@ -503,24 +461,12 @@ public class HighAvailabilityTaskAssignorTest {
         client1 = getMockClientWithPreviousCaughtUpTasks(allTasks);
         client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS);
 
-        clientStates = getClientStatesWithTwoClients();
+        clientStates = getClientStatesMap(client1, client2);
         createTaskAssignor();
         assertTrue(taskAssignor.assign());
         assertThat(client2.standbyTaskCount(), equalTo(1));
     }
 
-    private Map<UUID, ClientState> getClientStatesWithOneClient() {
-        return singletonMap(UUID_1, client1);
-    }
-
-    private Map<UUID, ClientState> getClientStatesWithTwoClients() {
-        return mkMap(mkEntry(UUID_1, client1), mkEntry(UUID_2, client2));
-    }
-
-    private Map<UUID, ClientState> getClientStatesWithThreeClients() {
-        return mkMap(mkEntry(UUID_1, client1), mkEntry(UUID_2, client2), mkEntry(UUID_3, client3));
-    }
-
     private static void assertHasNoActiveTasks(final ClientState... clients) {
         for (final ClientState client : clients) {
             assertTrue(client.activeTasks().isEmpty());
@@ -549,7 +495,7 @@ public class HighAvailabilityTaskAssignorTest {
         client.addPreviousActiveTasks(statefulActiveTasks);
         return client;
     }
-    
+
     static class MockClientState extends ClientState {
         private final Map<TaskId, Long> taskLagTotals;
 
@@ -558,7 +504,7 @@ public class HighAvailabilityTaskAssignorTest {
             super(capacity);
             this.taskLagTotals = taskLagTotals;
         }
-            
+
         @Override
         long lagFor(final TaskId task) {
             final Long totalLag = taskLagTotals.get(task);
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorConvergenceTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorConvergenceTest.java
index c47fc37..7be6ee7 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorConvergenceTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorConvergenceTest.java
@@ -19,7 +19,6 @@ package org.apache.kafka.streams.processor.internals.assignment;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs;
 import org.hamcrest.MatcherAssert;
-import org.junit.Ignore;
 import org.junit.Test;
 
 import java.util.Map;
@@ -266,7 +265,6 @@ public class TaskAssignorConvergenceTest {
         verifyValidAssignment(numStandbyReplicas, harness);
     }
 
-    @Ignore // Adding this failing test before adding the code that fixes it
     @Test
     public void droppingNodesShouldConverge() {
         final int numStatelessTasks = 15;
@@ -290,7 +288,6 @@ public class TaskAssignorConvergenceTest {
         verifyValidAssignment(numStandbyReplicas, harness);
     }
 
-    @Ignore // Adding this failing test before adding the code that fixes it
     @Test
     public void randomClusterPerturbationsShouldConverge() {
         // do as many tests as we can in 10 seconds
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovementTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovementTest.java
index ebc39c6..088f123 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovementTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovementTest.java
@@ -17,13 +17,14 @@
 package org.apache.kafka.streams.processor.internals.assignment;
 
 import static java.util.Arrays.asList;
-import static java.util.Collections.emptyList;
 import static java.util.Collections.emptyMap;
 import static java.util.Collections.singleton;
+import static java.util.Collections.singletonList;
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.common.utils.Utils.mkSet;
-import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.EMPTY_TASKS;
+import static org.apache.kafka.common.utils.Utils.mkSortedSet;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.EMPTY_TASK_LIST;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_0;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_1;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_2;
@@ -33,287 +34,258 @@ import static org.apache.kafka.streams.processor.internals.assignment.Assignment
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_1;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_2;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_3;
-import static org.apache.kafka.streams.processor.internals.assignment.TaskMovement.getMovements;
-import static org.easymock.EasyMock.expect;
-import static org.easymock.EasyMock.replay;
-import static org.hamcrest.CoreMatchers.equalTo;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.getClientStatesMap;
+import static org.apache.kafka.streams.processor.internals.assignment.TaskMovement.assignTaskMovements;
 import static org.hamcrest.MatcherAssert.assertThat;
-import static org.junit.Assert.assertThrows;
+import static org.hamcrest.Matchers.equalTo;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.assertFalse;
 
-import java.util.ArrayList;
 import java.util.HashMap;
-import java.util.LinkedList;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
-import java.util.Queue;
 import java.util.Set;
 import java.util.SortedSet;
-import java.util.TreeSet;
 import java.util.UUID;
 import java.util.stream.Collectors;
 import org.apache.kafka.streams.processor.TaskId;
-import org.easymock.EasyMock;
 import org.junit.Test;
 
 public class TaskMovementTest {
+    private final ClientState client1 = new ClientState(1);
+    private final ClientState client2 = new ClientState(1);
+    private final ClientState client3 = new ClientState(1);
 
-    private final ClientState client1 = EasyMock.createMock(ClientState.class);
-    private final ClientState client2 = EasyMock.createMock(ClientState.class);
-    private final ClientState client3 = EasyMock.createMock(ClientState.class);
+    private final Map<UUID, ClientState> clientStates = getClientStatesMap(client1, client2, client3);
+
+    private final Map<UUID, List<TaskId>> emptyWarmupAssignment = mkMap(
+        mkEntry(UUID_1, EMPTY_TASK_LIST),
+        mkEntry(UUID_2, EMPTY_TASK_LIST),
+        mkEntry(UUID_3, EMPTY_TASK_LIST)
+    );
 
     @Test
-    public void shouldGetMovementsFromStateConstrainedToBalancedAssignment() {
+    public void shouldAssignTasksToClientsAndReturnFalseWhenAllClientsCaughtUp() {
         final int maxWarmupReplicas = Integer.MAX_VALUE;
         final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2);
 
-        final Map<UUID, List<TaskId>> stateConstrainedAssignment = mkMap(
-            mkEntry(UUID_1, mkTaskList(TASK_0_0, TASK_1_2)),
-            mkEntry(UUID_2, mkTaskList(TASK_0_1, TASK_1_0)),
-            mkEntry(UUID_3, mkTaskList(TASK_0_2, TASK_1_1))
-        );
         final Map<UUID, List<TaskId>> balancedAssignment = mkMap(
-            mkEntry(UUID_1, mkTaskList(TASK_0_0, TASK_1_0)),
-            mkEntry(UUID_2, mkTaskList(TASK_0_1, TASK_1_1)),
-            mkEntry(UUID_3, mkTaskList(TASK_0_2, TASK_1_2))
-        );
-        final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = getMapWithNoCaughtUpClients(
-            mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2)
+            mkEntry(UUID_1, asList(TASK_0_0, TASK_1_0)),
+            mkEntry(UUID_2, asList(TASK_0_1, TASK_1_1)),
+            mkEntry(UUID_3, asList(TASK_0_2, TASK_1_2))
         );
 
-        expectNoPreviousStandbys(client1, client2, client3);
-
-        final Queue<TaskMovement> expectedMovements = new LinkedList<>();
-        expectedMovements.add(new TaskMovement(TASK_1_2, UUID_1, UUID_3));
-        expectedMovements.add(new TaskMovement(TASK_1_0, UUID_2, UUID_1));
-        expectedMovements.add(new TaskMovement(TASK_1_1, UUID_3, UUID_2));
-
-        assertThat(
-            getMovements(
-                stateConstrainedAssignment,
+        final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = new HashMap<>();
+        for (final TaskId task : allTasks) {
+            tasksToCaughtUpClients.put(task, mkSortedSet(UUID_1, UUID_2, UUID_3));
+        }
+        
+        assertFalse(
+            assignTaskMovements(
                 balancedAssignment,
                 tasksToCaughtUpClients,
-                getClientStatesWithThreeClients(),
+                clientStates,
                 getMapWithNumStandbys(allTasks, 1),
-                maxWarmupReplicas),
-            equalTo(expectedMovements));
+                maxWarmupReplicas)
+        );
+
+        verifyClientStateAssignments(balancedAssignment, emptyWarmupAssignment);
     }
 
     @Test
-    public void shouldImmediatelyMoveTasksWithCaughtUpDestinationClients() {
-        final int maxWarmupReplicas = Integer.MAX_VALUE;
+    public void shouldAssignAllTasksToClientsAndReturnFalseIfNoClientsAreCaughtUp() {
+        final int maxWarmupReplicas = 2;
         final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2);
 
-        final Map<UUID, List<TaskId>> stateConstrainedAssignment = mkMap(
-            mkEntry(UUID_1, mkTaskList(TASK_0_0, TASK_1_2)),
-            mkEntry(UUID_2, mkTaskList(TASK_0_1, TASK_1_0)),
-            mkEntry(UUID_3, mkTaskList(TASK_0_2, TASK_1_1))
-        );
         final Map<UUID, List<TaskId>> balancedAssignment = mkMap(
-            mkEntry(UUID_1, mkTaskList(TASK_0_0, TASK_1_0)),
-            mkEntry(UUID_2, mkTaskList(TASK_0_1, TASK_1_1)),
-            mkEntry(UUID_3, mkTaskList(TASK_0_2, TASK_1_2))
+            mkEntry(UUID_1, asList(TASK_0_0, TASK_1_0)),
+            mkEntry(UUID_2, asList(TASK_0_1, TASK_1_1)),
+            mkEntry(UUID_3, asList(TASK_0_2, TASK_1_2))
         );
 
-        final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = getMapWithNoCaughtUpClients(allTasks);
-        tasksToCaughtUpClients.get(TASK_1_0).add(UUID_1);
-
-        expectNoPreviousStandbys(client1, client2, client3);
-
-        final Queue<TaskMovement> expectedMovements = new LinkedList<>();
-        expectedMovements.add(new TaskMovement(TASK_1_2, UUID_1, UUID_3));
-        expectedMovements.add(new TaskMovement(TASK_1_1, UUID_3, UUID_2));
-
-
-        assertThat(
-            getMovements(
-                stateConstrainedAssignment,
+        assertFalse(
+            assignTaskMovements(
                 balancedAssignment,
-                tasksToCaughtUpClients,
-                getClientStatesWithThreeClients(),
+                emptyMap(),
+                clientStates,
                 getMapWithNumStandbys(allTasks, 1),
-                maxWarmupReplicas),
-            equalTo(expectedMovements));
-
-        assertFalse(stateConstrainedAssignment.get(UUID_2).contains(TASK_1_0));
-        assertTrue(stateConstrainedAssignment.get(UUID_1).contains(TASK_1_0));
+                maxWarmupReplicas)
+        );
+        verifyClientStateAssignments(balancedAssignment, emptyWarmupAssignment);
     }
 
     @Test
-    public void shouldOnlyGetUpToMaxWarmupReplicaMovements() {
-        final int maxWarmupReplicas = 1;
-        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2);
+    public void shouldMoveTasksToCaughtUpClientsAndAssignWarmupReplicasInTheirPlace() {
+        final int maxWarmupReplicas = Integer.MAX_VALUE;
+        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2);
 
-        final Map<UUID, List<TaskId>> stateConstrainedAssignment = mkMap(
-            mkEntry(UUID_1, mkTaskList(TASK_0_0, TASK_1_2)),
-            mkEntry(UUID_2, mkTaskList(TASK_0_1, TASK_1_0)),
-            mkEntry(UUID_3, mkTaskList(TASK_0_2, TASK_1_1))
-        );
         final Map<UUID, List<TaskId>> balancedAssignment = mkMap(
-            mkEntry(UUID_1, mkTaskList(TASK_0_0, TASK_1_0)),
-            mkEntry(UUID_2, mkTaskList(TASK_0_1, TASK_1_1)),
-            mkEntry(UUID_3, mkTaskList(TASK_0_2, TASK_1_2))
+            mkEntry(UUID_1, singletonList(TASK_0_0)),
+            mkEntry(UUID_2, singletonList(TASK_0_1)),
+            mkEntry(UUID_3, singletonList(TASK_0_2))
         );
-        final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = getMapWithNoCaughtUpClients(allTasks);
 
-        expectNoPreviousStandbys(client1, client2, client3);
+        final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = new HashMap<>();
+        tasksToCaughtUpClients.put(TASK_0_0, mkSortedSet(UUID_1));
+        tasksToCaughtUpClients.put(TASK_0_1, mkSortedSet(UUID_3));
+        tasksToCaughtUpClients.put(TASK_0_2, mkSortedSet(UUID_2));
+
+        final Map<UUID, List<TaskId>> expectedActiveTaskAssignment = mkMap(
+            mkEntry(UUID_1, singletonList(TASK_0_0)),
+            mkEntry(UUID_2, singletonList(TASK_0_2)),
+            mkEntry(UUID_3, singletonList(TASK_0_1))
+        );
 
-        final Queue<TaskMovement> expectedMovements = new LinkedList<>();
-        expectedMovements.add(new TaskMovement(TASK_1_2, UUID_1, UUID_3));
+        final Map<UUID, List<TaskId>> expectedWarmupTaskAssignment = mkMap(
+            mkEntry(UUID_1, EMPTY_TASK_LIST),
+            mkEntry(UUID_2, singletonList(TASK_0_1)),
+            mkEntry(UUID_3, singletonList(TASK_0_2))
+        );
 
-        assertThat(
-            getMovements(
-                stateConstrainedAssignment,
+        assertTrue(
+            assignTaskMovements(
                 balancedAssignment,
                 tasksToCaughtUpClients,
-                getClientStatesWithThreeClients(),
+                clientStates,
                 getMapWithNumStandbys(allTasks, 1),
-                maxWarmupReplicas),
-            equalTo(expectedMovements));
+                maxWarmupReplicas)
+        );
+        verifyClientStateAssignments(expectedActiveTaskAssignment, expectedWarmupTaskAssignment);
     }
 
     @Test
-    public void shouldNotCountPreviousStandbyTasksTowardsMaxWarmupReplicas() {
-        final int maxWarmupReplicas = 1;
+    public void shouldProduceBalancedAndStateConstrainedAssignment() {
+        final int maxWarmupReplicas = Integer.MAX_VALUE;
         final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2);
 
-        final Map<UUID, List<TaskId>> stateConstrainedAssignment = mkMap(
-            mkEntry(UUID_1, mkTaskList(TASK_0_0, TASK_1_2)),
-            mkEntry(UUID_2, mkTaskList(TASK_0_1, TASK_1_0)),
-            mkEntry(UUID_3, mkTaskList(TASK_0_2, TASK_1_1))
-        );
         final Map<UUID, List<TaskId>> balancedAssignment = mkMap(
-            mkEntry(UUID_1, mkTaskList(TASK_0_0, TASK_1_0)),
-            mkEntry(UUID_2, mkTaskList(TASK_0_1, TASK_1_1)),
-            mkEntry(UUID_3, mkTaskList(TASK_0_2, TASK_1_2))
+            mkEntry(UUID_1, asList(TASK_0_0, TASK_1_0)),
+            mkEntry(UUID_2, asList(TASK_0_1, TASK_1_1)),
+            mkEntry(UUID_3, asList(TASK_0_2, TASK_1_2))
         );
 
-        final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = getMapWithNoCaughtUpClients(allTasks);
+        final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = new HashMap<>();
+        tasksToCaughtUpClients.put(TASK_0_0, mkSortedSet(UUID_2, UUID_3));  // needs to be warmed up
 
-        expectNoPreviousStandbys(client1, client2);
-        expect(client3.prevStandbyTasks()).andStubReturn(singleton(TASK_1_2));
-        replay(client3);
+        tasksToCaughtUpClients.put(TASK_0_1, mkSortedSet(UUID_1, UUID_3));  // needs to be warmed up
 
-        final Queue<TaskMovement> expectedMovements = new LinkedList<>();
-        expectedMovements.add(new TaskMovement(TASK_1_2, UUID_1, UUID_3));
-        expectedMovements.add(new TaskMovement(TASK_1_0, UUID_2, UUID_1));
+        tasksToCaughtUpClients.put(TASK_0_2, mkSortedSet(UUID_2));          // needs to be warmed up
 
-        assertThat(
-            getMovements(
-                stateConstrainedAssignment,
-                balancedAssignment,
-                tasksToCaughtUpClients,
-                getClientStatesWithThreeClients(),
-                getMapWithNumStandbys(allTasks, 1),
-                maxWarmupReplicas),
-            equalTo(expectedMovements));
-    }
+        tasksToCaughtUpClients.put(TASK_1_1, mkSortedSet(UUID_1));  // needs to be warmed up
 
-    @Test
-    public void shouldReturnEmptyMovementsWhenPassedEmptyTaskAssignments() {
-        final int maxWarmupReplicas = 2;
-        final Map<UUID, List<TaskId>> stateConstrainedAssignment = mkMap(
-            mkEntry(UUID_1, emptyList()),
-            mkEntry(UUID_2, emptyList())
+        final Map<UUID, List<TaskId>> expectedActiveTaskAssignment = mkMap(
+            mkEntry(UUID_1, asList(TASK_1_0, TASK_1_1)),
+            mkEntry(UUID_2, asList(TASK_0_2, TASK_0_0)),
+            mkEntry(UUID_3, asList(TASK_0_1, TASK_1_2))
         );
-        final Map<UUID, List<TaskId>> balancedAssignment = mkMap(
-            mkEntry(UUID_1, emptyList()),
-            mkEntry(UUID_2, emptyList())
+
+        final Map<UUID, List<TaskId>> expectedWarmupTaskAssignment = mkMap(
+            mkEntry(UUID_1, singletonList(TASK_0_0)),
+            mkEntry(UUID_2, asList(TASK_0_1, TASK_1_1)),
+            mkEntry(UUID_3, singletonList(TASK_0_2))
         );
 
         assertTrue(
-            getMovements(
-                stateConstrainedAssignment,
+            assignTaskMovements(
                 balancedAssignment,
-                emptyMap(),
-                getClientStatesWithTwoClients(),
-                emptyMap(),
-                maxWarmupReplicas
-            ).isEmpty());
+                tasksToCaughtUpClients,
+                clientStates,
+                getMapWithNumStandbys(allTasks, 1),
+                maxWarmupReplicas)
+        );
+        verifyClientStateAssignments(expectedActiveTaskAssignment, expectedWarmupTaskAssignment);
     }
 
     @Test
-    public void shouldReturnEmptyMovementsWhenPassedIdenticalTaskAssignments() {
-        final int maxWarmupReplicas = 2;
-        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_1_0, TASK_1_1);
+    public void shouldOnlyGetUpToMaxWarmupReplicasAndReturnTrue() {
+        final int maxWarmupReplicas = 1;
+        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2);
 
-        final Map<UUID, List<TaskId>> stateConstrainedAssignment = mkMap(
-            mkEntry(UUID_1, mkTaskList(TASK_0_0, TASK_1_0)),
-            mkEntry(UUID_2, mkTaskList(TASK_0_1, TASK_1_1))
-        );
         final Map<UUID, List<TaskId>> balancedAssignment = mkMap(
-            mkEntry(UUID_1, mkTaskList(TASK_0_0, TASK_1_0)),
-            mkEntry(UUID_2, mkTaskList(TASK_0_1, TASK_1_1))
+            mkEntry(UUID_1, singletonList(TASK_0_0)),
+            mkEntry(UUID_2, singletonList(TASK_0_1)),
+            mkEntry(UUID_3, singletonList(TASK_0_2))
+        );
+
+        final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = new HashMap<>();
+        tasksToCaughtUpClients.put(TASK_0_0, mkSortedSet(UUID_1));
+        tasksToCaughtUpClients.put(TASK_0_1, mkSortedSet(UUID_3));
+        tasksToCaughtUpClients.put(TASK_0_2, mkSortedSet(UUID_2));
+
+        final Map<UUID, List<TaskId>> expectedActiveTaskAssignment = mkMap(
+            mkEntry(UUID_1, singletonList(TASK_0_0)),
+            mkEntry(UUID_2, singletonList(TASK_0_2)),
+            mkEntry(UUID_3, singletonList(TASK_0_1))
         );
 
+        final Map<UUID, List<TaskId>> expectedWarmupTaskAssignment = mkMap(
+            mkEntry(UUID_1, EMPTY_TASK_LIST),
+            mkEntry(UUID_2, singletonList(TASK_0_1)),
+            mkEntry(UUID_3, EMPTY_TASK_LIST)
+        );
         assertTrue(
-            getMovements(
-                stateConstrainedAssignment,
-                balancedAssignment,
-                getMapWithNoCaughtUpClients(allTasks),
-                getClientStatesWithTwoClients(),
-                getMapWithNumStandbys(allTasks, 1),
-                maxWarmupReplicas
-            ).isEmpty());
+            assignTaskMovements(
+               balancedAssignment,
+               tasksToCaughtUpClients,
+               clientStates,
+               getMapWithNumStandbys(allTasks, 1),
+               maxWarmupReplicas)
+        );
+
+        verifyClientStateAssignments(expectedActiveTaskAssignment, expectedWarmupTaskAssignment);
     }
 
     @Test
-    public void shouldThrowIllegalStateExceptionIfAssignmentsAreOfDifferentSize() {
-        final int maxWarmupReplicas = 2;
-        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_1_0, TASK_1_1);
+    public void shouldNotCountPreviousStandbyTasksTowardsMaxWarmupReplicas() {
+        final int maxWarmupReplicas = 1;
+        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2);
 
-        final Map<UUID, List<TaskId>> stateConstrainedAssignment = mkMap(
-            mkEntry(UUID_1, mkTaskList(TASK_0_0, TASK_0_1))
-        );
         final Map<UUID, List<TaskId>> balancedAssignment = mkMap(
-            mkEntry(UUID_1, mkTaskList(TASK_0_0, TASK_1_0)),
-            mkEntry(UUID_2, mkTaskList(TASK_0_1, TASK_1_1))
+            mkEntry(UUID_1, singletonList(TASK_0_0)),
+            mkEntry(UUID_2, singletonList(TASK_0_1)),
+            mkEntry(UUID_3, singletonList(TASK_0_2))
         );
 
-        assertThrows(
-            IllegalStateException.class,
-            () -> getMovements(
-                stateConstrainedAssignment,
-                balancedAssignment,
-                getMapWithNoCaughtUpClients(allTasks),
-                getClientStatesWithTwoClients(),
-                getMapWithNumStandbys(allTasks, 1),
-                maxWarmupReplicas)
+        final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = new HashMap<>();
+        tasksToCaughtUpClients.put(TASK_0_0, mkSortedSet(UUID_1));
+        tasksToCaughtUpClients.put(TASK_0_1, mkSortedSet(UUID_3));
+        tasksToCaughtUpClients.put(TASK_0_2, mkSortedSet(UUID_2));
+
+        final Map<UUID, List<TaskId>> expectedActiveTaskAssignment = mkMap(
+            mkEntry(UUID_1, singletonList(TASK_0_0)),
+            mkEntry(UUID_2, singletonList(TASK_0_2)),
+            mkEntry(UUID_3, singletonList(TASK_0_1))
         );
-    }
-
-    @Test
-    public void shouldThrowIllegalStateExceptionWhenTaskHasNoDestinationClient() {
-        final int maxWarmupReplicas = 2;
-        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_1_0);
 
-        final Map<UUID, List<TaskId>> stateConstrainedAssignment = mkMap(
-            mkEntry(UUID_1, mkTaskList(TASK_0_0, TASK_0_1)),
-            mkEntry(UUID_2, mkTaskList(TASK_1_0))
+        final Map<UUID, List<TaskId>> expectedWarmupTaskAssignment = mkMap(
+            mkEntry(UUID_1, EMPTY_TASK_LIST),
+            mkEntry(UUID_2, singletonList(TASK_0_1)),
+            mkEntry(UUID_3, singletonList(TASK_0_2))
         );
-        final Map<UUID, List<TaskId>> balancedAssignment = mkMap(
-            mkEntry(UUID_1, mkTaskList(TASK_0_0)),
-            mkEntry(UUID_2, mkTaskList(TASK_0_1))
-        );
-        expectNoPreviousStandbys(client1, client2);
 
-        assertThrows(
-            IllegalStateException.class,
-            () -> getMovements(
-                stateConstrainedAssignment,
+        client3.addPreviousStandbyTasks(singleton(TASK_0_2));
+
+        assertTrue(
+            assignTaskMovements(
                 balancedAssignment,
-                getMapWithNoCaughtUpClients(allTasks),
-                getClientStatesWithTwoClients(),
+                tasksToCaughtUpClients,
+                clientStates,
                 getMapWithNumStandbys(allTasks, 1),
                 maxWarmupReplicas)
         );
+
+        verifyClientStateAssignments(expectedActiveTaskAssignment, expectedWarmupTaskAssignment);
     }
 
-    private static void expectNoPreviousStandbys(final ClientState... states) {
-        for (final ClientState state : states) {
-            expect(state.prevStandbyTasks()).andStubReturn(EMPTY_TASKS);
-            replay(state);
+    private void verifyClientStateAssignments(final Map<UUID, List<TaskId>> expectedActiveTaskAssignment,
+                                              final Map<UUID, List<TaskId>> expectedStandbyTaskAssignment) {
+        for (final Map.Entry<UUID, ClientState> clientEntry : clientStates.entrySet()) {
+            final UUID client = clientEntry.getKey();
+            final ClientState state = clientEntry.getValue();
+            
+            assertThat(state.activeTasks(), equalTo(new HashSet<>(expectedActiveTaskAssignment.get(client))));
+            assertThat(state.standbyTasks(), equalTo(new HashSet<>(expectedStandbyTaskAssignment.get(client))));
         }
     }
 
@@ -321,30 +293,4 @@ public class TaskMovementTest {
         return tasks.stream().collect(Collectors.toMap(task -> task, t -> numStandbys));
     }
 
-    private Map<UUID, ClientState> getClientStatesWithTwoClients() {
-        return mkMap(
-            mkEntry(UUID_1, client1),
-            mkEntry(UUID_2, client2)
-        );
-    }
-
-    private Map<UUID, ClientState> getClientStatesWithThreeClients() {
-        return mkMap(
-            mkEntry(UUID_1, client1),
-            mkEntry(UUID_2, client2),
-            mkEntry(UUID_3, client3)
-        );
-    }
-
-    private static List<TaskId> mkTaskList(final TaskId... tasks) {
-        return new ArrayList<>(asList(tasks));
-    }
-
-    private static Map<TaskId, SortedSet<UUID>> getMapWithNoCaughtUpClients(final Set<TaskId> tasks) {
-        final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = new HashMap<>();
-        for (final TaskId task : tasks) {
-            tasksToCaughtUpClients.put(task, new TreeSet<>());
-        }
-        return tasksToCaughtUpClients;
-    }
 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ValidClientsByTaskLoadQueueTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ValidClientsByTaskLoadQueueTest.java
new file mode 100644
index 0000000..aff6153
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ValidClientsByTaskLoadQueueTest.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals.assignment;
+
+
+import static java.util.Arrays.asList;
+import static java.util.Collections.singletonList;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_0;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_1;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_2;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_1;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_2;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_2;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_1;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_2;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_3;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.getClientStatesMap;
+import static org.hamcrest.CoreMatchers.equalTo;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertNull;
+
+import java.util.Map;
+import java.util.UUID;
+import java.util.function.BiFunction;
+import org.apache.kafka.streams.processor.TaskId;
+import org.junit.Test;
+
+public class ValidClientsByTaskLoadQueueTest {
+    
+    private static final TaskId DUMMY_TASK = new TaskId(0, 0);
+
+    private final ClientState client1 = new ClientState(1);
+    private final ClientState client2 = new ClientState(1);
+    private final ClientState client3 = new ClientState(1);
+
+    private final BiFunction<UUID, TaskId, Boolean> alwaysTrue = (client, task) -> true;
+    private final BiFunction<UUID, TaskId, Boolean> alwaysFalse = (client, task) -> false;
+
+    private ValidClientsByTaskLoadQueue queue;
+
+    private Map<UUID, ClientState> clientStates;
+
+    @Test
+    public void shouldReturnOnlyClient() {
+        clientStates = getClientStatesMap(client1);
+        queue = new ValidClientsByTaskLoadQueue(clientStates, alwaysTrue);
+        queue.offerAll(clientStates.keySet());
+
+        assertThat(queue.poll(DUMMY_TASK), equalTo(UUID_1));
+    }
+
+    @Test
+    public void shouldReturnNull() {
+        clientStates = getClientStatesMap(client1);
+        queue = new ValidClientsByTaskLoadQueue(clientStates, alwaysFalse);
+        queue.offerAll(clientStates.keySet());
+
+        assertNull(queue.poll(DUMMY_TASK));
+    }
+
+    @Test
+    public void shouldReturnLeastLoadedClient() {
+        clientStates = getClientStatesMap(client1, client2, client3);
+        queue = new ValidClientsByTaskLoadQueue(clientStates, alwaysTrue);
+
+        client1.assignActive(TASK_0_0);
+        client2.assignActiveTasks(asList(TASK_0_1, TASK_1_1));
+        client3.assignActiveTasks(asList(TASK_0_2, TASK_1_2, TASK_2_2));
+
+        queue.offerAll(clientStates.keySet());
+
+        assertThat(queue.poll(DUMMY_TASK), equalTo(UUID_1));
+        assertThat(queue.poll(DUMMY_TASK), equalTo(UUID_2));
+        assertThat(queue.poll(DUMMY_TASK), equalTo(UUID_3));
+    }
+
+    @Test
+    public void shouldNotRetainDuplicates() {
+        clientStates = getClientStatesMap(client1);
+        queue = new ValidClientsByTaskLoadQueue(clientStates, alwaysTrue);
+
+        queue.offerAll(clientStates.keySet());
+        queue.offer(UUID_1);
+
+        assertThat(queue.poll(DUMMY_TASK), equalTo(UUID_1));
+        assertNull(queue.poll(DUMMY_TASK));
+    }
+
+    @Test
+    public void shouldOnlyReturnValidClients() {
+        clientStates = getClientStatesMap(client1, client2);
+        queue = new ValidClientsByTaskLoadQueue(clientStates, (client, task) -> client.equals(UUID_1));
+
+        queue.offerAll(clientStates.keySet());
+
+        assertThat(queue.poll(DUMMY_TASK, 2), equalTo(singletonList(UUID_1)));
+    }
+
+    @Test
+    public void shouldReturnUpToNumClients() {
+        clientStates = getClientStatesMap(client1, client2, client3);
+        queue = new ValidClientsByTaskLoadQueue(clientStates, alwaysTrue);
+
+        client1.assignActive(TASK_0_0);
+        client2.assignActiveTasks(asList(TASK_0_1, TASK_1_1));
+        client3.assignActiveTasks(asList(TASK_0_2, TASK_1_2, TASK_2_2));
+
+        queue.offerAll(clientStates.keySet());
+
+        assertThat(queue.poll(DUMMY_TASK, 2), equalTo(asList(UUID_1, UUID_2)));
+    }
+}