You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by gu...@apache.org on 2020/06/16 23:31:27 UTC

[kafka] branch trunk updated: KAFKA-10150: task state transitions/management and committing cleanup (#8856)

This is an automated email from the ASF dual-hosted git repository.

guozhang pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 2239004  KAFKA-10150: task state transitions/management and committing cleanup (#8856)
2239004 is described below

commit 2239004907b29e00811fee9ded5a790172701a03
Author: A. Sophie Blee-Goldman <so...@confluent.io>
AuthorDate: Tue Jun 16 16:30:37 2020 -0700

    KAFKA-10150: task state transitions/management and committing cleanup (#8856)
    
    * KAFKA-10150: always transition to SUSPENDED during suspend, no matter the current state only call prepareCommit before closing if task.commitNeeded is true
    
    * Don't commit any consumed offsets during handleAssignment -- revoked active tasks (and any others that need committing) will be committed during handleRevocation so we only need to worry about cleaning them up in handleAssignment
    
    * KAFKA-10152: when recycling a task we should always commit consumed offsets (if any), but don't need to write the checkpoint (since changelog offsets are preserved across task transitions)
    
    * Make sure we close all tasks during shutdown, even if an exception is thrown during commit
    
    Reviewers: Matthias J. Sax <ma...@confluent.io>, Guozhang Wang <wa...@gmail.com>
---
 .../streams/processor/internals/StandbyTask.java   |  31 ++-
 .../streams/processor/internals/StreamTask.java    |  97 +++----
 .../kafka/streams/processor/internals/Task.java    |  22 +-
 .../streams/processor/internals/TaskManager.java   | 241 +++++++++--------
 .../processor/internals/StandbyTaskTest.java       |  57 +++-
 .../processor/internals/StreamTaskTest.java        |  69 +++--
 .../processor/internals/TaskManagerTest.java       | 296 ++++++++-------------
 7 files changed, 402 insertions(+), 411 deletions(-)

diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java
index 1b069d6..5df59f6 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java
@@ -111,11 +111,25 @@ public class StandbyTask extends AbstractTask implements Task {
 
     @Override
     public void suspend() {
-        log.trace("No-op suspend with state {}", state());
-        if (state() == State.RUNNING) {
-            transitionTo(State.SUSPENDED);
-        } else if (state() == State.RESTORING) {
-            throw new IllegalStateException("Illegal state " + state() + " while suspending standby task " + id);
+        switch (state()) {
+            case CREATED:
+            case RUNNING:
+                log.info("Suspended {}", state());
+                transitionTo(State.SUSPENDED);
+
+                break;
+
+            case SUSPENDED:
+                log.info("Skip suspending since state is {}", state());
+
+                break;
+
+            case RESTORING:
+            case CLOSED:
+                throw new IllegalStateException("Illegal state " + state() + " while suspending standby task " + id);
+
+            default:
+                throw new IllegalStateException("Unknown state " + state() + " while suspending standby task " + id);
         }
     }
 
@@ -172,10 +186,7 @@ public class StandbyTask extends AbstractTask implements Task {
 
     @Override
     public void closeAndRecycleState() {
-        suspend();
-        prepareCommit();
-
-        if (state() == State.CREATED || state() == State.SUSPENDED) {
+        if (state() == State.SUSPENDED) {
             stateMgr.recycle();
         } else {
             throw new IllegalStateException("Illegal state " + state() + " while closing standby task " + id);
@@ -189,7 +200,6 @@ public class StandbyTask extends AbstractTask implements Task {
 
     private void close(final boolean clean) {
         switch (state()) {
-            case CREATED:
             case SUSPENDED:
                 executeAndMaybeSwallow(
                     clean,
@@ -212,6 +222,7 @@ public class StandbyTask extends AbstractTask implements Task {
                 log.trace("Skip closing since state is {}", state());
                 return;
 
+            case CREATED:
             case RESTORING: // a StandbyTask is never in RESTORING state
             case RUNNING:
                 throw new IllegalStateException("Illegal state " + state() + " while closing standby task " + id);
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
index 7f08643..fa8b94b 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
@@ -107,8 +107,6 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
     private boolean commitNeeded = false;
     private boolean commitRequested = false;
 
-    private Map<TopicPartition, Long> checkpoint = null;
-
     public StreamTask(final TaskId id,
                       final Set<TopicPartition> partitions,
                       final ProcessorTopology topology,
@@ -250,14 +248,9 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
     public void suspend() {
         switch (state()) {
             case CREATED:
-            case SUSPENDED:
-                log.info("Skip suspending since state is {}", state());
-
-                break;
-
             case RESTORING:
+                log.info("Suspended {}", state());
                 transitionTo(State.SUSPENDED);
-                log.info("Suspended restoring");
 
                 break;
 
@@ -272,6 +265,12 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
 
                 break;
 
+            case SUSPENDED:
+                log.info("Skip suspending since state is {}", state());
+
+                break;
+
+
             case CLOSED:
                 throw new IllegalStateException("Illegal state " + state() + " while suspending active task " + id);
 
@@ -342,7 +341,6 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
             case RUNNING:
             case RESTORING:
             case SUSPENDED:
-                maybeScheduleCheckpoint();
                 stateMgr.flush();
                 recordCollector.flush();
 
@@ -409,6 +407,9 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
         return committableOffsets;
     }
 
+    /**
+     * This should only be called if the attempted commit succeeded for this task
+     */
     @Override
     public void postCommit() {
         commitRequested = false;
@@ -416,23 +417,28 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
 
         switch (state()) {
             case RESTORING:
-                writeCheckpointIfNeed();
+                writeCheckpoint();
 
                 break;
 
             case RUNNING:
-                if (!eosEnabled) { // if RUNNING, checkpoint only for non-eos
-                    writeCheckpointIfNeed();
+                if (!eosEnabled) {
+                    writeCheckpoint();
                 }
 
                 break;
 
             case SUSPENDED:
-                writeCheckpointIfNeed();
-                // we cannot `clear()` the `PartitionGroup` in `suspend()` already, but only after committing,
-                // because otherwise we loose the partition-time information
+                /*
+                 * We must clear the `PartitionGroup` only after committing, and not in `suspend()`,
+                 * because otherwise we lose the partition-time information.
+                 * We also must clear it when the task is revoked, and not in `close()`, as the consumer will clear
+                 * its internal buffer when the corresponding partition is revoked but the task may be reassigned
+                 */
                 partitionGroup.clear();
 
+                writeCheckpoint();
+
                 break;
 
             case CREATED:
@@ -474,27 +480,22 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
 
     @Override
     public void closeAndRecycleState() {
-        suspend();
-        prepareCommit();
-        writeCheckpointIfNeed();
-
         switch (state()) {
-            case CREATED:
             case SUSPENDED:
                 stateMgr.recycle();
                 recordCollector.close();
 
                 break;
 
-            case RESTORING: // we should have transitioned to `SUSPENDED` already
-            case RUNNING: // we should have transitioned to `SUSPENDED` already
+            case CREATED:
+            case RESTORING:
+            case RUNNING:
             case CLOSED:
                 throw new IllegalStateException("Illegal state " + state() + " while recycling active task " + id);
             default:
                 throw new IllegalStateException("Unknown state " + state() + " while recycling active task " + id);
         }
 
-        partitionGroup.clear();
         closeTaskSensor.record();
 
         transitionTo(State.CLOSED);
@@ -502,56 +503,24 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
         log.info("Closed clean and recycled state");
     }
 
-    private void maybeScheduleCheckpoint() {
-        switch (state()) {
-            case RESTORING:
-            case SUSPENDED:
-                this.checkpoint = checkpointableOffsets();
-
-                break;
-
-            case RUNNING:
-                if (!eosEnabled) {
-                    this.checkpoint = checkpointableOffsets();
-                }
-
-                break;
-
-            case CREATED:
-            case CLOSED:
-                throw new IllegalStateException("Illegal state " + state() + " while scheduling checkpoint for active task " + id);
-
-            default:
-                throw new IllegalStateException("Unknown state " + state() + " while scheduling checkpoint for active task " + id);
-        }
-    }
-
-    private void writeCheckpointIfNeed() {
+    private void writeCheckpoint() {
         if (commitNeeded) {
+            log.error("Tried to write a checkpoint with pending uncommitted data, should complete the commit first.");
             throw new IllegalStateException("A checkpoint should only be written if no commit is needed.");
         }
-        if (checkpoint != null) {
-            stateMgr.checkpoint(checkpoint);
-            checkpoint = null;
-        }
+        stateMgr.checkpoint(checkpointableOffsets());
     }
 
     /**
-     * <pre>
-     * the following order must be followed:
-     *  1. checkpoint the state manager -- even if we crash before this step, EOS is still guaranteed
-     *  2. then if we are closing on EOS and dirty, wipe out the state store directory
-     *  3. finally release the state manager lock
-     * </pre>
+     * You must commit a task and checkpoint the state manager before closing as this will release the state dir lock
      */
     private void close(final boolean clean) {
-        if (clean) {
-            executeAndMaybeSwallow(true, this::writeCheckpointIfNeed, "state manager checkpoint", log);
+        if (clean && commitNeeded) {
+            log.debug("Tried to close clean but there was pending uncommitted data, this means we failed to"
+                          + " commit and should close as dirty instead");
+            throw new StreamsException("Tried to close dirty task as clean");
         }
-
         switch (state()) {
-            case CREATED:
-            case RESTORING:
             case SUSPENDED:
                 // first close state manager (which is idempotent) then close the record collector
                 // if the latter throws and we re-close dirty which would close the state manager again.
@@ -577,6 +546,8 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
                 log.trace("Skip closing since state is {}", state());
                 return;
 
+            case CREATED:
+            case RESTORING:
             case RUNNING:
                 throw new IllegalStateException("Illegal state " + state() + " while closing active task " + id);
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java
index 62332c7..0200870 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java
@@ -56,21 +56,21 @@ public interface Task {
      *          |            |              |     |
      *          |            v              |     |
      *          |     +------+--------+     |     |
-     *          |     | Suspended (3) | <---+     |    //TODO Suspended(3) could be removed after we've stable on KIP-429
-     *          |     +------+--------+           |
-     *          |            |                    |
-     *          |            v                    |
-     *          |      +-----+-------+            |
-     *          +----> | Closed (4)  | -----------+
+     *          +---> | Suspended (3) | ----+     |    //TODO Suspended(3) could be removed after we've stable on KIP-429
+     *                +------+--------+           |
+     *                       |                    |
+     *                       v                    |
+     *                 +-----+-------+            |
+     *                 | Closed (4)  | -----------+
      *                 +-------------+
      * </pre>
      */
     enum State {
-        CREATED(1, 4),         // 0
-        RESTORING(2, 3, 4),    // 1
-        RUNNING(3),            // 2
-        SUSPENDED(1, 4),       // 3
-        CLOSED(0);             // 4, we allow CLOSED to transit to CREATED to handle corrupted tasks
+        CREATED(1, 3),            // 0
+        RESTORING(2, 3),          // 1
+        RUNNING(3),               // 2
+        SUSPENDED(1, 4),          // 3
+        CLOSED(0);                // 4, we allow CLOSED to transit to CREATED to handle corrupted tasks
 
         private final Set<Integer> validTransitions = new HashSet<>();
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
index 689be9b..92885fd 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
@@ -215,91 +215,54 @@ public class TaskManager {
                      "\tExisting standby tasks: {}",
                  activeTasks.keySet(), standbyTasks.keySet(), activeTaskIds(), standbyTaskIds());
 
-        final Map<TaskId, Set<TopicPartition>> activeTasksToCreate = new HashMap<>(activeTasks);
-        final Map<TaskId, Set<TopicPartition>> standbyTasksToCreate = new HashMap<>(standbyTasks);
-        final Set<Task> tasksToRecycle = new HashSet<>();
-
         builder.addSubscribedTopicsFromAssignment(
             activeTasks.values().stream().flatMap(Collection::stream).collect(Collectors.toList()),
             logPrefix
         );
 
-        // first rectify all existing tasks
         final LinkedHashMap<TaskId, RuntimeException> taskCloseExceptions = new LinkedHashMap<>();
 
-        final Set<Task> tasksToClose = new HashSet<>();
-        final Map<TaskId, Map<TopicPartition, OffsetAndMetadata>> consumedOffsetsAndMetadataPerTask = new HashMap<>();
-        final Set<Task> additionalTasksForCommitting = new HashSet<>();
+        final Map<TaskId, Set<TopicPartition>> activeTasksToCreate = new HashMap<>(activeTasks);
+        final Map<TaskId, Set<TopicPartition>> standbyTasksToCreate = new HashMap<>(standbyTasks);
+        final List<Task> tasksToClose = new LinkedList<>();
+        final Set<Task> tasksToRecycle = new HashSet<>();
         final Set<Task> dirtyTasks = new HashSet<>();
 
+        // first rectify all existing tasks
         for (final Task task : tasks.values()) {
             if (activeTasks.containsKey(task.id()) && task.isActive()) {
                 updateInputPartitionsAndResume(task, activeTasks.get(task.id()));
-                if (task.commitNeeded()) {
-                    additionalTasksForCommitting.add(task);
-                }
                 activeTasksToCreate.remove(task.id());
             } else if (standbyTasks.containsKey(task.id()) && !task.isActive()) {
                 updateInputPartitionsAndResume(task, standbyTasks.get(task.id()));
                 standbyTasksToCreate.remove(task.id());
-                // check for tasks that were owned previously but have changed active/standby status
             } else if (activeTasks.containsKey(task.id()) || standbyTasks.containsKey(task.id())) {
+                // check for tasks that were owned previously but have changed active/standby status
                 tasksToRecycle.add(task);
             } else {
-                try {
-                    task.suspend();
-                    final Map<TopicPartition, OffsetAndMetadata> committableOffsets = task.prepareCommit();
-
-                    tasksToClose.add(task);
-                    if (!committableOffsets.isEmpty()) {
-                        consumedOffsetsAndMetadataPerTask.put(task.id(), committableOffsets);
-                    }
-                } catch (final RuntimeException e) {
-                    final String uncleanMessage = String.format(
-                        "Failed to close task %s cleanly. Attempting to close remaining tasks before re-throwing:",
-                        task.id());
-                    log.error(uncleanMessage, e);
-                    taskCloseExceptions.put(task.id(), e);
-                    // We've already recorded the exception (which is the point of clean).
-                    // Now, we should go ahead and complete the close because a half-closed task is no good to anyone.
-                    dirtyTasks.add(task);
-                }
+                tasksToClose.add(task);
             }
         }
 
-        if (!consumedOffsetsAndMetadataPerTask.isEmpty()) {
+        for (final Task task : tasksToClose) {
             try {
-                for (final Task task : additionalTasksForCommitting) {
-                    final Map<TopicPartition, OffsetAndMetadata> committableOffsets = task.prepareCommit();
-                    if (!committableOffsets.isEmpty()) {
-                        consumedOffsetsAndMetadataPerTask.put(task.id(), committableOffsets);
+                task.suspend(); // Should be a no-op for active tasks since they're suspended in handleRevocation
+                if (task.commitNeeded()) {
+                    if (task.isActive()) {
+                        log.error("Active task {} was revoked and should have already been committed", task.id());
+                        throw new IllegalStateException("Revoked active task was not committed during handleRevocation");
+                    } else {
+                        task.prepareCommit();
+                        task.postCommit();
                     }
                 }
-
-                commitOffsetsOrTransaction(consumedOffsetsAndMetadataPerTask);
-
-                for (final Task task : additionalTasksForCommitting) {
-                    task.postCommit();
-                }
-            } catch (final RuntimeException e) {
-                log.error("Failed to batch commit tasks, " +
-                    "will close all tasks involved in this commit as dirty by the end", e);
-                dirtyTasks.addAll(additionalTasksForCommitting);
-                dirtyTasks.addAll(tasksToClose);
-
-                tasksToClose.clear();
-                // Just add first taskId to re-throw by the end.
-                taskCloseExceptions.put(consumedOffsetsAndMetadataPerTask.keySet().iterator().next(), e);
-            }
-        }
-
-        for (final Task task : tasksToClose) {
-            try {
                 completeTaskCloseClean(task);
                 cleanUpTaskProducer(task, taskCloseExceptions);
                 tasks.remove(task.id());
             } catch (final RuntimeException e) {
-                final String uncleanMessage = String.format("Failed to close task %s cleanly. Attempting to close remaining tasks before re-throwing:", task.id());
+                final String uncleanMessage = String.format(
+                    "Failed to close task %s cleanly. Attempting to close remaining tasks before re-throwing:",
+                    task.id());
                 log.error(uncleanMessage, e);
                 taskCloseExceptions.put(task.id(), e);
                 // We've already recorded the exception (which is the point of clean).
@@ -315,6 +278,7 @@ public class TaskManager {
                     final Set<TopicPartition> partitions = standbyTasksToCreate.remove(oldTask.id());
                     newTask = standbyTaskCreator.createStandbyTaskFromActive((StreamTask) oldTask, partitions);
                 } else {
+                    oldTask.suspend(); // Only need to suspend transitioning standbys, actives should be suspended already
                     final Set<TopicPartition> partitions = activeTasksToCreate.remove(oldTask.id());
                     newTask = activeTaskCreator.createActiveTaskFromStandby((StandbyTask) oldTask, partitions, mainConsumer);
                 }
@@ -465,44 +429,76 @@ public class TaskManager {
     }
 
     /**
+     * Handle the revoked partitions and prepare for closing the associated tasks in {@link #handleAssignment(Map, Map)}
+     * We should commit the revoked tasks now as we will not officially own them anymore when {@link #handleAssignment(Map, Map)}
+     * is called. Note that only active task partitions are passed in from the rebalance listener, so we only need to
+     * consider/commit active tasks here
+     *
+     * If eos-beta is used, we must commit ALL tasks. Otherwise, we can just commit those (active) tasks which are revoked
+     *
      * @throws TaskMigratedException if the task producer got fenced (EOS only)
      */
     void handleRevocation(final Collection<TopicPartition> revokedPartitions) {
-        final Set<TopicPartition> remainingPartitions = new HashSet<>(revokedPartitions);
-
-        final Map<TaskId, Map<TopicPartition, OffsetAndMetadata>> consumedOffsetsAndMetadataPerTask = new HashMap<>();
-        for (final Task task : tasks.values()) {
-            if (remainingPartitions.containsAll(task.inputPartitions())) {
-                task.suspend();
-                final Map<TopicPartition, OffsetAndMetadata> committableOffsets = task.prepareCommit();
+        final Set<TopicPartition> remainingRevokedPartitions = new HashSet<>(revokedPartitions);
 
-                if (!committableOffsets.isEmpty()) {
-                    consumedOffsetsAndMetadataPerTask.put(task.id(), committableOffsets);
-                }
-            } else if (task.isActive() && task.commitNeeded()) {
-                final Map<TopicPartition, OffsetAndMetadata> committableOffsets = task.prepareCommit();
+        final Set<Task> tasksToCommit = new HashSet<>();
+        final Set<Task> additionalTasksForCommitting = new HashSet<>();
 
-                if (!committableOffsets.isEmpty()) {
-                    consumedOffsetsAndMetadataPerTask.put(task.id(), committableOffsets);
+        final AtomicReference<RuntimeException> firstException = new AtomicReference<>(null);
+        for (final Task task : activeTaskIterable()) {
+            if (remainingRevokedPartitions.containsAll(task.inputPartitions())) {
+                try {
+                    task.suspend();
+                    if (task.commitNeeded()) {
+                        tasksToCommit.add(task);
+                    }
+                } catch (final RuntimeException e) {
+                    log.error("Caught the following exception while trying to suspend revoked task " + task.id(), e);
+                    firstException.compareAndSet(null, new StreamsException("Failed to suspend " + task.id(), e));
                 }
+            } else if (task.commitNeeded()) {
+                additionalTasksForCommitting.add(task);
             }
-            remainingPartitions.removeAll(task.inputPartitions());
+            remainingRevokedPartitions.removeAll(task.inputPartitions());
         }
 
-        if (!consumedOffsetsAndMetadataPerTask.isEmpty()) {
-            commitOffsetsOrTransaction(consumedOffsetsAndMetadataPerTask);
+        if (!remainingRevokedPartitions.isEmpty()) {
+            log.warn("The following partitions {} are missing from the task partitions. It could potentially " +
+                         "due to race condition of consumer detecting the heartbeat failure, or the tasks " +
+                         "have been cleaned up by the handleAssignment callback.", remainingRevokedPartitions);
         }
 
-        for (final Task task : tasks.values()) {
-            if (consumedOffsetsAndMetadataPerTask.containsKey(task.id())) {
+        final RuntimeException suspendException = firstException.get();
+        if (suspendException != null) {
+            throw suspendException;
+        }
+
+        // If using eos-beta, if we must commit any task then we must commit all of them
+        // TODO: when KAFKA-9450 is done this will be less expensive, and we can simplify by always committing everything
+        if (processingMode ==  EXACTLY_ONCE_BETA && !tasksToCommit.isEmpty()) {
+            tasksToCommit.addAll(additionalTasksForCommitting);
+        }
+
+        final Map<TaskId, Map<TopicPartition, OffsetAndMetadata>> consumedOffsetsAndMetadataPerTask = new HashMap<>();
+        for (final Task task : tasksToCommit) {
+            final Map<TopicPartition, OffsetAndMetadata> committableOffsets = task.prepareCommit();
+            consumedOffsetsAndMetadataPerTask.put(task.id(), committableOffsets);
+        }
+
+        commitOffsetsOrTransaction(consumedOffsetsAndMetadataPerTask);
+
+        for (final Task task : tasksToCommit) {
+            try {
                 task.postCommit();
+            } catch (final RuntimeException e) {
+                log.error("Exception caught while post-committing task " + task.id(), e);
+                firstException.compareAndSet(null, e);
             }
         }
 
-        if (!remainingPartitions.isEmpty()) {
-            log.warn("The following partitions {} are missing from the task partitions. It could potentially " +
-                         "due to race condition of consumer detecting the heartbeat failure, or the tasks " +
-                         "have been cleaned up by the handleAssignment callback.", remainingPartitions);
+        final RuntimeException commitException = firstException.get();
+        if (commitException != null) {
+            throw commitException;
         }
     }
 
@@ -690,18 +686,21 @@ public class TaskManager {
         final AtomicReference<RuntimeException> firstException = new AtomicReference<>(null);
 
         final Set<Task> tasksToClose = new HashSet<>();
+        final Set<Task> tasksToCommit = new HashSet<>();
         final Map<TaskId, Map<TopicPartition, OffsetAndMetadata>> consumedOffsetsAndMetadataPerTask = new HashMap<>();
 
         for (final Task task : tasks.values()) {
             if (clean) {
                 try {
                     task.suspend();
-                    final Map<TopicPartition, OffsetAndMetadata> committableOffsets = task.prepareCommit();
-
-                    tasksToClose.add(task);
-                    if (!committableOffsets.isEmpty()) {
-                        consumedOffsetsAndMetadataPerTask.put(task.id(), committableOffsets);
+                    if (task.commitNeeded()) {
+                        tasksToCommit.add(task);
+                        final Map<TopicPartition, OffsetAndMetadata> committableOffsets = task.prepareCommit();
+                        if (task.isActive()) {
+                            consumedOffsetsAndMetadataPerTask.put(task.id(), committableOffsets);
+                        }
                     }
+                    tasksToClose.add(task);
                 } catch (final TaskMigratedException e) {
                     // just ignore the exception as it doesn't matter during shutdown
                     closeTaskDirty(task);
@@ -714,13 +713,25 @@ public class TaskManager {
             }
         }
 
-        if (clean && !consumedOffsetsAndMetadataPerTask.isEmpty()) {
-            commitOffsetsOrTransaction(consumedOffsetsAndMetadataPerTask);
+        try {
+            if (clean) {
+                commitOffsetsOrTransaction(consumedOffsetsAndMetadataPerTask);
+                for (final Task task : tasksToCommit) {
+                    try {
+                        task.postCommit();
+                    } catch (final RuntimeException e) {
+                        log.error("Exception caught while post-committing task " + task.id(), e);
+                        firstException.compareAndSet(null, e);
+                    }
+                }
+            }
+        } catch (final RuntimeException e) {
+            log.error("Exception caught while committing tasks during shutdown", e);
+            firstException.compareAndSet(null, e);
         }
 
         for (final Task task : tasksToClose) {
             try {
-                task.postCommit();
                 completeTaskCloseClean(task);
             } catch (final RuntimeException e) {
                 firstException.compareAndSet(null, e);
@@ -835,26 +846,24 @@ public class TaskManager {
      *                               or if the task producer got fenced (EOS)
      * @return number of committed offsets, or -1 if we are in the middle of a rebalance and cannot commit
      */
-    int commit(final Collection<Task> tasks) {
+    int commit(final Collection<Task> tasksToCommit) {
         if (rebalanceInProgress) {
             return -1;
         } else {
             int committed = 0;
             final Map<TaskId, Map<TopicPartition, OffsetAndMetadata>> consumedOffsetsAndMetadataPerTask = new HashMap<>();
-            for (final Task task : tasks) {
+            for (final Task task : tasksToCommit) {
                 if (task.commitNeeded()) {
                     final Map<TopicPartition, OffsetAndMetadata> offsetAndMetadata = task.prepareCommit();
-                    if (!offsetAndMetadata.isEmpty()) {
+                    if (task.isActive()) {
                         consumedOffsetsAndMetadataPerTask.put(task.id(), offsetAndMetadata);
                     }
                 }
             }
 
-            if (!consumedOffsetsAndMetadataPerTask.isEmpty()) {
-                commitOffsetsOrTransaction(consumedOffsetsAndMetadataPerTask);
-            }
+            commitOffsetsOrTransaction(consumedOffsetsAndMetadataPerTask);
 
-            for (final Task task : tasks) {
+            for (final Task task : tasksToCommit) {
                 if (task.commitNeeded()) {
                     ++committed;
                     task.postCommit();
@@ -883,28 +892,30 @@ public class TaskManager {
     }
 
     private void commitOffsetsOrTransaction(final Map<TaskId, Map<TopicPartition, OffsetAndMetadata>> offsetsPerTask) {
-        if (processingMode == EXACTLY_ONCE_ALPHA) {
-            for (final Map.Entry<TaskId, Map<TopicPartition, OffsetAndMetadata>> taskToCommit : offsetsPerTask.entrySet()) {
-                activeTaskCreator.streamsProducerForTask(taskToCommit.getKey())
-                    .commitTransaction(taskToCommit.getValue(), mainConsumer.groupMetadata());
-            }
-        } else {
-            final Map<TopicPartition, OffsetAndMetadata> allOffsets = offsetsPerTask.values().stream()
-                .flatMap(e -> e.entrySet().stream()).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
-
-            if (processingMode == EXACTLY_ONCE_BETA) {
-                activeTaskCreator.threadProducer().commitTransaction(allOffsets, mainConsumer.groupMetadata());
+        if (!offsetsPerTask.isEmpty()) {
+            if (processingMode == EXACTLY_ONCE_ALPHA) {
+                for (final Map.Entry<TaskId, Map<TopicPartition, OffsetAndMetadata>> taskToCommit : offsetsPerTask.entrySet()) {
+                    activeTaskCreator.streamsProducerForTask(taskToCommit.getKey())
+                        .commitTransaction(taskToCommit.getValue(), mainConsumer.groupMetadata());
+                }
             } else {
-                try {
-                    mainConsumer.commitSync(allOffsets);
-                } catch (final CommitFailedException error) {
-                    throw new TaskMigratedException("Consumer committing offsets failed, " +
-                        "indicating the corresponding thread is no longer part of the group", error);
-                } catch (final TimeoutException error) {
-                    // TODO KIP-447: we can consider treating it as non-fatal and retry on the thread level
-                    throw new StreamsException("Timed out while committing offsets via consumer", error);
-                } catch (final KafkaException error) {
-                    throw new StreamsException("Error encountered committing offsets via consumer", error);
+                final Map<TopicPartition, OffsetAndMetadata> allOffsets = offsetsPerTask.values().stream()
+                    .flatMap(e -> e.entrySet().stream()).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
+
+                if (processingMode == EXACTLY_ONCE_BETA) {
+                    activeTaskCreator.threadProducer().commitTransaction(allOffsets, mainConsumer.groupMetadata());
+                } else {
+                    try {
+                        mainConsumer.commitSync(allOffsets);
+                    } catch (final CommitFailedException error) {
+                        throw new TaskMigratedException("Consumer committing offsets failed, " +
+                                                            "indicating the corresponding thread is no longer part of the group", error);
+                    } catch (final TimeoutException error) {
+                        // TODO KIP-447: we can consider treating it as non-fatal and retry on the thread level
+                        throw new StreamsException("Timed out while committing offsets via consumer", error);
+                    } catch (final KafkaException error) {
+                        throw new StreamsException("Error encountered committing offsets via consumer", error);
+                    }
                 }
             }
         }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
index 8784cf1..3f4b410 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
@@ -57,6 +57,9 @@ import static java.util.Arrays.asList;
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.common.utils.Utils.mkProperties;
+import static org.apache.kafka.streams.processor.internals.Task.State.CREATED;
+import static org.apache.kafka.streams.processor.internals.Task.State.RUNNING;
+import static org.apache.kafka.streams.processor.internals.Task.State.SUSPENDED;
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.Assert.assertEquals;
@@ -139,7 +142,7 @@ public class StandbyTaskTest {
             try {
                 task.suspend();
             } catch (final IllegalStateException maybeSwallow) {
-                if (!maybeSwallow.getMessage().startsWith("Invalid transition from CLOSED to SUSPENDED")) {
+                if (!maybeSwallow.getMessage().startsWith("Illegal state CLOSED while suspending standby task")) {
                     throw maybeSwallow;
                 }
             }
@@ -171,16 +174,16 @@ public class StandbyTaskTest {
 
         task = createStandbyTask();
 
-        assertEquals(Task.State.CREATED, task.state());
+        assertEquals(CREATED, task.state());
 
         task.initializeIfNeeded();
 
-        assertEquals(Task.State.RUNNING, task.state());
+        assertEquals(RUNNING, task.state());
 
         // initialize should be idempotent
         task.initializeIfNeeded();
 
-        assertEquals(Task.State.RUNNING, task.state());
+        assertEquals(RUNNING, task.state());
 
         EasyMock.verify(stateManager);
     }
@@ -263,7 +266,7 @@ public class StandbyTaskTest {
     }
 
     @Test
-    public void shouldCommitOnCloseClean() {
+    public void shouldSuspendAndCommitBeforeCloseClean() {
         stateManager.close();
         EasyMock.expectLastCall();
         stateManager.checkpoint(EasyMock.eq(Collections.emptyMap()));
@@ -289,6 +292,17 @@ public class StandbyTaskTest {
     }
 
     @Test
+    public void shouldRequireSuspendingCreatedTasksBeforeClose() {
+        EasyMock.replay(stateManager);
+        task = createStandbyTask();
+        assertThat(task.state(), equalTo(CREATED));
+        assertThrows(IllegalStateException.class, () -> task.closeClean());
+
+        task.suspend();
+        task.closeClean();
+    }
+
+    @Test
     public void shouldOnlyNeedCommitWhenChangelogOffsetChanged() {
         EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.singleton(partition)).anyTimes();
         EasyMock.expect(stateManager.changelogOffsets())
@@ -355,7 +369,7 @@ public class StandbyTaskTest {
         task.prepareCommit();
         assertThrows(RuntimeException.class, task::postCommit);
 
-        assertEquals(Task.State.RUNNING, task.state());
+        assertEquals(RUNNING, task.state());
 
         final double expectedCloseTaskMetric = 0.0;
         verifyCloseTaskMetric(expectedCloseTaskMetric, streamsMetrics, metricName);
@@ -376,6 +390,7 @@ public class StandbyTaskTest {
         final MetricName metricName = setupCloseTaskMetric();
 
         task = createStandbyTask();
+        task.suspend();
 
         task.closeDirty();
 
@@ -405,6 +420,7 @@ public class StandbyTaskTest {
         )));
 
         task = createStandbyTask();
+        task.suspend();
 
         task.closeDirty();
 
@@ -435,6 +451,7 @@ public class StandbyTaskTest {
 
         task = createStandbyTask();
 
+        task.suspend();
         task.closeDirty();
 
         final double expectedCloseTaskMetric = 1.0;
@@ -447,20 +464,40 @@ public class StandbyTaskTest {
 
     @Test
     public void shouldRecycleTask() {
-        stateManager.flush();
-        EasyMock.expectLastCall();
         stateManager.recycle();
-        EasyMock.expectLastCall();
         EasyMock.replay(stateManager);
 
         task = createStandbyTask();
+        assertThrows(IllegalStateException.class, () -> task.closeAndRecycleState()); // CREATED
+
         task.initializeIfNeeded();
+        assertThrows(IllegalStateException.class, () -> task.closeAndRecycleState()); // RUNNING
 
-        task.closeAndRecycleState();
+        task.suspend();
+        task.closeAndRecycleState(); // SUSPENDED
 
         EasyMock.verify(stateManager);
     }
 
+    @Test
+    public void shouldAlwaysSuspendCreatedTasks() {
+        EasyMock.replay(stateManager);
+        task = createStandbyTask();
+        assertThat(task.state(), equalTo(CREATED));
+        task.suspend();
+        assertThat(task.state(), equalTo(SUSPENDED));
+    }
+
+    @Test
+    public void shouldAlwaysSuspendRunningTasks() {
+        EasyMock.replay(stateManager);
+        task = createStandbyTask();
+        task.initializeIfNeeded();
+        assertThat(task.state(), equalTo(RUNNING));
+        task.suspend();
+        assertThat(task.state(), equalTo(SUSPENDED));
+    }
+
     private StandbyTask createStandbyTask() {
 
         final ThreadCache cache = new ThreadCache(
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
index 0426f68..7a2cf7a 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
@@ -87,6 +87,10 @@ import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.common.utils.Utils.mkProperties;
 import static org.apache.kafka.common.utils.Utils.mkSet;
 import static org.apache.kafka.streams.processor.internals.StreamTask.encodeTimestamp;
+import static org.apache.kafka.streams.processor.internals.Task.State.CREATED;
+import static org.apache.kafka.streams.processor.internals.Task.State.RESTORING;
+import static org.apache.kafka.streams.processor.internals.Task.State.RUNNING;
+import static org.apache.kafka.streams.processor.internals.Task.State.SUSPENDED;
 import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.THREAD_ID_TAG;
 import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.THREAD_ID_TAG_0100_TO_24;
 import static org.apache.kafka.test.StreamsTestUtils.getMetricByNameFilterByTags;
@@ -329,14 +333,14 @@ public class StreamTaskTest {
 
         task.initializeIfNeeded();
 
-        assertEquals(Task.State.RESTORING, task.state());
+        assertEquals(RESTORING, task.state());
         assertFalse(source1.initialized);
         assertFalse(source2.initialized);
 
         // initialize should be idempotent
         task.initializeIfNeeded();
 
-        assertEquals(Task.State.RESTORING, task.state());
+        assertEquals(RESTORING, task.state());
 
         task.completeRestoration();
 
@@ -958,6 +962,7 @@ public class StreamTaskTest {
     @Test
     public void shouldFailOnCommitIfTaskIsClosed() {
         task = createStatelessTask(createConfig(false, "0"), StreamsConfig.METRICS_LATEST);
+        task.suspend();
         task.transitionTo(Task.State.CLOSED);
 
         final IllegalStateException thrown = assertThrows(
@@ -1299,7 +1304,7 @@ public class StreamTaskTest {
 
         task.resume();
 
-        assertEquals(Task.State.RESTORING, task.state());
+        assertEquals(RESTORING, task.state());
         assertFalse(source1.initialized);
         assertFalse(source2.initialized);
 
@@ -1506,6 +1511,7 @@ public class StreamTaskTest {
         task = createStatelessTask(createConfig(false, "100"), StreamsConfig.METRICS_LATEST);
         assertThrows(IllegalStateException.class, task::prepareCommit);
 
+        task.transitionTo(Task.State.SUSPENDED);
         task.transitionTo(Task.State.CLOSED);
         assertThrows(IllegalStateException.class, task::prepareCommit);
     }
@@ -1515,6 +1521,7 @@ public class StreamTaskTest {
         task = createStatelessTask(createConfig(false, "100"), StreamsConfig.METRICS_LATEST);
         assertThrows(IllegalStateException.class, task::postCommit);
 
+        task.transitionTo(Task.State.SUSPENDED);
         task.transitionTo(Task.State.CLOSED);
         assertThrows(IllegalStateException.class, task::postCommit);
     }
@@ -1608,6 +1615,7 @@ public class StreamTaskTest {
         task.completeRestoration();
         task.suspend();
         task.prepareCommit();
+        task.postCommit();
         task.closeClean();
 
         assertEquals(Task.State.CLOSED, task.state());
@@ -1633,6 +1641,7 @@ public class StreamTaskTest {
         task.completeRestoration();
         task.suspend();
         task.prepareCommit();
+        task.postCommit();
         task.closeClean();
 
         assertEquals(Task.State.CLOSED, task.state());
@@ -1662,6 +1671,7 @@ public class StreamTaskTest {
 
         task.suspend();
         task.prepareCommit();
+        task.postCommit();
         assertThrows(ProcessorStateException.class, () -> task.closeClean());
 
         final double expectedCloseTaskMetric = 0.0;
@@ -1669,7 +1679,7 @@ public class StreamTaskTest {
 
         EasyMock.verify(stateManager);
         EasyMock.reset(stateManager);
-        EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.singleton(changelogPartition)).anyTimes();
+        EasyMock.expect(stateManager.changelogPartitions()).andStubReturn(Collections.singleton(changelogPartition));
         stateManager.close();
         EasyMock.expectLastCall();
         EasyMock.replay(stateManager);
@@ -1695,7 +1705,7 @@ public class StreamTaskTest {
 
         assertThrows(ProcessorStateException.class, task::prepareCommit);
 
-        assertEquals(Task.State.RESTORING, task.state());
+        assertEquals(RESTORING, task.state());
 
         final double expectedCloseTaskMetric = 0.0;
         verifyCloseTaskMetric(expectedCloseTaskMetric, streamsMetrics, metricName);
@@ -1788,29 +1798,56 @@ public class StreamTaskTest {
     }
 
     @Test
-    public void shouldRecycleTask() {
-        EasyMock.expect(recordCollector.offsets()).andReturn(Collections.emptyMap()).anyTimes();
-        recordCollector.flush();
-        EasyMock.expectLastCall();
-        stateManager.flush();
-        EasyMock.expectLastCall();
-        stateManager.checkpoint(Collections.emptyMap());
-        EasyMock.expectLastCall();
+    public void shouldOnlyRecycleSuspendedTasks() {
         stateManager.recycle();
-        EasyMock.expectLastCall();
         recordCollector.close();
-        EasyMock.expectLastCall();
         EasyMock.replay(stateManager, recordCollector);
 
         task = createStatefulTask(createConfig(false, "100"), true);
+        assertThrows(IllegalStateException.class, () -> task.closeAndRecycleState()); // CREATED
+
         task.initializeIfNeeded();
+        assertThrows(IllegalStateException.class, () -> task.closeAndRecycleState()); // RESTORING
+
         task.completeRestoration();
+        assertThrows(IllegalStateException.class, () -> task.closeAndRecycleState()); // RUNNING
 
-        task.closeAndRecycleState();
+        task.suspend();
+        task.closeAndRecycleState(); // SUSPENDED
 
         EasyMock.verify(stateManager, recordCollector);
     }
 
+    @Test
+    public void shouldAlwaysSuspendCreatedTasks() {
+        EasyMock.replay(stateManager);
+        task = createStatefulTask(createConfig(false, "100"), true);
+        assertThat(task.state(), equalTo(CREATED));
+        task.suspend();
+        assertThat(task.state(), equalTo(SUSPENDED));
+    }
+
+    @Test
+    public void shouldAlwaysSuspendRestoringTasks() {
+        EasyMock.replay(stateManager);
+        task = createStatefulTask(createConfig(false, "100"), true);
+        task.initializeIfNeeded();
+        assertThat(task.state(), equalTo(RESTORING));
+        task.suspend();
+        assertThat(task.state(), equalTo(SUSPENDED));
+    }
+
+    @Test
+    public void shouldAlwaysSuspendRunningTasks() {
+        EasyMock.replay(stateManager);
+        task = createFaultyStatefulTask(createConfig(false, "100"));
+        task.initializeIfNeeded();
+        task.completeRestoration();
+        assertThat(task.state(), equalTo(RUNNING));
+        assertThrows(RuntimeException.class, () -> task.suspend());
+        assertThat(task.state(), equalTo(SUSPENDED));
+    }
+
     private StreamTask createOptimizedStatefulTask(final StreamsConfig config, final Consumer<byte[], byte[]> consumer) {
         final StateStore stateStore = new MockKeyValueStore(storeName, true);
 
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
index 76136b9..a0f3be5 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
@@ -42,6 +42,7 @@ import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.errors.TaskMigratedException;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.internals.StreamThread.ProcessingMode;
 import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
 import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender;
 import org.apache.kafka.streams.state.internals.OffsetCheckpoint;
@@ -53,7 +54,6 @@ import org.hamcrest.Matchers;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
-import org.junit.function.ThrowingRunnable;
 import org.junit.rules.TemporaryFolder;
 import org.junit.runner.RunWith;
 
@@ -84,7 +84,6 @@ import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.common.utils.Utils.mkSet;
 import static org.easymock.EasyMock.anyObject;
 import static org.easymock.EasyMock.anyString;
-import static org.easymock.EasyMock.checkOrder;
 import static org.easymock.EasyMock.eq;
 import static org.easymock.EasyMock.expect;
 import static org.easymock.EasyMock.expectLastCall;
@@ -421,10 +420,11 @@ public class TaskManagerTest {
     }
 
     @Test
-    public void shouldCloseDirtyActiveUnassignedSuspendedTasksWhenErrorCommittingRevokedTask() {
+    public void shouldCloseDirtyActiveUnassignedTasksWhenErrorSuspendingTask() {
         final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true) {
             @Override
-            public Map<TopicPartition, OffsetAndMetadata> prepareCommit() {
+            public void suspend() {
+                super.suspend();
                 throw new RuntimeException("KABOOM!");
             }
         };
@@ -510,21 +510,7 @@ public class TaskManagerTest {
         expectLastCall();
         replay(activeTaskCreator);
 
-        final StreamsMetricsImpl streamsMetrics =
-            new StreamsMetricsImpl(new Metrics(), "clientId", StreamsConfig.METRICS_LATEST);
-        taskManager = new TaskManager(
-            changeLogReader,
-            UUID.randomUUID(),
-            "taskManagerTest",
-            streamsMetrics,
-            activeTaskCreator,
-            standbyTaskCreator,
-            topologyBuilder,
-            adminClient,
-            stateDirectory,
-            StreamThread.ProcessingMode.EXACTLY_ONCE_BETA
-        );
-        taskManager.setMainConsumer(consumer);
+        setUpTaskManager(ProcessingMode.EXACTLY_ONCE_BETA);
 
         taskManager.handleLostAll();
 
@@ -649,6 +635,7 @@ public class TaskManagerTest {
         expectLastCall().anyTimes();
 
         expectRestoreToBeCompleted(consumer, changeLogReader);
+        consumer.commitSync(eq(emptyMap()));
 
         replay(activeTaskCreator, topologyBuilder, consumer, changeLogReader);
 
@@ -916,7 +903,7 @@ public class TaskManagerTest {
     }
 
     @Test
-    public void shouldSuspendActiveTasks() {
+    public void shouldSuspendActiveTasksDuringRevocation() {
         final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true);
         final Map<TopicPartition, OffsetAndMetadata> offsets = singletonMap(t1p0, new OffsetAndMetadata(0L, null));
         task00.setCommittableOffsetsAndMetadata(offsets);
@@ -937,10 +924,14 @@ public class TaskManagerTest {
     }
 
     @Test
-    public void shouldCommitAllActiveTasksTheNeedCommittingOnHandleAssignmentIfOneTaskClosed() {
+    public void shouldCommitAllActiveTasksThatNeedCommittingOnHandleRevocationWithEosBeta() {
+        final StreamsProducer producer = mock(StreamsProducer.class);
+        setUpTaskManager(ProcessingMode.EXACTLY_ONCE_BETA);
+
         final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true);
         final Map<TopicPartition, OffsetAndMetadata> offsets00 = singletonMap(t1p0, new OffsetAndMetadata(0L, null));
         task00.setCommittableOffsetsAndMetadata(offsets00);
+        task00.setCommitNeeded();
 
         final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true);
         final Map<TopicPartition, OffsetAndMetadata> offsets01 = singletonMap(t1p1, new OffsetAndMetadata(1L, null));
@@ -970,11 +961,14 @@ public class TaskManagerTest {
 
         expect(activeTaskCreator.createTasks(anyObject(), eq(assignmentActive)))
             .andReturn(asList(task00, task01, task02));
+        expect(activeTaskCreator.threadProducer()).andReturn(producer);
         activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId00);
-        expectLastCall();
         expect(standbyTaskCreator.createTasks(eq(assignmentStandby)))
             .andReturn(singletonList(task10));
-        consumer.commitSync(expectedCommittedOffsets);
+
+        final ConsumerGroupMetadata groupMetadata = new ConsumerGroupMetadata("appId");
+        expect(consumer.groupMetadata()).andReturn(groupMetadata);
+        producer.commitTransaction(expectedCommittedOffsets, groupMetadata);
         expectLastCall();
 
         replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
@@ -986,8 +980,7 @@ public class TaskManagerTest {
         assertThat(task02.state(), is(Task.State.RUNNING));
         assertThat(task10.state(), is(Task.State.RUNNING));
 
-        assignmentActive.remove(taskId00);
-        taskManager.handleAssignment(assignmentActive, assignmentStandby);
+        taskManager.handleRevocation(taskId00Partitions);
 
         assertThat(task00.commitNeeded, is(false));
         assertThat(task01.commitNeeded, is(false));
@@ -996,37 +989,65 @@ public class TaskManagerTest {
     }
 
     @Test
-    public void shouldNotCommitOnHandleAssignmentIfNoTaskClosed() {
+    public void shouldCommitOnlyRevokedActiveTasksThatNeedCommittingOnHandleRevocation() {
         final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true);
         final Map<TopicPartition, OffsetAndMetadata> offsets00 = singletonMap(t1p0, new OffsetAndMetadata(0L, null));
         task00.setCommittableOffsetsAndMetadata(offsets00);
         task00.setCommitNeeded();
 
+        final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true);
+        final Map<TopicPartition, OffsetAndMetadata> offsets01 = singletonMap(t1p1, new OffsetAndMetadata(1L, null));
+        task01.setCommittableOffsetsAndMetadata(offsets01);
+        task01.setCommitNeeded();
+
+        final StateMachineTask task02 = new StateMachineTask(taskId02, taskId02Partitions, true);
+        final Map<TopicPartition, OffsetAndMetadata> offsets02 = singletonMap(t1p2, new OffsetAndMetadata(2L, null));
+        task02.setCommittableOffsetsAndMetadata(offsets02);
+
         final StateMachineTask task10 = new StateMachineTask(taskId10, taskId10Partitions, false);
 
-        final Map<TaskId, Set<TopicPartition>> assignmentActive = singletonMap(taskId00, taskId00Partitions);
-        final Map<TaskId, Set<TopicPartition>> assignmentStandby = singletonMap(taskId10, taskId10Partitions);
+        final Map<TopicPartition, OffsetAndMetadata> expectedCommittedOffsets = new HashMap<>();
+        expectedCommittedOffsets.putAll(offsets00);
+
+        final Map<TaskId, Set<TopicPartition>> assignmentActive = mkMap(
+            mkEntry(taskId00, taskId00Partitions),
+            mkEntry(taskId01, taskId01Partitions),
+            mkEntry(taskId02, taskId02Partitions)
+        );
 
+        final Map<TaskId, Set<TopicPartition>> assignmentStandby = mkMap(
+            mkEntry(taskId10, taskId10Partitions)
+        );
         expectRestoreToBeCompleted(consumer, changeLogReader);
 
-        expect(activeTaskCreator.createTasks(anyObject(), eq(assignmentActive))).andReturn(singleton(task00));
-        expect(standbyTaskCreator.createTasks(eq(assignmentStandby))).andReturn(singletonList(task10));
+        expect(activeTaskCreator.createTasks(anyObject(), eq(assignmentActive)))
+            .andReturn(asList(task00, task01, task02));
+        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId00);
+        expectLastCall();
+        expect(standbyTaskCreator.createTasks(eq(assignmentStandby)))
+            .andReturn(singletonList(task10));
+        consumer.commitSync(expectedCommittedOffsets);
+        expectLastCall();
 
         replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
 
         taskManager.handleAssignment(assignmentActive, assignmentStandby);
         assertThat(taskManager.tryToCompleteRestoration(), is(true));
         assertThat(task00.state(), is(Task.State.RUNNING));
+        assertThat(task01.state(), is(Task.State.RUNNING));
+        assertThat(task02.state(), is(Task.State.RUNNING));
         assertThat(task10.state(), is(Task.State.RUNNING));
 
-        taskManager.handleAssignment(assignmentActive, assignmentStandby);
+        taskManager.handleRevocation(taskId00Partitions);
 
-        assertThat(task00.commitNeeded, is(true));
+        assertThat(task00.commitNeeded, is(false));
+        assertThat(task01.commitPrepared, is(false));
+        assertThat(task02.commitPrepared, is(false));
         assertThat(task10.commitPrepared, is(false));
     }
 
     @Test
-    public void shouldNotCommitOnHandleAssignmentIfOnlyStandbyTaskClosed() {
+    public void shouldNotCommitOnHandleAssignmentIfNoTaskClosed() {
         final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true);
         final Map<TopicPartition, OffsetAndMetadata> offsets00 = singletonMap(t1p0, new OffsetAndMetadata(0L, null));
         task00.setCommittableOffsetsAndMetadata(offsets00);
@@ -1049,127 +1070,57 @@ public class TaskManagerTest {
         assertThat(task00.state(), is(Task.State.RUNNING));
         assertThat(task10.state(), is(Task.State.RUNNING));
 
-        taskManager.handleAssignment(assignmentActive, Collections.emptyMap());
+        taskManager.handleAssignment(assignmentActive, assignmentStandby);
 
         assertThat(task00.commitNeeded, is(true));
+        assertThat(task10.commitPrepared, is(false));
     }
 
     @Test
-    public void shouldCleanupAnyTasksClosedAsDirtyAfterCommitException() {
-        final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true);
-        final Map<TopicPartition, OffsetAndMetadata> offsets00 = singletonMap(t1p0, new OffsetAndMetadata(0L, null));
-        task00.setCommittableOffsetsAndMetadata(offsets00);
-
-        final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true);
-        final Map<TopicPartition, OffsetAndMetadata> offsets01 = singletonMap(t1p1, new OffsetAndMetadata(1L, null));
-        task01.setCommittableOffsetsAndMetadata(offsets01);
-        task01.setCommitNeeded();
-
-        task01.setChangelogOffsets(singletonMap(t1p1, 0L));
-
-        final StateMachineTask task02 = new StateMachineTask(taskId02, taskId02Partitions, true);
-        final Map<TopicPartition, OffsetAndMetadata> offsets02 = singletonMap(t1p2, new OffsetAndMetadata(2L, null));
-        task02.setCommittableOffsetsAndMetadata(offsets02);
-
-        final Map<TopicPartition, OffsetAndMetadata> expectedCommittedOffsets = new HashMap<>();
-        expectedCommittedOffsets.putAll(offsets00);
-        expectedCommittedOffsets.putAll(offsets01);
-
-        final Map<TaskId, Set<TopicPartition>> assignmentActive = mkMap(
-            mkEntry(taskId00, taskId00Partitions),
-            mkEntry(taskId01, taskId01Partitions),
-            mkEntry(taskId02, taskId02Partitions)
-        );
-
-        expect(activeTaskCreator.createTasks(anyObject(), eq(assignmentActive)))
-            .andReturn(asList(task00, task01, task02));
-        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(EasyMock.anyObject(TaskId.class));
-        expectLastCall().anyTimes();
-
-        consumer.commitSync(expectedCommittedOffsets);
-        expectLastCall().andThrow(new RuntimeException("Something went wrong!"));
-
-        replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
-
-        taskManager.handleAssignment(assignmentActive, emptyMap());
-
-        assignmentActive.remove(taskId00);
-        assertThrows(
-            RuntimeException.class,
-            () -> taskManager.handleAssignment(assignmentActive, emptyMap())
-        );
-
-        verify(changeLogReader);
-    }
-
-    @Test
-    public void shouldCommitAllActiveTasksTheNeedCommittingOnRevocation() {
+    public void shouldNotCommitOnHandleAssignmentIfOnlyStandbyTaskClosed() {
         final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true);
         final Map<TopicPartition, OffsetAndMetadata> offsets00 = singletonMap(t1p0, new OffsetAndMetadata(0L, null));
         task00.setCommittableOffsetsAndMetadata(offsets00);
-
-        final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true);
-        final Map<TopicPartition, OffsetAndMetadata> offsets01 = singletonMap(t1p1, new OffsetAndMetadata(1L, null));
-        task01.setCommittableOffsetsAndMetadata(offsets01);
-        task01.setCommitNeeded();
-
-        final StateMachineTask task02 = new StateMachineTask(taskId02, taskId02Partitions, true);
-        final Map<TopicPartition, OffsetAndMetadata> offsets02 = singletonMap(t1p2, new OffsetAndMetadata(2L, null));
-        task02.setCommittableOffsetsAndMetadata(offsets02);
+        task00.setCommitNeeded();
 
         final StateMachineTask task10 = new StateMachineTask(taskId10, taskId10Partitions, false);
 
-        final Map<TopicPartition, OffsetAndMetadata> expectedCommittedOffsets = new HashMap<>();
-        expectedCommittedOffsets.putAll(offsets00);
-        expectedCommittedOffsets.putAll(offsets01);
-
-        final Map<TaskId, Set<TopicPartition>> assignmentActive = mkMap(
-            mkEntry(taskId00, taskId00Partitions),
-            mkEntry(taskId01, taskId01Partitions),
-            mkEntry(taskId02, taskId02Partitions)
-        );
+        final Map<TaskId, Set<TopicPartition>> assignmentActive = singletonMap(taskId00, taskId00Partitions);
+        final Map<TaskId, Set<TopicPartition>> assignmentStandby = singletonMap(taskId10, taskId10Partitions);
 
-        final Map<TaskId, Set<TopicPartition>> assignmentStandby = mkMap(
-            mkEntry(taskId10, taskId10Partitions)
-        );
         expectRestoreToBeCompleted(consumer, changeLogReader);
 
-        expect(activeTaskCreator.createTasks(anyObject(), eq(assignmentActive)))
-            .andReturn(asList(task00, task01, task02));
-        expect(standbyTaskCreator.createTasks(eq(assignmentStandby)))
-            .andReturn(singletonList(task10));
-        consumer.commitSync(expectedCommittedOffsets);
-        expectLastCall();
+        expect(activeTaskCreator.createTasks(anyObject(), eq(assignmentActive))).andReturn(singleton(task00));
+        expect(standbyTaskCreator.createTasks(eq(assignmentStandby))).andReturn(singletonList(task10));
 
         replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
 
         taskManager.handleAssignment(assignmentActive, assignmentStandby);
         assertThat(taskManager.tryToCompleteRestoration(), is(true));
         assertThat(task00.state(), is(Task.State.RUNNING));
-        assertThat(task01.state(), is(Task.State.RUNNING));
-        assertThat(task02.state(), is(Task.State.RUNNING));
         assertThat(task10.state(), is(Task.State.RUNNING));
 
-        taskManager.handleRevocation(taskId00Partitions);
+        taskManager.handleAssignment(assignmentActive, Collections.emptyMap());
 
-        assertThat(task01.commitPrepared, is(true));
-        assertThat(task01.commitNeeded, is(false));
-        assertThat(task02.commitPrepared, is(false));
-        assertThat(task10.commitPrepared, is(false));
+        assertThat(task00.commitNeeded, is(true));
     }
 
     @Test
-    public void shouldNotCommitCreatedTasksOnSuspend() {
+    public void shouldNotCommitCreatedTasksOnRevocationOrClosure() {
         final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true);
 
         expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andReturn(singletonList(task00));
+        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId00));
         replay(activeTaskCreator, consumer, changeLogReader);
 
         taskManager.handleAssignment(taskId00Assignment, emptyMap());
         assertThat(task00.state(), is(Task.State.CREATED));
 
         taskManager.handleRevocation(taskId00Partitions);
-        assertThat(task00.state(), is(Task.State.CREATED));
+        assertThat(task00.state(), is(Task.State.SUSPENDED));
+
+        taskManager.handleAssignment(emptyMap(), emptyMap());
+        assertThat(task00.state(), is(Task.State.CLOSED));
     }
 
     @Test
@@ -1423,91 +1374,64 @@ public class TaskManagerTest {
     }
 
     @Test
-    public void shouldCloseActiveTasksDirtyAndPropagatePrepareCommitException() {
+    public void shouldOnlyCommitRevokedStandbyTaskAndPropagatePrepareCommitException() {
         setUpTaskManager(StreamThread.ProcessingMode.EXACTLY_ONCE_ALPHA);
 
-        final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true);
+        final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, false);
 
-        final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true) {
+        final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, false) {
             @Override
             public Map<TopicPartition, OffsetAndMetadata> prepareCommit() {
                 throw new RuntimeException("task 0_1 prepare commit boom!");
             }
         };
-
-        task01.setCommittableOffsetsAndMetadata(singletonMap(t1p1, new OffsetAndMetadata(0L, null)));
         task01.setCommitNeeded();
 
-        final StateMachineTask task02 = new StateMachineTask(taskId02, taskId02Partitions, true);
-        final Map<TopicPartition, OffsetAndMetadata> offsetsT02 = singletonMap(t1p2, new OffsetAndMetadata(1L, null));
-
-        task02.setCommittableOffsetsAndMetadata(offsetsT02);
-        task02.setCommitNeeded();
-
         taskManager.tasks().put(taskId00, task00);
         taskManager.tasks().put(taskId01, task01);
-        taskManager.tasks().put(taskId02, task02);
-
-        checkOrder(activeTaskCreator, false);
-
-        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId01);
-        expectLastCall();
-
-        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId02);
-        expectLastCall();
-
-        replay(activeTaskCreator);
 
         final RuntimeException thrown = assertThrows(RuntimeException.class,
-            () -> taskManager.handleAssignment(mkMap(mkEntry(taskId00, taskId00Partitions),
-                mkEntry(taskId01, taskId01Partitions)), Collections.emptyMap()));
+            () -> taskManager.handleAssignment(
+                Collections.emptyMap(),
+                singletonMap(taskId00, taskId00Partitions)
+            ));
         assertThat(thrown.getCause().getMessage(), is("task 0_1 prepare commit boom!"));
 
         assertThat(task00.state(), is(Task.State.CREATED));
         assertThat(task01.state(), is(Task.State.CLOSED));
-        assertThat(task02.state(), is(Task.State.CLOSED));
 
         // All the tasks involving in the commit should already be removed.
         assertThat(taskManager.tasks(), is(Collections.singletonMap(taskId00, task00)));
-
-        verify(activeTaskCreator);
     }
 
     @Test
-    public void shouldCloseActiveTasksDirtyAndPropagateCommitException() {
+    public void shouldCloseActiveTasksDirtyAndPropagateSuspendException() {
         setUpTaskManager(StreamThread.ProcessingMode.EXACTLY_ONCE_ALPHA);
 
         final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true);
 
-        final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true);
-        task01.setCommittableOffsetsAndMetadata(singletonMap(t1p1, new OffsetAndMetadata(0L, null)));
-        task01.setCommitNeeded();
+        final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true) {
+            @Override
+            public void suspend() {
+                super.suspend();
+                throw new RuntimeException("task 0_1 suspend boom!");
+            }
+        };
 
         final StateMachineTask task02 = new StateMachineTask(taskId02, taskId02Partitions, true);
-        final Map<TopicPartition, OffsetAndMetadata> offsetsT02 = singletonMap(t1p2, new OffsetAndMetadata(1L, null));
-
-        task02.setCommittableOffsetsAndMetadata(offsetsT02);
-        task02.setCommitNeeded();
 
         taskManager.tasks().put(taskId00, task00);
         taskManager.tasks().put(taskId01, task01);
         taskManager.tasks().put(taskId02, task02);
 
-        expect(activeTaskCreator.streamsProducerForTask(taskId01)).andThrow(new RuntimeException("task 0_1 producer boom!"));
-
-        checkOrder(activeTaskCreator, false);
-
-        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId01);
-        expectLastCall();
-
         activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId02);
-        expectLastCall();
+        activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId01);
 
         replay(activeTaskCreator);
 
         final RuntimeException thrown = assertThrows(RuntimeException.class,
             () -> taskManager.handleAssignment(mkMap(mkEntry(taskId00, taskId00Partitions)), Collections.emptyMap()));
-        assertThat(thrown.getCause().getMessage(), is("task 0_1 producer boom!"));
+        assertThat(thrown.getCause().getMessage(), is("task 0_1 suspend boom!"));
 
         assertThat(task00.state(), is(Task.State.CREATED));
         assertThat(task01.state(), is(Task.State.CLOSED));
@@ -1741,7 +1665,8 @@ public class TaskManagerTest {
             .andReturn(Arrays.asList(task00, task01, task02)).anyTimes();
         expect(standbyTaskCreator.createTasks(eq(assignmentStandby)))
             .andReturn(Arrays.asList(task03, task04, task05)).anyTimes();
-        expectLastCall();
+
+        consumer.commitSync(eq(emptyMap()));
 
         replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
 
@@ -2632,36 +2557,31 @@ public class TaskManagerTest {
     }
 
     @Test
-    public void shouldNotCloseTasksIfCommittingFailsDuringRevocation() {
-        shouldNotCloseTaskIfCommitFailsDuringAction(() -> taskManager.handleRevocation(singletonList(t1p0)));
-    }
-
-    @Test
-    public void shouldNotCloseTasksIfCommittingFailsDuringShutdown() {
-        shouldNotCloseTaskIfCommitFailsDuringAction(() -> taskManager.shutdown(true));
-    }
-
-    private void shouldNotCloseTaskIfCommitFailsDuringAction(final ThrowingRunnable action) {
-        final Map<TopicPartition, OffsetAndMetadata> offsets = singletonMap(t1p0, new OffsetAndMetadata(0L, null));
+    public void shouldSuspendAllTasksButSkipCommitIfSuspendingFailsDuringRevocation() {
         final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true) {
             @Override
-            public Map<TopicPartition, OffsetAndMetadata> prepareCommit() {
-                return offsets;
+            public void suspend() {
+                super.suspend();
+                throw new RuntimeException("KABOOM!");
             }
         };
+        final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true);
 
-        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment)))
-            .andReturn(singletonList(task00));
-        consumer.commitSync(offsets);
-        expectLastCall().andThrow(new RuntimeException("KABOOM!"));
+        final Map<TaskId, Set<TopicPartition>> assignment = new HashMap<>(taskId00Assignment);
+        assignment.putAll(taskId01Assignment);
+        expect(activeTaskCreator.createTasks(anyObject(), eq(assignment)))
+            .andReturn(asList(task00, task01));
         replay(activeTaskCreator, consumer);
 
-        taskManager.handleAssignment(taskId00Assignment, Collections.emptyMap());
+        taskManager.handleAssignment(assignment, Collections.emptyMap());
 
-        final RuntimeException thrown = assertThrows(RuntimeException.class, action);
+        final RuntimeException thrown = assertThrows(
+            RuntimeException.class,
+            () -> taskManager.handleRevocation(asList(t1p0, t1p1)));
 
-        assertThat(thrown.getMessage(), is("KABOOM!"));
-        assertThat(task00.state(), is(Task.State.CREATED));
+        assertThat(thrown.getCause().getMessage(), is("KABOOM!"));
+        assertThat(task00.state(), is(Task.State.SUSPENDED));
+        assertThat(task01.state(), is(Task.State.SUSPENDED));
     }
 
     private static void expectRestoreToBeCompleted(final Consumer<byte[], byte[]> consumer,
@@ -2782,7 +2702,11 @@ public class TaskManagerTest {
 
         @Override
         public void suspend() {
-            if (state() == State.RUNNING) {
+            if (state() == State.CLOSED) {
+                throw new IllegalStateException("Illegal state " + state() + " while suspending active task " + id);
+            } else if (state() == State.SUSPENDED) {
+                // do nothing
+            } else {
                 transitionTo(State.SUSPENDED);
             }
         }