You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by vv...@apache.org on 2020/04/28 20:58:02 UTC

[kafka] branch trunk updated: KAFKA-6145: KIP-441: Add TaskAssignor class config (#8541)

This is an automated email from the ASF dual-hosted git repository.

vvcephei pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 5bb3415  KAFKA-6145: KIP-441: Add TaskAssignor class config (#8541)
5bb3415 is described below

commit 5bb3415c77cc61b7d1591ccfe028d10bbf9f2a7a
Author: John Roesler <vv...@users.noreply.github.com>
AuthorDate: Tue Apr 28 15:57:11 2020 -0500

    KAFKA-6145: KIP-441: Add TaskAssignor class config (#8541)
    
    * add a config to set the TaskAssignor
    * set the default assignor to HighAvailabilityTaskAssignor
    * fix broken tests (with some TODOs in the system tests)
    
    Implements: KIP-441
    Reviewers: Bruno Cadonna <br...@confluent.io>, A. Sophie Blee-Goldman <so...@confluent.io>
---
 build.gradle                                       |   4 +-
 .../java/org/apache/kafka/common/utils/Utils.java  |   9 +
 .../org/apache/kafka/common/utils/UtilsTest.java   |  15 +
 .../test/java/org/apache/kafka/test/TestUtils.java |   4 +-
 .../org/apache/kafka/streams/StreamsConfig.java    |   5 +
 .../internals/StreamsPartitionAssignor.java        |  62 +--
 .../assignment/AssignorConfiguration.java          |  24 +-
 .../internals/assignment/ClientState.java          |  22 +-
 .../assignment/FallbackPriorTaskAssignor.java      |  49 ++
 .../assignment/HighAvailabilityTaskAssignor.java   |  60 +--
 .../internals/assignment/StickyTaskAssignor.java   |  36 +-
 .../internals/assignment/TaskAssignor.java         |  13 +-
 .../streams/integration/EosIntegrationTest.java    | 121 +++--
 .../integration/LagFetchIntegrationTest.java       |  47 +-
 ...ghAvailabilityStreamsPartitionAssignorTest.java | 332 ++++++++++++++
 .../internals/StreamsPartitionAssignorTest.java    | 162 ++-----
 .../internals/assignment/ClientStateTest.java      |  33 +-
 .../assignment/FallbackPriorTaskAssignorTest.java  |  74 ++++
 .../HighAvailabilityTaskAssignorTest.java          | 491 ++++++++++-----------
 .../assignment/StickyTaskAssignorTest.java         | 325 +++++++-------
 .../assignment/TaskAssignorConvergenceTest.java    |   7 +-
 tests/kafkatest/services/streams.py                |  12 +-
 .../streams/streams_broker_down_resilience_test.py |  12 +-
 .../tests/streams/streams_standby_replica_test.py  |  11 +-
 .../tests/streams/streams_upgrade_test.py          |   4 +
 25 files changed, 1233 insertions(+), 701 deletions(-)

diff --git a/build.gradle b/build.gradle
index 644d3eb..b8a8862 100644
--- a/build.gradle
+++ b/build.gradle
@@ -236,8 +236,10 @@ subprojects {
     def logStreams = new HashMap<String, FileOutputStream>()
     beforeTest { TestDescriptor td ->
       def tid = testId(td)
+      // truncate the file name if it's too long
       def logFile = new File(
-          "${projectDir}/build/reports/testOutput/${tid}.test.stdout")
+              "${projectDir}/build/reports/testOutput/${tid.substring(0, Math.min(tid.size(),240))}.test.stdout"
+      )
       logFile.parentFile.mkdirs()
       logFiles.put(tid, logFile)
       logStreams.put(tid, new FileOutputStream(logFile))
diff --git a/clients/src/main/java/org/apache/kafka/common/utils/Utils.java b/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
index ee627c9..87c749a 100755
--- a/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
+++ b/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
@@ -1146,4 +1146,13 @@ public final class Utils {
             }
         };
     }
+
+    @SafeVarargs
+    public static <E> Set<E> union(final Supplier<Set<E>> constructor, final Set<E>... set) {
+        final Set<E> result = constructor.get();
+        for (final Set<E> s : set) {
+            result.addAll(s);
+        }
+        return result;
+    }
 }
diff --git a/clients/src/test/java/org/apache/kafka/common/utils/UtilsTest.java b/clients/src/test/java/org/apache/kafka/common/utils/UtilsTest.java
index c0e5fc8..0744f77 100755
--- a/clients/src/test/java/org/apache/kafka/common/utils/UtilsTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/utils/UtilsTest.java
@@ -37,6 +37,7 @@ import java.util.Map;
 import java.util.Properties;
 import java.util.Random;
 import java.util.Set;
+import java.util.TreeSet;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 
@@ -47,7 +48,11 @@ import static org.apache.kafka.common.utils.Utils.getHost;
 import static org.apache.kafka.common.utils.Utils.getPort;
 import static org.apache.kafka.common.utils.Utils.mkSet;
 import static org.apache.kafka.common.utils.Utils.murmur2;
+import static org.apache.kafka.common.utils.Utils.union;
 import static org.apache.kafka.common.utils.Utils.validHostPattern;
+import static org.hamcrest.CoreMatchers.equalTo;
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
@@ -582,4 +587,14 @@ public class UtilsTest {
         } catch (IllegalArgumentException e) {
         }
     }
+
+    @Test
+    public void testUnion() {
+        final Set<String> oneSet = mkSet("a", "b", "c");
+        final Set<String> anotherSet = mkSet("c", "d", "e");
+        final Set<String> union = union(TreeSet::new, oneSet, anotherSet);
+
+        assertThat(union, is(mkSet("a", "b", "c", "d", "e")));
+        assertThat(union.getClass(), equalTo(TreeSet.class));
+    }
 }
