You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by ca...@apache.org on 2022/03/28 14:59:12 UTC

[kafka] branch 3.2 updated: KAFKA-13600: Kafka Streams - Fall back to most caught up client if no caught up clients exist (#11760)

This is an automated email from the ASF dual-hosted git repository.

cadonna pushed a commit to branch 3.2
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/3.2 by this push:
     new ac11383  KAFKA-13600: Kafka Streams - Fall back to most caught up client if no caught up clients exist (#11760)
ac11383 is described below

commit ac11383365771895d8ca4bce8fb8f0882621a38f
Author: Tim Patterson <tp...@sailthru.com>
AuthorDate: Tue Mar 29 03:48:39 2022 +1300

    KAFKA-13600: Kafka Streams - Fall back to most caught up client if no caught up clients exist (#11760)
    
    The task assignor is modified to consider the Streams client with the most caught up states if no Streams client exists that is caught up, i.e., the lag of the states on that client is less than the acceptable recovery lag.
    
    Unit test for case task assignment where no caught up nodes exist.
    Existing unit and integration tests to verify no other behaviour has been changed
    
    Co-authored-by: Bruno Cadonna <ca...@apache.org>
    
    Reviewer: Bruno Cadonna <ca...@apache.org>
---
 .../assignment/HighAvailabilityTaskAssignor.java   |  17 ++
 .../internals/assignment/TaskMovement.java         | 178 +++++++++++++++------
 .../HighAvailabilityTaskAssignorTest.java          |  31 ++++
 .../internals/assignment/TaskMovementTest.java     | 144 ++++++++++++++---
 4 files changed, 303 insertions(+), 67 deletions(-)

diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java
index 7111ae2..c54199a 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java
@@ -22,6 +22,7 @@ import org.apache.kafka.streams.processor.internals.assignment.AssignorConfigura
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.util.Comparator;
 import java.util.HashMap;
 import java.util.Iterator;
 import java.util.Map;
@@ -68,6 +69,8 @@ public class HighAvailabilityTaskAssignor implements TaskAssignor {
             configs.acceptableRecoveryLag
         );
 
+        final Map<TaskId, SortedSet<UUID>> tasksToClientByLag = tasksToClientByLag(statefulTasks, clientStates);
+
         // We temporarily need to know which standby tasks were intended as warmups
         // for active tasks, so that we don't move them (again) when we plan standby
         // task movements. We can then immediately treat warmups exactly the same as
@@ -77,6 +80,7 @@ public class HighAvailabilityTaskAssignor implements TaskAssignor {
 
         final int neededActiveTaskMovements = assignActiveTaskMovements(
             tasksToCaughtUpClients,
+            tasksToClientByLag,
             clientStates,
             warmups,
             remainingWarmupReplicas
@@ -84,6 +88,7 @@ public class HighAvailabilityTaskAssignor implements TaskAssignor {
 
         final int neededStandbyTaskMovements = assignStandbyTaskMovements(
             tasksToCaughtUpClients,
+            tasksToClientByLag,
             clientStates,
             remainingWarmupReplicas,
             warmups
@@ -238,6 +243,18 @@ public class HighAvailabilityTaskAssignor implements TaskAssignor {
         return taskToCaughtUpClients;
     }
 
+    private static Map<TaskId, SortedSet<UUID>> tasksToClientByLag(final Set<TaskId> statefulTasks,
+                                                              final Map<UUID, ClientState> clientStates) {
+        final Map<TaskId, SortedSet<UUID>> tasksToClientByLag = new HashMap<>();
+        for (final TaskId task : statefulTasks) {
+            final SortedSet<UUID> clientLag = new TreeSet<>(Comparator.<UUID>comparingLong(a ->
+                    clientStates.get(a).lagFor(task)).thenComparing(a -> a));
+            clientLag.addAll(clientStates.keySet());
+            tasksToClientByLag.put(task, clientLag);
+        }
+        return tasksToClientByLag;
+    }
+
     private static boolean unbounded(final long acceptableRecoveryLag) {
         return acceptableRecoveryLag == Long.MAX_VALUE;
     }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovement.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovement.java
index cbfa3da..38e6427 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovement.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovement.java
@@ -29,6 +29,7 @@ import java.util.TreeSet;
 import java.util.UUID;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.function.BiFunction;
+import java.util.function.Function;
 
 import static java.util.Arrays.asList;
 import static java.util.Objects.requireNonNull;
@@ -42,10 +43,6 @@ final class TaskMovement {
         this.task = task;
         this.destination = destination;
         this.caughtUpClients = caughtUpClients;
-
-        if (caughtUpClients == null || caughtUpClients.isEmpty()) {
-            throw new IllegalStateException("Should not attempt to move a task if no caught up clients exist");
-        }
     }
 
     private TaskId task() {
@@ -56,25 +53,34 @@ final class TaskMovement {
         return caughtUpClients.size();
     }
 
-    private static boolean taskIsNotCaughtUpOnClientAndOtherCaughtUpClientsExist(final TaskId task,
-                                                                                 final UUID client,
-                                                                                 final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients) {
-        return !taskIsCaughtUpOnClientOrNoCaughtUpClientsExist(task, client, tasksToCaughtUpClients);
+    private static boolean taskIsNotCaughtUpOnClientAndOtherMoreCaughtUpClientsExist(final TaskId task,
+                                                                                     final UUID client,
+                                                                                     final Map<UUID, ClientState> clientStates,
+                                                                                     final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients,
+                                                                                     final Map<TaskId, SortedSet<UUID>> tasksToClientByLag) {
+        final SortedSet<UUID> taskClients = requireNonNull(tasksToClientByLag.get(task), "uninitialized set");
+        if (taskIsCaughtUpOnClient(task, client, tasksToCaughtUpClients)) {
+            return false;
+        }
+        final long mostCaughtUpLag = clientStates.get(taskClients.first()).lagFor(task);
+        final long clientLag = clientStates.get(client).lagFor(task);
+        return mostCaughtUpLag < clientLag;
     }
 
-    private static boolean taskIsCaughtUpOnClientOrNoCaughtUpClientsExist(final TaskId task,
-                                                                          final UUID client,
-                                                                          final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients) {
+    private static boolean taskIsCaughtUpOnClient(final TaskId task,
+                                                  final UUID client,
+                                                  final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients) {
         final Set<UUID> caughtUpClients = requireNonNull(tasksToCaughtUpClients.get(task), "uninitialized set");
-        return caughtUpClients.isEmpty() || caughtUpClients.contains(client);
+        return caughtUpClients.contains(client);
     }
 
     static int assignActiveTaskMovements(final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients,
+                                         final Map<TaskId, SortedSet<UUID>> tasksToClientByLag,
                                          final Map<UUID, ClientState> clientStates,
                                          final Map<UUID, Set<TaskId>> warmups,
                                          final AtomicInteger remainingWarmupReplicas) {
         final BiFunction<UUID, TaskId, Boolean> caughtUpPredicate =
-            (client, task) -> taskIsCaughtUpOnClientOrNoCaughtUpClientsExist(task, client, tasksToCaughtUpClients);
+            (client, task) -> taskIsCaughtUpOnClient(task, client, tasksToCaughtUpClients);
 
         final ConstrainedPrioritySet caughtUpClientsByTaskLoad = new ConstrainedPrioritySet(
             caughtUpPredicate,
@@ -89,10 +95,10 @@ final class TaskMovement {
             final UUID client = clientStateEntry.getKey();
             final ClientState state = clientStateEntry.getValue();
             for (final TaskId task : state.activeTasks()) {
-                // if the desired client is not caught up, and there is another client that _is_ caught up, then
-                // we schedule a movement, so we can move the active task to the caught-up client. We'll try to
+                // if the desired client is not caught up, and there is another client that _is_ more caught up, then
+                // we schedule a movement, so we can move the active task to a more caught-up client. We'll try to
                 // assign a warm-up to the desired client so that we can move it later on.
-                if (taskIsNotCaughtUpOnClientAndOtherCaughtUpClientsExist(task, client, tasksToCaughtUpClients)) {
+                if (taskIsNotCaughtUpOnClientAndOtherMoreCaughtUpClientsExist(task, client, clientStates, tasksToCaughtUpClients, tasksToClientByLag)) {
                     taskMovements.add(new TaskMovement(task, client, tasksToCaughtUpClients.get(task)));
                 }
             }
@@ -102,33 +108,14 @@ final class TaskMovement {
         final int movementsNeeded = taskMovements.size();
 
         for (final TaskMovement movement : taskMovements) {
-            final UUID standbySourceClient = caughtUpClientsByTaskLoad.poll(
-                movement.task,
-                c -> clientStates.get(c).hasStandbyTask(movement.task)
-            );
-            if (standbySourceClient == null) {
-                // there's not a caught-up standby available to take over the task, so we'll schedule a warmup instead
-                final UUID sourceClient = requireNonNull(
-                    caughtUpClientsByTaskLoad.poll(movement.task),
-                    "Tried to move task to caught-up client but none exist"
-                );
-
-                moveActiveAndTryToWarmUp(
-                    remainingWarmupReplicas,
-                    movement.task,
-                    clientStates.get(sourceClient),
-                    clientStates.get(movement.destination),
-                    warmups.computeIfAbsent(movement.destination, x -> new TreeSet<>())
-                );
-                caughtUpClientsByTaskLoad.offerAll(asList(sourceClient, movement.destination));
-            } else {
-                // we found a candidate to trade standby/active state with our destination, so we don't need a warmup
-                swapStandbyAndActive(
-                    movement.task,
-                    clientStates.get(standbySourceClient),
-                    clientStates.get(movement.destination)
-                );
-                caughtUpClientsByTaskLoad.offerAll(asList(standbySourceClient, movement.destination));
+            // Attempt to find a caught up standby, otherwise find any caught up client, failing that use the most
+            // caught up client.
+            final boolean moved = tryToSwapStandbyAndActiveOnCaughtUpClient(clientStates, caughtUpClientsByTaskLoad, movement) ||
+                    tryToMoveActiveToCaughtUpClientAndTryToWarmUp(clientStates, warmups, remainingWarmupReplicas, caughtUpClientsByTaskLoad, movement) ||
+                    tryToMoveActiveToMostCaughtUpClient(tasksToClientByLag, clientStates, warmups, remainingWarmupReplicas, caughtUpClientsByTaskLoad, movement);
+
+            if (!moved) {
+                throw new IllegalStateException("Tried to move task to more caught-up client as scheduled before but none exist");
             }
         }
 
@@ -136,11 +123,12 @@ final class TaskMovement {
     }
 
     static int assignStandbyTaskMovements(final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients,
+                                          final Map<TaskId, SortedSet<UUID>> tasksToClientByLag,
                                           final Map<UUID, ClientState> clientStates,
                                           final AtomicInteger remainingWarmupReplicas,
                                           final Map<UUID, Set<TaskId>> warmups) {
         final BiFunction<UUID, TaskId, Boolean> caughtUpPredicate =
-            (client, task) -> taskIsCaughtUpOnClientOrNoCaughtUpClientsExist(task, client, tasksToCaughtUpClients);
+            (client, task) -> taskIsCaughtUpOnClient(task, client, tasksToCaughtUpClients);
 
         final ConstrainedPrioritySet caughtUpClientsByTaskLoad = new ConstrainedPrioritySet(
             caughtUpPredicate,
@@ -157,8 +145,8 @@ final class TaskMovement {
             for (final TaskId task : state.standbyTasks()) {
                 if (warmups.getOrDefault(destination, Collections.emptySet()).contains(task)) {
                     // this is a warmup, so we won't move it.
-                } else if (taskIsNotCaughtUpOnClientAndOtherCaughtUpClientsExist(task, destination, tasksToCaughtUpClients)) {
-                    // if the desired client is not caught up, and there is another client that _is_ caught up, then
+                } else if (taskIsNotCaughtUpOnClientAndOtherMoreCaughtUpClientsExist(task, destination, clientStates, tasksToCaughtUpClients, tasksToClientByLag)) {
+                    // if the desired client is not caught up, and there is another client that _is_ more caught up, then
                     // we schedule a movement, so we can move the active task to the caught-up client. We'll try to
                     // assign a warm-up to the desired client so that we can move it later on.
                     taskMovements.add(new TaskMovement(task, destination, tasksToCaughtUpClients.get(task)));
@@ -170,12 +158,18 @@ final class TaskMovement {
         int movementsNeeded = 0;
 
         for (final TaskMovement movement : taskMovements) {
-            final UUID sourceClient = caughtUpClientsByTaskLoad.poll(
+            final Function<UUID, Boolean> eligibleClientPredicate =
+                    clientId -> !clientStates.get(clientId).hasAssignedTask(movement.task);
+            UUID sourceClient = caughtUpClientsByTaskLoad.poll(
                 movement.task,
-                clientId -> !clientStates.get(clientId).hasAssignedTask(movement.task)
+                eligibleClientPredicate
             );
 
             if (sourceClient == null) {
+                sourceClient = mostCaughtUpEligibleClient(tasksToClientByLag, eligibleClientPredicate, movement.task, movement.destination);
+            }
+
+            if (sourceClient == null) {
                 // then there's no caught-up client that doesn't already have a copy of this task, so there's
                 // nowhere to move it.
             } else {
@@ -193,6 +187,74 @@ final class TaskMovement {
         return movementsNeeded;
     }
 
+    private static boolean tryToSwapStandbyAndActiveOnCaughtUpClient(final Map<UUID, ClientState> clientStates,
+                                                                     final ConstrainedPrioritySet caughtUpClientsByTaskLoad,
+                                                                     final TaskMovement movement) {
+        final UUID caughtUpStandbySourceClient = caughtUpClientsByTaskLoad.poll(
+                movement.task,
+                c -> clientStates.get(c).hasStandbyTask(movement.task)
+        );
+        if (caughtUpStandbySourceClient != null) {
+            swapStandbyAndActive(
+                    movement.task,
+                    clientStates.get(caughtUpStandbySourceClient),
+                    clientStates.get(movement.destination)
+            );
+            caughtUpClientsByTaskLoad.offerAll(asList(caughtUpStandbySourceClient, movement.destination));
+            return true;
+        }
+        return false;
+    }
+
+    private static boolean tryToMoveActiveToCaughtUpClientAndTryToWarmUp(final Map<UUID, ClientState> clientStates,
+                                                                         final Map<UUID, Set<TaskId>> warmups,
+                                                                         final AtomicInteger remainingWarmupReplicas,
+                                                                         final ConstrainedPrioritySet caughtUpClientsByTaskLoad,
+                                                                         final TaskMovement movement) {
+        final UUID caughtUpSourceClient = caughtUpClientsByTaskLoad.poll(movement.task);
+        if (caughtUpSourceClient != null) {
+            moveActiveAndTryToWarmUp(
+                    remainingWarmupReplicas,
+                    movement.task,
+                    clientStates.get(caughtUpSourceClient),
+                    clientStates.get(movement.destination),
+                    warmups.computeIfAbsent(movement.destination, x -> new TreeSet<>())
+            );
+            caughtUpClientsByTaskLoad.offerAll(asList(caughtUpSourceClient, movement.destination));
+            return true;
+        }
+        return false;
+    }
+
+    private static boolean tryToMoveActiveToMostCaughtUpClient(final Map<TaskId, SortedSet<UUID>> tasksToClientByLag,
+                                                               final Map<UUID, ClientState> clientStates,
+                                                               final Map<UUID, Set<TaskId>> warmups,
+                                                               final AtomicInteger remainingWarmupReplicas,
+                                                               final ConstrainedPrioritySet caughtUpClientsByTaskLoad,
+                                                               final TaskMovement movement) {
+        final UUID mostCaughtUpSourceClient = mostCaughtUpEligibleClient(tasksToClientByLag, movement.task, movement.destination);
+        if (mostCaughtUpSourceClient != null) {
+            if (clientStates.get(mostCaughtUpSourceClient).hasStandbyTask(movement.task)) {
+                swapStandbyAndActive(
+                        movement.task,
+                        clientStates.get(mostCaughtUpSourceClient),
+                        clientStates.get(movement.destination)
+                );
+            } else {
+                moveActiveAndTryToWarmUp(
+                        remainingWarmupReplicas,
+                        movement.task,
+                        clientStates.get(mostCaughtUpSourceClient),
+                        clientStates.get(movement.destination),
+                        warmups.computeIfAbsent(movement.destination, x -> new TreeSet<>())
+                );
+            }
+            caughtUpClientsByTaskLoad.offerAll(asList(mostCaughtUpSourceClient, movement.destination));
+            return true;
+        }
+        return false;
+    }
+
     private static void moveActiveAndTryToWarmUp(final AtomicInteger remainingWarmupReplicas,
                                                  final TaskId task,
                                                  final ClientState sourceClientState,
@@ -235,4 +297,24 @@ final class TaskMovement {
         destinationClientState.assignStandby(task);
     }
 
+    private static UUID mostCaughtUpEligibleClient(final Map<TaskId, SortedSet<UUID>> tasksToClientByLag,
+                                                   final TaskId task,
+                                                   final UUID destinationClient) {
+        return mostCaughtUpEligibleClient(tasksToClientByLag, client -> true, task, destinationClient);
+    }
+
+    private static UUID mostCaughtUpEligibleClient(final Map<TaskId, SortedSet<UUID>> tasksToClientByLag,
+                                                   final Function<UUID, Boolean> constraint,
+                                                   final TaskId task,
+                                                   final UUID destinationClient) {
+        for (final UUID client : tasksToClientByLag.get(task)) {
+            if (destinationClient.equals(client)) {
+                break;
+            } else if (constraint.apply(client)) {
+                return client;
+            }
+        }
+        return null;
+    }
+
 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignorTest.java
index 36ae42f..90e0fed 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignorTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignorTest.java
@@ -420,6 +420,37 @@ public class HighAvailabilityTaskAssignorTest {
     }
 
     @Test
+    public void shouldAssignToMostCaughtUpIfActiveTasksWasNotOnCaughtUpClient() {
+        final Set<TaskId> allTasks = mkSet(TASK_0_0);
+        final Set<TaskId> statefulTasks = mkSet(TASK_0_0);
+        final ClientState client1 = new ClientState(emptySet(), emptySet(), singletonMap(TASK_0_0, Long.MAX_VALUE), EMPTY_CLIENT_TAGS, 1);
+        final ClientState client2 = new ClientState(emptySet(), emptySet(), singletonMap(TASK_0_0, 1000L), EMPTY_CLIENT_TAGS, 1);
+        final ClientState client3 = new ClientState(emptySet(), emptySet(), singletonMap(TASK_0_0, 500L), EMPTY_CLIENT_TAGS, 1);
+        final Map<UUID, ClientState> clientStates = mkMap(
+                mkEntry(UUID_1, client1),
+                mkEntry(UUID_2, client2),
+                mkEntry(UUID_3, client3)
+        );
+
+        final boolean probingRebalanceNeeded =
+                new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, statefulTasks, configWithStandbys);
+
+        assertThat(clientStates.get(UUID_1).activeTasks(), is(emptySet()));
+        assertThat(clientStates.get(UUID_2).activeTasks(), is(emptySet()));
+        assertThat(clientStates.get(UUID_3).activeTasks(), is(singleton(TASK_0_0)));
+
+        assertThat(clientStates.get(UUID_1).standbyTasks(), is(singleton(TASK_0_0))); // warm up
+        assertThat(clientStates.get(UUID_2).standbyTasks(), is(singleton(TASK_0_0))); // standby
+        assertThat(clientStates.get(UUID_3).standbyTasks(), is(emptySet()));
+
+        assertThat(probingRebalanceNeeded, is(true));
+        assertValidAssignment(1, 1, allTasks, emptySet(), clientStates, new StringBuilder());
+        assertBalancedActiveAssignment(clientStates, new StringBuilder());
+        assertBalancedStatefulAssignment(allTasks, clientStates, new StringBuilder());
+        assertBalancedTasks(clientStates);
+    }
+
+    @Test
     public void shouldAssignStandbysForStatefulTasks() {
         final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1);
         final Set<TaskId> statefulTasks = mkSet(TASK_0_0, TASK_0_1);
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovementTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovementTest.java
index 9b58d18..baf6d18 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovementTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovementTest.java
@@ -19,19 +19,21 @@ package org.apache.kafka.streams.processor.internals.assignment;
 import org.apache.kafka.streams.processor.TaskId;
 import org.junit.Test;
 
-import java.util.Collection;
+import java.util.Comparator;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.SortedSet;
 import java.util.TreeMap;
+import java.util.TreeSet;
 import java.util.UUID;
 import java.util.concurrent.atomic.AtomicInteger;
 
 import static java.util.Arrays.asList;
-import static java.util.Collections.emptyList;
+import static java.util.Collections.emptyMap;
+import static java.util.Collections.emptySet;
 import static java.util.Collections.emptySortedSet;
-import static java.util.Collections.singletonList;
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.common.utils.Utils.mkSet;
@@ -58,17 +60,20 @@ public class TaskMovementTest {
         final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2);
 
         final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = new HashMap<>();
+        final Map<TaskId, SortedSet<UUID>> tasksToClientByLag = new HashMap<>();
         for (final TaskId task : allTasks) {
             tasksToCaughtUpClients.put(task, mkSortedSet(UUID_1, UUID_2, UUID_3));
+            tasksToClientByLag.put(task, mkOrderedSet(UUID_1, UUID_2, UUID_3));
         }
 
-        final ClientState client1 = getClientStateWithActiveAssignment(asList(TASK_0_0, TASK_1_0));
-        final ClientState client2 = getClientStateWithActiveAssignment(asList(TASK_0_1, TASK_1_1));
-        final ClientState client3 = getClientStateWithActiveAssignment(asList(TASK_0_2, TASK_1_2));
+        final ClientState client1 = getClientStateWithActiveAssignment(mkSet(TASK_0_0, TASK_1_0), allTasks, allTasks);
+        final ClientState client2 = getClientStateWithActiveAssignment(mkSet(TASK_0_1, TASK_1_1), allTasks, allTasks);
+        final ClientState client3 = getClientStateWithActiveAssignment(mkSet(TASK_0_2, TASK_1_2), allTasks, allTasks);
 
         assertThat(
             assignActiveTaskMovements(
                 tasksToCaughtUpClients,
+                tasksToClientByLag,
                 getClientStatesMap(client1, client2, client3),
                 new TreeMap<>(),
                 new AtomicInteger(maxWarmupReplicas)
@@ -80,10 +85,11 @@ public class TaskMovementTest {
     @Test
     public void shouldAssignAllTasksToClientsAndReturnFalseIfNoClientsAreCaughtUp() {
         final int maxWarmupReplicas = Integer.MAX_VALUE;
+        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2);
 
-        final ClientState client1 = getClientStateWithActiveAssignment(asList(TASK_0_0, TASK_1_0));
-        final ClientState client2 = getClientStateWithActiveAssignment(asList(TASK_0_1, TASK_1_1));
-        final ClientState client3 = getClientStateWithActiveAssignment(asList(TASK_0_2, TASK_1_2));
+        final ClientState client1 = getClientStateWithActiveAssignment(mkSet(TASK_0_0, TASK_1_0), mkSet(), allTasks);
+        final ClientState client2 = getClientStateWithActiveAssignment(mkSet(TASK_0_1, TASK_1_1), mkSet(), allTasks);
+        final ClientState client3 = getClientStateWithActiveAssignment(mkSet(TASK_0_2, TASK_1_2), mkSet(), allTasks);
 
         final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = mkMap(
             mkEntry(TASK_0_0, emptySortedSet()),
@@ -93,9 +99,18 @@ public class TaskMovementTest {
             mkEntry(TASK_1_1, emptySortedSet()),
             mkEntry(TASK_1_2, emptySortedSet())
         );
+        final Map<TaskId, SortedSet<UUID>> tasksToClientByLag = mkMap(
+            mkEntry(TASK_0_0, mkOrderedSet(UUID_1, UUID_2, UUID_3)),
+            mkEntry(TASK_0_1, mkOrderedSet(UUID_1, UUID_2, UUID_3)),
+            mkEntry(TASK_0_2, mkOrderedSet(UUID_1, UUID_2, UUID_3)),
+            mkEntry(TASK_1_0, mkOrderedSet(UUID_1, UUID_2, UUID_3)),
+            mkEntry(TASK_1_1, mkOrderedSet(UUID_1, UUID_2, UUID_3)),
+            mkEntry(TASK_1_2, mkOrderedSet(UUID_1, UUID_2, UUID_3))
+        );
         assertThat(
             assignActiveTaskMovements(
                 tasksToCaughtUpClients,
+                tasksToClientByLag,
                 getClientStatesMap(client1, client2, client3),
                 new TreeMap<>(),
                 new AtomicInteger(maxWarmupReplicas)
@@ -107,9 +122,10 @@ public class TaskMovementTest {
     @Test
     public void shouldMoveTasksToCaughtUpClientsAndAssignWarmupReplicasInTheirPlace() {
         final int maxWarmupReplicas = Integer.MAX_VALUE;
-        final ClientState client1 = getClientStateWithActiveAssignment(singletonList(TASK_0_0));
-        final ClientState client2 = getClientStateWithActiveAssignment(singletonList(TASK_0_1));
-        final ClientState client3 = getClientStateWithActiveAssignment(singletonList(TASK_0_2));
+        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2);
+        final ClientState client1 = getClientStateWithActiveAssignment(mkSet(TASK_0_0), mkSet(TASK_0_0), allTasks);
+        final ClientState client2 = getClientStateWithActiveAssignment(mkSet(TASK_0_1), mkSet(TASK_0_2), allTasks);
+        final ClientState client3 = getClientStateWithActiveAssignment(mkSet(TASK_0_2), mkSet(TASK_0_1), allTasks);
         final Map<UUID, ClientState> clientStates = getClientStatesMap(client1, client2, client3);
 
         final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = mkMap(
@@ -117,11 +133,17 @@ public class TaskMovementTest {
             mkEntry(TASK_0_1, mkSortedSet(UUID_3)),
             mkEntry(TASK_0_2, mkSortedSet(UUID_2))
         );
+        final Map<TaskId, SortedSet<UUID>> tasksToClientByLag = mkMap(
+            mkEntry(TASK_0_0, mkOrderedSet(UUID_1, UUID_2, UUID_3)),
+            mkEntry(TASK_0_1, mkOrderedSet(UUID_3, UUID_1, UUID_2)),
+            mkEntry(TASK_0_2, mkOrderedSet(UUID_2, UUID_1, UUID_3))
+        );
 
         assertThat(
             "should have assigned movements",
             assignActiveTaskMovements(
                 tasksToCaughtUpClients,
+                tasksToClientByLag,
                 clientStates,
                 new TreeMap<>(),
                 new AtomicInteger(maxWarmupReplicas)
@@ -140,11 +162,59 @@ public class TaskMovementTest {
     }
 
     @Test
+    public void shouldMoveTasksToMostCaughtUpClientsAndAssignWarmupReplicasInTheirPlace() {
+        final int maxWarmupReplicas = Integer.MAX_VALUE;
+        final Map<TaskId, Long> client1Lags = mkMap(mkEntry(TASK_0_0, 10000L), mkEntry(TASK_0_1, 20000L), mkEntry(TASK_0_2, 30000L));
+        final Map<TaskId, Long> client2Lags = mkMap(mkEntry(TASK_0_2, 10000L), mkEntry(TASK_0_0, 20000L), mkEntry(TASK_0_1, 30000L));
+        final Map<TaskId, Long> client3Lags = mkMap(mkEntry(TASK_0_1, 10000L), mkEntry(TASK_0_2, 20000L), mkEntry(TASK_0_0, 30000L));
+
+        final ClientState client1 = getClientStateWithLags(mkSet(TASK_0_0), client1Lags);
+        final ClientState client2 = getClientStateWithLags(mkSet(TASK_0_1), client2Lags);
+        final ClientState client3 = getClientStateWithLags(mkSet(TASK_0_2), client3Lags);
+        // To test when the task is already a standby on the most caught up node
+        client3.assignStandby(TASK_0_1);
+        final Map<UUID, ClientState> clientStates = getClientStatesMap(client1, client2, client3);
+
+        final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = mkMap(
+                mkEntry(TASK_0_0, mkSortedSet()),
+                mkEntry(TASK_0_1, mkSortedSet()),
+                mkEntry(TASK_0_2, mkSortedSet())
+        );
+        final Map<TaskId, SortedSet<UUID>> tasksToClientByLag = mkMap(
+                mkEntry(TASK_0_0, mkOrderedSet(UUID_1, UUID_2, UUID_3)),
+                mkEntry(TASK_0_1, mkOrderedSet(UUID_3, UUID_1, UUID_2)),
+                mkEntry(TASK_0_2, mkOrderedSet(UUID_2, UUID_3, UUID_1))
+        );
+
+        assertThat(
+                "should have assigned movements",
+                assignActiveTaskMovements(
+                        tasksToCaughtUpClients,
+                        tasksToClientByLag,
+                        clientStates,
+                        new TreeMap<>(),
+                        new AtomicInteger(maxWarmupReplicas)
+                ),
+                is(2)
+        );
+        // The active tasks have changed to the ones that each client is most caught up on
+        assertThat(client1, hasProperty("activeTasks", ClientState::activeTasks, mkSet(TASK_0_0)));
+        assertThat(client2, hasProperty("activeTasks", ClientState::activeTasks, mkSet(TASK_0_2)));
+        assertThat(client3, hasProperty("activeTasks", ClientState::activeTasks, mkSet(TASK_0_1)));
+
+        // we assigned warmups to migrate to the input active assignment
+        assertThat(client1, hasProperty("standbyTasks", ClientState::standbyTasks, mkSet()));
+        assertThat(client2, hasProperty("standbyTasks", ClientState::standbyTasks, mkSet(TASK_0_1)));
+        assertThat(client3, hasProperty("standbyTasks", ClientState::standbyTasks, mkSet(TASK_0_2)));
+    }
+
+    @Test
     public void shouldOnlyGetUpToMaxWarmupReplicasAndReturnTrue() {
         final int maxWarmupReplicas = 1;
-        final ClientState client1 = getClientStateWithActiveAssignment(singletonList(TASK_0_0));
-        final ClientState client2 = getClientStateWithActiveAssignment(singletonList(TASK_0_1));
-        final ClientState client3 = getClientStateWithActiveAssignment(singletonList(TASK_0_2));
+        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2);
+        final ClientState client1 = getClientStateWithActiveAssignment(mkSet(TASK_0_0), mkSet(TASK_0_0), allTasks);
+        final ClientState client2 = getClientStateWithActiveAssignment(mkSet(TASK_0_1), mkSet(TASK_0_2), allTasks);
+        final ClientState client3 = getClientStateWithActiveAssignment(mkSet(TASK_0_2), mkSet(TASK_0_1), allTasks);
         final Map<UUID, ClientState> clientStates = getClientStatesMap(client1, client2, client3);
 
         final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = mkMap(
@@ -152,11 +222,17 @@ public class TaskMovementTest {
             mkEntry(TASK_0_1, mkSortedSet(UUID_3)),
             mkEntry(TASK_0_2, mkSortedSet(UUID_2))
         );
+        final Map<TaskId, SortedSet<UUID>> tasksToClientByLag = mkMap(
+            mkEntry(TASK_0_0, mkOrderedSet(UUID_1, UUID_2, UUID_3)),
+            mkEntry(TASK_0_1, mkOrderedSet(UUID_3, UUID_1, UUID_2)),
+            mkEntry(TASK_0_2, mkOrderedSet(UUID_2, UUID_1, UUID_3))
+        );
 
         assertThat(
             "should have assigned movements",
             assignActiveTaskMovements(
                 tasksToCaughtUpClients,
+                tasksToClientByLag,
                 clientStates,
                 new TreeMap<>(),
                 new AtomicInteger(maxWarmupReplicas)
@@ -182,19 +258,24 @@ public class TaskMovementTest {
     @Test
     public void shouldNotCountPreviousStandbyTasksTowardsMaxWarmupReplicas() {
         final int maxWarmupReplicas = 0;
-        final ClientState client1 = getClientStateWithActiveAssignment(emptyList());
+        final Set<TaskId> allTasks = mkSet(TASK_0_0);
+        final ClientState client1 = getClientStateWithActiveAssignment(mkSet(), mkSet(TASK_0_0), allTasks);
         client1.assignStandby(TASK_0_0);
-        final ClientState client2 = getClientStateWithActiveAssignment(singletonList(TASK_0_0));
+        final ClientState client2 = getClientStateWithActiveAssignment(mkSet(TASK_0_0), mkSet(), allTasks);
         final Map<UUID, ClientState> clientStates = getClientStatesMap(client1, client2);
 
         final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients = mkMap(
             mkEntry(TASK_0_0, mkSortedSet(UUID_1))
         );
+        final Map<TaskId, SortedSet<UUID>> tasksToClientByLag = mkMap(
+            mkEntry(TASK_0_0, mkOrderedSet(UUID_1, UUID_2))
+        );
 
         assertThat(
             "should have assigned movements",
             assignActiveTaskMovements(
                 tasksToCaughtUpClients,
+                tasksToClientByLag,
                 clientStates,
                 new TreeMap<>(),
                 new AtomicInteger(maxWarmupReplicas)
@@ -215,10 +296,35 @@ public class TaskMovementTest {
 
     }
 
-    private static ClientState getClientStateWithActiveAssignment(final Collection<TaskId> activeTasks) {
-        final ClientState client1 = new ClientState(1);
+    private static ClientState getClientStateWithActiveAssignment(final Set<TaskId> activeTasks,
+                                                                  final Set<TaskId> caughtUpTasks,
+                                                                  final Set<TaskId> allTasks) {
+        final Map<TaskId, Long> lags = new HashMap<>();
+        for (final TaskId task : allTasks) {
+            if (caughtUpTasks.contains(task)) {
+                lags.put(task, 0L);
+            } else {
+                lags.put(task, 10000L);
+            }
+        }
+        return getClientStateWithLags(activeTasks, lags);
+    }
+
+    private static ClientState getClientStateWithLags(final Set<TaskId> activeTasks,
+                                                      final Map<TaskId, Long> taskLags) {
+        final ClientState client1 = new ClientState(activeTasks, emptySet(), taskLags, emptyMap(), 1);
         client1.assignActiveTasks(activeTasks);
         return client1;
     }
 
+    /**
+     * Creates a SortedSet with the sort order being the order of elements in the parameter list
+     */
+    private static SortedSet<UUID> mkOrderedSet(final UUID... clients) {
+        final List<UUID> clientList = asList(clients);
+        final SortedSet<UUID> set = new TreeSet<>(Comparator.comparing(clientList::indexOf));
+        set.addAll(clientList);
+        return set;
+    }
+
 }