You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by sh...@apache.org on 2022/03/03 03:39:41 UTC

[kafka] branch trunk updated: KAFKA-6718 / Rack aware standby task assignor (#10851)

This is an automated email from the ASF dual-hosted git repository.

showuon pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 62e6466  KAFKA-6718 / Rack aware standby task assignor (#10851)
62e6466 is described below

commit 62e646619b208098c07e4d044c1ed7fcd44bf6f9
Author: Levani Kokhreidze <le...@gmail.com>
AuthorDate: Thu Mar 3 05:37:26 2022 +0200

    KAFKA-6718 / Rack aware standby task assignor (#10851)
    
    This PR is part of KIP-708 and adds rack aware standby task assignment logic.
    
    Reviewer: Bruno Cadonna <ca...@apache.org>, Luke Chen <sh...@gmail.com>, Vladimir Sitnikov <vladimirsitnikov.apache.org>
---
 .../assignment/AssignorConfiguration.java          |   8 +-
 .../internals/assignment/ClientState.java          |  14 +-
 .../ClientTagAwareStandbyTaskAssignor.java         | 337 +++++++++++
 .../assignment/DefaultStandbyTaskAssignor.java     |  70 +++
 .../assignment/HighAvailabilityTaskAssignor.java   |  75 ++-
 .../assignment/StandbyTaskAssignmentUtils.java     |  59 ++
 .../internals/assignment/StandbyTaskAssignor.java} |  22 +-
 .../internals/assignment/AssignmentTestUtils.java  |   4 +
 .../assignment/AssignorConfigurationTest.java      |   3 +-
 .../internals/assignment/ClientStateTest.java      |  14 +
 .../ClientTagAwareStandbyTaskAssignorTest.java     | 641 +++++++++++++++++++++
 .../assignment/FallbackPriorTaskAssignorTest.java  |   3 +-
 .../HighAvailabilityTaskAssignorTest.java          | 173 +++---
 .../assignment/StandbyTaskAssignmentUtilsTest.java | 126 ++++
 .../assignment/StickyTaskAssignorTest.java         |   5 +-
 .../assignment/TaskAssignorConvergenceTest.java    |  13 +-
 16 files changed, 1428 insertions(+), 139 deletions(-)

diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfiguration.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfiguration.java
index c14ab70..65cc7ae 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfiguration.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfiguration.java
@@ -29,6 +29,8 @@ import org.apache.kafka.streams.processor.internals.ClientUtils;
 import org.apache.kafka.streams.processor.internals.InternalTopicManager;
 import org.slf4j.Logger;
 
+import java.util.Collections;
+import java.util.List;
 import java.util.Map;
 
 import static org.apache.kafka.common.utils.Utils.getHost;
@@ -241,22 +243,26 @@ public final class AssignorConfiguration {
         public final int maxWarmupReplicas;
         public final int numStandbyReplicas;
         public final long probingRebalanceIntervalMs;
+        public final List<String> rackAwareAssignmentTags;
 
         private AssignmentConfigs(final StreamsConfig configs) {
             acceptableRecoveryLag = configs.getLong(StreamsConfig.ACCEPTABLE_RECOVERY_LAG_CONFIG);
             maxWarmupReplicas = configs.getInt(StreamsConfig.MAX_WARMUP_REPLICAS_CONFIG);
             numStandbyReplicas = configs.getInt(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG);
             probingRebalanceIntervalMs = configs.getLong(StreamsConfig.PROBING_REBALANCE_INTERVAL_MS_CONFIG);
+            rackAwareAssignmentTags = Collections.emptyList();
         }
 
         AssignmentConfigs(final Long acceptableRecoveryLag,
                           final Integer maxWarmupReplicas,
                           final Integer numStandbyReplicas,
-                          final Long probingRebalanceIntervalMs) {
+                          final Long probingRebalanceIntervalMs,
+                          final List<String> rackAwareAssignmentTags) {
             this.acceptableRecoveryLag = validated(StreamsConfig.ACCEPTABLE_RECOVERY_LAG_CONFIG, acceptableRecoveryLag);
             this.maxWarmupReplicas = validated(StreamsConfig.MAX_WARMUP_REPLICAS_CONFIG, maxWarmupReplicas);
             this.numStandbyReplicas = validated(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, numStandbyReplicas);
             this.probingRebalanceIntervalMs = validated(StreamsConfig.PROBING_REBALANCE_INTERVAL_MS_CONFIG, probingRebalanceIntervalMs);
+            this.rackAwareAssignmentTags = rackAwareAssignmentTags;
         }
 
         private static <T> T validated(final String configKey, final T value) {
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
index 46f107e..5ee0e93 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
@@ -23,6 +23,7 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.util.Collection;
+import java.util.Collections;
 import java.util.Comparator;
 import java.util.HashSet;
 import java.util.Map;
@@ -46,6 +47,7 @@ public class ClientState {
     private static final Logger LOG = LoggerFactory.getLogger(ClientState.class);
     public static final Comparator<TopicPartition> TOPIC_PARTITION_COMPARATOR = comparing(TopicPartition::topic).thenComparing(TopicPartition::partition);
 
+    private final Map<String, String> clientTags;
     private final Map<TaskId, Long> taskOffsetSums; // contains only stateful tasks we previously owned
     private final Map<TaskId, Long> taskLagTotals;  // contains lag for all stateful tasks in the app topology
     private final Map<TopicPartition, String> ownedPartitions = new TreeMap<>(TOPIC_PARTITION_COMPARATOR);
@@ -64,24 +66,30 @@ public class ClientState {
     }
 
     ClientState(final int capacity) {
+        this(capacity, Collections.emptyMap());
+    }
+
+    ClientState(final int capacity, final Map<String, String> clientTags) {
         previousStandbyTasks.taskIds(new TreeSet<>());
         previousActiveTasks.taskIds(new TreeSet<>());
-
         taskOffsetSums = new TreeMap<>();
         taskLagTotals = new TreeMap<>();
         this.capacity = capacity;
+        this.clientTags = unmodifiableMap(clientTags);
     }
 
     // For testing only
     public ClientState(final Set<TaskId> previousActiveTasks,
                        final Set<TaskId> previousStandbyTasks,
                        final Map<TaskId, Long> taskLagTotals,
+                       final Map<String, String> clientTags,
                        final int capacity) {
         this.previousStandbyTasks.taskIds(unmodifiableSet(new TreeSet<>(previousStandbyTasks)));
         this.previousActiveTasks.taskIds(unmodifiableSet(new TreeSet<>(previousActiveTasks)));
         taskOffsetSums = emptyMap();
         this.taskLagTotals = unmodifiableMap(taskLagTotals);
         this.capacity = capacity;
+        this.clientTags = unmodifiableMap(clientTags);
     }
 
     int capacity() {
@@ -266,6 +274,10 @@ public class ClientState {
         return ownedPartitions.get(partition);
     }
 
+    public Map<String, String> clientTags() {
+        return clientTags;
+    }
+
     public void addOwnedPartitions(final Collection<TopicPartition> ownedPartitions, final String consumer) {
         for (final TopicPartition tp : ownedPartitions) {
             this.ownedPartitions.put(tp, consumer);
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientTagAwareStandbyTaskAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientTagAwareStandbyTaskAssignor.java
new file mode 100644
index 0000000..c7399d7
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientTagAwareStandbyTaskAssignor.java
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Objects;
+import java.util.Set;
+import java.util.UUID;
+
+import static org.apache.kafka.streams.processor.internals.assignment.StandbyTaskAssignmentUtils.computeTasksToRemainingStandbys;
+import static org.apache.kafka.streams.processor.internals.assignment.StandbyTaskAssignmentUtils.createLeastLoadedPrioritySetConstrainedByAssignedTask;
+import static org.apache.kafka.streams.processor.internals.assignment.StandbyTaskAssignmentUtils.pollClientAndMaybeAssignAndUpdateRemainingStandbyTasks;
+
+/**
+ * Distributes standby tasks over different tag dimensions. Standby task distribution is on a best-effort basis.
+ * If rack aware standby task assignment is not possible, implementation fall backs to distributing standby tasks on least-loaded clients.
+ *
+ * @see DefaultStandbyTaskAssignor
+ */
+class ClientTagAwareStandbyTaskAssignor implements StandbyTaskAssignor {
+    private static final Logger log = LoggerFactory.getLogger(ClientTagAwareStandbyTaskAssignor.class);
+
+    /**
+     * The algorithm distributes standby tasks for the {@param statefulTaskIds} over different tag dimensions.
+     * For each stateful task, the number of standby tasks will be assigned based on configured {@link AssignmentConfigs#numStandbyReplicas}.
+     * Rack aware standby tasks distribution only takes into account tags specified via {@link AssignmentConfigs#rackAwareAssignmentTags}.
+     * Ideally, all standby tasks for any given stateful task will be located on different tag dimensions to have the best possible distribution.
+     * However, if the ideal (or partially ideal) distribution is impossible, the algorithm will fall back to the least-loaded clients without taking rack awareness constraints into consideration.
+     * The least-loaded clients are determined based on the total number of tasks (active and standby tasks) assigned to the client.
+     */
+    @Override
+    public boolean assign(final Map<UUID, ClientState> clients,
+                          final Set<TaskId> allTaskIds,
+                          final Set<TaskId> statefulTaskIds,
+                          final AssignorConfiguration.AssignmentConfigs configs) {
+        final int numStandbyReplicas = configs.numStandbyReplicas;
+        final Set<String> rackAwareAssignmentTags = new HashSet<>(configs.rackAwareAssignmentTags);
+
+        final Map<TaskId, Integer> tasksToRemainingStandbys = computeTasksToRemainingStandbys(
+            numStandbyReplicas,
+            allTaskIds
+        );
+
+        final Map<String, Set<String>> tagKeyToValues = new HashMap<>();
+        final Map<TagEntry, Set<UUID>> tagEntryToClients = new HashMap<>();
+
+        fillClientsTagStatistics(clients, tagEntryToClients, tagKeyToValues);
+
+        final ConstrainedPrioritySet standbyTaskClientsByTaskLoad = createLeastLoadedPrioritySetConstrainedByAssignedTask(clients);
+
+        final Map<TaskId, UUID> pendingStandbyTasksToClientId = new HashMap<>();
+
+        for (final TaskId statefulTaskId : statefulTaskIds) {
+            for (final Map.Entry<UUID, ClientState> entry : clients.entrySet()) {
+                final UUID clientId = entry.getKey();
+                final ClientState clientState = entry.getValue();
+
+                if (clientState.activeTasks().contains(statefulTaskId)) {
+                    assignStandbyTasksToClientsWithDifferentTags(
+                        standbyTaskClientsByTaskLoad,
+                        statefulTaskId,
+                        clientId,
+                        rackAwareAssignmentTags,
+                        clients,
+                        tasksToRemainingStandbys,
+                        tagKeyToValues,
+                        tagEntryToClients,
+                        pendingStandbyTasksToClientId
+                    );
+                }
+            }
+        }
+
+        if (!tasksToRemainingStandbys.isEmpty()) {
+            log.debug("Rack aware standby task assignment was not able to assign all standby tasks. " +
+                      "tasksToRemainingStandbys=[{}], pendingStandbyTasksToClientId=[{}]. " +
+                      "Will distribute the remaining standby tasks to least loaded clients.",
+                      tasksToRemainingStandbys, pendingStandbyTasksToClientId);
+
+            assignPendingStandbyTasksToLeastLoadedClients(clients,
+                                                          numStandbyReplicas,
+                                                          rackAwareAssignmentTags,
+                                                          standbyTaskClientsByTaskLoad,
+                                                          tasksToRemainingStandbys,
+                                                          pendingStandbyTasksToClientId);
+        }
+
+        // returning false, because standby task assignment will never require a follow-up probing rebalance.
+        return false;
+    }
+
+    private static void assignPendingStandbyTasksToLeastLoadedClients(final Map<UUID, ClientState> clients,
+                                                                      final int numStandbyReplicas,
+                                                                      final Set<String> rackAwareAssignmentTags,
+                                                                      final ConstrainedPrioritySet standbyTaskClientsByTaskLoad,
+                                                                      final Map<TaskId, Integer> pendingStandbyTaskToNumberRemainingStandbys,
+                                                                      final Map<TaskId, UUID> pendingStandbyTaskToClientId) {
+        // We need to re offer all the clients to find the least loaded ones
+        standbyTaskClientsByTaskLoad.offerAll(clients.keySet());
+
+        for (final Entry<TaskId, Integer> pendingStandbyTaskAssignmentEntry : pendingStandbyTaskToNumberRemainingStandbys.entrySet()) {
+            final TaskId activeTaskId = pendingStandbyTaskAssignmentEntry.getKey();
+            final UUID clientId = pendingStandbyTaskToClientId.get(activeTaskId);
+
+            final int numberOfRemainingStandbys = pollClientAndMaybeAssignAndUpdateRemainingStandbyTasks(
+                clients,
+                pendingStandbyTaskToNumberRemainingStandbys,
+                standbyTaskClientsByTaskLoad,
+                activeTaskId
+            );
+
+            if (numberOfRemainingStandbys > 0) {
+                log.warn("Unable to assign {} of {} standby tasks for task [{}] with client tags [{}]. " +
+                         "There is not enough available capacity. You should " +
+                         "increase the number of application instances " +
+                         "on different client tag dimensions " +
+                         "to maintain the requested number of standby replicas. " +
+                         "Rack awareness is configured with [{}] tags.",
+                         numberOfRemainingStandbys, numStandbyReplicas, activeTaskId,
+                         clients.get(clientId).clientTags(), rackAwareAssignmentTags);
+            }
+        }
+    }
+
+    @Override
+    public boolean isAllowedTaskMovement(final ClientState source, final ClientState destination) {
+        final Map<String, String> sourceClientTags = source.clientTags();
+        final Map<String, String> destinationClientTags = destination.clientTags();
+
+        for (final Entry<String, String> sourceClientTagEntry : sourceClientTags.entrySet()) {
+            if (!sourceClientTagEntry.getValue().equals(destinationClientTags.get(sourceClientTagEntry.getKey()))) {
+                return false;
+            }
+        }
+
+        return true;
+    }
+
+    // Visible for testing
+    static void fillClientsTagStatistics(final Map<UUID, ClientState> clientStates,
+                                         final Map<TagEntry, Set<UUID>> tagEntryToClients,
+                                         final Map<String, Set<String>> tagKeyToValues) {
+        for (final Entry<UUID, ClientState> clientStateEntry : clientStates.entrySet()) {
+            final UUID clientId = clientStateEntry.getKey();
+            final ClientState clientState = clientStateEntry.getValue();
+
+            clientState.clientTags().forEach((tagKey, tagValue) -> {
+                tagKeyToValues.computeIfAbsent(tagKey, ignored -> new HashSet<>()).add(tagValue);
+                tagEntryToClients.computeIfAbsent(new TagEntry(tagKey, tagValue), ignored -> new HashSet<>()).add(clientId);
+            });
+        }
+    }
+
+    // Visible for testing
+    static void assignStandbyTasksToClientsWithDifferentTags(final ConstrainedPrioritySet standbyTaskClientsByTaskLoad,
+                                                             final TaskId activeTaskId,
+                                                             final UUID activeTaskClient,
+                                                             final Set<String> rackAwareAssignmentTags,
+                                                             final Map<UUID, ClientState> clientStates,
+                                                             final Map<TaskId, Integer> tasksToRemainingStandbys,
+                                                             final Map<String, Set<String>> tagKeyToValues,
+                                                             final Map<TagEntry, Set<UUID>> tagEntryToClients,
+                                                             final Map<TaskId, UUID> pendingStandbyTasksToClientId) {
+        standbyTaskClientsByTaskLoad.offerAll(clientStates.keySet());
+
+        // We set countOfUsedClients as 1 because client where active task is located has to be considered as used.
+        int countOfUsedClients = 1;
+        int numRemainingStandbys = tasksToRemainingStandbys.get(activeTaskId);
+
+        final Map<TagEntry, Set<UUID>> tagEntryToUsedClients = new HashMap<>();
+
+        UUID lastUsedClient = activeTaskClient;
+        do {
+            updateClientsOnAlreadyUsedTagEntries(
+                lastUsedClient,
+                countOfUsedClients,
+                rackAwareAssignmentTags,
+                clientStates,
+                tagEntryToClients,
+                tagKeyToValues,
+                tagEntryToUsedClients
+            );
+
+            final UUID clientOnUnusedTagDimensions = standbyTaskClientsByTaskLoad.poll(
+                activeTaskId, uuid -> !isClientUsedOnAnyOfTheTagEntries(uuid, tagEntryToUsedClients)
+            );
+
+            if (clientOnUnusedTagDimensions == null) {
+                break;
+            }
+
+            clientStates.get(clientOnUnusedTagDimensions).assignStandby(activeTaskId);
+
+            countOfUsedClients++;
+            numRemainingStandbys--;
+
+            lastUsedClient = clientOnUnusedTagDimensions;
+        } while (numRemainingStandbys > 0);
+
+        if (numRemainingStandbys > 0) {
+            pendingStandbyTasksToClientId.put(activeTaskId, activeTaskClient);
+            tasksToRemainingStandbys.put(activeTaskId, numRemainingStandbys);
+        } else {
+            tasksToRemainingStandbys.remove(activeTaskId);
+        }
+    }
+
+    private static boolean isClientUsedOnAnyOfTheTagEntries(final UUID client,
+                                                            final Map<TagEntry, Set<UUID>> tagEntryToUsedClients) {
+        return tagEntryToUsedClients.values().stream().anyMatch(usedClients -> usedClients.contains(client));
+    }
+
+    private static void updateClientsOnAlreadyUsedTagEntries(final UUID usedClient,
+                                                             final int countOfUsedClients,
+                                                             final Set<String> rackAwareAssignmentTags,
+                                                             final Map<UUID, ClientState> clientStates,
+                                                             final Map<TagEntry, Set<UUID>> tagEntryToClients,
+                                                             final Map<String, Set<String>> tagKeyToValues,
+                                                             final Map<TagEntry, Set<UUID>> tagEntryToUsedClients) {
+        final Map<String, String> usedClientTags = clientStates.get(usedClient).clientTags();
+
+        for (final Entry<String, String> usedClientTagEntry : usedClientTags.entrySet()) {
+            final String tagKey = usedClientTagEntry.getKey();
+
+            if (!rackAwareAssignmentTags.contains(tagKey)) {
+                log.warn("Client tag with key [{}] will be ignored when computing rack aware standby " +
+                         "task assignment because it is not part of the configured rack awareness [{}].",
+                         tagKey, rackAwareAssignmentTags);
+                continue;
+            }
+
+            final Set<String> allTagValues = tagKeyToValues.get(tagKey);
+
+            // Consider the following client setup where we need to distribute 2 standby tasks for each stateful task.
+            //
+            // # Kafka Streams Client 1
+            // client.tag.zone: eu-central-1a
+            // client.tag.cluster: k8s-cluster1
+            // rack.aware.assignment.tags: zone,cluster
+            //
+            // # Kafka Streams Client 2
+            // client.tag.zone: eu-central-1b
+            // client.tag.cluster: k8s-cluster1
+            // rack.aware.assignment.tags: zone,cluster
+            //
+            // # Kafka Streams Client 3
+            // client.tag.zone: eu-central-1c
+            // client.tag.cluster: k8s-cluster1
+            // rack.aware.assignment.tags: zone,cluster
+            //
+            // # Kafka Streams Client 4
+            // client.tag.zone: eu-central-1a
+            // client.tag.cluster: k8s-cluster2
+            // rack.aware.assignment.tags: zone,cluster
+            //
+            // # Kafka Streams Client 5
+            // client.tag.zone: eu-central-1b
+            // client.tag.cluster: k8s-cluster2
+            // rack.aware.assignment.tags: zone,cluster
+            //
+            // # Kafka Streams Client 6
+            // client.tag.zone: eu-central-1c
+            // client.tag.cluster: k8s-cluster2
+            // rack.aware.assignment.tags: zone,cluster
+            //
+            // Since we have only two unique `cluster` tag values,
+            // we can only achieve "ideal" distribution on the 1st standby task assignment.
+            // Ideal distribution for the 1st standby task can be achieved because we can assign standby task
+            // to the client located on different cluster and zone compared to an active task.
+            // We can't consider the `cluster` tag for the 2nd standby task assignment because the 1st standby
+            // task would already be assigned on different cluster compared to the active one, which means
+            // we have already used all the available cluster tag values. Taking the `cluster` tag into consideration
+            // for the 2nd standby task assignment would effectively mean excluding all the clients.
+            // Instead, for the 2nd standby task, we can only achieve partial rack awareness based on the `zone` tag.
+            // As we don't consider the `cluster` tag for the 2nd standby task assignment, partial rack awareness
+            // can be satisfied by placing the 2nd standby client on a different `zone` tag compared to active and corresponding standby tasks.
+            // The `zone` on either `cluster` tags are valid candidates for the partial rack awareness, as our goal is to distribute clients on the different `zone` tags.
+
+            // This statement checks if we have used more clients than the number of unique values for the given tag,
+            // and if so, removes those tag entries from the tagEntryToUsedClients map.
+            if (allTagValues.size() <= countOfUsedClients) {
+                allTagValues.forEach(tagValue -> tagEntryToUsedClients.remove(new TagEntry(tagKey, tagValue)));
+            } else {
+                final String tagValue = usedClientTagEntry.getValue();
+                final TagEntry tagEntry = new TagEntry(tagKey, tagValue);
+                final Set<UUID> clientsOnUsedTagValue = tagEntryToClients.get(tagEntry);
+                tagEntryToUsedClients.put(tagEntry, clientsOnUsedTagValue);
+            }
+        }
+    }
+
+    // Visible for testing
+    static final class TagEntry {
+        private final String tagKey;
+        private final String tagValue;
+
+        TagEntry(final String tagKey, final String tagValue) {
+            this.tagKey = tagKey;
+            this.tagValue = tagValue;
+        }
+
+        @Override
+        public boolean equals(final Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            final TagEntry that = (TagEntry) o;
+            return Objects.equals(tagKey, that.tagKey) && Objects.equals(tagValue, that.tagValue);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(tagKey, tagValue);
+        }
+    }
+}
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/DefaultStandbyTaskAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/DefaultStandbyTaskAssignor.java
new file mode 100644
index 0000000..db6cb4e
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/DefaultStandbyTaskAssignor.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import org.apache.kafka.streams.processor.TaskId;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Map;
+import java.util.Set;
+import java.util.UUID;
+
+import static org.apache.kafka.streams.processor.internals.assignment.StandbyTaskAssignmentUtils.computeTasksToRemainingStandbys;
+import static org.apache.kafka.streams.processor.internals.assignment.StandbyTaskAssignmentUtils.createLeastLoadedPrioritySetConstrainedByAssignedTask;
+import static org.apache.kafka.streams.processor.internals.assignment.StandbyTaskAssignmentUtils.pollClientAndMaybeAssignAndUpdateRemainingStandbyTasks;
+
+/**
+ * Default standby task assignor that distributes standby tasks to the least loaded clients.
+ *
+ * @see ClientTagAwareStandbyTaskAssignor
+ */
+class DefaultStandbyTaskAssignor implements StandbyTaskAssignor {
+    private static final Logger log = LoggerFactory.getLogger(DefaultStandbyTaskAssignor.class);
+
+    @Override
+    public boolean assign(final Map<UUID, ClientState> clients,
+                          final Set<TaskId> allTaskIds,
+                          final Set<TaskId> statefulTaskIds,
+                          final AssignorConfiguration.AssignmentConfigs configs) {
+        final int numStandbyReplicas = configs.numStandbyReplicas;
+        final Map<TaskId, Integer> tasksToRemainingStandbys = computeTasksToRemainingStandbys(numStandbyReplicas,
+                                                                                              statefulTaskIds);
+
+        final ConstrainedPrioritySet standbyTaskClientsByTaskLoad = createLeastLoadedPrioritySetConstrainedByAssignedTask(clients);
+
+        standbyTaskClientsByTaskLoad.offerAll(clients.keySet());
+
+        for (final TaskId task : statefulTaskIds) {
+            final int numRemainingStandbys = pollClientAndMaybeAssignAndUpdateRemainingStandbyTasks(clients,
+                                                                                                    tasksToRemainingStandbys,
+                                                                                                    standbyTaskClientsByTaskLoad,
+                                                                                                    task);
+
+            if (numRemainingStandbys > 0) {
+                log.warn("Unable to assign {} of {} standby tasks for task [{}]. " +
+                         "There is not enough available capacity. You should " +
+                         "increase the number of application instances " +
+                         "to maintain the requested number of standby replicas.",
+                         numRemainingStandbys, numStandbyReplicas, task);
+            }
+        }
+
+        // returning false, because standby task assignment will never require a follow-up probing rebalance.
+        return false;
+    }
+}
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java
index f6464f8..d0bb50b 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java
@@ -33,8 +33,8 @@ import java.util.TreeSet;
 import java.util.UUID;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.function.BiConsumer;
+import java.util.function.BiPredicate;
 import java.util.function.Function;
-import java.util.stream.Collectors;
 
 import static org.apache.kafka.common.utils.Utils.diff;
 import static org.apache.kafka.streams.processor.internals.assignment.TaskMovement.assignActiveTaskMovements;
@@ -55,8 +55,9 @@ public class HighAvailabilityTaskAssignor implements TaskAssignor {
 
         assignStandbyReplicaTasks(
             clientStates,
+            allTaskIds,
             statefulTasks,
-            configs.numStandbyReplicas
+            configs
         );
 
         final AtomicInteger remainingWarmupReplicas = new AtomicInteger(configs.maxWarmupReplicas);
@@ -93,10 +94,10 @@ public class HighAvailabilityTaskAssignor implements TaskAssignor {
         final boolean probingRebalanceNeeded = neededActiveTaskMovements + neededStandbyTaskMovements > 0;
 
         log.info("Decided on assignment: " +
-                     clientStates +
-                     " with" +
-                     (probingRebalanceNeeded ? "" : " no") +
-                     " followup probing rebalance.");
+                 clientStates +
+                 " with" +
+                 (probingRebalanceNeeded ? "" : " no") +
+                 " followup probing rebalance.");
 
         return probingRebalanceNeeded;
     }
@@ -115,55 +116,46 @@ public class HighAvailabilityTaskAssignor implements TaskAssignor {
             clientStates,
             ClientState::activeTasks,
             ClientState::unassignActive,
-            ClientState::assignActive
+            ClientState::assignActive,
+            (source, destination) -> true
         );
     }
 
-    private static void assignStandbyReplicaTasks(final TreeMap<UUID, ClientState> clientStates,
-                                                  final Set<TaskId> statefulTasks,
-                                                  final int numStandbyReplicas) {
-        final Map<TaskId, Integer> tasksToRemainingStandbys =
-            statefulTasks.stream().collect(Collectors.toMap(task -> task, t -> numStandbyReplicas));
+    private void assignStandbyReplicaTasks(final TreeMap<UUID, ClientState> clientStates,
+                                           final Set<TaskId> allTaskIds,
+                                           final Set<TaskId> statefulTasks,
+                                           final AssignmentConfigs configs) {
+        if (configs.numStandbyReplicas == 0) {
+            return;
+        }
 
-        final ConstrainedPrioritySet standbyTaskClientsByTaskLoad = new ConstrainedPrioritySet(
-            (client, task) -> !clientStates.get(client).hasAssignedTask(task),
-            client -> clientStates.get(client).assignedTaskLoad()
-        );
-        standbyTaskClientsByTaskLoad.offerAll(clientStates.keySet());
+        final StandbyTaskAssignor standbyTaskAssignor = createStandbyTaskAssignor(configs);
 
-        for (final TaskId task : statefulTasks) {
-            int numRemainingStandbys = tasksToRemainingStandbys.get(task);
-            while (numRemainingStandbys > 0) {
-                final UUID client = standbyTaskClientsByTaskLoad.poll(task);
-                if (client == null) {
-                    break;
-                }
-                clientStates.get(client).assignStandby(task);
-                numRemainingStandbys--;
-                standbyTaskClientsByTaskLoad.offer(client);
-            }
-
-            if (numRemainingStandbys > 0) {
-                log.warn("Unable to assign {} of {} standby tasks for task [{}]. " +
-                             "There is not enough available capacity. You should " +
-                             "increase the number of application instances " +
-                             "to maintain the requested number of standby replicas.",
-                         numRemainingStandbys, numStandbyReplicas, task);
-            }
-        }
+        standbyTaskAssignor.assign(clientStates, allTaskIds, statefulTasks, configs);
 
         balanceTasksOverThreads(
             clientStates,
             ClientState::standbyTasks,
             ClientState::unassignStandby,
-            ClientState::assignStandby
+            ClientState::assignStandby,
+            standbyTaskAssignor::isAllowedTaskMovement
         );
     }
 
+    // Visible for testing
+    static StandbyTaskAssignor createStandbyTaskAssignor(final AssignmentConfigs configs) {
+        if (!configs.rackAwareAssignmentTags.isEmpty()) {
+            return new ClientTagAwareStandbyTaskAssignor();
+        } else {
+            return new DefaultStandbyTaskAssignor();
+        }
+    }
+
     private static void balanceTasksOverThreads(final SortedMap<UUID, ClientState> clientStates,
                                                 final Function<ClientState, Set<TaskId>> currentAssignmentAccessor,
                                                 final BiConsumer<ClientState, TaskId> taskUnassignor,
-                                                final BiConsumer<ClientState, TaskId> taskAssignor) {
+                                                final BiConsumer<ClientState, TaskId> taskAssignor,
+                                                final BiPredicate<ClientState, ClientState> taskMovementAttemptPredicate) {
         boolean keepBalancing = true;
         while (keepBalancing) {
             keepBalancing = false;
@@ -182,7 +174,10 @@ public class HighAvailabilityTaskAssignor implements TaskAssignor {
                     final Iterator<TaskId> sourceIterator = sourceTasks.iterator();
                     while (shouldMoveATask(sourceClientState, destinationClientState) && sourceIterator.hasNext()) {
                         final TaskId taskToMove = sourceIterator.next();
-                        final boolean canMove = !destinationClientState.hasAssignedTask(taskToMove);
+                        final boolean canMove = !destinationClientState.hasAssignedTask(taskToMove)
+                                                // When ClientTagAwareStandbyTaskAssignor is used, we need to make sure that
+                                                // sourceClient tags matches destinationClient tags.
+                                                && taskMovementAttemptPredicate.test(sourceClientState, destinationClientState);
                         if (canMove) {
                             taskUnassignor.accept(sourceClientState, taskToMove);
                             taskAssignor.accept(destinationClientState, taskToMove);
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StandbyTaskAssignmentUtils.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StandbyTaskAssignmentUtils.java
new file mode 100644
index 0000000..7ed6f5d
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StandbyTaskAssignmentUtils.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import org.apache.kafka.streams.processor.TaskId;
+
+import java.util.Map;
+import java.util.Set;
+import java.util.UUID;
+import java.util.function.Function;
+
+import static java.util.stream.Collectors.toMap;
+
+final class StandbyTaskAssignmentUtils {
+    private StandbyTaskAssignmentUtils() {}
+
+    static ConstrainedPrioritySet createLeastLoadedPrioritySetConstrainedByAssignedTask(final Map<UUID, ClientState> clients) {
+        return new ConstrainedPrioritySet((client, t) -> !clients.get(client).hasAssignedTask(t),
+                                          client -> clients.get(client).assignedTaskLoad());
+    }
+
+    static int pollClientAndMaybeAssignAndUpdateRemainingStandbyTasks(final Map<UUID, ClientState> clients,
+                                                                      final Map<TaskId, Integer> tasksToRemainingStandbys,
+                                                                      final ConstrainedPrioritySet standbyTaskClientsByTaskLoad,
+                                                                      final TaskId activeTaskId) {
+        int numRemainingStandbys = tasksToRemainingStandbys.get(activeTaskId);
+        while (numRemainingStandbys > 0) {
+            final UUID client = standbyTaskClientsByTaskLoad.poll(activeTaskId);
+            if (client == null) {
+                break;
+            }
+            clients.get(client).assignStandby(activeTaskId);
+            numRemainingStandbys--;
+            standbyTaskClientsByTaskLoad.offer(client);
+            tasksToRemainingStandbys.put(activeTaskId, numRemainingStandbys);
+        }
+
+        return numRemainingStandbys;
+    }
+
+    static Map<TaskId, Integer> computeTasksToRemainingStandbys(final int numStandbyReplicas,
+                                                                final Set<TaskId> statefulTaskIds) {
+        return statefulTaskIds.stream().collect(toMap(Function.identity(), t -> numStandbyReplicas));
+    }
+}
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfigurationTest.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StandbyTaskAssignor.java
similarity index 57%
copy from streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfigurationTest.java
copy to streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StandbyTaskAssignor.java
index 868c303..3b2ce99 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfigurationTest.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StandbyTaskAssignor.java
@@ -16,22 +16,8 @@
  */
 package org.apache.kafka.streams.processor.internals.assignment;
 
-import org.apache.kafka.common.config.ConfigException;
-import org.junit.Test;
-
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.Matchers.containsString;
-import static org.junit.Assert.assertThrows;
-
-public class AssignorConfigurationTest {
-
-    @Test
-    public void configsShouldRejectZeroWarmups() {
-        final ConfigException exception = assertThrows(
-            ConfigException.class,
-            () -> new AssignorConfiguration.AssignmentConfigs(1L, 0, 1, 1L)
-        );
-
-        assertThat(exception.getMessage(), containsString("Invalid value 0 for configuration max.warmup.replicas"));
+interface StandbyTaskAssignor extends TaskAssignor {
+    default boolean isAllowedTaskMovement(final ClientState source, final ClientState destination) {
+        return true;
     }
-}
+}
\ No newline at end of file
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentTestUtils.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentTestUtils.java
index 38669da..627114a 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentTestUtils.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentTestUtils.java
@@ -17,6 +17,8 @@
 package org.apache.kafka.streams.processor.internals.assignment;
 
 import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
 import java.util.Map.Entry;
 import org.apache.kafka.clients.admin.AdminClient;
 import org.apache.kafka.clients.admin.ListOffsetsResult;
@@ -102,6 +104,8 @@ public final class AssignmentTestUtils {
 
     public static final Set<TaskId> EMPTY_TASKS = emptySet();
     public static final Map<TopicPartition, Long> EMPTY_CHANGELOG_END_OFFSETS = new HashMap<>();
+    public static final List<String> EMPTY_RACK_AWARE_ASSIGNMENT_TAGS = Collections.emptyList();
+    public static final Map<String, String> EMPTY_CLIENT_TAGS = Collections.emptyMap();
 
     private AssignmentTestUtils() {}
 
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfigurationTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfigurationTest.java
index 868c303..9ff53b5 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfigurationTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfigurationTest.java
@@ -19,6 +19,7 @@ package org.apache.kafka.streams.processor.internals.assignment;
 import org.apache.kafka.common.config.ConfigException;
 import org.junit.Test;
 
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.EMPTY_RACK_AWARE_ASSIGNMENT_TAGS;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.containsString;
 import static org.junit.Assert.assertThrows;
@@ -29,7 +30,7 @@ public class AssignorConfigurationTest {
     public void configsShouldRejectZeroWarmups() {
         final ConfigException exception = assertThrows(
             ConfigException.class,
-            () -> new AssignorConfiguration.AssignmentConfigs(1L, 0, 1, 1L)
+            () -> new AssignorConfiguration.AssignmentConfigs(1L, 0, 1, 1L, EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
         );
 
         assertThat(exception.getMessage(), containsString("Invalid value 0 for configuration max.warmup.replicas"));
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java
index 928129c..940c693 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java
@@ -33,6 +33,7 @@ import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.common.utils.Utils.mkSet;
 import static org.apache.kafka.common.utils.Utils.mkSortedSet;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.EMPTY_CLIENT_TAGS;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.NAMED_TASK_T0_0_0;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.NAMED_TASK_T1_0_0;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_0;
@@ -52,6 +53,7 @@ import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.is;
+import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
@@ -66,6 +68,7 @@ public class ClientStateTest {
             mkSet(TASK_0_0, TASK_0_1),
             mkSet(TASK_0_2, TASK_0_3),
             mkMap(mkEntry(TASK_0_0, 5L), mkEntry(TASK_0_2, -1L)),
+            EMPTY_CLIENT_TAGS,
             4
         );
 
@@ -531,4 +534,15 @@ public class ClientStateTest {
         assertThrows(IllegalStateException.class, () -> client.assignActiveToConsumer(TASK_0_0, "c1"));
     }
 
+    @Test
+    public void shouldReturnClientTags() {
+        final Map<String, String> clientTags = mkMap(mkEntry("k1", "v1"));
+        assertEquals(clientTags, new ClientState(0, clientTags).clientTags());
+    }
+
+    @Test
+    public void shouldReturnEmptyClientTagsMapByDefault() {
+        assertTrue(new ClientState().clientTags().isEmpty());
+    }
+
 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientTagAwareStandbyTaskAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientTagAwareStandbyTaskAssignorTest.java
new file mode 100644
index 0000000..8a983de
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientTagAwareStandbyTaskAssignorTest.java
@@ -0,0 +1,641 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs;
+import org.apache.kafka.streams.processor.internals.assignment.ClientTagAwareStandbyTaskAssignor.TagEntry;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.UUID;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import static java.util.Arrays.asList;
+import static java.util.Collections.singletonList;
+import static org.apache.kafka.common.utils.Utils.mkEntry;
+import static org.apache.kafka.common.utils.Utils.mkMap;
+import static org.apache.kafka.common.utils.Utils.mkSet;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_0;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_1;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_2;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_0;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_1;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_2;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.uuidForInt;
+import static org.apache.kafka.streams.processor.internals.assignment.ClientTagAwareStandbyTaskAssignor.assignStandbyTasksToClientsWithDifferentTags;
+import static org.apache.kafka.streams.processor.internals.assignment.ClientTagAwareStandbyTaskAssignor.fillClientsTagStatistics;
+import static org.apache.kafka.streams.processor.internals.assignment.StandbyTaskAssignmentUtils.computeTasksToRemainingStandbys;
+import static org.apache.kafka.streams.processor.internals.assignment.StandbyTaskAssignmentUtils.createLeastLoadedPrioritySetConstrainedByAssignedTask;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+public class ClientTagAwareStandbyTaskAssignorTest {
+    private static final String ZONE_TAG = "zone";
+    private static final String CLUSTER_TAG = "cluster";
+
+    private static final String ZONE_1 = "zone1";
+    private static final String ZONE_2 = "zone2";
+    private static final String ZONE_3 = "zone3";
+
+    private static final String CLUSTER_1 = "cluster1";
+    private static final String CLUSTER_2 = "cluster2";
+    private static final String CLUSTER_3 = "cluster3";
+
+    private static final UUID UUID_1 = uuidForInt(1);
+    private static final UUID UUID_2 = uuidForInt(2);
+    private static final UUID UUID_3 = uuidForInt(3);
+    private static final UUID UUID_4 = uuidForInt(4);
+    private static final UUID UUID_5 = uuidForInt(5);
+    private static final UUID UUID_6 = uuidForInt(6);
+    private static final UUID UUID_7 = uuidForInt(7);
+    private static final UUID UUID_8 = uuidForInt(8);
+    private static final UUID UUID_9 = uuidForInt(9);
+
+    private StandbyTaskAssignor standbyTaskAssignor;
+
+    @Before
+    public void setup() {
+        standbyTaskAssignor = new ClientTagAwareStandbyTaskAssignor();
+    }
+
+    @Test
+    public void shouldRemoveClientToRemainingStandbysAndNotPopulatePendingStandbyTasksToClientIdWhenAllStandbyTasksWereAssigned() {
+        final Set<String> rackAwareAssignmentTags = mkSet(ZONE_TAG, CLUSTER_TAG);
+        final Map<UUID, ClientState> clientStates = mkMap(
+            mkEntry(UUID_1, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_0)),
+            mkEntry(UUID_2, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_2)), TASK_0_1)),
+            mkEntry(UUID_3, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_3)), TASK_0_2))
+        );
+
+        final ConstrainedPrioritySet constrainedPrioritySet = createLeastLoadedPrioritySetConstrainedByAssignedTask(clientStates);
+        final Set<TaskId> allActiveTasks = findAllActiveTasks(clientStates);
+        final Map<TaskId, UUID> taskToClientId = mkMap(mkEntry(TASK_0_0, UUID_1),
+                                                       mkEntry(TASK_0_1, UUID_2),
+                                                       mkEntry(TASK_0_2, UUID_3));
+
+        final Map<String, Set<String>> tagKeyToValues = new HashMap<>();
+        final Map<TagEntry, Set<UUID>> tagEntryToClients = new HashMap<>();
+
+        fillClientsTagStatistics(clientStates, tagEntryToClients, tagKeyToValues);
+
+        final Map<TaskId, UUID> pendingStandbyTasksToClientId = new HashMap<>();
+        final Map<TaskId, Integer> tasksToRemainingStandbys = computeTasksToRemainingStandbys(2, allActiveTasks);
+
+        for (final TaskId activeTaskId : allActiveTasks) {
+            assignStandbyTasksToClientsWithDifferentTags(
+                constrainedPrioritySet,
+                activeTaskId,
+                taskToClientId.get(activeTaskId),
+                rackAwareAssignmentTags,
+                clientStates,
+                tasksToRemainingStandbys,
+                tagKeyToValues,
+                tagEntryToClients,
+                pendingStandbyTasksToClientId
+            );
+        }
+
+        assertTrue(tasksToRemainingStandbys.isEmpty());
+        assertTrue(pendingStandbyTasksToClientId.isEmpty());
+    }
+
+    @Test
+    public void shouldUpdateClientToRemainingStandbysAndPendingStandbyTasksToClientIdWhenNotAllStandbyTasksWereAssigned() {
+        final Set<String> rackAwareAssignmentTags = mkSet(ZONE_TAG, CLUSTER_TAG);
+        final Map<UUID, ClientState> clientStates = mkMap(
+            mkEntry(UUID_1, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_0)),
+            mkEntry(UUID_2, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_2)), TASK_0_1)),
+            mkEntry(UUID_3, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_3)), TASK_0_2))
+        );
+
+        final ConstrainedPrioritySet constrainedPrioritySet = createLeastLoadedPrioritySetConstrainedByAssignedTask(clientStates);
+        final Set<TaskId> allActiveTasks = findAllActiveTasks(clientStates);
+        final Map<TaskId, UUID> taskToClientId = mkMap(mkEntry(TASK_0_0, UUID_1),
+                                                       mkEntry(TASK_0_1, UUID_2),
+                                                       mkEntry(TASK_0_2, UUID_3));
+
+        final Map<String, Set<String>> tagKeyToValues = new HashMap<>();
+        final Map<TagEntry, Set<UUID>> tagEntryToClients = new HashMap<>();
+
+        fillClientsTagStatistics(clientStates, tagEntryToClients, tagKeyToValues);
+
+        final Map<TaskId, UUID> pendingStandbyTasksToClientId = new HashMap<>();
+        final Map<TaskId, Integer> tasksToRemainingStandbys = computeTasksToRemainingStandbys(3, allActiveTasks);
+
+        for (final TaskId activeTaskId : allActiveTasks) {
+            assignStandbyTasksToClientsWithDifferentTags(
+                constrainedPrioritySet,
+                activeTaskId,
+                taskToClientId.get(activeTaskId),
+                rackAwareAssignmentTags,
+                clientStates,
+                tasksToRemainingStandbys,
+                tagKeyToValues,
+                tagEntryToClients,
+                pendingStandbyTasksToClientId
+            );
+        }
+
+        allActiveTasks.forEach(
+            activeTaskId -> assertEquals(String.format("Active task with id [%s] didn't match expected number " +
+                                                       "of remaining standbys value.", activeTaskId),
+                                         1,
+                                         tasksToRemainingStandbys.get(activeTaskId).longValue())
+        );
+
+        allActiveTasks.forEach(
+            activeTaskId -> assertEquals(String.format("Active task with id [%s] didn't match expected " +
+                                                       "client ID value.", activeTaskId),
+                                         taskToClientId.get(activeTaskId),
+                                         pendingStandbyTasksToClientId.get(activeTaskId))
+        );
+    }
+
+    @Test
+    public void shouldPermitTaskMovementWhenClientTagsMatch() {
+        final ClientState source = createClientStateWithCapacity(1, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1)));
+        final ClientState destination = createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1)));
+
+        assertTrue(standbyTaskAssignor.isAllowedTaskMovement(source, destination));
+    }
+
+    @Test
+    public void shouldDeclineTaskMovementWhenClientTagsDoNotMatch() {
+        final ClientState source = createClientStateWithCapacity(1, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1)));
+        final ClientState destination = createClientStateWithCapacity(1, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_1)));
+
+        assertFalse(standbyTaskAssignor.isAllowedTaskMovement(source, destination));
+    }
+
+    @Test
+    public void shouldDistributeStandbyTasksWhenActiveTasksAreLocatedOnSameZone() {
+        final Map<UUID, ClientState> clientStates = mkMap(
+            mkEntry(UUID_1, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_0, TASK_1_0)),
+            mkEntry(UUID_2, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_1)))),
+            mkEntry(UUID_3, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_3), mkEntry(CLUSTER_TAG, CLUSTER_1)))),
+
+            mkEntry(UUID_4, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_2)), TASK_0_1, TASK_1_1)),
+            mkEntry(UUID_5, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_2)))),
+            mkEntry(UUID_6, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_3), mkEntry(CLUSTER_TAG, CLUSTER_2)))),
+
+            mkEntry(UUID_7, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_3)), TASK_0_2, TASK_1_2)),
+            mkEntry(UUID_8, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_3)))),
+            mkEntry(UUID_9, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_3), mkEntry(CLUSTER_TAG, CLUSTER_3))))
+        );
+
+        final Set<TaskId> allActiveTasks = findAllActiveTasks(clientStates);
+        final AssignmentConfigs assignmentConfigs = newAssignmentConfigs(2, ZONE_TAG, CLUSTER_TAG);
+
+        standbyTaskAssignor.assign(clientStates, allActiveTasks, allActiveTasks, assignmentConfigs);
+
+        assertTrue(clientStates.values().stream().allMatch(ClientState::reachedCapacity));
+
+        Stream.of(UUID_1, UUID_4, UUID_7).forEach(client -> assertStandbyTaskCountForClientEqualsTo(clientStates, client, 0));
+        Stream.of(UUID_2, UUID_3, UUID_5, UUID_6, UUID_8, UUID_9).forEach(client -> assertStandbyTaskCountForClientEqualsTo(clientStates, client, 2));
+        assertTotalNumberOfStandbyTasksEqualsTo(clientStates, 12);
+
+        assertTrue(
+            standbyClientsHonorRackAwareness(
+                TASK_0_0,
+                clientStates,
+                asList(
+                    mkSet(UUID_9, UUID_5), mkSet(UUID_6, UUID_8)
+                )
+            )
+        );
+        assertTrue(
+            standbyClientsHonorRackAwareness(
+                TASK_1_0,
+                clientStates,
+                asList(
+                    mkSet(UUID_9, UUID_5), mkSet(UUID_6, UUID_8)
+                )
+            )
+        );
+
+        assertTrue(
+            standbyClientsHonorRackAwareness(
+                TASK_0_1,
+                clientStates,
+                asList(
+                    mkSet(UUID_2, UUID_9), mkSet(UUID_3, UUID_8)
+                )
+            )
+        );
+        assertTrue(
+            standbyClientsHonorRackAwareness(
+                TASK_1_1,
+                clientStates,
+                asList(
+                    mkSet(UUID_2, UUID_9), mkSet(UUID_3, UUID_8)
+                )
+            )
+        );
+
+        assertTrue(
+            standbyClientsHonorRackAwareness(
+                TASK_0_2,
+                clientStates,
+                asList(
+                    mkSet(UUID_5, UUID_3), mkSet(UUID_2, UUID_6)
+                )
+            )
+        );
+        assertTrue(
+            standbyClientsHonorRackAwareness(
+                TASK_1_2,
+                clientStates,
+                asList(
+                    mkSet(UUID_5, UUID_3), mkSet(UUID_2, UUID_6)
+                )
+            )
+        );
+    }
+
+    @Test
+    public void shouldDistributeStandbyTasksWhenActiveTasksAreLocatedOnSameCluster() {
+        final Map<UUID, ClientState> clientStates = mkMap(
+            mkEntry(UUID_1, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_0, TASK_1_0)),
+            mkEntry(UUID_2, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_1, TASK_1_1)),
+            mkEntry(UUID_3, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_3), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_2, TASK_1_2)),
+
+            mkEntry(UUID_4, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_2)))),
+            mkEntry(UUID_5, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_2)))),
+            mkEntry(UUID_6, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_3), mkEntry(CLUSTER_TAG, CLUSTER_2)))),
+
+            mkEntry(UUID_7, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_3)))),
+            mkEntry(UUID_8, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_3)))),
+            mkEntry(UUID_9, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_3), mkEntry(CLUSTER_TAG, CLUSTER_3))))
+        );
+
+        final Set<TaskId> allActiveTasks = findAllActiveTasks(clientStates);
+        final AssignmentConfigs assignmentConfigs = newAssignmentConfigs(2, ZONE_TAG, CLUSTER_TAG);
+
+        standbyTaskAssignor.assign(clientStates, allActiveTasks, allActiveTasks, assignmentConfigs);
+
+        assertTrue(clientStates.values().stream().allMatch(ClientState::reachedCapacity));
+
+        Stream.of(UUID_1, UUID_2, UUID_3).forEach(client -> assertStandbyTaskCountForClientEqualsTo(clientStates, client, 0));
+        Stream.of(UUID_4, UUID_5, UUID_6, UUID_7, UUID_8, UUID_9).forEach(client -> assertStandbyTaskCountForClientEqualsTo(clientStates, client, 2));
+        assertTotalNumberOfStandbyTasksEqualsTo(clientStates, 12);
+
+        assertTrue(
+            standbyClientsHonorRackAwareness(
+                TASK_0_0,
+                clientStates,
+                asList(
+                    mkSet(UUID_9, UUID_5), mkSet(UUID_6, UUID_8)
+                )
+            )
+        );
+        assertTrue(
+            standbyClientsHonorRackAwareness(
+                TASK_1_0,
+                clientStates,
+                asList(
+                    mkSet(UUID_9, UUID_5), mkSet(UUID_6, UUID_8)
+                )
+            )
+        );
+
+        assertTrue(
+            standbyClientsHonorRackAwareness(
+                TASK_0_1,
+                clientStates,
+                asList(
+                    mkSet(UUID_4, UUID_9), mkSet(UUID_6, UUID_7)
+                )
+            )
+        );
+        assertTrue(
+            standbyClientsHonorRackAwareness(
+                TASK_1_1,
+                clientStates,
+                asList(
+                    mkSet(UUID_4, UUID_9), mkSet(UUID_6, UUID_7)
+                )
+            )
+        );
+
+        assertTrue(
+            standbyClientsHonorRackAwareness(
+                TASK_0_2,
+                clientStates,
+                asList(
+                    mkSet(UUID_5, UUID_7), mkSet(UUID_4, UUID_8)
+                )
+            )
+        );
+        assertTrue(
+            standbyClientsHonorRackAwareness(
+                TASK_1_2,
+                clientStates,
+                asList(
+                    mkSet(UUID_5, UUID_7), mkSet(UUID_4, UUID_8)
+                )
+            )
+        );
+    }
+
+    @Test
+    public void shouldDoThePartialRackAwareness() {
+        final Map<UUID, ClientState> clientStates = mkMap(
+            mkEntry(UUID_1, createClientStateWithCapacity(1, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_1), mkEntry(ZONE_TAG, ZONE_1)), TASK_0_0)),
+            mkEntry(UUID_2, createClientStateWithCapacity(1, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_1), mkEntry(ZONE_TAG, ZONE_2)))),
+            mkEntry(UUID_3, createClientStateWithCapacity(1, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_1), mkEntry(ZONE_TAG, ZONE_3)))),
+
+            mkEntry(UUID_4, createClientStateWithCapacity(1, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_2), mkEntry(ZONE_TAG, ZONE_1)))),
+            mkEntry(UUID_5, createClientStateWithCapacity(1, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_2), mkEntry(ZONE_TAG, ZONE_2)))),
+            mkEntry(UUID_6, createClientStateWithCapacity(1, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_2), mkEntry(ZONE_TAG, ZONE_3)), TASK_1_0))
+        );
+
+        final Set<TaskId> allActiveTasks = findAllActiveTasks(clientStates);
+        final AssignmentConfigs assignmentConfigs = newAssignmentConfigs(2, CLUSTER_TAG, ZONE_TAG);
+
+        standbyTaskAssignor.assign(clientStates, allActiveTasks, allActiveTasks, assignmentConfigs);
+
+        // We need to distribute 2 standby tasks (+1 active task).
+        // Since we have only two unique `cluster` tag values,
+        // we can only achieve "ideal" distribution on the 1st standby task assignment.
+        // We can't consider the `cluster` tag for the 2nd standby task assignment because the 1st standby
+        // task would already be assigned on different clusters compared to the active one, which means
+        // we have already used all the available cluster tag values. Taking the `cluster` tag into consideration
+        // for the  2nd standby task assignment would affectively mean excluding all the clients.
+        // Instead, for the 2nd standby task, we can only achieve partial rack awareness based on the `zone` tag.
+        // As we don't consider the `cluster` tag for the 2nd standby task assignment, partial rack awareness
+        // can be satisfied by placing the 2nd standby client on a different `zone` tag compared to active and corresponding standby tasks.
+        // The `zone` on either `cluster` tags are valid candidates for the partial rack awareness, as our goal is to distribute clients on the different `zone` tags.
+
+        Stream.of(UUID_2, UUID_5).forEach(client -> assertStandbyTaskCountForClientEqualsTo(clientStates, client, 1));
+        // There's no strong guarantee where 2nd standby task will end up.
+        Stream.of(UUID_1, UUID_3, UUID_4, UUID_6).forEach(client -> assertStandbyTaskCountForClientEqualsTo(clientStates, client, 0, 1));
+        assertTotalNumberOfStandbyTasksEqualsTo(clientStates, 4);
+
+        assertTrue(
+            standbyClientsHonorRackAwareness(
+                TASK_0_0,
+                clientStates,
+                asList(
+                    // Since it's located on a different `cluster` and `zone` tag dimensions,
+                    // `UUID_5` is the "ideal" distribution for the 1st standby task assignment.
+                    // For the 2nd standby, either `UUID_3` or `UUID_6` are valid destinations as
+                    // we need to distribute the clients on different `zone`
+                    // tags without considering the `cluster` tag value.
+                    mkSet(UUID_5, UUID_3),
+                    mkSet(UUID_5, UUID_6)
+                )
+            )
+        );
+        assertTrue(
+            standbyClientsHonorRackAwareness(
+                TASK_1_0,
+                clientStates,
+                asList(
+                    // The same comment as above applies here too.
+                    // `UUID_2` is the ideal distribution on different `cluster`
+                    // and `zone` tag dimensions. In contrast, `UUID_4` and `UUID_1`
+                    // satisfy only the partial rack awareness as they are located on a different `zone` tag dimension.
+                    mkSet(UUID_2, UUID_4),
+                    mkSet(UUID_2, UUID_1)
+                )
+            )
+        );
+    }
+
+    @Test
+    public void shouldDistributeClientsOnDifferentZoneTagsEvenWhenClientsReachedCapacity() {
+        final Map<UUID, ClientState> clientStates = mkMap(
+            mkEntry(UUID_1, createClientStateWithCapacity(1, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_0)),
+            mkEntry(UUID_2, createClientStateWithCapacity(1, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_1)),
+            mkEntry(UUID_3, createClientStateWithCapacity(1, mkMap(mkEntry(ZONE_TAG, ZONE_3), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_2)),
+            mkEntry(UUID_4, createClientStateWithCapacity(1, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_1_0)),
+            mkEntry(UUID_5, createClientStateWithCapacity(1, mkMap(mkEntry(ZONE_TAG, ZONE_2), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_1_1)),
+            mkEntry(UUID_6, createClientStateWithCapacity(1, mkMap(mkEntry(ZONE_TAG, ZONE_3), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_1_2))
+        );
+
+        final Set<TaskId> allActiveTasks = findAllActiveTasks(clientStates);
+        final AssignmentConfigs assignmentConfigs = newAssignmentConfigs(1, ZONE_TAG, CLUSTER_TAG);
+
+        standbyTaskAssignor.assign(clientStates, allActiveTasks, allActiveTasks, assignmentConfigs);
+
+        clientStates.keySet().forEach(client -> assertStandbyTaskCountForClientEqualsTo(clientStates, client, 1));
+        assertTotalNumberOfStandbyTasksEqualsTo(clientStates, 6);
+
+        assertTrue(
+            standbyClientsHonorRackAwareness(
+                TASK_0_0,
+                clientStates,
+                asList(
+                    mkSet(UUID_2), mkSet(UUID_5), mkSet(UUID_3), mkSet(UUID_6)
+                )
+            )
+        );
+        assertTrue(
+            standbyClientsHonorRackAwareness(
+                TASK_1_0,
+                clientStates,
+                asList(
+                    mkSet(UUID_2), mkSet(UUID_5), mkSet(UUID_3), mkSet(UUID_6)
+                )
+            )
+        );
+
+        assertTrue(
+            standbyClientsHonorRackAwareness(
+                TASK_0_1,
+                clientStates,
+                asList(
+                    mkSet(UUID_1), mkSet(UUID_4), mkSet(UUID_3), mkSet(UUID_6)
+                )
+            )
+        );
+        assertTrue(
+            standbyClientsHonorRackAwareness(
+                TASK_1_1,
+                clientStates,
+                asList(
+                    mkSet(UUID_1), mkSet(UUID_4), mkSet(UUID_3), mkSet(UUID_6)
+                )
+            )
+        );
+
+        assertTrue(
+            standbyClientsHonorRackAwareness(
+                TASK_0_2,
+                clientStates,
+                asList(
+                    mkSet(UUID_1), mkSet(UUID_4), mkSet(UUID_2), mkSet(UUID_5)
+                )
+            )
+        );
+        assertTrue(
+            standbyClientsHonorRackAwareness(
+                TASK_1_2,
+                clientStates,
+                asList(
+                    mkSet(UUID_1), mkSet(UUID_4), mkSet(UUID_2), mkSet(UUID_5)
+                )
+            )
+        );
+    }
+
+    @Test
+    public void shouldIgnoreTagsThatAreNotPresentInRackAwareness() {
+        final Map<UUID, ClientState> clientStates = mkMap(
+            mkEntry(UUID_1, createClientStateWithCapacity(1, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_1), mkEntry(ZONE_TAG, ZONE_1)), TASK_0_0)),
+            mkEntry(UUID_2, createClientStateWithCapacity(2, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_1), mkEntry(ZONE_TAG, ZONE_2)))),
+
+            mkEntry(UUID_3, createClientStateWithCapacity(1, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_2), mkEntry(ZONE_TAG, ZONE_1))))
+        );
+
+        final Set<TaskId> allActiveTasks = findAllActiveTasks(clientStates);
+        final AssignmentConfigs assignmentConfigs = newAssignmentConfigs(1, CLUSTER_TAG);
+
+        standbyTaskAssignor.assign(clientStates, allActiveTasks, allActiveTasks, assignmentConfigs);
+
+        assertTotalNumberOfStandbyTasksEqualsTo(clientStates, 1);
+        assertEquals(1, clientStates.get(UUID_3).standbyTaskCount());
+    }
+
+    @Test
+    public void shouldHandleOverlappingTagValuesBetweenDifferentTagKeys() {
+        final Map<UUID, ClientState> clientStates = mkMap(
+            mkEntry(UUID_1, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, ZONE_1), mkEntry(CLUSTER_TAG, CLUSTER_1)), TASK_0_0)),
+            mkEntry(UUID_2, createClientStateWithCapacity(2, mkMap(mkEntry(ZONE_TAG, CLUSTER_1), mkEntry(CLUSTER_TAG, CLUSTER_3))))
+        );
+
+        final Set<TaskId> allActiveTasks = findAllActiveTasks(clientStates);
+        final AssignmentConfigs assignmentConfigs = newAssignmentConfigs(1, ZONE_TAG, CLUSTER_TAG);
+
+        standbyTaskAssignor.assign(clientStates, allActiveTasks, allActiveTasks, assignmentConfigs);
+
+        assertTotalNumberOfStandbyTasksEqualsTo(clientStates, 1);
+        assertTrue(
+            standbyClientsHonorRackAwareness(
+                TASK_0_0,
+                clientStates,
+                singletonList(
+                    mkSet(UUID_2)
+                )
+            )
+        );
+    }
+
+    @Test
+    public void shouldDistributeStandbyTasksOnLeastLoadedClientsWhenClientsAreNotOnDifferentTagDimensions() {
+        final Map<UUID, ClientState> clientStates = mkMap(
+            mkEntry(UUID_1, createClientStateWithCapacity(3, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_1), mkEntry(ZONE_TAG, ZONE_1)), TASK_0_0)),
+            mkEntry(UUID_2, createClientStateWithCapacity(3, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_1), mkEntry(ZONE_TAG, ZONE_1)), TASK_0_1)),
+            mkEntry(UUID_3, createClientStateWithCapacity(3, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_1), mkEntry(ZONE_TAG, ZONE_1)), TASK_0_2)),
+            mkEntry(UUID_4, createClientStateWithCapacity(3, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_1), mkEntry(ZONE_TAG, ZONE_1)), TASK_1_0))
+        );
+
+        final Set<TaskId> allActiveTasks = findAllActiveTasks(clientStates);
+        final AssignmentConfigs assignmentConfigs = newAssignmentConfigs(1, CLUSTER_TAG, ZONE_TAG);
+
+        standbyTaskAssignor.assign(clientStates, allActiveTasks, allActiveTasks, assignmentConfigs);
+
+        assertTotalNumberOfStandbyTasksEqualsTo(clientStates, 4);
+        assertEquals(1, clientStates.get(UUID_1).standbyTaskCount());
+        assertEquals(1, clientStates.get(UUID_2).standbyTaskCount());
+        assertEquals(1, clientStates.get(UUID_3).standbyTaskCount());
+        assertEquals(1, clientStates.get(UUID_4).standbyTaskCount());
+    }
+
+    @Test
+    public void shouldNotAssignStandbyTasksIfThereAreNoEnoughClients() {
+        final Map<UUID, ClientState> clientStates = mkMap(
+            mkEntry(UUID_1, createClientStateWithCapacity(3, mkMap(mkEntry(CLUSTER_TAG, CLUSTER_1), mkEntry(ZONE_TAG, ZONE_1)), TASK_0_0))
+        );
+
+        final Set<TaskId> allActiveTasks = findAllActiveTasks(clientStates);
+        final AssignmentConfigs assignmentConfigs = newAssignmentConfigs(1, CLUSTER_TAG, ZONE_TAG);
+
+        standbyTaskAssignor.assign(clientStates, allActiveTasks, allActiveTasks, assignmentConfigs);
+
+        assertTotalNumberOfStandbyTasksEqualsTo(clientStates, 0);
+        assertEquals(0, clientStates.get(UUID_1).standbyTaskCount());
+    }
+
+    private static void assertTotalNumberOfStandbyTasksEqualsTo(final Map<UUID, ClientState> clientStates, final int expectedTotalNumberOfStandbyTasks) {
+        final int actualTotalNumberOfStandbyTasks = clientStates.values().stream().map(ClientState::standbyTaskCount).reduce(0, Integer::sum);
+        assertEquals(expectedTotalNumberOfStandbyTasks, actualTotalNumberOfStandbyTasks);
+    }
+
+    private static void assertStandbyTaskCountForClientEqualsTo(final Map<UUID, ClientState> clientStates,
+                                                                final UUID client,
+                                                                final int... expectedStandbyTaskCounts) {
+        final int standbyTaskCount = clientStates.get(client).standbyTaskCount();
+        final String msg = String.format("Client [%s] doesn't have expected number of standby tasks. " +
+                                         "Expected any of %s, actual [%s]",
+                                         client, Arrays.toString(expectedStandbyTaskCounts), standbyTaskCount);
+
+        assertTrue(msg, Arrays.stream(expectedStandbyTaskCounts).anyMatch(expectedStandbyTaskCount -> expectedStandbyTaskCount == standbyTaskCount));
+    }
+
+    private static boolean standbyClientsHonorRackAwareness(final TaskId activeTaskId,
+                                                            final Map<UUID, ClientState> clientStates,
+                                                            final List<Set<UUID>> validClientIdsBasedOnRackAwareAssignmentTags) {
+        final Set<UUID> standbyTaskClientIds = findAllStandbyTaskClients(clientStates, activeTaskId);
+
+        return validClientIdsBasedOnRackAwareAssignmentTags.stream()
+                                                           .filter(it -> it.equals(standbyTaskClientIds))
+                                                           .count() == 1;
+    }
+
+    private static Set<UUID> findAllStandbyTaskClients(final Map<UUID, ClientState> clientStates, final TaskId task) {
+        return clientStates.keySet()
+                           .stream()
+                           .filter(clientId -> clientStates.get(clientId).standbyTasks().contains(task))
+                           .collect(Collectors.toSet());
+    }
+
+    private static AssignmentConfigs newAssignmentConfigs(final int numStandbyReplicas,
+                                                          final String... rackAwareAssignmentTags) {
+        return new AssignmentConfigs(0L,
+                                     1,
+                                     numStandbyReplicas,
+                                     60000L,
+                                     asList(rackAwareAssignmentTags));
+    }
+
+    private static ClientState createClientStateWithCapacity(final int capacity,
+                                                             final Map<String, String> clientTags,
+                                                             final TaskId... tasks) {
+        final ClientState clientState = new ClientState(capacity, clientTags);
+
+        Optional.ofNullable(tasks).ifPresent(t -> clientState.assignActiveTasks(asList(t)));
+
+        return clientState;
+    }
+
+    private static Set<TaskId> findAllActiveTasks(final Map<UUID, ClientState> clientStates) {
+        return clientStates.entrySet()
+                           .stream()
+                           .flatMap(clientStateEntry -> clientStateEntry.getValue().activeTasks().stream())
+                           .collect(Collectors.toSet());
+    }
+}
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/FallbackPriorTaskAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/FallbackPriorTaskAssignorTest.java
index 491fff1..0473d9b 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/FallbackPriorTaskAssignorTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/FallbackPriorTaskAssignorTest.java
@@ -28,6 +28,7 @@ import java.util.UUID;
 
 import static java.util.Arrays.asList;
 import static org.apache.kafka.common.utils.Utils.mkSet;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.EMPTY_RACK_AWARE_ASSIGNMENT_TAGS;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_0;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_1;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_2;