diff --git a/clients/src/test/java/org/apache/kafka/test/TestUtils.java b/clients/src/test/java/org/apache/kafka/test/TestUtils.java
index f9be363..23fd5ed 100644
--- a/clients/src/test/java/org/apache/kafka/test/TestUtils.java
+++ b/clients/src/test/java/org/apache/kafka/test/TestUtils.java
@@ -375,9 +375,9 @@ public class TestUtils {
      * avoid transient failures due to slow or overloaded machines.
      */
     public static void waitForCondition(final TestCondition testCondition, final long maxWaitMs, Supplier<String> conditionDetailsSupplier) throws InterruptedException {
-        String conditionDetailsSupplied = conditionDetailsSupplier != null ? conditionDetailsSupplier.get() : null;
-        String conditionDetails = conditionDetailsSupplied != null ? conditionDetailsSupplied : "";
         retryOnExceptionWithTimeout(maxWaitMs, () -> {
+            String conditionDetailsSupplied = conditionDetailsSupplier != null ? conditionDetailsSupplier.get() : null;
+            String conditionDetails = conditionDetailsSupplied != null ? conditionDetailsSupplied : "";
             assertThat("Condition not met within timeout " + maxWaitMs + ". " + conditionDetails,
                 testCondition.conditionMet());
         });
diff --git a/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java b/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java
index c441dee..1df6839 100644
--- a/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java
+++ b/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java
@@ -868,6 +868,11 @@ public class StreamsConfig extends AbstractConfig {
     }
 
     public static class InternalConfig {
+        // This is settable in the main Streams config, but it's a private API for now
+        public static final String INTERNAL_TASK_ASSIGNOR_CLASS = "internal.task.assignor.class";
+
+        // These are not settable in the main Streams config; they are set by the StreamThread to pass internal
+        // state into the assignor.
         public static final String TASK_MANAGER_FOR_PARTITION_ASSIGNOR = "__task.manager.instance__";
         public static final String STREAMS_METADATA_STATE_FOR_PARTITION_ASSIGNOR = "__streams.metadata.state.instance__";
         public static final String STREAMS_ADMIN_CLIENT = "__streams.admin.client.instance__";
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java
index d285e31..666da21 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java
@@ -39,7 +39,7 @@ import org.apache.kafka.streams.processor.internals.assignment.AssignorConfigura
 import org.apache.kafka.streams.processor.internals.assignment.AssignorError;
 import org.apache.kafka.streams.processor.internals.assignment.ClientState;
 import org.apache.kafka.streams.processor.internals.assignment.CopartitionedTopicsEnforcer;
-import org.apache.kafka.streams.processor.internals.assignment.HighAvailabilityTaskAssignor;
+import org.apache.kafka.streams.processor.internals.assignment.FallbackPriorTaskAssignor;
 import org.apache.kafka.streams.processor.internals.assignment.StickyTaskAssignor;
 import org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo;
 import org.apache.kafka.streams.processor.internals.assignment.TaskAssignor;
@@ -64,6 +64,7 @@ import java.util.TreeMap;
 import java.util.UUID;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.Supplier;
 import java.util.stream.Collectors;
 
 import static java.util.UUID.randomUUID;
@@ -171,7 +172,7 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
     private CopartitionedTopicsEnforcer copartitionedTopicsEnforcer;
     private RebalanceProtocol rebalanceProtocol;
 
-    private boolean highAvailabilityEnabled;
+    private Supplier<TaskAssignor> taskAssignorSupplier;
 
     /**
      * We need to have the PartitionAssignor and its StreamThread to be mutually accessible since the former needs
@@ -201,7 +202,7 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
         internalTopicManager = assignorConfiguration.getInternalTopicManager();
         copartitionedTopicsEnforcer = assignorConfiguration.getCopartitionedTopicsEnforcer();
         rebalanceProtocol = assignorConfiguration.rebalanceProtocol();
-        highAvailabilityEnabled = assignorConfiguration.isHighAvailabilityEnabled();
+        taskAssignorSupplier = assignorConfiguration::getTaskAssignor;
     }
 
     @Override
@@ -361,7 +362,7 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
         final Map<TaskId, Set<TopicPartition>> partitionsForTask =
             partitionGrouper.partitionGroups(sourceTopicsByGroup, fullMetadata);
 
-        final boolean followupRebalanceNeeded =
+        final boolean probingRebalanceNeeded =
             assignTasksToClients(allSourceTopics, partitionsForTask, topicGroups, clientMetadataMap, fullMetadata);
 
         // ---------------- Step Three ---------------- //
@@ -399,7 +400,7 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
                 allOwnedPartitions,
                 minReceivedMetadataVersion,
                 minSupportedMetadataVersion,
-                followupRebalanceNeeded
+                probingRebalanceNeeded
             );
         }
 
@@ -688,7 +689,7 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
 
     /**
      * Assigns a set of tasks to each client (Streams instance) using the configured task assignor
-     * @return true if a followup rebalance should be triggered
+     * @return true if a probing rebalance should be triggered
      */
     private boolean assignTasksToClients(final Set<String> allSourceTopics,
                                          final Map<TaskId, Set<TopicPartition>> partitionsForTask,
@@ -712,29 +713,32 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
         log.debug("Assigning tasks {} to clients {} with number of replicas {}",
             allTasks, clientStates, numStandbyReplicas());
 
-        final TaskAssignor taskAssignor;
-        if (highAvailabilityEnabled) {
-            if (lagComputationSuccessful) {
-                taskAssignor = new HighAvailabilityTaskAssignor(
-                    clientStates,
-                    allTasks,
-                    statefulTasks,
-                    assignmentConfigs);
-            } else {
-                log.info("Failed to fetch end offsets for changelogs, will return previous assignment to clients and "
-                             + "trigger another rebalance to retry.");
-                setAssignmentErrorCode(AssignorError.REBALANCE_NEEDED.code());
-                taskAssignor = new StickyTaskAssignor(clientStates, allTasks, statefulTasks, assignmentConfigs, true);
-            }
-        } else {
-            taskAssignor = new StickyTaskAssignor(clientStates, allTasks, statefulTasks, assignmentConfigs, false);
-        }
-        final boolean followupRebalanceNeeded = taskAssignor.assign();
+        final TaskAssignor taskAssignor = createTaskAssignor(lagComputationSuccessful);
+
+        final boolean probingRebalanceNeeded = taskAssignor.assign(clientStates,
+                                                                   allTasks,
+                                                                   statefulTasks,
+                                                                   assignmentConfigs);
 
         log.info("Assigned tasks to clients as {}{}.",
             Utils.NL, clientStates.entrySet().stream().map(Map.Entry::toString).collect(Collectors.joining(Utils.NL)));
 
-        return followupRebalanceNeeded;
+        return probingRebalanceNeeded;
+    }
+
+    private TaskAssignor createTaskAssignor(final boolean lagComputationSuccessful) {
+        final TaskAssignor taskAssignor = taskAssignorSupplier.get();
+        if (taskAssignor instanceof StickyTaskAssignor) {
+            // special case: to preserve pre-existing behavior, we invoke the StickyTaskAssignor
+            // whether or not lag computation failed.
+            return taskAssignor;
+        } else if (lagComputationSuccessful) {
+            return taskAssignor;
+        } else {
+            log.info("Failed to fetch end offsets for changelogs, will return previous assignment to clients and "
+                         + "trigger another rebalance to retry.");
+            return new FallbackPriorTaskAssignor();
+        }
     }
 
     /**
@@ -968,9 +972,9 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
                                       final int minUserMetadataVersion,
                                       final int minSupportedMetadataVersion,
                                       final boolean versionProbing,
-                                      final boolean followupRebalanceNeeded) {
-        boolean encodeNextRebalanceTime = followupRebalanceNeeded;
-        boolean stableAssignment = !followupRebalanceNeeded && !versionProbing;
+                                      final boolean probingRebalanceNeeded) {
+        boolean encodeNextRebalanceTime = probingRebalanceNeeded;
+        boolean stableAssignment = !probingRebalanceNeeded && !versionProbing;
 
         // Loop through the consumers and build their assignment
         for (final String consumer : clientMetadata.consumers) {
@@ -1025,7 +1029,7 @@ public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Conf
         if (stableAssignment) {
             log.info("Finished stable assignment of tasks, no followup rebalances required.");
         } else {
-            log.info("Finished unstable assignment of tasks, a followup rebalance will be triggered.");
+            log.info("Finished unstable assignment of tasks, a followup probing rebalance will be triggered.");
         }
     }
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfiguration.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfiguration.java
index d12640a..2a5d1d4 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfiguration.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfiguration.java
@@ -16,7 +16,6 @@
  */
 package org.apache.kafka.streams.processor.internals.assignment;
 
-import java.util.concurrent.atomic.AtomicLong;
 import org.apache.kafka.clients.CommonClientConfigs;
 import org.apache.kafka.clients.admin.Admin;
 import org.apache.kafka.clients.admin.AdminClientConfig;
@@ -25,6 +24,7 @@ import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.config.ConfigException;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.StreamsConfig.InternalConfig;
 import org.apache.kafka.streams.internals.QuietStreamsConfig;
@@ -35,14 +35,15 @@ import org.slf4j.Logger;
 
 import java.util.Map;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicLong;
 
 import static org.apache.kafka.common.utils.Utils.getHost;
 import static org.apache.kafka.common.utils.Utils.getPort;
+import static org.apache.kafka.streams.StreamsConfig.InternalConfig.INTERNAL_TASK_ASSIGNOR_CLASS;
 import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.LATEST_SUPPORTED_VERSION;
 
 public final class AssignorConfiguration {
-    public static final String HIGH_AVAILABILITY_ENABLED_CONFIG = "internal.high.availability.enabled";
-    private final boolean highAvailabilityEnabled;
+    private final String taskAssignorClass;
 
     private final String logPrefix;
     private final Logger log;
@@ -162,11 +163,11 @@ public final class AssignorConfiguration {
         copartitionedTopicsEnforcer = new CopartitionedTopicsEnforcer(logPrefix);
 
         {
-            final Object o = configs.get(HIGH_AVAILABILITY_ENABLED_CONFIG);
+            final String o = (String) configs.get(INTERNAL_TASK_ASSIGNOR_CLASS);
             if (o == null) {
-                highAvailabilityEnabled = false;
+                taskAssignorClass = HighAvailabilityTaskAssignor.class.getName();
             } else {
-                highAvailabilityEnabled = (Boolean) o;
+                taskAssignorClass = o;
             }
         }
     }
@@ -328,8 +329,15 @@ public final class AssignorConfiguration {
         return assignmentConfigs;
     }
 
-    public boolean isHighAvailabilityEnabled() {
-        return highAvailabilityEnabled;
+    public TaskAssignor getTaskAssignor() {
+        try {
+            return Utils.newInstance(taskAssignorClass, TaskAssignor.class);
+        } catch (final ClassNotFoundException e) {
+            throw new IllegalArgumentException(
+                "Expected an instantiable class name for " + INTERNAL_TASK_ASSIGNOR_CLASS,
+                e
+            );
+        }
     }
 
     public static class AssignmentConfigs {
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
index 5b8857c..3f64592 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
@@ -29,6 +29,10 @@ import java.util.Map;
 import java.util.Set;
 import java.util.UUID;
 
+import static java.util.Collections.emptyMap;
+import static java.util.Collections.unmodifiableMap;
+import static java.util.Collections.unmodifiableSet;
+import static org.apache.kafka.common.utils.Utils.union;
 import static org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo.UNKNOWN_OFFSET_SUM;
 
 public class ClientState {
@@ -86,6 +90,22 @@ public class ClientState {
         this.capacity = capacity;
     }
 
+    public ClientState(final Set<TaskId> previousActiveTasks,
+                       final Set<TaskId> previousStandbyTasks,
+                       final Map<TaskId, Long> taskLagTotals,
+                       final int capacity) {
+        activeTasks = new HashSet<>();
+        standbyTasks = new HashSet<>();
+        assignedTasks = new HashSet<>();
+        prevActiveTasks = unmodifiableSet(new HashSet<>(previousActiveTasks));
+        prevStandbyTasks = unmodifiableSet(new HashSet<>(previousStandbyTasks));
+        prevAssignedTasks = unmodifiableSet(union(HashSet::new, previousActiveTasks, previousStandbyTasks));
+        ownedPartitions = emptyMap();
+        taskOffsetSums = emptyMap();
+        this.taskLagTotals = unmodifiableMap(taskLagTotals);
+        this.capacity = capacity;
+    }
+
     public ClientState copy() {
         return new ClientState(
             new HashSet<>(activeTasks),
@@ -258,7 +278,7 @@ public class ClientState {
     }
 
     boolean hasMoreAvailableCapacityThan(final ClientState other) {
-        if (this.capacity <= 0) {
+        if (capacity <= 0) {
             throw new IllegalStateException("Capacity of this ClientState must be greater than 0.");
         }
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/FallbackPriorTaskAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/FallbackPriorTaskAssignor.java
new file mode 100644
index 0000000..b17b25c
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/FallbackPriorTaskAssignor.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs;
+
+import java.util.Map;
+import java.util.Set;
+import java.util.UUID;
+
+/**
+ * A special task assignor implementation to be used as a fallback in case the
+ * configured assignor couldn't be invoked.
+ *
+ * Specifically, this assignor must:
+ * 1. ignore the task lags in the ClientState map
+ * 2. always return true, indicating that a follow-up rebalance is needed
+ */
+public class FallbackPriorTaskAssignor implements TaskAssignor {
+    private final StickyTaskAssignor delegate;
+
+    public FallbackPriorTaskAssignor() {
+        delegate = new StickyTaskAssignor(true);
+    }
+
+    @Override
+    public boolean assign(final Map<UUID, ClientState> clients,
+                          final Set<TaskId> allTaskIds,
+                          final Set<TaskId> standbyTaskIds,
+                          final AssignmentConfigs configs) {
+        delegate.assign(clients, allTaskIds, standbyTaskIds, configs);
+        return true;
+    }
+}
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java
index b1570fb..3253c26 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java
@@ -16,49 +16,50 @@
  */
 package org.apache.kafka.streams.processor.internals.assignment;
 
-import static org.apache.kafka.streams.processor.internals.assignment.AssignmentUtils.taskIsCaughtUpOnClientOrNoCaughtUpClientsExist;
-import static org.apache.kafka.streams.processor.internals.assignment.RankedClient.buildClientRankingsByTask;
-import static org.apache.kafka.streams.processor.internals.assignment.RankedClient.tasksToCaughtUpClients;
-import static org.apache.kafka.streams.processor.internals.assignment.TaskMovement.assignTaskMovements;
+import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
+import java.util.Map;
+import java.util.Set;
 import java.util.SortedMap;
 import java.util.SortedSet;
 import java.util.TreeSet;
 import java.util.UUID;
 import java.util.stream.Collectors;
-import org.apache.kafka.streams.processor.TaskId;
-import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 
-import java.util.Map;
-import java.util.Set;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentUtils.taskIsCaughtUpOnClientOrNoCaughtUpClientsExist;
+import static org.apache.kafka.streams.processor.internals.assignment.RankedClient.buildClientRankingsByTask;
+import static org.apache.kafka.streams.processor.internals.assignment.RankedClient.tasksToCaughtUpClients;
+import static org.apache.kafka.streams.processor.internals.assignment.TaskMovement.assignTaskMovements;
 
 public class HighAvailabilityTaskAssignor implements TaskAssignor {
     private static final Logger log = LoggerFactory.getLogger(HighAvailabilityTaskAssignor.class);
 
-    private final Map<UUID, ClientState> clientStates;
-    private final Map<UUID, Integer> clientsToNumberOfThreads;
-    private final SortedSet<UUID> sortedClients;
+    private Map<UUID, ClientState> clientStates;
+    private Map<UUID, Integer> clientsToNumberOfThreads;
+    private SortedSet<UUID> sortedClients;
 
-    private final Set<TaskId> allTasks;
-    private final SortedSet<TaskId> statefulTasks;
-    private final SortedSet<TaskId> statelessTasks;
+    private Set<TaskId> allTasks;
+    private SortedSet<TaskId> statefulTasks;
+    private SortedSet<TaskId> statelessTasks;
 
-    private final AssignmentConfigs configs;
+    private AssignmentConfigs configs;
 
-    private final SortedMap<TaskId, SortedSet<RankedClient>> statefulTasksToRankedCandidates;
-    private final Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients;
+    private SortedMap<TaskId, SortedSet<RankedClient>> statefulTasksToRankedCandidates;
+    private Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients;
 
-    public HighAvailabilityTaskAssignor(final Map<UUID, ClientState> clientStates,
-                                        final Set<TaskId> allTasks,
-                                        final Set<TaskId> statefulTasks,
-                                        final AssignmentConfigs configs) {
+    @Override
+    public boolean assign(final Map<UUID, ClientState> clientStates,
+                          final Set<TaskId> allTasks,
+                          final Set<TaskId> statefulTasks,
+                          final AssignmentConfigs configs) {
         this.configs = configs;
         this.clientStates = clientStates;
         this.allTasks = allTasks;
@@ -77,10 +78,8 @@ public class HighAvailabilityTaskAssignor implements TaskAssignor {
         statefulTasksToRankedCandidates =
             buildClientRankingsByTask(statefulTasks, clientStates, configs.acceptableRecoveryLag);
         tasksToCaughtUpClients = tasksToCaughtUpClients(statefulTasksToRankedCandidates);
-    }
 
-    @Override
-    public boolean assign() {
+
         if (shouldUsePreviousAssignment()) {
             assignPreviousTasksToClientStates();
             return false;
@@ -89,13 +88,18 @@ public class HighAvailabilityTaskAssignor implements TaskAssignor {
         final Map<TaskId, Integer> tasksToRemainingStandbys =
             statefulTasks.stream().collect(Collectors.toMap(task -> task, t -> configs.numStandbyReplicas));
 
-        final boolean followupRebalanceNeeded = assignStatefulActiveTasks(tasksToRemainingStandbys);
+        final boolean probingRebalanceNeeded = assignStatefulActiveTasks(tasksToRemainingStandbys);
 
         assignStandbyReplicaTasks(tasksToRemainingStandbys);
 
         assignStatelessActiveTasks();
 
-        return followupRebalanceNeeded;
+        log.info("Decided on assignment: " +
+                     clientStates +
+                     " with " +
+                     (probingRebalanceNeeded ? "" : "no") +
+                     " followup probing rebalance.");
+        return probingRebalanceNeeded;
     }
 
     private boolean assignStatefulActiveTasks(final Map<TaskId, Integer> tasksToRemainingStandbys) {
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignor.java
index 2f2c77e..50b9381 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignor.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignor.java
@@ -16,7 +16,6 @@
  */
 package org.apache.kafka.streams.processor.internals.assignment;
 
-import java.util.UUID;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs;
 import org.slf4j.Logger;
@@ -32,40 +31,43 @@ import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
+import java.util.UUID;
 
 public class StickyTaskAssignor implements TaskAssignor {
 
     private static final Logger log = LoggerFactory.getLogger(StickyTaskAssignor.class);
-    private final Map<UUID, ClientState> clients;
-    private final Set<TaskId> allTaskIds;
-    private final Set<TaskId> standbyTaskIds;
+    private Map<UUID, ClientState> clients;
+    private Set<TaskId> allTaskIds;
+    private Set<TaskId> standbyTaskIds;
     private final Map<TaskId, UUID> previousActiveTaskAssignment = new HashMap<>();
     private final Map<TaskId, Set<UUID>> previousStandbyTaskAssignment = new HashMap<>();
-    private final TaskPairs taskPairs;
-    private final int numStandbyReplicas;
+    private TaskPairs taskPairs;
 
     private final boolean mustPreserveActiveTaskAssignment;
 
-    public StickyTaskAssignor(final Map<UUID, ClientState> clients,
-                              final Set<TaskId> allTaskIds,
-                              final Set<TaskId> standbyTaskIds,
-                              final AssignmentConfigs configs,
-                              final boolean mustPreserveActiveTaskAssignment) {
+    public StickyTaskAssignor() {
+        this(false);
+    }
+
+    StickyTaskAssignor(final boolean mustPreserveActiveTaskAssignment) {
+        this.mustPreserveActiveTaskAssignment = mustPreserveActiveTaskAssignment;
+    }
+
+    @Override
+    public boolean assign(final Map<UUID, ClientState> clients,
+                          final Set<TaskId> allTaskIds,
+                          final Set<TaskId> standbyTaskIds,
+                          final AssignmentConfigs configs) {
         this.clients = clients;
         this.allTaskIds = allTaskIds;
         this.standbyTaskIds = standbyTaskIds;
-        numStandbyReplicas = configs.numStandbyReplicas;
-        this.mustPreserveActiveTaskAssignment = mustPreserveActiveTaskAssignment;
 
         final int maxPairs = allTaskIds.size() * (allTaskIds.size() - 1) / 2;
         taskPairs = new TaskPairs(maxPairs);
         mapPreviousTaskAssignment(clients);
-    }
 
-    @Override
-    public boolean assign() {
         assignActive();
-        assignStandby(numStandbyReplicas);
+        assignStandby(configs.numStandbyReplicas);
         return false;
     }
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java
index cbecc24..485bd81 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java
@@ -16,9 +16,18 @@
  */
 package org.apache.kafka.streams.processor.internals.assignment;
 
+import org.apache.kafka.streams.processor.TaskId;
+
+import java.util.Map;
+import java.util.Set;
+import java.util.UUID;
+
 public interface TaskAssignor {
     /**
-     * @return whether the generated assignment requires a followup rebalance to satisfy all conditions
+     * @return whether the generated assignment requires a followup probing rebalance to satisfy all conditions
      */
-    boolean assign();
+    boolean assign(Map<UUID, ClientState> clients,
+                   Set<TaskId> allTaskIds,
+                   Set<TaskId> standbyTaskIds,
+                   AssignorConfiguration.AssignmentConfigs configs);
 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/EosIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/EosIntegrationTest.java
index 747ac40..e8ab668 100644
--- a/streams/src/test/java/org/apache/kafka/streams/integration/EosIntegrationTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/integration/EosIntegrationTest.java
@@ -59,6 +59,8 @@ import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 import org.junit.runners.Parameterized.Parameter;
 import org.junit.runners.Parameterized.Parameters;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import java.io.File;
 import java.util.ArrayList;
@@ -74,6 +76,7 @@ import java.util.Properties;
 import java.util.Set;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReference;
 
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
@@ -89,6 +92,7 @@ import static org.junit.Assert.fail;
 @RunWith(Parameterized.class)
 @Category({IntegrationTest.class})
 public class EosIntegrationTest {
+    private static final Logger LOG = LoggerFactory.getLogger(EosIntegrationTest.class);
     private static final int NUM_BROKERS = 3;
     private static final int MAX_POLL_INTERVAL_MS = 5 * 1000;
     private static final int MAX_WAIT_TIME_MS = 60 * 1000;
@@ -111,8 +115,9 @@ public class EosIntegrationTest {
     private final String storeName = "store";
 
     private AtomicBoolean errorInjected;
-    private AtomicBoolean gcInjected;
-    private volatile boolean doGC = true;
+    private AtomicBoolean stallInjected;
+    private AtomicReference<String> stallingHost;
+    private volatile boolean doStall = true;
     private AtomicInteger commitRequested;
     private Throwable uncaughtException;
 
@@ -382,7 +387,7 @@ public class EosIntegrationTest {
         // -> the failure only kills one thread
         // after fail over, we should read 40 committed records (even if 50 record got written)
 
-        try (final KafkaStreams streams = getKafkaStreams(false, "appDir", 2, eosConfig)) {
+        try (final KafkaStreams streams = getKafkaStreams("dummy", false, "appDir", 2, eosConfig)) {
             startKafkaStreamsAndWaitForRunningState(streams, MAX_WAIT_TIME_MS);
 
             final List<KeyValue<Long, Long>> committedDataBeforeFailure = prepareData(0L, 10L, 0L, 1L);
@@ -450,7 +455,7 @@ public class EosIntegrationTest {
         // after fail over, we should read 40 committed records and the state stores should contain the correct sums
         // per key (even if some records got processed twice)
 
-        try (final KafkaStreams streams = getKafkaStreams(true, "appDir", 2, eosConfig)) {
+        try (final KafkaStreams streams = getKafkaStreams("dummy", true, "appDir", 2, eosConfig)) {
             startKafkaStreamsAndWaitForRunningState(streams, MAX_WAIT_TIME_MS);
 
             final List<KeyValue<Long, Long>> committedDataBeforeFailure = prepareData(0L, 10L, 0L, 1L);
@@ -515,84 +520,114 @@ public class EosIntegrationTest {
         // the app is supposed to copy all 60 records into the output topic
         // the app commits after each 10 records per partition, and thus will have 2*5 uncommitted writes
         //
-        // a GC pause gets inject after 20 committed and 30 uncommitted records got received
-        // -> the GC pause only affects one thread and should trigger a rebalance
+        // a stall gets injected after 20 committed and 30 uncommitted records got received
+        // -> the stall only affects one thread and should trigger a rebalance
         // after rebalancing, we should read 40 committed records (even if 50 record got written)
         //
         // afterwards, the "stalling" thread resumes, and another rebalance should get triggered
         // we write the remaining 20 records and verify to read 60 result records
 
         try (
-            final KafkaStreams streams1 = getKafkaStreams(false, "appDir1", 1, eosConfig);
-            final KafkaStreams streams2 = getKafkaStreams(false, "appDir2", 1, eosConfig)
+            final KafkaStreams streams1 = getKafkaStreams("streams1", false, "appDir1", 1, eosConfig);
+            final KafkaStreams streams2 = getKafkaStreams("streams2", false, "appDir2", 1, eosConfig)
         ) {
             startKafkaStreamsAndWaitForRunningState(streams1, MAX_WAIT_TIME_MS);
             startKafkaStreamsAndWaitForRunningState(streams2, MAX_WAIT_TIME_MS);
 
-            final List<KeyValue<Long, Long>> committedDataBeforeGC = prepareData(0L, 10L, 0L, 1L);
-            final List<KeyValue<Long, Long>> uncommittedDataBeforeGC = prepareData(10L, 15L, 0L, 1L);
+            final List<KeyValue<Long, Long>> committedDataBeforeStall = prepareData(0L, 10L, 0L, 1L);
+            final List<KeyValue<Long, Long>> uncommittedDataBeforeStall = prepareData(10L, 15L, 0L, 1L);
 
-            final List<KeyValue<Long, Long>> dataBeforeGC = new ArrayList<>();
-            dataBeforeGC.addAll(committedDataBeforeGC);
-            dataBeforeGC.addAll(uncommittedDataBeforeGC);
+            final List<KeyValue<Long, Long>> dataBeforeStall = new ArrayList<>();
+            dataBeforeStall.addAll(committedDataBeforeStall);
+            dataBeforeStall.addAll(uncommittedDataBeforeStall);
 
             final List<KeyValue<Long, Long>> dataToTriggerFirstRebalance = prepareData(15L, 20L, 0L, 1L);
 
             final List<KeyValue<Long, Long>> dataAfterSecondRebalance = prepareData(20L, 30L, 0L, 1L);
 
-            writeInputData(committedDataBeforeGC);
+            writeInputData(committedDataBeforeStall);
 
             waitForCondition(
                 () -> commitRequested.get() == 2, MAX_WAIT_TIME_MS,
                 "SteamsTasks did not request commit.");
 
-            writeInputData(uncommittedDataBeforeGC);
+            writeInputData(uncommittedDataBeforeStall);
 
-            final List<KeyValue<Long, Long>> uncommittedRecords = readResult(dataBeforeGC.size(), null);
-            final List<KeyValue<Long, Long>> committedRecords = readResult(committedDataBeforeGC.size(), CONSUMER_GROUP_ID);
+            final List<KeyValue<Long, Long>> uncommittedRecords = readResult(dataBeforeStall.size(), null);
+            final List<KeyValue<Long, Long>> committedRecords = readResult(committedDataBeforeStall.size(), CONSUMER_GROUP_ID);
 
-            checkResultPerKey(committedRecords, committedDataBeforeGC);
-            checkResultPerKey(uncommittedRecords, dataBeforeGC);
+            checkResultPerKey(committedRecords, committedDataBeforeStall);
+            checkResultPerKey(uncommittedRecords, dataBeforeStall);
 
-            gcInjected.set(true);
+            LOG.info("Injecting Stall");
+            stallInjected.set(true);
             writeInputData(dataToTriggerFirstRebalance);
+            LOG.info("Input Data Written");
+            waitForCondition(
+                () -> stallingHost.get() != null,
+                MAX_WAIT_TIME_MS,
+                "Expected a host to start stalling"
+            );
+            final String observedStallingHost = stallingHost.get();
+            final KafkaStreams stallingInstance;
+            final KafkaStreams remainingInstance;
+            if ("streams1".equals(observedStallingHost)) {
+                stallingInstance = streams1;
+                remainingInstance = streams2;
+            } else if ("streams2".equals(observedStallingHost)) {
+                stallingInstance = streams2;
+                remainingInstance = streams1;
+            } else {
+                throw new IllegalArgumentException("unexpected host name: " + observedStallingHost);
+            }
 
+            // the stalling instance won't have an updated view, and it doesn't matter what it thinks
+            // the assignment is. We only really care that the remaining instance only sees one host
+            // that owns both partitions.
             waitForCondition(
-                () -> streams1.allMetadata().size() == 1
-                    && streams2.allMetadata().size() == 1
-                    && (streams1.allMetadata().iterator().next().topicPartitions().size() == 2
-                        || streams2.allMetadata().iterator().next().topicPartitions().size() == 2),
-                MAX_WAIT_TIME_MS, "Should have rebalanced.");
+                () -> stallingInstance.allMetadata().size() == 2
+                    && remainingInstance.allMetadata().size() == 1
+                    && remainingInstance.allMetadata().iterator().next().topicPartitions().size() == 2,
+                MAX_WAIT_TIME_MS,
+                () -> "Should have rebalanced.\n" +
+                    "Streams1[" + streams1.allMetadata() + "]\n" +
+                    "Streams2[" + streams2.allMetadata() + "]");
 
             final List<KeyValue<Long, Long>> committedRecordsAfterRebalance = readResult(
-                uncommittedDataBeforeGC.size() + dataToTriggerFirstRebalance.size(),
+                uncommittedDataBeforeStall.size() + dataToTriggerFirstRebalance.size(),
                 CONSUMER_GROUP_ID);
 
             final List<KeyValue<Long, Long>> expectedCommittedRecordsAfterRebalance = new ArrayList<>();
-            expectedCommittedRecordsAfterRebalance.addAll(uncommittedDataBeforeGC);
+            expectedCommittedRecordsAfterRebalance.addAll(uncommittedDataBeforeStall);
             expectedCommittedRecordsAfterRebalance.addAll(dataToTriggerFirstRebalance);
 
             checkResultPerKey(committedRecordsAfterRebalance, expectedCommittedRecordsAfterRebalance);
 
-            doGC = false;
+            LOG.info("Releasing Stall");
+            doStall = false;
+            // Once the stalling host rejoins the group, we expect both instances to see both instances.
+            // It doesn't really matter what the assignment is, but we might as well also assert that they
+            // both see both partitions assigned exactly once
             waitForCondition(
-                () -> streams1.allMetadata().size() == 1
-                    && streams2.allMetadata().size() == 1
-                    && streams1.allMetadata().iterator().next().topicPartitions().size() == 1
-                    && streams2.allMetadata().iterator().next().topicPartitions().size() == 1,
+                () -> streams1.allMetadata().size() == 2
+                    && streams2.allMetadata().size() == 2
+                    && streams1.allMetadata().stream().mapToLong(meta -> meta.topicPartitions().size()).sum() == 2
+                    && streams2.allMetadata().stream().mapToLong(meta -> meta.topicPartitions().size()).sum() == 2,
                 MAX_WAIT_TIME_MS,
-                "Should have rebalanced.");
+                () -> "Should have rebalanced.\n" +
+                    "Streams1[" + streams1.allMetadata() + "]\n" +
+                    "Streams2[" + streams2.allMetadata() + "]");
 
             writeInputData(dataAfterSecondRebalance);
 
             final List<KeyValue<Long, Long>> allCommittedRecords = readResult(
-                committedDataBeforeGC.size() + uncommittedDataBeforeGC.size()
+                committedDataBeforeStall.size() + uncommittedDataBeforeStall.size()
                 + dataToTriggerFirstRebalance.size() + dataAfterSecondRebalance.size(),
                 CONSUMER_GROUP_ID + "_ALL");
 
             final List<KeyValue<Long, Long>> allExpectedCommittedRecordsAfterRecovery = new ArrayList<>();
-            allExpectedCommittedRecordsAfterRecovery.addAll(committedDataBeforeGC);
-            allExpectedCommittedRecordsAfterRecovery.addAll(uncommittedDataBeforeGC);
+            allExpectedCommittedRecordsAfterRecovery.addAll(committedDataBeforeStall);
+            allExpectedCommittedRecordsAfterRecovery.addAll(uncommittedDataBeforeStall);
             allExpectedCommittedRecordsAfterRecovery.addAll(dataToTriggerFirstRebalance);
             allExpectedCommittedRecordsAfterRecovery.addAll(dataAfterSecondRebalance);
 
@@ -614,13 +649,15 @@ public class EosIntegrationTest {
         return data;
     }
 
-    private KafkaStreams getKafkaStreams(final boolean withState,
+    private KafkaStreams getKafkaStreams(final String dummyHostName,
+                                         final boolean withState,
                                          final String appDir,
                                          final int numberOfStreamsThreads,
                                          final String eosConfig) {
         commitRequested = new AtomicInteger(0);
         errorInjected = new AtomicBoolean(false);
-        gcInjected = new AtomicBoolean(false);
+        stallInjected = new AtomicBoolean(false);
+        stallingHost = new AtomicReference<>();
         final StreamsBuilder builder = new StreamsBuilder();
 
         String[] storeNames = new String[0];
@@ -653,8 +690,10 @@ public class EosIntegrationTest {
 
                     @Override
                     public KeyValue<Long, Long> transform(final Long key, final Long value) {
-                        if (gcInjected.compareAndSet(true, false)) {
-                            while (doGC) {
+                        if (stallInjected.compareAndSet(true, false)) {
+                            LOG.info(dummyHostName + " is executing the injected stall");
+                            stallingHost.set(dummyHostName);
+                            while (doStall) {
                                 final StreamThread thread = (StreamThread) Thread.currentThread();
                                 if (thread.isInterrupted() || !thread.isRunning()) {
                                     throw new RuntimeException("Detected we've been interrupted.");
@@ -714,7 +753,7 @@ public class EosIntegrationTest {
         properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG), MAX_POLL_INTERVAL_MS);
         properties.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0);
         properties.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath() + File.separator + appDir);
-        properties.put(StreamsConfig.APPLICATION_SERVER_CONFIG, "dummy:2142");
+        properties.put(StreamsConfig.APPLICATION_SERVER_CONFIG, dummyHostName + ":2142");
 
         final Properties config = StreamsTestUtils.getStreamsConfig(
             applicationId,
diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/LagFetchIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/LagFetchIntegrationTest.java
index 14c8e4a..f5143e1 100644
--- a/streams/src/test/java/org/apache/kafka/streams/integration/LagFetchIntegrationTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/integration/LagFetchIntegrationTest.java
@@ -16,27 +16,6 @@
  */
 package org.apache.kafka.streams.integration;
 
-import static org.apache.kafka.common.utils.Utils.mkSet;
-import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.startApplicationAndWaitUntilRunning;
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.core.IsEqual.equalTo;
-import static org.junit.Assert.assertTrue;
-
-import java.io.File;
-import java.nio.file.Files;
-import java.nio.file.Path;
-import java.time.Duration;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.Comparator;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Properties;
-import java.util.concurrent.CountDownLatch;
-import java.util.concurrent.CyclicBarrier;
-import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicReference;
 import kafka.utils.MockTime;
 import org.apache.kafka.clients.consumer.ConsumerConfig;
 import org.apache.kafka.common.TopicPartition;
@@ -57,6 +36,7 @@ import org.apache.kafka.streams.kstream.KTable;
 import org.apache.kafka.streams.kstream.Materialized;
 import org.apache.kafka.streams.processor.StateRestoreListener;
 import org.apache.kafka.streams.processor.internals.StreamThread;
+import org.apache.kafka.streams.processor.internals.assignment.FallbackPriorTaskAssignor;
 import org.apache.kafka.test.IntegrationTest;
 import org.apache.kafka.test.TestUtils;
 import org.junit.After;
@@ -69,6 +49,28 @@ import org.junit.rules.TestName;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.io.File;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Properties;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.CyclicBarrier;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
+
+import static org.apache.kafka.common.utils.Utils.mkSet;
+import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.startApplicationAndWaitUntilRunning;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.core.IsEqual.equalTo;
+import static org.junit.Assert.assertTrue;
+
 @Category({IntegrationTest.class})
 public class LagFetchIntegrationTest {
 
@@ -147,6 +149,9 @@ public class LagFetchIntegrationTest {
         // create stream threads
         for (int i = 0; i < 2; i++) {
             final Properties props = (Properties) streamsConfiguration.clone();
+            // this test relies on the second instance getting the standby, so we specify
+            // an assignor with this contract.
+            props.put(StreamsConfig.InternalConfig.INTERNAL_TASK_ASSIGNOR_CLASS, FallbackPriorTaskAssignor.class.getName());
             props.put(StreamsConfig.APPLICATION_SERVER_CONFIG, "localhost:" + i);
             props.put(StreamsConfig.CLIENT_ID_CONFIG, "instance-" + i);
             props.put(StreamsConfig.TOPOLOGY_OPTIMIZATION, optimization);
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/HighAvailabilityStreamsPartitionAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/HighAvailabilityStreamsPartitionAssignorTest.java
new file mode 100644
index 0000000..2c27d11
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/HighAvailabilityStreamsPartitionAssignorTest.java
@@ -0,0 +1,332 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.clients.admin.Admin;
+import org.apache.kafka.clients.admin.AdminClient;
+import org.apache.kafka.clients.admin.ListOffsetsResult;
+import org.apache.kafka.clients.admin.ListOffsetsResult.ListOffsetsResultInfo;
+import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.Assignment;
+import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.GroupSubscription;
+import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.Subscription;
+import org.apache.kafka.common.Cluster;
+import org.apache.kafka.common.Node;
+import org.apache.kafka.common.PartitionInfo;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.internals.KafkaFutureImpl;
+import org.apache.kafka.common.utils.MockTime;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.StreamsConfig.InternalConfig;
+import org.apache.kafka.streams.errors.StreamsException;
+import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.internals.assignment.AssignmentInfo;
+import org.apache.kafka.streams.processor.internals.assignment.AssignorError;
+import org.apache.kafka.streams.processor.internals.assignment.HighAvailabilityTaskAssignor;
+import org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo;
+import org.apache.kafka.test.MockClientSupplier;
+import org.apache.kafka.test.MockInternalTopicManager;
+import org.apache.kafka.test.MockKeyValueStoreBuilder;
+import org.apache.kafka.test.MockProcessorSupplier;
+import org.easymock.EasyMock;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Set;
+import java.util.UUID;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.stream.Collectors;
+
+import static java.util.Arrays.asList;
+import static java.util.Collections.emptySet;
+import static java.util.Collections.singletonList;
+import static java.util.Collections.singletonMap;
+import static org.apache.kafka.common.utils.Utils.mkSet;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.EMPTY_CHANGELOG_END_OFFSETS;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.EMPTY_TASKS;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_0;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_1;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_2;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_1;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_2;
+import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.LATEST_SUPPORTED_VERSION;
+import static org.easymock.EasyMock.anyObject;
+import static org.easymock.EasyMock.expect;
+import static org.hamcrest.CoreMatchers.equalTo;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.anyOf;
+import static org.hamcrest.Matchers.empty;
+import static org.hamcrest.Matchers.is;
+
+public class HighAvailabilityStreamsPartitionAssignorTest {
+
+    private final List<PartitionInfo> infos = asList(
+        new PartitionInfo("topic1", 0, Node.noNode(), new Node[0], new Node[0]),
+        new PartitionInfo("topic1", 1, Node.noNode(), new Node[0], new Node[0]),
+        new PartitionInfo("topic1", 2, Node.noNode(), new Node[0], new Node[0]),
+        new PartitionInfo("topic2", 0, Node.noNode(), new Node[0], new Node[0]),
+        new PartitionInfo("topic2", 1, Node.noNode(), new Node[0], new Node[0]),
+        new PartitionInfo("topic2", 2, Node.noNode(), new Node[0], new Node[0]),
+        new PartitionInfo("topic3", 0, Node.noNode(), new Node[0], new Node[0]),
+        new PartitionInfo("topic3", 1, Node.noNode(), new Node[0], new Node[0]),
+        new PartitionInfo("topic3", 2, Node.noNode(), new Node[0], new Node[0]),
+        new PartitionInfo("topic3", 3, Node.noNode(), new Node[0], new Node[0])
+    );
+
+    private final Cluster metadata = new Cluster(
+        "cluster",
+        singletonList(Node.noNode()),
+        infos,
+        emptySet(),
+        emptySet());
+
+    private final StreamsPartitionAssignor partitionAssignor = new StreamsPartitionAssignor();
+    private final MockClientSupplier mockClientSupplier = new MockClientSupplier();
+    private static final String USER_END_POINT = "localhost:8080";
+    private static final String APPLICATION_ID = "stream-partition-assignor-test";
+
+    private TaskManager taskManager;
+    private Admin adminClient;
+    private StreamsConfig streamsConfig = new StreamsConfig(configProps());
+    private final InternalTopologyBuilder builder = new InternalTopologyBuilder();
+    private final StreamsMetadataState streamsMetadataState = EasyMock.createNiceMock(StreamsMetadataState.class);
+    private final Map<String, Subscription> subscriptions = new HashMap<>();
+
+    private final AtomicInteger assignmentError = new AtomicInteger();
+    private final AtomicLong nextProbingRebalanceMs = new AtomicLong(Long.MAX_VALUE);
+    private final MockTime time = new MockTime();
+
+    private Map<String, Object> configProps() {
+        final Map<String, Object> configurationMap = new HashMap<>();
+        configurationMap.put(StreamsConfig.APPLICATION_ID_CONFIG, APPLICATION_ID);
+        configurationMap.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, USER_END_POINT);
+        configurationMap.put(InternalConfig.TASK_MANAGER_FOR_PARTITION_ASSIGNOR, taskManager);
+        configurationMap.put(InternalConfig.STREAMS_METADATA_STATE_FOR_PARTITION_ASSIGNOR, streamsMetadataState);
+        configurationMap.put(InternalConfig.STREAMS_ADMIN_CLIENT, adminClient);
+        configurationMap.put(InternalConfig.ASSIGNMENT_ERROR_CODE, assignmentError);
+        configurationMap.put(InternalConfig.NEXT_PROBING_REBALANCE_MS, nextProbingRebalanceMs);
+        configurationMap.put(InternalConfig.TIME, time);
+        configurationMap.put(InternalConfig.INTERNAL_TASK_ASSIGNOR_CLASS, HighAvailabilityTaskAssignor.class.getName());
+        return configurationMap;
+    }
+
+    // Make sure to complete setting up any mocks (such as TaskManager or AdminClient) before configuring the assignor
+    private void configurePartitionAssignorWith(final Map<String, Object> props) {
+        final Map<String, Object> configMap = configProps();
+        configMap.putAll(props);
+
+        streamsConfig = new StreamsConfig(configMap);
+        partitionAssignor.configure(configMap);
+        EasyMock.replay(taskManager, adminClient);
+
+        overwriteInternalTopicManagerWithMock();
+    }
+
+    // Useful for tests that don't care about the task offset sums
+    private void createMockTaskManager(final Set<TaskId> activeTasks) {
+        createMockTaskManager(getTaskOffsetSums(activeTasks));
+    }
+
+    private void createMockTaskManager(final Map<TaskId, Long> taskOffsetSums) {
+        taskManager = EasyMock.createNiceMock(TaskManager.class);
+        expect(taskManager.builder()).andReturn(builder).anyTimes();
+        expect(taskManager.getTaskOffsetSums()).andReturn(taskOffsetSums).anyTimes();
+        expect(taskManager.processId()).andReturn(UUID_1).anyTimes();
+        builder.setApplicationId(APPLICATION_ID);
+        builder.buildTopology();
+    }
+
+    // If you don't care about setting the end offsets for each specific topic partition, the helper method
+    // getTopicPartitionOffsetMap is useful for building this input map for all partitions
+    private void createMockAdminClient(final Map<TopicPartition, Long> changelogEndOffsets) {
+        adminClient = EasyMock.createMock(AdminClient.class);
+
+        final ListOffsetsResult result = EasyMock.createNiceMock(ListOffsetsResult.class);
+        final KafkaFutureImpl<Map<TopicPartition, ListOffsetsResultInfo>> allFuture = new KafkaFutureImpl<>();
+        allFuture.complete(changelogEndOffsets.entrySet().stream().collect(Collectors.toMap(
+            Entry::getKey,
+            t -> {
+                final ListOffsetsResultInfo info = EasyMock.createNiceMock(ListOffsetsResultInfo.class);
+                expect(info.offset()).andStubReturn(t.getValue());
+                EasyMock.replay(info);
+                return info;
+            }))
+        );
+
+        expect(adminClient.listOffsets(anyObject())).andStubReturn(result);
+        expect(result.all()).andReturn(allFuture);
+
+        EasyMock.replay(result);
+    }
+
+    private void overwriteInternalTopicManagerWithMock() {
+        final MockInternalTopicManager mockInternalTopicManager = new MockInternalTopicManager(streamsConfig, mockClientSupplier.restoreConsumer);
+        partitionAssignor.setInternalTopicManager(mockInternalTopicManager);
+    }
+
+    @Before
+    public void setUp() {
+        createMockAdminClient(EMPTY_CHANGELOG_END_OFFSETS);
+    }
+
+
+    @Test
+    public void shouldReturnAllActiveTasksToPreviousOwnerRegardlessOfBalanceAndTriggerRebalanceIfEndOffsetFetchFailsAndHighAvailabilityEnabled() {
+        final long rebalanceInterval = 5 * 60 * 1000L;
+
+        builder.addSource(null, "source1", null, null, null, "topic1");
+        builder.addProcessor("processor1", new MockProcessorSupplier<>(), "source1");
+        builder.addStateStore(new MockKeyValueStoreBuilder("store1", false), "processor1");
+        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2);
+
+        createMockTaskManager(allTasks);
+        adminClient = EasyMock.createMock(AdminClient.class);
+        expect(adminClient.listOffsets(anyObject())).andThrow(new StreamsException("Should be handled"));
+        configurePartitionAssignorWith(singletonMap(StreamsConfig.PROBING_REBALANCE_INTERVAL_MS_CONFIG, rebalanceInterval));
+
+        final String firstConsumer = "consumer1";
+        final String newConsumer = "consumer2";
+
+        subscriptions.put(firstConsumer,
+                          new Subscription(
+                              singletonList("source1"),
+                              getInfo(UUID_1, allTasks).encode()
+                          ));
+        subscriptions.put(newConsumer,
+                          new Subscription(
+                              singletonList("source1"),
+                              getInfo(UUID_2, EMPTY_TASKS).encode()
+                          ));
+
+        final Map<String, Assignment> assignments = partitionAssignor
+            .assign(metadata, new GroupSubscription(subscriptions))
+            .groupAssignment();
+
+        final AssignmentInfo firstConsumerUserData = AssignmentInfo.decode(assignments.get(firstConsumer).userData());
+        final List<TaskId> firstConsumerActiveTasks = firstConsumerUserData.activeTasks();
+        final AssignmentInfo newConsumerUserData = AssignmentInfo.decode(assignments.get(newConsumer).userData());
+        final List<TaskId> newConsumerActiveTasks = newConsumerUserData.activeTasks();
+
+        // The tasks were returned to their prior owner
+        assertThat(firstConsumerActiveTasks, equalTo(new ArrayList<>(allTasks)));
+        assertThat(newConsumerActiveTasks, empty());
+
+        // There is a rebalance scheduled
+        assertThat(
+            time.milliseconds() + rebalanceInterval,
+            anyOf(
+                is(firstConsumerUserData.nextRebalanceMs()),
+                is(newConsumerUserData.nextRebalanceMs())
+            )
+        );
+    }
+
+    @Test
+    public void shouldScheduleProbingRebalanceOnThisClientIfWarmupTasksRequired() {
+        final long rebalanceInterval = 5 * 60 * 1000L;
+
+        builder.addSource(null, "source1", null, null, null, "topic1");
+        builder.addProcessor("processor1", new MockProcessorSupplier<>(), "source1");
+        builder.addStateStore(new MockKeyValueStoreBuilder("store1", false), "processor1");
+        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2);
+
+        createMockTaskManager(allTasks);
+        createMockAdminClient(getTopicPartitionOffsetsMap(
+            singletonList(APPLICATION_ID + "-store1-changelog"),
+            singletonList(3)));
+        configurePartitionAssignorWith(singletonMap(StreamsConfig.PROBING_REBALANCE_INTERVAL_MS_CONFIG, rebalanceInterval));
+
+        final String firstConsumer = "consumer1";
+        final String newConsumer = "consumer2";
+
+        subscriptions.put(firstConsumer,
+                          new Subscription(
+                              singletonList("source1"),
+                              getInfo(UUID_1, allTasks).encode()
+                          ));
+        subscriptions.put(newConsumer,
+                          new Subscription(
+                              singletonList("source1"),
+                              getInfo(UUID_2, EMPTY_TASKS).encode()
+                          ));
+
+        final Map<String, Assignment> assignments = partitionAssignor
+            .assign(metadata, new GroupSubscription(subscriptions))
+            .groupAssignment();
+
+        final List<TaskId> firstConsumerActiveTasks =
+            AssignmentInfo.decode(assignments.get(firstConsumer).userData()).activeTasks();
+        final List<TaskId> newConsumerActiveTasks =
+            AssignmentInfo.decode(assignments.get(newConsumer).userData()).activeTasks();
+
+        assertThat(firstConsumerActiveTasks, equalTo(new ArrayList<>(allTasks)));
+        assertThat(newConsumerActiveTasks, empty());
+
+        assertThat(assignmentError.get(), equalTo(AssignorError.NONE.code()));
+
+        final long nextScheduledRebalanceOnThisClient =
+            AssignmentInfo.decode(assignments.get(firstConsumer).userData()).nextRebalanceMs();
+        final long nextScheduledRebalanceOnOtherClient =
+            AssignmentInfo.decode(assignments.get(newConsumer).userData()).nextRebalanceMs();
+
+        assertThat(nextScheduledRebalanceOnThisClient, equalTo(time.milliseconds() + rebalanceInterval));
+        assertThat(nextScheduledRebalanceOnOtherClient, equalTo(Long.MAX_VALUE));
+    }
+
+
+    /**
+     * Helper for building the input to createMockAdminClient in cases where we don't care about the actual offsets
+     * @param changelogTopics The names of all changelog topics in the topology
+     * @param topicsNumPartitions The number of partitions for the corresponding changelog topic, such that the number
+     *            of partitions of the ith topic in changelogTopics is given by the ith element of topicsNumPartitions
+     */
+    private static Map<TopicPartition, Long> getTopicPartitionOffsetsMap(final List<String> changelogTopics,
+                                                                         final List<Integer> topicsNumPartitions) {
+        if (changelogTopics.size() != topicsNumPartitions.size()) {
+            throw new IllegalStateException("Passed in " + changelogTopics.size() + " changelog topic names, but " +
+                                                topicsNumPartitions.size() + " different numPartitions for the topics");
+        }
+        final Map<TopicPartition, Long> changelogEndOffsets = new HashMap<>();
+        for (int i = 0; i < changelogTopics.size(); ++i) {
+            final String topic = changelogTopics.get(i);
+            final int numPartitions = topicsNumPartitions.get(i);
+            for (int partition = 0; partition < numPartitions; ++partition) {
+                changelogEndOffsets.put(new TopicPartition(topic, partition), Long.MAX_VALUE);
+            }
+        }
+        return changelogEndOffsets;
+    }
+
+    private static SubscriptionInfo getInfo(final UUID processId,
+                                            final Set<TaskId> prevTasks) {
+        return new SubscriptionInfo(
+            LATEST_SUPPORTED_VERSION, LATEST_SUPPORTED_VERSION, processId, null, getTaskOffsetSums(prevTasks));
+    }
+
+    // Stub offset sums for when we only care about the prev/standby task sets, not the actual offsets
+    private static Map<TaskId, Long> getTaskOffsetSums(final Set<TaskId> activeTasks) {
+        final Map<TaskId, Long> taskOffsetSums = activeTasks.stream().collect(Collectors.toMap(t -> t, t -> Task.LATEST_OFFSET));
+        taskOffsetSums.putAll(EMPTY_TASKS.stream().collect(Collectors.toMap(t -> t, t -> 0L)));
+        return taskOffsetSums;
+    }
+
+}
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignorTest.java
index d576b69d..814bbcf 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignorTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignorTest.java
@@ -16,8 +16,6 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
-import java.util.Map.Entry;
-import java.util.concurrent.atomic.AtomicLong;
 import org.apache.kafka.clients.admin.Admin;
 import org.apache.kafka.clients.admin.AdminClient;
 import org.apache.kafka.clients.admin.AdminClientConfig;
@@ -40,7 +38,6 @@ import org.apache.kafka.streams.StreamsBuilder;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.StreamsConfig.InternalConfig;
 import org.apache.kafka.streams.TopologyWrapper;
-import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.kstream.JoinWindows;
 import org.apache.kafka.streams.kstream.KStream;
 import org.apache.kafka.streams.kstream.KTable;
@@ -52,7 +49,11 @@ import org.apache.kafka.streams.processor.internals.assignment.AssignmentInfo;
 import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration;
 import org.apache.kafka.streams.processor.internals.assignment.AssignorError;
 import org.apache.kafka.streams.processor.internals.assignment.ClientState;
+import org.apache.kafka.streams.processor.internals.assignment.HighAvailabilityTaskAssignor;
+import org.apache.kafka.streams.processor.internals.assignment.FallbackPriorTaskAssignor;
+import org.apache.kafka.streams.processor.internals.assignment.StickyTaskAssignor;
 import org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo;
+import org.apache.kafka.streams.processor.internals.assignment.TaskAssignor;
 import org.apache.kafka.streams.state.HostInfo;
 import org.apache.kafka.test.MockClientSupplier;
 import org.apache.kafka.test.MockInternalTopicManager;
@@ -61,6 +62,8 @@ import org.apache.kafka.test.MockProcessorSupplier;
 import org.easymock.Capture;
 import org.easymock.EasyMock;
 import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
 
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
@@ -70,12 +73,12 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Map.Entry;
 import java.util.Set;
 import java.util.UUID;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicLong;
 import java.util.stream.Collectors;
-import org.junit.runner.RunWith;
-import org.junit.runners.Parameterized;
 
 import static java.time.Duration.ofMillis;
 import static java.util.Arrays.asList;
@@ -83,17 +86,11 @@ import static java.util.Collections.emptyList;
 import static java.util.Collections.emptyMap;
 import static java.util.Collections.emptySet;
 import static java.util.Collections.singletonList;
-import static java.util.Collections.singletonMap;
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.common.utils.Utils.mkSet;
-import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.EMPTY_TASKS;
-import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_1;
-import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_2;
-import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_3;
-import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_1;
-import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_2;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.EMPTY_CHANGELOG_END_OFFSETS;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.EMPTY_TASKS;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.EMPTY_TASK_OFFSET_SUMS;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_0;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_1;
@@ -104,13 +101,18 @@ import static org.apache.kafka.streams.processor.internals.assignment.Assignment
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_2;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_3;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_0;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_1;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_2;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_3;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_1;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_2;
 import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.LATEST_SUPPORTED_VERSION;
 import static org.easymock.EasyMock.anyObject;
 import static org.easymock.EasyMock.expect;
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.CoreMatchers.not;
-import static org.hamcrest.Matchers.is;
 import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.is;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotEquals;
@@ -192,11 +194,10 @@ public class StreamsPartitionAssignorTest {
 
     private TaskManager taskManager;
     private Admin adminClient;
-    private StreamsConfig streamsConfig = new StreamsConfig(configProps());
     private InternalTopologyBuilder builder = new InternalTopologyBuilder();
     private StreamsMetadataState streamsMetadataState = EasyMock.createNiceMock(StreamsMetadataState.class);
     private final Map<String, Subscription> subscriptions = new HashMap<>();
-    private final boolean highAvailabilityEnabled;
+    private final Class<? extends TaskAssignor> taskAssignor;
 
     private final AtomicInteger assignmentError = new AtomicInteger();
     private final AtomicLong nextProbingRebalanceMs = new AtomicLong(Long.MAX_VALUE);
@@ -206,13 +207,13 @@ public class StreamsPartitionAssignorTest {
         final Map<String, Object> configurationMap = new HashMap<>();
         configurationMap.put(StreamsConfig.APPLICATION_ID_CONFIG, APPLICATION_ID);
         configurationMap.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, USER_END_POINT);
-        configurationMap.put(StreamsConfig.InternalConfig.TASK_MANAGER_FOR_PARTITION_ASSIGNOR, taskManager);
-        configurationMap.put(StreamsConfig.InternalConfig.STREAMS_METADATA_STATE_FOR_PARTITION_ASSIGNOR, streamsMetadataState);
-        configurationMap.put(StreamsConfig.InternalConfig.STREAMS_ADMIN_CLIENT, adminClient);
-        configurationMap.put(StreamsConfig.InternalConfig.ASSIGNMENT_ERROR_CODE, assignmentError);
-        configurationMap.put(StreamsConfig.InternalConfig.NEXT_PROBING_REBALANCE_MS, nextProbingRebalanceMs);
+        configurationMap.put(InternalConfig.TASK_MANAGER_FOR_PARTITION_ASSIGNOR, taskManager);
+        configurationMap.put(InternalConfig.STREAMS_METADATA_STATE_FOR_PARTITION_ASSIGNOR, streamsMetadataState);
+        configurationMap.put(InternalConfig.STREAMS_ADMIN_CLIENT, adminClient);
+        configurationMap.put(InternalConfig.ASSIGNMENT_ERROR_CODE, assignmentError);
+        configurationMap.put(InternalConfig.NEXT_PROBING_REBALANCE_MS, nextProbingRebalanceMs);
         configurationMap.put(InternalConfig.TIME, time);
-        configurationMap.put(AssignorConfiguration.HIGH_AVAILABILITY_ENABLED_CONFIG, highAvailabilityEnabled);
+        configurationMap.put(InternalConfig.INTERNAL_TASK_ASSIGNOR_CLASS, taskAssignor.getName());
         return configurationMap;
     }
 
@@ -231,7 +232,6 @@ public class StreamsPartitionAssignorTest {
         final Map<String, Object> configMap = configProps();
         configMap.putAll(props);
 
-        streamsConfig = new StreamsConfig(configMap);
         partitionAssignor.configure(configMap);
         EasyMock.replay(taskManager, adminClient);
 
@@ -282,21 +282,23 @@ public class StreamsPartitionAssignorTest {
     }
 
     private MockInternalTopicManager overwriteInternalTopicManagerWithMock() {
-        final MockInternalTopicManager mockInternalTopicManager = new MockInternalTopicManager(streamsConfig, mockClientSupplier.restoreConsumer);
+        final MockInternalTopicManager mockInternalTopicManager =
+            new MockInternalTopicManager(new StreamsConfig(configProps()), mockClientSupplier.restoreConsumer);
         partitionAssignor.setInternalTopicManager(mockInternalTopicManager);
         return mockInternalTopicManager;
     }
 
-    @Parameterized.Parameters(name = "high availability enabled = {0}")
+    @Parameterized.Parameters(name = "task assignor = {0}")
     public static Collection<Object[]> parameters() {
         return asList(
-            new Object[]{true},
-            new Object[]{false}
+            new Object[]{HighAvailabilityTaskAssignor.class},
+            new Object[]{StickyTaskAssignor.class},
+            new Object[]{FallbackPriorTaskAssignor.class}
             );
     }
 
-    public StreamsPartitionAssignorTest(final boolean highAvailabilityEnabled) {
-        this.highAvailabilityEnabled = highAvailabilityEnabled;
+    public StreamsPartitionAssignorTest(final Class<? extends TaskAssignor> taskAssignor) {
+        this.taskAssignor = taskAssignor;
         createMockAdminClient(EMPTY_CHANGELOG_END_OFFSETS);
     }
 
@@ -1433,7 +1435,7 @@ public class StreamsPartitionAssignorTest {
     @Test
     public void shouldThrowKafkaExceptionIfTaskMangerNotConfigured() {
         final Map<String, Object> config = configProps();
-        config.remove(StreamsConfig.InternalConfig.TASK_MANAGER_FOR_PARTITION_ASSIGNOR);
+        config.remove(InternalConfig.TASK_MANAGER_FOR_PARTITION_ASSIGNOR);
 
         try {
             partitionAssignor.configure(config);
@@ -1446,7 +1448,7 @@ public class StreamsPartitionAssignorTest {
     @Test
     public void shouldThrowKafkaExceptionIfTaskMangerConfigIsNotTaskManagerInstance() {
         final Map<String, Object> config = configProps();
-        config.put(StreamsConfig.InternalConfig.TASK_MANAGER_FOR_PARTITION_ASSIGNOR, "i am not a task manager");
+        config.put(InternalConfig.TASK_MANAGER_FOR_PARTITION_ASSIGNOR, "i am not a task manager");
 
         try {
             partitionAssignor.configure(config);
@@ -1461,7 +1463,7 @@ public class StreamsPartitionAssignorTest {
     public void shouldThrowKafkaExceptionAssignmentErrorCodeNotConfigured() {
         createDefaultMockTaskManager();
         final Map<String, Object> config = configProps();
-        config.remove(StreamsConfig.InternalConfig.ASSIGNMENT_ERROR_CODE);
+        config.remove(InternalConfig.ASSIGNMENT_ERROR_CODE);
 
         try {
             partitionAssignor.configure(config);
@@ -1475,7 +1477,7 @@ public class StreamsPartitionAssignorTest {
     public void shouldThrowKafkaExceptionIfVersionProbingFlagConfigIsNotAtomicInteger() {
         createDefaultMockTaskManager();
         final Map<String, Object> config = configProps();
-        config.put(StreamsConfig.InternalConfig.ASSIGNMENT_ERROR_CODE, "i am not an AtomicInteger");
+        config.put(InternalConfig.ASSIGNMENT_ERROR_CODE, "i am not an AtomicInteger");
 
         try {
             partitionAssignor.configure(config);
@@ -1832,102 +1834,6 @@ public class StreamsPartitionAssignorTest {
         assertThrows(IllegalStateException.class, () -> partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)));
     }
 
-    @Test
-    public void shouldReturnAllActiveTasksToPreviousOwnerRegardlessOfBalanceAndTriggerRebalanceIfEndOffsetFetchFailsAndHighAvailabilityEnabled() {
-        if (highAvailabilityEnabled) {
-            builder.addSource(null, "source1", null, null, null, "topic1");
-            builder.addProcessor("processor1", new MockProcessorSupplier<>(), "source1");
-            builder.addStateStore(new MockKeyValueStoreBuilder("store1", false), "processor1");
-            final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2);
-
-            createMockTaskManager(allTasks, EMPTY_TASKS);
-            adminClient = EasyMock.createMock(AdminClient.class);
-            expect(adminClient.listOffsets(anyObject())).andThrow(new StreamsException("Should be handled"));
-            configureDefaultPartitionAssignor();
-
-            final String firstConsumer = "consumer1";
-            final String newConsumer = "consumer2";
-
-            subscriptions.put(firstConsumer,
-                new Subscription(
-                    singletonList("source1"),
-                    getInfo(UUID_1, allTasks, EMPTY_TASKS).encode()
-                ));
-            subscriptions.put(newConsumer,
-                new Subscription(
-                    singletonList("source1"),
-                    getInfo(UUID_2, EMPTY_TASKS, EMPTY_TASKS).encode()
-                ));
-
-            final Map<String, Assignment> assignments = partitionAssignor
-                                                            .assign(metadata, new GroupSubscription(subscriptions))
-                                                            .groupAssignment();
-
-            final List<TaskId> firstConsumerActiveTasks =
-                AssignmentInfo.decode(assignments.get(firstConsumer).userData()).activeTasks();
-            final List<TaskId> newConsumerActiveTasks =
-                AssignmentInfo.decode(assignments.get(newConsumer).userData()).activeTasks();
-
-            assertThat(firstConsumerActiveTasks, equalTo(new ArrayList<>(allTasks)));
-            assertTrue(newConsumerActiveTasks.isEmpty());
-            assertThat(assignmentError.get(), equalTo(AssignorError.REBALANCE_NEEDED.code()));
-        }
-    }
-
-    @Test
-    public void shouldScheduleProbingRebalanceOnThisClientIfWarmupTasksRequired() {
-        if (highAvailabilityEnabled) {
-            final long rebalanceInterval =  5 * 60 * 1000L;
-
-            builder.addSource(null, "source1", null, null, null, "topic1");
-            builder.addProcessor("processor1", new MockProcessorSupplier<>(), "source1");
-            builder.addStateStore(new MockKeyValueStoreBuilder("store1", false), "processor1");
-            final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2);
-
-            createMockTaskManager(allTasks, EMPTY_TASKS);
-            createMockAdminClient(getTopicPartitionOffsetsMap(
-                singletonList(APPLICATION_ID + "-store1-changelog"),
-                singletonList(3)));
-            configurePartitionAssignorWith(singletonMap(StreamsConfig.PROBING_REBALANCE_INTERVAL_MS_CONFIG, rebalanceInterval));
-
-            final String firstConsumer = "consumer1";
-            final String newConsumer = "consumer2";
-
-            subscriptions.put(firstConsumer,
-                new Subscription(
-                    singletonList("source1"),
-                    getInfo(UUID_1, allTasks, EMPTY_TASKS).encode()
-                ));
-            subscriptions.put(newConsumer,
-                new Subscription(
-                    singletonList("source1"),
-                    getInfo(UUID_2, EMPTY_TASKS, EMPTY_TASKS).encode()
-                ));
-
-            final Map<String, Assignment> assignments = partitionAssignor
-                                                            .assign(metadata, new GroupSubscription(subscriptions))
-                                                            .groupAssignment();
-
-            final List<TaskId> firstConsumerActiveTasks =
-                AssignmentInfo.decode(assignments.get(firstConsumer).userData()).activeTasks();
-            final List<TaskId> newConsumerActiveTasks =
-                AssignmentInfo.decode(assignments.get(newConsumer).userData()).activeTasks();
-
-            assertThat(firstConsumerActiveTasks, equalTo(new ArrayList<>(allTasks)));
-            assertTrue(newConsumerActiveTasks.isEmpty());
-
-            assertThat(assignmentError.get(), equalTo(AssignorError.NONE.code()));
-
-            final long nextScheduledRebalanceOnThisClient =
-                AssignmentInfo.decode(assignments.get(firstConsumer).userData()).nextRebalanceMs();
-            final long nextScheduledRebalanceOnOtherClient =
-                AssignmentInfo.decode(assignments.get(newConsumer).userData()).nextRebalanceMs();
-
-            assertThat(nextScheduledRebalanceOnThisClient, equalTo(time.milliseconds() + rebalanceInterval));
-            assertThat(nextScheduledRebalanceOnOtherClient, equalTo(Long.MAX_VALUE));
-        }
-    }
-
     private static ByteBuffer encodeFutureSubscription() {
         final ByteBuffer buf = ByteBuffer.allocate(4 /* used version */ + 4 /* supported version */);
         buf.putInt(LATEST_SUPPORTED_VERSION + 1);
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java
index cb32155..ac9dafe 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java
@@ -16,22 +16,26 @@
  */
 package org.apache.kafka.streams.processor.internals.assignment;
 
-import java.util.Map;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.internals.Task;
 import org.junit.Test;
 
 import java.util.Collections;
+import java.util.Map;
 
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.common.utils.Utils.mkSet;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_0;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_1;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_2;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_3;
 import static org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo.UNKNOWN_OFFSET_SUM;
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.empty;
+import static org.hamcrest.Matchers.is;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
@@ -42,6 +46,33 @@ public class ClientStateTest {
     private final ClientState zeroCapacityClient = new ClientState(0);
 
     @Test
+    public void previousStateConstructorShouldCreateAValidObject() {
+        final ClientState clientState = new ClientState(
+            mkSet(TASK_0_0, TASK_0_1),
+            mkSet(TASK_0_2, TASK_0_3),
+            mkMap(mkEntry(TASK_0_0, 5L), mkEntry(TASK_0_2, -1L)),
+            4
+        );
+
+        // all the "next assignment" fields should be empty
+        assertThat(clientState.activeTaskCount(), is(0));
+        assertThat(clientState.activeTaskLoad(), is(0.0));
+        assertThat(clientState.activeTasks(), is(empty()));
+        assertThat(clientState.standbyTaskCount(), is(0));
+        assertThat(clientState.standbyTasks(), is(empty()));
+        assertThat(clientState.assignedTaskCount(), is(0));
+        assertThat(clientState.assignedTasks(), is(empty()));
+
+        // and the "previous assignment" fields should match the constructor args
+        assertThat(clientState.prevActiveTasks(), is(mkSet(TASK_0_0, TASK_0_1)));
+        assertThat(clientState.prevStandbyTasks(), is(mkSet(TASK_0_2, TASK_0_3)));
+        assertThat(clientState.previousAssignedTasks(), is(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3)));
+        assertThat(clientState.capacity(), is(4));
+        assertThat(clientState.lagFor(TASK_0_0), is(5L));
+        assertThat(clientState.lagFor(TASK_0_2), is(-1L));
+    }
+
+    @Test
     public void shouldHaveNotReachedCapacityWhenAssignedTasksLessThanCapacity() {
         assertFalse(client.reachedCapacity());
     }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/FallbackPriorTaskAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/FallbackPriorTaskAssignorTest.java
new file mode 100644
index 0000000..687e5b6
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/FallbackPriorTaskAssignorTest.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import org.apache.kafka.streams.processor.TaskId;
+import org.junit.Test;
+
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.TreeMap;
+import java.util.UUID;
+
+import static java.util.Arrays.asList;
+import static org.apache.kafka.common.utils.Utils.mkSet;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_0;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_1;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_2;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_1;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_2;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.empty;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.is;
+
+public class FallbackPriorTaskAssignorTest {
+
+    private final Map<UUID, ClientState> clients = new TreeMap<>();
+
+    @Test
+    public void shouldViolateBalanceToPreserveActiveTaskStickiness() {
+        final ClientState c1 = createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_0, TASK_0_1, TASK_0_2);
+        final ClientState c2 = createClient(UUID_2, 1);
+
+        final List<TaskId> taskIds = asList(TASK_0_0, TASK_0_1, TASK_0_2);
+        Collections.shuffle(taskIds);
+        final boolean probingRebalanceNeeded = new FallbackPriorTaskAssignor().assign(
+            clients,
+            new HashSet<>(taskIds),
+            new HashSet<>(taskIds),
+            new AssignorConfiguration.AssignmentConfigs(0L, 0, 0, 0, 0L)
+        );
+        assertThat(probingRebalanceNeeded, is(true));
+
+        assertThat(c1.activeTasks(), equalTo(mkSet(TASK_0_0, TASK_0_1, TASK_0_2)));
+        assertThat(c2.activeTasks(), empty());
+    }
+
+    private ClientState createClient(final UUID processId, final int capacity) {
+        return createClientWithPreviousActiveTasks(processId, capacity);
+    }
+
+    private ClientState createClientWithPreviousActiveTasks(final UUID processId, final int capacity, final TaskId... taskIds) {
+        final ClientState clientState = new ClientState(capacity);
+        clientState.addPreviousActiveTasks(mkSet(taskIds));
+        clients.put(processId, clientState);
+        return clientState;
+    }
+}
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignorTest.java
index 098c650..17d7c17 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignorTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignorTest.java
@@ -16,6 +16,19 @@
  */
 package org.apache.kafka.streams.processor.internals.assignment;
 
+import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs;
+import org.easymock.EasyMock;
+import org.junit.Test;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import java.util.UUID;
+import java.util.stream.Collectors;
+
+import static java.util.Collections.emptySet;
 import static java.util.Collections.singleton;
 import static java.util.Collections.singletonMap;
 import static org.apache.kafka.common.utils.Utils.mkEntry;
@@ -41,132 +54,107 @@ import static org.easymock.EasyMock.expect;
 import static org.easymock.EasyMock.replay;
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.MatcherAssert.assertThat;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
-
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.Map;
-import java.util.Set;
-import java.util.UUID;
-import org.apache.kafka.streams.processor.TaskId;
-import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs;
-import org.easymock.EasyMock;
-import org.junit.Test;
+import static org.hamcrest.Matchers.empty;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.not;
 
 public class HighAvailabilityTaskAssignorTest {
-    private long acceptableRecoveryLag = 100L;
-    private int balanceFactor = 1;
-    private int maxWarmupReplicas = 2;
-    private int numStandbyReplicas = 0;
-    private long probingRebalanceInterval = 60 * 1000L;
-
-    private Map<UUID, ClientState> clientStates = new HashMap<>();
-    private Set<TaskId> allTasks = new HashSet<>();
-    private Set<TaskId> statefulTasks = new HashSet<>();
-
-    private ClientState client1;
-    private ClientState client2;
-    private ClientState client3;
-    
-    private HighAvailabilityTaskAssignor taskAssignor;
-
-    private void createTaskAssignor() {
-        final AssignmentConfigs configs = new AssignmentConfigs(
-            acceptableRecoveryLag,
-            balanceFactor,
-            maxWarmupReplicas,
-            numStandbyReplicas,
-            probingRebalanceInterval
-        );
-        taskAssignor = new HighAvailabilityTaskAssignor(
-            clientStates,
-            allTasks,
-            statefulTasks,
-            configs);
-    }
+    private final AssignmentConfigs configWithoutStandbys = new AssignmentConfigs(
+        /*acceptableRecoveryLag*/ 100L,
+        /*balanceFactor*/ 1,
+        /*maxWarmupReplicas*/ 2,
+        /*numStandbyReplicas*/ 0,
+        /*probingRebalanceIntervalMs*/ 60 * 1000L
+    );
+
+    private final AssignmentConfigs configWithStandbys = new AssignmentConfigs(
+        /*acceptableRecoveryLag*/ 100L,
+        /*balanceFactor*/ 1,
+        /*maxWarmupReplicas*/ 2,
+        /*numStandbyReplicas*/ 1,
+        /*probingRebalanceIntervalMs*/ 60 * 1000L
+    );
 
-    @Test
-    public void shouldDecidePreviousAssignmentIsInvalidIfThereAreUnassignedActiveTasks() {
-        client1 = EasyMock.createNiceMock(ClientState.class);
-        expect(client1.prevActiveTasks()).andReturn(singleton(TASK_0_0));
-        expect(client1.prevStandbyTasks()).andStubReturn(EMPTY_TASKS);
-        replay(client1);
-        allTasks =  mkSet(TASK_0_0, TASK_0_1);
-        clientStates = singletonMap(UUID_1, client1);
-        createTaskAssignor();
 
-        assertFalse(taskAssignor.previousAssignmentIsValid());
+    @Test
+    public void shouldComputeNewAssignmentIfThereAreUnassignedActiveTasks() {
+        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1);
+        final ClientState client1 = new ClientState(singleton(TASK_0_0), emptySet(), singletonMap(TASK_0_0, 0L), 1);
+        final Map<UUID, ClientState> clientStates = singletonMap(UUID_1, client1);
+
+        final boolean probingRebalanceNeeded = new HighAvailabilityTaskAssignor().assign(clientStates,
+                                                                                         allTasks,
+                                                                                         singleton(TASK_0_0),
+                                                                                         configWithoutStandbys);
+
+        assertThat(clientStates.get(UUID_1).activeTasks(), not(singleton(TASK_0_0)));
+        assertThat(clientStates.get(UUID_1).standbyTasks(), empty());
+        assertThat(probingRebalanceNeeded, is(false));
     }
 
     @Test
-    public void shouldDecidePreviousAssignmentIsInvalidIfThereAreUnassignedStandbyTasks() {
-        client1 = EasyMock.createNiceMock(ClientState.class);
-        expect(client1.prevActiveTasks()).andStubReturn(singleton(TASK_0_0));
-        expect(client1.prevStandbyTasks()).andReturn(EMPTY_TASKS);
-        replay(client1);
-        allTasks =  mkSet(TASK_0_0);
-        statefulTasks =  mkSet(TASK_0_0);
-        clientStates = singletonMap(UUID_1, client1);
-        numStandbyReplicas = 1;
-        createTaskAssignor();
-
-        assertFalse(taskAssignor.previousAssignmentIsValid());
+    public void shouldComputeNewAssignmentIfThereAreUnassignedStandbyTasks() {
+        final Set<TaskId> allTasks = mkSet(TASK_0_0);
+        final Set<TaskId> statefulTasks = mkSet(TASK_0_0);
+        final ClientState client1 = new ClientState(singleton(TASK_0_0), emptySet(), singletonMap(TASK_0_0, 0L), 1);
+        final ClientState client2 = new ClientState(emptySet(), emptySet(), singletonMap(TASK_0_0, 0L), 1);
+        final Map<UUID, ClientState> clientStates = mkMap(mkEntry(UUID_1, client1), mkEntry(UUID_2, client2));
+
+        final boolean probingRebalanceNeeded = new HighAvailabilityTaskAssignor().assign(clientStates,
+                                                                                         allTasks,
+                                                                                         statefulTasks,
+                                                                                         configWithStandbys);
+
+        assertThat(clientStates.get(UUID_2).standbyTasks(), not(empty()));
+        assertThat(probingRebalanceNeeded, is(false));
     }
 
     @Test
-    public void shouldDecidePreviousAssignmentIsInvalidIfActiveTasksWasNotOnCaughtUpClient() {
-        client1 = EasyMock.createNiceMock(ClientState.class);
-        client2 = EasyMock.createNiceMock(ClientState.class);
-        expect(client1.prevStandbyTasks()).andStubReturn(EMPTY_TASKS);
-        expect(client2.prevStandbyTasks()).andStubReturn(EMPTY_TASKS);
-
-        expect(client1.prevActiveTasks()).andReturn(singleton(TASK_0_0));
-        expect(client2.prevActiveTasks()).andReturn(singleton(TASK_0_1));
-        expect(client1.lagFor(TASK_0_0)).andReturn(500L);
-        expect(client2.lagFor(TASK_0_0)).andReturn(0L);
-        replay(client1, client2);
-
-        allTasks =  mkSet(TASK_0_0, TASK_0_1);
-        statefulTasks =  mkSet(TASK_0_0);
-        clientStates = mkMap(
+    public void shouldComputeNewAssignmentIfActiveTasksWasNotOnCaughtUpClient() {
+        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1);
+        final Set<TaskId> statefulTasks = mkSet(TASK_0_0);
+        final ClientState client1 = new ClientState(singleton(TASK_0_0), emptySet(), singletonMap(TASK_0_0, 500L), 1);
+        final ClientState client2 = new ClientState(singleton(TASK_0_1), emptySet(), singletonMap(TASK_0_0, 0L), 1);
+        final Map<UUID, ClientState> clientStates = mkMap(
             mkEntry(UUID_1, client1),
             mkEntry(UUID_2, client2)
         );
-        createTaskAssignor();
 
-        assertFalse(taskAssignor.previousAssignmentIsValid());
+        final boolean probingRebalanceNeeded =
+            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, statefulTasks, configWithoutStandbys);
+
+        assertThat(clientStates.get(UUID_1).activeTasks(), is(singleton(TASK_0_1)));
+        assertThat(clientStates.get(UUID_2).activeTasks(), is(singleton(TASK_0_0)));
+        // we'll warm up task 0_0 on client1 because it's first in sorted order,
+        // although this isn't an optimal convergence
+        assertThat(probingRebalanceNeeded, is(true));
     }
 
     @Test
-    public void shouldDecidePreviousAssignmentIsValid() {
-        client1 = EasyMock.createNiceMock(ClientState.class);
-        client2 = EasyMock.createNiceMock(ClientState.class);
-        expect(client1.prevStandbyTasks()).andStubReturn(EMPTY_TASKS);
-        expect(client2.prevStandbyTasks()).andStubReturn(EMPTY_TASKS);
-
-        expect(client1.prevActiveTasks()).andReturn(singleton(TASK_0_0));
-        expect(client2.prevActiveTasks()).andReturn(singleton(TASK_0_1));
-        expect(client1.lagFor(TASK_0_0)).andReturn(0L);
-        expect(client2.lagFor(TASK_0_0)).andReturn(0L);
-        replay(client1, client2);
-
-        allTasks =  mkSet(TASK_0_0, TASK_0_1);
-        statefulTasks =  mkSet(TASK_0_0);
-        clientStates = mkMap(
+    public void shouldReusePreviousAssignmentIfItIsAlreadyBalanced() {
+        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1);
+        final Set<TaskId> statefulTasks = mkSet(TASK_0_0);
+        final ClientState client1 = new ClientState(singleton(TASK_0_0), emptySet(), singletonMap(TASK_0_0, 0L), 1);
+        final ClientState client2 =
+            new ClientState(singleton(TASK_0_1), emptySet(), mkMap(mkEntry(TASK_0_0, 0L), mkEntry(TASK_0_1, 0L)), 1);
+        final Map<UUID, ClientState> clientStates = mkMap(
             mkEntry(UUID_1, client1),
             mkEntry(UUID_2, client2)
         );
-        createTaskAssignor();
 
-        assertTrue(taskAssignor.previousAssignmentIsValid());
+        final boolean probingRebalanceNeeded =
+            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, statefulTasks, configWithoutStandbys);
+
+        assertThat(clientStates.get(UUID_1).activeTasks(), is(singleton(TASK_0_0)));
+        assertThat(clientStates.get(UUID_2).activeTasks(), is(singleton(TASK_0_1)));
+        assertThat(probingRebalanceNeeded, is(false));
     }
 
     @Test
     public void shouldComputeBalanceFactorAsDifferenceBetweenMostAndLeastLoadedClients() {
-        client1 = EasyMock.createNiceMock(ClientState.class);
-        client2 = EasyMock.createNiceMock(ClientState.class);
-        client3 = EasyMock.createNiceMock(ClientState.class);
+        final ClientState client1 = EasyMock.createNiceMock(ClientState.class);
+        final ClientState client2 = EasyMock.createNiceMock(ClientState.class);
+        final ClientState client3 = EasyMock.createNiceMock(ClientState.class);
         final Set<ClientState> states = mkSet(client1, client2, client3);
         final Set<TaskId> statefulTasks =
             mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_1, TASK_2_0, TASK_2_1, TASK_2_3);
@@ -186,9 +174,9 @@ public class HighAvailabilityTaskAssignorTest {
 
     @Test
     public void shouldComputeBalanceFactorWithDifferentClientCapacities() {
-        client1 = EasyMock.createNiceMock(ClientState.class);
-        client2 = EasyMock.createNiceMock(ClientState.class);
-        client3 = EasyMock.createNiceMock(ClientState.class);
+        final ClientState client1 = EasyMock.createNiceMock(ClientState.class);
+        final ClientState client2 = EasyMock.createNiceMock(ClientState.class);
+        final ClientState client3 = EasyMock.createNiceMock(ClientState.class);
         final Set<ClientState> states = mkSet(client1, client2, client3);
         final Set<TaskId> statefulTasks =
             mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_1, TASK_2_0, TASK_2_1, TASK_2_3);
@@ -211,9 +199,9 @@ public class HighAvailabilityTaskAssignorTest {
 
     @Test
     public void shouldComputeBalanceFactorBasedOnStatefulTasksOnly() {
-        client1 = EasyMock.createNiceMock(ClientState.class);
-        client2 = EasyMock.createNiceMock(ClientState.class);
-        client3 = EasyMock.createNiceMock(ClientState.class);
+        final ClientState client1 = EasyMock.createNiceMock(ClientState.class);
+        final ClientState client2 = EasyMock.createNiceMock(ClientState.class);
+        final ClientState client3 = EasyMock.createNiceMock(ClientState.class);
         final Set<ClientState> states = mkSet(client1, client2, client3);
 
         // 0_0 and 0_1 are stateless
@@ -238,7 +226,7 @@ public class HighAvailabilityTaskAssignorTest {
     @Test
     public void shouldComputeBalanceFactorOfZeroWithOnlyOneClient() {
         final Set<TaskId> statefulTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3);
-        client1 = EasyMock.createNiceMock(ClientState.class);
+        final ClientState client1 = EasyMock.createNiceMock(ClientState.class);
         expect(client1.capacity()).andReturn(1);
         expect(client1.prevActiveTasks()).andReturn(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3));
         replay(client1);
@@ -247,239 +235,268 @@ public class HighAvailabilityTaskAssignorTest {
 
     @Test
     public void shouldAssignStandbysForStatefulTasks() {
-        numStandbyReplicas = 1;
-        allTasks = mkSet(TASK_0_0, TASK_0_1);
-        statefulTasks = mkSet(TASK_0_0, TASK_0_1);
+        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1);
+        final Set<TaskId> statefulTasks = mkSet(TASK_0_0, TASK_0_1);
+
+        final ClientState client1 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0), statefulTasks);
+        final ClientState client2 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_1), statefulTasks);
 
-        client1 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0));
-        client2 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_1));
+        final Map<UUID, ClientState> clientStates = getClientStatesMap(client1, client2);
+        final boolean probingRebalanceNeeded =
+            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, statefulTasks, configWithStandbys);
 
-        clientStates = getClientStatesMap(client1, client2);
-        createTaskAssignor();
-        taskAssignor.assign();
 
         assertThat(client1.activeTasks(), equalTo(mkSet(TASK_0_0)));
         assertThat(client2.activeTasks(), equalTo(mkSet(TASK_0_1)));
         assertThat(client1.standbyTasks(), equalTo(mkSet(TASK_0_1)));
         assertThat(client2.standbyTasks(), equalTo(mkSet(TASK_0_0)));
+        assertThat(probingRebalanceNeeded, is(false));
     }
 
     @Test
     public void shouldNotAssignStandbysForStatelessTasks() {
-        numStandbyReplicas = 1;
-        allTasks = mkSet(TASK_0_0, TASK_0_1);
-        statefulTasks = EMPTY_TASKS;
+        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1);
+        final Set<TaskId> statefulTasks = EMPTY_TASKS;
 
-        client1 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS);
-        client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS);
+        final ClientState client1 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks);
+        final ClientState client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks);
+
+        final Map<UUID, ClientState> clientStates = getClientStatesMap(client1, client2);
+        final boolean probingRebalanceNeeded =
+            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, statefulTasks, configWithStandbys);
 
-        clientStates = getClientStatesMap(client1, client2);
-        createTaskAssignor();
-        taskAssignor.assign();
 
         assertThat(client1.activeTaskCount(), equalTo(1));
         assertThat(client2.activeTaskCount(), equalTo(1));
         assertHasNoStandbyTasks(client1, client2);
+        assertThat(probingRebalanceNeeded, is(false));
     }
 
     @Test
     public void shouldAssignWarmupReplicasEvenIfNoStandbyReplicasConfigured() {
-        allTasks = mkSet(TASK_0_0, TASK_0_1);
-        statefulTasks = mkSet(TASK_0_0, TASK_0_1);
-        client1 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_1));
-        client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS);
-
-        clientStates = getClientStatesMap(client1, client2);
-        createTaskAssignor();
-        taskAssignor.assign();
-        
+        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1);
+        final Set<TaskId> statefulTasks = mkSet(TASK_0_0, TASK_0_1);
+        final ClientState client1 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_1), statefulTasks);
+        final ClientState client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks);
+
+        final Map<UUID, ClientState> clientStates = getClientStatesMap(client1, client2);
+        final boolean probingRebalanceNeeded =
+            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, statefulTasks, configWithoutStandbys);
+
+
         assertThat(client1.activeTasks(), equalTo(mkSet(TASK_0_0, TASK_0_1)));
         assertThat(client2.standbyTaskCount(), equalTo(1));
         assertHasNoStandbyTasks(client1);
         assertHasNoActiveTasks(client2);
+        assertThat(probingRebalanceNeeded, is(true));
     }
 
 
-
     @Test
     public void shouldNotAssignMoreThanMaxWarmupReplicas() {
-        maxWarmupReplicas = 1;
-        allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3);
-        statefulTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3);
-        client1 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3));
-        client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS);
+        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3);
+        final Set<TaskId> statefulTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3);
+        final ClientState client1 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3), statefulTasks);
+        final ClientState client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks);
+
+        final Map<UUID, ClientState> clientStates = getClientStatesMap(client1, client2);
+        final boolean probingRebalanceNeeded = new HighAvailabilityTaskAssignor().assign(
+            clientStates,
+            allTasks,
+            statefulTasks,
+            new AssignmentConfigs(
+                /*acceptableRecoveryLag*/ 100L,
+                /*balanceFactor*/ 1,
+                /*maxWarmupReplicas*/ 1,
+                /*numStandbyReplicas*/ 0,
+                /*probingRebalanceIntervalMs*/ 60 * 1000L
+            )
+        );
 
-        clientStates = getClientStatesMap(client1, client2);
-        createTaskAssignor();
-        taskAssignor.assign();
 
         assertThat(client1.activeTasks(), equalTo(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3)));
         assertThat(client2.standbyTaskCount(), equalTo(1));
         assertHasNoStandbyTasks(client1);
         assertHasNoActiveTasks(client2);
+        assertThat(probingRebalanceNeeded, is(true));
     }
 
     @Test
     public void shouldNotAssignWarmupAndStandbyToTheSameClient() {
-        numStandbyReplicas = 1;
-        maxWarmupReplicas = 1;
-
-        allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3);
-        statefulTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3);
-        client1 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3));
-        client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS);
+        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3);
+        final Set<TaskId> statefulTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3);
+        final ClientState client1 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3), statefulTasks);
+        final ClientState client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks);
 
-        clientStates = getClientStatesMap(client1, client2);
-        createTaskAssignor();
-        taskAssignor.assign();
+        final Map<UUID, ClientState> clientStates = getClientStatesMap(client1, client2);
+        final boolean probingRebalanceNeeded = new HighAvailabilityTaskAssignor().assign(
+            clientStates,
+            allTasks,
+            statefulTasks,
+            new AssignmentConfigs(
+                /*acceptableRecoveryLag*/ 100L,
+                /*balanceFactor*/ 1,
+                /*maxWarmupReplicas*/ 1,
+                /*numStandbyReplicas*/ 1,
+                /*probingRebalanceIntervalMs*/ 60 * 1000L
+            )
+        );
 
         assertThat(client1.activeTasks(), equalTo(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3)));
         assertThat(client2.standbyTasks(), equalTo(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3)));
         assertHasNoStandbyTasks(client1);
         assertHasNoActiveTasks(client2);
+        assertThat(probingRebalanceNeeded, is(true));
     }
 
     @Test
     public void shouldNotAssignAnyStandbysWithInsufficientCapacity() {
-        numStandbyReplicas = 1;
-        allTasks = mkSet(TASK_0_0, TASK_0_1);
-        statefulTasks = mkSet(TASK_0_0, TASK_0_1);
-        client1 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_1));
+        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1);
+        final Set<TaskId> statefulTasks = mkSet(TASK_0_0, TASK_0_1);
+        final ClientState client1 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_1), statefulTasks);
 
