You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by gu...@apache.org on 2023/02/24 18:25:30 UTC

[kafka] branch trunk updated: KAFKA-10199: Add task updater metrics, part 1 (#13228)

This is an automated email from the ASF dual-hosted git repository.

guozhang pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 2fad1652942 KAFKA-10199: Add task updater metrics, part 1 (#13228)
2fad1652942 is described below

commit 2fad1652942226454a44038f2350642817f9f74b
Author: Guozhang Wang <wa...@gmail.com>
AuthorDate: Fri Feb 24 10:25:11 2023 -0800

    KAFKA-10199: Add task updater metrics, part 1 (#13228)
    
    * Moved pausing-tasks logic out of the commit-interval loop to be on the top-level loop, similar to resuming tasks.
    * Added thread-level restoration metrics.
    * Related unit tests.
    
    Reviewers: Lucas Brutschy <lu...@users.noreply.github.com>, Matthias J. Sax <ma...@confluent.io>
---
 .../processor/internals/ChangelogReader.java       |  10 +-
 .../processor/internals/DefaultStateUpdater.java   | 230 +++++++++++++++++++--
 .../processor/internals/StoreChangelogReader.java  |  94 +++++----
 .../streams/processor/internals/StreamThread.java  |   5 +-
 .../streams/processor/internals/TaskManager.java   |   4 +-
 .../internals/metrics/StreamsMetricsImpl.java      |   5 +
 .../internals/DefaultStateUpdaterTest.java         | 151 ++++++++++++--
 .../processor/internals/MockChangelogReader.java   |   8 +-
 .../processor/internals/StreamThreadTest.java      |   2 +-
 9 files changed, 427 insertions(+), 82 deletions(-)

diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogReader.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogReader.java
index 03199d294ca..1cf8ef628da 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogReader.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogReader.java
@@ -28,8 +28,10 @@ import java.util.Set;
 public interface ChangelogReader extends ChangelogRegister {
     /**
      * Restore all registered state stores by reading from their changelogs
+     *
+     * @return the total number of records restored in this call
      */
-    void restore(final Map<TaskId, Task> tasks);
+    long restore(final Map<TaskId, Task> tasks);
 
     /**
      * Transit to restore active changelogs mode
@@ -41,6 +43,12 @@ public interface ChangelogReader extends ChangelogRegister {
      */
     void transitToUpdateStandby();
 
+    /**
+     * @return true if the reader is in restoring active changelog mode;
+     *         false if the reader is in updating standby changelog mode
+     */
+    boolean isRestoringActive();
+
     /**
      * @return the changelog partitions that have been completed restoring
      */
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
index ae6618c304f..5e912c99a5b 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
@@ -16,8 +16,15 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
+import org.apache.kafka.common.MetricName;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.errors.InterruptException;
+import org.apache.kafka.common.metrics.Metrics;
+import org.apache.kafka.common.metrics.Sensor;
+import org.apache.kafka.common.metrics.Sensor.RecordingLevel;
+import org.apache.kafka.common.metrics.stats.Avg;
+import org.apache.kafka.common.metrics.stats.Rate;
+import org.apache.kafka.common.metrics.stats.WindowedCount;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.streams.StreamsConfig;
@@ -32,7 +39,9 @@ import java.time.Duration;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.Deque;
 import java.util.HashSet;
+import java.util.LinkedHashMap;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
@@ -51,6 +60,10 @@ import java.util.function.Supplier;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
+import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.RATE_DESCRIPTION;
+import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.RATIO_DESCRIPTION;
+import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.THREAD_ID_TAG;
+
 public class DefaultStateUpdater implements StateUpdater {
 
     private final static String BUG_ERROR_MESSAGE = "This indicates a bug. " +
@@ -59,14 +72,20 @@ public class DefaultStateUpdater implements StateUpdater {
     private class StateUpdaterThread extends Thread {
 
         private final ChangelogReader changelogReader;
+        private final StateUpdaterMetrics updaterMetrics;
         private final AtomicBoolean isRunning = new AtomicBoolean(true);
         private final Map<TaskId, Task> updatingTasks = new ConcurrentHashMap<>();
         private final Map<TaskId, Task> pausedTasks = new ConcurrentHashMap<>();
         private final Logger log;
 
-        public StateUpdaterThread(final String name, final ChangelogReader changelogReader) {
+        private long totalCheckpointLatency = 0L;
+
+        public StateUpdaterThread(final String name,
+                                  final Metrics metrics,
+                                  final ChangelogReader changelogReader) {
             super(name);
             this.changelogReader = changelogReader;
+            this.updaterMetrics = new StateUpdaterMetrics(metrics, name);
 
             final String logPrefix = String.format("state-updater [%s] ", name);
             final LogContext logContext = new LogContext(logPrefix);
@@ -92,6 +111,30 @@ public class DefaultStateUpdater implements StateUpdater {
             return pausedTasks.values();
         }
 
+        public long getNumUpdatingStandbyTasks() {
+            return updatingTasks.values().stream()
+                .filter(t -> !t.isActive())
+                .count();
+        }
+
+        public long getNumRestoringActiveTasks() {
+            return updatingTasks.values().stream()
+                .filter(Task::isActive)
+                .count();
+        }
+
+        public long getNumPausedStandbyTasks() {
+            return pausedTasks.values().stream()
+                .filter(t -> !t.isActive())
+                .count();
+        }
+
+        public long getNumPausedActiveTasks() {
+            return pausedTasks.values().stream()
+                .filter(Task::isActive)
+                .count();
+        }
+
         @Override
         public void run() {
             log.info("State updater thread started");
@@ -109,17 +152,41 @@ public class DefaultStateUpdater implements StateUpdater {
                 Thread.interrupted(); // Clear the interrupted flag.
                 removeAddedTasksFromInputQueue();
                 removeUpdatingAndPausedTasks();
+                updaterMetrics.clear();
                 shutdownGate.countDown();
                 log.info("State updater thread shutdown");
             }
         }
 
+        // In each iteration:
+        //   1) check if updating tasks need to be paused
+        //   2) check if paused tasks need to be resumed
+        //   3) restore those updating tasks
+        //   4) checkpoint those updating task states
+        //   5) idle waiting if there is no more tasks to be restored
+        //
+        //   Note that, 1-3) are measured as restoring time, while 4) and 5) measured separately
+        //   as checkpointing time and idle time
         private void runOnce() throws InterruptedException {
+            final long totalStartTimeMs = time.milliseconds();
             performActionsOnTasks();
+
             resumeTasks();
-            restoreTasks();
-            checkAllUpdatingTaskStates(time.milliseconds());
+            pauseTasks();
+            restoreTasks(totalStartTimeMs);
+
+            final long checkpointStartTimeMs = time.milliseconds();
+            maybeCheckpointTasks(checkpointStartTimeMs);
+
+            final long waitStartTimeMs = time.milliseconds();
+
             waitIfAllChangelogsCompletelyRead();
+
+            final long endTimeMs = time.milliseconds();
+            final long totalWaitTime = Math.max(0L, endTimeMs - waitStartTimeMs);
+            final long totalTime = Math.max(0L, endTimeMs - totalStartTimeMs);
+
+            recordMetrics(endTimeMs, totalTime, totalWaitTime);
         }
 
         private void performActionsOnTasks() {
@@ -151,9 +218,18 @@ public class DefaultStateUpdater implements StateUpdater {
             }
         }
 
-        private void restoreTasks() {
+        private void pauseTasks() {
+            for (final Task task : updatingTasks.values()) {
+                if (topologyMetadata.isPaused(task.id().topologyName())) {
+                    pauseTask(task);
+                }
+            }
+        }
+
+        private void restoreTasks(final long now) {
             try {
-                changelogReader.restore(updatingTasks);
+                final long restored = changelogReader.restore(updatingTasks);
+                updaterMetrics.restoreSensor.record(restored, now);
             } catch (final TaskCorruptedException taskCorruptedException) {
                 handleTaskCorruptedException(taskCorruptedException);
             } catch (final StreamsException streamsException) {
@@ -193,7 +269,7 @@ public class DefaultStateUpdater implements StateUpdater {
             task.markChangelogAsCorrupted(task.changelogPartitions());
 
             // we need to enforce a checkpoint that removes the corrupted partitions
-            task.maybeCheckpoint(true);
+            measureCheckpointLatency(() -> task.maybeCheckpoint(true));
         }
 
         private void handleStreamsException(final StreamsException streamsException) {
@@ -251,10 +327,10 @@ public class DefaultStateUpdater implements StateUpdater {
 
         private void removeUpdatingAndPausedTasks() {
             changelogReader.clear();
-            updatingTasks.forEach((id, task) -> {
+            measureCheckpointLatency(() -> updatingTasks.forEach((id, task) -> {
                 task.maybeCheckpoint(true);
                 removedTasks.add(task);
-            });
+            }));
             updatingTasks.clear();
             pausedTasks.forEach((id, task) -> {
                 removedTasks.add(task);
@@ -313,7 +389,7 @@ public class DefaultStateUpdater implements StateUpdater {
             final Task task;
             if (updatingTasks.containsKey(taskId)) {
                 task = updatingTasks.get(taskId);
-                task.maybeCheckpoint(true);
+                measureCheckpointLatency(() -> task.maybeCheckpoint(true));
                 final Collection<TopicPartition> changelogPartitions = task.changelogPartitions();
                 changelogReader.unregister(changelogPartitions);
                 removedTasks.add(task);
@@ -339,7 +415,7 @@ public class DefaultStateUpdater implements StateUpdater {
         private void pauseTask(final Task task) {
             final TaskId taskId = task.id();
             // do not need to unregister changelog partitions for paused tasks
-            task.maybeCheckpoint(true);
+            measureCheckpointLatency(() -> task.maybeCheckpoint(true));
             pausedTasks.put(taskId, task);
             updatingTasks.remove(taskId);
             if (task.isActive()) {
@@ -373,7 +449,7 @@ public class DefaultStateUpdater implements StateUpdater {
                                               final Set<TopicPartition> restoredChangelogs) {
             final Collection<TopicPartition> changelogPartitions = task.changelogPartitions();
             if (restoredChangelogs.containsAll(changelogPartitions)) {
-                task.maybeCheckpoint(true);
+                measureCheckpointLatency(() -> task.maybeCheckpoint(true));
                 changelogReader.unregister(changelogPartitions);
                 addToRestoredTasks(task);
                 updatingTasks.remove(task.id());
@@ -399,31 +475,55 @@ public class DefaultStateUpdater implements StateUpdater {
             }
         }
 
-        private void checkAllUpdatingTaskStates(final long now) {
+        private void maybeCheckpointTasks(final long now) {
             final long elapsedMsSinceLastCommit = now - lastCommitMs;
             if (elapsedMsSinceLastCommit > commitIntervalMs) {
                 if (log.isDebugEnabled()) {
-                    log.debug("Checking all restoring task states since {}ms has elapsed (commit interval is {}ms)",
+                    log.debug("Checkpointing state of all restoring tasks since {}ms has elapsed (commit interval is {}ms)",
                         elapsedMsSinceLastCommit, commitIntervalMs);
                 }
 
-                for (final Task task : updatingTasks.values()) {
-                    if (topologyMetadata.isPaused(task.id().topologyName())) {
-                        pauseTask(task);
-                    } else {
-                        log.debug("Try to checkpoint current restoring progress for task {}", task.id());
+                measureCheckpointLatency(() -> {
+                    for (final Task task : updatingTasks.values()) {
                         // do not enforce checkpointing during restoration if its position has not advanced much
                         task.maybeCheckpoint(false);
                     }
-                }
+                });
 
                 lastCommitMs = now;
             }
         }
+
+        private void measureCheckpointLatency(final Runnable actionToMeasure) {
+            final long startMs = time.milliseconds();
+            try {
+                actionToMeasure.run();
+            } finally {
+                totalCheckpointLatency += Math.max(0L, time.milliseconds() - startMs);
+            }
+        }
+
+        private void recordMetrics(final long now, final long totalLatency, final long totalWaitLatency) {
+            final long totalRestoreLatency = Math.max(0L, totalLatency - totalWaitLatency - totalCheckpointLatency);
+
+            updaterMetrics.idleRatioSensor.record((double) totalWaitLatency / totalLatency, now);
+            updaterMetrics.checkpointRatioSensor.record((double) totalCheckpointLatency / totalLatency, now);
+
+            if (changelogReader.isRestoringActive()) {
+                updaterMetrics.activeRestoreRatioSensor.record((double) totalRestoreLatency / totalLatency, now);
+                updaterMetrics.standbyRestoreRatioSensor.record(0.0d, now);
+            } else {
+                updaterMetrics.standbyRestoreRatioSensor.record((double) totalRestoreLatency / totalLatency, now);
+                updaterMetrics.activeRestoreRatioSensor.record(0.0d, now);
+            }
+
+            totalCheckpointLatency = 0L;
+        }
     }
 
     private final Time time;
     private final String name;
+    private final Metrics metrics;
     private final ChangelogReader changelogReader;
     private final TopologyMetadata topologyMetadata;
     private final Queue<TaskAndAction> tasksAndActions = new LinkedList<>();
@@ -443,12 +543,14 @@ public class DefaultStateUpdater implements StateUpdater {
     private CountDownLatch shutdownGate;
 
     public DefaultStateUpdater(final String name,
+                               final Metrics metrics,
                                final StreamsConfig config,
                                final ChangelogReader changelogReader,
                                final TopologyMetadata topologyMetadata,
                                final Time time) {
         this.time = time;
         this.name = name;
+        this.metrics = metrics;
         this.changelogReader = changelogReader;
         this.topologyMetadata = topologyMetadata;
         this.commitIntervalMs = config.getLong(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG);
@@ -456,7 +558,7 @@ public class DefaultStateUpdater implements StateUpdater {
 
     public void start() {
         if (stateUpdaterThread == null) {
-            stateUpdaterThread = new StateUpdaterThread(name, changelogReader);
+            stateUpdaterThread = new StateUpdaterThread(name, metrics, changelogReader);
             stateUpdaterThread.start();
             shutdownGate = new CountDownLatch(1);
 
@@ -686,4 +788,92 @@ public class DefaultStateUpdater implements StateUpdater {
                             exceptionsAndFailedTasks.stream().flatMap(exceptionAndTasks -> exceptionAndTasks.getTasks().stream()),
                             removedTasks.stream()))));
     }
+
+    private class StateUpdaterMetrics {
+        private static final String STATE_LEVEL_GROUP = "stream-state-updater-metrics";
+
+        private static final String IDLE_RATIO_DESCRIPTION = RATIO_DESCRIPTION + "being idle";
+        private static final String RESTORE_RATIO_DESCRIPTION = RATIO_DESCRIPTION + "restoring active tasks";
+        private static final String UPDATE_RATIO_DESCRIPTION = RATIO_DESCRIPTION + "updating standby tasks";
+        private static final String CHECKPOINT_RATIO_DESCRIPTION = RATIO_DESCRIPTION + "checkpointing tasks restored progress";
+        private static final String RESTORE_RECORDS_RATE_DESCRIPTION = RATE_DESCRIPTION + "records restored";
+        private static final String RESTORE_RATE_DESCRIPTION = RATE_DESCRIPTION + "restore calls triggered";
+
+        private final Sensor restoreSensor;
+        private final Sensor idleRatioSensor;
+        private final Sensor activeRestoreRatioSensor;
+        private final Sensor standbyRestoreRatioSensor;
+        private final Sensor checkpointRatioSensor;
+
+        private final Deque<String> allSensorNames = new LinkedList<>();
+        private final Deque<MetricName> allMetricNames = new LinkedList<>();
+
+        private StateUpdaterMetrics(final Metrics metrics, final String threadId) {
+            final Map<String, String> threadLevelTags = new LinkedHashMap<>();
+            threadLevelTags.put(THREAD_ID_TAG, threadId);
+
+            MetricName metricName = metrics.metricName("active-restoring-tasks",
+                STATE_LEVEL_GROUP,
+                "The number of active tasks currently undergoing restoration",
+                threadLevelTags);
+            metrics.addMetric(metricName, (config, now) -> stateUpdaterThread != null ?
+                stateUpdaterThread.getNumRestoringActiveTasks() : 0);
+            allMetricNames.push(metricName);
+
+            metricName = metrics.metricName("standby-updating-tasks",
+                STATE_LEVEL_GROUP,
+                "The number of standby tasks currently undergoing state update",
+                threadLevelTags);
+            metrics.addMetric(metricName, (config, now) -> stateUpdaterThread != null ?
+                stateUpdaterThread.getNumUpdatingStandbyTasks() : 0);
+            allMetricNames.push(metricName);
+
+            metricName = metrics.metricName("active-paused-tasks",
+                STATE_LEVEL_GROUP,
+                "The number of active tasks paused restoring",
+                threadLevelTags);
+            metrics.addMetric(metricName, (config, now) -> stateUpdaterThread != null ?
+                stateUpdaterThread.getNumPausedActiveTasks() : 0);
+            allMetricNames.push(metricName);
+
+            metricName = metrics.metricName("standby-paused-tasks",
+                STATE_LEVEL_GROUP,
+                "The number of standby tasks paused state update",
+                threadLevelTags);
+            metrics.addMetric(metricName, (config, now) -> stateUpdaterThread != null ?
+                stateUpdaterThread.getNumPausedStandbyTasks() : 0);
+            allMetricNames.push(metricName);
+
+            this.idleRatioSensor = metrics.sensor("idle-ratio", RecordingLevel.INFO);
+            this.idleRatioSensor.add(new MetricName("idle-ratio", STATE_LEVEL_GROUP, IDLE_RATIO_DESCRIPTION, threadLevelTags), new Avg());
+            allSensorNames.add("idle-ratio");
+
+            this.activeRestoreRatioSensor = metrics.sensor("active-restore-ratio", RecordingLevel.INFO);
+            this.activeRestoreRatioSensor.add(new MetricName("active-restore-ratio", STATE_LEVEL_GROUP, RESTORE_RATIO_DESCRIPTION, threadLevelTags), new Avg());
+            allSensorNames.add("active-restore-ratio");
+
+            this.standbyRestoreRatioSensor = metrics.sensor("standby-update-ratio", RecordingLevel.INFO);
+            this.standbyRestoreRatioSensor.add(new MetricName("standby-update-ratio", STATE_LEVEL_GROUP, UPDATE_RATIO_DESCRIPTION, threadLevelTags), new Avg());
+            allSensorNames.add("standby-update-ratio");
+
+            this.checkpointRatioSensor = metrics.sensor("checkpoint-ratio", RecordingLevel.INFO);
+            this.checkpointRatioSensor.add(new MetricName("checkpoint-ratio", STATE_LEVEL_GROUP, CHECKPOINT_RATIO_DESCRIPTION, threadLevelTags), new Avg());
+            allSensorNames.add("checkpoint-ratio");
+
+            this.restoreSensor = metrics.sensor("restore-records", RecordingLevel.INFO);
+            this.restoreSensor.add(new MetricName("restore-records-rate", STATE_LEVEL_GROUP, RESTORE_RECORDS_RATE_DESCRIPTION, threadLevelTags), new Rate());
+            this.restoreSensor.add(new MetricName("restore-call-rate", STATE_LEVEL_GROUP, RESTORE_RATE_DESCRIPTION, threadLevelTags), new Rate(new WindowedCount()));
+            allSensorNames.add("restore-records");
+        }
+
+        void clear() {
+            while (!allSensorNames.isEmpty()) {
+                metrics.removeSensor(allSensorNames.pop());
+            }
+
+            while (!allMetricNames.isEmpty()) {
+                metrics.removeMetric(allMetricNames.pop());
+            }
+        }
+    }
 }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
index be580f3575c..701ecb9b567 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
@@ -330,6 +330,11 @@ public class StoreChangelogReader implements ChangelogReader {
         state = ChangelogReaderState.STANDBY_UPDATING;
     }
 
+    @Override
+    public boolean isRestoringActive() {
+        return state == ChangelogReaderState.ACTIVE_RESTORING;
+    }
+
     /**
      * Since it is shared for multiple tasks and hence multiple state managers, the registration would take its
      * corresponding state manager as well for restoring.
@@ -423,49 +428,22 @@ public class StoreChangelogReader implements ChangelogReader {
     // 2. if all changelogs have finished, return early;
     // 3. if there are any restoring changelogs, try to read from the restore consumer and process them.
     @Override
-    public void restore(final Map<TaskId, Task> tasks) {
-
-        // If we are updating only standby tasks, and are not using a separate thread, we should
-        // use a non-blocking poll to unblock the processing as soon as possible.
-        final boolean useNonBlockingPoll = state == ChangelogReaderState.STANDBY_UPDATING && !stateUpdaterEnabled;
-
+    public long restore(final Map<TaskId, Task> tasks) {
         initializeChangelogs(tasks, registeredChangelogs());
 
         if (!activeRestoringChangelogs().isEmpty() && state == ChangelogReaderState.STANDBY_UPDATING) {
             throw new IllegalStateException("Should not be in standby updating state if there are still un-completed active changelogs");
         }
 
+        long totalRestored = 0L;
         if (allChangelogsCompleted()) {
             log.debug("Finished restoring all changelogs {}", changelogs.keySet());
-            return;
+            return totalRestored;
         }
 
         final Set<TopicPartition> restoringChangelogs = restoringChangelogs();
         if (!restoringChangelogs.isEmpty()) {
-            final ConsumerRecords<byte[], byte[]> polledRecords;
-
-            try {
-                pauseResumePartitions(tasks, restoringChangelogs);
-
-                polledRecords = restoreConsumer.poll(useNonBlockingPoll ? Duration.ZERO : pollTime);
-
-                // TODO (?) If we cannot fetch records during restore, should we trigger `task.timeout.ms` ?
-                // TODO (?) If we cannot fetch records for standby task, should we trigger `task.timeout.ms` ?
-            } catch (final InvalidOffsetException e) {
-                log.warn("Encountered " + e.getClass().getName() +
-                    " fetching records from restore consumer for partitions " + e.partitions() + ", it is likely that " +
-                    "the consumer's position has fallen out of the topic partition offset range because the topic was " +
-                    "truncated or compacted on the broker, marking the corresponding tasks as corrupted and re-initializing" +
-                    " it later.", e);
-
-                final Set<TaskId> corruptedTasks = new HashSet<>();
-                e.partitions().forEach(partition -> corruptedTasks.add(changelogs.get(partition).stateManager.taskId()));
-                throw new TaskCorruptedException(corruptedTasks, e);
-            } catch (final InterruptException interruptException) {
-                throw interruptException;
-            } catch (final KafkaException e) {
-                throw new StreamsException("Restore consumer get unexpected error polling records.", e);
-            }
+            final ConsumerRecords<byte[], byte[]> polledRecords = pollRecordsFromRestoreConsumer(tasks, restoringChangelogs);
 
             for (final TopicPartition partition : polledRecords.partitions()) {
                 bufferChangelogRecords(restoringChangelogByPartition(partition), polledRecords.records(partition));
@@ -479,12 +457,15 @@ public class StoreChangelogReader implements ChangelogReader {
                 //       small batches; this can be optimized in the future, e.g. wait longer for larger batches.
                 final TaskId taskId = changelogs.get(partition).stateManager.taskId();
                 try {
-                    if (restoreChangelog(changelogs.get(partition))) {
+                    final ChangelogMetadata changelogMetadata = changelogs.get(partition);
+                    final int restored = restoreChangelog(changelogMetadata);
+                    if (restored > 0 || changelogMetadata.state().equals(ChangelogState.COMPLETED)) {
                         final Task task = tasks.get(taskId);
                         if (task != null) {
                             task.clearTaskTimeout();
                         }
                     }
+                    totalRestored += restored;
                 } catch (final TimeoutException timeoutException) {
                     tasks.get(taskId).maybeInitTaskTimeoutOrThrow(
                         time.milliseconds(),
@@ -497,6 +478,41 @@ public class StoreChangelogReader implements ChangelogReader {
 
             maybeLogRestorationProgress();
         }
+
+        return totalRestored;
+    }
+
+    private ConsumerRecords<byte[], byte[]> pollRecordsFromRestoreConsumer(final Map<TaskId, Task> tasks,
+                                                                           final Set<TopicPartition> restoringChangelogs) {
+        // If we are updating only standby tasks, and are not using a separate thread, we should
+        // use a non-blocking poll to unblock the processing as soon as possible.
+        final boolean useNonBlockingPoll = state == ChangelogReaderState.STANDBY_UPDATING && !stateUpdaterEnabled;
+        final ConsumerRecords<byte[], byte[]> polledRecords;
+
+        try {
+            pauseResumePartitions(tasks, restoringChangelogs);
+
+            polledRecords = restoreConsumer.poll(useNonBlockingPoll ? Duration.ZERO : pollTime);
+
+            // TODO (?) If we cannot fetch records during restore, should we trigger `task.timeout.ms` ?
+            // TODO (?) If we cannot fetch records for standby task, should we trigger `task.timeout.ms` ?
+        } catch (final InvalidOffsetException e) {
+            log.warn("Encountered " + e.getClass().getName() +
+                " fetching records from restore consumer for partitions " + e.partitions() + ", it is likely that " +
+                "the consumer's position has fallen out of the topic partition offset range because the topic was " +
+                "truncated or compacted on the broker, marking the corresponding tasks as corrupted and re-initializing " +
+                "it later.", e);
+
+            final Set<TaskId> corruptedTasks = new HashSet<>();
+            e.partitions().forEach(partition -> corruptedTasks.add(changelogs.get(partition).stateManager.taskId()));
+            throw new TaskCorruptedException(corruptedTasks, e);
+        } catch (final InterruptException interruptException) {
+            throw interruptException;
+        } catch (final KafkaException e) {
+            throw new StreamsException("Restore consumer get unexpected error polling records.", e);
+        }
+
+        return polledRecords;
     }
 
     private void pauseResumePartitions(final Map<TaskId, Task> tasks,
@@ -623,19 +639,17 @@ public class StoreChangelogReader implements ChangelogReader {
     /**
      * restore a changelog with its buffered records if there's any; for active changelogs also check if
      * it has completed the restoration and can transit to COMPLETED state and trigger restore callbacks
+     *
+     * @return number of records restored
      */
-    private boolean restoreChangelog(final ChangelogMetadata changelogMetadata) {
+    private int restoreChangelog(final ChangelogMetadata changelogMetadata) {
         final ProcessorStateManager stateManager = changelogMetadata.stateManager;
         final StateStoreMetadata storeMetadata = changelogMetadata.storeMetadata;
         final TopicPartition partition = storeMetadata.changelogPartition();
         final String storeName = storeMetadata.store().name();
         final int numRecords = changelogMetadata.bufferedLimitIndex;
 
-        boolean madeProgress = false;
-
         if (numRecords != 0) {
-            madeProgress = true;
-
             final List<ConsumerRecord<byte[], byte[]>> records = changelogMetadata.bufferedRecords.subList(0, numRecords);
             stateManager.restore(storeMetadata, records);
 
@@ -650,7 +664,7 @@ public class StoreChangelogReader implements ChangelogReader {
 
             final Long currentOffset = storeMetadata.offset();
             log.trace("Restored {} records from changelog {} to store {}, end offset is {}, current offset is {}",
-                partition, storeName, numRecords, recordEndOffset(changelogMetadata.restoreEndOffset), currentOffset);
+                numRecords, partition, storeName, recordEndOffset(changelogMetadata.restoreEndOffset), currentOffset);
 
             changelogMetadata.bufferedLimitIndex = 0;
             changelogMetadata.totalRestored += numRecords;
@@ -667,8 +681,6 @@ public class StoreChangelogReader implements ChangelogReader {
 
         // we should check even if there's nothing restored, but do not check completed if we are processing standby tasks
         if (changelogMetadata.stateManager.taskType() == Task.TaskType.ACTIVE && hasRestoredToEnd(changelogMetadata)) {
-            madeProgress = true;
-
             log.info("Finished restoring changelog {} to store {} with a total number of {} records",
                 partition, storeName, changelogMetadata.totalRestored);
 
@@ -682,7 +694,7 @@ public class StoreChangelogReader implements ChangelogReader {
             }
         }
 
-        return madeProgress;
+        return numRecords;
     }
 
     private Set<Task> getTasksFromPartitions(final Map<TaskId, Task> tasks,
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
index 02bd74a027d..1f2a91d27b7 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
@@ -404,7 +404,7 @@ public class StreamThread extends Thread {
             topologyMetadata,
             adminClient,
             stateDirectory,
-            maybeCreateAndStartStateUpdater(stateUpdaterEnabled, config, changelogReader, topologyMetadata, time, clientId, threadIdx)
+            maybeCreateAndStartStateUpdater(stateUpdaterEnabled, streamsMetrics, config, changelogReader, topologyMetadata, time, clientId, threadIdx)
         );
         referenceContainer.taskManager = taskManager;
 
@@ -448,6 +448,7 @@ public class StreamThread extends Thread {
     }
 
     private static StateUpdater maybeCreateAndStartStateUpdater(final boolean stateUpdaterEnabled,
+                                                                final StreamsMetricsImpl streamsMetrics,
                                                                 final StreamsConfig streamsConfig,
                                                                 final ChangelogReader changelogReader,
                                                                 final TopologyMetadata topologyMetadata,
@@ -456,7 +457,7 @@ public class StreamThread extends Thread {
                                                                 final int threadIdx) {
         if (stateUpdaterEnabled) {
             final String name = clientId + "-StateUpdater-" + threadIdx;
-            final StateUpdater stateUpdater = new DefaultStateUpdater(name, streamsConfig, changelogReader, topologyMetadata, time);
+            final StateUpdater stateUpdater = new DefaultStateUpdater(name, streamsMetrics.metricsRegistry(), streamsConfig, changelogReader, topologyMetadata, time);
             stateUpdater.start();
             return stateUpdater;
         } else {
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
index cf83cb27eba..c2f3b5253ac 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
@@ -1552,7 +1552,9 @@ public class TaskManager {
     /**
      * Returns tasks owned by the stream thread. With state updater disabled, these are all tasks. With
      * state updater enabled, this does not return any tasks currently owned by the state updater.
-     * @return
+     *
+     * TODO: after we complete switching to state updater, we could rename this function as allRunningTasks
+     *       to be differentiated from allTasks including running and restoring tasks
      */
     Map<TaskId, Task> allOwnedTasks() {
         // not bothering with an unmodifiable map, since the tasks themselves are mutable, but
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/StreamsMetricsImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/StreamsMetricsImpl.java
index 3260bfc1b82..f3cd0982eaa 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/StreamsMetricsImpl.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/StreamsMetricsImpl.java
@@ -146,6 +146,7 @@ public class StreamsMetricsImpl implements StreamsMetrics {
     public static final String OPERATIONS = " operations";
     public static final String TOTAL_DESCRIPTION = "The total number of ";
     public static final String RATE_DESCRIPTION = "The average per-second number of ";
+    public static final String RATIO_DESCRIPTION = "The fraction of time the thread spent on ";
     public static final String AVG_LATENCY_DESCRIPTION = "The average latency of ";
     public static final String MAX_LATENCY_DESCRIPTION = "The maximum latency of ";
     public static final String RATE_DESCRIPTION_PREFIX = "The average number of ";
@@ -177,6 +178,10 @@ public class StreamsMetricsImpl implements StreamsMetrics {
         return version;
     }
 
+    public Metrics metricsRegistry() {
+        return metrics;
+    }
+
     public RocksDBMetricsRecordingTrigger rocksDBMetricsRecordingTrigger() {
         return rocksDBMetricsRecordingTrigger;
     }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
index b0c0ba7c156..b3407114209 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
@@ -17,7 +17,9 @@
 package org.apache.kafka.streams.processor.internals;
 
 import org.apache.kafka.clients.producer.ProducerConfig;
+import org.apache.kafka.common.MetricName;
 import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.streams.StreamsConfig;
@@ -26,6 +28,7 @@ import org.apache.kafka.streams.errors.TaskCorruptedException;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.internals.StateUpdater.ExceptionAndTasks;
 import org.apache.kafka.streams.processor.internals.Task.State;
+import org.hamcrest.Matcher;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.Test;
 import org.mockito.InOrder;
@@ -35,6 +38,7 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashSet;
+import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Properties;
@@ -51,6 +55,10 @@ import static org.apache.kafka.test.StreamsTestUtils.TaskBuilder.statefulTask;
 import static org.apache.kafka.test.StreamsTestUtils.TaskBuilder.statelessTask;
 import static org.apache.kafka.test.StreamsTestUtils.TopologyMetadataBuilder.unnamedTopology;
 import static org.apache.kafka.test.TestUtils.waitForCondition;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.greaterThanOrEqualTo;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.not;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertThrows;
@@ -58,7 +66,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
 import static org.mockito.ArgumentMatchers.anyBoolean;
 import static org.mockito.ArgumentMatchers.anyMap;
 import static org.mockito.Mockito.atLeast;
-import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.inOrder;
 import static org.mockito.Mockito.mock;
@@ -76,6 +84,7 @@ class DefaultStateUpdaterTest {
     private final static TopicPartition TOPIC_PARTITION_A_0 = new TopicPartition("topicA", 0);
     private final static TopicPartition TOPIC_PARTITION_A_1 = new TopicPartition("topicA", 1);
     private final static TopicPartition TOPIC_PARTITION_B_0 = new TopicPartition("topicB", 0);
+    private final static TopicPartition TOPIC_PARTITION_B_1 = new TopicPartition("topicB", 1);
     private final static TopicPartition TOPIC_PARTITION_C_0 = new TopicPartition("topicC", 0);
     private final static TopicPartition TOPIC_PARTITION_D_0 = new TopicPartition("topicD", 0);
     private final static TaskId TASK_0_0 = new TaskId(0, 0);
@@ -84,14 +93,18 @@ class DefaultStateUpdaterTest {
     private final static TaskId TASK_1_0 = new TaskId(1, 0);
     private final static TaskId TASK_1_1 = new TaskId(1, 1);
     private final static TaskId TASK_A_0_0 = new TaskId(0, 0, "A");
+    private final static TaskId TASK_A_0_1 = new TaskId(0, 1, "A");
     private final static TaskId TASK_B_0_0 = new TaskId(0, 0, "B");
+    private final static TaskId TASK_B_0_1 = new TaskId(0, 1, "B");
 
     // need an auto-tick timer to work for draining with timeout
     private final Time time = new MockTime(1L);
+    private final Metrics metrics = new Metrics(time);
     private final StreamsConfig config = new StreamsConfig(configProps(COMMIT_INTERVAL));
     private final ChangelogReader changelogReader = mock(ChangelogReader.class);
     private final TopologyMetadata topologyMetadata = unnamedTopology().build();
-    private DefaultStateUpdater stateUpdater = new DefaultStateUpdater("test-state-updater", config, changelogReader, topologyMetadata, time);
+    private DefaultStateUpdater stateUpdater =
+        new DefaultStateUpdater("test-state-updater", metrics, config, changelogReader, topologyMetadata, time);
 
     @AfterEach
     public void tearDown() {
@@ -149,7 +162,7 @@ class DefaultStateUpdaterTest {
     @Test
     public void shouldRemoveUpdatingTasksOnShutdown() throws Exception {
         stateUpdater.shutdown(Duration.ofMillis(Long.MAX_VALUE));
-        stateUpdater = new DefaultStateUpdater("test-state-updater", new StreamsConfig(configProps(Integer.MAX_VALUE)), changelogReader, topologyMetadata, time);
+        stateUpdater = new DefaultStateUpdater("test-state-updater", metrics, new StreamsConfig(configProps(Integer.MAX_VALUE)), changelogReader, topologyMetadata, time);
         final StreamTask activeTask = statefulTask(TASK_0_0, mkSet(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
         final StandbyTask standbyTask = standbyTask(TASK_0_2, mkSet(TOPIC_PARTITION_C_0)).inState(State.RUNNING).build();
         when(changelogReader.completedChangelogs()).thenReturn(Collections.emptySet());
@@ -643,7 +656,7 @@ class DefaultStateUpdaterTest {
             mkEntry(activeTask2.id(), activeTask2),
             mkEntry(standbyTask.id(), standbyTask)
         );
-        doThrow(taskCorruptedException).doNothing().when(changelogReader).restore(updatingTasks1);
+        doThrow(taskCorruptedException).doReturn(0L).when(changelogReader).restore(updatingTasks1);
         when(changelogReader.allChangelogsCompleted()).thenReturn(false);
         stateUpdater.start();
 
@@ -670,7 +683,7 @@ class DefaultStateUpdaterTest {
         final TaskCorruptedException taskCorruptedException = new TaskCorruptedException(mkSet(task1.id()));
         final ExceptionAndTasks expectedExceptionAndTasks = new ExceptionAndTasks(mkSet(task1), taskCorruptedException);
         when(changelogReader.allChangelogsCompleted()).thenReturn(false);
-        doThrow(taskCorruptedException).doNothing().when(changelogReader).restore(updatingTasks);
+        doThrow(taskCorruptedException).doReturn(0L).when(changelogReader).restore(updatingTasks);
 
         stateUpdater.start();
         stateUpdater.add(task1);
@@ -828,7 +841,7 @@ class DefaultStateUpdaterTest {
             mkEntry(controlTask.id(), controlTask)
         );
         doThrow(streamsException)
-            .doNothing()
+            .doReturn(0L)
             .when(changelogReader).restore(updatingTasks);
         stateUpdater.start();
 
@@ -970,7 +983,7 @@ class DefaultStateUpdaterTest {
             mkEntry(controlTask.id(), controlTask)
         );
         doThrow(streamsException)
-            .doNothing()
+            .doReturn(0L)
             .when(changelogReader).restore(updatingTasks);
         stateUpdater.start();
 
@@ -1133,7 +1146,7 @@ class DefaultStateUpdaterTest {
                 mkEntry(controlTask.id(), controlTask)
         );
         doThrow(streamsException)
-                .doNothing()
+                .doReturn(0L)
                 .when(changelogReader).restore(updatingTasks);
         stateUpdater.start();
 
@@ -1184,7 +1197,7 @@ class DefaultStateUpdaterTest {
             mkEntry(task1.id(), task1),
             mkEntry(task2.id(), task2)
         );
-        doNothing().doThrow(streamsException).when(changelogReader).restore(updatingTasks);
+        doReturn(0L).doThrow(streamsException).when(changelogReader).restore(updatingTasks);
         stateUpdater.start();
 
         stateUpdater.add(task1);
@@ -1215,10 +1228,10 @@ class DefaultStateUpdaterTest {
             mkEntry(task2.id(), task2),
             mkEntry(task3.id(), task3)
         );
-        doNothing()
+        doReturn(0L)
             .doThrow(streamsException1)
             .when(changelogReader).restore(updatingTasksBeforeFirstThrow);
-        doNothing()
+        doReturn(0L)
             .doThrow(streamsException2)
             .when(changelogReader).restore(updatingTasksBeforeSecondThrow);
         stateUpdater.start();
@@ -1248,7 +1261,7 @@ class DefaultStateUpdaterTest {
             mkEntry(task2.id(), task2),
             mkEntry(task3.id(), task3)
         );
-        doNothing().doThrow(taskCorruptedException).doNothing().when(changelogReader).restore(updatingTasks);
+        doReturn(0L).doThrow(taskCorruptedException).doReturn(0L).when(changelogReader).restore(updatingTasks);
         stateUpdater.start();
 
         stateUpdater.add(task1);
@@ -1362,7 +1375,7 @@ class DefaultStateUpdaterTest {
     public void shouldNotAutoCheckpointTasksIfIntervalNotElapsed() {
         // we need to use a non auto-ticking timer here to control how much time elapsed exactly
         final Time time = new MockTime();
-        final DefaultStateUpdater stateUpdater = new DefaultStateUpdater("test-state-updater", config, changelogReader, topologyMetadata, time);
+        final DefaultStateUpdater stateUpdater = new DefaultStateUpdater("test-state-updater", metrics, config, changelogReader, topologyMetadata, time);
         try {
             final StreamTask task1 = statefulTask(TASK_0_0, mkSet(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
             final StreamTask task2 = statefulTask(TASK_0_2, mkSet(TOPIC_PARTITION_B_0)).inState(State.RESTORING).build();
@@ -1466,11 +1479,11 @@ class DefaultStateUpdaterTest {
             mkEntry(standbyTask1.id(), standbyTask1),
             mkEntry(standbyTask2.id(), standbyTask2)
         );
-        doNothing().doThrow(taskCorruptedException).doNothing().when(changelogReader).restore(updatingTasks1);
+        doReturn(0L).doThrow(taskCorruptedException).doReturn(0L).when(changelogReader).restore(updatingTasks1);
         final Map<TaskId, Task> updatingTasks2 = mkMap(
             mkEntry(activeTask1.id(), activeTask1)
         );
-        doNothing().doThrow(streamsException).doNothing().when(changelogReader).restore(updatingTasks2);
+        doReturn(0L).doThrow(streamsException).doReturn(0L).when(changelogReader).restore(updatingTasks2);
         stateUpdater.start();
         stateUpdater.add(standbyTask1);
         stateUpdater.add(activeTask1);
@@ -1526,6 +1539,114 @@ class DefaultStateUpdaterTest {
         verifyGetTasks(mkSet(activeTask), mkSet(standbyTask));
     }
 
+    @Test
+    public void shouldRecordMetrics() throws Exception {
+        final StreamTask activeTask1 = statefulTask(TASK_A_0_0, mkSet(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
+        final StreamTask activeTask2 = statefulTask(TASK_B_0_0, mkSet(TOPIC_PARTITION_B_0)).inState(State.RESTORING).build();
+        final StandbyTask standbyTask3 = standbyTask(TASK_A_0_1, mkSet(TOPIC_PARTITION_A_1)).inState(State.RUNNING).build();
+        final StandbyTask standbyTask4 = standbyTask(TASK_B_0_1, mkSet(TOPIC_PARTITION_B_1)).inState(State.RUNNING).build();
+        final Map<TaskId, Task> tasks1234 = mkMap(
+            mkEntry(activeTask1.id(), activeTask1),
+            mkEntry(activeTask2.id(), activeTask2),
+            mkEntry(standbyTask3.id(), standbyTask3),
+            mkEntry(standbyTask4.id(), standbyTask4)
+        );
+        final Map<TaskId, Task> tasks13 = mkMap(
+            mkEntry(activeTask1.id(), activeTask1),
+            mkEntry(standbyTask3.id(), standbyTask3)
+        );
+
+        when(topologyMetadata.isPaused("B")).thenReturn(true);
+        when(changelogReader.completedChangelogs()).thenReturn(Collections.emptySet());
+        when(changelogReader.allChangelogsCompleted()).thenReturn(false);
+        when(changelogReader.restore(tasks1234)).thenReturn(1L);
+        when(changelogReader.restore(tasks13)).thenReturn(1L);
+        when(changelogReader.isRestoringActive()).thenReturn(true);
+        stateUpdater.start();
+        stateUpdater.add(activeTask1);
+        stateUpdater.add(activeTask2);
+        stateUpdater.add(standbyTask3);
+        stateUpdater.add(standbyTask4);
+
+        verifyPausedTasks(activeTask2, standbyTask4);
+        assertThat(metrics.metrics().size(), is(11));
+
+        final Map<String, String> tagMap = new LinkedHashMap<>();
+        tagMap.put("thread-id", "test-state-updater");
+
+        MetricName metricName = new MetricName("active-restoring-tasks",
+            "stream-state-updater-metrics",
+            "The number of active tasks currently undergoing restoration",
+            tagMap);
+        verifyMetric(metrics, metricName, is(1.0));
+
+        metricName = new MetricName("standby-updating-tasks",
+            "stream-state-updater-metrics",
+            "The number of standby tasks currently undergoing state update",
+            tagMap);
+        verifyMetric(metrics, metricName, is(1.0));
+
+        metricName = new MetricName("active-paused-tasks",
+            "stream-state-updater-metrics",
+            "The number of active tasks paused restoring",
+            tagMap);
+        verifyMetric(metrics, metricName, is(1.0));
+
+        metricName = new MetricName("standby-paused-tasks",
+            "stream-state-updater-metrics",
+            "The number of standby tasks paused state update",
+            tagMap);
+        verifyMetric(metrics, metricName, is(1.0));
+
+        metricName = new MetricName("idle-ratio",
+            "stream-state-updater-metrics",
+            "The fraction of time the thread spent on being idle",
+            tagMap);
+        verifyMetric(metrics, metricName, greaterThanOrEqualTo(0.0d));
+
+        metricName = new MetricName("active-restore-ratio",
+            "stream-state-updater-metrics",
+            "The fraction of time the thread spent on restoring active tasks",
+            tagMap);
+        verifyMetric(metrics, metricName, greaterThanOrEqualTo(0.0d));
+
+        metricName = new MetricName("standby-update-ratio",
+            "stream-state-updater-metrics",
+            "The fraction of time the thread spent on updating standby tasks",
+            tagMap);
+        verifyMetric(metrics, metricName, is(0.0d));
+
+        metricName = new MetricName("checkpoint-ratio",
+            "stream-state-updater-metrics",
+            "The fraction of time the thread spent on checkpointing tasks restored progress",
+            tagMap);
+        verifyMetric(metrics, metricName, greaterThanOrEqualTo(0.0d));
+
+        metricName = new MetricName("restore-records-rate",
+            "stream-state-updater-metrics",
+            "The average per-second number of records restored",
+            tagMap);
+        verifyMetric(metrics, metricName, not(0.0d));
+
+        metricName = new MetricName("restore-call-rate",
+            "stream-state-updater-metrics",
+            "The average per-second number of restore calls triggered",
+            tagMap);
+        verifyMetric(metrics, metricName, not(0.0d));
+
+        stateUpdater.shutdown(Duration.ofMinutes(1));
+        assertThat(metrics.metrics().size(), is(1));
+    }
+
+    @SuppressWarnings("unchecked")
+    private static <T> void verifyMetric(final Metrics metrics,
+                                         final MetricName metricName,
+                                         final Matcher<T> matcher) {
+        assertThat(metrics.metrics().get(metricName).metricName().description(), is(metricName.description()));
+        assertThat(metrics.metrics().get(metricName).metricName().tags(), is(metricName.tags()));
+        assertThat((T) metrics.metrics().get(metricName).metricValue(), matcher);
+    }
+
     private void verifyGetTasks(final Set<StreamTask> expectedActiveTasks,
                                 final Set<StandbyTask> expectedStandbyTasks) {
         final Set<Task> tasks = stateUpdater.getTasks();
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/MockChangelogReader.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/MockChangelogReader.java
index a70420cb44b..8d0f8c7a6b0 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/MockChangelogReader.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/MockChangelogReader.java
@@ -46,8 +46,9 @@ public class MockChangelogReader implements ChangelogReader {
     }
 
     @Override
-    public void restore(final Map<TaskId, Task> tasks) {
+    public long restore(final Map<TaskId, Task> tasks) {
         // do nothing
+        return 0L;
     }
 
     @Override
@@ -60,6 +61,11 @@ public class MockChangelogReader implements ChangelogReader {
         // do nothing
     }
 
+    @Override
+    public boolean isRestoringActive() {
+        return true;
+    }
+
     @Override
     public Set<TopicPartition> completedChangelogs() {
         // assuming all restoring partitions are completed
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
index a8768ddd3eb..6050de77be0 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
@@ -1245,7 +1245,7 @@ public class StreamThreadTest {
 
         final ChangelogReader changelogReader = new MockChangelogReader() {
             @Override
-            public void restore(final Map<TaskId, Task> tasks) {
+            public long restore(final Map<TaskId, Task> tasks) {
                 consumer.addRecord(new ConsumerRecord<>(topic1, 1, 11, new byte[0], new byte[0]));
                 consumer.addRecord(new ConsumerRecord<>(topic1, 1, 12, new byte[1], new byte[0]));