You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by ca...@apache.org on 2022/09/07 12:21:52 UTC

[kafka] branch trunk updated: KAFKA-10199: Separate state updater from old restore (#12583)

This is an automated email from the ASF dual-hosted git repository.

cadonna pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 44b500b679 KAFKA-10199: Separate state updater from old restore (#12583)
44b500b679 is described below

commit 44b500b6795c4eb9f77362e25baee60fd83e5ce4
Author: Bruno Cadonna <ca...@apache.org>
AuthorDate: Wed Sep 7 14:21:36 2022 +0200

    KAFKA-10199: Separate state updater from old restore (#12583)
    
    Separates the code path for the new state updater from
    the code path of the old restoration.
    
    Ensures that with the state updater tasks are processed
    before all tasks are running.
    
    Reviewers: Guozhang Wang <wa...@gmail.com>, Walker Carlson <wcarlson@confluent.io
---
 .../streams/processor/internals/StreamThread.java  |  55 ++--
 .../streams/processor/internals/TaskManager.java   | 125 ++++----
 .../processor/internals/StreamThreadTest.java      |  54 +++-
 .../processor/internals/TaskManagerTest.java       | 336 +++++++++------------
 4 files changed, 303 insertions(+), 267 deletions(-)

diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
index aa1aa2c315..0bc19cefc3 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
@@ -320,6 +320,7 @@ public class StreamThread extends Thread {
     private final AtomicLong cacheResizeSize = new AtomicLong(-1L);
     private final AtomicBoolean leaveGroupRequested = new AtomicBoolean(false);
     private final boolean eosEnabled;
+    private final boolean stateUpdaterEnabled;
 
     public static StreamThread create(final TopologyMetadata topologyMetadata,
                                       final StreamsConfig config,
@@ -540,6 +541,7 @@ public class StreamThread extends Thread {
 
         this.numIterations = 1;
         this.eosEnabled = eosEnabled(config);
+        this.stateUpdaterEnabled = InternalConfig.getBoolean(config.originals(), InternalConfig.STATE_UPDATER_ENABLED, false);
     }
 
     private static final class InternalConsumerConfig extends ConsumerConfig {
@@ -770,7 +772,8 @@ public class StreamThread extends Thread {
         long totalCommitLatency = 0L;
         long totalProcessLatency = 0L;
         long totalPunctuateLatency = 0L;
-        if (state == State.RUNNING) {
+        if (state == State.RUNNING
+            || (stateUpdaterEnabled && isRunning())) {
             /*
              * Within an iteration, after processing up to N (N initialized as 1 upon start up) records for each applicable tasks, check the current time:
              *  1. If it is time to punctuate, do it;
@@ -867,37 +870,47 @@ public class StreamThread extends Thread {
     }
 
     private void initializeAndRestorePhase() {
-        // only try to initialize the assigned tasks
-        // if the state is still in PARTITION_ASSIGNED after the poll call
+        final java.util.function.Consumer<Set<TopicPartition>> offsetResetter = partitions -> resetOffsets(partitions, null);
         final State stateSnapshot = state;
-        if (stateSnapshot == State.PARTITIONS_ASSIGNED
-            || stateSnapshot == State.RUNNING && taskManager.needsInitializationOrRestoration()) {
+        if (stateUpdaterEnabled) {
+            checkStateUpdater();
+        } else {
+            // only try to initialize the assigned tasks
+            // if the state is still in PARTITION_ASSIGNED after the poll call
+            if (stateSnapshot == State.PARTITIONS_ASSIGNED
+                || stateSnapshot == State.RUNNING && taskManager.needsInitializationOrRestoration()) {
 
-            log.debug("State is {}; initializing tasks if necessary", stateSnapshot);
+                log.debug("State is {}; initializing tasks if necessary", stateSnapshot);
 
-            // transit to restore active is idempotent so we can call it multiple times
-            changelogReader.enforceRestoreActive();
+                if (taskManager.tryToCompleteRestoration(now, offsetResetter)) {
+                    log.info("Restoration took {} ms for all tasks {}", time.milliseconds() - lastPartitionAssignedMs,
+                        taskManager.allTasks().keySet());
+                    setState(State.RUNNING);
+                }
 
-            if (taskManager.tryToCompleteRestoration(now, partitions -> resetOffsets(partitions, null))) {
-                changelogReader.transitToUpdateStandby();
-                log.info("Restoration took {} ms for all tasks {}", time.milliseconds() - lastPartitionAssignedMs,
-                    taskManager.allTasks().keySet());
-                setState(State.RUNNING);
+                if (log.isDebugEnabled()) {
+                    log.debug("Initialization call done. State is {}", state);
+                }
             }
 
             if (log.isDebugEnabled()) {
-                log.debug("Initialization call done. State is {}", state);
+                log.debug("Idempotently invoking restoration logic in state {}", state);
             }
+            // we can always let changelog reader try restoring in order to initialize the changelogs;
+            // if there's no active restoring or standby updating it would not try to fetch any data
+            // After KAFKA-13873, we only restore the not paused tasks.
+            changelogReader.restore(taskManager.notPausedTasks());
+            log.debug("Idempotent restore call done. Thread state has not changed.");
         }
+    }
 
-        if (log.isDebugEnabled()) {
-            log.debug("Idempotently invoking restoration logic in state {}", state);
+    private void checkStateUpdater() {
+        final java.util.function.Consumer<Set<TopicPartition>> offsetResetter = partitions -> resetOffsets(partitions, null);
+        final State stateSnapshot = state;
+        final boolean allRunning = taskManager.checkStateUpdater(now, offsetResetter);
+        if (allRunning && stateSnapshot == State.PARTITIONS_ASSIGNED) {
+            setState(State.RUNNING);
         }
-        // we can always let changelog reader try restoring in order to initialize the changelogs;
-        // if there's no active restoring or standby updating it would not try to fetch any data
-        // After KAFKA-13873, we only restore the not paused tasks.
-        changelogReader.restore(taskManager.notPausedTasks());
-        log.debug("Idempotent restore call done. Thread state has not changed.");
     }
 
     // Check if the topology has been updated since we last checked, ie via #addNamedTopology or #removeNamedTopology
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
index 82fb110664..d9d05391e9 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
@@ -610,73 +610,79 @@ public class TaskManager {
      * @throws StreamsException if the store's change log does not contain the partition
      * @return {@code true} if all tasks are fully restored
      */
-    boolean tryToCompleteRestoration(final long now, final java.util.function.Consumer<Set<TopicPartition>> offsetResetter) {
-        if (stateUpdater == null) {
-            boolean allRunning = true;
+    boolean tryToCompleteRestoration(final long now,
+                                     final java.util.function.Consumer<Set<TopicPartition>> offsetResetter) {
+        boolean allRunning = true;
 
-            final List<Task> activeTasks = new LinkedList<>();
-            for (final Task task : tasks.allTasks()) {
-                try {
-                    task.initializeIfNeeded();
-                    task.clearTaskTimeout();
-                } catch (final LockException lockException) {
-                    // it is possible that if there are multiple threads within the instance that one thread
-                    // trying to grab the task from the other, while the other has not released the lock since
-                    // it did not participate in the rebalance. In this case we can just retry in the next iteration
-                    log.debug("Could not initialize task {} since: {}; will retry", task.id(), lockException.getMessage());
-                    allRunning = false;
-                } catch (final TimeoutException timeoutException) {
-                    task.maybeInitTaskTimeoutOrThrow(now, timeoutException);
-                    allRunning = false;
-                }
+        // transit to restore active is idempotent so we can call it multiple times
+        changelogReader.enforceRestoreActive();
 
-                if (task.isActive()) {
-                    activeTasks.add(task);
-                }
+        final List<Task> activeTasks = new LinkedList<>();
+        for (final Task task : tasks.allTasks()) {
+            try {
+                task.initializeIfNeeded();
+                task.clearTaskTimeout();
+            } catch (final LockException lockException) {
+                // it is possible that if there are multiple threads within the instance that one thread
+                // trying to grab the task from the other, while the other has not released the lock since
+                // it did not participate in the rebalance. In this case we can just retry in the next iteration
+                log.debug("Could not initialize task {} since: {}; will retry", task.id(), lockException.getMessage());
+                allRunning = false;
+            } catch (final TimeoutException timeoutException) {
+                task.maybeInitTaskTimeoutOrThrow(now, timeoutException);
+                allRunning = false;
             }
 
-            if (allRunning && !activeTasks.isEmpty()) {
-
-                final Set<TopicPartition> restored = changelogReader.completedChangelogs();
-
-                for (final Task task : activeTasks) {
-                    if (restored.containsAll(task.changelogPartitions())) {
-                        try {
-                            task.completeRestoration(offsetResetter);
-                            task.clearTaskTimeout();
-                        } catch (final TimeoutException timeoutException) {
-                            task.maybeInitTaskTimeoutOrThrow(now, timeoutException);
-                            log.debug(
-                                String.format(
-                                    "Could not complete restoration for %s due to the following exception; will retry",
-                                    task.id()),
-                                timeoutException
-                            );
-
-                            allRunning = false;
-                        }
-                    } else {
-                        // we found a restoring task that isn't done restoring, which is evidence that
-                        // not all tasks are running
+            if (task.isActive()) {
+                activeTasks.add(task);
+            }
+        }
+
+        if (allRunning && !activeTasks.isEmpty()) {
+
+            final Set<TopicPartition> restored = changelogReader.completedChangelogs();
+
+            for (final Task task : activeTasks) {
+                if (restored.containsAll(task.changelogPartitions())) {
+                    try {
+                        task.completeRestoration(offsetResetter);
+                        task.clearTaskTimeout();
+                    } catch (final TimeoutException timeoutException) {
+                        task.maybeInitTaskTimeoutOrThrow(now, timeoutException);
+                        log.debug(
+                            String.format(
+                                "Could not complete restoration for %s due to the following exception; will retry",
+                                task.id()),
+                            timeoutException
+                        );
+
                         allRunning = false;
                     }
+                } else {
+                    // we found a restoring task that isn't done restoring, which is evidence that
+                    // not all tasks are running
+                    allRunning = false;
                 }
             }
-            if (allRunning) {
-                // we can call resume multiple times since it is idempotent.
-                mainConsumer.resume(mainConsumer.assignment());
-            }
-
-            return allRunning;
-        } else {
-            addTasksToStateUpdater();
-
-            handleExceptionsFromStateUpdater();
+        }
+        if (allRunning) {
+            // we can call resume multiple times since it is idempotent.
+            mainConsumer.resume(mainConsumer.assignment());
+            changelogReader.transitToUpdateStandby();
+        }
 
-            handleRemovedTasksFromStateUpdater();
+        return allRunning;
+    }
 
-            return handleRestoredTasksFromStateUpdater(now, offsetResetter);
+    public boolean checkStateUpdater(final long now,
+                                     final java.util.function.Consumer<Set<TopicPartition>> offsetResetter) {
+        addTasksToStateUpdater();
+        handleExceptionsFromStateUpdater();
+        handleRemovedTasksFromStateUpdater();
+        if (stateUpdater.restoresActiveTasks()) {
+            handleRestoredTasksFromStateUpdater(now, offsetResetter);
         }
+        return !stateUpdater.restoresActiveTasks();
     }
 
     private void recycleTask(final Task task,
@@ -685,6 +691,7 @@ public class TaskManager {
                              final Map<TaskId, RuntimeException> taskExceptions) {
         Task newTask = null;
         try {
+            task.suspend();
             newTask = task.isActive() ?
                 convertActiveToStandby((StreamTask) task, inputPartitions) :
                 convertStandbyToActive((StandbyTask) task, inputPartitions);
@@ -817,8 +824,6 @@ public class TaskManager {
                 tasksToCloseDirty.add(task);
             } else if ((inputPartitions = tasks.removePendingTaskToUpdateInputPartitions(task.id())) != null) {
                 task.updateInputPartitions(inputPartitions, topologyMetadata.nodeToSourceTopics(task.id()));
-                // if the restored task happen to need input partition update, we can transit it to running
-                // right after completing the update as well
                 transitRestoredTaskToRunning(task, now, offsetResetter);
             } else {
                 transitRestoredTaskToRunning(task, now, offsetResetter);
@@ -1224,13 +1229,13 @@ public class TaskManager {
     private void shutdownStateUpdater() {
         if (stateUpdater != null) {
             stateUpdater.shutdown(Duration.ofMillis(Long.MAX_VALUE));
-            closeFailedTasks();
+            closeFailedTasksFromStateUpdater();
             addRestoredTasksToTaskRegistry();
             addRemovedTasksToTaskRegistry();
         }
     }
 
-    private void closeFailedTasks() {
+    private void closeFailedTasksFromStateUpdater() {
         final Set<Task> tasksToCloseDirty = stateUpdater.drainExceptionsAndFailedTasks().stream()
             .flatMap(exAndTasks -> exAndTasks.getTasks().stream()).collect(Collectors.toSet());
 
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
index 9edfbc10b3..9e2dcdfa71 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
@@ -52,6 +52,7 @@ import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.KeyValue;
 import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.StreamsConfig.InternalConfig;
 import org.apache.kafka.streams.ThreadMetadata;
 import org.apache.kafka.streams.errors.LogAndContinueExceptionHandler;
 import org.apache.kafka.streams.errors.StreamsException;
@@ -69,6 +70,7 @@ import org.apache.kafka.streams.processor.api.Processor;
 import org.apache.kafka.streams.processor.api.ProcessorContext;
 import org.apache.kafka.streams.processor.api.ProcessorSupplier;
 import org.apache.kafka.streams.processor.api.Record;
+import org.apache.kafka.streams.processor.internals.StreamThread.State;
 import org.apache.kafka.streams.processor.internals.assignment.ReferenceContainer;
 import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
 import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender;
@@ -87,6 +89,10 @@ import org.easymock.EasyMock;
 import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.junit.MockitoJUnitRunner;
 import org.slf4j.Logger;
 
 import java.io.File;
@@ -109,6 +115,7 @@ import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.function.BiConsumer;
+import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
 import static java.util.Collections.emptyMap;
@@ -142,7 +149,9 @@ import static org.junit.Assert.assertSame;
 import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
+import static org.mockito.Mockito.when;
 
+@RunWith(MockitoJUnitRunner.class)
 public class StreamThreadTest {
 
     private final static String APPLICATION_ID = "stream-thread-test";
@@ -164,6 +173,9 @@ public class StreamThreadTest {
     private final InternalTopologyBuilder internalTopologyBuilder = new InternalTopologyBuilder();
     private final InternalStreamsBuilder internalStreamsBuilder = new InternalStreamsBuilder(internalTopologyBuilder);
 
+    @Mock
+    private Consumer<byte[], byte[]> mainConsumer;
+
     private StreamsMetadataState streamsMetadataState;
     private final static BiConsumer<Throwable, Boolean> HANDLER = (e, b) -> {
         if (e instanceof RuntimeException) {
@@ -853,7 +865,7 @@ public class StreamThreadTest {
 
         final TaskManager taskManager = new TaskManager(
             null,
-            null,
+            changelogReader,
             null,
             null,
             activeTaskCreator,
@@ -2913,6 +2925,46 @@ public class StreamThreadTest {
         assertThat(failedThreads.metricValue(), is(shouldFail ? 1.0 : 0.0));
     }
 
+    @Test
+    public void shouldCheckStateUpdater() {
+        final Properties streamsConfigProps = StreamsTestUtils.getStreamsConfig();
+        final ConsumerGroupMetadata consumerGroupMetadata = Mockito.mock(ConsumerGroupMetadata.class);
+        when(consumerGroupMetadata.groupInstanceId()).thenReturn(Optional.empty());
+        when(mainConsumer.poll(Mockito.any(Duration.class))).thenReturn(new ConsumerRecords<>(Collections.emptyMap()));
+        when(mainConsumer.groupMetadata()).thenReturn(consumerGroupMetadata);
+        final TaskManager taskManager = Mockito.mock(TaskManager.class);
+        streamsConfigProps.put(InternalConfig.STATE_UPDATER_ENABLED, true);
+        final TopologyMetadata topologyMetadata = new TopologyMetadata(internalTopologyBuilder, config);
+        topologyMetadata.buildAndRewriteTopology();
+        final StreamThread streamThread = new StreamThread(
+            mockTime,
+            new StreamsConfig(streamsConfigProps.entrySet().stream()
+                .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))),
+            null,
+            mainConsumer,
+            null,
+            changelogReader,
+            "",
+            taskManager,
+            new StreamsMetricsImpl(metrics, CLIENT_ID, StreamsConfig.METRICS_LATEST, mockTime),
+            topologyMetadata,
+            "thread-id",
+            new LogContext(),
+            null,
+            null,
+            new LinkedList<>(),
+            null,
+            null,
+            null
+        );
+        streamThread.setState(State.STARTING);
+
+        streamThread.runOnce();
+
+        Mockito.verify(taskManager).checkStateUpdater(Mockito.anyLong(), Mockito.any());
+        Mockito.verify(taskManager).process(Mockito.anyInt(), Mockito.any());
+    }
+
     private TaskManager mockTaskManagerPurge(final int numberOfPurges) {
         final TaskManager taskManager = EasyMock.createNiceMock(TaskManager.class);
         final Task runningTask = mock(Task.class);
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
index 15ffbe7a23..5335d4dbbe 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
@@ -121,8 +121,8 @@ import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyLong;
+import static org.mockito.ArgumentMatchers.anyMap;
 import static org.mockito.ArgumentMatchers.argThat;
-import static org.mockito.ArgumentMatchers.isNull;
 import static org.mockito.Mockito.doNothing;
 import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.never;
@@ -280,25 +280,17 @@ public class TaskManagerTest {
 
     @Test
     public void shouldAddTasksToStateUpdater() {
-        final StreamTask task00 = statefulTask(taskId00, taskId00Partitions)
+        final StreamTask task00 = statefulTask(taskId00, taskId00ChangelogPartitions)
             .withInputPartitions(taskId00Partitions)
-            .inState(State.RESTORING)
-            .build();
-        final StandbyTask task01 = standbyTask(taskId01, taskId01Partitions)
+            .inState(State.RESTORING).build();
+        final StandbyTask task01 = standbyTask(taskId01, taskId01ChangelogPartitions)
             .withInputPartitions(taskId01Partitions)
-            .inState(State.RUNNING)
-            .build();
-        expect(changeLogReader.completedChangelogs()).andReturn(emptySet()).anyTimes();
-        expect(consumer.assignment()).andReturn(emptySet()).anyTimes();
-        consumer.resume(anyObject());
-        expectLastCall().anyTimes();
-        expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andStubReturn(singletonList(task00));
-        expect(standbyTaskCreator.createTasks(eq(taskId01Assignment))).andStubReturn(singletonList(task01));
-        replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
+            .inState(State.RUNNING).build();
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        when(tasks.drainPendingTaskToInit()).thenReturn(mkSet(task00, task01));
+        taskManager = setUpTaskManager(StreamsConfigUtils.ProcessingMode.AT_LEAST_ONCE, tasks, true);
 
-        taskManager = setUpTaskManager(StreamsConfigUtils.ProcessingMode.AT_LEAST_ONCE, true);
-        taskManager.handleAssignment(taskId00Assignment, taskId01Assignment);
-        taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter -> { });
+        taskManager.checkStateUpdater(time.milliseconds(), noOpResetter -> { });
 
         Mockito.verify(task00).initializeIfNeeded();
         Mockito.verify(task01).initializeIfNeeded();
@@ -308,38 +300,34 @@ public class TaskManagerTest {
 
     @Test
     public void shouldHandleRemovedTasksToRecycleFromStateUpdater() {
-        final StreamTask task00 = statefulTask(taskId00, taskId00Partitions)
+        final StreamTask task00 = statefulTask(taskId00, taskId00ChangelogPartitions)
             .withInputPartitions(taskId00Partitions)
-            .inState(State.RESTORING)
-            .build();
-        final StandbyTask task01 = standbyTask(taskId01, taskId01Partitions)
+            .inState(State.RESTORING).build();
+        final StandbyTask task01 = standbyTask(taskId01, taskId01ChangelogPartitions)
             .withInputPartitions(taskId01Partitions)
-            .inState(State.RUNNING)
-            .build();
+            .inState(State.RUNNING).build();
         final StandbyTask task00Converted = standbyTask(taskId00, taskId00Partitions)
-            .withInputPartitions(taskId00Partitions)
-            .build();
+            .withInputPartitions(taskId00Partitions).build();
         final StreamTask task01Converted = statefulTask(taskId01, taskId01Partitions)
-            .withInputPartitions(taskId01Partitions)
-            .build();
+            .withInputPartitions(taskId01Partitions).build();
         when(stateUpdater.drainRemovedTasks()).thenReturn(mkSet(task00, task01));
-
-        taskManager = setUpTaskManager(StreamsConfigUtils.ProcessingMode.AT_LEAST_ONCE, true);
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        when(tasks.removePendingTaskToRecycle(task00.id())).thenReturn(taskId00Partitions);
+        when(tasks.removePendingTaskToRecycle(task01.id())).thenReturn(taskId01Partitions);
+        taskManager = setUpTaskManager(StreamsConfigUtils.ProcessingMode.AT_LEAST_ONCE, tasks, true);
         expect(activeTaskCreator.createActiveTaskFromStandby(eq(task01), eq(taskId01Partitions), eq(consumer)))
             .andStubReturn(task01Converted);
         activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(anyObject());
         expectLastCall().once();
         expect(standbyTaskCreator.createStandbyTaskFromActive(eq(task00), eq(taskId00Partitions)))
             .andStubReturn(task00Converted);
-        expect(consumer.assignment()).andReturn(emptySet()).anyTimes();
-        consumer.resume(anyObject());
-        expectLastCall().anyTimes();
-        replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer);
+        replay(activeTaskCreator, standbyTaskCreator);
 
-        taskManager.tasks().addPendingTaskToRecycle(taskId00, taskId00Partitions);
-        taskManager.tasks().addPendingTaskToRecycle(taskId01, taskId01Partitions);
-        taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter -> { });
+        taskManager.checkStateUpdater(time.milliseconds(), noOpResetter -> { });
 
+        verify(activeTaskCreator, standbyTaskCreator);
+        Mockito.verify(task00).suspend();
+        Mockito.verify(task01).suspend();
         Mockito.verify(task00Converted).initializeIfNeeded();
         Mockito.verify(task01Converted).initializeIfNeeded();
         Mockito.verify(stateUpdater).add(task00Converted);
@@ -348,29 +336,25 @@ public class TaskManagerTest {
 
     @Test
     public void shouldHandleRemovedTasksToCloseFromStateUpdater() {
-        final StreamTask task00 = statefulTask(taskId00, taskId00Partitions)
+        final StreamTask task00 = statefulTask(taskId00, taskId00ChangelogPartitions)
             .withInputPartitions(taskId00Partitions)
-            .inState(State.RESTORING)
-            .build();
-        final StandbyTask task01 = standbyTask(taskId01, taskId01Partitions)
+            .inState(State.RESTORING).build();
+        final StandbyTask task01 = standbyTask(taskId01, taskId01ChangelogPartitions)
             .withInputPartitions(taskId01Partitions)
-            .inState(State.RUNNING)
-            .build();
+            .inState(State.RUNNING).build();
         when(stateUpdater.drainRemovedTasks()).thenReturn(mkSet(task00, task01));
-
-        taskManager = setUpTaskManager(StreamsConfigUtils.ProcessingMode.AT_LEAST_ONCE, true);
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        when(tasks.removePendingTaskToRecycle(any())).thenReturn(null);
+        when(tasks.removePendingTaskToCloseClean(task00.id())).thenReturn(true);
+        when(tasks.removePendingTaskToCloseClean(task01.id())).thenReturn(true);
+        taskManager = setUpTaskManager(StreamsConfigUtils.ProcessingMode.AT_LEAST_ONCE, tasks, true);
         activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(anyObject());
         expectLastCall().once();
-        expect(consumer.assignment()).andReturn(emptySet()).anyTimes();
-        consumer.resume(anyObject());
-        expectLastCall().anyTimes();
-        replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer);
-
-        taskManager.tasks().addPendingTaskToCloseClean(taskId00);
-        taskManager.tasks().addPendingTaskToCloseClean(taskId01);
+        replay(activeTaskCreator);
 
-        taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter -> { });
+        taskManager.checkStateUpdater(time.milliseconds(), noOpResetter -> { });
 
+        verify(activeTaskCreator);
         Mockito.verify(task00).suspend();
         Mockito.verify(task00).closeClean();
         Mockito.verify(task01).suspend();
@@ -379,30 +363,29 @@ public class TaskManagerTest {
 
     @Test
     public void shouldHandleRemovedTasksToUpdateInputPartitionsFromStateUpdater() {
-        final StreamTask task00 = statefulTask(taskId00, taskId00Partitions)
+        final StreamTask task00 = statefulTask(taskId00, taskId00ChangelogPartitions)
             .withInputPartitions(taskId00Partitions)
-            .inState(State.RESTORING)
-            .build();
-        final StandbyTask task01 = standbyTask(taskId01, taskId01Partitions)
+            .inState(State.RESTORING).build();
+        final StandbyTask task01 = standbyTask(taskId01, taskId01ChangelogPartitions)
             .withInputPartitions(taskId01Partitions)
-            .inState(State.RUNNING)
-            .build();
+            .inState(State.RUNNING).build();
         when(stateUpdater.drainRemovedTasks()).thenReturn(mkSet(task00, task01));
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        when(tasks.removePendingTaskToRecycle(any())).thenReturn(null);
+        when(tasks.removePendingTaskToUpdateInputPartitions(task00.id())).thenReturn(taskId02Partitions);
+        when(tasks.removePendingTaskToUpdateInputPartitions(task01.id())).thenReturn(taskId03Partitions);
+        taskManager = setUpTaskManager(StreamsConfigUtils.ProcessingMode.AT_LEAST_ONCE, tasks, true);
+        replay(topologyBuilder);
 
-        taskManager = setUpTaskManager(StreamsConfigUtils.ProcessingMode.AT_LEAST_ONCE, true);
-        expect(consumer.assignment()).andReturn(emptySet()).anyTimes();
-        consumer.resume(anyObject());
-        expectLastCall().anyTimes();
-        replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer);
-
-        taskManager.tasks().addPendingTaskToUpdateInputPartitions(taskId00, taskId02Partitions);
-        taskManager.tasks().addPendingTaskToUpdateInputPartitions(taskId01, taskId03Partitions);
-
-        taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter -> { });
+        taskManager.checkStateUpdater(time.milliseconds(), noOpResetter -> { });
 
-        Mockito.verify(task00).updateInputPartitions(taskId02Partitions, emptyMap());
+        Mockito.verify(task00).updateInputPartitions(Mockito.eq(taskId02Partitions), anyMap());
+        Mockito.verify(task00, never()).closeDirty();
+        Mockito.verify(task00, never()).closeClean();
         Mockito.verify(stateUpdater).add(task00);
-        Mockito.verify(task01).updateInputPartitions(taskId03Partitions, emptyMap());
+        Mockito.verify(task01).updateInputPartitions(Mockito.eq(taskId03Partitions), anyMap());
+        Mockito.verify(task01, never()).closeDirty();
+        Mockito.verify(task01, never()).closeClean();
         Mockito.verify(stateUpdater).add(task01);
     }
 
@@ -411,15 +394,14 @@ public class TaskManagerTest {
         final StreamTask task = statefulTask(taskId00, taskId00ChangelogPartitions)
             .inState(State.RESTORING)
             .withInputPartitions(taskId00Partitions).build();
-        final TaskManager taskManager = setupForRevocationAndLost(mkSet(task), mkSet(task));
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        final TaskManager taskManager = setupForRevocationAndLost(mkSet(task), tasks);
+        when(stateUpdater.getTasks()).thenReturn(mkSet(task));
 
-        taskManager.handleRevocation(taskId00Partitions);
+        taskManager.handleRevocation(task.inputPartitions());
 
+        Mockito.verify(tasks).addPendingTaskToCloseClean(task.id());
         Mockito.verify(stateUpdater).remove(task.id());
-
-        taskManager.tryToCompleteRestoration(time.milliseconds(), null);
-
-        Mockito.verify(task).closeClean();
     }
 
     public void shouldRemoveMultipleStatefulTaskWithRevokedInputPartitionsFromStateUpdaterOnRevocation() {
@@ -429,17 +411,15 @@ public class TaskManagerTest {
         final StreamTask task2 = statefulTask(taskId01, taskId01ChangelogPartitions)
             .inState(State.RESTORING)
             .withInputPartitions(taskId01Partitions).build();
-        final TaskManager taskManager = setupForRevocationAndLost(mkSet(task1, task2), mkSet(task1, task2));
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        final TaskManager taskManager = setupForRevocationAndLost(mkSet(task1, task2), tasks);
 
         taskManager.handleRevocation(union(HashSet::new, taskId00Partitions, taskId01Partitions));
 
+        Mockito.verify(tasks).addPendingTaskToCloseClean(task1.id());
+        Mockito.verify(tasks).addPendingTaskToCloseClean(task2.id());
         Mockito.verify(stateUpdater).remove(task1.id());
         Mockito.verify(stateUpdater).remove(task2.id());
-
-        taskManager.tryToCompleteRestoration(time.milliseconds(), null);
-
-        Mockito.verify(task1).closeClean();
-        Mockito.verify(task2).closeClean();
     }
 
     @Test
@@ -447,15 +427,13 @@ public class TaskManagerTest {
         final StreamTask task = statefulTask(taskId00, taskId00ChangelogPartitions)
             .inState(State.RESTORING)
             .withInputPartitions(taskId00Partitions).build();
-        final TaskManager taskManager = setupForRevocationAndLost(mkSet(task), Collections.emptySet());
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        final TaskManager taskManager = setupForRevocationAndLost(mkSet(task), tasks);
 
         taskManager.handleRevocation(taskId01Partitions);
 
         Mockito.verify(stateUpdater, never()).remove(task.id());
-
-        taskManager.tryToCompleteRestoration(time.milliseconds(), null);
-
-        Mockito.verify(task, never()).closeClean();
+        Mockito.verify(tasks, never()).addPendingTaskToCloseClean(task.id());
     }
 
     @Test
@@ -463,15 +441,13 @@ public class TaskManagerTest {
         final StandbyTask task = standbyTask(taskId00, taskId00ChangelogPartitions)
             .inState(State.RESTORING)
             .withInputPartitions(taskId00Partitions).build();
-        final TaskManager taskManager = setupForRevocationAndLost(mkSet(task), Collections.emptySet());
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        final TaskManager taskManager = setupForRevocationAndLost(mkSet(task), tasks);
 
         taskManager.handleRevocation(taskId00Partitions);
 
         Mockito.verify(stateUpdater, never()).remove(task.id());
-
-        taskManager.tryToCompleteRestoration(time.milliseconds(), null);
-
-        Mockito.verify(task, never()).closeClean();
+        Mockito.verify(tasks, never()).addPendingTaskToCloseClean(task.id());
     }
 
     @Test
@@ -485,102 +461,78 @@ public class TaskManagerTest {
         final StreamTask task3 = statefulTask(taskId02, taskId02ChangelogPartitions)
             .inState(State.RESTORING)
             .withInputPartitions(taskId02Partitions).build();
-        final TaskManager taskManager = setupForRevocationAndLost(mkSet(task1, task2, task3), mkSet(task1, task3));
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        final TaskManager taskManager = setupForRevocationAndLost(mkSet(task1, task2, task3), tasks);
 
         taskManager.handleLostAll();
 
         Mockito.verify(stateUpdater).remove(task1.id());
+        Mockito.verify(stateUpdater, never()).remove(task2.id());
         Mockito.verify(stateUpdater).remove(task3.id());
-
-        taskManager.tryToCompleteRestoration(time.milliseconds(), null);
-
-        Mockito.verify(task1).closeDirty();
-        Mockito.verify(task3).closeDirty();
-        Mockito.verify(task2, never()).closeDirty();
-        Mockito.verify(task2, never()).closeClean();
+        Mockito.verify(tasks).addPendingTaskToCloseDirty(task1.id());
+        Mockito.verify(tasks, never()).addPendingTaskToCloseDirty(task2.id());
+        Mockito.verify(tasks, never()).addPendingTaskToCloseClean(task2.id());
+        Mockito.verify(tasks).addPendingTaskToCloseDirty(task3.id());
     }
 
     private TaskManager setupForRevocationAndLost(final Set<Task> tasksInStateUpdater,
-                                                  final Set<Task> removedTasks) {
-        final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, true);
+                                                  final TasksRegistry tasks) {
+        final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
         when(stateUpdater.getTasks()).thenReturn(tasksInStateUpdater);
-        when(stateUpdater.drainRemovedTasks()).thenReturn(removedTasks);
-        expect(consumer.assignment()).andReturn(emptySet()).anyTimes();
-        consumer.resume(anyObject());
-        expectLastCall().anyTimes();
-        replay(consumer);
 
         return taskManager;
     }
 
     @Test
     public void shouldHandleRemovedTasksFromStateUpdater() {
-        // tasks to recycle
-        final StreamTask task00 = mock(StreamTask.class);
-        final StandbyTask task01 = mock(StandbyTask.class);
-        final StandbyTask task00Converted = mock(StandbyTask.class);
-        final StreamTask task01Converted = mock(StreamTask.class);
-        // task to close
-        final StreamTask task02 = mock(StreamTask.class);
-        // task to update inputs
-        final StreamTask task03 = mock(StreamTask.class);
-        when(task00.id()).thenReturn(taskId00);
-        when(task01.id()).thenReturn(taskId01);
-        when(task02.id()).thenReturn(taskId02);
-        when(task03.id()).thenReturn(taskId03);
-        when(task00.inputPartitions()).thenReturn(taskId00Partitions);
-        when(task01.inputPartitions()).thenReturn(taskId01Partitions);
-        when(task02.inputPartitions()).thenReturn(taskId02Partitions);
-        when(task03.inputPartitions()).thenReturn(taskId03Partitions);
-        when(task00.isActive()).thenReturn(true);
-        when(task01.isActive()).thenReturn(false);
-        when(task02.isActive()).thenReturn(true);
-        when(task03.isActive()).thenReturn(true);
-        when(task00.state()).thenReturn(State.RESTORING);
-        when(task01.state()).thenReturn(State.RUNNING);
-        when(task02.state()).thenReturn(State.RESTORING);
-        when(task03.state()).thenReturn(State.RESTORING);
-        when(stateUpdater.drainRemovedTasks()).thenReturn(mkSet(task00, task01, task02, task03));
-
-        expect(activeTaskCreator.createActiveTaskFromStandby(eq(task01), eq(taskId01Partitions), eq(consumer)))
-            .andStubReturn(task01Converted);
+        final StreamTask taskToRecycle0 = statefulTask(taskId00, taskId00ChangelogPartitions)
+            .inState(State.RESTORING)
+            .withInputPartitions(taskId00Partitions).build();
+        final StandbyTask taskToRecycle1 = standbyTask(taskId01, taskId01ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId01Partitions).build();
+        final StandbyTask convertedTask0 = standbyTask(taskId00, taskId00ChangelogPartitions).build();
+        final StreamTask convertedTask1 = statefulTask(taskId01, taskId01ChangelogPartitions).build();
+        final StreamTask taskToClose = statefulTask(taskId02, taskId02ChangelogPartitions)
+            .inState(State.RESTORING)
+            .withInputPartitions(taskId02Partitions).build();
+        final StreamTask taskToUpdateInputPartitions = statefulTask(taskId03, taskId03ChangelogPartitions)
+            .inState(State.RESTORING)
+            .withInputPartitions(taskId03Partitions).build();
+        when(stateUpdater.drainRemovedTasks())
+            .thenReturn(mkSet(taskToRecycle0, taskToRecycle1, taskToClose, taskToUpdateInputPartitions));
+        when(stateUpdater.restoresActiveTasks()).thenReturn(true);
+        expect(activeTaskCreator.createActiveTaskFromStandby(eq(taskToRecycle1), eq(taskId01Partitions), eq(consumer)))
+            .andStubReturn(convertedTask1);
         activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(anyObject());
         expectLastCall().times(2);
-        expect(standbyTaskCreator.createStandbyTaskFromActive(eq(task00), eq(taskId00Partitions)))
-            .andStubReturn(task00Converted);
+        expect(standbyTaskCreator.createStandbyTaskFromActive(eq(taskToRecycle0), eq(taskId00Partitions)))
+            .andStubReturn(convertedTask0);
         expect(consumer.assignment()).andReturn(emptySet()).anyTimes();
         consumer.resume(anyObject());
         expectLastCall().anyTimes();
-        replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer);
-
-        taskManager = new TaskManager(
-            time,
-            changeLogReader,
-            UUID.randomUUID(),
-            "taskManagerTest",
-            activeTaskCreator,
-            standbyTaskCreator,
-            new Tasks(new LogContext()),
-            topologyMetadata,
-            adminClient,
-            stateDirectory,
-            stateUpdater
-        );
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        when(tasks.removePendingTaskToCloseClean(taskToClose.id())).thenReturn(true);
+        when(tasks.removePendingTaskToRecycle(taskToRecycle0.id())).thenReturn(taskId00Partitions);
+        when(tasks.removePendingTaskToRecycle(taskToRecycle1.id())).thenReturn(taskId01Partitions);
+        when(tasks.removePendingTaskToRecycle(
+            argThat(taskId -> !taskId.equals(taskToRecycle0.id()) && !taskId.equals(taskToRecycle1.id())))
+        ).thenReturn(null);
+        when(tasks.removePendingTaskToUpdateInputPartitions(taskToUpdateInputPartitions.id())).thenReturn(taskId04Partitions);
+        final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
         taskManager.setMainConsumer(consumer);
-        taskManager.tasks().addPendingTaskToCloseClean(taskId02);
-        taskManager.tasks().addPendingTaskToRecycle(taskId00, taskId00Partitions);
-        taskManager.tasks().addPendingTaskToRecycle(taskId01, taskId01Partitions);
-        taskManager.tasks().addPendingTaskToUpdateInputPartitions(taskId03, taskId03Partitions);
+        replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer);
 
-        taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter -> { });
+        taskManager.checkStateUpdater(time.milliseconds(), noOpResetter -> { });
 
-        Mockito.verify(task00Converted).initializeIfNeeded();
-        Mockito.verify(task01Converted).initializeIfNeeded();
-        Mockito.verify(stateUpdater).add(task00Converted);
-        Mockito.verify(stateUpdater).add(task01Converted);
-        Mockito.verify(task02).closeClean();
-        Mockito.verify(task03).updateInputPartitions(taskId03Partitions, emptyMap());
-        Mockito.verify(stateUpdater).add(task03);
+        verify(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer);
+        Mockito.verify(convertedTask0).initializeIfNeeded();
+        Mockito.verify(convertedTask1).initializeIfNeeded();
+        Mockito.verify(stateUpdater).add(convertedTask0);
+        Mockito.verify(stateUpdater).add(convertedTask1);
+        Mockito.verify(taskToClose).closeClean();
+        Mockito.verify(taskToUpdateInputPartitions).updateInputPartitions(Mockito.eq(taskId04Partitions), anyMap());
+        Mockito.verify(stateUpdater).add(taskToUpdateInputPartitions);
     }
 
     @Test
@@ -593,7 +545,7 @@ public class TaskManagerTest {
         consumer.resume(task.inputPartitions());
         replay(consumer);
 
-        taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter);
+        taskManager.checkStateUpdater(time.milliseconds(), noOpResetter);
 
         Mockito.verify(task).completeRestoration(noOpResetter);
         Mockito.verify(task).clearTaskTimeout();
@@ -612,7 +564,7 @@ public class TaskManagerTest {
         doThrow(timeoutException).when(task).completeRestoration(noOpResetter);
         replay(consumer);
 
-        taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter);
+        taskManager.checkStateUpdater(time.milliseconds(), noOpResetter);
 
         Mockito.verify(task).maybeInitTaskTimeoutOrThrow(anyLong(), Mockito.eq(timeoutException));
         Mockito.verify(tasks, never()).addTask(task);
@@ -624,6 +576,7 @@ public class TaskManagerTest {
                                                                final TasksRegistry tasks) {
         when(tasks.removePendingTaskToRecycle(statefulTask.id())).thenReturn(null);
         when(tasks.removePendingTaskToUpdateInputPartitions(statefulTask.id())).thenReturn(null);
+        when(stateUpdater.restoresActiveTasks()).thenReturn(true);
         when(stateUpdater.drainRestoredActiveTasks(any(Duration.class))).thenReturn(mkSet(statefulTask));
 
         return setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
@@ -632,10 +585,12 @@ public class TaskManagerTest {
     @Test
     public void shouldReturnCorrectBooleanWhenTryingToCompleteRestorationWithStateUpdater() {
         final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, true);
-        when(stateUpdater.restoresActiveTasks()).thenReturn(false).thenReturn(true);
 
-        assertTrue(taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter));
-        assertFalse(taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter));
+        when(stateUpdater.restoresActiveTasks()).thenReturn(false);
+        assertTrue(taskManager.checkStateUpdater(time.milliseconds(), noOpResetter));
+
+        when(stateUpdater.restoresActiveTasks()).thenReturn(true);
+        assertFalse(taskManager.checkStateUpdater(time.milliseconds(), noOpResetter));
     }
 
     @Test
@@ -652,9 +607,10 @@ public class TaskManagerTest {
         activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(statefulTask.id());
         replay(activeTaskCreator, standbyTaskCreator);
 
-        taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter);
+        taskManager.checkStateUpdater(time.milliseconds(), noOpResetter);
 
         verify(activeTaskCreator, standbyTaskCreator);
+        Mockito.verify(statefulTask).suspend();
         Mockito.verify(standbyTask).initializeIfNeeded();
         Mockito.verify(stateUpdater).add(standbyTask);
     }
@@ -671,7 +627,7 @@ public class TaskManagerTest {
 
         assertThrows(
             StreamsException.class,
-            () -> taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter)
+            () -> taskManager.checkStateUpdater(time.milliseconds(), noOpResetter)
         );
 
         verify(standbyTaskCreator);
@@ -695,7 +651,7 @@ public class TaskManagerTest {
 
         assertThrows(
             StreamsException.class,
-            () -> taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter)
+            () -> taskManager.checkStateUpdater(time.milliseconds(), noOpResetter)
         );
 
         verify(standbyTaskCreator);
@@ -706,6 +662,7 @@ public class TaskManagerTest {
     private TaskManager setUpRecycleRestoredTask(final StreamTask statefulTask) {
         final TasksRegistry tasks = mock(TasksRegistry.class);
         when(tasks.removePendingTaskToRecycle(statefulTask.id())).thenReturn(taskId00Partitions);
+        when(stateUpdater.restoresActiveTasks()).thenReturn(true);
         when(stateUpdater.drainRestoredActiveTasks(any(Duration.class))).thenReturn(mkSet(statefulTask));
 
         return setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
@@ -721,7 +678,7 @@ public class TaskManagerTest {
         activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(statefulTask.id());
         replay(activeTaskCreator);
 
-        taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter);
+        taskManager.checkStateUpdater(time.milliseconds(), noOpResetter);
 
         verify(activeTaskCreator);
         Mockito.verify(statefulTask).suspend();
@@ -743,7 +700,7 @@ public class TaskManagerTest {
 
         assertThrows(
             RuntimeException.class,
-            () -> taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter)
+            () -> taskManager.checkStateUpdater(time.milliseconds(), noOpResetter)
         );
 
         verify(activeTaskCreator);
@@ -764,7 +721,7 @@ public class TaskManagerTest {
 
         assertThrows(
             RuntimeException.class,
-            () -> taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter)
+            () -> taskManager.checkStateUpdater(time.milliseconds(), noOpResetter)
         );
 
         verify(activeTaskCreator);
@@ -776,6 +733,7 @@ public class TaskManagerTest {
                                                     final TasksRegistry tasks) {
         when(tasks.removePendingTaskToRecycle(statefulTask.id())).thenReturn(null);
         when(tasks.removePendingTaskToCloseClean(statefulTask.id())).thenReturn(true);
+        when(stateUpdater.restoresActiveTasks()).thenReturn(true);
         when(stateUpdater.drainRestoredActiveTasks(any(Duration.class))).thenReturn(mkSet(statefulTask));
 
         return setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
@@ -790,11 +748,12 @@ public class TaskManagerTest {
         when(tasks.removePendingTaskToRecycle(statefulTask.id())).thenReturn(null);
         when(tasks.removePendingTaskToCloseDirty(statefulTask.id())).thenReturn(true);
         when(stateUpdater.drainRestoredActiveTasks(any(Duration.class))).thenReturn(mkSet(statefulTask));
+        when(stateUpdater.restoresActiveTasks()).thenReturn(true);
         activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(statefulTask.id());
         final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
         replay(activeTaskCreator);
 
-        taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter);
+        taskManager.checkStateUpdater(time.milliseconds(), noOpResetter);
 
         verify(activeTaskCreator);
         Mockito.verify(statefulTask).prepareCommit();
@@ -813,11 +772,13 @@ public class TaskManagerTest {
         when(tasks.removePendingTaskToRecycle(statefulTask.id())).thenReturn(null);
         when(tasks.removePendingTaskToUpdateInputPartitions(statefulTask.id())).thenReturn(taskId01Partitions);
         when(stateUpdater.drainRestoredActiveTasks(any(Duration.class))).thenReturn(mkSet(statefulTask));
+        when(stateUpdater.restoresActiveTasks()).thenReturn(true);
         final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, true);
+        replay(topologyBuilder);
 
-        taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter);
+        taskManager.checkStateUpdater(time.milliseconds(), noOpResetter);
 
-        Mockito.verify(statefulTask).updateInputPartitions(Mockito.eq(taskId01Partitions), isNull());
+        Mockito.verify(statefulTask).updateInputPartitions(Mockito.eq(taskId01Partitions), anyMap());
         Mockito.verify(statefulTask, never()).closeDirty();
         Mockito.verify(statefulTask, never()).closeClean();
     }
@@ -862,6 +823,7 @@ public class TaskManagerTest {
         when(tasks.removePendingTaskToUpdateInputPartitions(
             argThat(taskId -> !taskId.equals(taskToUpdateInputPartitions.id())))
         ).thenReturn(null);
+        when(stateUpdater.restoresActiveTasks()).thenReturn(true);
         when(stateUpdater.drainRestoredActiveTasks(any(Duration.class))).thenReturn(mkSet(
             taskToTransitToRunning,
             taskToRecycle,
@@ -869,16 +831,16 @@ public class TaskManagerTest {
             taskToCloseDirty,
             taskToUpdateInputPartitions
         ));
-        replay(standbyTaskCreator);
+        replay(standbyTaskCreator, topologyBuilder);
 
-        taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter);
+        taskManager.checkStateUpdater(time.milliseconds(), noOpResetter);
 
         Mockito.verify(tasks).addTask(taskToTransitToRunning);
         Mockito.verify(stateUpdater).add(recycledStandbyTask);
         Mockito.verify(stateUpdater).add(recycledStandbyTask);
         Mockito.verify(taskToCloseClean).closeClean();
         Mockito.verify(taskToCloseDirty).closeDirty();
-        Mockito.verify(taskToUpdateInputPartitions).updateInputPartitions(Mockito.eq(taskId05Partitions), isNull());
+        Mockito.verify(taskToUpdateInputPartitions).updateInputPartitions(Mockito.eq(taskId05Partitions), anyMap());
     }
 
     @Test
@@ -898,7 +860,7 @@ public class TaskManagerTest {
 
         final StreamsException thrown = assertThrows(
             StreamsException.class,
-            () -> taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter)
+            () -> taskManager.checkStateUpdater(time.milliseconds(), noOpResetter)
         );
 
         assertEquals(exception, thrown);
@@ -922,7 +884,7 @@ public class TaskManagerTest {
 
         final StreamsException thrown = assertThrows(
             StreamsException.class,
-            () -> taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter)
+            () -> taskManager.checkStateUpdater(time.milliseconds(), noOpResetter)
         );
 
         assertEquals(exception, thrown.getCause());
@@ -953,7 +915,7 @@ public class TaskManagerTest {
 
         final TaskCorruptedException thrown = assertThrows(
             TaskCorruptedException.class,
-            () -> taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter)
+            () -> taskManager.checkStateUpdater(time.milliseconds(), noOpResetter)
         );
 
         assertEquals(mkSet(taskId00, taskId01), thrown.corruptedTasks());
@@ -2449,6 +2411,7 @@ public class TaskManagerTest {
         };
 
         resetToStrict(changeLogReader);
+        changeLogReader.enforceRestoreActive();
         expect(changeLogReader.completedChangelogs()).andReturn(emptySet());
         expect(activeTaskCreator.createTasks(anyObject(), eq(assignment)))
             .andStubReturn(asList(task00, task01, task02, task03));
@@ -2520,6 +2483,7 @@ public class TaskManagerTest {
         task00.setCommittableOffsetsAndMetadata(offsets);
 
         resetToStrict(changeLogReader);
+        changeLogReader.enforceRestoreActive();
         expect(changeLogReader.completedChangelogs()).andReturn(emptySet());
         expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andStubReturn(singletonList(task00));
         activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId00));
@@ -2570,6 +2534,7 @@ public class TaskManagerTest {
         };
 
         resetToStrict(changeLogReader);
+        changeLogReader.enforceRestoreActive();
         expect(changeLogReader.completedChangelogs()).andReturn(emptySet());
         expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andStubReturn(singletonList(task00));
         activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId00));
@@ -2698,6 +2663,7 @@ public class TaskManagerTest {
         };
 
         resetToStrict(changeLogReader);
+        changeLogReader.enforceRestoreActive();
         expect(changeLogReader.completedChangelogs()).andReturn(emptySet());
         expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andStubReturn(asList(task00, task01, task02));
         activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(anyObject());