-        clientStates = getClientStatesMap(client1);
-        createTaskAssignor();
-        taskAssignor.assign();
+        final Map<UUID, ClientState> clientStates = getClientStatesMap(client1);
+        final boolean probingRebalanceNeeded =
+            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, statefulTasks, configWithStandbys);
 
         assertThat(client1.activeTasks(), equalTo(mkSet(TASK_0_0, TASK_0_1)));
         assertHasNoStandbyTasks(client1);
+        assertThat(probingRebalanceNeeded, is(false));
     }
 
     @Test
     public void shouldAssignActiveTasksToNotCaughtUpClientIfNoneExist() {
-        numStandbyReplicas = 1;
-        allTasks = mkSet(TASK_0_0, TASK_0_1);
-        statefulTasks = mkSet(TASK_0_0, TASK_0_1);
-        client1 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS);
+        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1);
+        final Set<TaskId> statefulTasks = mkSet(TASK_0_0, TASK_0_1);
+        final ClientState client1 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks);
 
-        clientStates = getClientStatesMap(client1);
-        createTaskAssignor();
-        taskAssignor.assign();
+        final Map<UUID, ClientState> clientStates = getClientStatesMap(client1);
 
+        final boolean probingRebalanceNeeded =
+            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, statefulTasks, configWithStandbys);
         assertThat(client1.activeTasks(), equalTo(mkSet(TASK_0_0, TASK_0_1)));
         assertHasNoStandbyTasks(client1);
