You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by mj...@apache.org on 2020/03/19 18:32:20 UTC

[kafka] branch trunk updated: KAFKA-9441: Unify committing within TaskManager (#8218)

This is an automated email from the ASF dual-hosted git repository.

mjsax pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 89cd2f2  KAFKA-9441: Unify committing within TaskManager (#8218)
89cd2f2 is described below

commit 89cd2f2a0b21368297323437fd15ba6341e4707b
Author: Matthias J. Sax <ma...@confluent.io>
AuthorDate: Thu Mar 19 11:31:51 2020 -0700

    KAFKA-9441: Unify committing within TaskManager (#8218)
    
     - part of KIP-447
     - commit all tasks at once using non-eos (and eos-beta in follow up work)
     - unified commit logic into TaskManager
     - split existing methods of Task interface in pre/post parts
    
    Reviewers: Boyang Chen <bo...@confluent.io>, Guozhang Wang <gu...@confluent.io>
---
 checkstyle/suppressions.xml                        |   2 +-
 .../processor/internals/ActiveTaskCreator.java     |  68 ++-
 .../processor/internals/RecordCollector.java       |   3 -
 .../processor/internals/RecordCollectorImpl.java   |  29 +-
 .../streams/processor/internals/StandbyTask.java   | 136 ++++--
 .../streams/processor/internals/StreamTask.java    | 235 +++++----
 .../streams/processor/internals/StreamThread.java  |   5 +-
 .../processor/internals/StreamsProducer.java       |  58 ++-
 .../kafka/streams/processor/internals/Task.java    |  37 +-
 .../streams/processor/internals/TaskManager.java   | 194 ++++++--
 .../integration/MetricsIntegrationTest.java        |  11 +-
 .../processor/internals/ActiveTaskCreatorTest.java | 125 +++++
 .../processor/internals/RecordCollectorTest.java   | 208 +-------
 .../processor/internals/StandbyTaskTest.java       |  34 +-
 .../processor/internals/StreamTaskTest.java        | 175 ++++---
 .../processor/internals/StreamThreadTest.java      |   8 +
 .../processor/internals/StreamsProducerTest.java   | 183 ++++---
 .../processor/internals/TaskManagerTest.java       | 527 ++++++++++++++++++---
 .../streams/state/KeyValueStoreTestDriver.java     |   4 +-
 .../StreamThreadStateStoreProviderTest.java        |   9 +-
 .../org/apache/kafka/test/MockRecordCollector.java |  13 -
 .../apache/kafka/streams/TopologyTestDriver.java   |  70 ++-
 .../processor/internals/TestDriverProducer.java    |  39 ++
 23 files changed, 1467 insertions(+), 706 deletions(-)

diff --git a/checkstyle/suppressions.xml b/checkstyle/suppressions.xml
index aa3a086..5cd6e52 100644
--- a/checkstyle/suppressions.xml
+++ b/checkstyle/suppressions.xml
@@ -181,7 +181,7 @@
               files="StreamsPartitionAssignor.java"/>
 
     <suppress checks="NPathComplexity"
-              files="(ProcessorStateManager|InternalTopologyBuilder|StreamsPartitionAssignor|StreamThread).java"/>
+              files="(ProcessorStateManager|InternalTopologyBuilder|StreamsPartitionAssignor|StreamThread|TaskManager).java"/>
 
     <suppress checks="(FinalLocalVariable|UnnecessaryParentheses|BooleanExpressionComplexity|CyclomaticComplexity|WhitespaceAfter|LocalVariableName)"
               files="Murmur3.java"/>
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
index 2f40556..d2fa577 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
@@ -47,20 +47,20 @@ import java.util.stream.Collectors;
 import static org.apache.kafka.streams.StreamsConfig.EXACTLY_ONCE;
 
 class ActiveTaskCreator {
-    private final String applicationId;
     private final InternalTopologyBuilder builder;
     private final StreamsConfig config;
     private final StreamsMetricsImpl streamsMetrics;
     private final StateDirectory stateDirectory;
     private final ChangelogReader storeChangelogReader;
-    private final Time time;
-    private final Logger log;
-    private final String threadId;
     private final ThreadCache cache;
-    private final Producer<byte[], byte[]> threadProducer;
+    private final Time time;
     private final KafkaClientSupplier clientSupplier;
-    private final Map<TaskId, Producer<byte[], byte[]>> taskProducers;
+    private final String threadId;
+    private final Logger log;
     private final Sensor createTaskSensor;
+    private final String applicationId;
+    private final Producer<byte[], byte[]> threadProducer;
+    private final Map<TaskId, StreamsProducer> taskProducers;
 
     private static String getThreadProducerClientId(final String threadClientId) {
         return threadClientId + "-producer";
@@ -80,32 +80,44 @@ class ActiveTaskCreator {
                       final KafkaClientSupplier clientSupplier,
                       final String threadId,
                       final Logger log) {
-        applicationId = config.getString(StreamsConfig.APPLICATION_ID_CONFIG);
         this.builder = builder;
         this.config = config;
         this.streamsMetrics = streamsMetrics;
         this.stateDirectory = stateDirectory;
         this.storeChangelogReader = storeChangelogReader;
+        this.cache = cache;
         this.time = time;
+        this.clientSupplier = clientSupplier;
+        this.threadId = threadId;
         this.log = log;
 
+        createTaskSensor = ThreadMetrics.createTaskSensor(threadId, streamsMetrics);
+        applicationId = config.getString(StreamsConfig.APPLICATION_ID_CONFIG);
+
         if (EXACTLY_ONCE.equals(config.getString(StreamsConfig.PROCESSING_GUARANTEE_CONFIG))) {
             threadProducer = null;
             taskProducers = new HashMap<>();
         } else {
+            log.info("Creating thread producer client");
+
             final String threadProducerClientId = getThreadProducerClientId(threadId);
             final Map<String, Object> producerConfigs = config.getProducerConfigs(threadProducerClientId);
-            log.info("Creating thread producer client");
+
             threadProducer = clientSupplier.getProducer(producerConfigs);
             taskProducers = Collections.emptyMap();
         }
+    }
 
+    StreamsProducer streamsProducerForTask(final TaskId taskId) {
+        if (threadProducer != null) {
+            throw new IllegalStateException("Producer per thread is used");
+        }
 
-        this.cache = cache;
-        this.threadId = threadId;
-        this.clientSupplier = clientSupplier;
-
-        createTaskSensor = ThreadMetrics.createTaskSensor(threadId, streamsMetrics);
+        final StreamsProducer taskProducer = taskProducers.get(taskId);
+        if (taskProducer == null) {
+            throw new IllegalStateException("Unknown TaskId: " + taskId);
+        }
+        return taskProducer;
     }
 
     Collection<Task> createTasks(final Consumer<byte[], byte[]> consumer,
@@ -132,23 +144,27 @@ class ActiveTaskCreator {
                 partitions
             );
 
+            final StreamsProducer streamsProducer;
             if (threadProducer == null) {
                 final String taskProducerClientId = getTaskProducerClientId(threadId, taskId);
                 final Map<String, Object> producerConfigs = config.getProducerConfigs(taskProducerClientId);
                 producerConfigs.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, applicationId + "-" + taskId);
                 log.info("Creating producer client for task {}", taskId);
-                taskProducers.put(taskId, clientSupplier.getProducer(producerConfigs));
+                streamsProducer = new StreamsProducer(
+                    clientSupplier.getProducer(producerConfigs),
+                    true,
+                    applicationId,
+                    logContext);
+                taskProducers.put(taskId, streamsProducer);
+            } else {
+                streamsProducer = new StreamsProducer(threadProducer, false, null, logContext);
             }
 
             final RecordCollector recordCollector = new RecordCollectorImpl(
                 logContext,
                 taskId,
-                consumer,
-                threadProducer != null ?
-                    new StreamsProducer(threadProducer, false, logContext, applicationId) :
-                    new StreamsProducer(taskProducers.get(taskId), true, logContext, applicationId),
+                streamsProducer,
                 config.defaultProductionExceptionHandler(),
-                EXACTLY_ONCE.equals(config.getString(StreamsConfig.PROCESSING_GUARANTEE_CONFIG)),
                 streamsMetrics
             );
 
@@ -178,18 +194,18 @@ class ActiveTaskCreator {
             try {
                 threadProducer.close();
             } catch (final RuntimeException e) {
-                throw new StreamsException("Thread Producer encounter unexpected error trying to close", e);
+                throw new StreamsException("Thread Producer encounter error trying to close", e);
             }
         }
     }
 
     void closeAndRemoveTaskProducerIfNeeded(final TaskId id) {
-        final Producer<byte[], byte[]> producer = taskProducers.remove(id);
-        if (producer != null) {
+        final StreamsProducer taskProducer = taskProducers.remove(id);
+        if (taskProducer != null) {
             try {
-                producer.close();
+                taskProducer.kafkaProducer().close();
             } catch (final RuntimeException e) {
-                throw new StreamsException("[" + id + "] Producer encounter unexpected error trying to close", e);
+                throw new StreamsException("[" + id + "] Producer encounter error trying to close", e);
             }
         }
     }
@@ -205,8 +221,8 @@ class ActiveTaskCreator {
             // When EOS is turned on, each task will have its own producer client
             // and the producer object passed in here will be null. We would then iterate through
             // all the active tasks and add their metrics to the output metrics map.
-            for (final Map.Entry<TaskId, Producer<byte[], byte[]>> entry : taskProducers.entrySet()) {
-                final Map<MetricName, ? extends Metric> taskProducerMetrics = entry.getValue().metrics();
+            for (final Map.Entry<TaskId, StreamsProducer> entry : taskProducers.entrySet()) {
+                final Map<MetricName, ? extends Metric> taskProducerMetrics = entry.getValue().kafkaProducer().metrics();
                 result.putAll(taskProducerMetrics);
             }
         }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollector.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollector.java
index 9594679..0101054 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollector.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollector.java
@@ -16,7 +16,6 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
-import org.apache.kafka.clients.consumer.OffsetAndMetadata;
 import org.apache.kafka.clients.producer.Producer;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.header.Headers;
@@ -45,8 +44,6 @@ public interface RecordCollector {
                      final Serializer<V> valueSerializer,
                      final StreamPartitioner<? super K, ? super V> partitioner);
 
-    void commit(final Map<TopicPartition, OffsetAndMetadata> offsets);
-
     /**
      * Initialize the internal {@link Producer}; note this function should be made idempotent
      *
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java
index 8e69e4d..124e6f6 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java
@@ -16,9 +16,6 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
-import org.apache.kafka.clients.consumer.CommitFailedException;
-import org.apache.kafka.clients.consumer.Consumer;
-import org.apache.kafka.clients.consumer.OffsetAndMetadata;
 import org.apache.kafka.clients.producer.Producer;
 import org.apache.kafka.clients.producer.ProducerRecord;
 import org.apache.kafka.common.KafkaException;
@@ -33,7 +30,6 @@ import org.apache.kafka.common.errors.ProducerFencedException;
 import org.apache.kafka.common.errors.RetriableException;
 import org.apache.kafka.common.errors.SecurityDisabledException;
 import org.apache.kafka.common.errors.SerializationException;
-import org.apache.kafka.common.errors.TimeoutException;
 import org.apache.kafka.common.errors.UnknownServerException;
 import org.apache.kafka.common.header.Headers;
 import org.apache.kafka.common.metrics.Sensor;
@@ -59,7 +55,6 @@ public class RecordCollectorImpl implements RecordCollector {
 
     private final Logger log;
     private final TaskId taskId;
-    private final Consumer<byte[], byte[]> mainConsumer;
     private final StreamsProducer streamsProducer;
     private final ProductionExceptionHandler productionExceptionHandler;
     private final Sensor droppedRecordsSensor;
@@ -73,17 +68,14 @@ public class RecordCollectorImpl implements RecordCollector {
      */
     public RecordCollectorImpl(final LogContext logContext,
                                final TaskId taskId,
-                               final Consumer<byte[], byte[]> mainConsumer,
                                final StreamsProducer streamsProducer,
                                final ProductionExceptionHandler productionExceptionHandler,
-                               final boolean eosEnabled,
                                final StreamsMetricsImpl streamsMetrics) {
         this.log = logContext.logger(getClass());
         this.taskId = taskId;
-        this.mainConsumer = mainConsumer;
         this.streamsProducer = streamsProducer;
         this.productionExceptionHandler = productionExceptionHandler;
-        this.eosEnabled = eosEnabled;
+        this.eosEnabled = streamsProducer.eosEnabled();
 
         final String threadId = Thread.currentThread().getName();
         this.droppedRecordsSensor = TaskMetrics.droppedRecordsSensorOrSkippedRecordsSensor(threadId, taskId.toString(), streamsMetrics);
@@ -242,25 +234,6 @@ public class RecordCollectorImpl implements RecordCollector {
         return securityException || communicationException;
     }
 
-    public void commit(final Map<TopicPartition, OffsetAndMetadata> offsets) {
-        if (eosEnabled) {
-            streamsProducer.commitTransaction(offsets);
-        } else {
-            try {
-                mainConsumer.commitSync(offsets);
-            } catch (final CommitFailedException error) {
-                throw new TaskMigratedException("Consumer committing offsets failed, " +
-                    "indicating the corresponding thread is no longer part of the group", error);
-            } catch (final TimeoutException error) {
-                // TODO KIP-447: we can consider treating it as non-fatal and retry on the thread level
-                throw new StreamsException("Timed out while committing offsets via consumer for task " + taskId, error);
-            } catch (final KafkaException error) {
-                throw new StreamsException("Error encountered committing offsets via consumer for task " + taskId, error);
-            }
-        }
-
-    }
-
     /**
      * @throws StreamsException fatal error that should cause the thread to die
      * @throws TaskMigratedException recoverable error that would cause the task to be removed
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java
index ad050c9..4f4dda9 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java
@@ -27,12 +27,13 @@ import org.apache.kafka.streams.errors.TaskMigratedException;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
 import org.apache.kafka.streams.processor.internals.metrics.ThreadMetrics;
+import org.slf4j.Logger;
 
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.Objects;
 import java.util.Set;
-import org.slf4j.Logger;
 
 /**
  * A StandbyTask
@@ -102,6 +103,11 @@ public class StandbyTask extends AbstractTask implements Task {
     }
 
     @Override
+    public void prepareSuspend() {
+        log.trace("No-op prepareSuspend with state {}", state());
+    }
+
+    @Override
     public void suspend() {
         log.trace("No-op suspend with state {}", state());
     }
@@ -115,48 +121,45 @@ public class StandbyTask extends AbstractTask implements Task {
      * 1. flush store
      * 2. write checkpoint file
      *
-     * @throws TaskMigratedException all the task has been migrated
      * @throws StreamsException fatal error, should close the thread
      */
     @Override
-    public void commit() {
-        switch (state()) {
-            case RUNNING:
-                stateMgr.flush();
-
-                // since there's no written offsets we can checkpoint with empty map,
-                // and the state current offset would be used to checkpoint
-                stateMgr.checkpoint(Collections.emptyMap());
-
-                offsetSnapshotSinceLastCommit = new HashMap<>(stateMgr.changelogOffsets());
-
-                log.info("Committed");
-                break;
-
-            case CLOSING:
-                // do nothing and also not throw
-                log.trace("Skip committing since task is closing");
-
-                break;
-
-            default:
-                throw new IllegalStateException("Illegal state " + state() + " while committing standby task " + id);
+    public void prepareCommit() {
+        if (state() == State.RUNNING) {
+            stateMgr.flush();
+            log.info("Task ready for committing");
+        } else {
+            throw new IllegalStateException("Illegal state " + state() + " while preparing standby task " + id + " for committing ");
+        }
+    }
 
+    @Override
+    public void postCommit() {
+        if (state() == State.RUNNING) {
+            // since there's no written offsets we can checkpoint with empty map,
+            // and the state current offset would be used to checkpoint
+            stateMgr.checkpoint(Collections.emptyMap());
+            offsetSnapshotSinceLastCommit = new HashMap<>(stateMgr.changelogOffsets());
+            log.info("Finalized commit");
+        } else {
+            throw new IllegalStateException("Illegal state " + state() + " while post committing standby task " + id);
         }
     }
 
     @Override
-    public void closeClean() {
-        close(true);
+    public Map<TopicPartition, Long> prepareCloseClean() {
+        prepareClose(true);
 
-        log.info("Closed clean");
+        log.info("Prepared clean close");
+
+        return Collections.emptyMap();
     }
 
     @Override
-    public void closeDirty() {
-        close(false);
+    public void prepareCloseDirty() {
+        prepareClose(false);
 
-        log.info("Closed dirty");
+        log.info("Prepared dirty close");
     }
 
     /**
@@ -166,29 +169,70 @@ public class StandbyTask extends AbstractTask implements Task {
      * @throws TaskMigratedException all the task has been migrated
      * @throws StreamsException fatal error, should close the thread
      */
-    private void close(final boolean clean) {
+    private void prepareClose(final boolean clean) {
         if (state() == State.CREATED) {
             // the task is created and not initialized, do nothing
-            transitionTo(State.CLOSING);
-        } else {
-            if (state() == State.RUNNING) {
-                if (clean) {
-                    commit();
-                }
+            return;
+        }
 
-                transitionTo(State.CLOSING);
+        if (state() == State.RUNNING) {
+            if (clean) {
+                stateMgr.flush();
             }
+        } else {
+            throw new IllegalStateException("Illegal state " + state() + " while closing standby task " + id);
+        }
+    }
+
+    @Override
+    public void closeClean(final Map<TopicPartition, Long> checkpoint) {
+        Objects.requireNonNull(checkpoint);
+        close(true, checkpoint);
+
+        log.info("Closed clean");
+    }
+
+    @Override
+    public void closeDirty() {
+        close(false, null);
+
+        log.info("Closed dirty");
+    }
 
-            if (state() == State.CLOSING) {
-                executeAndMaybeSwallow(clean, () -> {
-                    StateManagerUtil.closeStateManager(log, logPrefix, clean,
-                        false, stateMgr, stateDirectory, TaskType.STANDBY);
-                }, "state manager close", log);
+    private void close(final boolean clean,
+                       final Map<TopicPartition, Long> checkpoint) {
+        if (state() == State.CREATED) {
+            // the task is created and not initialized, do nothing
+            closeTaskSensor.record();
+            transitionTo(State.CLOSED);
+            return;
+        }
 
-                // TODO: if EOS is enabled, we should wipe out the state stores like we did for StreamTask too
-            } else {
-                throw new IllegalStateException("Illegal state " + state() + " while closing standby task " + id);
+        if (state() == State.RUNNING) {
+            if (clean) {
+                // since there's no written offsets we can checkpoint with empty map,
+                // and the state current offset would be used to checkpoint
+                stateMgr.checkpoint(Collections.emptyMap());
+                offsetSnapshotSinceLastCommit = new HashMap<>(stateMgr.changelogOffsets());
             }
+
+            executeAndMaybeSwallow(clean, () -> {
+                StateManagerUtil.closeStateManager(
+                    log,
+                    logPrefix,
+                    clean,
+                    false,
+                    stateMgr,
+                    stateDirectory,
+                    TaskType.STANDBY);
+                },
+                "state manager close",
+                log
+            );
+
+            // TODO: if EOS is enabled, we should wipe out the state stores like we did for StreamTask too
+        } else {
+            throw new IllegalStateException("Illegal state " + state() + " while closing standby task " + id);
         }
 
         closeTaskSensor.record();
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
index 64adb6a..beb0fd9 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
@@ -42,6 +42,7 @@ import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.V
 import org.apache.kafka.streams.processor.internals.metrics.TaskMetrics;
 import org.apache.kafka.streams.processor.internals.metrics.ThreadMetrics;
 import org.apache.kafka.streams.state.internals.ThreadCache;
+import org.slf4j.Logger;
 
 import java.io.IOException;
 import java.io.PrintWriter;
@@ -55,7 +56,6 @@ import java.util.Map;
 import java.util.Set;
 import java.util.function.Function;
 import java.util.stream.Collectors;
-import org.slf4j.Logger;
 
 import static java.util.Collections.emptyMap;
 import static java.util.Collections.singleton;
@@ -92,7 +92,6 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
     private final Sensor closeTaskSensor;
     private final Sensor processLatencySensor;
     private final Sensor punctuateLatencySensor;
-    private final Sensor commitSensor;
     private final Sensor enforcedProcessingSensor;
     private final InternalProcessorContext processorContext;
 
@@ -128,10 +127,8 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
         final String taskId = id.toString();
         if (streamsMetrics.version() == Version.FROM_0100_TO_24) {
             final Sensor parent = ThreadMetrics.commitOverTasksSensor(threadId, streamsMetrics);
-            commitSensor = TaskMetrics.commitSensor(threadId, taskId, streamsMetrics, parent);
             enforcedProcessingSensor = TaskMetrics.enforcedProcessingSensor(threadId, taskId, streamsMetrics, parent);
         } else {
-            commitSensor = TaskMetrics.commitSensor(threadId, taskId, streamsMetrics);
             enforcedProcessingSensor = TaskMetrics.enforcedProcessingSensor(threadId, taskId, streamsMetrics);
         }
         processLatencySensor = TaskMetrics.processLatencySensor(threadId, taskId, streamsMetrics);
@@ -231,19 +228,33 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
      *                               or if the task producer got fenced (EOS)
      */
     @Override
-    public void suspend() {
-        if (state() == State.CREATED || state() == State.CLOSING || state() == State.SUSPENDED) {
+    public void prepareSuspend() {
+        if (state() == State.CREATED || state() == State.SUSPENDED) {
             // do nothing
-            log.trace("Skip suspending since state is {}", state());
+            log.trace("Skip prepare suspending since state is {}", state());
         } else if (state() == State.RUNNING) {
             closeTopology(true);
 
-            commitState();
-            // whenever we have successfully committed state during suspension, it is safe to checkpoint
-            // the state as well no matter if EOS is enabled or not
-            stateMgr.checkpoint(checkpointableOffsets());
+            stateMgr.flush();
+            recordCollector.flush();
 
-            // we should also clear any buffered records of a task when suspending it
+            log.info("Prepare suspending running");
+        } else if (state() == State.RESTORING) {
+            stateMgr.flush();
+
+            log.info("Prepare suspending restoring");
+        } else {
+            throw new IllegalStateException("Illegal state " + state() + " while suspending active task " + id);
+        }
+    }
+
+    @Override
+    public void suspend() {
+        if (state() == State.CREATED || state() == State.SUSPENDED) {
+            // do nothing
+            log.trace("Skip suspending since state is {}", state());
+        } else if (state() == State.RUNNING) {
+            stateMgr.checkpoint(checkpointableOffsets());
             partitionGroup.clear();
 
             transitionTo(State.SUSPENDED);
@@ -251,7 +262,6 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
         } else if (state() == State.RESTORING) {
             // we just checkpoint the position that we've restored up to without
             // going through the commit process
-            stateMgr.flush();
             stateMgr.checkpoint(emptyMap());
 
             // we should also clear any buffered records of a task when suspending it
@@ -273,7 +283,6 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
     public void resume() {
         switch (state()) {
             case CREATED:
-            case CLOSING:
             case RUNNING:
             case RESTORING:
                 // no need to do anything, just let them continue running / restoring / closing
@@ -293,18 +302,30 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
         }
     }
 
-    /**
-     * @throws TaskMigratedException if committing offsets failed (non-EOS)
-     *                               or if the task producer got fenced (EOS)
-     */
     @Override
-    public void commit() {
+    public void prepareCommit() {
         switch (state()) {
             case RUNNING:
             case RESTORING:
-                commitState();
+                stateMgr.flush();
+                recordCollector.flush();
+
+                log.info("Prepared task for committing");
+
+                break;
+
+            default:
+                throw new IllegalStateException("Illegal state " + state() + " while preparing active task " + id + " for committing");
+        }
+    }
+
+    @Override
+    public void postCommit() {
+        switch (state()) {
+            case RUNNING:
+                commitNeeded = false;
+                commitRequested = false;
 
-                // this is an optimization for non-EOS only
                 if (eosDisabled) {
                     stateMgr.checkpoint(checkpointableOffsets());
                 }
@@ -313,35 +334,31 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
 
                 break;
 
-            case CLOSING:
-                // do nothing
+            case RESTORING:
+                commitNeeded = false;
+                commitRequested = false;
+
+                stateMgr.checkpoint(checkpointableOffsets());
+
+                log.info("Committed");
+
                 break;
 
             default:
-                throw new IllegalStateException("Illegal state " + state() + " while committing standby task " + id);
+                throw new IllegalStateException("Illegal state " + state() + " while post committing active task " + id);
         }
     }
 
-    /**
-     * <pre>
-     * the following order must be followed:
-     *  1. flush the state, send any left changelog records
-     *  2. then flush the record collector
-     *  3. then commit the record collector -- for EOS this is the synchronization barrier
-     * </pre>
-     *
-     * @throws TaskMigratedException if committing offsets failed (non-EOS)
-     *                               or if the task producer got fenced (EOS)
-     */
-    private void commitState() {
-        final long startNs = time.nanoseconds();
-
-        stateMgr.flush();
+    @Override
+    public Map<TopicPartition, OffsetAndMetadata> committableOffsetsAndMetadata() {
+        if (state() == State.CLOSED) {
+            throw new IllegalStateException("Task " + id + " is closed.");
+        }
 
-        recordCollector.flush();
+        if (state() != State.RUNNING) {
+            return Collections.emptyMap();
+        }
 
-        // we need to preserve the original partitions times before calling commit
-        // because all partition times are reset to -1 during close
         final Map<TopicPartition, Long> partitionTimes = extractPartitionTimes();
 
         final Map<TopicPartition, OffsetAndMetadata> consumedOffsetsAndMetadata = new HashMap<>(consumedOffsets.size());
@@ -365,11 +382,8 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
             final long partitionTime = partitionTimes.get(partition);
             consumedOffsetsAndMetadata.put(partition, new OffsetAndMetadata(offset, encodeTimestamp(partitionTime)));
         }
-        recordCollector.commit(consumedOffsetsAndMetadata);
 
-        commitNeeded = false;
-        commitRequested = false;
-        commitSensor.record(time.nanoseconds() - startNs);
+        return consumedOffsetsAndMetadata;
     }
 
     private Map<TopicPartition, Long> extractPartitionTimes() {
@@ -381,15 +395,31 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
     }
 
     @Override
-    public void closeClean() {
-        close(true);
+    public Map<TopicPartition, Long> prepareCloseClean() {
+        final Map<TopicPartition, Long> checkpoint = prepareClose(true);
+
+        log.info("Prepared clean close");
+
+        return checkpoint;
+    }
+
+    @Override
+    public void closeClean(final Map<TopicPartition, Long> checkpoint) {
+        close(true, checkpoint);
 
         log.info("Closed clean");
     }
 
     @Override
+    public void prepareCloseDirty() {
+        prepareClose(false);
+
+        log.info("Prepared dirty close");
+    }
+
+    @Override
     public void closeDirty() {
-        close(false);
+        close(false, null);
 
         log.info("Closed dirty");
     }
@@ -400,65 +430,87 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
      *  1. first close topology to make sure all cached records in the topology are processed
      *  2. then flush the state, send any left changelog records
      *  3. then flush the record collector
-     *  4. then commit the record collector -- for EOS this is the synchronization barrier
-     *  5. then checkpoint the state manager -- even if we crash before this step, EOS is still guaranteed
-     *  6. then if we are closing on EOS and dirty, wipe out the state store directory
-     *  7. finally release the state manager lock
      * </pre>
      *
-     * @param clean    shut down cleanly (ie, incl. flush and commit) if {@code true} --
+     * @param clean    shut down cleanly (ie, incl. flush) if {@code true} --
      *                 otherwise, just close open resources
-     * @throws TaskMigratedException if committing offsets failed (non-EOS)
-     *                               or if the task producer got fenced (EOS)
+     * @throws TaskMigratedException if the task producer got fenced (EOS)
      */
-    private void close(final boolean clean) {
+    private Map<TopicPartition, Long> prepareClose(final boolean clean) {
+        final Map<TopicPartition, Long> checkpoint;
+
         if (state() == State.CREATED) {
             // the task is created and not initialized, just re-write the checkpoint file
-            executeAndMaybeSwallow(clean, () -> {
-                stateMgr.checkpoint(Collections.emptyMap());
-            }, "state manager checkpoint", log);
-
-            transitionTo(State.CLOSING);
+            checkpoint = Collections.emptyMap();
         } else if (state() == State.RUNNING) {
             closeTopology(clean);
 
             if (clean) {
-                commitState();
-                // whenever we have successfully committed state, it is safe to checkpoint
-                // the state as well no matter if EOS is enabled or not
-                stateMgr.checkpoint(checkpointableOffsets());
+                stateMgr.flush();
+                recordCollector.flush();
+                checkpoint = checkpointableOffsets();
             } else {
+                checkpoint = null; // `null` indicates to not write a checkpoint
                 executeAndMaybeSwallow(false, stateMgr::flush, "state manager flush", log);
             }
-
-            transitionTo(State.CLOSING);
         } else if (state() == State.RESTORING) {
-            executeAndMaybeSwallow(clean, () -> {
-                stateMgr.flush();
-                stateMgr.checkpoint(Collections.emptyMap());
-            }, "state manager flush and checkpoint", log);
-
-            transitionTo(State.CLOSING);
+            executeAndMaybeSwallow(clean, stateMgr::flush, "state manager flush", log);
+            checkpoint = Collections.emptyMap();
         } else if (state() == State.SUSPENDED) {
-            // do not need to commit / checkpoint, since when suspending we've already committed the state
-            transitionTo(State.CLOSING);
+            // if `SUSPENDED` do not need to checkpoint, since when suspending we've already committed the state
+            checkpoint = null; // `null` indicates to not write a checkpoint
+        } else {
+            throw new IllegalStateException("Illegal state " + state() + " while prepare closing active task " + id);
         }
 
-        if (state() == State.CLOSING) {
-            // if EOS is enabled, we wipe out the whole state store for unclean close
-            // since they are invalid to use anymore
-            final boolean wipeStateStore = !clean && !eosDisabled;
+        return checkpoint;
+    }
 
-            // first close state manager (which is idempotent) then close the record collector (which could throw),
-            // if the latter throws and we re-close dirty which would close the state manager again.
-            executeAndMaybeSwallow(clean, () -> {
-                StateManagerUtil.closeStateManager(log, logPrefix, clean,
-                        wipeStateStore, stateMgr, stateDirectory, TaskType.ACTIVE);
-            }, "state manager close", log);
+    /**
+     * <pre>
+     * the following order must be followed:
+     *  1. checkpoint the state manager -- even if we crash before this step, EOS is still guaranteed
+     *  2. then if we are closing on EOS and dirty, wipe out the state store directory
+     *  3. finally release the state manager lock
+     * </pre>
+     */
+    private void close(final boolean clean,
+                       final Map<TopicPartition, Long> checkpoint) {
+        if (clean && checkpoint != null) {
+            executeAndMaybeSwallow(clean, () -> stateMgr.checkpoint(checkpoint), "state manager checkpoint", log);
+        }
 
-            executeAndMaybeSwallow(clean, recordCollector::close, "record collector close", log);
-        } else {
-            throw new IllegalStateException("Illegal state " + state() + " while closing active task " + id);
+        switch (state()) {
+            case CREATED:
+            case RUNNING:
+            case RESTORING:
+            case SUSPENDED:
+                // if EOS is enabled, we wipe out the whole state store for unclean close
+                // since they are invalid to use anymore
+                final boolean wipeStateStore = !clean && !eosDisabled;
+
+                // first close state manager (which is idempotent) then close the record collector (which could throw),
+                // if the latter throws and we re-close dirty which would close the state manager again.
+                executeAndMaybeSwallow(
+                    clean,
+                    () -> StateManagerUtil.closeStateManager(
+                        log,
+                        logPrefix,
+                        clean,
+                        wipeStateStore,
+                        stateMgr,
+                        stateDirectory,
+                        TaskType.ACTIVE
+                    ),
+                    "state manager close",
+                    log);
+
+                executeAndMaybeSwallow(clean, recordCollector::close, "record collector close", log);
+
+                break;
+
+            default:
+                throw new IllegalStateException("Illegal state " + state() + " while closing active task " + id);
         }
 
         partitionGroup.close();
@@ -472,7 +524,7 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
      * source topic partitions, or if it is enforced to be processable
      */
     public boolean isProcessable(final long wallClockTime) {
-        if (state() == State.CLOSED || state() == State.CLOSING) {
+        if (state() == State.CLOSED) {
             // a task is only closing / closed when 1) task manager is closing, 2) a rebalance is undergoing;
             // in either case we can just log it and move on without notifying the thread since the consumer
             // would soon be updated to not return any records for this task anymore.
@@ -509,8 +561,7 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
      * @return true if this method processes a record, false if it does not process a record.
      * @throws TaskMigratedException if the task producer got fenced (EOS only)
      */
-    @SuppressWarnings("unchecked")
-    @Override
+    @SuppressWarnings({"unchecked", "rawtypes"})
     public boolean process(final long wallClockTime) {
         if (!isProcessable(wallClockTime)) {
             return false;
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
index adb6f04..3a09552 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
@@ -57,6 +57,8 @@ import java.util.Set;
 import java.util.UUID;
 import java.util.concurrent.atomic.AtomicInteger;
 
+import static org.apache.kafka.streams.StreamsConfig.EXACTLY_ONCE;
+
 public class StreamThread extends Thread {
 
     private final Admin adminClient;
@@ -338,7 +340,8 @@ public class StreamThread extends Thread {
             standbyTaskCreator,
             builder,
             adminClient,
-            stateDirectory
+            stateDirectory,
+            EXACTLY_ONCE.equals(config.getString(StreamsConfig.PROCESSING_GUARANTEE_CONFIG))
         );
 
         log.info("Creating consumer client");
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsProducer.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsProducer.java
index 0aa4de6..2c4f592 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsProducer.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsProducer.java
@@ -59,24 +59,27 @@ public class StreamsProducer {
 
     public StreamsProducer(final Producer<byte[], byte[]> producer,
                            final boolean eosEnabled,
-                           final LogContext logContext,
-                           final String applicationId) {
-        log = logContext.logger(getClass());
-        logPrefix = logContext.logPrefix().trim();
-
+                           final String applicationId,
+                           final LogContext logContext) {
         this.producer = Objects.requireNonNull(producer, "producer cannot be null");
-        this.applicationId = applicationId;
         this.eosEnabled = eosEnabled;
+        this.applicationId = applicationId;
+        if (eosEnabled && applicationId == null) {
+            throw new IllegalArgumentException("applicationId cannot be null if EOS is enabled");
+        }
+
+        log = Objects.requireNonNull(logContext, "logContext cannot be null").logger(getClass());
+        logPrefix = logContext.logPrefix().trim();
     }
 
     private String formatException(final String message) {
-        return message + " [" + logPrefix + ", " + (eosEnabled ? "eos" : "alo") + "]";
+        return message + " [" + logPrefix + "]";
     }
 
     /**
      * @throws IllegalStateException if EOS is disabled
      */
-    public void initTransaction() {
+    void initTransaction() {
         if (!eosEnabled) {
             throw new IllegalStateException(formatException("EOS is disabled"));
         }
@@ -88,7 +91,7 @@ public class StreamsProducer {
                 transactionInitialized = true;
             } catch (final TimeoutException exception) {
                 log.warn(
-                    "Timeout exception caught when initializing transactions. " +
+                    "Timeout exception caught trying to initialize transactions. " +
                         "The broker is either slow or in bad state (like not having enough replicas) in " +
                         "responding to the request, or the connection to broker was interrupted sending " +
                         "the request or receiving the response. " +
@@ -100,34 +103,34 @@ public class StreamsProducer {
                 throw exception;
             } catch (final KafkaException exception) {
                 throw new StreamsException(
-                    formatException("Error encountered while initializing transactions"),
+                    formatException("Error encountered trying to initialize transactions"),
                     exception
                 );
             }
         }
     }
 
-    private void maybeBeginTransaction() throws ProducerFencedException {
+    void maybeBeginTransaction() throws ProducerFencedException {
         if (eosEnabled && !transactionInFlight) {
             try {
                 producer.beginTransaction();
                 transactionInFlight = true;
             } catch (final ProducerFencedException error) {
                 throw new TaskMigratedException(
-                    formatException("Producer get fenced trying to begin a new transaction"),
+                    formatException("Producer got fenced trying to begin a new transaction"),
                     error
                 );
             } catch (final KafkaException error) {
                 throw new StreamsException(
-                    formatException("Producer encounter unexpected error trying to begin a new transaction"),
+                    formatException("Error encountered trying to begin a new transaction"),
                     error
                 );
             }
         }
     }
 
-    public Future<RecordMetadata> send(final ProducerRecord<byte[], byte[]> record,
-                                       final Callback callback) {
+    Future<RecordMetadata> send(final ProducerRecord<byte[], byte[]> record,
+                                final Callback callback) {
         maybeBeginTransaction();
         try {
             return producer.send(record, callback);
@@ -137,12 +140,12 @@ public class StreamsProducer {
                 // in this case we should throw its wrapped inner cause so that it can be
                 // captured and re-wrapped as TaskMigrationException
                 throw new TaskMigratedException(
-                    formatException("Producer cannot send records anymore since it got fenced"),
+                    formatException("Producer got fenced trying to send a record"),
                     uncaughtException.getCause()
                 );
             } else {
                 throw new StreamsException(
-                    formatException(String.format("Error encountered sending record to topic %s", record.topic())),
+                    formatException(String.format("Error encountered trying to send record to topic %s", record.topic())),
                     uncaughtException
                 );
             }
@@ -158,7 +161,7 @@ public class StreamsProducer {
      * @throws IllegalStateException if EOS is disabled
      * @throws TaskMigratedException
      */
-    public void commitTransaction(final Map<TopicPartition, OffsetAndMetadata> offsets) throws ProducerFencedException {
+    void commitTransaction(final Map<TopicPartition, OffsetAndMetadata> offsets) throws ProducerFencedException {
         if (!eosEnabled) {
             throw new IllegalStateException(formatException("EOS is disabled"));
         }
@@ -169,15 +172,15 @@ public class StreamsProducer {
             transactionInFlight = false;
         } catch (final ProducerFencedException error) {
             throw new TaskMigratedException(
-                formatException("Producer get fenced trying to commit a transaction"),
+                formatException("Producer got fenced trying to commit a transaction"),
                 error
             );
         } catch (final TimeoutException error) {
             // TODO KIP-447: we can consider treating it as non-fatal and retry on the thread level
-            throw new StreamsException(formatException("Timed out while committing a transaction"), error);
+            throw new StreamsException(formatException("Timed out trying to commit a transaction"), error);
         } catch (final KafkaException error) {
             throw new StreamsException(
-                formatException("Producer encounter unexpected error trying to commit a transaction"),
+                formatException("Error encountered trying to commit a transaction"),
                 error
             );
         }
@@ -186,7 +189,7 @@ public class StreamsProducer {
     /**
      * @throws IllegalStateException if EOS is disabled
      */
-    public void abortTransaction() throws ProducerFencedException {
+    void abortTransaction() throws ProducerFencedException {
         if (!eosEnabled) {
             throw new IllegalStateException(formatException("EOS is disabled"));
         }
@@ -205,7 +208,7 @@ public class StreamsProducer {
                 log.debug("Encountered {} while aborting the transaction; this is expected and hence swallowed", error.getMessage());
             } catch (final KafkaException error) {
                 throw new StreamsException(
-                    formatException("Producer encounter unexpected error trying to abort a transaction"),
+                    formatException("Error encounter trying to abort a transaction"),
                     error
                 );
             }
@@ -213,15 +216,18 @@ public class StreamsProducer {
         }
     }
 
-    public List<PartitionInfo> partitionsFor(final String topic) throws TimeoutException {
+    List<PartitionInfo> partitionsFor(final String topic) throws TimeoutException {
         return producer.partitionsFor(topic);
     }
 
-    public void flush() {
+    void flush() {
         producer.flush();
     }
 
-    // for testing only
+    boolean eosEnabled() {
+        return eosEnabled;
+    }
+
     Producer<byte[], byte[]> kafkaProducer() {
         return producer;
     }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java
index 2bdce69..34fc600 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java
@@ -17,6 +17,7 @@
 package org.apache.kafka.streams.processor.internals;
 
 import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.consumer.OffsetAndMetadata;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.streams.errors.LockException;
 import org.apache.kafka.streams.errors.StreamsException;
@@ -73,8 +74,7 @@ public interface Task {
         RESTORING(2, 3, 4),    // 1
         RUNNING(3, 4),         // 2
         SUSPENDED(1, 4),       // 3
-        CLOSING(4, 5),         // 4, we allow CLOSING to transit to itself to make close idempotent
-        CLOSED(0);             // 5, we allow CLOSED to transit to CREATED to handle corrupted tasks
+        CLOSED(0);             // 4, we allow CLOSED to transit to CREATED to handle corrupted tasks
 
         private final Set<Integer> validTransitions = new HashSet<>();
 
@@ -125,34 +125,48 @@ public interface Task {
     boolean commitNeeded();
 
     /**
-     * @throws TaskMigratedException all the task has been migrated
      * @throws StreamsException fatal error, should close the thread
      */
-    void commit();
+    void prepareCommit();
+
+    void postCommit();
 
     /**
      * @throws TaskMigratedException all the task has been migrated
      * @throws StreamsException fatal error, should close the thread
      */
-    void suspend();
+    void prepareSuspend();
 
+    void suspend();
     /**
+     *
      * @throws StreamsException fatal error, should close the thread
      */
     void resume();
 
     /**
-     * Close a task that we still own. Commit all progress and close the task gracefully.
+     * Prepare to close a task that we still own and prepare it for committing
      * Throws an exception if this couldn't be done.
+     * Must be idempotent.
      *
-     * @throws TaskMigratedException all the task has been migrated
      * @throws StreamsException fatal error, should close the thread
      */
-    void closeClean();
+    Map<TopicPartition, Long> prepareCloseClean();
+
+    /**
+     * Must be idempotent.
+     */
+    void closeClean(final Map<TopicPartition, Long> checkpoint);
 
     /**
-     * Close a task that we may not own. Discard any uncommitted progress and close the task.
+     * Prepare to close a task that we may not own. Discard any uncommitted progress and close the task.
      * Never throws an exception, but just makes all attempts to release resources while closing.
+     * Must be idempotent.
+     */
+    void prepareCloseDirty();
+
+    /**
+     * Must be idempotent.
      */
     void closeDirty();
 
@@ -182,6 +196,10 @@ public interface Task {
         return Collections.emptyMap();
     }
 
+    default Map<TopicPartition, OffsetAndMetadata> committableOffsetsAndMetadata() {
+        return Collections.emptyMap();
+    }
+
     default boolean process(final long wallClockTime) {
         return false;
     }
@@ -197,4 +215,5 @@ public interface Task {
     default boolean maybePunctuateSystemTime() {
         return false;
     }
+
 }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
index b804a4a..071c7cb 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
@@ -21,7 +21,9 @@ import java.util.Collections;
 import org.apache.kafka.clients.admin.Admin;
 import org.apache.kafka.clients.admin.DeleteRecordsResult;
 import org.apache.kafka.clients.admin.RecordsToDelete;
+import org.apache.kafka.clients.consumer.CommitFailedException;
 import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.clients.consumer.OffsetAndMetadata;
 import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.Metric;
 import org.apache.kafka.common.MetricName;
@@ -71,6 +73,7 @@ public class TaskManager {
     private final InternalTopologyBuilder builder;
     private final Admin adminClient;
     private final StateDirectory stateDirectory;
+    private final boolean eosEnabled;
 
     private final Map<TaskId, Task> tasks = new TreeMap<>();
     // materializing this relationship because the lookup is on the hot path
@@ -93,7 +96,8 @@ public class TaskManager {
                 final StandbyTaskCreator standbyTaskCreator,
                 final InternalTopologyBuilder builder,
                 final Admin adminClient,
-                final StateDirectory stateDirectory) {
+                final StateDirectory stateDirectory,
+                final boolean eosEnabled) {
         this.changelogReader = changelogReader;
         this.processId = processId;
         this.logPrefix = logPrefix;
@@ -103,6 +107,7 @@ public class TaskManager {
         this.builder = builder;
         this.adminClient = adminClient;
         this.stateDirectory = stateDirectory;
+        this.eosEnabled = eosEnabled;
 
         final LogContext logContext = new LogContext(logPrefix);
         log = logContext.logger(getClass());
@@ -154,6 +159,7 @@ public class TaskManager {
             final Collection<TopicPartition> corruptedPartitions = entry.getValue();
             task.markChangelogAsCorrupted(corruptedPartitions);
 
+            task.prepareCloseDirty();
             task.closeDirty();
             task.revive();
         }
@@ -179,6 +185,11 @@ public class TaskManager {
 
         // first rectify all existing tasks
         final LinkedHashMap<TaskId, RuntimeException> taskCloseExceptions = new LinkedHashMap<>();
+
+        final Map<Task, Map<TopicPartition, Long>> checkpointPerTask = new HashMap<>();
+        final Map<TaskId, Map<TopicPartition, OffsetAndMetadata>> consumedOffsetsAndMetadataPerTask = new HashMap<>();
+        final Set<Task> dirtyTasks = new HashSet<>();
+
         final Iterator<Task> iterator = tasks.values().iterator();
         while (iterator.hasNext()) {
             final Task task = iterator.next();
@@ -192,30 +203,52 @@ public class TaskManager {
                 cleanupTask(task);
 
                 try {
-                    task.closeClean();
+                    final Map<TopicPartition, Long> checkpoint = task.prepareCloseClean();
+                    final Map<TopicPartition, OffsetAndMetadata> committableOffsets = task.committableOffsetsAndMetadata();
+
+                    checkpointPerTask.put(task, checkpoint);
+                    if (!committableOffsets.isEmpty()) {
+                        consumedOffsetsAndMetadataPerTask.put(task.id(), committableOffsets);
+                    }
                 } catch (final RuntimeException e) {
                     final String uncleanMessage = String.format("Failed to close task %s cleanly. Attempting to close remaining tasks before re-throwing:", task.id());
                     log.error(uncleanMessage, e);
                     taskCloseExceptions.put(task.id(), e);
                     // We've already recorded the exception (which is the point of clean).
                     // Now, we should go ahead and complete the close because a half-closed task is no good to anyone.
-                    task.closeDirty();
-                } finally {
-                    if (task.isActive()) {
-                        try {
-                            activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(task.id());
-                        } catch (final RuntimeException e) {
-                            final String uncleanMessage = String.format("Failed to close task %s cleanly. Attempting to close remaining tasks before re-throwing:", task.id());
-                            log.error(uncleanMessage, e);
-                            taskCloseExceptions.putIfAbsent(task.id(), e);
-                        }
-                    }
+                    task.prepareCloseDirty();
+                    dirtyTasks.add(task);
                 }
 
                 iterator.remove();
             }
         }
 
+        if (!consumedOffsetsAndMetadataPerTask.isEmpty()) {
+            commitOffsetsOrTransaction(consumedOffsetsAndMetadataPerTask);
+        }
+
+        for (final Map.Entry<Task, Map<TopicPartition, Long>> taskAndCheckpoint : checkpointPerTask.entrySet()) {
+            final Task task = taskAndCheckpoint.getKey();
+            try {
+                task.closeClean(checkpointPerTask.get(task));
+            } catch (final RuntimeException e) {
+                final String uncleanMessage = String.format("Failed to close task %s cleanly. Attempting to close remaining tasks before re-throwing:", task.id());
+                log.error(uncleanMessage, e);
+                taskCloseExceptions.put(task.id(), e);
+                // We've already recorded the exception (which is the point of clean).
+                // Now, we should go ahead and complete the close because a half-closed task is no good to anyone.
+                task.closeDirty();
+            } finally {
+                cleanUpTaskProducer(task, taskCloseExceptions);
+            }
+        }
+
+        for (final Task task : dirtyTasks) {
+            task.closeDirty();
+            cleanUpTaskProducer(task, taskCloseExceptions);
+        }
+
         if (!taskCloseExceptions.isEmpty()) {
             for (final Map.Entry<TaskId, RuntimeException> entry : taskCloseExceptions.entrySet()) {
                 if (!(entry.getValue() instanceof TaskMigratedException)) {
@@ -257,6 +290,19 @@ public class TaskManager {
         changelogReader.transitToRestoreActive();
     }
 
+    private void cleanUpTaskProducer(final Task task,
+                                     final Map<TaskId, RuntimeException> taskCloseExceptions) {
+        if (task.isActive()) {
+            try {
+                activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(task.id());
+            } catch (final RuntimeException e) {
+                final String uncleanMessage = String.format("Failed to close task %s cleanly. Attempting to close remaining tasks before re-throwing:", task.id());
+                log.error(uncleanMessage, e);
+                taskCloseExceptions.putIfAbsent(task.id(), e);
+            }
+        }
+    }
+
     private void addNewTask(final Task task) {
         final Task previous = tasks.put(task.id(), task);
         if (previous != null) {
@@ -330,13 +376,28 @@ public class TaskManager {
     void handleRevocation(final Collection<TopicPartition> revokedPartitions) {
         final Set<TopicPartition> remainingPartitions = new HashSet<>(revokedPartitions);
 
+        final Map<TaskId, Map<TopicPartition, OffsetAndMetadata>> consumedOffsetsAndMetadataPerTask = new HashMap<>();
         for (final Task task : tasks.values()) {
             if (remainingPartitions.containsAll(task.inputPartitions())) {
-                task.suspend();
+                task.prepareSuspend();
+                final Map<TopicPartition, OffsetAndMetadata> committableOffsets = task.committableOffsetsAndMetadata();
+                if (!committableOffsets.isEmpty()) {
+                    consumedOffsetsAndMetadataPerTask.put(task.id(), committableOffsets);
+                }
             }
             remainingPartitions.removeAll(task.inputPartitions());
         }
 
+        if (!consumedOffsetsAndMetadataPerTask.isEmpty()) {
+            commitOffsetsOrTransaction(consumedOffsetsAndMetadataPerTask);
+        }
+
+        for (final Task task : tasks.values()) {
+            if (consumedOffsetsAndMetadataPerTask.containsKey(task.id())) {
+                task.suspend();
+            }
+        }
+
         if (!remainingPartitions.isEmpty()) {
             log.warn("The following partitions {} are missing from the task partitions. It could potentially " +
                          "due to race condition of consumer detecting the heartbeat failure, or the tasks " +
@@ -362,6 +423,7 @@ public class TaskManager {
             // standby tasks while we rejoin.
             if (task.isActive()) {
                 cleanupTask(task);
+                task.prepareCloseDirty();
                 task.closeDirty();
                 iterator.remove();
                 try {
@@ -497,24 +559,52 @@ public class TaskManager {
 
     void shutdown(final boolean clean) {
         final AtomicReference<RuntimeException> firstException = new AtomicReference<>(null);
-        final Iterator<Task> iterator = tasks.values().iterator();
-        while (iterator.hasNext()) {
-            final Task task = iterator.next();
+
+        final Map<Task, Map<TopicPartition, Long>> checkpointPerTask = new HashMap<>();
+        final Map<TaskId, Map<TopicPartition, OffsetAndMetadata>> consumedOffsetsAndMetadataPerTask = new HashMap<>();
+
+        for (final Task task : tasks.values()) {
             cleanupTask(task);
 
             if (clean) {
                 try {
-                    task.closeClean();
+                    final Map<TopicPartition, Long> checkpoint = task.prepareCloseClean();
+                    final Map<TopicPartition, OffsetAndMetadata> committableOffsets = task.committableOffsetsAndMetadata();
+
+                    checkpointPerTask.put(task, checkpoint);
+                    if (!committableOffsets.isEmpty()) {
+                        consumedOffsetsAndMetadataPerTask.put(task.id(), committableOffsets);
+                    }
                 } catch (final TaskMigratedException e) {
                     // just ignore the exception as it doesn't matter during shutdown
+                    task.prepareCloseDirty();
                     task.closeDirty();
                 } catch (final RuntimeException e) {
                     firstException.compareAndSet(null, e);
+                    task.prepareCloseDirty();
                     task.closeDirty();
                 }
             } else {
+                task.prepareCloseDirty();
                 task.closeDirty();
             }
+        }
+
+        if (clean && !consumedOffsetsAndMetadataPerTask.isEmpty()) {
+            commitOffsetsOrTransaction(consumedOffsetsAndMetadataPerTask);
+        }
+
+        for (final Map.Entry<Task, Map<TopicPartition, Long>> taskAndCheckpoint : checkpointPerTask.entrySet()) {
+            final Task task = taskAndCheckpoint.getKey();
+            try {
+                task.closeClean(checkpointPerTask.get(task));
+            } catch (final RuntimeException e) {
+                firstException.compareAndSet(null, e);
+                task.closeDirty();
+            }
+        }
+
+        for (final Task task : tasks.values()) {
             if (task.isActive()) {
                 try {
                     activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(task.id());
@@ -526,9 +616,10 @@ public class TaskManager {
                     }
                 }
             }
-            iterator.remove();
         }
 
+        tasks.clear();
+
         try {
             activeTaskCreator.closeThreadProducerIfNeeded();
         } catch (final RuntimeException e) {
@@ -604,14 +695,28 @@ public class TaskManager {
         if (rebalanceInProgress) {
             return -1;
         } else {
-            int commits = 0;
+            int committed = 0;
+            final Map<TaskId, Map<TopicPartition, OffsetAndMetadata>> consumedOffsetsAndMetadataPerTask = new HashMap<>();
+            for (final Task task : tasks.values()) {
+                if (task.commitNeeded()) {
+                    task.prepareCommit();
+                    final Map<TopicPartition, OffsetAndMetadata> offsetAndMetadata = task.committableOffsetsAndMetadata();
+                    if (!offsetAndMetadata.isEmpty()) {
+                        consumedOffsetsAndMetadataPerTask.put(task.id(), offsetAndMetadata);
+                    }
+                }
+            }
+
+            commitOffsetsOrTransaction(consumedOffsetsAndMetadataPerTask);
+
             for (final Task task : tasks.values()) {
                 if (task.commitNeeded()) {
-                    task.commit();
-                    commits++;
+                    ++committed;
+                    task.postCommit();
                 }
             }
-            return commits;
+
+            return committed;
         }
     }
 
@@ -623,14 +728,49 @@ public class TaskManager {
         if (rebalanceInProgress) {
             return -1;
         } else {
-            int commits = 0;
+            final Map<TaskId, Map<TopicPartition, OffsetAndMetadata>> consumedOffsetsAndMetadataPerTask = new HashMap<>();
             for (final Task task : activeTaskIterable()) {
                 if (task.commitRequested() && task.commitNeeded()) {
-                    task.commit();
-                    commits++;
+                    task.prepareCommit();
+                    final Map<TopicPartition, OffsetAndMetadata> offsetAndMetadata = task.committableOffsetsAndMetadata();
+                    if (!offsetAndMetadata.isEmpty()) {
+                        consumedOffsetsAndMetadataPerTask.put(task.id(), offsetAndMetadata);
+                    }
                 }
             }
-            return commits;
+
+            commitOffsetsOrTransaction(consumedOffsetsAndMetadataPerTask);
+
+            for (final Task task : tasks.values()) {
+                if (consumedOffsetsAndMetadataPerTask.containsKey(task.id())) {
+                    task.postCommit();
+                }
+            }
+
+            return consumedOffsetsAndMetadataPerTask.size();
+        }
+    }
+
+    private void commitOffsetsOrTransaction(final Map<TaskId, Map<TopicPartition, OffsetAndMetadata>> offsetsPerTask) {
+        if (eosEnabled) {
+            for (final Map.Entry<TaskId, Map<TopicPartition, OffsetAndMetadata>> taskToCommit : offsetsPerTask.entrySet()) {
+                activeTaskCreator.streamsProducerForTask(taskToCommit.getKey()).commitTransaction(taskToCommit.getValue());
+            }
+        } else {
+            try {
+                final Map<TopicPartition, OffsetAndMetadata> allOffsets = offsetsPerTask.values().stream()
+                    .flatMap(e -> e.entrySet().stream()).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
+
+                mainConsumer.commitSync(allOffsets);
+            } catch (final CommitFailedException error) {
+                throw new TaskMigratedException("Consumer committing offsets failed, " +
+                    "indicating the corresponding thread is no longer part of the group", error);
+            } catch (final TimeoutException error) {
+                // TODO KIP-447: we can consider treating it as non-fatal and retry on the thread level
+                throw new StreamsException("Timed out while committing offsets via consumer", error);
+            } catch (final KafkaException error) {
+                throw new StreamsException("Error encountered committing offsets via consumer", error);
+            }
         }
     }
 
diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/MetricsIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/MetricsIntegrationTest.java
index 63ce9c9..2937498 100644
--- a/streams/src/test/java/org/apache/kafka/streams/integration/MetricsIntegrationTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/integration/MetricsIntegrationTest.java
@@ -61,7 +61,6 @@ import java.util.stream.Collectors;
 import static org.hamcrest.CoreMatchers.is;
 import static org.hamcrest.MatcherAssert.assertThat;
 
-@SuppressWarnings("unchecked")
 @Category({IntegrationTest.class})
 public class MetricsIntegrationTest {
 
@@ -261,7 +260,7 @@ public class MetricsIntegrationTest {
 
         verifyStateMetric(State.CREATED);
         verifyTopologyDescriptionMetric(topology.describe().toString());
-        verifyApplicationIdMetric(APPLICATION_ID_VALUE);
+        verifyApplicationIdMetric();
 
         kafkaStreams.start();
         TestUtils.waitForCondition(
@@ -501,13 +500,13 @@ public class MetricsIntegrationTest {
         assertThat(metricsList.get(0).metricValue(), is(topologyDescription));
     }
 
-    private void verifyApplicationIdMetric(final String applicationId) {
+    private void verifyApplicationIdMetric() {
         final List<Metric> metricsList = new ArrayList<Metric>(kafkaStreams.metrics().values()).stream()
             .filter(m -> m.metricName().name().equals(APPLICATION_ID) &&
                 m.metricName().group().equals(STREAM_CLIENT_NODE_METRICS))
             .collect(Collectors.toList());
         assertThat(metricsList.size(), is(1));
-        assertThat(metricsList.get(0).metricValue(), is(applicationId));
+        assertThat(metricsList.get(0).metricValue(), is(APPLICATION_ID_VALUE));
     }
 
     private void checkClientLevelMetrics() {
@@ -565,10 +564,6 @@ public class MetricsIntegrationTest {
             .collect(Collectors.toList());
         final int numberOfAddedMetrics = StreamsConfig.METRICS_0100_TO_24.equals(builtInMetricsVersion) ? 0 : 4;
         final int numberOfMetricsWithRemovedParent = StreamsConfig.METRICS_0100_TO_24.equals(builtInMetricsVersion) ? 5 : 4;
-        checkMetricByName(listMetricTask, COMMIT_LATENCY_AVG, numberOfMetricsWithRemovedParent);
-        checkMetricByName(listMetricTask, COMMIT_LATENCY_MAX, numberOfMetricsWithRemovedParent);
-        checkMetricByName(listMetricTask, COMMIT_RATE, numberOfMetricsWithRemovedParent);
-        checkMetricByName(listMetricTask, COMMIT_TOTAL, numberOfMetricsWithRemovedParent);
         checkMetricByName(listMetricTask, ENFORCED_PROCESSING_RATE, 4);
         checkMetricByName(listMetricTask, ENFORCED_PROCESSING_TOTAL, 4);
         checkMetricByName(listMetricTask, RECORD_LATENESS_AVG, 4);
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreatorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreatorTest.java
index 0f9ecdf..b1c4b2c 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreatorTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreatorTest.java
@@ -17,6 +17,7 @@
 package org.apache.kafka.streams.processor.internals;
 
 import java.io.File;
+
 import org.apache.kafka.clients.admin.Admin;
 import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.clients.consumer.MockConsumer;
@@ -43,14 +44,22 @@ import org.junit.runner.RunWith;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.stream.Collectors;
 
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
+import static org.apache.kafka.common.utils.Utils.mkSet;
 import static org.easymock.EasyMock.anyString;
 import static org.easymock.EasyMock.expect;
 import static org.easymock.EasyMock.mock;
 import static org.easymock.EasyMock.replay;
+import static org.easymock.EasyMock.same;
+import static org.hamcrest.CoreMatchers.equalTo;
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.core.IsNot.not;
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
 
 @RunWith(EasyMockRunner.class)
 public class ActiveTaskCreatorTest {
@@ -74,6 +83,122 @@ public class ActiveTaskCreatorTest {
     private ActiveTaskCreator activeTaskCreator;
 
     @Test
+    public void shouldFailForNonEosOnStreamsProducerPerTask() {
+        expect(config.getString(StreamsConfig.APPLICATION_ID_CONFIG)).andReturn("appId");
+        expect(config.getString(StreamsConfig.PROCESSING_GUARANTEE_CONFIG)).andReturn(StreamsConfig.AT_LEAST_ONCE);
+        expect(config.getProducerConfigs(anyString())).andReturn(Collections.emptyMap());
+        replay(config);
+
+        activeTaskCreator = new ActiveTaskCreator(
+            builder,
+            config,
+            streamsMetrics,
+            stateDirectory,
+            changeLogReader,
+            new ThreadCache(new LogContext(), 0L, streamsMetrics),
+            new MockTime(),
+            mockClientSupplier,
+            "threadId",
+            new LogContext().logger(ActiveTaskCreator.class)
+        );
+
+        final IllegalStateException thrown = assertThrows(
+            IllegalStateException.class,
+            () -> activeTaskCreator.streamsProducerForTask(null)
+        );
+
+        assertThat(thrown.getMessage(), is("Producer per thread is used"));
+    }
+
+    @Test
+    public void shouldFailForUnknownTaskOnStreamsProducerPerTask() {
+        expect(config.getString(StreamsConfig.APPLICATION_ID_CONFIG)).andReturn("appId");
+        expect(config.getString(StreamsConfig.PROCESSING_GUARANTEE_CONFIG)).andReturn(StreamsConfig.EXACTLY_ONCE);
+        expect(config.getProducerConfigs(anyString())).andReturn(Collections.emptyMap());
+        replay(config);
+
+        activeTaskCreator = new ActiveTaskCreator(
+            builder,
+            config,
+            streamsMetrics,
+            stateDirectory,
+            changeLogReader,
+            new ThreadCache(new LogContext(), 0L, streamsMetrics),
+            new MockTime(),
+            mockClientSupplier,
+            "threadId",
+            new LogContext().logger(ActiveTaskCreator.class)
+        );
+
+        {
+            final IllegalStateException thrown = assertThrows(
+                IllegalStateException.class,
+                () -> activeTaskCreator.streamsProducerForTask(null)
+            );
+
+            assertThat(thrown.getMessage(), is("Unknown TaskId: null"));
+        }
+        {
+            final IllegalStateException thrown = assertThrows(
+                IllegalStateException.class,
+                () -> activeTaskCreator.streamsProducerForTask(new TaskId(0, 0))
+            );
+
+            assertThat(thrown.getMessage(), is("Unknown TaskId: 0_0"));
+        }
+    }
+
+    @Test
+    public void shouldReturnStreamsProducerPerTask() {
+        final TaskId task00 = new TaskId(0, 0);
+        final TaskId task01 = new TaskId(0, 1);
+        final ProcessorTopology topology = mock(ProcessorTopology.class);
+
+        expect(config.getString(StreamsConfig.APPLICATION_ID_CONFIG)).andReturn("appId");
+        expect(config.getString(StreamsConfig.PROCESSING_GUARANTEE_CONFIG)).andReturn(StreamsConfig.EXACTLY_ONCE);
+        expect(config.getProducerConfigs(anyString())).andReturn(new HashMap<>()).anyTimes();
+        expect(config.getLong(anyString())).andReturn(0L).anyTimes();
+        expect(config.getInt(anyString())).andReturn(0).anyTimes();
+        expect(builder.buildSubtopology(task00.topicGroupId)).andReturn(topology).anyTimes();
+        expect(stateDirectory.directoryForTask(task00)).andReturn(new File(task00.toString()));
+        expect(stateDirectory.directoryForTask(task01)).andReturn(new File(task01.toString()));
+        expect(topology.storeToChangelogTopic()).andReturn(Collections.emptyMap()).anyTimes();
+        expect(topology.source("topic")).andReturn(mock(SourceNode.class)).andReturn(mock(SourceNode.class));
+        expect(topology.globalStateStores()).andReturn(Collections.emptyList()).anyTimes();
+        replay(config, builder, stateDirectory, topology);
+
+        mockClientSupplier.setApplicationIdForProducer("appId");
+        activeTaskCreator = new ActiveTaskCreator(
+            builder,
+            config,
+            streamsMetrics,
+            stateDirectory,
+            changeLogReader,
+            new ThreadCache(new LogContext(), 0L, streamsMetrics),
+            new MockTime(),
+            mockClientSupplier,
+            "threadId",
+            new LogContext().logger(ActiveTaskCreator.class)
+        );
+
+        assertThat(
+            activeTaskCreator.createTasks(
+                null,
+                mkMap(
+                    mkEntry(task00, Collections.singleton(new TopicPartition("topic", 0))),
+                    mkEntry(task01, Collections.singleton(new TopicPartition("topic", 1)))
+                )
+            ).stream().map(Task::id).collect(Collectors.toSet()),
+            equalTo(mkSet(task00, task01))
+        );
+
+        final StreamsProducer streamsProducer1 = activeTaskCreator.streamsProducerForTask(new TaskId(0, 0));
+        final StreamsProducer streamsProducer2 = activeTaskCreator.streamsProducerForTask(new TaskId(0, 1));
+
+        assertThat(streamsProducer1, not(same(streamsProducer2)));
+    }
+
+    @Test
     public void shouldConstructProducerMetricsWithoutEOS() {
         expect(config.getString(StreamsConfig.APPLICATION_ID_CONFIG)).andReturn("appId");
         expect(config.getString(StreamsConfig.PROCESSING_GUARANTEE_CONFIG)).andReturn(StreamsConfig.AT_LEAST_ONCE);
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java
index e09d99c..5df3f12 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java
@@ -16,11 +16,6 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
-import org.apache.kafka.clients.consumer.CommitFailedException;
-import org.apache.kafka.clients.consumer.KafkaConsumer;
-import org.apache.kafka.clients.consumer.MockConsumer;
-import org.apache.kafka.clients.consumer.OffsetAndMetadata;
-import org.apache.kafka.clients.consumer.OffsetResetStrategy;
 import org.apache.kafka.clients.producer.Callback;
 import org.apache.kafka.clients.producer.MockProducer;
 import org.apache.kafka.clients.producer.ProducerRecord;
@@ -35,7 +30,6 @@ import org.apache.kafka.common.PartitionInfo;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.errors.AuthenticationException;
 import org.apache.kafka.common.errors.ProducerFencedException;
-import org.apache.kafka.common.errors.TimeoutException;
 import org.apache.kafka.common.header.Header;
 import org.apache.kafka.common.header.Headers;
 import org.apache.kafka.common.header.internals.RecordHeader;
@@ -67,6 +61,7 @@ import java.util.Map;
 import java.util.concurrent.Future;
 import java.util.concurrent.atomic.AtomicBoolean;
 
+import static org.easymock.EasyMock.expect;
 import static org.easymock.EasyMock.expectLastCall;
 import static org.easymock.EasyMock.mock;
 import static org.easymock.EasyMock.replay;
@@ -104,10 +99,9 @@ public class RecordCollectorTest {
 
     private final StreamPartitioner<String, Object> streamPartitioner = (topic, key, value, numPartitions) -> Integer.parseInt(key) % numPartitions;
 
-    private final MockConsumer<byte[], byte[]> mockConsumer = new MockConsumer<>(OffsetResetStrategy.EARLIEST);
     private final MockProducer<byte[], byte[]> mockProducer = new MockProducer<>(
         cluster, true, new DefaultPartitioner(), byteArraySerializer, byteArraySerializer);
-    private final StreamsProducer streamsProducer = new StreamsProducer(mockProducer, false, logContext, null);
+    private final StreamsProducer streamsProducer = new StreamsProducer(mockProducer, false, null, logContext);
 
     private RecordCollectorImpl collector;
 
@@ -116,10 +110,8 @@ public class RecordCollectorTest {
         collector = new RecordCollectorImpl(
             logContext,
             taskId,
-            mockConsumer,
             streamsProducer,
             productionExceptionHandler,
-            false,
             streamsMetrics);
     }
 
@@ -243,51 +235,29 @@ public class RecordCollectorTest {
     }
 
     @Test
-    public void shouldCommitViaConsumerIfEosDisabled() {
-        final KafkaConsumer<byte[], byte[]> consumer = mock(KafkaConsumer.class);
-        consumer.commitSync((Map<TopicPartition, OffsetAndMetadata>) null);
-        expectLastCall();
-        replay(consumer);
-
-        final RecordCollector collector = new RecordCollectorImpl(
-            logContext,
-            taskId,
-            consumer,
-            streamsProducer,
-            productionExceptionHandler,
-            false,
-            streamsMetrics);
-
-        collector.commit(null);
-
-        verify(consumer);
-
-    }
-
-    @Test
-    public void shouldCommitViaProducerIfEosEnabled() {
+    public void shouldForwardFlushToStreamsProducer() {
         final StreamsProducer streamsProducer = mock(StreamsProducer.class);
-        streamsProducer.commitTransaction(null);
+        expect(streamsProducer.eosEnabled()).andReturn(false);
+        streamsProducer.flush();
         expectLastCall();
         replay(streamsProducer);
 
         final RecordCollector collector = new RecordCollectorImpl(
             logContext,
             taskId,
-            mockConsumer,
             streamsProducer,
             productionExceptionHandler,
-            true,
             streamsMetrics);
 
-        collector.commit(null);
+        collector.flush();
 
         verify(streamsProducer);
     }
 
     @Test
-    public void shouldForwardFlushToTransactionManager() {
+    public void shouldForwardFlushToStreamsProducerEosEnabled() {
         final StreamsProducer streamsProducer = mock(StreamsProducer.class);
+        expect(streamsProducer.eosEnabled()).andReturn(true);
         streamsProducer.flush();
         expectLastCall();
         replay(streamsProducer);
@@ -295,10 +265,8 @@ public class RecordCollectorTest {
         final RecordCollector collector = new RecordCollectorImpl(
             logContext,
             taskId,
-            mockConsumer,
             streamsProducer,
             productionExceptionHandler,
-            true,
             streamsMetrics);
 
         collector.flush();
@@ -307,37 +275,17 @@ public class RecordCollectorTest {
     }
 
     @Test
-    public void shouldForwardCloseToTransactionManager() {
-        final StreamsProducer streamsProducer = mock(StreamsProducer.class);
-        replay(streamsProducer);
-
-        final RecordCollector collector = new RecordCollectorImpl(
-            logContext,
-            taskId,
-            mockConsumer,
-            streamsProducer,
-            productionExceptionHandler,
-            false,
-            streamsMetrics);
-
-        collector.close();
-
-        verify(streamsProducer);
-    }
-
-    @Test
     public void shouldAbortTxIfEosEnabled() {
         final StreamsProducer streamsProducer = mock(StreamsProducer.class);
+        expect(streamsProducer.eosEnabled()).andReturn(true);
         streamsProducer.abortTransaction();
         replay(streamsProducer);
 
         final RecordCollector collector = new RecordCollectorImpl(
             logContext,
             taskId,
-            mockConsumer,
             streamsProducer,
             productionExceptionHandler,
-            true,
             streamsMetrics);
 
         collector.close();
@@ -345,7 +293,7 @@ public class RecordCollectorTest {
         verify(streamsProducer);
     }
 
-    @SuppressWarnings("unchecked")
+    @SuppressWarnings({"unchecked", "rawtypes"})
     @Test
     public void shouldThrowInformativeStreamsExceptionOnKeyClassCastException() {
         final StreamsException expected = assertThrows(
@@ -373,7 +321,7 @@ public class RecordCollectorTest {
         );
     }
 
-    @SuppressWarnings("unchecked")
+    @SuppressWarnings({"unchecked", "rawtypes"})
     @Test
     public void shouldThrowInformativeStreamsExceptionOnKeyAndNullValueClassCastException() {
         final StreamsException expected = assertThrows(
@@ -401,7 +349,7 @@ public class RecordCollectorTest {
         );
     }
 
-    @SuppressWarnings("unchecked")
+    @SuppressWarnings({"unchecked", "rawtypes"})
     @Test
     public void shouldThrowInformativeStreamsExceptionOnValueClassCastException() {
         final StreamsException expected = assertThrows(
@@ -429,7 +377,7 @@ public class RecordCollectorTest {
         );
     }
 
-    @SuppressWarnings("unchecked")
+    @SuppressWarnings({"unchecked", "rawtypes"})
     @Test
     public void shouldThrowInformativeStreamsExceptionOnValueAndNullKeyClassCastException() {
         final StreamsException expected = assertThrows(
@@ -463,7 +411,6 @@ public class RecordCollectorTest {
         final RecordCollector collector = new RecordCollectorImpl(
             logContext,
             taskId,
-            mockConsumer,
             new StreamsProducer(
                 new MockProducer<byte[], byte[]>(cluster, true, new DefaultPartitioner(), byteArraySerializer, byteArraySerializer) {
                     @Override
@@ -473,11 +420,10 @@ public class RecordCollectorTest {
                     }
                 },
                 true,
-                logContext,
-                "appId"
+                "appId",
+                logContext
             ),
             productionExceptionHandler,
-            true,
             streamsMetrics
         );
         collector.initialize();
@@ -506,7 +452,6 @@ public class RecordCollectorTest {
         final RecordCollector collector = new RecordCollectorImpl(
             logContext,
             taskId,
-            mockConsumer,
             new StreamsProducer(
                 new MockProducer<byte[], byte[]>(cluster, true, new DefaultPartitioner(), byteArraySerializer, byteArraySerializer) {
                     @Override
@@ -516,11 +461,10 @@ public class RecordCollectorTest {
                     }
                 },
                 false,
-                logContext,
-                null
+                null,
+                logContext
             ),
             productionExceptionHandler,
-            false,
             streamsMetrics
         );
 
@@ -548,7 +492,6 @@ public class RecordCollectorTest {
         final RecordCollector collector = new RecordCollectorImpl(
             logContext,
             taskId,
-            mockConsumer,
             new StreamsProducer(
                 new MockProducer<byte[], byte[]>(cluster, true, new DefaultPartitioner(), byteArraySerializer, byteArraySerializer) {
                     @Override
@@ -558,11 +501,10 @@ public class RecordCollectorTest {
                     }
                 },
                 false,
-                logContext,
-                null
+                null,
+                logContext
             ),
             new AlwaysContinueProductionExceptionHandler(),
-            false,
             streamsMetrics
         );
 
@@ -591,7 +533,6 @@ public class RecordCollectorTest {
         final RecordCollector collector = new RecordCollectorImpl(
             logContext,
             taskId,
-            mockConsumer,
             new StreamsProducer(
                 new MockProducer<byte[], byte[]>(cluster, true, new DefaultPartitioner(), byteArraySerializer, byteArraySerializer) {
                     @Override
@@ -601,11 +542,10 @@ public class RecordCollectorTest {
                     }
                 },
                 false,
-                logContext,
-                null
+                null,
+                logContext
             ),
             new AlwaysContinueProductionExceptionHandler(),
-            false,
             streamsMetrics
         );
 
@@ -628,102 +568,11 @@ public class RecordCollectorTest {
     }
 
     @Test
-    public void shouldThrowTaskMigratedExceptionOnCommitFailed() {
-        final RecordCollector collector = new RecordCollectorImpl(
-            logContext,
-            taskId,
-            new MockConsumer<byte[], byte[]>(OffsetResetStrategy.EARLIEST) {
-                @Override
-                public void commitSync(final Map<TopicPartition, OffsetAndMetadata> offsets) {
-                    throw new CommitFailedException();
-                }
-            },
-            streamsProducer,
-            productionExceptionHandler,
-            false,
-            streamsMetrics
-        );
-
-        final TaskMigratedException thrown = assertThrows(TaskMigratedException.class, () -> collector.commit(null));
-
-        assertThat(thrown.getMessage(), equalTo("Consumer committing offsets failed, indicating the corresponding thread is no longer part of the group; it means all tasks belonging to this thread should be migrated."));
-    }
-
-    @Test
-    public void shouldThrowStreamsExceptionOnCommitTimeout() {
-        final RecordCollector collector = new RecordCollectorImpl(
-            logContext,
-            taskId,
-            new MockConsumer<byte[], byte[]>(OffsetResetStrategy.EARLIEST) {
-                @Override
-                public void commitSync(final Map<TopicPartition, OffsetAndMetadata> offsets) {
-                    throw new TimeoutException();
-                }
-            },
-            streamsProducer,
-            productionExceptionHandler,
-            false,
-            streamsMetrics
-        );
-
-        final StreamsException thrown = assertThrows(StreamsException.class, () -> collector.commit(null));
-
-        assertThat(thrown.getMessage(), equalTo("Timed out while committing offsets via consumer for task 0_0"));
-    }
-
-    @Test
-    public void shouldStreamsExceptionOnCommitError() {
-        final RecordCollector collector = new RecordCollectorImpl(
-            logContext,
-            taskId,
-            new MockConsumer<byte[], byte[]>(OffsetResetStrategy.EARLIEST) {
-                @Override
-                public void commitSync(final Map<TopicPartition, OffsetAndMetadata> offsets) {
-                    throw new KafkaException();
-                }
-            },
-            streamsProducer,
-            productionExceptionHandler,
-            false,
-            streamsMetrics
-        );
-        collector.initialize();
-
-        final StreamsException thrown = assertThrows(StreamsException.class, () -> collector.commit(null));
-
-        assertThat(thrown.getMessage(), equalTo("Error encountered committing offsets via consumer for task 0_0"));
-    }
-
-    @Test
-    public void shouldFailOnCommitFatal() {
-        final RecordCollector collector = new RecordCollectorImpl(
-            logContext,
-            taskId,
-            new MockConsumer<byte[], byte[]>(OffsetResetStrategy.EARLIEST) {
-                @Override
-                public void commitSync(final Map<TopicPartition, OffsetAndMetadata> offsets) {
-                    throw new RuntimeException("KABOOM!");
-                }
-            },
-            streamsProducer,
-            productionExceptionHandler,
-            false,
-            streamsMetrics
-        );
-        collector.initialize();
-
-        final RuntimeException thrown = assertThrows(RuntimeException.class, () -> collector.commit(null));
-
-        assertThat(thrown.getMessage(), equalTo("KABOOM!"));
-    }
-
-    @Test
     public void shouldNotAbortTxnOnEOSCloseIfNothingSent() {
         final AtomicBoolean functionCalled = new AtomicBoolean(false);
         final RecordCollector collector = new RecordCollectorImpl(
             logContext,
             taskId,
-            mockConsumer,
             new StreamsProducer(
                 new MockProducer<byte[], byte[]>(cluster, true, new DefaultPartitioner(), byteArraySerializer, byteArraySerializer) {
                     @Override
@@ -732,11 +581,10 @@ public class RecordCollectorTest {
                     }
                 },
                 true,
-                logContext,
-                "appId"
+                "appId",
+                logContext
             ),
             productionExceptionHandler,
-            true,
             streamsMetrics
         );
 
@@ -749,7 +597,6 @@ public class RecordCollectorTest {
         final RecordCollector collector = new RecordCollectorImpl(
             logContext,
             taskId,
-            mockConsumer,
             new StreamsProducer(
                 new MockProducer<byte[], byte[]>(cluster, true, new DefaultPartitioner(), byteArraySerializer, byteArraySerializer) {
                     @Override
@@ -758,11 +605,10 @@ public class RecordCollectorTest {
                     }
                 },
                 false,
-                logContext,
-                null
+                null,
+                logContext
             ),
             productionExceptionHandler,
-            false,
             streamsMetrics
         );
         collector.initialize();
@@ -779,10 +625,8 @@ public class RecordCollectorTest {
         final RecordCollector collector = new RecordCollectorImpl(
             logContext,
             taskId,
-            mockConsumer,
-            new StreamsProducer(mockProducer, true, logContext, "appId"),
+            new StreamsProducer(mockProducer, true, "appId", logContext),
             productionExceptionHandler,
-            true,
             streamsMetrics
         );
 
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
index 2bde646..a12da5b 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
@@ -49,6 +49,7 @@ import org.junit.runner.RunWith;
 import java.io.File;
 import java.io.IOException;
 import java.util.Collections;
+import java.util.Map;
 
 import static java.util.Arrays.asList;
 import static org.apache.kafka.common.utils.Utils.mkEntry;
@@ -133,6 +134,7 @@ public class StandbyTaskTest {
     @After
     public void cleanup() throws IOException {
         if (task != null && !task.isClosed()) {
+            task.prepareCloseDirty();
             task.closeDirty();
             task = null;
         }
@@ -179,7 +181,7 @@ public class StandbyTaskTest {
     public void shouldThrowIfCommittingOnIllegalState() {
         task = createStandbyTask();
 
-        assertThrows(IllegalStateException.class, task::commit);
+        assertThrows(IllegalStateException.class, task::prepareCommit);
     }
 
     @Test
@@ -193,7 +195,8 @@ public class StandbyTaskTest {
 
         task = createStandbyTask();
         task.initializeIfNeeded();
-        task.commit();
+        task.prepareCommit();
+        task.postCommit();
 
         EasyMock.verify(stateManager);
     }
@@ -222,7 +225,8 @@ public class StandbyTaskTest {
         final MetricName metricName = setupCloseTaskMetric();
 
         task = createStandbyTask();
-        task.closeClean();
+        final Map<TopicPartition, Long> checkpoint = task.prepareCloseClean();
+        task.closeClean(checkpoint);
 
         assertEquals(Task.State.CLOSED, task.state());
 
@@ -246,6 +250,7 @@ public class StandbyTaskTest {
 
         task = createStandbyTask();
         task.initializeIfNeeded();
+        task.prepareCloseDirty();
         task.closeDirty();
 
         assertEquals(Task.State.CLOSED, task.state());
@@ -285,7 +290,8 @@ public class StandbyTaskTest {
 
         task = createStandbyTask();
         task.initializeIfNeeded();
-        task.closeClean();
+        final Map<TopicPartition, Long> checkpoint = task.prepareCloseClean();
+        task.closeClean(checkpoint);
 
         assertEquals(Task.State.CLOSED, task.state());
 
@@ -313,7 +319,8 @@ public class StandbyTaskTest {
 
         assertTrue(task.commitNeeded());
 
-        task.commit();
+        task.prepareCommit();
+        task.postCommit();
 
         // do not need to commit if there's no update
         assertFalse(task.commitNeeded());
@@ -336,9 +343,8 @@ public class StandbyTaskTest {
         task = createStandbyTask();
         task.initializeIfNeeded();
 
-        assertThrows(RuntimeException.class, task::closeClean);
-
-        assertEquals(Task.State.CLOSING, task.state());
+        final Map<TopicPartition, Long> checkpoint = task.prepareCloseClean();
+        assertThrows(RuntimeException.class, () -> task.closeClean(checkpoint));
 
         final double expectedCloseTaskMetric = 0.0;
         verifyCloseTaskMetric(expectedCloseTaskMetric, streamsMetrics, metricName);
@@ -360,7 +366,7 @@ public class StandbyTaskTest {
         task = createStandbyTask();
         task.initializeIfNeeded();
 
-        assertThrows(RuntimeException.class, task::closeClean);
+        assertThrows(RuntimeException.class, task::prepareCloseClean);
         assertEquals(Task.State.RUNNING, task.state());
 
         final double expectedCloseTaskMetric = 0.0;
@@ -383,7 +389,8 @@ public class StandbyTaskTest {
         task = createStandbyTask();
         task.initializeIfNeeded();
 
-        assertThrows(RuntimeException.class, task::closeClean);
+        task.prepareCommit();
+        assertThrows(RuntimeException.class, task::postCommit);
 
         assertEquals(Task.State.RUNNING, task.state());
 
@@ -400,11 +407,12 @@ public class StandbyTaskTest {
     public void shouldThrowIfClosingOnIllegalState() {
         task = createStandbyTask();
 
-        task.closeClean();
+        final Map<TopicPartition, Long> checkpoint = task.prepareCloseClean();
+        task.closeClean(checkpoint);
 
         // close call are not idempotent since we are already in closed
-        assertThrows(IllegalStateException.class, task::closeClean);
-        assertThrows(IllegalStateException.class, task::closeDirty);
+        assertThrows(IllegalStateException.class, task::prepareCloseClean);
+        assertThrows(IllegalStateException.class, task::prepareCloseDirty);
     }
 
     private StandbyTask createStandbyTask() {
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
index dc0404b..5218020 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
@@ -86,6 +86,7 @@ import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetric
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.CoreMatchers.nullValue;
 import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.is;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotNull;
@@ -207,6 +208,7 @@ public class StreamTaskTest {
     @After
     public void cleanup() throws IOException {
         if (task != null && !task.isClosed()) {
+            task.prepareCloseDirty();
             task.closeDirty();
             task = null;
         }
@@ -264,7 +266,7 @@ public class StreamTaskTest {
         ctrl.replay();
 
         task = createStatefulTask(createConfig(true, "100"), true, stateManager);
-        task.transitionTo(Task.State.CLOSING);
+        task.prepareCloseDirty();
         task.closeDirty();
         task = null;
 
@@ -383,31 +385,6 @@ public class StreamTaskTest {
         task = createStatelessTask(createConfig(false, "100"), builtInMetricsVersion);
 
         assertNotNull(getMetric(
-            "commit",
-            "%s-latency-avg",
-            task.id().toString(),
-            builtInMetricsVersion
-        ));
-        assertNotNull(getMetric(
-            "commit",
-            "%s-latency-max",
-            task.id().toString(),
-            builtInMetricsVersion
-        ));
-        assertNotNull(getMetric(
-            "commit",
-            "%s-rate",
-            task.id().toString(),
-            builtInMetricsVersion
-        ));
-        assertNotNull(getMetric(
-            "commit",
-            "%s-total",
-            task.id().toString(),
-            builtInMetricsVersion
-        ));
-
-        assertNotNull(getMetric(
             "enforced-processing",
             "%s-rate",
             task.id().toString(),
@@ -714,44 +691,80 @@ public class StreamTaskTest {
         assertTrue(task.process(0L));
         assertTrue(task.commitNeeded());
 
-        task.commit();
+        task.prepareCommit();
+        assertTrue(task.commitNeeded());
+
+        task.postCommit();
         assertFalse(task.commitNeeded());
 
         assertTrue(task.maybePunctuateStreamTime());
         assertTrue(task.commitNeeded());
 
-        task.commit();
+        task.prepareCommit();
+        assertTrue(task.commitNeeded());
+
+        task.postCommit();
         assertFalse(task.commitNeeded());
 
         time.sleep(10);
         assertTrue(task.maybePunctuateSystemTime());
         assertTrue(task.commitNeeded());
 
-        task.commit();
+        task.prepareCommit();
+        assertTrue(task.commitNeeded());
+
+        task.postCommit();
         assertFalse(task.commitNeeded());
     }
 
     @Test
     public void shouldCommitNextOffsetFromQueueIfAvailable() {
-        recordCollector.commit(EasyMock.eq(mkMap(mkEntry(partition1, new OffsetAndMetadata(5L, encodeTimestamp(5L))))));
-        EasyMock.expectLastCall();
-
         task = createStatelessTask(createConfig(false, "0"), StreamsConfig.METRICS_LATEST);
         task.initializeIfNeeded();
         task.completeRestoration();
 
         task.addRecords(partition1, Arrays.asList(getConsumerRecord(partition1, 0L), getConsumerRecord(partition1, 5L)));
         task.process(0L);
-        task.commit();
+        task.prepareCommit();
+        final Map<TopicPartition, OffsetAndMetadata> offsetsAndMetadata = task.committableOffsetsAndMetadata();
 
-        EasyMock.verify(recordCollector);
+        assertThat(offsetsAndMetadata, equalTo(mkMap(mkEntry(partition1, new OffsetAndMetadata(5L, encodeTimestamp(5L))))));
     }
 
     @Test
     public void shouldCommitConsumerPositionIfRecordQueueIsEmpty() {
-        recordCollector.commit(EasyMock.eq(mkMap(mkEntry(partition1, new OffsetAndMetadata(3L, encodeTimestamp(0L))))));
-        EasyMock.expectLastCall();
+        task = createStatelessTask(createConfig(false, "0"), StreamsConfig.METRICS_LATEST);
+        task.initializeIfNeeded();
+        task.completeRestoration();
+
+        consumer.addRecord(getConsumerRecord(partition1, 0L));
+        consumer.addRecord(getConsumerRecord(partition1, 1L));
+        consumer.addRecord(getConsumerRecord(partition1, 2L));
+        consumer.poll(Duration.ZERO);
+
+        task.addRecords(partition1, singletonList(getConsumerRecord(partition1, 0L)));
+        task.process(0L);
+        task.prepareCommit();
+        final Map<TopicPartition, OffsetAndMetadata> offsetsAndMetadata = task.committableOffsetsAndMetadata();
+
+        assertThat(offsetsAndMetadata, equalTo(mkMap(mkEntry(partition1, new OffsetAndMetadata(3L, encodeTimestamp(0L))))));
+    }
+
+    @Test
+    public void shouldFailOnCommitIfTaskIsClosed() {
+        task = createStatelessTask(createConfig(false, "0"), StreamsConfig.METRICS_LATEST);
+        task.transitionTo(Task.State.CLOSED);
+
+        final IllegalStateException thrown = assertThrows(
+            IllegalStateException.class,
+            task::committableOffsetsAndMetadata
+        );
+
+        assertThat(thrown.getMessage(), is("Task 0_0 is closed."));
+    }
 
+    @Test
+    public void shouldOnlyCommitConsumerPositionTaskIfRunning() {
         task = createStatelessTask(createConfig(false, "0"), StreamsConfig.METRICS_LATEST);
         task.initializeIfNeeded();
         task.completeRestoration();
@@ -763,9 +776,23 @@ public class StreamTaskTest {
 
         task.addRecords(partition1, singletonList(getConsumerRecord(partition1, 0L)));
         task.process(0L);
-        task.commit();
+        task.prepareCommit();
+        task.postCommit();
+
+        final Map<TopicPartition, OffsetAndMetadata> offsetsAndMetadata = task.committableOffsetsAndMetadata();
+
+        assertThat(offsetsAndMetadata, equalTo(mkMap(mkEntry(partition1, new OffsetAndMetadata(3L, encodeTimestamp(0L))))));
+
+        task.transitionTo(Task.State.SUSPENDED);
+        assertTrue(task.committableOffsetsAndMetadata().isEmpty());
+
+        task.transitionTo(Task.State.CLOSED);
+        task.transitionTo(Task.State.CREATED);
+        assertTrue(task.committableOffsetsAndMetadata().isEmpty());
+
+        task.transitionTo(Task.State.RESTORING);
+        assertTrue(task.committableOffsetsAndMetadata().isEmpty());
 
-        EasyMock.verify(recordCollector);
     }
 
     @Test
@@ -1041,8 +1068,6 @@ public class StreamTaskTest {
         stateDirectory = EasyMock.createNiceMock(StateDirectory.class);
         EasyMock.expect(stateDirectory.lock(taskId)).andReturn(true);
         EasyMock.expect(recordCollector.offsets()).andReturn(Collections.singletonMap(changelogPartition, 10L));
-        recordCollector.commit(EasyMock.eq(Collections.emptyMap()));
-        EasyMock.expectLastCall();
         EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()).anyTimes();
         stateManager.checkpoint(EasyMock.eq(Collections.singletonMap(changelogPartition, 10L)));
         EasyMock.expectLastCall();
@@ -1053,6 +1078,7 @@ public class StreamTaskTest {
         task.initializeIfNeeded();
         task.completeRestoration();
 
+        task.prepareSuspend();
         task.suspend();
 
         assertEquals(Task.State.SUSPENDED, task.state());
@@ -1082,8 +1108,6 @@ public class StreamTaskTest {
         stateDirectory = EasyMock.createNiceMock(StateDirectory.class);
         EasyMock.expect(stateDirectory.lock(taskId)).andReturn(true);
         EasyMock.expect(recordCollector.offsets()).andThrow(new AssertionError("Should not try to read offsets")).anyTimes();
-        recordCollector.commit(EasyMock.anyObject());
-        EasyMock.expectLastCall().andThrow(new AssertionError("Should not try to commit")).anyTimes();
         EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()).anyTimes();
         stateManager.checkpoint(EasyMock.eq(Collections.emptyMap()));
         EasyMock.expectLastCall();
@@ -1094,6 +1118,7 @@ public class StreamTaskTest {
 
         task.initializeIfNeeded();
 
+        task.prepareSuspend();
         task.suspend();
 
         assertEquals(Task.State.SUSPENDED, task.state());
@@ -1120,8 +1145,6 @@ public class StreamTaskTest {
         final Long offset = 543L;
 
         EasyMock.expect(recordCollector.offsets()).andReturn(Collections.singletonMap(changelogPartition, offset));
-        recordCollector.commit(EasyMock.eq(Collections.emptyMap()));
-        EasyMock.expectLastCall();
         stateManager.checkpoint(EasyMock.eq(Collections.singletonMap(changelogPartition, offset)));
         EasyMock.expectLastCall();
         EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.singleton(changelogPartition));
@@ -1133,7 +1156,8 @@ public class StreamTaskTest {
 
         task.initializeIfNeeded();
         task.completeRestoration();
-        task.commit();
+        task.prepareCommit();
+        task.postCommit();
 
         EasyMock.verify(recordCollector);
     }
@@ -1149,7 +1173,8 @@ public class StreamTaskTest {
 
         task.initializeIfNeeded();
         task.completeRestoration();
-        task.commit();
+        task.prepareCommit();
+        task.postCommit();
         final File checkpointFile = new File(
             stateDirectory.directoryForTask(taskId),
             StateManagerUtil.CHECKPOINT_FILE_NAME
@@ -1220,6 +1245,7 @@ public class StreamTaskTest {
         task.initializeIfNeeded();
         task.completeRestoration();
 
+        task.prepareCloseDirty();
         task.closeDirty();
 
         EasyMock.verify(stateManager);
@@ -1262,7 +1288,7 @@ public class StreamTaskTest {
         assertTrue(task.process(0L));
         assertTrue(task.process(0L));
 
-        task.commit();
+        task.prepareCommit();
 
         final Map<TopicPartition, Long> map = task.purgeableOffsets();
 
@@ -1291,8 +1317,19 @@ public class StreamTaskTest {
     @Test
     public void shouldThrowIfCommittingOnIllegalState() {
         task = createStatelessTask(createConfig(false, "100"), StreamsConfig.METRICS_LATEST);
+        assertThrows(IllegalStateException.class, task::prepareCommit);
 
-        assertThrows(IllegalStateException.class, task::commit);
+        task.transitionTo(Task.State.CLOSED);
+        assertThrows(IllegalStateException.class, task::prepareCommit);
+    }
+
+    @Test
+    public void shouldThrowIfPostCommittingOnIllegalState() {
+        task = createStatelessTask(createConfig(false, "100"), StreamsConfig.METRICS_LATEST);
+        assertThrows(IllegalStateException.class, task::postCommit);
+
+        task.transitionTo(Task.State.CLOSED);
+        assertThrows(IllegalStateException.class, task::postCommit);
     }
 
     @Test
@@ -1325,7 +1362,8 @@ public class StreamTaskTest {
 
         task = createOptimizedStatefulTask(createConfig(false, "100"), consumer);
 
-        task.closeClean();
+        final Map<TopicPartition, Long> checkpoint = task.prepareCloseClean();
+        task.closeClean(checkpoint);
 
         assertEquals(Task.State.CLOSED, task.state());
         assertFalse(source1.initialized);
@@ -1354,6 +1392,7 @@ public class StreamTaskTest {
         task.initializeIfNeeded();
         task.completeRestoration();
 
+        task.prepareCloseDirty();
         task.closeDirty();
 
         assertEquals(Task.State.CLOSED, task.state());
@@ -1372,14 +1411,13 @@ public class StreamTaskTest {
         EasyMock.expectLastCall();
         stateManager.checkpoint(EasyMock.eq(Collections.emptyMap()));
         EasyMock.expectLastCall();
-        recordCollector.commit(EasyMock.anyObject());
-        EasyMock.expectLastCall().andThrow(new AssertionError("Should not call this function")).anyTimes();
         EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()).anyTimes();
         EasyMock.replay(stateManager);
 
         task = createOptimizedStatefulTask(createConfig(false, "100"), consumer);
 
         task.initializeIfNeeded();
+        task.prepareSuspend();
         task.suspend();
 
         assertEquals(Task.State.SUSPENDED, task.state());
@@ -1393,15 +1431,14 @@ public class StreamTaskTest {
         EasyMock.expectLastCall();
         stateManager.checkpoint(EasyMock.eq(Collections.emptyMap()));
         EasyMock.expectLastCall();
-        recordCollector.commit(EasyMock.anyObject());
-        EasyMock.expectLastCall().andThrow(new AssertionError("Should not call this function")).anyTimes();
         EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()).anyTimes();
         EasyMock.replay(stateManager);
 
         task = createOptimizedStatefulTask(createConfig(false, "100"), consumer);
 
         task.initializeIfNeeded();
-        task.closeClean();
+        final Map<TopicPartition, Long> checkpoint = task.prepareCloseClean();
+        task.closeClean(checkpoint);
 
         assertEquals(Task.State.CLOSED, task.state());
 
@@ -1413,8 +1450,6 @@ public class StreamTaskTest {
         final long offset = 543L;
 
         EasyMock.expect(recordCollector.offsets()).andReturn(Collections.singletonMap(changelogPartition, offset));
-        recordCollector.commit(EasyMock.eq(Collections.emptyMap()));
-        EasyMock.expectLastCall();
         stateManager.close();
         EasyMock.expectLastCall();
         stateManager.flush();
@@ -1428,7 +1463,8 @@ public class StreamTaskTest {
         task = createOptimizedStatefulTask(createConfig(false, "100"), consumer);
         task.initializeIfNeeded();
         task.completeRestoration();
-        task.closeClean();
+        final Map<TopicPartition, Long> checkpoint = task.prepareCloseClean();
+        task.closeClean(checkpoint);
 
         assertEquals(Task.State.CLOSED, task.state());
 
@@ -1439,12 +1475,10 @@ public class StreamTaskTest {
     }
 
     @Test
-    public void shouldThrowOnCloseCleanError() {
+    public void shouldSwallowExceptionOnCloseCleanError() {
         final long offset = 543L;
 
         EasyMock.expect(recordCollector.offsets()).andReturn(Collections.singletonMap(changelogPartition, offset));
-        recordCollector.commit(EasyMock.eq(Collections.emptyMap()));
-        EasyMock.expectLastCall();
         stateManager.checkpoint(EasyMock.eq(Collections.singletonMap(changelogPartition, offset)));
         EasyMock.expectLastCall();
         EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.singleton(changelogPartition)).anyTimes();
@@ -1457,9 +1491,8 @@ public class StreamTaskTest {
         task.initializeIfNeeded();
         task.completeRestoration();
 
-        assertThrows(ProcessorStateException.class, task::closeClean);
-
-        assertEquals(Task.State.CLOSING, task.state());
+        final Map<TopicPartition, Long> checkpoint = task.prepareCloseClean();
+        assertThrows(ProcessorStateException.class, () -> task.closeClean(checkpoint));
 
         final double expectedCloseTaskMetric = 0.0;
         verifyCloseTaskMetric(expectedCloseTaskMetric, streamsMetrics, metricName);
@@ -1467,6 +1500,8 @@ public class StreamTaskTest {
         EasyMock.verify(stateManager);
         EasyMock.reset(stateManager);
         EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.singleton(changelogPartition)).anyTimes();
+        stateManager.close();
+        EasyMock.expectLastCall();
         EasyMock.replay(stateManager);
     }
 
@@ -1475,8 +1510,6 @@ public class StreamTaskTest {
         final long offset = 543L;
 
         EasyMock.expect(recordCollector.offsets()).andReturn(Collections.singletonMap(changelogPartition, offset));
-        recordCollector.commit(EasyMock.eq(Collections.emptyMap()));
-        EasyMock.expectLastCall();
         stateManager.flush();
         EasyMock.expectLastCall().andThrow(new ProcessorStateException("KABOOM!")).anyTimes();
         stateManager.checkpoint(EasyMock.anyObject());
@@ -1490,7 +1523,7 @@ public class StreamTaskTest {
         task = createOptimizedStatefulTask(createConfig(false, "100"), consumer);
         task.initializeIfNeeded();
 
-        assertThrows(ProcessorStateException.class, task::closeClean);
+        assertThrows(ProcessorStateException.class, task::prepareCloseClean);
 
         assertEquals(Task.State.RESTORING, task.state());
 
@@ -1508,8 +1541,6 @@ public class StreamTaskTest {
         final long offset = 543L;
 
         EasyMock.expect(recordCollector.offsets()).andReturn(Collections.singletonMap(changelogPartition, offset));
-        recordCollector.commit(EasyMock.eq(Collections.emptyMap()));
-        EasyMock.expectLastCall();
         stateManager.flush();
         EasyMock.expectLastCall();
         stateManager.checkpoint(Collections.emptyMap());
@@ -1524,7 +1555,8 @@ public class StreamTaskTest {
         task = createOptimizedStatefulTask(createConfig(false, "100"), consumer);
         task.initializeIfNeeded();
 
-        assertThrows(ProcessorStateException.class, task::closeClean);
+        final Map<TopicPartition, Long> checkpoint = task.prepareCloseClean();
+        assertThrows(ProcessorStateException.class, () -> task.closeClean(checkpoint));
 
         assertEquals(Task.State.RESTORING, task.state());
 
@@ -1558,10 +1590,11 @@ public class StreamTaskTest {
 
         task = createOptimizedStatefulTask(createConfig(false, "100"), consumer);
 
-        task.closeClean();
+        final Map<TopicPartition, Long> checkpoint = task.prepareCloseClean();
+        task.closeClean(checkpoint);
 
         // close call are not idempotent since we are already in closed
-        assertThrows(IllegalStateException.class, task::closeClean);
+        assertThrows(IllegalStateException.class, () -> task.closeClean(checkpoint));
         assertThrows(IllegalStateException.class, task::closeDirty);
 
         EasyMock.reset(stateManager);
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
index 6cd7129..2ddafd0 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
@@ -1046,6 +1046,10 @@ public class StreamThreadTest {
 
         assertThat(thread.activeTasks().size(), equalTo(1));
 
+        // need to process a record to enable committing
+        addRecord(mockConsumer, 0L);
+        thread.runOnce();
+
         clientSupplier.producers.get(0).commitTransactionException = new ProducerFencedException("Producer is fenced");
         assertThrows(TaskMigratedException.class, () -> thread.rebalanceListener.onPartitionsRevoked(assignedPartitions));
         assertFalse(clientSupplier.producers.get(0).transactionCommitted());
@@ -1133,6 +1137,10 @@ public class StreamThreadTest {
 
         assertThat(thread.activeTasks().size(), equalTo(1));
 
+        // need to process a record to enable committing
+        addRecord(mockConsumer, 0L);
+        thread.runOnce();
+
         thread.rebalanceListener.onPartitionsRevoked(assignedPartitions);
         assertTrue(clientSupplier.producers.get(0).transactionCommitted());
         assertFalse(clientSupplier.producers.get(0).closed());
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsProducerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsProducerTest.java
index c9945c3..90dedab 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsProducerTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsProducerTest.java
@@ -70,15 +70,15 @@ public class StreamsProducerTest {
         mkEntry(new TopicPartition(topic, 0), new OffsetAndMetadata(0L, null))
     );
 
-    private final MockProducer<byte[], byte[]> mockProducer = new MockProducer<>(
+    private final MockProducer<byte[], byte[]> nonEosMockProducer = new MockProducer<>(
         cluster, true, new DefaultPartitioner(), byteArraySerializer, byteArraySerializer);
-    private final StreamsProducer aloStreamsProducer =
-        new StreamsProducer(mockProducer, false, logContext, null);
+    private final StreamsProducer nonEosStreamsProducer =
+        new StreamsProducer(nonEosMockProducer, false, null, logContext);
 
     private final MockProducer<byte[], byte[]> eosMockProducer = new MockProducer<>(
         cluster, true, new DefaultPartitioner(), byteArraySerializer, byteArraySerializer);
     private final StreamsProducer eosStreamsProducer =
-        new StreamsProducer(eosMockProducer, true, logContext, "appId");
+        new StreamsProducer(eosMockProducer, true, "appId", logContext);
 
     private final ProducerRecord<byte[], byte[]> record =
         new ProducerRecord<>(topic, 0, 0L, new byte[0], new byte[0], new RecordHeaders());
@@ -88,44 +88,11 @@ public class StreamsProducerTest {
         eosStreamsProducer.initTransaction();
     }
 
-    @Test
-    public void shouldFailIfProducerIsNull() {
-        {
-            final NullPointerException thrown = assertThrows(
-                NullPointerException.class,
-                () -> new StreamsProducer(null, false, logContext, null)
-            );
 
-            assertThat(thrown.getMessage(), is("producer cannot be null"));
-        }
-
-        {
-            final NullPointerException thrown = assertThrows(
-                NullPointerException.class,
-                () -> new StreamsProducer(null, true, logContext, "appId")
-            );
-
-            assertThat(thrown.getMessage(), is("producer cannot be null"));
-        }
-    }
-
-    @Test
-    public void shouldNotInitTxIfEosDisable() {
-        assertThat(mockProducer.transactionInitialized(), is(false));
-    }
 
-    @Test
-    public void shouldNotBeginTxOnSendIfEosDisable() {
-        aloStreamsProducer.send(record, null);
-        assertThat(mockProducer.transactionInFlight(), is(false));
-    }
+    // generic tests (non-EOS and EOS)
 
-    @Test
-    public void shouldForwardRecordOnSend() {
-        aloStreamsProducer.send(record, null);
-        assertThat(mockProducer.history().size(), is(1));
-        assertThat(mockProducer.history().get(0), is(record));
-    }
+    // functional tests
 
     @Test
     public void shouldForwardCallToPartitionsFor() {
@@ -136,7 +103,7 @@ public class StreamsProducerTest {
         replay(producer);
 
         final StreamsProducer streamsProducer =
-            new StreamsProducer(producer, false, logContext, null);
+            new StreamsProducer(producer, false, null, logContext);
 
         final List<PartitionInfo> partitionInfo = streamsProducer.partitionsFor(topic);
 
@@ -153,7 +120,7 @@ public class StreamsProducerTest {
         replay(producer);
 
         final StreamsProducer streamsProducer =
-            new StreamsProducer(producer, false, logContext, null);
+            new StreamsProducer(producer, false, null, logContext);
 
         streamsProducer.flush();
 
@@ -163,36 +130,91 @@ public class StreamsProducerTest {
     // error handling tests
 
     @Test
+    public void shouldFailIfProducerIsNull() {
+        {
+            final NullPointerException thrown = assertThrows(
+                NullPointerException.class,
+                () -> new StreamsProducer(null, false, "appId", logContext)
+            );
+
+            assertThat(thrown.getMessage(), is("producer cannot be null"));
+        }
+
+        {
+            final NullPointerException thrown = assertThrows(
+                NullPointerException.class,
+                () -> new StreamsProducer(null, true, "appId", logContext)
+            );
+
+            assertThat(thrown.getMessage(), is("producer cannot be null"));
+        }
+    }
+
+    @Test
+    public void shouldFailIfLogContextIsNull() {
+        final NullPointerException thrown = assertThrows(
+            NullPointerException.class,
+            () -> new StreamsProducer(nonEosMockProducer, false, "appId", null)
+        );
+
+        assertThat(thrown.getMessage(), is("logContext cannot be null"));
+    }
+
+
+    // non-EOS tests
+
+    // functional tests
+
+    @Test
+    public void shouldNotInitTxIfEosDisable() {
+        assertThat(nonEosMockProducer.transactionInitialized(), is(false));
+    }
+
+    @Test
+    public void shouldNotBeginTxOnSendIfEosDisable() {
+        nonEosStreamsProducer.send(record, null);
+        assertThat(nonEosMockProducer.transactionInFlight(), is(false));
+    }
+
+    @Test
+    public void shouldForwardRecordOnSend() {
+        nonEosStreamsProducer.send(record, null);
+        assertThat(nonEosMockProducer.history().size(), is(1));
+        assertThat(nonEosMockProducer.history().get(0), is(record));
+    }
+
+    // error handling tests
+
+    @Test
     public void shouldFailOnInitTxIfEosDisabled() {
         final IllegalStateException thrown = assertThrows(
             IllegalStateException.class,
-            aloStreamsProducer::initTransaction
+            nonEosStreamsProducer::initTransaction
         );
 
-        assertThat(thrown.getMessage(), is("EOS is disabled [test, alo]"));
+        assertThat(thrown.getMessage(), is("EOS is disabled [test]"));
     }
 
     @Test
     public void shouldThrowStreamsExceptionOnSendError() {
-        mockProducer.sendException  = new KafkaException("KABOOM!");
+        nonEosMockProducer.sendException  = new KafkaException("KABOOM!");
 
         final StreamsException thrown = assertThrows(
             StreamsException.class,
-            () -> aloStreamsProducer.send(record, null)
+            () -> nonEosStreamsProducer.send(record, null)
         );
 
-        assertThat(thrown.getCause(), is(mockProducer.sendException));
-        assertThat(thrown.getMessage(), is("Error encountered sending record to topic topic [test, alo]"));
-        assertThat(thrown.getCause(), is(mockProducer.sendException));
+        assertThat(thrown.getCause(), is(nonEosMockProducer.sendException));
+        assertThat(thrown.getMessage(), is("Error encountered trying to send record to topic topic [test]"));
     }
 
     @Test
     public void shouldFailOnSendFatal() {
-        mockProducer.sendException = new RuntimeException("KABOOM!");
+        nonEosMockProducer.sendException = new RuntimeException("KABOOM!");
 
         final RuntimeException thrown = assertThrows(
             RuntimeException.class,
-            () -> aloStreamsProducer.send(record, null)
+            () -> nonEosStreamsProducer.send(record, null)
         );
 
         assertThat(thrown.getMessage(), is("KABOOM!"));
@@ -202,22 +224,23 @@ public class StreamsProducerTest {
     public void shouldFailOnCommitIfEosDisabled() {
         final IllegalStateException thrown = assertThrows(
             IllegalStateException.class,
-            () -> aloStreamsProducer.commitTransaction(null)
+            () -> nonEosStreamsProducer.commitTransaction(null)
         );
 
-        assertThat(thrown.getMessage(), is("EOS is disabled [test, alo]"));
+        assertThat(thrown.getMessage(), is("EOS is disabled [test]"));
     }
 
     @Test
     public void shouldFailOnAbortIfEosDisabled() {
         final IllegalStateException thrown = assertThrows(
             IllegalStateException.class,
-            aloStreamsProducer::abortTransaction
+            nonEosStreamsProducer::abortTransaction
         );
 
-        assertThat(thrown.getMessage(), is("EOS is disabled [test, alo]"));
+        assertThat(thrown.getMessage(), is("EOS is disabled [test]"));
     }
 
+
     // EOS tests
 
     // functional tests
@@ -262,7 +285,7 @@ public class StreamsProducerTest {
         replay(producer);
 
         final StreamsProducer streamsProducer =
-            new StreamsProducer(producer, true, logContext, "appId");
+            new StreamsProducer(producer, true, "appId", logContext);
         streamsProducer.initTransaction();
 
         streamsProducer.commitTransaction(offsetsAndMetadata);
@@ -318,7 +341,7 @@ public class StreamsProducerTest {
         replay(producer);
 
         final StreamsProducer streamsProducer =
-            new StreamsProducer(producer, true, logContext, "appId");
+            new StreamsProducer(producer, true, "appId", logContext);
         streamsProducer.initTransaction();
 
         streamsProducer.abortTransaction();
@@ -329,11 +352,21 @@ public class StreamsProducerTest {
     // error handling tests
 
     @Test
+    public void shouldFailIfApplicationIdIsNullOnEos() {
+        final IllegalArgumentException thrown = assertThrows(
+            IllegalArgumentException.class,
+            () -> new StreamsProducer(eosMockProducer, true, null, logContext)
+        );
+
+        assertThat(thrown.getMessage(), is("applicationId cannot be null if EOS is enabled"));
+    }
+
+    @Test
     public void shouldThrowTimeoutExceptionOnEosInitTxTimeout() {
         // use `mockProducer` instead of `eosMockProducer` to avoid double Tx-Init
-        mockProducer.initTransactionException = new TimeoutException("KABOOM!");
+        nonEosMockProducer.initTransactionException = new TimeoutException("KABOOM!");
         final StreamsProducer streamsProducer =
-            new StreamsProducer(mockProducer, true, logContext, "appId");
+            new StreamsProducer(nonEosMockProducer, true, "appId", logContext);
 
         final TimeoutException thrown = assertThrows(
             TimeoutException.class,
@@ -346,25 +379,25 @@ public class StreamsProducerTest {
     @Test
     public void shouldThrowStreamsExceptionOnEosInitError() {
         // use `mockProducer` instead of `eosMockProducer` to avoid double Tx-Init
-        mockProducer.initTransactionException = new KafkaException("KABOOM!");
+        nonEosMockProducer.initTransactionException = new KafkaException("KABOOM!");
         final StreamsProducer streamsProducer =
-            new StreamsProducer(mockProducer, true, logContext, "appId");
+            new StreamsProducer(nonEosMockProducer, true, "appId", logContext);
 
         final StreamsException thrown = assertThrows(
             StreamsException.class,
             streamsProducer::initTransaction
         );
 
-        assertThat(thrown.getCause(), is(mockProducer.initTransactionException));
-        assertThat(thrown.getMessage(), is("Error encountered while initializing transactions [test, eos]"));
+        assertThat(thrown.getCause(), is(nonEosMockProducer.initTransactionException));
+        assertThat(thrown.getMessage(), is("Error encountered trying to initialize transactions [test]"));
     }
 
     @Test
     public void shouldFailOnEosInitFatal() {
         // use `mockProducer` instead of `eosMockProducer` to avoid double Tx-Init
-        mockProducer.initTransactionException = new RuntimeException("KABOOM!");
+        nonEosMockProducer.initTransactionException = new RuntimeException("KABOOM!");
         final StreamsProducer streamsProducer =
-            new StreamsProducer(mockProducer, true, logContext, "appId");
+            new StreamsProducer(nonEosMockProducer, true, "appId", logContext);
 
         final RuntimeException thrown = assertThrows(
             RuntimeException.class,
@@ -385,7 +418,7 @@ public class StreamsProducerTest {
 
         assertThat(
             thrown.getMessage(),
-            is("Producer get fenced trying to begin a new transaction [test, eos];" +
+            is("Producer got fenced trying to begin a new transaction [test];" +
                    " it means all tasks belonging to this thread should be migrated.")
         );
     }
@@ -402,7 +435,7 @@ public class StreamsProducerTest {
         assertThat(thrown.getCause(), is(eosMockProducer.beginTransactionException));
         assertThat(
             thrown.getMessage(),
-            is("Producer encounter unexpected error trying to begin a new transaction [test, eos]")
+            is("Error encountered trying to begin a new transaction [test]")
         );
     }
 
@@ -433,7 +466,7 @@ public class StreamsProducerTest {
         assertThat(thrown.getCause(), is(exception));
         assertThat(
             thrown.getMessage(),
-            is("Producer cannot send records anymore since it got fenced [test, eos];" +
+            is("Producer got fenced trying to send a record [test];" +
                    " it means all tasks belonging to this thread should be migrated.")
         );
     }
@@ -452,7 +485,7 @@ public class StreamsProducerTest {
         assertThat(thrown.getCause(), is(exception));
         assertThat(
             thrown.getMessage(),
-            is("Producer cannot send records anymore since it got fenced [test, eos];" +
+            is("Producer got fenced trying to send a record [test];" +
                    " it means all tasks belonging to this thread should be migrated.")
         );
     }
@@ -472,7 +505,7 @@ public class StreamsProducerTest {
         assertThat(thrown.getCause(), is(eosMockProducer.sendOffsetsToTransactionException));
         assertThat(
             thrown.getMessage(),
-            is("Producer get fenced trying to commit a transaction [test, eos];" +
+            is("Producer got fenced trying to commit a transaction [test];" +
                    " it means all tasks belonging to this thread should be migrated.")
         );
     }
@@ -491,7 +524,7 @@ public class StreamsProducerTest {
         assertThat(thrown.getCause(), is(eosMockProducer.sendOffsetsToTransactionException));
         assertThat(
             thrown.getMessage(),
-            is("Producer encounter unexpected error trying to commit a transaction [test, eos]")
+            is("Error encountered trying to commit a transaction [test]")
         );
     }
 
@@ -523,7 +556,7 @@ public class StreamsProducerTest {
         assertThat(thrown.getCause(), is(eosMockProducer.commitTransactionException));
         assertThat(
             thrown.getMessage(),
-            is("Producer get fenced trying to commit a transaction [test, eos];" +
+            is("Producer got fenced trying to commit a transaction [test];" +
                    " it means all tasks belonging to this thread should be migrated.")
         );
     }
@@ -540,7 +573,7 @@ public class StreamsProducerTest {
 
         assertThat(eosMockProducer.sentOffsets(), is(true));
         assertThat(thrown.getCause(), is(eosMockProducer.commitTransactionException));
-        assertThat(thrown.getMessage(), is("Timed out while committing a transaction [test, eos]"));
+        assertThat(thrown.getMessage(), is("Timed out trying to commit a transaction [test]"));
     }
 
     @Test
@@ -556,7 +589,7 @@ public class StreamsProducerTest {
         assertThat(thrown.getCause(), is(eosMockProducer.commitTransactionException));
         assertThat(
             thrown.getMessage(),
-            is("Producer encounter unexpected error trying to commit a transaction [test, eos]")
+            is("Error encountered trying to commit a transaction [test]")
         );
     }
 
@@ -585,7 +618,7 @@ public class StreamsProducerTest {
         replay(producer);
 
         final StreamsProducer streamsProducer =
-            new StreamsProducer(producer, true, logContext, "appId");
+            new StreamsProducer(producer, true, "appId", logContext);
         streamsProducer.initTransaction();
         // call `send()` to start a transaction
         streamsProducer.send(record, null);
@@ -606,7 +639,7 @@ public class StreamsProducerTest {
         assertThat(thrown.getCause(), is(eosMockProducer.abortTransactionException));
         assertThat(
             thrown.getMessage(),
-            is("Producer encounter unexpected error trying to abort a transaction [test, eos]")
+            is("Error encounter trying to abort a transaction [test]")
         );
     }
 
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
index 2f7681d..d8bca5e 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
@@ -17,13 +17,16 @@
 package org.apache.kafka.streams.processor.internals;
 
 import java.util.HashSet;
+import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.stream.Collectors;
 import org.apache.kafka.clients.admin.Admin;
 import org.apache.kafka.clients.admin.DeleteRecordsResult;
 import org.apache.kafka.clients.admin.DeletedRecords;
 import org.apache.kafka.clients.admin.RecordsToDelete;
+import org.apache.kafka.clients.consumer.CommitFailedException;
 import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.consumer.OffsetAndMetadata;
 import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.Metric;
 import org.apache.kafka.common.MetricName;
@@ -36,6 +39,7 @@ import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.errors.LockException;
+import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.errors.TaskMigratedException;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.TaskId;
@@ -50,12 +54,14 @@ import org.hamcrest.Matchers;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
+import org.junit.function.ThrowingRunnable;
 import org.junit.rules.TemporaryFolder;
 import org.junit.runner.RunWith;
 
 import java.io.File;
 import java.io.IOException;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.Deque;
 import java.util.HashMap;
 import java.util.LinkedList;
@@ -79,12 +85,14 @@ import static org.easymock.EasyMock.anyString;
 import static org.easymock.EasyMock.eq;
 import static org.easymock.EasyMock.expect;
 import static org.easymock.EasyMock.expectLastCall;
+import static org.easymock.EasyMock.mock;
 import static org.easymock.EasyMock.replay;
 import static org.easymock.EasyMock.reset;
 import static org.easymock.EasyMock.resetToStrict;
 import static org.easymock.EasyMock.verify;
 import static org.hamcrest.CoreMatchers.hasItem;
 import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.not;
 import static org.hamcrest.core.IsEqual.equalTo;
@@ -110,6 +118,10 @@ public class TaskManagerTest {
     private final TopicPartition t1p2 = new TopicPartition(topic1, 2);
     private final Set<TopicPartition> taskId02Partitions = mkSet(t1p2);
 
+    private final TaskId taskId03 = new TaskId(0, 3);
+    private final TopicPartition t1p3 = new TopicPartition(topic1, 3);
+    private final Set<TopicPartition> taskId03Partitions = mkSet(t1p3);
+
     private final TaskId taskId10 = new TaskId(1, 0);
 
     @Mock(type = MockType.STRICT)
@@ -144,7 +156,8 @@ public class TaskManagerTest {
             standbyTaskCreator,
             topologyBuilder,
             adminClient,
-            stateDirectory
+            stateDirectory,
+            false
         );
         taskManager.setMainConsumer(consumer);
     }
@@ -155,10 +168,8 @@ public class TaskManagerTest {
         final Map<TaskId, Set<TopicPartition>> assignment = mkMap(mkEntry(taskId01, mkSet(t1p1, newTopicPartition)));
 
         expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andReturn(emptyList()).anyTimes();
-
         topologyBuilder.addSubscribedTopicsFromAssignment(eq(asList(t1p1, newTopicPartition)), anyString());
         expectLastCall();
-
         replay(activeTaskCreator, topologyBuilder);
 
         taskManager.handleAssignment(assignment, emptyMap());
@@ -341,26 +352,37 @@ public class TaskManagerTest {
 
     @Test
     public void shouldCloseActiveUnassignedSuspendedTasksWhenClosingRevokedTasks() {
-        final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true);
+        final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true);
+        final Map<TopicPartition, OffsetAndMetadata> offsets = singletonMap(t1p0, new OffsetAndMetadata(0L, null));
+        task00.setCommittableOffsetsAndMetadata(offsets);
 
+        // first `handleAssignment`
         expectRestoreToBeCompleted(consumer, changeLogReader);
         expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andReturn(singletonList(task00)).anyTimes();
         expect(activeTaskCreator.createTasks(anyObject(), eq(emptyMap()))).andReturn(emptyList()).anyTimes();
         activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId00);
         expectLastCall();
         expect(standbyTaskCreator.createTasks(anyObject())).andReturn(emptyList()).anyTimes();
-
         topologyBuilder.addSubscribedTopicsFromAssignment(anyObject(), anyString());
         expectLastCall().anyTimes();
 
+        // `handleRevocation`
+        consumer.commitSync(offsets);
+        expectLastCall();
+
+        // second `handleAssignment`
+        consumer.commitSync(offsets);
+        expectLastCall();
+
         replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer, changeLogReader);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
-
         assertThat(taskManager.tryToCompleteRestoration(), is(true));
         assertThat(task00.state(), is(Task.State.RUNNING));
+
         taskManager.handleRevocation(taskId00Partitions);
         assertThat(task00.state(), is(Task.State.SUSPENDED));
+
         taskManager.handleAssignment(emptyMap(), emptyMap());
         assertThat(task00.state(), is(Task.State.CLOSED));
         assertThat(taskManager.activeTaskMap(), Matchers.anEmptyMap());
@@ -368,27 +390,61 @@ public class TaskManagerTest {
     }
 
     @Test
-    public void shouldCloseActiveTasksWhenHandlingLostTasks() {
-        final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true);
-        final Task task01 = new StateMachineTask(taskId01, taskId01Partitions, false);
+    public void shouldCloseDirtyActiveUnassignedSuspendedTasksWhenErrorCommittingRevokedTask() {
+        final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true) {
+            @Override
+            public Map<TopicPartition, OffsetAndMetadata> committableOffsetsAndMetadata() {
+                throw new RuntimeException("KABOOM!");
+            }
+        };
 
+        // first `handleAssignment`
         expectRestoreToBeCompleted(consumer, changeLogReader);
         expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andReturn(singletonList(task00)).anyTimes();
         expect(activeTaskCreator.createTasks(anyObject(), eq(emptyMap()))).andReturn(emptyList()).anyTimes();
         activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId00);
         expectLastCall();
-        expect(standbyTaskCreator.createTasks(eq(taskId01Assignment))).andReturn(singletonList(task01)).anyTimes();
+        expect(standbyTaskCreator.createTasks(anyObject())).andReturn(emptyList()).anyTimes();
+        topologyBuilder.addSubscribedTopicsFromAssignment(anyObject(), anyString());
+        expectLastCall().anyTimes();
+
+        replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer, changeLogReader);
+
+        taskManager.handleAssignment(taskId00Assignment, emptyMap());
+
+        final RuntimeException thrown = assertThrows(
+            RuntimeException.class,
+            () -> taskManager.handleAssignment(emptyMap(), emptyMap())
+        );
+
+        assertThat(task00.state(), is(Task.State.CLOSED));
+        assertThat(thrown.getMessage(), is("Unexpected failure to close 1 task(s) [[0_0]]. First unexpected exception (for task 0_0) follows."));
+        assertThat(thrown.getCause().getMessage(), is("KABOOM!"));
+    }
 
+    @Test
+    public void shouldCloseActiveTasksWhenHandlingLostTasks() {
+        final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true);
+        final Task task01 = new StateMachineTask(taskId01, taskId01Partitions, false);
+
+        // `handleAssignment`
+        expectRestoreToBeCompleted(consumer, changeLogReader);
+        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andReturn(singletonList(task00)).anyTimes();
+        expect(standbyTaskCreator.createTasks(eq(taskId01Assignment))).andReturn(singletonList(task01)).anyTimes();
         topologyBuilder.addSubscribedTopicsFromAssignment(anyObject(), anyString());
         expectLastCall().anyTimes();
 
+        // `handleLostAll`
+        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId00);
+        expectLastCall();
+
         replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer, changeLogReader);
 
         taskManager.handleAssignment(taskId00Assignment, taskId01Assignment);
-
         assertThat(taskManager.tryToCompleteRestoration(), is(true));
         assertThat(task00.state(), is(Task.State.RUNNING));
         assertThat(task01.state(), is(Task.State.RUNNING));
+
         taskManager.handleLostAll();
         assertThat(task00.state(), is(Task.State.CLOSED));
         assertThat(task01.state(), is(Task.State.RUNNING));
@@ -397,32 +453,65 @@ public class TaskManagerTest {
     }
 
     @Test
+    public void shouldThrowWhenHandlingClosingTasksOnProducerCloseError() {
+        final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true);
+        final Map<TopicPartition, OffsetAndMetadata> offsets = singletonMap(t1p0, new OffsetAndMetadata(0L, null));
+        task00.setCommittableOffsetsAndMetadata(offsets);
+
+        // `handleAssignment`
+        expectRestoreToBeCompleted(consumer, changeLogReader);
+        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andReturn(singletonList(task00)).anyTimes();
+        expect(standbyTaskCreator.createTasks(anyObject())).andReturn(emptyList()).anyTimes();
+        topologyBuilder.addSubscribedTopicsFromAssignment(anyObject(), anyString());
+        expectLastCall().anyTimes();
+
+        // `handleAssignment`
+        consumer.commitSync(offsets);
+        expectLastCall();
+        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId00);
+        expectLastCall().andThrow(new RuntimeException("KABOOM!"));
+
+        replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer, changeLogReader);
+
+        taskManager.handleAssignment(taskId00Assignment, emptyMap());
+        assertThat(taskManager.tryToCompleteRestoration(), is(true));
+        assertThat(task00.state(), is(Task.State.RUNNING));
+
+        final RuntimeException thrown = assertThrows(
+            RuntimeException.class,
+            () -> taskManager.handleAssignment(emptyMap(), emptyMap())
+        );
+
+        assertThat(thrown.getMessage(), is("Unexpected failure to close 1 task(s) [[0_0]]. First unexpected exception (for task 0_0) follows."));
+        assertThat(thrown.getCause(), instanceOf(RuntimeException.class));
+        assertThat(thrown.getCause().getMessage(), is("KABOOM!"));
+    }
+
+    @Test
     public void shouldReviveCorruptTasks() {
         final ProcessorStateManager stateManager = EasyMock.createStrictMock(ProcessorStateManager.class);
         stateManager.markChangelogAsCorrupted(taskId00Partitions);
         replay(stateManager);
+
         final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true, stateManager);
 
+        // `handleAssignment`
         expectRestoreToBeCompleted(consumer, changeLogReader);
         expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andReturn(singletonList(task00)).anyTimes();
-        expect(activeTaskCreator.createTasks(anyObject(), eq(emptyMap()))).andReturn(emptyList()).anyTimes();
-        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId00);
-        expectLastCall();
-        expect(standbyTaskCreator.createTasks(anyObject())).andReturn(emptyList()).anyTimes();
-
         topologyBuilder.addSubscribedTopicsFromAssignment(anyObject(), anyString());
         expectLastCall().anyTimes();
 
-        replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer, changeLogReader);
+        replay(activeTaskCreator, topologyBuilder, consumer, changeLogReader);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
-
         assertThat(taskManager.tryToCompleteRestoration(), is(true));
         assertThat(task00.state(), is(Task.State.RUNNING));
+
         taskManager.handleCorruption(singletonMap(taskId00, taskId00Partitions));
         assertThat(task00.state(), is(Task.State.CREATED));
         assertThat(taskManager.activeTaskMap(), is(singletonMap(taskId00, task00)));
         assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
+
         verify(stateManager);
     }
 
@@ -431,33 +520,30 @@ public class TaskManagerTest {
         final ProcessorStateManager stateManager = EasyMock.createStrictMock(ProcessorStateManager.class);
         stateManager.markChangelogAsCorrupted(taskId00Partitions);
         replay(stateManager);
+
         final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true, stateManager) {
             @Override
-            public void closeClean() {
+            public Map<TopicPartition, Long> prepareCloseClean() {
                 throw new RuntimeException("oops");
             }
         };
 
         expectRestoreToBeCompleted(consumer, changeLogReader);
         expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andReturn(singletonList(task00)).anyTimes();
-        expect(activeTaskCreator.createTasks(anyObject(), eq(emptyMap()))).andReturn(emptyList()).anyTimes();
-        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId00);
-        expectLastCall();
-        expect(standbyTaskCreator.createTasks(anyObject())).andReturn(emptyList()).anyTimes();
-
         topologyBuilder.addSubscribedTopicsFromAssignment(anyObject(), anyString());
         expectLastCall().anyTimes();
 
-        replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer, changeLogReader);
+        replay(activeTaskCreator, topologyBuilder, consumer, changeLogReader);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
-
         assertThat(taskManager.tryToCompleteRestoration(), is(true));
         assertThat(task00.state(), is(Task.State.RUNNING));
+
         taskManager.handleCorruption(singletonMap(taskId00, taskId00Partitions));
         assertThat(task00.state(), is(Task.State.CREATED));
         assertThat(taskManager.activeTaskMap(), is(singletonMap(taskId00, task00)));
         assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
+
         verify(stateManager);
     }
 
@@ -467,10 +553,14 @@ public class TaskManagerTest {
 
         expectRestoreToBeCompleted(consumer, changeLogReader);
         expect(standbyTaskCreator.createTasks(eq(taskId00Assignment))).andReturn(singletonList(task00)).anyTimes();
+        consumer.commitSync(Collections.emptyMap());
+        expectLastCall();
         replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
+
         taskManager.handleAssignment(emptyMap(), taskId00Assignment);
         assertThat(taskManager.tryToCompleteRestoration(), is(true));
         assertThat(task00.state(), is(Task.State.RUNNING));
+
         taskManager.handleAssignment(emptyMap(), emptyMap());
         assertThat(task00.state(), is(Task.State.CLOSED));
         assertThat(taskManager.activeTaskMap(), Matchers.anEmptyMap());
@@ -548,6 +638,8 @@ public class TaskManagerTest {
             }
         };
 
+        consumer.commitSync(Collections.emptyMap());
+        expectLastCall();
         expect(changeLogReader.completedChangelogs()).andReturn(emptySet());
         expect(consumer.assignment()).andReturn(emptySet());
         consumer.resume(eq(emptySet()));
@@ -587,6 +679,8 @@ public class TaskManagerTest {
             }
         };
 
+        consumer.commitSync(Collections.emptyMap());
+        expectLastCall();
         expect(changeLogReader.completedChangelogs()).andReturn(emptySet());
         expect(consumer.assignment()).andReturn(emptySet());
         consumer.resume(eq(emptySet()));
@@ -614,12 +708,17 @@ public class TaskManagerTest {
 
     @Test
     public void shouldSuspendActiveTasks() {
-        final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true);
+        final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true);
+        final Map<TopicPartition, OffsetAndMetadata> offsets = singletonMap(t1p0, new OffsetAndMetadata(0L, null));
+        task00.setCommittableOffsetsAndMetadata(offsets);
 
         expectRestoreToBeCompleted(consumer, changeLogReader);
         expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andReturn(singletonList(task00));
+        consumer.commitSync(offsets);
+        expectLastCall();
 
         replay(activeTaskCreator, consumer, changeLogReader);
+
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(taskManager.tryToCompleteRestoration(), is(true));
         assertThat(task00.state(), is(Task.State.RUNNING));
@@ -629,10 +728,24 @@ public class TaskManagerTest {
     }
 
     @Test
-    public void shouldPassUpIfExceptionDuringSuspend() {
+    public void shouldNotCommitCreatedTasksOnSuspend() {
+        final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true);
+
+        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andReturn(singletonList(task00));
+        replay(activeTaskCreator, consumer, changeLogReader);
+
+        taskManager.handleAssignment(taskId00Assignment, emptyMap());
+        assertThat(task00.state(), is(Task.State.CREATED));
+
+        taskManager.handleRevocation(taskId00Partitions);
+        assertThat(task00.state(), is(Task.State.CREATED));
+    }
+
+    @Test
+    public void shouldPassUpIfExceptionDuringPrepareSuspend() {
         final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true) {
             @Override
-            public void suspend() {
+            public void prepareSuspend() {
                 throw new RuntimeException("KABOOM!");
             }
         };
@@ -647,6 +760,8 @@ public class TaskManagerTest {
 
         assertThrows(RuntimeException.class, () -> taskManager.handleRevocation(taskId00Partitions));
         assertThat(task00.state(), is(Task.State.RUNNING));
+
+        verify(consumer);
     }
 
     @Test
@@ -655,7 +770,8 @@ public class TaskManagerTest {
         final Map<TaskId, Set<TopicPartition>> assignment = mkMap(
             mkEntry(taskId00, taskId00Partitions),
             mkEntry(taskId01, taskId01Partitions),
-            mkEntry(taskId02, taskId02Partitions)
+            mkEntry(taskId02, taskId02Partitions),
+            mkEntry(taskId03, taskId03Partitions)
         );
         final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true) {
             @Override
@@ -663,17 +779,65 @@ public class TaskManagerTest {
                 return singletonList(changelog);
             }
         };
+        final AtomicBoolean prepareClosedDirtyTask01 = new AtomicBoolean(false);
+        final AtomicBoolean prepareClosedDirtyTask02 = new AtomicBoolean(false);
+        final AtomicBoolean prepareClosedDirtyTask03 = new AtomicBoolean(false);
+        final AtomicBoolean closedDirtyTask01 = new AtomicBoolean(false);
+        final AtomicBoolean closedDirtyTask02 = new AtomicBoolean(false);
+        final AtomicBoolean closedDirtyTask03 = new AtomicBoolean(false);
         final Task task01 = new StateMachineTask(taskId01, taskId01Partitions, true) {
             @Override
-            public void closeClean() {
+            public Map<TopicPartition, Long> prepareCloseClean() {
                 throw new TaskMigratedException("migrated", new RuntimeException("cause"));
             }
+
+            @Override
+            public void prepareCloseDirty() {
+                super.prepareCloseDirty();
+                prepareClosedDirtyTask01.set(true);
+            }
+
+            @Override
+            public void closeDirty() {
+                super.closeDirty();
+                closedDirtyTask01.set(true);
+            }
         };
         final Task task02 = new StateMachineTask(taskId02, taskId02Partitions, true) {
             @Override
-            public void closeClean() {
+            public Map<TopicPartition, Long> prepareCloseClean() {
                 throw new RuntimeException("oops");
             }
+
+            @Override
+            public void prepareCloseDirty() {
+                super.prepareCloseDirty();
+                prepareClosedDirtyTask02.set(true);
+            }
+
+            @Override
+            public void closeDirty() {
+                super.closeDirty();
+                closedDirtyTask02.set(true);
+            }
+        };
+        final Task task03 = new StateMachineTask(taskId03, taskId03Partitions, true) {
+            @Override
+            public Map<TopicPartition, OffsetAndMetadata> committableOffsetsAndMetadata() {
+                throw new RuntimeException("oops");
+            }
+
+            @Override
+            public void prepareCloseDirty() {
+                super.prepareCloseDirty();
+                prepareClosedDirtyTask03.set(true);
+            }
+
+            @Override
+            public void closeDirty() {
+                super.closeDirty();
+                closedDirtyTask03.set(true);
+            }
         };
 
         resetToStrict(changeLogReader);
@@ -683,13 +847,16 @@ public class TaskManagerTest {
         // make sure we also remove the changelog partitions from the changelog reader
         changeLogReader.remove(eq(singletonList(changelog)));
         expectLastCall();
-        expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andReturn(asList(task00, task01, task02)).anyTimes();
+        expect(activeTaskCreator.createTasks(anyObject(), eq(assignment)))
+            .andReturn(asList(task00, task01, task02, task03)).anyTimes();
         activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId00));
         expectLastCall();
         activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId01));
         expectLastCall();
         activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId02));
         expectLastCall();
+        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId03));
+        expectLastCall();
         activeTaskCreator.closeThreadProducerIfNeeded();
         expectLastCall();
         expect(standbyTaskCreator.createTasks(eq(emptyMap()))).andReturn(emptyList()).anyTimes();
@@ -700,29 +867,42 @@ public class TaskManagerTest {
         assertThat(task00.state(), is(Task.State.CREATED));
         assertThat(task01.state(), is(Task.State.CREATED));
         assertThat(task02.state(), is(Task.State.CREATED));
+        assertThat(task03.state(), is(Task.State.CREATED));
 
         taskManager.tryToCompleteRestoration();
 
         assertThat(task00.state(), is(Task.State.RESTORING));
         assertThat(task01.state(), is(Task.State.RUNNING));
         assertThat(task02.state(), is(Task.State.RUNNING));
+        assertThat(task03.state(), is(Task.State.RUNNING));
         assertThat(
             taskManager.activeTaskMap(),
             Matchers.equalTo(
                 mkMap(
                     mkEntry(taskId00, task00),
                     mkEntry(taskId01, task01),
-                    mkEntry(taskId02, task02)
+                    mkEntry(taskId02, task02),
+                    mkEntry(taskId03, task03)
                 )
             )
         );
         assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
 
-        final RuntimeException exception = assertThrows(RuntimeException.class, () -> taskManager.shutdown(true));
+        final RuntimeException exception = assertThrows(
+            RuntimeException.class,
+            () -> taskManager.shutdown(true)
+        );
 
+        assertThat(prepareClosedDirtyTask01.get(), is(true));
+        assertThat(closedDirtyTask01.get(), is(true));
+        assertThat(prepareClosedDirtyTask02.get(), is(true));
+        assertThat(closedDirtyTask02.get(), is(true));
+        assertThat(prepareClosedDirtyTask03.get(), is(true));
+        assertThat(closedDirtyTask03.get(), is(true));
         assertThat(task00.state(), is(Task.State.CLOSED));
         assertThat(task01.state(), is(Task.State.CLOSED));
         assertThat(task02.state(), is(Task.State.CLOSED));
+        assertThat(task03.state(), is(Task.State.CLOSED));
         assertThat(exception.getMessage(), is("Unexpected exception while closing task"));
         assertThat(exception.getCause().getMessage(), is("oops"));
         assertThat(taskManager.activeTaskMap(), Matchers.anEmptyMap());
@@ -737,12 +917,14 @@ public class TaskManagerTest {
         final Map<TaskId, Set<TopicPartition>> assignment = mkMap(
             mkEntry(taskId00, taskId00Partitions)
         );
-        final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true) {
+        final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true) {
             @Override
             public Collection<TopicPartition> changelogPartitions() {
                 return singletonList(changelog);
             }
         };
+        final Map<TopicPartition, OffsetAndMetadata> offsets = singletonMap(t1p0, new OffsetAndMetadata(0L, null));
+        task00.setCommittableOffsetsAndMetadata(offsets);
 
         resetToStrict(changeLogReader);
         changeLogReader.transitToRestoreActive();
@@ -859,13 +1041,13 @@ public class TaskManagerTest {
         };
         final Task task01 = new StateMachineTask(taskId01, taskId01Partitions, true) {
             @Override
-            public void closeClean() {
+            public Map<TopicPartition, Long> prepareCloseClean() {
                 throw new TaskMigratedException("migrated", new RuntimeException("cause"));
             }
         };
         final Task task02 = new StateMachineTask(taskId02, taskId02Partitions, true) {
             @Override
-            public void closeClean() {
+            public Map<TopicPartition, Long> prepareCloseClean() {
                 throw new RuntimeException("oops");
             }
         };
@@ -928,28 +1110,32 @@ public class TaskManagerTest {
         final Map<TaskId, Set<TopicPartition>> assignment = singletonMap(taskId00, taskId00Partitions);
         final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, false);
 
+        // `handleAssignment`
+        expect(standbyTaskCreator.createTasks(eq(assignment))).andReturn(singletonList(task00)).anyTimes();
+
+        // `tryToCompleteRestoration`
         expect(changeLogReader.completedChangelogs()).andReturn(emptySet());
         expect(consumer.assignment()).andReturn(emptySet());
         consumer.resume(eq(emptySet()));
         expectLastCall();
-        expect(activeTaskCreator.createTasks(anyObject(), eq(emptyMap()))).andReturn(emptyList()).anyTimes();
+
+        // `shutdown`
+        consumer.commitSync(Collections.emptyMap());
+        expectLastCall();
         activeTaskCreator.closeThreadProducerIfNeeded();
         expectLastCall();
-        expect(standbyTaskCreator.createTasks(eq(assignment))).andReturn(singletonList(task00)).anyTimes();
+
         replay(consumer, activeTaskCreator, standbyTaskCreator, changeLogReader);
 
         taskManager.handleAssignment(emptyMap(), assignment);
-
         assertThat(task00.state(), is(Task.State.CREATED));
 
         taskManager.tryToCompleteRestoration();
-
         assertThat(task00.state(), is(Task.State.RUNNING));
         assertThat(taskManager.activeTaskMap(), Matchers.anEmptyMap());
         assertThat(taskManager.standbyTaskMap(), Matchers.equalTo(singletonMap(taskId00, task00)));
 
         taskManager.shutdown(true);
-
         assertThat(task00.state(), is(Task.State.CLOSED));
         assertThat(taskManager.activeTaskMap(), Matchers.anEmptyMap());
         assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
@@ -1011,6 +1197,8 @@ public class TaskManagerTest {
     @Test
     public void shouldCommitActiveAndStandbyTasks() {
         final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true);
+        final Map<TopicPartition, OffsetAndMetadata> offsets = singletonMap(t1p0, new OffsetAndMetadata(0L, null));
+        task00.setCommittableOffsetsAndMetadata(offsets);
         final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, false);
 
         expectRestoreToBeCompleted(consumer, changeLogReader);
@@ -1018,6 +1206,8 @@ public class TaskManagerTest {
             .andReturn(singletonList(task00)).anyTimes();
         expect(standbyTaskCreator.createTasks(eq(taskId01Assignment)))
             .andReturn(singletonList(task01)).anyTimes();
+        consumer.commitSync(offsets);
+        expectLastCall();
 
         replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
 
@@ -1031,6 +1221,8 @@ public class TaskManagerTest {
         task01.setCommitNeeded();
 
         assertThat(taskManager.commitAll(), equalTo(2));
+        assertThat(task00.commitNeeded, is(false));
+        assertThat(task01.commitNeeded, is(false));
     }
 
     @Test
@@ -1071,10 +1263,61 @@ public class TaskManagerTest {
     }
 
     @Test
+    public void shouldCommitViaConsumerIfEosDisabled() {
+        final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true);
+        final Map<TopicPartition, OffsetAndMetadata> offsets = singletonMap(t1p1, new OffsetAndMetadata(0L, null));
+        task01.setCommittableOffsetsAndMetadata(offsets);
+        task01.setCommitNeeded();
+        taskManager.tasks().put(taskId01, task01);
+
+        consumer.commitSync(offsets);
+        expectLastCall();
+        replay(consumer);
+
+        taskManager.commitAll();
+
+        verify(consumer);
+    }
+
+    @Test
+    public void shouldCommitViaProducerIfEosEnabled() {
+        final StreamsProducer producer = mock(StreamsProducer.class);
+        final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(new Metrics(), "clientId", StreamsConfig.METRICS_LATEST);
+        taskManager = new TaskManager(
+            changeLogReader,
+            UUID.randomUUID(),
+            "taskManagerTest",
+            streamsMetrics,
+            activeTaskCreator,
+            standbyTaskCreator,
+            topologyBuilder,
+            adminClient,
+            stateDirectory,
+            true
+        );
+        taskManager.setMainConsumer(consumer);
+
+        final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true);
+        final Map<TopicPartition, OffsetAndMetadata> offsets = singletonMap(t1p1, new OffsetAndMetadata(0L, null));
+        task01.setCommittableOffsetsAndMetadata(offsets);
+        task01.setCommitNeeded();
+        taskManager.tasks().put(taskId01, task01);
+
+        expect(activeTaskCreator.streamsProducerForTask(taskId01)).andReturn(producer);
+        producer.commitTransaction(offsets);
+        expectLastCall();
+        replay(activeTaskCreator, producer);
+
+        taskManager.commitAll();
+
+        verify(producer);
+    }
+
+    @Test
     public void shouldPropagateExceptionFromActiveCommit() {
         final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true) {
             @Override
-            public void commit() {
+            public void prepareCommit() {
                 throw new RuntimeException("opsh.");
             }
         };
@@ -1101,7 +1344,7 @@ public class TaskManagerTest {
     public void shouldPropagateExceptionFromStandbyCommit() {
         final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, false) {
             @Override
-            public void commit() {
+            public void prepareCommit() {
                 throw new RuntimeException("opsh.");
             }
         };
@@ -1231,28 +1474,46 @@ public class TaskManagerTest {
     @Test
     public void shouldMaybeCommitActiveTasks() {
         final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true);
+        final Map<TopicPartition, OffsetAndMetadata> offsets0 = singletonMap(t1p0, new OffsetAndMetadata(0L, null));
+        task00.setCommittableOffsetsAndMetadata(offsets0);
         final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true);
+        final Map<TopicPartition, OffsetAndMetadata> offsets1 = singletonMap(t1p1, new OffsetAndMetadata(1L, null));
+        task01.setCommittableOffsetsAndMetadata(offsets1);
         final StateMachineTask task02 = new StateMachineTask(taskId02, taskId02Partitions, true);
+        final Map<TopicPartition, OffsetAndMetadata> offsets2 = singletonMap(t1p2, new OffsetAndMetadata(2L, null));
+        task02.setCommittableOffsetsAndMetadata(offsets2);
+        final StateMachineTask task03 = new StateMachineTask(taskId03, taskId03Partitions, false);
+        final Map<TopicPartition, OffsetAndMetadata> offsets3 = singletonMap(t1p3, new OffsetAndMetadata(3L, null));
+        task03.setCommittableOffsetsAndMetadata(offsets3);
 
-        final Map<TaskId, Set<TopicPartition>> assignment = mkMap(
+        final Map<TaskId, Set<TopicPartition>> assignmentActive = mkMap(
             mkEntry(taskId00, taskId00Partitions),
             mkEntry(taskId01, taskId01Partitions),
             mkEntry(taskId02, taskId02Partitions)
         );
 
+        final Map<TaskId, Set<TopicPartition>> assignmentStandby = mkMap(
+            mkEntry(taskId03, taskId03Partitions)
+        );
+
         expectRestoreToBeCompleted(consumer, changeLogReader);
-        expect(activeTaskCreator.createTasks(anyObject(), eq(assignment)))
+        expect(activeTaskCreator.createTasks(anyObject(), eq(assignmentActive)))
             .andReturn(asList(task00, task01, task02)).anyTimes();
+        expect(standbyTaskCreator.createTasks(eq(assignmentStandby)))
+            .andReturn(singletonList(task03)).anyTimes();
+        consumer.commitSync(offsets0);
+        expectLastCall();
 
         replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
 
 
-        taskManager.handleAssignment(assignment, emptyMap());
+        taskManager.handleAssignment(assignmentActive, assignmentStandby);
         assertThat(taskManager.tryToCompleteRestoration(), is(true));
 
         assertThat(task00.state(), is(Task.State.RUNNING));
         assertThat(task01.state(), is(Task.State.RUNNING));
         assertThat(task02.state(), is(Task.State.RUNNING));
+        assertThat(task03.state(), is(Task.State.RUNNING));
 
         task00.setCommitNeeded();
         task00.setCommitRequested();
@@ -1261,6 +1522,9 @@ public class TaskManagerTest {
 
         task02.setCommitRequested();
 
+        task03.setCommitNeeded();
+        task03.setCommitRequested();
+
         assertThat(taskManager.maybeCommitActiveTasksPerUserRequested(), equalTo(1));
     }
 
@@ -1450,9 +1714,13 @@ public class TaskManagerTest {
         final LogCaptureAppender appender = LogCaptureAppender.createAndRegister();
 
         final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true);
+        final Map<TopicPartition, OffsetAndMetadata> offsets = singletonMap(t1p0, new OffsetAndMetadata(0L, null));
+        task00.setCommittableOffsetsAndMetadata(offsets);
 
         expectRestoreToBeCompleted(consumer, changeLogReader);
         expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andReturn(singletonList(task00));
+        consumer.commitSync(offsets);
+        expectLastCall();
 
         replay(activeTaskCreator, consumer, changeLogReader);
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
@@ -1473,14 +1741,14 @@ public class TaskManagerTest {
     public void shouldThrowTaskMigratedWhenAllTaskCloseExceptionsAreTaskMigrated() {
         final StateMachineTask migratedTask01 = new StateMachineTask(taskId01, taskId01Partitions, false) {
             @Override
-            public void closeClean() {
+            public Map<TopicPartition, Long> prepareCloseClean() {
                 throw new TaskMigratedException("t1 close exception", new RuntimeException());
             }
         };
 
         final StateMachineTask migratedTask02 = new StateMachineTask(taskId02, taskId02Partitions, false) {
             @Override
-            public void closeClean() {
+            public Map<TopicPartition, Long> prepareCloseClean() {
                 throw new TaskMigratedException("t2 close exception", new RuntimeException());
             }
         };
@@ -1500,14 +1768,14 @@ public class TaskManagerTest {
     public void shouldThrowRuntimeExceptionWhenEncounteredUnknownExceptionDuringTaskClose() {
         final StateMachineTask migratedTask01 = new StateMachineTask(taskId01, taskId01Partitions, false) {
             @Override
-            public void closeClean() {
+            public Map<TopicPartition, Long> prepareCloseClean() {
                 throw new TaskMigratedException("t1 close exception", new RuntimeException());
             }
         };
 
         final StateMachineTask migratedTask02 = new StateMachineTask(taskId02, taskId02Partitions, false) {
             @Override
-            public void closeClean() {
+            public Map<TopicPartition, Long> prepareCloseClean() {
                 throw new IllegalStateException("t2 illegal state exception", new RuntimeException());
             }
         };
@@ -1529,14 +1797,14 @@ public class TaskManagerTest {
     public void shouldThrowSameKafkaExceptionWhenEncounteredDuringTaskClose() {
         final StateMachineTask migratedTask01 = new StateMachineTask(taskId01, taskId01Partitions, false) {
             @Override
-            public void closeClean() {
+            public Map<TopicPartition, Long> prepareCloseClean() {
                 throw new TaskMigratedException("t1 close exception", new RuntimeException());
             }
         };
 
         final StateMachineTask migratedTask02 = new StateMachineTask(taskId02, taskId02Partitions, false) {
             @Override
-            public void closeClean() {
+            public Map<TopicPartition, Long> prepareCloseClean() {
                 throw new KafkaException("Kaboom for t2!", new RuntimeException());
             }
         };
@@ -1645,8 +1913,125 @@ public class TaskManagerTest {
         consumer.pause(assignment);
     }
 
+    @Test
+    public void shouldThrowTaskMigratedExceptionOnCommitFailed() {
+        final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true);
+        task01.setCommitNeeded();
+        taskManager.tasks().put(taskId01, task01);
+
+        consumer.commitSync(Collections.emptyMap());
+        expectLastCall().andThrow(new CommitFailedException());
+        replay(consumer);
+
+        final TaskMigratedException thrown = assertThrows(
+            TaskMigratedException.class,
+            () -> taskManager.commitAll()
+        );
+
+        assertThat(thrown.getCause(), instanceOf(CommitFailedException.class));
+        assertThat(thrown.getMessage(), equalTo("Consumer committing offsets failed, indicating the corresponding thread is no longer part of the group; it means all tasks belonging to this thread should be migrated."));
+        assertThat(task01.state(), is(Task.State.CREATED));
+    }
+
+    @Test
+    public void shouldThrowStreamsExceptionOnCommitTimeout() {
+        final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true);
+        task01.setCommitNeeded();
+        taskManager.tasks().put(taskId01, task01);
+
+        consumer.commitSync(Collections.emptyMap());
+        expectLastCall().andThrow(new TimeoutException());
+        replay(consumer);
+
+        final StreamsException thrown = assertThrows(
+            StreamsException.class,
+            () -> taskManager.commitAll()
+        );
+
+        assertThat(thrown.getCause(), instanceOf(TimeoutException.class));
+        assertThat(thrown.getMessage(), equalTo("Timed out while committing offsets via consumer"));
+        assertThat(task01.state(), is(Task.State.CREATED));
+    }
+
+    @Test
+    public void shouldStreamsExceptionOnCommitError() {
+        final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true);
+        task01.setCommitNeeded();
+        taskManager.tasks().put(taskId01, task01);
+
+        consumer.commitSync(Collections.emptyMap());
+        expectLastCall().andThrow(new KafkaException());
+        replay(consumer);
+
+        final StreamsException thrown = assertThrows(
+            StreamsException.class,
+            () -> taskManager.commitAll()
+        );
+
+        assertThat(thrown.getCause(), instanceOf(KafkaException.class));
+        assertThat(thrown.getMessage(), equalTo("Error encountered committing offsets via consumer"));
+        assertThat(task01.state(), is(Task.State.CREATED));
+    }
+
+    @Test
+    public void shouldFailOnCommitFatal() {
+        final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true);
+        task01.setCommitNeeded();
+        taskManager.tasks().put(taskId01, task01);
+
+        consumer.commitSync(Collections.emptyMap());
+        expectLastCall().andThrow(new RuntimeException("KABOOM"));
+        replay(consumer);
+
+        final RuntimeException thrown = assertThrows(
+            RuntimeException.class,
+            () -> taskManager.commitAll()
+        );
+
+        assertThat(thrown.getMessage(), equalTo("KABOOM"));
+        assertThat(task01.state(), is(Task.State.CREATED));
+    }
+
+    @Test
+    public void shouldNotCloseTasksIfCommittingFailsDuringAssignment() {
+        shouldNotCloseTaskIfCommitFailsDuringAction(() -> taskManager.handleAssignment(Collections.emptyMap(), Collections.emptyMap()));
+    }
+
+    @Test
+    public void shouldNotCloseTasksIfCommittingFailsDuringRevocation() {
+        shouldNotCloseTaskIfCommitFailsDuringAction(() -> taskManager.handleRevocation(singletonList(t1p0)));
+    }
+
+    @Test
+    public void shouldNotCloseTasksIfCommittingFailsDuringShutdown() {
+        shouldNotCloseTaskIfCommitFailsDuringAction(() -> taskManager.shutdown(true));
+    }
+
+    private void shouldNotCloseTaskIfCommitFailsDuringAction(final ThrowingRunnable action) {
+        final Map<TopicPartition, OffsetAndMetadata> offsets = singletonMap(t1p0, new OffsetAndMetadata(0L, null));
+        final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true) {
+            @Override
+            public Map<TopicPartition, OffsetAndMetadata> committableOffsetsAndMetadata() {
+                return offsets;
+            }
+        };
+
+        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment)))
+            .andReturn(singletonList(task00));
+        consumer.commitSync(offsets);
+        expectLastCall().andThrow(new RuntimeException("KABOOM!"));
+        replay(activeTaskCreator, consumer);
+
+        taskManager.handleAssignment(taskId00Assignment, Collections.emptyMap());
+
+        final RuntimeException thrown =  assertThrows(RuntimeException.class, action);
+
+        assertThat(thrown.getMessage(), is("KABOOM!"));
+        assertThat(task00.state(), is(Task.State.CREATED));
+    }
+
     private static void expectRestoreToBeCompleted(final Consumer<byte[], byte[]> consumer,
-                                                   final ChangelogReader changeLogReader) {
+                                                             final ChangelogReader changeLogReader) {
         final Set<TopicPartition> assignment = singleton(new TopicPartition("assignment", 0));
         expect(consumer.assignment()).andReturn(assignment);
         consumer.resume(assignment);
@@ -1683,6 +2068,7 @@ public class TaskManagerTest {
         private final boolean active;
         private boolean commitNeeded = false;
         private boolean commitRequested = false;
+        private Map<TopicPartition, OffsetAndMetadata> committableOffsets = Collections.emptyMap();
         private Map<TopicPartition, Long> purgeableOffsets;
         private Map<TopicPartition, Long> changelogOffsets;
         private Map<TopicPartition, LinkedList<ConsumerRecord<byte[], byte[]>>> queue = new HashMap<>();
@@ -1732,7 +2118,15 @@ public class TaskManagerTest {
         }
 
         @Override
-        public void commit() {}
+        public void prepareCommit(){}
+
+        @Override
+        public void postCommit() {
+            commitNeeded = false;
+        }
+
+        @Override
+        public void prepareSuspend() {}
 
         @Override
         public void suspend() {
@@ -1747,17 +2141,32 @@ public class TaskManagerTest {
         }
 
         @Override
-        public void closeClean() {
-            transitionTo(State.CLOSING);
+        public Map<TopicPartition, Long> prepareCloseClean() {
+            return Collections.emptyMap();
+        }
+
+        @Override
+        public void prepareCloseDirty() {}
+
+        @Override
+        public void closeClean(final Map<TopicPartition, Long> checkpoint) {
             transitionTo(State.CLOSED);
         }
 
         @Override
         public void closeDirty() {
-            transitionTo(State.CLOSING);
             transitionTo(State.CLOSED);
         }
 
+        void setCommittableOffsetsAndMetadata(final Map<TopicPartition, OffsetAndMetadata> committableOffsets) {
+            this.committableOffsets = committableOffsets;
+        }
+
+        @Override
+        public Map<TopicPartition, OffsetAndMetadata> committableOffsetsAndMetadata() {
+            return committableOffsets;
+        }
+
         @Override
         public StateStore getStore(final String name) {
             return null;
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java b/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java
index 1c8bf98..b212f37 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java
@@ -200,10 +200,8 @@ public class KeyValueStoreTestDriver<K, V> {
         final RecordCollector recordCollector = new RecordCollectorImpl(
             logContext,
             new TaskId(0, 0),
-            consumer,
-            new StreamsProducer(producer, false, logContext, null),
+            new StreamsProducer(producer, false, null, logContext),
             new DefaultProductionExceptionHandler(),
-            false,
             new MockStreamsMetrics(new Metrics())
         ) {
             @Override
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java
index 2d7db43..9b03fc0 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java
@@ -374,10 +374,13 @@ public class StreamThreadStateStoreProviderTest {
         final RecordCollector recordCollector = new RecordCollectorImpl(
             logContext,
             taskId,
-            clientSupplier.consumer,
-            new StreamsProducer(clientSupplier.getProducer(new HashMap<>()), eosEnabled, logContext, streamsConfig.getString(StreamsConfig.APPLICATION_ID_CONFIG)),
+            new StreamsProducer(
+                clientSupplier.getProducer(new HashMap<>()),
+                eosEnabled,
+                streamsConfig.getString(StreamsConfig.APPLICATION_ID_CONFIG),
+                logContext
+            ),
             streamsConfig.defaultProductionExceptionHandler(),
-            eosEnabled,
             new MockStreamsMetrics(metrics));
         return new StreamTask(
             taskId,
diff --git a/streams/src/test/java/org/apache/kafka/test/MockRecordCollector.java b/streams/src/test/java/org/apache/kafka/test/MockRecordCollector.java
index e34fab1..2a781f0 100644
--- a/streams/src/test/java/org/apache/kafka/test/MockRecordCollector.java
+++ b/streams/src/test/java/org/apache/kafka/test/MockRecordCollector.java
@@ -16,7 +16,6 @@
  */
 package org.apache.kafka.test;
 
-import org.apache.kafka.clients.consumer.OffsetAndMetadata;
 import org.apache.kafka.clients.producer.ProducerRecord;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.header.Headers;
@@ -36,9 +35,6 @@ public class MockRecordCollector implements RecordCollector {
     // remember all records that are collected so far
     private final List<ProducerRecord<Object, Object>> collected = new LinkedList<>();
 
-    // remember all commits that are submitted so far
-    private final List<Map<TopicPartition, OffsetAndMetadata>> committed = new LinkedList<>();
-
     // remember if flushed is called
     private boolean flushed = false;
 
@@ -80,11 +76,6 @@ public class MockRecordCollector implements RecordCollector {
     public void initialize() {}
 
     @Override
-    public void commit(final Map<TopicPartition, OffsetAndMetadata> offsets) {
-        committed.add(offsets);
-    }
-
-    @Override
     public void flush() {
         flushed = true;
     }
@@ -101,10 +92,6 @@ public class MockRecordCollector implements RecordCollector {
         return unmodifiableList(collected);
     }
 
-    public List<Map<TopicPartition, OffsetAndMetadata>> committed() {
-        return unmodifiableList(committed);
-    }
-
     public boolean flushed() {
         return flushed;
     }
diff --git a/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
index 65eb17a..d0ee43a 100644
--- a/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
+++ b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
@@ -19,6 +19,7 @@ package org.apache.kafka.streams;
 import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.clients.consumer.MockConsumer;
+import org.apache.kafka.clients.consumer.OffsetAndMetadata;
 import org.apache.kafka.clients.consumer.OffsetResetStrategy;
 import org.apache.kafka.clients.producer.MockProducer;
 import org.apache.kafka.clients.producer.Producer;
@@ -65,7 +66,7 @@ import org.apache.kafka.streams.processor.internals.StateDirectory;
 import org.apache.kafka.streams.processor.internals.StoreChangelogReader;
 import org.apache.kafka.streams.processor.internals.StreamTask;
 import org.apache.kafka.streams.processor.internals.Task;
-import org.apache.kafka.streams.processor.internals.StreamsProducer;
+import org.apache.kafka.streams.processor.internals.TestDriverProducer;
 import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
 import org.apache.kafka.streams.processor.internals.metrics.TaskMetrics;
 import org.apache.kafka.streams.state.KeyValueStore;
@@ -213,7 +214,9 @@ public class TopologyTestDriver implements Closeable {
     ProcessorTopology processorTopology;
     ProcessorTopology globalTopology;
 
+    private final MockConsumer<byte[], byte[]> consumer;
     private final MockProducer<byte[], byte[]> producer;
+    private final TestDriverProducer testDriverProducer;
 
     private final Map<String, TopicPartition> partitionsByInputTopic = new HashMap<>();
     private final Map<String, TopicPartition> globalPartitionsByInputTopic = new HashMap<>();
@@ -300,8 +303,10 @@ public class TopologyTestDriver implements Closeable {
         final ThreadCache cache = new ThreadCache(
             logContext,
             Math.max(0, streamsConfig.getLong(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG)),
-            streamsMetrics);
+            streamsMetrics
+        );
 
+        consumer = new MockConsumer<>(OffsetResetStrategy.EARLIEST);
         final Serializer<byte[]> bytesSerializer = new ByteArraySerializer();
         producer = new MockProducer<byte[], byte[]>(true, bytesSerializer, bytesSerializer) {
             @Override
@@ -309,6 +314,12 @@ public class TopologyTestDriver implements Closeable {
                 return Collections.singletonList(new PartitionInfo(topic, PARTITION_ID, null, null, null));
             }
         };
+        testDriverProducer = new TestDriverProducer(
+            producer,
+            eosEnabled,
+            streamsConfig.getString(StreamsConfig.APPLICATION_ID_CONFIG),
+            logContext
+        );
 
         setupGlobalTask(streamsConfig, streamsMetrics, cache);
         setupTask(streamsConfig, streamsMetrics, cache);
@@ -388,7 +399,8 @@ public class TopologyTestDriver implements Closeable {
                 globalConsumer,
                 stateDirectory,
                 stateRestoreListener,
-                streamsConfig);
+                streamsConfig
+            );
 
             final GlobalProcessorContextImpl globalProcessorContext =
                 new GlobalProcessorContextImpl(streamsConfig, globalStateManager, streamsMetrics, cache);
@@ -407,7 +419,8 @@ public class TopologyTestDriver implements Closeable {
                 -1L,
                 -1,
                 ProcessorContextImpl.NONEXIST_TOPIC,
-                new RecordHeaders()));
+                new RecordHeaders())
+            );
         } else {
             globalStateManager = null;
             globalStateTask = null;
@@ -418,7 +431,6 @@ public class TopologyTestDriver implements Closeable {
                            final StreamsMetricsImpl streamsMetrics,
                            final ThreadCache cache) {
         if (!partitionsByInputTopic.isEmpty()) {
-            final MockConsumer<byte[], byte[]> consumer = new MockConsumer<>(OffsetResetStrategy.EARLIEST);
             consumer.assign(partitionsByInputTopic.values());
             final Map<TopicPartition, Long> startOffsets = new HashMap<>();
             for (final TopicPartition topicPartition : partitionsByInputTopic.values()) {
@@ -439,15 +451,15 @@ public class TopologyTestDriver implements Closeable {
                     createRestoreConsumer(processorTopology.storeToChangelogTopic()),
                     stateRestoreListener),
                 processorTopology.storeToChangelogTopic(),
-                new HashSet<>(partitionsByInputTopic.values()));
+                new HashSet<>(partitionsByInputTopic.values())
+            );
             final RecordCollector recordCollector = new RecordCollectorImpl(
                 logContext,
                 TASK_ID,
-                consumer,
-                new StreamsProducer(producer, eosEnabled, logContext, streamsConfig.getString(StreamsConfig.APPLICATION_ID_CONFIG)),
+                testDriverProducer,
                 streamsConfig.defaultProductionExceptionHandler(),
-                eosEnabled,
-                streamsMetrics);
+                streamsMetrics
+            );
             task = new StreamTask(
                 TASK_ID,
                 new HashSet<>(partitionsByInputTopic.values()),
@@ -459,7 +471,8 @@ public class TopologyTestDriver implements Closeable {
                 cache,
                 mockWallClockTime,
                 stateManager,
-                recordCollector);
+                recordCollector
+            );
             task.initializeIfNeeded();
             task.completeRestoration();
             ((InternalProcessorContext) task.context()).setRecordContext(new ProcessorRecordContext(
@@ -467,7 +480,8 @@ public class TopologyTestDriver implements Closeable {
                 -1L,
                 -1,
                 ProcessorContextImpl.NONEXIST_TOPIC,
-                new RecordHeaders()));
+                new RecordHeaders())
+            );
         } else {
             task = null;
         }
@@ -497,7 +511,8 @@ public class TopologyTestDriver implements Closeable {
             consumerRecord.timestamp(),
             consumerRecord.key(),
             consumerRecord.value(),
-            consumerRecord.headers());
+            consumerRecord.headers()
+        );
     }
 
     private void pipeRecord(final String topicName,
@@ -539,7 +554,8 @@ public class TopologyTestDriver implements Closeable {
             value == null ? ConsumerRecord.NULL_SIZE : value.length,
             key,
             value,
-            headers)));
+            headers))
+        );
     }
 
     private void completeAllProcessableWork() {
@@ -557,7 +573,8 @@ public class TopologyTestDriver implements Closeable {
                 // Process the record ...
                 task.process(mockWallClockTime.milliseconds());
                 task.maybePunctuateStreamTime();
-                task.commit();
+                task.prepareCommit();
+                commit(task.committableOffsetsAndMetadata());
                 captureOutputsAndReEnqueueInternalResults();
             }
             if (task.hasRecordsQueued()) {
@@ -570,6 +587,14 @@ public class TopologyTestDriver implements Closeable {
         }
     }
 
+    private void commit(final Map<TopicPartition, OffsetAndMetadata> offsets) {
+        if (eosEnabled) {
+            testDriverProducer.commitTransaction(offsets);
+        } else {
+            consumer.commitSync(offsets);
+        }
+    }
+
     private void processGlobalRecord(final TopicPartition globalInputTopicPartition,
                                      final long timestamp,
                                      final byte[] key,
@@ -586,7 +611,8 @@ public class TopologyTestDriver implements Closeable {
             value == null ? ConsumerRecord.NULL_SIZE : value.length,
             key,
             value,
-            headers));
+            headers)
+        );
         globalStateTask.flushState();
     }
 
@@ -637,7 +663,8 @@ public class TopologyTestDriver implements Closeable {
                     record.timestamp(),
                     record.key(),
                     record.value(),
-                    record.headers());
+                    record.headers()
+                );
             }
 
             if (globalInputTopicPartition != null) {
@@ -646,7 +673,8 @@ public class TopologyTestDriver implements Closeable {
                     record.timestamp(),
                     record.key(),
                     record.value(),
-                    record.headers());
+                    record.headers()
+                );
             }
         }
     }
@@ -691,7 +719,8 @@ public class TopologyTestDriver implements Closeable {
         mockWallClockTime.sleep(advance.toMillis());
         if (task != null) {
             task.maybePunctuateSystemTime();
-            task.commit();
+            task.prepareCommit();
+            commit(task.committableOffsetsAndMetadata());
         }
         completeAllProcessableWork();
     }
@@ -1096,7 +1125,8 @@ public class TopologyTestDriver implements Closeable {
      */
     public void close() {
         if (task != null) {
-            task.closeClean();
+            final Map<TopicPartition, Long> checkpoint = task.prepareCloseClean();
+            task.closeClean(checkpoint);
         }
         if (globalStateTask != null) {
             try {
diff --git a/streams/test-utils/src/main/java/org/apache/kafka/streams/processor/internals/TestDriverProducer.java b/streams/test-utils/src/main/java/org/apache/kafka/streams/processor/internals/TestDriverProducer.java
new file mode 100644
index 0000000..a33b519
--- /dev/null
+++ b/streams/test-utils/src/main/java/org/apache/kafka/streams/processor/internals/TestDriverProducer.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.clients.consumer.OffsetAndMetadata;
+import org.apache.kafka.clients.producer.Producer;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.errors.ProducerFencedException;
+import org.apache.kafka.common.utils.LogContext;
+
+import java.util.Map;
+
+public class TestDriverProducer extends StreamsProducer {
+
+    public TestDriverProducer(final Producer<byte[], byte[]> producer,
+                              final boolean eosEnabled,
+                              final String applicationId,
+                              final LogContext logContext) {
+        super(producer, eosEnabled, applicationId, logContext);
+    }
+
+    public void commitTransaction(final Map<TopicPartition, OffsetAndMetadata> offsets) throws ProducerFencedException {
+        super.commitTransaction(offsets);
+    }
+}