You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by vv...@apache.org on 2020/03/21 18:41:03 UTC

[kafka] branch trunk updated: KAFKA-6145: Pt 2.5 Compute overall task lag per client (#8252)

This is an automated email from the ASF dual-hosted git repository.

vvcephei pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 6cf27c9  KAFKA-6145: Pt 2.5 Compute overall task lag per client (#8252)
6cf27c9 is described below

commit 6cf27c9c771900baf43cc47f9b010dbf7a86fa22
Author: A. Sophie Blee-Goldman <so...@confluent.io>
AuthorDate: Sat Mar 21 11:40:34 2020 -0700

    KAFKA-6145: Pt 2.5 Compute overall task lag per client (#8252)
    
    Once we have encoded the offset sums per task for each client, we can compute the overall lag during assign by fetching the end offsets for all changelog and subtracting.
    
    If the listOffsets request fails, we simply return a "completely sticky" assignment, ie all active tasks are given to previous owners regardless of balance.
    
    Builds (but does not yet use) the statefulTasksToRankedCandidates map with the ranking:
    Rank -1: active running task
    Rank 0: standby or restoring task whose overall lag is within acceptableRecoveryLag
    Rank 1: tasks whose lag is unknown (eg during version probing)
    Rank 1+: all other tasks are ranked according to their actual total lag
    
    Implements: KIP-441
    Reviewers: Bruno Cadonna <br...@confluent.io>, John Roesler <vv...@apache.org>
---
 checkstyle/suppressions.xml                        |   2 +-
 .../java/org/apache/kafka/common/utils/Utils.java  |  18 +-
 .../org/apache/kafka/streams/KafkaStreams.java     |  38 +-
 .../internals/StreamsPartitionAssignor.java        | 210 ++++++++--
 .../assignment/AssignorConfiguration.java          |  15 +-
 .../internals/assignment/ClientState.java          | 152 +++++--
 .../internals/assignment/StickyTaskAssignor.java   |  19 +-
 .../internals/assignment/SubscriptionInfo.java     |   4 +-
 .../internals/assignment/TaskAssignor.java         |   2 +-
 .../org/apache/kafka/streams/KafkaStreamsTest.java |  64 +++
 .../internals/StreamsPartitionAssignorTest.java    | 440 +++++++++++++++++----
 .../internals/assignment/ClientStateTest.java      | 207 +++++++---
 .../assignment/StickyTaskAssignorTest.java         |  13 +
 13 files changed, 940 insertions(+), 244 deletions(-)

diff --git a/checkstyle/suppressions.xml b/checkstyle/suppressions.xml
index 4a06262..2f3186c 100644
--- a/checkstyle/suppressions.xml
+++ b/checkstyle/suppressions.xml
@@ -12,7 +12,7 @@
     <suppress checks="CyclomaticComplexity|BooleanExpressionComplexity"
               files="(SchemaGenerator|MessageDataGenerator|FieldSpec).java"/>
     <suppress checks="NPathComplexity"
-              files="(MessageDataGenerator|FieldSpec).java"/>
+              files="(MessageDataGenerator|FieldSpec|AssignorConfiguration).java"/>
     <suppress checks="JavaNCSS"
               files="(ApiMessageType).java|MessageDataGenerator.java"/>
     <suppress checks="MethodLength"
diff --git a/clients/src/main/java/org/apache/kafka/common/utils/Utils.java b/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
index e9d4cc4..682ccf5 100755
--- a/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
+++ b/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
@@ -16,6 +16,8 @@
  */
 package org.apache.kafka.common.utils;
 
+import java.util.SortedSet;
+import java.util.TreeSet;
 import org.apache.kafka.common.KafkaException;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -660,7 +662,7 @@ public final class Utils {
         return existingBuffer;
     }
 
-    /*
+    /**
      * Creates a set
      * @param elems the elements
      * @param <T> the type of element
@@ -675,6 +677,20 @@ public final class Utils {
     }
 
     /**
+     * Creates a sorted set
+     * @param elems the elements
+     * @param <T> the type of element, must be comparable
+     * @return SortedSet
+     */
+    @SafeVarargs
+    public static <T extends Comparable<T>> SortedSet<T> mkSortedSet(T... elems) {
+        SortedSet<T> result = new TreeSet<>();
+        for (T elem : elems)
+            result.add(elem);
+        return result;
+    }
+
+    /**
      * Creates a map entry (for use with {@link Utils#mkMap(java.util.Map.Entry[])})
      *
      * @param k   The key
diff --git a/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java b/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java
index 2807e62..8487268 100644
--- a/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java
+++ b/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java
@@ -19,6 +19,7 @@ package org.apache.kafka.streams;
 import java.util.LinkedList;
 import java.util.TreeMap;
 import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeoutException;
 import java.util.function.Function;
 import java.util.stream.Collectors;
 import org.apache.kafka.clients.admin.Admin;
@@ -27,6 +28,7 @@ import org.apache.kafka.clients.admin.OffsetSpec;
 import org.apache.kafka.clients.consumer.KafkaConsumer;
 import org.apache.kafka.clients.producer.KafkaProducer;
 import org.apache.kafka.clients.producer.ProducerConfig;
+import org.apache.kafka.common.KafkaFuture;
 import org.apache.kafka.common.Metric;
 import org.apache.kafka.common.MetricName;
 import org.apache.kafka.common.TopicPartition;
@@ -1217,17 +1219,9 @@ public class KafkaStreams implements AutoCloseable {
         }
 
         log.debug("Current changelog positions: {}", allChangelogPositions);
-        final Map<TopicPartition, ListOffsetsResultInfo> allEndOffsets;
-        try {
-            allEndOffsets = adminClient.listOffsets(
-                allPartitions.stream()
-                    .collect(Collectors.toMap(Function.identity(), tp -> OffsetSpec.latest()))
-            ).all().get();
-        } catch (final RuntimeException | InterruptedException | ExecutionException e) {
-            throw new StreamsException("Unable to obtain end offsets from kafka", e);
-        }
-
+        final Map<TopicPartition, ListOffsetsResultInfo> allEndOffsets = fetchEndOffsetsWithoutTimeout(allPartitions, adminClient);
         log.debug("Current end offsets :{}", allEndOffsets);
+
         for (final Map.Entry<TopicPartition, ListOffsetsResultInfo> entry : allEndOffsets.entrySet()) {
             // Avoiding an extra admin API lookup by computing lags for not-yet-started restorations
             // from zero instead of the real "earliest offset" for the changelog.
@@ -1244,4 +1238,28 @@ public class KafkaStreams implements AutoCloseable {
 
         return Collections.unmodifiableMap(localStorePartitionLags);
     }
+
+    static Map<TopicPartition, ListOffsetsResultInfo> fetchEndOffsetsWithoutTimeout(final Collection<TopicPartition> partitions,
+                                                                                    final Admin adminClient) {
+        return fetchEndOffsets(partitions, adminClient, null);
+    }
+
+    public static Map<TopicPartition, ListOffsetsResultInfo> fetchEndOffsets(final Collection<TopicPartition> partitions,
+                                                                             final Admin adminClient,
+                                                                             final Duration timeout) {
+        final Map<TopicPartition, ListOffsetsResultInfo> endOffsets;
+        try {
+            final KafkaFuture<Map<TopicPartition, ListOffsetsResultInfo>> future =  adminClient.listOffsets(
+                partitions.stream().collect(Collectors.toMap(Function.identity(), tp -> OffsetSpec.latest())))
+                                                                                        .all();
+            if (timeout == null) {
+                endOffsets = future.get();
+            } else {
+                endOffsets = future.get(timeout.toMillis(), TimeUnit.MILLISECONDS);
+            }
+        } catch (final TimeoutException | RuntimeException | InterruptedException | ExecutionException e) {
+            throw new StreamsException("Unable to obtain end offsets from kafka", e);
+        }
+        return endOffsets;
+    }
 }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java
index 2740dee..2d1d3de 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java
@@ -16,7 +16,12 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
+import java.time.Duration;
 import java.util.Objects;
+import java.util.SortedSet;
+import java.util.TreeSet;
+import org.apache.kafka.clients.admin.Admin;
+import org.apache.kafka.clients.admin.ListOffsetsResult.ListOffsetsResultInfo;
 import org.apache.kafka.clients.consumer.ConsumerGroupMetadata;
 import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor;
 import org.apache.kafka.common.Cluster;
@@ -30,6 +35,7 @@ import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.errors.TaskAssignmentException;
 import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder.TopicsInfo;
 import org.apache.kafka.streams.processor.internals.assignment.AssignmentInfo;
 import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration;
 import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs;
@@ -60,9 +66,11 @@ import java.util.concurrent.atomic.AtomicInteger;
 import java.util.stream.Collectors;
 
 import static java.util.UUID.randomUUID;
+import static org.apache.kafka.streams.KafkaStreams.fetchEndOffsets;
 import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.EARLIEST_PROBEABLE_VERSION;
 import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.LATEST_SUPPORTED_VERSION;
 import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.UNKNOWN;
+import static org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo.UNKNOWN_OFFSET_SUM;
 
 public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Configurable {
 
@@ -124,9 +132,8 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
             state.addOwnedPartitions(ownedPartitions, consumerMemberId);
         }
 
-        void addPreviousTasks(final SubscriptionInfo info) {
-            state.addPreviousActiveTasks(info.prevTasks());
-            state.addPreviousStandbyTasks(info.standbyTasks());
+        void addPreviousTasksAndOffsetSums(final Map<TaskId, Long> taskOffsetSums) {
+            state.addPreviousTasksAndOffsetSums(taskOffsetSums);
         }
 
         @Override
@@ -203,6 +210,8 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
 
     protected int usedSubscriptionMetadataVersion = LATEST_SUPPORTED_VERSION;
 
+    private Admin adminClient;
+    private int adminClientTimeout;
     private InternalTopicManager internalTopicManager;
     private CopartitionedTopicsEnforcer copartitionedTopicsEnforcer;
     private RebalanceProtocol rebalanceProtocol;
@@ -228,6 +237,8 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
         assignmentConfigs = assignorConfiguration.getAssignmentConfigs();
         partitionGrouper = assignorConfiguration.getPartitionGrouper();
         userEndPoint = assignorConfiguration.getUserEndPoint();
+        adminClient = assignorConfiguration.getAdminClient();
+        adminClientTimeout = assignorConfiguration.getAdminClientTimeout();
         internalTopicManager = assignorConfiguration.getInternalTopicManager();
         copartitionedTopicsEnforcer = assignorConfiguration.getCopartitionedTopicsEnforcer();
         rebalanceProtocol = assignorConfiguration.rebalanceProtocol();
@@ -350,7 +361,7 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
             // add the consumer and any info in its subscription to the client
             clientMetadata.addConsumer(consumerId, subscription.ownedPartitions());
             allOwnedPartitions.addAll(subscription.ownedPartitions());
-            clientMetadata.addPreviousTasks(info);
+            clientMetadata.addPreviousTasksAndOffsetSums(info.taskOffsetSums());
         }
 
         final boolean versionProbing =
@@ -363,7 +374,7 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
         // parse the topology to determine the repartition source topics,
         // making sure they are created with the number of partitions as
         // the maximum of the depending sub-topologies source topics' number of partitions
-        final Map<Integer, InternalTopologyBuilder.TopicsInfo> topicGroups = taskManager.builder().topicGroups();
+        final Map<Integer, TopicsInfo> topicGroups = taskManager.builder().topicGroups();
 
         final Map<TopicPartition, PartitionInfo> allRepartitionTopicPartitions;
         try {
@@ -385,7 +396,7 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
 
         final Set<String> allSourceTopics = new HashSet<>();
         final Map<Integer, Set<String>> sourceTopicsByGroup = new HashMap<>();
-        for (final Map.Entry<Integer, InternalTopologyBuilder.TopicsInfo> entry : topicGroups.entrySet()) {
+        for (final Map.Entry<Integer, TopicsInfo> entry : topicGroups.entrySet()) {
             allSourceTopics.addAll(entry.getValue().sourceTopics);
             sourceTopicsByGroup.put(entry.getKey(), entry.getValue().sourceTopics);
         }
@@ -394,7 +405,6 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
         final Map<TaskId, Set<TopicPartition>> partitionsForTask =
             partitionGrouper.partitionGroups(sourceTopicsByGroup, fullMetadata);
 
-
         assignTasksToClients(allSourceTopics, partitionsForTask, topicGroups, clientMetadataMap, fullMetadata);
 
         // ---------------- Step Three ---------------- //
@@ -482,10 +492,10 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
      * @return a map of repartition topics and their metadata
      * @throws TaskAssignmentException if there is incomplete source topic metadata due to missing source topic(s)
      */
-    private Map<String, InternalTopicConfig> computeRepartitionTopicMetadata(final Map<Integer, InternalTopologyBuilder.TopicsInfo> topicGroups,
+    private Map<String, InternalTopicConfig> computeRepartitionTopicMetadata(final Map<Integer, TopicsInfo> topicGroups,
                                                                              final Cluster metadata) throws TaskAssignmentException {
         final Map<String, InternalTopicConfig> repartitionTopicMetadata = new HashMap<>();
-        for (final InternalTopologyBuilder.TopicsInfo topicsInfo : topicGroups.values()) {
+        for (final TopicsInfo topicsInfo : topicGroups.values()) {
             for (final String topic : topicsInfo.sourceTopics) {
                 if (!topicsInfo.repartitionSourceTopics.keySet().contains(topic) &&
                         !metadata.topics().contains(topic)) {
@@ -507,7 +517,7 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
      *
      * @return map from repartition topic to its partition info
      */
-    private Map<TopicPartition, PartitionInfo> prepareRepartitionTopics(final Map<Integer, InternalTopologyBuilder.TopicsInfo> topicGroups,
+    private Map<TopicPartition, PartitionInfo> prepareRepartitionTopics(final Map<Integer, TopicsInfo> topicGroups,
                                                                            final Cluster metadata) {
         final Map<String, InternalTopicConfig> repartitionTopicMetadata = computeRepartitionTopicMetadata(topicGroups, metadata);
 
@@ -543,13 +553,13 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
      * Computes the number of partitions and sets it for each repartition topic in repartitionTopicMetadata
      */
     private void setRepartitionTopicMetadataNumberOfPartitions(final Map<String, InternalTopicConfig> repartitionTopicMetadata,
-                                                               final Map<Integer, InternalTopologyBuilder.TopicsInfo> topicGroups,
+                                                               final Map<Integer, TopicsInfo> topicGroups,
                                                                final Cluster metadata) {
         boolean numPartitionsNeeded;
         do {
             numPartitionsNeeded = false;
 
-            for (final InternalTopologyBuilder.TopicsInfo topicsInfo : topicGroups.values()) {
+            for (final TopicsInfo topicsInfo : topicGroups.values()) {
                 for (final String topicName : topicsInfo.repartitionSourceTopics.keySet()) {
                     final Optional<Integer> maybeNumPartitions = repartitionTopicMetadata.get(topicName)
                                                                      .numberOfPartitions();
@@ -557,7 +567,7 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
 
                     if (!maybeNumPartitions.isPresent()) {
                         // try set the number of partitions for this repartition topic if it is not set yet
-                        for (final InternalTopologyBuilder.TopicsInfo otherTopicsInfo : topicGroups.values()) {
+                        for (final TopicsInfo otherTopicsInfo : topicGroups.values()) {
                             final Set<String> otherSinkTopics = otherTopicsInfo.sinkTopics;
 
                             if (otherSinkTopics.contains(topicName)) {
@@ -670,17 +680,17 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
     /**
      * Resolve changelog topic metadata and create them if necessary.
      *
-     * @return set of standby task ids (any task that is stateful and has logging enabled)
+     * @return mapping of stateful tasks to their set of changelog topics
      */
-    private Set<TaskId> prepareChangelogTopics(final Map<Integer, InternalTopologyBuilder.TopicsInfo> topicGroups,
-                                               final Map<Integer, Set<TaskId>> tasksForTopicGroup) {
-        final Set<TaskId> standbyTaskIds = new HashSet<>();
+    private Map<TaskId, Set<TopicPartition>> prepareChangelogTopics(final Map<Integer, TopicsInfo> topicGroups,
+                                                                    final Map<Integer, Set<TaskId>> tasksForTopicGroup) {
+        final Map<TaskId, Set<TopicPartition>> changelogsByStatefulTask = new HashMap<>();
 
         // add tasks to state change log topic subscribers
         final Map<String, InternalTopicConfig> changelogTopicMetadata = new HashMap<>();
-        for (final Map.Entry<Integer, InternalTopologyBuilder.TopicsInfo> entry : topicGroups.entrySet()) {
+        for (final Map.Entry<Integer, TopicsInfo> entry : topicGroups.entrySet()) {
             final int topicGroupId = entry.getKey();
-            final InternalTopologyBuilder.TopicsInfo topicsInfo = entry.getValue();
+            final TopicsInfo topicsInfo = entry.getValue();
 
             final Set<TaskId> topicGroupTasks = tasksForTopicGroup.get(topicGroupId);
             if (topicGroupTasks == null) {
@@ -690,7 +700,15 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
                 continue;
             }
 
-            standbyTaskIds.addAll(topicGroupTasks);
+            for (final TaskId task : topicGroupTasks) {
+                changelogsByStatefulTask.put(
+                    task,
+                    topicsInfo.stateChangelogTopics
+                        .keySet()
+                        .stream()
+                        .map(topic -> new TopicPartition(topic, task.partition))
+                        .collect(Collectors.toSet()));
+            }
 
             for (final InternalTopicConfig topicConfig : topicsInfo.nonSourceChangelogTopics()) {
                  // the expected number of partitions is the max value of TaskId.partition + 1
@@ -707,33 +725,92 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
 
         prepareTopic(changelogTopicMetadata);
         log.debug("Created state changelog topics {} from the parsed topology.", changelogTopicMetadata.values());
-        return standbyTaskIds;
+        return changelogsByStatefulTask;
     }
 
     /**
-     * Assigns a set of tasks to each client (Streams instance) using the sticky assignor
+     * Assigns a set of tasks to each client (Streams instance) using the sticky assignor to prioritize clients
+     * based on the previous state and overall lag.
      */
     private void assignTasksToClients(final Set<String> allSourceTopics,
                                       final Map<TaskId, Set<TopicPartition>> partitionsForTask,
-                                      final Map<Integer, InternalTopologyBuilder.TopicsInfo> topicGroups,
+                                      final Map<Integer, TopicsInfo> topicGroups,
                                       final Map<UUID, ClientMetadata> clientMetadataMap,
                                       final Cluster fullMetadata) {
         final Map<TopicPartition, TaskId> taskForPartition = new HashMap<>();
         final Map<Integer, Set<TaskId>> tasksForTopicGroup = new HashMap<>();
         populateTasksForMaps(taskForPartition, tasksForTopicGroup, allSourceTopics, partitionsForTask, fullMetadata);
 
-        final Set<TaskId> standbyTaskIds = prepareChangelogTopics(topicGroups, tasksForTopicGroup);
+        final Map<TaskId, Set<TopicPartition>> changelogsByStatefulTask =
+            prepareChangelogTopics(topicGroups, tasksForTopicGroup);
+
+        final Map<UUID, ClientState> clientStates = new HashMap<>();
+        final boolean lagComputationSuccessful =
+            populateClientStatesMap(clientStates, clientMetadataMap, taskForPartition, changelogsByStatefulTask);
+
+        // assign tasks to clients
+        final Set<TaskId> allTasks = partitionsForTask.keySet();
+        final Set<TaskId> standbyTasks = changelogsByStatefulTask.keySet();
+
+        if (lagComputationSuccessful) {
+            final Map<TaskId, SortedSet<RankedClient<UUID>>> statefulTasksToRankedCandidates =
+                buildClientRankingsByTask(standbyTasks, clientStates, acceptableRecoveryLag());
+            log.trace("Computed statefulTasksToRankedCandidates map as {}", statefulTasksToRankedCandidates);
+        }
+
+        log.debug("Assigning tasks {} to clients {} with number of replicas {}",
+            allTasks, clientStates, numStandbyReplicas());
+
+        final StickyTaskAssignor<UUID> taskAssignor = new StickyTaskAssignor<>(clientStates, allTasks, standbyTasks);
+        if (!lagComputationSuccessful) {
+            taskAssignor.preservePreviousTaskAssignment();
+        }
+        taskAssignor.assign(numStandbyReplicas());
+
+        log.info("Assigned tasks to clients as {}{}.",
+            Utils.NL, clientStates.entrySet().stream().map(Map.Entry::toString).collect(Collectors.joining(Utils.NL)));
+    }
+
+    /**
+     * Builds a map from client to state, and readies each ClientState for assignment by adding any missing prev tasks
+     * and computing the per-task overall lag based on the fetched end offsets for each changelog.
+     *
+     * @param clientStates a map from each client to its state, including offset lags. Populated by this method.
+     * @param clientMetadataMap a map from each client to its full metadata
+     * @param taskForPartition map from topic partition to its corresponding task
+     * @param changelogsByStatefulTask map from each stateful task to its set of changelog topic partitions
+     *
+     * @return whether we were able to successfully fetch the changelog end offsets and compute each client's lag
+     */
+    private boolean populateClientStatesMap(final Map<UUID, ClientState> clientStates,
+                                            final Map<UUID, ClientMetadata> clientMetadataMap,
+                                            final Map<TopicPartition, TaskId> taskForPartition,
+                                            final Map<TaskId, Set<TopicPartition>> changelogsByStatefulTask) {
+        boolean fetchEndOffsetsSuccessful;
+        Map<TaskId, Long> allTaskEndOffsetSums;
+        try {
+            final Collection<TopicPartition> allChangelogPartitions = changelogsByStatefulTask.values().stream()
+                                                                          .flatMap(Collection::stream)
+                                                                          .collect(Collectors.toList());
+            final Map<TopicPartition, ListOffsetsResultInfo> endOffsets =
+                fetchEndOffsets(allChangelogPartitions, adminClient, Duration.ofMillis(adminClientTimeout));
+            allTaskEndOffsetSums = computeEndOffsetSumsByTask(endOffsets, changelogsByStatefulTask);
+            fetchEndOffsetsSuccessful = true;
+        } catch (final StreamsException e) {
+            allTaskEndOffsetSums = null;
+            fetchEndOffsetsSuccessful = false;
+            setAssignmentErrorCode(AssignorError.REBALANCE_NEEDED.code());
+        }
 
-        final Map<UUID, ClientState> states = new HashMap<>();
         for (final Map.Entry<UUID, ClientMetadata> entry : clientMetadataMap.entrySet()) {
             final UUID uuid = entry.getKey();
             final ClientState state = entry.getValue().state;
-            states.put(uuid, state);
 
-            // there are two cases where we need to construct the prevTasks from the ownedPartitions:
-            // 1) COOPERATIVE clients on version 2.4-2.5 do not encode active tasks and rely on ownedPartitions instead
+            // there are three cases where we need to construct some or all of the prevTasks from the ownedPartitions:
+            // 1) COOPERATIVE clients on version 2.4-2.5 do not encode active tasks at all and rely on ownedPartitions
             // 2) future client during version probing, when we can't decode the future subscription info's prev tasks
-            if (!state.ownedPartitions().isEmpty() && (uuid == FUTURE_ID || state.prevActiveTasks().isEmpty())) {
+            // 3) stateless tasks are not encoded in the task lags, and must be figured out from the ownedPartitions
+            if (!state.ownedPartitions().isEmpty()) {
                 final Set<TaskId> previousActiveTasks = new HashSet<>();
                 for (final Map.Entry<TopicPartition, String> partitionEntry : state.ownedPartitions().entrySet()) {
                     final TopicPartition tp = partitionEntry.getKey();
@@ -746,18 +823,75 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
                 }
                 state.addPreviousActiveTasks(previousActiveTasks);
             }
-        }
 
-        log.debug("Assigning tasks {} to clients {} with number of replicas {}",
-            partitionsForTask.keySet(), states, assignmentConfigs.numStandbyReplicas);
+            if (fetchEndOffsetsSuccessful) {
+                state.computeTaskLags(allTaskEndOffsetSums);
+            }
+            clientStates.put(uuid, state);
+        }
+        return fetchEndOffsetsSuccessful;
+    }
 
-        // assign tasks to clients
-        final StickyTaskAssignor<UUID> taskAssignor =
-            new StickyTaskAssignor<>(states, partitionsForTask.keySet(), standbyTaskIds);
-        taskAssignor.assign(assignmentConfigs.numStandbyReplicas);
+    /**
+     * @param endOffsets the listOffsets result from the adminClient, or null if the request failed
+     * @param changelogsByStatefulTask map from stateful task to its set of changelog topic partitions
+     *
+     * @return Map from stateful task to its total end offset summed across all changelog partitions
+     */
+    private Map<TaskId, Long> computeEndOffsetSumsByTask(final Map<TopicPartition, ListOffsetsResultInfo> endOffsets,
+                                                         final Map<TaskId, Set<TopicPartition>> changelogsByStatefulTask) {
+        final Map<TaskId, Long> taskEndOffsetSums = new HashMap<>();
+        for (final Map.Entry<TaskId, Set<TopicPartition>> taskEntry : changelogsByStatefulTask.entrySet()) {
+            final TaskId task = taskEntry.getKey();
+            final Set<TopicPartition> changelogs = taskEntry.getValue();
+
+            taskEndOffsetSums.put(task, 0L);
+            for (final TopicPartition changelog : changelogs) {
+                final ListOffsetsResultInfo offsetResult = endOffsets.get(changelog);
+                if (offsetResult == null) {
+                    log.debug("Fetched end offsets did not contain the changelog {} of task {}", changelog, task);
+                    throw new IllegalStateException("Could not get end offset for " + changelog);
+                }
+                taskEndOffsetSums.computeIfPresent(task, (id, curOffsetSum) -> curOffsetSum + offsetResult.offset());
+            }
+        }
+        return taskEndOffsetSums;
+    }
 
-        log.info("Assigned tasks to clients as {}{}.", Utils.NL, states.entrySet().stream()
-                                                                     .map(Map.Entry::toString).collect(Collectors.joining(Utils.NL)));
+    /**
+     * Rankings are computed as follows, with lower being more caught up:
+     *      Rank -1: active running task
+     *      Rank 0: standby or restoring task whose overall lag is within the acceptableRecoveryLag bounds
+     *      Rank 1: tasks whose lag is unknown, eg because it was not encoded in an older version subscription
+     *      Rank 1+: all other tasks are ranked according to their actual total lag
+     * @return Sorted set of all client candidates for each stateful task, ranked by their overall lag
+     */
+    static Map<TaskId, SortedSet<RankedClient<UUID>>> buildClientRankingsByTask(final Set<TaskId> statefulTasks,
+                                                                                final Map<UUID, ClientState> states,
+                                                                                final long acceptableRecoveryLag) {
+        final Map<TaskId, SortedSet<RankedClient<UUID>>> statefulTasksToRankedCandidates = new TreeMap<>();
+
+        for (final TaskId task : statefulTasks) {
+            final SortedSet<RankedClient<UUID>> rankedClientCandidates = new TreeSet<>();
+            statefulTasksToRankedCandidates.put(task, rankedClientCandidates);
+
+            for (final Map.Entry<UUID, ClientState> clientEntry : states.entrySet()) {
+                final UUID clientId = clientEntry.getKey();
+                final long taskLag = clientEntry.getValue().lagFor(task);
+                final long clientRank;
+                if (taskLag == Task.LATEST_OFFSET) {
+                    clientRank = Task.LATEST_OFFSET;
+                } else if (taskLag == UNKNOWN_OFFSET_SUM) {
+                    clientRank = 1L;
+                } else if (taskLag <= acceptableRecoveryLag) {
+                    clientRank = 0L;
+                } else {
+                    clientRank = taskLag;
+                }
+                rankedClientCandidates.add(new RankedClient<>(clientId, clientRank));
+            }
+        }
+        return statefulTasksToRankedCandidates;
     }
 
     /**
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfiguration.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfiguration.java
index 6aba395..9086479 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfiguration.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfiguration.java
@@ -18,6 +18,7 @@ package org.apache.kafka.streams.processor.internals.assignment;
 
 import org.apache.kafka.clients.CommonClientConfigs;
 import org.apache.kafka.clients.admin.Admin;
+import org.apache.kafka.clients.admin.AdminClientConfig;
 import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.RebalanceProtocol;
 import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.config.ConfigException;
@@ -45,6 +46,8 @@ public final class AssignorConfiguration {
     private final String userEndPoint;
     private final TaskManager taskManager;
     private final StreamsMetadataState streamsMetadataState;
+    private final Admin adminClient;
+    private final int adminClientTimeout;
     private final InternalTopicManager internalTopicManager;
     private final CopartitionedTopicsEnforcer copartitionedTopicsEnforcer;
     private final StreamsConfig streamsConfig;
@@ -144,9 +147,11 @@ public final class AssignorConfiguration {
                 throw fatalException;
             }
 
-            internalTopicManager = new InternalTopicManager((Admin) o, streamsConfig);
+            adminClient = (Admin) o;
+            internalTopicManager = new InternalTopicManager(adminClient, streamsConfig);
         }
 
+        adminClientTimeout = streamsConfig.getInt(AdminClientConfig.REQUEST_TIMEOUT_MS_CONFIG);
 
         copartitionedTopicsEnforcer = new CopartitionedTopicsEnforcer(logPrefix);
     }
@@ -250,6 +255,14 @@ public final class AssignorConfiguration {
         return userEndPoint;
     }
 
+    public Admin getAdminClient() {
+        return adminClient;
+    }
+
+    public int getAdminClientTimeout() {
+        return adminClientTimeout;
+    }
+
     public InternalTopicManager getInternalTopicManager() {
         return internalTopicManager;
     }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
index df42b14..9a522ed 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
@@ -16,6 +16,8 @@
  */
 package org.apache.kafka.streams.processor.internals.assignment;
 
+import static org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo.UNKNOWN_OFFSET_SUM;
+
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.streams.processor.TaskId;
 
@@ -24,6 +26,7 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Map;
 import java.util.Set;
+import org.apache.kafka.streams.processor.internals.Task;
 
 public class ClientState {
     private final Set<TaskId> activeTasks;
@@ -34,6 +37,8 @@ public class ClientState {
     private final Set<TaskId> prevAssignedTasks;
 
     private final Map<TopicPartition, String> ownedPartitions;
+    private final Map<TaskId, Long> taskOffsetSums; // contains only stateful tasks we previously owned
+    private final Map<TaskId, Long> taskLagTotals;  // contains lag for all stateful tasks in the app topology
 
     private int capacity;
 
@@ -49,7 +54,9 @@ public class ClientState {
              new HashSet<>(),
              new HashSet<>(),
              new HashMap<>(),
-             capacity);
+             new HashMap<>(),
+             new HashMap<>(),
+            capacity);
     }
 
     private ClientState(final Set<TaskId> activeTasks,
@@ -59,6 +66,8 @@ public class ClientState {
                         final Set<TaskId> prevStandbyTasks,
                         final Set<TaskId> prevAssignedTasks,
                         final Map<TopicPartition, String> ownedPartitions,
+                        final Map<TaskId, Long> taskOffsetSums,
+                        final Map<TaskId, Long> taskLagTotals,
                         final int capacity) {
         this.activeTasks = activeTasks;
         this.standbyTasks = standbyTasks;
@@ -67,6 +76,8 @@ public class ClientState {
         this.prevStandbyTasks = prevStandbyTasks;
         this.prevAssignedTasks = prevAssignedTasks;
         this.ownedPartitions = ownedPartitions;
+        this.taskOffsetSums = taskOffsetSums;
+        this.taskLagTotals = taskLagTotals;
         this.capacity = capacity;
     }
 
@@ -79,17 +90,29 @@ public class ClientState {
             new HashSet<>(prevStandbyTasks),
             new HashSet<>(prevAssignedTasks),
             new HashMap<>(ownedPartitions),
+            new HashMap<>(taskOffsetSums),
+            new HashMap<>(taskLagTotals),
             capacity);
     }
 
-    public void assign(final TaskId taskId, final boolean active) {
-        if (active) {
-            activeTasks.add(taskId);
-        } else {
-            standbyTasks.add(taskId);
-        }
+    void assignActive(final TaskId task) {
+        activeTasks.add(task);
+        assignedTasks.add(task);
+    }
 
-        assignedTasks.add(taskId);
+    void assignStandby(final TaskId task) {
+        standbyTasks.add(task);
+        assignedTasks.add(task);
+    }
+
+    public void assignActiveTasks(final Collection<TaskId> tasks) {
+        activeTasks.addAll(tasks);
+        assignedTasks.addAll(tasks);
+    }
+
+    void assignStandbyTasks(final Collection<TaskId> tasks) {
+        standbyTasks.addAll(tasks);
+        assignedTasks.addAll(tasks);
     }
 
     public Set<TaskId> activeTasks() {
@@ -100,11 +123,11 @@ public class ClientState {
         return standbyTasks;
     }
 
-    public Set<TaskId> prevActiveTasks() {
+    Set<TaskId> prevActiveTasks() {
         return prevActiveTasks;
     }
 
-    public Set<TaskId> prevStandbyTasks() {
+    Set<TaskId> prevStandbyTasks() {
         return prevStandbyTasks;
     }
 
@@ -129,9 +152,10 @@ public class ClientState {
     public void addPreviousActiveTasks(final Set<TaskId> prevTasks) {
         prevActiveTasks.addAll(prevTasks);
         prevAssignedTasks.addAll(prevTasks);
+        prevStandbyTasks.removeAll(prevTasks);
     }
 
-    public void addPreviousStandbyTasks(final Set<TaskId> standbyTasks) {
+    void addPreviousStandbyTasks(final Set<TaskId> standbyTasks) {
         prevStandbyTasks.addAll(standbyTasks);
         prevAssignedTasks.addAll(standbyTasks);
     }
@@ -142,28 +166,82 @@ public class ClientState {
         }
     }
 
+    public void addPreviousTasksAndOffsetSums(final Map<TaskId, Long> taskOffsetSums) {
+        for (final Map.Entry<TaskId, Long> taskEntry : taskOffsetSums.entrySet()) {
+            final TaskId id = taskEntry.getKey();
+            final long offsetSum = taskEntry.getValue();
+            if (offsetSum == Task.LATEST_OFFSET) {
+                prevActiveTasks.add(id);
+            } else {
+                prevStandbyTasks.add(id);
+            }
+            prevAssignedTasks.add(id);
+        }
+        this.taskOffsetSums.putAll(taskOffsetSums);
+    }
+
+    /**
+     * Compute the lag for each stateful task, including tasks this client did not previously have.
+     */
+    public void computeTaskLags(final Map<TaskId, Long> allTaskEndOffsetSums) {
+        if (!taskLagTotals.isEmpty()) {
+            throw new IllegalStateException("Already computed task lags for this client.");
+        }
+
+        for (final Map.Entry<TaskId, Long> taskEntry : allTaskEndOffsetSums.entrySet()) {
+            final TaskId task = taskEntry.getKey();
+            final Long endOffsetSum = taskEntry.getValue();
+            final Long offsetSum = taskOffsetSums.getOrDefault(task, 0L);
+
+            if (endOffsetSum < offsetSum) {
+                throw new IllegalStateException("Task " + task + " had endOffsetSum=" + endOffsetSum +
+                                                    " smaller than offsetSum=" + offsetSum);
+            }
+
+            if (offsetSum == Task.LATEST_OFFSET) {
+                taskLagTotals.put(task, Task.LATEST_OFFSET);
+            } else if (offsetSum == UNKNOWN_OFFSET_SUM) {
+                taskLagTotals.put(task, UNKNOWN_OFFSET_SUM);
+            } else {
+                taskLagTotals.put(task, endOffsetSum - offsetSum);
+            }
+        }
+    }
+
+    /**
+     * Returns the total lag across all logged stores in the task. Equal to the end offset sum if this client
+     * did not have any state for this task on disk.
+     *
+     * @return  end offset sum - offset sum
+     *          Task.LATEST_OFFSET if this was previously an active running task on this client
+     */
+    public long lagFor(final TaskId task) {
+        final Long totalLag = taskLagTotals.get(task);
+
+        if (totalLag == null) {
+            throw new IllegalStateException("Tried to lookup lag for unknown task " + task);
+        } else {
+            return totalLag;
+        }
+    }
+
     public void removeFromAssignment(final TaskId task) {
         activeTasks.remove(task);
         assignedTasks.remove(task);
     }
 
-    @Override
-    public String toString() {
-        return "[activeTasks: (" + activeTasks +
-                ") standbyTasks: (" + standbyTasks +
-                ") assignedTasks: (" + assignedTasks +
-                ") prevActiveTasks: (" + prevActiveTasks +
-                ") prevStandbyTasks: (" + prevStandbyTasks +
-                ") prevAssignedTasks: (" + prevAssignedTasks +
-                ") prevOwnedPartitionsByConsumerId: (" + ownedPartitions.keySet() +
-                ") capacity: " + capacity +
-                "]";
-    }
-
     boolean reachedCapacity() {
         return assignedTasks.size() >= capacity;
     }
 
+    int capacity() {
+        return capacity;
+    }
+
+    boolean hasUnfulfilledQuota(final int tasksPerThread) {
+        return activeTasks.size() < capacity * tasksPerThread;
+    }
+
     boolean hasMoreAvailableCapacityThan(final ClientState other) {
         if (this.capacity <= 0) {
             throw new IllegalStateException("Capacity of this ClientState must be greater than 0.");
@@ -189,6 +267,20 @@ public class ClientState {
         return assignedTasks.contains(taskId);
     }
 
+    @Override
+    public String toString() {
+        return "[activeTasks: (" + activeTasks +
+                   ") standbyTasks: (" + standbyTasks +
+                   ") assignedTasks: (" + assignedTasks +
+                   ") prevActiveTasks: (" + prevActiveTasks +
+                   ") prevStandbyTasks: (" + prevStandbyTasks +
+                   ") prevAssignedTasks: (" + prevAssignedTasks +
+                   ") prevOwnedPartitionsByConsumerId: (" + ownedPartitions.keySet() +
+                   ") changelogOffsetTotalsByTask: (" + taskOffsetSums.entrySet() +
+                   ") capacity: " + capacity +
+                   "]";
+    }
+
     // Visible for testing
     Set<TaskId> assignedTasks() {
         return assignedTasks;
@@ -198,16 +290,4 @@ public class ClientState {
         return prevAssignedTasks;
     }
 
-    int capacity() {
-        return capacity;
-    }
-
-    boolean hasUnfulfilledQuota(final int tasksPerThread) {
-        return activeTasks.size() < capacity * tasksPerThread;
-    }
-
-    // the following methods are used for testing only
-    public void assignActiveTasks(final Collection<TaskId> tasks) {
-        activeTasks.addAll(tasks);
-    }
 }
\ No newline at end of file
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignor.java
index 64ff5e6..f40ff85 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignor.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignor.java
@@ -31,7 +31,7 @@ import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
 
-public class StickyTaskAssignor<ID> implements TaskAssignor<ID, TaskId> {
+public class StickyTaskAssignor<ID> implements TaskAssignor<ID> {
 
     private static final Logger log = LoggerFactory.getLogger(StickyTaskAssignor.class);
     private final Map<ID, ClientState> clients;
@@ -41,12 +41,15 @@ public class StickyTaskAssignor<ID> implements TaskAssignor<ID, TaskId> {
     private final Map<TaskId, Set<ID>> previousStandbyTaskAssignment = new HashMap<>();
     private final TaskPairs taskPairs;
 
+    private boolean mustPreserveActiveTaskAssignment;
+
     public StickyTaskAssignor(final Map<ID, ClientState> clients,
                               final Set<TaskId> allTaskIds,
                               final Set<TaskId> standbyTaskIds) {
         this.clients = clients;
         this.allTaskIds = allTaskIds;
         this.standbyTaskIds = standbyTaskIds;
+        this.mustPreserveActiveTaskAssignment = false;
 
         final int maxPairs = allTaskIds.size() * (allTaskIds.size() - 1) / 2;
         taskPairs = new TaskPairs(maxPairs);
@@ -59,6 +62,10 @@ public class StickyTaskAssignor<ID> implements TaskAssignor<ID, TaskId> {
         assignStandby(numStandbyReplicas);
     }
 
+    public void preservePreviousTaskAssignment() {
+        mustPreserveActiveTaskAssignment = true;
+    }
+
     private void assignStandby(final int numStandbyReplicas) {
         for (final TaskId taskId : standbyTaskIds) {
             for (int i = 0; i < numStandbyReplicas; i++) {
@@ -88,7 +95,7 @@ public class StickyTaskAssignor<ID> implements TaskAssignor<ID, TaskId> {
             final TaskId taskId = entry.getKey();
             if (allTaskIds.contains(taskId)) {
                 final ClientState client = clients.get(entry.getValue());
-                if (client.hasUnfulfilledQuota(tasksPerThread)) {
+                if (mustPreserveActiveTaskAssignment || client.hasUnfulfilledQuota(tasksPerThread)) {
                     assignTaskToClient(assigned, taskId, client);
                 }
             }
@@ -125,12 +132,16 @@ public class StickyTaskAssignor<ID> implements TaskAssignor<ID, TaskId> {
     private void allocateTaskWithClientCandidates(final TaskId taskId, final Set<ID> clientsWithin, final boolean active) {
         final ClientState client = findClient(taskId, clientsWithin);
         taskPairs.addPairs(taskId, client.assignedTasks());
-        client.assign(taskId, active);
+        if (active) {
+            client.assignActive(taskId);
+        } else {
+            client.assignStandby(taskId);
+        }
     }
 
     private void assignTaskToClient(final Set<TaskId> assigned, final TaskId taskId, final ClientState client) {
         taskPairs.addPairs(taskId, client.assignedTasks());
-        client.assign(taskId, true);
+        client.assignActive(taskId);
         assigned.add(taskId);
     }
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfo.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfo.java
index 411ff04..49496cf 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfo.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfo.java
@@ -46,7 +46,7 @@ public class SubscriptionInfo {
 
     static final int UNKNOWN = -1;
     static final int MIN_VERSION_OFFSET_SUM_SUBSCRIPTION = 7;
-    static final long UNKNOWN_OFFSET_SUM = -3L;
+    public static final long UNKNOWN_OFFSET_SUM = -3L;
 
     private final SubscriptionInfoData data;
     private Set<TaskId> prevTasksCache = null;
@@ -198,7 +198,7 @@ public class SubscriptionInfo {
         return standbyTasksCache;
     }
 
-    Map<TaskId, Long> taskOffsetSums() {
+    public Map<TaskId, Long> taskOffsetSums() {
         if (taskOffsetSumsCache == null) {
             taskOffsetSumsCache = new HashMap<>();
             if (data.version() >= MIN_VERSION_OFFSET_SUM_SUBSCRIPTION) {
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java
index 162add0..679e416 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java
@@ -16,6 +16,6 @@
  */
 package org.apache.kafka.streams.processor.internals.assignment;
 
-public interface TaskAssignor<C, T extends Comparable<T>> {
+public interface TaskAssignor<C> {
     void assign(int numStandbyReplicas);
 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java b/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java
index 7087085..bc1ca1d 100644
--- a/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java
@@ -16,13 +16,17 @@
  */
 package org.apache.kafka.streams;
 
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeoutException;
 import org.apache.kafka.clients.admin.Admin;
+import org.apache.kafka.clients.admin.AdminClient;
 import org.apache.kafka.clients.admin.ListOffsetsResult;
 import org.apache.kafka.clients.admin.ListOffsetsResult.ListOffsetsResultInfo;
 import org.apache.kafka.clients.admin.MockAdminClient;
 import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.clients.producer.MockProducer;
 import org.apache.kafka.common.Cluster;
+import org.apache.kafka.common.KafkaFuture;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.internals.KafkaFutureImpl;
 import org.apache.kafka.common.metrics.MetricConfig;
@@ -33,6 +37,7 @@ import org.apache.kafka.common.serialization.Serdes;
 import org.apache.kafka.common.serialization.StringSerializer;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.internals.metrics.ClientMetrics;
 import org.apache.kafka.streams.kstream.Materialized;
 import org.apache.kafka.streams.processor.AbstractProcessor;
@@ -83,16 +88,21 @@ import java.util.concurrent.atomic.AtomicReference;
 
 import static java.util.Collections.emptyList;
 import static java.util.Collections.singletonList;
+import static org.apache.kafka.streams.KafkaStreams.fetchEndOffsets;
+import static org.apache.kafka.streams.KafkaStreams.fetchEndOffsetsWithoutTimeout;
 import static org.easymock.EasyMock.anyInt;
 import static org.easymock.EasyMock.anyLong;
 import static org.easymock.EasyMock.anyObject;
 import static org.easymock.EasyMock.anyString;
 import static org.easymock.EasyMock.capture;
+import static org.easymock.EasyMock.replay;
+import static org.easymock.EasyMock.verify;
 import static org.hamcrest.CoreMatchers.hasItem;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
@@ -883,6 +893,60 @@ public class KafkaStreamsTest {
         startStreamsAndCheckDirExists(topology, true);
     }
 
+    @Test
+    public void fetchEndOffsetsShouldRethrowRuntimeExceptionAsStreamsException() {
+        final Admin adminClient = EasyMock.createMock(AdminClient.class);
+        EasyMock.expect(adminClient.listOffsets(EasyMock.anyObject())).andThrow(new RuntimeException());
+        replay(adminClient);
+        assertThrows(StreamsException.class, () ->  fetchEndOffsetsWithoutTimeout(emptyList(), adminClient));
+        verify(adminClient);
+    }
+
+    @Test
+    public void fetchEndOffsetsShouldRethrowInterruptedExceptionAsStreamsException() throws InterruptedException, ExecutionException {
+        final Admin adminClient = EasyMock.createMock(AdminClient.class);
+        final ListOffsetsResult result = EasyMock.createNiceMock(ListOffsetsResult.class);
+        final KafkaFuture<Map<TopicPartition, ListOffsetsResultInfo>> allFuture = EasyMock.createMock(KafkaFuture.class);
+
+        EasyMock.expect(adminClient.listOffsets(EasyMock.anyObject())).andStubReturn(result);
+        EasyMock.expect(result.all()).andStubReturn(allFuture);
+        EasyMock.expect(allFuture.get()).andThrow(new InterruptedException());
+        replay(adminClient, result, allFuture);
+
+        assertThrows(StreamsException.class, () -> fetchEndOffsetsWithoutTimeout(emptyList(), adminClient));
+        verify(adminClient);
+    }
+
+    @Test
+    public void fetchEndOffsetsShouldRethrowExecutionExceptionAsStreamsException() throws InterruptedException, ExecutionException {
+        final Admin adminClient = EasyMock.createMock(AdminClient.class);
+        final ListOffsetsResult result = EasyMock.createNiceMock(ListOffsetsResult.class);
+        final KafkaFuture<Map<TopicPartition, ListOffsetsResultInfo>> allFuture = EasyMock.createMock(KafkaFuture.class);
+
+        EasyMock.expect(adminClient.listOffsets(EasyMock.anyObject())).andStubReturn(result);
+        EasyMock.expect(result.all()).andStubReturn(allFuture);
+        EasyMock.expect(allFuture.get()).andThrow(new ExecutionException(new RuntimeException()));
+        replay(adminClient, result, allFuture);
+
+        assertThrows(StreamsException.class, () -> fetchEndOffsetsWithoutTimeout(emptyList(), adminClient));
+        verify(adminClient);
+    }
+
+    @Test
+    public void fetchEndOffsetsWithTimeoutShouldRethrowTimeoutExceptionAsStreamsException() throws InterruptedException, ExecutionException, TimeoutException {
+        final Admin adminClient = EasyMock.createMock(AdminClient.class);
+        final ListOffsetsResult result = EasyMock.createNiceMock(ListOffsetsResult.class);
+        final KafkaFuture<Map<TopicPartition, ListOffsetsResultInfo>> allFuture = EasyMock.createMock(KafkaFuture.class);
+
+        EasyMock.expect(adminClient.listOffsets(EasyMock.anyObject())).andStubReturn(result);
+        EasyMock.expect(result.all()).andStubReturn(allFuture);
+        EasyMock.expect(allFuture.get(1L, TimeUnit.MILLISECONDS)).andThrow(new TimeoutException());
+        replay(adminClient, result, allFuture);
+
+        assertThrows(StreamsException.class, () -> fetchEndOffsets(emptyList(), adminClient, Duration.ofMillis(1)));
+        verify(adminClient);
+    }
+
     @SuppressWarnings("unchecked")
     private Topology getStatefulTopology(final String inputTopic,
                                          final String outputTopic,
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignorTest.java
index 820ead8..418e13f 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignorTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignorTest.java
@@ -16,7 +16,13 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
+import java.util.Map.Entry;
+import java.util.SortedSet;
 import org.apache.kafka.clients.admin.Admin;
+import org.apache.kafka.clients.admin.AdminClient;
+import org.apache.kafka.clients.admin.AdminClientConfig;
+import org.apache.kafka.clients.admin.ListOffsetsResult;
+import org.apache.kafka.clients.admin.ListOffsetsResult.ListOffsetsResultInfo;
 import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.Assignment;
 import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.GroupSubscription;
 import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.RebalanceProtocol;
@@ -27,10 +33,12 @@ import org.apache.kafka.common.Node;
 import org.apache.kafka.common.PartitionInfo;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.config.ConfigException;
+import org.apache.kafka.common.internals.KafkaFutureImpl;
 import org.apache.kafka.streams.KeyValue;
 import org.apache.kafka.streams.StreamsBuilder;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.TopologyWrapper;
+import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.kstream.JoinWindows;
 import org.apache.kafka.streams.kstream.KStream;
 import org.apache.kafka.streams.kstream.KTable;
@@ -38,7 +46,9 @@ import org.apache.kafka.streams.kstream.KeyValueMapper;
 import org.apache.kafka.streams.kstream.Materialized;
 import org.apache.kafka.streams.kstream.ValueJoiner;
 import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.internals.StreamsPartitionAssignor.RankedClient;
 import org.apache.kafka.streams.processor.internals.assignment.AssignmentInfo;
+import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration;
 import org.apache.kafka.streams.processor.internals.assignment.AssignorError;
 import org.apache.kafka.streams.processor.internals.assignment.ClientState;
 import org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo;
@@ -49,6 +59,7 @@ import org.apache.kafka.test.MockKeyValueStoreBuilder;
 import org.apache.kafka.test.MockProcessorSupplier;
 import org.easymock.Capture;
 import org.easymock.EasyMock;
+import org.junit.Before;
 import org.junit.Test;
 
 import java.nio.ByteBuffer;
@@ -69,11 +80,18 @@ import static java.util.Arrays.asList;
 import static java.util.Collections.emptyList;
 import static java.util.Collections.emptyMap;
 import static java.util.Collections.emptySet;
-import static java.util.Collections.singletonMap;
+import static java.util.Collections.singleton;
+import static java.util.Collections.singletonList;
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.common.utils.Utils.mkSet;
+import static org.apache.kafka.common.utils.Utils.mkSortedSet;
+import static org.apache.kafka.streams.processor.internals.StreamsPartitionAssignor.buildClientRankingsByTask;
 import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.LATEST_SUPPORTED_VERSION;
+import static org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo.UNKNOWN_OFFSET_SUM;
+import static org.easymock.EasyMock.anyObject;
+import static org.easymock.EasyMock.expect;
+import static org.easymock.EasyMock.replay;
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.CoreMatchers.not;
 import static org.hamcrest.Matchers.is;
@@ -156,9 +174,11 @@ public class StreamsPartitionAssignorTest {
 
     private final Set<TaskId> emptyTasks = emptySet();
     private final Map<TaskId, Long> emptyTaskOffsetSums = emptyMap();
+    private final Map<TopicPartition, Long> emptyChangelogEndOffsets = new HashMap<>();
+
     private final UUID uuid1 = UUID.randomUUID();
     private final UUID uuid2 = UUID.randomUUID();
-
+    private final UUID uuid3 = UUID.randomUUID();
     private final SubscriptionInfo defaultSubscriptionInfo = getInfo(uuid1, emptyTasks, emptyTasks);
 
     private final Cluster metadata = new Cluster(
@@ -170,12 +190,14 @@ public class StreamsPartitionAssignorTest {
 
     private final StreamsPartitionAssignor partitionAssignor = new StreamsPartitionAssignor();
     private final MockClientSupplier mockClientSupplier = new MockClientSupplier();
-    private final StreamsConfig streamsConfig = new StreamsConfig(configProps());
     private static final String USER_END_POINT = "localhost:8080";
     private static final String OTHER_END_POINT = "other:9090";
     private static final String APPLICATION_ID = "stream-partition-assignor-test";
+    private static final long ACCEPTABLE_RECOVERY_LAG = 100L;
 
     private TaskManager taskManager;
+    private Admin adminClient;
+    private StreamsConfig streamsConfig = new StreamsConfig(configProps());
     private InternalTopologyBuilder builder = new InternalTopologyBuilder();
     private StreamsMetadataState streamsMetadataState = EasyMock.createNiceMock(StreamsMetadataState.class);
     private final Map<String, Subscription> subscriptions = new HashMap<>();
@@ -186,28 +208,31 @@ public class StreamsPartitionAssignorTest {
         configurationMap.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, USER_END_POINT);
         configurationMap.put(StreamsConfig.InternalConfig.TASK_MANAGER_FOR_PARTITION_ASSIGNOR, taskManager);
         configurationMap.put(StreamsConfig.InternalConfig.STREAMS_METADATA_STATE_FOR_PARTITION_ASSIGNOR, streamsMetadataState);
-        configurationMap.put(StreamsConfig.InternalConfig.STREAMS_ADMIN_CLIENT, EasyMock.createNiceMock(Admin.class));
+        configurationMap.put(StreamsConfig.InternalConfig.STREAMS_ADMIN_CLIENT, adminClient);
         configurationMap.put(StreamsConfig.InternalConfig.ASSIGNMENT_ERROR_CODE, new AtomicInteger());
         return configurationMap;
     }
 
-    private void configureDefault() {
+    private MockInternalTopicManager configureDefault() {
         createDefaultMockTaskManager();
-        configureDefaultPartitionAssignor();
+        return configureDefaultPartitionAssignor();
     }
 
-    // TaskManager must be created first
-    private void configureDefaultPartitionAssignor() {
-        partitionAssignor.configure(configProps());
-        EasyMock.replay(taskManager);
+    // Make sure to complete setting up any mocks (such as TaskManager or AdminClient) before configuring the assignor
+    private MockInternalTopicManager configureDefaultPartitionAssignor() {
+        return configurePartitionAssignorWith(emptyMap());
     }
 
-    // TaskManager must be created first
-    private void configurePartitionAssignorWith(final Map<String, Object> props) {
-        final Map<String, Object> configurationMap = configProps();
-        configurationMap.putAll(props);
-        partitionAssignor.configure(configurationMap);
-        EasyMock.replay(taskManager);
+    // Make sure to complete setting up any mocks (such as TaskManager or AdminClient) before configuring the assignor
+    private MockInternalTopicManager configurePartitionAssignorWith(final Map<String, Object> props) {
+        final Map<String, Object> configMap = configProps();
+        configMap.putAll(props);
+
+        streamsConfig = new StreamsConfig(configMap);
+        partitionAssignor.configure(configMap);
+        EasyMock.replay(taskManager, adminClient);
+
+        return overwriteInternalTopicManagerWithMock();
     }
 
     private void createDefaultMockTaskManager() {
@@ -223,13 +248,47 @@ public class StreamsPartitionAssignorTest {
     private void createMockTaskManager(final Map<TaskId, Long> taskOffsetSums,
                                        final UUID processId) {
         taskManager = EasyMock.createNiceMock(TaskManager.class);
-        EasyMock.expect(taskManager.builder()).andReturn(builder).anyTimes();
-        EasyMock.expect(taskManager.getTaskOffsetSums()).andReturn(taskOffsetSums).anyTimes();
-        EasyMock.expect(taskManager.processId()).andReturn(processId).anyTimes();
+        expect(taskManager.builder()).andReturn(builder).anyTimes();
+        expect(taskManager.getTaskOffsetSums()).andReturn(taskOffsetSums).anyTimes();
+        expect(taskManager.processId()).andReturn(processId).anyTimes();
         builder.setApplicationId(APPLICATION_ID);
         builder.buildTopology();
     }
 
+    // If you don't care about setting the end offsets for each specific topic partition, the helper method
+    // getTopicPartitionOffsetMap is useful for building this input map for all partitions
+    private void createMockAdminClient(final Map<TopicPartition, Long> changelogEndOffsets) {
+        adminClient = EasyMock.createMock(AdminClient.class);
+
+        final ListOffsetsResult result = EasyMock.createNiceMock(ListOffsetsResult.class);
+        final KafkaFutureImpl<Map<TopicPartition, ListOffsetsResultInfo>> allFuture = new KafkaFutureImpl<>();
+        allFuture.complete(changelogEndOffsets.entrySet().stream().collect(Collectors.toMap(
+            Entry::getKey,
+            t -> {
+                final ListOffsetsResultInfo info = EasyMock.createNiceMock(ListOffsetsResultInfo.class);
+                expect(info.offset()).andStubReturn(t.getValue());
+                EasyMock.replay(info);
+                return info;
+            }))
+        );
+
+        expect(adminClient.listOffsets(anyObject())).andStubReturn(result);
+        expect(result.all()).andReturn(allFuture);
+
+        EasyMock.replay(result);
+    }
+
+    private MockInternalTopicManager overwriteInternalTopicManagerWithMock() {
+        final MockInternalTopicManager mockInternalTopicManager = new MockInternalTopicManager(streamsConfig, mockClientSupplier.restoreConsumer);
+        partitionAssignor.setInternalTopicManager(mockInternalTopicManager);
+        return mockInternalTopicManager;
+    }
+
+    @Before
+    public void setUp() {
+        createMockAdminClient(emptyChangelogEndOffsets);
+    }
+
     @Test
     public void shouldUseEagerRebalancingProtocol() {
         createDefaultMockTaskManager();
@@ -478,8 +537,6 @@ public class StreamsPartitionAssignorTest {
         createMockTaskManager(prevTasks10, standbyTasks10);
         configureDefaultPartitionAssignor();
 
-        partitionAssignor.setInternalTopicManager(new MockInternalTopicManager(streamsConfig, mockClientSupplier.restoreConsumer));
-
         subscriptions.put("consumer10",
                           new Subscription(
                               topics,
@@ -566,8 +623,6 @@ public class StreamsPartitionAssignorTest {
 
         configureDefault();
 
-        partitionAssignor.setInternalTopicManager(new MockInternalTopicManager(streamsConfig, mockClientSupplier.restoreConsumer));
-
         subscriptions.put("consumer10",
                           new Subscription(
                               topics,
@@ -611,10 +666,12 @@ public class StreamsPartitionAssignorTest {
         final Set<TaskId> allTasks = mkSet(task0_0, task0_1, task0_2);
 
         createDefaultMockTaskManager();
+        createMockAdminClient(getTopicPartitionOffsetsMap(
+            singletonList(APPLICATION_ID + "-store1-changelog"),
+            singletonList(3))
+        );
         configurePartitionAssignorWith(Collections.singletonMap(StreamsConfig.PARTITION_GROUPER_CLASS_CONFIG, SingleGroupPartitionGrouperStub.class));
 
-        partitionAssignor.setInternalTopicManager(new MockInternalTopicManager(streamsConfig, mockClientSupplier.restoreConsumer));
-
         // will throw exception if it fails
         subscriptions.put("consumer10",
                           new Subscription(
@@ -703,8 +760,6 @@ public class StreamsPartitionAssignorTest {
         createMockTaskManager(prevTasks10, emptyTasks);
         configureDefaultPartitionAssignor();
 
-        partitionAssignor.setInternalTopicManager(new MockInternalTopicManager(streamsConfig, mockClientSupplier.restoreConsumer));
-
         subscriptions.put("consumer10",
                           new Subscription(
                               topics,
@@ -761,10 +816,14 @@ public class StreamsPartitionAssignorTest {
         final TaskId task12 = new TaskId(1, 2);
         final List<TaskId> tasks = asList(task00, task01, task02, task10, task11, task12);
 
+        createMockAdminClient(getTopicPartitionOffsetsMap(
+            asList(APPLICATION_ID + "-store1-changelog",
+                   APPLICATION_ID + "-store2-changelog",
+                   APPLICATION_ID + "-store3-changelog"),
+            asList(3, 3, 3))
+        );
         configureDefault();
 
-        partitionAssignor.setInternalTopicManager(new MockInternalTopicManager(streamsConfig, mockClientSupplier.restoreConsumer));
-
         subscriptions.put("consumer10",
                           new Subscription(topics, defaultSubscriptionInfo.encode()));
         subscriptions.put("consumer11",
@@ -823,10 +882,6 @@ public class StreamsPartitionAssignorTest {
 
     @Test
     public void testAssignWithStandbyReplicasAndStatelessTasks() {
-        final Map<String, Object> props = configProps();
-        props.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, "1");
-        final StreamsConfig streamsConfig = new StreamsConfig(props);
-
         builder.addSource(null, "source1", null, null, null, "topic1", "topic2");
         builder.addProcessor("processor", new MockProcessorSupplier(), "source1");
 
@@ -834,7 +889,6 @@ public class StreamsPartitionAssignorTest {
 
         createMockTaskManager(mkSet(task0_0), emptySet());
         configurePartitionAssignorWith(Collections.singletonMap(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1));
-        partitionAssignor.setInternalTopicManager(new MockInternalTopicManager(streamsConfig, mockClientSupplier.restoreConsumer));
 
         subscriptions.put("consumer10",
             new Subscription(
@@ -857,10 +911,6 @@ public class StreamsPartitionAssignorTest {
 
     @Test
     public void testAssignWithStandbyReplicasAndLoggingDisabled() {
-        final Map<String, Object> props = configProps();
-        props.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, "1");
-        final StreamsConfig streamsConfig = new StreamsConfig(props);
-
         builder.addSource(null, "source1", null, null, null, "topic1", "topic2");
         builder.addProcessor("processor", new MockProcessorSupplier(), "source1");
         builder.addStateStore(new MockKeyValueStoreBuilder("store1", false).withLoggingDisabled(), "processor");
@@ -869,7 +919,6 @@ public class StreamsPartitionAssignorTest {
 
         createMockTaskManager(mkSet(task0_0), emptySet());
         configurePartitionAssignorWith(Collections.singletonMap(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1));
-        partitionAssignor.setInternalTopicManager(new MockInternalTopicManager(streamsConfig, mockClientSupplier.restoreConsumer));
 
         subscriptions.put("consumer10",
             new Subscription(
@@ -892,10 +941,6 @@ public class StreamsPartitionAssignorTest {
 
     @Test
     public void testAssignWithStandbyReplicas() {
-        final Map<String, Object> props = configProps();
-        props.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, "1");
-        final StreamsConfig streamsConfig = new StreamsConfig(props);
-
         builder.addSource(null, "source1", null, null, null, "topic1");
         builder.addSource(null, "source2", null, null, null, "topic2");
         builder.addProcessor("processor", new MockProcessorSupplier(), "source1", "source2");
@@ -917,10 +962,12 @@ public class StreamsPartitionAssignorTest {
         final Set<TaskId> standbyTasks02 = mkSet(task0_2);
 
         createMockTaskManager(prevTasks00, standbyTasks01);
+        createMockAdminClient(getTopicPartitionOffsetsMap(
+            singletonList(APPLICATION_ID + "-store1-changelog"),
+            singletonList(3))
+        );
         configurePartitionAssignorWith(Collections.singletonMap(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1));
 
-        partitionAssignor.setInternalTopicManager(new MockInternalTopicManager(streamsConfig, mockClientSupplier.restoreConsumer));
-
         subscriptions.put("consumer10",
                           new Subscription(
                               topics,
@@ -1036,17 +1083,14 @@ public class StreamsPartitionAssignorTest {
         final List<String> topics = asList("topic1", APPLICATION_ID + "-topicX");
         final Set<TaskId> allTasks = mkSet(task0_0, task0_1, task0_2);
 
-        configureDefault();
-
-        final MockInternalTopicManager internalTopicManager = new MockInternalTopicManager(streamsConfig, mockClientSupplier.restoreConsumer);
-        partitionAssignor.setInternalTopicManager(internalTopicManager);
+        final MockInternalTopicManager internalTopicManager = configureDefault();
 
         subscriptions.put("consumer10",
                           new Subscription(
                               topics,
                               defaultSubscriptionInfo.encode())
         );
-        partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)).groupAssignment();
+        partitionAssignor.assign(metadata, new GroupSubscription(subscriptions));
 
         // check prepared internal topics
         assertEquals(1, internalTopicManager.readyTopics.size());
@@ -1067,18 +1111,14 @@ public class StreamsPartitionAssignorTest {
         final List<String> topics = asList("topic1", APPLICATION_ID + "-topicX", APPLICATION_ID + "-topicZ");
         final Set<TaskId> allTasks = mkSet(task0_0, task0_1, task0_2);
 
-        configureDefault();
-
-        final MockInternalTopicManager internalTopicManager =
-            new MockInternalTopicManager(streamsConfig, mockClientSupplier.restoreConsumer);
-        partitionAssignor.setInternalTopicManager(internalTopicManager);
+        final MockInternalTopicManager internalTopicManager = configureDefault();
 
         subscriptions.put("consumer10",
                           new Subscription(
                               topics,
                               defaultSubscriptionInfo.encode())
         );
-        partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)).groupAssignment();
+        partitionAssignor.assign(metadata, new GroupSubscription(subscriptions));
 
         // check prepared internal topics
         assertEquals(2, internalTopicManager.readyTopics.size());
@@ -1112,12 +1152,13 @@ public class StreamsPartitionAssignorTest {
         final String client = "client1";
         builder = TopologyWrapper.getInternalTopologyBuilder(streamsBuilder.build());
 
-        configureDefault();
+        createMockAdminClient(getTopicPartitionOffsetsMap(
+            asList(APPLICATION_ID + "-topic3-STATE-STORE-0000000002-changelog",
+                   APPLICATION_ID + "-KTABLE-AGGREGATE-STATE-STORE-0000000006-changelog"),
+            asList(4, 4))
+        );
 
-        final MockInternalTopicManager mockInternalTopicManager = new MockInternalTopicManager(
-            streamsConfig,
-            mockClientSupplier.restoreConsumer);
-        partitionAssignor.setInternalTopicManager(mockInternalTopicManager);
+        final MockInternalTopicManager mockInternalTopicManager = configureDefault();
 
         subscriptions.put(client,
                           new Subscription(
@@ -1125,8 +1166,7 @@ public class StreamsPartitionAssignorTest {
                               defaultSubscriptionInfo.encode())
         );
         final Map<String, Assignment> assignment =
-            partitionAssignor.assign(metadata, new GroupSubscription(subscriptions))
-                             .groupAssignment();
+            partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)).groupAssignment();
 
         final Map<String, Integer> expectedCreatedInternalTopics = new HashMap<>();
         expectedCreatedInternalTopics.put(APPLICATION_ID + "-KTABLE-AGGREGATE-STATE-STORE-0000000006-repartition", 4);
@@ -1187,8 +1227,6 @@ public class StreamsPartitionAssignorTest {
         createDefaultMockTaskManager();
         configurePartitionAssignorWith(Collections.singletonMap(StreamsConfig.APPLICATION_SERVER_CONFIG, USER_END_POINT));
 
-        partitionAssignor.setInternalTopicManager(new MockInternalTopicManager(streamsConfig, mockClientSupplier.restoreConsumer));
-
         subscriptions.put("consumer1",
                           new Subscription(
                               topics,
@@ -1263,12 +1301,7 @@ public class StreamsPartitionAssignorTest {
 
         builder = TopologyWrapper.getInternalTopologyBuilder(streamsBuilder.build());
 
-        configureDefault();
-
-        final MockInternalTopicManager mockInternalTopicManager = new MockInternalTopicManager(
-            streamsConfig,
-            mockClientSupplier.restoreConsumer);
-        partitionAssignor.setInternalTopicManager(mockInternalTopicManager);
+        final MockInternalTopicManager mockInternalTopicManager = configureDefault();
 
         subscriptions.put(client,
                           new Subscription(
@@ -1287,7 +1320,7 @@ public class StreamsPartitionAssignorTest {
         final Map<HostInfo, Set<TopicPartition>> initialHostState = mkMap(
             mkEntry(new HostInfo("localhost", 9090), mkSet(t1p0, t1p1)),
             mkEntry(new HostInfo("otherhost", 9090), mkSet(t2p0, t2p1))
-            );
+        );
 
         final Map<HostInfo, Set<TopicPartition>> newHostState = mkMap(
             mkEntry(new HostInfo("localhost", 9090), mkSet(t1p0, t1p1)),
@@ -1322,7 +1355,7 @@ public class StreamsPartitionAssignorTest {
         );
 
         createDefaultMockTaskManager();
-        configurePartitionAssignorWith(singletonMap(StreamsConfig.APPLICATION_SERVER_CONFIG, "newhost:9090"));
+        configurePartitionAssignorWith(Collections.singletonMap(StreamsConfig.APPLICATION_SERVER_CONFIG, "newhost:9090"));
 
         partitionAssignor.onAssignment(createAssignment(oldHostState), null);
 
@@ -1343,14 +1376,16 @@ public class StreamsPartitionAssignorTest {
         builder = TopologyWrapper.getInternalTopologyBuilder(streamsBuilder.build());
 
         createDefaultMockTaskManager();
+        createMockAdminClient(getTopicPartitionOffsetsMap(
+            singletonList(APPLICATION_ID + "-KSTREAM-AGGREGATE-STATE-STORE-0000000001-changelog"),
+            singletonList(3))
+        );
+
         final Map<String, Object> props = new HashMap<>();
         props.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1);
         props.put(StreamsConfig.APPLICATION_SERVER_CONFIG, USER_END_POINT);
-        configurePartitionAssignorWith(props);
 
-        partitionAssignor.setInternalTopicManager(new MockInternalTopicManager(
-            streamsConfig,
-            mockClientSupplier.restoreConsumer));
+        configurePartitionAssignorWith(props);
 
         subscriptions.put("consumer1",
                           new Subscription(
@@ -1575,10 +1610,6 @@ public class StreamsPartitionAssignorTest {
 
     @Test
     public void shouldReturnNormalAssignmentForOldAndFutureInstancesDuringVersionProbing() {
-        final Map<String, Object> props = configProps();
-        props.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, "1");
-        final StreamsConfig streamsConfig = new StreamsConfig(props);
-
         builder.addSource(null, "source1", null, null, null, "topic1");
         builder.addProcessor("processor", new MockProcessorSupplier(), "source1");
         builder.addStateStore(new MockKeyValueStoreBuilder("store1", false), "processor");
@@ -1596,8 +1627,12 @@ public class StreamsPartitionAssignorTest {
         );
 
         createMockTaskManager(allTasks, allTasks);
+        createMockAdminClient(getTopicPartitionOffsetsMap(
+            singletonList(APPLICATION_ID + "-store1-changelog"),
+            singletonList(3))
+        );
+
         configurePartitionAssignorWith(Collections.singletonMap(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1));
-        partitionAssignor.setInternalTopicManager(new MockInternalTopicManager(streamsConfig, mockClientSupplier.restoreConsumer));
 
         subscriptions.put("consumer1",
                 new Subscription(
@@ -1754,10 +1789,6 @@ public class StreamsPartitionAssignorTest {
 
         configureDefault();
 
-        final MockInternalTopicManager internalTopicManager =
-            new MockInternalTopicManager(streamsConfig, mockClientSupplier.restoreConsumer);
-        partitionAssignor.setInternalTopicManager(internalTopicManager);
-
         subscriptions.put("consumer10",
             new Subscription(
                 topics,
@@ -1795,6 +1826,218 @@ public class StreamsPartitionAssignorTest {
         assertThat(partitionAssignor.probingRebalanceIntervalMs(), equalTo(55 * 60 * 1000L));
     }
 
+    @Test
+    public void shouldSetAdminClientTimeout() {
+        createDefaultMockTaskManager();
+
+        final Map<String, Object> props = configProps();
+        props.put(AdminClientConfig.REQUEST_TIMEOUT_MS_CONFIG, 2 * 60 * 1000);
+        final AssignorConfiguration assignorConfiguration = new AssignorConfiguration(props);
+
+        assertThat(assignorConfiguration.getAdminClientTimeout(), is(2 * 60 * 1000));
+    }
+
+    @Test
+    public void shouldRankPreviousClientAboveEquallyCaughtUpClient() {
+        final ClientState client1 = EasyMock.createMock(ClientState.class);
+        final ClientState client2 = EasyMock.createMock(ClientState.class);
+
+        expect(client1.lagFor(task0_0)).andReturn(Task.LATEST_OFFSET);
+        expect(client2.lagFor(task0_0)).andReturn(0L);
+
+        final SortedSet<RankedClient<UUID>> expectedClientRanking = mkSortedSet(
+            new RankedClient<>(uuid1, Task.LATEST_OFFSET),
+            new RankedClient<>(uuid2, 0L)
+        );
+
+        replay(client1, client2);
+
+        final Map<UUID, ClientState> states = mkMap(
+            mkEntry(uuid1, client1),
+            mkEntry(uuid2, client2)
+        );
+
+        final Map<TaskId, SortedSet<RankedClient<UUID>>> statefulTasksToRankedCandidates =
+            buildClientRankingsByTask(singleton(task0_0), states, ACCEPTABLE_RECOVERY_LAG);
+
+        final SortedSet<RankedClient<UUID>> clientRanking = statefulTasksToRankedCandidates.get(task0_0);
+
+        EasyMock.verify(client1, client2);
+        assertThat(clientRanking, equalTo(expectedClientRanking));
+    }
+
+    @Test
+    public void shouldRankTaskWithUnknownOffsetSumBelowCaughtUpClientAndClientWithLargeLag() {
+        final ClientState client1 = EasyMock.createMock(ClientState.class);
+        final ClientState client2 = EasyMock.createMock(ClientState.class);
+        final ClientState client3 = EasyMock.createMock(ClientState.class);
+
+        expect(client1.lagFor(task0_0)).andReturn(UNKNOWN_OFFSET_SUM);
+        expect(client2.lagFor(task0_0)).andReturn(50L);
+        expect(client3.lagFor(task0_0)).andReturn(500L);
+
+        final SortedSet<RankedClient<UUID>> expectedClientRanking = mkSortedSet(
+            new RankedClient<>(uuid2, 0L),
+            new RankedClient<>(uuid1, 1L),
+            new RankedClient<>(uuid3, 500L)
+        );
+
+        replay(client1, client2, client3);
+
+        final Map<UUID, ClientState> states = mkMap(
+            mkEntry(uuid1, client1),
+            mkEntry(uuid2, client2),
+            mkEntry(uuid3, client3)
+        );
+
+        final Map<TaskId, SortedSet<RankedClient<UUID>>> statefulTasksToRankedCandidates =
+            buildClientRankingsByTask(singleton(task0_0), states, ACCEPTABLE_RECOVERY_LAG);
+
+        final SortedSet<RankedClient<UUID>> clientRanking = statefulTasksToRankedCandidates.get(task0_0);
+
+        EasyMock.verify(client1, client2, client3);
+        assertThat(clientRanking, equalTo(expectedClientRanking));
+    }
+
+    @Test
+    public void shouldRankAllClientsWithinAcceptableRecoveryLagWithRank0() {
+        final ClientState client1 = EasyMock.createMock(ClientState.class);
+        final ClientState client2 = EasyMock.createMock(ClientState.class);
+
+        expect(client1.lagFor(task0_0)).andReturn(100L);
+        expect(client2.lagFor(task0_0)).andReturn(0L);
+
+        final SortedSet<RankedClient<UUID>> expectedClientRanking = mkSortedSet(
+            new RankedClient<>(uuid1, 0L),
+            new RankedClient<>(uuid2, 0L)
+        );
+
+        replay(client1, client2);
+
+        final Map<UUID, ClientState> states = mkMap(
+            mkEntry(uuid1, client1),
+            mkEntry(uuid2, client2)
+        );
+
+        final Map<TaskId, SortedSet<RankedClient<UUID>>> statefulTasksToRankedCandidates =
+            buildClientRankingsByTask(singleton(task0_0), states, ACCEPTABLE_RECOVERY_LAG);
+
+        EasyMock.verify(client1, client2);
+        assertThat(statefulTasksToRankedCandidates.get(task0_0), equalTo(expectedClientRanking));
+    }
+
+    @Test
+    public void shouldRankNotCaughtUpClientsAccordingToLag() {
+        final ClientState client1 = EasyMock.createMock(ClientState.class);
+        final ClientState client2 = EasyMock.createMock(ClientState.class);
+        final ClientState client3 = EasyMock.createMock(ClientState.class);
+
+        expect(client1.lagFor(task0_0)).andReturn(900L);
+        expect(client2.lagFor(task0_0)).andReturn(800L);
+        expect(client3.lagFor(task0_0)).andReturn(500L);
+
+        final SortedSet<RankedClient<UUID>> expectedClientRanking = mkSortedSet(
+            new RankedClient<>(uuid3, 500L),
+            new RankedClient<>(uuid2, 800L),
+            new RankedClient<>(uuid1, 900L)
+        );
+
+        replay(client1, client2, client3);
+
+        final Map<UUID, ClientState> states = mkMap(
+            mkEntry(uuid1, client1),
+            mkEntry(uuid2, client2),
+            mkEntry(uuid3, client3)
+        );
+
+        final Map<TaskId, SortedSet<RankedClient<UUID>>> statefulTasksToRankedCandidates =
+            buildClientRankingsByTask(singleton(task0_0), states, ACCEPTABLE_RECOVERY_LAG);
+
+        EasyMock.verify(client1, client2, client3);
+        assertThat(statefulTasksToRankedCandidates.get(task0_0), equalTo(expectedClientRanking));
+    }
+
+    @Test
+    public void shouldThrowIllegalStateExceptionIfAnyPartitionsMissingFromChangelogEndOffsets() {
+        final int changelogNumPartitions = 3;
+        builder.addSource(null, "source1", null, null, null, "topic1");
+        builder.addProcessor("processor1", new MockProcessorSupplier(), "source1");
+        builder.addStateStore(new MockKeyValueStoreBuilder("store1", false), "processor1");
+
+        createMockAdminClient(getTopicPartitionOffsetsMap(
+            singletonList(APPLICATION_ID + "-store1-changelog"),
+            singletonList(changelogNumPartitions - 1))
+        );
+
+        configureDefault();
+
+        subscriptions.put("consumer10",
+            new Subscription(
+                singletonList("topic1"),
+                defaultSubscriptionInfo.encode()
+            ));
+        assertThrows(IllegalStateException.class, () -> partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)));
+    }
+
+    @Test
+    public void shouldThrowIllegalStateExceptionIfAnyTopicsMissingFromChangelogEndOffsets() {
+        builder.addSource(null, "source1", null, null, null, "topic1");
+        builder.addProcessor("processor1", new MockProcessorSupplier(), "source1");
+        builder.addStateStore(new MockKeyValueStoreBuilder("store1", false), "processor1");
+        builder.addStateStore(new MockKeyValueStoreBuilder("store2", false), "processor1");
+
+        createMockAdminClient(getTopicPartitionOffsetsMap(
+            singletonList(APPLICATION_ID + "-store1-changelog"),
+            singletonList(3))
+        );
+
+        configureDefault();
+
+        subscriptions.put("consumer10",
+            new Subscription(
+                singletonList("topic1"),
+                defaultSubscriptionInfo.encode()
+            ));
+        assertThrows(IllegalStateException.class, () -> partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)));
+    }
+
+    @Test
+    public void shouldReturnAllActiveTasksToPreviousOwnerRegardlessOfBalanceIfEndOffsetFetchFails() {
+        builder.addSource(null, "source1", null, null, null, "topic1");
+        builder.addProcessor("processor1", new MockProcessorSupplier(), "source1");
+        builder.addStateStore(new MockKeyValueStoreBuilder("store1", false), "processor1");
+        final Set<TaskId> allTasks = mkSet(task0_0, task0_1, task0_2);
+
+        createMockTaskManager(allTasks, emptyTasks);
+        adminClient = EasyMock.createMock(AdminClient.class);
+        expect(adminClient.listOffsets(anyObject())).andThrow(new StreamsException("Should be handled"));
+        configureDefaultPartitionAssignor();
+
+        final String firstConsumer = "consumer1";
+        final String newConsumer = "consumer2";
+
+        subscriptions.put(firstConsumer,
+                          new Subscription(
+                              singletonList("source1"),
+                              getInfo(uuid1, allTasks, emptyTasks).encode()
+                          ));
+        subscriptions.put(newConsumer,
+                          new Subscription(
+                              singletonList("source1"),
+                              getInfo(uuid2, emptyTasks, emptyTasks).encode()
+                          ));
+
+        final Map<String, Assignment> assignments = partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)).groupAssignment();
+
+        final List<TaskId> firstConsumerActiveTasks =
+            AssignmentInfo.decode(assignments.get(firstConsumer).userData()).activeTasks();
+        final List<TaskId> newConsumerActiveTasks =
+            AssignmentInfo.decode(assignments.get(newConsumer).userData()).activeTasks();
+
+        assertThat(firstConsumerActiveTasks, equalTo(new ArrayList<>(allTasks)));
+        assertTrue(newConsumerActiveTasks.isEmpty());
+    }
+
     private static ByteBuffer encodeFutureSubscription() {
         final ByteBuffer buf = ByteBuffer.allocate(4 /* used version */ + 4 /* supported version */);
         buf.putInt(LATEST_SUPPORTED_VERSION + 1);
@@ -1883,6 +2126,29 @@ public class StreamsPartitionAssignorTest {
         }
     }
 
+    /**
+     * Helper for building the input to createMockAdminClient in cases where we don't care about the actual offsets
+     * @param changelogTopics The names of all changelog topics in the topology
+     * @param topicsNumPartitions The number of partitions for the corresponding changelog topic, such that the number
+     *            of partitions of the ith topic in changelogTopics is given by the ith element of topicsNumPartitions
+     */
+    private static Map<TopicPartition, Long> getTopicPartitionOffsetsMap(final List<String> changelogTopics,
+                                                                         final List<Integer> topicsNumPartitions) {
+        if (changelogTopics.size() != topicsNumPartitions.size()) {
+            throw new IllegalStateException("Passed in " + changelogTopics.size() + " changelog topic names, but " +
+                                               topicsNumPartitions.size() + " different numPartitions for the topics");
+        }
+        final Map<TopicPartition, Long> changelogEndOffsets = new HashMap<>();
+        for (int i = 0; i < changelogTopics.size(); ++i) {
+            final String topic = changelogTopics.get(i);
+            final int numPartitions = topicsNumPartitions.get(i);
+            for (int partition = 0; partition < numPartitions; ++partition) {
+                changelogEndOffsets.put(new TopicPartition(topic, partition), 0L);
+            }
+        }
+        return changelogEndOffsets;
+    }
+
     private static SubscriptionInfo getInfo(final UUID processId,
                                             final Set<TaskId> prevTasks,
                                             final Set<TaskId> standbyTasks) {
@@ -1907,7 +2173,7 @@ public class StreamsPartitionAssignorTest {
     }
 
     // Stub offset sums for when we only care about the prev/standby task sets, not the actual offsets
-    static Map<TaskId, Long> getTaskOffsetSums(final Set<TaskId> activeTasks, final Set<TaskId> standbyTasks) {
+    private static Map<TaskId, Long> getTaskOffsetSums(final Set<TaskId> activeTasks, final Set<TaskId> standbyTasks) {
         final Map<TaskId, Long> taskOffsetSums = activeTasks.stream().collect(Collectors.toMap(t -> t, t -> Task.LATEST_OFFSET));
         taskOffsetSums.putAll(standbyTasks.stream().collect(Collectors.toMap(t -> t, t -> 0L)));
         return taskOffsetSums;
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java
index 1443edf..f08ae1a 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java
@@ -16,20 +16,31 @@
  */
 package org.apache.kafka.streams.processor.internals.assignment;
 
+import java.util.Map;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.internals.Task;
 import org.junit.Test;
 
 import java.util.Collections;
 
+import static org.apache.kafka.common.utils.Utils.mkEntry;
+import static org.apache.kafka.common.utils.Utils.mkMap;
+import static org.apache.kafka.common.utils.Utils.mkSet;
+import static org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo.UNKNOWN_OFFSET_SUM;
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
 
 public class ClientStateTest {
 
     private final ClientState client = new ClientState(1);
+    private final ClientState zeroCapacityClient = new ClientState(0);
+
+    private final TaskId taskId01 = new TaskId(0, 1);
+    private final TaskId taskId02 = new TaskId(0, 2);
 
     @Test
     public void shouldHaveNotReachedCapacityWhenAssignedTasksLessThanCapacity() {
@@ -38,142 +49,212 @@ public class ClientStateTest {
 
     @Test
     public void shouldHaveReachedCapacityWhenAssignedTasksGreaterThanOrEqualToCapacity() {
-        client.assign(new TaskId(0, 1), true);
+        client.assignActive(taskId01);
         assertTrue(client.reachedCapacity());
     }
 
-
     @Test
     public void shouldAddActiveTasksToBothAssignedAndActive() {
-        final TaskId tid = new TaskId(0, 1);
-
-        client.assign(tid, true);
-        assertThat(client.activeTasks(), equalTo(Collections.singleton(tid)));
-        assertThat(client.assignedTasks(), equalTo(Collections.singleton(tid)));
+        client.assignActive(taskId01);
+        assertThat(client.activeTasks(), equalTo(Collections.singleton(taskId01)));
+        assertThat(client.assignedTasks(), equalTo(Collections.singleton(taskId01)));
         assertThat(client.assignedTaskCount(), equalTo(1));
         assertThat(client.standbyTasks().size(), equalTo(0));
     }
 
     @Test
-    public void shouldAddStandbyTasksToBothStandbyAndActive() {
-        final TaskId tid = new TaskId(0, 1);
-
-        client.assign(tid, false);
-        assertThat(client.assignedTasks(), equalTo(Collections.singleton(tid)));
-        assertThat(client.standbyTasks(), equalTo(Collections.singleton(tid)));
+    public void shouldAddStandbyTasksToBothStandbyAndAssigned() {
+        client.assignStandby(taskId01);
+        assertThat(client.assignedTasks(), equalTo(Collections.singleton(taskId01)));
+        assertThat(client.standbyTasks(), equalTo(Collections.singleton(taskId01)));
         assertThat(client.assignedTaskCount(), equalTo(1));
         assertThat(client.activeTasks().size(), equalTo(0));
     }
 
     @Test
     public void shouldAddPreviousActiveTasksToPreviousAssignedAndPreviousActive() {
-        final TaskId tid1 = new TaskId(0, 1);
-        final TaskId tid2 = new TaskId(0, 2);
-
-        client.addPreviousActiveTasks(Utils.mkSet(tid1, tid2));
-        assertThat(client.prevActiveTasks(), equalTo(Utils.mkSet(tid1, tid2)));
-        assertThat(client.previousAssignedTasks(), equalTo(Utils.mkSet(tid1, tid2)));
+        client.addPreviousActiveTasks(Utils.mkSet(taskId01, taskId02));
+        assertThat(client.prevActiveTasks(), equalTo(Utils.mkSet(taskId01, taskId02)));
+        assertThat(client.previousAssignedTasks(), equalTo(Utils.mkSet(taskId01, taskId02)));
     }
 
     @Test
-    public void shouldAddPreviousStandbyTasksToPreviousAssigned() {
-        final TaskId tid1 = new TaskId(0, 1);
-        final TaskId tid2 = new TaskId(0, 2);
-
-        client.addPreviousStandbyTasks(Utils.mkSet(tid1, tid2));
+    public void shouldAddPreviousStandbyTasksToPreviousAssignedAndPreviousStandby() {
+        client.addPreviousStandbyTasks(Utils.mkSet(taskId01, taskId02));
         assertThat(client.prevActiveTasks().size(), equalTo(0));
-        assertThat(client.previousAssignedTasks(), equalTo(Utils.mkSet(tid1, tid2)));
+        assertThat(client.previousAssignedTasks(), equalTo(Utils.mkSet(taskId01, taskId02)));
     }
 
     @Test
     public void shouldHaveAssignedTaskIfActiveTaskAssigned() {
-        final TaskId tid = new TaskId(0, 2);
-
-        client.assign(tid, true);
-        assertTrue(client.hasAssignedTask(tid));
+        client.assignActive(taskId01);
+        assertTrue(client.hasAssignedTask(taskId01));
     }
 
     @Test
     public void shouldHaveAssignedTaskIfStandbyTaskAssigned() {
-        final TaskId tid = new TaskId(0, 2);
-
-        client.assign(tid, false);
-        assertTrue(client.hasAssignedTask(tid));
+        client.assignStandby(taskId01);
+        assertTrue(client.hasAssignedTask(taskId01));
     }
 
     @Test
     public void shouldNotHaveAssignedTaskIfTaskNotAssigned() {
-
-        client.assign(new TaskId(0, 2), true);
-        assertFalse(client.hasAssignedTask(new TaskId(0, 3)));
+        client.assignActive(taskId01);
+        assertFalse(client.hasAssignedTask(taskId02));
     }
 
     @Test
     public void shouldHaveMoreAvailableCapacityWhenCapacityTheSameButFewerAssignedTasks() {
-        final ClientState c2 = new ClientState(1);
-        client.assign(new TaskId(0, 1), true);
-        assertTrue(c2.hasMoreAvailableCapacityThan(client));
-        assertFalse(client.hasMoreAvailableCapacityThan(c2));
+        final ClientState otherClient = new ClientState(1);
+        client.assignActive(taskId01);
+        assertTrue(otherClient.hasMoreAvailableCapacityThan(client));
+        assertFalse(client.hasMoreAvailableCapacityThan(otherClient));
     }
 
     @Test
     public void shouldHaveMoreAvailableCapacityWhenCapacityHigherAndSameAssignedTaskCount() {
-        final ClientState c2 = new ClientState(2);
-        assertTrue(c2.hasMoreAvailableCapacityThan(client));
-        assertFalse(client.hasMoreAvailableCapacityThan(c2));
+        final ClientState otherClient = new ClientState(2);
+        assertTrue(otherClient.hasMoreAvailableCapacityThan(client));
+        assertFalse(client.hasMoreAvailableCapacityThan(otherClient));
     }
 
     @Test
     public void shouldUseMultiplesOfCapacityToDetermineClientWithMoreAvailableCapacity() {
-        final ClientState c2 = new ClientState(2);
+        final ClientState otherClient = new ClientState(2);
 
         for (int i = 0; i < 7; i++) {
-            c2.assign(new TaskId(0, i), true);
+            otherClient.assignActive(new TaskId(0, i));
         }
 
         for (int i = 7; i < 11; i++) {
-            client.assign(new TaskId(0, i), true);
+            client.assignActive(new TaskId(0, i));
         }
 
-        assertTrue(c2.hasMoreAvailableCapacityThan(client));
+        assertTrue(otherClient.hasMoreAvailableCapacityThan(client));
     }
 
     @Test
     public void shouldHaveMoreAvailableCapacityWhenCapacityIsTheSameButAssignedTasksIsLess() {
-        final ClientState c1 = new ClientState(3);
-        final ClientState c2 = new ClientState(3);
+        final ClientState client = new ClientState(3);
+        final ClientState otherClient = new ClientState(3);
         for (int i = 0; i < 4; i++) {
-            c1.assign(new TaskId(0, i), true);
-            c2.assign(new TaskId(0, i), true);
+            client.assignActive(new TaskId(0, i));
+            otherClient.assignActive(new TaskId(0, i));
         }
-        c2.assign(new TaskId(0, 5), true);
-        assertTrue(c1.hasMoreAvailableCapacityThan(c2));
+        otherClient.assignActive(new TaskId(0, 5));
+        assertTrue(client.hasMoreAvailableCapacityThan(otherClient));
     }
 
-    @Test(expected = IllegalStateException.class)
+    @Test
     public void shouldThrowIllegalStateExceptionIfCapacityOfThisClientStateIsZero() {
-        final ClientState c1 = new ClientState(0);
-        c1.hasMoreAvailableCapacityThan(new ClientState(1));
+        assertThrows(IllegalStateException.class, () -> zeroCapacityClient.hasMoreAvailableCapacityThan(client));
     }
 
-    @Test(expected = IllegalStateException.class)
+    @Test
     public void shouldThrowIllegalStateExceptionIfCapacityOfOtherClientStateIsZero() {
-        final ClientState c1 = new ClientState(1);
-        c1.hasMoreAvailableCapacityThan(new ClientState(0));
+        assertThrows(IllegalStateException.class, () -> client.hasMoreAvailableCapacityThan(zeroCapacityClient));
     }
 
     @Test
     public void shouldHaveUnfulfilledQuotaWhenActiveTaskSizeLessThanCapacityTimesTasksPerThread() {
-        final ClientState client = new ClientState(1);
-        client.assign(new TaskId(0, 1), true);
+        client.assignActive(new TaskId(0, 1));
         assertTrue(client.hasUnfulfilledQuota(2));
     }
 
     @Test
     public void shouldNotHaveUnfulfilledQuotaWhenActiveTaskSizeGreaterEqualThanCapacityTimesTasksPerThread() {
-        final ClientState client = new ClientState(1);
-        client.assign(new TaskId(0, 1), true);
+        client.assignActive(new TaskId(0, 1));
         assertFalse(client.hasUnfulfilledQuota(1));
     }
 
+    @Test
+    public void shouldAddTasksWithLatestOffsetToPrevActiveTasks() {
+        final Map<TaskId, Long> taskOffsetSums = Collections.singletonMap(taskId01, Task.LATEST_OFFSET);
+        client.addPreviousTasksAndOffsetSums(taskOffsetSums);
+        assertThat(client.prevActiveTasks(), equalTo(Collections.singleton(taskId01)));
+        assertThat(client.previousAssignedTasks(), equalTo(Collections.singleton(taskId01)));
+        assertTrue(client.prevStandbyTasks().isEmpty());
+    }
+
+    @Test
+    public void shouldAddTasksInOffsetSumsMapToPrevStandbyTasks() {
+        final Map<TaskId, Long> taskOffsetSums = mkMap(
+            mkEntry(taskId01, 0L),
+            mkEntry(taskId02, 100L)
+        );
+        client.addPreviousTasksAndOffsetSums(taskOffsetSums);
+        assertThat(client.prevStandbyTasks(), equalTo(mkSet(taskId01, taskId02)));
+        assertThat(client.previousAssignedTasks(), equalTo(mkSet(taskId01, taskId02)));
+        assertTrue(client.prevActiveTasks().isEmpty());
+    }
+
+    @Test
+    public void shouldComputeTaskLags() {
+        final Map<TaskId, Long> taskOffsetSums = mkMap(
+            mkEntry(taskId01, 0L),
+            mkEntry(taskId02, 100L)
+        );
+        final Map<TaskId, Long> allTaskEndOffsetSums = mkMap(
+            mkEntry(taskId01, 500L),
+            mkEntry(taskId02, 100L)
+        );
+        client.addPreviousTasksAndOffsetSums(taskOffsetSums);
+        client.computeTaskLags(allTaskEndOffsetSums);
+
+        assertThat(client.lagFor(taskId01), equalTo(500L));
+        assertThat(client.lagFor(taskId02), equalTo(0L));
+    }
+
+    @Test
+    public void shouldReturnEndOffsetSumForLagOfTaskWeDidNotPreviouslyOwn() {
+        final Map<TaskId, Long> taskOffsetSums = Collections.emptyMap();
+        final Map<TaskId, Long> allTaskEndOffsetSums = Collections.singletonMap(taskId01, 500L);
+        client.addPreviousTasksAndOffsetSums(taskOffsetSums);
+        client.computeTaskLags(allTaskEndOffsetSums);
+        assertThat(client.lagFor(taskId01), equalTo(500L));
+    }
+
+    @Test
+    public void shouldReturnLatestOffsetForLagOfPreviousActiveRunningTask() {
+        final Map<TaskId, Long> taskOffsetSums = Collections.singletonMap(taskId01, Task.LATEST_OFFSET);
+        final Map<TaskId, Long> allTaskEndOffsetSums = Collections.singletonMap(taskId01, 500L);
+        client.addPreviousTasksAndOffsetSums(taskOffsetSums);
+        client.computeTaskLags(allTaskEndOffsetSums);
+        assertThat(client.lagFor(taskId01), equalTo(Task.LATEST_OFFSET));
+    }
+
+    @Test
+    public void shouldReturnUnknownOffsetSumForLagOfTaskWithUnknownOffset() {
+        final Map<TaskId, Long> taskOffsetSums = Collections.singletonMap(taskId01, UNKNOWN_OFFSET_SUM);
+        final Map<TaskId, Long> allTaskEndOffsetSums = Collections.singletonMap(taskId01, 500L);
+        client.addPreviousTasksAndOffsetSums(taskOffsetSums);
+        client.computeTaskLags(allTaskEndOffsetSums);
+        assertThat(client.lagFor(taskId01), equalTo(UNKNOWN_OFFSET_SUM));
+    }
+
+    @Test
+    public void shouldThrowIllegalStateExceptionIfOffsetSumIsGreaterThanEndOffsetSum() {
+        final Map<TaskId, Long> taskOffsetSums = Collections.singletonMap(taskId01, 5L);
+        final Map<TaskId, Long> allTaskEndOffsetSums = Collections.singletonMap(taskId01, 1L);
+        client.addPreviousTasksAndOffsetSums(taskOffsetSums);
+        assertThrows(IllegalStateException.class, () -> client.computeTaskLags(allTaskEndOffsetSums));
+    }
+
+    @Test
+    public void shouldThrowIllegalStateExceptionIfTaskLagsMapIsNotEmpty() {
+        final Map<TaskId, Long> taskOffsetSums = Collections.singletonMap(taskId01, 5L);
+        final Map<TaskId, Long> allTaskEndOffsetSums = Collections.singletonMap(taskId01, 1L);
+        client.computeTaskLags(taskOffsetSums);
+        assertThrows(IllegalStateException.class, () -> client.computeTaskLags(allTaskEndOffsetSums));
+    }
+
+    @Test
+    public void shouldThrowIllegalStateExceptionOnLagForUnknownTask() {
+        final Map<TaskId, Long> taskOffsetSums = Collections.singletonMap(taskId01, 0L);
+        final Map<TaskId, Long> allTaskEndOffsetSums = Collections.singletonMap(taskId01, 500L);
+        client.addPreviousTasksAndOffsetSums(taskOffsetSums);
+        client.computeTaskLags(allTaskEndOffsetSums);
+        assertThrows(IllegalStateException.class, () -> client.lagFor(taskId02));
+    }
+
 }
\ No newline at end of file
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignorTest.java
index 3fb284a..47254bb 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignorTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignorTest.java
@@ -674,6 +674,19 @@ public class StickyTaskAssignorTest {
         assertThat(newClient.activeTaskCount(), equalTo(2));
     }
 
+    @Test
+    public void shouldViolateBalanceToPreserveActiveTaskStickiness() {
+        final ClientState c1 = createClientWithPreviousActiveTasks(p1, 1, task00, task01, task02);
+        final ClientState c2 = createClient(p2, 1);
+
+        final StickyTaskAssignor<Integer> taskAssignor = createTaskAssignor(task00, task01, task02);
+        taskAssignor.preservePreviousTaskAssignment();
+        taskAssignor.assign(0);
+
+        assertThat(c1.activeTasks(), equalTo(Utils.mkSet(task00, task01, task02)));
+        assertTrue(c2.activeTasks().isEmpty());
+    }
+
     private StickyTaskAssignor<Integer> createTaskAssignor(final TaskId... tasks) {
         final List<TaskId> taskIds = Arrays.asList(tasks);
         Collections.shuffle(taskIds);