+        assertThat(probingRebalanceNeeded, is(false));
     }
 
     @Test
     public void shouldNotAssignMoreThanMaxWarmupReplicasWithStandbys() {
-        numStandbyReplicas = 1;
-
-        allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3);
-        statefulTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3);
-        client1 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3));
-        client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS);
-        client3 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS);
+        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3);
+        final Set<TaskId> statefulTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3);
+        final ClientState client1 = getMockClientWithPreviousCaughtUpTasks(statefulTasks, statefulTasks);
+        final ClientState client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks);
+        final ClientState client3 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks);
 
-        clientStates = getClientStatesMap(client1, client2, client3);
-        createTaskAssignor();
-        taskAssignor.assign();
+        final Map<UUID, ClientState> clientStates = getClientStatesMap(client1, client2, client3);
 
+        final boolean probingRebalanceNeeded =
+            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, statefulTasks, configWithStandbys);
         assertThat(client1.activeTaskCount(), equalTo(4));
         assertThat(client2.standbyTaskCount(), equalTo(3)); // 1
         assertThat(client3.standbyTaskCount(), equalTo(3));
         assertHasNoStandbyTasks(client1);
         assertHasNoActiveTasks(client2, client3);
+        assertThat(probingRebalanceNeeded, is(true));
     }
 
     @Test
     public void shouldDistributeStatelessTasksToBalanceTotalTaskLoad() {
-        numStandbyReplicas = 1;
-        allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_1, TASK_1_2);
-        statefulTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3);
+        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_1, TASK_1_2);
+        final Set<TaskId> statefulTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3);
 
