You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by ca...@apache.org on 2022/07/21 13:11:48 UTC

[kafka] branch trunk updated: KAFKA-10199: Cleanup TaskManager and Task interfaces (#12397)

This is an automated email from the ASF dual-hosted git repository.

cadonna pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new c9b6e19b3b KAFKA-10199: Cleanup TaskManager and Task interfaces (#12397)
c9b6e19b3b is described below

commit c9b6e19b3b37499de17da19e82b1b98e3b9f6b5c
Author: Guozhang Wang <wa...@gmail.com>
AuthorDate: Thu Jul 21 06:11:40 2022 -0700

    KAFKA-10199: Cleanup TaskManager and Task interfaces (#12397)
    
    In order to integrate with the state updater, we would need to refactor the TaskManager and Task interfaces. This PR achieved the following purposes:
    
        Separate active and standby tasks in the Tasks placeholder, plus adding pendingActiveTasks and pendingStandbyTasks into Tasks. The exposed active/standby tasks from the Tasks set would only be mutated by a single thread, and the pending tasks hold for those tasks that are assigned but cannot be actively managed yet. For now they include two scenarios: a) tasks from unknown sub-topologies and hence cannot be initialized, b) tasks that are pending for being recycled from active to s [...]
    
        Extract any logic that mutates a task out of the Tasks / TaskCreators. Tasks should only be a place for maintaining the set of tasks, but not for manipulations of a task; and TaskCreators should only be used for creating the tasks, but not for anything else. These logic are all migrated into TaskManger.
    
        While doing 2) I noticed we have a couple of minor issues in the code where we duplicate the closing logics, so I also cleaned them up in the following way:
        a) When closing a task, we first trigger the corresponding closeClean/Dirty function; then we remove the task from Tasks bookkeeping, and for active task we also remove its task producer if EOS-V1 is used.
        b) For closing dirty, we swallow the exception from close call and the remove task producer call; for closing clean, we store the thrown exception from either close call or the remove task producer, and then rethrow at the end of the caller. The difference though is that, for the exception from close call we need to retry close it dirty; for the exception from the remove task producer we do not need to re-close it dirty.
    
    Reviewer: Bruno Cadonna <ca...@apache.org>
---
 .../streams/processor/internals/AbstractTask.java  |  17 +-
 .../processor/internals/ActiveTaskCreator.java     |  93 +++----
 .../processor/internals/PartitionGroup.java        |   3 +-
 .../processor/internals/ProcessorStateManager.java |   7 +-
 .../streams/processor/internals/StandbyTask.java   |  41 ++-
 .../processor/internals/StandbyTaskCreator.java    |  45 +---
 .../streams/processor/internals/StreamTask.java    |  44 +++-
 .../streams/processor/internals/StreamThread.java  |   9 +-
 .../processor/internals/TaskExecutionMetadata.java |  10 +
 .../streams/processor/internals/TaskExecutor.java  |  17 +-
 .../streams/processor/internals/TaskManager.java   | 179 +++++++------
 .../kafka/streams/processor/internals/Tasks.java   | 276 +++++++++++----------
 .../processor/internals/StreamThreadTest.java      |  33 +--
 .../processor/internals/TaskExecutorTest.java      |   4 +-
 .../processor/internals/TaskManagerTest.java       |  60 +++--
 .../streams/processor/internals/TasksTest.java     |  65 -----
 16 files changed, 452 insertions(+), 451 deletions(-)

diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractTask.java
index c64fadfe5c..f8476b3e8b 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractTask.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractTask.java
@@ -20,6 +20,7 @@ import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.errors.TimeoutException;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.TopologyConfig.TaskConfig;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.errors.TaskMigratedException;
 import org.apache.kafka.streams.processor.StateStore;
