You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by gu...@apache.org on 2022/07/28 00:29:13 UTC

[kafka] branch trunk updated: KAFKA-10199: Further refactor task lifecycle management (#12439)

This is an automated email from the ASF dual-hosted git repository.

guozhang pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 06f47c3b51 KAFKA-10199: Further refactor task lifecycle management (#12439)
06f47c3b51 is described below

commit 06f47c3b517f5c96c1bd981369843f3483947197
Author: Guozhang Wang <wa...@gmail.com>
AuthorDate: Wed Jul 27 17:29:05 2022 -0700

    KAFKA-10199: Further refactor task lifecycle management (#12439)
    
    1. Consolidate the task recycle procedure into a single function within the task. The current procedure now becomes: a) task.recycleStateAndConvert, at end of it the task is in closed while its stateManager is retained, and the manager type has been converted; 2) create the new task with old task's fields and the stateManager inside the creators.
    2. Move the task execution related metadata into the corresponding TaskExecutionMetadata class, including the task idle related metadata (e.g. successfully processed tasks); reduce the number of params needed for TaskExecutor as well as Tasks.
    3. Move the task execution related fields (embedded producer and consumer) and task creators out of Tasks and migrated into TaskManager. Now the Tasks is only a bookkeeping place without any task mutation logic.
    4. When adding tests, I realized that we should not add task to state updater right after creation, since it was not initialized yet, while state updater would validate that the task's state is already restoring / running. So I updated that logic while adding unit tests.
    
    Reviewers: Bruno Cadonna <ca...@apache.org>
---
 .../streams/processor/internals/AbstractTask.java  |   5 +-
 .../processor/internals/ActiveTaskCreator.java     |  36 ++-
 .../streams/processor/internals/StandbyTask.java   |  51 +---
 .../processor/internals/StandbyTaskCreator.java    |  32 ++-
 .../streams/processor/internals/StreamTask.java    |  52 +---
 .../kafka/streams/processor/internals/Task.java    |   5 +-
 .../processor/internals/TaskExecutionMetadata.java |  36 ++-
 .../streams/processor/internals/TaskExecutor.java  |  49 ++--
 .../streams/processor/internals/TaskManager.java   | 188 ++++++++-----
 .../kafka/streams/processor/internals/Tasks.java   | 203 ++++----------
 .../processor/internals/TopologyMetadata.java      |   4 +-
 .../processor/internals/StandbyTaskTest.java       |   6 +-
 .../processor/internals/StreamTaskTest.java        |  12 +-
 .../processor/internals/StreamThreadTest.java      |   1 +
 .../internals/TaskExecutionMetadataTest.java       |  11 +-
 .../processor/internals/TaskExecutorTest.java      |   5 +-
 .../processor/internals/TaskManagerTest.java       | 293 +++++++++++++--------
 .../streams/processor/internals/TasksTest.java     |  76 +-----
 18 files changed, 512 insertions(+), 553 deletions(-)

diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractTask.java
index 8ffbf9cd2e..a88d89fc33 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractTask.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractTask.java
@@ -43,10 +43,11 @@ public abstract class AbstractTask implements Task {
     private Task.State state = CREATED;
     private long deadlineMs = NO_DEADLINE;
 
-    protected Set<TopicPartition> inputPartitions;
     protected final Logger log;
-    protected final LogContext logContext;
     protected final String logPrefix;
+    protected final LogContext logContext;
+
+    protected Set<TopicPartition> inputPartitions;
 
     /**
      * If the checkpoint has not been loaded from the file yet (null), then we should not overwrite the checkpoint;
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
index ef21a51c99..46455111db 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
@@ -135,10 +135,8 @@ class ActiveTaskCreator {
         return threadProducer;
     }
 
-    // TODO: change return type to `StreamTask`
     public Collection<Task> createTasks(final Consumer<byte[], byte[]> consumer,
-                                 final Map<TaskId, Set<TopicPartition>> tasksToBeCreated) {
-        // TODO: change type to `StreamTask`
+                                        final Map<TaskId, Set<TopicPartition>> tasksToBeCreated) {
         final List<Task> createdTasks = new ArrayList<>();
 
         for (final Map.Entry<TaskId, Set<TopicPartition>> newTaskAndPartitions : tasksToBeCreated.entrySet()) {
@@ -211,13 +209,39 @@ class ActiveTaskCreator {
         );
     }
 
+    /*
+     * TODO: we pass in the new input partitions to validate if they still match,
+     *       in the future we when we have fixed partitions -> tasks mapping,
+     *       we should always reuse the input partition and hence no need validations
+     */
     StreamTask createActiveTaskFromStandby(final StandbyTask standbyTask,
                                            final Set<TopicPartition> inputPartitions,
                                            final Consumer<byte[], byte[]> consumer) {
+        if (!inputPartitions.equals(standbyTask.inputPartitions)) {
+            log.warn("Detected unmatched input partitions for task {} when recycling it from standby to active", standbyTask.id);
+        }
+
+        standbyTask.prepareRecycle();
+        standbyTask.stateMgr.transitionTaskType(Task.TaskType.ACTIVE);
+
         final RecordCollector recordCollector = createRecordCollector(standbyTask.id, getLogContext(standbyTask.id), standbyTask.topology);
-        final StreamTask task = standbyTask.recycle(time, cache, recordCollector, inputPartitions, consumer);
+        final StreamTask task = new StreamTask(
+            standbyTask.id,
+            inputPartitions,
+            standbyTask.topology,
+            consumer,
+            standbyTask.config,
+            streamsMetrics,
+            stateDirectory,
+            cache,
+            time,
+            standbyTask.stateMgr,
+            recordCollector,
+            standbyTask.processorContext,
+            standbyTask.logContext
+        );
 
-        log.trace("Created active task {} with assigned partitions {}", task.id, inputPartitions);
+        log.trace("Created active task {} from recycled standby task with assigned partitions {}", task.id, inputPartitions);
         createTaskSensor.record();
         return task;
     }
@@ -228,7 +252,7 @@ class ActiveTaskCreator {
                                         final LogContext logContext,
                                         final ProcessorTopology topology,
                                         final ProcessorStateManager stateManager,
-                                        final InternalProcessorContext context) {
+                                        final InternalProcessorContext<Object, Object> context) {
         final RecordCollector recordCollector = createRecordCollector(taskId, logContext, topology);
 
         final StreamTask task = new StreamTask(
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java
index 102a767495..bb7aef1dcd 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java
@@ -16,12 +16,10 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
-import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.clients.consumer.OffsetAndMetadata;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.metrics.Sensor;
-import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.StreamsMetrics;
 import org.apache.kafka.streams.errors.StreamsException;
@@ -41,11 +39,13 @@ import java.util.Set;
  * A StandbyTask
  */
 public class StandbyTask extends AbstractTask implements Task {
-    private final Sensor closeTaskSensor;
     private final boolean eosEnabled;
-    private final InternalProcessorContext processorContext;
+    private final Sensor closeTaskSensor;
     private final StreamsMetricsImpl streamsMetrics;
 
+    @SuppressWarnings("rawtypes")
+    protected final InternalProcessorContext processorContext;
+
     /**
      * @param id              the ID of this task
      * @param inputPartitions input topic partitions, used for thread metadata only
@@ -55,6 +55,7 @@ public class StandbyTask extends AbstractTask implements Task {
      * @param stateMgr        the {@link ProcessorStateManager} for this task
      * @param stateDirectory  the {@link StateDirectory} created by the thread
      */
+    @SuppressWarnings("rawtypes")
     StandbyTask(final TaskId id,
                 final Set<TopicPartition> inputPartitions,
                 final ProcessorTopology topology,
@@ -222,7 +223,7 @@ public class StandbyTask extends AbstractTask implements Task {
     }
 
     @Override
-    public void closeCleanAndRecycleState() {
+    public void prepareRecycle() {
         streamsMetrics.removeAllTaskLevelSensors(Thread.currentThread().getName(), id.toString());
         if (state() == State.SUSPENDED) {
             stateMgr.recycle();
@@ -233,44 +234,7 @@ public class StandbyTask extends AbstractTask implements Task {
         closeTaskSensor.record();
         transitionTo(State.CLOSED);
 
-        log.info("Closed clean and recycled state");
-    }
-
-    /**
-     * Create an active task from this standby task without closing and re-initializing the state stores.
-     * The task should have been in suspended state when calling this function
-     *
-     * TODO: we should be able to not need the input partitions as input param in future but always reuse
-     *       the task's input partitions when we have fixed partitions -> tasks mapping
-     */
-    public StreamTask recycle(final Time time,
-                              final ThreadCache cache,
-                              final RecordCollector recordCollector,
-                              final Set<TopicPartition> inputPartitions,
-                              final Consumer<byte[], byte[]> mainConsumer) {
-        if (!inputPartitions.equals(this.inputPartitions)) {
-            log.warn("Detected unmatched input partitions for task {} when recycling it from active to standby", id);
-        }
-
-        stateMgr.transitionTaskType(TaskType.ACTIVE);
-
-        log.debug("Recycling standby task {} to active", id);
-
-        return new StreamTask(
-            id,
-            inputPartitions,
-            topology,
-            mainConsumer,
-            config,
-            streamsMetrics,
-            stateDirectory,
-            cache,
-            time,
-            stateMgr,
-            recordCollector,
-            processorContext,
-            logContext
-        );
+        log.info("Closed and recycled state, and converted type to active");
     }
 
     private void close(final boolean clean) {
@@ -342,6 +306,7 @@ public class StandbyTask extends AbstractTask implements Task {
         throw new IllegalStateException("Attempted to add records to task " + id() + " for invalid input partition " + partition);
     }
 
+    @SuppressWarnings("rawtypes")
     InternalProcessorContext processorContext() {
         return processorContext;
     }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java
index 59e2aa0285..2f48cdb67f 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java
@@ -67,9 +67,7 @@ class StandbyTaskCreator {
         );
     }
 
-    // TODO: change return type to `StandbyTask`
     Collection<Task> createTasks(final Map<TaskId, Set<TopicPartition>> tasksToBeCreated) {
-        // TODO: change type to `StandbyTask`
         final List<Task> createdTasks = new ArrayList<>();
 
         for (final Map.Entry<TaskId, Set<TopicPartition>> newTaskAndPartitions : tasksToBeCreated.entrySet()) {
@@ -109,11 +107,33 @@ class StandbyTaskCreator {
         return createdTasks;
     }
 
+    /*
+     * TODO: we pass in the new input partitions to validate if they still match,
+     *       in the future we when we have fixed partitions -> tasks mapping,
+     *       we should always reuse the input partition and hence no need validations
+     */
     StandbyTask createStandbyTaskFromActive(final StreamTask streamTask,
                                             final Set<TopicPartition> inputPartitions) {
-        final StandbyTask task = streamTask.recycle(inputPartitions);
+        if (!inputPartitions.equals(streamTask.inputPartitions)) {
+            log.warn("Detected unmatched input partitions for task {} when recycling it from active to standby", streamTask.id);
+        }
+
+        streamTask.prepareRecycle();
+        streamTask.stateMgr.transitionTaskType(Task.TaskType.STANDBY);
+
+        final StandbyTask task = new StandbyTask(
+            streamTask.id,
+            inputPartitions,
+            streamTask.topology,
+            streamTask.config,
+            streamsMetrics,
+            streamTask.stateMgr,
+            stateDirectory,
+            dummyCache,
+            streamTask.processorContext
+        );
 
-        log.trace("Created task {} with assigned partitions {}", task.id, inputPartitions);
+        log.trace("Created standby task {} from recycled active task with assigned partitions {}", task.id, inputPartitions);
         createTaskSensor.record();
         return task;
     }
@@ -122,7 +142,7 @@ class StandbyTaskCreator {
                                   final Set<TopicPartition> inputPartitions,
                                   final ProcessorTopology topology,
                                   final ProcessorStateManager stateManager,
-                                  final InternalProcessorContext context) {
+                                  final InternalProcessorContext<Object, Object> context) {
         final StandbyTask task = new StandbyTask(
             taskId,
             inputPartitions,
@@ -135,7 +155,7 @@ class StandbyTaskCreator {
             context
         );
 
-        log.trace("Created task {} with assigned partitions {}", taskId, inputPartitions);
+        log.trace("Created standby task {} with assigned partitions {}", taskId, inputPartitions);
         createTaskSensor.record();
         return task;
     }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
index 0c7ea49e76..07c4494225 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
@@ -83,7 +83,6 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
     private final Map<TopicPartition, Long> committedOffsets;
     private final Map<TopicPartition, Long> highWatermark;
     private final Set<TopicPartition> resetOffsetsForPartitions;
-    private Optional<Long> timeCurrentIdlingStarted;
     private final PunctuationQueue streamTimePunctuationQueue;
     private final PunctuationQueue systemTimePunctuationQueue;
     private final StreamsMetricsImpl streamsMetrics;
@@ -97,15 +96,16 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
     private final Sensor bufferedRecordsSensor;
     private final Map<String, Sensor> e2eLatencySensors = new HashMap<>();
 
-    @SuppressWarnings("rawtypes")
-    private final InternalProcessorContext processorContext;
-
     private final RecordQueueCreator recordQueueCreator;
 
+    @SuppressWarnings("rawtypes")
+    protected final InternalProcessorContext processorContext;
+
     private StampedRecord record;
     private boolean commitNeeded = false;
     private boolean commitRequested = false;
     private boolean hasPendingTxCommit = false;
+    private Optional<Long> timeCurrentIdlingStarted;
 
     @SuppressWarnings("rawtypes")
     public StreamTask(final TaskId id,
@@ -544,7 +544,7 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
     }
 
     @Override
-    public void closeCleanAndRecycleState() {
+    public void prepareRecycle() {
         validateClean();
         removeAllSensors();
         clearCommitStatuses();
@@ -566,49 +566,9 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
         }
 
         closeTaskSensor.record();
-
         transitionTo(State.CLOSED);
 
-        log.info("Closed clean and recycled state");
-    }
-
-    /**
-     * Create a standby task from this active task without closing and re-initializing the state stores.
-     * The task should have been in suspended state when calling this function
-     *
-     * TODO: we should be able to not need the input partitions as input param in future but always reuse
-     *       the task's input partitions when we have fixed partitions -> tasks mapping
-     */
-    public StandbyTask recycle(final Set<TopicPartition> inputPartitions) {
-        if (state() != Task.State.CLOSED) {
-            throw new IllegalStateException("Attempted to convert an active task that's not closed: " + id);
-        }
-
-        if (!inputPartitions.equals(this.inputPartitions)) {
-            log.warn("Detected unmatched input partitions for task {} when recycling it from active to standby", id);
-        }
-
-        stateMgr.transitionTaskType(TaskType.STANDBY);
-
-        final ThreadCache dummyCache = new ThreadCache(
-            new LogContext(String.format("stream-thread [%s] ", Thread.currentThread().getName())),
-            0,
-            streamsMetrics
-        );
-
-        log.debug("Recycling active task {} to standby", id);
-
-        return new StandbyTask(
-            id,
-            inputPartitions,
-            topology,
-            config,
-            streamsMetrics,
-            stateMgr,
-            stateDirectory,
-            dummyCache,
-            processorContext
-        );
+        log.info("Closed and recycled state, and converted type to standby");
     }
 
     /**
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java
index 5402c8e761..20f0343b2a 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java
@@ -155,10 +155,9 @@ public interface Task {
     void revive();
 
     /**
-     * Attempt a clean close but do not close the underlying state
+     * Close the task except the state, so that the states can be later recycled
      */
-    void closeCleanAndRecycleState();
-
+    void prepareRecycle();
 
     // runtime methods (using in RUNNING state)
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskExecutionMetadata.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskExecutionMetadata.java
index 48ea76b84b..bd1515c7bc 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskExecutionMetadata.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskExecutionMetadata.java
@@ -17,8 +17,11 @@
 package org.apache.kafka.streams.processor.internals;
 
 import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.streams.internals.StreamsConfigUtils.ProcessingMode;
 import org.apache.kafka.streams.processor.TaskId;
 
+import java.util.Collection;
+import java.util.HashSet;
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
@@ -36,12 +39,25 @@ public class TaskExecutionMetadata {
 
     private final boolean hasNamedTopologies;
     private final Set<String> pausedTopologies;
+    private final ProcessingMode processingMode;
+    private final Collection<Task> successfullyProcessed = new HashSet<>();
     // map of topologies experiencing errors/currently under backoff
     private final ConcurrentHashMap<String, NamedTopologyMetadata> topologyNameToErrorMetadata = new ConcurrentHashMap<>();
 
-    public TaskExecutionMetadata(final Set<String> allTopologyNames, final Set<String> pausedTopologies) {
+    public TaskExecutionMetadata(final Set<String> allTopologyNames,
+                                 final Set<String> pausedTopologies,
+                                 final ProcessingMode processingMode) {
         this.hasNamedTopologies = !(allTopologyNames.size() == 1 && allTopologyNames.contains(UNNAMED_TOPOLOGY));
         this.pausedTopologies = pausedTopologies;
+        this.processingMode = processingMode;
+    }
+
+    public boolean hasNamedTopologies() {
+        return hasNamedTopologies;
+    }
+
+    public ProcessingMode processingMode() {
+        return processingMode;
     }
 
     public boolean canProcessTask(final Task task, final long now) {
@@ -77,6 +93,22 @@ public class TaskExecutionMetadata {
         }
     }
 
+    Collection<Task> successfullyProcessed() {
+        return successfullyProcessed;
+    }
+
+    void addToSuccessfullyProcessed(final Task task) {
+        successfullyProcessed.add(task);
+    }
+
+    void removeTaskFromSuccessfullyProcessedBeforeClosing(final Task task) {
+        successfullyProcessed.remove(task);
+    }
+
+    void clearSuccessfullyProcessed() {
+        successfullyProcessed.clear();
+    }
+
     private class NamedTopologyMetadata {
         private final Logger log;
         private final Map<TaskId, Long> tasksToErrorTime = new ConcurrentHashMap<>();
@@ -109,7 +141,7 @@ public class TaskExecutionMetadata {
         }
 
         public synchronized void registerTaskError(final Task task, final Throwable t, final long now) {
-            log.info("Begin backoff for unhealthy task {} at t={}", task.id(), now);
+            log.info("Begin backoff for unhealthy task {} at t={} due to {}", task.id(), now, t.getClass().getName());
             tasksToErrorTime.put(task.id(), now);
         }
     }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskExecutor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskExecutor.java
index 5aa6eb1fe2..2827fe1e90 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskExecutor.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskExecutor.java
@@ -26,7 +26,6 @@ import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.errors.TaskCorruptedException;
 import org.apache.kafka.streams.errors.TaskMigratedException;
-import org.apache.kafka.streams.internals.StreamsConfigUtils.ProcessingMode;
 import org.apache.kafka.streams.processor.TaskId;
 
 import java.util.Collection;
@@ -47,21 +46,17 @@ import static org.apache.kafka.streams.internals.StreamsConfigUtils.ProcessingMo
 public class TaskExecutor {
 
     private final Logger log;
-
-    private final boolean hasNamedTopologies;
-    private final ProcessingMode processingMode;
     private final Tasks tasks;
-    private final TaskExecutionMetadata taskExecutionMetadata;
+    private final TaskManager taskManager;
+    private final TaskExecutionMetadata executionMetadata;
 
     public TaskExecutor(final Tasks tasks,
-                        final TaskExecutionMetadata taskExecutionMetadata,
-                        final ProcessingMode processingMode,
-                        final boolean hasNamedTopologies,
+                        final TaskManager taskManager,
+                        final TaskExecutionMetadata executionMetadata,
                         final LogContext logContext) {
         this.tasks = tasks;
-        this.taskExecutionMetadata = taskExecutionMetadata;
-        this.processingMode = processingMode;
-        this.hasNamedTopologies = hasNamedTopologies;
+        this.taskManager = taskManager;
+        this.executionMetadata = executionMetadata;
         this.log = logContext.logger(getClass());
     }
 
@@ -76,13 +71,13 @@ public class TaskExecutor {
         for (final Task task : tasks.activeTasks()) {
             final long now = time.milliseconds();
             try {
-                if (taskExecutionMetadata.canProcessTask(task, now)) {
+                if (executionMetadata.canProcessTask(task, now)) {
                     lastProcessed = task;
                     totalProcessed += processTask(task, maxNumRecords, now, time);
                 }
             } catch (final Throwable t) {
-                taskExecutionMetadata.registerTaskError(task, t, now);
-                tasks.removeTaskFromSuccessfullyProcessedBeforeClosing(lastProcessed);
+                executionMetadata.registerTaskError(task, t, now);
+                executionMetadata.removeTaskFromSuccessfullyProcessedBeforeClosing(lastProcessed);
                 commitSuccessfullyProcessedTasks();
                 throw t;
             }
@@ -102,9 +97,9 @@ public class TaskExecutor {
                 processed++;
             }
             // TODO: enable regardless of whether using named topologies
-            if (processed > 0 && hasNamedTopologies && processingMode != EXACTLY_ONCE_V2) {
+            if (processed > 0 && executionMetadata.hasNamedTopologies() && executionMetadata.processingMode() != EXACTLY_ONCE_V2) {
                 log.trace("Successfully processed task {}", task.id());
-                tasks.addToSuccessfullyProcessed(task);
+                executionMetadata.addToSuccessfullyProcessed(task);
             }
         } catch (final TimeoutException timeoutException) {
             // TODO consolidate TimeoutException retries with general error handling
@@ -181,12 +176,12 @@ public class TaskExecutor {
         final Set<TaskId> corruptedTasks = new HashSet<>();
 
         if (!offsetsPerTask.isEmpty()) {
-            if (processingMode == EXACTLY_ONCE_ALPHA) {
+            if (executionMetadata.processingMode() == EXACTLY_ONCE_ALPHA) {
                 for (final Map.Entry<Task, Map<TopicPartition, OffsetAndMetadata>> taskToCommit : offsetsPerTask.entrySet()) {
                     final Task task = taskToCommit.getKey();
                     try {
-                        tasks.streamsProducerForTask(task.id())
-                            .commitTransaction(taskToCommit.getValue(), tasks.mainConsumer().groupMetadata());
+                        taskManager.streamsProducerForTask(task.id())
+                            .commitTransaction(taskToCommit.getValue(), taskManager.mainConsumer().groupMetadata());
                         updateTaskCommitMetadata(taskToCommit.getValue());
                     } catch (final TimeoutException timeoutException) {
                         log.error(
@@ -200,9 +195,9 @@ public class TaskExecutor {
                 final Map<TopicPartition, OffsetAndMetadata> allOffsets = offsetsPerTask.values().stream()
                     .flatMap(e -> e.entrySet().stream()).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
 
-                if (processingMode == EXACTLY_ONCE_V2) {
+                if (executionMetadata.processingMode() == EXACTLY_ONCE_V2) {
                     try {
-                        tasks.threadProducer().commitTransaction(allOffsets, tasks.mainConsumer().groupMetadata());
+                        taskManager.threadProducer().commitTransaction(allOffsets, taskManager.mainConsumer().groupMetadata());
                         updateTaskCommitMetadata(allOffsets);
                     } catch (final TimeoutException timeoutException) {
                         log.error(
@@ -220,7 +215,7 @@ public class TaskExecutor {
                     }
                 } else {
                     try {
-                        tasks.mainConsumer().commitSync(allOffsets);
+                        taskManager.mainConsumer().commitSync(allOffsets);
                         updateTaskCommitMetadata(allOffsets);
                     } catch (final CommitFailedException error) {
                         throw new TaskMigratedException("Consumer committing offsets failed, " +
@@ -261,13 +256,13 @@ public class TaskExecutor {
     }
 
     private void commitSuccessfullyProcessedTasks() {
-        if (!tasks.successfullyProcessed().isEmpty()) {
+        if (!executionMetadata.successfullyProcessed().isEmpty()) {
             log.info("Streams encountered an error when processing tasks." +
                 " Will commit all previously successfully processed tasks {}",
-                tasks.successfullyProcessed().stream().map(Task::id));
-            commitTasksAndMaybeUpdateCommittableOffsets(tasks.successfullyProcessed(), new HashMap<>());
+                executionMetadata.successfullyProcessed().stream().map(Task::id));
+            commitTasksAndMaybeUpdateCommittableOffsets(executionMetadata.successfullyProcessed(), new HashMap<>());
         }
-        tasks.clearSuccessfullyProcessed();
+        executionMetadata.clearSuccessfullyProcessed();
     }
 
     /**
@@ -278,7 +273,7 @@ public class TaskExecutor {
 
         for (final Task task : tasks.activeTasks()) {
             try {
-                if (taskExecutionMetadata.canPunctuateTask(task)) {
+                if (executionMetadata.canPunctuateTask(task)) {
                     if (task.maybePunctuateStreamTime()) {
                         punctuated++;
                     }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
index f827cb5f68..fe6282d5c9 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
@@ -95,9 +95,10 @@ public class TaskManager {
     // includes assigned & initialized tasks and unassigned tasks we locked temporarily during rebalance
     private final Set<TaskId> lockedTaskDirectories = new HashSet<>();
 
+    private final ActiveTaskCreator activeTaskCreator;
+    private final StandbyTaskCreator standbyTaskCreator;
     private final StateUpdater stateUpdater;
 
-
     TaskManager(final Time time,
                 final ChangelogReader changelogReader,
                 final UUID processId,
@@ -109,12 +110,14 @@ public class TaskManager {
                 final StateDirectory stateDirectory,
                 final StreamsConfig streamsConfig) {
         this.time = time;
-        this.changelogReader = changelogReader;
         this.processId = processId;
         this.logPrefix = logPrefix;
-        this.topologyMetadata = topologyMetadata;
         this.adminClient = adminClient;
         this.stateDirectory = stateDirectory;
+        this.changelogReader = changelogReader;
+        this.topologyMetadata = topologyMetadata;
+        this.activeTaskCreator = activeTaskCreator;
+        this.standbyTaskCreator = standbyTaskCreator;
         this.processingMode = topologyMetadata.processingMode();
 
         final LogContext logContext = new LogContext(logPrefix);
@@ -128,23 +131,21 @@ public class TaskManager {
         } else {
             stateUpdater = null;
         }
-        this.tasks = new Tasks(logContext, activeTaskCreator, standbyTaskCreator, stateUpdater);
+        this.tasks = new Tasks(logContext);
         this.taskExecutor = new TaskExecutor(
             tasks,
+            this,
             topologyMetadata.taskExecutionMetadata(),
-            processingMode,
-            topologyMetadata.hasNamedTopologies(),
             logContext
         );
     }
 
     void setMainConsumer(final Consumer<byte[], byte[]> mainConsumer) {
         this.mainConsumer = mainConsumer;
-        tasks.setMainConsumer(mainConsumer);
     }
 
     public double totalProducerBlockedTime() {
-        return tasks.totalProducerBlockedTime();
+        return activeTaskCreator.totalProducerBlockedTime();
     }
 
     public UUID processId() {
@@ -155,6 +156,18 @@ public class TaskManager {
         return topologyMetadata;
     }
 
+    Consumer<byte[], byte[]> mainConsumer() {
+        return mainConsumer;
+    }
+
+    StreamsProducer streamsProducerForTask(final TaskId taskId) {
+        return activeTaskCreator.streamsProducerForTask(taskId);
+    }
+
+    StreamsProducer threadProducer() {
+        return activeTaskCreator.threadProducer();
+    }
+
     boolean isRebalanceInProgress() {
         return rebalanceInProgress;
     }
@@ -336,8 +349,8 @@ public class TaskManager {
             }
         }
 
-        tasks.addActivePendingTasks(pendingTasksToCreate(activeTasksToCreate));
-        tasks.addStandbyPendingTasks(pendingTasksToCreate(standbyTasksToCreate));
+        tasks.addPendingActiveTasks(pendingTasksToCreate(activeTasksToCreate));
+        tasks.addPendingStandbyTasks(pendingTasksToCreate(standbyTasksToCreate));
 
         // close and recycle those tasks
         closeAndRecycleTasks(
@@ -377,7 +390,21 @@ public class TaskManager {
             throw first.getValue();
         }
 
-        tasks.createTasks(activeTasksToCreate, standbyTasksToCreate);
+        createNewTasks(activeTasksToCreate, standbyTasksToCreate);
+    }
+
+    private void createNewTasks(final Map<TaskId, Set<TopicPartition>> activeTasksToCreate,
+                                final Map<TaskId, Set<TopicPartition>> standbyTasksToCreate) {
+        final Collection<Task> newActiveTasks = activeTaskCreator.createTasks(mainConsumer, activeTasksToCreate);
+        final Collection<Task> newStandbyTask = standbyTaskCreator.createTasks(standbyTasksToCreate);
+
+        if (stateUpdater == null) {
+            tasks.addNewActiveTasks(newActiveTasks);
+            tasks.addNewStandbyTasks(newStandbyTask);
+        } else {
+            tasks.addPendingTaskToRestore(newActiveTasks);
+            tasks.addPendingTaskToRestore(newStandbyTask);
+        }
     }
 
     private Map<TaskId, Set<TopicPartition>> pendingTasksToCreate(final Map<TaskId, Set<TopicPartition>> tasksToCreate) {
@@ -460,14 +487,14 @@ public class TaskManager {
         for (final Map.Entry<Task, Set<TopicPartition>> entry : tasksToRecycle.entrySet()) {
             final Task oldTask = entry.getKey();
             try {
-                oldTask.closeCleanAndRecycleState();
                 if (oldTask.isActive()) {
-                    tasks.convertActiveToStandby((StreamTask) oldTask, entry.getValue(), taskCloseExceptions);
+                    convertActiveToStandby((StreamTask) oldTask, entry.getValue());
                 } else {
-                    tasks.convertStandbyToActive((StandbyTask) oldTask, entry.getValue());
+                    convertStandbyToActive((StandbyTask) oldTask, entry.getValue());
                 }
             } catch (final RuntimeException e) {
-                final String uncleanMessage = String.format("Failed to recycle task %s cleanly. Attempting to close remaining tasks before re-throwing:", oldTask.id());
+                final String uncleanMessage = String.format("Failed to recycle task %s cleanly. " +
+                    "Attempting to close remaining tasks before re-throwing:", oldTask.id());
                 log.error(uncleanMessage, e);
                 taskCloseExceptions.putIfAbsent(oldTask.id(), e);
                 tasksToCloseDirty.add(oldTask);
@@ -480,6 +507,19 @@ public class TaskManager {
         }
     }
 
+    private void convertActiveToStandby(final StreamTask activeTask,
+                                        final Set<TopicPartition> partitions) {
+        final StandbyTask standbyTask = standbyTaskCreator.createStandbyTaskFromActive(activeTask, partitions);
+        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(activeTask.id());
+        tasks.replaceActiveWithStandby(standbyTask);
+    }
+
+    private void convertStandbyToActive(final StandbyTask standbyTask,
+                                        final Set<TopicPartition> partitions) {
+        final StreamTask activeTask = activeTaskCreator.createActiveTaskFromStandby(standbyTask, partitions, mainConsumer);
+        tasks.replaceStandbyWithActive(activeTask);
+    }
+
     /**
      * Tries to initialize any new or still-uninitialized tasks, then checks if they can/have completed restoration.
      *
@@ -490,53 +530,61 @@ public class TaskManager {
     boolean tryToCompleteRestoration(final long now, final java.util.function.Consumer<Set<TopicPartition>> offsetResetter) {
         boolean allRunning = true;
 
-        final List<Task> activeTasks = new LinkedList<>();
-        for (final Task task : tasks.allTasks()) {
-            try {
-                task.initializeIfNeeded();
-                task.clearTaskTimeout();
-            } catch (final LockException lockException) {
-                // it is possible that if there are multiple threads within the instance that one thread
-                // trying to grab the task from the other, while the other has not released the lock since
-                // it did not participate in the rebalance. In this case we can just retry in the next iteration
-                log.debug("Could not initialize task {} since: {}; will retry", task.id(), lockException.getMessage());
-                allRunning = false;
-            } catch (final TimeoutException timeoutException) {
-                task.maybeInitTaskTimeoutOrThrow(now, timeoutException);
-                allRunning = false;
-            }
+        if (stateUpdater == null) {
+            final List<Task> activeTasks = new LinkedList<>();
+            for (final Task task : tasks.allTasks()) {
+                try {
+                    task.initializeIfNeeded();
+                    task.clearTaskTimeout();
+                } catch (final LockException lockException) {
+                    // it is possible that if there are multiple threads within the instance that one thread
+                    // trying to grab the task from the other, while the other has not released the lock since
+                    // it did not participate in the rebalance. In this case we can just retry in the next iteration
+                    log.debug("Could not initialize task {} since: {}; will retry", task.id(), lockException.getMessage());
+                    allRunning = false;
+                } catch (final TimeoutException timeoutException) {
+                    task.maybeInitTaskTimeoutOrThrow(now, timeoutException);
+                    allRunning = false;
+                }
 
-            if (task.isActive()) {
-                activeTasks.add(task);
+                if (task.isActive()) {
+                    activeTasks.add(task);
+                }
             }
-        }
-
-        if (allRunning && !activeTasks.isEmpty()) {
-
-            final Set<TopicPartition> restored = changelogReader.completedChangelogs();
-
-            for (final Task task : activeTasks) {
-                if (restored.containsAll(task.changelogPartitions())) {
-                    try {
-                        task.completeRestoration(offsetResetter);
-                        task.clearTaskTimeout();
-                    } catch (final TimeoutException timeoutException) {
-                        task.maybeInitTaskTimeoutOrThrow(now, timeoutException);
-                        log.debug(
-                            String.format(
-                                "Could not complete restoration for %s due to the following exception; will retry",
-                                task.id()),
-                            timeoutException
-                        );
 
+            if (allRunning && !activeTasks.isEmpty()) {
+
+                final Set<TopicPartition> restored = changelogReader.completedChangelogs();
+
+                for (final Task task : activeTasks) {
+                    if (restored.containsAll(task.changelogPartitions())) {
+                        try {
+                            task.completeRestoration(offsetResetter);
+                            task.clearTaskTimeout();
+                        } catch (final TimeoutException timeoutException) {
+                            task.maybeInitTaskTimeoutOrThrow(now, timeoutException);
+                            log.debug(
+                                String.format(
+                                    "Could not complete restoration for %s due to the following exception; will retry",
+                                    task.id()),
+                                timeoutException
+                            );
+
+                            allRunning = false;
+                        }
+                    } else {
+                        // we found a restoring task that isn't done restoring, which is evidence that
+                        // not all tasks are running
                         allRunning = false;
                     }
-                } else {
-                    // we found a restoring task that isn't done restoring, which is evidence that
-                    // not all tasks are running
-                    allRunning = false;
                 }
             }
+        } else {
+            for (final Task task : tasks.drainPendingTaskToRestore()) {
+                stateUpdater.add(task);
+            }
+
+            // TODO: should add logic for checking and resuming when all active tasks have been restored
         }
 
         if (allRunning) {
@@ -698,7 +746,7 @@ public class TaskManager {
         }
 
         if (processingMode == EXACTLY_ONCE_V2) {
-            tasks.reInitializeThreadProducer();
+            activeTaskCreator.reInitializeThreadProducer();
         }
     }
 
@@ -843,6 +891,10 @@ public class TaskManager {
 
         try {
             tasks.removeTask(task);
+
+            if (task.isActive()) {
+                activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(task.id());
+            }
         } catch (final RuntimeException swallow) {
             log.error("Error removing dirty task {}: {}", task.id(), swallow.getMessage());
         }
@@ -852,6 +904,10 @@ public class TaskManager {
         task.closeClean();
         try {
             tasks.removeTask(task);
+
+            if (task.isActive()) {
+                activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(task.id());
+            }
         } catch (final RuntimeException e) {
             log.error("Error removing active task {}: {}", task.id(), e.getMessage());
             return e;
@@ -875,7 +931,7 @@ public class TaskManager {
 
         executeAndMaybeSwallow(
             clean,
-            tasks::closeThreadProducerIfNeeded,
+            activeTaskCreator::closeThreadProducerIfNeeded,
             e -> firstException.compareAndSet(null, e),
             e -> log.warn("Ignoring an exception while closing thread producer.", e)
         );
@@ -1185,7 +1241,7 @@ public class TaskManager {
      */
     void handleTopologyUpdates() {
         topologyMetadata.executeTopologyUpdatesAndBumpThreadVersion(
-            tasks::createPendingTasks,
+            this::createPendingTasks,
             this::maybeCloseTasksFromRemovedTopologies
         );
 
@@ -1213,7 +1269,6 @@ public class TaskManager {
 
             final Set<Task> allTasksToRemove = union(HashSet::new, activeTasksToRemove, standbyTasksToRemove);
             closeAndCleanUpTasks(activeTasksToRemove, standbyTasksToRemove, true);
-            allTasksToRemove.forEach(tasks::removeTask);
             releaseLockedDirectoriesForTasks(allTasksToRemove.stream().map(Task::id).collect(Collectors.toSet()));
         } catch (final Exception e) {
             // TODO KAFKA-12648: for now just swallow the exception to avoid interfering with the other topologies
@@ -1223,6 +1278,13 @@ public class TaskManager {
         }
     }
 
+    void createPendingTasks(final Set<String> currentNamedTopologies) {
+        final Map<TaskId, Set<TopicPartition>> activeTasksToCreate = tasks.pendingActiveTasksForTopologies(currentNamedTopologies);
+        final Map<TaskId, Set<TopicPartition>> standbyTasksToCreate = tasks.pendingStandbyTasksForTopologies(currentNamedTopologies);
+
+        createNewTasks(activeTasksToCreate, standbyTasksToCreate);
+    }
+
     /**
      * @throws TaskMigratedException if the task producer got fenced (EOS only)
      * @throws StreamsException      if any task threw an exception while processing
@@ -1298,11 +1360,11 @@ public class TaskManager {
     }
 
     Map<MetricName, Metric> producerMetrics() {
-        return tasks.producerMetrics();
+        return activeTaskCreator.producerMetrics();
     }
 
     Set<String> producerClientIds() {
-        return tasks.producerClientIds();
+        return activeTaskCreator.producerClientIds();
     }
 
     Set<TaskId> lockedTaskDirectories() {
@@ -1356,4 +1418,8 @@ public class TaskManager {
     void addTask(final Task task) {
         tasks.addTask(task);
     }
+
+    StateUpdater stateUpdater() {
+        return stateUpdater;
+    }
 }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java
index 2904eca068..d7c2cc5037 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java
@@ -16,9 +16,6 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
-import org.apache.kafka.clients.consumer.Consumer;
-import org.apache.kafka.common.Metric;
-import org.apache.kafka.common.MetricName;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.streams.processor.TaskId;
@@ -61,118 +58,86 @@ class Tasks {
     //
     // When that occurs we stash these pending tasks until either they are finally clear to be created,
     // or they are revoked from a new assignment.
-    private final Map<TaskId, Set<TopicPartition>> pendingActiveTasks = new HashMap<>();
-    private final Map<TaskId, Set<TopicPartition>> pendingStandbyTasks = new HashMap<>();
+    private final Map<TaskId, Set<TopicPartition>> pendingActiveTasksToCreate = new HashMap<>();
+    private final Map<TaskId, Set<TopicPartition>> pendingStandbyTasksToCreate = new HashMap<>();
+
+    private final Set<Task> pendingTasksToRestore = new HashSet<>();
 
     // TODO: change type to `StreamTask`
     private final Map<TopicPartition, Task> activeTasksPerPartition = new HashMap<>();
 
-    private final Collection<Task> successfullyProcessed = new HashSet<>();
-
-    private final ActiveTaskCreator activeTaskCreator;
-    private final StandbyTaskCreator standbyTaskCreator;
-    private final StateUpdater stateUpdater;
-
-    private Consumer<byte[], byte[]> mainConsumer;
-
-    Tasks(final LogContext logContext,
-          final ActiveTaskCreator activeTaskCreator,
-          final StandbyTaskCreator standbyTaskCreator,
-          final StateUpdater stateUpdater) {
-
+    Tasks(final LogContext logContext) {
         this.log = logContext.logger(getClass());
-        this.activeTaskCreator = activeTaskCreator;
-        this.standbyTaskCreator = standbyTaskCreator;
-        this.stateUpdater = stateUpdater;
-    }
-
-    void setMainConsumer(final Consumer<byte[], byte[]> mainConsumer) {
-        this.mainConsumer = mainConsumer;
     }
 
     void purgePendingTasks(final Set<TaskId> assignedActiveTasks, final Set<TaskId> assignedStandbyTasks) {
-        pendingActiveTasks.keySet().retainAll(assignedActiveTasks);
-        pendingStandbyTasks.keySet().retainAll(assignedStandbyTasks);
+        pendingActiveTasksToCreate.keySet().retainAll(assignedActiveTasks);
+        pendingStandbyTasksToCreate.keySet().retainAll(assignedStandbyTasks);
     }
 
-    void addActivePendingTasks(final Map<TaskId, Set<TopicPartition>> pendingTasks) {
-        pendingActiveTasks.putAll(pendingTasks);
+    void addPendingActiveTasks(final Map<TaskId, Set<TopicPartition>> pendingTasks) {
+        pendingActiveTasksToCreate.putAll(pendingTasks);
     }
 
-    void addStandbyPendingTasks(final Map<TaskId, Set<TopicPartition>> pendingTasks) {
-        pendingStandbyTasks.putAll(pendingTasks);
+    void addPendingStandbyTasks(final Map<TaskId, Set<TopicPartition>> pendingTasks) {
+        pendingStandbyTasksToCreate.putAll(pendingTasks);
     }
 
-    void createPendingTasks(final Set<String> currentNamedTopologies) {
-        createTasks(
-            pendingActiveTasksForTopologies(currentNamedTopologies),
-            pendingStandbyTasksForTopologies(currentNamedTopologies)
-        );
+    void addPendingTaskToRestore(final Collection<Task> tasks) {
+        pendingTasksToRestore.addAll(tasks);
     }
 
-    private Map<TaskId, Set<TopicPartition>> pendingActiveTasksForTopologies(final Set<String> currentTopologies) {
-        return filterMap(pendingActiveTasks, t -> currentTopologies.contains(t.getKey().topologyName()));
+    Set<Task> drainPendingTaskToRestore() {
+        final Set<Task> result = new HashSet<>(pendingTasksToRestore);
+        pendingTasksToRestore.clear();
+        return result;
     }
 
-    private Map<TaskId, Set<TopicPartition>> pendingStandbyTasksForTopologies(final Set<String> currentTopologies) {
-        return filterMap(pendingStandbyTasks, t -> currentTopologies.contains(t.getKey().topologyName()));
+    Map<TaskId, Set<TopicPartition>> pendingActiveTasksForTopologies(final Set<String> currentTopologies) {
+        return filterMap(pendingActiveTasksToCreate, t -> currentTopologies.contains(t.getKey().topologyName()));
     }
 
-    void createTasks(final Map<TaskId, Set<TopicPartition>> activeTasksToCreate,
-                     final Map<TaskId, Set<TopicPartition>> standbyTasksToCreate) {
-        createActiveTasks(activeTasksToCreate);
-        createStandbyTasks(standbyTasksToCreate);
+    Map<TaskId, Set<TopicPartition>> pendingStandbyTasksForTopologies(final Set<String> currentTopologies) {
+        return filterMap(pendingStandbyTasksToCreate, t -> currentTopologies.contains(t.getKey().topologyName()));
     }
 
-    private void createActiveTasks(final Map<TaskId, Set<TopicPartition>> activeTasksToCreate) {
-        for (final Map.Entry<TaskId, Set<TopicPartition>> taskToBeCreated : activeTasksToCreate.entrySet()) {
-            final TaskId taskId = taskToBeCreated.getKey();
+    void addNewActiveTasks(final Collection<Task> newTasks) {
+        if (!newTasks.isEmpty()) {
+            for (final Task activeTask : newTasks) {
+                final TaskId taskId = activeTask.id();
 
-            if (activeTasksPerId.containsKey(taskId)) {
-                throw new IllegalStateException("Attempted to create an active task that we already own: " + taskId);
-            }
+                if (activeTasksPerId.containsKey(taskId)) {
+                    throw new IllegalStateException("Attempted to create an active task that we already own: " + taskId);
+                }
 
-            if (pendingStandbyTasks.containsKey(taskId)) {
-                throw new IllegalStateException("Attempted to create an active task while we already own its standby: " + taskId);
-            }
-        }
+                if (standbyTasksPerId.containsKey(taskId)) {
+                    throw new IllegalStateException("Attempted to create an active task while we already own its standby: " + taskId);
+                }
 
-        if (!activeTasksToCreate.isEmpty()) {
-            for (final Task activeTask : activeTaskCreator.createTasks(mainConsumer, activeTasksToCreate)) {
-                if (stateUpdater != null) {
-                    stateUpdater.add(activeTask);
-                } else {
-                    activeTasksPerId.put(activeTask.id(), activeTask);
-                    pendingActiveTasks.remove(activeTask.id());
-                    for (final TopicPartition topicPartition : activeTask.inputPartitions()) {
-                        activeTasksPerPartition.put(topicPartition, activeTask);
-                    }
+                activeTasksPerId.put(activeTask.id(), activeTask);
+                pendingActiveTasksToCreate.remove(activeTask.id());
+                for (final TopicPartition topicPartition : activeTask.inputPartitions()) {
+                    activeTasksPerPartition.put(topicPartition, activeTask);
                 }
             }
         }
     }
 
-    private void createStandbyTasks(final Map<TaskId, Set<TopicPartition>> standbyTasksToCreate) {
-        for (final Map.Entry<TaskId, Set<TopicPartition>> taskToBeCreated : standbyTasksToCreate.entrySet()) {
-            final TaskId taskId = taskToBeCreated.getKey();
+    void addNewStandbyTasks(final Collection<Task> newTasks) {
+        if (!newTasks.isEmpty()) {
+            for (final Task standbyTask : newTasks) {
+                final TaskId taskId = standbyTask.id();
 
-            if (standbyTasksPerId.containsKey(taskId)) {
-                throw new IllegalStateException("Attempted to create an active task that we already own: " + taskId);
-            }
-
-            if (pendingActiveTasks.containsKey(taskId)) {
-                throw new IllegalStateException("Attempted to create an active task while we already own its standby: " + taskId);
-            }
-        }
+                if (standbyTasksPerId.containsKey(taskId)) {
+                    throw new IllegalStateException("Attempted to create an standby task that we already own: " + taskId);
+                }
 
-        if (!standbyTasksToCreate.isEmpty()) {
-            for (final Task standbyTask : standbyTaskCreator.createTasks(standbyTasksToCreate)) {
-                if (stateUpdater != null) {
-                    stateUpdater.add(standbyTask);
-                } else {
-                    standbyTasksPerId.put(standbyTask.id(), standbyTask);
-                    pendingActiveTasks.remove(standbyTask.id());
+                if (activeTasksPerId.containsKey(taskId)) {
+                    throw new IllegalStateException("Attempted to create an standby task while we already own its active: " + taskId);
                 }
+
+                standbyTasksPerId.put(standbyTask.id(), standbyTask);
+                pendingStandbyTasksToCreate.remove(standbyTask.id());
             }
         }
     }
@@ -188,44 +153,32 @@ class Tasks {
             if (activeTasksPerId.remove(taskId) == null) {
                 throw new IllegalArgumentException("Attempted to remove an active task that is not owned: " + taskId);
             }
-            activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId);
             removePartitionsForActiveTask(taskId);
-            pendingActiveTasks.remove(taskId);
+            pendingActiveTasksToCreate.remove(taskId);
         } else {
             if (standbyTasksPerId.remove(taskId) == null) {
                 throw new IllegalArgumentException("Attempted to remove a standby task that is not owned: " + taskId);
             }
-            pendingStandbyTasks.remove(taskId);
+            pendingStandbyTasksToCreate.remove(taskId);
         }
     }
 
-    void convertActiveToStandby(final StreamTask activeTask,
-                                final Set<TopicPartition> partitions,
-                                final Map<TaskId, RuntimeException> taskCloseExceptions) {
-        final TaskId taskId = activeTask.id();
+    void replaceActiveWithStandby(final StandbyTask standbyTask) {
+        final TaskId taskId = standbyTask.id();
         if (activeTasksPerId.remove(taskId) == null) {
-            throw new IllegalStateException("Attempted to convert unknown active task to standby task: " + taskId);
+            throw new IllegalStateException("Attempted to replace unknown active task with standby task: " + taskId);
         }
         removePartitionsForActiveTask(taskId);
 
-        try {
-            activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId);
-        } catch (final RuntimeException e) {
-            taskCloseExceptions.putIfAbsent(taskId, e);
-        }
-
-        final StandbyTask standbyTask = standbyTaskCreator.createStandbyTaskFromActive(activeTask, partitions);
         standbyTasksPerId.put(standbyTask.id(), standbyTask);
     }
 
-    void convertStandbyToActive(final StandbyTask standbyTask,
-                                final Set<TopicPartition> partitions) {
-        final TaskId taskId = standbyTask.id();
+    void replaceStandbyWithActive(final StreamTask activeTask) {
+        final TaskId taskId = activeTask.id();
         if (standbyTasksPerId.remove(taskId) == null) {
             throw new IllegalStateException("Attempted to convert unknown standby task to stream task: " + taskId);
         }
 
-        final StreamTask activeTask = activeTaskCreator.createActiveTaskFromStandby(standbyTask, partitions, mainConsumer);
         activeTasksPerId.put(activeTask.id(), activeTask);
         for (final TopicPartition topicPartition : activeTask.inputPartitions()) {
             activeTasksPerPartition.put(topicPartition, activeTask);
@@ -249,14 +202,6 @@ class Tasks {
         return requiresUpdate;
     }
 
-    void reInitializeThreadProducer() {
-        activeTaskCreator.reInitializeThreadProducer();
-    }
-
-    void closeThreadProducerIfNeeded() {
-        activeTaskCreator.closeThreadProducerIfNeeded();
-    }
-
     private void removePartitionsForActiveTask(final TaskId taskId) {
         final Set<TopicPartition> toBeRemoved = activeTasksPerPartition.entrySet().stream()
             .filter(e -> e.getValue().id().equals(taskId))
@@ -331,46 +276,6 @@ class Tasks {
         return getTask(taskId) != null;
     }
 
-    StreamsProducer streamsProducerForTask(final TaskId taskId) {
-        return activeTaskCreator.streamsProducerForTask(taskId);
-    }
-
-    StreamsProducer threadProducer() {
-        return activeTaskCreator.threadProducer();
-    }
-
-    Map<MetricName, Metric> producerMetrics() {
-        return activeTaskCreator.producerMetrics();
-    }
-
-    Set<String> producerClientIds() {
-        return activeTaskCreator.producerClientIds();
-    }
-
-    Consumer<byte[], byte[]> mainConsumer() {
-        return mainConsumer;
-    }
-
-    Collection<Task> successfullyProcessed() {
-        return successfullyProcessed;
-    }
-
-    void addToSuccessfullyProcessed(final Task task) {
-        successfullyProcessed.add(task);
-    }
-
-    void removeTaskFromSuccessfullyProcessedBeforeClosing(final Task task) {
-        successfullyProcessed.remove(task);
-    }
-
-    void clearSuccessfullyProcessed() {
-        successfullyProcessed.clear();
-    }
-
-    double totalProducerBlockedTime() {
-        return activeTaskCreator.totalProducerBlockedTime();
-    }
-
     // for testing only
     void addTask(final Task task) {
         if (task.isActive()) {
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TopologyMetadata.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TopologyMetadata.java
index e3c057b9a0..0e5d68cd75 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TopologyMetadata.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TopologyMetadata.java
@@ -113,7 +113,7 @@ public class TopologyMetadata {
         } else {
             builders.put(UNNAMED_TOPOLOGY, builder);
         }
-        this.taskExecutionMetadata = new TaskExecutionMetadata(builders.keySet(), pausedTopologies);
+        this.taskExecutionMetadata = new TaskExecutionMetadata(builders.keySet(), pausedTopologies, processingMode);
     }
 
     public TopologyMetadata(final ConcurrentNavigableMap<String, InternalTopologyBuilder> builders,
@@ -128,7 +128,7 @@ public class TopologyMetadata {
         if (builders.isEmpty()) {
             log.info("Created an empty KafkaStreams app with no topology");
         }
-        this.taskExecutionMetadata = new TaskExecutionMetadata(builders.keySet(), pausedTopologies);
+        this.taskExecutionMetadata = new TaskExecutionMetadata(builders.keySet(), pausedTopologies, processingMode);
     }
 
     // Need to (re)set the log here to pick up the `processId` part of the clientId in the prefix
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
index dd6bb6cafc..02d742d8ab 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
@@ -571,13 +571,13 @@ public class StandbyTaskTest {
         EasyMock.replay(stateManager);
 
         task = createStandbyTask();
-        assertThrows(IllegalStateException.class, () -> task.closeCleanAndRecycleState()); // CREATED
+        assertThrows(IllegalStateException.class, () -> task.prepareRecycle()); // CREATED
 
         task.initializeIfNeeded();
-        assertThrows(IllegalStateException.class, () -> task.closeCleanAndRecycleState()); // RUNNING
+        assertThrows(IllegalStateException.class, () -> task.prepareRecycle()); // RUNNING
 
         task.suspend();
-        task.closeCleanAndRecycleState(); // SUSPENDED
+        task.prepareRecycle(); // SUSPENDED
 
         // Currently, there are no metrics registered for standby tasks.
         // This is a regression test so that, if we add some, we will be sure to deregister them.
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
index 8f18c01e3d..26efc1126b 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
@@ -2187,7 +2187,7 @@ public class StreamTaskTest {
 
         task.suspend();
         assertThat(getTaskMetrics(), not(empty()));
-        task.closeCleanAndRecycleState();
+        task.prepareRecycle();
         assertThat(getTaskMetrics(), empty());
     }
 
@@ -2266,7 +2266,7 @@ public class StreamTaskTest {
         task.process(0L);
         assertTrue(task.commitNeeded());
 
-        assertThrows(TaskMigratedException.class, () -> task.closeCleanAndRecycleState());
+        assertThrows(TaskMigratedException.class, () -> task.prepareRecycle());
     }
 
     @Test
@@ -2277,16 +2277,16 @@ public class StreamTaskTest {
         EasyMock.replay(stateManager, recordCollector);
 
         task = createStatefulTask(createConfig("100"), true);
-        assertThrows(IllegalStateException.class, () -> task.closeCleanAndRecycleState()); // CREATED
+        assertThrows(IllegalStateException.class, () -> task.prepareRecycle()); // CREATED
 
         task.initializeIfNeeded();
-        assertThrows(IllegalStateException.class, () -> task.closeCleanAndRecycleState()); // RESTORING
+        assertThrows(IllegalStateException.class, () -> task.prepareRecycle()); // RESTORING
 
         task.completeRestoration(noOpResetter -> { });
-        assertThrows(IllegalStateException.class, () -> task.closeCleanAndRecycleState()); // RUNNING
+        assertThrows(IllegalStateException.class, () -> task.prepareRecycle()); // RUNNING
 
         task.suspend();
-        task.closeCleanAndRecycleState(); // SUSPENDED
+        task.prepareRecycle(); // SUSPENDED
 
         EasyMock.verify(stateManager, recordCollector);
     }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
index cff917c64e..d3bbe9fe4f 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
@@ -841,6 +841,7 @@ public class StreamThreadTest {
         expect(activeTaskCreator.producerClientIds()).andStubReturn(Collections.singleton("producerClientId"));
 
         final StandbyTaskCreator standbyTaskCreator = mock(StandbyTaskCreator.class);
+        expect(standbyTaskCreator.createTasks(anyObject())).andStubReturn(Collections.emptySet());
 
         EasyMock.replay(consumer, consumerGroupMetadata, task, activeTaskCreator, standbyTaskCreator);
 
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskExecutionMetadataTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskExecutionMetadataTest.java
index 7fcb5e30b4..127999ad51 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskExecutionMetadataTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskExecutionMetadataTest.java
@@ -16,6 +16,7 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
+import org.apache.kafka.streams.internals.StreamsConfigUtils.ProcessingMode;
 import org.apache.kafka.streams.processor.TaskId;
 import org.junit.Assert;
 import org.junit.Test;
@@ -41,7 +42,7 @@ public class TaskExecutionMetadataTest {
         final Set<String> topologies = Collections.singleton(UNNAMED_TOPOLOGY);
         final Set<String> pausedTopologies = new HashSet<>();
 
-        final TaskExecutionMetadata metadata = new TaskExecutionMetadata(topologies, pausedTopologies);
+        final TaskExecutionMetadata metadata = new TaskExecutionMetadata(topologies, pausedTopologies, ProcessingMode.AT_LEAST_ONCE);
 
         final Task mockTask = createMockTask(UNNAMED_TOPOLOGY);
 
@@ -55,7 +56,7 @@ public class TaskExecutionMetadataTest {
     @Test
     public void testNamedTopologiesCanBePausedIndependently() {
         final Set<String> pausedTopologies = new HashSet<>();
-        final TaskExecutionMetadata metadata = new TaskExecutionMetadata(NAMED_TOPOLOGIES, pausedTopologies);
+        final TaskExecutionMetadata metadata = new TaskExecutionMetadata(NAMED_TOPOLOGIES, pausedTopologies, ProcessingMode.AT_LEAST_ONCE);
 
         final Task mockTask1 = createMockTask(TOPOLOGY1);
         final Task mockTask2 = createMockTask(TOPOLOGY2);
@@ -77,8 +78,7 @@ public class TaskExecutionMetadataTest {
         final Set<String> pausedTopologies = new HashSet<>();
         pausedTopologies.add(TOPOLOGY1);
 
-        final TaskExecutionMetadata metadata = new TaskExecutionMetadata(NAMED_TOPOLOGIES,
-            pausedTopologies);
+        final TaskExecutionMetadata metadata = new TaskExecutionMetadata(NAMED_TOPOLOGIES, pausedTopologies, ProcessingMode.AT_LEAST_ONCE);
 
         final Task mockTask1 = createMockTask(TOPOLOGY1);
         final Task mockTask2 = createMockTask(TOPOLOGY2);
@@ -95,8 +95,7 @@ public class TaskExecutionMetadataTest {
     public void testNamedTopologiesCanBackoff() {
         final Set<String> pausedTopologies = new HashSet<>();
 
-        final TaskExecutionMetadata metadata = new TaskExecutionMetadata(NAMED_TOPOLOGIES,
-            pausedTopologies);
+        final TaskExecutionMetadata metadata = new TaskExecutionMetadata(NAMED_TOPOLOGIES, pausedTopologies, ProcessingMode.AT_LEAST_ONCE);
 
         final Task mockTask1 = createMockTask(TOPOLOGY1);
         final Task mockTask2 = createMockTask(TOPOLOGY2);
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskExecutorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskExecutorTest.java
index 88ee70e57e..131a68044b 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskExecutorTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskExecutorTest.java
@@ -18,7 +18,6 @@
 package org.apache.kafka.streams.processor.internals;
 
 import org.apache.kafka.common.utils.LogContext;
-import org.apache.kafka.streams.internals.StreamsConfigUtils.ProcessingMode;
 import org.junit.Test;
 
 import static org.mockito.Mockito.mock;
@@ -28,10 +27,10 @@ public class TaskExecutorTest {
     @Test
     public void testPunctuateWithPause() {
         final Tasks tasks = mock(Tasks.class);
+        final TaskManager taskManager = mock(TaskManager.class);
         final TaskExecutionMetadata metadata = mock(TaskExecutionMetadata.class);
 
-        final TaskExecutor taskExecutor =
-            new TaskExecutor(tasks, metadata, ProcessingMode.AT_LEAST_ONCE, false, new LogContext());
+        final TaskExecutor taskExecutor = new TaskExecutor(tasks, taskManager, metadata, new LogContext());
 
         taskExecutor.punctuate();
         verify(tasks).activeTasks();
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
index e943938fec..35e0283d86 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
@@ -75,6 +75,7 @@ import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
+import java.util.Properties;
 import java.util.Set;
 import java.util.UUID;
 import java.util.concurrent.atomic.AtomicBoolean;
@@ -99,7 +100,6 @@ import static org.easymock.EasyMock.anyString;
 import static org.easymock.EasyMock.eq;
 import static org.easymock.EasyMock.expect;
 import static org.easymock.EasyMock.expectLastCall;
-import static org.easymock.EasyMock.mock;
 import static org.easymock.EasyMock.replay;
 import static org.easymock.EasyMock.reset;
 import static org.easymock.EasyMock.resetToStrict;
@@ -112,10 +112,15 @@ import static org.hamcrest.Matchers.is;
 import static org.hamcrest.core.IsEqual.equalTo;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
+import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.when;
+import static org.mockito.Mockito.mock;
 
 @RunWith(EasyMockRunner.class)
 public class TaskManagerTest {
@@ -204,17 +209,63 @@ public class TaskManagerTest {
         expect(topologyBuilder.nodeToSourceTopics()).andStubReturn(emptyMap());
     }
 
+    @Test
+    public void shouldCreateStateUpdaterAndUpdateOnTaskInitialization() {
+        final StreamTask task00 = mock(StreamTask.class);
+        final StandbyTask task01 = mock(StandbyTask.class);
+        when(task00.id()).thenReturn(taskId00);
+        when(task01.id()).thenReturn(taskId01);
+        when(task00.inputPartitions()).thenReturn(taskId00Partitions);
+        when(task01.inputPartitions()).thenReturn(taskId01Partitions);
+        when(task00.isActive()).thenReturn(true);
+        when(task01.isActive()).thenReturn(false);
+        when(task00.state()).thenReturn(State.RESTORING);
+        when(task01.state()).thenReturn(State.RUNNING);
+        expect(changeLogReader.completedChangelogs()).andReturn(emptySet()).anyTimes();
+        expect(consumer.assignment()).andReturn(emptySet()).anyTimes();
+        consumer.resume(anyObject());
+        expectLastCall().anyTimes();
+        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andStubReturn(singletonList(task00));
+        expect(standbyTaskCreator.createTasks(eq(taskId01Assignment))).andStubReturn(singletonList(task01));
+        replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
+
+        final Properties prop = getStreamsConfig("test-app");
+        prop.put(StreamsConfig.InternalConfig.STATE_UPDATER_ENABLED, true);
+        taskManager = new TaskManager(
+            time,
+            changeLogReader,
+            UUID.randomUUID(),
+            "taskManagerTest",
+            activeTaskCreator,
+            standbyTaskCreator,
+            topologyMetadata,
+            adminClient,
+            stateDirectory,
+            new StreamsConfig(prop)
+        );
+        taskManager.setMainConsumer(consumer);
+
+        assertNotNull(taskManager.stateUpdater());
+
+        taskManager.handleAssignment(taskId00Assignment, taskId01Assignment);
+        taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter -> { });
+
+        assertEquals(Collections.singleton(task00), taskManager.stateUpdater().getActiveTasks());
+        assertEquals(Collections.singleton(task01), taskManager.stateUpdater().getStandbyTasks());
+    }
+
     @Test
     public void shouldIdempotentlyUpdateSubscriptionFromActiveAssignment() {
         final TopicPartition newTopicPartition = new TopicPartition("topic2", 1);
         final Map<TaskId, Set<TopicPartition>> assignment = mkMap(mkEntry(taskId01, mkSet(t1p1, newTopicPartition)));
 
         expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andStubReturn(emptyList());
+        expect(standbyTaskCreator.createTasks(anyObject())).andStubReturn(Collections.emptySet());
 
         topologyBuilder.addSubscribedTopicsFromAssignment(eq(asList(t1p1, newTopicPartition)), anyString());
         expectLastCall();
 
-        replay(activeTaskCreator, topologyBuilder);
+        replay(activeTaskCreator, standbyTaskCreator, topologyBuilder);
 
         taskManager.handleAssignment(assignment, emptyMap());
 
@@ -381,7 +432,8 @@ public class TaskManagerTest {
         taskManager.handleRebalanceStart(singleton("topic"));
         final StateMachineTask uninitializedTask = new StateMachineTask(taskId00, taskId00Partitions, true);
         expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andStubReturn(singleton(uninitializedTask));
-        replay(activeTaskCreator);
+        expect(standbyTaskCreator.createTasks(anyObject())).andStubReturn(Collections.emptySet());
+        replay(activeTaskCreator, standbyTaskCreator);
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
 
         assertThat(uninitializedTask.state(), is(State.CREATED));
@@ -406,7 +458,8 @@ public class TaskManagerTest {
 
         taskManager.handleRebalanceStart(singleton("topic"));
         expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andStubReturn(singleton(closedTask));
-        replay(activeTaskCreator);
+        expect(standbyTaskCreator.createTasks(anyObject())).andStubReturn(Collections.emptySet());
+        replay(activeTaskCreator, standbyTaskCreator);
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
 
         closedTask.suspend();
@@ -658,10 +711,11 @@ public class TaskManagerTest {
         // `handleAssignment`
         expectRestoreToBeCompleted(consumer, changeLogReader);
         expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andStubReturn(singletonList(task00));
+        expect(standbyTaskCreator.createTasks(anyObject())).andStubReturn(Collections.emptySet());
         topologyBuilder.addSubscribedTopicsFromAssignment(anyObject(), anyString());
         expectLastCall().anyTimes();
         expect(consumer.assignment()).andReturn(taskId00Partitions);
-        replay(activeTaskCreator, topologyBuilder, consumer, changeLogReader);
+        replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer, changeLogReader);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), tp -> assertThat(tp, is(empty()))), is(true));
@@ -697,10 +751,11 @@ public class TaskManagerTest {
 
         expectRestoreToBeCompleted(consumer, changeLogReader);
         expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andStubReturn(singletonList(task00));
+        expect(standbyTaskCreator.createTasks(anyObject())).andStubReturn(Collections.emptySet());
         topologyBuilder.addSubscribedTopicsFromAssignment(anyObject(), anyString());
         expectLastCall().anyTimes();
         expect(consumer.assignment()).andReturn(taskId00Partitions);
-        replay(activeTaskCreator, topologyBuilder, consumer, changeLogReader);
+        replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer, changeLogReader);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), tp -> assertThat(tp, is(empty()))), is(true));
@@ -733,6 +788,8 @@ public class TaskManagerTest {
         // `handleAssignment`
         expect(activeTaskCreator.createTasks(anyObject(), eq(assignment)))
             .andStubReturn(asList(corruptedTask, nonCorruptedTask));
+        expect(standbyTaskCreator.createTasks(anyObject()))
+            .andStubReturn(Collections.emptySet());
         topologyBuilder.addSubscribedTopicsFromAssignment(anyObject(), anyString());
         expectLastCall().anyTimes();
         expectRestoreToBeCompleted(consumer, changeLogReader);
@@ -740,7 +797,7 @@ public class TaskManagerTest {
         // check that we should not commit empty map either
         consumer.commitSync(eq(emptyMap()));
         expectLastCall().andStubThrow(new AssertionError("should not invoke commitSync when offset map is empty"));
-        replay(activeTaskCreator, topologyBuilder, consumer, changeLogReader);
+        replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer, changeLogReader);
 
         taskManager.handleAssignment(assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), tp -> assertThat(tp, is(empty()))), is(true));
@@ -775,10 +832,12 @@ public class TaskManagerTest {
         // `handleAssignment`
         expect(activeTaskCreator.createTasks(anyObject(), eq(assignment)))
             .andStubReturn(asList(corruptedTask, nonRunningNonCorruptedTask));
+        expect(standbyTaskCreator.createTasks(anyObject()))
+            .andStubReturn(Collections.emptySet());
         topologyBuilder.addSubscribedTopicsFromAssignment(anyObject(), anyString());
         expectLastCall().anyTimes();
         expect(consumer.assignment()).andReturn(taskId00Partitions);
-        replay(activeTaskCreator, topologyBuilder, consumer, changeLogReader);
+        replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer, changeLogReader);
 
         taskManager.handleAssignment(assignment, emptyMap());
 
@@ -853,6 +912,7 @@ public class TaskManagerTest {
         assignment.putAll(taskId00Assignment);
         assignment.putAll(taskId01Assignment);
         expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andStubReturn(asList(corruptedActive, uncorruptedActive));
+        expect(standbyTaskCreator.createTasks(anyObject())).andStubReturn(Collections.emptySet());
         topologyBuilder.addSubscribedTopicsFromAssignment(anyObject(), anyString());
         expectLastCall().anyTimes();
         topologyBuilder.addSubscribedTopicsFromMetadata(eq(singleton(topic1)), anyObject());
@@ -908,6 +968,7 @@ public class TaskManagerTest {
         assignment.putAll(taskId00Assignment);
         assignment.putAll(taskId01Assignment);
         expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andStubReturn(asList(corruptedActive, uncorruptedActive));
+        expect(standbyTaskCreator.createTasks(anyObject())).andStubReturn(Collections.emptySet());
         topologyBuilder.addSubscribedTopicsFromAssignment(anyObject(), anyString());
         expectLastCall().anyTimes();
 
@@ -985,6 +1046,7 @@ public class TaskManagerTest {
         assignment.putAll(taskId00Assignment);
         assignment.putAll(taskId01Assignment);
         expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andStubReturn(asList(corruptedActiveTask, uncorruptedActiveTask));
+        expect(standbyTaskCreator.createTasks(anyObject())).andStubReturn(Collections.emptySet());
         topologyBuilder.addSubscribedTopicsFromAssignment(anyObject(), anyString());
         expectLastCall().anyTimes();
 
@@ -992,12 +1054,12 @@ public class TaskManagerTest {
 
         final ConsumerGroupMetadata groupMetadata = new ConsumerGroupMetadata("appId");
         expect(consumer.groupMetadata()).andReturn(groupMetadata);
-        producer.commitTransaction(offsets, groupMetadata);
-        expectLastCall().andThrow(new TimeoutException());
+
+        doThrow(new TimeoutException()).when(producer).commitTransaction(offsets, groupMetadata);
 
         expect(consumer.assignment()).andStubReturn(union(HashSet::new, taskId00Partitions, taskId01Partitions));
 
-        replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer, changeLogReader, stateManager, producer);
+        replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer, changeLogReader, stateManager);
 
         taskManager.handleAssignment(assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true));
@@ -1068,6 +1130,7 @@ public class TaskManagerTest {
         expectRestoreToBeCompleted(consumer, changeLogReader);
 
         expect(activeTaskCreator.createTasks(anyObject(), eq(assignmentActive))).andReturn(asList(revokedActiveTask, unrevokedActiveTaskWithCommitNeeded, unrevokedActiveTaskWithoutCommitNeeded));
+        expect(standbyTaskCreator.createTasks(anyObject())).andStubReturn(Collections.emptySet());
         activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId00);
         expectLastCall();
         consumer.commitSync(expectedCommittedOffsets);
@@ -1131,17 +1194,18 @@ public class TaskManagerTest {
         expectRestoreToBeCompleted(consumer, changeLogReader);
 
         expect(activeTaskCreator.createTasks(anyObject(), eq(assignmentActive))).andReturn(asList(revokedActiveTask, unrevokedActiveTask, unrevokedActiveTaskWithoutCommitNeeded));
+        expect(standbyTaskCreator.createTasks(anyObject())).andStubReturn(Collections.emptySet());
         activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId00);
         expectLastCall();
 
         final ConsumerGroupMetadata groupMetadata = new ConsumerGroupMetadata("appId");
         expect(consumer.groupMetadata()).andReturn(groupMetadata);
-        producer.commitTransaction(expectedCommittedOffsets, groupMetadata);
-        expectLastCall().andThrow(new TimeoutException());
+
+        doThrow(new TimeoutException()).when(producer).commitTransaction(expectedCommittedOffsets, groupMetadata);
 
         expect(consumer.assignment()).andStubReturn(union(HashSet::new, taskId00Partitions, taskId01Partitions, taskId02Partitions));
 
-        replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader, producer, stateManager);
+        replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader, stateManager);
 
         taskManager.handleAssignment(assignmentActive, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true));
@@ -1168,6 +1232,8 @@ public class TaskManagerTest {
 
         expectRestoreToBeCompleted(consumer, changeLogReader);
         expect(standbyTaskCreator.createTasks(eq(taskId00Assignment))).andStubReturn(singletonList(task00));
+        expect(activeTaskCreator.createTasks(anyObject(), anyObject())).andStubReturn(Collections.emptySet());
+        expect(standbyTaskCreator.createTasks(eq(Collections.emptyMap()))).andStubReturn(Collections.emptySet());
         consumer.commitSync(Collections.emptyMap());
         expectLastCall();
         replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
@@ -1191,7 +1257,9 @@ public class TaskManagerTest {
         // expect these calls twice (because we're going to tryToCompleteRestoration twice)
         expectRestoreToBeCompleted(consumer, changeLogReader);
         expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andReturn(singletonList(task00));
-        expect(standbyTaskCreator.createTasks(eq(taskId01Assignment))).andReturn(singletonList(task01));
+        expect(activeTaskCreator.createTasks(anyObject(), eq(Collections.emptyMap()))).andReturn(Collections.emptySet());
+        expect(standbyTaskCreator.createTasks(eq(taskId01Assignment))).andReturn(singletonList(task01)).anyTimes();
+        expect(standbyTaskCreator.createTasks(eq(Collections.emptyMap()))).andReturn(Collections.emptySet());
         topologyBuilder.addSubscribedTopicsFromAssignment(eq(asList(t1p0)), anyString());
         expectLastCall().anyTimes();
         replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader, topologyBuilder);
@@ -1217,8 +1285,9 @@ public class TaskManagerTest {
         // expect these calls twice (because we're going to tryToCompleteRestoration twice)
         expectRestoreToBeCompleted(consumer, changeLogReader, false);
         expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andReturn(singletonList(task00));
-        replay(activeTaskCreator, consumer, changeLogReader);
-
+        expect(activeTaskCreator.createTasks(anyObject(), eq(Collections.emptyMap()))).andReturn(Collections.emptySet());
+        expect(standbyTaskCreator.createTasks(anyObject())).andStubReturn(Collections.emptySet());
+        replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true));
@@ -1355,10 +1424,11 @@ public class TaskManagerTest {
 
         expectRestoreToBeCompleted(consumer, changeLogReader);
         expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andReturn(singletonList(task00));
+        expect(standbyTaskCreator.createTasks(anyObject())).andStubReturn(Collections.emptySet());
         consumer.commitSync(offsets);
         expectLastCall();
 
-        replay(activeTaskCreator, consumer, changeLogReader);
+        replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true));
@@ -1519,7 +1589,9 @@ public class TaskManagerTest {
         expectRestoreToBeCompleted(consumer, changeLogReader);
 
         expect(activeTaskCreator.createTasks(anyObject(), eq(assignmentActive))).andReturn(singleton(task00));
+        expect(activeTaskCreator.createTasks(anyObject(), eq(Collections.emptyMap()))).andReturn(Collections.emptySet());
         expect(standbyTaskCreator.createTasks(eq(assignmentStandby))).andReturn(singletonList(task10));
+        expect(standbyTaskCreator.createTasks(eq(Collections.emptyMap()))).andReturn(Collections.emptySet());
         topologyBuilder.addSubscribedTopicsFromAssignment(eq(asList(t1p0)), anyString());
         expectLastCall().anyTimes();
 
@@ -1551,7 +1623,9 @@ public class TaskManagerTest {
         expectRestoreToBeCompleted(consumer, changeLogReader);
 
         expect(activeTaskCreator.createTasks(anyObject(), eq(assignmentActive))).andReturn(singleton(task00));
+        expect(activeTaskCreator.createTasks(anyObject(), eq(Collections.emptyMap()))).andReturn(Collections.emptySet());
         expect(standbyTaskCreator.createTasks(eq(assignmentStandby))).andReturn(singletonList(task10));
+        expect(standbyTaskCreator.createTasks(eq(Collections.emptyMap()))).andReturn(Collections.emptySet());
 
         replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
 
@@ -1570,8 +1644,11 @@ public class TaskManagerTest {
         final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true);
 
         expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andReturn(singletonList(task00));
+        expect(standbyTaskCreator.createTasks(anyObject())).andStubReturn(Collections.emptySet());
         activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId00));
-        replay(activeTaskCreator, consumer, changeLogReader);
+        expectLastCall().once();
+        expect(activeTaskCreator.createTasks(anyObject(), eq(Collections.emptyMap()))).andReturn(Collections.emptySet());
+        replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(task00.state(), is(Task.State.CREATED));
@@ -1595,8 +1672,8 @@ public class TaskManagerTest {
 
         expectRestoreToBeCompleted(consumer, changeLogReader);
         expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andReturn(singletonList(task00));
-
-        replay(activeTaskCreator, consumer, changeLogReader);
+        expect(standbyTaskCreator.createTasks(anyObject())).andStubReturn(Collections.emptySet());
+        replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true));
         assertThat(task00.state(), is(Task.State.RUNNING));
@@ -1968,6 +2045,7 @@ public class TaskManagerTest {
         final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, false);
 
         // `handleAssignment`
+        expect(activeTaskCreator.createTasks(anyObject(), anyObject())).andStubReturn(Collections.emptySet());
         expect(standbyTaskCreator.createTasks(eq(assignment))).andStubReturn(singletonList(task00));
 
         // `tryToCompleteRestoration`
@@ -2006,7 +2084,8 @@ public class TaskManagerTest {
         expectRestoreToBeCompleted(consumer, changeLogReader);
         expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment)))
             .andStubReturn(singletonList(task00));
-        replay(activeTaskCreator, consumer, changeLogReader);
+        expect(standbyTaskCreator.createTasks(anyObject())).andStubReturn(Collections.emptySet());
+        replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true));
@@ -2023,10 +2102,10 @@ public class TaskManagerTest {
         final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, false);
 
         expectRestoreToBeCompleted(consumer, changeLogReader);
-        expect(standbyTaskCreator.createTasks(eq(taskId01Assignment)))
-            .andStubReturn(singletonList(task01));
+        expect(activeTaskCreator.createTasks(anyObject(), anyObject())).andStubReturn(Collections.emptySet());
+        expect(standbyTaskCreator.createTasks(eq(taskId01Assignment))).andStubReturn(singletonList(task01));
 
-        replay(standbyTaskCreator, consumer, changeLogReader);
+        replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
 
         taskManager.handleAssignment(emptyMap(), taskId01Assignment);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true));
@@ -2137,8 +2216,8 @@ public class TaskManagerTest {
         final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, false);
 
         expectRestoreToBeCompleted(consumer, changeLogReader);
-        expect(standbyTaskCreator.createTasks(eq(taskId00Assignment)))
-            .andStubReturn(singletonList(task00));
+        expect(activeTaskCreator.createTasks(anyObject(), anyObject())).andStubReturn(Collections.emptySet());
+        expect(standbyTaskCreator.createTasks(eq(taskId00Assignment))).andStubReturn(singletonList(task00));
         expectLastCall();
 
         replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
@@ -2210,7 +2289,7 @@ public class TaskManagerTest {
 
     @Test
     public void shouldCommitViaProducerIfEosAlphaEnabled() {
-        final StreamsProducer producer = mock(StreamsProducer.class);
+        final StreamsProducer producer = EasyMock.mock(StreamsProducer.class);
         expect(activeTaskCreator.streamsProducerForTask(anyObject(TaskId.class)))
             .andReturn(producer)
             .andReturn(producer);
@@ -2228,7 +2307,7 @@ public class TaskManagerTest {
 
     @Test
     public void shouldCommitViaProducerIfEosV2Enabled() {
-        final StreamsProducer producer = mock(StreamsProducer.class);
+        final StreamsProducer producer = EasyMock.mock(StreamsProducer.class);
         expect(activeTaskCreator.threadProducer()).andReturn(producer);
 
         final Map<TopicPartition, OffsetAndMetadata> offsetsT01 = singletonMap(t1p1, new OffsetAndMetadata(0L, null));
@@ -2277,10 +2356,9 @@ public class TaskManagerTest {
         };
 
         expectRestoreToBeCompleted(consumer, changeLogReader);
-        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment)))
-            .andStubReturn(singletonList(task00));
-
-        replay(activeTaskCreator, consumer, changeLogReader);
+        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andStubReturn(singletonList(task00));
+        expect(standbyTaskCreator.createTasks(anyObject())).andStubReturn(Collections.emptySet());
+        replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true));
@@ -2304,10 +2382,10 @@ public class TaskManagerTest {
         };
 
         expectRestoreToBeCompleted(consumer, changeLogReader);
-        expect(standbyTaskCreator.createTasks(eq(taskId01Assignment)))
-            .andStubReturn(singletonList(task01));
+        expect(activeTaskCreator.createTasks(anyObject(), anyObject())).andStubReturn(Collections.emptySet());
+        expect(standbyTaskCreator.createTasks(eq(taskId01Assignment))).andStubReturn(singletonList(task01));
 
-        replay(standbyTaskCreator, consumer, changeLogReader);
+        replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
 
         taskManager.handleAssignment(emptyMap(), taskId01Assignment);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true));
@@ -2339,10 +2417,10 @@ public class TaskManagerTest {
         };
 
         expectRestoreToBeCompleted(consumer, changeLogReader);
-        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment)))
-            .andStubReturn(singletonList(task00));
+        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andStubReturn(singletonList(task00));
+        expect(standbyTaskCreator.createTasks(anyObject())).andStubReturn(Collections.emptySet());
 
-        replay(activeTaskCreator, consumer, changeLogReader);
+        replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true));
@@ -2375,10 +2453,9 @@ public class TaskManagerTest {
         };
 
         expectRestoreToBeCompleted(consumer, changeLogReader);
-        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment)))
-            .andStubReturn(singletonList(task00));
-
-        replay(activeTaskCreator, consumer, changeLogReader);
+        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andStubReturn(singletonList(task00));
+        expect(standbyTaskCreator.createTasks(anyObject())).andStubReturn(Collections.emptySet());
+        replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true));
@@ -2402,8 +2479,6 @@ public class TaskManagerTest {
         final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true);
 
         expectRestoreToBeCompleted(consumer, changeLogReader);
-        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment)))
-            .andStubReturn(singletonList(task00));
 
         final KafkaFutureImpl<DeletedRecords> futureDeletedRecords = new KafkaFutureImpl<>();
         final DeleteRecordsResult deleteRecordsResult = new DeleteRecordsResult(singletonMap(t1p1, futureDeletedRecords));
@@ -2412,7 +2487,7 @@ public class TaskManagerTest {
 
         replay(activeTaskCreator, adminClient, consumer, changeLogReader);
 
-        taskManager.handleAssignment(taskId00Assignment, emptyMap());
+        taskManager.addTask(task00);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true));
 
         assertThat(task00.state(), is(Task.State.RUNNING));
@@ -2499,10 +2574,9 @@ public class TaskManagerTest {
         assignment.put(taskId01, taskId01Partitions);
 
         expectRestoreToBeCompleted(consumer, changeLogReader);
-        expect(activeTaskCreator.createTasks(anyObject(), eq(assignment)))
-            .andStubReturn(Arrays.asList(task00, task01));
-
-        replay(activeTaskCreator, consumer, changeLogReader);
+        expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andStubReturn(Arrays.asList(task00, task01));
+        expect(standbyTaskCreator.createTasks(anyObject())).andStubReturn(Collections.emptySet());
+        replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
 
         taskManager.handleAssignment(assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true));
@@ -2613,10 +2687,9 @@ public class TaskManagerTest {
         };
 
         expectRestoreToBeCompleted(consumer, changeLogReader);
-        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment)))
-            .andStubReturn(singletonList(task00));
-
-        replay(activeTaskCreator, consumer, changeLogReader);
+        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andStubReturn(singletonList(task00));
+        expect(standbyTaskCreator.createTasks(anyObject())).andStubReturn(Collections.emptySet());
+        replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true));
@@ -2641,8 +2714,8 @@ public class TaskManagerTest {
         expectRestoreToBeCompleted(consumer, changeLogReader);
         expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment)))
             .andStubReturn(singletonList(task00));
-
-        replay(activeTaskCreator, consumer, changeLogReader);
+        expect(standbyTaskCreator.createTasks(anyObject())).andStubReturn(Collections.emptySet());
+        replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true));
@@ -2668,10 +2741,9 @@ public class TaskManagerTest {
         };
 
         expectRestoreToBeCompleted(consumer, changeLogReader);
-        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment)))
-            .andStubReturn(singletonList(task00));
-
-        replay(activeTaskCreator, consumer, changeLogReader);
+        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andStubReturn(singletonList(task00));
+        expect(standbyTaskCreator.createTasks(anyObject())).andStubReturn(Collections.emptySet());
+        replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true));
@@ -2691,10 +2763,9 @@ public class TaskManagerTest {
         };
 
         expectRestoreToBeCompleted(consumer, changeLogReader);
-        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment)))
-            .andStubReturn(singletonList(task00));
-
-        replay(activeTaskCreator, consumer, changeLogReader);
+        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andStubReturn(singletonList(task00));
+        expect(standbyTaskCreator.createTasks(anyObject())).andStubReturn(Collections.emptySet());
+        replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true));
@@ -2721,8 +2792,9 @@ public class TaskManagerTest {
         expectRestoreToBeCompleted(consumer, changeLogReader);
         expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment)))
             .andStubReturn(singletonList(task00));
-
-        replay(activeTaskCreator, consumer, changeLogReader);
+        expect(standbyTaskCreator.createTasks(anyObject()))
+            .andStubReturn(Collections.emptySet());
+        replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true));
@@ -2743,10 +2815,9 @@ public class TaskManagerTest {
         };
 
         expect(changeLogReader.completedChangelogs()).andReturn(emptySet());
-        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment)))
-            .andStubReturn(singletonList(task00));
-
-        replay(activeTaskCreator, changeLogReader, consumer);
+        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andStubReturn(singletonList(task00));
+        expect(standbyTaskCreator.createTasks(anyObject())).andStubReturn(Collections.emptySet());
+        replay(activeTaskCreator, standbyTaskCreator, changeLogReader, consumer);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(false));
@@ -2764,10 +2835,11 @@ public class TaskManagerTest {
 
         expectRestoreToBeCompleted(consumer, changeLogReader);
         expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andReturn(singletonList(task00));
+        expect(standbyTaskCreator.createTasks(anyObject())).andStubReturn(Collections.emptySet());
         consumer.commitSync(offsets);
         expectLastCall();
 
-        replay(activeTaskCreator, consumer, changeLogReader);
+        replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
 
         try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(TaskManager.class)) {
             LogCaptureAppender.setClassLoggerToDebug(TaskManager.class);
@@ -3039,15 +3111,14 @@ public class TaskManagerTest {
         final Map<TopicPartition, OffsetAndMetadata> offsetsT00 = singletonMap(t1p0, new OffsetAndMetadata(0L, null));
         final Map<TopicPartition, OffsetAndMetadata> offsetsT01 = singletonMap(t1p1, new OffsetAndMetadata(1L, null));
 
-        producer.commitTransaction(offsetsT00, null);
-        expectLastCall().andThrow(new TimeoutException("KABOOM!"));
-        producer.commitTransaction(offsetsT00, null);
-        expectLastCall();
-
-        producer.commitTransaction(offsetsT01, null);
-        expectLastCall();
-        producer.commitTransaction(offsetsT01, null);
-        expectLastCall();
+        doThrow(new TimeoutException("KABOOM!"))
+            .doNothing()
+            .doNothing()
+            .doNothing()
+            .when(producer).commitTransaction(offsetsT00, null);
+        doNothing()
+            .doNothing()
+            .when(producer).commitTransaction(offsetsT01, null);
 
         final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true);
         task00.setCommittableOffsetsAndMetadata(offsetsT00);
@@ -3056,7 +3127,7 @@ public class TaskManagerTest {
         final StateMachineTask task02 = new StateMachineTask(taskId02, taskId02Partitions, true);
 
         expect(consumer.groupMetadata()).andStubReturn(null);
-        replay(producer, activeTaskCreator, consumer);
+        replay(activeTaskCreator, consumer);
 
         task00.setCommitNeeded();
         task01.setCommitNeeded();
@@ -3085,10 +3156,7 @@ public class TaskManagerTest {
         final Map<TopicPartition, OffsetAndMetadata> allOffsets = new HashMap<>(offsetsT00);
         allOffsets.putAll(offsetsT01);
 
-        producer.commitTransaction(allOffsets, null);
-        expectLastCall().andThrow(new TimeoutException("KABOOM!"));
-        producer.commitTransaction(allOffsets, null);
-        expectLastCall();
+        doThrow(new TimeoutException("KABOOM!")).doNothing().when(producer).commitTransaction(allOffsets, null);
 
         final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true);
         task00.setCommittableOffsetsAndMetadata(offsetsT00);
@@ -3097,7 +3165,7 @@ public class TaskManagerTest {
         final StateMachineTask task02 = new StateMachineTask(taskId02, taskId02Partitions, true);
 
         expect(consumer.groupMetadata()).andStubReturn(null);
-        replay(producer, activeTaskCreator, consumer);
+        replay(activeTaskCreator, consumer);
 
         task00.setCommitNeeded();
         task01.setCommitNeeded();
@@ -3168,9 +3236,9 @@ public class TaskManagerTest {
 
         final Map<TaskId, Set<TopicPartition>> assignment = new HashMap<>(taskId00Assignment);
         assignment.putAll(taskId01Assignment);
-        expect(activeTaskCreator.createTasks(anyObject(), eq(assignment)))
-            .andReturn(asList(task00, task01));
-        replay(activeTaskCreator, consumer);
+        expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andReturn(asList(task00, task01));
+        expect(standbyTaskCreator.createTasks(anyObject())).andStubReturn(Collections.emptySet());
+        replay(activeTaskCreator, standbyTaskCreator, consumer);
 
         taskManager.handleAssignment(assignment, Collections.emptyMap());
 
@@ -3185,24 +3253,23 @@ public class TaskManagerTest {
 
     @Test
     public void shouldConvertActiveTaskToStandbyTask() {
-        final StreamTask activeTask = mock(StreamTask.class);
+        final StreamTask activeTask = EasyMock.mock(StreamTask.class);
         expect(activeTask.id()).andStubReturn(taskId00);
         expect(activeTask.inputPartitions()).andStubReturn(taskId00Partitions);
         expect(activeTask.isActive()).andStubReturn(true);
         expect(activeTask.prepareCommit()).andStubReturn(Collections.emptyMap());
 
-        final StandbyTask standbyTask = mock(StandbyTask.class);
+        final StandbyTask standbyTask = EasyMock.mock(StandbyTask.class);
         expect(standbyTask.id()).andStubReturn(taskId00);
 
-        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment)))
-            .andReturn(singletonList(activeTask));
+        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andReturn(singletonList(activeTask));
+        expect(standbyTaskCreator.createTasks(anyObject())).andStubReturn(Collections.emptySet());
+        activeTask.prepareRecycle();
+        expectLastCall().once();
         activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId00);
         expectLastCall().anyTimes();
-        activeTask.closeCleanAndRecycleState();
-        expectLastCall().once();
-
-        expect(standbyTaskCreator.createStandbyTaskFromActive(anyObject(), eq(taskId00Partitions)))
-            .andReturn(standbyTask);
+        expect(standbyTaskCreator.createStandbyTaskFromActive(anyObject(), eq(taskId00Partitions))).andReturn(standbyTask);
+        expect(activeTaskCreator.createTasks(anyObject(), eq(Collections.emptyMap()))).andReturn(Collections.emptySet());
 
         replay(activeTask, standbyTask, activeTaskCreator, standbyTaskCreator, consumer);
 
@@ -3215,27 +3282,21 @@ public class TaskManagerTest {
     @Test
     public void shouldConvertStandbyTaskToActiveTask() {
         final StandbyTask standbyTask = mock(StandbyTask.class);
-        expect(standbyTask.id()).andStubReturn(taskId00);
-        expect(standbyTask.isActive()).andStubReturn(false);
-        expect(standbyTask.prepareCommit()).andStubReturn(Collections.emptyMap());
-        standbyTask.suspend();
-        expectLastCall().anyTimes();
-        standbyTask.postCommit(true);
-        expectLastCall().anyTimes();
-        standbyTask.closeCleanAndRecycleState();
-        expectLastCall().once();
+        when(standbyTask.id()).thenReturn(taskId00);
+        when(standbyTask.isActive()).thenReturn(false);
+        when(standbyTask.prepareCommit()).thenReturn(Collections.emptyMap());
 
         final StreamTask activeTask = mock(StreamTask.class);
-        expect(activeTask.id()).andStubReturn(taskId00);
-        expect(activeTask.inputPartitions()).andStubReturn(taskId00Partitions);
-
-        expect(standbyTaskCreator.createTasks(eq(taskId00Assignment)))
-            .andReturn(singletonList(standbyTask));
+        when(activeTask.id()).thenReturn(taskId00);
+        when(activeTask.inputPartitions()).thenReturn(taskId00Partitions);
 
-        expect(activeTaskCreator.createActiveTaskFromStandby(anyObject(), eq(taskId00Partitions), anyObject()))
-            .andReturn(activeTask);
+        expect(activeTaskCreator.createTasks(anyObject(), eq(Collections.emptyMap()))).andReturn(Collections.emptySet());
+        expect(standbyTaskCreator.createTasks(eq(taskId00Assignment))).andReturn(singletonList(standbyTask));
+        expect(activeTaskCreator.createActiveTaskFromStandby(eq(standbyTask), eq(taskId00Partitions), anyObject())).andReturn(activeTask);
+        expect(activeTaskCreator.createTasks(anyObject(), eq(Collections.emptyMap()))).andReturn(Collections.emptySet());
+        expect(standbyTaskCreator.createTasks(eq(Collections.emptyMap()))).andReturn(Collections.emptySet());
 
-        replay(standbyTask, activeTask, standbyTaskCreator, activeTaskCreator, consumer);
+        replay(standbyTaskCreator, activeTaskCreator, consumer);
 
         taskManager.handleAssignment(Collections.emptyMap(), taskId00Assignment);
         taskManager.handleAssignment(taskId00Assignment, Collections.emptyMap());
@@ -3434,7 +3495,7 @@ public class TaskManagerTest {
         }
 
         @Override
-        public void closeCleanAndRecycleState() {
+        public void prepareRecycle() {
             transitionTo(State.CLOSED);
         }
 
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TasksTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TasksTest.java
index 09699776d0..756aa53f86 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TasksTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TasksTest.java
@@ -16,31 +16,19 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
-import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.streams.processor.TaskId;
 import org.junit.jupiter.api.Test;
 
-import java.util.Arrays;
 import java.util.Collections;
-import java.util.Map;
-import java.util.Set;
 
-import static org.apache.kafka.common.utils.Utils.mkEntry;
-import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.common.utils.Utils.mkSet;
 import static org.apache.kafka.test.StreamsTestUtils.TaskBuilder.standbyTask;
 import static org.apache.kafka.test.StreamsTestUtils.TaskBuilder.statefulTask;
 import static org.apache.kafka.test.StreamsTestUtils.TaskBuilder.statelessTask;
 import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertFalse;
-import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.never;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
 
 public class TasksTest {
 
@@ -51,69 +39,16 @@ public class TasksTest {
     private final static TaskId TASK_1_0 = new TaskId(1, 0);
 
     private final LogContext logContext = new LogContext();
-    private final ActiveTaskCreator activeTaskCreator = mock(ActiveTaskCreator.class);
-    private final StandbyTaskCreator standbyTaskCreator = mock(StandbyTaskCreator.class);
-    private final StateUpdater stateUpdater = mock(StateUpdater.class);
-
-    private Consumer<byte[], byte[]> mainConsumer = null;
-
-    @Test
-    public void shouldCreateTasksWithStateUpdater() {
-        final Tasks tasks = new Tasks(logContext, activeTaskCreator, standbyTaskCreator, stateUpdater);
-        tasks.setMainConsumer(mainConsumer);
-        final StreamTask statefulTask = statefulTask(TASK_0_0, mkSet(TOPIC_PARTITION_A_0)).build();
-        final StandbyTask standbyTask = standbyTask(TASK_0_1, mkSet(TOPIC_PARTITION_A_1)).build();
-        final StreamTask statelessTask = statelessTask(TASK_1_0).build();
-        final Map<TaskId, Set<TopicPartition>> activeTasks = mkMap(
-            mkEntry(statefulTask.id(), statefulTask.changelogPartitions()),
-            mkEntry(statelessTask.id(), statelessTask.changelogPartitions())
-        );
-        final Map<TaskId, Set<TopicPartition>> standbyTasks =
-            mkMap(mkEntry(standbyTask.id(), standbyTask.changelogPartitions()));
-        when(activeTaskCreator.createTasks(mainConsumer, activeTasks)).thenReturn(Arrays.asList(statefulTask, statelessTask));
-        when(standbyTaskCreator.createTasks(standbyTasks)).thenReturn(Collections.singletonList(standbyTask));
-
-        tasks.createTasks(activeTasks, standbyTasks);
-
-        final Exception exceptionForStatefulTaskOnTask = assertThrows(IllegalStateException.class, () -> tasks.task(statefulTask.id()));
-        assertEquals("Task unknown: " + statefulTask.id(), exceptionForStatefulTaskOnTask.getMessage());
-        assertFalse(tasks.activeTasks().contains(statefulTask));
-        assertFalse(tasks.allTasks().contains(statefulTask));
-        final Exception exceptionForStatefulTaskOnTasks = assertThrows(IllegalStateException.class, () -> tasks.tasks(mkSet(statefulTask.id())));
-        assertEquals("Task unknown: " + statefulTask.id(), exceptionForStatefulTaskOnTasks.getMessage());
-        final Exception exceptionForStatelessTaskOnTask = assertThrows(IllegalStateException.class, () -> tasks.task(statelessTask.id()));
-        assertEquals("Task unknown: " + statelessTask.id(), exceptionForStatelessTaskOnTask.getMessage());
-        assertFalse(tasks.activeTasks().contains(statelessTask));
-        assertFalse(tasks.allTasks().contains(statelessTask));
-        final Exception exceptionForStatelessTaskOnTasks = assertThrows(IllegalStateException.class, () -> tasks.tasks(mkSet(statelessTask.id())));
-        assertEquals("Task unknown: " + statelessTask.id(), exceptionForStatelessTaskOnTasks.getMessage());
-        final Exception exceptionForStandbyTaskOnTask = assertThrows(IllegalStateException.class, () -> tasks.task(standbyTask.id()));
-        assertEquals("Task unknown: " + standbyTask.id(), exceptionForStandbyTaskOnTask.getMessage());
-        assertFalse(tasks.allTasks().contains(standbyTask));
-        final Exception exceptionForStandByTaskOnTasks = assertThrows(IllegalStateException.class, () -> tasks.tasks(mkSet(standbyTask.id())));
-        assertEquals("Task unknown: " + standbyTask.id(), exceptionForStandByTaskOnTasks.getMessage());
-        verify(activeTaskCreator).createTasks(mainConsumer, activeTasks);
-        verify(standbyTaskCreator).createTasks(standbyTasks);
-        verify(stateUpdater).add(statefulTask);
-    }
 
     @Test
-    public void shouldCreateTasksWithoutStateUpdater() {
-        final Tasks tasks = new Tasks(logContext, activeTaskCreator, standbyTaskCreator, null);
-        tasks.setMainConsumer(mainConsumer);
+    public void shouldCreateTasks() {
+        final Tasks tasks = new Tasks(logContext);
         final StreamTask statefulTask = statefulTask(TASK_0_0, mkSet(TOPIC_PARTITION_A_0)).build();
         final StandbyTask standbyTask = standbyTask(TASK_0_1, mkSet(TOPIC_PARTITION_A_1)).build();
         final StreamTask statelessTask = statelessTask(TASK_1_0).build();
-        final Map<TaskId, Set<TopicPartition>> activeTasks = mkMap(
-            mkEntry(statefulTask.id(), statefulTask.changelogPartitions()),
-            mkEntry(statelessTask.id(), statelessTask.changelogPartitions())
-        );
-        final Map<TaskId, Set<TopicPartition>> standbyTasks =
-            mkMap(mkEntry(standbyTask.id(), standbyTask.changelogPartitions()));
-        when(activeTaskCreator.createTasks(mainConsumer, activeTasks)).thenReturn(Arrays.asList(statefulTask, statelessTask));
-        when(standbyTaskCreator.createTasks(standbyTasks)).thenReturn(Collections.singletonList(standbyTask));
 
-        tasks.createTasks(activeTasks, standbyTasks);
+        tasks.addNewActiveTasks(mkSet(statefulTask, statelessTask));
+        tasks.addNewStandbyTasks(Collections.singletonList(standbyTask));
 
         assertEquals(statefulTask, tasks.task(statefulTask.id()));
         assertTrue(tasks.activeTasks().contains(statefulTask));
@@ -126,8 +61,5 @@ public class TasksTest {
         assertEquals(standbyTask, tasks.task(standbyTask.id()));
         assertTrue(tasks.allTasks().contains(standbyTask));
         assertTrue(tasks.tasks(mkSet(standbyTask.id())).contains(standbyTask));
-        verify(activeTaskCreator).createTasks(mainConsumer, activeTasks);
-        verify(standbyTaskCreator).createTasks(standbyTasks);
-        verify(stateUpdater, never()).add(statefulTask);
     }
 }
\ No newline at end of file