@@ -53,7 +54,7 @@ public class FallbackPriorTaskAssignorTest {
             clients,
             new HashSet<>(taskIds),
             new HashSet<>(taskIds),
-            new AssignorConfiguration.AssignmentConfigs(0L, 1, 0, 60_000L)
+            new AssignorConfiguration.AssignmentConfigs(0L, 1, 0, 60_000L, EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
         );
         assertThat(probingRebalanceNeeded, is(true));
 
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignorTest.java
index a2d4716..bf78db6 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignorTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignorTest.java
@@ -20,8 +20,10 @@ import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs;
 import org.junit.Test;
 
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.UUID;
@@ -29,10 +31,13 @@ import java.util.stream.Collectors;
 
 import static java.util.Collections.emptySet;
 import static java.util.Collections.singleton;
+import static java.util.Collections.singletonList;
 import static java.util.Collections.singletonMap;
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.common.utils.Utils.mkSet;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.EMPTY_CLIENT_TAGS;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.EMPTY_RACK_AWARE_ASSIGNMENT_TAGS;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.EMPTY_TASKS;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_0;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_1;
@@ -63,6 +68,7 @@ import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.greaterThanOrEqualTo;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.not;
+import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
 public class HighAvailabilityTaskAssignorTest {
@@ -70,22 +76,24 @@ public class HighAvailabilityTaskAssignorTest {
         /*acceptableRecoveryLag*/ 100L,
         /*maxWarmupReplicas*/ 2,
         /*numStandbyReplicas*/ 0,
-        /*probingRebalanceIntervalMs*/ 60 * 1000L
+        /*probingRebalanceIntervalMs*/ 60 * 1000L,
+        /*rackAwareAssignmentTags*/ EMPTY_RACK_AWARE_ASSIGNMENT_TAGS
     );
 
     private final AssignmentConfigs configWithStandbys = new AssignmentConfigs(
         /*acceptableRecoveryLag*/ 100L,
         /*maxWarmupReplicas*/ 2,
         /*numStandbyReplicas*/ 1,
-        /*probingRebalanceIntervalMs*/ 60 * 1000L
+        /*probingRebalanceIntervalMs*/ 60 * 1000L,
+        /*rackAwareAssignmentTags*/ EMPTY_RACK_AWARE_ASSIGNMENT_TAGS
     );
 
     @Test
     public void shouldBeStickyForActiveAndStandbyTasksWhileWarmingUp() {
         final Set<TaskId> allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2, TASK_2_0, TASK_2_1, TASK_2_2);
-        final ClientState clientState1 = new ClientState(allTaskIds, emptySet(), allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 0L)), 1);
-        final ClientState clientState2 = new ClientState(emptySet(), allTaskIds, allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L)), 1);
-        final ClientState clientState3 = new ClientState(emptySet(), emptySet(), allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> Long.MAX_VALUE)), 1);
+        final ClientState clientState1 = new ClientState(allTaskIds, emptySet(), allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 0L)), EMPTY_CLIENT_TAGS, 1);
+        final ClientState clientState2 = new ClientState(emptySet(), allTaskIds, allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L)), EMPTY_CLIENT_TAGS, 1);
+        final ClientState clientState3 = new ClientState(emptySet(), emptySet(), allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> Long.MAX_VALUE)), EMPTY_CLIENT_TAGS, 1);
 
         final Map<UUID, ClientState> clientStates = mkMap(
             mkEntry(UUID_1, clientState1),
@@ -97,7 +105,7 @@ public class HighAvailabilityTaskAssignorTest {
             clientStates,
             allTaskIds,
             allTaskIds,
-            new AssignmentConfigs(11L, 2, 1, 60_000L)
+            new AssignmentConfigs(11L, 2, 1, 60_000L, EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
         );
 
         assertThat(clientState1, hasAssignedTasks(allTaskIds.size()));
@@ -112,9 +120,9 @@ public class HighAvailabilityTaskAssignorTest {
     @Test
     public void shouldSkipWarmupsWhenAcceptableLagIsMax() {
         final Set<TaskId> allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2, TASK_2_0, TASK_2_1, TASK_2_2);
-        final ClientState clientState1 = new ClientState(allTaskIds, emptySet(), allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 0L)), 1);
-        final ClientState clientState2 = new ClientState(emptySet(), emptySet(), allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> Long.MAX_VALUE)), 1);
-        final ClientState clientState3 = new ClientState(emptySet(), emptySet(), allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> Long.MAX_VALUE)), 1);
+        final ClientState clientState1 = new ClientState(allTaskIds, emptySet(), allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 0L)), EMPTY_CLIENT_TAGS, 1);
+        final ClientState clientState2 = new ClientState(emptySet(), emptySet(), allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> Long.MAX_VALUE)), EMPTY_CLIENT_TAGS, 1);
+        final ClientState clientState3 = new ClientState(emptySet(), emptySet(), allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> Long.MAX_VALUE)), EMPTY_CLIENT_TAGS, 1);
 
         final Map<UUID, ClientState> clientStates = mkMap(
             mkEntry(UUID_1, clientState1),
@@ -126,7 +134,7 @@ public class HighAvailabilityTaskAssignorTest {
             clientStates,
             allTaskIds,
             allTaskIds,
-            new AssignmentConfigs(Long.MAX_VALUE, 1, 1, 60_000L)
+            new AssignmentConfigs(Long.MAX_VALUE, 1, 1, 60_000L, EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
         );
 
         assertThat(clientState1, hasAssignedTasks(6));
@@ -139,15 +147,15 @@ public class HighAvailabilityTaskAssignorTest {
     public void shouldAssignActiveStatefulTasksEvenlyOverClientsWhereNumberOfClientsIntegralDivisorOfNumberOfTasks() {
         final Set<TaskId> allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2, TASK_2_0, TASK_2_1, TASK_2_2);
         final Map<TaskId, Long> lags = allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L));
-        final ClientState clientState1 = new ClientState(emptySet(), emptySet(), lags, 1);
-        final ClientState clientState2 = new ClientState(emptySet(), emptySet(), lags, 1);
-        final ClientState clientState3 = new ClientState(emptySet(), emptySet(), lags, 1);
+        final ClientState clientState1 = new ClientState(emptySet(), emptySet(), lags, EMPTY_CLIENT_TAGS, 1);
+        final ClientState clientState2 = new ClientState(emptySet(), emptySet(), lags, EMPTY_CLIENT_TAGS, 1);
+        final ClientState clientState3 = new ClientState(emptySet(), emptySet(), lags, EMPTY_CLIENT_TAGS, 1);
         final Map<UUID, ClientState> clientStates = getClientStatesMap(clientState1, clientState2, clientState3);
         final boolean unstable = new HighAvailabilityTaskAssignor().assign(
             clientStates,
             allTaskIds,
             allTaskIds,
-            new AssignmentConfigs(0L, 1, 0, 60_000L)
+            new AssignmentConfigs(0L, 1, 0, 60_000L, EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
         );
         assertThat(unstable, is(false));
         assertValidAssignment(0, allTaskIds, emptySet(), clientStates, new StringBuilder());
@@ -160,15 +168,15 @@ public class HighAvailabilityTaskAssignorTest {
     public void shouldAssignActiveStatefulTasksEvenlyOverClientsWhereNumberOfThreadsIntegralDivisorOfNumberOfTasks() {
         final Set<TaskId> allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2, TASK_2_0, TASK_2_1, TASK_2_2);
         final Map<TaskId, Long> lags = allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L));
-        final ClientState clientState1 = new ClientState(emptySet(), emptySet(), lags, 3);
-        final ClientState clientState2 = new ClientState(emptySet(), emptySet(), lags, 3);
-        final ClientState clientState3 = new ClientState(emptySet(), emptySet(), lags, 3);
+        final ClientState clientState1 = new ClientState(emptySet(), emptySet(), lags, EMPTY_CLIENT_TAGS, 3);
+        final ClientState clientState2 = new ClientState(emptySet(), emptySet(), lags, EMPTY_CLIENT_TAGS, 3);
+        final ClientState clientState3 = new ClientState(emptySet(), emptySet(), lags, EMPTY_CLIENT_TAGS, 3);
         final Map<UUID, ClientState> clientStates = getClientStatesMap(clientState1, clientState2, clientState3);
         final boolean unstable = new HighAvailabilityTaskAssignor().assign(
             clientStates,
             allTaskIds,
             allTaskIds,
-            new AssignmentConfigs(0L, 1, 0, 60_000L)
+            new AssignmentConfigs(0L, 1, 0, 60_000L, EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
         );
         assertThat(unstable, is(false));
         assertValidAssignment(0, allTaskIds, emptySet(), clientStates, new StringBuilder());
@@ -181,14 +189,14 @@ public class HighAvailabilityTaskAssignorTest {
     public void shouldAssignActiveStatefulTasksEvenlyOverClientsWhereNumberOfClientsNotIntegralDivisorOfNumberOfTasks() {
         final Set<TaskId> allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2, TASK_2_0, TASK_2_1, TASK_2_2);
         final Map<TaskId, Long> lags = allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L));
-        final ClientState clientState1 = new ClientState(emptySet(), emptySet(), lags, 1);
-        final ClientState clientState2 = new ClientState(emptySet(), emptySet(), lags, 1);
+        final ClientState clientState1 = new ClientState(emptySet(), emptySet(), lags, EMPTY_CLIENT_TAGS, 1);
+        final ClientState clientState2 = new ClientState(emptySet(), emptySet(), lags, EMPTY_CLIENT_TAGS, 1);
         final Map<UUID, ClientState> clientStates = getClientStatesMap(clientState1, clientState2);
         final boolean unstable = new HighAvailabilityTaskAssignor().assign(
             clientStates,
             allTaskIds,
             allTaskIds,
-            new AssignmentConfigs(0L, 1, 0, 60_000L)
+            new AssignmentConfigs(0L, 1, 0, 60_000L, EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
         );
 
         assertThat(unstable, is(false));
@@ -202,15 +210,15 @@ public class HighAvailabilityTaskAssignorTest {
     public void shouldAssignActiveStatefulTasksEvenlyOverUnevenlyDistributedStreamThreads() {
         final Set<TaskId> allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2);
         final Map<TaskId, Long> lags = allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L));
-        final ClientState clientState1 = new ClientState(emptySet(), emptySet(), lags, 1);
-        final ClientState clientState2 = new ClientState(emptySet(), emptySet(), lags, 2);
-        final ClientState clientState3 = new ClientState(emptySet(), emptySet(), lags, 3);
+        final ClientState clientState1 = new ClientState(emptySet(), emptySet(), lags, EMPTY_CLIENT_TAGS, 1);
+        final ClientState clientState2 = new ClientState(emptySet(), emptySet(), lags, EMPTY_CLIENT_TAGS, 2);
+        final ClientState clientState3 = new ClientState(emptySet(), emptySet(), lags, EMPTY_CLIENT_TAGS, 3);
         final Map<UUID, ClientState> clientStates = getClientStatesMap(clientState1, clientState2, clientState3);
         final boolean unstable = new HighAvailabilityTaskAssignor().assign(
             clientStates,
             allTaskIds,
             allTaskIds,
-            new AssignmentConfigs(0L, 1, 0, 60_000L)
+            new AssignmentConfigs(0L, 1, 0, 60_000L, EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
         );
 
         assertThat(unstable, is(false));
@@ -231,15 +239,15 @@ public class HighAvailabilityTaskAssignorTest {
     public void shouldAssignActiveStatefulTasksEvenlyOverClientsWithMoreClientsThanTasks() {
         final Set<TaskId> allTaskIds = mkSet(TASK_0_0, TASK_0_1);
         final Map<TaskId, Long> lags = allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L));
-        final ClientState clientState1 = new ClientState(emptySet(), emptySet(), lags, 1);
-        final ClientState clientState2 = new ClientState(emptySet(), emptySet(), lags, 1);
-        final ClientState clientState3 = new ClientState(emptySet(), emptySet(), lags, 1);
+        final ClientState clientState1 = new ClientState(emptySet(), emptySet(), lags, EMPTY_CLIENT_TAGS, 1);
+        final ClientState clientState2 = new ClientState(emptySet(), emptySet(), lags, EMPTY_CLIENT_TAGS, 1);
+        final ClientState clientState3 = new ClientState(emptySet(), emptySet(), lags, EMPTY_CLIENT_TAGS, 1);
         final Map<UUID, ClientState> clientStates = getClientStatesMap(clientState1, clientState2, clientState3);
         final boolean unstable = new HighAvailabilityTaskAssignor().assign(
             clientStates,
             allTaskIds,
             allTaskIds,
-            new AssignmentConfigs(0L, 1, 0, 60_000L)
+            new AssignmentConfigs(0L, 1, 0, 60_000L, EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
         );
 
         assertThat(unstable, is(false));
@@ -253,15 +261,15 @@ public class HighAvailabilityTaskAssignorTest {
     public void shouldAssignActiveStatefulTasksEvenlyOverClientsAndStreamThreadsWithEqualStreamThreadsPerClientAsTasks() {
         final Set<TaskId> allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2, TASK_2_0, TASK_2_1, TASK_2_2);
         final Map<TaskId, Long> lags = allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L));
-        final ClientState clientState1 = new ClientState(emptySet(), emptySet(), lags, 9);
-        final ClientState clientState2 = new ClientState(emptySet(), emptySet(), lags, 9);
-        final ClientState clientState3 = new ClientState(emptySet(), emptySet(), lags, 9);
+        final ClientState clientState1 = new ClientState(emptySet(), emptySet(), lags, EMPTY_CLIENT_TAGS, 9);
+        final ClientState clientState2 = new ClientState(emptySet(), emptySet(), lags, EMPTY_CLIENT_TAGS, 9);
+        final ClientState clientState3 = new ClientState(emptySet(), emptySet(), lags, EMPTY_CLIENT_TAGS, 9);
         final Map<UUID, ClientState> clientStates = getClientStatesMap(clientState1, clientState2, clientState3);
         final boolean unstable = new HighAvailabilityTaskAssignor().assign(
             clientStates,
             allTaskIds,
             allTaskIds,
-            new AssignmentConfigs(0L, 1, 0, 60_000L)
+            new AssignmentConfigs(0L, 1, 0, 60_000L, EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
         );
 
         assertThat(unstable, is(false));
@@ -277,16 +285,16 @@ public class HighAvailabilityTaskAssignorTest {
         final Map<TaskId, Long> lagsForCaughtUpClient = allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 0L));
         final Map<TaskId, Long> lagsForNotCaughtUpClient =
             allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> Long.MAX_VALUE));