@@ -27,6 +28,7 @@ import org.apache.kafka.streams.processor.TaskId;
 import org.slf4j.Logger;
 
 import java.util.Collection;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -55,25 +57,25 @@ public abstract class AbstractTask implements Task {
     protected Map<TopicPartition, Long> offsetSnapshotSinceLastFlush = null;
 
     protected final TaskId id;
+    protected final TaskConfig config;
     protected final ProcessorTopology topology;
     protected final StateDirectory stateDirectory;
     protected final ProcessorStateManager stateMgr;
-    private final long taskTimeoutMs;
 
     AbstractTask(final TaskId id,
                  final ProcessorTopology topology,
                  final StateDirectory stateDirectory,
                  final ProcessorStateManager stateMgr,
                  final Set<TopicPartition> inputPartitions,
-                 final long taskTimeoutMs,
+                 final TaskConfig config,
                  final String taskType,
                  final Class<? extends AbstractTask> clazz) {
         this.id = id;
         this.stateMgr = stateMgr;
         this.topology = topology;
+        this.config = config;
         this.inputPartitions = inputPartitions;
         this.stateDirectory = stateDirectory;
-        this.taskTimeoutMs = taskTimeoutMs;
 
         final String threadIdPrefix = String.format("stream-thread [%s] ", Thread.currentThread().getName());
         logPrefix = threadIdPrefix + String.format("%s [%s] ", taskType, id);
@@ -106,7 +108,7 @@ public abstract class AbstractTask implements Task {
 
     @Override
     public Set<TopicPartition> inputPartitions() {
-        return inputPartitions;
+        return Collections.unmodifiableSet(inputPartitions);
     }
 
     @Override
@@ -151,7 +153,8 @@ public abstract class AbstractTask implements Task {
 
     @Override
     public void updateInputPartitions(final Set<TopicPartition> topicPartitions, final Map<String, List<String>> allTopologyNodesToSourceTopics) {
-        this.inputPartitions = topicPartitions;
+        this.inputPartitions.clear();
+        this.inputPartitions.addAll(topicPartitions);
         topology.updateSourceTopics(allTopologyNodesToSourceTopics);
     }
 
@@ -159,12 +162,12 @@ public abstract class AbstractTask implements Task {
     public void maybeInitTaskTimeoutOrThrow(final long currentWallClockMs,
                                             final Exception cause) {
         if (deadlineMs == NO_DEADLINE) {
-            deadlineMs = currentWallClockMs + taskTimeoutMs;
+            deadlineMs = currentWallClockMs + config.taskTimeoutMs;
         } else if (currentWallClockMs > deadlineMs) {
             final String errorMessage = String.format(
                 "Task %s did not make progress within %d ms. Adjust `%s` if needed.",
                 id,
-                currentWallClockMs - deadlineMs + taskTimeoutMs,
+                currentWallClockMs - deadlineMs + config.taskTimeoutMs,
                 StreamsConfig.TASK_TIMEOUT_MS_CONFIG
             );
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
index e7832df6b1..ef21a51c99 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
@@ -28,7 +28,6 @@ import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.internals.StreamsConfigUtils.ProcessingMode;
 import org.apache.kafka.streams.processor.TaskId;
-import org.apache.kafka.streams.processor.internals.Task.TaskType;
 import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
 import org.apache.kafka.streams.processor.internals.metrics.ThreadMetrics;
 import org.apache.kafka.streams.state.internals.ThreadCache;
@@ -44,7 +43,6 @@ import java.util.Set;
 import java.util.UUID;
 import java.util.stream.Collectors;
 
-import static org.apache.kafka.common.utils.Utils.filterMap;
 import static org.apache.kafka.streams.internals.StreamsConfigUtils.ProcessingMode.EXACTLY_ONCE_ALPHA;
 import static org.apache.kafka.streams.internals.StreamsConfigUtils.ProcessingMode.EXACTLY_ONCE_V2;
 import static org.apache.kafka.streams.internals.StreamsConfigUtils.eosEnabled;
@@ -68,11 +66,6 @@ class ActiveTaskCreator {
     private final Map<TaskId, StreamsProducer> taskProducers;
     private final ProcessingMode processingMode;
 
-    // Tasks may have been assigned for a NamedTopology that is not yet known by this host. When that occurs we stash
-    // these unknown tasks until either the corresponding NamedTopology is added and we can create them at last, or
-    // we receive a new assignment and they are revoked from the thread.
-    private final Map<TaskId, Set<TopicPartition>> unknownTasksToBeCreated = new HashMap<>();
-
     ActiveTaskCreator(final TopologyMetadata topologyMetadata,
                       final StreamsConfig applicationConfig,
                       final StreamsMetricsImpl streamsMetrics,
@@ -142,33 +135,16 @@ class ActiveTaskCreator {
         return threadProducer;
     }
 
-    void removeRevokedUnknownTasks(final Set<TaskId> assignedTasks) {
-        unknownTasksToBeCreated.keySet().retainAll(assignedTasks);
-    }
-
-    Map<TaskId, Set<TopicPartition>> uncreatedTasksForTopologies(final Set<String> currentTopologies) {
-        return filterMap(unknownTasksToBeCreated, t -> currentTopologies.contains(t.getKey().topologyName()));
-    }
-
     // TODO: change return type to `StreamTask`
     public Collection<Task> createTasks(final Consumer<byte[], byte[]> consumer,
                                  final Map<TaskId, Set<TopicPartition>> tasksToBeCreated) {
         // TODO: change type to `StreamTask`
         final List<Task> createdTasks = new ArrayList<>();
-        final Map<TaskId, Set<TopicPartition>> newUnknownTasks = new HashMap<>();
 
         for (final Map.Entry<TaskId, Set<TopicPartition>> newTaskAndPartitions : tasksToBeCreated.entrySet()) {
             final TaskId taskId = newTaskAndPartitions.getKey();
-            final Set<TopicPartition> partitions = newTaskAndPartitions.getValue();
-
             final LogContext logContext = getLogContext(taskId);
-
-            // task belongs to a named topology that hasn't been added yet, wait until it has to create this
-            if (taskId.topologyName() != null && !topologyMetadata.namedTopologiesView().contains(taskId.topologyName())) {
-                newUnknownTasks.put(taskId, partitions);
-                continue;
-            }
-
+            final Set<TopicPartition> partitions = newTaskAndPartitions.getValue();
             final ProcessorTopology topology = topologyMetadata.buildSubtopology(taskId);
 
             final ProcessorStateManager stateManager = new ProcessorStateManager(
@@ -182,7 +158,7 @@ class ActiveTaskCreator {
                 partitions
             );
 
-            final InternalProcessorContext context = new ProcessorContextImpl(
+            final InternalProcessorContext<Object, Object> context = new ProcessorContextImpl(
                 taskId,
                 applicationConfig,
                 stateManager,
@@ -201,44 +177,13 @@ class ActiveTaskCreator {
                     context
                 )
             );
-            unknownTasksToBeCreated.remove(taskId);
-        }
-        if (!newUnknownTasks.isEmpty()) {
-            log.info("Delaying creation of tasks not yet known by this instance: {}", newUnknownTasks.keySet());
-            unknownTasksToBeCreated.putAll(newUnknownTasks);
         }
         return createdTasks;
     }
 
-
-    StreamTask createActiveTaskFromStandby(final StandbyTask standbyTask,
-                                           final Set<TopicPartition> inputPartitions,
-                                           final Consumer<byte[], byte[]> consumer) {
-        final InternalProcessorContext context = standbyTask.processorContext();
-        final ProcessorStateManager stateManager = standbyTask.stateMgr;
-        final LogContext logContext = getLogContext(standbyTask.id);
-
-        standbyTask.closeCleanAndRecycleState();
-        stateManager.transitionTaskType(TaskType.ACTIVE, logContext);
-
-        return createActiveTask(
-            standbyTask.id,
-            inputPartitions,
-            consumer,
-            logContext,
-            topologyMetadata.buildSubtopology(standbyTask.id),
-            stateManager,
-            context
-        );
-    }
-
-    private StreamTask createActiveTask(final TaskId taskId,
-                                        final Set<TopicPartition> inputPartitions,
-                                        final Consumer<byte[], byte[]> consumer,
-                                        final LogContext logContext,
-                                        final ProcessorTopology topology,
-                                        final ProcessorStateManager stateManager,
-                                        final InternalProcessorContext context) {
+    private RecordCollector createRecordCollector(final TaskId taskId,
+                                                  final LogContext logContext,
+                                                  final ProcessorTopology topology) {
         final StreamsProducer streamsProducer;
         if (processingMode == ProcessingMode.EXACTLY_ONCE_ALPHA) {
             log.info("Creating producer client for task {}", taskId);
@@ -249,13 +194,14 @@ class ActiveTaskCreator {
                 taskId,
                 null,
                 logContext,
-                time);
+                time
+            );
             taskProducers.put(taskId, streamsProducer);
         } else {
             streamsProducer = threadProducer;
         }
 
-        final RecordCollector recordCollector = new RecordCollectorImpl(
+        return new RecordCollectorImpl(
             logContext,
             taskId,
             streamsProducer,
@@ -263,6 +209,27 @@ class ActiveTaskCreator {
             streamsMetrics,
             topology
         );
+    }
+
+    StreamTask createActiveTaskFromStandby(final StandbyTask standbyTask,
+                                           final Set<TopicPartition> inputPartitions,
+                                           final Consumer<byte[], byte[]> consumer) {
+        final RecordCollector recordCollector = createRecordCollector(standbyTask.id, getLogContext(standbyTask.id), standbyTask.topology);
+        final StreamTask task = standbyTask.recycle(time, cache, recordCollector, inputPartitions, consumer);
+
+        log.trace("Created active task {} with assigned partitions {}", task.id, inputPartitions);
+        createTaskSensor.record();
+        return task;
+    }
+
+    private StreamTask createActiveTask(final TaskId taskId,
+                                        final Set<TopicPartition> inputPartitions,
+                                        final Consumer<byte[], byte[]> consumer,
+                                        final LogContext logContext,
+                                        final ProcessorTopology topology,
+                                        final ProcessorStateManager stateManager,
+                                        final InternalProcessorContext context) {
+        final RecordCollector recordCollector = createRecordCollector(taskId, logContext, topology);
 
         final StreamTask task = new StreamTask(
             taskId,
@@ -280,7 +247,7 @@ class ActiveTaskCreator {
             logContext
         );
 
-        log.trace("Created task {} with assigned partitions {}", taskId, inputPartitions);
+        log.trace("Created active task {} with assigned partitions {}", taskId, inputPartitions);
         createTaskSensor.record();
         return task;
     }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/PartitionGroup.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/PartitionGroup.java
index f6b5d800fd..7ce538e66b 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/PartitionGroup.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/PartitionGroup.java
@@ -216,8 +216,9 @@ public class PartitionGroup {
     }
 
     // creates queues for new partitions, removes old queues, saves cached records for previously assigned partitions
-    void updatePartitions(final Set<TopicPartition> newInputPartitions, final Function<TopicPartition, RecordQueue> recordQueueCreator) {
+    void updatePartitions(final Set<TopicPartition> inputPartitions, final Function<TopicPartition, RecordQueue> recordQueueCreator) {
         final Set<TopicPartition> removedPartitions = new HashSet<>();
+        final Set<TopicPartition> newInputPartitions = new HashSet<>(inputPartitions);
         final Iterator<Map.Entry<TopicPartition, RecordQueue>> queuesIterator = partitionQueues.entrySet().iterator();
         while (queuesIterator.hasNext()) {
             final Map.Entry<TopicPartition, RecordQueue> queueEntry = queuesIterator.next();
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java
index 3c8c40fc32..6efd3124ef 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java
@@ -576,17 +576,12 @@ public class ProcessorStateManager implements StateManager {
         changelogReader.unregister(allChangelogs);
     }
 
-    void transitionTaskType(final TaskType newType, final LogContext logContext) {
+    void transitionTaskType(final TaskType newType) {
         if (taskType.equals(newType)) {
             throw new IllegalStateException("Tried to recycle state for task type conversion but new type was the same.");
         }
 
-        final TaskType oldType = taskType;
         taskType = newType;
-        log = logContext.logger(ProcessorStateManager.class);
-        logPrefix = logContext.logPrefix();
-
-        log.debug("Transitioning state manager for {} task {} to {}", oldType, taskId, newType);
     }
 
     @Override
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java
index 670c0c4beb..102a767495 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java
@@ -16,10 +16,12 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
+import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.clients.consumer.OffsetAndMetadata;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.metrics.Sensor;
+import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.StreamsMetrics;
 import org.apache.kafka.streams.errors.StreamsException;
@@ -68,7 +70,7 @@ public class StandbyTask extends AbstractTask implements Task {
             stateDirectory,
             stateMgr,
             inputPartitions,
-            config.taskTimeoutMs,
+            config,
             "standby-task",
             StandbyTask.class
         );
@@ -234,6 +236,43 @@ public class StandbyTask extends AbstractTask implements Task {
         log.info("Closed clean and recycled state");
     }
 
+    /**
+     * Create an active task from this standby task without closing and re-initializing the state stores.
+     * The task should have been in suspended state when calling this function
+     *
+     * TODO: we should be able to not need the input partitions as input param in future but always reuse
+     *       the task's input partitions when we have fixed partitions -> tasks mapping
+     */
+    public StreamTask recycle(final Time time,
+                              final ThreadCache cache,
+                              final RecordCollector recordCollector,
+                              final Set<TopicPartition> inputPartitions,
+                              final Consumer<byte[], byte[]> mainConsumer) {
+        if (!inputPartitions.equals(this.inputPartitions)) {
+            log.warn("Detected unmatched input partitions for task {} when recycling it from active to standby", id);
+        }
+
+        stateMgr.transitionTaskType(TaskType.ACTIVE);
+
+        log.debug("Recycling standby task {} to active", id);
+
+        return new StreamTask(
+            id,
+            inputPartitions,
+            topology,
+            mainConsumer,
+            config,
+            streamsMetrics,
+            stateDirectory,
+            cache,
+            time,
+            stateMgr,
+            recordCollector,
+            processorContext,
+            logContext
+        );
+    }
+
     private void close(final boolean clean) {
         switch (state()) {
             case SUSPENDED:
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java
index 43ebd40a35..59e2aa0285 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java
@@ -21,7 +21,6 @@ import org.apache.kafka.common.metrics.Sensor;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.processor.TaskId;
-import org.apache.kafka.streams.processor.internals.Task.TaskType;
 import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
 import org.apache.kafka.streams.processor.internals.metrics.ThreadMetrics;
 import org.apache.kafka.streams.state.internals.ThreadCache;
@@ -29,12 +28,10 @@ import org.slf4j.Logger;
 
 import java.util.ArrayList;
 import java.util.Collection;
-import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
 
-import static org.apache.kafka.common.utils.Utils.filterMap;
 import static org.apache.kafka.streams.internals.StreamsConfigUtils.eosEnabled;
 
 class StandbyTaskCreator {
@@ -47,9 +44,6 @@ class StandbyTaskCreator {
     private final Logger log;
     private final Sensor createTaskSensor;
 
-    // tasks may be assigned for a NamedTopology that is not yet known by this host, and saved for later creation
-    private final Map<TaskId, Set<TopicPartition>> unknownTasksToBeCreated = new HashMap<>();
-
     StandbyTaskCreator(final TopologyMetadata topologyMetadata,
                        final StreamsConfig applicationConfig,
                        final StreamsMetricsImpl streamsMetrics,
@@ -73,30 +67,14 @@ class StandbyTaskCreator {
         );
     }
 
-    void removeRevokedUnknownTasks(final Set<TaskId> assignedTasks) {
-        unknownTasksToBeCreated.keySet().retainAll(assignedTasks);
-    }
-
-    Map<TaskId, Set<TopicPartition>> uncreatedTasksForTopologies(final Set<String> currentTopologies) {
-        return filterMap(unknownTasksToBeCreated, t -> currentTopologies.contains(t.getKey().topologyName()));
-    }
-
     // TODO: change return type to `StandbyTask`
     Collection<Task> createTasks(final Map<TaskId, Set<TopicPartition>> tasksToBeCreated) {
         // TODO: change type to `StandbyTask`
         final List<Task> createdTasks = new ArrayList<>();
-        final Map<TaskId, Set<TopicPartition>>  newUnknownTasks = new HashMap<>();
 
         for (final Map.Entry<TaskId, Set<TopicPartition>> newTaskAndPartitions : tasksToBeCreated.entrySet()) {
             final TaskId taskId = newTaskAndPartitions.getKey();
             final Set<TopicPartition> partitions = newTaskAndPartitions.getValue();
-
-            // task belongs to a named topology that hasn't been added yet, wait until it has to create this
-            if (taskId.topologyName() != null && !topologyMetadata.namedTopologiesView().contains(taskId.topologyName())) {
-                newUnknownTasks.put(taskId, partitions);
-                continue;
-            }
-
             final ProcessorTopology topology = topologyMetadata.buildSubtopology(taskId);
 
             if (topology.hasStateWithChangelogs()) {
@@ -111,7 +89,7 @@ class StandbyTaskCreator {
                     partitions
                 );
 
-                final InternalProcessorContext context = new ProcessorContextImpl(
+                final InternalProcessorContext<Object, Object> context = new ProcessorContextImpl(
                     taskId,
                     applicationConfig,
                     stateManager,
@@ -127,30 +105,17 @@ class StandbyTaskCreator {
                     taskId, partitions
                 );
             }
-            unknownTasksToBeCreated.remove(taskId);
-        }
-        if (!newUnknownTasks.isEmpty()) {
-            log.info("Delaying creation of tasks not yet known by this instance: {}", newUnknownTasks.keySet());
-            unknownTasksToBeCreated.putAll(newUnknownTasks);
         }
         return createdTasks;
     }
 
     StandbyTask createStandbyTaskFromActive(final StreamTask streamTask,
                                             final Set<TopicPartition> inputPartitions) {
-        final InternalProcessorContext context = streamTask.processorContext();
-        final ProcessorStateManager stateManager = streamTask.stateMgr;
-
-        streamTask.closeCleanAndRecycleState();
-        stateManager.transitionTaskType(TaskType.STANDBY, getLogContext(streamTask.id()));
+        final StandbyTask task = streamTask.recycle(inputPartitions);
 
-        return createStandbyTask(
-            streamTask.id(),
-            inputPartitions,
-            topologyMetadata.buildSubtopology(streamTask.id),
-            stateManager,
-            context
-        );
+        log.trace("Created task {} with assigned partitions {}", task.id, inputPartitions);
+        createTaskSensor.record();
+        return task;
     }
 
     StandbyTask createStandbyTask(final TaskId taskId,
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
index 5b62e3f579..0c7ea49e76 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
@@ -127,7 +127,7 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
             stateDirectory,
             stateMgr,
             inputPartitions,
-            config.taskTimeoutMs,
+            config,
             "task",
             StreamTask.class
         );
@@ -292,7 +292,6 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
                     partitionGroup.clear();
                 } finally {
                     transitToSuspend();
-                    log.info("Suspended running");
                 }
 
                 break;
@@ -573,6 +572,45 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
         log.info("Closed clean and recycled state");
     }
 
+    /**
+     * Create a standby task from this active task without closing and re-initializing the state stores.
+     * The task should have been in suspended state when calling this function
+     *
+     * TODO: we should be able to not need the input partitions as input param in future but always reuse
+     *       the task's input partitions when we have fixed partitions -> tasks mapping
+     */
+    public StandbyTask recycle(final Set<TopicPartition> inputPartitions) {
+        if (state() != Task.State.CLOSED) {
+            throw new IllegalStateException("Attempted to convert an active task that's not closed: " + id);
+        }
+
+        if (!inputPartitions.equals(this.inputPartitions)) {
+            log.warn("Detected unmatched input partitions for task {} when recycling it from active to standby", id);
+        }
+
+        stateMgr.transitionTaskType(TaskType.STANDBY);
+
+        final ThreadCache dummyCache = new ThreadCache(
+            new LogContext(String.format("stream-thread [%s] ", Thread.currentThread().getName())),
+            0,
+            streamsMetrics
+        );
+
+        log.debug("Recycling active task {} to standby", id);
+
+        return new StandbyTask(
+            id,
+            inputPartitions,
+            topology,
+            config,
+            streamsMetrics,
+            stateMgr,
+            stateDirectory,
+            dummyCache,
+            processorContext
+        );
+    }
+
     /**
      * The following exceptions maybe thrown from the state manager flushing call
      *
@@ -1217,7 +1255,7 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
     }
 
     private void transitToSuspend() {
-        log.info("Suspended {}", state());
+        log.info("Suspended from {}", state());
         transitionTo(State.SUSPENDED);
         timeCurrentIdlingStarted = Optional.of(System.currentTimeMillis());
     }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
index 8b6820fc22..74a49a7e8e 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
@@ -389,7 +389,6 @@ public class StreamThread extends Thread {
             changelogReader,
             processId,
             logPrefix,
-            streamsMetrics,
             activeTaskCreator,
             standbyTaskCreator,
             topologyMetadata,
@@ -864,7 +863,7 @@ public class StreamThread extends Thread {
             if (taskManager.tryToCompleteRestoration(now, partitions -> resetOffsets(partitions, null))) {
                 changelogReader.transitToUpdateStandby();
                 log.info("Restoration took {} ms for all tasks {}", time.milliseconds() - lastPartitionAssignedMs,
-                    taskManager.tasks().keySet());
+                    taskManager.allTasks().keySet());
                 setState(State.RUNNING);
             }
 
@@ -1064,7 +1063,7 @@ public class StreamThread extends Thread {
             }
 
             committed = taskManager.commit(
-                taskManager.tasks()
+                taskManager.allTasks()
                     .values()
                     .stream()
                     .filter(t -> t.state() == Task.State.RUNNING || t.state() == Task.State.RESTORING)
@@ -1124,7 +1123,7 @@ public class StreamThread extends Thread {
         // intentionally do not check the returned flag
         setState(State.PENDING_SHUTDOWN);
 
-        log.info("Shutting down");
+        log.info("Shutting down {}", cleanRun ? "clean" : "unclean");
 
         try {
             taskManager.shutdown(cleanRun);
@@ -1232,7 +1231,7 @@ public class StreamThread extends Thread {
     }
 
     public Map<TaskId, Task> allTasks() {
-        return taskManager.tasks();
+        return taskManager.allTasks();
     }
 
     /**
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskExecutionMetadata.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskExecutionMetadata.java
index 310cdef66e..48ea76b84b 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskExecutionMetadata.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskExecutionMetadata.java
@@ -59,6 +59,16 @@ public class TaskExecutionMetadata {
         }
     }
 
+    public boolean canPunctuateTask(final Task task) {
+        final String topologyName = task.id().topologyName();
+
+        if (topologyName == null) {
+            return !pausedTopologies.contains(UNNAMED_TOPOLOGY);
+        } else {
+            return !pausedTopologies.contains(topologyName);
+        }
+    }
+
     public void registerTaskError(final Task task, final Throwable t, final long now) {
         if (hasNamedTopologies) {
             final String topologyName = task.id().topologyName();
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskExecutor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskExecutor.java
index bab8a75149..5aa6eb1fe2 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskExecutor.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskExecutor.java
@@ -82,7 +82,7 @@ public class TaskExecutor {
                 }
             } catch (final Throwable t) {
                 taskExecutionMetadata.registerTaskError(task, t, now);
-                tasks.removeTaskFromCuccessfullyProcessedBeforeClosing(lastProcessed);
+                tasks.removeTaskFromSuccessfullyProcessedBeforeClosing(lastProcessed);
                 commitSuccessfullyProcessedTasks();
                 throw t;
             }
@@ -163,6 +163,7 @@ public class TaskExecutor {
                 task.postCommit(false);
             }
         }
+
         return committed;
     }
 
@@ -275,13 +276,15 @@ public class TaskExecutor {
     int punctuate() {
         int punctuated = 0;
 
-        for (final Task task : tasks.notPausedActiveTasks()) {
+        for (final Task task : tasks.activeTasks()) {
             try {
-                if (task.maybePunctuateStreamTime()) {
-                    punctuated++;
-                }
-                if (task.maybePunctuateSystemTime()) {
-                    punctuated++;
+                if (taskExecutionMetadata.canPunctuateTask(task)) {
+                    if (task.maybePunctuateStreamTime()) {
+                        punctuated++;
+                    }
+                    if (task.maybePunctuateSystemTime()) {
+                        punctuated++;
+                    }
                 }
             } catch (final TaskMigratedException e) {
                 log.info("Failed to punctuate stream task {} since it got migrated to another thread already. " +
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
index d5a0ee3d03..4ea419ab9c 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
@@ -38,7 +38,6 @@ import org.apache.kafka.streams.internals.StreamsConfigUtils.ProcessingMode;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.internals.StateDirectory.TaskDirectory;
 import org.apache.kafka.streams.processor.internals.Task.State;
-import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
 import org.apache.kafka.streams.state.internals.OffsetCheckpoint;
 
 import org.slf4j.Logger;
@@ -74,14 +73,15 @@ public class TaskManager {
     // by QueryableState
     private final Logger log;
     private final Time time;
-    private final ChangelogReader changelogReader;
+    private final Tasks tasks;
     private final UUID processId;
     private final String logPrefix;
-    private final TopologyMetadata topologyMetadata;
     private final Admin adminClient;
     private final StateDirectory stateDirectory;
     private final ProcessingMode processingMode;
-    private final Tasks tasks;
+    private final ChangelogReader changelogReader;
+    private final TopologyMetadata topologyMetadata;
+
     private final TaskExecutor taskExecutor;
 
     private Consumer<byte[], byte[]> mainConsumer;
@@ -97,7 +97,6 @@ public class TaskManager {
                 final ChangelogReader changelogReader,
                 final UUID processId,
                 final String logPrefix,
-                final StreamsMetricsImpl streamsMetrics,
                 final ActiveTaskCreator activeTaskCreator,
                 final StandbyTaskCreator standbyTaskCreator,
                 final TopologyMetadata topologyMetadata,
@@ -115,7 +114,7 @@ public class TaskManager {
         final LogContext logContext = new LogContext(logPrefix);
         this.log = logContext.logger(getClass());
 
-        this.tasks = new Tasks(logContext, topologyMetadata, activeTaskCreator, standbyTaskCreator);
+        this.tasks = new Tasks(logContext, activeTaskCreator, standbyTaskCreator);
         this.taskExecutor = new TaskExecutor(
             tasks,
             topologyMetadata.taskExecutionMetadata(),
@@ -186,7 +185,7 @@ public class TaskManager {
 
         // We need to commit before closing the corrupted active tasks since this will force the ongoing txn to abort
         try {
-            final Collection<Task> tasksToCommit = tasks()
+            final Collection<Task> tasksToCommit = allTasks()
                 .values()
                 .stream()
                 .filter(t -> t.state() == Task.State.RUNNING || t.state() == Task.State.RESTORING)
@@ -285,34 +284,52 @@ public class TaskManager {
         final LinkedHashMap<TaskId, RuntimeException> taskCloseExceptions = new LinkedHashMap<>();
         final Map<TaskId, Set<TopicPartition>> activeTasksToCreate = new HashMap<>(activeTasks);
         final Map<TaskId, Set<TopicPartition>> standbyTasksToCreate = new HashMap<>(standbyTasks);
+        final Map<Task, Set<TopicPartition>> tasksToRecycle = new HashMap<>();
         final Comparator<Task> byId = Comparator.comparing(Task::id);
-        final Set<Task> tasksToRecycle = new TreeSet<>(byId);
         final Set<Task> tasksToCloseClean = new TreeSet<>(byId);
         final Set<Task> tasksToCloseDirty = new TreeSet<>(byId);
 
-        // first rectify all existing tasks
+        tasks.purgePendingTasks(activeTasks.keySet(), standbyTasks.keySet());
+
+        // first rectify all existing tasks:
+        // 1. for tasks that are already owned, just update input partitions / resume and skip re-creating them
+        // 2. for tasks that have changed active/standby status, just recycle and skip re-creating them
+        // 3. otherwise, close them since they are no longer owned
         for (final Task task : tasks.allTasks()) {
-            if (activeTasks.containsKey(task.id()) && task.isActive()) {
-                tasks.updateInputPartitionsAndResume(task, activeTasks.get(task.id()));
-                activeTasksToCreate.remove(task.id());
-            } else if (standbyTasks.containsKey(task.id()) && !task.isActive()) {
-                tasks.updateInputPartitionsAndResume(task, standbyTasks.get(task.id()));
-                standbyTasksToCreate.remove(task.id());
-            } else if (activeTasks.containsKey(task.id()) || standbyTasks.containsKey(task.id())) {
-                // check for tasks that were owned previously but have changed active/standby status
-                tasksToRecycle.add(task);
+            final TaskId taskId = task.id();
+            if (activeTasksToCreate.containsKey(taskId)) {
+                if (task.isActive()) {
+                    final Set<TopicPartition> topicPartitions = activeTasksToCreate.get(taskId);
+                    if (tasks.updateActiveTaskInputPartitions(task, topicPartitions)) {
+                        task.updateInputPartitions(topicPartitions, topologyMetadata.nodeToSourceTopics(task.id()));
+                    }
+                    task.resume();
+                } else {
+                    tasksToRecycle.put(task, activeTasksToCreate.get(taskId));
+                }
+                activeTasksToCreate.remove(taskId);
+            } else if (standbyTasksToCreate.containsKey(taskId)) {
+                if (!task.isActive()) {
+                    final Set<TopicPartition> topicPartitions = standbyTasksToCreate.get(taskId);
+                    task.updateInputPartitions(topicPartitions, topologyMetadata.nodeToSourceTopics(task.id()));
+                    task.resume();
+                } else {
+                    tasksToRecycle.put(task, standbyTasksToCreate.get(taskId));
+                }
+                standbyTasksToCreate.remove(taskId);
             } else {
                 tasksToCloseClean.add(task);
             }
         }
 
+        tasks.addActivePendingTasks(pendingTasksToCreate(activeTasksToCreate));
+        tasks.addStandbyPendingTasks(pendingTasksToCreate(standbyTasksToCreate));
+
         // close and recycle those tasks
-        handleCloseAndRecycle(
+        closeAndRecycleTasks(
             tasksToRecycle,
             tasksToCloseClean,
             tasksToCloseDirty,
-            activeTasksToCreate,
-            standbyTasksToCreate,
             taskCloseExceptions
         );
 
@@ -346,22 +363,34 @@ public class TaskManager {
             throw first.getValue();
         }
 
-        tasks.handleNewAssignmentAndCreateTasks(activeTasksToCreate, standbyTasksToCreate, activeTasks.keySet(), standbyTasks.keySet());
+        tasks.createTasks(activeTasksToCreate, standbyTasksToCreate);
+    }
+
+    private Map<TaskId, Set<TopicPartition>> pendingTasksToCreate(final Map<TaskId, Set<TopicPartition>> tasksToCreate) {
+        final Map<TaskId, Set<TopicPartition>> pendingTasks = new HashMap<>();
+        final Iterator<Map.Entry<TaskId, Set<TopicPartition>>> iter = tasksToCreate.entrySet().iterator();
+        while (iter.hasNext()) {
+            final Map.Entry<TaskId, Set<TopicPartition>> entry = iter.next();
+            final TaskId taskId = entry.getKey();
+            if (taskId.topologyName() != null && !topologyMetadata.namedTopologiesView().contains(taskId.topologyName())) {
+                pendingTasks.put(taskId, entry.getValue());
+                iter.remove();
+            }
+        }
+        return pendingTasks;
     }
 
-    private void handleCloseAndRecycle(final Set<Task> tasksToRecycle,
-                                       final Set<Task> tasksToCloseClean,
-                                       final Set<Task> tasksToCloseDirty,
-                                       final Map<TaskId, Set<TopicPartition>> activeTasksToCreate,
-                                       final Map<TaskId, Set<TopicPartition>> standbyTasksToCreate,
-                                       final LinkedHashMap<TaskId, RuntimeException> taskCloseExceptions) {
+    private void closeAndRecycleTasks(final Map<Task, Set<TopicPartition>> tasksToRecycle,
+                                      final Set<Task> tasksToCloseClean,
+                                      final Set<Task> tasksToCloseDirty,
+                                      final LinkedHashMap<TaskId, RuntimeException> taskCloseExceptions) {
         if (!tasksToCloseDirty.isEmpty()) {
             throw new IllegalArgumentException("Tasks to close-dirty should be empty");
         }
 
         // for all tasks to close or recycle, we should first write a checkpoint as in post-commit
         final List<Task> tasksToCheckpoint = new ArrayList<>(tasksToCloseClean);
-        tasksToCheckpoint.addAll(tasksToRecycle);
+        tasksToCheckpoint.addAll(tasksToRecycle.keySet());
         for (final Task task : tasksToCheckpoint) {
             try {
                 // Note that we are not actually committing here but just check if we need to write checkpoint file:
@@ -399,29 +428,29 @@ public class TaskManager {
         tasksToCloseClean.removeAll(tasksToCloseDirty);
         for (final Task task : tasksToCloseClean) {
             try {
-                completeTaskCloseClean(task);
-                if (task.isActive()) {
-                    tasks.cleanUpTaskProducerAndRemoveTask(task.id(), taskCloseExceptions);
+                final RuntimeException removeTaskException = completeTaskCloseClean(task);
+                if (removeTaskException != null) {
+                    taskCloseExceptions.putIfAbsent(task.id(), removeTaskException);
                 }
-            } catch (final RuntimeException e) {
+            } catch (final RuntimeException closeTaskException) {
                 final String uncleanMessage = String.format(
                         "Failed to close task %s cleanly. Attempting to close remaining tasks before re-throwing:",
                         task.id());
-                log.error(uncleanMessage, e);
-                taskCloseExceptions.putIfAbsent(task.id(), e);
+                log.error(uncleanMessage, closeTaskException);
+                taskCloseExceptions.putIfAbsent(task.id(), closeTaskException);
                 tasksToCloseDirty.add(task);
             }
         }
 
-        tasksToRecycle.removeAll(tasksToCloseDirty);
-        for (final Task oldTask : tasksToRecycle) {
+        tasksToRecycle.keySet().removeAll(tasksToCloseDirty);
+        for (final Map.Entry<Task, Set<TopicPartition>> entry : tasksToRecycle.entrySet()) {
+            final Task oldTask = entry.getKey();
             try {
+                oldTask.closeCleanAndRecycleState();
                 if (oldTask.isActive()) {
-                    final Set<TopicPartition> partitions = standbyTasksToCreate.remove(oldTask.id());
-                    tasks.convertActiveToStandby((StreamTask) oldTask, partitions, taskCloseExceptions);
+                    tasks.convertActiveToStandby((StreamTask) oldTask, entry.getValue(), taskCloseExceptions);
                 } else {
-                    final Set<TopicPartition> partitions = activeTasksToCreate.remove(oldTask.id());
-                    tasks.convertStandbyToActive((StandbyTask) oldTask, partitions);
+                    tasks.convertStandbyToActive((StandbyTask) oldTask, entry.getValue());
                 }
             } catch (final RuntimeException e) {
                 final String uncleanMessage = String.format("Failed to recycle task %s cleanly. Attempting to close remaining tasks before re-throwing:", oldTask.id());
@@ -434,7 +463,6 @@ public class TaskManager {
         // for tasks that cannot be cleanly closed or recycled, close them dirty
         for (final Task task : tasksToCloseDirty) {
             closeTaskDirty(task);
-            tasks.cleanUpTaskProducerAndRemoveTask(task.id(), taskCloseExceptions);
         }
     }
 
@@ -646,14 +674,12 @@ public class TaskManager {
     void handleLostAll() {
         log.debug("Closing lost active tasks as zombies.");
 
-        final Set<Task> allTask = new HashSet<>(tasks.allTasks());
+        final Set<Task> allTask = tasks.allTasks();
         for (final Task task : allTask) {
             // Even though we've apparently dropped out of the group, we can continue safely to maintain our
             // standby tasks while we rejoin.
             if (task.isActive()) {
                 closeTaskDirty(task);
-
-                tasks.cleanUpTaskProducerAndRemoveTask(task.id(), new HashMap<>());
             }
         }
 
@@ -673,7 +699,7 @@ public class TaskManager {
         // Not all tasks will create directories, and there may be directories for tasks we don't currently own,
         // so we consider all tasks that are either owned or on disk. This includes stateless tasks, which should
         // just have an empty changelogOffsets map.
-        for (final TaskId id : union(HashSet::new, lockedTaskDirectories, tasks.tasksPerId().keySet())) {
+        for (final TaskId id : union(HashSet::new, lockedTaskDirectories, tasks.allTaskIds())) {
             final Task task = tasks.owned(id) ? tasks.task(id) : null;
             // Closed and uninitialized tasks don't have any offsets so we should read directly from the checkpoint
             if (task != null && task.state() != State.CREATED && task.state() != State.CLOSED) {
@@ -796,15 +822,27 @@ public class TaskManager {
         try {
             task.suspend();
         } catch (final RuntimeException swallow) {
-            log.error("Error suspending dirty task {} ", task.id(), swallow);
+            log.error("Error suspending dirty task {}: {}", task.id(), swallow.getMessage());
         }
-        tasks.removeTaskBeforeClosing(task.id());
+
         task.closeDirty();
+
+        try {
+            tasks.removeTask(task);
+        } catch (final RuntimeException swallow) {
+            log.error("Error removing dirty task {}: {}", task.id(), swallow.getMessage());
+        }
     }
 
-    private void completeTaskCloseClean(final Task task) {
-        tasks.removeTaskBeforeClosing(task.id());
+    private RuntimeException completeTaskCloseClean(final Task task) {
         task.closeClean();
+        try {
+            tasks.removeTask(task);
+        } catch (final RuntimeException e) {
+            log.error("Error removing active task {}: {}", task.id(), e.getMessage());
+            return e;
+        }
+        return null;
     }
 
     void shutdown(final boolean clean) {
@@ -859,16 +897,6 @@ public class TaskManager {
             closeTaskDirty(task);
         }
 
-        // TODO: change type to `StreamTask`
-        for (final Task activeTask : activeTasks) {
-            executeAndMaybeSwallow(
-                clean,
-                () -> tasks.closeAndRemoveTaskProducerIfNeeded(activeTask),
-                e -> firstException.compareAndSet(null, e),
-                e -> log.warn("Ignoring an exception while closing task " + activeTask.id() + " producer.", e)
-            );
-        }
-
         final RuntimeException exception = firstException.get();
         if (exception != null) {
             throw exception;
@@ -957,7 +985,10 @@ public class TaskManager {
         for (final Task task : tasksToCloseClean) {
             try {
                 task.suspend();
-                completeTaskCloseClean(task);
+                final RuntimeException exception = completeTaskCloseClean(task);
+                if (exception != null) {
+                    firstException.compareAndSet(null, exception);
+                }
             } catch (final StreamsException e) {
                 log.error("Exception caught while clean-closing task " + task.id(), e);
                 e.setTaskId(task.id());
@@ -988,7 +1019,10 @@ public class TaskManager {
                 task.prepareCommit();
                 task.postCommit(true);
                 task.suspend();
-                completeTaskCloseClean(task);
+                final RuntimeException exception = completeTaskCloseClean(task);
+                if (exception != null) {
+                    maybeWrapAndSetFirstException(firstException, exception, task.id());
+                }
             } catch (final TaskMigratedException e) {
                 // just ignore the exception as it doesn't matter during shutdown
                 tasksToCloseDirty.add(task);
@@ -1012,14 +1046,16 @@ public class TaskManager {
             .collect(Collectors.toSet());
     }
 
-    Map<TaskId, Task> tasks() {
+    Map<TaskId, Task> allTasks() {
         // not bothering with an unmodifiable map, since the tasks themselves are mutable, but
         // if any outside code modifies the map or the tasks, it would be a severe transgression.
-        return tasks.tasksPerId();
+        return tasks.allTasksPerId();
     }
 
     Map<TaskId, Task> notPausedTasks() {
-        return Collections.unmodifiableMap(tasks.notPausedTasks().stream()
+        return Collections.unmodifiableMap(tasks.allTasks()
+            .stream()
+            .filter(t -> !topologyMetadata.isPaused(t.id().topologyName()))
             .collect(Collectors.toMap(Task::id, v -> v)));
     }
 
@@ -1049,7 +1085,7 @@ public class TaskManager {
 
     // For testing only.
     int commitAll() {
-        return commit(new HashSet<>(tasks.allTasks()));
+        return commit(tasks.allTasks());
     }
 
     /**
@@ -1135,7 +1171,7 @@ public class TaskManager {
      */
     void handleTopologyUpdates() {
         topologyMetadata.executeTopologyUpdatesAndBumpThreadVersion(
-            tasks::maybeCreateTasksFromNewTopologies,
+            tasks::createPendingTasks,
             this::maybeCloseTasksFromRemovedTopologies
         );
 
@@ -1161,11 +1197,10 @@ public class TaskManager {
                 }
             }
 
-            final Set<TaskId> allRemovedTasks =
-                union(HashSet::new, activeTasksToRemove, standbyTasksToRemove).stream().map(Task::id).collect(Collectors.toSet());
+            final Set<Task> allTasksToRemove = union(HashSet::new, activeTasksToRemove, standbyTasksToRemove);
             closeAndCleanUpTasks(activeTasksToRemove, standbyTasksToRemove, true);
-            allRemovedTasks.forEach(tasks::removeTaskBeforeClosing);
-            releaseLockedDirectoriesForTasks(allRemovedTasks);
+            allTasksToRemove.forEach(tasks::removeTask);
+            releaseLockedDirectoriesForTasks(allTasksToRemove.stream().map(Task::id).collect(Collectors.toSet()));
         } catch (final Exception e) {
             // TODO KAFKA-12648: for now just swallow the exception to avoid interfering with the other topologies
             //  that are running alongside, but eventually we should be able to rethrow up to the handler to inform
@@ -1300,7 +1335,7 @@ public class TaskManager {
     }
 
     boolean needsInitializationOrRestoration() {
-        return tasks().values().stream().anyMatch(Task::needsInitializationOrRestoration);
+        return activeTaskIterable().stream().anyMatch(Task::needsInitializationOrRestoration);
     }
 
     // for testing only
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java
index ca5481b67b..fbb45c5940 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java
@@ -24,7 +24,6 @@ import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.streams.processor.TaskId;
 import org.slf4j.Logger;
 
-import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
@@ -34,29 +33,40 @@ import java.util.Set;
 import java.util.TreeMap;
 import java.util.stream.Collectors;
 
+import static org.apache.kafka.common.utils.Utils.filterMap;
+import static org.apache.kafka.common.utils.Utils.union;
+
+/**
+ * All tasks contained by the Streams instance.
+ *
+ * Note that these tasks are shared between the TaskManager (stream thread) and the StateUpdater (restore thread),
+ * i.e. all running active tasks are processed by the former and all restoring active tasks and standby tasks are
+ * processed by the latter.
+ */
 class Tasks {
     private final Logger log;
-    private final TopologyMetadata topologyMetadata;
-
-    private final Map<TaskId, Task> allTasksPerId = Collections.synchronizedSortedMap(new TreeMap<>());
-    private final Map<TaskId, Task> readOnlyTasksPerId = Collections.unmodifiableMap(allTasksPerId);
-    private final Collection<Task> readOnlyTasks = Collections.unmodifiableCollection(allTasksPerId.values());
 
     // TODO: change type to `StreamTask`
     private final Map<TaskId, Task> activeTasksPerId = new TreeMap<>();
+    // TODO: change type to `StandbyTask`
+    private final Map<TaskId, Task> standbyTasksPerId = new TreeMap<>();
+
+    // Tasks may have been assigned for a NamedTopology that is not yet known by this host. When that occurs we stash
+    // these unknown tasks until either the corresponding NamedTopology is added and we can create them at last, or
+    // we receive a new assignment and they are revoked from the thread.
+
+    // Tasks may have been assigned but not yet created because:
+    // 1. They are for a NamedTopology that is yet known by this host.
+    // 2. They are to be recycled from an existing restoring task yet to be returned from the state updater.
+    //
+    // When that occurs we stash these pending tasks until either they are finally clear to be created,
+    // or they are revoked from a new assignment.
+    private final Map<TaskId, Set<TopicPartition>> pendingActiveTasks = new HashMap<>();
+    private final Map<TaskId, Set<TopicPartition>> pendingStandbyTasks = new HashMap<>();
+
     // TODO: change type to `StreamTask`
     private final Map<TopicPartition, Task> activeTasksPerPartition = new HashMap<>();
-    // TODO: change type to `StreamTask`
-    private final Map<TaskId, Task> readOnlyActiveTasksPerId = Collections.unmodifiableMap(activeTasksPerId);
-    private final Set<TaskId> readOnlyActiveTaskIds = Collections.unmodifiableSet(activeTasksPerId.keySet());
-    // TODO: change type to `StreamTask`
-    private final Collection<Task> readOnlyActiveTasks = Collections.unmodifiableCollection(activeTasksPerId.values());
 
-    // TODO: change type to `StandbyTask`
-    private final Map<TaskId, Task> standbyTasksPerId = new TreeMap<>();
-    // TODO: change type to `StandbyTask`
-    private final Map<TaskId, Task> readOnlyStandbyTasksPerId = Collections.unmodifiableMap(standbyTasksPerId);
-    private final Set<TaskId> readOnlyStandbyTaskIds = Collections.unmodifiableSet(standbyTasksPerId.keySet());
     private final Collection<Task> successfullyProcessed = new HashSet<>();
 
     private final ActiveTaskCreator activeTaskCreator;
@@ -65,13 +75,10 @@ class Tasks {
     private Consumer<byte[], byte[]> mainConsumer;
 
     Tasks(final LogContext logContext,
-          final TopologyMetadata topologyMetadata,
           final ActiveTaskCreator activeTaskCreator,
           final StandbyTaskCreator standbyTaskCreator) {
 
-        log = logContext.logger(getClass());
-
-        this.topologyMetadata = topologyMetadata;
+        this.log = logContext.logger(getClass());
         this.activeTaskCreator = activeTaskCreator;
         this.standbyTaskCreator = standbyTaskCreator;
     }
@@ -80,88 +87,131 @@ class Tasks {
         this.mainConsumer = mainConsumer;
     }
 
-    void handleNewAssignmentAndCreateTasks(final Map<TaskId, Set<TopicPartition>> activeTasksToCreate,
-                                           final Map<TaskId, Set<TopicPartition>> standbyTasksToCreate,
-                                           final Set<TaskId> assignedActiveTasks,
-                                           final Set<TaskId> assignedStandbyTasks) {
-        activeTaskCreator.removeRevokedUnknownTasks(assignedActiveTasks);
-        standbyTaskCreator.removeRevokedUnknownTasks(assignedStandbyTasks);
-        createTasks(activeTasksToCreate, standbyTasksToCreate);
+    void purgePendingTasks(final Set<TaskId> assignedActiveTasks, final Set<TaskId> assignedStandbyTasks) {
+        pendingActiveTasks.keySet().retainAll(assignedActiveTasks);
+        pendingStandbyTasks.keySet().retainAll(assignedStandbyTasks);
     }
 
-    void maybeCreateTasksFromNewTopologies(final Set<String> currentNamedTopologies) {
+    void addActivePendingTasks(final Map<TaskId, Set<TopicPartition>> pendingTasks) {
+        pendingActiveTasks.putAll(pendingTasks);
+    }
+
+    void addStandbyPendingTasks(final Map<TaskId, Set<TopicPartition>> pendingTasks) {
+        pendingStandbyTasks.putAll(pendingTasks);
+    }
+
+    void createPendingTasks(final Set<String> currentNamedTopologies) {
         createTasks(
-            activeTaskCreator.uncreatedTasksForTopologies(currentNamedTopologies),
-            standbyTaskCreator.uncreatedTasksForTopologies(currentNamedTopologies)
+            pendingActiveTasksForTopologies(currentNamedTopologies),
+            pendingStandbyTasksForTopologies(currentNamedTopologies)
         );
     }
 
-    double totalProducerBlockedTime() {
-        return activeTaskCreator.totalProducerBlockedTime();
+    private Map<TaskId, Set<TopicPartition>> pendingActiveTasksForTopologies(final Set<String> currentTopologies) {
+        return filterMap(pendingActiveTasks, t -> currentTopologies.contains(t.getKey().topologyName()));
+    }
+
+    private Map<TaskId, Set<TopicPartition>> pendingStandbyTasksForTopologies(final Set<String> currentTopologies) {
+        return filterMap(pendingStandbyTasks, t -> currentTopologies.contains(t.getKey().topologyName()));
     }
 
     void createTasks(final Map<TaskId, Set<TopicPartition>> activeTasksToCreate,
                      final Map<TaskId, Set<TopicPartition>> standbyTasksToCreate) {
+        createActiveTasks(activeTasksToCreate);
+        createStandbyTasks(standbyTasksToCreate);
+    }
+
+    private void createActiveTasks(final Map<TaskId, Set<TopicPartition>> activeTasksToCreate) {
         for (final Map.Entry<TaskId, Set<TopicPartition>> taskToBeCreated : activeTasksToCreate.entrySet()) {
             final TaskId taskId = taskToBeCreated.getKey();
 
             if (activeTasksPerId.containsKey(taskId)) {
                 throw new IllegalStateException("Attempted to create an active task that we already own: " + taskId);
             }
-        }
-
-        for (final Map.Entry<TaskId, Set<TopicPartition>> taskToBeCreated : standbyTasksToCreate.entrySet()) {
-            final TaskId taskId = taskToBeCreated.getKey();
 
-            if (standbyTasksPerId.containsKey(taskId)) {
-                throw new IllegalStateException("Attempted to create a standby task that we already own: " + taskId);
+            if (pendingStandbyTasks.containsKey(taskId)) {
+                throw new IllegalStateException("Attempted to create an active task while we already own its standby: " + taskId);
             }
         }
 
-        // keep this check to simplify testing (ie, no need to mock `activeTaskCreator`)
         if (!activeTasksToCreate.isEmpty()) {
-            // TODO: change type to `StreamTask`
             for (final Task activeTask : activeTaskCreator.createTasks(mainConsumer, activeTasksToCreate)) {
                 activeTasksPerId.put(activeTask.id(), activeTask);
-                allTasksPerId.put(activeTask.id(), activeTask);
+                pendingActiveTasks.remove(activeTask.id());
                 for (final TopicPartition topicPartition : activeTask.inputPartitions()) {
                     activeTasksPerPartition.put(topicPartition, activeTask);
                 }
             }
         }
+    }
+
+    private void createStandbyTasks(final Map<TaskId, Set<TopicPartition>> standbyTasksToCreate) {
+        for (final Map.Entry<TaskId, Set<TopicPartition>> taskToBeCreated : standbyTasksToCreate.entrySet()) {
+            final TaskId taskId = taskToBeCreated.getKey();
+
+            if (standbyTasksPerId.containsKey(taskId)) {
+                throw new IllegalStateException("Attempted to create an active task that we already own: " + taskId);
+            }
+
+            if (pendingActiveTasks.containsKey(taskId)) {
+                throw new IllegalStateException("Attempted to create an active task while we already own its standby: " + taskId);
+            }
+        }
 
-        // keep this check to simplify testing (ie, no need to mock `standbyTaskCreator`)
         if (!standbyTasksToCreate.isEmpty()) {
-            // TODO: change type to `StandbyTask`
             for (final Task standbyTask : standbyTaskCreator.createTasks(standbyTasksToCreate)) {
                 standbyTasksPerId.put(standbyTask.id(), standbyTask);
-                allTasksPerId.put(standbyTask.id(), standbyTask);
+                pendingActiveTasks.remove(standbyTask.id());
             }
         }
     }
 
+    void removeTask(final Task taskToRemove) {
+        final TaskId taskId = taskToRemove.id();
+
+        if (taskToRemove.state() != Task.State.CLOSED) {
+            throw new IllegalStateException("Attempted to remove a task that is not closed: " + taskId);
+        }
+
+        if (taskToRemove.isActive()) {
+            if (activeTasksPerId.remove(taskId) == null) {
+                throw new IllegalArgumentException("Attempted to remove an active task that is not owned: " + taskId);
+            }
+            activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId);
+            removePartitionsForActiveTask(taskId);
+            pendingActiveTasks.remove(taskId);
+        } else {
+            if (standbyTasksPerId.remove(taskId) == null) {
+                throw new IllegalArgumentException("Attempted to remove a standby task that is not owned: " + taskId);
+            }
+            pendingStandbyTasks.remove(taskId);
+        }
+    }
+
     void convertActiveToStandby(final StreamTask activeTask,
                                 final Set<TopicPartition> partitions,
                                 final Map<TaskId, RuntimeException> taskCloseExceptions) {
-        if (activeTasksPerId.remove(activeTask.id()) == null) {
-            throw new IllegalStateException("Attempted to convert unknown active task to standby task: " + activeTask.id());
+        final TaskId taskId = activeTask.id();
+        if (activeTasksPerId.remove(taskId) == null) {
+            throw new IllegalStateException("Attempted to convert unknown active task to standby task: " + taskId);
         }
-        final Set<TopicPartition> toBeRemoved = activeTasksPerPartition.entrySet().stream()
-            .filter(e -> e.getValue().id().equals(activeTask.id()))
-            .map(Map.Entry::getKey)
-            .collect(Collectors.toSet());
-        toBeRemoved.forEach(activeTasksPerPartition::remove);
+        removePartitionsForActiveTask(taskId);
 
-        cleanUpTaskProducerAndRemoveTask(activeTask.id(), taskCloseExceptions);
+        try {
+            activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId);
+        } catch (final RuntimeException e) {
+            taskCloseExceptions.putIfAbsent(taskId, e);
+        }
 
         final StandbyTask standbyTask = standbyTaskCreator.createStandbyTaskFromActive(activeTask, partitions);
         standbyTasksPerId.put(standbyTask.id(), standbyTask);
-        allTasksPerId.put(standbyTask.id(), standbyTask);
     }
 
-    void convertStandbyToActive(final StandbyTask standbyTask, final Set<TopicPartition> partitions) {
-        if (standbyTasksPerId.remove(standbyTask.id()) == null) {
-            throw new IllegalStateException("Attempted to convert unknown standby task to stream task: " + standbyTask.id());
+    void convertStandbyToActive(final StandbyTask standbyTask,
+                                final Set<TopicPartition> partitions) {
+        final TaskId taskId = standbyTask.id();
+        if (standbyTasksPerId.remove(taskId) == null) {
+            throw new IllegalStateException("Attempted to convert unknown standby task to stream task: " + taskId);
         }
 
         final StreamTask activeTask = activeTaskCreator.createActiveTaskFromStandby(standbyTask, partitions, mainConsumer);
@@ -169,36 +219,23 @@ class Tasks {
         for (final TopicPartition topicPartition : activeTask.inputPartitions()) {
             activeTasksPerPartition.put(topicPartition, activeTask);
         }
-        allTasksPerId.put(activeTask.id(), activeTask);
     }
 
-    void updateInputPartitionsAndResume(final Task task, final Set<TopicPartition> topicPartitions) {
+    boolean updateActiveTaskInputPartitions(final Task task, final Set<TopicPartition> topicPartitions) {
         final boolean requiresUpdate = !task.inputPartitions().equals(topicPartitions);
         if (requiresUpdate) {
             log.debug("Update task {} inputPartitions: current {}, new {}", task, task.inputPartitions(), topicPartitions);
-            for (final TopicPartition inputPartition : task.inputPartitions()) {
-                activeTasksPerPartition.remove(inputPartition);
-            }
             if (task.isActive()) {
+                for (final TopicPartition inputPartition : task.inputPartitions()) {
+                    activeTasksPerPartition.remove(inputPartition);
+                }
                 for (final TopicPartition topicPartition : topicPartitions) {
                     activeTasksPerPartition.put(topicPartition, task);
                 }
             }
-            task.updateInputPartitions(topicPartitions, topologyMetadata.nodeToSourceTopics(task.id()));
         }
-        task.resume();
-    }
 
-    void cleanUpTaskProducerAndRemoveTask(final TaskId taskId,
-                                          final Map<TaskId, RuntimeException> taskCloseExceptions) {
-        try {
-            activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId);
-        } catch (final RuntimeException e) {
-            final String uncleanMessage = String.format("Failed to close task %s cleanly. Attempting to close remaining tasks before re-throwing:", taskId);
-            log.error(uncleanMessage, e);
-            taskCloseExceptions.putIfAbsent(taskId, e);
-        }
-        removeTaskBeforeClosing(taskId);
+        return requiresUpdate;
     }
 
     void reInitializeThreadProducer() {
@@ -209,27 +246,18 @@ class Tasks {
         activeTaskCreator.closeThreadProducerIfNeeded();
     }
 
-    // TODO: change type to `StreamTask`
-    void closeAndRemoveTaskProducerIfNeeded(final Task activeTask) {
-        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(activeTask.id());
-    }
-
-    void removeTaskBeforeClosing(final TaskId taskId) {
-        activeTasksPerId.remove(taskId);
+    private void removePartitionsForActiveTask(final TaskId taskId) {
         final Set<TopicPartition> toBeRemoved = activeTasksPerPartition.entrySet().stream()
             .filter(e -> e.getValue().id().equals(taskId))
             .map(Map.Entry::getKey)
             .collect(Collectors.toSet());
         toBeRemoved.forEach(activeTasksPerPartition::remove);
-        standbyTasksPerId.remove(taskId);
-        allTasksPerId.remove(taskId);
     }
 
     void clear() {
         activeTasksPerId.clear();
-        activeTasksPerPartition.clear();
         standbyTasksPerId.clear();
-        allTasksPerId.clear();
+        activeTasksPerPartition.clear();
     }
 
     // TODO: change return type to `StreamTask`
@@ -237,19 +265,23 @@ class Tasks {
         return activeTasksPerPartition.get(partition);
     }
 
-    // TODO: change return type to `StandbyTask`
-    Task standbyTask(final TaskId taskId) {
-        if (!standbyTasksPerId.containsKey(taskId)) {
-            throw new IllegalStateException("Standby task unknown: " + taskId);
+    private Task getTask(final TaskId taskId) {
+        if (activeTasksPerId.containsKey(taskId)) {
+            return activeTasksPerId.get(taskId);
+        }
+        if (standbyTasksPerId.containsKey(taskId)) {
+            return standbyTasksPerId.get(taskId);
         }
-        return standbyTasksPerId.get(taskId);
+        return null;
     }
 
     Task task(final TaskId taskId) {
-        if (!allTasksPerId.containsKey(taskId)) {
+        final Task task = getTask(taskId);
+
+        if (task != null)
+            return task;
+        else
             throw new IllegalStateException("Task unknown: " + taskId);
-        }
-        return allTasksPerId.get(taskId);
     }
 
     Collection<Task> tasks(final Collection<TaskId> taskIds) {
@@ -262,51 +294,30 @@ class Tasks {
 
     // TODO: change return type to `StreamTask`
     Collection<Task> activeTasks() {
-        return readOnlyActiveTasks;
-    }
-
-    Collection<Task> allTasks() {
-        return readOnlyTasks;
-    }
-
-    Collection<Task> notPausedActiveTasks() {
-        return new ArrayList<>(readOnlyActiveTasks)
-            .stream()
-            .filter(t -> !topologyMetadata.isPaused(t.id().topologyName()))
-            .collect(Collectors.toList());
+        return Collections.unmodifiableCollection(activeTasksPerId.values());
     }
 
-    Collection<Task> notPausedTasks() {
-        return new ArrayList<>(readOnlyTasks)
-            .stream()
-            .filter(t -> !topologyMetadata.isPaused(t.id().topologyName()))
-            .collect(Collectors.toList());
+    /**
+     * All tasks returned by any of the getters are read-only and should NOT be modified;
+     * and the returned task could be modified by other threads concurrently
+     */
+    Set<Task> allTasks() {
+        return union(HashSet::new, new HashSet<>(activeTasksPerId.values()), new HashSet<>(standbyTasksPerId.values()));
     }
 
-    Set<TaskId> activeTaskIds() {
-        return readOnlyActiveTaskIds;
+    Set<TaskId> allTaskIds() {
+        return union(HashSet::new, activeTasksPerId.keySet(), standbyTasksPerId.keySet());
     }
 
-    Set<TaskId> standbyTaskIds() {
-        return readOnlyStandbyTaskIds;
-    }
-
-    // TODO: change return type to `StreamTask`
-    Map<TaskId, Task> activeTaskMap() {
-        return readOnlyActiveTasksPerId;
-    }
-
-    // TODO: change return type to `StandbyTask`
-    Map<TaskId, Task> standbyTaskMap() {
-        return readOnlyStandbyTasksPerId;
-    }
-
-    Map<TaskId, Task> tasksPerId() {
-        return readOnlyTasksPerId;
+    Map<TaskId, Task> allTasksPerId() {
+        final Map<TaskId, Task> ret = new HashMap<>();
+        ret.putAll(activeTasksPerId);
+        ret.putAll(standbyTasksPerId);
+        return ret;
     }
 
     boolean owned(final TaskId taskId) {
-        return allTasksPerId.containsKey(taskId);
+        return getTask(taskId) != null;
     }
 
     StreamsProducer streamsProducerForTask(final TaskId taskId) {
@@ -337,7 +348,7 @@ class Tasks {
         successfullyProcessed.add(task);
     }
 
-    void removeTaskFromCuccessfullyProcessedBeforeClosing(final Task task) {
+    void removeTaskFromSuccessfullyProcessedBeforeClosing(final Task task) {
         successfullyProcessed.remove(task);
     }
 
@@ -345,6 +356,10 @@ class Tasks {
         successfullyProcessed.clear();
     }
 
+    double totalProducerBlockedTime() {
+        return activeTaskCreator.totalProducerBlockedTime();
+    }
+
     // for testing only
     void addTask(final Task task) {
         if (task.isActive()) {
@@ -352,6 +367,5 @@ class Tasks {
         } else {
             standbyTasksPerId.put(task.id(), task);
         }
-        allTasksPerId.put(task.id(), task);
     }
 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
index a43b0793a2..2c096789f2 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
@@ -754,7 +754,6 @@ public class StreamThreadTest {
             null,
             null,
             null,
-            null,
             topologyMetadata,
             null,
             null
@@ -839,12 +838,8 @@ public class StreamThreadTest {
         final ActiveTaskCreator activeTaskCreator = mock(ActiveTaskCreator.class);
         expect(activeTaskCreator.createTasks(anyObject(), anyObject())).andStubReturn(Collections.singleton(task));
         expect(activeTaskCreator.producerClientIds()).andStubReturn(Collections.singleton("producerClientId"));
-        expect(activeTaskCreator.uncreatedTasksForTopologies(anyObject())).andStubReturn(emptyMap());
-        activeTaskCreator.removeRevokedUnknownTasks(singleton(task1));
 
         final StandbyTaskCreator standbyTaskCreator = mock(StandbyTaskCreator.class);
-        expect(standbyTaskCreator.uncreatedTasksForTopologies(anyObject())).andStubReturn(emptyMap());
-        standbyTaskCreator.removeRevokedUnknownTasks(emptySet());
 
         EasyMock.replay(consumer, consumerGroupMetadata, task, activeTaskCreator, standbyTaskCreator);
 
@@ -858,7 +853,6 @@ public class StreamThreadTest {
             null,
             null,
             null,
-            null,
             activeTaskCreator,
             standbyTaskCreator,
             topologyMetadata,
@@ -1047,19 +1041,11 @@ public class StreamThreadTest {
 
         final StreamThread thread = createStreamThread(CLIENT_ID, new StreamsConfig(configProps(true)), true);
 
-        thread.start();
-        TestUtils.waitForCondition(
-            () -> thread.state() == StreamThread.State.STARTING,
-            10 * 1000,
-            "Thread never started.");
-
-        thread.rebalanceListener().onPartitionsRevoked(Collections.emptyList());
         thread.taskManager().handleRebalanceStart(Collections.singleton(topic1));
 
+        // assign single partition
         final Map<TaskId, Set<TopicPartition>> activeTasks = new HashMap<>();
         final List<TopicPartition> assignedPartitions = new ArrayList<>();
-
-        // assign single partition
         assignedPartitions.add(t1p1);
         assignedPartitions.add(t1p2);
         activeTasks.put(task1, Collections.singleton(t1p1));
@@ -1067,11 +1053,18 @@ public class StreamThreadTest {
 
         thread.taskManager().handleAssignment(activeTasks, emptyMap());
 
+        thread.start();
+        TestUtils.waitForCondition(
+                () -> thread.state() == StreamThread.State.STARTING,
+                10 * 1000,
+                "Thread never started.");
+
         thread.shutdown();
 
         // even if thread is no longer running, it should still be polling
         // as long as the rebalance is still ongoing
         assertFalse(thread.isRunning());
+        assertTrue(thread.isAlive());
 
         Thread.sleep(1000);
         assertEquals(Utils.mkSet(task1, task2), thread.taskManager().activeTaskIds());
@@ -2657,7 +2650,7 @@ public class StreamThreadTest {
         expect(task3.state()).andReturn(Task.State.CREATED).anyTimes();
         expect(task3.id()).andReturn(taskId3).anyTimes();
 
-        expect(taskManager.tasks()).andReturn(mkMap(
+        expect(taskManager.allTasks()).andReturn(mkMap(
             mkEntry(taskId1, task1),
             mkEntry(taskId2, task2),
             mkEntry(taskId3, task3)
@@ -2922,7 +2915,7 @@ public class StreamThreadTest {
 
         expect(runningTask.state()).andReturn(Task.State.RUNNING).anyTimes();
         expect(runningTask.id()).andReturn(taskId).anyTimes();
-        expect(taskManager.tasks())
+        expect(taskManager.allTasks())
                 .andReturn(Collections.singletonMap(taskId, runningTask)).anyTimes();
         expect(taskManager.commit(Collections.singleton(runningTask))).andReturn(1).anyTimes();
         taskManager.maybePurgeCommittedRecords();
@@ -2940,7 +2933,7 @@ public class StreamThreadTest {
 
         expect(runningTask.state()).andReturn(Task.State.RUNNING).anyTimes();
         expect(runningTask.id()).andReturn(taskId).anyTimes();
-        expect(taskManager.tasks())
+        expect(taskManager.allTasks())
             .andReturn(Collections.singletonMap(taskId, runningTask)).times(numberOfCommits);
         expect(taskManager.commit(Collections.singleton(runningTask))).andReturn(commits).times(numberOfCommits);
         EasyMock.replay(taskManager, runningTask);
@@ -2997,7 +2990,7 @@ public class StreamThreadTest {
     }
 
     StreamTask activeTask(final TaskManager taskManager, final TopicPartition partition) {
-        final Stream<Task> standbys = taskManager.tasks().values().stream().filter(t -> t.isActive());
+        final Stream<Task> standbys = taskManager.allTasks().values().stream().filter(Task::isActive);
         for (final Task task : (Iterable<Task>) standbys::iterator) {
             if (task.inputPartitions().contains(partition)) {
                 return (StreamTask) task;
@@ -3006,7 +2999,7 @@ public class StreamThreadTest {
         return null;
     }
     StandbyTask standbyTask(final TaskManager taskManager, final TopicPartition partition) {
-        final Stream<Task> standbys = taskManager.tasks().values().stream().filter(t -> !t.isActive());
+        final Stream<Task> standbys = taskManager.allTasks().values().stream().filter(t -> !t.isActive());
         for (final Task task : (Iterable<Task>) standbys::iterator) {
             if (task.inputPartitions().contains(partition)) {
                 return (StandbyTask) task;
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskExecutorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskExecutorTest.java
index a44970238a..88ee70e57e 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskExecutorTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskExecutorTest.java
@@ -22,7 +22,6 @@ import org.apache.kafka.streams.internals.StreamsConfigUtils.ProcessingMode;
 import org.junit.Test;
 
 import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.verify;
 
 public class TaskExecutorTest {
@@ -35,7 +34,6 @@ public class TaskExecutorTest {
             new TaskExecutor(tasks, metadata, ProcessingMode.AT_LEAST_ONCE, false, new LogContext());
 
         taskExecutor.punctuate();
-        verify(tasks).notPausedActiveTasks();
-        verify(tasks, never()).notPausedTasks();
+        verify(tasks).activeTasks();
     }
 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
index 073dede23f..b3ffb29b1a 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
@@ -33,10 +33,9 @@ import org.apache.kafka.common.errors.TimeoutException;
 import org.apache.kafka.common.internals.KafkaFutureImpl;
 import org.apache.kafka.common.metrics.KafkaMetric;
 import org.apache.kafka.common.metrics.Measurable;
-import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.Time;
-import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.TopologyConfig;
 import org.apache.kafka.streams.errors.LockException;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.errors.TaskCorruptedException;
@@ -47,7 +46,6 @@ import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.internals.StateDirectory.TaskDirectory;
 import org.apache.kafka.streams.processor.internals.Task.State;
-import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
 import org.apache.kafka.streams.processor.internals.testutil.DummyStreamsConfig;
 import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender;
 import org.apache.kafka.streams.state.internals.OffsetCheckpoint;
@@ -93,6 +91,7 @@ import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.common.utils.Utils.mkSet;
 import static org.apache.kafka.common.utils.Utils.union;
+import static org.apache.kafka.streams.processor.internals.TopologyMetadata.UNNAMED_TOPOLOGY;
 import static org.easymock.EasyMock.anyObject;
 import static org.easymock.EasyMock.anyString;
 import static org.easymock.EasyMock.eq;
@@ -172,6 +171,7 @@ public class TaskManagerTest {
     private Admin adminClient;
 
     private TaskManager taskManager;
+    private TopologyMetadata topologyMetadata;
     private final Time time = new MockTime();
 
     @Rule
@@ -183,25 +183,22 @@ public class TaskManagerTest {
     }
 
     private void setUpTaskManager(final StreamsConfigUtils.ProcessingMode processingMode) {
+        topologyMetadata = new TopologyMetadata(topologyBuilder, new DummyStreamsConfig(processingMode));
         taskManager = new TaskManager(
             time,
             changeLogReader,
             UUID.randomUUID(),
             "taskManagerTest",
-            new StreamsMetricsImpl(new Metrics(), "clientId", StreamsConfig.METRICS_LATEST, time),
             activeTaskCreator,
             standbyTaskCreator,
-            new TopologyMetadata(topologyBuilder, new DummyStreamsConfig(processingMode)),
+            topologyMetadata,
             adminClient,
             stateDirectory
         );
         taskManager.setMainConsumer(consumer);
         reset(topologyBuilder);
         expect(topologyBuilder.hasNamedTopology()).andStubReturn(false);
-        activeTaskCreator.removeRevokedUnknownTasks(anyObject());
-        expectLastCall().asStub();
-        standbyTaskCreator.removeRevokedUnknownTasks(anyObject());
-        expectLastCall().asStub();
+        expect(topologyBuilder.nodeToSourceTopics()).andStubReturn(emptyMap());
     }
 
     @Test
@@ -1192,7 +1189,9 @@ public class TaskManagerTest {
         expectRestoreToBeCompleted(consumer, changeLogReader);
         expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andReturn(singletonList(task00));
         expect(standbyTaskCreator.createTasks(eq(taskId01Assignment))).andReturn(singletonList(task01));
-        replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
+        topologyBuilder.addSubscribedTopicsFromAssignment(eq(asList(t1p0)), anyString());
+        expectLastCall().anyTimes();
+        replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader, topologyBuilder);
 
         taskManager.handleAssignment(taskId00Assignment, taskId01Assignment);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true));
@@ -1518,8 +1517,10 @@ public class TaskManagerTest {
 
         expect(activeTaskCreator.createTasks(anyObject(), eq(assignmentActive))).andReturn(singleton(task00));
         expect(standbyTaskCreator.createTasks(eq(assignmentStandby))).andReturn(singletonList(task10));
+        topologyBuilder.addSubscribedTopicsFromAssignment(eq(asList(t1p0)), anyString());
+        expectLastCall().anyTimes();
 
-        replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
+        replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader, topologyBuilder);
 
         taskManager.handleAssignment(assignmentActive, assignmentStandby);
         assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true));
@@ -1665,14 +1666,8 @@ public class TaskManagerTest {
         expect(changeLogReader.completedChangelogs()).andReturn(emptySet());
         expect(activeTaskCreator.createTasks(anyObject(), eq(assignment)))
             .andStubReturn(asList(task00, task01, task02, task03));
-        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId00));
-        expectLastCall();
-        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId01));
-        expectLastCall();
-        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId02));
-        expectLastCall();
-        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId03));
-        expectLastCall();
+        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(anyObject());
+        expectLastCall().times(4);
         activeTaskCreator.closeThreadProducerIfNeeded();
         expectLastCall();
         expect(standbyTaskCreator.createTasks(eq(emptyMap()))).andStubReturn(emptyList());
@@ -1857,7 +1852,7 @@ public class TaskManagerTest {
         assertThat(task01.state(), is(Task.State.CLOSED));
 
         // All the tasks involving in the commit should already be removed.
-        assertThat(taskManager.tasks(), is(Collections.singletonMap(taskId00, task00)));
+        assertThat(taskManager.allTasks(), is(Collections.singletonMap(taskId00, task00)));
     }
 
     @Test
@@ -1923,12 +1918,8 @@ public class TaskManagerTest {
         resetToStrict(changeLogReader);
         expect(changeLogReader.completedChangelogs()).andReturn(emptySet());
         expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andStubReturn(asList(task00, task01, task02));
-        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId00));
-        expectLastCall().andThrow(new RuntimeException("whatever 0"));
-        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId01));
-        expectLastCall().andThrow(new RuntimeException("whatever 1"));
-        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId02));
-        expectLastCall().andThrow(new RuntimeException("whatever 2"));
+        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(anyObject());
+        expectLastCall().andThrow(new RuntimeException("whatever")).times(3);
         activeTaskCreator.closeThreadProducerIfNeeded();
         expectLastCall().andThrow(new RuntimeException("whatever all"));
         expect(standbyTaskCreator.createTasks(eq(emptyMap()))).andStubReturn(emptyList());
@@ -3204,6 +3195,8 @@ public class TaskManagerTest {
             .andReturn(singletonList(activeTask));
         activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId00);
         expectLastCall().anyTimes();
+        activeTask.closeCleanAndRecycleState();
+        expectLastCall().once();
 
         expect(standbyTaskCreator.createStandbyTaskFromActive(anyObject(), eq(taskId00Partitions)))
             .andReturn(standbyTask);
@@ -3226,6 +3219,8 @@ public class TaskManagerTest {
         expectLastCall().anyTimes();
         standbyTask.postCommit(true);
         expectLastCall().anyTimes();
+        standbyTask.closeCleanAndRecycleState();
+        expectLastCall().once();
 
         final StreamTask activeTask = mock(StreamTask.class);
         expect(activeTask.id()).andStubReturn(taskId00);
@@ -3245,6 +3240,17 @@ public class TaskManagerTest {
         verify(standbyTaskCreator, activeTaskCreator);
     }
 
+    @Test
+    public void shouldListNotPausedTasks() {
+        handleAssignment(taskId00Assignment, taskId01Assignment, emptyMap());
+
+        assertEquals(taskManager.notPausedTasks().size(), 2);
+
+        topologyMetadata.pauseTopology(UNNAMED_TOPOLOGY);
+
+        assertEquals(taskManager.notPausedTasks().size(), 0);
+    }
+
     private static void expectRestoreToBeCompleted(final Consumer<byte[], byte[]> consumer,
                                                    final ChangelogReader changeLogReader) {
         expectRestoreToBeCompleted(consumer, changeLogReader, true);
@@ -3315,7 +3321,7 @@ public class TaskManagerTest {
                          final Set<TopicPartition> partitions,
                          final boolean active,
                          final ProcessorStateManager processorStateManager) {
-            super(id, null, null, processorStateManager, partitions, 0L, "test-task", StateMachineTask.class);
+            super(id, null, null, processorStateManager, partitions, (new TopologyConfig(new DummyStreamsConfig())).getTaskConfig(), "test-task", StateMachineTask.class);
             this.active = active;
         }
 
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TasksTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TasksTest.java
deleted file mode 100644
index ad701f8ca4..0000000000
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TasksTest.java
+++ /dev/null
@@ -1,65 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.kafka.streams.processor.internals;
-
-import org.apache.kafka.common.utils.LogContext;
-import org.apache.kafka.streams.processor.TaskId;
-import org.junit.Assert;
-import org.junit.Test;
-
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
-
-public class TasksTest {
-    @Test
-    public void testNotPausedTasks() {
-        final TopologyMetadata topologyMetadata = mock(TopologyMetadata.class);
-        final String unnamedTopologyName = null;
-        when(topologyMetadata.isPaused(unnamedTopologyName))
-            .thenReturn(false)
-            .thenReturn(false).thenReturn(false)
-            .thenReturn(true)
-            .thenReturn(true).thenReturn(true);
-
-        final Tasks tasks = new Tasks(
-            new LogContext(),
-            topologyMetadata,
-            mock(ActiveTaskCreator.class),
-            mock(StandbyTaskCreator.class)
-        );
-
-        final TaskId taskId1 = new TaskId(0, 1);
-        final TaskId taskId2 = new TaskId(0, 2);
-
-        final StreamTask streamTask = mock(StreamTask.class);
-        when(streamTask.isActive()).thenReturn(true);
-        when(streamTask.id()).thenReturn(taskId1);
-
-        final StandbyTask standbyTask1 = mock(StandbyTask.class);
-        when(standbyTask1.isActive()).thenReturn(false);
-        when(standbyTask1.id()).thenReturn(taskId2);
-
-        tasks.addTask(streamTask);
-        tasks.addTask(standbyTask1);
-        Assert.assertEquals(tasks.notPausedActiveTasks().size(), 1);
-        Assert.assertEquals(tasks.notPausedTasks().size(), 2);
-
-        Assert.assertEquals(tasks.notPausedActiveTasks().size(), 0);
-        Assert.assertEquals(tasks.notPausedTasks().size(), 0);
-    }
-}