You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by mj...@apache.org on 2020/02/11 22:00:19 UTC

[kafka] branch 2.5 updated: KAFKA-6607: Commit correct offsets for transactional input data (#8040)

This is an automated email from the ASF dual-hosted git repository.

mjsax pushed a commit to branch 2.5
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/2.5 by this push:
     new 4912a8d  KAFKA-6607: Commit correct offsets for transactional input data (#8040)
4912a8d is described below

commit 4912a8d262df4a3ebb26a0d28f36a4c19b439ff8
Author: Matthias J. Sax <ma...@confluent.io>
AuthorDate: Tue Feb 11 13:59:47 2020 -0800

    KAFKA-6607: Commit correct offsets for transactional input data (#8040)
    
    Reviewer: Guozhang Wang <gu...@confluent.io>
---
 .../test/java/org/apache/kafka/test/TestUtils.java |   1 -
 .../processor/internals/PartitionGroup.java        |  41 +++++--
 .../streams/processor/internals/RecordQueue.java   |   6 +-
 .../streams/processor/internals/StreamTask.java    |  21 +++-
 .../streams/processor/internals/StreamThread.java  |  22 ++--
 .../integration/AbstractResetIntegrationTest.java  |  65 +++--------
 .../streams/integration/EosIntegrationTest.java    |  63 +++++++++--
 .../integration/utils/IntegrationTestUtils.java    |  46 +++++++-
 .../processor/internals/PartitionGroupTest.java    | 124 ++++++++++++++++++---
 .../processor/internals/RecordQueueTest.java       |  61 +++++++---
 .../processor/internals/StreamTaskTest.java        |  46 +++++++-
 .../apache/kafka/streams/TopologyTestDriver.java   |  16 ++-
 12 files changed, 390 insertions(+), 122 deletions(-)

diff --git a/clients/src/test/java/org/apache/kafka/test/TestUtils.java b/clients/src/test/java/org/apache/kafka/test/TestUtils.java
index ad2ad99..14eab72 100644
--- a/clients/src/test/java/org/apache/kafka/test/TestUtils.java
+++ b/clients/src/test/java/org/apache/kafka/test/TestUtils.java
@@ -282,7 +282,6 @@ public class TestUtils {
         final Properties properties = new Properties();
         properties.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers);
         properties.put(ProducerConfig.ACKS_CONFIG, "all");
-        properties.put(ProducerConfig.RETRIES_CONFIG, 0);
         properties.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, keySerializer);
         properties.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, valueSerializer);
         properties.putAll(additional);
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/PartitionGroup.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/PartitionGroup.java
index 9160468..9b77345 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/PartitionGroup.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/PartitionGroup.java
@@ -58,15 +58,14 @@ public class PartitionGroup {
     private int totalBuffered;
     private boolean allBuffered;
 
-
-    public static class RecordInfo {
+    static class RecordInfo {
         RecordQueue queue;
 
-        public ProcessorNode node() {
+        ProcessorNode node() {
             return queue.source();
         }
 
-        public TopicPartition partition() {
+        TopicPartition partition() {
             return queue.partition();
         }
 
@@ -84,20 +83,23 @@ public class PartitionGroup {
         streamTime = RecordQueue.UNKNOWN;
     }
 
-    // visible for testing
     long partitionTimestamp(final TopicPartition partition) {
         final RecordQueue queue = partitionQueues.get(partition);
+
         if (queue == null) {
-            throw new NullPointerException("Partition " + partition + " not found.");
+            throw new IllegalStateException("Partition " + partition + " not found.");
         }
+
         return queue.partitionTime();
     }
 
     void setPartitionTime(final TopicPartition partition, final long partitionTime) {
         final RecordQueue queue = partitionQueues.get(partition);
+
         if (queue == null) {
-            throw new NullPointerException("Partition " + partition + " not found.");
+            throw new IllegalStateException("Partition " + partition + " not found.");
         }
+
         if (streamTime < partitionTime) {
             streamTime = partitionTime;
         }
@@ -152,6 +154,10 @@ public class PartitionGroup {
     int addRawRecords(final TopicPartition partition, final Iterable<ConsumerRecord<byte[], byte[]>> rawRecords) {
         final RecordQueue recordQueue = partitionQueues.get(partition);
 
+        if (recordQueue == null) {
+            throw new IllegalStateException("Partition " + partition + " not found.");
+        }
+
         final int oldSize = recordQueue.size();
         final int newSize = recordQueue.addRawRecords(rawRecords);
 
@@ -172,17 +178,27 @@ public class PartitionGroup {
         return newSize;
     }
 
-    public Set<TopicPartition> partitions() {
+    Set<TopicPartition> partitions() {
         return Collections.unmodifiableSet(partitionQueues.keySet());
     }
 
     /**
      * Return the stream-time of this partition group defined as the largest timestamp seen across all partitions
      */
-    public long streamTime() {
+    long streamTime() {
         return streamTime;
     }
 
+    Long headRecordOffset(final TopicPartition partition) {
+        final RecordQueue recordQueue = partitionQueues.get(partition);
+
+        if (recordQueue == null) {
+            throw new IllegalStateException("Partition " + partition + " not found.");
+        }
+
+        return recordQueue.headRecordOffset();
+    }
+
     /**
      * @throws IllegalStateException if the record's partition does not belong to this partition group
      */
@@ -190,7 +206,7 @@ public class PartitionGroup {
         final RecordQueue recordQueue = partitionQueues.get(partition);
 
         if (recordQueue == null) {
-            throw new IllegalStateException(String.format("Record's partition %s does not belong to this partition-group.", partition));
+            throw new IllegalStateException("Partition " + partition + " not found.");
         }
 
         return recordQueue.size();
@@ -204,14 +220,15 @@ public class PartitionGroup {
         return allBuffered;
     }
 
-    public void close() {
+    void close() {
         clear();
         partitionQueues.clear();
     }
 
-    public void clear() {
+    void clear() {
         nonEmptyQueuesByTime.clear();
         streamTime = RecordQueue.UNKNOWN;
+        totalBuffered = 0;
         for (final RecordQueue queue : partitionQueues.values()) {
             queue.clear();
         }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordQueue.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordQueue.java
index 8736b15..e4a2d96 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordQueue.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordQueue.java
@@ -75,7 +75,7 @@ public class RecordQueue {
         );
         this.log = logContext.logger(RecordQueue.class);
     }
- 
+
     void setPartitionTime(final long partitionTime) {
         this.partitionTime = partitionTime;
     }
@@ -156,6 +156,10 @@ public class RecordQueue {
         return headRecord == null ? UNKNOWN : headRecord.timestamp;
     }
 
+    public Long headRecordOffset() {
+        return headRecord == null ? null : headRecord.offset();
+    }
+
     /**
      * Clear the fifo queue of its elements, also clear the time tracker's kept stamped elements
      */
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
index f541a37..5a7790a 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
@@ -496,9 +496,24 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator
         }
 
         final Map<TopicPartition, OffsetAndMetadata> consumedOffsetsAndMetadata = new HashMap<>(consumedOffsets.size());
+
         for (final Map.Entry<TopicPartition, Long> entry : consumedOffsets.entrySet()) {
             final TopicPartition partition = entry.getKey();
-            final long offset = entry.getValue() + 1;
+            Long offset = partitionGroup.headRecordOffset(partition);
+            if (offset == null) {
+                try {
+                    offset = consumer.position(partition);
+                } catch (final TimeoutException error) {
+                    // the `consumer.position()` call should never block, because we know that we did process data
+                    // for the requested partition and thus the consumer should have a valid local position
+                    // that it can return immediately
+
+                    // hence, a `TimeoutException` indicates a bug and thus we rethrow it as fatal `IllegalStateException`
+                    throw new IllegalStateException(error);
+                } catch (final KafkaException fatal) {
+                    throw new StreamsException(fatal);
+                }
+            }
             final long partitionTime = partitionTimes.get(partition);
             consumedOffsetsAndMetadata.put(partition, new OffsetAndMetadata(offset, encodeTimestamp(partitionTime)));
         }
@@ -621,6 +636,8 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator
             try {
                 commit(false, partitionTimes);
             } finally {
+                partitionGroup.clear();
+
                 if (eosEnabled) {
                     stateMgr.checkpoint(activeTaskCheckpointableOffsets());
 
@@ -677,8 +694,6 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator
     private void closeTopology() {
         log.trace("Closing processor topology");
 
-        partitionGroup.clear();
-
         // close the processors
         // make sure close() is called for each node even when there is a RuntimeException
         RuntimeException exception = null;
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
index 8bf7ef9..b6b3d83 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
@@ -754,6 +754,13 @@ public class StreamThread extends Thread {
             throw new StreamsException(logPrefix + "Unexpected state " + state + " during normal iteration");
         }
 
+        final long pollLatency = advanceNowAndComputeLatency();
+
+        if (records != null && !records.isEmpty()) {
+            pollSensor.record(pollLatency, now);
+            addRecordsToTasks(records);
+        }
+
         // Shutdown hook could potentially be triggered and transit the thread state to PENDING_SHUTDOWN during #pollRequests().
         // The task manager internal states could be uninitialized if the state transition happens during #onPartitionsAssigned().
         // Should only proceed when the thread is still running after #pollRequests(), because no external state mutation
@@ -763,13 +770,6 @@ public class StreamThread extends Thread {
             return;
         }
 
-        final long pollLatency = advanceNowAndComputeLatency();
-
-        if (records != null && !records.isEmpty()) {
-            pollSensor.record(pollLatency, now);
-            addRecordsToTasks(records);
-        }
-
         // only try to initialize the assigned tasks
         // if the state is still in PARTITION_ASSIGNED after the poll call
         if (state == State.PARTITIONS_ASSIGNED) {
@@ -917,6 +917,14 @@ public class StreamThread extends Thread {
             final StreamTask task = taskManager.activeTask(partition);
 
             if (task == null) {
+                if (!isRunning()) {
+                    // if we are in PENDING_SHUTDOWN and don't find the task it implies that it was a newly assigned
+                    // task that we just skipped to create;
+                    // hence, we just skip adding the corresponding records
+                    log.info("State already transits to {}, skipping the add records to non-existing task for partition {}", state, partition);
+                    continue;
+                }
+
                 log.error(
                     "Unable to locate active task for received-record partition {}. Current tasks: {}",
                     partition,
diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/AbstractResetIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/AbstractResetIntegrationTest.java
index 675286b..0def694 100644
--- a/streams/src/test/java/org/apache/kafka/streams/integration/AbstractResetIntegrationTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/integration/AbstractResetIntegrationTest.java
@@ -19,7 +19,6 @@ package org.apache.kafka.streams.integration;
 import kafka.tools.StreamsResetter;
 import org.apache.kafka.clients.CommonClientConfigs;
 import org.apache.kafka.clients.admin.Admin;
-import org.apache.kafka.clients.admin.ConsumerGroupDescription;
 import org.apache.kafka.clients.consumer.ConsumerConfig;
 import org.apache.kafka.clients.producer.ProducerConfig;
 import org.apache.kafka.common.config.SslConfigs;
@@ -42,7 +41,6 @@ import org.apache.kafka.streams.kstream.KStream;
 import org.apache.kafka.streams.kstream.Produced;
 import org.apache.kafka.streams.kstream.TimeWindows;
 import org.apache.kafka.test.IntegrationTest;
-import org.apache.kafka.test.TestCondition;
 import org.apache.kafka.test.TestUtils;
 import org.junit.AfterClass;
 import org.junit.Assert;
@@ -62,9 +60,9 @@ import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.Properties;
-import java.util.concurrent.ExecutionException;
 
 import static java.time.Duration.ofMillis;
+import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.waitForEmptyConsumerGroup;
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.MatcherAssert.assertThat;
 
@@ -170,25 +168,11 @@ public abstract class AbstractResetIntegrationTest {
     private static final long CLEANUP_CONSUMER_TIMEOUT = 2000L;
     private static final int TIMEOUT_MULTIPLIER = 15;
 
-    private class ConsumerGroupInactiveCondition implements TestCondition {
-        @Override
-        public boolean conditionMet() {
-            try {
-                final ConsumerGroupDescription groupDescription = adminClient.describeConsumerGroups(Collections.singletonList(appID)).describedGroups().get(appID).get();
-                return groupDescription.members().isEmpty();
-            } catch (final ExecutionException | InterruptedException e) {
-                return false;
-            }
-        }
-    }
-
     void prepareTest() throws Exception {
         prepareConfigs();
         prepareEnvironment();
 
-        // busy wait until cluster (ie, ConsumerGroupCoordinator) is available
-        TestUtils.waitForCondition(new ConsumerGroupInactiveCondition(), TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT,
-                "Test consumer group " + appID + " still active even after waiting " + (TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT) + " ms.");
+        waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT);
 
         cluster.deleteAndRecreateTopics(INPUT_TOPIC, OUTPUT_TOPIC, OUTPUT_TOPIC_2, OUTPUT_TOPIC_2_RERUN);
 
@@ -286,15 +270,13 @@ public abstract class AbstractResetIntegrationTest {
         final List<KeyValue<Long, Long>> result = IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(resultConsumerConfig, OUTPUT_TOPIC, 10);
 
         streams.close();
-        TestUtils.waitForCondition(new ConsumerGroupInactiveCondition(), TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT,
-            "Streams Application consumer group " + appID + " did not time out after " + (TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT) + " ms.");
+        waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT);
 
         // RESET
         streams = new KafkaStreams(setupTopologyWithoutIntermediateUserTopic(), streamsConfig);
         streams.cleanUp();
         cleanGlobal(false, null, null);
-        TestUtils.waitForCondition(new ConsumerGroupInactiveCondition(), TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT,
-                "Reset Tool consumer group " + appID + " did not time out after " + (TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT) + " ms.");
+        waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT);
 
         assertInternalTopicsGotDeleted(null);
 
@@ -305,8 +287,7 @@ public abstract class AbstractResetIntegrationTest {
 
         assertThat(resultRerun, equalTo(result));
 
-        TestUtils.waitForCondition(new ConsumerGroupInactiveCondition(), TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT,
-                "Reset Tool consumer group " + appID + " did not time out after " + (TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT) + " ms.");
+        waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT);
         cleanGlobal(false, null, null);
     }
 
@@ -325,8 +306,7 @@ public abstract class AbstractResetIntegrationTest {
         final List<KeyValue<Long, Long>> result2 = IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(resultConsumerConfig, OUTPUT_TOPIC_2, 40);
 
         streams.close();
-        TestUtils.waitForCondition(new ConsumerGroupInactiveCondition(), TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT,
-            "Streams Application consumer group " + appID + " did not time out after " + (TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT) + " ms.");
+        waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT);
 
         // insert bad record to make sure intermediate user topic gets seekToEnd()
         mockTime.sleep(1);
@@ -341,8 +321,7 @@ public abstract class AbstractResetIntegrationTest {
         streams = new KafkaStreams(setupTopologyWithIntermediateUserTopic(OUTPUT_TOPIC_2_RERUN), streamsConfig);
         streams.cleanUp();
         cleanGlobal(true, null, null);
-        TestUtils.waitForCondition(new ConsumerGroupInactiveCondition(), TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT,
-                "Reset Tool consumer group " + appID + " did not time out after " + (TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT) + " ms.");
+        waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT);
 
         assertInternalTopicsGotDeleted(INTERMEDIATE_USER_TOPIC);
 
@@ -363,8 +342,7 @@ public abstract class AbstractResetIntegrationTest {
         }
         assertThat(resultIntermediate.get(10), equalTo(badMessage));
 
-        TestUtils.waitForCondition(new ConsumerGroupInactiveCondition(), TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT,
-                "Reset Tool consumer group " + appID + " did not time out after " + (TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT) + " ms.");
+        waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT);
         cleanGlobal(true, null, null);
 
         cluster.deleteTopicAndWait(INTERMEDIATE_USER_TOPIC);
@@ -380,8 +358,7 @@ public abstract class AbstractResetIntegrationTest {
         final List<KeyValue<Long, Long>> result = IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(resultConsumerConfig, OUTPUT_TOPIC, 10);
 
         streams.close();
-        TestUtils.waitForCondition(new ConsumerGroupInactiveCondition(), TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT,
-            "Streams Application consumer group " + appID + " did not time out after " + (TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT) + " ms.");
+        waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT);
 
         // RESET
         final File resetFile = File.createTempFile("reset", ".csv");
@@ -393,8 +370,7 @@ public abstract class AbstractResetIntegrationTest {
         streams.cleanUp();
 
         cleanGlobal(false, "--from-file", resetFile.getAbsolutePath());
-        TestUtils.waitForCondition(new ConsumerGroupInactiveCondition(), TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT,
-                "Reset Tool consumer group " + appID + " did not time out after " + (TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT) + " ms.");
+        waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT);
 
         assertInternalTopicsGotDeleted(null);
 
@@ -408,8 +384,7 @@ public abstract class AbstractResetIntegrationTest {
         result.remove(0);
         assertThat(resultRerun, equalTo(result));
 
-        TestUtils.waitForCondition(new ConsumerGroupInactiveCondition(), TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT,
-                "Reset Tool consumer group " + appID + " did not time out after " + (TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT) + " ms.");
+        waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT);
         cleanGlobal(false, null, null);
     }
 
@@ -423,8 +398,7 @@ public abstract class AbstractResetIntegrationTest {
         final List<KeyValue<Long, Long>> result = IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(resultConsumerConfig, OUTPUT_TOPIC, 10);
 
         streams.close();
-        TestUtils.waitForCondition(new ConsumerGroupInactiveCondition(), TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT,
-            "Streams Application consumer group " + appID + " did not time out after " + (TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT) + " ms.");
+        waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT);
 
         // RESET
         final File resetFile = File.createTempFile("reset", ".csv");
@@ -441,8 +415,7 @@ public abstract class AbstractResetIntegrationTest {
         calendar.add(Calendar.DATE, -1);
 
         cleanGlobal(false, "--to-datetime", format.format(calendar.getTime()));
-        TestUtils.waitForCondition(new ConsumerGroupInactiveCondition(), TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT,
-                "Reset Tool consumer group " + appID + " did not time out after " + (TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT) + " ms.");
+        waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT);
 
         assertInternalTopicsGotDeleted(null);
 
@@ -455,8 +428,7 @@ public abstract class AbstractResetIntegrationTest {
 
         assertThat(resultRerun, equalTo(result));
 
-        TestUtils.waitForCondition(new ConsumerGroupInactiveCondition(), TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT,
-                "Reset Tool consumer group " + appID + " did not time out after " + (TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT) + " ms.");
+        waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT);
         cleanGlobal(false, null, null);
     }
 
@@ -470,8 +442,7 @@ public abstract class AbstractResetIntegrationTest {
         final List<KeyValue<Long, Long>> result = IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(resultConsumerConfig, OUTPUT_TOPIC, 10);
 
         streams.close();
-        TestUtils.waitForCondition(new ConsumerGroupInactiveCondition(), TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT,
-            "Streams Application consumer group " + appID + "  did not time out after " + (TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT) + " ms.");
+        waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT);
 
         // RESET
         final File resetFile = File.createTempFile("reset", ".csv");
@@ -483,8 +454,7 @@ public abstract class AbstractResetIntegrationTest {
         streams.cleanUp();
         cleanGlobal(false, "--by-duration", "PT1M");
 
-        TestUtils.waitForCondition(new ConsumerGroupInactiveCondition(), TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT,
-            "Reset Tool consumer group " + appID + " did not time out after " + (TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT) + " ms.");
+        waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT);
 
         assertInternalTopicsGotDeleted(null);
 
@@ -497,8 +467,7 @@ public abstract class AbstractResetIntegrationTest {
 
         assertThat(resultRerun, equalTo(result));
 
-        TestUtils.waitForCondition(new ConsumerGroupInactiveCondition(), TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT,
-            "Reset Tool consumer group " + appID + " did not time out after " + (TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT) + " ms.");
+        waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT);
         cleanGlobal(false, null, null);
     }
 
diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/EosIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/EosIntegrationTest.java
index 4ff847f..cb06703 100644
--- a/streams/src/test/java/org/apache/kafka/streams/integration/EosIntegrationTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/integration/EosIntegrationTest.java
@@ -16,8 +16,15 @@
  */
 package org.apache.kafka.streams.integration;
 
+import org.apache.kafka.clients.admin.Admin;
+import org.apache.kafka.clients.admin.AdminClientConfig;
+import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.clients.consumer.ConsumerConfig;
+import org.apache.kafka.clients.consumer.KafkaConsumer;
+import org.apache.kafka.clients.producer.ProducerConfig;
 import org.apache.kafka.common.IsolationLevel;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.serialization.ByteArrayDeserializer;
 import org.apache.kafka.common.serialization.LongDeserializer;
 import org.apache.kafka.common.serialization.LongSerializer;
 import org.apache.kafka.common.serialization.Serdes;
@@ -50,6 +57,7 @@ import org.junit.experimental.categories.Category;
 
 import java.io.File;
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -61,6 +69,9 @@ import java.util.Set;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
 
+import static org.apache.kafka.common.utils.Utils.mkEntry;
+import static org.apache.kafka.common.utils.Utils.mkMap;
+import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.waitForEmptyConsumerGroup;
 import static org.apache.kafka.test.StreamsTestUtils.startKafkaStreamsAndWaitForRunningState;
 import static org.apache.kafka.test.TestUtils.waitForCondition;
 import static org.hamcrest.CoreMatchers.equalTo;
@@ -116,38 +127,66 @@ public class EosIntegrationTest {
 
     @Test
     public void shouldBeAbleToRunWithEosEnabled() throws Exception {
-        runSimpleCopyTest(1, SINGLE_PARTITION_INPUT_TOPIC, null, SINGLE_PARTITION_OUTPUT_TOPIC);
+        runSimpleCopyTest(1, SINGLE_PARTITION_INPUT_TOPIC, null, SINGLE_PARTITION_OUTPUT_TOPIC, false);
+    }
+
+    @Test
+    public void shouldCommitCorrectOffsetIfInputTopicIsTransactional() throws Exception {
+        runSimpleCopyTest(1, SINGLE_PARTITION_INPUT_TOPIC, null, SINGLE_PARTITION_OUTPUT_TOPIC, true);
+
+        try (final Admin adminClient = Admin.create(mkMap(mkEntry(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers())));
+            final Consumer<byte[], byte[]> consumer = new KafkaConsumer<>(mkMap(
+                mkEntry(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()),
+                mkEntry(ConsumerConfig.GROUP_ID_CONFIG, applicationId),
+                mkEntry(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, ByteArrayDeserializer.class),
+                mkEntry(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, ByteArrayDeserializer.class)))) {
+
+            waitForEmptyConsumerGroup(adminClient, applicationId, 5 * MAX_POLL_INTERVAL_MS);
+
+            final TopicPartition topicPartition = new TopicPartition(SINGLE_PARTITION_INPUT_TOPIC, 0);
+            final Collection<TopicPartition> topicPartitions = Collections.singleton(topicPartition);
+
+            final long committedOffset = adminClient.listConsumerGroupOffsets(applicationId).partitionsToOffsetAndMetadata().get().get(topicPartition).offset();
+
+            consumer.assign(topicPartitions);
+            final long consumerPosition = consumer.position(topicPartition);
+            final long endOffset = consumer.endOffsets(topicPartitions).get(topicPartition);
+
+            assertThat(committedOffset, equalTo(consumerPosition));
+            assertThat(committedOffset, equalTo(endOffset));
+        }
     }
 
     @Test
     public void shouldBeAbleToRestartAfterClose() throws Exception {
-        runSimpleCopyTest(2, SINGLE_PARTITION_INPUT_TOPIC, null, SINGLE_PARTITION_OUTPUT_TOPIC);
+        runSimpleCopyTest(2, SINGLE_PARTITION_INPUT_TOPIC, null, SINGLE_PARTITION_OUTPUT_TOPIC, false);
     }
 
     @Test
     public void shouldBeAbleToCommitToMultiplePartitions() throws Exception {
-        runSimpleCopyTest(1, SINGLE_PARTITION_INPUT_TOPIC, null, MULTI_PARTITION_OUTPUT_TOPIC);
+        runSimpleCopyTest(1, SINGLE_PARTITION_INPUT_TOPIC, null, MULTI_PARTITION_OUTPUT_TOPIC, false);
     }
 
     @Test
     public void shouldBeAbleToCommitMultiplePartitionOffsets() throws Exception {
-        runSimpleCopyTest(1, MULTI_PARTITION_INPUT_TOPIC, null, SINGLE_PARTITION_OUTPUT_TOPIC);
+        runSimpleCopyTest(1, MULTI_PARTITION_INPUT_TOPIC, null, SINGLE_PARTITION_OUTPUT_TOPIC, false);
     }
 
     @Test
     public void shouldBeAbleToRunWithTwoSubtopologies() throws Exception {
-        runSimpleCopyTest(1, SINGLE_PARTITION_INPUT_TOPIC, SINGLE_PARTITION_THROUGH_TOPIC, SINGLE_PARTITION_OUTPUT_TOPIC);
+        runSimpleCopyTest(1, SINGLE_PARTITION_INPUT_TOPIC, SINGLE_PARTITION_THROUGH_TOPIC, SINGLE_PARTITION_OUTPUT_TOPIC, false);
     }
 
     @Test
     public void shouldBeAbleToRunWithTwoSubtopologiesAndMultiplePartitions() throws Exception {
-        runSimpleCopyTest(1, MULTI_PARTITION_INPUT_TOPIC, MULTI_PARTITION_THROUGH_TOPIC, MULTI_PARTITION_OUTPUT_TOPIC);
+        runSimpleCopyTest(1, MULTI_PARTITION_INPUT_TOPIC, MULTI_PARTITION_THROUGH_TOPIC, MULTI_PARTITION_OUTPUT_TOPIC, false);
     }
 
     private void runSimpleCopyTest(final int numberOfRestarts,
                                    final String inputTopic,
                                    final String throughTopic,
-                                   final String outputTopic) throws Exception {
+                                   final String outputTopic,
+                                   final boolean inputTopicTransactional) throws Exception {
         final StreamsBuilder builder = new StreamsBuilder();
         final KStream<Long, Long> input = builder.stream(inputTopic);
         KStream<Long, Long> output = input;
@@ -177,11 +216,17 @@ public class EosIntegrationTest {
 
                 final List<KeyValue<Long, Long>> inputData = prepareData(i * 100, i * 100 + 10L, 0L, 1L);
 
+                final Properties producerConfigs = new Properties();
+                if (inputTopicTransactional) {
+                    producerConfigs.setProperty(ProducerConfig.TRANSACTIONAL_ID_CONFIG, applicationId + "-input-producer");
+                }
+
                 IntegrationTestUtils.produceKeyValuesSynchronously(
                     inputTopic,
                     inputData,
-                    TestUtils.producerConfig(CLUSTER.bootstrapServers(), LongSerializer.class, LongSerializer.class),
-                    CLUSTER.time
+                    TestUtils.producerConfig(CLUSTER.bootstrapServers(), LongSerializer.class, LongSerializer.class, producerConfigs),
+                    CLUSTER.time,
+                    inputTopicTransactional
                 );
 
                 final List<KeyValue<Long, Long>> committedRecords =
diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/utils/IntegrationTestUtils.java b/streams/src/test/java/org/apache/kafka/streams/integration/utils/IntegrationTestUtils.java
index ef508a0..42d203c 100644
--- a/streams/src/test/java/org/apache/kafka/streams/integration/utils/IntegrationTestUtils.java
+++ b/streams/src/test/java/org/apache/kafka/streams/integration/utils/IntegrationTestUtils.java
@@ -16,15 +16,11 @@
  */
 package org.apache.kafka.streams.integration.utils;
 
-import java.lang.reflect.Field;
-import java.util.Map.Entry;
-import java.util.concurrent.TimeUnit;
-import java.util.concurrent.locks.Condition;
-import java.util.concurrent.locks.Lock;
-import java.util.concurrent.locks.ReentrantLock;
 import kafka.api.Request;
 import kafka.server.KafkaServer;
 import kafka.server.MetadataCache;
+import org.apache.kafka.clients.admin.Admin;
+import org.apache.kafka.clients.admin.ConsumerGroupDescription;
 import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.clients.consumer.ConsumerConfig;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
@@ -55,6 +51,7 @@ import scala.Option;
 
 import java.io.File;
 import java.io.IOException;
+import java.lang.reflect.Field;
 import java.nio.file.Paths;
 import java.time.Duration;
 import java.util.ArrayList;
@@ -65,12 +62,17 @@ import java.util.Iterator;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
+import java.util.Map.Entry;
 import java.util.Objects;
 import java.util.Optional;
 import java.util.Properties;
 import java.util.Set;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.locks.Condition;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReentrantLock;
 import java.util.stream.Collectors;
 
 import static org.apache.kafka.test.TestUtils.retryOnExceptionWithTimeout;
@@ -839,6 +841,38 @@ public class IntegrationTestUtils {
         });
     }
 
+    private static class ConsumerGroupInactiveCondition implements TestCondition {
+        private final Admin adminClient;
+        private final String applicationId;
+
+        private ConsumerGroupInactiveCondition(final Admin adminClient,
+                                               final String applicationId) {
+            this.adminClient = adminClient;
+            this.applicationId = applicationId;
+        }
+
+        @Override
+        public boolean conditionMet() {
+            try {
+                final ConsumerGroupDescription groupDescription =
+                    adminClient.describeConsumerGroups(Collections.singletonList(applicationId))
+                        .describedGroups()
+                        .get(applicationId)
+                        .get();
+                return groupDescription.members().isEmpty();
+            } catch (final ExecutionException | InterruptedException e) {
+                return false;
+            }
+        }
+    }
+
+    public static void waitForEmptyConsumerGroup(final Admin adminClient,
+                                                 final String applicationId,
+                                                 final long timeoutMs) throws Exception {
+        TestUtils.waitForCondition(new IntegrationTestUtils.ConsumerGroupInactiveCondition(adminClient, applicationId), timeoutMs,
+            "Test consumer group " + applicationId + " still active even after waiting " + timeoutMs + " ms.");
+    }
+
     private static StateListener getStateListener(final KafkaStreams streams) {
         try {
             final Field field = streams.getClass().getDeclaredField("stateListener");
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/PartitionGroupTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/PartitionGroupTest.java
index 3584f9c..48a4160 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/PartitionGroupTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/PartitionGroupTest.java
@@ -39,7 +39,10 @@ import java.util.List;
 
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.core.IsEqual.equalTo;
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertThrows;
 
 public class PartitionGroupTest {
@@ -47,8 +50,8 @@ public class PartitionGroupTest {
     private final Serializer<Integer> intSerializer = new IntegerSerializer();
     private final Deserializer<Integer> intDeserializer = new IntegerDeserializer();
     private final TimestampExtractor timestampExtractor = new MockTimestampExtractor();
-    private final TopicPartition randomPartition = new TopicPartition("random-partition", 0);
-    private final String errMessage = "Partition " + randomPartition + " not found.";
+    private final TopicPartition unknownPartition = new TopicPartition("unknown-partition", 0);
+    private final String errMessage = "Partition " + unknownPartition + " not found.";
     private final String[] topics = {"topic"};
     private final TopicPartition partition1 = new TopicPartition(topics[0], 1);
     private final TopicPartition partition2 = new TopicPartition(topics[0], 2);
@@ -88,6 +91,14 @@ public class PartitionGroupTest {
 
     @Test
     public void testTimeTracking() {
+        testFirstBatch();
+        testSecondBatch();
+    }
+
+    private void testFirstBatch() {
+        StampedRecord record;
+        final PartitionGroup.RecordInfo info = new PartitionGroup.RecordInfo();
+
         assertEquals(0, group.numBuffered());
         // add three 3 records with timestamp 1, 3, 5 to partition-1
         final List<ConsumerRecord<byte[], byte[]>> list1 = Arrays.asList(
@@ -109,12 +120,13 @@ public class PartitionGroupTest {
         // st: -1 since no records was being processed yet
 
         verifyBuffered(6, 3, 3);
+        assertEquals(1L, group.partitionTimestamp(partition1));
+        assertEquals(2L, group.partitionTimestamp(partition2));
+        assertEquals(1L, group.headRecordOffset(partition1).longValue());
+        assertEquals(2L, group.headRecordOffset(partition2).longValue());
         assertEquals(-1L, group.streamTime());
         assertEquals(0.0, metrics.metric(lastLatenessValue).metricValue());
 
-        StampedRecord record;
-        final PartitionGroup.RecordInfo info = new PartitionGroup.RecordInfo();
-
         // get one record, now the time should be advanced
         record = group.nextRecord(info);
         // 1:[3, 5]
@@ -123,6 +135,8 @@ public class PartitionGroupTest {
         assertEquals(partition1, info.partition());
         assertEquals(3L, group.partitionTimestamp(partition1));
         assertEquals(2L, group.partitionTimestamp(partition2));
+        assertEquals(3L, group.headRecordOffset(partition1).longValue());
+        assertEquals(2L, group.headRecordOffset(partition2).longValue());
         assertEquals(1L, group.streamTime());
         verifyTimes(record, 1L, 1L);
         verifyBuffered(5, 2, 3);
@@ -136,10 +150,17 @@ public class PartitionGroupTest {
         assertEquals(partition2, info.partition());
         assertEquals(3L, group.partitionTimestamp(partition1));
         assertEquals(4L, group.partitionTimestamp(partition2));
+        assertEquals(3L, group.headRecordOffset(partition1).longValue());
+        assertEquals(4L, group.headRecordOffset(partition2).longValue());
         assertEquals(2L, group.streamTime());
         verifyTimes(record, 2L, 2L);
         verifyBuffered(4, 2, 2);
         assertEquals(0.0, metrics.metric(lastLatenessValue).metricValue());
+    }
+
+    private void testSecondBatch() {
+        StampedRecord record;
+        final PartitionGroup.RecordInfo info = new PartitionGroup.RecordInfo();
 
         // add 2 more records with timestamp 2, 4 to partition-1
         final List<ConsumerRecord<byte[], byte[]>> list3 = Arrays.asList(
@@ -153,6 +174,8 @@ public class PartitionGroupTest {
         verifyBuffered(6, 4, 2);
         assertEquals(3L, group.partitionTimestamp(partition1));
         assertEquals(4L, group.partitionTimestamp(partition2));
+        assertEquals(3L, group.headRecordOffset(partition1).longValue());
+        assertEquals(4L, group.headRecordOffset(partition2).longValue());
         assertEquals(2L, group.streamTime());
         assertEquals(0.0, metrics.metric(lastLatenessValue).metricValue());
 
@@ -164,6 +187,8 @@ public class PartitionGroupTest {
         assertEquals(partition1, info.partition());
         assertEquals(5L, group.partitionTimestamp(partition1));
         assertEquals(4L, group.partitionTimestamp(partition2));
+        assertEquals(5L, group.headRecordOffset(partition1).longValue());
+        assertEquals(4L, group.headRecordOffset(partition2).longValue());
         assertEquals(3L, group.streamTime());
         verifyTimes(record, 3L, 3L);
         verifyBuffered(5, 3, 2);
@@ -177,6 +202,8 @@ public class PartitionGroupTest {
         assertEquals(partition2, info.partition());
         assertEquals(5L, group.partitionTimestamp(partition1));
         assertEquals(6L, group.partitionTimestamp(partition2));
+        assertEquals(5L, group.headRecordOffset(partition1).longValue());
+        assertEquals(6L, group.headRecordOffset(partition2).longValue());
         assertEquals(4L, group.streamTime());
         verifyTimes(record, 4L, 4L);
         verifyBuffered(4, 3, 1);
@@ -190,6 +217,8 @@ public class PartitionGroupTest {
         assertEquals(partition1, info.partition());
         assertEquals(5L, group.partitionTimestamp(partition1));
         assertEquals(6L, group.partitionTimestamp(partition2));
+        assertEquals(2L, group.headRecordOffset(partition1).longValue());
+        assertEquals(6L, group.headRecordOffset(partition2).longValue());
         assertEquals(5L, group.streamTime());
         verifyTimes(record, 5L, 5L);
         verifyBuffered(3, 2, 1);
@@ -203,6 +232,8 @@ public class PartitionGroupTest {
         assertEquals(partition1, info.partition());
         assertEquals(5L, group.partitionTimestamp(partition1));
         assertEquals(6L, group.partitionTimestamp(partition2));
+        assertEquals(4L, group.headRecordOffset(partition1).longValue());
+        assertEquals(6L, group.headRecordOffset(partition2).longValue());
         assertEquals(5L, group.streamTime());
         verifyTimes(record, 2L, 5L);
         verifyBuffered(2, 1, 1);
@@ -216,6 +247,8 @@ public class PartitionGroupTest {
         assertEquals(partition1, info.partition());
         assertEquals(5L, group.partitionTimestamp(partition1));
         assertEquals(6L, group.partitionTimestamp(partition2));
+        assertNull(group.headRecordOffset(partition1));
+        assertEquals(6L, group.headRecordOffset(partition2).longValue());
         assertEquals(5L, group.streamTime());
         verifyTimes(record, 4L, 5L);
         verifyBuffered(1, 0, 1);
@@ -229,6 +262,8 @@ public class PartitionGroupTest {
         assertEquals(partition2, info.partition());
         assertEquals(5L, group.partitionTimestamp(partition1));
         assertEquals(6L, group.partitionTimestamp(partition2));
+        assertNull(group.headRecordOffset(partition1));
+        assertNull(group.headRecordOffset(partition2));
         assertEquals(6L, group.streamTime());
         verifyTimes(record, 6L, 6L);
         verifyBuffered(0, 0, 0);
@@ -305,16 +340,79 @@ public class PartitionGroupTest {
     }
 
     @Test
-    public void shouldThrowNullpointerUponSetPartitionTimestampFailure() {
-        assertThrows(errMessage, NullPointerException.class, () -> {
-            group.setPartitionTime(randomPartition, 0L);
-        });
+    public void shouldThrowIllegalStateExceptionUponAddRecordsIfPartitionUnknown() {
+        final IllegalStateException exception = assertThrows(
+            IllegalStateException.class,
+            () -> group.addRawRecords(unknownPartition, null));
+        assertThat(errMessage, equalTo(exception.getMessage()));
+    }
+
+    @Test
+    public void shouldThrowIllegalStateExceptionUponNumBufferedIfPartitionUnknown() {
+        final IllegalStateException exception = assertThrows(
+            IllegalStateException.class,
+            () -> group.numBuffered(unknownPartition));
+        assertThat(errMessage, equalTo(exception.getMessage()));
+    }
+
+    @Test
+    public void shouldThrowIllegalStateExceptionUponSetPartitionTimestampIfPartitionUnknown() {
+        final IllegalStateException exception = assertThrows(
+            IllegalStateException.class,
+            () -> group.setPartitionTime(unknownPartition, 0L));
+        assertThat(errMessage, equalTo(exception.getMessage()));
+    }
+
+    @Test
+    public void shouldThrowIllegalStateExceptionUponGetPartitionTimestampIfPartitionUnknown() {
+        final IllegalStateException exception = assertThrows(
+            IllegalStateException.class,
+            () -> group.partitionTimestamp(unknownPartition));
+        assertThat(errMessage, equalTo(exception.getMessage()));
+    }
+
+    @Test
+    public void shouldThrowIllegalStateExceptionUponGetHeadRecordOffsetIfPartitionUnknown() {
+        final IllegalStateException exception = assertThrows(
+            IllegalStateException.class,
+            () -> group.headRecordOffset(unknownPartition));
+        assertThat(errMessage, equalTo(exception.getMessage()));
+    }
+
+    @Test
+    public void shouldEmpyPartitionsOnClean() {
+        final List<ConsumerRecord<byte[], byte[]>> list = Arrays.asList(
+            new ConsumerRecord<>("topic", 1, 1L, recordKey, recordValue),
+            new ConsumerRecord<>("topic", 1, 3L, recordKey, recordValue),
+            new ConsumerRecord<>("topic", 1, 5L, recordKey, recordValue));
+        group.addRawRecords(partition1, list);
+
+        group.clear();
+
+        assertThat(group.numBuffered(), equalTo(0));
+        assertThat(group.streamTime(), equalTo(RecordQueue.UNKNOWN));
+        assertThat(group.nextRecord(new PartitionGroup.RecordInfo()), equalTo(null));
+
+        group.addRawRecords(partition1, list);
     }
 
     @Test
-    public void shouldThrowNullpointerUponGetPartitionTimestampFailure() {
-        assertThrows(errMessage, NullPointerException.class, () -> {
-            group.partitionTimestamp(randomPartition);
-        });
+    public void shouldCleanPartitionsOnClose() {
+        final List<ConsumerRecord<byte[], byte[]>> list = Arrays.asList(
+            new ConsumerRecord<>("topic", 1, 1L, recordKey, recordValue),
+            new ConsumerRecord<>("topic", 1, 3L, recordKey, recordValue),
+            new ConsumerRecord<>("topic", 1, 5L, recordKey, recordValue));
+        group.addRawRecords(partition1, list);
+
+        group.close();
+
+        assertThat(group.numBuffered(), equalTo(0));
+        assertThat(group.streamTime(), equalTo(RecordQueue.UNKNOWN));
+        assertThat(group.nextRecord(new PartitionGroup.RecordInfo()), equalTo(null));
+
+        final IllegalStateException exception = assertThrows(
+            IllegalStateException.class,
+            () -> group.addRawRecords(partition1, list));
+        assertThat("Partition topic-1 not found.", equalTo(exception.getMessage()));
     }
 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordQueueTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordQueueTest.java
index d12481a..16daa52 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordQueueTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordQueueTest.java
@@ -18,6 +18,7 @@ package org.apache.kafka.streams.processor.internals;
 
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.errors.SerializationException;
 import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.metrics.Sensor;
 import org.apache.kafka.common.record.TimestampType;
@@ -47,14 +48,18 @@ import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 
+import static org.hamcrest.CoreMatchers.equalTo;
+import static org.hamcrest.CoreMatchers.instanceOf;
+import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
 
 public class RecordQueueTest {
     private final Serializer<Integer> intSerializer = new IntegerSerializer();
     private final Deserializer<Integer> intDeserializer = new IntegerDeserializer();
     private final TimestampExtractor timestampExtractor = new MockTimestampExtractor();
-    private final String[] topics = {"topic"};
 
     private final Sensor droppedRecordsSensor = new Metrics().sensor("skipped-records");
 
@@ -67,16 +72,16 @@ public class RecordQueueTest {
             droppedRecordsSensor
         )
     );
-    private final MockSourceNode mockSourceNodeWithMetrics = new MockSourceNode<>(topics, intDeserializer, intDeserializer);
+    private final MockSourceNode mockSourceNodeWithMetrics = new MockSourceNode<>(new String[] {"topic"}, intDeserializer, intDeserializer);
     private final RecordQueue queue = new RecordQueue(
-        new TopicPartition(topics[0], 1),
+        new TopicPartition("topic", 1),
         mockSourceNodeWithMetrics,
         timestampExtractor,
         new LogAndFailExceptionHandler(),
         context,
         new LogContext());
     private final RecordQueue queueThatSkipsDeserializeErrors = new RecordQueue(
-        new TopicPartition(topics[0], 1),
+        new TopicPartition("topic", 1),
         mockSourceNodeWithMetrics,
         timestampExtractor,
         new LogAndContinueExceptionHandler(),
@@ -98,10 +103,10 @@ public class RecordQueueTest {
 
     @Test
     public void testTimeTracking() {
-
         assertTrue(queue.isEmpty());
         assertEquals(0, queue.size());
         assertEquals(RecordQueue.UNKNOWN, queue.headRecordTimestamp());
+        assertNull(queue.headRecordOffset());
 
         // add three 3 out-of-order records with timestamp 2, 1, 3
         final List<ConsumerRecord<byte[], byte[]>> list1 = Arrays.asList(
@@ -113,16 +118,19 @@ public class RecordQueueTest {
 
         assertEquals(3, queue.size());
         assertEquals(2L, queue.headRecordTimestamp());
+        assertEquals(2L, queue.headRecordOffset().longValue());
 
         // poll the first record, now with 1, 3
         assertEquals(2L, queue.poll().timestamp);
         assertEquals(2, queue.size());
         assertEquals(1L, queue.headRecordTimestamp());
+        assertEquals(1L, queue.headRecordOffset().longValue());
 
         // poll the second record, now with 3
         assertEquals(1L, queue.poll().timestamp);
         assertEquals(1, queue.size());
         assertEquals(3L, queue.headRecordTimestamp());
+        assertEquals(3L, queue.headRecordOffset().longValue());
 
         // add three 3 out-of-order records with timestamp 4, 1, 2
         // now with 3, 4, 1, 2
@@ -135,23 +143,28 @@ public class RecordQueueTest {
 
         assertEquals(4, queue.size());
         assertEquals(3L, queue.headRecordTimestamp());
+        assertEquals(3L, queue.headRecordOffset().longValue());
 
         // poll the third record, now with 4, 1, 2
         assertEquals(3L, queue.poll().timestamp);
         assertEquals(3, queue.size());
         assertEquals(4L, queue.headRecordTimestamp());
+        assertEquals(4L, queue.headRecordOffset().longValue());
 
         // poll the rest records
         assertEquals(4L, queue.poll().timestamp);
         assertEquals(1L, queue.headRecordTimestamp());
+        assertEquals(1L, queue.headRecordOffset().longValue());
 
         assertEquals(1L, queue.poll().timestamp);
         assertEquals(2L, queue.headRecordTimestamp());
+        assertEquals(2L, queue.headRecordOffset().longValue());
 
         assertEquals(2L, queue.poll().timestamp);
         assertTrue(queue.isEmpty());
         assertEquals(0, queue.size());
         assertEquals(RecordQueue.UNKNOWN, queue.headRecordTimestamp());
+        assertNull(queue.headRecordOffset());
 
         // add three more records with 4, 5, 6
         final List<ConsumerRecord<byte[], byte[]>> list3 = Arrays.asList(
@@ -163,23 +176,27 @@ public class RecordQueueTest {
 
         assertEquals(3, queue.size());
         assertEquals(4L, queue.headRecordTimestamp());
+        assertEquals(4L, queue.headRecordOffset().longValue());
 
         // poll one record again, the timestamp should advance now
         assertEquals(4L, queue.poll().timestamp);
         assertEquals(2, queue.size());
         assertEquals(5L, queue.headRecordTimestamp());
+        assertEquals(5L, queue.headRecordOffset().longValue());
 
         // clear the queue
         queue.clear();
         assertTrue(queue.isEmpty());
         assertEquals(0, queue.size());
         assertEquals(RecordQueue.UNKNOWN, queue.headRecordTimestamp());
+        assertNull(queue.headRecordOffset());
 
         // re-insert the three records with 4, 5, 6
         queue.addRawRecords(list3);
 
         assertEquals(3, queue.size());
         assertEquals(4L, queue.headRecordTimestamp());
+        assertEquals(4L, queue.headRecordOffset().longValue());
     }
 
     @Test
@@ -245,13 +262,17 @@ public class RecordQueueTest {
         queue.addRawRecords(records);
     }
 
-    @Test(expected = StreamsException.class)
+    @Test
     public void shouldThrowStreamsExceptionWhenValueDeserializationFails() {
         final byte[] value = Serdes.Long().serializer().serialize("foo", 1L);
         final List<ConsumerRecord<byte[], byte[]>> records = Collections.singletonList(
             new ConsumerRecord<>("topic", 1, 1, 0L, TimestampType.CREATE_TIME, 0L, 0, 0, recordKey, value));
 
-        queue.addRawRecords(records);
+        final StreamsException exception = assertThrows(
+            StreamsException.class,
+            () -> queue.addRawRecords(records)
+        );
+        assertThat(exception.getCause(), instanceOf(SerializationException.class));
     }
 
     @Test
@@ -274,21 +295,29 @@ public class RecordQueueTest {
         assertEquals(0, queueThatSkipsDeserializeErrors.size());
     }
 
-
-    @Test(expected = StreamsException.class)
+    @Test
     public void shouldThrowOnNegativeTimestamp() {
         final List<ConsumerRecord<byte[], byte[]>> records = Collections.singletonList(
             new ConsumerRecord<>("topic", 1, 1, -1L, TimestampType.CREATE_TIME, 0L, 0, 0, recordKey, recordValue));
 
         final RecordQueue queue = new RecordQueue(
-            new TopicPartition(topics[0], 1),
-            new MockSourceNode<>(topics, intDeserializer, intDeserializer),
+            new TopicPartition("topic", 1),
+            mockSourceNodeWithMetrics,
             new FailOnInvalidTimestamp(),
             new LogAndContinueExceptionHandler(),
             new InternalMockProcessorContext(),
             new LogContext());
 
-        queue.addRawRecords(records);
+        final StreamsException exception = assertThrows(
+            StreamsException.class,
+            () -> queue.addRawRecords(records)
+        );
+        assertThat(exception.getMessage(), equalTo("Input record ConsumerRecord(topic = topic, partition = 1, " +
+            "leaderEpoch = null, offset = 1, CreateTime = -1, serialized key size = 0, serialized value size = 0, " +
+            "headers = RecordHeaders(headers = [], isReadOnly = false), key = 1, value = 10) has invalid (negative) " +
+            "timestamp. Possibly because a pre-0.10 producer client was used to write this record to Kafka without " +
+            "embedding a timestamp, or because the input topic was created before upgrading the Kafka cluster to 0.10+. " +
+            "Use a different TimestampExtractor to process this data."));
     }
 
     @Test
@@ -297,8 +326,8 @@ public class RecordQueueTest {
             new ConsumerRecord<>("topic", 1, 1, -1L, TimestampType.CREATE_TIME, 0L, 0, 0, recordKey, recordValue));
 
         final RecordQueue queue = new RecordQueue(
-            new TopicPartition(topics[0], 1),
-            new MockSourceNode<>(topics, intDeserializer, intDeserializer),
+            new TopicPartition("topic", 1),
+            mockSourceNodeWithMetrics,
             new LogAndSkipOnInvalidTimestamp(),
             new LogAndContinueExceptionHandler(),
             new InternalMockProcessorContext(),
@@ -313,7 +342,7 @@ public class RecordQueueTest {
 
         final PartitionTimeTrackingTimestampExtractor timestampExtractor = new PartitionTimeTrackingTimestampExtractor();
         final RecordQueue queue = new RecordQueue(
-            new TopicPartition(topics[0], 1),
+            new TopicPartition("topic", 1),
             mockSourceNodeWithMetrics,
             timestampExtractor,
             new LogAndFailExceptionHandler(),
@@ -349,7 +378,7 @@ public class RecordQueueTest {
 
     }
 
-    class PartitionTimeTrackingTimestampExtractor implements TimestampExtractor {
+    private static class PartitionTimeTrackingTimestampExtractor implements TimestampExtractor {
         private long partitionTime = RecordQueue.UNKNOWN;
 
         public long extract(final ConsumerRecord<Object, Object> record, final long partitionTime) {
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
index e804b11..1d0ca4f 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
@@ -62,10 +62,10 @@ import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender;
 import org.apache.kafka.streams.state.internals.OffsetCheckpoint;
 import org.apache.kafka.test.MockKeyValueStore;
 import org.apache.kafka.test.MockProcessorNode;
+import org.apache.kafka.test.MockRecordCollector;
 import org.apache.kafka.test.MockSourceNode;
 import org.apache.kafka.test.MockStateRestoreListener;
 import org.apache.kafka.test.MockTimestampExtractor;
-import org.apache.kafka.test.MockRecordCollector;
 import org.apache.kafka.test.TestUtils;
 import org.junit.After;
 import org.junit.Before;
@@ -75,6 +75,7 @@ import java.io.File;
 import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.time.Duration;
+import java.util.Arrays;
 import java.util.Base64;
 import java.util.Collections;
 import java.util.HashMap;
@@ -217,6 +218,7 @@ public class StreamTaskTest {
     @Before
     public void setup() {
         consumer.assign(asList(partition1, partition2));
+        consumer.updateBeginningOffsets(mkMap(mkEntry(partition1, 0L), mkEntry(partition2, 0L)));
         stateDirectory = new StateDirectory(createConfig(false), new MockTime(), true);
     }
 
@@ -776,6 +778,47 @@ public class StreamTaskTest {
     }
 
     @Test
+    public void shouldCommitNextOffsetFromQueueIfAvailable() {
+        task = createStatelessTask(createConfig(false), StreamsConfig.METRICS_LATEST);
+        task.initializeStateStores();
+        task.initializeTopology();
+
+        task.addRecords(partition1, Arrays.asList(getConsumerRecord(partition1, 0), getConsumerRecord(partition1, 5)));
+        task.process();
+        task.commit();
+
+        final Map<TopicPartition, Long> committedOffsets = getCommittetOffsets(consumer.committed(partitions));
+        assertThat(committedOffsets, equalTo(mkMap(mkEntry(partition1, 5L))));
+    }
+
+    @Test
+    public void shouldCommitConsumerPositionIfRecordQueueIsEmpty() {
+        task = createStatelessTask(createConfig(false), StreamsConfig.METRICS_LATEST);
+        task.initializeStateStores();
+        task.initializeTopology();
+
+        consumer.addRecord(getConsumerRecord(partition1, 0));
+        consumer.addRecord(getConsumerRecord(partition1, 1));
+        consumer.addRecord(getConsumerRecord(partition1, 2));
+        consumer.poll(Duration.ZERO);
+
+        task.addRecords(partition1, singletonList(getConsumerRecord(partition1, 0)));
+        task.process();
+        task.commit();
+
+        final Map<TopicPartition, Long> committedOffsets = getCommittetOffsets(consumer.committed(partitions));
+        assertThat(committedOffsets, equalTo(mkMap(mkEntry(partition1, 3L))));
+    }
+
+    private Map<TopicPartition, Long> getCommittetOffsets(final Map<TopicPartition, OffsetAndMetadata> committedOffsetsAndMetadata) {
+        final Map<TopicPartition, Long> committedOffsets = new HashMap<>();
+        for (final Map.Entry<TopicPartition, OffsetAndMetadata> e : committedOffsetsAndMetadata.entrySet()) {
+            committedOffsets.put(e.getKey(), e.getValue().offset());
+        }
+        return committedOffsets;
+    }
+
+    @Test
     public void shouldRestorePartitionTimeAfterRestartWithEosDisabled() {
         createTaskWithProcessAndCommit(false);
 
@@ -1855,6 +1898,7 @@ public class StreamTaskTest {
             Collections.singleton(repartition.topic())
         );
         consumer.assign(asList(partition1, repartition));
+        consumer.updateBeginningOffsets(mkMap(mkEntry(repartition, 0L)));
 
         task = new StreamTask(
             taskId00,
diff --git a/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
index 98d7aad..cffffbe 100644
--- a/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
+++ b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
@@ -43,7 +43,6 @@ import org.apache.kafka.streams.errors.TopologyException;
 import org.apache.kafka.streams.internals.KeyValueStoreFacade;
 import org.apache.kafka.streams.internals.QuietStreamsConfig;
 import org.apache.kafka.streams.internals.WindowStoreFacade;
-import org.apache.kafka.streams.processor.internals.metrics.TaskMetrics;
 import org.apache.kafka.streams.processor.ProcessorContext;
 import org.apache.kafka.streams.processor.PunctuationType;
 import org.apache.kafka.streams.processor.Punctuator;
@@ -63,6 +62,7 @@ import org.apache.kafka.streams.processor.internals.StateDirectory;
 import org.apache.kafka.streams.processor.internals.StoreChangelogReader;
 import org.apache.kafka.streams.processor.internals.StreamTask;
 import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
+import org.apache.kafka.streams.processor.internals.metrics.TaskMetrics;
 import org.apache.kafka.streams.state.KeyValueStore;
 import org.apache.kafka.streams.state.ReadOnlyKeyValueStore;
 import org.apache.kafka.streams.state.ReadOnlySessionStore;
@@ -333,22 +333,28 @@ public class TopologyTestDriver implements Closeable {
             offsetsByTopicPartition.put(tp, new AtomicLong());
         }
         consumer.assign(partitionsByTopic.values());
+        final Map<TopicPartition, Long> startOffsets = new HashMap<>();
+        for (final TopicPartition topicPartition : partitionsByTopic.values()) {
+            startOffsets.put(topicPartition, 0L);
+        }
+        consumer.updateBeginningOffsets(startOffsets);
 
         if (globalTopology != null) {
+            final MockConsumer<byte[], byte[]> globalConsumer = new MockConsumer<>(OffsetResetStrategy.NONE);
             for (final String topicName : globalTopology.sourceTopics()) {
                 final TopicPartition partition = new TopicPartition(topicName, 0);
                 globalPartitionsByTopic.put(topicName, partition);
                 offsetsByTopicPartition.put(partition, new AtomicLong());
-                consumer.updatePartitions(topicName, Collections.singletonList(
+                globalConsumer.updatePartitions(topicName, Collections.singletonList(
                     new PartitionInfo(topicName, 0, null, null, null)));
-                consumer.updateBeginningOffsets(Collections.singletonMap(partition, 0L));
-                consumer.updateEndOffsets(Collections.singletonMap(partition, 0L));
+                globalConsumer.updateBeginningOffsets(Collections.singletonMap(partition, 0L));
+                globalConsumer.updateEndOffsets(Collections.singletonMap(partition, 0L));
             }
 
             globalStateManager = new GlobalStateManagerImpl(
                 new LogContext("mock "),
                 globalTopology,
-                consumer,
+                globalConsumer,
                 stateDirectory,
                 stateRestoreListener,
                 streamsConfig);