-        final ClientState caughtUpClientState = new ClientState(allTaskIds, emptySet(), lagsForCaughtUpClient, 5);
-        final ClientState notCaughtUpClientState1 = new ClientState(emptySet(), emptySet(), lagsForNotCaughtUpClient, 5);
-        final ClientState notCaughtUpClientState2 = new ClientState(emptySet(), emptySet(), lagsForNotCaughtUpClient, 5);
+        final ClientState caughtUpClientState = new ClientState(allTaskIds, emptySet(), lagsForCaughtUpClient, EMPTY_CLIENT_TAGS, 5);
+        final ClientState notCaughtUpClientState1 = new ClientState(emptySet(), emptySet(), lagsForNotCaughtUpClient, EMPTY_CLIENT_TAGS, 5);
+        final ClientState notCaughtUpClientState2 = new ClientState(emptySet(), emptySet(), lagsForNotCaughtUpClient, EMPTY_CLIENT_TAGS, 5);
         final Map<UUID, ClientState> clientStates =
             getClientStatesMap(caughtUpClientState, notCaughtUpClientState1, notCaughtUpClientState2);
         final boolean unstable = new HighAvailabilityTaskAssignor().assign(
             clientStates,
             allTaskIds,
             allTaskIds,
-            new AssignmentConfigs(0L, allTaskIds.size() / 3 + 1, 0, 60_000L)
+            new AssignmentConfigs(0L, allTaskIds.size() / 3 + 1, 0, 60_000L, EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
         );
 
         assertThat(unstable, is(true));
@@ -307,16 +315,16 @@ public class HighAvailabilityTaskAssignorTest {
         final Map<TaskId, Long> lagsForWarmedUpClient2 =
             allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> Long.MAX_VALUE));
         lagsForWarmedUpClient2.put(TASK_1_0, 0L);
