You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by vv...@apache.org on 2020/03/05 20:21:23 UTC

[kafka] branch trunk updated: KAFKA-9615: Clean up task/producer create and close (#8213)

This is an automated email from the ASF dual-hosted git repository.

vvcephei pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 78374a1  KAFKA-9615: Clean up task/producer create and close (#8213)
78374a1 is described below

commit 78374a15492cfb6df49353bd166d8c45ac9abdb2
Author: John Roesler <vv...@users.noreply.github.com>
AuthorDate: Thu Mar 5 14:20:46 2020 -0600

    KAFKA-9615: Clean up task/producer create and close (#8213)
    
    * Consolidate task/producer management. Now, exactly one component manages
      the creation and destruction of Producers, whether they are per-thread or per-task.
    * Add missing test coverage on TaskManagerTest
    
    Reviewers: Guozhang Wang <wa...@gmail.com>, Boyang Chen <bo...@confluent.io>
---
 checkstyle/suppressions.xml                        |   2 +
 .../processor/internals/ActiveTaskCreator.java     | 225 +++++++
 .../processor/internals/RecordCollector.java       |   2 +-
 .../processor/internals/RecordCollectorImpl.java   |   1 -
 .../processor/internals/StandbyTaskCreator.java    | 113 ++++
 .../streams/processor/internals/StreamThread.java  | 292 +--------
 .../processor/internals/StreamsProducer.java       | 121 ++--
 .../streams/processor/internals/TaskManager.java   | 111 ++--
 .../processor/internals/RecordCollectorTest.java   |  64 +-
 .../processor/internals/StreamThreadTest.java      | 177 +++---
 .../processor/internals/StreamsProducerTest.java   | 281 ++++-----
 .../processor/internals/TaskManagerTest.java       | 673 +++++++++++++++++++--
 .../streams/state/KeyValueStoreTestDriver.java     |   2 +-
 .../StreamThreadStateStoreProviderTest.java        |   8 +-
 .../apache/kafka/streams/TopologyTestDriver.java   |   6 +-
 15 files changed, 1362 insertions(+), 716 deletions(-)

diff --git a/checkstyle/suppressions.xml b/checkstyle/suppressions.xml
index e21c115..f533179 100644
--- a/checkstyle/suppressions.xml
+++ b/checkstyle/suppressions.xml
@@ -174,6 +174,8 @@
               files="StreamsPartitionAssignor.java"/>
     <suppress checks="CyclomaticComplexity"
               files="StreamThread.java"/>
+    <suppress checks="CyclomaticComplexity"
+              files="TaskManager.java"/>
 
     <suppress checks="JavaNCSS"
               files="StreamsPartitionAssignor.java"/>
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
new file mode 100644
index 0000000..43ae0d4
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.clients.producer.Producer;
+import org.apache.kafka.clients.producer.ProducerConfig;
+import org.apache.kafka.common.Metric;
+import org.apache.kafka.common.MetricName;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.metrics.Sensor;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.streams.KafkaClientSupplier;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.errors.StreamsException;
+import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
+import org.apache.kafka.streams.processor.internals.metrics.ThreadMetrics;
+import org.apache.kafka.streams.state.internals.ThreadCache;
+import org.slf4j.Logger;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+import static org.apache.kafka.streams.StreamsConfig.EXACTLY_ONCE;
+
+class ActiveTaskCreator {
+    private final String applicationId;
+    private final InternalTopologyBuilder builder;
+    private final StreamsConfig config;
+    private final StreamsMetricsImpl streamsMetrics;
+    private final StateDirectory stateDirectory;
+    private final ChangelogReader storeChangelogReader;
+    private final Time time;
+    private final Logger log;
+    private final String threadId;
+    private final ThreadCache cache;
+    private final Producer<byte[], byte[]> threadProducer;
+    private final KafkaClientSupplier clientSupplier;
+    private final Map<TaskId, Producer<byte[], byte[]>> taskProducers;
+    private final Sensor createTaskSensor;
+
+    private static String getThreadProducerClientId(final String threadClientId) {
+        return threadClientId + "-producer";
+    }
+
+    private static String getTaskProducerClientId(final String threadClientId, final TaskId taskId) {
+        return threadClientId + "-" + taskId + "-producer";
+    }
+
+    ActiveTaskCreator(final InternalTopologyBuilder builder,
+                      final StreamsConfig config,
+                      final StreamsMetricsImpl streamsMetrics,
+                      final StateDirectory stateDirectory,
+                      final ChangelogReader storeChangelogReader,
+                      final ThreadCache cache,
+                      final Time time,
+                      final KafkaClientSupplier clientSupplier,
+                      final String threadId,
+                      final Logger log) {
+        applicationId = config.getString(StreamsConfig.APPLICATION_ID_CONFIG);
+        this.builder = builder;
+        this.config = config;
+        this.streamsMetrics = streamsMetrics;
+        this.stateDirectory = stateDirectory;
+        this.storeChangelogReader = storeChangelogReader;
+        this.time = time;
+        this.log = log;
+
+        if (EXACTLY_ONCE.equals(config.getString(StreamsConfig.PROCESSING_GUARANTEE_CONFIG))) {
+            threadProducer = null;
+            taskProducers = new HashMap<>();
+        } else {
+            final String threadProducerClientId = getThreadProducerClientId(threadId);
+            final Map<String, Object> producerConfigs = config.getProducerConfigs(threadProducerClientId);
+            log.info("Creating thread producer client");
+            threadProducer = clientSupplier.getProducer(producerConfigs);
+            taskProducers = Collections.emptyMap();
+        }
+
+
+        this.cache = cache;
+        this.threadId = threadId;
+        this.clientSupplier = clientSupplier;
+
+        createTaskSensor = ThreadMetrics.createTaskSensor(threadId, streamsMetrics);
+    }
+
+    Collection<Task> createTasks(final Consumer<byte[], byte[]> consumer,
+                                 final Map<TaskId, Set<TopicPartition>> tasksToBeCreated) {
+        final List<Task> createdTasks = new ArrayList<>();
+        for (final Map.Entry<TaskId, Set<TopicPartition>> newTaskAndPartitions : tasksToBeCreated.entrySet()) {
+            final TaskId taskId = newTaskAndPartitions.getKey();
+            final Set<TopicPartition> partitions = newTaskAndPartitions.getValue();
+
+            final String threadIdPrefix = String.format("stream-thread [%s] ", Thread.currentThread().getName());
+            final String logPrefix = threadIdPrefix + String.format("%s [%s] ", "task", taskId);
+            final LogContext logContext = new LogContext(logPrefix);
+
+            final ProcessorTopology topology = builder.buildSubtopology(taskId.topicGroupId);
+
+            final ProcessorStateManager stateManager = new ProcessorStateManager(
+                taskId,
+                partitions,
+                Task.TaskType.ACTIVE,
+                stateDirectory,
+                topology.storeToChangelogTopic(),
+                storeChangelogReader,
+                logContext
+            );
+
+            if (threadProducer == null) {
+                final String taskProducerClientId = getTaskProducerClientId(threadId, taskId);
+                final Map<String, Object> producerConfigs = config.getProducerConfigs(taskProducerClientId);
+                producerConfigs.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, applicationId + "-" + taskId);
+                log.info("Creating producer client for task {}", taskId);
+                taskProducers.put(taskId, clientSupplier.getProducer(producerConfigs));
+            }
+
+            final RecordCollector recordCollector = new RecordCollectorImpl(
+                logContext,
+                taskId,
+                consumer,
+                threadProducer != null ?
+                    new StreamsProducer(threadProducer, false, logContext, applicationId) :
+                    new StreamsProducer(taskProducers.get(taskId), true, logContext, applicationId),
+                config.defaultProductionExceptionHandler(),
+                EXACTLY_ONCE.equals(config.getString(StreamsConfig.PROCESSING_GUARANTEE_CONFIG)),
+                streamsMetrics
+            );
+
+            final Task task = new StreamTask(
+                taskId,
+                partitions,
+                topology,
+                consumer,
+                config,
+                streamsMetrics,
+                stateDirectory,
+                cache,
+                time,
+                stateManager,
+                recordCollector
+            );
+
+            log.trace("Created task {} with assigned partitions {}", taskId, partitions);
+            createdTasks.add(task);
+            createTaskSensor.record();
+        }
+        return createdTasks;
+    }
+
+    void closeThreadProducerIfNeeded() {
+        if (threadProducer != null) {
+            try {
+                threadProducer.close();
+            } catch (final RuntimeException e) {
+                throw new StreamsException("Thread Producer encounter unexpected error trying to close", e);
+            }
+        }
+    }
+
+    void closeAndRemoveTaskProducerIfNeeded(final TaskId id) {
+        final Producer<byte[], byte[]> producer = taskProducers.remove(id);
+        if (producer != null) {
+            try {
+                producer.close();
+            } catch (final RuntimeException e) {
+                throw new StreamsException("[" + id + "] Producer encounter unexpected error trying to close", e);
+            }
+        }
+    }
+
+    Map<MetricName, Metric> producerMetrics() {
+        final Map<MetricName, Metric> result = new LinkedHashMap<>();
+        if (threadProducer != null) {
+            final Map<MetricName, ? extends Metric> producerMetrics = threadProducer.metrics();
+            if (producerMetrics != null) {
+                result.putAll(producerMetrics);
+            }
+        } else {
+            // When EOS is turned on, each task will have its own producer client
+            // and the producer object passed in here will be null. We would then iterate through
+            // all the active tasks and add their metrics to the output metrics map.
+            for (final Map.Entry<TaskId, Producer<byte[], byte[]>> entry : taskProducers.entrySet()) {
+                final Map<MetricName, ? extends Metric> taskProducerMetrics = entry.getValue().metrics();
+                result.putAll(taskProducerMetrics);
+            }
+        }
+        return result;
+    }
+
+    Set<String> producerClientIds() {
+        if (threadProducer != null) {
+            return Collections.singleton(getThreadProducerClientId(threadId));
+        } else {
+            return taskProducers.keySet()
+                                .stream()
+                                .map(taskId -> getTaskProducerClientId(threadId, taskId))
+                                .collect(Collectors.toSet());
+        }
+    }
+}
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollector.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollector.java
index 5e8a073..9594679 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollector.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollector.java
@@ -25,7 +25,7 @@ import org.apache.kafka.streams.processor.StreamPartitioner;
 
 import java.util.Map;
 
-public interface RecordCollector extends AutoCloseable {
+public interface RecordCollector {
 
     <K, V> void send(final String topic,
                      final K key,
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java
index 4fef0f3..c5ac440 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java
@@ -262,7 +262,6 @@ public class RecordCollectorImpl implements RecordCollector {
         if (eosEnabled) {
             streamsProducer.abortTransaction();
         }
-        streamsProducer.close();
 
         checkForException();
     }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java
new file mode 100644
index 0000000..fbebe72
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.metrics.Sensor;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
+import org.apache.kafka.streams.processor.internals.metrics.ThreadMetrics;
+import org.slf4j.Logger;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+class StandbyTaskCreator {
+    private final InternalTopologyBuilder builder;
+    private final StreamsConfig config;
+    private final StreamsMetricsImpl streamsMetrics;
+    private final StateDirectory stateDirectory;
+    private final ChangelogReader storeChangelogReader;
+    private final Logger log;
+    private final Sensor createTaskSensor;
+
+    StandbyTaskCreator(final InternalTopologyBuilder builder,
+                       final StreamsConfig config,
+                       final StreamsMetricsImpl streamsMetrics,
+                       final StateDirectory stateDirectory,
+                       final ChangelogReader storeChangelogReader,
+                       final String threadId,
+                       final Logger log) {
+        this.builder = builder;
+        this.config = config;
+        this.streamsMetrics = streamsMetrics;
+        this.stateDirectory = stateDirectory;
+        this.storeChangelogReader = storeChangelogReader;
+        this.log = log;
+        createTaskSensor = ThreadMetrics.createTaskSensor(threadId, streamsMetrics);
+    }
+
+    Collection<Task> createTasks(final Map<TaskId, Set<TopicPartition>> tasksToBeCreated) {
+        final List<Task> createdTasks = new ArrayList<>();
+        for (final Map.Entry<TaskId, Set<TopicPartition>> newTaskAndPartitions : tasksToBeCreated.entrySet()) {
+            final TaskId taskId = newTaskAndPartitions.getKey();
+            final Set<TopicPartition> partitions = newTaskAndPartitions.getValue();
+
+            final String threadIdPrefix = String.format("stream-thread [%s] ", Thread.currentThread().getName());
+            final String logPrefix = threadIdPrefix + String.format("%s [%s] ", "standby-task", taskId);
+            final LogContext logContext = new LogContext(logPrefix);
+
+            final ProcessorTopology topology = builder.buildSubtopology(taskId.topicGroupId);
+
+            if (topology.hasStateWithChangelogs()) {
+                final ProcessorStateManager stateManager = new ProcessorStateManager(
+                    taskId,
+                    partitions,
+                    Task.TaskType.STANDBY,
+                    stateDirectory,
+                    topology.storeToChangelogTopic(),
+                    storeChangelogReader,
+                    logContext
+                );
+
+                final StandbyTask task = new StandbyTask(
+                    taskId,
+                    partitions,
+                    topology,
+                    config,
+                    streamsMetrics,
+                    stateManager,
+                    stateDirectory
+                );
+
+                log.trace("Created task {} with assigned partitions {}", taskId, partitions);
+                createdTasks.add(task);
+                createTaskSensor.record();
+            } else {
+                log.trace(
+                    "Skipped standby task {} with assigned partitions {} " +
+                        "since it does not have any state stores to materialize",
+                    taskId, partitions
+                );
+            }
+        }
+        return createdTasks;
+    }
+
+    public InternalTopologyBuilder builder() {
+        return builder;
+    }
+
+    public StateDirectory stateDirectory() {
+        return stateDirectory;
+    }
+}
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
index 4092825..1465110 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
@@ -22,8 +22,6 @@ import org.apache.kafka.clients.consumer.ConsumerConfig;
 import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
 import org.apache.kafka.clients.consumer.ConsumerRecords;
 import org.apache.kafka.clients.consumer.InvalidOffsetException;
-import org.apache.kafka.clients.producer.Producer;
-import org.apache.kafka.clients.producer.ProducerConfig;
 import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.Metric;
 import org.apache.kafka.common.MetricName;
@@ -49,11 +47,8 @@ import org.apache.kafka.streams.state.internals.ThreadCache;
 import org.slf4j.Logger;
 
 import java.time.Duration;
-import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Collection;
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.HashSet;
 import java.util.LinkedHashMap;
 import java.util.List;
@@ -62,8 +57,6 @@ import java.util.Set;
 import java.util.UUID;
 import java.util.concurrent.atomic.AtomicInteger;
 
-import static org.apache.kafka.streams.StreamsConfig.EXACTLY_ONCE;
-
 public class StreamThread extends Thread {
 
     private final Admin adminClient;
@@ -247,235 +240,6 @@ public class StreamThread extends Thread {
         return assignmentErrorCode.get();
     }
 
-    static abstract class AbstractTaskCreator<T extends Task> {
-        final String applicationId;
-        final InternalTopologyBuilder builder;
-        final StreamsConfig config;
-        final StreamsMetricsImpl streamsMetrics;
-        final StateDirectory stateDirectory;
-        final ChangelogReader storeChangelogReader;
-        final Time time;
-        final Logger log;
-
-        AbstractTaskCreator(final InternalTopologyBuilder builder,
-                            final StreamsConfig config,
-                            final StreamsMetricsImpl streamsMetrics,
-                            final StateDirectory stateDirectory,
-                            final ChangelogReader storeChangelogReader,
-                            final Time time,
-                            final Logger log) {
-            this.applicationId = config.getString(StreamsConfig.APPLICATION_ID_CONFIG);
-            this.builder = builder;
-            this.config = config;
-            this.streamsMetrics = streamsMetrics;
-            this.stateDirectory = stateDirectory;
-            this.storeChangelogReader = storeChangelogReader;
-            this.time = time;
-            this.log = log;
-        }
-
-        public InternalTopologyBuilder builder() {
-            return builder;
-        }
-
-        public StateDirectory stateDirectory() {
-            return stateDirectory;
-        }
-
-        Collection<T> createTasks(final Consumer<byte[], byte[]> consumer,
-                                  final Map<TaskId, Set<TopicPartition>> tasksToBeCreated) {
-            final List<T> createdTasks = new ArrayList<>();
-            for (final Map.Entry<TaskId, Set<TopicPartition>> newTaskAndPartitions : tasksToBeCreated.entrySet()) {
-                final TaskId taskId = newTaskAndPartitions.getKey();
-                final Set<TopicPartition> partitions = newTaskAndPartitions.getValue();
-                final T task = createTask(consumer, taskId, partitions);
-                if (task != null) {
-                    log.trace("Created task {} with assigned partitions {}", taskId, partitions);
-                    createdTasks.add(task);
-                }
-
-            }
-            return createdTasks;
-        }
-
-        abstract T createTask(final Consumer<byte[], byte[]> consumer, final TaskId id, final Set<TopicPartition> partitions);
-
-        void close() {}
-    }
-
-    static class TaskCreator extends AbstractTaskCreator<StreamTask> {
-        private final String threadId;
-        private final ThreadCache cache;
-        private final Producer<byte[], byte[]> threadProducer;
-        private final KafkaClientSupplier clientSupplier;
-        final Map<TaskId, Producer<byte[], byte[]>> taskProducers;
-        private final Sensor createTaskSensor;
-
-        TaskCreator(final InternalTopologyBuilder builder,
-                    final StreamsConfig config,
-                    final StreamsMetricsImpl streamsMetrics,
-                    final StateDirectory stateDirectory,
-                    final ChangelogReader storeChangelogReader,
-                    final ThreadCache cache,
-                    final Time time,
-                    final KafkaClientSupplier clientSupplier,
-                    final Map<TaskId, Producer<byte[], byte[]>> taskProducers,
-                    final String threadId,
-                    final Logger log) {
-            super(
-                builder,
-                config,
-                streamsMetrics,
-                stateDirectory,
-                storeChangelogReader,
-                time,
-                log);
-
-            final boolean eosEnabled = EXACTLY_ONCE.equals(config.getString(StreamsConfig.PROCESSING_GUARANTEE_CONFIG));
-            if (!eosEnabled) {
-                final Map<String, Object> producerConfigs = config.getProducerConfigs(getThreadProducerClientId(threadId));
-                log.info("Creating thread producer client");
-                this.threadProducer = clientSupplier.getProducer(producerConfigs);
-            } else {
-                this.threadProducer = null;
-            }
-            this.taskProducers = taskProducers;
-
-            this.cache = cache;
-            this.threadId = threadId;
-            this.clientSupplier = clientSupplier;
-
-            this.createTaskSensor = ThreadMetrics.createTaskSensor(threadId, streamsMetrics);
-        }
-
-        @Override
-        StreamTask createTask(final Consumer<byte[], byte[]> mainConsumer,
-                              final TaskId taskId,
-                              final Set<TopicPartition> partitions) {
-            createTaskSensor.record();
-
-            final String threadIdPrefix = String.format("stream-thread [%s] ", Thread.currentThread().getName());
-            final String logPrefix = threadIdPrefix + String.format("%s [%s] ", "task", taskId);
-            final LogContext logContext = new LogContext(logPrefix);
-
-            final ProcessorTopology topology = builder.buildSubtopology(taskId.topicGroupId);
-
-            final ProcessorStateManager stateManager = new ProcessorStateManager(
-                taskId,
-                partitions,
-                Task.TaskType.ACTIVE,
-                stateDirectory,
-                topology.storeToChangelogTopic(),
-                storeChangelogReader,
-                logContext);
-
-            if (threadProducer == null) {
-                // create one producer per task for EOS
-                // TODO: after KIP-447 this would be removed
-                final Map<String, Object> producerConfigs = config.getProducerConfigs(getTaskProducerClientId(threadId, taskId));
-                producerConfigs.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, applicationId + "-" + taskId);
-                log.info("Creating producer client for task {}", taskId);
-                taskProducers.put(taskId, clientSupplier.getProducer(producerConfigs));
-            }
-            final RecordCollector recordCollector = new RecordCollectorImpl(
-                logContext,
-                taskId,
-                mainConsumer,
-                threadProducer != null ?
-                    new StreamsProducer(logContext, threadProducer) :
-                    new StreamsProducer(logContext, taskProducers.get(taskId), applicationId, taskId),
-                config.defaultProductionExceptionHandler(),
-                EXACTLY_ONCE.equals(config.getString(StreamsConfig.PROCESSING_GUARANTEE_CONFIG)),
-                streamsMetrics);
-
-            return new StreamTask(
-                taskId,
-                partitions,
-                topology,
-                mainConsumer,
-                config,
-                streamsMetrics,
-                stateDirectory,
-                cache,
-                time,
-                stateManager,
-                recordCollector);
-        }
-
-        public void close() {
-            if (threadProducer != null) {
-                try {
-                    threadProducer.close();
-                } catch (final Throwable e) {
-                    log.error("Failed to close producer due to the following error:", e);
-                }
-            }
-        }
-    }
-
-    static class StandbyTaskCreator extends AbstractTaskCreator<StandbyTask> {
-        private final Sensor createTaskSensor;
-
-        StandbyTaskCreator(final InternalTopologyBuilder builder,
-                           final StreamsConfig config,
-                           final StreamsMetricsImpl streamsMetrics,
-                           final StateDirectory stateDirectory,
-                           final ChangelogReader storeChangelogReader,
-                           final Time time,
-                           final String threadId,
-                           final Logger log) {
-            super(
-                builder,
-                config,
-                streamsMetrics,
-                stateDirectory,
-                storeChangelogReader,
-                time,
-                log);
-            createTaskSensor = ThreadMetrics.createTaskSensor(threadId, streamsMetrics);
-        }
-
-        @Override
-        StandbyTask createTask(final Consumer<byte[], byte[]> consumer,
-                               final TaskId taskId,
-                               final Set<TopicPartition> partitions) {
-            createTaskSensor.record();
-
-            final String threadIdPrefix = String.format("stream-thread [%s] ", Thread.currentThread().getName());
-            final String logPrefix = threadIdPrefix + String.format("%s [%s] ", "standby-task", taskId);
-            final LogContext logContext = new LogContext(logPrefix);
-
-            final ProcessorTopology topology = builder.buildSubtopology(taskId.topicGroupId);
-
-            if (topology.hasStateWithChangelogs()) {
-                final ProcessorStateManager stateManager = new ProcessorStateManager(
-                    taskId,
-                    partitions,
-                    Task.TaskType.STANDBY,
-                    stateDirectory,
-                    topology.storeToChangelogTopic(),
-                    storeChangelogReader,
-                    logContext);
-
-                return new StandbyTask(
-                    taskId,
-                    partitions,
-                    topology,
-                    config,
-                    streamsMetrics,
-                    stateManager,
-                    stateDirectory);
-            } else {
-                log.trace(
-                    "Skipped standby task {} with assigned partitions {} " +
-                        "since it does not have any state stores to materialize",
-                    taskId, partitions
-                );
-                return null;
-            }
-        }
-    }
-
     private final Time time;
     private final Logger log;
     private final String logPrefix;
@@ -508,8 +272,6 @@ public class StreamThread extends Thread {
     final ConsumerRebalanceListener rebalanceListener;
     final Consumer<byte[], byte[]> mainConsumer;
     final Consumer<byte[], byte[]> restoreConsumer;
-    final Producer<byte[], byte[]> threadProducer;
-    final Map<TaskId, Producer<byte[], byte[]>> taskProducers;
     final InternalTopologyBuilder builder;
 
     public static StreamThread create(final InternalTopologyBuilder builder,
@@ -544,11 +306,7 @@ public class StreamThread extends Thread {
 
         final ThreadCache cache = new ThreadCache(logContext, cacheSizeBytes, streamsMetrics);
 
-        final Map<TaskId, Producer<byte[], byte[]>> taskProducers = new HashMap<>();
-
-        // TODO: refactor `TaskCreator` into `TaskManager`;
-        //  this will allow to reduce the surface area of `taskProducers` that is passed to many classes atm
-        final TaskCreator activeTaskCreator = new TaskCreator(
+        final ActiveTaskCreator activeTaskCreator = new ActiveTaskCreator(
             builder,
             config,
             streamsMetrics,
@@ -557,7 +315,6 @@ public class StreamThread extends Thread {
             cache,
             time,
             clientSupplier,
-            taskProducers,
             threadId,
             log);
         final StandbyTaskCreator standbyTaskCreator = new StandbyTaskCreator(
@@ -566,7 +323,6 @@ public class StreamThread extends Thread {
             streamsMetrics,
             stateDirectory,
             changelogReader,
-            time,
             threadId,
             log);
         final TaskManager taskManager = new TaskManager(
@@ -576,10 +332,9 @@ public class StreamThread extends Thread {
             streamsMetrics,
             activeTaskCreator,
             standbyTaskCreator,
-            taskProducers,
             builder,
-            adminClient
-        );
+            adminClient,
+            stateDirectory);
 
         log.info("Creating consumer client");
         final String applicationId = config.getString(StreamsConfig.APPLICATION_ID_CONFIG);
@@ -602,8 +357,6 @@ public class StreamThread extends Thread {
         final StreamThread streamThread = new StreamThread(
             time,
             config,
-            activeTaskCreator.threadProducer,
-            taskProducers,
             adminClient,
             mainConsumer,
             restoreConsumer,
@@ -621,8 +374,6 @@ public class StreamThread extends Thread {
 
     public StreamThread(final Time time,
                         final StreamsConfig config,
-                        final Producer<byte[], byte[]> threadProducer,
-                        final Map<TaskId, Producer<byte[], byte[]>> taskProducers,
                         final Admin adminClient,
                         final Consumer<byte[], byte[]> mainConsumer,
                         final Consumer<byte[], byte[]> restoreConsumer,
@@ -665,8 +416,6 @@ public class StreamThread extends Thread {
         this.taskManager = taskManager;
         this.restoreConsumer = restoreConsumer;
         this.mainConsumer = mainConsumer;
-        this.threadProducer = threadProducer;
-        this.taskProducers = taskProducers;
         this.changelogReader = changelogReader;
         this.originalReset = originalReset;
         this.assignmentErrorCode = assignmentErrorCode;
@@ -686,14 +435,6 @@ public class StreamThread extends Thread {
         }
     }
 
-    private static String getTaskProducerClientId(final String threadClientId, final TaskId taskId) {
-        return threadClientId + "-" + taskId + "-producer";
-    }
-
-    private static String getThreadProducerClientId(final String threadClientId) {
-        return threadClientId + "-producer";
-    }
-
     private static String getConsumerClientId(final String threadClientId) {
         return threadClientId + "-consumer";
     }
@@ -1127,9 +868,7 @@ public class StreamThread extends Thread {
             this.state().name(),
             getConsumerClientId(this.getName()),
             getRestoreConsumerClientId(this.getName()),
-            threadProducer == null ?
-                Collections.emptySet() :
-                Collections.singleton(getThreadProducerClientId(this.getName())),
+            taskManager.producerClientIds(),
             adminClientId,
             Collections.emptySet(),
             Collections.emptySet());
@@ -1139,11 +878,9 @@ public class StreamThread extends Thread {
 
     private void updateThreadMetadata(final Map<TaskId, Task> activeTasks,
                                       final Map<TaskId, Task> standbyTasks) {
-        final Set<String> producerClientIds = new HashSet<>();
         final Set<TaskMetadata> activeTasksMetadata = new HashSet<>();
         for (final Map.Entry<TaskId, Task> task : activeTasks.entrySet()) {
             activeTasksMetadata.add(new TaskMetadata(task.getKey().toString(), task.getValue().inputPartitions()));
-            producerClientIds.add(getTaskProducerClientId(getName(), task.getKey()));
         }
         final Set<TaskMetadata> standbyTasksMetadata = new HashSet<>();
         for (final Map.Entry<TaskId, Task> task : standbyTasks.entrySet()) {
@@ -1156,9 +893,7 @@ public class StreamThread extends Thread {
             this.state().name(),
             getConsumerClientId(this.getName()),
             getRestoreConsumerClientId(this.getName()),
-            threadProducer == null ?
-                producerClientIds :
-                Collections.singleton(getThreadProducerClientId(this.getName())),
+            taskManager.producerClientIds(),
             adminClientId,
             activeTasksMetadata,
             standbyTasksMetadata);
@@ -1198,22 +933,7 @@ public class StreamThread extends Thread {
     }
 
     public Map<MetricName, Metric> producerMetrics() {
-        final LinkedHashMap<MetricName, Metric> result = new LinkedHashMap<>();
-        if (threadProducer != null) {
-            final Map<MetricName, ? extends Metric> producerMetrics = threadProducer.metrics();
-            if (producerMetrics != null) {
-                result.putAll(producerMetrics);
-            }
-        } else {
-            // When EOS is turned on, each task will have its own producer client
-            // and the producer object passed in here will be null. We would then iterate through
-            // all the active tasks and add their metrics to the output metrics map.
-            for (final StreamTask task : taskManager.fixmeStreamTasks().values()) {
-                final Map<MetricName, ? extends Metric> taskProducerMetrics = taskProducers.get(task.id).metrics();
-                result.putAll(taskProducerMetrics);
-            }
-        }
-        return result;
+        return taskManager.producerMetrics();
     }
 
     public Map<MetricName, Metric> consumerMetrics() {
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsProducer.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsProducer.java
index db8d2bd..0324bf2 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsProducer.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsProducer.java
@@ -31,7 +31,6 @@ import org.apache.kafka.common.errors.UnknownProducerIdException;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.errors.TaskMigratedException;
-import org.apache.kafka.streams.processor.TaskId;
 import org.slf4j.Logger;
 
 import java.util.List;
@@ -49,42 +48,29 @@ import java.util.concurrent.Future;
  */
 public class StreamsProducer {
     private final Logger log;
+    private final String logPrefix;
 
     private final Producer<byte[], byte[]> producer;
     private final String applicationId;
-    private final TaskId taskId;
-    private final String logMessage;
     private final boolean eosEnabled;
 
     private boolean transactionInFlight = false;
     private boolean transactionInitialized = false;
 
-    public StreamsProducer(final LogContext logContext,
-                           final Producer<byte[], byte[]> producer) {
-        this(logContext, producer, null, null);
-    }
-
-    public StreamsProducer(final LogContext logContext,
-                           final Producer<byte[], byte[]> producer,
-                           final String applicationId,
-                           final TaskId taskId) {
-        if ((applicationId != null && taskId == null) ||
-            (applicationId == null && taskId != null)) {
-            throw new IllegalArgumentException("applicationId and taskId must either be both null or both be not null");
-        }
-
-        this.log = logContext.logger(getClass());
+    public StreamsProducer(final Producer<byte[], byte[]> producer,
+                           final boolean eosEnabled,
+                           final LogContext logContext,
+                           final String applicationId) {
+        log = logContext.logger(getClass());
+        logPrefix = logContext.logPrefix().trim();
 
         this.producer = Objects.requireNonNull(producer, "producer cannot be null");
         this.applicationId = applicationId;
-        this.taskId = taskId;
-        if (taskId != null) {
-            logMessage = "task " + taskId.toString();
-            eosEnabled = true;
-        } else {
-            logMessage = "all owned active tasks";
-            eosEnabled = false;
-        }
+        this.eosEnabled = eosEnabled;
+    }
+
+    private String formatException(final String message) {
+        return message + " [" + logPrefix + ", " + (eosEnabled ? "eos" : "alo") + "]";
     }
 
     /**
@@ -92,7 +78,7 @@ public class StreamsProducer {
      */
     public void initTransaction() {
         if (!eosEnabled) {
-            throw new IllegalStateException("EOS is disabled");
+            throw new IllegalStateException(formatException("EOS is disabled"));
         }
         if (!transactionInitialized) {
             // initialize transactions if eos is turned on, which will block if the previous transaction has not
@@ -101,17 +87,22 @@ public class StreamsProducer {
                 producer.initTransactions();
                 transactionInitialized = true;
             } catch (final TimeoutException exception) {
-                log.warn("Timeout exception caught when initializing transactions for {}. " +
-                    "\nThe broker is either slow or in bad state (like not having enough replicas) in responding to the request, " +
-                    "or the connection to broker was interrupted sending the request or receiving the response. " +
-                    "Will retry initializing the task in the next loop. " +
-                    "\nConsider overwriting {} to a larger value to avoid timeout errors",
-                    logMessage,
-                    ProducerConfig.MAX_BLOCK_MS_CONFIG);
+                log.warn(
+                    "Timeout exception caught when initializing transactions. " +
+                        "The broker is either slow or in bad state (like not having enough replicas) in " +
+                        "responding to the request, or the connection to broker was interrupted sending " +
+                        "the request or receiving the response. " +
+                        "Will retry initializing the task in the next loop. " +
+                        "Consider overwriting {} to a larger value to avoid timeout errors",
+                    ProducerConfig.MAX_BLOCK_MS_CONFIG
+                );
 
                 throw exception;
             } catch (final KafkaException exception) {
-                throw new StreamsException("Error encountered while initializing transactions for " + logMessage, exception);
+                throw new StreamsException(
+                    formatException("Error encountered while initializing transactions"),
+                    exception
+                );
             }
         }
     }
@@ -122,9 +113,15 @@ public class StreamsProducer {
                 producer.beginTransaction();
                 transactionInFlight = true;
             } catch (final ProducerFencedException error) {
-                throw new TaskMigratedException("Producer get fenced trying to begin a new transaction", error);
+                throw new TaskMigratedException(
+                    formatException("Producer get fenced trying to begin a new transaction"),
+                    error
+                );
             } catch (final KafkaException error) {
-                throw new StreamsException("Producer encounter unexpected error trying to begin a new transaction for " + logMessage, error);
+                throw new StreamsException(
+                    formatException("Producer encounter unexpected error trying to begin a new transaction"),
+                    error
+                );
             }
         }
     }
@@ -137,15 +134,17 @@ public class StreamsProducer {
         } catch (final KafkaException uncaughtException) {
             if (isRecoverable(uncaughtException)) {
                 // producer.send() call may throw a KafkaException which wraps a FencedException,
-                // in this case we should throw its wrapped inner cause so that it can be captured and re-wrapped as TaskMigrationException
-                throw new TaskMigratedException("Producer cannot send records anymore since it got fenced", uncaughtException.getCause());
+                // in this case we should throw its wrapped inner cause so that it can be
+                // captured and re-wrapped as TaskMigrationException
+                throw new TaskMigratedException(
+                    formatException("Producer cannot send records anymore since it got fenced"),
+                    uncaughtException.getCause()
+                );
             } else {
-                final String errorMessage = String.format(
-                    "Error encountered sending record to topic %s%s due to:%n%s",
-                    record.topic(),
-                    taskId == null ? "" : " " + logMessage,
-                    uncaughtException.toString());
-                throw new StreamsException(errorMessage, uncaughtException);
+                throw new StreamsException(
+                    formatException(String.format("Error encountered sending record to topic %s", record.topic())),
+                    uncaughtException
+                );
             }
         }
     }
@@ -161,7 +160,7 @@ public class StreamsProducer {
      */
     public void commitTransaction(final Map<TopicPartition, OffsetAndMetadata> offsets) throws ProducerFencedException {
         if (!eosEnabled) {
-            throw new IllegalStateException("EOS is disabled");
+            throw new IllegalStateException(formatException("EOS is disabled"));
         }
         maybeBeginTransaction();
         try {
@@ -169,12 +168,18 @@ public class StreamsProducer {
             producer.commitTransaction();
             transactionInFlight = false;
         } catch (final ProducerFencedException error) {
-            throw new TaskMigratedException("Producer get fenced trying to commit a transaction", error);
+            throw new TaskMigratedException(
+                formatException("Producer get fenced trying to commit a transaction"),
+                error
+            );
         } catch (final TimeoutException error) {
             // TODO KIP-447: we can consider treating it as non-fatal and retry on the thread level
-            throw new StreamsException("Timed out while committing a transaction for " + logMessage, error);
+            throw new StreamsException(formatException("Timed out while committing a transaction"), error);
         } catch (final KafkaException error) {
-            throw new StreamsException("Producer encounter unexpected error trying to commit a transaction for " + logMessage, error);
+            throw new StreamsException(
+                formatException("Producer encounter unexpected error trying to commit a transaction"),
+                error
+            );
         }
     }
 
@@ -183,7 +188,7 @@ public class StreamsProducer {
      */
     public void abortTransaction() throws ProducerFencedException {
         if (!eosEnabled) {
-            throw new IllegalStateException("EOS is disabled");
+            throw new IllegalStateException(formatException("EOS is disabled"));
         }
         if (transactionInFlight) {
             try {
@@ -198,7 +203,10 @@ public class StreamsProducer {
 
                 // can be ignored: transaction got already aborted by brokers/transactional-coordinator if this happens
             } catch (final KafkaException error) {
-                throw new StreamsException("Producer encounter unexpected error trying to abort a transaction for " + logMessage, error);
+                throw new StreamsException(
+                    formatException("Producer encounter unexpected error trying to abort a transaction"),
+                    error
+                );
             }
             transactionInFlight = false;
         }
@@ -212,17 +220,6 @@ public class StreamsProducer {
         producer.flush();
     }
 
-    public void close() {
-        if (eosEnabled) {
-            try {
-                producer.close();
-            } catch (final KafkaException error) {
-                throw new StreamsException("Producer encounter unexpected " +
-                    "error trying to close" + (taskId == null ? "" : " " + logMessage), error);
-            }
-        }
-    }
-
     // for testing only
     Producer<byte[], byte[]> kafkaProducer() {
         return producer;
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
index dc14d31..069da8c 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
@@ -20,8 +20,9 @@ import org.apache.kafka.clients.admin.Admin;
 import org.apache.kafka.clients.admin.DeleteRecordsResult;
 import org.apache.kafka.clients.admin.RecordsToDelete;
 import org.apache.kafka.clients.consumer.Consumer;
-import org.apache.kafka.clients.producer.Producer;
 import org.apache.kafka.common.KafkaException;
+import org.apache.kafka.common.Metric;
+import org.apache.kafka.common.MetricName;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.errors.TimeoutException;
 import org.apache.kafka.common.utils.LogContext;
@@ -61,11 +62,11 @@ public class TaskManager {
     private final UUID processId;
     private final String logPrefix;
     private final StreamsMetricsImpl streamsMetrics;
-    private final StreamThread.AbstractTaskCreator<? extends Task> activeTaskCreator;
-    private final StreamThread.AbstractTaskCreator<? extends Task> standbyTaskCreator;
-    private final Map<TaskId, Producer<byte[], byte[]>> taskProducers;
+    private final ActiveTaskCreator activeTaskCreator;
+    private final StandbyTaskCreator standbyTaskCreator;
     private final InternalTopologyBuilder builder;
     private final Admin adminClient;
+    private final StateDirectory stateDirectory;
 
     private final Map<TaskId, Task> tasks = new TreeMap<>();
     // materializing this relationship because the lookup is on the hot path
@@ -81,23 +82,23 @@ public class TaskManager {
                 final UUID processId,
                 final String logPrefix,
                 final StreamsMetricsImpl streamsMetrics,
-                final StreamThread.AbstractTaskCreator<? extends Task> activeTaskCreator,
-                final StreamThread.AbstractTaskCreator<? extends Task> standbyTaskCreator,
-                final Map<TaskId, Producer<byte[], byte[]>> taskProducers,
+                final ActiveTaskCreator activeTaskCreator,
+                final StandbyTaskCreator standbyTaskCreator,
                 final InternalTopologyBuilder builder,
-                final Admin adminClient) {
+                final Admin adminClient,
+                final StateDirectory stateDirectory) {
         this.changelogReader = changelogReader;
         this.processId = processId;
         this.logPrefix = logPrefix;
         this.streamsMetrics = streamsMetrics;
         this.activeTaskCreator = activeTaskCreator;
         this.standbyTaskCreator = standbyTaskCreator;
-        this.taskProducers = taskProducers;
         this.builder = builder;
         this.adminClient = adminClient;
+        this.stateDirectory = stateDirectory;
 
         final LogContext logContext = new LogContext(logPrefix);
-        this.log = logContext.logger(getClass());
+        log = logContext.logger(getClass());
     }
 
     void setMainConsumer(final Consumer<byte[], byte[]> mainConsumer) {
@@ -162,11 +163,11 @@ public class TaskManager {
     public void handleAssignment(final Map<TaskId, Set<TopicPartition>> activeTasks,
                                  final Map<TaskId, Set<TopicPartition>> standbyTasks) {
         log.info("Handle new assignment with:\n" +
-                "\tNew active tasks: {}\n" +
-                "\tNew standby tasks: {}\n" +
-                "\tExisting active tasks: {}\n" +
-                "\tExisting standby tasks: {}",
-            activeTasks.keySet(), standbyTasks.keySet(), activeTaskIds(), standbyTaskIds());
+                     "\tNew active tasks: {}\n" +
+                     "\tNew standby tasks: {}\n" +
+                     "\tExisting active tasks: {}\n" +
+                     "\tExisting standby tasks: {}",
+                 activeTasks.keySet(), standbyTasks.keySet(), activeTaskIds(), standbyTaskIds());
 
         final Map<TaskId, Set<TopicPartition>> activeTasksToCreate = new TreeMap<>(activeTasks);
         final Map<TaskId, Set<TopicPartition>> standbyTasksToCreate = new TreeMap<>(standbyTasks);
@@ -188,13 +189,22 @@ public class TaskManager {
                 try {
                     task.closeClean();
                 } catch (final RuntimeException e) {
-                    log.error(String.format("Failed to close task %s cleanly. Attempting to close remaining tasks before re-throwing:", task.id()), e);
+                    final String uncleanMessage = String.format("Failed to close task %s cleanly. Attempting to close remaining tasks before re-throwing:", task.id());
+                    log.error(uncleanMessage, e);
                     taskCloseExceptions.put(task.id(), e);
                     // We've already recorded the exception (which is the point of clean).
                     // Now, we should go ahead and complete the close because a half-closed task is no good to anyone.
                     task.closeDirty();
                 } finally {
-                    taskProducers.remove(task.id());
+                    if (task.isActive()) {
+                        try {
+                            activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(task.id());
+                        } catch (final RuntimeException e) {
+                            final String uncleanMessage = String.format("Failed to close task %s cleanly. Attempting to close remaining tasks before re-throwing:", task.id());
+                            log.error(uncleanMessage, e);
+                            taskCloseExceptions.putIfAbsent(task.id(), e);
+                        }
+                    }
                 }
 
                 iterator.remove();
@@ -223,11 +233,15 @@ public class TaskManager {
         }
 
         if (!activeTasksToCreate.isEmpty()) {
-            activeTaskCreator.createTasks(mainConsumer, activeTasksToCreate).forEach(this::addNewTask);
+            for (final Task task : activeTaskCreator.createTasks(mainConsumer, activeTasksToCreate)) {
+                addNewTask(task);
+            }
         }
 
         if (!standbyTasksToCreate.isEmpty()) {
-            standbyTaskCreator.createTasks(mainConsumer, standbyTasksToCreate).forEach(this::addNewTask);
+            for (final Task task : standbyTaskCreator.createTasks(standbyTasksToCreate)) {
+                addNewTask(task);
+            }
         }
 
         builder.addSubscribedTopicsFromAssignment(
@@ -268,7 +282,7 @@ public class TaskManager {
                     // it is possible that if there are multiple threads within the instance that one thread
                     // trying to grab the task from the other, while the other has not released the lock since
                     // it did not participate in the rebalance. In this case we can just retry in the next iteration
-                    log.debug("Could not initialize {} due to {}; will retry", task.id(), e.toString());
+                    log.debug("Could not initialize {} due to {}; will retry", task.id(), e);
                     allRunning = false;
                 }
             }
@@ -285,7 +299,7 @@ public class TaskManager {
                     try {
                         task.completeRestoration();
                     } catch (final TimeoutException e) {
-                        log.debug("Could not complete restoration for {} due to {}; will retry", task.id(), e.toString());
+                        log.debug("Could not complete restoration for {} due to {}; will retry", task.id(), e);
 
                         allRunning = false;
                     }
@@ -320,8 +334,8 @@ public class TaskManager {
 
         if (!remainingPartitions.isEmpty()) {
             log.warn("The following partitions {} are missing from the task partitions. It could potentially " +
-                "due to race condition of consumer detecting the heartbeat failure, or the tasks " +
-                "have been cleaned up by the handleAssignment callback.", remainingPartitions);
+                         "due to race condition of consumer detecting the heartbeat failure, or the tasks " +
+                         "have been cleaned up by the handleAssignment callback.", remainingPartitions);
         }
     }
 
@@ -345,7 +359,11 @@ public class TaskManager {
                 cleanupTask(task);
                 task.closeDirty();
                 iterator.remove();
-                taskProducers.remove(task.id());
+                try {
+                    activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(task.id());
+                } catch (final RuntimeException e) {
+                    log.warn("Error closing task producer for " + task.id() + " while handling lostAll", e);
+                }
             }
 
             for (final TopicPartition inputPartition : inputPartitions) {
@@ -366,7 +384,7 @@ public class TaskManager {
 
         final Set<TaskId> locallyStoredTasks = new HashSet<>();
 
-        final File[] stateDirs = activeTaskCreator.stateDirectory().listTaskDirectories();
+        final File[] stateDirs = stateDirectory.listTaskDirectories();
         if (stateDirs != null) {
             for (final File dir : stateDirs) {
                 try {
@@ -389,7 +407,9 @@ public class TaskManager {
         // 1. remove the changelog partitions from changelog reader;
         // 2. remove the input partitions from the materialized map;
         // 3. remove the task metrics from the metrics registry
-        changelogReader.remove(task.changelogPartitions());
+        if (!task.changelogPartitions().isEmpty()) {
+            changelogReader.remove(task.changelogPartitions());
+        }
 
         for (final TopicPartition inputPartition : task.inputPartitions()) {
             partitionToTask.remove(inputPartition);
@@ -419,14 +439,33 @@ public class TaskManager {
             } else {
                 task.closeDirty();
             }
+            if (task.isActive()) {
+                try {
+                    activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(task.id());
+                } catch (final RuntimeException e) {
+                    if (clean) {
+                        firstException.compareAndSet(null, e);
+                    } else {
+                        log.warn("Ignoring an exception while closing task " + task.id() + " producer.", e);
+                    }
+                }
+            }
             iterator.remove();
         }
 
-        activeTaskCreator.close();
+        try {
+            activeTaskCreator.closeThreadProducerIfNeeded();
+        } catch (final RuntimeException e) {
+            if (clean) {
+                firstException.compareAndSet(null, e);
+            } else {
+                log.warn("Ignoring an exception while closing thread producer.", e);
+            }
+        }
 
         final RuntimeException fatalException = firstException.get();
         if (fatalException != null) {
-            throw fatalException;
+            throw new RuntimeException("Unexpected exception while closing task", fatalException);
         }
     }
 
@@ -614,19 +653,11 @@ public class TaskManager {
         return stringBuilder.toString();
     }
 
-    // below are for testing only
-    StandbyTask standbyTask(final TopicPartition partition) {
-        for (final Task task : (Iterable<Task>) standbyTaskStream()::iterator) {
-            if (task.inputPartitions().contains(partition)) {
-                return (StandbyTask) task;
-            }
-        }
-        return null;
+    Map<MetricName, Metric> producerMetrics() {
+        return activeTaskCreator.producerMetrics();
     }
 
-    // TODO K9113: this is used from StreamThread only for a hack to collect metrics from the record collectors inside of StreamTasks
-    // Instead, we should register and record the metrics properly inside of the record collector.
-    Map<TaskId, StreamTask> fixmeStreamTasks() {
-        return tasks.values().stream().filter(t -> t instanceof StreamTask).map(t -> (StreamTask) t).collect(Collectors.toMap(Task::id, t -> t));
+    Set<String> producerClientIds() {
+        return activeTaskCreator.producerClientIds();
     }
 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java
index 7bb6b2c..7c6752d 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java
@@ -104,7 +104,7 @@ public class RecordCollectorTest {
     private final MockConsumer<byte[], byte[]> mockConsumer = new MockConsumer<>(OffsetResetStrategy.EARLIEST);
     private final MockProducer<byte[], byte[]> mockProducer = new MockProducer<>(
         cluster, true, new DefaultPartitioner(), byteArraySerializer, byteArraySerializer);
-    private final StreamsProducer streamsProducer = new StreamsProducer(logContext, mockProducer);
+    private final StreamsProducer streamsProducer = new StreamsProducer(mockProducer, false, logContext, null);
 
     private RecordCollectorImpl collector;
 
@@ -127,7 +127,7 @@ public class RecordCollectorTest {
 
     @Test
     public void shouldSendToSpecificPartition() {
-        final Headers headers = new RecordHeaders(new Header[]{new RecordHeader("key", "value".getBytes())});
+        final Headers headers = new RecordHeaders(new Header[] {new RecordHeader("key", "value".getBytes())});
 
         collector.send(topic, "999", "0", null, 0, null, stringSerializer, stringSerializer);
         collector.send(topic, "999", "0", null, 0, null, stringSerializer, stringSerializer);
@@ -157,7 +157,7 @@ public class RecordCollectorTest {
 
     @Test
     public void shouldSendWithPartitioner() {
-        final Headers headers = new RecordHeaders(new Header[]{new RecordHeader("key", "value".getBytes())});
+        final Headers headers = new RecordHeaders(new Header[] {new RecordHeader("key", "value".getBytes())});
 
         collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner);
         collector.send(topic, "9", "0", null, null, stringSerializer, stringSerializer, streamPartitioner);
@@ -183,7 +183,7 @@ public class RecordCollectorTest {
 
     @Test
     public void shouldSendWithNoPartition() {
-        final Headers headers = new RecordHeaders(new Header[]{new RecordHeader("key", "value".getBytes())});
+        final Headers headers = new RecordHeaders(new Header[] {new RecordHeader("key", "value".getBytes())});
 
         collector.send(topic, "3", "0", headers, null, null, stringSerializer, stringSerializer);
         collector.send(topic, "9", "0", headers, null, null, stringSerializer, stringSerializer);
@@ -306,8 +306,6 @@ public class RecordCollectorTest {
     @Test
     public void shouldForwardCloseToTransactionManager() {
         final StreamsProducer streamsProducer = mock(StreamsProducer.class);
-        streamsProducer.close();
-        expectLastCall();
         replay(streamsProducer);
 
         final RecordCollector collector = new RecordCollectorImpl(
@@ -328,8 +326,6 @@ public class RecordCollectorTest {
     public void shouldAbortTxIfEosEnabled() {
         final StreamsProducer streamsProducer = mock(StreamsProducer.class);
         streamsProducer.abortTransaction();
-        streamsProducer.close();
-        expectLastCall();
         replay(streamsProducer);
 
         final RecordCollector collector = new RecordCollectorImpl(
@@ -354,7 +350,6 @@ public class RecordCollectorTest {
             taskId,
             mockConsumer,
             new StreamsProducer(
-                logContext,
                 new MockProducer<byte[], byte[]>(cluster, true, new DefaultPartitioner(), byteArraySerializer, byteArraySerializer) {
                     @Override
                     public synchronized Future<RecordMetadata> send(final ProducerRecord<byte[], byte[]> record, final Callback callback) {
@@ -362,8 +357,10 @@ public class RecordCollectorTest {
                         return null;
                     }
                 },
-                "appId",
-                taskId),
+                true,
+                logContext,
+                "appId"
+            ),
             productionExceptionHandler,
             true,
             streamsMetrics
@@ -396,14 +393,17 @@ public class RecordCollectorTest {
             taskId,
             mockConsumer,
             new StreamsProducer(
-                logContext,
                 new MockProducer<byte[], byte[]>(cluster, true, new DefaultPartitioner(), byteArraySerializer, byteArraySerializer) {
                     @Override
                     public synchronized Future<RecordMetadata> send(final ProducerRecord<byte[], byte[]> record, final Callback callback) {
                         callback.onCompletion(null, exception);
                         return null;
                     }
-                }),
+                },
+                false,
+                logContext,
+                null
+            ),
             productionExceptionHandler,
             false,
             streamsMetrics
@@ -435,14 +435,17 @@ public class RecordCollectorTest {
             taskId,
             mockConsumer,
             new StreamsProducer(
-                logContext,
                 new MockProducer<byte[], byte[]>(cluster, true, new DefaultPartitioner(), byteArraySerializer, byteArraySerializer) {
                     @Override
                     public synchronized Future<RecordMetadata> send(final ProducerRecord<byte[], byte[]> record, final Callback callback) {
                         callback.onCompletion(null, new Exception());
                         return null;
                     }
-                }),
+                },
+                false,
+                logContext,
+                null
+            ),
             new AlwaysContinueProductionExceptionHandler(),
             false,
             streamsMetrics
@@ -475,14 +478,17 @@ public class RecordCollectorTest {
             taskId,
             mockConsumer,
             new StreamsProducer(
-                logContext,
                 new MockProducer<byte[], byte[]>(cluster, true, new DefaultPartitioner(), byteArraySerializer, byteArraySerializer) {
                     @Override
                     public synchronized Future<RecordMetadata> send(final ProducerRecord<byte[], byte[]> record, final Callback callback) {
                         callback.onCompletion(null, exception);
                         return null;
                     }
-                }),
+                },
+                false,
+                logContext,
+                null
+            ),
             new AlwaysContinueProductionExceptionHandler(),
             false,
             streamsMetrics
@@ -604,15 +610,16 @@ public class RecordCollectorTest {
             taskId,
             mockConsumer,
             new StreamsProducer(
-                logContext,
                 new MockProducer<byte[], byte[]>(cluster, true, new DefaultPartitioner(), byteArraySerializer, byteArraySerializer) {
                     @Override
                     public void abortTransaction() {
                         functionCalled.set(true);
                     }
                 },
-                "appId",
-                taskId),
+                true,
+                logContext,
+                "appId"
+            ),
             productionExceptionHandler,
             true,
             streamsMetrics
@@ -629,13 +636,16 @@ public class RecordCollectorTest {
             taskId,
             mockConsumer,
             new StreamsProducer(
-                logContext,
                 new MockProducer<byte[], byte[]>(cluster, true, new DefaultPartitioner(), byteArraySerializer, byteArraySerializer) {
                     @Override
                     public List<PartitionInfo> partitionsFor(final String topic) {
                         return Collections.emptyList();
                     }
-                }),
+                },
+                false,
+                logContext,
+                null
+            ),
             productionExceptionHandler,
             false,
             streamsMetrics
@@ -650,16 +660,12 @@ public class RecordCollectorTest {
     }
 
     @Test
-    public void shouldCloseInternalProducerForEOS() {
+    public void shouldNotCloseInternalProducerForEOS() {
         final RecordCollector collector = new RecordCollectorImpl(
             logContext,
             taskId,
             mockConsumer,
-            new StreamsProducer(
-                logContext,
-                mockProducer,
-                "appId",
-                taskId),
+            new StreamsProducer(mockProducer, true, logContext, "appId"),
             productionExceptionHandler,
             true,
             streamsMetrics
@@ -668,7 +674,7 @@ public class RecordCollectorTest {
         collector.close();
 
         // Flush should not throw as producer is still alive.
-        assertThrows(IllegalStateException.class, streamsProducer::flush);
+        streamsProducer.flush();
     }
 
     @Test
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
index 57cc1ed..13a669d 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
@@ -81,6 +81,7 @@ import java.io.File;
 import java.time.Duration;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -90,7 +91,11 @@ import java.util.Properties;
 import java.util.Set;
 import java.util.UUID;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Stream;
 
+import static java.util.Collections.emptyMap;
+import static java.util.Collections.emptySet;
+import static java.util.Collections.singletonMap;
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.common.utils.Utils.mkProperties;
@@ -98,9 +103,10 @@ import static org.apache.kafka.streams.processor.internals.StateManagerUtil.CHEC
 import static org.apache.kafka.streams.processor.internals.StreamThread.getSharedAdminClientId;
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.CoreMatchers.not;
-import static org.hamcrest.CoreMatchers.nullValue;
 import static org.hamcrest.CoreMatchers.startsWith;
 import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.empty;
+import static org.hamcrest.Matchers.is;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotEquals;
@@ -190,7 +196,7 @@ public class StreamThreadTest {
         );
 
         internalTopologyBuilder.buildTopology();
-        
+
         return StreamThread.create(
             internalTopologyBuilder,
             config,
@@ -426,8 +432,6 @@ public class StreamThreadTest {
             mockTime,
             config,
             null,
-            null,
-            null,
             consumer,
             consumer,
             null,
@@ -472,7 +476,7 @@ public class StreamThreadTest {
 
         thread.taskManager().handleAssignment(
             Collections.singletonMap(task1, assignedPartitions),
-            Collections.emptyMap()
+            emptyMap()
         );
 
         final MockConsumer<byte[], byte[]> mockConsumer = (MockConsumer<byte[], byte[]>) thread.mainConsumer;
@@ -556,8 +560,6 @@ public class StreamThreadTest {
             mockTime,
             config,
             null,
-            null,
-            null,
             consumer,
             consumer,
             null,
@@ -594,8 +596,6 @@ public class StreamThreadTest {
             mockTime,
             config,
             null,
-            null,
-            null,
             consumer,
             consumer,
             changelogReader,
@@ -636,7 +636,7 @@ public class StreamThreadTest {
         activeTasks.put(task1, Collections.singleton(t1p1));
         activeTasks.put(task2, Collections.singleton(t1p2));
 
-        thread.taskManager().handleAssignment(activeTasks, Collections.emptyMap());
+        thread.taskManager().handleAssignment(activeTasks, emptyMap());
 
         final MockConsumer<byte[], byte[]> mockConsumer = (MockConsumer<byte[], byte[]>) thread.mainConsumer;
         mockConsumer.assign(assignedPartitions);
@@ -673,7 +673,7 @@ public class StreamThreadTest {
         activeTasks.put(task1, Collections.singleton(t1p1));
         activeTasks.put(task2, Collections.singleton(t1p2));
 
-        thread.taskManager().handleAssignment(activeTasks, Collections.emptyMap());
+        thread.taskManager().handleAssignment(activeTasks, emptyMap());
 
         final MockConsumer<byte[], byte[]> mockConsumer = (MockConsumer<byte[], byte[]>) thread.mainConsumer;
         mockConsumer.assign(assignedPartitions);
@@ -714,7 +714,7 @@ public class StreamThreadTest {
         activeTasks.put(task1, Collections.singleton(t1p1));
         activeTasks.put(task2, Collections.singleton(t1p2));
 
-        thread.taskManager().handleAssignment(activeTasks, Collections.emptyMap());
+        thread.taskManager().handleAssignment(activeTasks, emptyMap());
 
         thread.shutdown();
 
@@ -758,7 +758,7 @@ public class StreamThreadTest {
         activeTasks.put(task1, Collections.singleton(t1p1));
         activeTasks.put(task2, Collections.singleton(t1p2));
 
-        thread.taskManager().handleAssignment(activeTasks, Collections.emptyMap());
+        thread.taskManager().handleAssignment(activeTasks, emptyMap());
         thread.rebalanceListener.onPartitionsAssigned(assignedPartitions);
 
         thread.shutdown();
@@ -786,8 +786,6 @@ public class StreamThreadTest {
             mockTime,
             config,
             null,
-            null,
-            null,
             consumer,
             consumer,
             null,
@@ -823,8 +821,6 @@ public class StreamThreadTest {
             mockTime,
             config,
             null,
-            null,
-            null,
             consumer,
             consumer,
             null,
@@ -854,8 +850,6 @@ public class StreamThreadTest {
             mockTime,
             config,
             null,
-            null,
-            null,
             consumer,
             consumer,
             null,
@@ -888,7 +882,7 @@ public class StreamThreadTest {
         // assign single partition
         standbyTasks.put(task1, Collections.singleton(t1p1));
 
-        thread.taskManager().handleAssignment(Collections.emptyMap(), standbyTasks);
+        thread.taskManager().handleAssignment(emptyMap(), standbyTasks);
 
         thread.rebalanceListener.onPartitionsAssigned(Collections.emptyList());
     }
@@ -914,7 +908,7 @@ public class StreamThreadTest {
         assignedPartitions.add(t1p1);
         activeTasks.put(task1, Collections.singleton(t1p1));
 
-        thread.taskManager().handleAssignment(activeTasks, Collections.emptyMap());
+        thread.taskManager().handleAssignment(activeTasks, emptyMap());
 
         final MockConsumer<byte[], byte[]> mockConsumer = (MockConsumer<byte[], byte[]>) thread.mainConsumer;
         mockConsumer.assign(assignedPartitions);
@@ -972,7 +966,7 @@ public class StreamThreadTest {
         assignedPartitions.add(t1p1);
         activeTasks.put(task1, Collections.singleton(t1p1));
 
-        thread.taskManager().handleAssignment(activeTasks, Collections.emptyMap());
+        thread.taskManager().handleAssignment(activeTasks, emptyMap());
 
         final MockConsumer<byte[], byte[]> mockConsumer = (MockConsumer<byte[], byte[]>) thread.mainConsumer;
         mockConsumer.assign(assignedPartitions);
@@ -1011,7 +1005,7 @@ public class StreamThreadTest {
         assignedPartitions.add(t1p1);
         activeTasks.put(task1, Collections.singleton(t1p1));
 
-        thread.taskManager().handleAssignment(activeTasks, Collections.emptyMap());
+        thread.taskManager().handleAssignment(activeTasks, emptyMap());
 
         final MockConsumer<byte[], byte[]> mockConsumer = (MockConsumer<byte[], byte[]>) thread.mainConsumer;
         mockConsumer.assign(assignedPartitions);
@@ -1059,7 +1053,7 @@ public class StreamThreadTest {
         assignedPartitions.add(t1p1);
         activeTasks.put(task1, Collections.singleton(t1p1));
 
-        thread.taskManager().handleAssignment(activeTasks, Collections.emptyMap());
+        thread.taskManager().handleAssignment(activeTasks, emptyMap());
 
         final MockConsumer<byte[], byte[]> mockConsumer = (MockConsumer<byte[], byte[]>) thread.mainConsumer;
         mockConsumer.assign(assignedPartitions);
@@ -1080,7 +1074,31 @@ public class StreamThreadTest {
     public void shouldReturnActiveTaskMetadataWhileRunningState() {
         internalTopologyBuilder.addSource(null, "source", null, null, null, topic1);
 
-        final StreamThread thread = createStreamThread(CLIENT_ID, config, false);
+        clientSupplier.setClusterForAdminClient(createCluster());
+
+        final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(
+            metrics,
+            APPLICATION_ID,
+            config.getString(StreamsConfig.BUILT_IN_METRICS_VERSION_CONFIG)
+        );
+
+        internalTopologyBuilder.buildTopology();
+
+        final StreamThread thread = StreamThread.create(
+            internalTopologyBuilder,
+            config,
+            clientSupplier,
+            clientSupplier.getAdmin(config.getAdminConfigs(CLIENT_ID)),
+            PROCESS_ID,
+            CLIENT_ID,
+            streamsMetrics,
+            mockTime,
+            streamsMetadataState,
+            0,
+            stateDirectory,
+            new MockStateRestoreListener(),
+            threadIdx
+        );
 
         thread.setState(StreamThread.State.STARTING);
         thread.rebalanceListener.onPartitionsRevoked(Collections.emptySet());
@@ -1092,7 +1110,7 @@ public class StreamThreadTest {
         assignedPartitions.add(t1p1);
         activeTasks.put(task1, Collections.singleton(t1p1));
 
-        thread.taskManager().handleAssignment(activeTasks, Collections.emptyMap());
+        thread.taskManager().handleAssignment(activeTasks, emptyMap());
 
         final MockConsumer<byte[], byte[]> mockConsumer = (MockConsumer<byte[], byte[]>) thread.mainConsumer;
         mockConsumer.assign(assignedPartitions);
@@ -1148,7 +1166,7 @@ public class StreamThreadTest {
         // assign single partition
         standbyTasks.put(task1, Collections.singleton(t1p1));
 
-        thread.taskManager().handleAssignment(Collections.emptyMap(), standbyTasks);
+        thread.taskManager().handleAssignment(emptyMap(), standbyTasks);
 
         thread.rebalanceListener.onPartitionsAssigned(Collections.emptyList());
 
@@ -1201,15 +1219,15 @@ public class StreamThreadTest {
         standbyTasks.put(task1, Collections.singleton(t1p1));
         standbyTasks.put(task3, Collections.singleton(t2p1));
 
-        thread.taskManager().handleAssignment(Collections.emptyMap(), standbyTasks);
+        thread.taskManager().handleAssignment(emptyMap(), standbyTasks);
         thread.taskManager().tryToCompleteRestoration();
 
         thread.rebalanceListener.onPartitionsAssigned(Collections.emptyList());
 
         thread.runOnce();
 
-        final StandbyTask standbyTask1 = thread.taskManager().standbyTask(t1p1);
-        final StandbyTask standbyTask2 = thread.taskManager().standbyTask(t2p1);
+        final StandbyTask standbyTask1 = standbyTask(thread.taskManager(), t1p1);
+        final StandbyTask standbyTask2 = standbyTask(thread.taskManager(), t2p1);
         assertEquals(task1, standbyTask1.id());
         assertEquals(task3, standbyTask2.id());
 
@@ -1245,18 +1263,14 @@ public class StreamThreadTest {
         setupInternalTopologyWithoutState();
         internalTopologyBuilder.addStateStore(new MockKeyValueStoreBuilder("myStore", true), "processor1");
 
-        final StandbyTask standbyTask = createStandbyTask();
-
-        assertThat(standbyTask, not(nullValue()));
+        assertThat(createStandbyTask(), not(empty()));
     }
 
     @Test
     public void shouldNotCreateStandbyTaskWithoutStateStores() {
         setupInternalTopologyWithoutState();
 
-        final StandbyTask standbyTask = createStandbyTask();
-
-        assertThat(standbyTask, nullValue());
+        assertThat(createStandbyTask(), empty());
     }
 
     @Test
@@ -1267,9 +1281,7 @@ public class StreamThreadTest {
         storeBuilder.withLoggingDisabled();
         internalTopologyBuilder.addStateStore(storeBuilder, "processor1");
 
-        final StandbyTask standbyTask = createStandbyTask();
-
-        assertThat(standbyTask, nullValue());
+        assertThat(createStandbyTask(), empty());
     }
 
     @Test
@@ -1306,7 +1318,7 @@ public class StreamThreadTest {
         assignedPartitions.add(t1p1);
         activeTasks.put(task1, Collections.singleton(t1p1));
 
-        thread.taskManager().handleAssignment(activeTasks, Collections.emptyMap());
+        thread.taskManager().handleAssignment(activeTasks, emptyMap());
 
         clientSupplier.consumer.assign(assignedPartitions);
         clientSupplier.consumer.updateBeginningOffsets(Collections.singletonMap(t1p1, 0L));
@@ -1434,7 +1446,7 @@ public class StreamThreadTest {
         final TaskId task0 = new TaskId(0, 0);
         activeTasks.put(task0, topicPartitionSet);
 
-        thread.taskManager().handleAssignment(activeTasks, Collections.emptyMap());
+        thread.taskManager().handleAssignment(activeTasks, emptyMap());
 
         mockConsumer.updatePartitions(
             "topic",
@@ -1563,7 +1575,7 @@ public class StreamThreadTest {
         final Set<TopicPartition> assignedPartitions = Collections.singleton(t1p1);
         thread.taskManager().handleAssignment(
             Collections.singletonMap(task1, assignedPartitions),
-            Collections.emptyMap());
+            emptyMap());
 
         final MockConsumer<byte[], byte[]> mockConsumer = (MockConsumer<byte[], byte[]>) thread.mainConsumer;
         mockConsumer.assign(Collections.singleton(t1p1));
@@ -1641,8 +1653,6 @@ public class StreamThreadTest {
             mockTime,
             config,
             null,
-            null,
-            null,
             consumer,
             consumer,
             null,
@@ -1685,8 +1695,6 @@ public class StreamThreadTest {
             mockTime,
             config,
             null,
-            null,
-            null,
             consumer,
             consumer,
             null,
@@ -1740,7 +1748,7 @@ public class StreamThreadTest {
             Collections.singletonMap(
                 task1,
                 assignedPartitions),
-            Collections.emptyMap());
+            emptyMap());
 
         final MockConsumer<byte[], byte[]> mockConsumer = (MockConsumer<byte[], byte[]>) thread.mainConsumer;
         mockConsumer.assign(Collections.singleton(t1p1));
@@ -1834,30 +1842,11 @@ public class StreamThreadTest {
     }
 
     @Test
-    public void shouldConstructProducerMetricsWithoutEOS() {
-        final MockProducer<byte[], byte[]> producer = new MockProducer<>();
+    public void shouldTransmitTaskManagerMetrics() {
         final Consumer<byte[], byte[]> consumer = EasyMock.createNiceMock(Consumer.class);
-        final TaskManager taskManager = mockTaskManagerCommit(consumer, 1, 0);
 
-        final StreamsMetricsImpl streamsMetrics =
-            new StreamsMetricsImpl(metrics, CLIENT_ID, StreamsConfig.METRICS_LATEST);
-        final StreamThread thread = new StreamThread(
-            mockTime,
-            config,
-            producer,
-            null,
-            null,
-            consumer,
-            consumer,
-            null,
-            null,
-            taskManager,
-            streamsMetrics,
-            internalTopologyBuilder,
-            CLIENT_ID,
-            new LogContext(""),
-            new AtomicInteger()
-        );
+        final TaskManager taskManager = EasyMock.createNiceMock(TaskManager.class);
+
         final MetricName testMetricName = new MetricName("test_metric", "", "", new HashMap<>());
         final Metric testMetric = new KafkaMetric(
             new Object(),
@@ -1865,24 +1854,16 @@ public class StreamThreadTest {
             (Measurable) (config, now) -> 0,
             null,
             new MockTime());
-        producer.setMockMetrics(testMetricName, testMetric);
-        final Map<MetricName, Metric> producerMetrics = thread.producerMetrics();
-        assertEquals(testMetricName, producerMetrics.get(testMetricName).metricName());
-    }
+        final Map<MetricName, Metric> dummyProducerMetrics = singletonMap(testMetricName, testMetric);
 
-    @Test
-    public void shouldConstructProducerMetricsWithEOS() {
-        final MockProducer<byte[], byte[]> producer = new MockProducer<>();
-        final Consumer<byte[], byte[]> consumer = EasyMock.createNiceMock(Consumer.class);
-        final TaskManager taskManager = mockTaskManagerCommit(consumer, 1, 0);
+        EasyMock.expect(taskManager.producerMetrics()).andReturn(dummyProducerMetrics);
+        EasyMock.replay(taskManager, consumer);
 
         final StreamsMetricsImpl streamsMetrics =
             new StreamsMetricsImpl(metrics, CLIENT_ID, StreamsConfig.METRICS_LATEST);
         final StreamThread thread = new StreamThread(
             mockTime,
             new StreamsConfig(configProps(true)),
-            null,       // with EOS the thread producer should be null
-            null,
             null,
             consumer,
             consumer,
@@ -1895,18 +1876,8 @@ public class StreamThreadTest {
             new LogContext(""),
             new AtomicInteger()
         );
-        final MetricName testMetricName = new MetricName("test_metric", "", "", new HashMap<>());
-        final Metric testMetric = new KafkaMetric(
-            new Object(),
-            testMetricName,
-            (Measurable) (config, now) -> 0,
-            null,
-            new MockTime());
 
-        // without creating tasks the metrics should be empty
-        producer.setMockMetrics(testMetricName, testMetric);
-        final Map<MetricName, Metric> producerMetrics = thread.producerMetrics();
-        assertEquals(Collections.<MetricName, Metric>emptyMap(), producerMetrics);
+        assertThat(dummyProducerMetrics, is(thread.producerMetrics()));
     }
 
     @Test
@@ -1917,7 +1888,6 @@ public class StreamThreadTest {
 
         final MockAdminClient adminClient = new MockAdminClient(cluster, broker1, null);
 
-        final MockProducer<byte[], byte[]> producer = new MockProducer<>();
         final Consumer<byte[], byte[]> consumer = EasyMock.createNiceMock(Consumer.class);
         final TaskManager taskManager = EasyMock.createNiceMock(TaskManager.class);
 
@@ -1925,8 +1895,6 @@ public class StreamThreadTest {
         final StreamThread thread = new StreamThread(
             mockTime,
             config,
-            producer,
-            null,
             adminClient,
             consumer,
             consumer,
@@ -1958,7 +1926,6 @@ public class StreamThreadTest {
                                               final int numberOfCommits,
                                               final int commits) {
         final TaskManager taskManager = EasyMock.createNiceMock(TaskManager.class);
-        EasyMock.expect(taskManager.fixmeStreamTasks()).andReturn(Collections.emptyMap()).anyTimes();
         EasyMock.expect(taskManager.commitAll()).andReturn(commits).times(numberOfCommits);
         EasyMock.replay(taskManager, consumer);
         return taskManager;
@@ -1970,24 +1937,20 @@ public class StreamThreadTest {
         internalTopologyBuilder.addProcessor("processor1", () -> mockProcessor, "source1");
     }
 
-    private StandbyTask createStandbyTask() {
+    private Collection<Task> createStandbyTask() {
         final LogContext logContext = new LogContext("test");
         final Logger log = logContext.logger(StreamThreadTest.class);
         final StreamsMetricsImpl streamsMetrics =
             new StreamsMetricsImpl(metrics, CLIENT_ID, StreamsConfig.METRICS_LATEST);
-        final StreamThread.StandbyTaskCreator standbyTaskCreator = new StreamThread.StandbyTaskCreator(
+        final StandbyTaskCreator standbyTaskCreator = new StandbyTaskCreator(
             internalTopologyBuilder,
             config,
             streamsMetrics,
             stateDirectory,
             new MockChangelogReader(),
-            mockTime,
             CLIENT_ID,
             log);
-        return standbyTaskCreator.createTask(
-            new MockConsumer<>(OffsetResetStrategy.EARLIEST),
-            new TaskId(1, 2),
-            Collections.emptySet());
+        return standbyTaskCreator.createTasks(singletonMap(new TaskId(1, 2), emptySet()));
     }
 
     private void addRecord(final MockConsumer<byte[], byte[]> mockConsumer,
@@ -2010,4 +1973,14 @@ public class StreamThreadTest {
             new byte[0],
             new byte[0]));
     }
+
+    StandbyTask standbyTask(final TaskManager taskManager, final TopicPartition partition) {
+        final Stream<Task> standbys = taskManager.tasks().values().stream().filter(t -> !t.isActive());
+        for (final Task task : (Iterable<Task>) standbys::iterator) {
+            if (task.inputPartitions().contains(partition)) {
+                return (StandbyTask) task;
+            }
+        }
+        return null;
+    }
 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsProducerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsProducerTest.java
index 5dfae2d..6205198 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsProducerTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsProducerTest.java
@@ -34,7 +34,6 @@ import org.apache.kafka.common.serialization.ByteArraySerializer;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.errors.TaskMigratedException;
-import org.apache.kafka.streams.processor.TaskId;
 import org.junit.Before;
 import org.junit.Test;
 
@@ -50,12 +49,9 @@ import static org.easymock.EasyMock.mock;
 import static org.easymock.EasyMock.replay;
 import static org.easymock.EasyMock.verify;
 import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.core.IsEqual.equalTo;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertSame;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.sameInstance;
 import static org.junit.Assert.assertThrows;
-import static org.junit.Assert.assertTrue;
 
 public class StreamsProducerTest {
 
@@ -68,7 +64,7 @@ public class StreamsProducerTest {
         Collections.emptySet(),
         Collections.emptySet()
     );
-    private final TaskId taskId = new TaskId(0, 0);
+
     private final ByteArraySerializer byteArraySerializer = new ByteArraySerializer();
     private final Map<TopicPartition, OffsetAndMetadata> offsetsAndMetadata = mkMap(
         mkEntry(new TopicPartition(topic, 0), new OffsetAndMetadata(0L, null))
@@ -76,11 +72,13 @@ public class StreamsProducerTest {
 
     private final MockProducer<byte[], byte[]> mockProducer = new MockProducer<>(
         cluster, true, new DefaultPartitioner(), byteArraySerializer, byteArraySerializer);
-    private final StreamsProducer streamsProducer = new StreamsProducer(logContext, mockProducer);
+    private final StreamsProducer aloStreamsProducer =
+        new StreamsProducer(mockProducer, false, logContext, null);
 
     private final MockProducer<byte[], byte[]> eosMockProducer = new MockProducer<>(
         cluster, true, new DefaultPartitioner(), byteArraySerializer, byteArraySerializer);
-    private final StreamsProducer eosStreamsProducer = new StreamsProducer(logContext, eosMockProducer, "appId", taskId);
+    private final StreamsProducer eosStreamsProducer =
+        new StreamsProducer(eosMockProducer, true, logContext, "appId");
 
     private final ProducerRecord<byte[], byte[]> record =
         new ProducerRecord<>(topic, 0, 0L, new byte[0], new byte[0], new RecordHeaders());
@@ -95,61 +93,38 @@ public class StreamsProducerTest {
         {
             final NullPointerException thrown = assertThrows(
                 NullPointerException.class,
-                () -> new StreamsProducer(logContext, null)
+                () -> new StreamsProducer(null, false, logContext, null)
             );
 
-            assertThat(thrown.getMessage(), equalTo("producer cannot be null"));
+            assertThat(thrown.getMessage(), is("producer cannot be null"));
         }
 
         {
             final NullPointerException thrown = assertThrows(
                 NullPointerException.class,
-                () -> new StreamsProducer(logContext, null, "appId", taskId)
+                () -> new StreamsProducer(null, true, logContext, "appId")
             );
 
-            assertThat(thrown.getMessage(), equalTo("producer cannot be null"));
-        }
-    }
-
-    @Test
-    public void shouldThrowIfIncorrectlyInitialized() {
-        {
-            final IllegalArgumentException thrown = assertThrows(
-                IllegalArgumentException.class,
-                () -> new StreamsProducer(logContext, mockProducer, null, taskId)
-            );
-            assertThat(thrown.getMessage(), equalTo("applicationId and taskId must either be both null or both be not null"));
-        }
-
-        {
-            final IllegalArgumentException thrown = assertThrows(
-                IllegalArgumentException.class,
-                () -> new StreamsProducer(logContext, mockProducer, "appId", null)
-            );
-            assertThat(thrown.getMessage(), equalTo("applicationId and taskId must either be both null or both be not null"));
+            assertThat(thrown.getMessage(), is("producer cannot be null"));
         }
     }
 
-    // non-eos tests
-
-    // functional tests
-
     @Test
     public void shouldNotInitTxIfEosDisable() {
-        assertFalse(mockProducer.transactionInitialized());
+        assertThat(mockProducer.transactionInitialized(), is(false));
     }
 
     @Test
     public void shouldNotBeginTxOnSendIfEosDisable() {
-        streamsProducer.send(record, null);
-        assertFalse(mockProducer.transactionInFlight());
+        aloStreamsProducer.send(record, null);
+        assertThat(mockProducer.transactionInFlight(), is(false));
     }
 
     @Test
     public void shouldForwardRecordOnSend() {
-        streamsProducer.send(record, null);
-        assertThat(mockProducer.history().size(), equalTo(1));
-        assertThat(mockProducer.history().get(0), equalTo(record));
+        aloStreamsProducer.send(record, null);
+        assertThat(mockProducer.history().size(), is(1));
+        assertThat(mockProducer.history().get(0), is(record));
     }
 
     @Test
@@ -160,11 +135,12 @@ public class StreamsProducerTest {
         expect(producer.partitionsFor("topic")).andReturn(expectedPartitionInfo);
         replay(producer);
 
-        final StreamsProducer streamsProducer = new StreamsProducer(logContext, producer);
+        final StreamsProducer streamsProducer =
+            new StreamsProducer(producer, false, logContext, null);
 
         final List<PartitionInfo> partitionInfo = streamsProducer.partitionsFor(topic);
 
-        assertSame(expectedPartitionInfo, partitionInfo);
+        assertThat(partitionInfo, sameInstance(expectedPartitionInfo));
         verify(producer);
     }
 
@@ -176,7 +152,8 @@ public class StreamsProducerTest {
         expectLastCall();
         replay(producer);
 
-        final StreamsProducer streamsProducer = new StreamsProducer(logContext, producer);
+        final StreamsProducer streamsProducer =
+            new StreamsProducer(producer, false, logContext, null);
 
         streamsProducer.flush();
 
@@ -189,10 +166,10 @@ public class StreamsProducerTest {
     public void shouldFailOnInitTxIfEosDisabled() {
         final IllegalStateException thrown = assertThrows(
             IllegalStateException.class,
-            streamsProducer::initTransaction
+            aloStreamsProducer::initTransaction
         );
 
-        assertThat(thrown.getMessage(), equalTo("EOS is disabled"));
+        assertThat(thrown.getMessage(), is("EOS is disabled [test, alo]"));
     }
 
     @Test
@@ -201,11 +178,12 @@ public class StreamsProducerTest {
 
         final StreamsException thrown = assertThrows(
             StreamsException.class,
-            () -> streamsProducer.send(record, null)
+            () -> aloStreamsProducer.send(record, null)
         );
 
-        assertEquals(mockProducer.sendException, thrown.getCause());
-        assertThat(thrown.getMessage(), equalTo("Error encountered sending record to topic topic due to:\norg.apache.kafka.common.KafkaException: KABOOM!"));
+        assertThat(thrown.getCause(), is(mockProducer.sendException));
+        assertThat(thrown.getMessage(), is("Error encountered sending record to topic topic [test, alo]"));
+        assertThat(thrown.getCause(), is(mockProducer.sendException));
     }
 
     @Test
@@ -214,38 +192,30 @@ public class StreamsProducerTest {
 
         final RuntimeException thrown = assertThrows(
             RuntimeException.class,
-            () -> streamsProducer.send(record, null)
+            () -> aloStreamsProducer.send(record, null)
         );
 
-        assertThat(thrown.getMessage(), equalTo("KABOOM!"));
+        assertThat(thrown.getMessage(), is("KABOOM!"));
     }
 
     @Test
     public void shouldFailOnCommitIfEosDisabled() {
         final IllegalStateException thrown = assertThrows(
             IllegalStateException.class,
-            () -> streamsProducer.commitTransaction(null)
+            () -> aloStreamsProducer.commitTransaction(null)
         );
 
-        assertThat(thrown.getMessage(), equalTo("EOS is disabled"));
+        assertThat(thrown.getMessage(), is("EOS is disabled [test, alo]"));
     }
 
     @Test
     public void shouldFailOnAbortIfEosDisabled() {
         final IllegalStateException thrown = assertThrows(
             IllegalStateException.class,
-            streamsProducer::abortTransaction
+            aloStreamsProducer::abortTransaction
         );
 
-        assertThat(thrown.getMessage(), equalTo("EOS is disabled"));
-    }
-
-    @Test
-    public void shouldNotCloseProducerIfEosDisabled() {
-        mockProducer.closeException = new KafkaException("KABOOM!");
-        streamsProducer.close();
-
-        assertFalse(mockProducer.closed());
+        assertThat(thrown.getMessage(), is("EOS is disabled [test, alo]"));
     }
 
     // EOS tests
@@ -254,30 +224,30 @@ public class StreamsProducerTest {
 
     @Test
     public void shouldInitTxOnEos() {
-        assertTrue(eosMockProducer.transactionInitialized());
+        assertThat(eosMockProducer.transactionInitialized(), is(true));
     }
 
     @Test
     public void shouldBeginTxOnEosSend() {
         eosStreamsProducer.send(record, null);
-        assertTrue(eosMockProducer.transactionInFlight());
+        assertThat(eosMockProducer.transactionInFlight(), is(true));
     }
 
     @Test
     public void shouldContinueTxnSecondEosSend() {
         eosStreamsProducer.send(record, null);
         eosStreamsProducer.send(record, null);
-        assertTrue(eosMockProducer.transactionInFlight());
-        assertThat(eosMockProducer.uncommittedRecords().size(), equalTo(2));
+        assertThat(eosMockProducer.transactionInFlight(), is(true));
+        assertThat(eosMockProducer.uncommittedRecords().size(), is(2));
     }
 
     @Test
     public void shouldForwardRecordButNotCommitOnEosSend() {
         eosStreamsProducer.send(record, null);
-        assertTrue(eosMockProducer.transactionInFlight());
-        assertTrue(eosMockProducer.history().isEmpty());
-        assertThat(eosMockProducer.uncommittedRecords().size(), equalTo(1));
-        assertThat(eosMockProducer.uncommittedRecords().get(0), equalTo(record));
+        assertThat(eosMockProducer.transactionInFlight(), is(true));
+        assertThat(eosMockProducer.history().isEmpty(), is(true));
+        assertThat(eosMockProducer.uncommittedRecords().size(), is(1));
+        assertThat(eosMockProducer.uncommittedRecords().get(0), is(record));
     }
 
     @Test
@@ -291,7 +261,8 @@ public class StreamsProducerTest {
         expectLastCall();
         replay(producer);
 
-        final StreamsProducer streamsProducer = new StreamsProducer(logContext, producer, "appId", taskId);
+        final StreamsProducer streamsProducer = 
+            new StreamsProducer(producer, true, logContext, "appId");
         streamsProducer.initTransaction();
 
         streamsProducer.commitTransaction(offsetsAndMetadata);
@@ -302,40 +273,40 @@ public class StreamsProducerTest {
     @Test
     public void shouldSendOffsetToTxOnEosCommit() {
         eosStreamsProducer.commitTransaction(offsetsAndMetadata);
-        assertTrue(eosMockProducer.sentOffsets());
+        assertThat(eosMockProducer.sentOffsets(), is(true));
     }
 
     @Test
     public void shouldCommitTxOnEosCommit() {
         eosStreamsProducer.send(record, null);
-        assertTrue(eosMockProducer.transactionInFlight());
+        assertThat(eosMockProducer.transactionInFlight(), is(true));
 
         eosStreamsProducer.commitTransaction(offsetsAndMetadata);
 
-        assertFalse(eosMockProducer.transactionInFlight());
-        assertTrue(eosMockProducer.uncommittedRecords().isEmpty());
-        assertTrue(eosMockProducer.uncommittedOffsets().isEmpty());
-        assertThat(eosMockProducer.history().size(), equalTo(1));
-        assertThat(eosMockProducer.history().get(0), equalTo(record));
-        assertThat(eosMockProducer.consumerGroupOffsetsHistory().size(), equalTo(1));
-        assertThat(eosMockProducer.consumerGroupOffsetsHistory().get(0).get("appId"), equalTo(offsetsAndMetadata));
+        assertThat(eosMockProducer.transactionInFlight(), is(false));
+        assertThat(eosMockProducer.uncommittedRecords().isEmpty(), is(true));
+        assertThat(eosMockProducer.uncommittedOffsets().isEmpty(), is(true));
+        assertThat(eosMockProducer.history().size(), is(1));
+        assertThat(eosMockProducer.history().get(0), is(record));
+        assertThat(eosMockProducer.consumerGroupOffsetsHistory().size(), is(1));
+        assertThat(eosMockProducer.consumerGroupOffsetsHistory().get(0).get("appId"), is(offsetsAndMetadata));
     }
 
     @Test
     public void shouldAbortTxOnEosAbort() {
         // call `send()` to start a transaction
         eosStreamsProducer.send(record, null);
-        assertTrue(eosMockProducer.transactionInFlight());
-        assertThat(eosMockProducer.uncommittedRecords().size(), equalTo(1));
-        assertThat(eosMockProducer.uncommittedRecords().get(0), equalTo(record));
+        assertThat(eosMockProducer.transactionInFlight(), is(true));
+        assertThat(eosMockProducer.uncommittedRecords().size(), is(1));
+        assertThat(eosMockProducer.uncommittedRecords().get(0), is(record));
 
         eosStreamsProducer.abortTransaction();
 
-        assertFalse(eosMockProducer.transactionInFlight());
-        assertTrue(eosMockProducer.uncommittedRecords().isEmpty());
-        assertTrue(eosMockProducer.uncommittedOffsets().isEmpty());
-        assertTrue(eosMockProducer.history().isEmpty());
-        assertTrue(eosMockProducer.consumerGroupOffsetsHistory().isEmpty());
+        assertThat(eosMockProducer.transactionInFlight(), is(false));
+        assertThat(eosMockProducer.uncommittedRecords().isEmpty(), is(true));
+        assertThat(eosMockProducer.uncommittedOffsets().isEmpty(), is(true));
+        assertThat(eosMockProducer.history().isEmpty(), is(true));
+        assertThat(eosMockProducer.consumerGroupOffsetsHistory().isEmpty(), is(true));
     }
 
     @Test
@@ -346,7 +317,8 @@ public class StreamsProducerTest {
         expectLastCall();
         replay(producer);
 
-        final StreamsProducer streamsProducer = new StreamsProducer(logContext, producer, "appId", taskId);
+        final StreamsProducer streamsProducer = 
+            new StreamsProducer(producer, true, logContext, "appId");
         streamsProducer.initTransaction();
 
         streamsProducer.abortTransaction();
@@ -360,43 +332,46 @@ public class StreamsProducerTest {
     public void shouldThrowTimeoutExceptionOnEosInitTxTimeout() {
         // use `mockProducer` instead of `eosMockProducer` to avoid double Tx-Init
         mockProducer.initTransactionException = new TimeoutException("KABOOM!");
-        final StreamsProducer streamsProducer = new StreamsProducer(logContext, mockProducer, "appId", taskId);
+        final StreamsProducer streamsProducer = 
+            new StreamsProducer(mockProducer, true, logContext, "appId");
 
         final TimeoutException thrown = assertThrows(
             TimeoutException.class,
             streamsProducer::initTransaction
         );
 
-        assertThat(thrown.getMessage(), equalTo("KABOOM!"));
+        assertThat(thrown.getMessage(), is("KABOOM!"));
     }
 
     @Test
     public void shouldThrowStreamsExceptionOnEosInitError() {
         // use `mockProducer` instead of `eosMockProducer` to avoid double Tx-Init
         mockProducer.initTransactionException = new KafkaException("KABOOM!");
-        final StreamsProducer streamsProducer = new StreamsProducer(logContext, mockProducer, "appId", taskId);
+        final StreamsProducer streamsProducer = 
+            new StreamsProducer(mockProducer, true, logContext, "appId");
 
         final StreamsException thrown = assertThrows(
             StreamsException.class,
             streamsProducer::initTransaction
         );
 
-        assertEquals(mockProducer.initTransactionException, thrown.getCause());
-        assertThat(thrown.getMessage(), equalTo("Error encountered while initializing transactions for task 0_0"));
+        assertThat(thrown.getCause(), is(mockProducer.initTransactionException));
+        assertThat(thrown.getMessage(), is("Error encountered while initializing transactions [test, eos]"));
     }
 
     @Test
     public void shouldFailOnEosInitFatal() {
         // use `mockProducer` instead of `eosMockProducer` to avoid double Tx-Init
         mockProducer.initTransactionException = new RuntimeException("KABOOM!");
-        final StreamsProducer streamsProducer = new StreamsProducer(logContext, mockProducer, "appId", taskId);
+        final StreamsProducer streamsProducer =
+            new StreamsProducer(mockProducer, true, logContext, "appId");
 
         final RuntimeException thrown = assertThrows(
             RuntimeException.class,
             streamsProducer::initTransaction
         );
 
-        assertThat(thrown.getMessage(), equalTo("KABOOM!"));
+        assertThat(thrown.getMessage(), is("KABOOM!"));
     }
 
     @Test
@@ -408,7 +383,11 @@ public class StreamsProducerTest {
             () -> eosStreamsProducer.send(null, null)
         );
 
-        assertThat(thrown.getMessage(), equalTo("Producer get fenced trying to begin a new transaction; it means all tasks belonging to this thread should be migrated."));
+        assertThat(
+            thrown.getMessage(),
+            is("Producer get fenced trying to begin a new transaction [test, eos];" +
+                   " it means all tasks belonging to this thread should be migrated.")
+        );
     }
 
     @Test
@@ -420,8 +399,11 @@ public class StreamsProducerTest {
             StreamsException.class,
             () -> eosStreamsProducer.send(null, null));
 
-        assertEquals(eosMockProducer.beginTransactionException, thrown.getCause());
-        assertThat(thrown.getMessage(), equalTo("Producer encounter unexpected error trying to begin a new transaction for task 0_0"));
+        assertThat(thrown.getCause(), is(eosMockProducer.beginTransactionException));
+        assertThat(
+            thrown.getMessage(),
+            is("Producer encounter unexpected error trying to begin a new transaction [test, eos]")
+        );
     }
 
     @Test
@@ -433,7 +415,7 @@ public class StreamsProducerTest {
             RuntimeException.class,
             () -> eosStreamsProducer.send(null, null));
 
-        assertThat(thrown.getMessage(), equalTo("KABOOM!"));
+        assertThat(thrown.getMessage(), is("KABOOM!"));
     }
 
     @Test
@@ -448,8 +430,12 @@ public class StreamsProducerTest {
             () -> eosStreamsProducer.send(record, null)
         );
 
-        assertEquals(exception, thrown.getCause());
-        assertThat(thrown.getMessage(), equalTo("Producer cannot send records anymore since it got fenced; it means all tasks belonging to this thread should be migrated."));
+        assertThat(thrown.getCause(), is(exception));
+        assertThat(
+            thrown.getMessage(),
+            is("Producer cannot send records anymore since it got fenced [test, eos];" +
+                   " it means all tasks belonging to this thread should be migrated.")
+        );
     }
 
     @Test
@@ -463,8 +449,12 @@ public class StreamsProducerTest {
             () -> eosStreamsProducer.send(record, null)
         );
 
-        assertEquals(exception, thrown.getCause());
-        assertThat(thrown.getMessage(), equalTo("Producer cannot send records anymore since it got fenced; it means all tasks belonging to this thread should be migrated."));
+        assertThat(thrown.getCause(), is(exception));
+        assertThat(
+            thrown.getMessage(),
+            is("Producer cannot send records anymore since it got fenced [test, eos];" +
+                   " it means all tasks belonging to this thread should be migrated.")
+        );
     }
 
     @Test
@@ -479,8 +469,12 @@ public class StreamsProducerTest {
             () -> eosStreamsProducer.commitTransaction(null)
         );
 
-        assertEquals(eosMockProducer.sendOffsetsToTransactionException, thrown.getCause());
-        assertThat(thrown.getMessage(), equalTo("Producer get fenced trying to commit a transaction; it means all tasks belonging to this thread should be migrated."));
+        assertThat(thrown.getCause(), is(eosMockProducer.sendOffsetsToTransactionException));
+        assertThat(
+            thrown.getMessage(),
+            is("Producer get fenced trying to commit a transaction [test, eos];" +
+                   " it means all tasks belonging to this thread should be migrated.")
+        );
     }
 
     @Test
@@ -494,8 +488,11 @@ public class StreamsProducerTest {
             () -> eosStreamsProducer.commitTransaction(null)
         );
 
-        assertEquals(eosMockProducer.sendOffsetsToTransactionException, thrown.getCause());
-        assertThat(thrown.getMessage(), equalTo("Producer encounter unexpected error trying to commit a transaction for task 0_0"));
+        assertThat(thrown.getCause(), is(eosMockProducer.sendOffsetsToTransactionException));
+        assertThat(
+            thrown.getMessage(),
+            is("Producer encounter unexpected error trying to commit a transaction [test, eos]")
+        );
     }
 
     @Test
@@ -509,7 +506,7 @@ public class StreamsProducerTest {
             () -> eosStreamsProducer.commitTransaction(null)
         );
 
-        assertThat(thrown.getMessage(), equalTo("KABOOM!"));
+        assertThat(thrown.getMessage(), is("KABOOM!"));
     }
 
     @Test
@@ -522,9 +519,13 @@ public class StreamsProducerTest {
             () -> eosStreamsProducer.commitTransaction(offsetsAndMetadata)
         );
 
-        assertTrue(eosMockProducer.sentOffsets());
-        assertEquals(eosMockProducer.commitTransactionException, thrown.getCause());
-        assertThat(thrown.getMessage(), equalTo("Producer get fenced trying to commit a transaction; it means all tasks belonging to this thread should be migrated."));
+        assertThat(eosMockProducer.sentOffsets(), is(true));
+        assertThat(thrown.getCause(), is(eosMockProducer.commitTransactionException));
+        assertThat(
+            thrown.getMessage(),
+            is("Producer get fenced trying to commit a transaction [test, eos];" +
+                   " it means all tasks belonging to this thread should be migrated.")
+        );
     }
 
     @Test
@@ -537,9 +538,9 @@ public class StreamsProducerTest {
             () -> eosStreamsProducer.commitTransaction(offsetsAndMetadata)
         );
 
-        assertTrue(eosMockProducer.sentOffsets());
-        assertEquals(eosMockProducer.commitTransactionException, thrown.getCause());
-        assertThat(thrown.getMessage(), equalTo("Timed out while committing a transaction for task " + taskId));
+        assertThat(eosMockProducer.sentOffsets(), is(true));
+        assertThat(thrown.getCause(), is(eosMockProducer.commitTransactionException));
+        assertThat(thrown.getMessage(), is("Timed out while committing a transaction [test, eos]"));
     }
 
     @Test
@@ -551,9 +552,12 @@ public class StreamsProducerTest {
             () -> eosStreamsProducer.commitTransaction(offsetsAndMetadata)
         );
 
-        assertTrue(eosMockProducer.sentOffsets());
-        assertEquals(eosMockProducer.commitTransactionException, thrown.getCause());
-        assertThat(thrown.getMessage(), equalTo("Producer encounter unexpected error trying to commit a transaction for task 0_0"));
+        assertThat(eosMockProducer.sentOffsets(), is(true));
+        assertThat(thrown.getCause(), is(eosMockProducer.commitTransactionException));
+        assertThat(
+            thrown.getMessage(),
+            is("Producer encounter unexpected error trying to commit a transaction [test, eos]")
+        );
     }
 
     @Test
@@ -565,8 +569,8 @@ public class StreamsProducerTest {
             () -> eosStreamsProducer.commitTransaction(offsetsAndMetadata)
         );
 
-        assertTrue(eosMockProducer.sentOffsets());
-        assertThat(thrown.getMessage(), equalTo("KABOOM!"));
+        assertThat(eosMockProducer.sentOffsets(), is(true));
+        assertThat(thrown.getMessage(), is("KABOOM!"));
     }
 
     @Test
@@ -580,7 +584,8 @@ public class StreamsProducerTest {
         expectLastCall().andThrow(new ProducerFencedException("KABOOM!"));
         replay(producer);
 
-        final StreamsProducer streamsProducer = new StreamsProducer(logContext, producer, "appId", taskId);
+        final StreamsProducer streamsProducer =
+            new StreamsProducer(producer, true, logContext, "appId");
         streamsProducer.initTransaction();
         // call `send()` to start a transaction
         streamsProducer.send(record, null);
@@ -598,8 +603,11 @@ public class StreamsProducerTest {
 
         final StreamsException thrown = assertThrows(StreamsException.class, eosStreamsProducer::abortTransaction);
 
-        assertEquals(eosMockProducer.abortTransactionException, thrown.getCause());
-        assertThat(thrown.getMessage(), equalTo("Producer encounter unexpected error trying to abort a transaction for task 0_0"));
+        assertThat(thrown.getCause(), is(eosMockProducer.abortTransactionException));
+        assertThat(
+            thrown.getMessage(),
+            is("Producer encounter unexpected error trying to abort a transaction [test, eos]")
+        );
     }
 
     @Test
@@ -610,27 +618,6 @@ public class StreamsProducerTest {
 
         final RuntimeException thrown = assertThrows(RuntimeException.class, eosStreamsProducer::abortTransaction);
 
-        assertThat(thrown.getMessage(), equalTo("KABOOM!"));
-    }
-
-    @Test
-    public void shouldFailOnCloseFatal() {
-        eosMockProducer.closeException = new RuntimeException("KABOOM!");
-
-        final RuntimeException thrown = assertThrows(
-            RuntimeException.class,
-            eosStreamsProducer::close
-        );
-
-        assertThat(thrown.getMessage(), equalTo("KABOOM!"));
-    }
-
-    @Test
-    public void shouldCloseProducerIfEosEnabled() {
-        eosStreamsProducer.close();
-
-        final RuntimeException thrown = assertThrows(IllegalStateException.class, () -> eosStreamsProducer.send(record, null));
-
-        assertThat(thrown.getMessage(), equalTo("MockProducer is already closed."));
+        assertThat(thrown.getMessage(), is("KABOOM!"));
     }
 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
index 3f38a11..40b8782 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
@@ -24,14 +24,20 @@ import org.apache.kafka.clients.admin.RecordsToDelete;
 import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.common.KafkaException;
+import org.apache.kafka.common.Metric;
+import org.apache.kafka.common.MetricName;
 import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.errors.TimeoutException;
 import org.apache.kafka.common.internals.KafkaFutureImpl;
+import org.apache.kafka.common.metrics.KafkaMetric;
+import org.apache.kafka.common.metrics.Measurable;
 import org.apache.kafka.common.metrics.Metrics;
+import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.errors.LockException;
 import org.apache.kafka.streams.errors.TaskMigratedException;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.TaskId;
-
 import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
 import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender;
 import org.easymock.EasyMock;
@@ -108,9 +114,9 @@ public class TaskManagerTest {
     @Mock(type = MockType.STRICT)
     private Consumer<byte[], byte[]> consumer;
     @Mock(type = MockType.STRICT)
-    private StreamThread.AbstractTaskCreator<Task> activeTaskCreator;
+    private ActiveTaskCreator activeTaskCreator;
     @Mock(type = MockType.NICE)
-    private StreamThread.AbstractTaskCreator<Task> standbyTaskCreator;
+    private StandbyTaskCreator standbyTaskCreator;
     @Mock(type = MockType.NICE)
     private Admin adminClient;
 
@@ -128,9 +134,8 @@ public class TaskManagerTest {
                                       streamsMetrics,
                                       activeTaskCreator,
                                       standbyTaskCreator,
-                                      new HashMap<>(),
                                       topologyBuilder,
-                                      adminClient);
+                                      adminClient, stateDirectory);
         taskManager.setMainConsumer(consumer);
     }
 
@@ -139,7 +144,6 @@ public class TaskManagerTest {
         final TopicPartition newTopicPartition = new TopicPartition("topic2", 1);
         final Map<TaskId, Set<TopicPartition>> assignment = mkMap(mkEntry(taskId01, mkSet(t1p1, newTopicPartition)));
 
-        expect(activeTaskCreator.builder()).andReturn(topologyBuilder).anyTimes();
         expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andReturn(emptyList()).anyTimes();
 
         topologyBuilder.addSubscribedTopicsFromAssignment(eq(asList(t1p1, newTopicPartition)), anyString());
@@ -164,7 +168,6 @@ public class TaskManagerTest {
         assertThat((new File(taskFolders[1], StateManagerUtil.CHECKPOINT_FILE_NAME)).createNewFile(), is(true));
         assertThat((new File(taskFolders[3], StateManagerUtil.CHECKPOINT_FILE_NAME)).createNewFile(), is(true));
 
-        expect(activeTaskCreator.stateDirectory()).andReturn(stateDirectory).once();
         expect(stateDirectory.listTaskDirectories()).andReturn(taskFolders).once();
 
         replay(activeTaskCreator, stateDirectory);
@@ -183,7 +186,9 @@ public class TaskManagerTest {
         expectRestoreToBeCompleted(consumer, changeLogReader);
         expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andReturn(singletonList(task00)).anyTimes();
         expect(activeTaskCreator.createTasks(anyObject(), eq(emptyMap()))).andReturn(emptyList()).anyTimes();
-        expect(standbyTaskCreator.createTasks(anyObject(), anyObject())).andReturn(emptyList()).anyTimes();
+        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId00);
+        expectLastCall();
+        expect(standbyTaskCreator.createTasks(anyObject())).andReturn(emptyList()).anyTimes();
 
         topologyBuilder.addSubscribedTopicsFromAssignment(anyObject(), anyString());
         expectLastCall().anyTimes();
@@ -203,11 +208,105 @@ public class TaskManagerTest {
     }
 
     @Test
+    public void shouldCloseActiveTasksWhenHandlingLostTasks() {
+        final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true);
+        final Task task01 = new StateMachineTask(taskId01, taskId01Partitions, false);
+
+        expectRestoreToBeCompleted(consumer, changeLogReader);
+        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andReturn(singletonList(task00)).anyTimes();
+        expect(activeTaskCreator.createTasks(anyObject(), eq(emptyMap()))).andReturn(emptyList()).anyTimes();
+        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId00);
+        expectLastCall();
+        expect(standbyTaskCreator.createTasks(eq(taskId01Assignment))).andReturn(singletonList(task01)).anyTimes();
+
+        topologyBuilder.addSubscribedTopicsFromAssignment(anyObject(), anyString());
+        expectLastCall().anyTimes();
+
+        replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer, changeLogReader);
+
+        taskManager.handleAssignment(taskId00Assignment, taskId01Assignment);
+
+        assertThat(taskManager.tryToCompleteRestoration(), is(true));
+        assertThat(task00.state(), is(Task.State.RUNNING));
+        assertThat(task01.state(), is(Task.State.RUNNING));
+        taskManager.handleLostAll();
+        assertThat(task00.state(), is(Task.State.CLOSED));
+        assertThat(task01.state(), is(Task.State.RUNNING));
+        assertThat(taskManager.activeTaskMap(), Matchers.anEmptyMap());
+        assertThat(taskManager.standbyTaskMap(), is(singletonMap(taskId01, task01)));
+    }
+
+    @Test
+    public void shouldReviveCorruptTasks() {
+        final ProcessorStateManager stateManager = EasyMock.createStrictMock(ProcessorStateManager.class);
+        stateManager.markChangelogAsCorrupted(taskId00Partitions);
+        replay(stateManager);
+        final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true, stateManager);
+
+        expectRestoreToBeCompleted(consumer, changeLogReader);
+        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andReturn(singletonList(task00)).anyTimes();
+        expect(activeTaskCreator.createTasks(anyObject(), eq(emptyMap()))).andReturn(emptyList()).anyTimes();
+        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId00);
+        expectLastCall();
+        expect(standbyTaskCreator.createTasks(anyObject())).andReturn(emptyList()).anyTimes();
+
+        topologyBuilder.addSubscribedTopicsFromAssignment(anyObject(), anyString());
+        expectLastCall().anyTimes();
+
+        replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer, changeLogReader);
+
+        taskManager.handleAssignment(taskId00Assignment, emptyMap());
+
+        assertThat(taskManager.tryToCompleteRestoration(), is(true));
+        assertThat(task00.state(), is(Task.State.RUNNING));
+        taskManager.handleCorruption(singletonMap(taskId00, taskId00Partitions));
+        assertThat(task00.state(), is(Task.State.CREATED));
+        assertThat(taskManager.activeTaskMap(), is(singletonMap(taskId00, task00)));
+        assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
+        verify(stateManager);
+    }
+
+    @Test
+    public void shouldReviveCorruptTasksEvenIfTheyCannotCloseClean() {
+        final ProcessorStateManager stateManager = EasyMock.createStrictMock(ProcessorStateManager.class);
+        stateManager.markChangelogAsCorrupted(taskId00Partitions);
+        replay(stateManager);
+        final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true, stateManager) {
+            @Override
+            public void closeClean() {
+                throw new RuntimeException("oops");
+            }
+        };
+
+        expectRestoreToBeCompleted(consumer, changeLogReader);
+        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andReturn(singletonList(task00)).anyTimes();
+        expect(activeTaskCreator.createTasks(anyObject(), eq(emptyMap()))).andReturn(emptyList()).anyTimes();
+        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId00);
+        expectLastCall();
+        expect(standbyTaskCreator.createTasks(anyObject())).andReturn(emptyList()).anyTimes();
+
+        topologyBuilder.addSubscribedTopicsFromAssignment(anyObject(), anyString());
+        expectLastCall().anyTimes();
+
+        replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer, changeLogReader);
+
+        taskManager.handleAssignment(taskId00Assignment, emptyMap());
+
+        assertThat(taskManager.tryToCompleteRestoration(), is(true));
+        assertThat(task00.state(), is(Task.State.RUNNING));
+        taskManager.handleCorruption(singletonMap(taskId00, taskId00Partitions));
+        assertThat(task00.state(), is(Task.State.CREATED));
+        assertThat(taskManager.activeTaskMap(), is(singletonMap(taskId00, task00)));
+        assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
+        verify(stateManager);
+    }
+
+    @Test
     public void shouldCloseStandbyUnassignedTasksWhenCreatingNewTasks() {
         final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, false);
 
         expectRestoreToBeCompleted(consumer, changeLogReader);
-        expect(standbyTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andReturn(singletonList(task00)).anyTimes();
+        expect(standbyTaskCreator.createTasks(eq(taskId00Assignment))).andReturn(singletonList(task00)).anyTimes();
         replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
         taskManager.handleAssignment(emptyMap(), taskId00Assignment);
         assertThat(taskManager.tryToCompleteRestoration(), is(true));
@@ -221,20 +320,24 @@ public class TaskManagerTest {
     @Test
     public void shouldAddNonResumedSuspendedTasks() {
         final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true);
+        final Task task01 = new StateMachineTask(taskId01, taskId01Partitions, false);
 
         expectRestoreToBeCompleted(consumer, changeLogReader);
         // expect these calls twice (because we're going to tryToCompleteRestoration twice)
         expectRestoreToBeCompleted(consumer, changeLogReader);
         expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andReturn(singletonList(task00));
-        replay(activeTaskCreator, consumer, changeLogReader);
+        expect(standbyTaskCreator.createTasks(eq(taskId01Assignment))).andReturn(singletonList(task01));
+        replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
 
-        taskManager.handleAssignment(taskId00Assignment, emptyMap());
+        taskManager.handleAssignment(taskId00Assignment, taskId01Assignment);
         assertThat(taskManager.tryToCompleteRestoration(), is(true));
         assertThat(task00.state(), is(Task.State.RUNNING));
+        assertThat(task01.state(), is(Task.State.RUNNING));
 
-        taskManager.handleAssignment(taskId00Assignment, emptyMap());
+        taskManager.handleAssignment(taskId00Assignment, taskId01Assignment);
         assertThat(taskManager.tryToCompleteRestoration(), is(true));
         assertThat(task00.state(), is(Task.State.RUNNING));
+        assertThat(task01.state(), is(Task.State.RUNNING));
 
         verify(activeTaskCreator);
     }
@@ -251,7 +354,7 @@ public class TaskManagerTest {
         changeLogReader.transitToRestoreActive();
         expectLastCall();
         expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andReturn(singletonList(task00)).anyTimes();
-        expect(standbyTaskCreator.createTasks(anyObject(), eq(emptyMap()))).andReturn(emptyList()).anyTimes();
+        expect(standbyTaskCreator.createTasks(eq(emptyMap()))).andReturn(emptyList()).anyTimes();
         replay(consumer, activeTaskCreator, standbyTaskCreator, changeLogReader);
 
         taskManager.handleAssignment(assignment, emptyMap());
@@ -267,6 +370,89 @@ public class TaskManagerTest {
     }
 
     @Test
+    public void shouldNotCompleteRestorationIfTasksCannotInitialize() {
+        final Map<TaskId, Set<TopicPartition>> assignment = mkMap(
+            mkEntry(taskId00, taskId00Partitions),
+            mkEntry(taskId01, taskId01Partitions)
+        );
+        final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true) {
+            @Override
+            public void initializeIfNeeded() {
+                throw new LockException("can't lock");
+            }
+        };
+        final Task task01 = new StateMachineTask(taskId01, taskId01Partitions, true) {
+            @Override
+            public void initializeIfNeeded() {
+                throw new TimeoutException("timed out");
+            }
+        };
+
+        expect(changeLogReader.completedChangelogs()).andReturn(emptySet());
+        expect(consumer.assignment()).andReturn(emptySet());
+        consumer.resume(eq(emptySet()));
+        expectLastCall();
+        changeLogReader.transitToRestoreActive();
+        expectLastCall();
+        expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andReturn(asList(task00, task01)).anyTimes();
+        expect(standbyTaskCreator.createTasks(eq(emptyMap()))).andReturn(emptyList()).anyTimes();
+        replay(consumer, activeTaskCreator, standbyTaskCreator, changeLogReader);
+
+        taskManager.handleAssignment(assignment, emptyMap());
+
+        assertThat(task00.state(), is(Task.State.CREATED));
+        assertThat(task01.state(), is(Task.State.CREATED));
+
+        assertThat(taskManager.tryToCompleteRestoration(), is(false));
+
+        assertThat(task00.state(), is(Task.State.CREATED));
+        assertThat(task01.state(), is(Task.State.CREATED));
+        assertThat(
+            taskManager.activeTaskMap(),
+            Matchers.equalTo(mkMap(mkEntry(taskId00, task00), mkEntry(taskId01, task01)))
+        );
+        assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
+        verify(activeTaskCreator);
+    }
+
+    @Test
+    public void shouldNotCompleteRestorationIfTaskCannotCompleteRestoration() {
+        final Map<TaskId, Set<TopicPartition>> assignment = mkMap(
+            mkEntry(taskId00, taskId00Partitions)
+        );
+        final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true) {
+            @Override
+            public void completeRestoration() {
+                throw new TimeoutException("timeout!");
+            }
+        };
+
+        expect(changeLogReader.completedChangelogs()).andReturn(emptySet());
+        expect(consumer.assignment()).andReturn(emptySet());
+        consumer.resume(eq(emptySet()));
+        expectLastCall();
+        changeLogReader.transitToRestoreActive();
+        expectLastCall();
+        expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andReturn(singletonList(task00)).anyTimes();
+        expect(standbyTaskCreator.createTasks(eq(emptyMap()))).andReturn(emptyList()).anyTimes();
+        replay(consumer, activeTaskCreator, standbyTaskCreator, changeLogReader);
+
+        taskManager.handleAssignment(assignment, emptyMap());
+
+        assertThat(task00.state(), is(Task.State.CREATED));
+
+        assertThat(taskManager.tryToCompleteRestoration(), is(false));
+
+        assertThat(task00.state(), is(Task.State.RESTORING));
+        assertThat(
+            taskManager.activeTaskMap(),
+            Matchers.equalTo(mkMap(mkEntry(taskId00, task00)))
+        );
+        assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
+        verify(activeTaskCreator);
+    }
+
+    @Test
     public void shouldSuspendActiveTasks() {
         final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true);
 
@@ -304,17 +490,101 @@ public class TaskManagerTest {
     }
 
     @Test
-    public void shouldCloseActiveTasksOnShutdown() {
+    public void shouldCloseActiveTasksAndPropogateExceptionsOnCleanShutdown() {
         final TopicPartition changelog = new TopicPartition("changelog", 0);
-        final Map<TaskId, Set<TopicPartition>> assignment = singletonMap(taskId00, taskId00Partitions);
+        final Map<TaskId, Set<TopicPartition>> assignment = mkMap(
+            mkEntry(taskId00, taskId00Partitions),
+            mkEntry(taskId01, taskId01Partitions),
+            mkEntry(taskId02, taskId02Partitions)
+        );
         final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true) {
             @Override
             public Collection<TopicPartition> changelogPartitions() {
                 return singletonList(changelog);
             }
         };
+        final Task task01 = new StateMachineTask(taskId01, taskId01Partitions, true) {
+            @Override
+            public void closeClean() {
+                throw new TaskMigratedException("migrated", new RuntimeException("cause"));
+            }
+        };
+        final Task task02 = new StateMachineTask(taskId02, taskId02Partitions, true) {
+            @Override
+            public void closeClean() {
+                throw new RuntimeException("oops");
+            }
+        };
 
-        EasyMock.resetToStrict(changeLogReader);
+        resetToStrict(changeLogReader);
+        changeLogReader.transitToRestoreActive();
+        expectLastCall();
+        expect(changeLogReader.completedChangelogs()).andReturn(emptySet());
+        // make sure we also remove the changelog partitions from the changelog reader
+        changeLogReader.remove(eq(singletonList(changelog)));
+        expectLastCall();
+        expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andReturn(asList(task00, task01, task02)).anyTimes();
+        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId00));
+        expectLastCall();
+        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId01));
+        expectLastCall();
+        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId02));
+        expectLastCall();
+        activeTaskCreator.closeThreadProducerIfNeeded();
+        expectLastCall();
+        expect(standbyTaskCreator.createTasks(eq(emptyMap()))).andReturn(emptyList()).anyTimes();
+        replay(activeTaskCreator, standbyTaskCreator, changeLogReader);
+
+        taskManager.handleAssignment(assignment, emptyMap());
+
+        assertThat(task00.state(), is(Task.State.CREATED));
+        assertThat(task01.state(), is(Task.State.CREATED));
+        assertThat(task02.state(), is(Task.State.CREATED));
+
+        taskManager.tryToCompleteRestoration();
+
+        assertThat(task00.state(), is(Task.State.RESTORING));
+        assertThat(task01.state(), is(Task.State.RUNNING));
+        assertThat(task02.state(), is(Task.State.RUNNING));
+        assertThat(
+            taskManager.activeTaskMap(),
+            Matchers.equalTo(
+                mkMap(
+                    mkEntry(taskId00, task00),
+                    mkEntry(taskId01, task01),
+                    mkEntry(taskId02, task02)
+                )
+            )
+        );
+        assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
+
+        final RuntimeException exception = assertThrows(RuntimeException.class, () -> taskManager.shutdown(true));
+
+        assertThat(task00.state(), is(Task.State.CLOSED));
+        assertThat(task01.state(), is(Task.State.CLOSED));
+        assertThat(task02.state(), is(Task.State.CLOSED));
+        assertThat(exception.getMessage(), is("Unexpected exception while closing task"));
+        assertThat(exception.getCause().getMessage(), is("oops"));
+        assertThat(taskManager.activeTaskMap(), Matchers.anEmptyMap());
+        assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
+        // the active task creator should also get closed (so that it closes the thread producer if applicable)
+        verify(activeTaskCreator, changeLogReader);
+    }
+
+    @Test
+    public void shouldCloseActiveTasksAndPropagateTaskProducerExceptionsOnCleanShutdown() {
+        final TopicPartition changelog = new TopicPartition("changelog", 0);
+        final Map<TaskId, Set<TopicPartition>> assignment = mkMap(
+            mkEntry(taskId00, taskId00Partitions)
+        );
+        final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true) {
+            @Override
+            public Collection<TopicPartition> changelogPartitions() {
+                return singletonList(changelog);
+            }
+        };
+
+        resetToStrict(changeLogReader);
         changeLogReader.transitToRestoreActive();
         expectLastCall();
         expect(changeLogReader.completedChangelogs()).andReturn(emptySet());
@@ -322,9 +592,11 @@ public class TaskManagerTest {
         changeLogReader.remove(eq(singletonList(changelog)));
         expectLastCall();
         expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andReturn(singletonList(task00)).anyTimes();
-        activeTaskCreator.close();
+        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId00));
+        expectLastCall().andThrow(new RuntimeException("whatever"));
+        activeTaskCreator.closeThreadProducerIfNeeded();
         expectLastCall();
-        expect(standbyTaskCreator.createTasks(anyObject(), eq(emptyMap()))).andReturn(emptyList()).anyTimes();
+        expect(standbyTaskCreator.createTasks(eq(emptyMap()))).andReturn(emptyList()).anyTimes();
         replay(activeTaskCreator, standbyTaskCreator, changeLogReader);
 
         taskManager.handleAssignment(assignment, emptyMap());
@@ -334,11 +606,157 @@ public class TaskManagerTest {
         taskManager.tryToCompleteRestoration();
 
         assertThat(task00.state(), is(Task.State.RESTORING));
-        assertThat(taskManager.activeTaskMap(), Matchers.equalTo(singletonMap(taskId00, task00)));
+        assertThat(
+            taskManager.activeTaskMap(),
+            Matchers.equalTo(
+                mkMap(
+                    mkEntry(taskId00, task00)
+                )
+            )
+        );
+        assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
+
+        final RuntimeException exception = assertThrows(RuntimeException.class, () -> taskManager.shutdown(true));
+
+        assertThat(task00.state(), is(Task.State.CLOSED));
+        assertThat(exception.getMessage(), is("Unexpected exception while closing task"));
+        assertThat(exception.getCause().getMessage(), is("whatever"));
+        assertThat(taskManager.activeTaskMap(), Matchers.anEmptyMap());
+        assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
+        // the active task creator should also get closed (so that it closes the thread producer if applicable)
+        verify(activeTaskCreator, changeLogReader);
+    }
+
+    @Test
+    public void shouldCloseActiveTasksAndPropagateThreadProducerExceptionsOnCleanShutdown() {
+        final TopicPartition changelog = new TopicPartition("changelog", 0);
+        final Map<TaskId, Set<TopicPartition>> assignment = mkMap(
+            mkEntry(taskId00, taskId00Partitions)
+        );
+        final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true) {
+            @Override
+            public Collection<TopicPartition> changelogPartitions() {
+                return singletonList(changelog);
+            }
+        };
+
+        resetToStrict(changeLogReader);
+        changeLogReader.transitToRestoreActive();
+        expectLastCall();
+        expect(changeLogReader.completedChangelogs()).andReturn(emptySet());
+        // make sure we also remove the changelog partitions from the changelog reader
+        changeLogReader.remove(eq(singletonList(changelog)));
+        expectLastCall();
+        expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andReturn(singletonList(task00)).anyTimes();
+        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId00));
+        expectLastCall();
+        activeTaskCreator.closeThreadProducerIfNeeded();
+        expectLastCall().andThrow(new RuntimeException("whatever"));
+        expect(standbyTaskCreator.createTasks(eq(emptyMap()))).andReturn(emptyList()).anyTimes();
+        replay(activeTaskCreator, standbyTaskCreator, changeLogReader);
+
+        taskManager.handleAssignment(assignment, emptyMap());
+
+        assertThat(task00.state(), is(Task.State.CREATED));
+
+        taskManager.tryToCompleteRestoration();
+
+        assertThat(task00.state(), is(Task.State.RESTORING));
+        assertThat(
+            taskManager.activeTaskMap(),
+            Matchers.equalTo(
+                mkMap(
+                    mkEntry(taskId00, task00)
+                )
+            )
+        );
         assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
-        taskManager.shutdown(true);
+
+        final RuntimeException exception = assertThrows(RuntimeException.class, () -> taskManager.shutdown(true));
 
         assertThat(task00.state(), is(Task.State.CLOSED));
+        assertThat(exception.getMessage(), is("Unexpected exception while closing task"));
+        assertThat(exception.getCause().getMessage(), is("whatever"));
+        assertThat(taskManager.activeTaskMap(), Matchers.anEmptyMap());
+        assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
+        // the active task creator should also get closed (so that it closes the thread producer if applicable)
+        verify(activeTaskCreator, changeLogReader);
+    }
+
+    @Test
+    public void shouldCloseActiveTasksAndIgnoreExceptionsOnUncleanShutdown() {
+        final TopicPartition changelog = new TopicPartition("changelog", 0);
+        final Map<TaskId, Set<TopicPartition>> assignment = mkMap(
+            mkEntry(taskId00, taskId00Partitions),
+            mkEntry(taskId01, taskId01Partitions),
+            mkEntry(taskId02, taskId02Partitions)
+        );
+        final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true) {
+            @Override
+            public Collection<TopicPartition> changelogPartitions() {
+                return singletonList(changelog);
+            }
+        };
+        final Task task01 = new StateMachineTask(taskId01, taskId01Partitions, true) {
+            @Override
+            public void closeClean() {
+                throw new TaskMigratedException("migrated", new RuntimeException("cause"));
+            }
+        };
+        final Task task02 = new StateMachineTask(taskId02, taskId02Partitions, true) {
+            @Override
+            public void closeClean() {
+                throw new RuntimeException("oops");
+            }
+        };
+
+        resetToStrict(changeLogReader);
+        changeLogReader.transitToRestoreActive();
+        expectLastCall();
+        expect(changeLogReader.completedChangelogs()).andReturn(emptySet());
+        // make sure we also remove the changelog partitions from the changelog reader
+        changeLogReader.remove(eq(singletonList(changelog)));
+        expectLastCall();
+        expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andReturn(asList(task00, task01, task02)).anyTimes();
+        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId00));
+        expectLastCall().andThrow(new RuntimeException("whatever 0"));
+        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId01));
+        expectLastCall().andThrow(new RuntimeException("whatever 1"));
+        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId02));
+        expectLastCall().andThrow(new RuntimeException("whatever 2"));
+        activeTaskCreator.closeThreadProducerIfNeeded();
+        expectLastCall().andThrow(new RuntimeException("whatever all"));
+        expect(standbyTaskCreator.createTasks(eq(emptyMap()))).andReturn(emptyList()).anyTimes();
+        replay(activeTaskCreator, standbyTaskCreator, changeLogReader);
+
+        taskManager.handleAssignment(assignment, emptyMap());
+
+        assertThat(task00.state(), is(Task.State.CREATED));
+        assertThat(task01.state(), is(Task.State.CREATED));
+        assertThat(task02.state(), is(Task.State.CREATED));
+
+        taskManager.tryToCompleteRestoration();
+
+        assertThat(task00.state(), is(Task.State.RESTORING));
+        assertThat(task01.state(), is(Task.State.RUNNING));
+        assertThat(task02.state(), is(Task.State.RUNNING));
+        assertThat(
+            taskManager.activeTaskMap(),
+            Matchers.equalTo(
+                mkMap(
+                    mkEntry(taskId00, task00),
+                    mkEntry(taskId01, task01),
+                    mkEntry(taskId02, task02)
+                )
+            )
+        );
+        assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
+
+        taskManager.shutdown(false);
+
+        assertThat(task00.state(), is(Task.State.CLOSED));
+        assertThat(task01.state(), is(Task.State.CLOSED));
+        assertThat(task02.state(), is(Task.State.CLOSED));
         assertThat(taskManager.activeTaskMap(), Matchers.anEmptyMap());
         assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
         // the active task creator should also get closed (so that it closes the thread producer if applicable)
@@ -355,9 +773,9 @@ public class TaskManagerTest {
         consumer.resume(eq(emptySet()));
         expectLastCall();
         expect(activeTaskCreator.createTasks(anyObject(), eq(emptyMap()))).andReturn(emptyList()).anyTimes();
-        activeTaskCreator.close();
+        activeTaskCreator.closeThreadProducerIfNeeded();
         expectLastCall();
-        expect(standbyTaskCreator.createTasks(anyObject(), eq(assignment))).andReturn(singletonList(task00)).anyTimes();
+        expect(standbyTaskCreator.createTasks(eq(assignment))).andReturn(singletonList(task00)).anyTimes();
         replay(consumer, activeTaskCreator, standbyTaskCreator, changeLogReader);
 
         taskManager.handleAssignment(emptyMap(), assignment);
@@ -402,7 +820,7 @@ public class TaskManagerTest {
         final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, false);
 
         expectRestoreToBeCompleted(consumer, changeLogReader);
-        expect(standbyTaskCreator.createTasks(anyObject(), eq(taskId01Assignment)))
+        expect(standbyTaskCreator.createTasks(eq(taskId01Assignment)))
             .andReturn(singletonList(task01)).anyTimes();
 
         replay(standbyTaskCreator, consumer, changeLogReader);
@@ -416,6 +834,20 @@ public class TaskManagerTest {
     }
 
     @Test
+    public void shouldHandleRebalanceEvents() {
+        final Set<TopicPartition> assignment = singleton(new TopicPartition("assignment", 0));
+        expect(consumer.assignment()).andReturn(assignment);
+        consumer.pause(assignment);
+        expectLastCall();
+        replay(consumer);
+        assertThat(taskManager.isRebalanceInProgress(), is(false));
+        taskManager.handleRebalanceStart(emptySet());
+        assertThat(taskManager.isRebalanceInProgress(), is(true));
+        taskManager.handleRebalanceComplete();
+        assertThat(taskManager.isRebalanceInProgress(), is(false));
+    }
+
+    @Test
     public void shouldCommitActiveAndStandbyTasks() {
         final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true);
         final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, false);
@@ -423,7 +855,7 @@ public class TaskManagerTest {
         expectRestoreToBeCompleted(consumer, changeLogReader);
         expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment)))
             .andReturn(singletonList(task00)).anyTimes();
-        expect(standbyTaskCreator.createTasks(anyObject(), eq(taskId01Assignment)))
+        expect(standbyTaskCreator.createTasks(eq(taskId01Assignment)))
             .andReturn(singletonList(task01)).anyTimes();
 
         replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
@@ -441,6 +873,41 @@ public class TaskManagerTest {
     }
 
     @Test
+    public void shouldNotCommitActiveAndStandbyTasksWhileRebalanceInProgress() {
+        final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true);
+        final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, false);
+
+        expectRestoreToBeCompleted(consumer, changeLogReader);
+        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment)))
+            .andReturn(singletonList(task00)).anyTimes();
+        expect(standbyTaskCreator.createTasks(eq(taskId01Assignment)))
+            .andReturn(singletonList(task01)).anyTimes();
+
+        replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
+
+        taskManager.handleAssignment(taskId00Assignment, taskId01Assignment);
+        assertThat(taskManager.tryToCompleteRestoration(), is(true));
+
+        assertThat(task00.state(), is(Task.State.RUNNING));
+        assertThat(task01.state(), is(Task.State.RUNNING));
+
+        task00.setCommitNeeded();
+        task01.setCommitNeeded();
+
+        taskManager.handleRebalanceStart(emptySet());
+
+        assertThat(
+            taskManager.commitAll(),
+            equalTo(-1) // sentinel indicating that nothing was done because a rebalance is in progress
+        );
+
+        assertThat(
+            taskManager.maybeCommitActiveTasksPerUserRequested(),
+            equalTo(-1) // sentinel indicating that nothing was done because a rebalance is in progress
+        );
+    }
+
+    @Test
     public void shouldPropagateExceptionFromActiveCommit() {
         final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true) {
             @Override
@@ -477,7 +944,7 @@ public class TaskManagerTest {
         };
 
         expectRestoreToBeCompleted(consumer, changeLogReader);
-        expect(standbyTaskCreator.createTasks(anyObject(), eq(taskId01Assignment)))
+        expect(standbyTaskCreator.createTasks(eq(taskId01Assignment)))
             .andReturn(singletonList(task01)).anyTimes();
 
         replay(standbyTaskCreator, consumer, changeLogReader);
@@ -659,6 +1126,111 @@ public class TaskManagerTest {
     }
 
     @Test
+    public void shouldPropagateTaskMigratedExceptionsInProcessActiveTasks() {
+        final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true) {
+            @Override
+            public boolean process(final long wallClockTime) {
+                throw new TaskMigratedException("migrated", new RuntimeException("cause"));
+            }
+        };
+
+        expectRestoreToBeCompleted(consumer, changeLogReader);
+        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment)))
+            .andReturn(singletonList(task00)).anyTimes();
+
+        replay(activeTaskCreator, consumer, changeLogReader);
+
+        taskManager.handleAssignment(taskId00Assignment, emptyMap());
+        assertThat(taskManager.tryToCompleteRestoration(), is(true));
+
+        assertThat(task00.state(), is(Task.State.RUNNING));
+
+        final TopicPartition partition = taskId00Partitions.iterator().next();
+        task00.addRecords(
+            partition,
+            singletonList(new ConsumerRecord<>(partition.topic(), partition.partition(), 0L, null, null))
+        );
+
+        assertThrows(TaskMigratedException.class, () -> taskManager.process(0L));
+    }
+
+    @Test
+    public void shouldPropagateRuntimeExceptionsInProcessActiveTasks() {
+        final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true) {
+            @Override
+            public boolean process(final long wallClockTime) {
+                throw new RuntimeException("oops");
+            }
+        };
+
+        expectRestoreToBeCompleted(consumer, changeLogReader);
+        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment)))
+            .andReturn(singletonList(task00)).anyTimes();
+
+        replay(activeTaskCreator, consumer, changeLogReader);
+
+        taskManager.handleAssignment(taskId00Assignment, emptyMap());
+        assertThat(taskManager.tryToCompleteRestoration(), is(true));
+
+        assertThat(task00.state(), is(Task.State.RUNNING));
+
+        final TopicPartition partition = taskId00Partitions.iterator().next();
+        task00.addRecords(
+            partition,
+            singletonList(new ConsumerRecord<>(partition.topic(), partition.partition(), 0L, null, null))
+        );
+
+        final RuntimeException exception = assertThrows(RuntimeException.class, () -> taskManager.process(0L));
+        assertThat(exception.getMessage(), is("oops"));
+    }
+
+    @Test
+    public void shouldPropagateTaskMigratedExceptionsInPunctuateActiveTasks() {
+        final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true) {
+            @Override
+            public boolean maybePunctuateStreamTime() {
+                throw new TaskMigratedException("migrated", new RuntimeException("cause"));
+            }
+        };
+
+        expectRestoreToBeCompleted(consumer, changeLogReader);
+        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment)))
+            .andReturn(singletonList(task00)).anyTimes();
+
+        replay(activeTaskCreator, consumer, changeLogReader);
+
+        taskManager.handleAssignment(taskId00Assignment, emptyMap());
+        assertThat(taskManager.tryToCompleteRestoration(), is(true));
+
+        assertThat(task00.state(), is(Task.State.RUNNING));
+
+        assertThrows(TaskMigratedException.class, () -> taskManager.punctuate());
+    }
+
+    @Test
+    public void shouldPropagateKafkaExceptionsInPunctuateActiveTasks() {
+        final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true) {
+            @Override
+            public boolean maybePunctuateStreamTime() {
+                throw new KafkaException("oops");
+            }
+        };
+
+        expectRestoreToBeCompleted(consumer, changeLogReader);
+        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment)))
+            .andReturn(singletonList(task00)).anyTimes();
+
+        replay(activeTaskCreator, consumer, changeLogReader);
+
+        taskManager.handleAssignment(taskId00Assignment, emptyMap());
+        assertThat(taskManager.tryToCompleteRestoration(), is(true));
+
+        assertThat(task00.state(), is(Task.State.RUNNING));
+
+        assertThrows(KafkaException.class, () -> taskManager.punctuate());
+    }
+
+    @Test
     public void shouldPunctuateActiveTasks() {
         final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true) {
             @Override
@@ -729,8 +1301,9 @@ public class TaskManagerTest {
 
         final List<String> messages = appender.getMessages();
         assertThat(messages, hasItem("taskManagerTestThe following partitions [unknown-0] are missing " +
-            "from the task partitions. It could potentially due to race condition of consumer " +
-            "detecting the heartbeat failure, or the tasks have been cleaned up by the handleAssignment callback."));
+                                         "from the task partitions. It could potentially due to race " +
+                                         "condition of consumer detecting the heartbeat failure, or the " +
+                                         "tasks have been cleaned up by the handleAssignment callback."));
     }
 
     @Test
@@ -751,8 +1324,10 @@ public class TaskManagerTest {
         taskManager.tasks().put(taskId01, migratedTask01);
         taskManager.tasks().put(taskId02, migratedTask02);
 
-        final TaskMigratedException thrown = assertThrows(TaskMigratedException.class,
-            () -> taskManager.handleAssignment(emptyMap(), emptyMap()));
+        final TaskMigratedException thrown = assertThrows(
+            TaskMigratedException.class,
+            () -> taskManager.handleAssignment(emptyMap(), emptyMap())
+        );
         // The task map orders tasks based on topic group id and partition, so here
         // t1 should always be the first.
         assertThat(thrown.getMessage(), equalTo("t1 close exception; it means all tasks belonging to this thread should be migrated."));
@@ -776,11 +1351,13 @@ public class TaskManagerTest {
         taskManager.tasks().put(taskId01, migratedTask01);
         taskManager.tasks().put(taskId02, migratedTask02);
 
-        final RuntimeException thrown = assertThrows(RuntimeException.class,
-            () -> taskManager.handleAssignment(emptyMap(), emptyMap()));
+        final RuntimeException thrown = assertThrows(
+            RuntimeException.class,
+            () -> taskManager.handleAssignment(emptyMap(), emptyMap())
+        );
         // Fatal exception thrown first.
         assertThat(thrown.getMessage(), equalTo("Unexpected failure to close 2 task(s) [[0_1, 0_2]]. " +
-            "First unexpected exception (for task 0_2) follows."));
+                                                    "First unexpected exception (for task 0_2) follows."));
 
         assertThat(thrown.getCause().getMessage(), equalTo("t2 illegal state exception"));
     }
@@ -803,8 +1380,10 @@ public class TaskManagerTest {
         taskManager.tasks().put(taskId01, migratedTask01);
         taskManager.tasks().put(taskId02, migratedTask02);
 
-        final KafkaException thrown = assertThrows(KafkaException.class,
-            () -> taskManager.handleAssignment(emptyMap(), emptyMap()));
+        final KafkaException thrown = assertThrows(
+            KafkaException.class,
+            () -> taskManager.handleAssignment(emptyMap(), emptyMap())
+        );
 
         // Expecting the original Kafka exception instead of a wrapped one.
         assertThat(thrown.getMessage(), equalTo("Kaboom for t2!"));
@@ -812,6 +1391,23 @@ public class TaskManagerTest {
         assertThat(thrown.getCause().getMessage(), equalTo(null));
     }
 
+    @Test
+    public void shouldTransmitProducerMetrics() {
+        final MetricName testMetricName = new MetricName("test_metric", "", "", new HashMap<>());
+        final Metric testMetric = new KafkaMetric(
+            new Object(),
+            testMetricName,
+            (Measurable) (config, now) -> 0,
+            null,
+            new MockTime());
+        final Map<MetricName, Metric> dummyProducerMetrics = singletonMap(testMetricName, testMetric);
+
+        expect(activeTaskCreator.producerMetrics()).andReturn(dummyProducerMetrics);
+        replay(activeTaskCreator);
+
+        assertThat(taskManager.producerMetrics(), is(dummyProducerMetrics));
+    }
+
     private static void expectRestoreToBeCompleted(final Consumer<byte[], byte[]> consumer,
                                                    final ChangelogReader changeLogReader) {
         final Set<TopicPartition> assignment = singleton(new TopicPartition("assignment", 0));
@@ -837,7 +1433,14 @@ public class TaskManagerTest {
         StateMachineTask(final TaskId id,
                          final Set<TopicPartition> partitions,
                          final boolean active) {
-            super(id, null, null, null, partitions);
+            this(id, partitions, active, null);
+        }
+
+        StateMachineTask(final TaskId id,
+                         final Set<TopicPartition> partitions,
+                         final boolean active,
+                         final ProcessorStateManager processorStateManager) {
+            super(id, null, null, processorStateManager, partitions);
             this.active = active;
         }
 
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java b/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java
index 5e6a3e4..1c8bf98 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java
@@ -201,7 +201,7 @@ public class KeyValueStoreTestDriver<K, V> {
             logContext,
             new TaskId(0, 0),
             consumer,
-            new StreamsProducer(logContext, producer),
+            new StreamsProducer(producer, false, logContext, null),
             new DefaultProductionExceptionHandler(),
             false,
             new MockStreamsMetrics(new Metrics())
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java
index 0478168..0307b0c 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java
@@ -374,13 +374,7 @@ public class StreamThreadStateStoreProviderTest {
             logContext,
             taskId,
             clientSupplier.consumer,
-            eosEnabled ?
-                new StreamsProducer(
-                    logContext,
-                    clientSupplier.getProducer(new HashMap<>()),
-                    streamsConfig.getString(StreamsConfig.APPLICATION_ID_CONFIG),
-                    taskId) :
-                new StreamsProducer(logContext, clientSupplier.getProducer(new HashMap<>())),
+            new StreamsProducer(clientSupplier.getProducer(new HashMap<>()), eosEnabled, logContext, streamsConfig.getString(StreamsConfig.APPLICATION_ID_CONFIG)),
             streamsConfig.defaultProductionExceptionHandler(),
             eosEnabled,
             new MockStreamsMetrics(metrics));
diff --git a/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
index b997688..41fef28 100644
--- a/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
+++ b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
@@ -443,11 +443,7 @@ public class TopologyTestDriver implements Closeable {
                 logContext,
                 TASK_ID,
                 consumer,
-                new StreamsProducer(
-                    logContext,
-                    producer,
-                    eosEnabled ? streamsConfig.getString(StreamsConfig.APPLICATION_ID_CONFIG) : null,
-                    eosEnabled ? TASK_ID : null),
+                new StreamsProducer(producer, eosEnabled, logContext, streamsConfig.getString(StreamsConfig.APPLICATION_ID_CONFIG)),
                 streamsConfig.defaultProductionExceptionHandler(),
                 eosEnabled,
                 streamsMetrics);