-        client1 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3));
-        client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS);
+        final ClientState client1 = getMockClientWithPreviousCaughtUpTasks(statefulTasks, statefulTasks);
+        final ClientState client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks);
 
-        clientStates = getClientStatesMap(client1, client2);
-        createTaskAssignor();
-        taskAssignor.assign();
+        final Map<UUID, ClientState> clientStates = getClientStatesMap(client1, client2);
 
+        final boolean probingRebalanceNeeded =
+            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, statefulTasks, configWithStandbys);
         assertThat(client1.activeTasks(), equalTo(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_2)));
         assertHasNoStandbyTasks(client1);
         assertThat(client2.activeTasks(), equalTo(mkSet(TASK_1_1)));
         assertThat(client2.standbyTasks(), equalTo(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3)));
+        assertThat(probingRebalanceNeeded, is(true));
     }
 
     @Test
     public void shouldDistributeStatefulActiveTasksToAllClients() {
-        allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_1, TASK_1_2, TASK_1_3, TASK_2_0); // 9 total
-        statefulTasks = new HashSet<>(allTasks);
-        client1 = getMockClientWithPreviousCaughtUpTasks(allTasks).withCapacity(100);
-        client2 = getMockClientWithPreviousCaughtUpTasks(allTasks).withCapacity(50);
-        client3 = getMockClientWithPreviousCaughtUpTasks(allTasks).withCapacity(1);
-
-        clientStates = getClientStatesMap(client1, client2, client3);
-        createTaskAssignor();
-        taskAssignor.assign();
-
-        assertFalse(client1.activeTasks().isEmpty());
-        assertFalse(client2.activeTasks().isEmpty());
-        assertFalse(client3.activeTasks().isEmpty());
+        final Set<TaskId> allTasks =
+            mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_1, TASK_1_2, TASK_1_3, TASK_2_0); // 9 total
+        final Map<TaskId, Long> allTaskLags = allTasks.stream().collect(Collectors.toMap(t -> t, t -> 0L));
+        final Set<TaskId> statefulTasks = new HashSet<>(allTasks);
+        final ClientState client1 = new ClientState(emptySet(), emptySet(), allTaskLags, 100);
+        final ClientState client2 = new ClientState(emptySet(), emptySet(), allTaskLags, 50);
+        final ClientState client3 = new ClientState(emptySet(), emptySet(), allTaskLags, 1);
+
+        final Map<UUID, ClientState> clientStates = getClientStatesMap(client1, client2, client3);
+
+        final boolean probingRebalanceNeeded =
+            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, statefulTasks, configWithoutStandbys);
+
+        assertThat(client1.activeTasks(), not(empty()));
+        assertThat(client2.activeTasks(), not(empty()));
+        assertThat(client3.activeTasks(), not(empty()));
+        assertThat(probingRebalanceNeeded, is(false));
     }
 
     @Test
     public void shouldReturnFalseIfPreviousAssignmentIsReused() {
-        allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3);
-        statefulTasks = new HashSet<>(allTasks);
-        client1 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_2));
-        client2 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_1, TASK_0_3));
+        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3);
+        final Set<TaskId> statefulTasks = new HashSet<>(allTasks);
+        final ClientState client1 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_2), statefulTasks);
+        final ClientState client2 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_1, TASK_0_3), statefulTasks);
 
