You are viewing a plain text version of this content. The canonical link for it is here.
Posted to jira@kafka.apache.org by GitBox <gi...@apache.org> on 2020/09/16 01:50:10 UTC

[GitHub] [kafka] ableegoldman commented on a change in pull request #8988: KAFKA-10199: Separate restore threads

ableegoldman commented on a change in pull request #8988:
URL: https://github.com/apache/kafka/pull/8988#discussion_r488999403



##########
File path: streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java
##########
@@ -65,10 +65,6 @@ public ProcessorContextImpl(final TaskId id,
 
     @Override
     public void transitionToActive(final StreamTask streamTask, final RecordCollector recordCollector, final ThreadCache newCache) {
-        if (stateManager.taskType() != TaskType.ACTIVE) {

Review comment:
       If this is because we don't transition the state manager's type until later, maybe we should assert that the `stateManager.taskType == STANDBY` instead of removing the check altogether (and vice versa down below)

##########
File path: streams/src/main/java/org/apache/kafka/streams/processor/internals/StateRestoreThread.java
##########
@@ -0,0 +1,263 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.clients.admin.Admin;
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.errors.InterruptException;
+import org.apache.kafka.common.errors.TimeoutException;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.errors.StreamsException;
+import org.apache.kafka.streams.errors.TaskCorruptedException;
+import org.apache.kafka.streams.processor.StateRestoreListener;
+import org.slf4j.Logger;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.LinkedBlockingDeque;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.stream.Collectors;
+
+
+/**
+ * This is the thread responsible for restoring state stores for both active and standby tasks
+ */
+public class StateRestoreThread extends Thread {
+
+    private final Time time;
+    private final Logger log;
+    private final ChangelogReader changelogReader;
+    private final AtomicBoolean isRunning = new AtomicBoolean(true);
+    private final CountDownLatch shutdownLatch = new CountDownLatch(1);
+    private final LinkedBlockingDeque<TaskItem> taskItemQueue;
+    private final AtomicReference<Set<TopicPartition>> completedChangelogs;
+    private final LinkedBlockingDeque<TaskCorruptedException> corruptedExceptions;
+
+    public boolean isRunning() {
+        return isRunning.get();
+    }
+
+    public StateRestoreThread(final Time time,
+                              final StreamsConfig config,
+                              final String threadClientId,
+                              final Admin adminClient,
+                              final String groupId,
+                              final Consumer<byte[], byte[]> restoreConsumer,
+                              final StateRestoreListener userStateRestoreListener) {
+        this(time, threadClientId, new StoreChangelogReader(time, config, threadClientId,
+                adminClient, groupId, restoreConsumer, userStateRestoreListener));
+    }
+
+    // for testing only
+    public StateRestoreThread(final Time time,
+                              final String threadClientId,
+                              final ChangelogReader changelogReader) {
+        super(threadClientId);
+
+        final String logPrefix = String.format("state-restore-thread [%s] ", threadClientId);
+        final LogContext logContext = new LogContext(logPrefix);
+
+        this.time = time;
+        this.log = logContext.logger(getClass());
+        this.taskItemQueue = new LinkedBlockingDeque<>();
+        this.corruptedExceptions = new LinkedBlockingDeque<>();
+        this.completedChangelogs = new AtomicReference<>(Collections.emptySet());
+
+        this.changelogReader = changelogReader;
+    }
+
+    private synchronized void waitIfAllChangelogsCompleted() {
+        final Set<TopicPartition> allChangelogs = changelogReader.allChangelogs();
+        if (allChangelogs.equals(changelogReader.completedChangelogs())) {
+            log.debug("All changelogs {} have completed restoration so far, will wait " +
+                    "until new changelogs are registered", allChangelogs);
+
+            while (isRunning.get() && taskItemQueue.isEmpty()) {
+                try {
+                    wait();
+                } catch (final InterruptedException e) {
+                    // do nothing
+                }
+            }
+        }
+    }
+
+    public synchronized void addInitializedTasks(final List<AbstractTask> tasks) {
+        if (!tasks.isEmpty()) {
+            for (final AbstractTask task: tasks) {
+                taskItemQueue.add(new TaskItem(task, ItemType.CREATE, task.changelogPartitions()));
+            }
+            notifyAll();
+        }
+    }
+
+    public synchronized void addClosedTasks(final Map<AbstractTask, Collection<TopicPartition>> tasks) {
+        if (!tasks.isEmpty()) {
+            for (final Map.Entry<AbstractTask, Collection<TopicPartition>> entry : tasks.entrySet()) {
+                taskItemQueue.add(new TaskItem(entry.getKey(), ItemType.CLOSE, entry.getValue()));
+            }
+            notifyAll();
+        }
+    }
+
+    public Set<TopicPartition> completedChangelogs() {
+        return completedChangelogs.get();
+    }
+
+    @Override
+    public void run() {
+        try {
+            while (isRunning()) {
+                runOnce();
+            }
+        } catch (final Exception e) {
+            log.error("Encountered the following exception while restoring states " +
+                    "and the thread is going to shut down: ", e);
+            throw e;
+        } finally {
+            try {
+                changelogReader.clear();
+            } catch (final Throwable e) {
+                log.error("Failed to close changelog reader due to the following error:", e);
+            }
+
+            shutdownLatch.countDown();
+        }
+    }
+
+    // Visible for testing
+    void runOnce() {
+        waitIfAllChangelogsCompleted();
+
+        if (!isRunning.get())
+            return;
+
+        // a task being recycled maybe in both closed and initialized tasks,
+        // and hence we should process the closed ones first and then initialized ones
+        final List<TaskItem> items = new ArrayList<>();
+        taskItemQueue.drainTo(items);
+
+        if (!items.isEmpty()) {
+            for (final TaskItem item : items) {
+                // TODO: we should consider also call the listener if the
+                //       changelog is not yet completed
+                if (item.type == ItemType.CLOSE) {
+                    changelogReader.unregister(item.changelogPartitions);
+
+                    log.info("Unregistered changelogs {} for closing task {}",
+                            item.task.changelogPartitions(),
+                            item.task.id());
+                } else if (item.type == ItemType.CREATE) {
+                    // we should only convert the state manager type right StateRestoreThreadTest.javabefore re-registering the changelog

Review comment:
       Looks like something extra slipped into this comment

##########
File path: streams/src/main/java/org/apache/kafka/streams/processor/internals/StateRestoreThread.java
##########
@@ -0,0 +1,263 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.clients.admin.Admin;
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.errors.InterruptException;
+import org.apache.kafka.common.errors.TimeoutException;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.errors.StreamsException;
+import org.apache.kafka.streams.errors.TaskCorruptedException;
+import org.apache.kafka.streams.processor.StateRestoreListener;
+import org.slf4j.Logger;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.LinkedBlockingDeque;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.stream.Collectors;
+
+
+/**
+ * This is the thread responsible for restoring state stores for both active and standby tasks
+ */
+public class StateRestoreThread extends Thread {
+
+    private final Time time;
+    private final Logger log;
+    private final ChangelogReader changelogReader;
+    private final AtomicBoolean isRunning = new AtomicBoolean(true);
+    private final CountDownLatch shutdownLatch = new CountDownLatch(1);
+    private final LinkedBlockingDeque<TaskItem> taskItemQueue;
+    private final AtomicReference<Set<TopicPartition>> completedChangelogs;
+    private final LinkedBlockingDeque<TaskCorruptedException> corruptedExceptions;
+
+    public boolean isRunning() {
+        return isRunning.get();
+    }
+
+    public StateRestoreThread(final Time time,
+                              final StreamsConfig config,
+                              final String threadClientId,
+                              final Admin adminClient,
+                              final String groupId,
+                              final Consumer<byte[], byte[]> restoreConsumer,
+                              final StateRestoreListener userStateRestoreListener) {
+        this(time, threadClientId, new StoreChangelogReader(time, config, threadClientId,
+                adminClient, groupId, restoreConsumer, userStateRestoreListener));
+    }
+
+    // for testing only
+    public StateRestoreThread(final Time time,
+                              final String threadClientId,
+                              final ChangelogReader changelogReader) {
+        super(threadClientId);
+
+        final String logPrefix = String.format("state-restore-thread [%s] ", threadClientId);
+        final LogContext logContext = new LogContext(logPrefix);
+
+        this.time = time;
+        this.log = logContext.logger(getClass());
+        this.taskItemQueue = new LinkedBlockingDeque<>();
+        this.corruptedExceptions = new LinkedBlockingDeque<>();
+        this.completedChangelogs = new AtomicReference<>(Collections.emptySet());
+
+        this.changelogReader = changelogReader;
+    }
+
+    private synchronized void waitIfAllChangelogsCompleted() {
+        final Set<TopicPartition> allChangelogs = changelogReader.allChangelogs();
+        if (allChangelogs.equals(changelogReader.completedChangelogs())) {
+            log.debug("All changelogs {} have completed restoration so far, will wait " +
+                    "until new changelogs are registered", allChangelogs);
+
+            while (isRunning.get() && taskItemQueue.isEmpty()) {
+                try {
+                    wait();
+                } catch (final InterruptedException e) {
+                    // do nothing
+                }
+            }
+        }
+    }
+
+    public synchronized void addInitializedTasks(final List<AbstractTask> tasks) {
+        if (!tasks.isEmpty()) {
+            for (final AbstractTask task: tasks) {
+                taskItemQueue.add(new TaskItem(task, ItemType.CREATE, task.changelogPartitions()));
+            }
+            notifyAll();
+        }
+    }
+
+    public synchronized void addClosedTasks(final Map<AbstractTask, Collection<TopicPartition>> tasks) {
+        if (!tasks.isEmpty()) {
+            for (final Map.Entry<AbstractTask, Collection<TopicPartition>> entry : tasks.entrySet()) {
+                taskItemQueue.add(new TaskItem(entry.getKey(), ItemType.CLOSE, entry.getValue()));
+            }
+            notifyAll();
+        }
+    }
+
+    public Set<TopicPartition> completedChangelogs() {
+        return completedChangelogs.get();
+    }
+
+    @Override
+    public void run() {
+        try {
+            while (isRunning()) {
+                runOnce();
+            }
+        } catch (final Exception e) {
+            log.error("Encountered the following exception while restoring states " +
+                    "and the thread is going to shut down: ", e);
+            throw e;
+        } finally {
+            try {
+                changelogReader.clear();
+            } catch (final Throwable e) {
+                log.error("Failed to close changelog reader due to the following error:", e);
+            }
+
+            shutdownLatch.countDown();
+        }
+    }
+
+    // Visible for testing
+    void runOnce() {
+        waitIfAllChangelogsCompleted();
+
+        if (!isRunning.get())
+            return;
+
+        // a task being recycled maybe in both closed and initialized tasks,
+        // and hence we should process the closed ones first and then initialized ones
+        final List<TaskItem> items = new ArrayList<>();
+        taskItemQueue.drainTo(items);
+
+        if (!items.isEmpty()) {
+            for (final TaskItem item : items) {
+                // TODO: we should consider also call the listener if the

Review comment:
       You mean `StateListener#onRestoreEnd`? I agree it would be useful to provide a callback that is guaranteed to be invoked once for every invocation of `#onRestoreStart`. I think it's not immediately obvious that this isn't the case with `#onRestoreEnd` today, and the it's been requested before. Maybe we should file a ticket to consider this?

##########
File path: streams/src/main/java/org/apache/kafka/streams/processor/internals/StateRestoreThread.java
##########
@@ -0,0 +1,263 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.clients.admin.Admin;
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.errors.InterruptException;
+import org.apache.kafka.common.errors.TimeoutException;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.errors.StreamsException;
+import org.apache.kafka.streams.errors.TaskCorruptedException;
+import org.apache.kafka.streams.processor.StateRestoreListener;
+import org.slf4j.Logger;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.LinkedBlockingDeque;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.stream.Collectors;
+
+
+/**
+ * This is the thread responsible for restoring state stores for both active and standby tasks
+ */
+public class StateRestoreThread extends Thread {
+
+    private final Time time;
+    private final Logger log;
+    private final ChangelogReader changelogReader;
+    private final AtomicBoolean isRunning = new AtomicBoolean(true);
+    private final CountDownLatch shutdownLatch = new CountDownLatch(1);
+    private final LinkedBlockingDeque<TaskItem> taskItemQueue;
+    private final AtomicReference<Set<TopicPartition>> completedChangelogs;
+    private final LinkedBlockingDeque<TaskCorruptedException> corruptedExceptions;
+
+    public boolean isRunning() {
+        return isRunning.get();
+    }
+
+    public StateRestoreThread(final Time time,
+                              final StreamsConfig config,
+                              final String threadClientId,
+                              final Admin adminClient,
+                              final String groupId,
+                              final Consumer<byte[], byte[]> restoreConsumer,
+                              final StateRestoreListener userStateRestoreListener) {
+        this(time, threadClientId, new StoreChangelogReader(time, config, threadClientId,
+                adminClient, groupId, restoreConsumer, userStateRestoreListener));
+    }
+
+    // for testing only
+    public StateRestoreThread(final Time time,
+                              final String threadClientId,
+                              final ChangelogReader changelogReader) {
+        super(threadClientId);
+
+        final String logPrefix = String.format("state-restore-thread [%s] ", threadClientId);
+        final LogContext logContext = new LogContext(logPrefix);
+
+        this.time = time;
+        this.log = logContext.logger(getClass());
+        this.taskItemQueue = new LinkedBlockingDeque<>();
+        this.corruptedExceptions = new LinkedBlockingDeque<>();
+        this.completedChangelogs = new AtomicReference<>(Collections.emptySet());
+
+        this.changelogReader = changelogReader;
+    }
+
+    private synchronized void waitIfAllChangelogsCompleted() {
+        final Set<TopicPartition> allChangelogs = changelogReader.allChangelogs();
+        if (allChangelogs.equals(changelogReader.completedChangelogs())) {
+            log.debug("All changelogs {} have completed restoration so far, will wait " +
+                    "until new changelogs are registered", allChangelogs);
+
+            while (isRunning.get() && taskItemQueue.isEmpty()) {
+                try {
+                    wait();
+                } catch (final InterruptedException e) {
+                    // do nothing
+                }
+            }
+        }
+    }
+
+    public synchronized void addInitializedTasks(final List<AbstractTask> tasks) {
+        if (!tasks.isEmpty()) {
+            for (final AbstractTask task: tasks) {
+                taskItemQueue.add(new TaskItem(task, ItemType.CREATE, task.changelogPartitions()));
+            }
+            notifyAll();
+        }
+    }
+
+    public synchronized void addClosedTasks(final Map<AbstractTask, Collection<TopicPartition>> tasks) {
+        if (!tasks.isEmpty()) {
+            for (final Map.Entry<AbstractTask, Collection<TopicPartition>> entry : tasks.entrySet()) {
+                taskItemQueue.add(new TaskItem(entry.getKey(), ItemType.CLOSE, entry.getValue()));
+            }
+            notifyAll();
+        }
+    }
+
+    public Set<TopicPartition> completedChangelogs() {
+        return completedChangelogs.get();
+    }
+
+    @Override
+    public void run() {
+        try {
+            while (isRunning()) {
+                runOnce();
+            }
+        } catch (final Exception e) {
+            log.error("Encountered the following exception while restoring states " +
+                    "and the thread is going to shut down: ", e);
+            throw e;
+        } finally {
+            try {
+                changelogReader.clear();
+            } catch (final Throwable e) {
+                log.error("Failed to close changelog reader due to the following error:", e);
+            }
+
+            shutdownLatch.countDown();
+        }
+    }
+
+    // Visible for testing
+    void runOnce() {
+        waitIfAllChangelogsCompleted();
+
+        if (!isRunning.get())
+            return;
+
+        // a task being recycled maybe in both closed and initialized tasks,
+        // and hence we should process the closed ones first and then initialized ones
+        final List<TaskItem> items = new ArrayList<>();
+        taskItemQueue.drainTo(items);
+
+        if (!items.isEmpty()) {
+            for (final TaskItem item : items) {
+                // TODO: we should consider also call the listener if the
+                //       changelog is not yet completed
+                if (item.type == ItemType.CLOSE) {
+                    changelogReader.unregister(item.changelogPartitions);
+
+                    log.info("Unregistered changelogs {} for closing task {}",
+                            item.task.changelogPartitions(),
+                            item.task.id());
+                } else if (item.type == ItemType.CREATE) {
+                    // we should only convert the state manager type right StateRestoreThreadTest.javabefore re-registering the changelog
+                    item.task.stateMgr.maybeConvertToNewTaskType();

Review comment:
       nit: can we rename this to `maybeCompleteTaskTypeTransition` or `completeTaskTypeConversionIfNecessary`, etc ? Right now it kind of sounds like we just randomly attempt to convert it to a new task type out of nowhere

##########
File path: streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
##########
@@ -623,16 +613,46 @@ void runOnce() {
             return;
         }
 
-        initializeAndRestorePhase();
+        // we need to first add closed tasks and then created tasks to work with those revived / recycled tasks
+        restoreThread.addClosedTasks(taskManager.drainRemovedTasks());
+
+        // try to initialize created tasks that are either newly assigned or re-created from corrupted tasks
+        final List<AbstractTask> initializedTasks;
+        if (!(initializedTasks = taskManager.tryInitializeNewTasks()).isEmpty()) {
+            if (log.isDebugEnabled()) {
+                log.debug("Initializing newly created tasks {} under state {}",
+                        initializedTasks.stream().map(AbstractTask::id).collect(Collectors.toList()), state);
+            }
+
+            restoreThread.addInitializedTasks(initializedTasks);
+        }
+
+        // try complete restoration if there are any restoring tasks

Review comment:
       nit:
   ```suggestion
           // try to complete restoration for any tasks that have finished restoring
   ```

##########
File path: streams/src/main/java/org/apache/kafka/streams/processor/internals/StateRestoreThread.java
##########
@@ -0,0 +1,263 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.clients.admin.Admin;
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.errors.InterruptException;
+import org.apache.kafka.common.errors.TimeoutException;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.errors.StreamsException;
+import org.apache.kafka.streams.errors.TaskCorruptedException;
+import org.apache.kafka.streams.processor.StateRestoreListener;
+import org.slf4j.Logger;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.LinkedBlockingDeque;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.stream.Collectors;
+
+
+/**
+ * This is the thread responsible for restoring state stores for both active and standby tasks
+ */
+public class StateRestoreThread extends Thread {
+
+    private final Time time;
+    private final Logger log;
+    private final ChangelogReader changelogReader;
+    private final AtomicBoolean isRunning = new AtomicBoolean(true);
+    private final CountDownLatch shutdownLatch = new CountDownLatch(1);
+    private final LinkedBlockingDeque<TaskItem> taskItemQueue;
+    private final AtomicReference<Set<TopicPartition>> completedChangelogs;
+    private final LinkedBlockingDeque<TaskCorruptedException> corruptedExceptions;
+
+    public boolean isRunning() {
+        return isRunning.get();
+    }
+
+    public StateRestoreThread(final Time time,
+                              final StreamsConfig config,
+                              final String threadClientId,
+                              final Admin adminClient,
+                              final String groupId,
+                              final Consumer<byte[], byte[]> restoreConsumer,
+                              final StateRestoreListener userStateRestoreListener) {
+        this(time, threadClientId, new StoreChangelogReader(time, config, threadClientId,
+                adminClient, groupId, restoreConsumer, userStateRestoreListener));
+    }
+
+    // for testing only
+    public StateRestoreThread(final Time time,
+                              final String threadClientId,
+                              final ChangelogReader changelogReader) {
+        super(threadClientId);
+
+        final String logPrefix = String.format("state-restore-thread [%s] ", threadClientId);
+        final LogContext logContext = new LogContext(logPrefix);
+
+        this.time = time;
+        this.log = logContext.logger(getClass());
+        this.taskItemQueue = new LinkedBlockingDeque<>();
+        this.corruptedExceptions = new LinkedBlockingDeque<>();
+        this.completedChangelogs = new AtomicReference<>(Collections.emptySet());
+
+        this.changelogReader = changelogReader;
+    }
+
+    private synchronized void waitIfAllChangelogsCompleted() {
+        final Set<TopicPartition> allChangelogs = changelogReader.allChangelogs();
+        if (allChangelogs.equals(changelogReader.completedChangelogs())) {
+            log.debug("All changelogs {} have completed restoration so far, will wait " +
+                    "until new changelogs are registered", allChangelogs);
+
+            while (isRunning.get() && taskItemQueue.isEmpty()) {
+                try {
+                    wait();
+                } catch (final InterruptedException e) {
+                    // do nothing
+                }
+            }
+        }
+    }
+
+    public synchronized void addInitializedTasks(final List<AbstractTask> tasks) {
+        if (!tasks.isEmpty()) {
+            for (final AbstractTask task: tasks) {
+                taskItemQueue.add(new TaskItem(task, ItemType.CREATE, task.changelogPartitions()));
+            }
+            notifyAll();
+        }
+    }
+
+    public synchronized void addClosedTasks(final Map<AbstractTask, Collection<TopicPartition>> tasks) {
+        if (!tasks.isEmpty()) {
+            for (final Map.Entry<AbstractTask, Collection<TopicPartition>> entry : tasks.entrySet()) {
+                taskItemQueue.add(new TaskItem(entry.getKey(), ItemType.CLOSE, entry.getValue()));
+            }
+            notifyAll();
+        }
+    }
+
+    public Set<TopicPartition> completedChangelogs() {
+        return completedChangelogs.get();
+    }
+
+    @Override
+    public void run() {
+        try {
+            while (isRunning()) {
+                runOnce();
+            }
+        } catch (final Exception e) {
+            log.error("Encountered the following exception while restoring states " +
+                    "and the thread is going to shut down: ", e);
+            throw e;
+        } finally {
+            try {
+                changelogReader.clear();
+            } catch (final Throwable e) {
+                log.error("Failed to close changelog reader due to the following error:", e);
+            }
+
+            shutdownLatch.countDown();
+        }
+    }
+
+    // Visible for testing
+    void runOnce() {
+        waitIfAllChangelogsCompleted();
+
+        if (!isRunning.get())
+            return;
+
+        // a task being recycled maybe in both closed and initialized tasks,
+        // and hence we should process the closed ones first and then initialized ones
+        final List<TaskItem> items = new ArrayList<>();
+        taskItemQueue.drainTo(items);
+
+        if (!items.isEmpty()) {
+            for (final TaskItem item : items) {
+                // TODO: we should consider also call the listener if the
+                //       changelog is not yet completed
+                if (item.type == ItemType.CLOSE) {
+                    changelogReader.unregister(item.changelogPartitions);
+
+                    log.info("Unregistered changelogs {} for closing task {}",
+                            item.task.changelogPartitions(),
+                            item.task.id());
+                } else if (item.type == ItemType.CREATE) {
+                    // we should only convert the state manager type right StateRestoreThreadTest.javabefore re-registering the changelog
+                    item.task.stateMgr.maybeConvertToNewTaskType();
+
+                    for (final TopicPartition partition : item.changelogPartitions) {
+                        changelogReader.register(partition, item.task.stateMgr);
+                    }
+
+                    log.info("Registered changelogs {} for created task {}",
+                            item.task.changelogPartitions(),
+                            item.task.id());
+                }
+            }
+        }
+        items.clear();
+
+        // try to restore some changelogs
+        final long startMs = time.milliseconds();
+        try {
+            final int numRestored = changelogReader.restore();
+            // TODO: we should record the restoration related metrics

Review comment:
       If this won't be added as part of this PR, can you file a ticket? TODOs are all too easily forgotten

##########
File path: streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractTask.java
##########
@@ -115,13 +114,16 @@ public boolean isClosed() {
     @Override
     public void revive() {
         if (state == CLOSED) {
+            // clear all the stores since they should be re-registered

Review comment:
       Can you elaborate on what you mean by "materialize" the changelogs here?

##########
File path: streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
##########
@@ -185,7 +181,7 @@ StreamTask createActiveTaskFromStandby(final StandbyTask standbyTask,
         final LogContext logContext = getLogContext(standbyTask.id);
 
         standbyTask.closeCleanAndRecycleState();
-        stateManager.transitionTaskType(TaskType.ACTIVE, logContext);
+        stateManager.prepareNewTaskType(TaskType.ACTIVE, logContext);

Review comment:
       Can you elaborate here as well? Maybe I'm missing something but it doesn't look like the restore thread ever checks on the task/state-manager type at all. For example `waitIfAllChangelogsCompleted` just compares the changelog reader's `allChangelogs()` vs `completedChangelogs()`. Neither of those touch on the active/standby status of the task. Should they?

##########
File path: streams/src/main/java/org/apache/kafka/streams/processor/internals/StateRestoreThread.java
##########
@@ -0,0 +1,263 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.clients.admin.Admin;
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.errors.InterruptException;
+import org.apache.kafka.common.errors.TimeoutException;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.errors.StreamsException;
+import org.apache.kafka.streams.errors.TaskCorruptedException;
+import org.apache.kafka.streams.processor.StateRestoreListener;
+import org.slf4j.Logger;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.LinkedBlockingDeque;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.stream.Collectors;
+
+
+/**
+ * This is the thread responsible for restoring state stores for both active and standby tasks
+ */
+public class StateRestoreThread extends Thread {
+
+    private final Time time;
+    private final Logger log;
+    private final ChangelogReader changelogReader;
+    private final AtomicBoolean isRunning = new AtomicBoolean(true);
+    private final CountDownLatch shutdownLatch = new CountDownLatch(1);
+    private final LinkedBlockingDeque<TaskItem> taskItemQueue;
+    private final AtomicReference<Set<TopicPartition>> completedChangelogs;
+    private final LinkedBlockingDeque<TaskCorruptedException> corruptedExceptions;
+
+    public boolean isRunning() {
+        return isRunning.get();
+    }
+
+    public StateRestoreThread(final Time time,
+                              final StreamsConfig config,
+                              final String threadClientId,
+                              final Admin adminClient,
+                              final String groupId,
+                              final Consumer<byte[], byte[]> restoreConsumer,
+                              final StateRestoreListener userStateRestoreListener) {
+        this(time, threadClientId, new StoreChangelogReader(time, config, threadClientId,
+                adminClient, groupId, restoreConsumer, userStateRestoreListener));
+    }
+
+    // for testing only
+    public StateRestoreThread(final Time time,
+                              final String threadClientId,
+                              final ChangelogReader changelogReader) {
+        super(threadClientId);
+
+        final String logPrefix = String.format("state-restore-thread [%s] ", threadClientId);
+        final LogContext logContext = new LogContext(logPrefix);
+
+        this.time = time;
+        this.log = logContext.logger(getClass());
+        this.taskItemQueue = new LinkedBlockingDeque<>();
+        this.corruptedExceptions = new LinkedBlockingDeque<>();
+        this.completedChangelogs = new AtomicReference<>(Collections.emptySet());
+
+        this.changelogReader = changelogReader;
+    }
+
+    private synchronized void waitIfAllChangelogsCompleted() {
+        final Set<TopicPartition> allChangelogs = changelogReader.allChangelogs();
+        if (allChangelogs.equals(changelogReader.completedChangelogs())) {
+            log.debug("All changelogs {} have completed restoration so far, will wait " +
+                    "until new changelogs are registered", allChangelogs);
+
+            while (isRunning.get() && taskItemQueue.isEmpty()) {
+                try {
+                    wait();
+                } catch (final InterruptedException e) {
+                    // do nothing
+                }
+            }
+        }
+    }
+
+    public synchronized void addInitializedTasks(final List<AbstractTask> tasks) {
+        if (!tasks.isEmpty()) {
+            for (final AbstractTask task: tasks) {
+                taskItemQueue.add(new TaskItem(task, ItemType.CREATE, task.changelogPartitions()));
+            }
+            notifyAll();
+        }
+    }
+
+    public synchronized void addClosedTasks(final Map<AbstractTask, Collection<TopicPartition>> tasks) {
+        if (!tasks.isEmpty()) {
+            for (final Map.Entry<AbstractTask, Collection<TopicPartition>> entry : tasks.entrySet()) {
+                taskItemQueue.add(new TaskItem(entry.getKey(), ItemType.CLOSE, entry.getValue()));
+            }
+            notifyAll();
+        }
+    }
+
+    public Set<TopicPartition> completedChangelogs() {
+        return completedChangelogs.get();
+    }
+
+    @Override
+    public void run() {
+        try {
+            while (isRunning()) {
+                runOnce();
+            }
+        } catch (final Exception e) {
+            log.error("Encountered the following exception while restoring states " +
+                    "and the thread is going to shut down: ", e);
+            throw e;
+        } finally {
+            try {
+                changelogReader.clear();
+            } catch (final Throwable e) {
+                log.error("Failed to close changelog reader due to the following error:", e);
+            }
+
+            shutdownLatch.countDown();
+        }
+    }
+
+    // Visible for testing
+    void runOnce() {
+        waitIfAllChangelogsCompleted();
+
+        if (!isRunning.get())
+            return;
+
+        // a task being recycled maybe in both closed and initialized tasks,
+        // and hence we should process the closed ones first and then initialized ones
+        final List<TaskItem> items = new ArrayList<>();
+        taskItemQueue.drainTo(items);
+
+        if (!items.isEmpty()) {
+            for (final TaskItem item : items) {
+                // TODO: we should consider also call the listener if the
+                //       changelog is not yet completed
+                if (item.type == ItemType.CLOSE) {
+                    changelogReader.unregister(item.changelogPartitions);
+
+                    log.info("Unregistered changelogs {} for closing task {}",
+                            item.task.changelogPartitions(),
+                            item.task.id());
+                } else if (item.type == ItemType.CREATE) {
+                    // we should only convert the state manager type right StateRestoreThreadTest.javabefore re-registering the changelog
+                    item.task.stateMgr.maybeConvertToNewTaskType();
+
+                    for (final TopicPartition partition : item.changelogPartitions) {
+                        changelogReader.register(partition, item.task.stateMgr);
+                    }
+
+                    log.info("Registered changelogs {} for created task {}",
+                            item.task.changelogPartitions(),
+                            item.task.id());
+                }
+            }
+        }
+        items.clear();
+
+        // try to restore some changelogs
+        final long startMs = time.milliseconds();
+        try {
+            final int numRestored = changelogReader.restore();
+            // TODO: we should record the restoration related metrics
+            log.debug("Restored {} records in {} ms", numRestored, time.milliseconds() - startMs);
+        } catch (final TaskCorruptedException e) {
+            log.warn("Detected the states of tasks " + e.corruptedTaskWithChangelogs() + " are corrupted. " +
+                    "Will close the task as dirty and re-create and bootstrap from scratch.", e);
+
+            // remove corrupted partitions form the changelog reader and continue; we can still proceed
+            // and restore other partitions until the main thread come to handle this exception
+            changelogReader.unregister(e.corruptedTaskWithChangelogs().values().stream()
+                    .flatMap(Collection::stream)
+                    .collect(Collectors.toList()));
+
+            corruptedExceptions.add(e);
+        } catch (final StreamsException e) {
+            // if we are shutting down, the consumer could throw interrupt exception which can be ignored;
+            // otherwise, we re-throw
+            if (!(e.getCause() instanceof InterruptException) || isRunning.get()) {
+                throw e;
+            }
+        } catch (final TimeoutException e) {
+            log.info("Encountered timeout when restoring states, will retry in the next loop");
+        }
+
+        // finally update completed changelogs
+        completedChangelogs.set(changelogReader.completedChangelogs());
+    }
+
+    public TaskCorruptedException nextCorruptedException() {
+        return corruptedExceptions.poll();
+    }
+
+    public void shutdown(final long timeoutMs) throws InterruptedException {
+        log.info("Shutting down");
+
+        isRunning.set(false);
+        interrupt();
+
+        final boolean ret = shutdownLatch.await(timeoutMs, TimeUnit.MILLISECONDS);
+
+        if (ret) {
+            log.info("Shutdown complete");
+        } else {
+            log.warn("Shutdown timed out after {}", timeoutMs);
+        }
+    }
+
+    private enum ItemType {
+        CREATE,
+        CLOSE,
+        REVIVE
+    }
+
+    private static class TaskItem {
+        private final ItemType type;
+        private final AbstractTask task;

Review comment:
       Why `AbstractTask` and not just `Task`? It seems like we shouldn't be relying on specific APIs or members that exist in an abstract implementation but not on the interface itself; what would be the point of the interface then. If we need access to the `stateMgr` then let's just add a `Task#stateManager` method 

##########
File path: streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java
##########
@@ -220,11 +224,6 @@ public void closeDirty() {
     @Override
     public void closeCleanAndRecycleState() {
         streamsMetrics.removeAllTaskLevelSensors(Thread.currentThread().getName(), id.toString());
-        if (state() == State.SUSPENDED) {

Review comment:
       If it's still true that a task should always be suspended before recycling, let's keep the check here

##########
File path: streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java
##########
@@ -537,36 +513,37 @@ public void close() throws ProcessorStateException {
                 }
             }
 
-            stores.clear();
+            // do not clear the store map since they will still be used for changelog de-registeration later
         }
 
         if (firstException != null) {
             throw firstException;
         }
     }
 
-    /**
-     * Alternative to {@link #close()} that just resets the changelogs without closing any of the underlying state
-     * or unregistering the stores themselves
-     */
-    void recycle() {
-        log.debug("Recycling state for {} task {}.", taskType, taskId);
-
-        final List<TopicPartition> allChangelogs = getAllChangelogTopicPartitions();
-        changelogReader.unregister(allChangelogs);
+    public void clear() {
+        stores.clear();
     }
 
-    void transitionTaskType(final TaskType newType, final LogContext logContext) {
+    void prepareNewTaskType(final TaskType newType, final LogContext logContext) {
         if (taskType.equals(newType)) {
             throw new IllegalStateException("Tried to recycle state for task type conversion but new type was the same.");
         }
 
-        final TaskType oldType = taskType;
-        taskType = newType;
-        log = logContext.logger(ProcessorStateManager.class);
-        logPrefix = logContext.logPrefix();
+        // we'd have to defer the actual conversion until the registration of changelogs have completed
+        this.newType = newType;
+        this.logPrefix = logContext.logPrefix();
+        this.log = logContext.logger(ProcessorStateManager.class);
 
-        log.debug("Transitioning state manager for {} task {} to {}", oldType, taskId, newType);
+        log.debug("Preparing to transit state manager for task {} from {} to {}", taskId, taskType, newType);
+    }
+
+    void maybeConvertToNewTaskType() {

Review comment:
       nit: call it `maybeCompleteTaskTypeConversion` or `completeTaskTypeConversionIfNecessary`? Right now it kind of sounds like we're just randomly converting the task to a new type but really we're just finishing that conversion off

##########
File path: streams/src/main/java/org/apache/kafka/streams/processor/internals/StateRestoreThread.java
##########
@@ -0,0 +1,263 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.clients.admin.Admin;
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.errors.InterruptException;
+import org.apache.kafka.common.errors.TimeoutException;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.errors.StreamsException;
+import org.apache.kafka.streams.errors.TaskCorruptedException;
+import org.apache.kafka.streams.processor.StateRestoreListener;
+import org.slf4j.Logger;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.LinkedBlockingDeque;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.stream.Collectors;
+
+
+/**
+ * This is the thread responsible for restoring state stores for both active and standby tasks
+ */
+public class StateRestoreThread extends Thread {
+
+    private final Time time;
+    private final Logger log;
+    private final ChangelogReader changelogReader;
+    private final AtomicBoolean isRunning = new AtomicBoolean(true);
+    private final CountDownLatch shutdownLatch = new CountDownLatch(1);
+    private final LinkedBlockingDeque<TaskItem> taskItemQueue;
+    private final AtomicReference<Set<TopicPartition>> completedChangelogs;
+    private final LinkedBlockingDeque<TaskCorruptedException> corruptedExceptions;
+
+    public boolean isRunning() {
+        return isRunning.get();
+    }
+
+    public StateRestoreThread(final Time time,
+                              final StreamsConfig config,
+                              final String threadClientId,
+                              final Admin adminClient,
+                              final String groupId,
+                              final Consumer<byte[], byte[]> restoreConsumer,
+                              final StateRestoreListener userStateRestoreListener) {
+        this(time, threadClientId, new StoreChangelogReader(time, config, threadClientId,
+                adminClient, groupId, restoreConsumer, userStateRestoreListener));
+    }
+
+    // for testing only
+    public StateRestoreThread(final Time time,
+                              final String threadClientId,
+                              final ChangelogReader changelogReader) {
+        super(threadClientId);
+
+        final String logPrefix = String.format("state-restore-thread [%s] ", threadClientId);
+        final LogContext logContext = new LogContext(logPrefix);
+
+        this.time = time;
+        this.log = logContext.logger(getClass());
+        this.taskItemQueue = new LinkedBlockingDeque<>();
+        this.corruptedExceptions = new LinkedBlockingDeque<>();
+        this.completedChangelogs = new AtomicReference<>(Collections.emptySet());
+
+        this.changelogReader = changelogReader;
+    }
+
+    private synchronized void waitIfAllChangelogsCompleted() {
+        final Set<TopicPartition> allChangelogs = changelogReader.allChangelogs();
+        if (allChangelogs.equals(changelogReader.completedChangelogs())) {
+            log.debug("All changelogs {} have completed restoration so far, will wait " +
+                    "until new changelogs are registered", allChangelogs);
+
+            while (isRunning.get() && taskItemQueue.isEmpty()) {
+                try {
+                    wait();
+                } catch (final InterruptedException e) {
+                    // do nothing
+                }
+            }
+        }
+    }
+
+    public synchronized void addInitializedTasks(final List<AbstractTask> tasks) {
+        if (!tasks.isEmpty()) {
+            for (final AbstractTask task: tasks) {

Review comment:
       ```suggestion
               for (final AbstractTask task : tasks) {
   ```

##########
File path: streams/src/main/java/org/apache/kafka/streams/processor/internals/StateRestoreThread.java
##########
@@ -0,0 +1,263 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.clients.admin.Admin;
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.errors.InterruptException;
+import org.apache.kafka.common.errors.TimeoutException;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.errors.StreamsException;
+import org.apache.kafka.streams.errors.TaskCorruptedException;
+import org.apache.kafka.streams.processor.StateRestoreListener;
+import org.slf4j.Logger;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.LinkedBlockingDeque;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.stream.Collectors;
+
+
+/**
+ * This is the thread responsible for restoring state stores for both active and standby tasks
+ */
+public class StateRestoreThread extends Thread {
+
+    private final Time time;
+    private final Logger log;
+    private final ChangelogReader changelogReader;
+    private final AtomicBoolean isRunning = new AtomicBoolean(true);
+    private final CountDownLatch shutdownLatch = new CountDownLatch(1);
+    private final LinkedBlockingDeque<TaskItem> taskItemQueue;
+    private final AtomicReference<Set<TopicPartition>> completedChangelogs;
+    private final LinkedBlockingDeque<TaskCorruptedException> corruptedExceptions;
+
+    public boolean isRunning() {
+        return isRunning.get();
+    }
+
+    public StateRestoreThread(final Time time,
+                              final StreamsConfig config,
+                              final String threadClientId,
+                              final Admin adminClient,
+                              final String groupId,
+                              final Consumer<byte[], byte[]> restoreConsumer,
+                              final StateRestoreListener userStateRestoreListener) {
+        this(time, threadClientId, new StoreChangelogReader(time, config, threadClientId,
+                adminClient, groupId, restoreConsumer, userStateRestoreListener));
+    }
+
+    // for testing only
+    public StateRestoreThread(final Time time,
+                              final String threadClientId,
+                              final ChangelogReader changelogReader) {
+        super(threadClientId);
+
+        final String logPrefix = String.format("state-restore-thread [%s] ", threadClientId);
+        final LogContext logContext = new LogContext(logPrefix);
+
+        this.time = time;
+        this.log = logContext.logger(getClass());
+        this.taskItemQueue = new LinkedBlockingDeque<>();
+        this.corruptedExceptions = new LinkedBlockingDeque<>();
+        this.completedChangelogs = new AtomicReference<>(Collections.emptySet());
+
+        this.changelogReader = changelogReader;
+    }
+
+    private synchronized void waitIfAllChangelogsCompleted() {
+        final Set<TopicPartition> allChangelogs = changelogReader.allChangelogs();
+        if (allChangelogs.equals(changelogReader.completedChangelogs())) {
+            log.debug("All changelogs {} have completed restoration so far, will wait " +
+                    "until new changelogs are registered", allChangelogs);
+
+            while (isRunning.get() && taskItemQueue.isEmpty()) {
+                try {
+                    wait();
+                } catch (final InterruptedException e) {
+                    // do nothing
+                }
+            }
+        }
+    }
+
+    public synchronized void addInitializedTasks(final List<AbstractTask> tasks) {
+        if (!tasks.isEmpty()) {
+            for (final AbstractTask task: tasks) {
+                taskItemQueue.add(new TaskItem(task, ItemType.CREATE, task.changelogPartitions()));
+            }
+            notifyAll();
+        }
+    }
+
+    public synchronized void addClosedTasks(final Map<AbstractTask, Collection<TopicPartition>> tasks) {
+        if (!tasks.isEmpty()) {
+            for (final Map.Entry<AbstractTask, Collection<TopicPartition>> entry : tasks.entrySet()) {
+                taskItemQueue.add(new TaskItem(entry.getKey(), ItemType.CLOSE, entry.getValue()));
+            }
+            notifyAll();
+        }
+    }
+
+    public Set<TopicPartition> completedChangelogs() {
+        return completedChangelogs.get();
+    }
+
+    @Override
+    public void run() {
+        try {
+            while (isRunning()) {
+                runOnce();
+            }
+        } catch (final Exception e) {
+            log.error("Encountered the following exception while restoring states " +
+                    "and the thread is going to shut down: ", e);
+            throw e;
+        } finally {
+            try {
+                changelogReader.clear();
+            } catch (final Throwable e) {
+                log.error("Failed to close changelog reader due to the following error:", e);
+            }
+
+            shutdownLatch.countDown();
+        }
+    }
+
+    // Visible for testing
+    void runOnce() {
+        waitIfAllChangelogsCompleted();
+
+        if (!isRunning.get())
+            return;
+
+        // a task being recycled maybe in both closed and initialized tasks,
+        // and hence we should process the closed ones first and then initialized ones
+        final List<TaskItem> items = new ArrayList<>();
+        taskItemQueue.drainTo(items);
+
+        if (!items.isEmpty()) {

Review comment:
       This seems redundant

##########
File path: streams/src/main/java/org/apache/kafka/streams/processor/internals/StateRestoreThread.java
##########
@@ -0,0 +1,263 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.clients.admin.Admin;
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.errors.InterruptException;
+import org.apache.kafka.common.errors.TimeoutException;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.errors.StreamsException;
+import org.apache.kafka.streams.errors.TaskCorruptedException;
+import org.apache.kafka.streams.processor.StateRestoreListener;
+import org.slf4j.Logger;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.LinkedBlockingDeque;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.stream.Collectors;
+
+
+/**
+ * This is the thread responsible for restoring state stores for both active and standby tasks
+ */
+public class StateRestoreThread extends Thread {
+
+    private final Time time;
+    private final Logger log;
+    private final ChangelogReader changelogReader;
+    private final AtomicBoolean isRunning = new AtomicBoolean(true);
+    private final CountDownLatch shutdownLatch = new CountDownLatch(1);
+    private final LinkedBlockingDeque<TaskItem> taskItemQueue;
+    private final AtomicReference<Set<TopicPartition>> completedChangelogs;
+    private final LinkedBlockingDeque<TaskCorruptedException> corruptedExceptions;
+
+    public boolean isRunning() {
+        return isRunning.get();
+    }
+
+    public StateRestoreThread(final Time time,
+                              final StreamsConfig config,
+                              final String threadClientId,
+                              final Admin adminClient,
+                              final String groupId,
+                              final Consumer<byte[], byte[]> restoreConsumer,
+                              final StateRestoreListener userStateRestoreListener) {
+        this(time, threadClientId, new StoreChangelogReader(time, config, threadClientId,
+                adminClient, groupId, restoreConsumer, userStateRestoreListener));
+    }
+
+    // for testing only
+    public StateRestoreThread(final Time time,
+                              final String threadClientId,
+                              final ChangelogReader changelogReader) {
+        super(threadClientId);
+
+        final String logPrefix = String.format("state-restore-thread [%s] ", threadClientId);
+        final LogContext logContext = new LogContext(logPrefix);
+
+        this.time = time;
+        this.log = logContext.logger(getClass());
+        this.taskItemQueue = new LinkedBlockingDeque<>();
+        this.corruptedExceptions = new LinkedBlockingDeque<>();
+        this.completedChangelogs = new AtomicReference<>(Collections.emptySet());
+
+        this.changelogReader = changelogReader;
+    }
+
+    private synchronized void waitIfAllChangelogsCompleted() {
+        final Set<TopicPartition> allChangelogs = changelogReader.allChangelogs();
+        if (allChangelogs.equals(changelogReader.completedChangelogs())) {
+            log.debug("All changelogs {} have completed restoration so far, will wait " +
+                    "until new changelogs are registered", allChangelogs);
+
+            while (isRunning.get() && taskItemQueue.isEmpty()) {
+                try {
+                    wait();
+                } catch (final InterruptedException e) {
+                    // do nothing
+                }
+            }
+        }
+    }
+
+    public synchronized void addInitializedTasks(final List<AbstractTask> tasks) {
+        if (!tasks.isEmpty()) {
+            for (final AbstractTask task: tasks) {
+                taskItemQueue.add(new TaskItem(task, ItemType.CREATE, task.changelogPartitions()));
+            }
+            notifyAll();
+        }
+    }
+
+    public synchronized void addClosedTasks(final Map<AbstractTask, Collection<TopicPartition>> tasks) {
+        if (!tasks.isEmpty()) {
+            for (final Map.Entry<AbstractTask, Collection<TopicPartition>> entry : tasks.entrySet()) {
+                taskItemQueue.add(new TaskItem(entry.getKey(), ItemType.CLOSE, entry.getValue()));
+            }
+            notifyAll();
+        }
+    }
+
+    public Set<TopicPartition> completedChangelogs() {
+        return completedChangelogs.get();
+    }
+
+    @Override
+    public void run() {
+        try {
+            while (isRunning()) {
+                runOnce();
+            }
+        } catch (final Exception e) {
+            log.error("Encountered the following exception while restoring states " +
+                    "and the thread is going to shut down: ", e);
+            throw e;
+        } finally {
+            try {
+                changelogReader.clear();
+            } catch (final Throwable e) {
+                log.error("Failed to close changelog reader due to the following error:", e);
+            }
+
+            shutdownLatch.countDown();
+        }
+    }
+
+    // Visible for testing
+    void runOnce() {
+        waitIfAllChangelogsCompleted();
+
+        if (!isRunning.get())
+            return;
+
+        // a task being recycled maybe in both closed and initialized tasks,
+        // and hence we should process the closed ones first and then initialized ones
+        final List<TaskItem> items = new ArrayList<>();
+        taskItemQueue.drainTo(items);
+
+        if (!items.isEmpty()) {
+            for (final TaskItem item : items) {
+                // TODO: we should consider also call the listener if the
+                //       changelog is not yet completed
+                if (item.type == ItemType.CLOSE) {
+                    changelogReader.unregister(item.changelogPartitions);
+
+                    log.info("Unregistered changelogs {} for closing task {}",
+                            item.task.changelogPartitions(),
+                            item.task.id());
+                } else if (item.type == ItemType.CREATE) {
+                    // we should only convert the state manager type right StateRestoreThreadTest.javabefore re-registering the changelog
+                    item.task.stateMgr.maybeConvertToNewTaskType();
+
+                    for (final TopicPartition partition : item.changelogPartitions) {
+                        changelogReader.register(partition, item.task.stateMgr);
+                    }
+
+                    log.info("Registered changelogs {} for created task {}",
+                            item.task.changelogPartitions(),
+                            item.task.id());
+                }
+            }

Review comment:
       Do we need to handle reviving tasks here? Seems like we don't actually use the `REVIVE` ItemType at all

##########
File path: streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
##########
@@ -623,16 +613,46 @@ void runOnce() {
             return;
         }
 
-        initializeAndRestorePhase();
+        // we need to first add closed tasks and then created tasks to work with those revived / recycled tasks

Review comment:
       ```suggestion
           // we need to first add any closed revoked/corrupted/recycled tasks and then add the initialized tasks to update the changelogs of revived/recycled tasks
   ```

##########
File path: streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
##########
@@ -623,16 +613,46 @@ void runOnce() {
             return;
         }
 
-        initializeAndRestorePhase();
+        // we need to first add closed tasks and then created tasks to work with those revived / recycled tasks
+        restoreThread.addClosedTasks(taskManager.drainRemovedTasks());
+
+        // try to initialize created tasks that are either newly assigned or re-created from corrupted tasks
+        final List<AbstractTask> initializedTasks;
+        if (!(initializedTasks = taskManager.tryInitializeNewTasks()).isEmpty()) {

Review comment:
       req:
   ```suggestion
           final List<AbstractTask> initializedTasks = taskManager.tryInitializeNewTasks();
           if (!initializedTasks.isEmpty()) {
   ```

##########
File path: streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
##########
@@ -623,16 +613,46 @@ void runOnce() {
             return;
         }
 
-        initializeAndRestorePhase();
+        // we need to first add closed tasks and then created tasks to work with those revived / recycled tasks
+        restoreThread.addClosedTasks(taskManager.drainRemovedTasks());
+
+        // try to initialize created tasks that are either newly assigned or re-created from corrupted tasks
+        final List<AbstractTask> initializedTasks;
+        if (!(initializedTasks = taskManager.tryInitializeNewTasks()).isEmpty()) {
+            if (log.isDebugEnabled()) {
+                log.debug("Initializing newly created tasks {} under state {}",
+                        initializedTasks.stream().map(AbstractTask::id).collect(Collectors.toList()), state);
+            }
+
+            restoreThread.addInitializedTasks(initializedTasks);
+        }
+
+        // try complete restoration if there are any restoring tasks
+        if (taskManager.tryToCompleteRestoration(restoreThread.completedChangelogs())) {
+            log.debug("Completed restoring all tasks now");

Review comment:
       I don't think this log message adds any value on top of the `All tasks are now running and transited State to RUNNING` message below. Also, that one will only be logged when we actually transition, whereas this will be logged on every iteration of the main loop, even after restoration is long done, which is confusing. I think we can just remove it and then let `taskManager.tryToCompleteRestoration` have return type void

##########
File path: streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
##########
@@ -200,20 +200,23 @@ private void closeAndRevive(final Map<Task, Collection<TopicPartition>> taskWith
             }
             task.closeDirty();
 
+            removedTasks.put((AbstractTask) task, task.changelogPartitions());
+            ((AbstractTask) task).stateMgr.clear();

Review comment:
       This seems not good...we shouldn't be casting to an abstract class. Let's just add whatever we need as a method on the `Task` interface

##########
File path: streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
##########
@@ -623,16 +613,46 @@ void runOnce() {
             return;
         }
 
-        initializeAndRestorePhase();
+        // we need to first add closed tasks and then created tasks to work with those revived / recycled tasks
+        restoreThread.addClosedTasks(taskManager.drainRemovedTasks());
+
+        // try to initialize created tasks that are either newly assigned or re-created from corrupted tasks
+        final List<AbstractTask> initializedTasks;
+        if (!(initializedTasks = taskManager.tryInitializeNewTasks()).isEmpty()) {
+            if (log.isDebugEnabled()) {
+                log.debug("Initializing newly created tasks {} under state {}",
+                        initializedTasks.stream().map(AbstractTask::id).collect(Collectors.toList()), state);
+            }
+
+            restoreThread.addInitializedTasks(initializedTasks);
+        }
+
+        // try complete restoration if there are any restoring tasks
+        if (taskManager.tryToCompleteRestoration(restoreThread.completedChangelogs())) {
+            log.debug("Completed restoring all tasks now");
+        }
+
+        if (state == State.PARTITIONS_ASSIGNED && taskManager.allTasksRunning()) {
+            // it is possible that we have no assigned tasks in which case we would still transit state
+            setState(State.RUNNING);
+
+            log.debug("All tasks are now running and transited State to {}", State.RUNNING);

Review comment:
       ```suggestion
               log.info("All tasks are now running");
   ```
   (the `setState` call will log the state change so we don't really need to log it again here imo)

##########
File path: streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
##########
@@ -623,16 +613,46 @@ void runOnce() {
             return;
         }
 
-        initializeAndRestorePhase();
+        // we need to first add closed tasks and then created tasks to work with those revived / recycled tasks
+        restoreThread.addClosedTasks(taskManager.drainRemovedTasks());
+
+        // try to initialize created tasks that are either newly assigned or re-created from corrupted tasks

Review comment:
       ```suggestion
           // try to initialize created tasks that are either newly assigned, recycled, or revived from corrupted tasks
   ```

##########
File path: streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
##########
@@ -623,16 +613,46 @@ void runOnce() {
             return;
         }
 
-        initializeAndRestorePhase();
+        // we need to first add closed tasks and then created tasks to work with those revived / recycled tasks
+        restoreThread.addClosedTasks(taskManager.drainRemovedTasks());
+
+        // try to initialize created tasks that are either newly assigned or re-created from corrupted tasks
+        final List<AbstractTask> initializedTasks;
+        if (!(initializedTasks = taskManager.tryInitializeNewTasks()).isEmpty()) {
+            if (log.isDebugEnabled()) {
+                log.debug("Initializing newly created tasks {} under state {}",
+                        initializedTasks.stream().map(AbstractTask::id).collect(Collectors.toList()), state);
+            }
+
+            restoreThread.addInitializedTasks(initializedTasks);
+        }
+
+        // try complete restoration if there are any restoring tasks
+        if (taskManager.tryToCompleteRestoration(restoreThread.completedChangelogs())) {
+            log.debug("Completed restoring all tasks now");
+        }
+
+        if (state == State.PARTITIONS_ASSIGNED && taskManager.allTasksRunning()) {
+            // it is possible that we have no assigned tasks in which case we would still transit state
+            setState(State.RUNNING);
+
+            log.debug("All tasks are now running and transited State to {}", State.RUNNING);
+        }
 
-        // TODO: we should record the restore latency and its relative time spent ratio after
-        //       we figure out how to move this method out of the stream thread
-        advanceNowAndComputeLatency();
+        // check if restore thread has encountered TaskCorrupted exception; if yes
+        // rethrow it to trigger the handling logic
+        final TaskCorruptedException e = restoreThread.nextCorruptedException();
+        if (e != null) {
+            throw e;
+        }
 
         int totalProcessed = 0;
         long totalCommitLatency = 0L;
         long totalProcessLatency = 0L;
         long totalPunctuateLatency = 0L;
+
+        // TODO: we should allow active tasks processing even if we are not yet in RUNNING
+        //       after restoration is moved to the other thread

Review comment:
       This seems like a big TODO...can you make a ticket for it, or do you plan to do the followup right away? It seems like without this fix, we won't actually see any benefit from moving restore to a separate thread (if anything, it'll be slightly worse due to the synchronization overhead)
   

##########
File path: streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
##########
@@ -107,6 +107,7 @@
 
         private final ProcessorStateManager stateManager;
 
+        // NOTE only this field may be concurrently accessed by stream and restore threads

Review comment:
       Also out of date?

##########
File path: streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
##########
@@ -946,14 +934,14 @@ private void completeShutdown(final boolean cleanRun) {
         log.info("Shutting down");
 
         try {
-            taskManager.shutdown(cleanRun);
+            restoreThread.shutdown(10_000L);

Review comment:
       Why just choose a random number here...? Shouldn't we apply the timeout supplied to `KafkaStreams#close`?

##########
File path: streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
##########
@@ -185,55 +186,55 @@ int bufferedLimitIndex() {
     }
 
     private final static long DEFAULT_OFFSET_UPDATE_MS = Duration.ofMinutes(5L).toMillis();
-
-    private ChangelogReaderState state;
+    private static final long RESTORE_LOG_INTERVAL_MS = Duration.ofSeconds(10L).toMillis();
 
     private final Time time;
     private final Logger log;
+    private final String groupId;
     private final Duration pollTime;
     private final long updateOffsetIntervalMs;
+    private long lastUpdateOffsetTime = 0L;
+    private long lastRestoreLogTime = 0L;
 
     // 1) we keep adding partitions to restore consumer whenever new tasks are registered with the state manager;
     // 2) we do not unassign partitions when we switch between standbys and actives, we just pause / resume them;
     // 3) we only remove an assigned partition when the corresponding task is being removed from the thread.
     private final Consumer<byte[], byte[]> restoreConsumer;
     private final StateRestoreListener stateRestoreListener;
 
+    // the changelog reader needs the admin client to list end offsets and the group's committed offsets
+    private final Admin adminClient;
+
     // source of the truth of the current registered changelogs;
     // NOTE a changelog would only be removed when its corresponding task
     // is being removed from the thread; otherwise it would stay in this map even after completed
+    //
+    // this map may be concurrently accessed by different threads and hence need to be guarded

Review comment:
       I'll stop pointing them all out and let you give this class a pass for any out of date comments 🙂 

##########
File path: streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
##########
@@ -55,16 +55,16 @@
 import static org.apache.kafka.streams.processor.internals.ClientUtils.fetchCommittedOffsets;
 
 /**
- * ChangelogReader is created and maintained by the stream thread and used for both updating standby tasks and
+ * ChangelogReader is created and shared by the stream thread and restore thread. It is used for both updating standby tasks and
  * restoring active tasks. It manages the restore consumer, including its assigned partitions, when to pause / resume
  * these partitions, etc.
  * <p>
+ * This object is thread-safe for concurrent access between the two threads.
+ * <p>

Review comment:
       These comments need to be updated, no longer shared or thread-safe right?

##########
File path: streams/src/main/java/org/apache/kafka/streams/processor/internals/StateRestoreThread.java
##########
@@ -0,0 +1,263 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.clients.admin.Admin;
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.errors.InterruptException;
+import org.apache.kafka.common.errors.TimeoutException;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.errors.StreamsException;
+import org.apache.kafka.streams.errors.TaskCorruptedException;
+import org.apache.kafka.streams.processor.StateRestoreListener;
+import org.slf4j.Logger;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.LinkedBlockingDeque;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.stream.Collectors;
+
+
+/**
+ * This is the thread responsible for restoring state stores for both active and standby tasks
+ */
+public class StateRestoreThread extends Thread {
+
+    private final Time time;
+    private final Logger log;
+    private final ChangelogReader changelogReader;
+    private final AtomicBoolean isRunning = new AtomicBoolean(true);
+    private final CountDownLatch shutdownLatch = new CountDownLatch(1);
+    private final LinkedBlockingDeque<TaskItem> taskItemQueue;
+    private final AtomicReference<Set<TopicPartition>> completedChangelogs;
+    private final LinkedBlockingDeque<TaskCorruptedException> corruptedExceptions;
+
+    public boolean isRunning() {
+        return isRunning.get();
+    }
+
+    public StateRestoreThread(final Time time,
+                              final StreamsConfig config,
+                              final String threadClientId,
+                              final Admin adminClient,
+                              final String groupId,
+                              final Consumer<byte[], byte[]> restoreConsumer,
+                              final StateRestoreListener userStateRestoreListener) {
+        this(time, threadClientId, new StoreChangelogReader(time, config, threadClientId,
+                adminClient, groupId, restoreConsumer, userStateRestoreListener));
+    }
+
+    // for testing only
+    public StateRestoreThread(final Time time,
+                              final String threadClientId,
+                              final ChangelogReader changelogReader) {
+        super(threadClientId);
+
+        final String logPrefix = String.format("state-restore-thread [%s] ", threadClientId);
+        final LogContext logContext = new LogContext(logPrefix);
+
+        this.time = time;
+        this.log = logContext.logger(getClass());
+        this.taskItemQueue = new LinkedBlockingDeque<>();
+        this.corruptedExceptions = new LinkedBlockingDeque<>();
+        this.completedChangelogs = new AtomicReference<>(Collections.emptySet());
+
+        this.changelogReader = changelogReader;
+    }
+
+    private synchronized void waitIfAllChangelogsCompleted() {
+        final Set<TopicPartition> allChangelogs = changelogReader.allChangelogs();
+        if (allChangelogs.equals(changelogReader.completedChangelogs())) {
+            log.debug("All changelogs {} have completed restoration so far, will wait " +
+                    "until new changelogs are registered", allChangelogs);
+
+            while (isRunning.get() && taskItemQueue.isEmpty()) {
+                try {
+                    wait();
+                } catch (final InterruptedException e) {
+                    // do nothing
+                }
+            }
+        }
+    }
+
+    public synchronized void addInitializedTasks(final List<AbstractTask> tasks) {
+        if (!tasks.isEmpty()) {
+            for (final AbstractTask task: tasks) {
+                taskItemQueue.add(new TaskItem(task, ItemType.CREATE, task.changelogPartitions()));
+            }
+            notifyAll();
+        }
+    }
+
+    public synchronized void addClosedTasks(final Map<AbstractTask, Collection<TopicPartition>> tasks) {
+        if (!tasks.isEmpty()) {
+            for (final Map.Entry<AbstractTask, Collection<TopicPartition>> entry : tasks.entrySet()) {
+                taskItemQueue.add(new TaskItem(entry.getKey(), ItemType.CLOSE, entry.getValue()));
+            }
+            notifyAll();
+        }
+    }
+
+    public Set<TopicPartition> completedChangelogs() {
+        return completedChangelogs.get();
+    }
+
+    @Override
+    public void run() {
+        try {
+            while (isRunning()) {
+                runOnce();
+            }
+        } catch (final Exception e) {
+            log.error("Encountered the following exception while restoring states " +
+                    "and the thread is going to shut down: ", e);
+            throw e;
+        } finally {
+            try {
+                changelogReader.clear();
+            } catch (final Throwable e) {
+                log.error("Failed to close changelog reader due to the following error:", e);
+            }
+
+            shutdownLatch.countDown();
+        }
+    }
+
+    // Visible for testing
+    void runOnce() {
+        waitIfAllChangelogsCompleted();
+
+        if (!isRunning.get())
+            return;
+
+        // a task being recycled maybe in both closed and initialized tasks,
+        // and hence we should process the closed ones first and then initialized ones
+        final List<TaskItem> items = new ArrayList<>();
+        taskItemQueue.drainTo(items);
+
+        if (!items.isEmpty()) {
+            for (final TaskItem item : items) {
+                // TODO: we should consider also call the listener if the
+                //       changelog is not yet completed
+                if (item.type == ItemType.CLOSE) {
+                    changelogReader.unregister(item.changelogPartitions);
+
+                    log.info("Unregistered changelogs {} for closing task {}",
+                            item.task.changelogPartitions(),
+                            item.task.id());
+                } else if (item.type == ItemType.CREATE) {
+                    // we should only convert the state manager type right StateRestoreThreadTest.javabefore re-registering the changelog
+                    item.task.stateMgr.maybeConvertToNewTaskType();
+
+                    for (final TopicPartition partition : item.changelogPartitions) {
+                        changelogReader.register(partition, item.task.stateMgr);
+                    }
+
+                    log.info("Registered changelogs {} for created task {}",
+                            item.task.changelogPartitions(),
+                            item.task.id());
+                }
+            }
+        }
+        items.clear();
+
+        // try to restore some changelogs
+        final long startMs = time.milliseconds();
+        try {
+            final int numRestored = changelogReader.restore();
+            // TODO: we should record the restoration related metrics
+            log.debug("Restored {} records in {} ms", numRestored, time.milliseconds() - startMs);
+        } catch (final TaskCorruptedException e) {
+            log.warn("Detected the states of tasks " + e.corruptedTaskWithChangelogs() + " are corrupted. " +
+                    "Will close the task as dirty and re-create and bootstrap from scratch.", e);
+
+            // remove corrupted partitions form the changelog reader and continue; we can still proceed
+            // and restore other partitions until the main thread come to handle this exception
+            changelogReader.unregister(e.corruptedTaskWithChangelogs().values().stream()
+                    .flatMap(Collection::stream)
+                    .collect(Collectors.toList()));
+
+            corruptedExceptions.add(e);
+        } catch (final StreamsException e) {
+            // if we are shutting down, the consumer could throw interrupt exception which can be ignored;
+            // otherwise, we re-throw
+            if (!(e.getCause() instanceof InterruptException) || isRunning.get()) {
+                throw e;
+            }
+        } catch (final TimeoutException e) {
+            log.info("Encountered timeout when restoring states, will retry in the next loop");
+        }
+
+        // finally update completed changelogs
+        completedChangelogs.set(changelogReader.completedChangelogs());
+    }
+
+    public TaskCorruptedException nextCorruptedException() {
+        return corruptedExceptions.poll();
+    }
+
+    public void shutdown(final long timeoutMs) throws InterruptedException {
+        log.info("Shutting down");
+
+        isRunning.set(false);
+        interrupt();
+
+        final boolean ret = shutdownLatch.await(timeoutMs, TimeUnit.MILLISECONDS);

Review comment:
       nit: `ret` --> `shutdownComplete` 

##########
File path: streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
##########
@@ -623,16 +613,46 @@ void runOnce() {
             return;
         }
 
-        initializeAndRestorePhase();
+        // we need to first add closed tasks and then created tasks to work with those revived / recycled tasks
+        restoreThread.addClosedTasks(taskManager.drainRemovedTasks());
+
+        // try to initialize created tasks that are either newly assigned or re-created from corrupted tasks
+        final List<AbstractTask> initializedTasks;
+        if (!(initializedTasks = taskManager.tryInitializeNewTasks()).isEmpty()) {
+            if (log.isDebugEnabled()) {
+                log.debug("Initializing newly created tasks {} under state {}",
+                        initializedTasks.stream().map(AbstractTask::id).collect(Collectors.toList()), state);
+            }
+

Review comment:
       I think it's worth logging an INFO-level message here just saying `All tasks initialized and started restoring` or something

##########
File path: streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
##########
@@ -440,8 +455,11 @@ public void restore() {
 
                 final Map<TaskId, Collection<TopicPartition>> taskWithCorruptedChangelogs = new HashMap<>();
                 for (final TopicPartition partition : e.partitions()) {
-                    final TaskId taskId = changelogs.get(partition).stateManager.taskId();
-                    taskWithCorruptedChangelogs.computeIfAbsent(taskId, k -> new HashSet<>()).add(partition);
+                    final ChangelogMetadata metadata = changelogs.get(partition);
+                    if (metadata != null) {

Review comment:
       Can `metadata` be null? Doesn't seem like we should ever get an InvalidOffsetException for changelogs we don't have. At the very least we should log a warning, if not throw IllegalStateException

##########
File path: streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
##########
@@ -659,13 +665,12 @@ void runOnce() {
             }
         }
 
-        // we can always let changelog reader try restoring in order to initialize the changelogs;
-        // if there's no active restoring or standby updating it would not try to fetch any data
-        changelogReader.restore();
-
-        // TODO: we should record the restore latency and its relative time spent ratio after
-        //       we figure out how to move this method out of the stream thread
-        advanceNowAndComputeLatency();
+        // check if restore thread has encountered TaskCorrupted exception; if yes
+        // rethrow it to trigger the handling logic
+        final TaskCorruptedException e = restoreThread.nextCorruptedException();

Review comment:
       Sounds good. Just to throw one idea out that I think would be pretty small LOE: Instead of queueing up new TaskCorruptedExceptions, we could just store a single TaskCorruptedException and update it's taskWithChangelogs map as new corrupted tasks are detected. Since TaskCorruptedException is already multi-task, and `TaskManager#handleCorruption` takes a map of tasks as input, it seems natural to just expand the exception with new tasks instead of building up multiple multi-task exceptions.
   
   I think it would also be good to have a mutable task/changelog map so we can make sure any revoked tasks/changelogs are removed from the TaskCorruptedException. Currently, it seems like the StreamThread could actually pull and rethrow a TaskCorruptedException for a task that no longer exists on the instance




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org