You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by bb...@apache.org on 2019/04/23 17:42:48 UTC

[kafka] branch 2.1 updated: KAFKA-7895: fix Suppress changelog restore (#6536) (#6616)

This is an automated email from the ASF dual-hosted git repository.

bbejeck pushed a commit to branch 2.1
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/2.1 by this push:
     new a193f37  KAFKA-7895: fix Suppress changelog restore (#6536) (#6616)
a193f37 is described below

commit a193f370c0cd43d056dd913e97e457e636dbf76d
Author: John Roesler <vv...@users.noreply.github.com>
AuthorDate: Tue Apr 23 12:42:04 2019 -0500

    KAFKA-7895: fix Suppress changelog restore (#6536) (#6616)
    
    Several issues have come to light since the 2.2.0 release:
    upon restore, suppress incorrectly set the record metadata using the changelog record, instead of preserving the original metadata
    restoring a tombstone incorrectly didn't update the buffer size and min-timestamp
    
    Cherry-picked from #6536 / 6538e9e
    
    Reviewers: Bill Bejeck<bb...@gmail.com>
---
 .../kafka/clients/consumer/ConsumerRecord.java     |   6 +-
 .../suppress/KTableSuppressProcessor.java          |   4 +-
 .../internals/ProcessorRecordContext.java          | 137 ++++-
 .../streams/state/internals/ContextualRecord.java  |  43 +-
 .../InMemoryTimeOrderedKeyValueBuffer.java         | 173 ++++--
 .../apache/kafka/streams/KeyValueTimestamp.java    |  17 +
 .../SuppressionDurabilityIntegrationTest.java      | 166 ++++--
 .../integration/SuppressionIntegrationTest.java    |   3 +-
 .../integration/utils/IntegrationTestUtils.java    |  60 +-
 .../suppress/KTableSuppressProcessorTest.java      |   1 -
 .../internals/TimeOrderedKeyValueBufferTest.java   | 604 +++++++++++++++++++++
 .../kafka/test/MockInternalProcessorContext.java   | 104 +++-
 12 files changed, 1200 insertions(+), 118 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerRecord.java b/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerRecord.java
index 0413d5b..a7dad7b 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerRecord.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerRecord.java
@@ -157,6 +157,8 @@ public class ConsumerRecord<K, V> {
                           Optional<Integer> leaderEpoch) {
         if (topic == null)
             throw new IllegalArgumentException("Topic cannot be null");
+        if (headers == null)
+            throw new IllegalArgumentException("Headers cannot be null");
 
         this.topic = topic;
         this.partition = partition;
@@ -173,7 +175,7 @@ public class ConsumerRecord<K, V> {
     }
 
     /**
-     * The topic this record is received from
+     * The topic this record is received from (never null)
      */
     public String topic() {
         return this.topic;
@@ -187,7 +189,7 @@ public class ConsumerRecord<K, V> {
     }
 
     /**
-     * The headers
+     * The headers (never null)
      */
     public Headers headers() {
         return headers;
diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/suppress/KTableSuppressProcessor.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/suppress/KTableSuppressProcessor.java
index 4058083..622223c 100644
--- a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/suppress/KTableSuppressProcessor.java
+++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/suppress/KTableSuppressProcessor.java
@@ -31,8 +31,6 @@ import org.apache.kafka.streams.processor.internals.ProcessorRecordContext;
 import org.apache.kafka.streams.state.internals.ContextualRecord;
 import org.apache.kafka.streams.state.internals.TimeOrderedKeyValueBuffer;
 
-import java.util.Objects;
-
 import static java.util.Objects.requireNonNull;
 
 public class KTableSuppressProcessor<K, V> implements Processor<K, Change<V>> {
@@ -74,7 +72,7 @@ public class KTableSuppressProcessor<K, V> implements Processor<K, Change<V>> {
         internalProcessorContext = (InternalProcessorContext) context;
         keySerde = keySerde == null ? (Serde<K>) context.keySerde() : keySerde;
         valueSerde = valueSerde == null ? FullChangeSerde.castOrWrap(context.valueSerde()) : valueSerde;
-        buffer = Objects.requireNonNull((TimeOrderedKeyValueBuffer) context.getStateStore(storeName));
+        buffer = requireNonNull((TimeOrderedKeyValueBuffer) context.getStateStore(storeName));
     }
 
     @Override
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorRecordContext.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorRecordContext.java
index da44e96..4f991a2 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorRecordContext.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorRecordContext.java
@@ -18,10 +18,15 @@ package org.apache.kafka.streams.processor.internals;
 
 import org.apache.kafka.common.header.Header;
 import org.apache.kafka.common.header.Headers;
+import org.apache.kafka.common.header.internals.RecordHeader;
+import org.apache.kafka.common.header.internals.RecordHeaders;
 import org.apache.kafka.streams.processor.RecordContext;
 
+import java.nio.ByteBuffer;
 import java.util.Objects;
 
+import static java.nio.charset.StandardCharsets.UTF_8;
+
 public class ProcessorRecordContext implements RecordContext {
 
     long timestamp;
@@ -80,13 +85,13 @@ public class ProcessorRecordContext implements RecordContext {
     }
 
     public long sizeBytes() {
-        long size = 0L;
-        size += 8; // value.context.timestamp
-        size += 8; // value.context.offset
+        long size = 0;
+        size += Long.BYTES; // value.context.timestamp
+        size += Long.BYTES; // value.context.offset
         if (topic != null) {
             size += topic.toCharArray().length;
         }
-        size += 4; // partition
+        size += Integer.BYTES; // partition
         if (headers != null) {
             for (final Header header : headers) {
                 size += header.key().toCharArray().length;
@@ -99,20 +104,136 @@ public class ProcessorRecordContext implements RecordContext {
         return size;
     }
 
+    public byte[] serialize() {
+        final byte[] topicBytes = topic.getBytes(UTF_8);
+        final byte[][] headerKeysBytes;
+        final byte[][] headerValuesBytes;
+
+
+        int size = 0;
+        size += Long.BYTES; // value.context.timestamp
+        size += Long.BYTES; // value.context.offset
+        size += Integer.BYTES; // size of topic
+        size += topicBytes.length;
+        size += Integer.BYTES; // partition
+        size += Integer.BYTES; // number of headers
+
+        if (headers == null) {
+            headerKeysBytes = headerValuesBytes = null;
+        } else {
+            final Header[] headers = this.headers.toArray();
+            headerKeysBytes = new byte[headers.length][];
+            headerValuesBytes = new byte[headers.length][];
+
+            for (int i = 0; i < headers.length; i++) {
+                size += 2 * Integer.BYTES; // sizes of key and value
+
+                final byte[] keyBytes = headers[i].key().getBytes(UTF_8);
+                size += keyBytes.length;
+                final byte[] valueBytes = headers[i].value();
+                if (valueBytes != null) {
+                    size += valueBytes.length;
+                }
+
+                headerKeysBytes[i] = keyBytes;
+                headerValuesBytes[i] = valueBytes;
+            }
+        }
+
+        final ByteBuffer buffer = ByteBuffer.allocate(size);
+        buffer.putLong(timestamp);
+        buffer.putLong(offset);
+
+        // not handling the null condition because we believe topic will never be null in cases where we serialize
+        buffer.putInt(topicBytes.length);
+        buffer.put(topicBytes);
+
+        buffer.putInt(partition);
+        if (headers == null) {
+            buffer.putInt(-1);
+        } else {
+            buffer.putInt(headerKeysBytes.length);
+            for (int i = 0; i < headerKeysBytes.length; i++) {
+                buffer.putInt(headerKeysBytes[i].length);
+                buffer.put(headerKeysBytes[i]);
+
+                if (headerValuesBytes[i] != null) {
+                    buffer.putInt(headerValuesBytes[i].length);
+                    buffer.put(headerValuesBytes[i]);
+                } else {
+                    buffer.putInt(-1);
+                }
+            }
+        }
+
+        return buffer.array();
+    }
+
+    public static ProcessorRecordContext deserialize(final ByteBuffer buffer) {
+        final long timestamp = buffer.getLong();
+        final long offset = buffer.getLong();
+        final int topicSize = buffer.getInt();
+        final String topic;
+        {
+            // not handling the null topic condition, because we believe the topic will never be null when we serialize
+            final byte[] topicBytes = new byte[topicSize];
+            buffer.get(topicBytes);
+            topic = new String(topicBytes, UTF_8);
+        }
+        final int partition = buffer.getInt();
+        final int headerCount = buffer.getInt();
+        final Headers headers;
+        if (headerCount == -1) {
+            headers = null;
+        } else {
+            final Header[] headerArr = new Header[headerCount];
+            for (int i = 0; i < headerCount; i++) {
+                final int keySize = buffer.getInt();
+                final byte[] keyBytes = new byte[keySize];
+                buffer.get(keyBytes);
+
+                final int valueSize = buffer.getInt();
+                final byte[] valueBytes;
+                if (valueSize == -1) {
+                    valueBytes = null;
+                } else {
+                    valueBytes = new byte[valueSize];
+                    buffer.get(valueBytes);
+                }
+
+                headerArr[i] = new RecordHeader(new String(keyBytes, UTF_8), valueBytes);
+            }
+            headers = new RecordHeaders(headerArr);
+        }
+
+        return new ProcessorRecordContext(timestamp, offset, partition, topic, headers);
+    }
+
     @Override
     public boolean equals(final Object o) {
         if (this == o) return true;
         if (o == null || getClass() != o.getClass()) return false;
         final ProcessorRecordContext that = (ProcessorRecordContext) o;
         return timestamp == that.timestamp &&
-                offset == that.offset &&
-                partition == that.partition &&
-                Objects.equals(topic, that.topic) &&
-                Objects.equals(headers, that.headers);
+            offset == that.offset &&
+            partition == that.partition &&
+            Objects.equals(topic, that.topic) &&
+            Objects.equals(headers, that.headers);
     }
 
     @Override
     public int hashCode() {
         return Objects.hash(timestamp, offset, topic, partition, headers);
     }
+
+    @Override
+    public String toString() {
+        return "ProcessorRecordContext{" +
+            "topic='" + topic + '\'' +
+            ", partition=" + partition +
+            ", offset=" + offset +
+            ", timestamp=" + timestamp +
+            ", headers=" + headers +
+            '}';
+    }
 }
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/ContextualRecord.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/ContextualRecord.java
index 89935c0..7891d71 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/ContextualRecord.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/ContextualRecord.java
@@ -18,6 +18,7 @@ package org.apache.kafka.streams.state.internals;
 
 import org.apache.kafka.streams.processor.internals.ProcessorRecordContext;
 
+import java.nio.ByteBuffer;
 import java.util.Arrays;
 import java.util.Objects;
 
@@ -27,7 +28,7 @@ public class ContextualRecord {
 
     public ContextualRecord(final byte[] value, final ProcessorRecordContext recordContext) {
         this.value = value;
-        this.recordContext = recordContext;
+        this.recordContext = Objects.requireNonNull(recordContext);
     }
 
     public ProcessorRecordContext recordContext() {
@@ -42,6 +43,38 @@ public class ContextualRecord {
         return (value == null ? 0 : value.length) + recordContext.sizeBytes();
     }
 
+    byte[] serialize() {
+        final byte[] serializedContext = recordContext.serialize();
+
+        final int sizeOfContext = serializedContext.length;
+        final int sizeOfValueLength = Integer.BYTES;
+        final int sizeOfValue = value == null ? 0 : value.length;
+        final ByteBuffer buffer = ByteBuffer.allocate(sizeOfContext + sizeOfValueLength + sizeOfValue);
+
+        buffer.put(serializedContext);
+        if (value == null) {
+            buffer.putInt(-1);
+        } else {
+            buffer.putInt(value.length);
+            buffer.put(value);
+        }
+
+        return buffer.array();
+    }
+
+    static ContextualRecord deserialize(final ByteBuffer buffer) {
+        final ProcessorRecordContext context = ProcessorRecordContext.deserialize(buffer);
+
+        final int valueLength = buffer.getInt();
+        if (valueLength == -1) {
+            return new ContextualRecord(null, context);
+        } else {
+            final byte[] value = new byte[valueLength];
+            buffer.get(value);
+            return new ContextualRecord(value, context);
+        }
+    }
+
     @Override
     public boolean equals(final Object o) {
         if (this == o) return true;
@@ -55,4 +88,12 @@ public class ContextualRecord {
     public int hashCode() {
         return Objects.hash(value, recordContext);
     }
+
+    @Override
+    public String toString() {
+        return "ContextualRecord{" +
+            "recordContext=" + recordContext +
+            ", value=" + Arrays.toString(value) +
+            '}';
+    }
 }
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryTimeOrderedKeyValueBuffer.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryTimeOrderedKeyValueBuffer.java
index d94f671..d323d97 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryTimeOrderedKeyValueBuffer.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryTimeOrderedKeyValueBuffer.java
@@ -17,6 +17,9 @@
 package org.apache.kafka.streams.state.internals;
 
 import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.common.header.Header;
+import org.apache.kafka.common.header.internals.RecordHeader;
+import org.apache.kafka.common.header.internals.RecordHeaders;
 import org.apache.kafka.common.serialization.ByteArraySerializer;
 import org.apache.kafka.common.serialization.BytesSerializer;
 import org.apache.kafka.common.utils.Bytes;
@@ -42,9 +45,13 @@ import java.util.TreeMap;
 import java.util.function.Consumer;
 import java.util.function.Supplier;
 
-public class InMemoryTimeOrderedKeyValueBuffer implements TimeOrderedKeyValueBuffer {
+import static java.util.Objects.requireNonNull;
+
+public final class InMemoryTimeOrderedKeyValueBuffer implements TimeOrderedKeyValueBuffer {
     private static final BytesSerializer KEY_SERIALIZER = new BytesSerializer();
     private static final ByteArraySerializer VALUE_SERIALIZER = new ByteArraySerializer();
+    private static final RecordHeaders V_1_CHANGELOG_HEADERS =
+        new RecordHeaders(new Header[] {new RecordHeader("v", new byte[] {(byte) 1})});
 
     private final Map<Bytes, BufferKey> index = new HashMap<>();
     private final TreeMap<BufferKey, ContextualRecord> sortedMap = new TreeMap<>();
@@ -60,6 +67,8 @@ public class InMemoryTimeOrderedKeyValueBuffer implements TimeOrderedKeyValueBuf
 
     private volatile boolean open;
 
+    private int partition;
+
     public static class Builder implements StoreBuilder<StateStore> {
 
         private final String storeName;
@@ -125,7 +134,7 @@ public class InMemoryTimeOrderedKeyValueBuffer implements TimeOrderedKeyValueBuf
         }
     }
 
-    private static class BufferKey implements Comparable<BufferKey> {
+    private static final class BufferKey implements Comparable<BufferKey> {
         private final long time;
         private final Bytes key;
 
@@ -154,6 +163,14 @@ public class InMemoryTimeOrderedKeyValueBuffer implements TimeOrderedKeyValueBuf
             final int timeComparison = Long.compare(time, o.time);
             return timeComparison == 0 ? key.compareTo(o.key) : timeComparison;
         }
+
+        @Override
+        public String toString() {
+            return "BufferKey{" +
+                "key=" + key +
+                ", time=" + time +
+                '}';
+        }
     }
 
     private InMemoryTimeOrderedKeyValueBuffer(final String storeName, final boolean loggingEnabled) {
@@ -180,6 +197,7 @@ public class InMemoryTimeOrderedKeyValueBuffer implements TimeOrderedKeyValueBuf
             changelogTopic = ProcessorStateManager.storeChangelogTopic(context.applicationId(), storeName);
         }
         open = true;
+        partition = context.taskId().partition;
     }
 
     @Override
@@ -207,33 +225,52 @@ public class InMemoryTimeOrderedKeyValueBuffer implements TimeOrderedKeyValueBuf
 
                 if (bufferKey == null) {
                     // The record was evicted from the buffer. Send a tombstone.
-                    collector.send(changelogTopic, key, null, null, null, null, KEY_SERIALIZER, VALUE_SERIALIZER);
+                    logTombstone(key);
                 } else {
                     final ContextualRecord value = sortedMap.get(bufferKey);
 
-                    final byte[] innerValue = value.value();
-                    final byte[] timeAndValue = ByteBuffer.wrap(new byte[8 + innerValue.length])
-                                                          .putLong(bufferKey.time)
-                                                          .put(innerValue)
-                                                          .array();
-
-                    final ProcessorRecordContext recordContext = value.recordContext();
-                    collector.send(
-                        changelogTopic,
-                        key,
-                        timeAndValue,
-                        recordContext.headers(),
-                        recordContext.partition(),
-                        recordContext.timestamp(),
-                        KEY_SERIALIZER,
-                        VALUE_SERIALIZER
-                    );
+                    logValue(key, bufferKey, value);
                 }
             }
             dirtyKeys.clear();
         }
     }
 
+    private void logValue(final Bytes key, final BufferKey bufferKey, final ContextualRecord value) {
+        final byte[] serializedContextualRecord = value.serialize();
+
+        final int sizeOfBufferTime = Long.BYTES;
+        final int sizeOfContextualRecord = serializedContextualRecord.length;
+
+        final byte[] timeAndContextualRecord = ByteBuffer.wrap(new byte[sizeOfBufferTime + sizeOfContextualRecord])
+                                                         .putLong(bufferKey.time)
+                                                         .put(serializedContextualRecord)
+                                                         .array();
+
+        collector.send(
+            changelogTopic,
+            key,
+            timeAndContextualRecord,
+            V_1_CHANGELOG_HEADERS,
+            partition,
+            null,
+            KEY_SERIALIZER,
+            VALUE_SERIALIZER
+        );
+    }
+
+    private void logTombstone(final Bytes key) {
+        collector.send(changelogTopic,
+                       key,
+                       null,
+                       null,
+                       partition,
+                       null,
+                       KEY_SERIALIZER,
+                       VALUE_SERIALIZER
+        );
+    }
+
     private void restoreBatch(final Collection<ConsumerRecord<byte[], byte[]>> batch) {
         for (final ConsumerRecord<byte[], byte[]> record : batch) {
             final Bytes key = Bytes.wrap(record.key());
@@ -241,26 +278,62 @@ public class InMemoryTimeOrderedKeyValueBuffer implements TimeOrderedKeyValueBuf
                 // This was a tombstone. Delete the record.
                 final BufferKey bufferKey = index.remove(key);
                 if (bufferKey != null) {
-                    sortedMap.remove(bufferKey);
+                    final ContextualRecord removed = sortedMap.remove(bufferKey);
+                    if (removed != null) {
+                        memBufferSize -= computeRecordSize(bufferKey.key, removed);
+                    }
+                    if (bufferKey.time == minTimestamp) {
+                        minTimestamp = sortedMap.isEmpty() ? Long.MAX_VALUE : sortedMap.firstKey().time;
+                    }
+                }
+
+                if (record.partition() != partition) {
+                    throw new IllegalStateException(
+                        String.format(
+                            "record partition [%d] is being restored by the wrong suppress partition [%d]",
+                            record.partition(),
+                            partition
+                        )
+                    );
                 }
             } else {
                 final ByteBuffer timeAndValue = ByteBuffer.wrap(record.value());
                 final long time = timeAndValue.getLong();
                 final byte[] value = new byte[record.value().length - 8];
                 timeAndValue.get(value);
-
-                cleanPut(
-                    time,
-                    key,
-                    new ContextualRecord(
-                        value,
-                        new ProcessorRecordContext(
-                            record.timestamp(),
-                            record.offset(),
-                            record.partition(),
-                            record.topic(),
-                            record.headers()
+                if (record.headers().lastHeader("v") == null) {
+                    cleanPut(
+                        time,
+                        key,
+                        new ContextualRecord(
+                            value,
+                            new ProcessorRecordContext(
+                                record.timestamp(),
+                                record.offset(),
+                                record.partition(),
+                                record.topic(),
+                                record.headers()
+                            )
                         )
+                    );
+                } else if (V_1_CHANGELOG_HEADERS.lastHeader("v").equals(record.headers().lastHeader("v"))) {
+                    final ContextualRecord contextualRecord = ContextualRecord.deserialize(ByteBuffer.wrap(value));
+
+                    cleanPut(
+                        time,
+                        key,
+                        contextualRecord
+                    );
+                } else {
+                    throw new IllegalArgumentException("Restoring apparently invalid changelog record: " + record);
+                }
+            }
+            if (record.partition() != partition) {
+                throw new IllegalStateException(
+                    String.format(
+                        "record partition [%d] is being restored by the wrong suppress partition [%d]",
+                        record.partition(),
+                        partition
                     )
                 );
             }
@@ -281,6 +354,12 @@ public class InMemoryTimeOrderedKeyValueBuffer implements TimeOrderedKeyValueBuf
 
             // predicate being true means we read one record, call the callback, and then remove it
             while (next != null && predicate.get()) {
+                if (next.getKey().time != minTimestamp) {
+                    throw new IllegalStateException(
+                        "minTimestamp [" + minTimestamp + "] did not match the actual min timestamp [" +
+                            next.getKey().time + "]"
+                    );
+                }
                 callback.accept(new KeyValue<>(next.getKey().key, next.getValue()));
 
                 delegate.remove();
@@ -288,7 +367,7 @@ public class InMemoryTimeOrderedKeyValueBuffer implements TimeOrderedKeyValueBuf
 
                 dirtyKeys.add(next.getKey().key);
 
-                memBufferSize = memBufferSize - computeRecordSize(next.getKey().key, next.getValue());
+                memBufferSize -= computeRecordSize(next.getKey().key, next.getValue());
 
                 // peek at the next record so we can update the minTimestamp
                 if (delegate.hasNext()) {
@@ -305,8 +384,11 @@ public class InMemoryTimeOrderedKeyValueBuffer implements TimeOrderedKeyValueBuf
     @Override
     public void put(final long time,
                     final Bytes key,
-                    final ContextualRecord value) {
-        cleanPut(time, key, value);
+                    final ContextualRecord contextualRecord) {
+        requireNonNull(contextualRecord.value(), "value cannot be null");
+        requireNonNull(contextualRecord.recordContext(), "recordContext cannot be null");
+
+        cleanPut(time, key, contextualRecord);
         dirtyKeys.add(key);
     }
 
@@ -321,7 +403,7 @@ public class InMemoryTimeOrderedKeyValueBuffer implements TimeOrderedKeyValueBuf
             index.put(key, nextKey);
             sortedMap.put(nextKey, value);
             minTimestamp = Math.min(minTimestamp, time);
-            memBufferSize = memBufferSize + computeRecordSize(key, value);
+            memBufferSize += computeRecordSize(key, value);
         } else {
             final ContextualRecord removedValue = sortedMap.put(previousKey, value);
             memBufferSize =
@@ -346,7 +428,7 @@ public class InMemoryTimeOrderedKeyValueBuffer implements TimeOrderedKeyValueBuf
         return minTimestamp;
     }
 
-    private long computeRecordSize(final Bytes key, final ContextualRecord value) {
+    private static long computeRecordSize(final Bytes key, final ContextualRecord value) {
         long size = 0L;
         size += 8; // buffer time
         size += key.get().length;
@@ -355,4 +437,19 @@ public class InMemoryTimeOrderedKeyValueBuffer implements TimeOrderedKeyValueBuf
         }
         return size;
     }
+
+    @Override
+    public String toString() {
+        return "InMemoryTimeOrderedKeyValueBuffer{" +
+            "storeName='" + storeName + '\'' +
+            ", changelogTopic='" + changelogTopic + '\'' +
+            ", open=" + open +
+            ", loggingEnabled=" + loggingEnabled +
+            ", minTimestamp=" + minTimestamp +
+            ", memBufferSize=" + memBufferSize +
+            ", \n\tdirtyKeys=" + dirtyKeys +
+            ", \n\tindex=" + index +
+            ", \n\tsortedMap=" + sortedMap +
+            '}';
+    }
 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/KeyValueTimestamp.java b/streams/src/test/java/org/apache/kafka/streams/KeyValueTimestamp.java
index 4213112..b578562 100644
--- a/streams/src/test/java/org/apache/kafka/streams/KeyValueTimestamp.java
+++ b/streams/src/test/java/org/apache/kafka/streams/KeyValueTimestamp.java
@@ -16,6 +16,8 @@
  */
 package org.apache.kafka.streams;
 
+import java.util.Objects;
+
 public class KeyValueTimestamp<K, V> {
     private final K key;
     private final V value;
@@ -43,4 +45,19 @@ public class KeyValueTimestamp<K, V> {
     public String toString() {
         return "KeyValueTimestamp{key=" + key + ", value=" + value + ", timestamp=" + timestamp + '}';
     }
+
+    @Override
+    public boolean equals(final Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        final KeyValueTimestamp<?, ?> that = (KeyValueTimestamp<?, ?>) o;
+        return timestamp == that.timestamp &&
+            Objects.equals(key, that.key) &&
+            Objects.equals(value, that.value);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(key, value, timestamp);
+    }
 }
\ No newline at end of file
diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/SuppressionDurabilityIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/SuppressionDurabilityIntegrationTest.java
index c26b52f..fa49386 100644
--- a/streams/src/test/java/org/apache/kafka/streams/integration/SuppressionDurabilityIntegrationTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/integration/SuppressionDurabilityIntegrationTest.java
@@ -34,11 +34,13 @@ import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
 import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
 import org.apache.kafka.streams.kstream.Consumed;
-import org.apache.kafka.streams.kstream.Grouped;
 import org.apache.kafka.streams.kstream.KStream;
 import org.apache.kafka.streams.kstream.KTable;
 import org.apache.kafka.streams.kstream.Materialized;
 import org.apache.kafka.streams.kstream.Produced;
+import org.apache.kafka.streams.kstream.Transformer;
+import org.apache.kafka.streams.kstream.TransformerSupplier;
+import org.apache.kafka.streams.processor.ProcessorContext;
 import org.apache.kafka.streams.state.KeyValueStore;
 import org.apache.kafka.test.IntegrationTest;
 import org.junit.ClassRule;
@@ -47,13 +49,18 @@ import org.junit.experimental.categories.Category;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 import org.junit.runners.Parameterized.Parameters;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
-import java.util.Arrays;
 import java.util.Collection;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Locale;
+import java.util.Optional;
 import java.util.Properties;
+import java.util.Set;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReference;
 
 import static java.lang.Long.MAX_VALUE;
 import static java.time.Duration.ofMillis;
@@ -70,9 +77,10 @@ import static org.apache.kafka.streams.kstream.Suppressed.BufferConfig.maxRecord
 import static org.apache.kafka.streams.kstream.Suppressed.untilTimeLimit;
 import static org.hamcrest.CoreMatchers.is;
 import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.equalTo;
 
 @RunWith(Parameterized.class)
-@Category({IntegrationTest.class})
+@Category(IntegrationTest.class)
 public class SuppressionDurabilityIntegrationTest {
     @ClassRule
     public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(
@@ -87,26 +95,13 @@ public class SuppressionDurabilityIntegrationTest {
     private static final int COMMIT_INTERVAL = 100;
     private final boolean eosEnabled;
 
-    public SuppressionDurabilityIntegrationTest(final boolean eosEnabled) {
-        this.eosEnabled = eosEnabled;
-    }
-
     @Parameters(name = "{index}: eosEnabled={0}")
     public static Collection<Object[]> parameters() {
-        return Arrays.asList(new Object[] {false}, new Object[] {true});
+        return asList(new Object[] {false}, new Object[] {true});
     }
 
-    private KTable<String, Long> buildCountsTable(final String input, final StreamsBuilder builder) {
-        return builder
-            .table(
-                input,
-                Consumed.with(STRING_SERDE, STRING_SERDE),
-                Materialized.<String, String, KeyValueStore<Bytes, byte[]>>with(STRING_SERDE, STRING_SERDE)
-                    .withCachingDisabled()
-                    .withLoggingDisabled()
-            )
-            .groupBy((k, v) -> new KeyValue<>(v, k), Grouped.with(STRING_SERDE, STRING_SERDE))
-            .count(Materialized.<String, Long, KeyValueStore<Bytes, byte[]>>as("counts").withCachingDisabled());
+    public SuppressionDurabilityIntegrationTest(final boolean eosEnabled) {
+        this.eosEnabled = eosEnabled;
     }
 
     @Test
@@ -114,13 +109,21 @@ public class SuppressionDurabilityIntegrationTest {
         final String testId = "-shouldRecoverBufferAfterShutdown";
         final String appId = getClass().getSimpleName().toLowerCase(Locale.getDefault()) + testId;
         final String input = "input" + testId;
+        final String storeName = "counts";
         final String outputSuppressed = "output-suppressed" + testId;
         final String outputRaw = "output-raw" + testId;
 
-        cleanStateBeforeTest(CLUSTER, input, outputRaw, outputSuppressed);
+        // create multiple partitions as a trap, in case the buffer doesn't properly set the
+        // partition on the records, but instead relies on the default key partitioner
+        cleanStateBeforeTest(CLUSTER, 2, input, outputRaw, outputSuppressed);
 
         final StreamsBuilder builder = new StreamsBuilder();
-        final KTable<String, Long> valueCounts = buildCountsTable(input, builder);
+        final KTable<String, Long> valueCounts = builder
+            .stream(
+                input,
+                Consumed.with(STRING_SERDE, STRING_SERDE))
+            .groupByKey()
+            .count(Materialized.<String, Long, KeyValueStore<Bytes, byte[]>>as(storeName).withCachingDisabled());
 
         final KStream<String, Long> suppressedCounts = valueCounts
             .suppress(untilTimeLimit(ofMillis(MAX_VALUE), maxRecords(3L).emitEarlyWhenFull()))
@@ -129,11 +132,16 @@ public class SuppressionDurabilityIntegrationTest {
         final AtomicInteger eventCount = new AtomicInteger(0);
         suppressedCounts.foreach((key, value) -> eventCount.incrementAndGet());
 
+        // expect all post-suppress records to keep the right input topic
+        final MetadataValidator metadataValidator = new MetadataValidator(input);
+
         suppressedCounts
+            .transform(metadataValidator)
             .to(outputSuppressed, Produced.with(STRING_SERDE, Serdes.Long()));
 
         valueCounts
             .toStream()
+            .transform(metadataValidator)
             .to(outputRaw, Produced.with(STRING_SERDE, Serdes.Long()));
 
         final Properties streamsConfig = mkProperties(mkMap(
@@ -147,7 +155,9 @@ public class SuppressionDurabilityIntegrationTest {
         KafkaStreams driver = getStartedStreams(streamsConfig, builder, true);
         try {
             // start by putting some stuff in the buffer
-            produceSynchronously(
+            // note, we send all input records to partition 0
+            // to make sure that suppress doesn't erroneously send records to other partitions.
+            produceSynchronouslyToPartitionZero(
                 input,
                 asList(
                     new KeyValueTimestamp<>("k1", "v1", scaledTime(1L)),
@@ -157,16 +167,16 @@ public class SuppressionDurabilityIntegrationTest {
             );
             verifyOutput(
                 outputRaw,
-                asList(
-                    new KeyValueTimestamp<>("v1", 1L, scaledTime(1L)),
-                    new KeyValueTimestamp<>("v2", 1L, scaledTime(2L)),
-                    new KeyValueTimestamp<>("v3", 1L, scaledTime(3L))
-                )
+                new HashSet<>(asList(
+                    new KeyValueTimestamp<>("k1", 1L, scaledTime(1L)),
+                    new KeyValueTimestamp<>("k2", 1L, scaledTime(2L)),
+                    new KeyValueTimestamp<>("k3", 1L, scaledTime(3L))
+                ))
             );
             assertThat(eventCount.get(), is(0));
 
             // flush two of the first three events out.
-            produceSynchronously(
+            produceSynchronouslyToPartitionZero(
                 input,
                 asList(
                     new KeyValueTimestamp<>("k4", "v4", scaledTime(4L)),
@@ -175,17 +185,17 @@ public class SuppressionDurabilityIntegrationTest {
             );
             verifyOutput(
                 outputRaw,
-                asList(
-                    new KeyValueTimestamp<>("v4", 1L, scaledTime(4L)),
-                    new KeyValueTimestamp<>("v5", 1L, scaledTime(5L))
-                )
+                new HashSet<>(asList(
+                    new KeyValueTimestamp<>("k4", 1L, scaledTime(4L)),
+                    new KeyValueTimestamp<>("k5", 1L, scaledTime(5L))
+                ))
             );
             assertThat(eventCount.get(), is(2));
             verifyOutput(
                 outputSuppressed,
                 asList(
-                    new KeyValueTimestamp<>("v1", 1L, scaledTime(1L)),
-                    new KeyValueTimestamp<>("v2", 1L, scaledTime(2L))
+                    new KeyValueTimestamp<>("k1", 1L, scaledTime(1L)),
+                    new KeyValueTimestamp<>("k2", 1L, scaledTime(2L))
                 )
             );
 
@@ -199,7 +209,7 @@ public class SuppressionDurabilityIntegrationTest {
 
 
             // flush those recovered buffered events out.
-            produceSynchronously(
+            produceSynchronouslyToPartitionZero(
                 input,
                 asList(
                     new KeyValueTimestamp<>("k6", "v6", scaledTime(6L)),
@@ -209,29 +219,78 @@ public class SuppressionDurabilityIntegrationTest {
             );
             verifyOutput(
                 outputRaw,
-                asList(
-                    new KeyValueTimestamp<>("v6", 1L, scaledTime(6L)),
-                    new KeyValueTimestamp<>("v7", 1L, scaledTime(7L)),
-                    new KeyValueTimestamp<>("v8", 1L, scaledTime(8L))
-                )
+                new HashSet<>(asList(
+                    new KeyValueTimestamp<>("k6", 1L, scaledTime(6L)),
+                    new KeyValueTimestamp<>("k7", 1L, scaledTime(7L)),
+                    new KeyValueTimestamp<>("k8", 1L, scaledTime(8L))
+                ))
             );
-            assertThat(eventCount.get(), is(5));
+            assertThat("suppress has apparently produced some duplicates. There should only be 5 output events.",
+                       eventCount.get(), is(5));
+
             verifyOutput(
                 outputSuppressed,
                 asList(
-                    new KeyValueTimestamp<>("v3", 1L, scaledTime(3L)),
-                    new KeyValueTimestamp<>("v4", 1L, scaledTime(4L)),
-                    new KeyValueTimestamp<>("v5", 1L, scaledTime(5L))
+                    new KeyValueTimestamp<>("k3", 1L, scaledTime(3L)),
+                    new KeyValueTimestamp<>("k4", 1L, scaledTime(4L)),
+                    new KeyValueTimestamp<>("k5", 1L, scaledTime(5L))
                 )
             );
 
+            metadataValidator.raiseExceptionIfAny();
+
         } finally {
             driver.close();
             cleanStateAfterTest(CLUSTER, driver);
         }
     }
 
-    private void verifyOutput(final String topic, final List<KeyValueTimestamp<String, Long>> keyValueTimestamps) {
+    private static final class MetadataValidator implements TransformerSupplier<String, Long, KeyValue<String, Long>> {
+        private static final Logger LOG = LoggerFactory.getLogger(MetadataValidator.class);
+        private final AtomicReference<Throwable> firstException = new AtomicReference<>();
+        private final String topic;
+
+        MetadataValidator(final String topic) {
+            this.topic = topic;
+        }
+
+        @Override
+        public Transformer<String, Long, KeyValue<String, Long>> get() {
+            return new Transformer<String, Long, KeyValue<String, Long>>() {
+                private ProcessorContext context;
+
+                @Override
+                public void init(final ProcessorContext context) {
+                    this.context = context;
+                }
+
+                @Override
+                public KeyValue<String, Long> transform(final String key, final Long value) {
+                    try {
+                        assertThat(context.topic(), equalTo(topic));
+                    } catch (final Throwable e) {
+                        firstException.compareAndSet(null, e);
+                        LOG.error("Validation Failed", e);
+                    }
+                    return new KeyValue<>(key, value);
+                }
+
+                @Override
+                public void close() {
+
+                }
+            };
+        }
+
+        void raiseExceptionIfAny() {
+            final Throwable exception = firstException.get();
+            if (exception != null) {
+                throw new AssertionError("Got an exception during run", exception);
+            }
+        }
+    }
+
+    private static void verifyOutput(final String topic, final List<KeyValueTimestamp<String, Long>> keyValueTimestamps) {
         final Properties properties = mkProperties(
             mkMap(
                 mkEntry(ConsumerConfig.GROUP_ID_CONFIG, "test-group"),
@@ -241,24 +300,35 @@ public class SuppressionDurabilityIntegrationTest {
             )
         );
         IntegrationTestUtils.verifyKeyValueTimestamps(properties, topic, keyValueTimestamps);
+    }
 
+    private static void verifyOutput(final String topic, final Set<KeyValueTimestamp<String, Long>> keyValueTimestamps) {
+        final Properties properties = mkProperties(
+            mkMap(
+                mkEntry(ConsumerConfig.GROUP_ID_CONFIG, "test-group"),
+                mkEntry(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()),
+                mkEntry(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, ((Deserializer<String>) STRING_DESERIALIZER).getClass().getName()),
+                mkEntry(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, ((Deserializer<Long>) LONG_DESERIALIZER).getClass().getName())
+            )
+        );
+        IntegrationTestUtils.verifyKeyValueTimestamps(properties, topic, keyValueTimestamps);
     }
 
     /**
      * scaling to ensure that there are commits in between the various test events,
      * just to exercise that everything works properly in the presence of commits.
      */
-    private long scaledTime(final long unscaledTime) {
+    private static long scaledTime(final long unscaledTime) {
         return COMMIT_INTERVAL * 2 * unscaledTime;
     }
 
-    private void produceSynchronously(final String topic, final List<KeyValueTimestamp<String, String>> toProduce) {
+    private static void produceSynchronouslyToPartitionZero(final String topic, final List<KeyValueTimestamp<String, String>> toProduce) {
         final Properties producerConfig = mkProperties(mkMap(
             mkEntry(ProducerConfig.CLIENT_ID_CONFIG, "anything"),
             mkEntry(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, ((Serializer<String>) STRING_SERIALIZER).getClass().getName()),
             mkEntry(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, ((Serializer<String>) STRING_SERIALIZER).getClass().getName()),
             mkEntry(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers())
         ));
-        IntegrationTestUtils.produceSynchronously(producerConfig, false, topic, toProduce);
+        IntegrationTestUtils.produceSynchronously(producerConfig, false, topic, Optional.of(0), toProduce);
     }
 }
\ No newline at end of file
diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/SuppressionIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/SuppressionIntegrationTest.java
index ee32a1d..da91b91 100644
--- a/streams/src/test/java/org/apache/kafka/streams/integration/SuppressionIntegrationTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/integration/SuppressionIntegrationTest.java
@@ -51,6 +51,7 @@ import org.junit.experimental.categories.Category;
 import java.time.Duration;
 import java.util.List;
 import java.util.Locale;
+import java.util.Optional;
 import java.util.Properties;
 
 import static java.lang.Long.MAX_VALUE;
@@ -518,7 +519,7 @@ public class SuppressionIntegrationTest {
             mkEntry(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, ((Serializer<String>) STRING_SERIALIZER).getClass().getName()),
             mkEntry(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers())
         ));
-        IntegrationTestUtils.produceSynchronously(producerConfig, false, topic, toProduce);
+        IntegrationTestUtils.produceSynchronously(producerConfig, false, topic, Optional.empty(), toProduce);
     }
 
     private void verifyErrorShutdown(final KafkaStreams driver) throws InterruptedException {
diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/utils/IntegrationTestUtils.java b/streams/src/test/java/org/apache/kafka/streams/integration/utils/IntegrationTestUtils.java
index 8bca79f..e6ba7b8 100644
--- a/streams/src/test/java/org/apache/kafka/streams/integration/utils/IntegrationTestUtils.java
+++ b/streams/src/test/java/org/apache/kafka/streams/integration/utils/IntegrationTestUtils.java
@@ -58,11 +58,16 @@ import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.Optional;
 import java.util.Properties;
+import java.util.Set;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Future;
 import java.util.stream.Collectors;
 
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.equalTo;
+
 /**
  * Utility functions to make integration testing more convenient.
  */
@@ -118,10 +123,14 @@ public class IntegrationTestUtils {
     }
 
     public static void cleanStateBeforeTest(final EmbeddedKafkaCluster cluster, final String... topics) {
+        cleanStateBeforeTest(cluster, 1, topics);
+    }
+
+    public static void cleanStateBeforeTest(final EmbeddedKafkaCluster cluster, final int partitionCount, final String... topics) {
         try {
             cluster.deleteAllTopicsAndWait(DEFAULT_TIMEOUT);
             for (final String topic : topics) {
-                cluster.createTopic(topic, 1, 1);
+                cluster.createTopic(topic, partitionCount, 1);
             }
         } catch (final InterruptedException e) {
             throw new RuntimeException(e);
@@ -147,13 +156,13 @@ public class IntegrationTestUtils {
     public static <K, V> void produceKeyValuesSynchronously(
         final String topic, final Collection<KeyValue<K, V>> records, final Properties producerConfig, final Time time)
         throws ExecutionException, InterruptedException {
-        IntegrationTestUtils.produceKeyValuesSynchronously(topic, records, producerConfig, time, false);
+        produceKeyValuesSynchronously(topic, records, producerConfig, time, false);
     }
 
     public static <K, V> void produceKeyValuesSynchronously(
         final String topic, final Collection<KeyValue<K, V>> records, final Properties producerConfig, final Headers headers, final Time time)
         throws ExecutionException, InterruptedException {
-        IntegrationTestUtils.produceKeyValuesSynchronously(topic, records, producerConfig, headers, time, false);
+        produceKeyValuesSynchronously(topic, records, producerConfig, headers, time, false);
     }
 
     /**
@@ -167,7 +176,7 @@ public class IntegrationTestUtils {
     public static <K, V> void produceKeyValuesSynchronously(
         final String topic, final Collection<KeyValue<K, V>> records, final Properties producerConfig, final Time time, final boolean enableTransactions)
         throws ExecutionException, InterruptedException {
-        IntegrationTestUtils.produceKeyValuesSynchronously(topic, records, producerConfig, null, time, enableTransactions);
+        produceKeyValuesSynchronously(topic, records, producerConfig, null, time, enableTransactions);
     }
 
     public static <K, V> void produceKeyValuesSynchronously(final String topic,
@@ -193,7 +202,7 @@ public class IntegrationTestUtils {
                                                                          final Properties producerConfig,
                                                                          final Long timestamp)
         throws ExecutionException, InterruptedException {
-        IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp(topic, records, producerConfig, timestamp, false);
+        produceKeyValuesSynchronouslyWithTimestamp(topic, records, producerConfig, timestamp, false);
     }
 
     public static <K, V> void produceKeyValuesSynchronouslyWithTimestamp(final String topic,
@@ -202,7 +211,7 @@ public class IntegrationTestUtils {
                                                                          final Long timestamp,
                                                                          final boolean enableTransactions)
         throws ExecutionException, InterruptedException {
-        IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp(topic, records, producerConfig, null, timestamp, enableTransactions);
+        produceKeyValuesSynchronouslyWithTimestamp(topic, records, producerConfig, null, timestamp, enableTransactions);
     }
 
     public static <K, V> void produceKeyValuesSynchronouslyWithTimestamp(final String topic,
@@ -230,20 +239,19 @@ public class IntegrationTestUtils {
     }
 
     public static <V, K> void produceSynchronously(final Properties producerConfig,
-                                                    final boolean eos,
-                                                    final String topic,
-                                                    final List<KeyValueTimestamp<K, V>> toProduce) {
+                                                   final boolean eos,
+                                                   final String topic,
+                                                   final Optional<Integer> partition,
+                                                   final List<KeyValueTimestamp<K, V>> toProduce) {
         try (final Producer<K, V> producer = new KafkaProducer<>(producerConfig)) {
-            // TODO: test EOS
-            //noinspection ConstantConditions
-            if (false) {
+            if (eos) {
                 producer.initTransactions();
                 producer.beginTransaction();
             }
             final LinkedList<Future<RecordMetadata>> futures = new LinkedList<>();
             for (final KeyValueTimestamp<K, V> record : toProduce) {
                 final Future<RecordMetadata> f = producer.send(
-                    new ProducerRecord<>(topic, null, record.timestamp(), record.key(), record.value(), null)
+                    new ProducerRecord<>(topic, partition.orElse(null), record.timestamp(), record.key(), record.value(), null)
                 );
                 futures.add(f);
             }
@@ -286,7 +294,7 @@ public class IntegrationTestUtils {
                                                       final Properties producerConfig,
                                                       final Time time)
         throws ExecutionException, InterruptedException {
-        IntegrationTestUtils.produceValuesSynchronously(topic, records, producerConfig, time, false);
+        produceValuesSynchronously(topic, records, producerConfig, time, false);
     }
 
     public static <V> void produceValuesSynchronously(final String topic,
@@ -540,7 +548,7 @@ public class IntegrationTestUtils {
 
         final List<ConsumerRecord<String, Long>> results;
         try {
-            results = IntegrationTestUtils.waitUntilMinRecordsReceived(consumerConfig, topic, expected.size());
+            results = waitUntilMinRecordsReceived(consumerConfig, topic, expected.size());
         } catch (final InterruptedException e) {
             throw new RuntimeException(e);
         }
@@ -559,6 +567,28 @@ public class IntegrationTestUtils {
         }
     }
 
+    public static void verifyKeyValueTimestamps(final Properties consumerConfig,
+                                                final String topic,
+                                                final Set<KeyValueTimestamp<String, Long>> expected) {
+        final List<ConsumerRecord<String, Long>> results;
+        try {
+            results = waitUntilMinRecordsReceived(consumerConfig, topic, expected.size());
+        } catch (final InterruptedException e) {
+            throw new RuntimeException(e);
+        }
+
+        if (results.size() != expected.size()) {
+            throw new AssertionError(printRecords(results) + " != " + expected);
+        }
+
+        final Set<KeyValueTimestamp<String, Long>> actual =
+            results.stream()
+                   .map(result -> new KeyValueTimestamp<>(result.key(), result.value(), result.timestamp()))
+                   .collect(Collectors.toSet());
+
+        assertThat(actual, equalTo(expected));
+    }
+
     private static <K, V> void compareKeyValueTimestamp(final ConsumerRecord<K, V> record,
                                                         final K expectedKey,
                                                         final V expectedValue,
diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/suppress/KTableSuppressProcessorTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/suppress/KTableSuppressProcessorTest.java
index 43c3f40..bb1bc0f 100644
--- a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/suppress/KTableSuppressProcessorTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/suppress/KTableSuppressProcessorTest.java
@@ -56,7 +56,6 @@ import static org.hamcrest.CoreMatchers.is;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.Assert.fail;
 
-@SuppressWarnings("PointlessArithmeticExpression")
 public class KTableSuppressProcessorTest {
     private static final long ARBITRARY_LONG = 5L;
 
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/TimeOrderedKeyValueBufferTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/TimeOrderedKeyValueBufferTest.java
new file mode 100644
index 0000000..2953953
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/TimeOrderedKeyValueBufferTest.java
@@ -0,0 +1,604 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.state.internals;
+
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.producer.ProducerRecord;
+import org.apache.kafka.common.header.Header;
+import org.apache.kafka.common.header.internals.RecordHeader;
+import org.apache.kafka.common.header.internals.RecordHeaders;
+import org.apache.kafka.common.record.TimestampType;
+import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.common.utils.Utils;
+import org.apache.kafka.streams.KeyValue;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.internals.ProcessorRecordContext;
+import org.apache.kafka.streams.processor.internals.RecordBatchingStateRestoreCallback;
+import org.apache.kafka.test.MockInternalProcessorContext;
+import org.apache.kafka.test.MockInternalProcessorContext.MockRecordCollector;
+import org.apache.kafka.test.TestUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.Collection;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Properties;
+import java.util.Random;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+
+import static java.nio.charset.StandardCharsets.UTF_8;
+import static java.util.Arrays.asList;
+import static java.util.Collections.singletonList;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.is;
+import static org.junit.Assert.fail;
+
+@RunWith(Parameterized.class)
+public class TimeOrderedKeyValueBufferTest<B extends TimeOrderedKeyValueBuffer> {
+    private static final RecordHeaders V_1_CHANGELOG_HEADERS =
+        new RecordHeaders(new Header[] {new RecordHeader("v", new byte[] {(byte) 1})});
+
+    private static final String APP_ID = "test-app";
+    private final Function<String, B> bufferSupplier;
+    private final String testName;
+
+    // As we add more buffer implementations/configurations, we can add them here
+    @Parameterized.Parameters(name = "{index}: test={0}")
+    public static Collection<Object[]> parameters() {
+        return singletonList(
+            new Object[] {
+                "in-memory buffer",
+                (Function<String, InMemoryTimeOrderedKeyValueBuffer>) name ->
+                    (InMemoryTimeOrderedKeyValueBuffer) new InMemoryTimeOrderedKeyValueBuffer
+                        .Builder(name)
+                        .build()
+            }
+        );
+    }
+
+    public TimeOrderedKeyValueBufferTest(final String testName, final Function<String, B> bufferSupplier) {
+        this.testName = testName + "_" + new Random().nextInt(Integer.MAX_VALUE);
+        this.bufferSupplier = bufferSupplier;
+    }
+
+    private static MockInternalProcessorContext makeContext() {
+        final Properties properties = new Properties();
+        properties.setProperty(StreamsConfig.APPLICATION_ID_CONFIG, APP_ID);
+        properties.setProperty(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "");
+
+        final TaskId taskId = new TaskId(0, 0);
+
+        final MockInternalProcessorContext context = new MockInternalProcessorContext(properties, taskId, TestUtils.tempDirectory());
+        context.setRecordCollector(new MockRecordCollector());
+
+        return context;
+    }
+
+
+    private static void cleanup(final MockInternalProcessorContext context, final TimeOrderedKeyValueBuffer buffer) {
+        try {
+            buffer.close();
+            Utils.delete(context.stateDir());
+        } catch (final IOException e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+    @Test
+    public void shouldInit() {
+        final TimeOrderedKeyValueBuffer buffer = bufferSupplier.apply(testName);
+        final MockInternalProcessorContext context = makeContext();
+        buffer.init(context, buffer);
+        cleanup(context, buffer);
+    }
+
+    @Test
+    public void shouldAcceptData() {
+        final TimeOrderedKeyValueBuffer buffer = bufferSupplier.apply(testName);
+        final MockInternalProcessorContext context = makeContext();
+        buffer.init(context, buffer);
+        putRecord(buffer, context, "2p93nf", 0, "asdf");
+        cleanup(context, buffer);
+    }
+
+    @Test
+    public void shouldRejectNullValues() {
+        final TimeOrderedKeyValueBuffer buffer = bufferSupplier.apply(testName);
+        final MockInternalProcessorContext context = makeContext();
+        buffer.init(context, buffer);
+        try {
+            buffer.put(0, getBytes("asdf"), new ContextualRecord(
+                null,
+                new ProcessorRecordContext(0, 0, 0, "topic")
+            ));
+            fail("expected an exception");
+        } catch (final NullPointerException expected) {
+            // expected
+        }
+        cleanup(context, buffer);
+    }
+
+    private static ContextualRecord getRecord(final String value) {
+        return getRecord(value, 0L);
+    }
+
+    private static ContextualRecord getRecord(final String value, final long timestamp) {
+        return new ContextualRecord(
+            value.getBytes(UTF_8),
+            new ProcessorRecordContext(timestamp, 0, 0, "topic")
+        );
+    }
+
+    private static Bytes getBytes(final String key) {
+        return Bytes.wrap(key.getBytes(UTF_8));
+    }
+
+    @Test
+    public void shouldRemoveData() {
+        final TimeOrderedKeyValueBuffer buffer = bufferSupplier.apply(testName);
+        final MockInternalProcessorContext context = makeContext();
+        buffer.init(context, buffer);
+        putRecord(buffer, context, "qwer", 0, "asdf");
+        assertThat(buffer.numRecords(), is(1));
+        buffer.evictWhile(() -> true, kv -> { });
+        assertThat(buffer.numRecords(), is(0));
+        cleanup(context, buffer);
+    }
+
+    @Test
+    public void shouldRespectEvictionPredicate() {
+        final TimeOrderedKeyValueBuffer buffer = bufferSupplier.apply(testName);
+        final MockInternalProcessorContext context = makeContext();
+        buffer.init(context, buffer);
+        final Bytes firstKey = getBytes("asdf");
+        final ContextualRecord firstRecord = getRecord("eyt");
+        putRecord(0, buffer, context, firstRecord, firstKey);
+        putRecord(buffer, context, "rtg", 1, "zxcv");
+        assertThat(buffer.numRecords(), is(2));
+        final List<KeyValue<Bytes, ContextualRecord>> evicted = new LinkedList<>();
+        buffer.evictWhile(() -> buffer.numRecords() > 1, evicted::add);
+        assertThat(buffer.numRecords(), is(1));
+        assertThat(evicted, is(singletonList(new KeyValue<>(firstKey, firstRecord))));
+        cleanup(context, buffer);
+    }
+
+    @Test
+    public void shouldTrackCount() {
+        final TimeOrderedKeyValueBuffer buffer = bufferSupplier.apply(testName);
+        final MockInternalProcessorContext context = makeContext();
+        buffer.init(context, buffer);
+        putRecord(buffer, context, "oin", 0, "asdf");
+        assertThat(buffer.numRecords(), is(1));
+        putRecord(buffer, context, "wekjn", 1, "asdf");
+        assertThat(buffer.numRecords(), is(1));
+        putRecord(buffer, context, "24inf", 0, "zxcv");
+        assertThat(buffer.numRecords(), is(2));
+        cleanup(context, buffer);
+    }
+
+    @Test
+    public void shouldTrackSize() {
+        final TimeOrderedKeyValueBuffer buffer = bufferSupplier.apply(testName);
+        final MockInternalProcessorContext context = makeContext();
+        buffer.init(context, buffer);
+        putRecord(buffer, context, "23roni", 0, "asdf");
+        assertThat(buffer.bufferSize(), is(43L));
+        putRecord(buffer, context, "3l", 1, "asdf");
+        assertThat(buffer.bufferSize(), is(39L));
+        putRecord(buffer, context, "qfowin", 0, "zxcv");
+        assertThat(buffer.bufferSize(), is(82L));
+        cleanup(context, buffer);
+    }
+
+    @Test
+    public void shouldTrackMinTimestamp() {
+        final TimeOrderedKeyValueBuffer buffer = bufferSupplier.apply(testName);
+        final MockInternalProcessorContext context = makeContext();
+        buffer.init(context, buffer);
+        putRecord(buffer, context, "2093j", 1, "asdf");
+        assertThat(buffer.minTimestamp(), is(1L));
+        putRecord(buffer, context, "3gon4i", 0, "zxcv");
+        assertThat(buffer.minTimestamp(), is(0L));
+        cleanup(context, buffer);
+    }
+
+    private static void putRecord(final TimeOrderedKeyValueBuffer buffer,
+                                  final MockInternalProcessorContext context,
+                                  final String value,
+                                  final int time,
+                                  final String key) {
+        putRecord(time, buffer, context, getRecord(value), getBytes(key));
+    }
+
+    private static void putRecord(final int time,
+                                  final TimeOrderedKeyValueBuffer buffer,
+                                  final MockInternalProcessorContext context,
+                                  final ContextualRecord firstRecord,
+                                  final Bytes firstKey) {
+        context.setRecordContext(firstRecord.recordContext());
+        buffer.put(time, firstKey, firstRecord);
+    }
+
+    @Test
+    public void shouldEvictOldestAndUpdateSizeAndCountAndMinTimestamp() {
+        final TimeOrderedKeyValueBuffer buffer = bufferSupplier.apply(testName);
+        final MockInternalProcessorContext context = makeContext();
+        buffer.init(context, buffer);
+
+        putRecord(buffer, context, "o23i4", 1, "zxcv");
+        assertThat(buffer.numRecords(), is(1));
+        assertThat(buffer.bufferSize(), is(42L));
+        assertThat(buffer.minTimestamp(), is(1L));
+
+        putRecord(buffer, context, "3ng", 0, "asdf");
+        assertThat(buffer.numRecords(), is(2));
+        assertThat(buffer.bufferSize(), is(82L));
+        assertThat(buffer.minTimestamp(), is(0L));
+
+        final AtomicInteger callbackCount = new AtomicInteger(0);
+        buffer.evictWhile(() -> true, kv -> {
+            switch (callbackCount.incrementAndGet()) {
+                case 1: {
+                    assertThat(new String(kv.key.get(), UTF_8), is("asdf"));
+                    assertThat(buffer.numRecords(), is(2));
+                    assertThat(buffer.bufferSize(), is(82L));
+                    assertThat(buffer.minTimestamp(), is(0L));
+                    break;
+                }
+                case 2: {
+                    assertThat(new String(kv.key.get(), UTF_8), is("zxcv"));
+                    assertThat(buffer.numRecords(), is(1));
+                    assertThat(buffer.bufferSize(), is(42L));
+                    assertThat(buffer.minTimestamp(), is(1L));
+                    break;
+                }
+                default: {
+                    fail("too many invocations");
+                    break;
+                }
+            }
+        });
+        assertThat(callbackCount.get(), is(2));
+        assertThat(buffer.numRecords(), is(0));
+        assertThat(buffer.bufferSize(), is(0L));
+        assertThat(buffer.minTimestamp(), is(Long.MAX_VALUE));
+        cleanup(context, buffer);
+    }
+
+    @Test
+    public void shouldFlush() {
+        final TimeOrderedKeyValueBuffer buffer = bufferSupplier.apply(testName);
+        final MockInternalProcessorContext context = makeContext();
+        buffer.init(context, buffer);
+        putRecord(2, buffer, context, getRecord("2093j", 0L), getBytes("asdf"));
+        putRecord(1, buffer, context, getRecord("3gon4i", 1L), getBytes("zxcv"));
+        putRecord(0, buffer, context, getRecord("deadbeef", 2L), getBytes("deleteme"));
+
+        // replace "deleteme" with a tombstone
+        buffer.evictWhile(() -> buffer.minTimestamp() < 1, kv -> { });
+
+        // flush everything to the changelog
+        buffer.flush();
+
+        // the buffer should serialize the buffer time and the value as byte[],
+        // which we can't compare for equality using ProducerRecord.
+        // As a workaround, I'm deserializing them and shoving them in a KeyValue, just for ease of testing.
+
+        final List<ProducerRecord<String, KeyValue<Long, ContextualRecord>>> collected =
+            ((MockRecordCollector) context.recordCollector())
+                .collected()
+                .stream()
+                .map(pr -> {
+                    final KeyValue<Long, ContextualRecord> niceValue;
+                    if (pr.value() == null) {
+                        niceValue = null;
+                    } else {
+                        final byte[] timestampAndValue = pr.value();
+                        final ByteBuffer wrap = ByteBuffer.wrap(timestampAndValue);
+                        final long timestamp = wrap.getLong();
+                        final ContextualRecord contextualRecord = ContextualRecord.deserialize(wrap);
+                        niceValue = new KeyValue<>(timestamp, contextualRecord);
+                    }
+
+                    return new ProducerRecord<>(pr.topic(),
+                                                pr.partition(),
+                                                pr.timestamp(),
+                                                new String(pr.key(), UTF_8),
+                                                niceValue,
+                                                pr.headers());
+                })
+                .collect(Collectors.toList());
+
+        assertThat(collected, is(asList(
+            new ProducerRecord<>(APP_ID + "-" + testName + "-changelog",
+                                 0,   // Producer will assign
+                                 null,
+                                 "deleteme",
+                                 null,
+                                 new RecordHeaders()
+            ),
+            new ProducerRecord<>(APP_ID + "-" + testName + "-changelog",
+                                 0,
+                                 null,
+                                 "zxcv",
+                                 new KeyValue<>(1L, getRecord("3gon4i", 1)),
+                                 V_1_CHANGELOG_HEADERS
+            ),
+            new ProducerRecord<>(APP_ID + "-" + testName + "-changelog",
+                                 0,
+                                 null,
+                                 "asdf",
+                                 new KeyValue<>(2L, getRecord("2093j", 0)),
+                                 V_1_CHANGELOG_HEADERS
+            )
+        )));
+
+        cleanup(context, buffer);
+    }
+
+
+    @Test
+    public void shouldRestoreOldFormat() {
+        final TimeOrderedKeyValueBuffer buffer = bufferSupplier.apply(testName);
+        final MockInternalProcessorContext context = makeContext();
+        buffer.init(context, buffer);
+
+        final RecordBatchingStateRestoreCallback stateRestoreCallback =
+            (RecordBatchingStateRestoreCallback) context.stateRestoreCallback(testName);
+
+        context.setRecordContext(new ProcessorRecordContext(0, 0, 0, ""));
+
+        stateRestoreCallback.restoreBatch(asList(
+            new ConsumerRecord<>("changelog-topic",
+                                 0,
+                                 0,
+                                 0,
+                                 TimestampType.CREATE_TIME,
+                                 -1,
+                                 -1,
+                                 -1,
+                                 "todelete".getBytes(UTF_8),
+                                 ByteBuffer.allocate(Long.BYTES + 6).putLong(0L).put("doomed".getBytes(UTF_8)).array()),
+            new ConsumerRecord<>("changelog-topic",
+                                 0,
+                                 1,
+                                 1,
+                                 TimestampType.CREATE_TIME,
+                                 -1,
+                                 -1,
+                                 -1,
+                                 "asdf".getBytes(UTF_8),
+                                 ByteBuffer.allocate(Long.BYTES + 4).putLong(2L).put("qwer".getBytes(UTF_8)).array()),
+            new ConsumerRecord<>("changelog-topic",
+                                 0,
+                                 2,
+                                 2,
+                                 TimestampType.CREATE_TIME,
+                                 -1,
+                                 -1,
+                                 -1,
+                                 "zxcv".getBytes(UTF_8),
+                                 ByteBuffer.allocate(Long.BYTES + 5).putLong(1L).put("3o4im".getBytes(UTF_8)).array())
+        ));
+
+        assertThat(buffer.numRecords(), is(3));
+        assertThat(buffer.minTimestamp(), is(0L));
+        assertThat(buffer.bufferSize(), is(160L));
+
+        stateRestoreCallback.restoreBatch(singletonList(
+            new ConsumerRecord<>("changelog-topic",
+                                 0,
+                                 3,
+                                 3,
+                                 TimestampType.CREATE_TIME,
+                                 -1,
+                                 -1,
+                                 -1,
+                                 "todelete".getBytes(UTF_8),
+                                 null)
+        ));
+
+        assertThat(buffer.numRecords(), is(2));
+        assertThat(buffer.minTimestamp(), is(1L));
+        assertThat(buffer.bufferSize(), is(103L));
+
+        // flush the buffer into a list in buffer order so we can make assertions about the contents.
+
+        final List<KeyValue<Bytes, ContextualRecord>> evicted = new LinkedList<>();
+        buffer.evictWhile(() -> true, evicted::add);
+
+        // Several things to note:
+        // * The buffered records are ordered according to their buffer time (serialized in the value of the changelog)
+        // * The record timestamps are properly restored, and not conflated with the record's buffer time.
+        // * The keys and values are properly restored
+        // * The record topic is set to the changelog topic. This was an oversight in the original implementation,
+        //   which is fixed in changelog format v1. But upgraded applications still need to be able to handle the
+        //   original format.
+
+        assertThat(evicted, is(asList(
+            new KeyValue<>(
+                getBytes("zxcv"),
+                new ContextualRecord("3o4im".getBytes(UTF_8),
+                                     new ProcessorRecordContext(2,
+                                                                2,
+                                                                0,
+                                                                "changelog-topic",
+                                                                new RecordHeaders()))),
+            new KeyValue<>(
+                getBytes("asdf"),
+                new ContextualRecord("qwer".getBytes(UTF_8),
+                                     new ProcessorRecordContext(1,
+                                                                1,
+                                                                0,
+                                                                "changelog-topic",
+                                                                new RecordHeaders())))
+        )));
+
+        cleanup(context, buffer);
+    }
+
+    @Test
+    public void shouldRestoreNewFormat() {
+        final TimeOrderedKeyValueBuffer buffer = bufferSupplier.apply(testName);
+        final MockInternalProcessorContext context = makeContext();
+        buffer.init(context, buffer);
+
+        final RecordBatchingStateRestoreCallback stateRestoreCallback =
+            (RecordBatchingStateRestoreCallback) context.stateRestoreCallback(testName);
+
+        context.setRecordContext(new ProcessorRecordContext(0, 0, 0, ""));
+
+        final RecordHeaders v1FlagHeaders = new RecordHeaders(new Header[] {new RecordHeader("v", new byte[] {(byte) 1})});
+
+        final byte[] todeleteValue = getRecord("doomed", 0).serialize();
+        final byte[] asdfValue = getRecord("qwer", 1).serialize();
+        final byte[] zxcvValue = getRecord("3o4im", 2).serialize();
+        stateRestoreCallback.restoreBatch(asList(
+            new ConsumerRecord<>("changelog-topic",
+                                 0,
+                                 0,
+                                 999,
+                                 TimestampType.CREATE_TIME,
+                                 -1L,
+                                 -1,
+                                 -1,
+                                 "todelete".getBytes(UTF_8),
+                                 ByteBuffer.allocate(Long.BYTES + todeleteValue.length).putLong(0L).put(todeleteValue).array(),
+                                 v1FlagHeaders),
+            new ConsumerRecord<>("changelog-topic",
+                                 0,
+                                 1,
+                                 9999,
+                                 TimestampType.CREATE_TIME,
+                                 -1L,
+                                 -1,
+                                 -1,
+                                 "asdf".getBytes(UTF_8),
+                                 ByteBuffer.allocate(Long.BYTES + asdfValue.length).putLong(2L).put(asdfValue).array(),
+                                 v1FlagHeaders),
+            new ConsumerRecord<>("changelog-topic",
+                                 0,
+                                 2,
+                                 99,
+                                 TimestampType.CREATE_TIME,
+                                 -1L,
+                                 -1,
+                                 -1,
+                                 "zxcv".getBytes(UTF_8),
+                                 ByteBuffer.allocate(Long.BYTES + zxcvValue.length).putLong(1L).put(zxcvValue).array(),
+                                 v1FlagHeaders)
+        ));
+
+        assertThat(buffer.numRecords(), is(3));
+        assertThat(buffer.minTimestamp(), is(0L));
+        assertThat(buffer.bufferSize(), is(130L));
+
+        stateRestoreCallback.restoreBatch(singletonList(
+            new ConsumerRecord<>("changelog-topic",
+                                 0,
+                                 3,
+                                 3,
+                                 TimestampType.CREATE_TIME,
+                                 -1L,
+                                 -1,
+                                 -1,
+                                 "todelete".getBytes(UTF_8),
+                                 null)
+        ));
+
+        assertThat(buffer.numRecords(), is(2));
+        assertThat(buffer.minTimestamp(), is(1L));
+        assertThat(buffer.bufferSize(), is(83L));
+
+        // flush the buffer into a list in buffer order so we can make assertions about the contents.
+
+        final List<KeyValue<Bytes, ContextualRecord>> evicted = new LinkedList<>();
+        buffer.evictWhile(() -> true, evicted::add);
+
+        // Several things to note:
+        // * The buffered records are ordered according to their buffer time (serialized in the value of the changelog)
+        // * The record timestamps are properly restored, and not conflated with the record's buffer time.
+        // * The keys and values are properly restored
+        // * The record topic is set to the original input topic, *not* the changelog topic
+        // * The record offset preserves the origininal input record's offset, *not* the offset of the changelog record
+
+
+        assertThat(evicted, is(asList(
+            new KeyValue<>(
+                getBytes("zxcv"),
+                new ContextualRecord("3o4im".getBytes(UTF_8),
+                                     new ProcessorRecordContext(2,
+                                                                0,
+                                                                0,
+                                                                "topic",
+                                                                null))),
+            new KeyValue<>(
+                getBytes("asdf"),
+                new ContextualRecord("qwer".getBytes(UTF_8),
+                                     new ProcessorRecordContext(1,
+                                                                0,
+                                                                0,
+                                                                "topic",
+                                                                null)))
+        )));
+
+        cleanup(context, buffer);
+    }
+
+    @Test
+    public void shouldNotRestoreUnrecognizedVersionRecord() {
+        final TimeOrderedKeyValueBuffer buffer = bufferSupplier.apply(testName);
+        final MockInternalProcessorContext context = makeContext();
+        buffer.init(context, buffer);
+
+        final RecordBatchingStateRestoreCallback stateRestoreCallback =
+            (RecordBatchingStateRestoreCallback) context.stateRestoreCallback(testName);
+
+        context.setRecordContext(new ProcessorRecordContext(0, 0, 0, ""));
+
+        final RecordHeaders unknownFlagHeaders = new RecordHeaders(new Header[] {new RecordHeader("v", new byte[] {(byte) -1})});
+
+        final byte[] todeleteValue = getRecord("doomed", 0).serialize();
+        try {
+            stateRestoreCallback.restoreBatch(singletonList(
+                new ConsumerRecord<>("changelog-topic",
+                                     0,
+                                     0,
+                                     999,
+                                     TimestampType.CREATE_TIME,
+                                     -1L,
+                                     -1,
+                                     -1,
+                                     "todelete".getBytes(UTF_8),
+                                     ByteBuffer.allocate(Long.BYTES + todeleteValue.length).putLong(0L).put(todeleteValue).array(),
+                                     unknownFlagHeaders)
+            ));
+            fail("expected an exception");
+        } catch (final IllegalArgumentException expected) {
+            // nothing to do.
+        } finally {
+            cleanup(context, buffer);
+        }
+    }
+}
diff --git a/streams/src/test/java/org/apache/kafka/test/MockInternalProcessorContext.java b/streams/src/test/java/org/apache/kafka/test/MockInternalProcessorContext.java
index 62a8491..b25aa9e 100644
--- a/streams/src/test/java/org/apache/kafka/test/MockInternalProcessorContext.java
+++ b/streams/src/test/java/org/apache/kafka/test/MockInternalProcessorContext.java
@@ -16,15 +16,100 @@
  */
 package org.apache.kafka.test;
 
+import org.apache.kafka.clients.producer.Producer;
+import org.apache.kafka.clients.producer.ProducerRecord;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.header.Headers;
+import org.apache.kafka.common.serialization.Serializer;
 import org.apache.kafka.streams.processor.MockProcessorContext;
+import org.apache.kafka.streams.processor.StateRestoreCallback;
+import org.apache.kafka.streams.processor.StateStore;
+import org.apache.kafka.streams.processor.StreamPartitioner;
+import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.internals.InternalProcessorContext;
 import org.apache.kafka.streams.processor.internals.ProcessorNode;
 import org.apache.kafka.streams.processor.internals.ProcessorRecordContext;
+import org.apache.kafka.streams.processor.internals.RecordCollector;
 import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
 import org.apache.kafka.streams.state.internals.ThreadCache;
 
+import java.io.File;
+import java.util.LinkedHashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Properties;
+
+import static java.util.Collections.unmodifiableList;
+
 public class MockInternalProcessorContext extends MockProcessorContext implements InternalProcessorContext {
+    public static final class MockRecordCollector implements RecordCollector {
+        private final List<ProducerRecord<byte[], byte[]>> collected = new LinkedList<>();
+
+        @Override
+        public <K, V> void send(final String topic,
+                                final K key,
+                                final V value,
+                                final Headers headers,
+                                final Integer partition,
+                                final Long timestamp,
+                                final Serializer<K> keySerializer,
+                                final Serializer<V> valueSerializer) {
+            collected.add(new ProducerRecord<>(topic,
+                                               partition,
+                                               timestamp,
+                                               keySerializer.serialize(topic, key),
+                                               valueSerializer.serialize(topic, value),
+                                               headers));
+        }
+
+        @Override
+        public <K, V> void send(final String topic,
+                                final K key,
+                                final V value,
+                                final Headers headers,
+                                final Long timestamp,
+                                final Serializer<K> keySerializer,
+                                final Serializer<V> valueSerializer,
+                                final StreamPartitioner<? super K, ? super V> partitioner) {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        public void init(final Producer<byte[], byte[]> producer) {
+
+        }
+
+        @Override
+        public void flush() {
+
+        }
+
+        @Override
+        public void close() {
+
+        }
+
+        @Override
+        public Map<TopicPartition, Long> offsets() {
+            return null;
+        }
+
+        public List<ProducerRecord<byte[], byte[]>> collected() {
+            return unmodifiableList(collected);
+        }
+    }
+
+    private final Map<String, StateRestoreCallback> restoreCallbacks = new LinkedHashMap<>();
     private ProcessorNode currentNode;
+    private RecordCollector recordCollector;
+
+    public MockInternalProcessorContext() {
+    }
+
+    public MockInternalProcessorContext(final Properties config, final TaskId taskId, final File stateDir) {
+        super(config, taskId, stateDir);
+    }
 
     @Override
     public StreamsMetricsImpl metrics() {
@@ -68,7 +153,24 @@ public class MockInternalProcessorContext extends MockProcessorContext implement
     }
 
     @Override
-    public void uninitialize() {
+    public void uninitialize() {}
+
+    @Override
+    public RecordCollector recordCollector() {
+        return recordCollector;
+    }
+
+    public void setRecordCollector(final RecordCollector recordCollector) {
+        this.recordCollector = recordCollector;
+    }
+
+    @Override
+    public void register(final StateStore store, final StateRestoreCallback stateRestoreCallback) {
+        restoreCallbacks.put(store.name(), stateRestoreCallback);
+        super.register(store, stateRestoreCallback);
+    }
 
+    public StateRestoreCallback stateRestoreCallback(final String storeName) {
+        return restoreCallbacks.get(storeName);
     }
 }
\ No newline at end of file