-        clientStates = getClientStatesMap(client1, client2);
-        createTaskAssignor();
-        assertFalse(taskAssignor.assign());
+        final Map<UUID, ClientState> clientStates = getClientStatesMap(client1, client2);
+        final boolean probingRebalanceNeeded =
+            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, statefulTasks, configWithoutStandbys);
 
+        assertThat(probingRebalanceNeeded, is(false));
         assertThat(client1.activeTasks(), equalTo(client1.prevActiveTasks()));
         assertThat(client2.activeTasks(), equalTo(client2.prevActiveTasks()));
     }
 
     @Test
     public void shouldReturnFalseIfNoWarmupTasksAreAssigned() {
-        allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3);
-        statefulTasks = EMPTY_TASKS;
-        client1 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS);
-        client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS);
-
-        clientStates = getClientStatesMap(client1, client2);
-        createTaskAssignor();
-        assertFalse(taskAssignor.assign());
+        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3);
+        final Set<TaskId> statefulTasks = EMPTY_TASKS;
+        final ClientState client1 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks);
+        final ClientState client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks);
+
+        final Map<UUID, ClientState> clientStates = getClientStatesMap(client1, client2);
+        final boolean probingRebalanceNeeded =
+            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, statefulTasks, configWithoutStandbys);
+        assertThat(probingRebalanceNeeded, is(false));
         assertHasNoStandbyTasks(client1, client2);
     }
 
     @Test
     public void shouldReturnTrueIfWarmupTasksAreAssigned() {
-        allTasks = mkSet(TASK_0_0, TASK_0_1);
-        statefulTasks = mkSet(TASK_0_0, TASK_0_1);
-        client1 = getMockClientWithPreviousCaughtUpTasks(allTasks);
-        client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS);
-
-        clientStates = getClientStatesMap(client1, client2);
-        createTaskAssignor();
-        assertTrue(taskAssignor.assign());
+        final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1);
+        final Set<TaskId> statefulTasks = mkSet(TASK_0_0, TASK_0_1);
+        final ClientState client1 = getMockClientWithPreviousCaughtUpTasks(allTasks, statefulTasks);
+        final ClientState client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks);
+
+        final Map<UUID, ClientState> clientStates = getClientStatesMap(client1, client2);
+        final boolean probingRebalanceNeeded =
+            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, statefulTasks, configWithoutStandbys);
+        assertThat(probingRebalanceNeeded, is(true));
         assertThat(client2.standbyTaskCount(), equalTo(1));
     }
 
     private static void assertHasNoActiveTasks(final ClientState... clients) {
         for (final ClientState client : clients) {
-            assertTrue(client.activeTasks().isEmpty());
+            assertThat(client.activeTasks(), is(empty()));
         }
     }
 
     private static void assertHasNoStandbyTasks(final ClientState... clients) {
         for (final ClientState client : clients) {
-            assertTrue(client.standbyTasks().isEmpty());
+            assertThat(client.standbyTasks(), is(empty()));
         }
     }
 
-    private MockClientState getMockClientWithPreviousCaughtUpTasks(final Set<TaskId> statefulActiveTasks) {
+    private static ClientState getMockClientWithPreviousCaughtUpTasks(final Set<TaskId> statefulActiveTasks,
+                                                                      final Set<TaskId> statefulTasks) {
         if (!statefulTasks.containsAll(statefulActiveTasks)) {
             throw new IllegalArgumentException("Need to initialize stateful tasks set before creating mock clients");
         }
@@ -491,32 +508,6 @@ public class HighAvailabilityTaskAssignorTest {
                 taskLags.put(task, Long.MAX_VALUE);
             }
         }
-        final MockClientState client = new MockClientState(1, taskLags);
-        client.addPreviousActiveTasks(statefulActiveTasks);
-        return client;
-    }
-
-    static class MockClientState extends ClientState {
-        private final Map<TaskId, Long> taskLagTotals;
-
-        private MockClientState(final int capacity,
-                                final Map<TaskId, Long> taskLagTotals) {
-            super(capacity);
-            this.taskLagTotals = taskLagTotals;
-        }
-
-        @Override
-        long lagFor(final TaskId task) {
-            final Long totalLag = taskLagTotals.get(task);
-            if (totalLag == null) {
-                return Long.MAX_VALUE;
-            } else {
-                return totalLag;
-            }
-        }
-
-        MockClientState withCapacity(final int capacity) {
-            return new MockClientState(capacity, taskLagTotals);
-        }
+        return new ClientState(statefulActiveTasks, emptySet(), taskLags, 1);
     }
 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignorTest.java
index d241a57..5203832 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignorTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignorTest.java
@@ -16,10 +16,7 @@
  */
 package org.apache.kafka.streams.processor.internals.assignment;
 
-import java.util.UUID;
-import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.processor.TaskId;
-import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs;
 import org.junit.Test;
 
 import java.util.ArrayList;
@@ -31,8 +28,11 @@ import java.util.Map;
 import java.util.Set;
 import java.util.TreeMap;
 import java.util.TreeSet;
+import java.util.UUID;
 
 import static java.util.Arrays.asList;
