You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by ca...@apache.org on 2022/09/14 07:03:57 UTC

[kafka] branch trunk updated: KAFKA-10199: Suspend tasks in the state updater on revocation (#12600)

This is an automated email from the ASF dual-hosted git repository.

cadonna pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 1ab4596ee67 KAFKA-10199: Suspend tasks in the state updater on revocation (#12600)
1ab4596ee67 is described below

commit 1ab4596ee67f99cfee3dab2451223fd0c8ee7d30
Author: Bruno Cadonna <ca...@apache.org>
AuthorDate: Wed Sep 14 09:03:43 2022 +0200

    KAFKA-10199: Suspend tasks in the state updater on revocation (#12600)
    
    In the first attempt to handle revoked tasks in the state updater
    we removed the revoked tasks from the state updater and added it to
    the set of pending tasks to close cleanly. This is not correct since
    a revoked task that is immediately reassigned to the same stream thread
    would neither be re-added to the state updater nor be created again.
    Also a revoked active task might be added to more than one bookkeeping
    set in the tasks registry since it might still be returned from
    stateUpdater.getTasks() after it was removed from the state updater.
    The reason is that the removal from the state updater is done
    asynchronously.
    
    This PR solves this issue by introducing a new bookkeeping set
    in the tasks registry to bookkeep revoked active tasks (actually
    suspended active tasks).
    
    Additionally this PR closes some testing holes around the modified
    code.
    
    Reviewers: Guozhang Wang <wa...@gmail.com>, Hao Li <11...@users.noreply.github.com>
---
 .../processor/internals/PendingUpdateAction.java   |  80 +++
 .../streams/processor/internals/TaskManager.java   | 146 +++--
 .../kafka/streams/processor/internals/Tasks.java   |  55 +-
 .../streams/processor/internals/TasksRegistry.java |   4 +
 .../processor/internals/TaskManagerTest.java       | 671 ++++++++++++++++++---
 .../streams/processor/internals/TasksTest.java     | 146 ++++-
 6 files changed, 931 insertions(+), 171 deletions(-)

diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/PendingUpdateAction.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/PendingUpdateAction.java
new file mode 100644
index 00000000000..c51921d8aea
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/PendingUpdateAction.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.common.TopicPartition;
+
+import java.util.Objects;
+import java.util.Set;
+
+public class PendingUpdateAction {
+
+    enum Action {
+        UPDATE_INPUT_PARTITIONS,
+        RECYCLE,
+        SUSPEND,
+        CLOSE_DIRTY,
+        CLOSE_CLEAN
+    }
+
+    private final Set<TopicPartition> inputPartitions;
+    private final Action action;
+
+    private PendingUpdateAction(final Action action, final Set<TopicPartition> inputPartitions) {
+        this.action = action;
+        this.inputPartitions = inputPartitions;
+    }
+
+    private PendingUpdateAction(final Action action) {
+        this(action, null);
+    }
+
+    public static PendingUpdateAction createUpdateInputPartition(final Set<TopicPartition> inputPartitions) {
+        Objects.requireNonNull(inputPartitions, "Set of input partitions to update is null!");
+        return new PendingUpdateAction(Action.UPDATE_INPUT_PARTITIONS, inputPartitions);
+    }
+
+    public static PendingUpdateAction createRecycleTask(final Set<TopicPartition> inputPartitions) {
+        Objects.requireNonNull(inputPartitions, "Set of input partitions to update is null!");
+        return new PendingUpdateAction(Action.RECYCLE, inputPartitions);
+    }
+
+    public static PendingUpdateAction createSuspend() {
+        return new PendingUpdateAction(Action.SUSPEND);
+    }
+
+    public static PendingUpdateAction createCloseDirty() {
+        return new PendingUpdateAction(Action.CLOSE_DIRTY);
+    }
+
+    public static PendingUpdateAction createCloseClean() {
+        return new PendingUpdateAction(Action.CLOSE_CLEAN);
+    }
+
+    public Set<TopicPartition> getInputPartitions() {
+        if (action != Action.UPDATE_INPUT_PARTITIONS && action != Action.RECYCLE) {
+            throw new IllegalStateException("Action type " + action + " does not have a set of input partitions!");
+        }
+        return inputPartitions;
+    }
+
+    public Action getAction() {
+        return action;
+    }
+
+
+}
\ No newline at end of file
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
index d9d05391e9c..31eba40ae61 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
@@ -317,9 +317,9 @@ public class TaskManager {
         // 2. for tasks that have changed active/standby status, just recycle and skip re-creating them
         // 3. otherwise, close them since they are no longer owned
         if (stateUpdater == null) {
-            classifyTasksWithoutStateUpdater(activeTasksToCreate, standbyTasksToCreate, tasksToRecycle, tasksToCloseClean);
+            handleTasksWithoutStateUpdater(activeTasksToCreate, standbyTasksToCreate, tasksToRecycle, tasksToCloseClean);
         } else {
-            classifyTasksWithStateUpdater(activeTasksToCreate, standbyTasksToCreate, tasksToRecycle, tasksToCloseClean);
+            handleTasksWithStateUpdater(activeTasksToCreate, standbyTasksToCreate, tasksToRecycle, tasksToCloseClean);
         }
 
         final Map<TaskId, RuntimeException> taskCloseExceptions = closeAndRecycleTasks(tasksToRecycle, tasksToCloseClean);
@@ -389,10 +389,10 @@ public class TaskManager {
         }
     }
 
-    private void classifyTasksWithoutStateUpdater(final Map<TaskId, Set<TopicPartition>> activeTasksToCreate,
-                                                  final Map<TaskId, Set<TopicPartition>> standbyTasksToCreate,
-                                                  final Map<Task, Set<TopicPartition>> tasksToRecycle,
-                                                  final Set<Task> tasksToCloseClean) {
+    private void handleTasksWithoutStateUpdater(final Map<TaskId, Set<TopicPartition>> activeTasksToCreate,
+                                                final Map<TaskId, Set<TopicPartition>> standbyTasksToCreate,
+                                                final Map<Task, Set<TopicPartition>> tasksToRecycle,
+                                                final Set<Task> tasksToCloseClean) {
         for (final Task task : tasks.allTasks()) {
             final TaskId taskId = task.id();
             if (activeTasksToCreate.containsKey(taskId)) {
@@ -421,29 +421,28 @@ public class TaskManager {
         }
     }
 
-    private void classifyRunningTasks(final Map<TaskId, Set<TopicPartition>> activeTasksToCreate,
-                                      final Map<TaskId, Set<TopicPartition>> standbyTasksToCreate,
-                                      final Map<Task, Set<TopicPartition>> tasksToRecycle,
-                                      final Set<Task> tasksToCloseClean) {
+    private void handleTasksWithStateUpdater(final Map<TaskId, Set<TopicPartition>> activeTasksToCreate,
+                                             final Map<TaskId, Set<TopicPartition>> standbyTasksToCreate,
+                                             final Map<Task, Set<TopicPartition>> tasksToRecycle,
+                                             final Set<Task> tasksToCloseClean) {
+        handleRunningAndSuspendedTasks(activeTasksToCreate, standbyTasksToCreate, tasksToRecycle, tasksToCloseClean);
+        handleTasksInStateUpdater(activeTasksToCreate, standbyTasksToCreate);
+    }
+
+    private void handleRunningAndSuspendedTasks(final Map<TaskId, Set<TopicPartition>> activeTasksToCreate,
+                                                final Map<TaskId, Set<TopicPartition>> standbyTasksToCreate,
+                                                final Map<Task, Set<TopicPartition>> tasksToRecycle,
+                                                final Set<Task> tasksToCloseClean) {
         for (final Task task : tasks.allTasks()) {
+            if (!task.isActive()) {
+                throw new IllegalStateException("Standby tasks should only be managed by the state updater");
+            }
             final TaskId taskId = task.id();
             if (activeTasksToCreate.containsKey(taskId)) {
-                if (task.isActive()) {
-                    final Set<TopicPartition> topicPartitions = activeTasksToCreate.get(taskId);
-                    if (tasks.updateActiveTaskInputPartitions(task, topicPartitions)) {
-                        task.updateInputPartitions(topicPartitions, topologyMetadata.nodeToSourceTopics(task.id()));
-                    }
-                    task.resume();
-                } else {
-                    throw new IllegalStateException("Standby tasks should only be managed by the state updater");
-                }
+                handleReAssignedActiveTask(task, activeTasksToCreate.get(taskId));
                 activeTasksToCreate.remove(taskId);
             } else if (standbyTasksToCreate.containsKey(taskId)) {
-                if (!task.isActive()) {
-                    throw new IllegalStateException("Standby tasks should only be managed by the state updater");
-                } else {
-                    tasksToRecycle.put(task, standbyTasksToCreate.get(taskId));
-                }
+                tasksToRecycle.put(task, standbyTasksToCreate.get(taskId));
                 standbyTasksToCreate.remove(taskId);
             } else {
                 tasksToCloseClean.add(task);
@@ -451,43 +450,67 @@ public class TaskManager {
         }
     }
 
-    private void classifyTasksWithStateUpdater(final Map<TaskId, Set<TopicPartition>> activeTasksToCreate,
-                                               final Map<TaskId, Set<TopicPartition>> standbyTasksToCreate,
-                                               final Map<Task, Set<TopicPartition>> tasksToRecycle,
-                                               final Set<Task> tasksToCloseClean) {
-        classifyRunningTasks(activeTasksToCreate, standbyTasksToCreate, tasksToRecycle, tasksToCloseClean);
+    private void handleReAssignedActiveTask(final Task task,
+                                            final Set<TopicPartition> inputPartitions) {
+        if (tasks.updateActiveTaskInputPartitions(task, inputPartitions)) {
+            task.updateInputPartitions(inputPartitions, topologyMetadata.nodeToSourceTopics(task.id()));
+        }
+        if (task.state() == State.SUSPENDED) {
+            task.resume();
+            moveTaskFromTasksRegistryToStateUpdater(task);
+        }
+    }
+
+    private void moveTaskFromTasksRegistryToStateUpdater(final Task task) {
+        tasks.removeTask(task);
+        stateUpdater.add(task);
+    }
+
+    private void handleTasksInStateUpdater(final Map<TaskId, Set<TopicPartition>> activeTasksToCreate,
+                                           final Map<TaskId, Set<TopicPartition>> standbyTasksToCreate) {
         for (final Task task : stateUpdater.getTasks()) {
             final TaskId taskId = task.id();
-            final Set<TopicPartition> topicPartitions = activeTasksToCreate.get(taskId);
             if (activeTasksToCreate.containsKey(taskId)) {
+                final Set<TopicPartition> inputPartitions = activeTasksToCreate.get(taskId);
                 if (task.isActive()) {
-                    if (!task.inputPartitions().equals(topicPartitions)) {
-                        stateUpdater.remove(taskId);
-                        tasks.addPendingTaskToUpdateInputPartitions(taskId, topicPartitions);
-                    }
+                    updateInputPartitionsOrRemoveTaskFromTasksToSuspend(task, inputPartitions);
                 } else {
-                    stateUpdater.remove(taskId);
-                    tasks.addPendingTaskToRecycle(taskId, topicPartitions);
+                    removeTaskToRecycleFromStateUpdater(taskId, inputPartitions);
                 }
                 activeTasksToCreate.remove(taskId);
             } else if (standbyTasksToCreate.containsKey(taskId)) {
-                if (!task.isActive()) {
-                    if (!task.inputPartitions().equals(topicPartitions)) {
-                        stateUpdater.remove(taskId);
-                        tasks.addPendingTaskToUpdateInputPartitions(taskId, topicPartitions);
-                    }
-                } else {
-                    stateUpdater.remove(taskId);
-                    tasks.addPendingTaskToRecycle(taskId, topicPartitions);
+                if (task.isActive()) {
+                    removeTaskToRecycleFromStateUpdater(taskId, standbyTasksToCreate.get(taskId));
                 }
                 standbyTasksToCreate.remove(taskId);
             } else {
-                stateUpdater.remove(taskId);
-                tasks.addPendingTaskToCloseClean(taskId);
+                removeUnusedTaskFromStateUpdater(taskId);
             }
         }
     }
 
+    private void updateInputPartitionsOrRemoveTaskFromTasksToSuspend(final Task task,
+                                                                     final Set<TopicPartition> inputPartitions) {
+        final TaskId taskId = task.id();
+        if (!task.inputPartitions().equals(inputPartitions)) {
+            stateUpdater.remove(taskId);
+            tasks.addPendingTaskToUpdateInputPartitions(taskId, inputPartitions);
+        } else {
+            tasks.removePendingActiveTaskToSuspend(taskId);
+        }
+    }
+
+    private void removeTaskToRecycleFromStateUpdater(final TaskId taskId,
+                                                     final Set<TopicPartition> inputPartitions) {
+        stateUpdater.remove(taskId);
+        tasks.addPendingTaskToRecycle(taskId, inputPartitions);
+    }
+
+    private void removeUnusedTaskFromStateUpdater(final TaskId taskId) {
+        stateUpdater.remove(taskId);
+        tasks.addPendingTaskToCloseClean(taskId);
+    }
+
     private Map<TaskId, Set<TopicPartition>> pendingTasksToCreate(final Map<TaskId, Set<TopicPartition>> tasksToCreate) {
         final Map<TaskId, Set<TopicPartition>> pendingTasks = new HashMap<>();
         final Iterator<Map.Entry<TaskId, Set<TopicPartition>>> iter = tasksToCreate.entrySet().iterator();
@@ -685,10 +708,10 @@ public class TaskManager {
         return !stateUpdater.restoresActiveTasks();
     }
 
-    private void recycleTask(final Task task,
-                             final Set<TopicPartition> inputPartitions,
-                             final Set<Task> tasksToCloseDirty,
-                             final Map<TaskId, RuntimeException> taskExceptions) {
+    private void recycleTaskFromStateUpdater(final Task task,
+                                             final Set<TopicPartition> inputPartitions,
+                                             final Set<Task> tasksToCloseDirty,
+                                             final Map<TaskId, RuntimeException> taskExceptions) {
         Task newTask = null;
         try {
             task.suspend();
@@ -786,7 +809,7 @@ public class TaskManager {
         for (final Task task : stateUpdater.drainRemovedTasks()) {
             Set<TopicPartition> inputPartitions;
             if ((inputPartitions = tasks.removePendingTaskToRecycle(task.id())) != null) {
-                recycleTask(task, inputPartitions, tasksToCloseDirty, taskExceptions);
+                recycleTaskFromStateUpdater(task, inputPartitions, tasksToCloseDirty, taskExceptions);
             } else if (tasks.removePendingTaskToCloseClean(task.id())) {
                 closeTaskClean(task, tasksToCloseDirty, taskExceptions);
             } else if (tasks.removePendingTaskToCloseDirty(task.id())) {
@@ -794,9 +817,12 @@ public class TaskManager {
             } else if ((inputPartitions = tasks.removePendingTaskToUpdateInputPartitions(task.id())) != null) {
                 task.updateInputPartitions(inputPartitions, topologyMetadata.nodeToSourceTopics(task.id()));
                 stateUpdater.add(task);
+            } else if (tasks.removePendingActiveTaskToSuspend(task.id())) {
+                task.suspend();
+                tasks.addTask(task);
             } else {
                 throw new IllegalStateException("Got a removed task " + task.id() + " from the state updater " +
-                    " that is not for recycle, closing, or updating input partitions; this should not happen");
+                    "that is not for recycle, closing, or updating input partitions; this should not happen");
             }
         }
 
@@ -808,7 +834,7 @@ public class TaskManager {
         maybeThrowTaskExceptions(taskExceptions);
     }
 
-    private boolean handleRestoredTasksFromStateUpdater(final long now,
+    private void handleRestoredTasksFromStateUpdater(final long now,
                                                         final java.util.function.Consumer<Set<TopicPartition>> offsetResetter) {
         final Map<TaskId, RuntimeException> taskExceptions = new LinkedHashMap<>();
         final Set<Task> tasksToCloseDirty = new TreeSet<>(Comparator.comparing(Task::id));
@@ -817,7 +843,7 @@ public class TaskManager {
         for (final Task task : stateUpdater.drainRestoredActiveTasks(timeout)) {
             Set<TopicPartition> inputPartitions;
             if ((inputPartitions = tasks.removePendingTaskToRecycle(task.id())) != null) {
-                recycleTask(task, inputPartitions, tasksToCloseDirty, taskExceptions);
+                recycleTaskFromStateUpdater(task, inputPartitions, tasksToCloseDirty, taskExceptions);
             } else if (tasks.removePendingTaskToCloseClean(task.id())) {
                 closeTaskClean(task, tasksToCloseDirty, taskExceptions);
             } else if (tasks.removePendingTaskToCloseDirty(task.id())) {
@@ -825,6 +851,9 @@ public class TaskManager {
             } else if ((inputPartitions = tasks.removePendingTaskToUpdateInputPartitions(task.id())) != null) {
                 task.updateInputPartitions(inputPartitions, topologyMetadata.nodeToSourceTopics(task.id()));
                 transitRestoredTaskToRunning(task, now, offsetResetter);
+            } else if (tasks.removePendingActiveTaskToSuspend(task.id())) {
+                task.suspend();
+                tasks.addTask(task);
             } else {
                 transitRestoredTaskToRunning(task, now, offsetResetter);
             }
@@ -835,8 +864,6 @@ public class TaskManager {
         }
 
         maybeThrowTaskExceptions(taskExceptions);
-
-        return !stateUpdater.restoresActiveTasks();
     }
 
     /**
@@ -868,7 +895,7 @@ public class TaskManager {
             }
         }
 
-        removeRevokedTasksFromStateUpdater(remainingRevokedPartitions);
+        addRevokedTasksInStateUpdaterToPendingTasksToSuspend(remainingRevokedPartitions);
 
         if (!remainingRevokedPartitions.isEmpty()) {
             log.debug("The following revoked partitions {} are missing from the current task partitions. It could "
@@ -955,13 +982,12 @@ public class TaskManager {
         }
     }
 
-    private void removeRevokedTasksFromStateUpdater(final Set<TopicPartition> remainingRevokedPartitions) {
+    private void addRevokedTasksInStateUpdaterToPendingTasksToSuspend(final Set<TopicPartition> remainingRevokedPartitions) {
         if (stateUpdater != null) {
             for (final Task restoringTask : stateUpdater.getTasks()) {
                 if (restoringTask.isActive()) {
                     if (remainingRevokedPartitions.containsAll(restoringTask.inputPartitions())) {
-                        tasks.addPendingTaskToCloseClean(restoringTask.id());
-                        stateUpdater.remove(restoringTask.id());
+                        tasks.addPendingActiveTaskToSuspend(restoringTask.id());
                         remainingRevokedPartitions.removeAll(restoringTask.inputPartitions());
                     }
                 }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java
index 7630eaff31c..b29ec9ff051 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java
@@ -19,6 +19,7 @@ package org.apache.kafka.streams.processor.internals;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.internals.PendingUpdateAction.Action;
 import org.slf4j.Logger;
 
 import java.util.Collection;
@@ -52,11 +53,8 @@ class Tasks implements TasksRegistry {
     // we receive a new assignment and they are revoked from the thread.
     private final Map<TaskId, Set<TopicPartition>> pendingActiveTasksToCreate = new HashMap<>();
     private final Map<TaskId, Set<TopicPartition>> pendingStandbyTasksToCreate = new HashMap<>();
-    private final Map<TaskId, Set<TopicPartition>> pendingTasksToRecycle = new HashMap<>();
-    private final Map<TaskId, Set<TopicPartition>> pendingTasksToUpdateInputPartitions = new HashMap<>();
     private final Set<Task> pendingTasksToInit = new HashSet<>();
-    private final Set<TaskId> pendingTasksToCloseClean = new HashSet<>();
-    private final Set<TaskId> pendingTasksToCloseDirty = new HashSet<>();
+    private final Map<TaskId, PendingUpdateAction> pendingUpdateActions = new HashMap<>();
 
     // TODO: convert to Stream/StandbyTask when we remove TaskManager#StateMachineTask with mocks
     private final Map<TopicPartition, Task> activeTasksPerPartition = new HashMap<>();
@@ -103,42 +101,75 @@ class Tasks implements TasksRegistry {
 
     @Override
     public Set<TopicPartition> removePendingTaskToRecycle(final TaskId taskId) {
-        return pendingTasksToRecycle.remove(taskId);
+        if (containsTaskIdWithAction(taskId, Action.RECYCLE)) {
+            return pendingUpdateActions.remove(taskId).getInputPartitions();
+        }
+        return null;
     }
 
     @Override
     public void addPendingTaskToRecycle(final TaskId taskId, final Set<TopicPartition> inputPartitions) {
-        pendingTasksToRecycle.put(taskId, inputPartitions);
+        pendingUpdateActions.put(taskId, PendingUpdateAction.createRecycleTask(inputPartitions));
     }
 
     @Override
     public Set<TopicPartition> removePendingTaskToUpdateInputPartitions(final TaskId taskId) {
-        return pendingTasksToUpdateInputPartitions.remove(taskId);
+        if (containsTaskIdWithAction(taskId, Action.UPDATE_INPUT_PARTITIONS)) {
+            return pendingUpdateActions.remove(taskId).getInputPartitions();
+        }
+        return null;
     }
 
     @Override
     public void addPendingTaskToUpdateInputPartitions(final TaskId taskId, final Set<TopicPartition> inputPartitions) {
-        pendingTasksToUpdateInputPartitions.put(taskId, inputPartitions);
+        pendingUpdateActions.put(taskId, PendingUpdateAction.createUpdateInputPartition(inputPartitions));
     }
 
     @Override
     public boolean removePendingTaskToCloseDirty(final TaskId taskId) {
-        return pendingTasksToCloseDirty.remove(taskId);
+        if (containsTaskIdWithAction(taskId, Action.CLOSE_DIRTY)) {
+            pendingUpdateActions.remove(taskId);
+            return true;
+        }
+        return false;
     }
 
     @Override
     public void addPendingTaskToCloseDirty(final TaskId taskId) {
-        pendingTasksToCloseDirty.add(taskId);
+        pendingUpdateActions.put(taskId, PendingUpdateAction.createCloseDirty());
     }
 
     @Override
     public boolean removePendingTaskToCloseClean(final TaskId taskId) {
-        return pendingTasksToCloseClean.remove(taskId);
+        if (containsTaskIdWithAction(taskId, Action.CLOSE_CLEAN)) {
+            pendingUpdateActions.remove(taskId);
+            return true;
+        }
+        return false;
     }
 
     @Override
     public void addPendingTaskToCloseClean(final TaskId taskId) {
-        pendingTasksToCloseClean.add(taskId);
+        pendingUpdateActions.put(taskId, PendingUpdateAction.createCloseClean());
+    }
+
+    @Override
+    public boolean removePendingActiveTaskToSuspend(final TaskId taskId) {
+        if (containsTaskIdWithAction(taskId, Action.SUSPEND)) {
+            pendingUpdateActions.remove(taskId);
+            return true;
+        }
+        return false;
+    }
+
+    @Override
+    public void addPendingActiveTaskToSuspend(final TaskId taskId) {
+        pendingUpdateActions.put(taskId, PendingUpdateAction.createSuspend());
+    }
+
+    private boolean containsTaskIdWithAction(final TaskId taskId, final Action action) {
+        final PendingUpdateAction pendingUpdateAction = pendingUpdateActions.get(taskId);
+        return !(pendingUpdateAction == null || pendingUpdateAction.getAction() != action);
     }
 
     @Override
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TasksRegistry.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TasksRegistry.java
index 91330fc4d33..c93c4f145c6 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TasksRegistry.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TasksRegistry.java
@@ -55,6 +55,10 @@ public interface TasksRegistry {
 
     void addPendingTaskToInit(final Collection<Task> tasks);
 
+    boolean removePendingActiveTaskToSuspend(final TaskId taskId);
+
+    void addPendingActiveTaskToSuspend(final TaskId taskId);
+
     void addActiveTasks(final Collection<Task> tasks);
 
     void addStandbyTasks(final Collection<Task> tasks);
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
index 5335d4dbbef..82683ec4488 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
@@ -250,32 +250,472 @@ public class TaskManagerTest {
     }
 
     @Test
-    public void shouldClassifyExistingTasksWithStateUpdater() {
-        final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, true);
-        final StandbyTask standbyTaskToRecycle = standbyTask(taskId02, mkSet(t2p2)).build();
-        final StandbyTask standbyTaskToClose = standbyTask(taskId04, mkSet(t2p0)).build();
-        final StreamTask restoringActiveTaskToRecycle = statefulTask(taskId03, mkSet(t1p3)).build();
-        final StreamTask restoringActiveTaskToClose = statefulTask(taskId01, mkSet(t1p1)).build();
-        final Map<TaskId, Set<TopicPartition>> standbyTasks =
-            mkMap(mkEntry(standbyTaskToRecycle.id(), standbyTaskToRecycle.changelogPartitions()));
-        final Map<TaskId, Set<TopicPartition>> restoringActiveTasks = mkMap(
-            mkEntry(restoringActiveTaskToRecycle.id(), restoringActiveTaskToRecycle.changelogPartitions())
+    public void shouldPrepareActiveTaskInStateUpdaterToBeRecycled() {
+        final StreamTask activeTaskToRecycle = statefulTask(taskId03, taskId03ChangelogPartitions)
+            .inState(State.RESTORING)
+            .withInputPartitions(taskId03Partitions).build();
+        final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
+        final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
+        when(stateUpdater.getTasks()).thenReturn(mkSet(activeTaskToRecycle));
+        expect(activeTaskCreator.createTasks(consumer, Collections.emptyMap())).andReturn(emptySet());
+        expect(standbyTaskCreator.createTasks(Collections.emptyMap())).andReturn(emptySet());
+        replay(activeTaskCreator, standbyTaskCreator);
+
+        taskManager.handleAssignment(
+            Collections.emptyMap(),
+            mkMap(mkEntry(activeTaskToRecycle.id(), activeTaskToRecycle.inputPartitions()))
         );
-        when(stateUpdater.getTasks()).thenReturn(mkSet(
-            standbyTaskToRecycle,
-            restoringActiveTaskToRecycle,
-            restoringActiveTaskToClose,
-            standbyTaskToClose
-        ));
-        handleAssignment(Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap());
 
-        taskManager.handleAssignment(standbyTasks, restoringActiveTasks);
+        verify(activeTaskCreator, standbyTaskCreator);
+        Mockito.verify(tasks).addPendingTaskToRecycle(activeTaskToRecycle.id(), activeTaskToRecycle.inputPartitions());
+    }
+
+    @Test
+    public void shouldPrepareStandbyTaskInStateUpdaterToBeRecycled() {
+        final StandbyTask standbyTaskToRecycle = standbyTask(taskId03, taskId03ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId03Partitions).build();
+        final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
+        final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
+        when(stateUpdater.getTasks()).thenReturn(mkSet(standbyTaskToRecycle));
+        expect(activeTaskCreator.createTasks(consumer, Collections.emptyMap())).andReturn(emptySet());
+        expect(standbyTaskCreator.createTasks(Collections.emptyMap())).andReturn(emptySet());
+        replay(activeTaskCreator, standbyTaskCreator);
+
+        taskManager.handleAssignment(
+            mkMap(mkEntry(standbyTaskToRecycle.id(), standbyTaskToRecycle.inputPartitions())),
+            Collections.emptyMap()
+        );
 
-        Mockito.verify(stateUpdater).getTasks();
+        verify(activeTaskCreator, standbyTaskCreator);
         Mockito.verify(stateUpdater).remove(standbyTaskToRecycle.id());
+        Mockito.verify(tasks).addPendingTaskToRecycle(standbyTaskToRecycle.id(), standbyTaskToRecycle.inputPartitions());
+    }
+
+    @Test
+    public void shouldRemoveUnusedActiveTaskFromStateUpdater() {
+        final StreamTask activeTaskToClose = statefulTask(taskId03, taskId03ChangelogPartitions)
+            .inState(State.RESTORING)
+            .withInputPartitions(taskId03Partitions).build();
+        final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
+        final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
+        when(stateUpdater.getTasks()).thenReturn(mkSet(activeTaskToClose));
+        expect(activeTaskCreator.createTasks(consumer, Collections.emptyMap())).andReturn(emptySet());
+        expect(standbyTaskCreator.createTasks(Collections.emptyMap())).andReturn(emptySet());
+        replay(activeTaskCreator, standbyTaskCreator);
+
+        taskManager.handleAssignment(Collections.emptyMap(), Collections.emptyMap());
+
+        verify(activeTaskCreator, standbyTaskCreator);
+        Mockito.verify(stateUpdater).remove(activeTaskToClose.id());
+        Mockito.verify(tasks).addPendingTaskToCloseClean(activeTaskToClose.id());
+    }
+
+    @Test
+    public void shouldRemoveUnusedStandbyTaskFromStateUpdater() {
+        final StandbyTask standbyTaskToClose = standbyTask(taskId02, taskId02ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId02Partitions).build();
+        final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
+        final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
+        when(stateUpdater.getTasks()).thenReturn(mkSet(standbyTaskToClose));
+        expect(activeTaskCreator.createTasks(consumer, Collections.emptyMap())).andReturn(emptySet());
+        expect(standbyTaskCreator.createTasks(Collections.emptyMap())).andReturn(emptySet());
+        replay(activeTaskCreator, standbyTaskCreator);
+
+        taskManager.handleAssignment(Collections.emptyMap(), Collections.emptyMap());
+
+        verify(activeTaskCreator, standbyTaskCreator);
         Mockito.verify(stateUpdater).remove(standbyTaskToClose.id());
-        Mockito.verify(stateUpdater).remove(restoringActiveTaskToRecycle.id());
-        Mockito.verify(stateUpdater).remove(restoringActiveTaskToClose.id());
+        Mockito.verify(tasks).addPendingTaskToCloseClean(standbyTaskToClose.id());
+    }
+
+    @Test
+    public void shouldUpdateInputPartitionOfActiveTaskInStateUpdater() {
+        final StreamTask activeTaskToUpdateInputPartitions = statefulTask(taskId03, taskId03ChangelogPartitions)
+            .inState(State.RESTORING)
+            .withInputPartitions(taskId03Partitions).build();
+        final Set<TopicPartition> newInputPartitions = taskId02Partitions;
+        final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
+        final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
+        when(stateUpdater.getTasks()).thenReturn(mkSet(activeTaskToUpdateInputPartitions));
+        expect(activeTaskCreator.createTasks(consumer, Collections.emptyMap())).andReturn(emptySet());
+        expect(standbyTaskCreator.createTasks(Collections.emptyMap())).andReturn(emptySet());
+        replay(activeTaskCreator, standbyTaskCreator);
+
+        taskManager.handleAssignment(
+            mkMap(mkEntry(activeTaskToUpdateInputPartitions.id(), newInputPartitions)),
+            Collections.emptyMap()
+        );
+
+        verify(activeTaskCreator, standbyTaskCreator);
+        Mockito.verify(stateUpdater).remove(activeTaskToUpdateInputPartitions.id());
+        Mockito.verify(tasks).addPendingTaskToUpdateInputPartitions(activeTaskToUpdateInputPartitions.id(), newInputPartitions);
+    }
+
+    @Test
+    public void shouldKeepReAssignedActiveTaskInStateUpdater() {
+        final StreamTask reassignedActiveTask = statefulTask(taskId03, taskId03ChangelogPartitions)
+            .inState(State.RESTORING)
+            .withInputPartitions(taskId03Partitions).build();
+        final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
+        final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
+        when(stateUpdater.getTasks()).thenReturn(mkSet(reassignedActiveTask));
+        expect(activeTaskCreator.createTasks(consumer, Collections.emptyMap())).andReturn(emptySet());
+        expect(standbyTaskCreator.createTasks(Collections.emptyMap())).andReturn(emptySet());
+        replay(activeTaskCreator, standbyTaskCreator);
+
+        taskManager.handleAssignment(
+            mkMap(mkEntry(reassignedActiveTask.id(), reassignedActiveTask.inputPartitions())),
+            Collections.emptyMap()
+        );
+
+        verify(activeTaskCreator, standbyTaskCreator);
+    }
+
+    @Test
+    public void shouldRemoveReAssignedRevokedActiveTaskInStateUpdaterFromPendingTaskToSuspend() {
+        final StreamTask reAssignedRevokedActiveTask = statefulTask(taskId03, taskId03ChangelogPartitions)
+            .inState(State.RESTORING)
+            .withInputPartitions(taskId03Partitions).build();
+        final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
+        final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
+        when(stateUpdater.getTasks()).thenReturn(mkSet(reAssignedRevokedActiveTask));
+        expect(activeTaskCreator.createTasks(consumer, Collections.emptyMap())).andReturn(emptySet());
+        expect(standbyTaskCreator.createTasks(Collections.emptyMap())).andReturn(emptySet());
+        replay(activeTaskCreator, standbyTaskCreator);
+
+        taskManager.handleAssignment(
+            mkMap(mkEntry(reAssignedRevokedActiveTask.id(), reAssignedRevokedActiveTask.inputPartitions())),
+            Collections.emptyMap()
+        );
+
+        verify(activeTaskCreator, standbyTaskCreator);
+        Mockito.verify(tasks).removePendingActiveTaskToSuspend(reAssignedRevokedActiveTask.id());
+    }
+
+    @Test
+    public void shouldNeverUpdateInputPartitionsOfStandbyTaskInStateUpdater() {
+        final StandbyTask standbyTaskToUpdateInputPartitions = standbyTask(taskId02, taskId02ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId02Partitions).build();
+        final Set<TopicPartition> newInputPartitions = taskId03Partitions;
+        final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
+        final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
+        when(stateUpdater.getTasks()).thenReturn(mkSet(standbyTaskToUpdateInputPartitions));
+        expect(activeTaskCreator.createTasks(consumer, Collections.emptyMap())).andReturn(emptySet());
+        expect(standbyTaskCreator.createTasks(Collections.emptyMap())).andReturn(emptySet());
+        replay(activeTaskCreator, standbyTaskCreator);
+
+        taskManager.handleAssignment(
+            Collections.emptyMap(),
+            mkMap(mkEntry(standbyTaskToUpdateInputPartitions.id(), newInputPartitions))
+        );
+
+        verify(activeTaskCreator, standbyTaskCreator);
+        Mockito.verify(stateUpdater, never()).remove(standbyTaskToUpdateInputPartitions.id());
+        Mockito.verify(tasks, never())
+            .addPendingTaskToUpdateInputPartitions(standbyTaskToUpdateInputPartitions.id(), newInputPartitions);
+    }
+
+    @Test
+    public void shouldKeepReAssignedStandbyTaskInStateUpdater() {
+        final StandbyTask reAssignedStandbyTask = standbyTask(taskId02, taskId02ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId02Partitions).build();
+        final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
+        final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
+        when(stateUpdater.getTasks()).thenReturn(mkSet(reAssignedStandbyTask));
+        expect(activeTaskCreator.createTasks(consumer, Collections.emptyMap())).andReturn(emptySet());
+        expect(standbyTaskCreator.createTasks(Collections.emptyMap())).andReturn(emptySet());
+        replay(activeTaskCreator, standbyTaskCreator);
+
+        taskManager.handleAssignment(
+            Collections.emptyMap(),
+            mkMap(mkEntry(reAssignedStandbyTask.id(), reAssignedStandbyTask.inputPartitions()))
+        );
+
+        verify(activeTaskCreator, standbyTaskCreator);
+    }
+
+    @Test
+    public void shouldAssignMultipleTasksInStateUpdater() {
+        final StreamTask activeTaskToClose = statefulTask(taskId03, taskId03ChangelogPartitions)
+            .inState(State.RESTORING)
+            .withInputPartitions(taskId03Partitions).build();
+        final StandbyTask standbyTaskToRecycle = standbyTask(taskId02, taskId02ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId02Partitions).build();
+        final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
+        final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
+        when(stateUpdater.getTasks()).thenReturn(mkSet(activeTaskToClose, standbyTaskToRecycle));
+        expect(activeTaskCreator.createTasks(consumer, Collections.emptyMap())).andReturn(emptySet());
+        expect(standbyTaskCreator.createTasks(Collections.emptyMap())).andReturn(emptySet());
+        replay(activeTaskCreator, standbyTaskCreator);
+
+        taskManager.handleAssignment(
+            mkMap(mkEntry(standbyTaskToRecycle.id(), standbyTaskToRecycle.inputPartitions())),
+            Collections.emptyMap()
+        );
+
+        verify(activeTaskCreator, standbyTaskCreator);
+        Mockito.verify(stateUpdater).remove(activeTaskToClose.id());
+        Mockito.verify(tasks).addPendingTaskToCloseClean(activeTaskToClose.id());
+        Mockito.verify(stateUpdater).remove(standbyTaskToRecycle.id());
+        Mockito.verify(tasks).addPendingTaskToRecycle(standbyTaskToRecycle.id(), standbyTaskToRecycle.inputPartitions());
+    }
+
+    @Test
+    public void shouldCreateActiveTaskDuringAssignment() {
+        final StreamTask activeTaskToBeCreated = statefulTask(taskId03, taskId03ChangelogPartitions)
+            .inState(State.CREATED)
+            .withInputPartitions(taskId03Partitions).build();
+        final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
+        final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
+        final Set<Task> createdTasks = mkSet(activeTaskToBeCreated);
+        expect(activeTaskCreator.createTasks(consumer, mkMap(
+            mkEntry(activeTaskToBeCreated.id(), activeTaskToBeCreated.inputPartitions())))
+        ).andReturn(createdTasks);
+        expect(standbyTaskCreator.createTasks(Collections.emptyMap())).andReturn(emptySet());
+        replay(activeTaskCreator, standbyTaskCreator);
+
+        taskManager.handleAssignment(
+            mkMap(mkEntry(activeTaskToBeCreated.id(), activeTaskToBeCreated.inputPartitions())),
+            Collections.emptyMap()
+        );
+
+        verify(activeTaskCreator, standbyTaskCreator);
+        Mockito.verify(tasks).addPendingTaskToInit(createdTasks);
+    }
+
+    @Test
+    public void shouldCreateStandbyTaskDuringAssignment() {
+        final StandbyTask standbyTaskToBeCreated = standbyTask(taskId02, taskId02ChangelogPartitions)
+            .inState(State.CREATED)
+            .withInputPartitions(taskId02Partitions).build();
+        final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
+        final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
+        final Set<Task> createdTasks = mkSet(standbyTaskToBeCreated);
+        expect(activeTaskCreator.createTasks(consumer, Collections.emptyMap())).andReturn(emptySet());
+        expect(standbyTaskCreator.createTasks(mkMap(
+            mkEntry(standbyTaskToBeCreated.id(), standbyTaskToBeCreated.inputPartitions())))
+        ).andReturn(createdTasks);
+        replay(activeTaskCreator, standbyTaskCreator);
+
+        taskManager.handleAssignment(
+            Collections.emptyMap(),
+            mkMap(mkEntry(standbyTaskToBeCreated.id(), standbyTaskToBeCreated.inputPartitions()))
+        );
+
+        verify(activeTaskCreator, standbyTaskCreator);
+        Mockito.verify(tasks).addPendingTaskToInit(createdTasks);
+    }
+
+    @Test
+    public void shouldAssignActiveTaskInTasksRegistryToBeRecycledWithStateUpdaterEnabled() {
+        final StreamTask activeTaskToRecycle = statefulTask(taskId03, taskId03ChangelogPartitions)
+            .inState(State.SUSPENDED)
+            .withInputPartitions(taskId03Partitions).build();
+        final StandbyTask recycledStandbyTask = standbyTask(taskId03, taskId03ChangelogPartitions)
+            .inState(State.CREATED)
+            .withInputPartitions(taskId03Partitions).build();
+        final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
+        when(tasks.allTasks()).thenReturn(mkSet(activeTaskToRecycle));
+        final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
+        expect(standbyTaskCreator.createStandbyTaskFromActive(activeTaskToRecycle, activeTaskToRecycle.inputPartitions()))
+            .andReturn(recycledStandbyTask);
+        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(activeTaskToRecycle.id());
+        expect(activeTaskCreator.createTasks(consumer, Collections.emptyMap())).andReturn(emptySet());
+        expect(standbyTaskCreator.createTasks(Collections.emptyMap())).andReturn(emptySet());
+        replay(activeTaskCreator, standbyTaskCreator);
+
+        taskManager.handleAssignment(
+            Collections.emptyMap(),
+            mkMap(mkEntry(activeTaskToRecycle.id(), activeTaskToRecycle.inputPartitions()))
+        );
+
+        verify(activeTaskCreator, standbyTaskCreator);
+        Mockito.verify(activeTaskToRecycle).prepareCommit();
+        Mockito.verify(tasks).replaceActiveWithStandby(recycledStandbyTask);
+    }
+
+    @Test
+    public void shouldThrowDuringAssignmentIfStandbyTaskToRecycleIsFoundInTasksRegistryWithStateUpdaterEnabled() {
+        final StandbyTask standbyTaskToRecycle = standbyTask(taskId03, taskId03ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId03Partitions).build();
+        final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
+        when(tasks.allTasks()).thenReturn(mkSet(standbyTaskToRecycle));
+        final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
+        replay(activeTaskCreator, standbyTaskCreator);
+
+        final IllegalStateException illegalStateException = assertThrows(
+            IllegalStateException.class,
+            () -> taskManager.handleAssignment(
+                mkMap(mkEntry(standbyTaskToRecycle.id(), standbyTaskToRecycle.inputPartitions())),
+                Collections.emptyMap()
+            )
+        );
+
+        assertEquals(illegalStateException.getMessage(), "Standby tasks should only be managed by the state updater");
+        verify(activeTaskCreator, standbyTaskCreator);
+    }
+
+    @Test
+    public void shouldAssignActiveTaskInTasksRegistryToBeClosedCleanlyWithStateUpdaterEnabled() {
+        final StreamTask activeTaskToClose = statefulTask(taskId03, taskId03ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId03Partitions).build();
+        final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
+        final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
+        when(tasks.allTasks()).thenReturn(mkSet(activeTaskToClose));
+        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(activeTaskToClose.id());
+        expect(activeTaskCreator.createTasks(consumer, Collections.emptyMap())).andReturn(emptySet());
+        expect(standbyTaskCreator.createTasks(Collections.emptyMap())).andReturn(emptySet());
+        replay(activeTaskCreator, standbyTaskCreator);
+
+        taskManager.handleAssignment(Collections.emptyMap(), Collections.emptyMap());
+
+        verify(activeTaskCreator, standbyTaskCreator);
+        Mockito.verify(activeTaskToClose).prepareCommit();
+        Mockito.verify(activeTaskToClose).closeClean();
+        Mockito.verify(tasks).removeTask(activeTaskToClose);
+    }
+
+    @Test
+    public void shouldThrowDuringAssignmentIfStandbyTaskToCloseIsFoundInTasksRegistryWithStateUpdaterEnabled() {
+        final StandbyTask standbyTaskToClose = standbyTask(taskId03, taskId03ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId03Partitions).build();
+        final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
+        final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
+        when(tasks.allTasks()).thenReturn(mkSet(standbyTaskToClose));
+        replay(activeTaskCreator, standbyTaskCreator);
+
+        final IllegalStateException illegalStateException = assertThrows(
+            IllegalStateException.class,
+            () -> taskManager.handleAssignment(Collections.emptyMap(), Collections.emptyMap())
+        );
+
+        assertEquals(illegalStateException.getMessage(), "Standby tasks should only be managed by the state updater");
+        verify(activeTaskCreator, standbyTaskCreator);
+    }
+
+    @Test
+    public void shouldAssignActiveTaskInTasksRegistryToUpdateInputPartitionsWithStateUpdaterEnabled() {
+        final StreamTask activeTaskToUpdateInputPartitions = statefulTask(taskId03, taskId03ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId03Partitions).build();
+        final Set<TopicPartition> newInputPartitions = taskId02Partitions;
+        final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
+        final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
+        when(tasks.allTasks()).thenReturn(mkSet(activeTaskToUpdateInputPartitions));
+        when(tasks.updateActiveTaskInputPartitions(activeTaskToUpdateInputPartitions, newInputPartitions)).thenReturn(true);
+        expect(activeTaskCreator.createTasks(consumer, Collections.emptyMap())).andReturn(emptySet());
+        expect(standbyTaskCreator.createTasks(Collections.emptyMap())).andReturn(emptySet());
+        replay(activeTaskCreator, standbyTaskCreator);
+
+        taskManager.handleAssignment(
+            mkMap(mkEntry(activeTaskToUpdateInputPartitions.id(), newInputPartitions)),
+            Collections.emptyMap()
+        );
+
+        verify(activeTaskCreator, standbyTaskCreator);
+        Mockito.verify(activeTaskToUpdateInputPartitions).updateInputPartitions(Mockito.eq(newInputPartitions), any());
+    }
+
+    @Test
+    public void shouldResumeActiveRunningTaskInTasksRegistryWithStateUpdaterEnabled() {
+        final StreamTask activeTaskToResume = statefulTask(taskId03, taskId03ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId03Partitions).build();
+        final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
+        final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
+        when(tasks.allTasks()).thenReturn(mkSet(activeTaskToResume));
+        expect(activeTaskCreator.createTasks(consumer, Collections.emptyMap())).andReturn(emptySet());
+        expect(standbyTaskCreator.createTasks(Collections.emptyMap())).andReturn(emptySet());
+        replay(activeTaskCreator, standbyTaskCreator);
+
+        taskManager.handleAssignment(
+            mkMap(mkEntry(activeTaskToResume.id(), activeTaskToResume.inputPartitions())),
+            Collections.emptyMap()
+        );
+
+        verify(activeTaskCreator, standbyTaskCreator);
+    }
+
+    @Test
+    public void shouldResumeActiveSuspendedTaskInTasksRegistryAndAddToStateUpdater() {
+        final StreamTask activeTaskToResume = statefulTask(taskId03, taskId03ChangelogPartitions)
+            .inState(State.SUSPENDED)
+            .withInputPartitions(taskId03Partitions).build();
+        final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
+        final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
+        when(tasks.allTasks()).thenReturn(mkSet(activeTaskToResume));
+        expect(activeTaskCreator.createTasks(consumer, Collections.emptyMap())).andReturn(emptySet());
+        expect(standbyTaskCreator.createTasks(Collections.emptyMap())).andReturn(emptySet());
+        replay(activeTaskCreator, standbyTaskCreator);
+
+        taskManager.handleAssignment(
+            mkMap(mkEntry(activeTaskToResume.id(), activeTaskToResume.inputPartitions())),
+            Collections.emptyMap()
+        );
+
+        verify(activeTaskCreator, standbyTaskCreator);
+        Mockito.verify(activeTaskToResume).resume();
+        Mockito.verify(stateUpdater).add(activeTaskToResume);
+        Mockito.verify(tasks).removeTask(activeTaskToResume);
+    }
+
+    @Test
+    public void shouldThrowDuringAssignmentIfStandbyTaskToUpdateInputPartitionsIsFoundInTasksRegistryWithStateUpdaterEnabled() {
+        final StandbyTask standbyTaskToUpdateInputPartitions = standbyTask(taskId02, taskId02ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId02Partitions).build();
+        final Set<TopicPartition> newInputPartitions = taskId03Partitions;
+        final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
+        final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
+        when(tasks.allTasks()).thenReturn(mkSet(standbyTaskToUpdateInputPartitions));
+        replay(activeTaskCreator, standbyTaskCreator);
+
+        final IllegalStateException illegalStateException = assertThrows(
+            IllegalStateException.class,
+            () -> taskManager.handleAssignment(
+                Collections.emptyMap(),
+                mkMap(mkEntry(standbyTaskToUpdateInputPartitions.id(), newInputPartitions))
+            )
+        );
+
+        assertEquals(illegalStateException.getMessage(), "Standby tasks should only be managed by the state updater");
+        verify(activeTaskCreator, standbyTaskCreator);
+    }
+
+    @Test
+    public void shouldAssignMultipleTasksInTasksRegistryWithStateUpdaterEnabled() {
+        final StreamTask activeTaskToClose = statefulTask(taskId03, taskId03ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId03Partitions).build();
+        final StreamTask activeTaskToCreate = statefulTask(taskId02, taskId02ChangelogPartitions)
+            .inState(State.CREATED)
+            .withInputPartitions(taskId02Partitions).build();
+        final TasksRegistry tasks = Mockito.mock(TasksRegistry.class);
+        final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
+        when(tasks.allTasks()).thenReturn(mkSet(activeTaskToClose));
+        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(activeTaskToClose.id());
+        expect(activeTaskCreator.createTasks(
+            consumer,
+            mkMap(mkEntry(activeTaskToCreate.id(), activeTaskToCreate.inputPartitions()))
+        )).andReturn(emptySet());
+        expect(standbyTaskCreator.createTasks(Collections.emptyMap())).andReturn(emptySet());
+        replay(activeTaskCreator, standbyTaskCreator);
+
+        taskManager.handleAssignment(
+            mkMap(mkEntry(activeTaskToCreate.id(), activeTaskToCreate.inputPartitions())),
+            Collections.emptyMap()
+        );
+
+        verify(activeTaskCreator, standbyTaskCreator);
+        Mockito.verify(activeTaskToClose).closeClean();
     }
 
     @Test
@@ -290,7 +730,7 @@ public class TaskManagerTest {
         when(tasks.drainPendingTaskToInit()).thenReturn(mkSet(task00, task01));
         taskManager = setUpTaskManager(StreamsConfigUtils.ProcessingMode.AT_LEAST_ONCE, tasks, true);
 
-        taskManager.checkStateUpdater(time.milliseconds(), noOpResetter -> { });
+        taskManager.checkStateUpdater(time.milliseconds(), noOpResetter);
 
         Mockito.verify(task00).initializeIfNeeded();
         Mockito.verify(task01).initializeIfNeeded();
@@ -299,7 +739,7 @@ public class TaskManagerTest {
     }
 
     @Test
-    public void shouldHandleRemovedTasksToRecycleFromStateUpdater() {
+    public void shouldRecycleTasksRemovedFromStateUpdater() {
         final StreamTask task00 = statefulTask(taskId00, taskId00ChangelogPartitions)
             .withInputPartitions(taskId00Partitions)
             .inState(State.RESTORING).build();
@@ -323,7 +763,7 @@ public class TaskManagerTest {
             .andStubReturn(task00Converted);
         replay(activeTaskCreator, standbyTaskCreator);
 
-        taskManager.checkStateUpdater(time.milliseconds(), noOpResetter -> { });
+        taskManager.checkStateUpdater(time.milliseconds(), noOpResetter);
 
         verify(activeTaskCreator, standbyTaskCreator);
         Mockito.verify(task00).suspend();
@@ -335,7 +775,7 @@ public class TaskManagerTest {
     }
 
     @Test
-    public void shouldHandleRemovedTasksToCloseFromStateUpdater() {
+    public void shouldCloseTasksRemovedFromStateUpdater() {
         final StreamTask task00 = statefulTask(taskId00, taskId00ChangelogPartitions)
             .withInputPartitions(taskId00Partitions)
             .inState(State.RESTORING).build();
@@ -352,7 +792,7 @@ public class TaskManagerTest {
         expectLastCall().once();
         replay(activeTaskCreator);
 
-        taskManager.checkStateUpdater(time.milliseconds(), noOpResetter -> { });
+        taskManager.checkStateUpdater(time.milliseconds(), noOpResetter);
 
         verify(activeTaskCreator);
         Mockito.verify(task00).suspend();
@@ -362,7 +802,7 @@ public class TaskManagerTest {
     }
 
     @Test
-    public void shouldHandleRemovedTasksToUpdateInputPartitionsFromStateUpdater() {
+    public void shouldUpdateInputPartitionsOfTasksRemovedFromStateUpdater() {
         final StreamTask task00 = statefulTask(taskId00, taskId00ChangelogPartitions)
             .withInputPartitions(taskId00Partitions)
             .inState(State.RESTORING).build();
@@ -377,7 +817,7 @@ public class TaskManagerTest {
         taskManager = setUpTaskManager(StreamsConfigUtils.ProcessingMode.AT_LEAST_ONCE, tasks, true);
         replay(topologyBuilder);
 
-        taskManager.checkStateUpdater(time.milliseconds(), noOpResetter -> { });
+        taskManager.checkStateUpdater(time.milliseconds(), noOpResetter);
 
         Mockito.verify(task00).updateInputPartitions(Mockito.eq(taskId02Partitions), anyMap());
         Mockito.verify(task00, never()).closeDirty();
@@ -390,7 +830,78 @@ public class TaskManagerTest {
     }
 
     @Test
-    public void shouldRemoveStatefulTaskWithRevokedInputPartitionsFromStateUpdaterOnRevocation() {
+    public void shouldSuspendRevokedTaskRemovedFromStateUpdater() {
+        final StreamTask statefulTask = statefulTask(taskId00, taskId00ChangelogPartitions)
+            .inState(State.RESTORING)
+            .withInputPartitions(taskId00Partitions).build();
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        when(tasks.removePendingTaskToRecycle(statefulTask.id())).thenReturn(null);
+        when(tasks.removePendingTaskToUpdateInputPartitions(statefulTask.id())).thenReturn(null);
+        when(tasks.removePendingActiveTaskToSuspend(statefulTask.id())).thenReturn(true);
+        when(stateUpdater.drainRemovedTasks()).thenReturn(mkSet(statefulTask));
+        final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
+        replay(consumer);
+
+        taskManager.checkStateUpdater(time.milliseconds(), noOpResetter);
+
+        verify(consumer);
+        Mockito.verify(statefulTask).suspend();
+        Mockito.verify(tasks).addTask(statefulTask);
+    }
+    @Test
+    public void shouldHandleMultipleRemovedTasksFromStateUpdater() {
+        final StreamTask taskToRecycle0 = statefulTask(taskId00, taskId00ChangelogPartitions)
+            .inState(State.RESTORING)
+            .withInputPartitions(taskId00Partitions).build();
+        final StandbyTask taskToRecycle1 = standbyTask(taskId01, taskId01ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId01Partitions).build();
+        final StandbyTask convertedTask0 = standbyTask(taskId00, taskId00ChangelogPartitions).build();
+        final StreamTask convertedTask1 = statefulTask(taskId01, taskId01ChangelogPartitions).build();
+        final StreamTask taskToClose = statefulTask(taskId02, taskId02ChangelogPartitions)
+            .inState(State.RESTORING)
+            .withInputPartitions(taskId02Partitions).build();
+        final StreamTask taskToUpdateInputPartitions = statefulTask(taskId03, taskId03ChangelogPartitions)
+            .inState(State.RESTORING)
+            .withInputPartitions(taskId03Partitions).build();
+        when(stateUpdater.drainRemovedTasks())
+            .thenReturn(mkSet(taskToRecycle0, taskToRecycle1, taskToClose, taskToUpdateInputPartitions));
+        when(stateUpdater.restoresActiveTasks()).thenReturn(true);
+        expect(activeTaskCreator.createActiveTaskFromStandby(eq(taskToRecycle1), eq(taskId01Partitions), eq(consumer)))
+            .andStubReturn(convertedTask1);
+        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(anyObject());
+        expectLastCall().times(2);
+        expect(standbyTaskCreator.createStandbyTaskFromActive(eq(taskToRecycle0), eq(taskId00Partitions)))
+            .andStubReturn(convertedTask0);
+        expect(consumer.assignment()).andReturn(emptySet()).anyTimes();
+        consumer.resume(anyObject());
+        expectLastCall().anyTimes();
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        when(tasks.removePendingTaskToCloseClean(taskToClose.id())).thenReturn(true);
+        when(tasks.removePendingTaskToRecycle(taskToRecycle0.id())).thenReturn(taskId00Partitions);
+        when(tasks.removePendingTaskToRecycle(taskToRecycle1.id())).thenReturn(taskId01Partitions);
+        when(tasks.removePendingTaskToRecycle(
+            argThat(taskId -> !taskId.equals(taskToRecycle0.id()) && !taskId.equals(taskToRecycle1.id())))
+        ).thenReturn(null);
+        when(tasks.removePendingTaskToUpdateInputPartitions(taskToUpdateInputPartitions.id())).thenReturn(taskId04Partitions);
+        final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
+        taskManager.setMainConsumer(consumer);
+        replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer);
+
+        taskManager.checkStateUpdater(time.milliseconds(), noOpResetter -> { });
+
+        verify(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer);
+        Mockito.verify(convertedTask0).initializeIfNeeded();
+        Mockito.verify(convertedTask1).initializeIfNeeded();
+        Mockito.verify(stateUpdater).add(convertedTask0);
+        Mockito.verify(stateUpdater).add(convertedTask1);
+        Mockito.verify(taskToClose).closeClean();
+        Mockito.verify(taskToUpdateInputPartitions).updateInputPartitions(Mockito.eq(taskId04Partitions), anyMap());
+        Mockito.verify(stateUpdater).add(taskToUpdateInputPartitions);
+    }
+
+    @Test
+    public void shouldAddActiveTaskWithRevokedInputPartitionsInStateUpdaterToPendingTasksToSuspend() {
         final StreamTask task = statefulTask(taskId00, taskId00ChangelogPartitions)
             .inState(State.RESTORING)
             .withInputPartitions(taskId00Partitions).build();
@@ -400,11 +911,11 @@ public class TaskManagerTest {
 
         taskManager.handleRevocation(task.inputPartitions());
 
-        Mockito.verify(tasks).addPendingTaskToCloseClean(task.id());
-        Mockito.verify(stateUpdater).remove(task.id());
+        Mockito.verify(tasks).addPendingActiveTaskToSuspend(task.id());
+        Mockito.verify(stateUpdater, never()).remove(task.id());
     }
 
-    public void shouldRemoveMultipleStatefulTaskWithRevokedInputPartitionsFromStateUpdaterOnRevocation() {
+    public void shouldAddMultipleActiveTasksWithRevokedInputPartitionsInStateUpdaterToPendingTasksToSuspend() {
         final StreamTask task1 = statefulTask(taskId00, taskId00ChangelogPartitions)
             .inState(State.RESTORING)
             .withInputPartitions(taskId00Partitions).build();
@@ -416,14 +927,12 @@ public class TaskManagerTest {
 
         taskManager.handleRevocation(union(HashSet::new, taskId00Partitions, taskId01Partitions));
 
-        Mockito.verify(tasks).addPendingTaskToCloseClean(task1.id());
-        Mockito.verify(tasks).addPendingTaskToCloseClean(task2.id());
-        Mockito.verify(stateUpdater).remove(task1.id());
-        Mockito.verify(stateUpdater).remove(task2.id());
+        Mockito.verify(tasks).addPendingActiveTaskToSuspend(task1.id());
+        Mockito.verify(tasks).addPendingActiveTaskToSuspend(task2.id());
     }
 
     @Test
-    public void shouldNotRemoveStatefulTaskWithoutRevokedInputPartitionsFromStateUpdaterOnRevocation() {
+    public void shouldNotAddActiveTaskWithoutRevokedInputPartitionsInStateUpdaterToPendingTasksToSuspend() {
         final StreamTask task = statefulTask(taskId00, taskId00ChangelogPartitions)
             .inState(State.RESTORING)
             .withInputPartitions(taskId00Partitions).build();
@@ -433,11 +942,11 @@ public class TaskManagerTest {
         taskManager.handleRevocation(taskId01Partitions);
 
         Mockito.verify(stateUpdater, never()).remove(task.id());
-        Mockito.verify(tasks, never()).addPendingTaskToCloseClean(task.id());
+        Mockito.verify(tasks, never()).addPendingActiveTaskToSuspend(task.id());
     }
 
     @Test
-    public void shouldNotRemoveStandbyTaskFromStateUpdaterOnRevocation() {
+    public void shouldNotRevokeStandbyTaskInStateUpdaterOnRevocation() {
         final StandbyTask task = standbyTask(taskId00, taskId00ChangelogPartitions)
             .inState(State.RESTORING)
             .withInputPartitions(taskId00Partitions).build();
@@ -447,7 +956,7 @@ public class TaskManagerTest {
         taskManager.handleRevocation(taskId00Partitions);
 
         Mockito.verify(stateUpdater, never()).remove(task.id());
-        Mockito.verify(tasks, never()).addPendingTaskToCloseClean(task.id());
+        Mockito.verify(tasks, never()).addPendingActiveTaskToSuspend(task.id());
     }
 
     @Test
@@ -483,58 +992,6 @@ public class TaskManagerTest {
         return taskManager;
     }
 
-    @Test
-    public void shouldHandleRemovedTasksFromStateUpdater() {
-        final StreamTask taskToRecycle0 = statefulTask(taskId00, taskId00ChangelogPartitions)
-            .inState(State.RESTORING)
-            .withInputPartitions(taskId00Partitions).build();
-        final StandbyTask taskToRecycle1 = standbyTask(taskId01, taskId01ChangelogPartitions)
-            .inState(State.RUNNING)
-            .withInputPartitions(taskId01Partitions).build();
-        final StandbyTask convertedTask0 = standbyTask(taskId00, taskId00ChangelogPartitions).build();
-        final StreamTask convertedTask1 = statefulTask(taskId01, taskId01ChangelogPartitions).build();
-        final StreamTask taskToClose = statefulTask(taskId02, taskId02ChangelogPartitions)
-            .inState(State.RESTORING)
-            .withInputPartitions(taskId02Partitions).build();
-        final StreamTask taskToUpdateInputPartitions = statefulTask(taskId03, taskId03ChangelogPartitions)
-            .inState(State.RESTORING)
-            .withInputPartitions(taskId03Partitions).build();
-        when(stateUpdater.drainRemovedTasks())
-            .thenReturn(mkSet(taskToRecycle0, taskToRecycle1, taskToClose, taskToUpdateInputPartitions));
-        when(stateUpdater.restoresActiveTasks()).thenReturn(true);
-        expect(activeTaskCreator.createActiveTaskFromStandby(eq(taskToRecycle1), eq(taskId01Partitions), eq(consumer)))
-            .andStubReturn(convertedTask1);
-        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(anyObject());
-        expectLastCall().times(2);
-        expect(standbyTaskCreator.createStandbyTaskFromActive(eq(taskToRecycle0), eq(taskId00Partitions)))
-            .andStubReturn(convertedTask0);
-        expect(consumer.assignment()).andReturn(emptySet()).anyTimes();
-        consumer.resume(anyObject());
-        expectLastCall().anyTimes();
-        final TasksRegistry tasks = mock(TasksRegistry.class);
-        when(tasks.removePendingTaskToCloseClean(taskToClose.id())).thenReturn(true);
-        when(tasks.removePendingTaskToRecycle(taskToRecycle0.id())).thenReturn(taskId00Partitions);
-        when(tasks.removePendingTaskToRecycle(taskToRecycle1.id())).thenReturn(taskId01Partitions);
-        when(tasks.removePendingTaskToRecycle(
-            argThat(taskId -> !taskId.equals(taskToRecycle0.id()) && !taskId.equals(taskToRecycle1.id())))
-        ).thenReturn(null);
-        when(tasks.removePendingTaskToUpdateInputPartitions(taskToUpdateInputPartitions.id())).thenReturn(taskId04Partitions);
-        final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
-        taskManager.setMainConsumer(consumer);
-        replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer);
-
-        taskManager.checkStateUpdater(time.milliseconds(), noOpResetter -> { });
-
-        verify(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer);
-        Mockito.verify(convertedTask0).initializeIfNeeded();
-        Mockito.verify(convertedTask1).initializeIfNeeded();
-        Mockito.verify(stateUpdater).add(convertedTask0);
-        Mockito.verify(stateUpdater).add(convertedTask1);
-        Mockito.verify(taskToClose).closeClean();
-        Mockito.verify(taskToUpdateInputPartitions).updateInputPartitions(Mockito.eq(taskId04Partitions), anyMap());
-        Mockito.verify(stateUpdater).add(taskToUpdateInputPartitions);
-    }
-
     @Test
     public void shouldTransitRestoredTaskToRunning() {
         final StreamTask task = statefulTask(taskId00, taskId00ChangelogPartitions)
@@ -774,13 +1231,37 @@ public class TaskManagerTest {
         when(stateUpdater.drainRestoredActiveTasks(any(Duration.class))).thenReturn(mkSet(statefulTask));
         when(stateUpdater.restoresActiveTasks()).thenReturn(true);
         final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
-        replay(topologyBuilder);
+        consumer.resume(statefulTask.inputPartitions());
+        replay(consumer, topologyBuilder);
 
         taskManager.checkStateUpdater(time.milliseconds(), noOpResetter);
 
+        verify(consumer);
         Mockito.verify(statefulTask).updateInputPartitions(Mockito.eq(taskId01Partitions), anyMap());
-        Mockito.verify(statefulTask, never()).closeDirty();
-        Mockito.verify(statefulTask, never()).closeClean();
+        Mockito.verify(statefulTask).completeRestoration(noOpResetter);
+        Mockito.verify(statefulTask).clearTaskTimeout();
+        Mockito.verify(tasks).addTask(statefulTask);
+    }
+
+    @Test
+    public void shouldSuspendRestoredTaskIfRevoked() {
+        final StreamTask statefulTask = statefulTask(taskId00, taskId00ChangelogPartitions)
+            .inState(State.RESTORING)
+            .withInputPartitions(taskId00Partitions).build();
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        when(tasks.removePendingTaskToRecycle(statefulTask.id())).thenReturn(null);
+        when(tasks.removePendingTaskToUpdateInputPartitions(statefulTask.id())).thenReturn(null);
+        when(tasks.removePendingActiveTaskToSuspend(statefulTask.id())).thenReturn(true);
+        when(stateUpdater.drainRestoredActiveTasks(any(Duration.class))).thenReturn(mkSet(statefulTask));
+        when(stateUpdater.restoresActiveTasks()).thenReturn(true);
+        final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
+        replay(consumer);
+
+        taskManager.checkStateUpdater(time.milliseconds(), noOpResetter);
+
+        verify(consumer);
+        Mockito.verify(statefulTask).suspend();
+        Mockito.verify(tasks).addTask(statefulTask);
     }
 
     @Test
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TasksTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TasksTest.java
index be1f5c4972f..d303eb4f603 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TasksTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TasksTest.java
@@ -23,6 +23,7 @@ import org.junit.jupiter.api.Test;
 
 import java.util.Collections;
 import java.util.HashSet;
+import java.util.Set;
 
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
@@ -31,6 +32,9 @@ import static org.apache.kafka.test.StreamsTestUtils.TaskBuilder.standbyTask;
 import static org.apache.kafka.test.StreamsTestUtils.TaskBuilder.statefulTask;
 import static org.apache.kafka.test.StreamsTestUtils.TaskBuilder.statelessTask;
 import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertNull;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
 public class TasksTest {
@@ -43,11 +47,10 @@ public class TasksTest {
     private final static TaskId TASK_0_1 = new TaskId(0, 1);
     private final static TaskId TASK_1_0 = new TaskId(1, 0);
 
-    private final LogContext logContext = new LogContext();
+    private final Tasks tasks = new Tasks(new LogContext());
 
     @Test
     public void shouldKeepAddedTasks() {
-        final Tasks tasks = new Tasks(logContext);
         final StreamTask statefulTask = statefulTask(TASK_0_0, mkSet(TOPIC_PARTITION_A_0)).build();
         final StandbyTask standbyTask = standbyTask(TASK_0_1, mkSet(TOPIC_PARTITION_A_1)).build();
         final StreamTask statelessTask = statelessTask(TASK_1_0).build();
@@ -77,8 +80,6 @@ public class TasksTest {
 
     @Test
     public void shouldDrainPendingTasksToCreate() {
-        final Tasks tasks = new Tasks(logContext);
-
         tasks.addPendingActiveTasksToCreate(mkMap(
             mkEntry(new TaskId(0, 0, "A"), mkSet(TOPIC_PARTITION_A_0)),
             mkEntry(new TaskId(0, 1, "A"), mkSet(TOPIC_PARTITION_A_1)),
@@ -108,4 +109,141 @@ public class TasksTest {
         assertEquals(Collections.emptyMap(), tasks.drainPendingActiveTasksForTopologies(mkSet("B")));
         assertEquals(Collections.emptyMap(), tasks.drainPendingStandbyTasksForTopologies(mkSet("B")));
     }
+
+    @Test
+    public void shouldAddAndRemovePendingTaskToRecycle() {
+        final Set<TopicPartition> expectedInputPartitions = mkSet(TOPIC_PARTITION_A_0);
+        assertNull(tasks.removePendingTaskToRecycle(TASK_0_0));
+
+        tasks.addPendingTaskToRecycle(TASK_0_0, expectedInputPartitions);
+        final Set<TopicPartition> actualInputPartitions = tasks.removePendingTaskToRecycle(TASK_0_0);
+
+        assertEquals(expectedInputPartitions, actualInputPartitions);
+        assertNull(tasks.removePendingTaskToRecycle(TASK_0_0));
+    }
+
+    @Test
+    public void shouldAddAndRemovePendingTaskToUpdateInputPartitions() {
+        final Set<TopicPartition> expectedInputPartitions = mkSet(TOPIC_PARTITION_A_0);
+        assertNull(tasks.removePendingTaskToUpdateInputPartitions(TASK_0_0));
+
+        tasks.addPendingTaskToUpdateInputPartitions(TASK_0_0, expectedInputPartitions);
+        final Set<TopicPartition> actualInputPartitions = tasks.removePendingTaskToUpdateInputPartitions(TASK_0_0);
+
+        assertEquals(expectedInputPartitions, actualInputPartitions);
+        assertNull(tasks.removePendingTaskToUpdateInputPartitions(TASK_0_0));
+    }
+
+    @Test
+    public void shouldAddAndRemovePendingTaskToCloseClean() {
+        assertFalse(tasks.removePendingTaskToCloseClean(TASK_0_0));
+
+        tasks.addPendingTaskToCloseClean(TASK_0_0);
+
+        assertTrue(tasks.removePendingTaskToCloseClean(TASK_0_0));
+        assertFalse(tasks.removePendingTaskToCloseClean(TASK_0_0));
+    }
+
+    @Test
+    public void shouldAddAndRemovePendingTaskToCloseDirty() {
+        assertFalse(tasks.removePendingTaskToCloseDirty(TASK_0_0));
+
+        tasks.addPendingTaskToCloseDirty(TASK_0_0);
+
+        assertTrue(tasks.removePendingTaskToCloseDirty(TASK_0_0));
+        assertFalse(tasks.removePendingTaskToCloseDirty(TASK_0_0));
+    }
+
+    @Test
+    public void shouldAddAndRemovePendingTaskToSuspend() {
+        assertFalse(tasks.removePendingActiveTaskToSuspend(TASK_0_0));
+
+        tasks.addPendingActiveTaskToSuspend(TASK_0_0);
+
+        assertTrue(tasks.removePendingActiveTaskToSuspend(TASK_0_0));
+        assertFalse(tasks.removePendingActiveTaskToSuspend(TASK_0_0));
+    }
+
+    @Test
+    public void onlyRemovePendingTaskToRecycleShouldRemoveTaskFromPendingUpdateActions() {
+        tasks.addPendingTaskToRecycle(TASK_0_0, mkSet(TOPIC_PARTITION_A_0));
+
+        assertFalse(tasks.removePendingTaskToCloseDirty(TASK_0_0));
+        assertFalse(tasks.removePendingTaskToCloseClean(TASK_0_0));
+        assertFalse(tasks.removePendingActiveTaskToSuspend(TASK_0_0));
+        assertNull(tasks.removePendingTaskToUpdateInputPartitions(TASK_0_0));
+        assertNotNull(tasks.removePendingTaskToRecycle(TASK_0_0));
+    }
+
+    @Test
+    public void onlyRemovePendingTaskToUpdateInputPartitionsShouldRemoveTaskFromPendingUpdateActions() {
+        tasks.addPendingTaskToUpdateInputPartitions(TASK_0_0, mkSet(TOPIC_PARTITION_A_0));
+
+        assertFalse(tasks.removePendingTaskToCloseDirty(TASK_0_0));
+        assertFalse(tasks.removePendingTaskToCloseClean(TASK_0_0));
+        assertFalse(tasks.removePendingActiveTaskToSuspend(TASK_0_0));
+        assertNull(tasks.removePendingTaskToRecycle(TASK_0_0));
+        assertNotNull(tasks.removePendingTaskToUpdateInputPartitions(TASK_0_0));
+    }
+
+    @Test
+    public void onlyRemovePendingTaskToCloseCleanShouldRemoveTaskFromPendingUpdateActions() {
+        tasks.addPendingTaskToCloseClean(TASK_0_0);
+
+        assertFalse(tasks.removePendingTaskToCloseDirty(TASK_0_0));
+        assertFalse(tasks.removePendingActiveTaskToSuspend(TASK_0_0));
+        assertNull(tasks.removePendingTaskToRecycle(TASK_0_0));
+        assertNull(tasks.removePendingTaskToUpdateInputPartitions(TASK_0_0));
+        assertTrue(tasks.removePendingTaskToCloseClean(TASK_0_0));
+    }
+
+    @Test
+    public void onlyRemovePendingTaskToCloseDirtyShouldRemoveTaskFromPendingUpdateActions() {
+        tasks.addPendingTaskToCloseDirty(TASK_0_0);
+
+        assertFalse(tasks.removePendingTaskToCloseClean(TASK_0_0));
+        assertFalse(tasks.removePendingActiveTaskToSuspend(TASK_0_0));
+        assertNull(tasks.removePendingTaskToRecycle(TASK_0_0));
+        assertNull(tasks.removePendingTaskToUpdateInputPartitions(TASK_0_0));
+        assertTrue(tasks.removePendingTaskToCloseDirty(TASK_0_0));
+    }
+
+    @Test
+    public void onlyRemovePendingTaskToSuspendShouldRemoveTaskFromPendingUpdateActions() {
+        tasks.addPendingActiveTaskToSuspend(TASK_0_0);
+
+        assertFalse(tasks.removePendingTaskToCloseClean(TASK_0_0));
+        assertFalse(tasks.removePendingTaskToCloseDirty(TASK_0_0));
+        assertNull(tasks.removePendingTaskToRecycle(TASK_0_0));
+        assertNull(tasks.removePendingTaskToUpdateInputPartitions(TASK_0_0));
+        assertTrue(tasks.removePendingActiveTaskToSuspend(TASK_0_0));
+    }
+
+    @Test
+    public void shouldOnlyKeepLastUpdateAction() {
+        tasks.addPendingTaskToRecycle(TASK_0_0, mkSet(TOPIC_PARTITION_A_0));
+        tasks.addPendingTaskToUpdateInputPartitions(TASK_0_0, mkSet(TOPIC_PARTITION_A_0));
+        assertNull(tasks.removePendingTaskToRecycle(TASK_0_0));
+        assertNotNull(tasks.removePendingTaskToUpdateInputPartitions(TASK_0_0));
+
+        tasks.addPendingTaskToUpdateInputPartitions(TASK_0_0, mkSet(TOPIC_PARTITION_A_0));
+        tasks.addPendingTaskToCloseClean(TASK_0_0);
+        assertNull(tasks.removePendingTaskToUpdateInputPartitions(TASK_0_0));
+        assertTrue(tasks.removePendingTaskToCloseClean(TASK_0_0));
+
+        tasks.addPendingTaskToCloseClean(TASK_0_0);
+        tasks.addPendingTaskToCloseDirty(TASK_0_0);
+        assertFalse(tasks.removePendingTaskToCloseClean(TASK_0_0));
+        assertTrue(tasks.removePendingTaskToCloseDirty(TASK_0_0));
+
+        tasks.addPendingTaskToCloseDirty(TASK_0_0);
+        tasks.addPendingActiveTaskToSuspend(TASK_0_0);
+        assertFalse(tasks.removePendingTaskToCloseDirty(TASK_0_0));
+        assertTrue(tasks.removePendingActiveTaskToSuspend(TASK_0_0));
+
+        tasks.addPendingActiveTaskToSuspend(TASK_0_0);
+        tasks.addPendingTaskToRecycle(TASK_0_0, mkSet(TOPIC_PARTITION_A_0));
+        assertFalse(tasks.removePendingActiveTaskToSuspend(TASK_0_0));
+        assertNotNull(tasks.removePendingTaskToRecycle(TASK_0_0));
+    }
 }
\ No newline at end of file