You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by mj...@apache.org on 2023/04/05 18:49:17 UTC

[kafka] branch trunk updated: KAFKA-10199: Add task updater metrics, part 2 (#13300)

This is an automated email from the ASF dual-hosted git repository.

mjsax pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 653baa66948 KAFKA-10199: Add task updater metrics, part 2 (#13300)
653baa66948 is described below

commit 653baa669486cca928a7bafba8ec47626624fbdc
Author: Guozhang Wang <wa...@gmail.com>
AuthorDate: Wed Apr 5 11:49:08 2023 -0700

    KAFKA-10199: Add task updater metrics, part 2 (#13300)
    
    Part of KIP-869
    
    Reviewers: Lucas Brutschy <lb...@confluent.io>, Matthias J. Sax <ma...@confluent.io>
---
 .../processor/internals/DefaultStateUpdater.java   |  1 -
 .../streams/processor/internals/ReadOnlyTask.java  |  6 ++
 .../streams/processor/internals/StandbyTask.java   | 11 +++
 .../processor/internals/StoreChangelogReader.java  | 33 +++++---
 .../streams/processor/internals/StreamTask.java    | 15 ++++
 .../kafka/streams/processor/internals/Task.java    |  3 +
 .../internals/metrics/ProcessorNodeMetrics.java    | 26 +++---
 .../internals/metrics/StreamsMetricsImpl.java      | 79 ++++++++---------
 .../processor/internals/metrics/TaskMetrics.java   | 99 +++++++++++++++++++---
 .../processor/internals/metrics/ThreadMetrics.java | 16 ++--
 .../state/internals/metrics/RocksDBMetrics.java    |  4 +-
 .../state/internals/metrics/StateStoreMetrics.java | 51 +++--------
 .../internals/DefaultStateUpdaterTest.java         |  1 +
 .../processor/internals/StandbyTaskTest.java       | 49 ++++++++++-
 .../processor/internals/StreamTaskTest.java        | 31 +++++++
 .../processor/internals/TaskManagerTest.java       |  5 ++
 .../internals/metrics/StateStoreMetricsTest.java   | 26 ------
 17 files changed, 300 insertions(+), 156 deletions(-)

diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
index cb96f545870..12cda4688e3 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
@@ -168,7 +168,6 @@ public class DefaultStateUpdater implements StateUpdater {
             maybeCheckpointTasks(checkpointStartTimeMs);
 
             final long waitStartTimeMs = time.milliseconds();
-
             waitIfAllChangelogsCompletelyRead();
 
             final long endTimeMs = time.milliseconds();
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ReadOnlyTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ReadOnlyTask.java
index ee3989cf62e..4ae67856736 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ReadOnlyTask.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ReadOnlyTask.java
@@ -19,6 +19,7 @@ package org.apache.kafka.streams.processor.internals;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.clients.consumer.OffsetAndMetadata;
 import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.TaskId;
 
@@ -188,6 +189,11 @@ public class ReadOnlyTask implements Task {
         throw new UnsupportedOperationException("This task is read-only");
     }
 
+    @Override
+    public void maybeRecordRestored(final Time time, final long numRecords) {
+        throw new UnsupportedOperationException("This task is read-only");
+    }
+
     @Override
     public boolean commitNeeded() {
         throw new UnsupportedOperationException("This task is read-only");
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java
index 5c416c5574e..bcaf9a9dae0 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java
@@ -20,6 +20,7 @@ import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.clients.consumer.OffsetAndMetadata;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.metrics.Sensor;
+import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.StreamsMetrics;
 import org.apache.kafka.streams.errors.StreamsException;
@@ -27,6 +28,7 @@ import org.apache.kafka.streams.errors.TaskCorruptedException;
 import org.apache.kafka.streams.errors.TaskMigratedException;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
+import org.apache.kafka.streams.processor.internals.metrics.TaskMetrics;
 import org.apache.kafka.streams.processor.internals.metrics.ThreadMetrics;
 import org.apache.kafka.streams.TopologyConfig.TaskConfig;
 import org.apache.kafka.streams.state.internals.ThreadCache;
@@ -36,12 +38,15 @@ import java.util.Map;
 import java.util.Optional;
 import java.util.Set;
 
+import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.maybeRecordSensor;
+
 /**
  * A StandbyTask
  */
 public class StandbyTask extends AbstractTask implements Task {
     private final boolean eosEnabled;
     private final Sensor closeTaskSensor;
+    private final Sensor updateSensor;
     private final StreamsMetricsImpl streamsMetrics;
 
     @SuppressWarnings("rawtypes")
@@ -81,6 +86,7 @@ public class StandbyTask extends AbstractTask implements Task {
         processorContext.transitionToStandby(cache);
 
         closeTaskSensor = ThreadMetrics.closeTaskSensor(Thread.currentThread().getName(), streamsMetrics);
+        updateSensor = TaskMetrics.updateSensor(Thread.currentThread().getName(), id.toString(), streamsMetrics);
         this.eosEnabled = config.eosEnabled;
     }
 
@@ -89,6 +95,11 @@ public class StandbyTask extends AbstractTask implements Task {
         return false;
     }
 
+    @Override
+    public void maybeRecordRestored(final Time time, final long numRecords) {
+        maybeRecordSensor(numRecords, time, updateSensor);
+    }
+
     /**
      * @throws TaskCorruptedException if the state cannot be reused (with EOS) and needs to be reset)
      * @throws StreamsException fatal error, should close the thread
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
index 701ecb9b567..8c201834c67 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
@@ -457,15 +457,9 @@ public class StoreChangelogReader implements ChangelogReader {
                 //       small batches; this can be optimized in the future, e.g. wait longer for larger batches.
                 final TaskId taskId = changelogs.get(partition).stateManager.taskId();
                 try {
+                    final Task task = tasks.get(taskId);
                     final ChangelogMetadata changelogMetadata = changelogs.get(partition);
-                    final int restored = restoreChangelog(changelogMetadata);
-                    if (restored > 0 || changelogMetadata.state().equals(ChangelogState.COMPLETED)) {
-                        final Task task = tasks.get(taskId);
-                        if (task != null) {
-                            task.clearTaskTimeout();
-                        }
-                    }
-                    totalRestored += restored;
+                    totalRestored += restoreChangelog(task, changelogMetadata);
                 } catch (final TimeoutException timeoutException) {
                     tasks.get(taskId).maybeInitTaskTimeoutOrThrow(
                         time.milliseconds(),
@@ -642,7 +636,7 @@ public class StoreChangelogReader implements ChangelogReader {
      *
      * @return number of records restored
      */
-    private int restoreChangelog(final ChangelogMetadata changelogMetadata) {
+    private int restoreChangelog(final Task task, final ChangelogMetadata changelogMetadata) {
         final ProcessorStateManager stateManager = changelogMetadata.stateManager;
         final StateStoreMetadata storeMetadata = changelogMetadata.storeMetadata;
         final TopicPartition partition = storeMetadata.changelogPartition();
@@ -662,6 +656,8 @@ public class StoreChangelogReader implements ChangelogReader {
                 changelogMetadata.bufferedRecords.clear();
             }
 
+            task.maybeRecordRestored(time, records.size());
+
             final Long currentOffset = storeMetadata.offset();
             log.trace("Restored {} records from changelog {} to store {}, end offset is {}, current offset is {}",
                 numRecords, partition, storeName, recordEndOffset(changelogMetadata.restoreEndOffset), currentOffset);
@@ -694,6 +690,10 @@ public class StoreChangelogReader implements ChangelogReader {
             }
         }
 
+        if (numRecords > 0 || changelogMetadata.state().equals(ChangelogState.COMPLETED)) {
+            task.clearTaskTimeout();
+        }
+
         return numRecords;
     }
 
@@ -857,7 +857,7 @@ public class StoreChangelogReader implements ChangelogReader {
             }
         }
 
-        // try initialize limit offsets for standby tasks for the first time
+        // try initializing limit offsets for standby tasks for the first time
         if (!committedOffsets.isEmpty()) {
             updateLimitOffsetsForStandbyChangelogs(committedOffsets);
         }
@@ -875,7 +875,7 @@ public class StoreChangelogReader implements ChangelogReader {
         }
 
         // prepare newly added partitions of the restore consumer by setting their starting position
-        prepareChangelogs(newPartitionsToRestore);
+        prepareChangelogs(tasks, newPartitionsToRestore);
     }
 
     private void addChangelogsToRestoreConsumer(final Set<TopicPartition> partitions) {
@@ -930,7 +930,8 @@ public class StoreChangelogReader implements ChangelogReader {
         log.debug("Resumed partitions {} from the restore consumer", partitions);
     }
 
-    private void prepareChangelogs(final Set<ChangelogMetadata> newPartitionsToRestore) {
+    private void prepareChangelogs(final Map<TaskId, Task> tasks,
+                                   final Set<ChangelogMetadata> newPartitionsToRestore) {
         // separate those who do not have the current offset loaded from checkpoint
         final Set<TopicPartition> newPartitionsWithoutStartOffset = new HashSet<>();
 
@@ -986,6 +987,14 @@ public class StoreChangelogReader implements ChangelogReader {
                 } catch (final Exception e) {
                     throw new StreamsException("State restore listener failed on batch restored", e);
                 }
+
+                final TaskId taskId = changelogs.get(partition).stateManager.taskId();
+                final StreamTask task = (StreamTask) tasks.get(taskId);
+                // if the log is truncated between when we get the log end offset and when we get the
+                // consumer position, then it's possible that the difference become negative and there's actually
+                // no records to restore; in this case we just initialize the sensor to zero
+                final long recordsToRestore = Math.max(changelogMetadata.restoreEndOffset - startOffset, 0L);
+                task.initRemainingRecordsToRestore(time, recordsToRestore);
             }
         }
     }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
index f3b5818ddfb..a44c13d8e13 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
@@ -60,6 +60,7 @@ import java.util.function.Function;
 import java.util.stream.Collectors;
 
 import static java.util.Collections.singleton;
+import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.maybeRecordSensor;
 import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.maybeMeasureLatency;
 
 /**
@@ -92,6 +93,8 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
     private final Sensor closeTaskSensor;
     private final Sensor processRatioSensor;
     private final Sensor processLatencySensor;
+    private final Sensor restoreSensor;
+    private final Sensor restoreRemainingSensor;
     private final Sensor punctuateLatencySensor;
     private final Sensor bufferedRecordsSensor;
     private final Map<String, Sensor> e2eLatencySensors = new HashMap<>();
@@ -144,6 +147,8 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
         this.streamsMetrics = streamsMetrics;
         closeTaskSensor = ThreadMetrics.closeTaskSensor(threadId, streamsMetrics);
         final String taskId = id.toString();
+        restoreSensor = TaskMetrics.restoreSensor(threadId, taskId, streamsMetrics);
+        restoreRemainingSensor = TaskMetrics.restoreRemainingRecordsSensor(threadId, taskId, streamsMetrics);
         processRatioSensor = TaskMetrics.activeProcessRatioSensor(threadId, taskId, streamsMetrics);
         processLatencySensor = TaskMetrics.processLatencySensor(threadId, taskId, streamsMetrics);
         punctuateLatencySensor = TaskMetrics.punctuateSensor(threadId, taskId, streamsMetrics);
@@ -213,6 +218,16 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
         return true;
     }
 
+    @Override
+    public void maybeRecordRestored(final Time time, final long numRecords) {
+        maybeRecordSensor(numRecords, time, restoreSensor);
+        maybeRecordSensor(-1 * numRecords, time, restoreRemainingSensor);
+    }
+
+    public void initRemainingRecordsToRestore(final Time time, final long numRecords) {
+        maybeRecordSensor(numRecords, time, restoreRemainingSensor);
+    }
+
     /**
      * @throws TaskCorruptedException if the state cannot be reused (with EOS) and needs to be reset
      * @throws LockException    could happen when multi-threads within the single instance, could retry
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java
index 58462760759..5b08de2bf45 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java
@@ -19,6 +19,7 @@ package org.apache.kafka.streams.processor.internals;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.clients.consumer.OffsetAndMetadata;
 import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.streams.errors.TaskCorruptedException;
 import org.apache.kafka.streams.errors.LockException;
 import org.apache.kafka.streams.errors.StreamsException;
@@ -202,6 +203,8 @@ public interface Task {
 
     void clearTaskTimeout();
 
+    void maybeRecordRestored(final Time time, final long numRecords);
+
     // task status inquiry
 
     TaskId id();
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/ProcessorNodeMetrics.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/ProcessorNodeMetrics.java
index 8dcd265a244..1a7d914b941 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/ProcessorNodeMetrics.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/ProcessorNodeMetrics.java
@@ -166,8 +166,8 @@ public class ProcessorNodeMetrics {
                                           final String taskId,
                                           final String processorNodeId,
                                           final StreamsMetricsImpl streamsMetrics) {
-        final String sensorName = processorNodeId + "-" + RECORD_E2E_LATENCY;
-        final Sensor sensor = streamsMetrics.nodeLevelSensor(threadId, taskId, processorNodeId, sensorName, RecordingLevel.INFO);
+        final String sensorSuffix = processorNodeId + "-" + RECORD_E2E_LATENCY;
+        final Sensor sensor = streamsMetrics.nodeLevelSensor(threadId, taskId, processorNodeId, sensorSuffix, RecordingLevel.INFO);
         final Map<String, String> tagMap = streamsMetrics.nodeLevelTagMap(threadId, taskId, processorNodeId);
         addAvgAndMinAndMaxToSensor(
             sensor,
@@ -185,8 +185,8 @@ public class ProcessorNodeMetrics {
                                                 final String taskId,
                                                 final String processorNodeId,
                                                 final StreamsMetricsImpl streamsMetrics) {
-        final String sensorName = processorNodeId + "-" + EMIT_FINAL_LATENCY;
-        final Sensor sensor = streamsMetrics.nodeLevelSensor(threadId, taskId, processorNodeId, sensorName, RecordingLevel.DEBUG);
+        final String sensorSuffix = processorNodeId + "-" + EMIT_FINAL_LATENCY;
+        final Sensor sensor = streamsMetrics.nodeLevelSensor(threadId, taskId, processorNodeId, sensorSuffix, RecordingLevel.DEBUG);
         final Map<String, String> tagMap = streamsMetrics.nodeLevelTagMap(threadId, taskId, processorNodeId);
         addAvgAndMaxToSensor(
             sensor,
@@ -203,8 +203,8 @@ public class ProcessorNodeMetrics {
                                               final String taskId,
                                               final String processorNodeId,
                                               final StreamsMetricsImpl streamsMetrics) {
-        final String sensorName = processorNodeId + "-" + EMITTED_RECORDS;
-        final Sensor sensor = streamsMetrics.nodeLevelSensor(threadId, taskId, processorNodeId, sensorName, RecordingLevel.DEBUG);
+        final String sensorSuffix = processorNodeId + "-" + EMITTED_RECORDS;
+        final Sensor sensor = streamsMetrics.nodeLevelSensor(threadId, taskId, processorNodeId, sensorSuffix, RecordingLevel.DEBUG);
         final Map<String, String> tagMap = streamsMetrics.nodeLevelTagMap(threadId, taskId, processorNodeId);
         addRateOfSumAndSumMetricsToSensor(
             sensor,
@@ -219,18 +219,19 @@ public class ProcessorNodeMetrics {
 
     private static Sensor throughputParentSensor(final String threadId,
                                                  final String taskId,
-                                                 final String metricNamePrefix,
+                                                 final String operation,
                                                  final String descriptionOfRate,
                                                  final String descriptionOfCount,
                                                  final RecordingLevel recordingLevel,
                                                  final StreamsMetricsImpl streamsMetrics) {
-        final Sensor sensor = streamsMetrics.taskLevelSensor(threadId, taskId, metricNamePrefix, recordingLevel);
+        // use operation name as sensor suffix and metric prefix
+        final Sensor sensor = streamsMetrics.taskLevelSensor(threadId, taskId, operation, recordingLevel);
         final Map<String, String> parentTagMap = streamsMetrics.nodeLevelTagMap(threadId, taskId, ROLLUP_VALUE);
         addInvocationRateAndCountToSensor(
             sensor,
             PROCESSOR_NODE_LEVEL_GROUP,
             parentTagMap,
-            metricNamePrefix,
+            operation,
             descriptionOfRate,
             descriptionOfCount
         );
@@ -240,20 +241,21 @@ public class ProcessorNodeMetrics {
     private static Sensor throughputSensor(final String threadId,
                                            final String taskId,
                                            final String processorNodeId,
-                                           final String metricNamePrefix,
+                                           final String operationName,
                                            final String descriptionOfRate,
                                            final String descriptionOfCount,
                                            final RecordingLevel recordingLevel,
                                            final StreamsMetricsImpl streamsMetrics,
                                            final Sensor... parentSensors) {
+        // use operation name as sensor suffix and metric name prefix
         final Sensor sensor =
-            streamsMetrics.nodeLevelSensor(threadId, taskId, processorNodeId, metricNamePrefix, recordingLevel, parentSensors);
+            streamsMetrics.nodeLevelSensor(threadId, taskId, processorNodeId, operationName, recordingLevel, parentSensors);
         final Map<String, String> tagMap = streamsMetrics.nodeLevelTagMap(threadId, taskId, processorNodeId);
         addInvocationRateAndCountToSensor(
             sensor,
             PROCESSOR_NODE_LEVEL_GROUP,
             tagMap,
-            metricNamePrefix,
+            operationName,
             descriptionOfRate,
             descriptionOfCount
         );
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/StreamsMetricsImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/StreamsMetricsImpl.java
index f3cd0982eaa..e7a5c3202a0 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/StreamsMetricsImpl.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/StreamsMetricsImpl.java
@@ -255,12 +255,12 @@ public class StreamsMetricsImpl implements StreamsMetrics {
     }
 
     public final Sensor threadLevelSensor(final String threadId,
-                                          final String sensorName,
+                                          final String sensorSuffix,
                                           final RecordingLevel recordingLevel,
                                           final Sensor... parents) {
-        final String key = threadSensorPrefix(threadId);
+        final String sensorPrefix = threadSensorPrefix(threadId);
         synchronized (threadLevelSensors) {
-            return getSensors(threadLevelSensors, sensorName, key, recordingLevel, parents);
+            return getSensors(threadLevelSensors, sensorSuffix, sensorPrefix, recordingLevel, parents);
         }
     }
 
@@ -353,12 +353,12 @@ public class StreamsMetricsImpl implements StreamsMetrics {
 
     public final Sensor taskLevelSensor(final String threadId,
                                         final String taskId,
-                                        final String sensorName,
+                                        final String sensorSuffix,
                                         final RecordingLevel recordingLevel,
                                         final Sensor... parents) {
-        final String key = taskSensorPrefix(threadId, taskId);
+        final String sensorPrefix = taskSensorPrefix(threadId, taskId);
         synchronized (taskLevelSensors) {
-            return getSensors(taskLevelSensors, sensorName, key, recordingLevel, parents);
+            return getSensors(taskLevelSensors, sensorSuffix, sensorPrefix, recordingLevel, parents);
         }
     }
 
@@ -380,12 +380,12 @@ public class StreamsMetricsImpl implements StreamsMetrics {
     public Sensor nodeLevelSensor(final String threadId,
                                   final String taskId,
                                   final String processorNodeName,
-                                  final String sensorName,
+                                  final String sensorSuffix,
                                   final Sensor.RecordingLevel recordingLevel,
                                   final Sensor... parents) {
-        final String key = nodeSensorPrefix(threadId, taskId, processorNodeName);
+        final String sensorPrefix = nodeSensorPrefix(threadId, taskId, processorNodeName);
         synchronized (nodeLevelSensors) {
-            return getSensors(nodeLevelSensors, sensorName, key, recordingLevel, parents);
+            return getSensors(nodeLevelSensors, sensorSuffix, sensorPrefix, recordingLevel, parents);
         }
     }
 
@@ -410,12 +410,12 @@ public class StreamsMetricsImpl implements StreamsMetrics {
                                    final String taskId,
                                    final String processorNodeName,
                                    final String topicName,
-                                   final String sensorName,
+                                   final String sensorSuffix,
                                    final Sensor.RecordingLevel recordingLevel,
                                    final Sensor... parents) {
-        final String key = topicSensorPrefix(threadId, taskId, processorNodeName, topicName);
+        final String sensorPrefix = topicSensorPrefix(threadId, taskId, processorNodeName, topicName);
         synchronized (topicLevelSensors) {
-            return getSensors(topicLevelSensors, sensorName, key, recordingLevel, parents);
+            return getSensors(topicLevelSensors, sensorSuffix, sensorPrefix, recordingLevel, parents);
         }
     }
 
@@ -443,12 +443,13 @@ public class StreamsMetricsImpl implements StreamsMetrics {
     public Sensor cacheLevelSensor(final String threadId,
                                    final String taskName,
                                    final String storeName,
-                                   final String sensorName,
+                                   final String ratioName,
                                    final Sensor.RecordingLevel recordingLevel,
                                    final Sensor... parents) {
-        final String key = cacheSensorPrefix(threadId, taskName, storeName);
+        // use ratio name as sensor suffix
+        final String sensorPrefix = cacheSensorPrefix(threadId, taskName, storeName);
         synchronized (cacheLevelSensors) {
-            return getSensors(cacheLevelSensors, sensorName, key, recordingLevel, parents);
+            return getSensors(cacheLevelSensors, ratioName, sensorPrefix, recordingLevel, parents);
         }
     }
 
@@ -479,10 +480,10 @@ public class StreamsMetricsImpl implements StreamsMetrics {
 
     public final Sensor storeLevelSensor(final String taskId,
                                          final String storeName,
-                                         final String sensorName,
+                                         final String sensorSuffix,
                                          final RecordingLevel recordingLevel,
                                          final Sensor... parents) {
-        final String key = storeSensorPrefix(Thread.currentThread().getName(), taskId, storeName);
+        final String sensorPrefix = storeSensorPrefix(Thread.currentThread().getName(), taskId, storeName);
             // since the keys in the map storeLevelSensors contain the name of the current thread and threads only
             // access keys in which their name is contained, the value in the maps do not need to be thread safe
             // and we can use a LinkedList here.
@@ -490,7 +491,7 @@ public class StreamsMetricsImpl implements StreamsMetrics {
             //  that contain its name. Similar is true for the other metric levels. Thread-level metrics need some
             //  special attention, since they are created before the thread is constructed. The creation of those
             //  metrics could be moved into the run() method of the thread.
-        return getSensors(storeLevelSensors, sensorName, key, recordingLevel, parents);
+        return getSensors(storeLevelSensors, sensorSuffix, sensorPrefix, recordingLevel, parents);
     }
 
     public <T> void addStoreLevelMutableMetric(final String taskId,
@@ -647,12 +648,12 @@ public class StreamsMetricsImpl implements StreamsMetrics {
     public static void addAvgAndMaxToSensor(final Sensor sensor,
                                             final String group,
                                             final Map<String, String> tags,
-                                            final String operation,
+                                            final String gaugeName,
                                             final String descriptionOfAvg,
                                             final String descriptionOfMax) {
         sensor.add(
             new MetricName(
-                operation + AVG_SUFFIX,
+                gaugeName + AVG_SUFFIX,
                 group,
                 descriptionOfAvg,
                 tags),
@@ -660,7 +661,7 @@ public class StreamsMetricsImpl implements StreamsMetrics {
         );
         sensor.add(
             new MetricName(
-                operation + MAX_SUFFIX,
+                gaugeName + MAX_SUFFIX,
                 group,
                 descriptionOfMax,
                 tags),
@@ -718,14 +719,14 @@ public class StreamsMetricsImpl implements StreamsMetrics {
     public static void addAvgAndMinAndMaxToSensor(final Sensor sensor,
                                                   final String group,
                                                   final Map<String, String> tags,
-                                                  final String operation,
+                                                  final String gaugeName,
                                                   final String descriptionOfAvg,
                                                   final String descriptionOfMin,
                                                   final String descriptionOfMax) {
-        addAvgAndMaxToSensor(sensor, group, tags, operation, descriptionOfAvg, descriptionOfMax);
+        addAvgAndMaxToSensor(sensor, group, tags, gaugeName, descriptionOfAvg, descriptionOfMax);
         sensor.add(
             new MetricName(
-                operation + MIN_SUFFIX,
+                gaugeName + MIN_SUFFIX,
                 group,
                 descriptionOfMin,
                 tags),
@@ -767,20 +768,6 @@ public class StreamsMetricsImpl implements StreamsMetrics {
         );
     }
 
-    public static void addInvocationRateAndCountToSensor(final Sensor sensor,
-                                                         final String group,
-                                                         final Map<String, String> tags,
-                                                         final String operation) {
-        addInvocationRateAndCountToSensor(
-            sensor,
-            group,
-            tags,
-            operation,
-            RATE_DESCRIPTION + operation,
-            TOTAL_DESCRIPTION + operation
-        );
-    }
-
     public static void addRateOfSumAndSumMetricsToSensor(final Sensor sensor,
                                                          final String group,
                                                          final Map<String, String> tags,
@@ -863,6 +850,14 @@ public class StreamsMetricsImpl implements StreamsMetrics {
         );
     }
 
+    public static void maybeRecordSensor(final double value,
+                                         final Time time,
+                                         final Sensor sensor) {
+        if (sensor.shouldRecord() && sensor.hasMetrics()) {
+            sensor.record(value, time.milliseconds());
+        }
+    }
+
     public static void maybeMeasureLatency(final Runnable actionToMeasure,
                                            final Time time,
                                            final Sensor sensor) {
@@ -894,14 +889,14 @@ public class StreamsMetricsImpl implements StreamsMetrics {
     }
 
     private Sensor getSensors(final Map<String, Deque<String>> sensors,
-                              final String sensorName,
-                              final String key,
+                              final String sensorSuffix,
+                              final String sensorPrefix,
                               final RecordingLevel recordingLevel,
                               final Sensor... parents) {
-        final String fullSensorName = key + SENSOR_NAME_DELIMITER + sensorName;
+        final String fullSensorName = sensorPrefix + SENSOR_NAME_DELIMITER + sensorSuffix;
         final Sensor sensor = metrics.getSensor(fullSensorName);
         if (sensor == null) {
-            sensors.computeIfAbsent(key, ignored -> new LinkedList<>()).push(fullSensorName);
+            sensors.computeIfAbsent(sensorPrefix, ignored -> new LinkedList<>()).push(fullSensorName);
             return metrics.sensor(fullSensorName, recordingLevel, parents);
         }
         return sensor;
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/TaskMetrics.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/TaskMetrics.java
index 3566ac1ada5..0afdbb4bcef 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/TaskMetrics.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/TaskMetrics.java
@@ -25,8 +25,11 @@ import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetric
 import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.RATIO_SUFFIX;
 import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.TASK_LEVEL_GROUP;
 import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.TOTAL_DESCRIPTION;
+import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.TOTAL_SUFFIX;
 import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addAvgAndMaxToSensor;
 import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addInvocationRateAndCountToSensor;
+import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addInvocationRateToSensor;
+import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addSumMetricToSensor;
 import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addValueMetricToSensor;
 
 public class TaskMetrics {
@@ -52,6 +55,21 @@ public class TaskMetrics {
     private static final String PUNCTUATE_AVG_LATENCY_DESCRIPTION = AVG_LATENCY_DESCRIPTION + PUNCTUATE_DESCRIPTION;
     private static final String PUNCTUATE_MAX_LATENCY_DESCRIPTION = MAX_LATENCY_DESCRIPTION + PUNCTUATE_DESCRIPTION;
 
+    private static final String RESTORE = "restore";
+    private static final String RESTORE_DESCRIPTION = "records restored";
+    private static final String RESTORE_TOTAL_DESCRIPTION = TOTAL_DESCRIPTION + RESTORE_DESCRIPTION;
+    private static final String RESTORE_RATE_DESCRIPTION =
+        RATE_DESCRIPTION_PREFIX + RESTORE_DESCRIPTION + RATE_DESCRIPTION_SUFFIX;
+
+    private static final String UPDATE = "update";
+    private static final String UPDATE_DESCRIPTION = "records updated";
+    private static final String UPDATE_TOTAL_DESCRIPTION = TOTAL_DESCRIPTION + UPDATE_DESCRIPTION;
+    private static final String UPDATE_RATE_DESCRIPTION =
+        RATE_DESCRIPTION_PREFIX + UPDATE_DESCRIPTION + RATE_DESCRIPTION_SUFFIX;
+
+    private static final String REMAINING_RECORDS = "-remaining-records";
+    private static final String REMAINING_RECORDS_DESCRIPTION = TOTAL_DESCRIPTION + "records remaining to be restored";
+
     private static final String ENFORCED_PROCESSING = "enforced-processing";
     private static final String ENFORCED_PROCESSING_TOTAL_DESCRIPTION =
         "The total number of occurrences of enforced-processing operations";
@@ -117,6 +135,15 @@ public class TaskMetrics {
         return sensor;
     }
 
+    public static Sensor restoreRemainingRecordsSensor(final String threadId,
+                                                       final String taskId,
+                                                       final StreamsMetricsImpl streamsMetrics) {
+        final String name = RESTORE + REMAINING_RECORDS + TOTAL_SUFFIX;
+        final Map<String, String> tags = streamsMetrics.taskLevelTagMap(threadId, taskId);
+        final Sensor sensor = streamsMetrics.taskLevelSensor(threadId, taskId, name, Sensor.RecordingLevel.INFO);
+        addSumMetricToSensor(sensor, TASK_LEVEL_GROUP, tags, name, false, REMAINING_RECORDS_DESCRIPTION);
+        return sensor;
+    }
 
     public static Sensor activeProcessRatioSensor(final String threadId,
                                                   final String taskId,
@@ -180,6 +207,38 @@ public class TaskMetrics {
         );
     }
 
+    public static Sensor restoreSensor(final String threadId,
+                                       final String taskId,
+                                       final StreamsMetricsImpl streamsMetrics,
+                                       final Sensor... parentSensor) {
+        return invocationRateAndTotalSensor(
+            threadId,
+            taskId,
+            RESTORE,
+            RESTORE_RATE_DESCRIPTION,
+            RESTORE_TOTAL_DESCRIPTION,
+            Sensor.RecordingLevel.DEBUG,
+            streamsMetrics,
+            parentSensor
+        );
+    }
+
+    public static Sensor updateSensor(final String threadId,
+                                      final String taskId,
+                                      final StreamsMetricsImpl streamsMetrics,
+                                      final Sensor... parentSensor) {
+        return invocationRateAndTotalSensor(
+            threadId,
+            taskId,
+            UPDATE,
+            UPDATE_RATE_DESCRIPTION,
+            UPDATE_TOTAL_DESCRIPTION,
+            Sensor.RecordingLevel.DEBUG,
+            streamsMetrics,
+            parentSensor
+        );
+    }
+
     public static Sensor enforcedProcessingSensor(final String threadId,
                                                   final String taskId,
                                                   final StreamsMetricsImpl streamsMetrics,
@@ -213,7 +272,7 @@ public class TaskMetrics {
     public static Sensor droppedRecordsSensor(final String threadId,
                                               final String taskId,
                                               final StreamsMetricsImpl streamsMetrics) {
-        return invocationRateAndCountSensor(
+        return invocationRateAndTotalSensor(
             threadId,
             taskId,
             DROPPED_RECORDS,
@@ -226,39 +285,56 @@ public class TaskMetrics {
 
     private static Sensor invocationRateAndCountSensor(final String threadId,
                                                        final String taskId,
-                                                       final String metricName,
+                                                       final String operation,
                                                        final String descriptionOfRate,
                                                        final String descriptionOfCount,
                                                        final RecordingLevel recordingLevel,
                                                        final StreamsMetricsImpl streamsMetrics,
                                                        final Sensor... parentSensors) {
-        final Sensor sensor = streamsMetrics.taskLevelSensor(threadId, taskId, metricName, recordingLevel, parentSensors);
+        final Sensor sensor = streamsMetrics.taskLevelSensor(threadId, taskId, operation, recordingLevel, parentSensors);
         addInvocationRateAndCountToSensor(
             sensor,
             TASK_LEVEL_GROUP,
             streamsMetrics.taskLevelTagMap(threadId, taskId),
-            metricName,
+            operation,
             descriptionOfRate,
             descriptionOfCount
         );
         return sensor;
     }
 
+    private static Sensor invocationRateAndTotalSensor(final String threadId,
+                                                       final String taskId,
+                                                       final String operation,
+                                                       final String descriptionOfRate,
+                                                       final String descriptionOfTotal,
+                                                       final RecordingLevel recordingLevel,
+                                                       final StreamsMetricsImpl streamsMetrics,
+                                                       final Sensor... parentSensors) {
+        final Sensor sensor = streamsMetrics.taskLevelSensor(threadId, taskId, operation, recordingLevel, parentSensors);
+        final Map<String, String> tags = streamsMetrics.taskLevelTagMap(threadId, taskId);
+
+        addInvocationRateToSensor(sensor, TASK_LEVEL_GROUP, tags, operation, descriptionOfRate);
+        addSumMetricToSensor(sensor, TASK_LEVEL_GROUP, tags, operation, true, descriptionOfTotal);
+        return sensor;
+    }
+
     private static Sensor avgAndMaxSensor(final String threadId,
                                           final String taskId,
-                                          final String metricName,
+                                          final String gaugeName,
                                           final String descriptionOfAvg,
                                           final String descriptionOfMax,
                                           final RecordingLevel recordingLevel,
                                           final StreamsMetricsImpl streamsMetrics,
                                           final Sensor... parentSensors) {
-        final Sensor sensor = streamsMetrics.taskLevelSensor(threadId, taskId, metricName, recordingLevel, parentSensors);
+        // use latency name as sensor suffix and metric name prefix
+        final Sensor sensor = streamsMetrics.taskLevelSensor(threadId, taskId, gaugeName, recordingLevel, parentSensors);
         final Map<String, String> tagMap = streamsMetrics.taskLevelTagMap(threadId, taskId);
         addAvgAndMaxToSensor(
             sensor,
             TASK_LEVEL_GROUP,
             tagMap,
-            metricName,
+            gaugeName,
             descriptionOfAvg,
             descriptionOfMax
         );
@@ -267,7 +343,7 @@ public class TaskMetrics {
 
     private static Sensor invocationRateAndCountAndAvgAndMaxLatencySensor(final String threadId,
                                                                           final String taskId,
-                                                                          final String metricName,
+                                                                          final String operation,
                                                                           final String descriptionOfRate,
                                                                           final String descriptionOfCount,
                                                                           final String descriptionOfAvg,
@@ -275,13 +351,14 @@ public class TaskMetrics {
                                                                           final RecordingLevel recordingLevel,
                                                                           final StreamsMetricsImpl streamsMetrics,
                                                                           final Sensor... parentSensors) {
-        final Sensor sensor = streamsMetrics.taskLevelSensor(threadId, taskId, metricName, recordingLevel, parentSensors);
+        // use operation name as sensor suffix and metric name prefix
+        final Sensor sensor = streamsMetrics.taskLevelSensor(threadId, taskId, operation, recordingLevel, parentSensors);
         final Map<String, String> tagMap = streamsMetrics.taskLevelTagMap(threadId, taskId);
         addAvgAndMaxToSensor(
             sensor,
             TASK_LEVEL_GROUP,
             tagMap,
-            metricName + LATENCY_SUFFIX,
+            operation + LATENCY_SUFFIX,
             descriptionOfAvg,
             descriptionOfMax
         );
@@ -289,7 +366,7 @@ public class TaskMetrics {
             sensor,
             TASK_LEVEL_GROUP,
             tagMap,
-            metricName,
+            operation,
             descriptionOfRate,
             descriptionOfCount
         );
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/ThreadMetrics.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/ThreadMetrics.java
index eda173e532f..988d7302c90 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/ThreadMetrics.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/ThreadMetrics.java
@@ -324,17 +324,18 @@ public class ThreadMetrics {
     }
 
     private static Sensor invocationRateAndCountSensor(final String threadId,
-                                                       final String metricName,
+                                                       final String operation,
                                                        final String descriptionOfRate,
                                                        final String descriptionOfCount,
                                                        final RecordingLevel recordingLevel,
                                                        final StreamsMetricsImpl streamsMetrics) {
-        final Sensor sensor = streamsMetrics.threadLevelSensor(threadId, metricName, recordingLevel);
+        // use operation name as the sensor suffix, and metric names
+        final Sensor sensor = streamsMetrics.threadLevelSensor(threadId, operation, recordingLevel);
         addInvocationRateAndCountToSensor(
             sensor,
             THREAD_LEVEL_GROUP,
             streamsMetrics.threadLevelTagMap(threadId),
-            metricName,
+            operation,
             descriptionOfRate,
             descriptionOfCount
         );
@@ -342,20 +343,21 @@ public class ThreadMetrics {
     }
 
     private static Sensor invocationRateAndCountAndAvgAndMaxLatencySensor(final String threadId,
-                                                                          final String metricName,
+                                                                          final String operation,
                                                                           final String descriptionOfRate,
                                                                           final String descriptionOfCount,
                                                                           final String descriptionOfAvg,
                                                                           final String descriptionOfMax,
                                                                           final RecordingLevel recordingLevel,
                                                                           final StreamsMetricsImpl streamsMetrics) {
-        final Sensor sensor = streamsMetrics.threadLevelSensor(threadId, metricName, recordingLevel);
+        // use operation name as the sensor suffix, and metric names
+        final Sensor sensor = streamsMetrics.threadLevelSensor(threadId, operation, recordingLevel);
         final Map<String, String> tagMap = streamsMetrics.threadLevelTagMap(threadId);
         addAvgAndMaxToSensor(
             sensor,
             THREAD_LEVEL_GROUP,
             tagMap,
-            metricName + LATENCY_SUFFIX,
+            operation + LATENCY_SUFFIX,
             descriptionOfAvg,
             descriptionOfMax
         );
@@ -363,7 +365,7 @@ public class ThreadMetrics {
             sensor,
             THREAD_LEVEL_GROUP,
             tagMap,
-            metricName,
+            operation,
             descriptionOfRate,
             descriptionOfCount
         );
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/metrics/RocksDBMetrics.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/metrics/RocksDBMetrics.java
index a30a891fc06..a92d7f963bc 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/metrics/RocksDBMetrics.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/metrics/RocksDBMetrics.java
@@ -793,11 +793,11 @@ public class RocksDBMetrics {
 
     private static Sensor createSensor(final StreamsMetricsImpl streamsMetrics,
                                        final RocksDBMetricContext metricContext,
-                                       final String sensorName) {
+                                       final String sensorSuffix) {
         return streamsMetrics.storeLevelSensor(
             metricContext.taskName(),
             metricContext.storeName(),
-            sensorName,
+            sensorSuffix,
             RecordingLevel.DEBUG);
     }
 }
\ No newline at end of file
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/metrics/StateStoreMetrics.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/metrics/StateStoreMetrics.java
index 360cd8d0e10..cee44ba906d 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/metrics/StateStoreMetrics.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/metrics/StateStoreMetrics.java
@@ -28,10 +28,8 @@ import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetric
 import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.RECORD_E2E_LATENCY_MAX_DESCRIPTION;
 import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.RECORD_E2E_LATENCY_MIN_DESCRIPTION;
 import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.STATE_STORE_LEVEL_GROUP;
-import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.TOTAL_DESCRIPTION;
 import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addAvgAndMaxToSensor;
 import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addAvgAndMinAndMaxToSensor;
-import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addInvocationRateAndCountToSensor;
 import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addInvocationRateToSensor;
 
 public class StateStoreMetrics {
@@ -146,13 +144,6 @@ public class StateStoreMetrics {
     private static final String SUPPRESSION_BUFFER_SIZE_MAX_DESCRIPTION =
         MAX_DESCRIPTION_PREFIX + SUPPRESSION_BUFFER_SIZE_DESCRIPTION;
 
-    private static final String EXPIRED_WINDOW_RECORD_DROP = "expired-window-record-drop";
-    private static final String EXPIRED_WINDOW_RECORD_DROP_DESCRIPTION = "dropped records due to an expired window";
-    private static final String EXPIRED_WINDOW_RECORD_DROP_TOTAL_DESCRIPTION =
-        TOTAL_DESCRIPTION + EXPIRED_WINDOW_RECORD_DROP_DESCRIPTION;
-    private static final String EXPIRED_WINDOW_RECORD_DROP_RATE_DESCRIPTION =
-        RATE_DESCRIPTION_PREFIX + EXPIRED_WINDOW_RECORD_DROP_DESCRIPTION + RATE_DESCRIPTION_SUFFIX;
-
     public static Sensor putSensor(final String taskId,
                                    final String storeType,
                                    final String storeName,
@@ -276,10 +267,7 @@ public class StateStoreMetrics {
                                           final String storeType,
                                           final String storeName,
                                           final StreamsMetricsImpl streamsMetrics) {
-
-        final String latencyMetricName = PREFIX_SCAN + LATENCY_SUFFIX;
         final Map<String, String> tagMap = streamsMetrics.storeLevelTagMap(taskId, storeType, storeName);
-
         final Sensor sensor = streamsMetrics.storeLevelSensor(taskId, storeName, PREFIX_SCAN, RecordingLevel.DEBUG);
         addInvocationRateToSensor(
             sensor,
@@ -292,7 +280,7 @@ public class StateStoreMetrics {
             sensor,
             STATE_STORE_LEVEL_GROUP,
             tagMap,
-            latencyMetricName,
+            PREFIX_SCAN + LATENCY_SUFFIX,
             PREFIX_SCAN_AVG_LATENCY_DESCRIPTION,
             PREFIX_SCAN_MAX_LATENCY_DESCRIPTION
         );
@@ -366,27 +354,6 @@ public class StateStoreMetrics {
         );
     }
 
-    public static Sensor expiredWindowRecordDropSensor(final String taskId,
-                                                       final String storeType,
-                                                       final String storeName,
-                                                       final StreamsMetricsImpl streamsMetrics) {
-        final Sensor sensor = streamsMetrics.storeLevelSensor(
-            taskId,
-            storeName,
-            EXPIRED_WINDOW_RECORD_DROP,
-            RecordingLevel.INFO
-        );
-        addInvocationRateAndCountToSensor(
-            sensor,
-            "stream-" + storeType + "-metrics",
-            streamsMetrics.storeLevelTagMap(taskId, storeType, storeName),
-            EXPIRED_WINDOW_RECORD_DROP,
-            EXPIRED_WINDOW_RECORD_DROP_RATE_DESCRIPTION,
-            EXPIRED_WINDOW_RECORD_DROP_TOTAL_DESCRIPTION
-        );
-        return sensor;
-    }
-
     public static Sensor suppressionBufferCountSensor(final String taskId,
                                                       final String storeType,
                                                       final String storeName,
@@ -440,34 +407,36 @@ public class StateStoreMetrics {
     private static Sensor sizeOrCountSensor(final String taskId,
                                             final String storeType,
                                             final String storeName,
-                                            final String metricName,
+                                            final String gaugeName,
                                             final String descriptionOfAvg,
                                             final String descriptionOfMax,
                                             final RecordingLevel recordingLevel,
                                             final StreamsMetricsImpl streamsMetrics) {
-        final Sensor sensor = streamsMetrics.storeLevelSensor(taskId, storeName, metricName, recordingLevel);
+        // use the gauge name (either size or count) as the sensor suffix, and metric name prefix
+        final Sensor sensor = streamsMetrics.storeLevelSensor(taskId, storeName, gaugeName, recordingLevel);
         final String group;
         final Map<String, String> tagMap;
         group = STATE_STORE_LEVEL_GROUP;
         tagMap = streamsMetrics.storeLevelTagMap(taskId, storeType, storeName);
-        addAvgAndMaxToSensor(sensor, group, tagMap, metricName, descriptionOfAvg, descriptionOfMax);
+        addAvgAndMaxToSensor(sensor, group, tagMap, gaugeName, descriptionOfAvg, descriptionOfMax);
         return sensor;
     }
 
     private static Sensor throughputAndLatencySensor(final String taskId,
                                                      final String storeType,
                                                      final String storeName,
-                                                     final String metricName,
+                                                     final String operation,
                                                      final String descriptionOfRate,
                                                      final String descriptionOfAvg,
                                                      final String descriptionOfMax,
                                                      final RecordingLevel recordingLevel,
                                                      final StreamsMetricsImpl streamsMetrics) {
+        // use operation as the sensor suffix and metric name prefix
         final Sensor sensor;
-        final String latencyMetricName = metricName + LATENCY_SUFFIX;
+        final String latencyMetricName = operation + LATENCY_SUFFIX;
         final Map<String, String> tagMap = streamsMetrics.storeLevelTagMap(taskId, storeType, storeName);
-        sensor = streamsMetrics.storeLevelSensor(taskId, storeName, metricName, recordingLevel);
-        addInvocationRateToSensor(sensor, STATE_STORE_LEVEL_GROUP, tagMap, metricName, descriptionOfRate);
+        sensor = streamsMetrics.storeLevelSensor(taskId, storeName, operation, recordingLevel);
+        addInvocationRateToSensor(sensor, STATE_STORE_LEVEL_GROUP, tagMap, operation, descriptionOfRate);
         addAvgAndMaxToSensor(
             sensor,
             STATE_STORE_LEVEL_GROUP,
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
index f105a68bcd9..afd263775d6 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
@@ -1564,6 +1564,7 @@ class DefaultStateUpdaterTest {
         when(changelogReader.restore(tasks1234)).thenReturn(1L);
         when(changelogReader.restore(tasks13)).thenReturn(1L);
         when(changelogReader.isRestoringActive()).thenReturn(true);
+
         stateUpdater.start();
         stateUpdater.add(activeTask1);
         stateUpdater.add(activeTask2);
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
index ba484d210ca..578f47af9a7 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
@@ -22,6 +22,7 @@ import org.apache.kafka.common.PartitionInfo;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.errors.TimeoutException;
 import org.apache.kafka.common.metrics.KafkaMetric;
+import org.apache.kafka.common.metrics.MetricConfig;
 import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.metrics.Sensor;
 import org.apache.kafka.common.metrics.stats.CumulativeSum;
@@ -60,17 +61,20 @@ import java.util.List;
 import java.util.stream.Collectors;
 
 import static java.util.Arrays.asList;
+import static org.apache.kafka.common.metrics.Sensor.RecordingLevel.DEBUG;
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.common.utils.Utils.mkProperties;
 import static org.apache.kafka.streams.processor.internals.Task.State.CREATED;
 import static org.apache.kafka.streams.processor.internals.Task.State.RUNNING;
 import static org.apache.kafka.streams.processor.internals.Task.State.SUSPENDED;
+import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.THREAD_ID_TAG;
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.isA;
+import static org.hamcrest.Matchers.not;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertThrows;
@@ -97,8 +101,10 @@ public class StandbyTaskTest {
         asList(store1, store2),
         mkMap(mkEntry(storeName1, storeChangelogTopicName1), mkEntry(storeName2, storeChangelogTopicName2))
     );
-    private final StreamsMetricsImpl streamsMetrics =
-        new StreamsMetricsImpl(new Metrics(), threadName, StreamsConfig.METRICS_LATEST, new MockTime());
+
+    private final MockTime time = new MockTime();
+    private final Metrics metrics = new Metrics(new MetricConfig().recordLevel(Sensor.RecordingLevel.DEBUG), time);
+    private final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, threadName, StreamsConfig.METRICS_LATEST, time);
 
     private File baseDir;
     private StreamsConfig config;
@@ -108,6 +114,7 @@ public class StandbyTaskTest {
     private StreamsConfig createConfig(final File baseDir) throws IOException {
         return new StreamsConfig(mkProperties(mkMap(
             mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, applicationId),
+            mkEntry(StreamsConfig.METRICS_RECORDING_LEVEL_CONFIG, DEBUG.name),
             mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:2171"),
             mkEntry(StreamsConfig.BUFFERED_RECORDS_PER_PARTITION_CONFIG, "3"),
             mkEntry(StreamsConfig.STATE_DIR_CONFIG, baseDir.getCanonicalPath()),
@@ -638,6 +645,44 @@ public class StandbyTaskTest {
         task.maybeInitTaskTimeoutOrThrow(Duration.ofMinutes(5).plus(Duration.ofMillis(1L)).toMillis(), null);
     }
 
+    @Test
+    public void shouldRecordRestoredRecords() {
+        EasyMock.replay(stateManager);
+
+        task = createStandbyTask();
+
+        final KafkaMetric totalMetric = getMetric("update", "%s-total", task.id().toString());
+        final KafkaMetric rateMetric = getMetric("update", "%s-rate", task.id().toString());
+
+        assertThat(totalMetric.metricValue(), equalTo(0.0));
+        assertThat(rateMetric.metricValue(), equalTo(0.0));
+
+        task.maybeRecordRestored(time, 25L);
+
+        assertThat(totalMetric.metricValue(), equalTo(25.0));
+        assertThat(rateMetric.metricValue(), not(0.0));
+
+        task.maybeRecordRestored(time, 50L);
+
+        assertThat(totalMetric.metricValue(), equalTo(75.0));
+        assertThat(rateMetric.metricValue(), not(0.0));
+    }
+
+    private KafkaMetric getMetric(final String operation,
+                                  final String nameFormat,
+                                  final String taskId) {
+        final String descriptionIsNotVerified = "";
+        return metrics.metrics().get(metrics.metricName(
+            String.format(nameFormat, operation),
+            "stream-task-metrics",
+            descriptionIsNotVerified,
+            mkMap(
+                mkEntry("task-id", taskId),
+                mkEntry(THREAD_ID_TAG, Thread.currentThread().getName())
+            )
+        ));
+    }
+
     private StandbyTask createStandbyTask() {
 
         final ThreadCache cache = new ThreadCache(
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
index 61b8791af74..f6d27f57281 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
@@ -97,6 +97,7 @@ import static java.util.Collections.emptyMap;
 import static java.util.Collections.singleton;
 import static java.util.Collections.singletonList;
 import static java.util.Collections.singletonMap;
+import static org.apache.kafka.common.metrics.Sensor.RecordingLevel.DEBUG;
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.common.utils.Utils.mkProperties;
@@ -242,6 +243,7 @@ public class StreamTaskTest {
             mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:2171"),
             mkEntry(StreamsConfig.BUFFERED_RECORDS_PER_PARTITION_CONFIG, "3"),
             mkEntry(StreamsConfig.STATE_DIR_CONFIG, canonicalPath),
+            mkEntry(StreamsConfig.METRICS_RECORDING_LEVEL_CONFIG, DEBUG.name),
             mkEntry(StreamsConfig.DEFAULT_TIMESTAMP_EXTRACTOR_CLASS_CONFIG, MockTimestampExtractor.class.getName()),
             mkEntry(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, eosConfig),
             mkEntry(StreamsConfig.MAX_TASK_IDLE_MS_CONFIG, enforcedProcessingValue),
@@ -767,6 +769,35 @@ public class StreamTaskTest {
         assertThat(terminalMax.metricValue(), equalTo(23.0));
     }
 
+    @Test
+    public void shouldRecordRestoredRecords() {
+        task = createSingleSourceStateless(createConfig(AT_LEAST_ONCE, "0"), StreamsConfig.METRICS_LATEST);
+
+        final KafkaMetric totalMetric = getMetric("restore", "%s-total", task.id().toString());
+        final KafkaMetric rateMetric = getMetric("restore", "%s-rate", task.id().toString());
+        final KafkaMetric remainMetric = getMetric("restore", "%s-remaining-records-total", task.id().toString());
+
+        assertThat(totalMetric.metricValue(), equalTo(0.0));
+        assertThat(rateMetric.metricValue(), equalTo(0.0));
+        assertThat(remainMetric.metricValue(), equalTo(0.0));
+
+        task.initRemainingRecordsToRestore(time, 100L);
+
+        assertThat(remainMetric.metricValue(), equalTo(100.0));
+
+        task.maybeRecordRestored(time, 25L);
+
+        assertThat(totalMetric.metricValue(), equalTo(25.0));
+        assertThat(rateMetric.metricValue(), not(0.0));
+        assertThat(remainMetric.metricValue(), equalTo(75.0));
+
+        task.maybeRecordRestored(time, 50L);
+
+        assertThat(totalMetric.metricValue(), equalTo(75.0));
+        assertThat(rateMetric.metricValue(), not(0.0));
+        assertThat(remainMetric.metricValue(), equalTo(25.0));
+    }
+
     @Test
     public void shouldThrowOnTimeoutExceptionAndBufferRecordForRetryIfEosDisabled() {
         createTimeoutTask(AT_LEAST_ONCE);
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
index f43de372388..3e714a3f6c9 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
@@ -4865,6 +4865,11 @@ public class TaskManagerTest {
             timeout = null;
         }
 
+        @Override
+        public void maybeRecordRestored(final Time time, final long numRecords) {
+            // do nothing
+        }
+
         @Override
         public void closeClean() {
             transitionTo(State.CLOSED);
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/metrics/StateStoreMetricsTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/metrics/StateStoreMetricsTest.java
index 48413306a93..4995559d248 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/metrics/StateStoreMetricsTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/metrics/StateStoreMetricsTest.java
@@ -300,32 +300,6 @@ public class StateStoreMetricsTest {
         );
     }
 
-    @Test
-    public void shouldGetExpiredWindowRecordDropSensor() {
-        final String metricName = "expired-window-record-drop";
-        final String descriptionOfRate = "The average number of dropped records due to an expired window per second";
-        final String descriptionOfCount = "The total number of dropped records due to an expired window";
-        when(streamsMetrics.storeLevelSensor(TASK_ID, STORE_NAME, metricName, RecordingLevel.INFO))
-            .thenReturn(expectedSensor);
-        when(streamsMetrics.storeLevelTagMap(TASK_ID, STORE_TYPE, STORE_NAME)).thenReturn(storeTagMap);
-
-        try (final MockedStatic<StreamsMetricsImpl> streamsMetricsStaticMock = mockStatic(StreamsMetricsImpl.class)) {
-            final Sensor sensor =
-                StateStoreMetrics.expiredWindowRecordDropSensor(TASK_ID, STORE_TYPE, STORE_NAME, streamsMetrics);
-            streamsMetricsStaticMock.verify(
-                () -> StreamsMetricsImpl.addInvocationRateAndCountToSensor(
-                    expectedSensor,
-                    "stream-" + STORE_TYPE + "-metrics",
-                    storeTagMap,
-                    metricName,
-                    descriptionOfRate,
-                    descriptionOfCount
-                )
-            );
-            assertThat(sensor, is(expectedSensor));
-        }
-    }
-
     @Test
     public void shouldGetRecordE2ELatencySensor() {
         final String metricName = "record-e2e-latency";