+import static java.util.Collections.singleton;
+import static org.apache.kafka.common.utils.Utils.mkSet;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_0;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_1;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_2;
@@ -54,12 +54,15 @@ import static org.apache.kafka.streams.processor.internals.assignment.Assignment
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_4;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_5;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_6;
-import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.core.IsIterableContaining.hasItem;
-import static org.hamcrest.core.IsIterableContaining.hasItems;
-import static org.hamcrest.core.IsNot.not;
-import static org.junit.Assert.assertTrue;
+import static org.hamcrest.Matchers.empty;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.greaterThanOrEqualTo;
+import static org.hamcrest.Matchers.hasItem;
+import static org.hamcrest.Matchers.hasItems;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.lessThanOrEqualTo;
+import static org.hamcrest.Matchers.not;
 
 public class StickyTaskAssignorTest {
 
@@ -73,11 +76,11 @@ public class StickyTaskAssignorTest {
         createClient(UUID_2, 1);
         createClient(UUID_3, 1);
 
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(TASK_0_0, TASK_0_1, TASK_0_2);
-        taskAssignor.assign();
+        final boolean probingRebalanceNeeded = assign(TASK_0_0, TASK_0_1, TASK_0_2);
+        assertThat(probingRebalanceNeeded, is(false));
 
-        for (final UUID processId : clients.keySet()) {
-            assertThat(clients.get(processId).activeTaskCount(), equalTo(1));
+        for (final ClientState clientState : clients.values()) {
+            assertThat(clientState.activeTaskCount(), equalTo(1));
         }
     }
 
@@ -87,8 +90,9 @@ public class StickyTaskAssignorTest {
         createClient(UUID_2, 2);
         createClient(UUID_3, 2);
 
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(TASK_1_0, TASK_1_1, TASK_2_2, TASK_2_0, TASK_2_1, TASK_1_2);
-        taskAssignor.assign();
+        final boolean probingRebalanceNeeded = assign(TASK_1_0, TASK_1_1, TASK_2_2, TASK_2_0, TASK_2_1, TASK_1_2);
+        assertThat(probingRebalanceNeeded, is(false));
+
         assertActiveTaskTopicGroupIdsEvenlyDistributed();
     }
 
@@ -98,8 +102,9 @@ public class StickyTaskAssignorTest {
         createClient(UUID_2, 2);
         createClient(UUID_3, 2);
 
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(1, TASK_2_0, TASK_1_1, TASK_1_2, TASK_1_0, TASK_2_1, TASK_2_2);
-        taskAssignor.assign();
+        final boolean probingRebalanceNeeded = assign(1, TASK_2_0, TASK_1_1, TASK_1_2, TASK_1_0, TASK_2_1, TASK_2_2);
+        assertThat(probingRebalanceNeeded, is(false));
+
         assertActiveTaskTopicGroupIdsEvenlyDistributed();
     }
 
@@ -108,8 +113,7 @@ public class StickyTaskAssignorTest {
         createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_0);
         createClientWithPreviousActiveTasks(UUID_2, 1, TASK_0_1);
 
-        final StickyTaskAssignor firstAssignor = createTaskAssignor(TASK_0_0, TASK_0_1, TASK_0_2);
-        firstAssignor.assign();
+        assertThat(assign(TASK_0_0, TASK_0_1, TASK_0_2), is(false));
 
         assertThat(clients.get(UUID_1).activeTasks(), hasItems(TASK_0_0));
         assertThat(clients.get(UUID_2).activeTasks(), hasItems(TASK_0_1));
@@ -121,8 +125,7 @@ public class StickyTaskAssignorTest {
         createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_1);
         createClientWithPreviousActiveTasks(UUID_2, 1, TASK_0_2);
 
-        final StickyTaskAssignor secondAssignor = createTaskAssignor(TASK_0_0, TASK_0_1, TASK_0_2);
-        secondAssignor.assign();
+        assertThat(assign(TASK_0_0, TASK_0_1, TASK_0_2), is(false));
 
         assertThat(clients.get(UUID_1).activeTasks(), hasItems(TASK_0_1));
         assertThat(clients.get(UUID_2).activeTasks(), hasItems(TASK_0_2));
@@ -135,11 +138,10 @@ public class StickyTaskAssignorTest {
         createClientWithPreviousActiveTasks(UUID_2, 1, TASK_0_1);
         createClient(UUID_3, 1);
 
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(TASK_0_0, TASK_0_1, TASK_0_2);
+        final boolean probingRebalanceNeeded = assign(TASK_0_0, TASK_0_1, TASK_0_2);
 
-        taskAssignor.assign();
-
-        assertThat(clients.get(UUID_2).activeTasks(), equalTo(Collections.singleton(TASK_0_1)));
+        assertThat(probingRebalanceNeeded, is(false));
+        assertThat(clients.get(UUID_2).activeTasks(), equalTo(singleton(TASK_0_1)));
         assertThat(clients.get(UUID_1).activeTasks().size(), equalTo(1));
         assertThat(clients.get(UUID_3).activeTasks().size(), equalTo(1));
         assertThat(allActiveTasks(), equalTo(asList(TASK_0_0, TASK_0_1, TASK_0_2)));
@@ -149,9 +151,9 @@ public class StickyTaskAssignorTest {
     public void shouldAssignBasedOnCapacity() {
         createClient(UUID_1, 1);
         createClient(UUID_2, 2);
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(TASK_0_0, TASK_0_1, TASK_0_2);
+        final boolean probingRebalanceNeeded = assign(TASK_0_0, TASK_0_1, TASK_0_2);
 
-        taskAssignor.assign();
+        assertThat(probingRebalanceNeeded, is(false));
         assertThat(clients.get(UUID_1).activeTasks().size(), equalTo(1));
         assertThat(clients.get(UUID_2).activeTasks().size(), equalTo(2));
     }
@@ -162,31 +164,29 @@ public class StickyTaskAssignorTest {
 
         createClient(UUID_2, 1);
 
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(TASK_1_0, TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_0_4, TASK_0_5);
+        assertThat(assign(TASK_1_0, TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_0_4, TASK_0_5), is(false));
 
         final Set<TaskId> expectedClientITasks = new HashSet<>(asList(TASK_0_0, TASK_0_1, TASK_1_0, TASK_0_5));
         final Set<TaskId> expectedClientIITasks = new HashSet<>(asList(TASK_0_2, TASK_0_3, TASK_0_4));
 
-        taskAssignor.assign();
 
         assertThat(clients.get(UUID_1).activeTasks(), equalTo(expectedClientITasks));
         assertThat(clients.get(UUID_2).activeTasks(), equalTo(expectedClientIITasks));
     }
 
     @Test
-    public void shouldKeepActiveTaskStickynessWhenMoreClientThanActiveTasks() {
+    public void shouldKeepActiveTaskStickinessWhenMoreClientThanActiveTasks() {
         createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_0);
         createClientWithPreviousActiveTasks(UUID_2, 1, TASK_0_2);
         createClientWithPreviousActiveTasks(UUID_3, 1, TASK_0_1);
         createClient(UUID_4, 1);
         createClient(UUID_5, 1);
 
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(TASK_0_0, TASK_0_1, TASK_0_2);
-        taskAssignor.assign();
+        assertThat(assign(TASK_0_0, TASK_0_1, TASK_0_2), is(false));
 
-        assertThat(clients.get(UUID_1).activeTasks(), equalTo(Collections.singleton(TASK_0_0)));
-        assertThat(clients.get(UUID_2).activeTasks(), equalTo(Collections.singleton(TASK_0_2)));
-        assertThat(clients.get(UUID_3).activeTasks(), equalTo(Collections.singleton(TASK_0_1)));
+        assertThat(clients.get(UUID_1).activeTasks(), equalTo(singleton(TASK_0_0)));
+        assertThat(clients.get(UUID_2).activeTasks(), equalTo(singleton(TASK_0_2)));
+        assertThat(clients.get(UUID_3).activeTasks(), equalTo(singleton(TASK_0_1)));
 
         // change up the assignment and make sure it is still sticky
         clients.clear();
@@ -196,72 +196,72 @@ public class StickyTaskAssignorTest {
         createClientWithPreviousActiveTasks(UUID_4, 1, TASK_0_2);
         createClientWithPreviousActiveTasks(UUID_5, 1, TASK_0_1);
 
-        final StickyTaskAssignor secondAssignor = createTaskAssignor(TASK_0_0, TASK_0_1, TASK_0_2);
-        secondAssignor.assign();
+        assertThat(assign(TASK_0_0, TASK_0_1, TASK_0_2), is(false));
 
-        assertThat(clients.get(UUID_2).activeTasks(), equalTo(Collections.singleton(TASK_0_0)));
-        assertThat(clients.get(UUID_4).activeTasks(), equalTo(Collections.singleton(TASK_0_2)));
-        assertThat(clients.get(UUID_5).activeTasks(), equalTo(Collections.singleton(TASK_0_1)));
+        assertThat(clients.get(UUID_2).activeTasks(), equalTo(singleton(TASK_0_0)));
+        assertThat(clients.get(UUID_4).activeTasks(), equalTo(singleton(TASK_0_2)));
+        assertThat(clients.get(UUID_5).activeTasks(), equalTo(singleton(TASK_0_1)));
     }
 
     @Test
     public void shouldAssignTasksToClientWithPreviousStandbyTasks() {
         final ClientState client1 = createClient(UUID_1, 1);
-        client1.addPreviousStandbyTasks(Utils.mkSet(TASK_0_2));
+        client1.addPreviousStandbyTasks(mkSet(TASK_0_2));
         final ClientState client2 = createClient(UUID_2, 1);
-        client2.addPreviousStandbyTasks(Utils.mkSet(TASK_0_1));
+        client2.addPreviousStandbyTasks(mkSet(TASK_0_1));
         final ClientState client3 = createClient(UUID_3, 1);
-        client3.addPreviousStandbyTasks(Utils.mkSet(TASK_0_0));
+        client3.addPreviousStandbyTasks(mkSet(TASK_0_0));
 
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(TASK_0_0, TASK_0_1, TASK_0_2);
+        final boolean probingRebalanceNeeded = assign(TASK_0_0, TASK_0_1, TASK_0_2);
 
-        taskAssignor.assign();
+        assertThat(probingRebalanceNeeded, is(false));
 
-        assertThat(clients.get(UUID_1).activeTasks(), equalTo(Collections.singleton(TASK_0_2)));
-        assertThat(clients.get(UUID_2).activeTasks(), equalTo(Collections.singleton(TASK_0_1)));
-        assertThat(clients.get(UUID_3).activeTasks(), equalTo(Collections.singleton(TASK_0_0)));
+        assertThat(clients.get(UUID_1).activeTasks(), equalTo(singleton(TASK_0_2)));
+        assertThat(clients.get(UUID_2).activeTasks(), equalTo(singleton(TASK_0_1)));
+        assertThat(clients.get(UUID_3).activeTasks(), equalTo(singleton(TASK_0_0)));
     }
 
     @Test
     public void shouldAssignBasedOnCapacityWhenMultipleClientHaveStandbyTasks() {
         final ClientState c1 = createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_0);
-        c1.addPreviousStandbyTasks(Utils.mkSet(TASK_0_1));
+        c1.addPreviousStandbyTasks(mkSet(TASK_0_1));
         final ClientState c2 = createClientWithPreviousActiveTasks(UUID_2, 2, TASK_0_2);
-        c2.addPreviousStandbyTasks(Utils.mkSet(TASK_0_1));
+        c2.addPreviousStandbyTasks(mkSet(TASK_0_1));
 
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(TASK_0_0, TASK_0_1, TASK_0_2);
+        final boolean probingRebalanceNeeded = assign(TASK_0_0, TASK_0_1, TASK_0_2);
 
-        taskAssignor.assign();
+        assertThat(probingRebalanceNeeded, is(false));
 
-        assertThat(clients.get(UUID_1).activeTasks(), equalTo(Collections.singleton(TASK_0_0)));
-        assertThat(clients.get(UUID_2).activeTasks(), equalTo(Utils.mkSet(TASK_0_2, TASK_0_1)));
+        assertThat(clients.get(UUID_1).activeTasks(), equalTo(singleton(TASK_0_0)));
+        assertThat(clients.get(UUID_2).activeTasks(), equalTo(mkSet(TASK_0_2, TASK_0_1)));
     }
 
     @Test
-    public void shouldAssignStandbyTasksToDifferentClientThanCorrespondingActiveTaskIsAssingedTo() {
+    public void shouldAssignStandbyTasksToDifferentClientThanCorrespondingActiveTaskIsAssignedTo() {
         createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_0);
         createClientWithPreviousActiveTasks(UUID_2, 1, TASK_0_1);
         createClientWithPreviousActiveTasks(UUID_3, 1, TASK_0_2);
         createClientWithPreviousActiveTasks(UUID_4, 1, TASK_0_3);
 
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(1, TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3);
-        taskAssignor.assign();
+        final boolean probingRebalanceNeeded = assign(1, TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3);
+        assertThat(probingRebalanceNeeded, is(false));
+
 
         assertThat(clients.get(UUID_1).standbyTasks(), not(hasItems(TASK_0_0)));
-        assertTrue(clients.get(UUID_1).standbyTasks().size() <= 2);
+        assertThat(clients.get(UUID_1).standbyTasks().size(), lessThanOrEqualTo(2));
         assertThat(clients.get(UUID_2).standbyTasks(), not(hasItems(TASK_0_1)));
-        assertTrue(clients.get(UUID_2).standbyTasks().size() <= 2);
+        assertThat(clients.get(UUID_2).standbyTasks().size(), lessThanOrEqualTo(2));
         assertThat(clients.get(UUID_3).standbyTasks(), not(hasItems(TASK_0_2)));
-        assertTrue(clients.get(UUID_3).standbyTasks().size() <= 2);
+        assertThat(clients.get(UUID_3).standbyTasks().size(), lessThanOrEqualTo(2));
         assertThat(clients.get(UUID_4).standbyTasks(), not(hasItems(TASK_0_3)));
-        assertTrue(clients.get(UUID_4).standbyTasks().size() <= 2);
+        assertThat(clients.get(UUID_4).standbyTasks().size(), lessThanOrEqualTo(2));
 
         int nonEmptyStandbyTaskCount = 0;
-        for (final UUID client : clients.keySet()) {
-            nonEmptyStandbyTaskCount += clients.get(client).standbyTasks().isEmpty() ? 0 : 1;
+        for (final ClientState clientState : clients.values()) {
+            nonEmptyStandbyTaskCount += clientState.standbyTasks().isEmpty() ? 0 : 1;
         }
 
-        assertTrue(nonEmptyStandbyTaskCount >= 3);
+        assertThat(nonEmptyStandbyTaskCount, greaterThanOrEqualTo(3));
         assertThat(allStandbyTasks(), equalTo(asList(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3)));
     }
 
@@ -271,19 +271,19 @@ public class StickyTaskAssignorTest {
         createClientWithPreviousActiveTasks(UUID_2, 1, TASK_0_1);
         createClientWithPreviousActiveTasks(UUID_3, 1, TASK_0_2);
 
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(2, TASK_0_0, TASK_0_1, TASK_0_2);
-        taskAssignor.assign();
+        final boolean probingRebalanceNeeded = assign(2, TASK_0_0, TASK_0_1, TASK_0_2);
+        assertThat(probingRebalanceNeeded, is(false));
 
-        assertThat(clients.get(UUID_1).standbyTasks(), equalTo(Utils.mkSet(TASK_0_1, TASK_0_2)));
-        assertThat(clients.get(UUID_2).standbyTasks(), equalTo(Utils.mkSet(TASK_0_2, TASK_0_0)));
-        assertThat(clients.get(UUID_3).standbyTasks(), equalTo(Utils.mkSet(TASK_0_0, TASK_0_1)));
+        assertThat(clients.get(UUID_1).standbyTasks(), equalTo(mkSet(TASK_0_1, TASK_0_2)));
+        assertThat(clients.get(UUID_2).standbyTasks(), equalTo(mkSet(TASK_0_2, TASK_0_0)));
+        assertThat(clients.get(UUID_3).standbyTasks(), equalTo(mkSet(TASK_0_0, TASK_0_1)));
     }
 
     @Test
     public void shouldNotAssignStandbyTaskReplicasWhenNoClientAvailableWithoutHavingTheTaskAssigned() {
         createClient(UUID_1, 1);
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(1, TASK_0_0);
-        taskAssignor.assign();
+        final boolean probingRebalanceNeeded = assign(1, TASK_0_0);
+        assertThat(probingRebalanceNeeded, is(false));
         assertThat(clients.get(UUID_1).standbyTasks().size(), equalTo(0));
     }
 
@@ -293,8 +293,8 @@ public class StickyTaskAssignorTest {
         createClient(UUID_2, 1);
         createClient(UUID_3, 1);
 
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(1, TASK_0_0, TASK_0_1, TASK_0_2);
-        taskAssignor.assign();
+        final boolean probingRebalanceNeeded = assign(1, TASK_0_0, TASK_0_1, TASK_0_2);
+        assertThat(probingRebalanceNeeded, is(false));
 
         assertThat(allActiveTasks(), equalTo(asList(TASK_0_0, TASK_0_1, TASK_0_2)));
         assertThat(allStandbyTasks(), equalTo(asList(TASK_0_0, TASK_0_1, TASK_0_2)));
@@ -306,8 +306,8 @@ public class StickyTaskAssignorTest {
         createClient(UUID_2, 1);
         createClient(UUID_3, 1);
 
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(TASK_0_0, TASK_0_1, TASK_0_2);
-        taskAssignor.assign();
+        final boolean probingRebalanceNeeded = assign(TASK_0_0, TASK_0_1, TASK_0_2);
+        assertThat(probingRebalanceNeeded, is(false));
         assertThat(clients.get(UUID_1).assignedTaskCount(), equalTo(1));
         assertThat(clients.get(UUID_2).assignedTaskCount(), equalTo(1));
         assertThat(clients.get(UUID_3).assignedTaskCount(), equalTo(1));
@@ -322,8 +322,8 @@ public class StickyTaskAssignorTest {
         createClient(UUID_5, 1);
         createClient(UUID_6, 1);
 
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(TASK_0_0, TASK_0_1, TASK_0_2);
-        taskAssignor.assign();
+        final boolean probingRebalanceNeeded = assign(TASK_0_0, TASK_0_1, TASK_0_2);
+        assertThat(probingRebalanceNeeded, is(false));
 
         assertThat(allActiveTasks(), equalTo(asList(TASK_0_0, TASK_0_1, TASK_0_2)));
     }
@@ -337,8 +337,8 @@ public class StickyTaskAssignorTest {
         createClient(UUID_5, 1);
         createClient(UUID_6, 1);
 
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(1, TASK_0_0, TASK_0_1, TASK_0_2);
-        taskAssignor.assign();
+        final boolean probingRebalanceNeeded = assign(1, TASK_0_0, TASK_0_1, TASK_0_2);
+        assertThat(probingRebalanceNeeded, is(false));
 
         for (final ClientState clientState : clients.values()) {
             assertThat(clientState.assignedTaskCount(), equalTo(1));
@@ -350,20 +350,22 @@ public class StickyTaskAssignorTest {
         createClient(UUID_2, 2);
         createClient(UUID_1, 1);
 
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(TASK_0_0,
-                                                                            TASK_0_1,
-                                                                            TASK_0_2,
-                                                                            new TaskId(1, 0),
-                                                                            new TaskId(1, 1),
-                                                                            new TaskId(1, 2),
-                                                                            new TaskId(2, 0),
-                                                                            new TaskId(2, 1),
-                                                                            new TaskId(2, 2),
-                                                                            new TaskId(3, 0),
-                                                                            new TaskId(3, 1),
-                                                                            new TaskId(3, 2));
-
-        taskAssignor.assign();
+        final boolean probingRebalanceNeeded = assign(
+            TASK_0_0,
+            TASK_0_1,
+            TASK_0_2,
+            new TaskId(1, 0),
+            new TaskId(1, 1),
+            new TaskId(1, 2),
+            new TaskId(2, 0),
+            new TaskId(2, 1),
+            new TaskId(2, 2),
+            new TaskId(3, 0),
+            new TaskId(3, 1),
+            new TaskId(3, 2)
+        );
+
+        assertThat(probingRebalanceNeeded, is(false));
         assertThat(clients.get(UUID_2).assignedTaskCount(), equalTo(8));
         assertThat(clients.get(UUID_1).assignedTaskCount(), equalTo(4));
     }
@@ -387,8 +389,8 @@ public class StickyTaskAssignorTest {
         Collections.shuffle(taskIds);
         taskIds.toArray(taskIdArray);
 
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(taskIdArray);
-        taskAssignor.assign();
+        final boolean probingRebalanceNeeded = assign(taskIdArray);
+        assertThat(probingRebalanceNeeded, is(false));
 
         Collections.sort(taskIds);
         final Set<TaskId> expectedClientOneAssignment = getExpectedTaskIdAssignment(taskIds, 0, 4, 8, 12);
@@ -412,8 +414,8 @@ public class StickyTaskAssignorTest {
         createClient(UUID_3, 1);
         createClient(UUID_4, 1);
 
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(1, TASK_0_0, TASK_0_2, TASK_0_1, TASK_0_3);
-        taskAssignor.assign();
+        final boolean probingRebalanceNeeded = assign(1, TASK_0_0, TASK_0_2, TASK_0_1, TASK_0_3);
+        assertThat(probingRebalanceNeeded, is(false));
 
         for (final UUID uuid : allUUIDs) {
             final Set<TaskId> taskIds = clients.get(uuid).assignedTasks();
@@ -435,8 +437,8 @@ public class StickyTaskAssignorTest {
         createClientWithPreviousActiveTasks(UUID_3, 1, TASK_0_0);
         createClient(UUID_4, 1);
 
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(1, TASK_0_0, TASK_0_2, TASK_0_1, TASK_0_3);
-        taskAssignor.assign();
+        final boolean probingRebalanceNeeded = assign(1, TASK_0_0, TASK_0_2, TASK_0_1, TASK_0_3);
+        assertThat(probingRebalanceNeeded, is(false));
 
         for (final UUID uuid : allUUIDs) {
             final Set<TaskId> taskIds = clients.get(uuid).assignedTasks();
@@ -455,15 +457,15 @@ public class StickyTaskAssignorTest {
         final List<UUID> allUUIDs = asList(UUID_1, UUID_2, UUID_3, UUID_4);
 
         final ClientState c1 = createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_1, TASK_0_2);
-        c1.addPreviousStandbyTasks(Utils.mkSet(TASK_0_3, TASK_0_0));
+        c1.addPreviousStandbyTasks(mkSet(TASK_0_3, TASK_0_0));
         final ClientState c2 = createClientWithPreviousActiveTasks(UUID_2, 1, TASK_0_3, TASK_0_0);
-        c2.addPreviousStandbyTasks(Utils.mkSet(TASK_0_1, TASK_0_2));
+        c2.addPreviousStandbyTasks(mkSet(TASK_0_1, TASK_0_2));
 
         createClient(UUID_3, 1);
         createClient(UUID_4, 1);
 
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(1, TASK_0_0, TASK_0_2, TASK_0_1, TASK_0_3);
-        taskAssignor.assign();
+        final boolean probingRebalanceNeeded = assign(1, TASK_0_0, TASK_0_2, TASK_0_1, TASK_0_3);
+        assertThat(probingRebalanceNeeded, is(false));
 
         for (final UUID uuid : allUUIDs) {
             final Set<TaskId> taskIds = clients.get(uuid).assignedTasks();
@@ -484,8 +486,8 @@ public class StickyTaskAssignorTest {
         createClient(UUID_2, 1);
         createClient(UUID_4, 1);
 
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(TASK_0_0, TASK_0_2, TASK_0_1, TASK_0_3);
-        taskAssignor.assign();
+        final boolean probingRebalanceNeeded = assign(TASK_0_0, TASK_0_2, TASK_0_1, TASK_0_3);
+        assertThat(probingRebalanceNeeded, is(false));
 
         assertThat(clients.get(UUID_1).assignedTaskCount(), equalTo(1));
         assertThat(clients.get(UUID_2).assignedTaskCount(), equalTo(1));
@@ -499,8 +501,8 @@ public class StickyTaskAssignorTest {
         createClient(UUID_1, 1);
         createClient(UUID_2, 1);
 
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(TASK_0_0, TASK_0_2, TASK_0_1, TASK_0_3);
-        taskAssignor.assign();
+        final boolean probingRebalanceNeeded = assign(TASK_0_0, TASK_0_2, TASK_0_1, TASK_0_3);
+        assertThat(probingRebalanceNeeded, is(false));
 
         assertThat(clients.get(UUID_3).assignedTaskCount(), equalTo(2));
         assertThat(clients.get(UUID_1).assignedTaskCount(), equalTo(1));
@@ -511,23 +513,23 @@ public class StickyTaskAssignorTest {
     public void shouldRebalanceTasksToClientsBasedOnCapacity() {
         createClientWithPreviousActiveTasks(UUID_2, 1, TASK_0_0, TASK_0_3, TASK_0_2);
         createClient(UUID_3, 2);
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(TASK_0_0, TASK_0_2, TASK_0_3);
-        taskAssignor.assign();
+        final boolean probingRebalanceNeeded = assign(TASK_0_0, TASK_0_2, TASK_0_3);
+        assertThat(probingRebalanceNeeded, is(false));
         assertThat(clients.get(UUID_2).assignedTaskCount(), equalTo(1));
         assertThat(clients.get(UUID_3).assignedTaskCount(), equalTo(2));
     }
 
     @Test
     public void shouldMoveMinimalNumberOfTasksWhenPreviouslyAboveCapacityAndNewClientAdded() {
-        final Set<TaskId> p1PrevTasks = Utils.mkSet(TASK_0_0, TASK_0_2);
-        final Set<TaskId> p2PrevTasks = Utils.mkSet(TASK_0_1, TASK_0_3);
+        final Set<TaskId> p1PrevTasks = mkSet(TASK_0_0, TASK_0_2);
+        final Set<TaskId> p2PrevTasks = mkSet(TASK_0_1, TASK_0_3);
 
         createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_0, TASK_0_2);
         createClientWithPreviousActiveTasks(UUID_2, 1, TASK_0_1, TASK_0_3);
         createClientWithPreviousActiveTasks(UUID_3, 1);
 
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(TASK_0_0, TASK_0_2, TASK_0_1, TASK_0_3);
-        taskAssignor.assign();
+        final boolean probingRebalanceNeeded = assign(TASK_0_0, TASK_0_2, TASK_0_1, TASK_0_3);
+        assertThat(probingRebalanceNeeded, is(false));
 
         final Set<TaskId> p3ActiveTasks = clients.get(UUID_3).activeTasks();
         assertThat(p3ActiveTasks.size(), equalTo(1));
@@ -543,8 +545,8 @@ public class StickyTaskAssignorTest {
         createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_0, TASK_0_1);
         createClientWithPreviousActiveTasks(UUID_2, 1, TASK_0_2, TASK_0_3);
 
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(TASK_0_3, TASK_0_1, TASK_0_4, TASK_0_2, TASK_0_0, TASK_0_5);
-        taskAssignor.assign();
+        final boolean probingRebalanceNeeded = assign(TASK_0_3, TASK_0_1, TASK_0_4, TASK_0_2, TASK_0_0, TASK_0_5);
+        assertThat(probingRebalanceNeeded, is(false));
 
         assertThat(clients.get(UUID_1).activeTasks(), hasItems(TASK_0_0, TASK_0_1));
         assertThat(clients.get(UUID_2).activeTasks(), hasItems(TASK_0_2, TASK_0_3));
@@ -557,8 +559,8 @@ public class StickyTaskAssignorTest {
         createClientWithPreviousActiveTasks(UUID_2, 1, TASK_0_0, TASK_0_3);
         createClient(UUID_3, 1);
 
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(TASK_0_3, TASK_0_1, TASK_0_4, TASK_0_2, TASK_0_0, TASK_0_5);
-        taskAssignor.assign();
+        final boolean probingRebalanceNeeded = assign(TASK_0_3, TASK_0_1, TASK_0_4, TASK_0_2, TASK_0_0, TASK_0_5);
+        assertThat(probingRebalanceNeeded, is(false));
 
         assertThat(clients.get(UUID_1).activeTasks(), hasItems(TASK_0_2, TASK_0_1));
         assertThat(clients.get(UUID_2).activeTasks(), hasItems(TASK_0_0, TASK_0_3));
@@ -568,51 +570,51 @@ public class StickyTaskAssignorTest {
     @Test
     public void shouldAssignTasksNotPreviouslyActiveToNewClient() {
         final ClientState c1 = createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_1, TASK_1_2, TASK_1_3);
-        c1.addPreviousStandbyTasks(Utils.mkSet(TASK_0_0, TASK_1_1, TASK_2_0, TASK_2_1, TASK_2_3));
+        c1.addPreviousStandbyTasks(mkSet(TASK_0_0, TASK_1_1, TASK_2_0, TASK_2_1, TASK_2_3));
         final ClientState c2 = createClientWithPreviousActiveTasks(UUID_2, 1, TASK_0_0, TASK_1_1, TASK_2_2);
-        c2.addPreviousStandbyTasks(Utils.mkSet(TASK_0_1, TASK_1_0, TASK_0_2, TASK_2_0, TASK_0_3, TASK_1_2, TASK_2_1, TASK_1_3, TASK_2_3));
+        c2.addPreviousStandbyTasks(mkSet(TASK_0_1, TASK_1_0, TASK_0_2, TASK_2_0, TASK_0_3, TASK_1_2, TASK_2_1, TASK_1_3, TASK_2_3));
         final ClientState c3 = createClientWithPreviousActiveTasks(UUID_3, 1, TASK_2_0, TASK_2_1, TASK_2_3);
-        c3.addPreviousStandbyTasks(Utils.mkSet(TASK_0_2, TASK_1_2));
+        c3.addPreviousStandbyTasks(mkSet(TASK_0_2, TASK_1_2));
 
         final ClientState newClient = createClient(UUID_4, 1);
-        newClient.addPreviousStandbyTasks(Utils.mkSet(TASK_0_0, TASK_1_0, TASK_0_1, TASK_0_2, TASK_1_1, TASK_2_0, TASK_0_3, TASK_1_2, TASK_2_1, TASK_1_3, TASK_2_2, TASK_2_3));
+        newClient.addPreviousStandbyTasks(mkSet(TASK_0_0, TASK_1_0, TASK_0_1, TASK_0_2, TASK_1_1, TASK_2_0, TASK_0_3, TASK_1_2, TASK_2_1, TASK_1_3, TASK_2_2, TASK_2_3));
 
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(TASK_0_0, TASK_1_0, TASK_0_1, TASK_0_2, TASK_1_1, TASK_2_0, TASK_0_3, TASK_1_2, TASK_2_1, TASK_1_3, TASK_2_2, TASK_2_3);
-        taskAssignor.assign();
+        final boolean probingRebalanceNeeded = assign(TASK_0_0, TASK_1_0, TASK_0_1, TASK_0_2, TASK_1_1, TASK_2_0, TASK_0_3, TASK_1_2, TASK_2_1, TASK_1_3, TASK_2_2, TASK_2_3);
+        assertThat(probingRebalanceNeeded, is(false));
 
-        assertThat(c1.activeTasks(), equalTo(Utils.mkSet(TASK_0_1, TASK_1_2, TASK_1_3)));
-        assertThat(c2.activeTasks(), equalTo(Utils.mkSet(TASK_0_0, TASK_1_1, TASK_2_2)));
-        assertThat(c3.activeTasks(), equalTo(Utils.mkSet(TASK_2_0, TASK_2_1, TASK_2_3)));
-        assertThat(newClient.activeTasks(), equalTo(Utils.mkSet(TASK_0_2, TASK_0_3, TASK_1_0)));
+        assertThat(c1.activeTasks(), equalTo(mkSet(TASK_0_1, TASK_1_2, TASK_1_3)));
+        assertThat(c2.activeTasks(), equalTo(mkSet(TASK_0_0, TASK_1_1, TASK_2_2)));
+        assertThat(c3.activeTasks(), equalTo(mkSet(TASK_2_0, TASK_2_1, TASK_2_3)));
+        assertThat(newClient.activeTasks(), equalTo(mkSet(TASK_0_2, TASK_0_3, TASK_1_0)));
     }
 
     @Test
     public void shouldAssignTasksNotPreviouslyActiveToMultipleNewClients() {
         final ClientState c1 = createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_1, TASK_1_2, TASK_1_3);
-        c1.addPreviousStandbyTasks(Utils.mkSet(TASK_0_0, TASK_1_1, TASK_2_0, TASK_2_1, TASK_2_3));
+        c1.addPreviousStandbyTasks(mkSet(TASK_0_0, TASK_1_1, TASK_2_0, TASK_2_1, TASK_2_3));
         final ClientState c2 = createClientWithPreviousActiveTasks(UUID_2, 1, TASK_0_0, TASK_1_1, TASK_2_2);
-        c2.addPreviousStandbyTasks(Utils.mkSet(TASK_0_1, TASK_1_0, TASK_0_2, TASK_2_0, TASK_0_3, TASK_1_2, TASK_2_1, TASK_1_3, TASK_2_3));
+        c2.addPreviousStandbyTasks(mkSet(TASK_0_1, TASK_1_0, TASK_0_2, TASK_2_0, TASK_0_3, TASK_1_2, TASK_2_1, TASK_1_3, TASK_2_3));
 
         final ClientState bounce1 = createClient(UUID_3, 1);
-        bounce1.addPreviousStandbyTasks(Utils.mkSet(TASK_2_0, TASK_2_1, TASK_2_3));
+        bounce1.addPreviousStandbyTasks(mkSet(TASK_2_0, TASK_2_1, TASK_2_3));
 
         final ClientState bounce2 = createClient(UUID_4, 1);
-        bounce2.addPreviousStandbyTasks(Utils.mkSet(TASK_0_2, TASK_0_3, TASK_1_0));
+        bounce2.addPreviousStandbyTasks(mkSet(TASK_0_2, TASK_0_3, TASK_1_0));
 
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(TASK_0_0, TASK_1_0, TASK_0_1, TASK_0_2, TASK_1_1, TASK_2_0, TASK_0_3, TASK_1_2, TASK_2_1, TASK_1_3, TASK_2_2, TASK_2_3);
-        taskAssignor.assign();
+        final boolean probingRebalanceNeeded = assign(TASK_0_0, TASK_1_0, TASK_0_1, TASK_0_2, TASK_1_1, TASK_2_0, TASK_0_3, TASK_1_2, TASK_2_1, TASK_1_3, TASK_2_2, TASK_2_3);
+        assertThat(probingRebalanceNeeded, is(false));
 
-        assertThat(c1.activeTasks(), equalTo(Utils.mkSet(TASK_0_1, TASK_1_2, TASK_1_3)));
-        assertThat(c2.activeTasks(), equalTo(Utils.mkSet(TASK_0_0, TASK_1_1, TASK_2_2)));
-        assertThat(bounce1.activeTasks(), equalTo(Utils.mkSet(TASK_2_0, TASK_2_1, TASK_2_3)));
-        assertThat(bounce2.activeTasks(), equalTo(Utils.mkSet(TASK_0_2, TASK_0_3, TASK_1_0)));
+        assertThat(c1.activeTasks(), equalTo(mkSet(TASK_0_1, TASK_1_2, TASK_1_3)));
+        assertThat(c2.activeTasks(), equalTo(mkSet(TASK_0_0, TASK_1_1, TASK_2_2)));
+        assertThat(bounce1.activeTasks(), equalTo(mkSet(TASK_2_0, TASK_2_1, TASK_2_3)));
+        assertThat(bounce2.activeTasks(), equalTo(mkSet(TASK_0_2, TASK_0_3, TASK_1_0)));
     }
 
     @Test
     public void shouldAssignTasksToNewClient() {
         createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_1, TASK_0_2);
         createClient(UUID_2, 1);
-        createTaskAssignor(TASK_0_1, TASK_0_2).assign();
+        assertThat(assign(TASK_0_1, TASK_0_2), is(false));
         assertThat(clients.get(UUID_1).activeTaskCount(), equalTo(1));
     }
 
@@ -622,8 +624,8 @@ public class StickyTaskAssignorTest {
         final ClientState c2 = createClientWithPreviousActiveTasks(UUID_2, 1, TASK_0_3, TASK_0_4, TASK_0_5);
         final ClientState newClient = createClient(UUID_3, 1);
 
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_0_4, TASK_0_5);
-        taskAssignor.assign();
+        final boolean probingRebalanceNeeded = assign(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_0_4, TASK_0_5);
+        assertThat(probingRebalanceNeeded, is(false));
         assertThat(c1.activeTasks(), not(hasItem(TASK_0_3)));
         assertThat(c1.activeTasks(), not(hasItem(TASK_0_4)));
         assertThat(c1.activeTasks(), not(hasItem(TASK_0_5)));
@@ -639,11 +641,11 @@ public class StickyTaskAssignorTest {
     public void shouldAssignTasksToNewClientWithoutFlippingAssignmentBetweenExistingAndBouncedClients() {
         final ClientState c1 = createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_6);
         final ClientState c2 = createClient(UUID_2, 1);
-        c2.addPreviousStandbyTasks(Utils.mkSet(TASK_0_3, TASK_0_4, TASK_0_5));
+        c2.addPreviousStandbyTasks(mkSet(TASK_0_3, TASK_0_4, TASK_0_5));
         final ClientState newClient = createClient(UUID_3, 1);
 
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_0_4, TASK_0_5, TASK_0_6);
-        taskAssignor.assign();
+        final boolean probingRebalanceNeeded = assign(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_0_4, TASK_0_5, TASK_0_6);
+        assertThat(probingRebalanceNeeded, is(false));
         assertThat(c1.activeTasks(), not(hasItem(TASK_0_3)));
         assertThat(c1.activeTasks(), not(hasItem(TASK_0_4)));
         assertThat(c1.activeTasks(), not(hasItem(TASK_0_5)));
@@ -660,32 +662,32 @@ public class StickyTaskAssignorTest {
         final ClientState c1 = createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_0, TASK_0_1, TASK_0_2);
         final ClientState c2 = createClient(UUID_2, 1);
 
-        final StickyTaskAssignor taskAssignor = createTaskAssignor(0, true, TASK_0_0, TASK_0_1, TASK_0_2);
-        taskAssignor.assign();
+        final List<TaskId> taskIds = asList(TASK_0_0, TASK_0_1, TASK_0_2);
+        Collections.shuffle(taskIds);
+        final boolean probingRebalanceNeeded = new StickyTaskAssignor(true).assign(
+            clients,
+            new HashSet<>(taskIds),
+            new HashSet<>(taskIds),
+            new AssignorConfiguration.AssignmentConfigs(0L, 0, 0, 0, 0L)
+        );
+        assertThat(probingRebalanceNeeded, is(false));
 
-        assertThat(c1.activeTasks(), equalTo(Utils.mkSet(TASK_0_0, TASK_0_1, TASK_0_2)));
-        assertTrue(c2.activeTasks().isEmpty());
+        assertThat(c1.activeTasks(), equalTo(mkSet(TASK_0_0, TASK_0_1, TASK_0_2)));
+        assertThat(c2.activeTasks(), empty());
     }
 
-    private StickyTaskAssignor createTaskAssignor(final TaskId... tasks) {
-        return createTaskAssignor(0, false, tasks);
-    }
-    
-    private StickyTaskAssignor createTaskAssignor(final int numStandbys, final TaskId... tasks) {
-        return createTaskAssignor(numStandbys, false, tasks);
+    private boolean assign(final TaskId... tasks) {
+        return assign(0, tasks);
     }
 
-    private StickyTaskAssignor createTaskAssignor(final int numStandbys,
-                                                  final boolean mustPreserveActiveTaskAssignment,
-                                                  final TaskId... tasks) {
+    private boolean assign(final int numStandbys, final TaskId... tasks) {
         final List<TaskId> taskIds = asList(tasks);
         Collections.shuffle(taskIds);
-        return new StickyTaskAssignor(
+        return new StickyTaskAssignor().assign(
             clients,
             new HashSet<>(taskIds),
             new HashSet<>(taskIds),
-            new AssignmentConfigs(0L, 0, 0, numStandbys, 0L),
-            mustPreserveActiveTaskAssignment
+            new AssignorConfiguration.AssignmentConfigs(0L, 0, 0, numStandbys, 0L)
         );
     }
 
@@ -713,7 +715,7 @@ public class StickyTaskAssignorTest {
 
     private ClientState createClientWithPreviousActiveTasks(final UUID processId, final int capacity, final TaskId... taskIds) {
         final ClientState clientState = new ClientState(capacity);
-        clientState.addPreviousActiveTasks(Utils.mkSet(taskIds));
+        clientState.addPreviousActiveTasks(mkSet(taskIds));
         clients.put(processId, clientState);
         return clientState;
     }
@@ -730,7 +732,7 @@ public class StickyTaskAssignorTest {
         }
     }
 
-    private Map<UUID, Set<TaskId>> sortClientAssignments(final Map<UUID, ClientState> clients) {
+    private static Map<UUID, Set<TaskId>> sortClientAssignments(final Map<UUID, ClientState> clients) {
         final Map<UUID, Set<TaskId>> sortedAssignments = new HashMap<>();
         for (final Map.Entry<UUID, ClientState> entry : clients.entrySet()) {
             final Set<TaskId> sorted = new TreeSet<>(entry.getValue().activeTasks());
@@ -739,12 +741,11 @@ public class StickyTaskAssignorTest {
         return sortedAssignments;
     }
 
-    private Set<TaskId> getExpectedTaskIdAssignment(final List<TaskId> tasks, final int... indices) {
+    private static Set<TaskId> getExpectedTaskIdAssignment(final List<TaskId> tasks, final int... indices) {
         final Set<TaskId> sortedAssignment = new TreeSet<>();
         for (final int index : indices) {
             sortedAssignment.add(tasks.get(index));
         }
         return sortedAssignment;
     }
-
 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorConvergenceTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorConvergenceTest.java
index 7be6ee7..9517400 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorConvergenceTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorConvergenceTest.java
@@ -416,11 +416,12 @@ public class TaskAssignorConvergenceTest {
             iteration++;
             harness.prepareForNextRebalance();
             harness.recordBefore(iteration);
-            rebalancePending = new HighAvailabilityTaskAssignor(
-                harness.clientStates, allTasks,
+            rebalancePending = new HighAvailabilityTaskAssignor().assign(
+                harness.clientStates,
+                allTasks,
                 harness.statefulTaskEndOffsetSums.keySet(),
                 configs
-            ).assign();
+            );
             harness.recordAfter(iteration, rebalancePending);
         }
 
diff --git a/tests/kafkatest/services/streams.py b/tests/kafkatest/services/streams.py
index e878882..b5e2feb 100644
--- a/tests/kafkatest/services/streams.py
+++ b/tests/kafkatest/services/streams.py
@@ -477,6 +477,10 @@ class StreamsUpgradeTestJobRunnerService(StreamsTestBaseService):
                                                                  "")
         self.UPGRADE_FROM = None
         self.UPGRADE_TO = None
+        self.extra_properties = {}
+
+    def set_config(self, key, value):
+        self.extra_properties[key] = value
 
     def set_version(self, kafka_streams_version):
         self.KAFKA_STREAMS_VERSION = kafka_streams_version
@@ -488,8 +492,10 @@ class StreamsUpgradeTestJobRunnerService(StreamsTestBaseService):
         self.UPGRADE_TO = upgrade_to
 
     def prop_file(self):
-        properties = {streams_property.STATE_DIR: self.PERSISTENT_ROOT,
-                      streams_property.KAFKA_SERVERS: self.kafka.bootstrap_servers()}
+        properties = self.extra_properties.copy()
+        properties[streams_property.STATE_DIR] = self.PERSISTENT_ROOT
+        properties[streams_property.KAFKA_SERVERS] = self.kafka.bootstrap_servers()
+
         if self.UPGRADE_FROM is not None:
             properties['upgrade.from'] = self.UPGRADE_FROM
         if self.UPGRADE_TO == "future_version":
@@ -562,6 +568,8 @@ class StaticMemberTestService(StreamsTestBaseService):
                       consumer_property.SESSION_TIMEOUT_MS: 60000}
 
         properties['input.topic'] = self.INPUT_TOPIC
+        # TODO KIP-441: consider rewriting the test for HighAvailabilityTaskAssignor
+        properties['internal.task.assignor.class'] = "org.apache.kafka.streams.processor.internals.assignment.StickyTaskAssignor"
 
         cfg = KafkaConfig(**properties)
         return cfg.render()
diff --git a/tests/kafkatest/tests/streams/streams_broker_down_resilience_test.py b/tests/kafkatest/tests/streams/streams_broker_down_resilience_test.py
index 58f3b18..8fcf14a 100644
--- a/tests/kafkatest/tests/streams/streams_broker_down_resilience_test.py
+++ b/tests/kafkatest/tests/streams/streams_broker_down_resilience_test.py
@@ -144,7 +144,11 @@ class StreamsBrokerDownResilience(BaseStreamsTest):
     def test_streams_should_scale_in_while_brokers_down(self):
         self.kafka.start()
 
-        configs = self.get_configs(extra_configs=",application.id=shutdown_with_broker_down")
+        # TODO KIP-441: consider rewriting the test for HighAvailabilityTaskAssignor
+        configs = self.get_configs(
+            extra_configs=",application.id=shutdown_with_broker_down" +
+                          ",internal.task.assignor.class=org.apache.kafka.streams.processor.internals.assignment.StickyTaskAssignor"
+        )
 
         processor = StreamsBrokerDownResilienceService(self.test_context, self.kafka, configs)
         processor.start()
@@ -217,7 +221,11 @@ class StreamsBrokerDownResilience(BaseStreamsTest):
     def test_streams_should_failover_while_brokers_down(self):
         self.kafka.start()
 
-        configs = self.get_configs(extra_configs=",application.id=failover_with_broker_down")
+        # TODO KIP-441: consider rewriting the test for HighAvailabilityTaskAssignor
+        configs = self.get_configs(
+            extra_configs=",application.id=failover_with_broker_down" +
+                          ",internal.task.assignor.class=org.apache.kafka.streams.processor.internals.assignment.StickyTaskAssignor"
+        )
 
         processor = StreamsBrokerDownResilienceService(self.test_context, self.kafka, configs)
         processor.start()
diff --git a/tests/kafkatest/tests/streams/streams_standby_replica_test.py b/tests/kafkatest/tests/streams/streams_standby_replica_test.py
index 310f8a5..e847c3e 100644
--- a/tests/kafkatest/tests/streams/streams_standby_replica_test.py
+++ b/tests/kafkatest/tests/streams/streams_standby_replica_test.py
@@ -44,9 +44,14 @@ class StreamsStandbyTask(BaseStreamsTest):
                                                  })
 
     def test_standby_tasks_rebalance(self):
-        configs = self.get_configs(",sourceTopic=%s,sinkTopic1=%s,sinkTopic2=%s" % (self.streams_source_topic,
-                                                                                    self.streams_sink_topic_1,
-                                                                                    self.streams_sink_topic_2))
+        # TODO KIP-441: consider rewriting the test for HighAvailabilityTaskAssignor
+        configs = self.get_configs(
+            ",sourceTopic=%s,sinkTopic1=%s,sinkTopic2=%s,internal.task.assignor.class=org.apache.kafka.streams.processor.internals.assignment.StickyTaskAssignor" % (
+            self.streams_source_topic,
+            self.streams_sink_topic_1,
+            self.streams_sink_topic_2
+            )
+        )
 
         producer = self.get_producer(self.streams_source_topic, self.num_messages, throughput=15000, repeating_keys=6)
         producer.start()
diff --git a/tests/kafkatest/tests/streams/streams_upgrade_test.py b/tests/kafkatest/tests/streams/streams_upgrade_test.py
index 1d00d90..16b9d60 100644
--- a/tests/kafkatest/tests/streams/streams_upgrade_test.py
+++ b/tests/kafkatest/tests/streams/streams_upgrade_test.py
@@ -303,9 +303,13 @@ class StreamsUpgradeTest(Test):
 
         self.driver = StreamsSmokeTestDriverService(self.test_context, self.kafka)
         self.driver.disable_auto_terminate()
+        # TODO KIP-441: consider rewriting the test for HighAvailabilityTaskAssignor
         self.processor1 = StreamsUpgradeTestJobRunnerService(self.test_context, self.kafka)
+        self.processor1.set_config("internal.task.assignor.class", "org.apache.kafka.streams.processor.internals.assignment.StickyTaskAssignor")
         self.processor2 = StreamsUpgradeTestJobRunnerService(self.test_context, self.kafka)
+        self.processor2.set_config("internal.task.assignor.class", "org.apache.kafka.streams.processor.internals.assignment.StickyTaskAssignor")
         self.processor3 = StreamsUpgradeTestJobRunnerService(self.test_context, self.kafka)
+        self.processor3.set_config("internal.task.assignor.class", "org.apache.kafka.streams.processor.internals.assignment.StickyTaskAssignor")
 
         self.driver.start()
         self.start_all_nodes_with("") # run with TRUNK