You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by ca...@apache.org on 2022/06/07 14:02:33 UTC

[kafka] branch trunk updated: KAFKA-13945: add bytes/records consumed and produced metrics (#12235)

This is an automated email from the ASF dual-hosted git repository.

cadonna pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new a6c5a74fdb KAFKA-13945: add bytes/records consumed and produced metrics (#12235)
a6c5a74fdb is described below

commit a6c5a74fdbdce9a992b47706913c920902cda28c
Author: A. Sophie Blee-Goldman <so...@confluent.io>
AuthorDate: Tue Jun 7 07:02:17 2022 -0700

    KAFKA-13945: add bytes/records consumed and produced metrics (#12235)
    
    Implementation of KIP-846: Source/sink node metrics for Consumed/Produced throughput in Streams
    
    Adds the following INFO topic-level metrics for the total bytes/records consumed and produced:
    
        bytes-consumed-total
        records-consumed-total
        bytes-produced-total
        records-produced-total
    
    Reviewers: Kvicii <Ka...@gmail.com>, Guozhang Wang <gu...@apache.org>, Bruno Cadonna <ca...@apache.org>
---
 checkstyle/suppressions.xml                        |   2 +-
 docs/ops.html                                      |   4 +-
 .../processor/internals/ActiveTaskCreator.java     |   3 +-
 .../streams/processor/internals/ClientUtils.java   |  47 ++++
 .../processor/internals/PartitionGroup.java        |   8 +-
 .../processor/internals/ProcessorContextImpl.java  |   5 +-
 .../processor/internals/RecordCollector.java       |   6 +-
 .../processor/internals/RecordCollectorImpl.java   |  69 ++++-
 .../streams/processor/internals/RecordQueue.java   |  50 ++--
 .../streams/processor/internals/SinkNode.java      |  12 +-
 .../streams/processor/internals/StreamTask.java    |  11 +-
 .../internals/metrics/StreamsMetricsImpl.java      |  64 +++++
 .../processor/internals/metrics/TaskMetrics.java   |   6 +-
 .../processor/internals/metrics/TopicMetrics.java  |  92 +++++++
 .../InMemoryTimeOrderedKeyValueBuffer.java         |  10 +-
 .../integration/MetricsIntegrationTest.java        |  18 ++
 .../processor/internals/ActiveTaskCreatorTest.java |   2 +
 .../processor/internals/ClientUtilsTest.java       | 114 ++++++++-
 .../processor/internals/PartitionGroupTest.java    |  17 +-
 .../internals/ProcessorContextImplTest.java        |  39 +--
 .../processor/internals/RecordCollectorTest.java   | 279 +++++++++++++++------
 .../processor/internals/RecordQueueTest.java       | 101 ++++++--
 .../internals/WriteConsistencyVectorTest.java      |  19 +-
 .../internals/metrics/StreamsMetricsImplTest.java  |  82 +++++-
 .../internals/metrics/TaskMetricsTest.java         |   2 +-
 .../internals/metrics/TopicMetricsTest.java        | 118 +++++++++
 .../streams/state/KeyValueStoreTestDriver.java     |  25 +-
 .../StreamThreadStateStoreProviderTest.java        |   4 +-
 .../kafka/test/InternalMockProcessorContext.java   |   4 +-
 .../org/apache/kafka/test/MockRecordCollector.java |  17 +-
 .../streams/scala/kstream/KStreamSplitTest.scala   |  69 ++---
 .../apache/kafka/streams/TopologyTestDriver.java   |   3 +-
 .../org/apache/kafka/streams/TestTopicsTest.java   |  58 +++--
 33 files changed, 1098 insertions(+), 262 deletions(-)

diff --git a/checkstyle/suppressions.xml b/checkstyle/suppressions.xml
index 9d48741edd..773af1bc1b 100644
--- a/checkstyle/suppressions.xml
+++ b/checkstyle/suppressions.xml
@@ -213,7 +213,7 @@
 
     <!-- Streams tests -->
     <suppress checks="ClassFanOutComplexity"
-              files="(StreamsPartitionAssignorTest|StreamThreadTest|StreamTaskTest|TaskManagerTest|TopologyTestDriverTest).java"/>
+              files="(RecordCollectorTest|StreamsPartitionAssignorTest|StreamThreadTest|StreamTaskTest|TaskManagerTest|TopologyTestDriverTest).java"/>
 
     <suppress checks="MethodLength"
               files="(EosIntegrationTest|EosV2UpgradeIntegrationTest|KStreamKStreamJoinTest|RocksDBWindowStoreTest).java"/>
diff --git a/docs/ops.html b/docs/ops.html
index 843c5cc59e..0c1973bd69 100644
--- a/docs/ops.html
+++ b/docs/ops.html
@@ -2308,7 +2308,7 @@ $ bin/kafka-acls.sh \
   <h4 class="anchor-heading"><a id="kafka_streams_monitoring" class="anchor-link"></a><a href="#kafka_streams_monitoring">Streams Monitoring</a></h4>
 
   A Kafka Streams instance contains all the producer and consumer metrics as well as additional metrics specific to Streams.
-  By default Kafka Streams has metrics with three recording levels: <code>info</code>, <code>debug</code>, and <code>trace</code>.
+  The metrics have three recording levels: <code>info</code>, <code>debug</code>, and <code>trace</code>.
 
   <p>
     Note that the metrics have a 4-layer hierarchy. At the top level there are client-level metrics for each started
@@ -2617,7 +2617,7 @@ active-process-ratio metrics which have a recording level of <code>info</code>:
  </table>
 
  <h5 class="anchor-heading"><a id="kafka_streams_store_monitoring" class="anchor-link"></a><a href="#kafka_streams_store_monitoring">State Store Metrics</a></h5>
-All of the following metrics have a recording level of <code>debug</code>, except for the record-e2e-latency-* metrics which have a recording level <code>trace></code>.
+All of the following metrics have a recording level of <code>debug</code>, except for the record-e2e-latency-* metrics which have a recording level <code>trace</code>.
 Note that the <code>store-scope</code> value is specified in <code>StoreSupplier#metricsScope()</code> for user's customized state stores;
 for built-in state stores, currently we have:
   <ul>
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
index d5545eafec..d90266dbf4 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
@@ -260,7 +260,8 @@ class ActiveTaskCreator {
             taskId,
             streamsProducer,
             applicationConfig.defaultProductionExceptionHandler(),
-            streamsMetrics
+            streamsMetrics,
+            topology
         );
 
         final StreamTask task = new StreamTask(
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ClientUtils.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ClientUtils.java
index b47b68f8c4..ea48c64617 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ClientUtils.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ClientUtils.java
@@ -21,12 +21,17 @@ import org.apache.kafka.clients.admin.ListOffsetsResult.ListOffsetsResultInfo;
 import org.apache.kafka.clients.admin.OffsetSpec;
 import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.clients.consumer.ConsumerConfig;
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.producer.ProducerRecord;
 import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.KafkaFuture;
 import org.apache.kafka.common.Metric;
 import org.apache.kafka.common.MetricName;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.errors.TimeoutException;
+import org.apache.kafka.common.header.Header;
+import org.apache.kafka.common.header.Headers;
+import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.processor.TaskId;
@@ -166,4 +171,46 @@ public class ClientUtils {
         final int index = fullThreadName.indexOf("StreamThread-");
         return fullThreadName.substring(index);
     }
+
+    public static long producerRecordSizeInBytes(final ProducerRecord<byte[], byte[]> record) {
+        return recordSizeInBytes(
+            record.key().length,
+            record.value() == null ? 0 : record.value().length,
+            record.topic(),
+            record.headers()
+        );
+    }
+
+    public static long consumerRecordSizeInBytes(final ConsumerRecord<byte[], byte[]> record) {
+        return recordSizeInBytes(
+            record.serializedKeySize(),
+            record.serializedValueSize(),
+            record.topic(),
+            record.headers()
+        );
+    }
+
+    public static long recordSizeInBytes(final long keyBytes,
+                                         final long valueBytes,
+                                         final String topic,
+                                         final Headers headers) {
+        long headerSizeInBytes = 0L;
+
+        if (headers != null) {
+            for (final Header header : headers.toArray()) {
+                headerSizeInBytes += Utils.utf8(header.key()).length;
+                if (header.value() != null) {
+                    headerSizeInBytes += header.value().length;
+                }
+            }
+        }
+
+        return keyBytes +
+            valueBytes +
+            8L + // timestamp
+            8L + // offset
+            Utils.utf8(topic).length +
+            4L + // partition
+            headerSizeInBytes;
+    }
 }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/PartitionGroup.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/PartitionGroup.java
index dd1257f179..21d3cbfa3f 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/PartitionGroup.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/PartitionGroup.java
@@ -268,7 +268,7 @@ public class PartitionGroup {
             // get the buffer size of queue before poll
             final long oldBufferSize = queue.getTotalBytesBuffered();
             // get the first record from this queue.
-            record = queue.poll();
+            record = queue.poll(wallClockTime);
             // After polling, the buffer size would have reduced.
             final long newBufferSize = queue.getTotalBytesBuffered();
 
@@ -392,6 +392,12 @@ public class PartitionGroup {
         streamTime = RecordQueue.UNKNOWN;
     }
 
+    void close() {
+        for (final RecordQueue queue : partitionQueues.values()) {
+            queue.close();
+        }
+    }
+
     // Below methods are for only testing.
 
     boolean allPartitionsBufferedLocally() {
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java
index 5a548aba97..ffa5dcaf73 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java
@@ -147,8 +147,9 @@ public class ProcessorContextImpl extends AbstractProcessorContext<Object, Objec
             changelogPartition.partition(),
             timestamp,
             BYTES_KEY_SERIALIZER,
-            BYTEARRAY_VALUE_SERIALIZER
-        );
+            BYTEARRAY_VALUE_SERIALIZER,
+            null,
+            null);
     }
 
     /**
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollector.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollector.java
index 8b22f22f82..a48a671d46 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollector.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollector.java
@@ -33,7 +33,9 @@ public interface RecordCollector {
                      final Integer partition,
                      final Long timestamp,
                      final Serializer<K> keySerializer,
-                     final Serializer<V> valueSerializer);
+                     final Serializer<V> valueSerializer,
+                     final String processorNodeId,
+                     final InternalProcessorContext<Void, Void> context);
 
     <K, V> void send(final String topic,
                      final K key,
@@ -42,6 +44,8 @@ public interface RecordCollector {
                      final Long timestamp,
                      final Serializer<K> keySerializer,
                      final Serializer<V> valueSerializer,
+                     final String processorNodeId,
+                     final InternalProcessorContext<Void, Void> context,
                      final StreamPartitioner<? super K, ? super V> partitioner);
 
     /**
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java
index f8c9cf9d7d..358af6b1b3 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java
@@ -46,6 +46,8 @@ import org.apache.kafka.streams.processor.StreamPartitioner;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
 import org.apache.kafka.streams.processor.internals.metrics.TaskMetrics;
+import org.apache.kafka.streams.processor.internals.metrics.TopicMetrics;
+
 import org.slf4j.Logger;
 
 import java.util.Collections;
@@ -54,6 +56,8 @@ import java.util.List;
 import java.util.Map;
 import java.util.concurrent.atomic.AtomicReference;
 
+import static org.apache.kafka.streams.processor.internals.ClientUtils.producerRecordSizeInBytes;
+
 public class RecordCollectorImpl implements RecordCollector {
     private final static String SEND_EXCEPTION_MESSAGE = "Error encountered sending record to topic %s for task %s due to:%n%s";
 
@@ -61,10 +65,13 @@ public class RecordCollectorImpl implements RecordCollector {
     private final TaskId taskId;
     private final StreamsProducer streamsProducer;
     private final ProductionExceptionHandler productionExceptionHandler;
-    private final Sensor droppedRecordsSensor;
     private final boolean eosEnabled;
     private final Map<TopicPartition, Long> offsets;
 
+    private final StreamsMetricsImpl streamsMetrics;
+    private final Sensor droppedRecordsSensor;
+    private final Map<String, Map<String, Sensor>> sinkNodeToProducedSensorByTopic = new HashMap<>();
+
     private final AtomicReference<KafkaException> sendException = new AtomicReference<>(null);
 
     /**
@@ -74,15 +81,29 @@ public class RecordCollectorImpl implements RecordCollector {
                                final TaskId taskId,
                                final StreamsProducer streamsProducer,
                                final ProductionExceptionHandler productionExceptionHandler,
-                               final StreamsMetricsImpl streamsMetrics) {
+                               final StreamsMetricsImpl streamsMetrics,
+                               final ProcessorTopology topology) {
         this.log = logContext.logger(getClass());
         this.taskId = taskId;
         this.streamsProducer = streamsProducer;
         this.productionExceptionHandler = productionExceptionHandler;
         this.eosEnabled = streamsProducer.eosEnabled();
+        this.streamsMetrics = streamsMetrics;
 
         final String threadId = Thread.currentThread().getName();
         this.droppedRecordsSensor = TaskMetrics.droppedRecordsSensor(threadId, taskId.toString(), streamsMetrics);
+        for (final String topic : topology.sinkTopics()) {
+            final String processorNodeId = topology.sink(topic).name();
+            sinkNodeToProducedSensorByTopic.computeIfAbsent(processorNodeId, t -> new HashMap<>()).put(
+                topic,
+                TopicMetrics.producedSensor(
+                    threadId,
+                    taskId.toString(),
+                    processorNodeId,
+                    topic,
+                    streamsMetrics
+                ));
+        }
 
         this.offsets = new HashMap<>();
     }
@@ -106,6 +127,8 @@ public class RecordCollectorImpl implements RecordCollector {
                             final Long timestamp,
                             final Serializer<K> keySerializer,
                             final Serializer<V> valueSerializer,
+                            final String processorNodeId,
+                            final InternalProcessorContext<Void, Void> context,
                             final StreamPartitioner<? super K, ? super V> partitioner) {
         final Integer partition;
 
@@ -122,7 +145,7 @@ public class RecordCollectorImpl implements RecordCollector {
                 // here we cannot drop the message on the floor even if it is a transient timeout exception,
                 // so we treat everything the same as a fatal exception
                 throw new StreamsException("Could not determine the number of partitions for topic '" + topic +
-                    "' for task " + taskId + " due to " + fatal.toString(),
+                    "' for task " + taskId + " due to " + fatal,
                     fatal
                 );
             }
@@ -136,7 +159,7 @@ public class RecordCollectorImpl implements RecordCollector {
             partition = null;
         }
 
-        send(topic, key, value, headers, partition, timestamp, keySerializer, valueSerializer);
+        send(topic, key, value, headers, partition, timestamp, keySerializer, valueSerializer, processorNodeId, context);
     }
 
     @Override
@@ -147,7 +170,9 @@ public class RecordCollectorImpl implements RecordCollector {
                             final Integer partition,
                             final Long timestamp,
                             final Serializer<K> keySerializer,
-                            final Serializer<V> valueSerializer) {
+                            final Serializer<V> valueSerializer,
+                            final String processorNodeId,
+                            final InternalProcessorContext<Void, Void> context) {
         checkForException();
 
         final byte[] keyBytes;
@@ -173,7 +198,7 @@ public class RecordCollectorImpl implements RecordCollector {
                     valueClass),
                 exception);
         } catch (final RuntimeException exception) {
-            final String errorMessage = String.format(SEND_EXCEPTION_MESSAGE, topic, taskId, exception.toString());
+            final String errorMessage = String.format(SEND_EXCEPTION_MESSAGE, topic, taskId, exception);
             throw new StreamsException(errorMessage, exception);
         }
 
@@ -192,6 +217,28 @@ public class RecordCollectorImpl implements RecordCollector {
                 } else {
                     log.warn("Received offset={} in produce response for {}", metadata.offset(), tp);
                 }
+
+                if (!topic.endsWith("-changelog")) {
+                    // we may not have created a sensor yet if the node uses dynamic topic routing
+                    final Map<String, Sensor> producedSensorByTopic = sinkNodeToProducedSensorByTopic.get(processorNodeId);
+                    if (producedSensorByTopic == null) {
+                        log.error("Unable to records bytes produced to topic {} by sink node {} as the node is not recognized.\n"
+                                      + "Known sink nodes are {}.", topic, processorNodeId, sinkNodeToProducedSensorByTopic.keySet());
+                    } else {
+                        final Sensor topicProducedSensor = producedSensorByTopic.computeIfAbsent(
+                            topic,
+                            t -> TopicMetrics.producedSensor(
+                                Thread.currentThread().getName(),
+                                taskId.toString(),
+                                processorNodeId,
+                                topic,
+                                context.metrics()
+                            )
+                        );
+                        final long bytesProduced = producerRecordSizeInBytes(serializedRecord);
+                        topicProducedSensor.record(bytesProduced, context.currentSystemTimeMs());
+                    }
+                }
             } else {
                 recordSendError(topic, exception, serializedRecord);
 
@@ -267,6 +314,8 @@ public class RecordCollectorImpl implements RecordCollector {
     public void closeClean() {
         log.info("Closing record collector clean");
 
+        removeAllProducedSensors();
+
         // No need to abort transaction during a clean close: either we have successfully committed the ongoing
         // transaction during handleRevocation and thus there is no transaction in flight, or else none of the revoked
         // tasks had any data in the current transaction and therefore there is no need to commit or abort it.
@@ -290,6 +339,14 @@ public class RecordCollectorImpl implements RecordCollector {
         checkForException();
     }
 
+    private void removeAllProducedSensors() {
+        for (final Map<String, Sensor> nodeMap : sinkNodeToProducedSensorByTopic.values()) {
+            for (final Sensor sensor : nodeMap.values()) {
+                streamsMetrics.removeSensor(sensor);
+            }
+        }
+    }
+
     @Override
     public Map<TopicPartition, Long> offsets() {
         return Collections.unmodifiableMap(new HashMap<>(offsets));
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordQueue.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordQueue.java
index 90d67a7b0f..8aa58414a9 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordQueue.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordQueue.java
@@ -18,19 +18,21 @@ package org.apache.kafka.streams.processor.internals;
 
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.common.TopicPartition;
-import org.apache.kafka.common.header.Header;
 import org.apache.kafka.common.metrics.Sensor;
 import org.apache.kafka.common.utils.LogContext;
-import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.errors.DeserializationExceptionHandler;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.processor.internals.metrics.TaskMetrics;
 import org.apache.kafka.streams.processor.api.ProcessorContext;
 import org.apache.kafka.streams.processor.TimestampExtractor;
+import org.apache.kafka.streams.processor.internals.metrics.TopicMetrics;
+
 import org.slf4j.Logger;
 
 import java.util.ArrayDeque;
 
+import static org.apache.kafka.streams.processor.internals.ClientUtils.consumerRecordSizeInBytes;
+
 /**
  * RecordQueue is a FIFO queue of {@link StampedRecord} (ConsumerRecord + timestamp). It also keeps track of the
  * partition timestamp defined as the largest timestamp seen on the partition so far; this is passed to the
@@ -52,6 +54,7 @@ public class RecordQueue {
     private long partitionTime = UNKNOWN;
 
     private final Sensor droppedRecordsSensor;
+    private final Sensor consumedSensor;
     private long totalBytesBuffered;
     private long headRecordSizeInBytes;
 
@@ -66,9 +69,18 @@ public class RecordQueue {
         this.fifoQueue = new ArrayDeque<>();
         this.timestampExtractor = timestampExtractor;
         this.processorContext = processorContext;
+
+        final String threadName = Thread.currentThread().getName();
         droppedRecordsSensor = TaskMetrics.droppedRecordsSensor(
-            Thread.currentThread().getName(),
+            threadName,
+            processorContext.taskId().toString(),
+            processorContext.metrics()
+        );
+        consumedSensor = TopicMetrics.consumedSensor(
+            threadName,
             processorContext.taskId().toString(),
+            source.name(),
+            partition.topic(),
             processorContext.metrics()
         );
         recordDeserializer = new RecordDeserializer(
@@ -104,25 +116,6 @@ public class RecordQueue {
         return partition;
     }
 
-    private long sizeInBytes(final ConsumerRecord<byte[], byte[]> record) {
-        long headerSizeInBytes = 0L;
-
-        for (final Header header: record.headers().toArray()) {
-            headerSizeInBytes += Utils.utf8(header.key()).length;
-            if (header.value() != null) {
-                headerSizeInBytes += header.value().length;
-            }
-        }
-
-        return record.serializedKeySize() +
-                record.serializedValueSize() +
-                8L + // timestamp
-                8L + // offset
-                Utils.utf8(record.topic()).length +
-                4L + // partition
-                headerSizeInBytes;
-    }
-
     /**
      * Add a batch of {@link ConsumerRecord} into the queue
      *
@@ -132,7 +125,7 @@ public class RecordQueue {
     int addRawRecords(final Iterable<ConsumerRecord<byte[], byte[]>> rawRecords) {
         for (final ConsumerRecord<byte[], byte[]> rawRecord : rawRecords) {
             fifoQueue.addLast(rawRecord);
-            this.totalBytesBuffered += sizeInBytes(rawRecord);
+            this.totalBytesBuffered += consumerRecordSizeInBytes(rawRecord);
         }
 
         updateHead();
@@ -145,8 +138,11 @@ public class RecordQueue {
      *
      * @return StampedRecord
      */
-    public StampedRecord poll() {
+    public StampedRecord poll(final long wallClockTime) {
         final StampedRecord recordToReturn = headRecord;
+
+        consumedSensor.record(headRecordSizeInBytes, wallClockTime);
+
         totalBytesBuffered -= headRecordSizeInBytes;
         headRecord = null;
         headRecordSizeInBytes = 0L;
@@ -199,6 +195,10 @@ public class RecordQueue {
         partitionTime = UNKNOWN;
     }
 
+    public void close() {
+        processorContext.metrics().removeSensor(consumedSensor);
+    }
+
     private void updateHead() {
         ConsumerRecord<byte[], byte[]> lastCorruptedRecord = null;
 
@@ -235,7 +235,7 @@ public class RecordQueue {
                 continue;
             }
             headRecord = new StampedRecord(deserialized, timestamp);
-            headRecordSizeInBytes = sizeInBytes(raw);
+            headRecordSizeInBytes = consumerRecordSizeInBytes(raw);
         }
 
         // if all records in the FIFO queue are corrupted, make the last one the headRecord
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/SinkNode.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/SinkNode.java
index f30e2d2847..6f508eff27 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/SinkNode.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/SinkNode.java
@@ -82,7 +82,17 @@ public class SinkNode<KIn, VIn> extends ProcessorNode<KIn, VIn, Void, Void> {
 
         final String topic = topicExtractor.extract(key, value, contextForExtraction);
 
-        collector.send(topic, key, value, record.headers(), timestamp, keySerializer, valSerializer, partitioner);
+        collector.send(
+            topic,
+            key,
+            value,
+            record.headers(),
+            timestamp,
+            keySerializer,
+            valSerializer,
+            name(),
+            context,
+            partitioner);
     }
 
     /**
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
index 3e6513cf7b..ea593a2973 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
@@ -184,7 +184,7 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
             createPartitionQueues(),
             mainConsumer::currentLag,
             TaskMetrics.recordLatenessSensor(threadId, taskId, streamsMetrics),
-            TaskMetrics.totalBytesSensor(threadId, taskId, streamsMetrics),
+            TaskMetrics.totalInputBufferBytesSensor(threadId, taskId, streamsMetrics),
             enforcedProcessingSensor,
             maxTaskIdleMs
         );
@@ -553,6 +553,7 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
         switch (state()) {
             case SUSPENDED:
                 stateMgr.recycle();
+                partitionGroup.close();
                 recordCollector.closeClean();
 
                 break;
@@ -614,6 +615,13 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
     private void close(final boolean clean) {
         switch (state()) {
             case SUSPENDED:
+                TaskManager.executeAndMaybeSwallow(
+                    clean,
+                    partitionGroup::close,
+                    "partition group close",
+                    log
+                );
+
                 // first close state manager (which is idempotent) then close the record collector
                 // if the latter throws and we re-close dirty which would close the state manager again.
                 TaskManager.executeAndMaybeSwallow(
@@ -653,7 +661,6 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
         }
 
         record = null;
-        partitionGroup.clear();
         closeTaskSensor.record();
 
         transitionTo(State.CLOSED);
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/StreamsMetricsImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/StreamsMetricsImpl.java
index dea23993d2..4bfd96265b 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/StreamsMetricsImpl.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/StreamsMetricsImpl.java
@@ -95,6 +95,7 @@ public class StreamsMetricsImpl implements StreamsMetrics {
     private final Map<String, Deque<String>> threadLevelSensors = new HashMap<>();
     private final Map<String, Deque<String>> taskLevelSensors = new HashMap<>();
     private final Map<String, Deque<String>> nodeLevelSensors = new HashMap<>();
+    private final Map<String, Deque<String>> topicLevelSensors = new HashMap<>();
     private final Map<String, Deque<String>> cacheLevelSensors = new HashMap<>();
     private final ConcurrentMap<String, Deque<String>> storeLevelSensors = new ConcurrentHashMap<>();
     private final ConcurrentMap<String, Deque<MetricName>> storeLevelMetrics = new ConcurrentHashMap<>();
@@ -105,6 +106,7 @@ public class StreamsMetricsImpl implements StreamsMetrics {
     private static final String SENSOR_NAME_DELIMITER = ".s.";
     private static final String SENSOR_TASK_LABEL = "task";
     private static final String SENSOR_NODE_LABEL = "node";
+    private static final String SENSOR_TOPIC_LABEL = "topic";
     private static final String SENSOR_CACHE_LABEL = "cache";
     private static final String SENSOR_STORE_LABEL = "store";
     private static final String SENSOR_ENTITY_LABEL = "entity";
@@ -115,6 +117,7 @@ public class StreamsMetricsImpl implements StreamsMetrics {
     public static final String THREAD_ID_TAG = "thread-id";
     public static final String TASK_ID_TAG = "task-id";
     public static final String PROCESSOR_NODE_ID_TAG = "processor-node-id";
+    public static final String TOPIC_NAME_TAG = "topic-name";
     public static final String STORE_ID_TAG = "state-id";
     public static final String RECORD_CACHE_ID_TAG = "record-cache-id";
 
@@ -136,6 +139,7 @@ public class StreamsMetricsImpl implements StreamsMetrics {
     public static final String THREAD_LEVEL_GROUP = GROUP_PREFIX + "thread" + GROUP_SUFFIX;
     public static final String TASK_LEVEL_GROUP = GROUP_PREFIX + "task" + GROUP_SUFFIX;
     public static final String PROCESSOR_NODE_LEVEL_GROUP = GROUP_PREFIX + "processor-node" + GROUP_SUFFIX;
+    public static final String TOPIC_LEVEL_GROUP = GROUP_PREFIX + "topic" + GROUP_SUFFIX;
     public static final String STATE_STORE_LEVEL_GROUP = GROUP_PREFIX + "state" + GROUP_SUFFIX;
     public static final String CACHE_LEVEL_GROUP = GROUP_PREFIX + "record-cache" + GROUP_SUFFIX;
 
@@ -325,6 +329,15 @@ public class StreamsMetricsImpl implements StreamsMetrics {
         return tagMap;
     }
 
+    public Map<String, String> topicLevelTagMap(final String threadId,
+                                                final String taskName,
+                                                final String processorNodeName,
+                                                final String topicName) {
+        final Map<String, String> tagMap = nodeLevelTagMap(threadId, taskName, processorNodeName);
+        tagMap.put(TOPIC_NAME_TAG, topicName);
+        return tagMap;
+    }
+
     public Map<String, String> storeLevelTagMap(final String taskName,
                                                 final String storeType,
                                                 final String storeName) {
@@ -388,6 +401,40 @@ public class StreamsMetricsImpl implements StreamsMetrics {
             + SENSOR_PREFIX_DELIMITER + SENSOR_NODE_LABEL + SENSOR_PREFIX_DELIMITER + processorNodeName;
     }
 
+    public Sensor topicLevelSensor(final String threadId,
+                                   final String taskId,
+                                   final String processorNodeName,
+                                   final String topicName,
+                                   final String sensorName,
+                                   final Sensor.RecordingLevel recordingLevel,
+                                   final Sensor... parents) {
+        final String key = topicSensorPrefix(threadId, taskId, processorNodeName, topicName);
+        synchronized (topicLevelSensors) {
+            return getSensors(topicLevelSensors, sensorName, key, recordingLevel, parents);
+        }
+    }
+
+    public final void removeAllTopicLevelSensors(final String threadId,
+                                                 final String taskId,
+                                                 final String processorNodeName,
+                                                 final String topicName) {
+        final String key = topicSensorPrefix(threadId, taskId, processorNodeName, topicName);
+        synchronized (topicLevelSensors) {
+            final Deque<String> sensors = topicLevelSensors.remove(key);
+            while (sensors != null && !sensors.isEmpty()) {
+                metrics.removeSensor(sensors.pop());
+            }
+        }
+    }
+
+    private String topicSensorPrefix(final String threadId,
+                                     final String taskId,
+                                     final String processorNodeName,
+                                     final String topicName) {
+        return nodeSensorPrefix(threadId, taskId, processorNodeName)
+            + SENSOR_PREFIX_DELIMITER + SENSOR_TOPIC_LABEL + SENSOR_PREFIX_DELIMITER + topicName;
+    }
+
     public Sensor cacheLevelSensor(final String threadId,
                                    final String taskName,
                                    final String storeName,
@@ -795,6 +842,23 @@ public class StreamsMetricsImpl implements StreamsMetrics {
         );
     }
 
+    public static void addTotalCountAndSumMetricsToSensor(final Sensor sensor,
+                                                          final String group,
+                                                          final Map<String, String> tags,
+                                                          final String countMetricNamePrefix,
+                                                          final String sumMetricNamePrefix,
+                                                          final String descriptionOfCount,
+                                                          final String descriptionOfTotal) {
+        sensor.add(
+            new MetricName(countMetricNamePrefix + TOTAL_SUFFIX, group, descriptionOfCount, tags),
+            new CumulativeCount()
+        );
+        sensor.add(
+            new MetricName(sumMetricNamePrefix + TOTAL_SUFFIX, group, descriptionOfTotal, tags),
+            new CumulativeSum()
+        );
+    }
+
     public static void maybeMeasureLatency(final Runnable actionToMeasure,
                                            final Time time,
                                            final Sensor sensor) {
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/TaskMetrics.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/TaskMetrics.java
index f173bac403..8949390e7f 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/TaskMetrics.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/TaskMetrics.java
@@ -133,9 +133,9 @@ public class TaskMetrics {
         return sensor;
     }
 
-    public static Sensor totalBytesSensor(final String threadId,
-                                          final String taskId,
-                                          final StreamsMetricsImpl streamsMetrics) {
+    public static Sensor totalInputBufferBytesSensor(final String threadId,
+                                                     final String taskId,
+                                                     final StreamsMetricsImpl streamsMetrics) {
         final String name = INPUT_BUFFER_BYTES_TOTAL;
         final Sensor sensor = streamsMetrics.taskLevelSensor(threadId, taskId, name, RecordingLevel.INFO);
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/TopicMetrics.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/TopicMetrics.java
new file mode 100644
index 0000000000..85b438d969
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/TopicMetrics.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals.metrics;
+
+import org.apache.kafka.common.metrics.Sensor;
+import org.apache.kafka.common.metrics.Sensor.RecordingLevel;
+
+import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.TOPIC_LEVEL_GROUP;
+import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.TOTAL_DESCRIPTION;
+import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addTotalCountAndSumMetricsToSensor;
+
+public class TopicMetrics {
+
+    private static final String CONSUMED = "consumed";
+    private static final String BYTES_CONSUMED = "bytes-consumed";
+    private static final String BYTES_CONSUMED_DESCRIPTION = "bytes consumed from this topic";
+    private static final String BYTES_CONSUMED_TOTAL_DESCRIPTION = TOTAL_DESCRIPTION + BYTES_CONSUMED_DESCRIPTION;
+    private static final String RECORDS_CONSUMED = "records-consumed";
+    private static final String RECORDS_CONSUMED_DESCRIPTION = "records consumed from this topic";
+    private static final String RECORDS_CONSUMED_TOTAL_DESCRIPTION = TOTAL_DESCRIPTION + RECORDS_CONSUMED_DESCRIPTION;
+
+    private static final String PRODUCED = "produced";
+    private static final String BYTES_PRODUCED = "bytes-produced";
+    private static final String BYTES_PRODUCED_DESCRIPTION = "bytes produced to this topic";
+    private static final String BYTES_PRODUCED_TOTAL_DESCRIPTION = TOTAL_DESCRIPTION + BYTES_PRODUCED_DESCRIPTION;
+    private static final String RECORDS_PRODUCED = "records-produced";
+    private static final String RECORDS_PRODUCED_DESCRIPTION = "records produced to this topic";
+    private static final String RECORDS_PRODUCED_TOTAL_DESCRIPTION = TOTAL_DESCRIPTION + RECORDS_PRODUCED_DESCRIPTION;
+
+    public static Sensor consumedSensor(final String threadId,
+                                        final String taskId,
+                                        final String processorNodeId,
+                                        final String topic,
+                                        final StreamsMetricsImpl streamsMetrics) {
+        final Sensor sensor = streamsMetrics.topicLevelSensor(
+            threadId,
+            taskId,
+            processorNodeId,
+            topic,
+            CONSUMED,
+            RecordingLevel.INFO);
+        addTotalCountAndSumMetricsToSensor(
+            sensor,
+            TOPIC_LEVEL_GROUP,
+            streamsMetrics.topicLevelTagMap(threadId, taskId, processorNodeId, topic),
+            RECORDS_CONSUMED,
+            BYTES_CONSUMED,
+            RECORDS_CONSUMED_TOTAL_DESCRIPTION,
+            BYTES_CONSUMED_TOTAL_DESCRIPTION
+        );
+        return sensor;
+    }
+
+    public static Sensor producedSensor(final String threadId,
+                                        final String taskId,
+                                        final String processorNodeId,
+                                        final String topic,
+                                        final StreamsMetricsImpl streamsMetrics) {
+        final Sensor sensor = streamsMetrics.topicLevelSensor(
+            threadId,
+            taskId,
+            processorNodeId,
+            topic,
+            PRODUCED,
+            RecordingLevel.INFO);
+        addTotalCountAndSumMetricsToSensor(
+            sensor,
+            TOPIC_LEVEL_GROUP,
+            streamsMetrics.topicLevelTagMap(threadId, taskId, processorNodeId, topic),
+            RECORDS_PRODUCED,
+            BYTES_PRODUCED,
+            RECORDS_PRODUCED_TOTAL_DESCRIPTION,
+            BYTES_PRODUCED_TOTAL_DESCRIPTION
+        );
+        return sensor;
+    }
+
+}
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryTimeOrderedKeyValueBuffer.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryTimeOrderedKeyValueBuffer.java
index a861da6469..5894023bbe 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryTimeOrderedKeyValueBuffer.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryTimeOrderedKeyValueBuffer.java
@@ -292,8 +292,9 @@ public final class InMemoryTimeOrderedKeyValueBuffer<K, V> implements TimeOrdere
             partition,
             null,
             KEY_SERIALIZER,
-            VALUE_SERIALIZER
-        );
+            VALUE_SERIALIZER,
+            null,
+            null);
     }
 
     private void logTombstone(final Bytes key) {
@@ -305,8 +306,9 @@ public final class InMemoryTimeOrderedKeyValueBuffer<K, V> implements TimeOrdere
             partition,
             null,
             KEY_SERIALIZER,
-            VALUE_SERIALIZER
-        );
+            VALUE_SERIALIZER,
+            null,
+            null);
     }
 
     private void restoreBatch(final Collection<ConsumerRecord<byte[], byte[]>> batch) {
diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/MetricsIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/MetricsIntegrationTest.java
index 0da0ac9eef..2e1dd141f8 100644
--- a/streams/src/test/java/org/apache/kafka/streams/integration/MetricsIntegrationTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/integration/MetricsIntegrationTest.java
@@ -93,6 +93,7 @@ public class MetricsIntegrationTest {
     private static final String STREAM_THREAD_NODE_METRICS = "stream-thread-metrics";
     private static final String STREAM_TASK_NODE_METRICS = "stream-task-metrics";
     private static final String STREAM_PROCESSOR_NODE_METRICS = "stream-processor-node-metrics";
+    private static final String STREAM_TOPIC_METRICS = "stream-topic-metrics";
     private static final String STREAM_CACHE_NODE_METRICS = "stream-record-cache-metrics";
 
     private static final String IN_MEMORY_KVSTORE_TAG_KEY = "in-memory-state-id";
@@ -217,6 +218,10 @@ public class MetricsIntegrationTest {
     private static final String RECORD_E2E_LATENCY_AVG = "record-e2e-latency-avg";
     private static final String RECORD_E2E_LATENCY_MIN = "record-e2e-latency-min";
     private static final String RECORD_E2E_LATENCY_MAX = "record-e2e-latency-max";
+    private static final String BYTES_CONSUMED_TOTAL = "bytes-consumed-total";
+    private static final String RECORDS_CONSUMED_TOTAL = "records-consumed-total";
+    private static final String BYTES_PRODUCED_TOTAL = "bytes-produced-total";
+    private static final String RECORDS_PRODUCED_TOTAL = "records-produced-total";
 
     // stores name
     private static final String TIME_WINDOWED_AGGREGATED_STREAM_STORE = "time-windowed-aggregated-stream-store";
@@ -360,6 +365,7 @@ public class MetricsIntegrationTest {
         checkThreadLevelMetrics();
         checkTaskLevelMetrics();
         checkProcessorNodeLevelMetrics();
+        checkTopicLevelMetrics();
         checkKeyValueStoreMetrics(IN_MEMORY_KVSTORE_TAG_KEY);
         checkKeyValueStoreMetrics(ROCKSDB_KVSTORE_TAG_KEY);
         checkKeyValueStoreMetrics(IN_MEMORY_LRUCACHE_TAG_KEY);
@@ -548,6 +554,18 @@ public class MetricsIntegrationTest {
         checkMetricByName(listMetricProcessor, RECORD_E2E_LATENCY_MAX, numberOfSourceNodes + numberOfTerminalNodes);
     }
 
+    private void checkTopicLevelMetrics() {
+        final List<Metric> listMetricProcessor = new ArrayList<Metric>(kafkaStreams.metrics().values()).stream()
+            .filter(m -> m.metricName().group().equals(STREAM_TOPIC_METRICS))
+            .collect(Collectors.toList());
+        final int numberOfSourceTopics = 4;
+        final int numberOfSinkTopics = 4;
+        checkMetricByName(listMetricProcessor, BYTES_CONSUMED_TOTAL, numberOfSourceTopics);
+        checkMetricByName(listMetricProcessor, RECORDS_CONSUMED_TOTAL, numberOfSourceTopics);
+        checkMetricByName(listMetricProcessor, BYTES_PRODUCED_TOTAL, numberOfSinkTopics);
+        checkMetricByName(listMetricProcessor, RECORDS_PRODUCED_TOTAL, numberOfSinkTopics);
+    }
+
     private void checkKeyValueStoreMetrics(final String tagKey) {
         final List<Metric> listMetricStore = new ArrayList<Metric>(kafkaStreams.metrics().values()).stream()
             .filter(m -> m.metricName().tags().containsKey(tagKey) && m.metricName().group().equals(STATE_STORE_LEVEL_GROUP))
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreatorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreatorTest.java
index e2ef0d16d4..538360bd63 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreatorTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreatorTest.java
@@ -60,6 +60,7 @@ import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.closeTo;
 import static org.hamcrest.core.IsNot.not;
 import static org.junit.Assert.assertThrows;
+import static java.util.Collections.emptySet;
 
 @RunWith(EasyMockRunner.class)
 public class ActiveTaskCreatorTest {
@@ -478,6 +479,7 @@ public class ActiveTaskCreatorTest {
         reset(builder, stateDirectory);
         expect(builder.topologyConfigs()).andStubReturn(new TopologyConfig(new StreamsConfig(properties)));
         expect(builder.buildSubtopology(0)).andReturn(topology).anyTimes();
+        expect(topology.sinkTopics()).andStubReturn(emptySet());
         expect(stateDirectory.getOrCreateDirectoryForTask(task00)).andReturn(mock(File.class));
         expect(stateDirectory.checkpointFileFor(task00)).andReturn(mock(File.class));
         expect(stateDirectory.getOrCreateDirectoryForTask(task01)).andReturn(mock(File.class));
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ClientUtilsTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ClientUtilsTest.java
index a6c5e3d0b4..775268496e 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ClientUtilsTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ClientUtilsTest.java
@@ -17,6 +17,7 @@
 package org.apache.kafka.streams.processor.internals;
 
 import java.util.Map;
+import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.ExecutionException;
 import org.apache.kafka.clients.admin.Admin;
@@ -24,29 +25,76 @@ import org.apache.kafka.clients.admin.AdminClient;
 import org.apache.kafka.clients.admin.ListOffsetsResult;
 import org.apache.kafka.clients.admin.ListOffsetsResult.ListOffsetsResultInfo;
 import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.producer.ProducerRecord;
 import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.KafkaFuture;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.errors.TimeoutException;
+import org.apache.kafka.common.header.Headers;
+import org.apache.kafka.common.header.internals.RecordHeader;
+import org.apache.kafka.common.header.internals.RecordHeaders;
+import org.apache.kafka.common.record.TimestampType;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.easymock.EasyMock;
 import org.junit.Test;
 
+import static java.util.Arrays.asList;
 import static java.util.Collections.emptySet;
 import static org.apache.kafka.common.utils.Utils.mkSet;
+import static org.apache.kafka.streams.processor.internals.ClientUtils.consumerRecordSizeInBytes;
 import static org.apache.kafka.streams.processor.internals.ClientUtils.fetchCommittedOffsets;
 import static org.apache.kafka.streams.processor.internals.ClientUtils.fetchEndOffsets;
+import static org.apache.kafka.streams.processor.internals.ClientUtils.producerRecordSizeInBytes;
+
 import static org.easymock.EasyMock.expect;
 import static org.easymock.EasyMock.replay;
 import static org.easymock.EasyMock.verify;
+import static org.hamcrest.CoreMatchers.equalTo;
+import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
 
 public class ClientUtilsTest {
 
+    // consumer and producer records use utf8 encoding for topic name, header keys, etc
+    private static final String TOPIC = "topic";
+    private static final int TOPIC_BYTES = 5;
+
+    private static final byte[] KEY = "key".getBytes();
+    private static final int KEY_BYTES = 3;
+
+    private static final byte[] VALUE = "value".getBytes();
+    private static final int VALUE_BYTES = 5;
+
+    private static final Headers HEADERS = new RecordHeaders(asList(
+        new RecordHeader("h1", "headerVal1".getBytes()),   // 2 + 10 --> 12 bytes
+        new RecordHeader("h2", "headerVal2".getBytes())
+    ));    // 2 + 10 --> 12 bytes
+    private static final int HEADERS_BYTES = 24;
+
+    private static final int RECORD_METADATA_BYTES =
+        8 + // timestamp
+        8 + // offset
+        4;  // partition
+
+    // 57 bytes
+    private static final long SIZE_IN_BYTES =
+        KEY_BYTES +
+        VALUE_BYTES +
+        TOPIC_BYTES +
+        HEADERS_BYTES +
+        RECORD_METADATA_BYTES;
+
+    private static final long TOMBSTONE_SIZE_IN_BYTES =
+        KEY_BYTES +
+        TOPIC_BYTES +
+        HEADERS_BYTES +
+        RECORD_METADATA_BYTES;
+
     private static final Set<TopicPartition> PARTITIONS = mkSet(
-        new TopicPartition("topic", 1),
-        new TopicPartition("topic", 2)
+        new TopicPartition(TOPIC, 1),
+        new TopicPartition(TOPIC, 2)
     );
 
     @Test
@@ -121,5 +169,67 @@ public class ClientUtilsTest {
         assertThrows(StreamsException.class, () -> fetchEndOffsets(PARTITIONS, adminClient));
         verify(adminClient);
     }
+    
+    @Test
+    public void shouldComputeSizeInBytesForConsumerRecord() {
+        final ConsumerRecord<byte[], byte[]> record = new ConsumerRecord<>(
+            TOPIC,
+            1,
+            0L,
+            0L,
+            TimestampType.CREATE_TIME,
+            KEY_BYTES,
+            VALUE_BYTES,
+            KEY,
+            VALUE,
+            HEADERS,
+            Optional.empty()
+        );
+
+        assertThat(consumerRecordSizeInBytes(record), equalTo(SIZE_IN_BYTES));
+    }
+
+    @Test
+    public void shouldComputeSizeInBytesForProducerRecord() {
+        final ProducerRecord<byte[], byte[]> record = new ProducerRecord<>(
+            TOPIC,
+            1,
+            0L,
+            KEY,
+            VALUE,
+            HEADERS
+        );
+        assertThat(producerRecordSizeInBytes(record), equalTo(SIZE_IN_BYTES));
+    }
+
+    @Test
+    public void shouldComputeSizeInBytesForConsumerRecordWithNullValue() {
+        final ConsumerRecord<byte[], byte[]> record = new ConsumerRecord<>(
+            TOPIC,
+            1,
+            0,
+            0L,
+            TimestampType.CREATE_TIME,
+            KEY_BYTES,
+            0,
+            KEY,
+            null,
+            HEADERS,
+            Optional.empty()
+        );
+        assertThat(consumerRecordSizeInBytes(record), equalTo(TOMBSTONE_SIZE_IN_BYTES));
+    }
 
+    @Test
+    public void shouldComputeSizeInBytesForProducerRecordWithNullValue() {
+        final ProducerRecord<byte[], byte[]> record = new ProducerRecord<>(
+            TOPIC,
+            1,
+            0L,
+            KEY,
+            null,
+            HEADERS
+        );
+        assertThat(producerRecordSizeInBytes(record), equalTo(TOMBSTONE_SIZE_IN_BYTES));
+    }
 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/PartitionGroupTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/PartitionGroupTest.java
index 389d0d58c8..012373607e 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/PartitionGroupTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/PartitionGroupTest.java
@@ -19,7 +19,6 @@ package org.apache.kafka.streams.processor.internals;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.common.MetricName;
 import org.apache.kafka.common.TopicPartition;
-import org.apache.kafka.common.header.Header;
 import org.apache.kafka.common.header.internals.RecordHeaders;
 import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.metrics.Sensor;
@@ -53,6 +52,8 @@ import java.util.Optional;
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.common.utils.Utils.mkSet;
+import static org.apache.kafka.streams.processor.internals.ClientUtils.consumerRecordSizeInBytes;
+
 import static org.hamcrest.CoreMatchers.is;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.equalTo;
@@ -846,19 +847,7 @@ public class PartitionGroupTest {
     private long getBytesBufferedForRawRecords(final List<ConsumerRecord<byte[], byte[]>> rawRecords) {
         long rawRecordsSizeInBytes = 0L;
         for (final ConsumerRecord<byte[], byte[]> rawRecord : rawRecords) {
-            long headerSizeInBytes = 0L;
-
-            for (final Header header: rawRecord.headers().toArray()) {
-                headerSizeInBytes += header.key().getBytes().length + header.value().length;
-            }
-
-            rawRecordsSizeInBytes += rawRecord.serializedKeySize() +
-                    rawRecord.serializedValueSize() +
-                    8L + // timestamp
-                    8L + // offset
-                    rawRecord.topic().getBytes().length +
-                    4L + // partition
-                    headerSizeInBytes;
+            rawRecordsSizeInBytes += consumerRecordSizeInBytes(rawRecord);
         }
         return rawRecordsSizeInBytes;
     }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorContextImplTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorContextImplTest.java
index b6cc7a789e..3e1832048d 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorContextImplTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorContextImplTest.java
@@ -397,15 +397,17 @@ public class ProcessorContextImplTest {
     @Test
     public void shouldNotSendRecordHeadersToChangelogTopic() {
         recordCollector.send(
-                CHANGELOG_PARTITION.topic(),
-                KEY_BYTES,
-                VALUE_BYTES,
-                null,
-                CHANGELOG_PARTITION.partition(),
-                TIMESTAMP,
-                BYTES_KEY_SERIALIZER,
-                BYTEARRAY_VALUE_SERIALIZER
-        );
+            CHANGELOG_PARTITION.topic(),
+            KEY_BYTES,
+            VALUE_BYTES,
+            null,
+            CHANGELOG_PARTITION.partition(),
+            TIMESTAMP,
+            BYTES_KEY_SERIALIZER,
+            BYTEARRAY_VALUE_SERIALIZER,
+            null,
+            null);
+
         final StreamTask task = EasyMock.createNiceMock(StreamTask.class);
 
         replay(recordCollector, task);
@@ -423,15 +425,16 @@ public class ProcessorContextImplTest {
         headers.add(new RecordHeader(ChangelogRecordDeserializationHelper.CHANGELOG_POSITION_HEADER_KEY,
                 PositionSerde.serialize(position).array()));
         recordCollector.send(
-                CHANGELOG_PARTITION.topic(),
-                KEY_BYTES,
-                VALUE_BYTES,
-                headers,
-                CHANGELOG_PARTITION.partition(),
-                TIMESTAMP,
-                BYTES_KEY_SERIALIZER,
-                BYTEARRAY_VALUE_SERIALIZER
-        );
+            CHANGELOG_PARTITION.topic(),
+            KEY_BYTES,
+            VALUE_BYTES,
+            headers,
+            CHANGELOG_PARTITION.partition(),
+            TIMESTAMP,
+            BYTES_KEY_SERIALIZER,
+            BYTEARRAY_VALUE_SERIALIZER,
+            null,
+            null);
 
         final StreamTask task = EasyMock.createNiceMock(StreamTask.class);
 
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java
index 2da9b397e5..b3fa516a3f 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java
@@ -53,6 +53,7 @@ import org.apache.kafka.streams.processor.StreamPartitioner;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
 import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender;
+import org.apache.kafka.test.InternalMockProcessorContext;
 import org.apache.kafka.test.MockClientSupplier;
 
 import java.util.UUID;
@@ -69,6 +70,9 @@ import java.util.concurrent.atomic.AtomicBoolean;
 
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
+import static org.apache.kafka.streams.processor.internals.ClientUtils.producerRecordSizeInBytes;
+import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.TOPIC_LEVEL_GROUP;
+
 import static org.easymock.EasyMock.expect;
 import static org.easymock.EasyMock.expectLastCall;
 import static org.easymock.EasyMock.mock;
@@ -81,6 +85,10 @@ import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
+import static java.util.Collections.emptyList;
+import static java.util.Collections.emptyMap;
+import static java.util.Collections.emptySet;
+import static java.util.Collections.singletonMap;
 
 public class RecordCollectorTest {
 
@@ -99,6 +107,7 @@ public class RecordCollectorTest {
     ));
 
     private final String topic = "topic";
+    private final String sinkNodeName = "output-node";
     private final Cluster cluster = new Cluster(
         "cluster",
         Collections.singletonList(Node.noNode()),
@@ -120,6 +129,8 @@ public class RecordCollectorTest {
 
     private MockProducer<byte[], byte[]> mockProducer;
     private StreamsProducer streamsProducer;
+    private ProcessorTopology topology;
+    private final InternalProcessorContext<Void, Void> context = new InternalMockProcessorContext<>();
 
     private RecordCollectorImpl collector;
 
@@ -137,12 +148,29 @@ public class RecordCollectorTest {
             Time.SYSTEM
         );
         mockProducer = clientSupplier.producers.get(0);
+        final SinkNode<?, ?> sinkNode = new SinkNode<>(
+            sinkNodeName,
+            new StaticTopicNameExtractor<>(topic),
+            stringSerializer,
+            byteArraySerializer,
+            streamPartitioner);
+        topology = new ProcessorTopology(
+            emptyList(),
+            emptyMap(),
+            singletonMap(topic, sinkNode),
+            emptyList(),
+            emptyList(),
+            emptyMap(),
+            emptySet()
+        );
         collector = new RecordCollectorImpl(
             logContext,
             taskId,
             streamsProducer,
             productionExceptionHandler,
-            streamsMetrics);
+            streamsMetrics,
+            topology
+        );
     }
 
     @After
@@ -150,16 +178,73 @@ public class RecordCollectorTest {
         collector.closeClean();
     }
 
+    @Test
+    public void shouldRecordRecordsAndBytesProduced() {
+        final Headers headers = new RecordHeaders(new Header[]{new RecordHeader("key", "value".getBytes())});
+
+        final String threadId = Thread.currentThread().getName();
+        final String processorNodeId = sinkNodeName;
+        final String topic = "topic";
+        final Metric recordsProduced = streamsMetrics.metrics().get(
+            new MetricName("records-produced-total",
+                           TOPIC_LEVEL_GROUP,
+                           "The total number of records produced from this topic",
+                           streamsMetrics.topicLevelTagMap(threadId, taskId.toString(), processorNodeId, topic))
+        );
+        final Metric bytesProduced = streamsMetrics.metrics().get(
+            new MetricName("bytes-produced-total",
+                           TOPIC_LEVEL_GROUP,
+                           "The total number of bytes produced from this topic",
+                           streamsMetrics.topicLevelTagMap(threadId, taskId.toString(), processorNodeId, topic))
+        );
+
+        double totalRecords = 0D;
+        double totalBytes = 0D;
+
+        assertThat(recordsProduced.metricValue(), equalTo(totalRecords));
+        assertThat(bytesProduced.metricValue(), equalTo(totalBytes));
+
+        collector.send(topic, "999", "0", null, 0, null, stringSerializer, stringSerializer, sinkNodeName, context);
+        ++totalRecords;
+        totalBytes += producerRecordSizeInBytes(mockProducer.history().get(0));
+        assertThat(recordsProduced.metricValue(), equalTo(totalRecords));
+        assertThat(bytesProduced.metricValue(), equalTo(totalBytes));
+
+        collector.send(topic, "999", "0", headers, 1, null, stringSerializer, stringSerializer, sinkNodeName, context);
+        ++totalRecords;
+        totalBytes += producerRecordSizeInBytes(mockProducer.history().get(1));
+        assertThat(recordsProduced.metricValue(), equalTo(totalRecords));
+        assertThat(bytesProduced.metricValue(), equalTo(totalBytes));
+
+        collector.send(topic, "999", "0", null, 0, null, stringSerializer, stringSerializer, sinkNodeName, context);
+        ++totalRecords;
+        totalBytes += producerRecordSizeInBytes(mockProducer.history().get(2));
+        assertThat(recordsProduced.metricValue(), equalTo(totalRecords));
+        assertThat(bytesProduced.metricValue(), equalTo(totalBytes));
+
+        collector.send(topic, "999", "0", headers, 1, null, stringSerializer, stringSerializer, sinkNodeName, context);
+        ++totalRecords;
+        totalBytes += producerRecordSizeInBytes(mockProducer.history().get(3));
+        assertThat(recordsProduced.metricValue(), equalTo(totalRecords));
+        assertThat(bytesProduced.metricValue(), equalTo(totalBytes));
+
+        collector.send(topic, "999", "0", null, 0, null, stringSerializer, stringSerializer, sinkNodeName, context);
+        ++totalRecords;
+        totalBytes += producerRecordSizeInBytes(mockProducer.history().get(4));
+        assertThat(recordsProduced.metricValue(), equalTo(totalRecords));
+        assertThat(bytesProduced.metricValue(), equalTo(totalBytes));
+    }
+
     @Test
     public void shouldSendToSpecificPartition() {
         final Headers headers = new RecordHeaders(new Header[] {new RecordHeader("key", "value".getBytes())});
 
-        collector.send(topic, "999", "0", null, 0, null, stringSerializer, stringSerializer);
-        collector.send(topic, "999", "0", null, 0, null, stringSerializer, stringSerializer);
-        collector.send(topic, "999", "0", null, 0, null, stringSerializer, stringSerializer);
-        collector.send(topic, "999", "0", headers, 1, null, stringSerializer, stringSerializer);
-        collector.send(topic, "999", "0", headers, 1, null, stringSerializer, stringSerializer);
-        collector.send(topic, "999", "0", headers, 2, null, stringSerializer, stringSerializer);
+        collector.send(topic, "999", "0", null, 0, null, stringSerializer, stringSerializer, null, null);
+        collector.send(topic, "999", "0", null, 0, null, stringSerializer, stringSerializer, null, null);
+        collector.send(topic, "999", "0", null, 0, null, stringSerializer, stringSerializer, null, null);
+        collector.send(topic, "999", "0", headers, 1, null, stringSerializer, stringSerializer, null, null);
+        collector.send(topic, "999", "0", headers, 1, null, stringSerializer, stringSerializer, null, null);
+        collector.send(topic, "999", "0", headers, 2, null, stringSerializer, stringSerializer, null, null);
 
         Map<TopicPartition, Long> offsets = collector.offsets();
 
@@ -168,9 +253,9 @@ public class RecordCollectorTest {
         assertEquals(0L, (long) offsets.get(new TopicPartition(topic, 2)));
         assertEquals(6, mockProducer.history().size());
 
-        collector.send(topic, "999", "0", null, 0, null, stringSerializer, stringSerializer);
-        collector.send(topic, "999", "0", null, 1, null, stringSerializer, stringSerializer);
-        collector.send(topic, "999", "0", headers, 2, null, stringSerializer, stringSerializer);
+        collector.send(topic, "999", "0", null, 0, null, stringSerializer, stringSerializer, null, null);
+        collector.send(topic, "999", "0", null, 1, null, stringSerializer, stringSerializer, null, null);
+        collector.send(topic, "999", "0", headers, 2, null, stringSerializer, stringSerializer, null, null);
 
         offsets = collector.offsets();
 
@@ -184,15 +269,15 @@ public class RecordCollectorTest {
     public void shouldSendWithPartitioner() {
         final Headers headers = new RecordHeaders(new Header[] {new RecordHeader("key", "value".getBytes())});
 
-        collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner);
-        collector.send(topic, "9", "0", null, null, stringSerializer, stringSerializer, streamPartitioner);
-        collector.send(topic, "27", "0", null, null, stringSerializer, stringSerializer, streamPartitioner);
-        collector.send(topic, "81", "0", null, null, stringSerializer, stringSerializer, streamPartitioner);
-        collector.send(topic, "243", "0", null, null, stringSerializer, stringSerializer, streamPartitioner);
-        collector.send(topic, "28", "0", headers, null, stringSerializer, stringSerializer, streamPartitioner);
-        collector.send(topic, "82", "0", headers, null, stringSerializer, stringSerializer, streamPartitioner);
-        collector.send(topic, "244", "0", headers, null, stringSerializer, stringSerializer, streamPartitioner);
-        collector.send(topic, "245", "0", null, null, stringSerializer, stringSerializer, streamPartitioner);
+        collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, null, null, streamPartitioner);
+        collector.send(topic, "9", "0", null, null, stringSerializer, stringSerializer, null, null, streamPartitioner);
+        collector.send(topic, "27", "0", null, null, stringSerializer, stringSerializer, null, null, streamPartitioner);
+        collector.send(topic, "81", "0", null, null, stringSerializer, stringSerializer, null, null, streamPartitioner);
+        collector.send(topic, "243", "0", null, null, stringSerializer, stringSerializer, null, null, streamPartitioner);
+        collector.send(topic, "28", "0", headers, null, stringSerializer, stringSerializer, null, null, streamPartitioner);
+        collector.send(topic, "82", "0", headers, null, stringSerializer, stringSerializer, null, null, streamPartitioner);
+        collector.send(topic, "244", "0", headers, null, stringSerializer, stringSerializer, null, null, streamPartitioner);
+        collector.send(topic, "245", "0", null, null, stringSerializer, stringSerializer, null, null, streamPartitioner);
 
         final Map<TopicPartition, Long> offsets = collector.offsets();
 
@@ -210,15 +295,15 @@ public class RecordCollectorTest {
     public void shouldSendWithNoPartition() {
         final Headers headers = new RecordHeaders(new Header[] {new RecordHeader("key", "value".getBytes())});
 
-        collector.send(topic, "3", "0", headers, null, null, stringSerializer, stringSerializer);
-        collector.send(topic, "9", "0", headers, null, null, stringSerializer, stringSerializer);
-        collector.send(topic, "27", "0", headers, null, null, stringSerializer, stringSerializer);
-        collector.send(topic, "81", "0", headers, null, null, stringSerializer, stringSerializer);
-        collector.send(topic, "243", "0", headers, null, null, stringSerializer, stringSerializer);
-        collector.send(topic, "28", "0", headers, null, null, stringSerializer, stringSerializer);
-        collector.send(topic, "82", "0", headers, null, null, stringSerializer, stringSerializer);
-        collector.send(topic, "244", "0", headers, null, null, stringSerializer, stringSerializer);
-        collector.send(topic, "245", "0", headers, null, null, stringSerializer, stringSerializer);
+        collector.send(topic, "3", "0", headers, null, null, stringSerializer, stringSerializer, null, null);
+        collector.send(topic, "9", "0", headers, null, null, stringSerializer, stringSerializer, null, null);
+        collector.send(topic, "27", "0", headers, null, null, stringSerializer, stringSerializer, null, null);
+        collector.send(topic, "81", "0", headers, null, null, stringSerializer, stringSerializer, null, null);
+        collector.send(topic, "243", "0", headers, null, null, stringSerializer, stringSerializer, null, null);
+        collector.send(topic, "28", "0", headers, null, null, stringSerializer, stringSerializer, null, null);
+        collector.send(topic, "82", "0", headers, null, null, stringSerializer, stringSerializer, null, null);
+        collector.send(topic, "244", "0", headers, null, null, stringSerializer, stringSerializer, null, null);
+        collector.send(topic, "245", "0", headers, null, null, stringSerializer, stringSerializer, null, null);
 
         final Map<TopicPartition, Long> offsets = collector.offsets();
 
@@ -233,9 +318,9 @@ public class RecordCollectorTest {
     public void shouldUpdateOffsetsUponCompletion() {
         Map<TopicPartition, Long> offsets = collector.offsets();
 
-        collector.send(topic, "999", "0", null, 0, null, stringSerializer, stringSerializer);
-        collector.send(topic, "999", "0", null, 1, null, stringSerializer, stringSerializer);
-        collector.send(topic, "999", "0", null, 2, null, stringSerializer, stringSerializer);
+        collector.send(topic, "999", "0", null, 0, null, stringSerializer, stringSerializer, null, null);
+        collector.send(topic, "999", "0", null, 1, null, stringSerializer, stringSerializer, null, null);
+        collector.send(topic, "999", "0", null, 2, null, stringSerializer, stringSerializer, null, null);
 
         assertEquals(Collections.<TopicPartition, Long>emptyMap(), offsets);
 
@@ -253,7 +338,7 @@ public class RecordCollectorTest {
         final CustomStringSerializer valueSerializer = new CustomStringSerializer();
         keySerializer.configure(Collections.emptyMap(), true);
 
-        collector.send(topic, "3", "0", new RecordHeaders(), null, keySerializer, valueSerializer, streamPartitioner);
+        collector.send(topic, "3", "0", new RecordHeaders(), null, keySerializer, valueSerializer, null, null, streamPartitioner);
 
         final List<ProducerRecord<byte[], byte[]>> recordHistory = mockProducer.history();
         for (final ProducerRecord<byte[], byte[]> sentRecord : recordHistory) {
@@ -270,14 +355,19 @@ public class RecordCollectorTest {
         expect(streamsProducer.eosEnabled()).andReturn(false);
         streamsProducer.flush();
         expectLastCall();
-        replay(streamsProducer);
+
+        final ProcessorTopology topology = mock(ProcessorTopology.class);
+        expect(topology.sinkTopics()).andStubReturn(Collections.emptySet());
+        replay(streamsProducer, topology);
 
         final RecordCollector collector = new RecordCollectorImpl(
             logContext,
             taskId,
             streamsProducer,
             productionExceptionHandler,
-            streamsMetrics);
+            streamsMetrics, 
+            topology
+        );
 
         collector.flush();
 
@@ -290,14 +380,18 @@ public class RecordCollectorTest {
         expect(streamsProducer.eosEnabled()).andReturn(true);
         streamsProducer.flush();
         expectLastCall();
-        replay(streamsProducer);
-
+        final ProcessorTopology topology = mock(ProcessorTopology.class);
+        expect(topology.sinkTopics()).andStubReturn(Collections.emptySet());
+        replay(streamsProducer, topology);
+        
         final RecordCollector collector = new RecordCollectorImpl(
             logContext,
             taskId,
             streamsProducer,
             productionExceptionHandler,
-            streamsMetrics);
+            streamsMetrics,
+            topology
+        );
 
         collector.flush();
 
@@ -308,15 +402,20 @@ public class RecordCollectorTest {
     public void shouldNotAbortTxOnCloseCleanIfEosEnabled() {
         final StreamsProducer streamsProducer = mock(StreamsProducer.class);
         expect(streamsProducer.eosEnabled()).andReturn(true);
-        replay(streamsProducer);
-
+        
+        final ProcessorTopology topology = mock(ProcessorTopology.class);
+        expect(topology.sinkTopics()).andStubReturn(Collections.emptySet());
+        replay(streamsProducer, topology);
+        
         final RecordCollector collector = new RecordCollectorImpl(
             logContext,
             taskId,
             streamsProducer,
             productionExceptionHandler,
-            streamsMetrics);
-
+            streamsMetrics,
+            topology
+        );
+       
         collector.closeClean();
 
         verify(streamsProducer);
@@ -327,14 +426,19 @@ public class RecordCollectorTest {
         final StreamsProducer streamsProducer = mock(StreamsProducer.class);
         expect(streamsProducer.eosEnabled()).andReturn(true);
         streamsProducer.abortTransaction();
-        replay(streamsProducer);
-
+        
+        final ProcessorTopology topology = mock(ProcessorTopology.class);
+        expect(topology.sinkTopics()).andStubReturn(Collections.emptySet());
+        replay(streamsProducer, topology);
+        
         final RecordCollector collector = new RecordCollectorImpl(
             logContext,
             taskId,
             streamsProducer,
             productionExceptionHandler,
-            streamsMetrics);
+            streamsMetrics,
+            topology
+        );
 
         collector.closeDirty();
 
@@ -354,7 +458,7 @@ public class RecordCollectorTest {
                 0,
                 0L,
                 (Serializer) new LongSerializer(), // need to add cast to trigger `ClassCastException`
-                new StringSerializer())
+                new StringSerializer(), null, null)
         );
 
         assertThat(expected.getCause(), instanceOf(ClassCastException.class));
@@ -382,7 +486,7 @@ public class RecordCollectorTest {
                 0,
                 0L,
                 (Serializer) new LongSerializer(), // need to add cast to trigger `ClassCastException`
-                new StringSerializer())
+                new StringSerializer(), null, null)
         );
 
         assertThat(expected.getCause(), instanceOf(ClassCastException.class));
@@ -410,7 +514,7 @@ public class RecordCollectorTest {
                 0,
                 0L,
                 new StringSerializer(),
-                (Serializer) new LongSerializer()) // need to add cast to trigger `ClassCastException`
+                (Serializer) new LongSerializer(), null, null) // need to add cast to trigger `ClassCastException`
         );
 
         assertThat(expected.getCause(), instanceOf(ClassCastException.class));
@@ -438,7 +542,7 @@ public class RecordCollectorTest {
                 0,
                 0L,
                 new StringSerializer(),
-                (Serializer) new LongSerializer()) // need to add cast to trigger `ClassCastException`
+                (Serializer) new LongSerializer(), null, null) // need to add cast to trigger `ClassCastException`
         );
 
         assertThat(expected.getCause(), instanceOf(ClassCastException.class));
@@ -460,13 +564,14 @@ public class RecordCollectorTest {
             taskId,
             getExceptionalStreamProducerOnPartitionsFor(new KafkaException("Kaboom!")),
             productionExceptionHandler,
-            streamsMetrics
+            streamsMetrics,
+            topology
         );
         collector.initialize();
 
         final StreamsException exception = assertThrows(
             StreamsException.class,
-            () -> collector.send(topic, "0", "0", null, null, stringSerializer, stringSerializer, streamPartitioner)
+            () -> collector.send(topic, "0", "0", null, null, stringSerializer, stringSerializer, null, null, streamPartitioner)
         );
         assertThat(
             exception.getMessage(),
@@ -491,13 +596,14 @@ public class RecordCollectorTest {
             taskId,
             getExceptionalStreamProducerOnPartitionsFor(runtimeException),
             productionExceptionHandler,
-            streamsMetrics
+            streamsMetrics,
+            topology
         );
         collector.initialize();
 
         final RuntimeException exception = assertThrows(
             runtimeException.getClass(),
-            () -> collector.send(topic, "0", "0", null, null, stringSerializer, stringSerializer, streamPartitioner)
+            () -> collector.send(topic, "0", "0", null, null, stringSerializer, stringSerializer, null, null, streamPartitioner)
         );
         assertThat(exception.getMessage(), equalTo("Kaboom!"));
     }
@@ -518,15 +624,16 @@ public class RecordCollectorTest {
             taskId,
             getExceptionalStreamsProducerOnSend(exception),
             productionExceptionHandler,
-            streamsMetrics
+            streamsMetrics,
+            topology
         );
         collector.initialize();
 
-        collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner);
+        collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, null, null, streamPartitioner);
 
         final TaskMigratedException thrown = assertThrows(
             TaskMigratedException.class,
-            () -> collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner)
+            () -> collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, null, null, streamPartitioner)
         );
         assertEquals(exception, thrown.getCause());
     }
@@ -547,11 +654,12 @@ public class RecordCollectorTest {
             taskId,
             getExceptionalStreamsProducerOnSend(exception),
             productionExceptionHandler,
-            streamsMetrics
+            streamsMetrics,
+            topology
         );
         collector.initialize();
 
-        collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner);
+        collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, null, null, streamPartitioner);
 
         final TaskMigratedException thrown = assertThrows(TaskMigratedException.class, collector::flush);
         assertEquals(exception, thrown.getCause());
@@ -573,11 +681,12 @@ public class RecordCollectorTest {
             taskId,
             getExceptionalStreamsProducerOnSend(exception),
             productionExceptionHandler,
-            streamsMetrics
+            streamsMetrics,
+            topology
         );
         collector.initialize();
 
-        collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner);
+        collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, null, null, streamPartitioner);
 
         final TaskMigratedException thrown = assertThrows(TaskMigratedException.class, collector::closeClean);
         assertEquals(exception, thrown.getCause());
@@ -591,14 +700,15 @@ public class RecordCollectorTest {
             taskId,
             getExceptionalStreamsProducerOnSend(exception),
             productionExceptionHandler,
-            streamsMetrics
+            streamsMetrics,
+            topology
         );
 
-        collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner);
+        collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, null, null, streamPartitioner);
 
         final StreamsException thrown = assertThrows(
             StreamsException.class,
-            () -> collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner)
+            () -> collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, null, null, streamPartitioner)
         );
         assertEquals(exception, thrown.getCause());
         assertThat(
@@ -617,10 +727,11 @@ public class RecordCollectorTest {
             taskId,
             getExceptionalStreamsProducerOnSend(exception),
             productionExceptionHandler,
-            streamsMetrics
+            streamsMetrics,
+            topology
         );
 
-        collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner);
+        collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, null, null, streamPartitioner);
 
         final StreamsException thrown = assertThrows(StreamsException.class, collector::flush);
         assertEquals(exception, thrown.getCause());
@@ -640,10 +751,11 @@ public class RecordCollectorTest {
             taskId,
             getExceptionalStreamsProducerOnSend(exception),
             productionExceptionHandler,
-            streamsMetrics
+            streamsMetrics,
+            topology
         );
 
-        collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner);
+        collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, null, null, streamPartitioner);
 
         final StreamsException thrown = assertThrows(StreamsException.class, collector::closeClean);
         assertEquals(exception, thrown.getCause());
@@ -663,14 +775,15 @@ public class RecordCollectorTest {
             taskId,
             getExceptionalStreamsProducerOnSend(exception),
             new AlwaysContinueProductionExceptionHandler(),
-            streamsMetrics
+            streamsMetrics,
+            topology
         );
 
-        collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner);
+        collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, null, null, streamPartitioner);
 
         final StreamsException thrown = assertThrows(
             StreamsException.class,
-            () -> collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner)
+            () -> collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, null, null, streamPartitioner)
         );
         assertEquals(exception, thrown.getCause());
         assertThat(
@@ -689,10 +802,11 @@ public class RecordCollectorTest {
             taskId,
             getExceptionalStreamsProducerOnSend(exception),
             new AlwaysContinueProductionExceptionHandler(),
-            streamsMetrics
+            streamsMetrics,
+            topology
         );
 
-        collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner);
+        collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, null, null, streamPartitioner);
 
         final StreamsException thrown = assertThrows(StreamsException.class, collector::flush);
         assertEquals(exception, thrown.getCause());
@@ -712,10 +826,11 @@ public class RecordCollectorTest {
             taskId,
             getExceptionalStreamsProducerOnSend(exception),
             new AlwaysContinueProductionExceptionHandler(),
-            streamsMetrics
+            streamsMetrics,
+            topology
         );
 
-        collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner);
+        collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, null, null, streamPartitioner);
 
         final StreamsException thrown = assertThrows(StreamsException.class, collector::closeClean);
         assertEquals(exception, thrown.getCause());
@@ -734,13 +849,14 @@ public class RecordCollectorTest {
             taskId,
             getExceptionalStreamsProducerOnSend(new Exception()),
             new AlwaysContinueProductionExceptionHandler(),
-            streamsMetrics
+            streamsMetrics,
+            topology
         );
 
         try (final LogCaptureAppender logCaptureAppender =
                  LogCaptureAppender.createAndRegister(RecordCollectorImpl.class)) {
 
-            collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner);
+            collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, null, null, streamPartitioner);
             collector.flush();
 
             final List<String> messages = logCaptureAppender.getMessages();
@@ -766,7 +882,7 @@ public class RecordCollectorTest {
         ));
         assertEquals(1.0, metric.metricValue());
 
-        collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner);
+        collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, null, null, streamPartitioner);
         collector.flush();
         collector.closeClean();
     }
@@ -797,7 +913,8 @@ public class RecordCollectorTest {
                 Time.SYSTEM
             ),
             productionExceptionHandler,
-            streamsMetrics
+            streamsMetrics,
+            topology
         );
 
         collector.closeDirty();
@@ -829,13 +946,14 @@ public class RecordCollectorTest {
                 Time.SYSTEM
             ),
             productionExceptionHandler,
-            streamsMetrics
+            streamsMetrics,
+            topology
         );
         collector.initialize();
 
         final StreamsException thrown = assertThrows(
             StreamsException.class,
-            () -> collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner)
+            () -> collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, null, null, streamPartitioner)
         );
         assertThat(
             thrown.getMessage(),
@@ -864,7 +982,8 @@ public class RecordCollectorTest {
                 Time.SYSTEM
             ),
             productionExceptionHandler,
-            streamsMetrics
+            streamsMetrics,
+            topology
         );
 
         collector.closeClean();
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordQueueTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordQueueTest.java
index bea7a05700..9741ba1c17 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordQueueTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordQueueTest.java
@@ -17,9 +17,12 @@
 package org.apache.kafka.streams.processor.internals;
 
 import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.common.Metric;
+import org.apache.kafka.common.MetricName;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.errors.SerializationException;
 import org.apache.kafka.common.header.internals.RecordHeaders;
+import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.record.TimestampType;
 import org.apache.kafka.common.serialization.Deserializer;
 import org.apache.kafka.common.serialization.IntegerDeserializer;
@@ -28,12 +31,15 @@ import org.apache.kafka.common.serialization.Serdes;
 import org.apache.kafka.common.serialization.Serializer;
 import org.apache.kafka.common.utils.Bytes;
 import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.MockTime;
+import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.errors.LogAndContinueExceptionHandler;
 import org.apache.kafka.streams.errors.LogAndFailExceptionHandler;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.processor.FailOnInvalidTimestamp;
 import org.apache.kafka.streams.processor.LogAndSkipOnInvalidTimestamp;
 import org.apache.kafka.streams.processor.TimestampExtractor;
+import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
 import org.apache.kafka.streams.state.StateSerdes;
 import org.apache.kafka.test.InternalMockProcessorContext;
 import org.apache.kafka.test.MockRecordCollector;
@@ -48,6 +54,9 @@ import java.util.Collections;
 import java.util.List;
 import java.util.Optional;
 
+import static org.apache.kafka.streams.processor.internals.ClientUtils.consumerRecordSizeInBytes;
+import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.TOPIC_LEVEL_GROUP;
+
 import static org.hamcrest.CoreMatchers.is;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.equalTo;
@@ -62,10 +71,15 @@ public class RecordQueueTest {
     private final Deserializer<Integer> intDeserializer = new IntegerDeserializer();
     private final TimestampExtractor timestampExtractor = new MockTimestampExtractor();
 
+    private final Metrics metrics = new Metrics();
+    private final StreamsMetricsImpl streamsMetrics =
+        new StreamsMetricsImpl(metrics, "mock", StreamsConfig.METRICS_LATEST, new MockTime());
+
     @SuppressWarnings("rawtypes")
     final InternalMockProcessorContext context = new InternalMockProcessorContext<>(
         StateSerdes.withBuiltinTypes("anyName", Bytes.class, Bytes.class),
-        new MockRecordCollector()
+        new MockRecordCollector(),
+        metrics
     );
     private final MockSourceNode<Integer, Integer> mockSourceNodeWithMetrics
         = new MockSourceNode<>(intDeserializer, intDeserializer);
@@ -98,6 +112,57 @@ public class RecordQueueTest {
         mockSourceNodeWithMetrics.close();
     }
 
+    @Test
+    public void testConsumedSensor() {
+        final List<ConsumerRecord<byte[], byte[]>> records = Arrays.asList(
+            new ConsumerRecord<>("topic", 1, 1, 0L, TimestampType.CREATE_TIME, 0, 0, recordKey, recordValue, new RecordHeaders(), Optional.empty()),
+            new ConsumerRecord<>("topic", 1, 2, 0L, TimestampType.CREATE_TIME, 0, 0, recordKey, recordValue, new RecordHeaders(), Optional.empty()),
+            new ConsumerRecord<>("topic", 1, 3, 0L, TimestampType.CREATE_TIME, 0, 0, recordKey, recordValue, new RecordHeaders(), Optional.empty()));
+
+        queue.addRawRecords(records);
+
+        final String threadId = Thread.currentThread().getName();
+        final String taskId = context.taskId().toString();
+        final String processorNodeId = mockSourceNodeWithMetrics.name();
+        final String topic = "topic";
+        final Metric recordsConsumed = context.metrics().metrics().get(
+            new MetricName("records-consumed-total",
+                           TOPIC_LEVEL_GROUP,
+                           "The total number of records consumed from this topic",
+                           streamsMetrics.topicLevelTagMap(threadId, taskId, processorNodeId, topic))
+        );
+        final Metric bytesConsumed = context.metrics().metrics().get(
+            new MetricName("bytes-consumed-total",
+                           TOPIC_LEVEL_GROUP,
+                           "The total number of bytes consumed from this topic",
+                           streamsMetrics.topicLevelTagMap(threadId, taskId, processorNodeId, topic))
+        );
+
+        double totalBytes = 0D;
+        double totalRecords = 0D;
+
+        queue.poll(5L);
+        ++totalRecords;
+        totalBytes += consumerRecordSizeInBytes(records.get(0));
+
+        assertThat(bytesConsumed.metricValue(), equalTo(totalBytes));
+        assertThat(recordsConsumed.metricValue(), equalTo(totalRecords));
+
+        queue.poll(6L);
+        ++totalRecords;
+        totalBytes += consumerRecordSizeInBytes(records.get(1));
+
+        assertThat(bytesConsumed.metricValue(), equalTo(totalBytes));
+        assertThat(recordsConsumed.metricValue(), equalTo(totalRecords));
+
+        queue.poll(7L);
+        ++totalRecords;
+        totalBytes += consumerRecordSizeInBytes(records.get(2));
+
+        assertThat(bytesConsumed.metricValue(), equalTo(totalBytes));
+        assertThat(recordsConsumed.metricValue(), equalTo(totalRecords));
+    }
+
     @Test
     public void testTimeTracking() {
         assertTrue(queue.isEmpty());
@@ -118,13 +183,13 @@ public class RecordQueueTest {
         assertEquals(2L, queue.headRecordOffset().longValue());
 
         // poll the first record, now with 1, 3
-        assertEquals(2L, queue.poll().timestamp);
+        assertEquals(2L, queue.poll(0).timestamp);
         assertEquals(2, queue.size());
         assertEquals(1L, queue.headRecordTimestamp());
         assertEquals(1L, queue.headRecordOffset().longValue());
 
         // poll the second record, now with 3
-        assertEquals(1L, queue.poll().timestamp);
+        assertEquals(1L, queue.poll(0).timestamp);
         assertEquals(1, queue.size());
         assertEquals(3L, queue.headRecordTimestamp());
         assertEquals(3L, queue.headRecordOffset().longValue());
@@ -143,21 +208,21 @@ public class RecordQueueTest {
         assertEquals(3L, queue.headRecordOffset().longValue());
 
         // poll the third record, now with 4, 1, 2
-        assertEquals(3L, queue.poll().timestamp);
+        assertEquals(3L, queue.poll(0).timestamp);
         assertEquals(3, queue.size());
         assertEquals(4L, queue.headRecordTimestamp());
         assertEquals(4L, queue.headRecordOffset().longValue());
 
         // poll the rest records
-        assertEquals(4L, queue.poll().timestamp);
+        assertEquals(4L, queue.poll(0).timestamp);
         assertEquals(1L, queue.headRecordTimestamp());
         assertEquals(1L, queue.headRecordOffset().longValue());
 
-        assertEquals(1L, queue.poll().timestamp);
+        assertEquals(1L, queue.poll(0).timestamp);
         assertEquals(2L, queue.headRecordTimestamp());
         assertEquals(2L, queue.headRecordOffset().longValue());
 
-        assertEquals(2L, queue.poll().timestamp);
+        assertEquals(2L, queue.poll(0).timestamp);
         assertTrue(queue.isEmpty());
         assertEquals(0, queue.size());
         assertEquals(RecordQueue.UNKNOWN, queue.headRecordTimestamp());
@@ -176,7 +241,7 @@ public class RecordQueueTest {
         assertEquals(4L, queue.headRecordOffset().longValue());
 
         // poll one record again, the timestamp should advance now
-        assertEquals(4L, queue.poll().timestamp);
+        assertEquals(4L, queue.poll(0).timestamp);
         assertEquals(2, queue.size());
         assertEquals(5L, queue.headRecordTimestamp());
         assertEquals(5L, queue.headRecordOffset().longValue());
@@ -218,13 +283,13 @@ public class RecordQueueTest {
         queue.addRawRecords(list1);
         assertThat(queue.partitionTime(), is(RecordQueue.UNKNOWN));
 
-        queue.poll();
+        queue.poll(0);
         assertThat(queue.partitionTime(), is(2L));
 
-        queue.poll();
+        queue.poll(0);
         assertThat(queue.partitionTime(), is(2L));
 
-        queue.poll();
+        queue.poll(0);
         assertThat(queue.partitionTime(), is(3L));
     }
 
@@ -251,13 +316,13 @@ public class RecordQueueTest {
         queue.addRawRecords(list1);
         assertThat(queue.partitionTime(), is(150L));
 
-        queue.poll();
+        queue.poll(0);
         assertThat(queue.partitionTime(), is(200L));
 
         queue.setPartitionTime(500L);
         assertThat(queue.partitionTime(), is(500L));
 
-        queue.poll();
+        queue.poll(0);
         assertThat(queue.partitionTime(), is(500L));
     }
 
@@ -299,7 +364,7 @@ public class RecordQueueTest {
 
         queueThatSkipsDeserializeErrors.addRawRecords(records);
         assertEquals(1, queueThatSkipsDeserializeErrors.size());
-        assertEquals(new CorruptedRecord(record), queueThatSkipsDeserializeErrors.poll());
+        assertEquals(new CorruptedRecord(record), queueThatSkipsDeserializeErrors.poll(0));
     }
 
     @Test
@@ -313,7 +378,7 @@ public class RecordQueueTest {
 
         queueThatSkipsDeserializeErrors.addRawRecords(records);
         assertEquals(1, queueThatSkipsDeserializeErrors.size());
-        assertEquals(new CorruptedRecord(record), queueThatSkipsDeserializeErrors.poll());
+        assertEquals(new CorruptedRecord(record), queueThatSkipsDeserializeErrors.poll(0));
     }
 
     @Test
@@ -394,13 +459,13 @@ public class RecordQueueTest {
         // no (known) timestamp has yet been passed to the timestamp extractor
         assertEquals(RecordQueue.UNKNOWN, timestampExtractor.partitionTime);
 
-        queue.poll();
+        queue.poll(0);
         assertEquals(2L, timestampExtractor.partitionTime);
 
-        queue.poll();
+        queue.poll(0);
         assertEquals(2L, timestampExtractor.partitionTime);
 
-        queue.poll();
+        queue.poll(0);
         assertEquals(3L, timestampExtractor.partitionTime);
 
     }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/WriteConsistencyVectorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/WriteConsistencyVectorTest.java
index 1ca19d8963..d9a68a81b6 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/WriteConsistencyVectorTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/WriteConsistencyVectorTest.java
@@ -115,15 +115,16 @@ public class WriteConsistencyVectorTest {
                 ChangelogRecordDeserializationHelper.CHANGELOG_POSITION_HEADER_KEY,
                 PositionSerde.serialize(position).array()));
         recordCollector.send(
-                CHANGELOG_PARTITION.topic(),
-                KEY_BYTES,
-                VALUE_BYTES,
-                headers,
-                CHANGELOG_PARTITION.partition(),
-                TIMESTAMP,
-                BYTES_KEY_SERIALIZER,
-                BYTEARRAY_VALUE_SERIALIZER
-        );
+            CHANGELOG_PARTITION.topic(),
+            KEY_BYTES,
+            VALUE_BYTES,
+            headers,
+            CHANGELOG_PARTITION.partition(),
+            TIMESTAMP,
+            BYTES_KEY_SERIALIZER,
+            BYTEARRAY_VALUE_SERIALIZER,
+            null,
+            null);
 
         final StreamTask task = EasyMock.createNiceMock(StreamTask.class);
 
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/metrics/StreamsMetricsImplTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/metrics/StreamsMetricsImplTest.java
index 24cf8c7f1c..b8d3d92e62 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/metrics/StreamsMetricsImplTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/metrics/StreamsMetricsImplTest.java
@@ -58,6 +58,7 @@ import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetric
 import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.ROLLUP_VALUE;
 import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.STATE_STORE_LEVEL_GROUP;
 import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.THREAD_LEVEL_GROUP;
+import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.TOPIC_LEVEL_GROUP;
 import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.TOTAL_SUFFIX;
 import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addAvgAndMaxLatencyToSensor;
 import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addInvocationRateAndCountToSensor;
@@ -99,6 +100,10 @@ public class StreamsMetricsImplTest {
     private final static String THREAD_ID1 = "test-thread-1";
     private final static String TASK_ID1 = "test-task-1";
     private final static String TASK_ID2 = "test-task-2";
+    private final static String NODE_ID1 = "test-node-1";
+    private final static String NODE_ID2 = "test-node-2";
+    private final static String TOPIC_ID1 = "test-topic-1";
+    private final static String TOPIC_ID2 = "test-topic-2";
     private final static String METRIC_NAME1 = "test-metric1";
     private final static String METRIC_NAME2 = "test-metric2";
     private final static String THREAD_ID_TAG = "thread-id";
@@ -319,6 +324,46 @@ public class StreamsMetricsImplTest {
         assertThat(actualSensor, is(equalToObject(sensor)));
     }
 
+    @Test
+    public void shouldGetNewTopicLevelSensor() {
+        final Metrics metrics = mock(Metrics.class);
+        final RecordingLevel recordingLevel = RecordingLevel.INFO;
+        setupGetNewSensorTest(metrics, recordingLevel);
+        final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, CLIENT_ID, VERSION, time);
+
+        final Sensor actualSensor = streamsMetrics.topicLevelSensor(
+            THREAD_ID1,
+            TASK_ID1,
+            NODE_ID1,
+            TOPIC_ID1,
+            SENSOR_NAME_1,
+            recordingLevel
+        );
+
+        verify(metrics);
+        assertThat(actualSensor, is(equalToObject(sensor)));
+    }
+
+    @Test
+    public void shouldGetExistingTopicLevelSensor() {
+        final Metrics metrics = mock(Metrics.class);
+        final RecordingLevel recordingLevel = RecordingLevel.INFO;
+        setupGetExistingSensorTest(metrics);
+        final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, CLIENT_ID, VERSION, time);
+
+        final Sensor actualSensor = streamsMetrics.topicLevelSensor(
+            THREAD_ID1,
+            TASK_ID1,
+            NODE_ID1,
+            TOPIC_ID1,
+            SENSOR_NAME_1,
+            recordingLevel
+        );
+
+        verify(metrics);
+        assertThat(actualSensor, is(equalToObject(sensor)));
+    }
+
     @Test
     public void shouldGetNewStoreLevelSensorIfNoneExists() {
         final Metrics metrics = mock(Metrics.class);
@@ -505,14 +550,13 @@ public class StreamsMetricsImplTest {
     public void shouldGetNewNodeLevelSensor() {
         final Metrics metrics = mock(Metrics.class);
         final RecordingLevel recordingLevel = RecordingLevel.INFO;
-        final String processorNodeName = "processorNodeName";
         setupGetNewSensorTest(metrics, recordingLevel);
         final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, CLIENT_ID, VERSION, time);
 
         final Sensor actualSensor = streamsMetrics.nodeLevelSensor(
             THREAD_ID1,
             TASK_ID1,
-            processorNodeName,
+            NODE_ID1,
             SENSOR_NAME_1,
             recordingLevel
         );
@@ -525,14 +569,13 @@ public class StreamsMetricsImplTest {
     public void shouldGetExistingNodeLevelSensor() {
         final Metrics metrics = mock(Metrics.class);
         final RecordingLevel recordingLevel = RecordingLevel.INFO;
-        final String processorNodeName = "processorNodeName";
         setupGetExistingSensorTest(metrics);
         final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, CLIENT_ID, VERSION, time);
 
         final Sensor actualSensor = streamsMetrics.nodeLevelSensor(
             THREAD_ID1,
             TASK_ID1,
-            processorNodeName,
+            NODE_ID1,
             SENSOR_NAME_1,
             recordingLevel
         );
@@ -732,6 +775,9 @@ public class StreamsMetricsImplTest {
         final String processorNodeName = "processorNodeName";
         final Map<String, String> nodeTags = mkMap(mkEntry("nkey", "value"));
 
+        final String topicName = "topicName";
+        final Map<String, String> topicTags = mkMap(mkEntry("tkey", "value"));
+
         final Sensor parent1 = metrics.taskLevelSensor(THREAD_ID1, taskName, operation, RecordingLevel.DEBUG);
         addAvgAndMaxLatencyToSensor(parent1, PROCESSOR_NODE_LEVEL_GROUP, taskTags, operation);
         addInvocationRateAndCountToSensor(parent1, PROCESSOR_NODE_LEVEL_GROUP, taskTags, operation, "", "");
@@ -744,6 +790,18 @@ public class StreamsMetricsImplTest {
 
         assertThat(registry.metrics().size(), greaterThan(numberOfTaskMetrics));
 
+        final int numberOfNodeMetrics = registry.metrics().size();
+
+        final Sensor child1 = metrics.topicLevelSensor(THREAD_ID1, taskName, processorNodeName, topicName, operation, RecordingLevel.DEBUG, sensor1);
+        addAvgAndMaxLatencyToSensor(child1, TOPIC_LEVEL_GROUP, topicTags, operation);
+        addInvocationRateAndCountToSensor(child1, TOPIC_LEVEL_GROUP, topicTags, operation, "", "");
+
+        assertThat(registry.metrics().size(), greaterThan(numberOfNodeMetrics));
+
+        metrics.removeAllTopicLevelSensors(THREAD_ID1, taskName, processorNodeName, topicName);
+
+        assertThat(registry.metrics().size(), equalTo(numberOfNodeMetrics));
+
         metrics.removeAllNodeLevelSensors(THREAD_ID1, taskName, processorNodeName);
 
         assertThat(registry.metrics().size(), equalTo(numberOfTaskMetrics));
@@ -1104,6 +1162,22 @@ public class StreamsMetricsImplTest {
         assertThat(metrics.metrics().size(), equalTo(1 + 1)); // one metric is added automatically in the constructor of Metrics
     }
 
+    @Test
+    public void shouldAddTotalCountAndSumMetricsToSensor() {
+        final String totalMetricNamePrefix = "total";
+        final String sumMetricNamePrefix = "count";
+        StreamsMetricsImpl
+            .addTotalCountAndSumMetricsToSensor(sensor, group, tags, totalMetricNamePrefix, sumMetricNamePrefix, DESCRIPTION1, DESCRIPTION2);
+
+        final double valueToRecord1 = 18.0;
+        final double valueToRecord2 = 42.0;
+        final double expectedCountMetricValue = 2;
+        verifyMetric(totalMetricNamePrefix + "-total", DESCRIPTION1, valueToRecord1, valueToRecord2, expectedCountMetricValue);
+        final double expectedSumMetricValue = 2 * valueToRecord1 + 2 * valueToRecord2; // values are recorded once for each metric verification
+        verifyMetric(sumMetricNamePrefix + "-total", DESCRIPTION2, valueToRecord1, valueToRecord2, expectedSumMetricValue);
+        assertThat(metrics.metrics().size(), equalTo(2 + 1)); // one metric is added automatically in the constructor of Metrics
+    }
+
     @Test
     public void shouldAddAvgAndTotalMetricsToSensor() {
         StreamsMetricsImpl
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/metrics/TaskMetricsTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/metrics/TaskMetricsTest.java
index cababb1c31..9f697759b7 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/metrics/TaskMetricsTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/metrics/TaskMetricsTest.java
@@ -100,7 +100,7 @@ public class TaskMetricsTest {
         );
 
 
-        final Sensor sensor = TaskMetrics.totalBytesSensor(THREAD_ID, TASK_ID, streamsMetrics);
+        final Sensor sensor = TaskMetrics.totalInputBufferBytesSensor(THREAD_ID, TASK_ID, streamsMetrics);
 
         assertThat(sensor, is(expectedSensor));
     }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/metrics/TopicMetricsTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/metrics/TopicMetricsTest.java
new file mode 100644
index 0000000000..c359affe5d
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/metrics/TopicMetricsTest.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals.metrics;
+
+import org.apache.kafka.common.metrics.Sensor;
+import org.apache.kafka.common.metrics.Sensor.RecordingLevel;
+
+import org.junit.AfterClass;
+import org.junit.Test;
+import org.mockito.MockedStatic;
+import java.util.Collections;
+import java.util.Map;
+import java.util.function.Supplier;
+
+import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.TOPIC_LEVEL_GROUP;
+
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.mockStatic;
+import static org.mockito.Mockito.when;
+
+public class TopicMetricsTest {
+
+    private static final String THREAD_ID = "test-thread";
+    private static final String TASK_ID = "test-task";
+    private static final String PROCESSOR_NODE_ID = "test-processor";
+    private static final String TOPIC_NAME = "topic";
+
+    private final Map<String, String> tagMap = Collections.singletonMap("hello", "world");
+
+    private final Sensor expectedSensor = mock(Sensor.class);
+    private static final MockedStatic<StreamsMetricsImpl> STREAMS_METRICS_STATIC_MOCK = mockStatic(StreamsMetricsImpl.class);
+    private final StreamsMetricsImpl streamsMetrics = mock(StreamsMetricsImpl.class);
+
+    @AfterClass
+    public static void cleanUp() {
+        STREAMS_METRICS_STATIC_MOCK.close();
+    }
+
+    @Test
+    public void shouldGetRecordsAndBytesConsumedSensor() {
+        final String recordsMetricNamePrefix = "records-consumed";
+        final String bytesMetricNamePrefix = "bytes-consumed";
+        final String descriptionOfRecordsTotal = "The total number of records consumed from this topic";
+        final String descriptionOfBytesTotal = "The total number of bytes consumed from this topic";
+
+        when(streamsMetrics.topicLevelSensor(THREAD_ID, TASK_ID, PROCESSOR_NODE_ID, TOPIC_NAME, "consumed", RecordingLevel.INFO))
+            .thenReturn(expectedSensor);
+        when(streamsMetrics.topicLevelSensor(THREAD_ID, TASK_ID, PROCESSOR_NODE_ID, TOPIC_NAME, "consumed", RecordingLevel.INFO))
+            .thenReturn(expectedSensor);
+        when(streamsMetrics.topicLevelTagMap(THREAD_ID, TASK_ID, PROCESSOR_NODE_ID, TOPIC_NAME)).thenReturn(tagMap);
+
+        verifySensor(
+            () -> TopicMetrics.consumedSensor(THREAD_ID, TASK_ID, PROCESSOR_NODE_ID, TOPIC_NAME, streamsMetrics)
+        );
+
+        STREAMS_METRICS_STATIC_MOCK.verify(
+            () -> StreamsMetricsImpl.addTotalCountAndSumMetricsToSensor(
+                expectedSensor,
+                TOPIC_LEVEL_GROUP,
+                tagMap,
+                recordsMetricNamePrefix,
+                bytesMetricNamePrefix,
+                descriptionOfRecordsTotal,
+                descriptionOfBytesTotal
+            )
+        );
+    }
+
+    @Test
+    public void shouldGetRecordsAndBytesProducedSensor() {
+        final String recordsMetricNamePrefix = "records-produced";
+        final String bytesMetricNamePrefix = "bytes-produced";
+        final String descriptionOfRecordsTotal = "The total number of records produced to this topic";
+        final String descriptionOfBytesTotal = "The total number of bytes produced to this topic";
+
+        when(streamsMetrics.topicLevelSensor(THREAD_ID, TASK_ID, PROCESSOR_NODE_ID, TOPIC_NAME, "produced", RecordingLevel.INFO))
+            .thenReturn(expectedSensor);
+        when(streamsMetrics.topicLevelSensor(THREAD_ID, TASK_ID, PROCESSOR_NODE_ID, TOPIC_NAME, "produced", RecordingLevel.INFO))
+            .thenReturn(expectedSensor);
+        when(streamsMetrics.topicLevelTagMap(THREAD_ID, TASK_ID, PROCESSOR_NODE_ID, TOPIC_NAME)).thenReturn(tagMap);
+
+        verifySensor(() -> TopicMetrics.producedSensor(THREAD_ID, TASK_ID, PROCESSOR_NODE_ID, TOPIC_NAME, streamsMetrics));
+
+        STREAMS_METRICS_STATIC_MOCK.verify(
+            () -> StreamsMetricsImpl.addTotalCountAndSumMetricsToSensor(
+                expectedSensor,
+                TOPIC_LEVEL_GROUP,
+                tagMap,
+                recordsMetricNamePrefix,
+                bytesMetricNamePrefix,
+                descriptionOfRecordsTotal,
+                descriptionOfBytesTotal
+            )
+        );
+    }
+
+    private void verifySensor(final Supplier<Sensor> sensorSupplier) {
+        final Sensor sensor = sensorSupplier.get();
+        assertThat(sensor, is(expectedSensor));
+    }
+
+}
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java b/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java
index 6a95ccbd08..1783994260 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java
@@ -32,7 +32,9 @@ import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.StateStoreContext;
 import org.apache.kafka.streams.processor.StreamPartitioner;
 import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.internals.InternalProcessorContext;
 import org.apache.kafka.streams.processor.internals.MockStreamsMetrics;
+import org.apache.kafka.streams.processor.internals.ProcessorTopology;
 import org.apache.kafka.streams.processor.internals.RecordCollector;
 import org.apache.kafka.streams.processor.internals.RecordCollectorImpl;
 import org.apache.kafka.streams.processor.internals.StreamsProducer;
@@ -45,6 +47,7 @@ import org.apache.kafka.test.MockTimestampExtractor;
 import org.apache.kafka.test.TestUtils;
 
 import java.io.File;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.LinkedList;
@@ -54,6 +57,9 @@ import java.util.Objects;
 import java.util.Properties;
 import java.util.Set;
 
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
 /**
  * A component that provides a {@link #context() ProcessingContext} that can be supplied to a {@link KeyValueStore} so that
  * all entries written to the Kafka topic by the store during {@link KeyValueStore#flush()} are captured for testing purposes.
@@ -199,6 +205,9 @@ public class KeyValueStoreTestDriver<K, V> {
         props.put(StreamsConfig.ROCKSDB_CONFIG_SETTER_CLASS_CONFIG, MockRocksDbConfigSetter.class);
         props.put(StreamsConfig.METRICS_RECORDING_LEVEL_CONFIG, "DEBUG");
 
+        final ProcessorTopology topology = mock(ProcessorTopology.class);
+        when(topology.sinkTopics()).thenReturn(Collections.emptySet());
+
         final LogContext logContext = new LogContext("KeyValueStoreTestDriver ");
         final RecordCollector recordCollector = new RecordCollectorImpl(
             logContext,
@@ -212,7 +221,8 @@ public class KeyValueStoreTestDriver<K, V> {
                 logContext,
                 Time.SYSTEM),
             new DefaultProductionExceptionHandler(),
-            new MockStreamsMetrics(new Metrics())
+            new MockStreamsMetrics(new Metrics()),
+            topology
         ) {
             @Override
             public <K1, V1> void send(final String topic,
@@ -222,11 +232,16 @@ public class KeyValueStoreTestDriver<K, V> {
                                       final Integer partition,
                                       final Long timestamp,
                                       final Serializer<K1> keySerializer,
-                                      final Serializer<V1> valueSerializer) {
+                                      final Serializer<V1> valueSerializer,
+                                      final String processorNodeId,
+                                      final InternalProcessorContext<Void, Void> context) {
                 // for byte arrays we need to wrap it for comparison
 
-                final K keyTest = serdes.keyFrom(keySerializer.serialize(topic, headers, key));
-                final V valueTest = serdes.valueFrom(valueSerializer.serialize(topic, headers, value));
+                final byte[] keyBytes = keySerializer.serialize(topic, headers, key);
+                final byte[] valueBytes = valueSerializer.serialize(topic, headers, value);
+
+                final K keyTest = serdes.keyFrom(keyBytes);
+                final V valueTest = serdes.valueFrom(valueBytes);
 
                 recordFlushed(keyTest, valueTest);
             }
@@ -239,6 +254,8 @@ public class KeyValueStoreTestDriver<K, V> {
                                       final Long timestamp,
                                       final Serializer<K1> keySerializer,
                                       final Serializer<V1> valueSerializer,
+                                      final String processorNodeId,
+                                      final InternalProcessorContext<Void, Void> context,
                                       final StreamPartitioner<? super K1, ? super V1> partitioner) {
                 throw new UnsupportedOperationException();
             }
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java
index 5b0479d325..a70bf57814 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java
@@ -430,7 +430,9 @@ public class StreamThreadStateStoreProviderTest {
                 Time.SYSTEM
             ),
             streamsConfig.defaultProductionExceptionHandler(),
-            new MockStreamsMetrics(metrics));
+            new MockStreamsMetrics(metrics),
+            topology
+        );
         final StreamsMetricsImpl streamsMetrics = new MockStreamsMetrics(metrics);
         final InternalProcessorContext context = new ProcessorContextImpl(
             taskId,
diff --git a/streams/src/test/java/org/apache/kafka/test/InternalMockProcessorContext.java b/streams/src/test/java/org/apache/kafka/test/InternalMockProcessorContext.java
index 0744995a07..5192a1f678 100644
--- a/streams/src/test/java/org/apache/kafka/test/InternalMockProcessorContext.java
+++ b/streams/src/test/java/org/apache/kafka/test/InternalMockProcessorContext.java
@@ -460,7 +460,9 @@ public class InternalMockProcessorContext<KOut, VOut>
             taskId().partition(),
             timestamp,
             BYTES_KEY_SERIALIZER,
-            BYTEARRAY_VALUE_SERIALIZER);
+            BYTEARRAY_VALUE_SERIALIZER,
+            null,
+            null);
     }
 
     @Override
diff --git a/streams/src/test/java/org/apache/kafka/test/MockRecordCollector.java b/streams/src/test/java/org/apache/kafka/test/MockRecordCollector.java
index 505ee6858a..c99a32c980 100644
--- a/streams/src/test/java/org/apache/kafka/test/MockRecordCollector.java
+++ b/streams/src/test/java/org/apache/kafka/test/MockRecordCollector.java
@@ -21,6 +21,7 @@ import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.header.Headers;
 import org.apache.kafka.common.serialization.Serializer;
 import org.apache.kafka.streams.processor.StreamPartitioner;
+import org.apache.kafka.streams.processor.internals.InternalProcessorContext;
 import org.apache.kafka.streams.processor.internals.RecordCollector;
 
 import java.util.Collections;
@@ -46,13 +47,15 @@ public class MockRecordCollector implements RecordCollector {
                             final Integer partition,
                             final Long timestamp,
                             final Serializer<K> keySerializer,
-                            final Serializer<V> valueSerializer) {
+                            final Serializer<V> valueSerializer,
+                            final String processorNodeId,
+                            final InternalProcessorContext<Void, Void> context) {
         collected.add(new ProducerRecord<>(topic,
-            partition,
-            timestamp,
-            key,
-            value,
-            headers));
+                                           partition,
+                                           timestamp,
+                                           key,
+                                           value,
+                                           headers));
     }
 
     @Override
@@ -63,6 +66,8 @@ public class MockRecordCollector implements RecordCollector {
                             final Long timestamp,
                             final Serializer<K> keySerializer,
                             final Serializer<V> valueSerializer,
+                            final String processorNodeId,
+                            final InternalProcessorContext<Void, Void> context,
                             final StreamPartitioner<? super K, ? super V> partitioner) {
         collected.add(new ProducerRecord<>(topic,
             0, // partition id
diff --git a/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/KStreamSplitTest.scala b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/KStreamSplitTest.scala
index bbcc1b503f..89adbd6b1e 100644
--- a/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/KStreamSplitTest.scala
+++ b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/KStreamSplitTest.scala
@@ -17,6 +17,7 @@
 package org.apache.kafka.streams.scala.kstream
 
 import org.apache.kafka.streams.kstream.Named
+import org.apache.kafka.streams.KeyValue
 import org.apache.kafka.streams.scala.ImplicitConversions._
 import org.apache.kafka.streams.scala.StreamsBuilder
 import org.apache.kafka.streams.scala.serialization.Serdes._
@@ -35,7 +36,7 @@ class KStreamSplitTest extends TestDriver {
     val sinkTopic = Array("default", "even", "three");
 
     val m = builder
-      .stream[Integer, Integer](sourceTopic)
+      .stream[Int, Int](sourceTopic)
       .split(Named.as("_"))
       .branch((_, v) => v % 2 == 0)
       .branch((_, v) => v % 3 == 0)
@@ -46,14 +47,17 @@ class KStreamSplitTest extends TestDriver {
     m("_2").to(sinkTopic(2))
 
     val testDriver = createTestDriver(builder)
-    val testInput = testDriver.createInput[Integer, Integer](sourceTopic)
-    val testOutput = sinkTopic.map(name => testDriver.createOutput[Integer, Integer](name))
-
-    testInput.pipeValueList(
-      List(1, 2, 3, 4, 5)
-        .map(Integer.valueOf)
-        .asJava
-    )
+    val testInput = testDriver.createInput[Int, Int](sourceTopic)
+    val testOutput = sinkTopic.map(name => testDriver.createOutput[Int, Int](name))
+
+    testInput pipeKeyValueList List(
+      new KeyValue(1, 1),
+      new KeyValue(1, 2),
+      new KeyValue(1, 3),
+      new KeyValue(1, 4),
+      new KeyValue(1, 5)
+    ).asJava
+
     assertEquals(List(1, 5), testOutput(0).readValuesToList().asScala)
     assertEquals(List(2, 4), testOutput(1).readValuesToList().asScala)
     assertEquals(List(3), testOutput(2).readValuesToList().asScala)
@@ -67,7 +71,7 @@ class KStreamSplitTest extends TestDriver {
     val sourceTopic = "source"
 
     val m = builder
-      .stream[Integer, Integer](sourceTopic)
+      .stream[Int, Int](sourceTopic)
       .split(Named.as("_"))
       .branch((_, v) => v % 2 == 0, Branched.withConsumer(ks => ks.to("even"), "consumedEvens"))
       .branch((_, v) => v % 3 == 0, Branched.withFunction(ks => ks.mapValues(x => x * x), "mapped"))
@@ -76,15 +80,18 @@ class KStreamSplitTest extends TestDriver {
     m("_mapped").to("mapped")
 
     val testDriver = createTestDriver(builder)
-    val testInput = testDriver.createInput[Integer, Integer](sourceTopic)
-    testInput.pipeValueList(
-      List(1, 2, 3, 4, 5, 9)
-        .map(Integer.valueOf)
-        .asJava
-    )
-
-    val even = testDriver.createOutput[Integer, Integer]("even")
-    val mapped = testDriver.createOutput[Integer, Integer]("mapped")
+    val testInput = testDriver.createInput[Int, Int](sourceTopic)
+    testInput pipeKeyValueList List(
+      new KeyValue(1, 1),
+      new KeyValue(1, 2),
+      new KeyValue(1, 3),
+      new KeyValue(1, 4),
+      new KeyValue(1, 5),
+      new KeyValue(1, 9)
+    ).asJava
+
+    val even = testDriver.createOutput[Int, Int]("even")
+    val mapped = testDriver.createOutput[Int, Int]("mapped")
 
     assertEquals(List(2, 4), even.readValuesToList().asScala)
     assertEquals(List(9, 81), mapped.readValuesToList().asScala)
@@ -98,7 +105,7 @@ class KStreamSplitTest extends TestDriver {
     val sourceTopic = "source"
 
     val m = builder
-      .stream[Integer, Integer](sourceTopic)
+      .stream[Int, Int](sourceTopic)
       .split(Named.as("_"))
       .branch((_, v) => v % 2 == 0, Branched.withConsumer(ks => ks.to("even")))
       .branch((_, v) => v % 3 == 0, Branched.withFunction(ks => ks.mapValues(x => x * x)))
@@ -107,19 +114,23 @@ class KStreamSplitTest extends TestDriver {
     m("_2").to("mapped")
 
     val testDriver = createTestDriver(builder)
-    val testInput = testDriver.createInput[Integer, Integer](sourceTopic)
-    testInput.pipeValueList(
-      List(1, 2, 3, 4, 5, 9)
-        .map(Integer.valueOf)
-        .asJava
-    )
-
-    val even = testDriver.createOutput[Integer, Integer]("even")
-    val mapped = testDriver.createOutput[Integer, Integer]("mapped")
+    val testInput = testDriver.createInput[Int, Int](sourceTopic)
+    testInput pipeKeyValueList List(
+      new KeyValue(1, 1),
+      new KeyValue(1, 2),
+      new KeyValue(1, 3),
+      new KeyValue(1, 4),
+      new KeyValue(1, 5),
+      new KeyValue(1, 9)
+    ).asJava
+
+    val even = testDriver.createOutput[Int, Int]("even")
+    val mapped = testDriver.createOutput[Int, Int]("mapped")
 
     assertEquals(List(2, 4), even.readValuesToList().asScala)
     assertEquals(List(9, 81), mapped.readValuesToList().asScala)
 
     testDriver.close()
   }
+
 }
diff --git a/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
index 49ebf732e9..c918a491ae 100644
--- a/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
+++ b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
@@ -498,7 +498,8 @@ public class TopologyTestDriver implements Closeable {
                 TASK_ID,
                 testDriverProducer,
                 streamsConfig.defaultProductionExceptionHandler(),
-                streamsMetrics
+                streamsMetrics,
+                processorTopology
             );
 
             final InternalProcessorContext context = new ProcessorContextImpl(
diff --git a/streams/test-utils/src/test/java/org/apache/kafka/streams/TestTopicsTest.java b/streams/test-utils/src/test/java/org/apache/kafka/streams/TestTopicsTest.java
index 766729fffa..f39b2b554a 100644
--- a/streams/test-utils/src/test/java/org/apache/kafka/streams/TestTopicsTest.java
+++ b/streams/test-utils/src/test/java/org/apache/kafka/streams/TestTopicsTest.java
@@ -43,6 +43,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.NoSuchElementException;
 import java.util.Properties;
+import java.util.stream.Collectors;
 
 import static org.hamcrest.CoreMatchers.containsString;
 import static org.hamcrest.CoreMatchers.equalTo;
@@ -89,12 +90,12 @@ public class TestTopicsTest {
 
     @Test
     public void testValue() {
-        final TestInputTopic<String, String> inputTopic =
-            testDriver.createInputTopic(INPUT_TOPIC, stringSerde.serializer(), stringSerde.serializer());
+        final TestInputTopic<Long, String> inputTopic =
+            testDriver.createInputTopic(INPUT_TOPIC, longSerde.serializer(), stringSerde.serializer());
         final TestOutputTopic<String, String> outputTopic =
             testDriver.createOutputTopic(OUTPUT_TOPIC, stringSerde.deserializer(), stringSerde.deserializer());
-        //Feed word "Hello" to inputTopic and no kafka key, timestamp is irrelevant in this case
-        inputTopic.pipeInput("Hello");
+        //Feed word "Hello" to inputTopic, timestamp and key irrelevant in this case
+        inputTopic.pipeInput(1L, "Hello");
         assertThat(outputTopic.readValue(), equalTo("Hello"));
         //No more output in topic
         assertThat(outputTopic.isEmpty(), is(true));
@@ -102,16 +103,20 @@ public class TestTopicsTest {
 
     @Test
     public void testValueList() {
-        final TestInputTopic<String, String> inputTopic =
-            testDriver.createInputTopic(INPUT_TOPIC, stringSerde.serializer(), stringSerde.serializer());
+        final TestInputTopic<Long, String> inputTopic =
+            testDriver.createInputTopic(INPUT_TOPIC, longSerde.serializer(), stringSerde.serializer());
         final TestOutputTopic<String, String> outputTopic =
             testDriver.createOutputTopic(OUTPUT_TOPIC, stringSerde.deserializer(), stringSerde.deserializer());
-        final List<String> inputList = Arrays.asList("This", "is", "an", "example");
-        //Feed list of words to inputTopic and no kafka key, timestamp is irrelevant in this case
-        inputTopic.pipeValueList(inputList);
+        final List<KeyValue<Long, String>> inputList = Arrays.asList(
+            new KeyValue<>(1L, "This"),
+            new KeyValue<>(2L, "is"),
+            new KeyValue<>(3L, "an"),
+            new KeyValue<>(4L, "example"));
+        //Feed list of words to inputTopic, key and timestamp are irrelevant in this case
+        inputTopic.pipeKeyValueList(inputList);
         final List<String> output = outputTopic.readValuesToList();
         assertThat(output, hasItems("This", "is", "an", "example"));
-        assertThat(output, is(equalTo(inputList)));
+        assertThat(output, is(equalTo(inputList.stream().map(kv -> kv.value).collect(Collectors.toList()))));
     }
 
     @Test
@@ -166,15 +171,16 @@ public class TestTopicsTest {
     }
 
     @Test
-    public void testKeyValuesToMapWithNull() {
+    public void testPipeInputWithNullKey() {
         final TestInputTopic<Long, String> inputTopic =
             testDriver.createInputTopic(INPUT_TOPIC, longSerde.serializer(), stringSerde.serializer());
         final TestOutputTopic<Long, String> outputTopic =
             testDriver.createOutputTopic(OUTPUT_TOPIC, longSerde.deserializer(), stringSerde.deserializer());
-        inputTopic.pipeInput("value");
-        assertThrows(IllegalStateException.class, outputTopic::readKeyValuesToMap);
-    }
+        final StreamsException exception = assertThrows(StreamsException.class, () -> inputTopic.pipeInput("value"));
+        assertThat(exception.getCause() instanceof NullPointerException, is(true));
+        assertThat(outputTopic.readKeyValuesToMap().isEmpty(), is(true));
 
+    }
 
     @Test
     public void testKeyValueListDuration() {
@@ -229,8 +235,8 @@ public class TestTopicsTest {
             testDriver.createInputTopic(INPUT_TOPIC, longSerde.serializer(), stringSerde.serializer());
         final TestOutputTopic<Long, String> outputTopic =
             testDriver.createOutputTopic(OUTPUT_TOPIC, longSerde.deserializer(), stringSerde.deserializer());
-        inputTopic.pipeInput(null, "Hello", baseTime);
-        assertThat(outputTopic.readRecord(), is(equalTo(new TestRecord<>(null, "Hello", null, baseTime))));
+        inputTopic.pipeInput(1L, "Hello", baseTime);
+        assertThat(outputTopic.readRecord(), is(equalTo(new TestRecord<>(1L, "Hello", null, baseTime))));
 
         inputTopic.pipeInput(2L, "Kafka", ++baseTime);
         assertThat(outputTopic.readRecord(), is(equalTo(new TestRecord<>(2L, "Kafka", null, baseTime))));
@@ -238,13 +244,15 @@ public class TestTopicsTest {
         inputTopic.pipeInput(2L, "Kafka", testBaseTime);
         assertThat(outputTopic.readRecord(), is(equalTo(new TestRecord<>(2L, "Kafka", testBaseTime))));
 
-        final List<String> inputList = Arrays.asList("Advancing", "time");
+        final List<KeyValue<Long, String>> inputList = Arrays.asList(
+            new KeyValue<>(1L, "Advancing"),
+            new KeyValue<>(2L, "time"));
         //Feed list of words to inputTopic and no kafka key, timestamp advancing from testInstant
         final Duration advance = Duration.ofSeconds(15);
         final Instant recordInstant = testBaseTime.plus(Duration.ofDays(1));
-        inputTopic.pipeValueList(inputList, recordInstant, advance);
-        assertThat(outputTopic.readRecord(), is(equalTo(new TestRecord<>(null, "Advancing", recordInstant))));
-        assertThat(outputTopic.readRecord(), is(equalTo(new TestRecord<>(null, "time", null, recordInstant.plus(advance)))));
+        inputTopic.pipeKeyValueList(inputList, recordInstant, advance);
+        assertThat(outputTopic.readRecord(), is(equalTo(new TestRecord<>(1L, "Advancing", recordInstant))));
+        assertThat(outputTopic.readRecord(), is(equalTo(new TestRecord<>(2L, "time", null, recordInstant.plus(advance)))));
     }
 
     @Test
@@ -292,8 +300,8 @@ public class TestTopicsTest {
             testDriver.createInputTopic(INPUT_TOPIC, longSerde.serializer(), stringSerde.serializer(), testBaseTime, advance);
         final TestOutputTopic<Long, String> outputTopic =
             testDriver.createOutputTopic(OUTPUT_TOPIC, longSerde.deserializer(), stringSerde.deserializer());
-        inputTopic.pipeInput("Hello");
-        assertThat(outputTopic.readRecord(), is(equalTo(new TestRecord<>(null, "Hello", testBaseTime))));
+        inputTopic.pipeInput(1L, "Hello");
+        assertThat(outputTopic.readRecord(), is(equalTo(new TestRecord<>(1L, "Hello", testBaseTime))));
         inputTopic.pipeInput(2L, "Kafka");
         assertThat(outputTopic.readRecord(), is(equalTo(new TestRecord<>(2L, "Kafka", testBaseTime.plus(advance)))));
     }
@@ -337,12 +345,12 @@ public class TestTopicsTest {
 
     @Test
     public void testEmptyTopic() {
-        final TestInputTopic<String, String> inputTopic =
-            testDriver.createInputTopic(INPUT_TOPIC, stringSerde.serializer(), stringSerde.serializer());
+        final TestInputTopic<Long, String> inputTopic =
+            testDriver.createInputTopic(INPUT_TOPIC, longSerde.serializer(), stringSerde.serializer());
         final TestOutputTopic<String, String> outputTopic =
             testDriver.createOutputTopic(OUTPUT_TOPIC, stringSerde.deserializer(), stringSerde.deserializer());
         //Feed word "Hello" to inputTopic and no kafka key, timestamp is irrelevant in this case
-        inputTopic.pipeInput("Hello");
+        inputTopic.pipeInput(1L, "Hello");
         assertThat(outputTopic.readValue(), equalTo("Hello"));
         //No more output in topic
         assertThrows(NoSuchElementException.class, outputTopic::readRecord, "Empty topic");