-        final ClientState caughtUpClientState = new ClientState(allTaskIds, emptySet(), lagsForCaughtUpClient, 5);
-        final ClientState warmedUpClientState1 = new ClientState(emptySet(), warmedUpTaskIds1, lagsForWarmedUpClient1, 5);
-        final ClientState warmedUpClientState2 = new ClientState(emptySet(), warmedUpTaskIds2, lagsForWarmedUpClient2, 5);
+        final ClientState caughtUpClientState = new ClientState(allTaskIds, emptySet(), lagsForCaughtUpClient, EMPTY_CLIENT_TAGS, 5);
+        final ClientState warmedUpClientState1 = new ClientState(emptySet(), warmedUpTaskIds1, lagsForWarmedUpClient1, EMPTY_CLIENT_TAGS, 5);
+        final ClientState warmedUpClientState2 = new ClientState(emptySet(), warmedUpTaskIds2, lagsForWarmedUpClient2, EMPTY_CLIENT_TAGS, 5);
         final Map<UUID, ClientState> clientStates =
             getClientStatesMap(caughtUpClientState, warmedUpClientState1, warmedUpClientState2);
         final boolean unstable = new HighAvailabilityTaskAssignor().assign(
             clientStates,
             allTaskIds,
             allTaskIds,
-            new AssignmentConfigs(0L, allTaskIds.size() / 3 + 1, 0, 60_000L)
+            new AssignmentConfigs(0L, allTaskIds.size() / 3 + 1, 0, 60_000L, EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
         );
 
         assertThat(unstable, is(false));
@@ -330,14 +338,14 @@ public class HighAvailabilityTaskAssignorTest {
     public void shouldAssignActiveStatefulTasksEvenlyOverStreamThreadsButBestEffortOverClients() {
         final Set<TaskId> allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2, TASK_2_0, TASK_2_1, TASK_2_2);
         final Map<TaskId, Long> lags = allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L));
-        final ClientState clientState1 = new ClientState(emptySet(), emptySet(), lags, 6);
-        final ClientState clientState2 = new ClientState(emptySet(), emptySet(), lags, 3);
+        final ClientState clientState1 = new ClientState(emptySet(), emptySet(), lags, EMPTY_CLIENT_TAGS, 6);
+        final ClientState clientState2 = new ClientState(emptySet(), emptySet(), lags, EMPTY_CLIENT_TAGS, 3);
         final Map<UUID, ClientState> clientStates = getClientStatesMap(clientState1, clientState2);
         final boolean unstable = new HighAvailabilityTaskAssignor().assign(
             clientStates,
             allTaskIds,
             allTaskIds,
-            new AssignmentConfigs(0L, 1, 0, 60_000L)
+            new AssignmentConfigs(0L, 1, 0, 60_000L, EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
         );
 
         assertThat(unstable, is(false));
@@ -351,7 +359,7 @@ public class HighAvailabilityTaskAssignorTest {
     @Test
     public void shouldComputeNewAssignmentIfThereAreUnassignedActiveTasks() {
         final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1);
-        final ClientState client1 = new ClientState(singleton(TASK_0_0), emptySet(), singletonMap(TASK_0_0, 0L), 1);
+        final ClientState client1 = new ClientState(singleton(TASK_0_0), emptySet(), singletonMap(TASK_0_0, 0L), EMPTY_CLIENT_TAGS, 1);
         final Map<UUID, ClientState> clientStates = singletonMap(UUID_1, client1);
 
         final boolean probingRebalanceNeeded = new HighAvailabilityTaskAssignor().assign(clientStates,
@@ -373,8 +381,8 @@ public class HighAvailabilityTaskAssignorTest {
     public void shouldComputeNewAssignmentIfThereAreUnassignedStandbyTasks() {
         final Set<TaskId> allTasks = mkSet(TASK_0_0);
         final Set<TaskId> statefulTasks = mkSet(TASK_0_0);
-        final ClientState client1 = new ClientState(singleton(TASK_0_0), emptySet(), singletonMap(TASK_0_0, 0L), 1);
-        final ClientState client2 = new ClientState(emptySet(), emptySet(), singletonMap(TASK_0_0, 0L), 1);
+        final ClientState client1 = new ClientState(singleton(TASK_0_0), emptySet(), singletonMap(TASK_0_0, 0L), EMPTY_CLIENT_TAGS, 1);
+        final ClientState client2 = new ClientState(emptySet(), emptySet(), singletonMap(TASK_0_0, 0L), EMPTY_CLIENT_TAGS, 1);
         final Map<UUID, ClientState> clientStates = mkMap(mkEntry(UUID_1, client1), mkEntry(UUID_2, client2));
 
         final boolean probingRebalanceNeeded = new HighAvailabilityTaskAssignor().assign(clientStates,
@@ -394,8 +402,8 @@ public class HighAvailabilityTaskAssignorTest {
     public void shouldComputeNewAssignmentIfActiveTasksWasNotOnCaughtUpClient() {
         final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1);
         final Set<TaskId> statefulTasks = mkSet(TASK_0_0);
-        final ClientState client1 = new ClientState(singleton(TASK_0_0), emptySet(), singletonMap(TASK_0_0, 500L), 1);
-        final ClientState client2 = new ClientState(singleton(TASK_0_1), emptySet(), singletonMap(TASK_0_0, 0L), 1);
+        final ClientState client1 = new ClientState(singleton(TASK_0_0), emptySet(), singletonMap(TASK_0_0, 500L), EMPTY_CLIENT_TAGS, 1);
+        final ClientState client2 = new ClientState(singleton(TASK_0_1), emptySet(), singletonMap(TASK_0_0, 0L), EMPTY_CLIENT_TAGS, 1);
         final Map<UUID, ClientState> clientStates = mkMap(
             mkEntry(UUID_1, client1),
             mkEntry(UUID_2, client2)
@@ -490,7 +498,8 @@ public class HighAvailabilityTaskAssignorTest {
                 /*acceptableRecoveryLag*/ 100L,
                 /*maxWarmupReplicas*/ 1,
                 /*numStandbyReplicas*/ 0,
-                /*probingRebalanceIntervalMs*/ 60 * 1000L
+                /*probingRebalanceIntervalMs*/ 60 * 1000L,
+                /*rackAwareAssignmentTags*/ EMPTY_RACK_AWARE_ASSIGNMENT_TAGS
             )
         );
 
@@ -518,7 +527,8 @@ public class HighAvailabilityTaskAssignorTest {
                 /*acceptableRecoveryLag*/ 100L,
                 /*maxWarmupReplicas*/ 1,
                 /*numStandbyReplicas*/ 1,
-                /*probingRebalanceIntervalMs*/ 60 * 1000L
+                /*probingRebalanceIntervalMs*/ 60 * 1000L,
+                /*rackAwareAssignmentTags*/ EMPTY_RACK_AWARE_ASSIGNMENT_TAGS
             )
         );
 
@@ -621,9 +631,9 @@ public class HighAvailabilityTaskAssignorTest {
             mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_1, TASK_1_2, TASK_1_3, TASK_2_0); // 9 total
         final Map<TaskId, Long> allTaskLags = allTasks.stream().collect(Collectors.toMap(t -> t, t -> 0L));
         final Set<TaskId> statefulTasks = new HashSet<>(allTasks);
-        final ClientState client1 = new ClientState(emptySet(), emptySet(), allTaskLags, 100);
-        final ClientState client2 = new ClientState(emptySet(), emptySet(), allTaskLags, 50);
-        final ClientState client3 = new ClientState(emptySet(), emptySet(), allTaskLags, 1);
+        final ClientState client1 = new ClientState(emptySet(), emptySet(), allTaskLags, EMPTY_CLIENT_TAGS, 100);
+        final ClientState client2 = new ClientState(emptySet(), emptySet(), allTaskLags, EMPTY_CLIENT_TAGS, 50);
+        final ClientState client3 = new ClientState(emptySet(), emptySet(), allTaskLags, EMPTY_CLIENT_TAGS, 1);
 
         final Map<UUID, ClientState> clientStates = getClientStatesMap(client1, client2, client3);
 
@@ -687,9 +697,9 @@ public class HighAvailabilityTaskAssignorTest {
         final Set<TaskId> statelessTasks = new HashSet<>(allTasks);
 
         final Map<TaskId, Long> taskLags = new HashMap<>();
-        final ClientState client1 = new ClientState(emptySet(), emptySet(), taskLags, 7);
-        final ClientState client2 = new ClientState(emptySet(), emptySet(), taskLags, 7);
-        final ClientState client3 = new ClientState(emptySet(), emptySet(), taskLags, 7);
+        final ClientState client1 = new ClientState(emptySet(), emptySet(), taskLags, EMPTY_CLIENT_TAGS, 7);
+        final ClientState client2 = new ClientState(emptySet(), emptySet(), taskLags, EMPTY_CLIENT_TAGS, 7);
+        final ClientState client3 = new ClientState(emptySet(), emptySet(), taskLags, EMPTY_CLIENT_TAGS, 7);
 
         final Map<UUID, ClientState> clientStates = getClientStatesMap(client1, client2, client3);
 
@@ -697,7 +707,7 @@ public class HighAvailabilityTaskAssignorTest {
             clientStates,
             allTasks,
             statefulTasks,
-            new AssignmentConfigs(0L, 1, 0, 60_000L)
+            new AssignmentConfigs(0L, 1, 0, 60_000L, EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
         );
 
         assertValidAssignment(
@@ -718,9 +728,9 @@ public class HighAvailabilityTaskAssignorTest {
         final Set<TaskId> statelessTasks = new HashSet<>(allTasks);
 
         final Map<TaskId, Long> taskLags = new HashMap<>();
-        final ClientState client1 = new ClientState(emptySet(), emptySet(), taskLags, 2);
-        final ClientState client2 = new ClientState(emptySet(), emptySet(), taskLags, 2);
-        final ClientState client3 = new ClientState(emptySet(), emptySet(), taskLags, 2);
+        final ClientState client1 = new ClientState(emptySet(), emptySet(), taskLags, EMPTY_CLIENT_TAGS, 2);
+        final ClientState client2 = new ClientState(emptySet(), emptySet(), taskLags, EMPTY_CLIENT_TAGS, 2);
+        final ClientState client3 = new ClientState(emptySet(), emptySet(), taskLags, EMPTY_CLIENT_TAGS, 2);
 
         final Map<UUID, ClientState> clientStates = getClientStatesMap(client1, client2, client3);
 
@@ -728,7 +738,7 @@ public class HighAvailabilityTaskAssignorTest {
             clientStates,
             allTasks,
             statefulTasks,
-            new AssignmentConfigs(0L, 1, 0, 60_000L)
+            new AssignmentConfigs(0L, 1, 0, 60_000L, EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
         );
 
         assertValidAssignment(
@@ -749,9 +759,9 @@ public class HighAvailabilityTaskAssignorTest {
         final Set<TaskId> statelessTasks = new HashSet<>(allTasks);
 
         final Map<TaskId, Long> taskLags = new HashMap<>();
-        final ClientState client1 = new ClientState(emptySet(), emptySet(), taskLags, 1);
-        final ClientState client2 = new ClientState(emptySet(), emptySet(), taskLags, 2);
-        final ClientState client3 = new ClientState(emptySet(), emptySet(), taskLags, 3);
+        final ClientState client1 = new ClientState(emptySet(), emptySet(), taskLags, EMPTY_CLIENT_TAGS, 1);
+        final ClientState client2 = new ClientState(emptySet(), emptySet(), taskLags, EMPTY_CLIENT_TAGS, 2);
+        final ClientState client3 = new ClientState(emptySet(), emptySet(), taskLags, EMPTY_CLIENT_TAGS, 3);
 
         final Map<UUID, ClientState> clientStates = getClientStatesMap(client1, client2, client3);
 
@@ -759,7 +769,7 @@ public class HighAvailabilityTaskAssignorTest {
             clientStates,
             allTasks,
             statefulTasks,
-            new AssignmentConfigs(0L, 1, 0, 60_000L)
+            new AssignmentConfigs(0L, 1, 0, 60_000L, EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
         );
 
         assertValidAssignment(
@@ -780,9 +790,9 @@ public class HighAvailabilityTaskAssignorTest {
         final Set<TaskId> statelessTasks = new HashSet<>(allTasks);
 
         final Map<TaskId, Long> taskLags = new HashMap<>();
-        final ClientState client1 = new ClientState(statelessTasks, emptySet(), taskLags, 3);
-        final ClientState client2 = new ClientState(emptySet(), emptySet(), taskLags, 3);
-        final ClientState client3 = new ClientState(emptySet(), emptySet(), taskLags, 3);
+        final ClientState client1 = new ClientState(statelessTasks, emptySet(), taskLags, EMPTY_CLIENT_TAGS, 3);
+        final ClientState client2 = new ClientState(emptySet(), emptySet(), taskLags, EMPTY_CLIENT_TAGS, 3);
+        final ClientState client3 = new ClientState(emptySet(), emptySet(), taskLags, EMPTY_CLIENT_TAGS, 3);
 
         final Map<UUID, ClientState> clientStates = getClientStatesMap(client1, client2, client3);
 
@@ -790,7 +800,7 @@ public class HighAvailabilityTaskAssignorTest {
             clientStates,
             allTasks,
             statefulTasks,
-            new AssignmentConfigs(0L, 1, 0, 60_000L)
+            new AssignmentConfigs(0L, 1, 0, 60_000L, EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
         );
 
         assertValidAssignment(
@@ -804,6 +814,27 @@ public class HighAvailabilityTaskAssignorTest {
         assertThat(probingRebalanceNeeded, is(false));
     }
 
+    @Test
+    public void shouldReturnClientTagAwareStandbyTaskAssignorWhenRackAwareAssignmentTagsIsSet() {
+        final StandbyTaskAssignor standbyTaskAssignor = HighAvailabilityTaskAssignor.createStandbyTaskAssignor(newAssignmentConfigs(1, singletonList("az")));
+        assertTrue(standbyTaskAssignor instanceof ClientTagAwareStandbyTaskAssignor);
+    }
+
+    @Test
+    public void shouldReturnDefaultStandbyTaskAssignorWhenRackAwareAssignmentTagsIsEmpty() {
+        final StandbyTaskAssignor standbyTaskAssignor = HighAvailabilityTaskAssignor.createStandbyTaskAssignor(newAssignmentConfigs(1, Collections.emptyList()));
+        assertTrue(standbyTaskAssignor instanceof DefaultStandbyTaskAssignor);
+    }
+
+    private static AssignorConfiguration.AssignmentConfigs newAssignmentConfigs(final int numStandbyReplicas,
+                                                                                final List<String> rackAwareAssignmentTags) {
+        return new AssignorConfiguration.AssignmentConfigs(0L,
+                                                           1,
+                                                           numStandbyReplicas,
+                                                           60000L,
+                                                           rackAwareAssignmentTags);
+    }
+
     private static void assertHasNoActiveTasks(final ClientState... clients) {
         for (final ClientState client : clients) {
             assertThat(client.activeTasks(), is(empty()));
@@ -829,6 +860,6 @@ public class HighAvailabilityTaskAssignorTest {
                 taskLags.put(task, Long.MAX_VALUE);
             }
         }
-        return new ClientState(statefulActiveTasks, emptySet(), taskLags, 1);
+        return new ClientState(statefulActiveTasks, emptySet(), taskLags, EMPTY_CLIENT_TAGS, 1);
     }
 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StandbyTaskAssignmentUtilsTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StandbyTaskAssignmentUtilsTest.java
new file mode 100644
index 0000000..1abf1b9
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StandbyTaskAssignmentUtilsTest.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import org.apache.kafka.streams.processor.TaskId;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.Map;
+import java.util.Set;
+import java.util.UUID;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+
+import static org.apache.kafka.common.utils.Utils.mkSet;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_0;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_1;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_2;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.getClientStatesMap;
+import static org.apache.kafka.streams.processor.internals.assignment.StandbyTaskAssignmentUtils.computeTasksToRemainingStandbys;
+import static org.apache.kafka.streams.processor.internals.assignment.StandbyTaskAssignmentUtils.pollClientAndMaybeAssignAndUpdateRemainingStandbyTasks;
+import static org.hamcrest.CoreMatchers.equalTo;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertTrue;
+
+public class StandbyTaskAssignmentUtilsTest {
+    private static final Set<TaskId> ACTIVE_TASKS = mkSet(TASK_0_0, TASK_0_1, TASK_0_2);
+
+    private Map<UUID, ClientState> clients;
+    private ConstrainedPrioritySet clientsByTaskLoad;
+
+    @Before
+    public void setup() {
+        clients = getClientStatesMap(ACTIVE_TASKS.stream().map(StandbyTaskAssignmentUtilsTest::mkState).toArray(ClientState[]::new));
+        clientsByTaskLoad = new ConstrainedPrioritySet(
+            (client, task) -> !clients.get(client).hasAssignedTask(task),
+            client -> clients.get(client).assignedTaskLoad()
+        );
+        clientsByTaskLoad.offerAll(clients.keySet());
+    }
+
+    @Test
+    public void shouldReturnNumberOfStandbyTasksThatWereNotAssigned() {
+        final Map<TaskId, Integer> tasksToRemainingStandbys = computeTasksToRemainingStandbys(3, ACTIVE_TASKS);
+
+        assertTrue(tasksToRemainingStandbys.keySet()
+                                           .stream()
+                                           .map(taskId -> pollClientAndMaybeAssignAndUpdateRemainingStandbyTasks(
+                                               clients,
+                                               tasksToRemainingStandbys,
+                                               clientsByTaskLoad,
+                                               taskId
+                                           ))
+                                           .allMatch(numRemainingStandbys -> numRemainingStandbys == 1));
+
+        assertTrue(ACTIVE_TASKS.stream().allMatch(activeTask -> tasksToRemainingStandbys.get(activeTask) == 1));
+        assertTrue(areStandbyTasksPresentForAllActiveTasks(2));
+    }
+
+    @Test
+    public void shouldReturnZeroWhenAllStandbyTasksWereSuccessfullyAssigned() {
+        final Map<TaskId, Integer> tasksToRemainingStandbys = computeTasksToRemainingStandbys(1, ACTIVE_TASKS);
+
+        assertTrue(tasksToRemainingStandbys.keySet()
+                                           .stream()
+                                           .map(taskId -> pollClientAndMaybeAssignAndUpdateRemainingStandbyTasks(
+                                               clients,
+                                               tasksToRemainingStandbys,
+                                               clientsByTaskLoad,
+                                               taskId
+                                           ))
+                                           .allMatch(numRemainingStandbys -> numRemainingStandbys == 0));
+
+        assertTrue(ACTIVE_TASKS.stream().allMatch(activeTask -> tasksToRemainingStandbys.get(activeTask) == 0));
+        assertTrue(areStandbyTasksPresentForAllActiveTasks(1));
+    }
+
+    @Test
+    public void shouldComputeTasksToRemainingStandbys() {
+        assertThat(
+            computeTasksToRemainingStandbys(0, ACTIVE_TASKS),
+            equalTo(
+                ACTIVE_TASKS.stream().collect(Collectors.toMap(Function.identity(), it -> 0))
+            )
+        );
+        assertThat(
+            computeTasksToRemainingStandbys(5, ACTIVE_TASKS),
+            equalTo(
+                ACTIVE_TASKS.stream().collect(Collectors.toMap(Function.identity(), it -> 5))
+            )
+        );
+    }
+
+    private boolean areStandbyTasksPresentForAllActiveTasks(final int expectedNumberOfStandbyTasks) {
+        return ACTIVE_TASKS.stream().allMatch(taskId -> clients.values()
+                                                               .stream()
+                                                               .filter(client -> client.hasStandbyTask(taskId))
+                                                               .count() == expectedNumberOfStandbyTasks);
+    }
+
+    private static ClientState mkState(final TaskId... activeTasks) {
+        return mkState(1, activeTasks);
+    }
+
+    private static ClientState mkState(final int capacity, final TaskId... activeTasks) {
+        final ClientState clientState = new ClientState(capacity);
+        for (final TaskId activeTask : activeTasks) {
+            clientState.assignActive(activeTask);
+        }
+        return clientState;
+    }
+}
\ No newline at end of file
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignorTest.java
index 7536ad2..8c1347f 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignorTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignorTest.java
@@ -33,6 +33,7 @@ import java.util.UUID;
 import static java.util.Arrays.asList;
 import static java.util.Collections.singleton;
 import static org.apache.kafka.common.utils.Utils.mkSet;
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.EMPTY_RACK_AWARE_ASSIGNMENT_TAGS;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_0;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_1;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_2;
@@ -676,7 +677,7 @@ public class StickyTaskAssignorTest {
             clients,
             new HashSet<>(taskIds),
             new HashSet<>(taskIds),
-            new AssignorConfiguration.AssignmentConfigs(0L, 1, 0, 60_000L)
+            new AssignorConfiguration.AssignmentConfigs(0L, 1, 0, 60_000L, EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
         );
         assertThat(probingRebalanceNeeded, is(false));
 
@@ -695,7 +696,7 @@ public class StickyTaskAssignorTest {
             clients,
             new HashSet<>(taskIds),
             new HashSet<>(taskIds),
-            new AssignorConfiguration.AssignmentConfigs(0L, 1, numStandbys, 60_000L)
+            new AssignorConfiguration.AssignmentConfigs(0L, 1, numStandbys, 60_000L, EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
         );
     }
 
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorConvergenceTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorConvergenceTest.java
index 68c9dfe..c1be5f3 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorConvergenceTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorConvergenceTest.java
@@ -29,6 +29,7 @@ import java.util.TreeSet;
 import java.util.UUID;
 import java.util.function.Supplier;
 
+import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.EMPTY_RACK_AWARE_ASSIGNMENT_TAGS;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.appendClientStates;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.assertBalancedActiveAssignment;
 import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.assertBalancedStatefulAssignment;
@@ -231,7 +232,8 @@ public class TaskAssignorConvergenceTest {
         final AssignmentConfigs configs = new AssignmentConfigs(100L,
                                                                 2,
                                                                 0,
-                                                                60_000L);
+                                                                60_000L,
+                                                                EMPTY_RACK_AWARE_ASSIGNMENT_TAGS);
 
         final Harness harness = Harness.initializeCluster(1, 1, 1, () -> 1);
 
@@ -250,7 +252,8 @@ public class TaskAssignorConvergenceTest {
         final AssignmentConfigs configs = new AssignmentConfigs(100L,
                                                                 maxWarmupReplicas,
                                                                 numStandbyReplicas,
-                                                                60_000L);
+                                                                60_000L,
+                                                                EMPTY_RACK_AWARE_ASSIGNMENT_TAGS);
 
         final Harness harness = Harness.initializeCluster(numStatelessTasks, numStatefulTasks, 1, () -> 5);
         testForConvergence(harness, configs, 1);
@@ -272,7 +275,8 @@ public class TaskAssignorConvergenceTest {
         final AssignmentConfigs configs = new AssignmentConfigs(100L,
                                                                 maxWarmupReplicas,
                                                                 numStandbyReplicas,
-                                                                60_000L);
+                                                                60_000L,
+                                                                EMPTY_RACK_AWARE_ASSIGNMENT_TAGS);
 
         final Harness harness = Harness.initializeCluster(numStatelessTasks, numStatefulTasks, 7, () -> 5);
         testForConvergence(harness, configs, 1);
@@ -313,7 +317,8 @@ public class TaskAssignorConvergenceTest {
             final AssignmentConfigs configs = new AssignmentConfigs(100L,
                                                                     maxWarmupReplicas,
                                                                     numStandbyReplicas,
-                                                                    60_000L);
+                                                                    60_000L,
+                                                                    EMPTY_RACK_AWARE_ASSIGNMENT_TAGS);
 
             harness = Harness.initializeCluster(
                 numStatelessTasks,