You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by gu...@apache.org on 2019/05/11 18:20:27 UTC

[kafka] branch trunk updated: [MINOR] Consolidate in-memory/rocksdb unit tests for window & session store (#6677)

This is an automated email from the ASF dual-hosted git repository.

guozhang pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 5236a3e  [MINOR] Consolidate in-memory/rocksdb unit tests for window & session store (#6677)
5236a3e is described below

commit 5236a3e5ec35b02ad3b40e904997b91b82e9f36d
Author: A. Sophie Blee-Goldman <so...@confluent.io>
AuthorDate: Sat May 11 11:20:15 2019 -0700

    [MINOR] Consolidate in-memory/rocksdb unit tests for window & session store (#6677)
    
    Consolidated the unit tests by having {RocksDB/InMemory}{Window/Session}StoreTest extend {Window/Session}BytesStoreTest. Besides some implementation-specific tests (eg involving segment maintenance) all tests were moved to the abstract XXXBytesStoreTest class. The test coverage now is a superset of the original test coverage for each store type.
    
    The only difference made to existing tests (besides moving them) was to switch from list-based equality comparison to set based, in order to reflect that the stores make no guarantees regarding the ordering of records returned from a range fetch.
    
    There are some implementation-specific tests that were left in the corresponding test class. The RocksDBWindowStoreTest, for example, had several tests pertaining to segments and/or the underlying filesystem. Another key difference is that the in-memory versions should delete expired records aggressively, while the RocksDB versions should only remove entirely expired segments.
    
    
    Reviewers: John Roesler <jo...@confluent.io>, Guozhang Wang <wa...@gmail.com>
---
 .../state/internals/InMemorySessionStoreTest.java  |  460 +-----
 .../state/internals/InMemoryWindowStoreTest.java   |  597 ++------
 .../state/internals/RocksDBSessionStoreTest.java   |  331 +----
 .../state/internals/RocksDBWindowStoreTest.java    | 1481 +++++---------------
 ...onStoreTest.java => SessionBytesStoreTest.java} |  285 ++--
 .../state/internals/WindowBytesStoreTest.java      | 1104 +++++++++++++++
 6 files changed, 1728 insertions(+), 2530 deletions(-)

diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/InMemorySessionStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/InMemorySessionStoreTest.java
index bbe8d21..6641bcc 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/InMemorySessionStoreTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/InMemorySessionStoreTest.java
@@ -18,323 +18,46 @@ package org.apache.kafka.streams.state.internals;
 
 import static java.time.Duration.ofMillis;
 
-import static org.apache.kafka.common.utils.Utils.mkEntry;
-import static org.apache.kafka.common.utils.Utils.mkMap;
-import static org.apache.kafka.test.StreamsTestUtils.toList;
-import static org.apache.kafka.test.StreamsTestUtils.valuesToList;
-import static org.hamcrest.CoreMatchers.equalTo;
-import static org.hamcrest.CoreMatchers.hasItem;
-import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNotEquals;
-import static org.junit.Assert.assertTrue;
 
-import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Collections;
-import java.util.List;
+import java.util.HashSet;
 
-import java.util.Map;
-import org.apache.kafka.clients.producer.MockProducer;
-import org.apache.kafka.clients.producer.Producer;
-import org.apache.kafka.common.Metric;
-import org.apache.kafka.common.MetricName;
-import org.apache.kafka.common.header.Headers;
-import org.apache.kafka.common.metrics.Metrics;
-import org.apache.kafka.common.serialization.Serdes;
-import org.apache.kafka.common.serialization.Serializer;
-import org.apache.kafka.common.utils.LogContext;
-import org.apache.kafka.streams.KeyValue;
-import org.apache.kafka.streams.errors.DefaultProductionExceptionHandler;
+import org.apache.kafka.common.serialization.Serde;
 import org.apache.kafka.streams.kstream.Windowed;
 import org.apache.kafka.streams.kstream.internals.SessionWindow;
-import org.apache.kafka.streams.processor.internals.MockStreamsMetrics;
-import org.apache.kafka.streams.processor.internals.RecordCollector;
-import org.apache.kafka.streams.processor.internals.RecordCollectorImpl;
 import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender;
 import org.apache.kafka.streams.state.KeyValueIterator;
 import org.apache.kafka.streams.state.SessionStore;
 import org.apache.kafka.streams.state.Stores;
-import org.apache.kafka.test.InternalMockProcessorContext;
-import org.apache.kafka.test.TestUtils;
 
-import org.junit.After;
-import org.junit.Before;
 import org.junit.Test;
 
-public class InMemorySessionStoreTest {
+public class InMemorySessionStoreTest extends SessionBytesStoreTest {
 
-    private static final String STORE_NAME = "InMemorySessionStore";
-    private static final long RETENTION_PERIOD = 10_000L;
+    private static final String STORE_NAME = "in-memory session store";
 
-    private SessionStore<String, Long> sessionStore;
-    private InternalMockProcessorContext context;
-
-    private final List<KeyValue<byte[], byte[]>> changeLog = new ArrayList<>();
-
-    private final Producer<byte[], byte[]> producer = new MockProducer<>(true,
-        Serdes.ByteArray().serializer(),
-        Serdes.ByteArray().serializer());
-
-    private final RecordCollector recordCollector = new RecordCollectorImpl(
-        STORE_NAME,
-        new LogContext(STORE_NAME),
-        new DefaultProductionExceptionHandler(),
-        new Metrics().sensor("skipped-records")) {
-
-        @Override
-        public <K1, V1> void send(final String topic,
-            final K1 key,
-            final V1 value,
-            final Headers headers,
-            final Integer partition,
-            final Long timestamp,
-            final Serializer<K1> keySerializer,
-            final Serializer<V1> valueSerializer) {
-            changeLog.add(new KeyValue<>(
-                keySerializer.serialize(topic, headers, key),
-                valueSerializer.serialize(topic, headers, value))
-            );
-        }
-    };
-
-    private SessionStore<String, Long> buildSessionStore(final long retentionPeriod) {
+    @Override
+    <K, V> SessionStore<K, V> buildSessionStore(final long retentionPeriod,
+                                                 final Serde<K> keySerde,
+                                                 final Serde<V> valueSerde) {
         return Stores.sessionStoreBuilder(
             Stores.inMemorySessionStore(
                 STORE_NAME,
                 ofMillis(retentionPeriod)),
-            Serdes.String(),
-            Serdes.Long()).build();
-    }
-
-    @Before
-    public void before() {
-        context = new InternalMockProcessorContext(
-            TestUtils.tempDirectory(),
-            Serdes.String(),
-            Serdes.Long(),
-            recordCollector,
-            new ThreadCache(
-                new LogContext("testCache"),
-                0,
-                new MockStreamsMetrics(new Metrics())));
-
-        sessionStore = buildSessionStore(RETENTION_PERIOD);
-
-        sessionStore.init(context, sessionStore);
-        recordCollector.init(producer);
-    }
-
-    @After
-    public void after() {
-        sessionStore.close();
-    }
-
-    @Test
-    public void shouldPutAndFindSessionsInRange() {
-        final String key = "a";
-        final Windowed<String> a1 = new Windowed<>(key, new SessionWindow(10, 10L));
-        final Windowed<String> a2 = new Windowed<>(key, new SessionWindow(500L, 1000L));
-        sessionStore.put(a1, 1L);
-        sessionStore.put(a2, 2L);
-        sessionStore.put(new Windowed<>(key, new SessionWindow(1500L, 2000L)), 1L);
-        sessionStore.put(new Windowed<>(key, new SessionWindow(2500L, 3000L)), 2L);
-
-        final List<KeyValue<Windowed<String>, Long>> expected =
-            Arrays.asList(KeyValue.pair(a1, 1L), KeyValue.pair(a2, 2L));
-
-        try (final KeyValueIterator<Windowed<String>, Long> values =
-            sessionStore.findSessions(key, 0, 1000L)
-        ) {
-            assertEquals(expected, toList(values));
-        }
-
-        final List<KeyValue<Windowed<String>, Long>> expected2 = Collections.singletonList(KeyValue.pair(a2, 2L));
-
-        try (final KeyValueIterator<Windowed<String>, Long> values2 =
-            sessionStore.findSessions(key, 400L, 600L)
-        ) {
-            assertEquals(expected2, toList(values2));
-        }
+            keySerde,
+            valueSerde).build();
     }
 
-    @Test
-    public void shouldFetchAllSessionsWithSameRecordKey() {
-        final List<KeyValue<Windowed<String>, Long>> expected = Arrays.asList(
-            KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 1L),
-            KeyValue.pair(new Windowed<>("a", new SessionWindow(10, 10)), 2L),
-            KeyValue.pair(new Windowed<>("a", new SessionWindow(100, 100)), 3L),
-            KeyValue.pair(new Windowed<>("a", new SessionWindow(1000, 1000)), 4L));
-
-        for (final KeyValue<Windowed<String>, Long> kv : expected) {
-            sessionStore.put(kv.key, kv.value);
-        }
-
-        // add one that shouldn't appear in the results
-        sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 0)), 5L);
-
-        try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.fetch("a")) {
-            assertEquals(expected, toList(values));
-        }
-    }
-
-    @Test
-    public void shouldFetchAllSessionsWithinKeyRange() {
-        final List<KeyValue<Windowed<String>, Long>> expected = Arrays.asList(
-            KeyValue.pair(new Windowed<>("aa", new SessionWindow(10, 10)), 2L),
-            KeyValue.pair(new Windowed<>("aaa", new SessionWindow(100, 100)), 3L),
-            KeyValue.pair(new Windowed<>("b", new SessionWindow(1000, 1000)), 4L),
-            KeyValue.pair(new Windowed<>("bb", new SessionWindow(1500, 2000)), 5L));
-
-        for (final KeyValue<Windowed<String>, Long> kv : expected) {
-            sessionStore.put(kv.key, kv.value);
-        }
-
-        // add some that shouldn't appear in the results
-        sessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L);
-        sessionStore.put(new Windowed<>("bbb", new SessionWindow(2500, 3000)), 6L);
-
-        try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.fetch("aa", "bb")) {
-            assertEquals(expected, toList(values));
-        }
-    }
-
-    @Test
-    public void shouldFetchExactSession() {
-        sessionStore.put(new Windowed<>("a", new SessionWindow(0, 4)), 1L);
-        sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 3)), 2L);
-        sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 4)), 3L);
-        sessionStore.put(new Windowed<>("aa", new SessionWindow(1, 4)), 4L);
-        sessionStore.put(new Windowed<>("aaa", new SessionWindow(0, 4)), 5L);
-
-        final long result = sessionStore.fetchSession("aa", 0, 4);
-        assertEquals(3L, result);
-    }
-
-    @Test
-    public void shouldFindValuesWithinMergingSessionWindowRange() {
-        final String key = "a";
-        sessionStore.put(new Windowed<>(key, new SessionWindow(0L, 0L)), 1L);
-        sessionStore.put(new Windowed<>(key, new SessionWindow(1000L, 1000L)), 2L);
-
-        final List<KeyValue<Windowed<String>, Long>> expected = Arrays.asList(
-            KeyValue.pair(new Windowed<>(key, new SessionWindow(0L, 0L)), 1L),
-            KeyValue.pair(new Windowed<>(key, new SessionWindow(1000L, 1000L)), 2L));
-
-        try (final KeyValueIterator<Windowed<String>, Long> results =
-            sessionStore.findSessions(key, -1, 1000L)) {
-            assertEquals(expected, toList(results));
-        }
-    }
-
-    @Test
-    public void shouldRemove() {
-        sessionStore.put(new Windowed<>("a", new SessionWindow(0, 1000)), 1L);
-        sessionStore.put(new Windowed<>("a", new SessionWindow(1500, 2500)), 2L);
-
-        sessionStore.remove(new Windowed<>("a", new SessionWindow(0, 1000)));
-
-        try (final KeyValueIterator<Windowed<String>, Long> results =
-            sessionStore.findSessions("a", 0L, 1000L)) {
-            assertFalse(results.hasNext());
-        }
-
-        try (final KeyValueIterator<Windowed<String>, Long> results =
-            sessionStore.findSessions("a", 1500L, 2500L)) {
-            assertTrue(results.hasNext());
-        }
-    }
-
-    @Test
-    public void shouldRemoveOnNullAggValue() {
-        sessionStore.put(new Windowed<>("a", new SessionWindow(0, 1000)), 1L);
-        sessionStore.put(new Windowed<>("a", new SessionWindow(1500, 2500)), 2L);
-
-        sessionStore.put(new Windowed<>("a", new SessionWindow(0, 1000)), null);
-
-        try (final KeyValueIterator<Windowed<String>, Long> results =
-            sessionStore.findSessions("a", 0L, 1000L)) {
-            assertFalse(results.hasNext());
-        }
-
-        try (final KeyValueIterator<Windowed<String>, Long> results =
-            sessionStore.findSessions("a", 1500L, 2500L)) {
-            assertTrue(results.hasNext());
-        }
-    }
-
-    @Test
-    public void shouldFindSessionsToMerge() {
-        final Windowed<String> session1 = new Windowed<>("a", new SessionWindow(0, 100));
-        final Windowed<String> session2 = new Windowed<>("a", new SessionWindow(101, 200));
-        final Windowed<String> session3 = new Windowed<>("a", new SessionWindow(201, 300));
-        final Windowed<String> session4 = new Windowed<>("a", new SessionWindow(301, 400));
-        final Windowed<String> session5 = new Windowed<>("a", new SessionWindow(401, 500));
-        sessionStore.put(session1, 1L);
-        sessionStore.put(session2, 2L);
-        sessionStore.put(session3, 3L);
-        sessionStore.put(session4, 4L);
-        sessionStore.put(session5, 5L);
-
-        try (final KeyValueIterator<Windowed<String>, Long> results =
-            sessionStore.findSessions("a", 150, 300)
-        ) {
-            assertEquals(session2, results.next().key);
-            assertEquals(session3, results.next().key);
-            assertFalse(results.hasNext());
-        }
-    }
-
-    @Test
-    public void shouldFetchExactKeys() {
-        sessionStore = buildSessionStore(0x7a00000000000000L);
-        sessionStore.init(context, sessionStore);
-
-        sessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L);
-        sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 10)), 2L);
-        sessionStore.put(new Windowed<>("a", new SessionWindow(10, 20)), 3L);
-        sessionStore.put(new Windowed<>("aa", new SessionWindow(10, 20)), 4L);
-        sessionStore.put(new Windowed<>("a", new SessionWindow(0x7a00000000000000L - 2, 0x7a00000000000000L - 1)), 5L);
-
-        try (final KeyValueIterator<Windowed<String>, Long> iterator =
-            sessionStore.findSessions("a", 0, Long.MAX_VALUE)
-        ) {
-            assertThat(valuesToList(iterator), equalTo(Arrays.asList(1L, 3L, 5L)));
-        }
-
-        try (final KeyValueIterator<Windowed<String>, Long> iterator =
-            sessionStore.findSessions("aa", 0, Long.MAX_VALUE)
-        ) {
-            assertThat(valuesToList(iterator), equalTo(Arrays.asList(2L, 4L)));
-        }
-
-        try (final KeyValueIterator<Windowed<String>, Long> iterator =
-            sessionStore.findSessions("a", "aa", 0, Long.MAX_VALUE)
-        ) {
-            assertThat(valuesToList(iterator), equalTo(Arrays.asList(1L, 2L, 3L, 4L, 5L)));
-        }
-
-        try (final KeyValueIterator<Windowed<String>, Long> iterator =
-            sessionStore.findSessions("a", "aa", 10, 0)
-        ) {
-            assertThat(valuesToList(iterator), equalTo(Collections.singletonList(2L)));
-        }
+    @Override
+    String getMetricsScope() {
+        return new InMemorySessionBytesStoreSupplier(null, 0).metricsScope();
     }
 
-    @Test
-    public void testIteratorPeek() {
-        sessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L);
-        sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 10)), 2L);
-        sessionStore.put(new Windowed<>("a", new SessionWindow(10, 20)), 3L);
-        sessionStore.put(new Windowed<>("aa", new SessionWindow(10, 20)), 4L);
-
-        final KeyValueIterator<Windowed<String>, Long> iterator = sessionStore.findSessions("a", 0L, 20);
-
-        assertEquals(iterator.peekNextKey(), new Windowed<>("a", new SessionWindow(0L, 0L)));
-        assertEquals(iterator.peekNextKey(), iterator.next().key);
-        assertEquals(iterator.peekNextKey(), iterator.next().key);
-        assertFalse(iterator.hasNext());
+    @Override
+    void setClassLoggerToDebug() {
+        LogCaptureAppender.setClassLoggerToDebug(InMemorySessionStore.class);
     }
 
     @Test
@@ -349,154 +72,27 @@ public class InMemorySessionStoreTest {
         try (final KeyValueIterator<Windowed<String>, Long> iterator =
             sessionStore.findSessions("a", "b", 0L, Long.MAX_VALUE)
         ) {
-            assertThat(valuesToList(iterator), equalTo(Arrays.asList(2L, 3L, 4L)));
+            assertEquals(valuesToSet(iterator), new HashSet<>(Arrays.asList(2L, 3L, 4L)));
         }
     }
 
     @Test
-    public void shouldRestore() {
-        final List<KeyValue<Windowed<String>, Long>> expected = Arrays.asList(
-            KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 1L),
-            KeyValue.pair(new Windowed<>("a", new SessionWindow(10, 10)), 2L),
-            KeyValue.pair(new Windowed<>("a", new SessionWindow(100, 100)), 3L),
-            KeyValue.pair(new Windowed<>("a", new SessionWindow(1000, 1000)), 4L));
+    public void shouldNotExpireFromOpenIterator() {
 
-        for (final KeyValue<Windowed<String>, Long> kv : expected) {
-            sessionStore.put(kv.key, kv.value);
-        }
-
-        try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.fetch("a")) {
-            assertEquals(expected, toList(values));
-        }
-
-        sessionStore.close();
-
-        try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.fetch("a")) {
-            assertEquals(Collections.emptyList(), toList(values));
-        }
-
-        context.restore(STORE_NAME, changeLog);
-
-        try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.fetch("a")) {
-            assertEquals(expected, toList(values));
-        }
-    }
-
-    @Test
-    public void shouldReturnSameResultsForSingleKeyFindSessionsAndEqualKeyRangeFindSessions() {
-        sessionStore.put(new Windowed<>("a", new SessionWindow(0, 1)), 0L);
-        sessionStore.put(new Windowed<>("aa", new SessionWindow(2, 3)), 1L);
-        sessionStore.put(new Windowed<>("aa", new SessionWindow(4, 5)), 2L);
-        sessionStore.put(new Windowed<>("aaa", new SessionWindow(6, 7)), 3L);
-
-        final KeyValueIterator<Windowed<String>, Long> singleKeyIterator = sessionStore.findSessions("aa", 0L, 10L);
-        final KeyValueIterator<Windowed<String>, Long> keyRangeIterator = sessionStore.findSessions("aa", "aa", 0L, 10L);
-
-        assertEquals(singleKeyIterator.next(), keyRangeIterator.next());
-        assertEquals(singleKeyIterator.next(), keyRangeIterator.next());
-        assertFalse(singleKeyIterator.hasNext());
-        assertFalse(keyRangeIterator.hasNext());
-    }
-
-    @Test
-    public void shouldLogAndMeasureExpiredRecords() {
-        LogCaptureAppender.setClassLoggerToDebug(InMemorySessionStore.class);
-        final LogCaptureAppender appender = LogCaptureAppender.createAndRegister();
-
-
-        // Advance stream time by inserting record with large enough timestamp that records with timestamp 0 are expired
-        sessionStore.put(new Windowed<>("initial record", new SessionWindow(0, RETENTION_PERIOD)), 0L);
-
-        // Try inserting a record with timestamp 0 -- should be dropped
-        sessionStore.put(new Windowed<>("late record", new SessionWindow(0, 0)), 0L);
-        sessionStore.put(new Windowed<>("another on-time record", new SessionWindow(0, RETENTION_PERIOD)), 0L);
-
-        LogCaptureAppender.unregister(appender);
-
-        final Map<MetricName, ? extends Metric> metrics = context.metrics().metrics();
-
-        final Metric dropTotal = metrics.get(new MetricName(
-            "expired-window-record-drop-total",
-            "stream-in-memory-session-state-metrics",
-            "The total number of occurrence of expired-window-record-drop operations.",
-            mkMap(
-                mkEntry("client-id", "mock"),
-                mkEntry("task-id", "0_0"),
-                mkEntry("in-memory-session-state-id", STORE_NAME)
-            )
-        ));
-
-        final Metric dropRate = metrics.get(new MetricName(
-            "expired-window-record-drop-rate",
-            "stream-in-memory-session-state-metrics",
-            "The average number of occurrence of expired-window-record-drop operation per second.",
-            mkMap(
-                mkEntry("client-id", "mock"),
-                mkEntry("task-id", "0_0"),
-                mkEntry("in-memory-session-state-id", STORE_NAME)
-            )
-        ));
-
-        assertEquals(1.0, dropTotal.metricValue());
-        assertNotEquals(0.0, dropRate.metricValue());
-        final List<String> messages = appender.getMessages();
-        assertThat(messages, hasItem("Skipping record for expired segment."));
-    }
-
-    @Test(expected = NullPointerException.class)
-    public void shouldThrowNullPointerExceptionOnFindSessionsNullKey() {
-        sessionStore.findSessions(null, 1L, 2L);
-    }
-
-    @Test(expected = NullPointerException.class)
-    public void shouldThrowNullPointerExceptionOnFindSessionsNullFromKey() {
-        sessionStore.findSessions(null, "anyKeyTo", 1L, 2L);
-    }
-
-    @Test(expected = NullPointerException.class)
-    public void shouldThrowNullPointerExceptionOnFindSessionsNullToKey() {
-        sessionStore.findSessions("anyKeyFrom", null, 1L, 2L);
-    }
-
-    @Test(expected = NullPointerException.class)
-    public void shouldThrowNullPointerExceptionOnFetchNullFromKey() {
-        sessionStore.fetch(null, "anyToKey");
-    }
-
-    @Test(expected = NullPointerException.class)
-    public void shouldThrowNullPointerExceptionOnFetchNullToKey() {
-        sessionStore.fetch("anyFromKey", null);
-    }
-
-    @Test(expected = NullPointerException.class)
-    public void shouldThrowNullPointerExceptionOnFetchNullKey() {
-        sessionStore.fetch(null);
-    }
-
-    @Test(expected = NullPointerException.class)
-    public void shouldThrowNullPointerExceptionOnRemoveNullKey() {
-        sessionStore.remove(null);
-    }
-
-    @Test(expected = NullPointerException.class)
-    public void shouldThrowNullPointerExceptionOnPutNullKey() {
-        sessionStore.put(null, 1L);
-    }
+        sessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L);
+        sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 10)), 2L);
+        sessionStore.put(new Windowed<>("a", new SessionWindow(10, 20)), 3L);
 
-    @Test
-    public void shouldNotThrowInvalidRangeExceptionWithNegativeFromKey() {
-        LogCaptureAppender.setClassLoggerToDebug(InMemorySessionStore.class);
-        final LogCaptureAppender appender = LogCaptureAppender.createAndRegister();
+        final KeyValueIterator<Windowed<String>, Long> iterator = sessionStore.findSessions("a", "b", 0L, RETENTION_PERIOD);
 
-        final String keyFrom = Serdes.String().deserializer().deserialize("", Serdes.Integer().serializer().serialize("", -1));
-        final String keyTo = Serdes.String().deserializer().deserialize("", Serdes.Integer().serializer().serialize("", 1));
+        // Advance stream time to expire the first three record
+        sessionStore.put(new Windowed<>("aa", new SessionWindow(100, 2 * RETENTION_PERIOD)), 4L);
 
-        final KeyValueIterator<Windowed<String>, Long> iterator = sessionStore.findSessions(keyFrom, keyTo, 0L, 10L);
+        assertEquals(valuesToSet(iterator), new HashSet<>(Arrays.asList(1L, 2L, 3L, 4L)));
         assertFalse(iterator.hasNext());
 
-        final List<String> messages = appender.getMessages();
-        assertThat(messages, hasItem("Returning empty iterator for fetch with invalid key range: from > to. "
-            + "This may be due to serdes that don't preserve ordering when lexicographically comparing the serialized bytes. "
-            + "Note that the built-in numerical serdes do not follow this for negative numbers"));
+        iterator.close();
+        assertFalse(sessionStore.findSessions("a", "b", 0L, 20L).hasNext());
     }
+
 }
\ No newline at end of file
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/InMemoryWindowStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/InMemoryWindowStoreTest.java
index 1524d9c..41ed073 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/InMemoryWindowStoreTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/InMemoryWindowStoreTest.java
@@ -17,600 +17,161 @@
 package org.apache.kafka.streams.state.internals;
 
 import static java.time.Duration.ofMillis;
-import static org.apache.kafka.common.utils.Utils.mkEntry;
-import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.streams.state.internals.WindowKeySchema.toStoreKeyBinary;
-import static org.hamcrest.CoreMatchers.hasItem;
-import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNotEquals;
 
-import java.io.File;
-import java.util.ArrayList;
 import java.util.LinkedList;
 import java.util.List;
-import java.util.Map;
-import org.apache.kafka.clients.producer.MockProducer;
-import org.apache.kafka.clients.producer.Producer;
-import org.apache.kafka.common.Metric;
-import org.apache.kafka.common.MetricName;
-import org.apache.kafka.common.header.Headers;
-import org.apache.kafka.common.metrics.Metrics;
+import org.apache.kafka.common.serialization.Serde;
 import org.apache.kafka.common.serialization.Serdes;
-import org.apache.kafka.common.serialization.Serializer;
-import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.streams.KeyValue;
-import org.apache.kafka.streams.errors.DefaultProductionExceptionHandler;
 import org.apache.kafka.streams.kstream.Windowed;
-import org.apache.kafka.streams.processor.ProcessorContext;
-import org.apache.kafka.streams.processor.internals.MockStreamsMetrics;
-import org.apache.kafka.streams.processor.internals.ProcessorRecordContext;
-import org.apache.kafka.streams.processor.internals.RecordCollector;
-import org.apache.kafka.streams.processor.internals.RecordCollectorImpl;
 import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender;
 import org.apache.kafka.streams.state.KeyValueIterator;
 import org.apache.kafka.streams.state.StateSerdes;
 import org.apache.kafka.streams.state.Stores;
 import org.apache.kafka.streams.state.WindowStore;
 import org.apache.kafka.streams.state.WindowStoreIterator;
-import org.apache.kafka.test.InternalMockProcessorContext;
-import org.apache.kafka.test.TestUtils;
-import org.junit.After;
-import org.junit.Before;
 import org.junit.Test;
 
-public class InMemoryWindowStoreTest {
-
-    private static final long DEFAULT_CACHE_SIZE_BYTES = 1024 * 1024L;
-
-    private final String storeName = "InMemoryWindowStore";
-    private final long retentionPeriod = 40L * 1000L;
-    private final long windowSize = 10L;
-
-    private final StateSerdes<Integer, String> serdes = new StateSerdes<>("", Serdes.Integer(), Serdes.String());
-
-    private final List<KeyValue<byte[], byte[]>> changeLog = new ArrayList<>();
-    private final ThreadCache cache = new ThreadCache(new LogContext("TestCache "),
-                                                      DEFAULT_CACHE_SIZE_BYTES,
-                                                      new MockStreamsMetrics(new Metrics()));
-
-    private final Producer<byte[], byte[]> producer =
-        new MockProducer<>(true, Serdes.ByteArray().serializer(), Serdes.ByteArray().serializer());
-    private final RecordCollector recordCollector = new RecordCollectorImpl("InMemoryWindowStoreTestTask",
-                                                                            new LogContext("InMemoryWindowStoreTestTask "),
-                                                                            new DefaultProductionExceptionHandler(),
-                                                                            new Metrics().sensor("skipped-records")) {
-        @Override
-        public <K1, V1> void send(final String topic,
-            final K1 key,
-            final V1 value,
-            final Headers headers,
-            final Integer partition,
-            final Long timestamp,
-            final Serializer<K1> keySerializer,
-            final Serializer<V1> valueSerializer) {
-            changeLog.add(new KeyValue<>(
-                keySerializer.serialize(topic, headers, key),
-                valueSerializer.serialize(topic, headers, value))
-            );
-        }
-    };
-
-    private final File baseDir = TestUtils.tempDirectory("test");
-    private final InternalMockProcessorContext context = new InternalMockProcessorContext(baseDir, Serdes.ByteArray(), Serdes.ByteArray(), recordCollector, cache);
-    private WindowStore<Integer, String> windowStore;
-
-    private WindowStore<Integer, String> createInMemoryWindowStore(final ProcessorContext context, final boolean retainDuplicates) {
-        final WindowStore<Integer, String> store = Stores.windowStoreBuilder(Stores.inMemoryWindowStore(
-                                                                             storeName,
-                                                                             ofMillis(retentionPeriod),
-                                                                             ofMillis(windowSize),
-                                                                             retainDuplicates),
-            Serdes.Integer(),
-            Serdes.String()).build();
-
-        store.init(context, store);
-        return store;
-    }
-
-    @Before
-    public void initRecordCollector() {
-        recordCollector.init(producer);
-    }
-
-    @After
-    public void closeStore() {
-        if (windowStore != null) {
-            windowStore.close();
-        }
-    }
-
-    private void setCurrentTime(final long currentTime) {
-        context.setRecordContext(createRecordContext(currentTime));
-    }
-
-    private ProcessorRecordContext createRecordContext(final long time) {
-        return new ProcessorRecordContext(time, 0, 0, "topic", null);
-    }
-
-    private <K, V> KeyValue<Windowed<K>, V> windowedPair(final K key, final V value, final long timestamp) {
-        return windowedPair(key, value, timestamp, windowSize);
-    }
-
-    private static <K, V> KeyValue<Windowed<K>, V> windowedPair(final K key, final V value, final long timestamp, final long windowSize) {
-        return KeyValue.pair(new Windowed<>(key, WindowKeySchema.timeWindowForSize(timestamp, windowSize)), value);
-    }
-
-    @Test
-    public void testSingleFetch() {
-        windowStore = createInMemoryWindowStore(context, false);
-
-        long currentTime = 0;
-        setCurrentTime(currentTime);
-        windowStore.put(1, "one");
-
-        currentTime += windowSize;
-        setCurrentTime(currentTime);
-        windowStore.put(1, "two");
-
-        currentTime += 3 * windowSize;
-        setCurrentTime(currentTime);
-        windowStore.put(1, "three");
-
-        assertEquals("one", windowStore.fetch(1, 0));
-        assertEquals("two", windowStore.fetch(1, windowSize));
-        assertEquals("three", windowStore.fetch(1, 4 * windowSize));
-    }
-
-    @Test
-    public void testDeleteAndUpdate() {
-        windowStore = createInMemoryWindowStore(context, false);
-
-        final long currentTime = 0;
-        setCurrentTime(currentTime);
-        windowStore.put(1, "one");
-        windowStore.put(1, "one v2");
+public class InMemoryWindowStoreTest extends WindowBytesStoreTest {
 
-        WindowStoreIterator<String> iterator = windowStore.fetch(1, 0, currentTime);
-        assertEquals(new KeyValue<>(currentTime, "one v2"), iterator.next());
+    private final static String STORE_NAME = "InMemoryWindowStore";
 
-        windowStore.put(1, null);
-        iterator = windowStore.fetch(1, 0, currentTime);
-        assertFalse(iterator.hasNext());
+    @Override
+    <K, V> WindowStore<K, V> buildWindowStore(final long retentionPeriod,
+        final long windowSize,
+        final boolean retainDuplicates,
+        final Serde<K> keySerde,
+        final Serde<V> valueSerde) {
+        return Stores.windowStoreBuilder(
+            Stores.inMemoryWindowStore(
+                STORE_NAME,
+                ofMillis(retentionPeriod),
+                ofMillis(windowSize),
+                retainDuplicates),
+            keySerde,
+            valueSerde)
+            .build();
     }
 
-    @Test
-    public void testFetchAll() {
-        windowStore = createInMemoryWindowStore(context, false);
-
-        long currentTime = 0;
-        setCurrentTime(currentTime);
-        windowStore.put(1, "one");
-
-        currentTime += windowSize * 10;
-        setCurrentTime(currentTime);
-        windowStore.put(1, "two");
-
-        currentTime += windowSize * 10;
-        setCurrentTime(currentTime);
-        windowStore.put(1, "three");
-
-        currentTime += windowSize * 10;
-        setCurrentTime(currentTime);
-        windowStore.put(2, "four");
-
-        currentTime += windowSize * 10;
-        setCurrentTime(currentTime);
-        windowStore.put(2, "five");
-
-        final KeyValueIterator<Windowed<Integer>, String> iterator = windowStore.fetchAll(windowSize * 10, windowSize * 30);
-
-        assertEquals(windowedPair(1, "two", windowSize * 10), iterator.next());
-        assertEquals(windowedPair(1, "three", windowSize * 20), iterator.next());
-        assertEquals(windowedPair(2, "four", windowSize * 30), iterator.next());
-        assertFalse(iterator.hasNext());
+    @Override
+    String getMetricsScope() {
+        return new InMemoryWindowBytesStoreSupplier(null, 0, 0, false).metricsScope();
     }
 
-    @Test
-    public void testAll() {
-        windowStore = createInMemoryWindowStore(context, false);
-
-        long currentTime = 0;
-        setCurrentTime(currentTime);
-        windowStore.put(1, "one");
-
-        currentTime += windowSize * 10;
-        setCurrentTime(currentTime);
-        windowStore.put(1, "two");
-
-        currentTime += windowSize * 10;
-        setCurrentTime(currentTime);
-        windowStore.put(1, "three");
-
-        currentTime += windowSize * 10;
-        setCurrentTime(currentTime);
-        windowStore.put(2, "four");
-
-        final KeyValueIterator<Windowed<Integer>, String> iterator = windowStore.all();
-
-        assertEquals(windowedPair(1, "one", 0), iterator.next());
-        assertEquals(windowedPair(1, "two", windowSize * 10), iterator.next());
-        assertEquals(windowedPair(1, "three", windowSize * 20), iterator.next());
-        assertEquals(windowedPair(2, "four", windowSize * 30), iterator.next());
-        assertFalse(iterator.hasNext());
+    @Override
+    void setClassLoggerToDebug() {
+        LogCaptureAppender.setClassLoggerToDebug(InMemoryWindowStore.class);
     }
 
     @Test
-    public void testTimeRangeFetch() {
-
-        windowStore = createInMemoryWindowStore(context, false);
-
-        long currentTime = 0;
-        setCurrentTime(currentTime);
-        windowStore.put(1, "one");
-
-        currentTime += windowSize * 10;
-        setCurrentTime(currentTime);
-        windowStore.put(1, "two");
+    public void shouldRestore() {
+        // should be empty initially
+        assertFalse(windowStore.all().hasNext());
 
-        currentTime += windowSize * 10;
-        setCurrentTime(currentTime);
-        windowStore.put(1, "three");
+        final StateSerdes<Integer, String> serdes = new StateSerdes<>("", Serdes.Integer(),
+            Serdes.String());
 
-        currentTime += windowSize * 10;
-        setCurrentTime(currentTime);
-        windowStore.put(1, "four");
+        final List<KeyValue<byte[], byte[]>> restorableEntries = new LinkedList<>();
 
-        currentTime += windowSize * 10;
-        setCurrentTime(currentTime);
-        windowStore.put(1, "five");
+        restorableEntries
+            .add(new KeyValue<>(toStoreKeyBinary(1, 0L, 0, serdes).get(), serdes.rawValue("one")));
+        restorableEntries.add(new KeyValue<>(toStoreKeyBinary(2, WINDOW_SIZE, 0, serdes).get(),
+            serdes.rawValue("two")));
+        restorableEntries.add(new KeyValue<>(toStoreKeyBinary(3, 2 * WINDOW_SIZE, 0, serdes).get(),
+            serdes.rawValue("three")));
 
-        final WindowStoreIterator<String> iterator = windowStore.fetch(1, windowSize * 10, 3 * windowSize * 10);
+        context.restore(STORE_NAME, restorableEntries);
+        final KeyValueIterator<Windowed<Integer>, String> iterator = windowStore
+            .fetchAll(0L, 2 * WINDOW_SIZE);
 
-        // should return only the middle three records
-        assertEquals(new KeyValue<>(windowSize * 10, "two"), iterator.next());
-        assertEquals(new KeyValue<>(2 * windowSize * 10, "three"), iterator.next());
-        assertEquals(new KeyValue<>(3 * windowSize * 10, "four"), iterator.next());
+        assertEquals(windowedPair(1, "one", 0L), iterator.next());
+        assertEquals(windowedPair(2, "two", WINDOW_SIZE), iterator.next());
+        assertEquals(windowedPair(3, "three", 2 * WINDOW_SIZE), iterator.next());
         assertFalse(iterator.hasNext());
     }
 
     @Test
-    public void testKeyRangeFetch() {
-
-        windowStore = createInMemoryWindowStore(context, false);
-
-        long currentTime = 0;
-        setCurrentTime(currentTime);
-        windowStore.put(1, "one");
-
-        currentTime += windowSize * 10;
-        setCurrentTime(currentTime);
-        windowStore.put(2, "two");
-
-        currentTime += windowSize * 10;
-        setCurrentTime(currentTime);
-        windowStore.put(3, "three");
-
-        currentTime += windowSize * 10;
-        setCurrentTime(currentTime);
-        windowStore.put(4, "four");
+    public void shouldNotExpireFromOpenIterator() {
 
-        windowStore.put(5, "five");
+        windowStore.put(1, "one", 0L);
+        windowStore.put(1, "two", 10L);
 
-        final KeyValueIterator<Windowed<Integer>, String> iterator = windowStore.fetch(1, 4, 0L, currentTime);
+        windowStore.put(2, "one", 5L);
+        windowStore.put(2, "two", 15L);
 
-        // should return only the first four keys
-        assertEquals(windowedPair(1, "one", 0), iterator.next());
-        assertEquals(windowedPair(2, "two", windowSize * 10), iterator.next());
-        assertEquals(windowedPair(3, "three", windowSize * 20), iterator.next());
-        assertEquals(windowedPair(4, "four", windowSize * 30), iterator.next());
-        assertFalse(iterator.hasNext());
-    }
+        final WindowStoreIterator<String> iterator1 = windowStore.fetch(1, 0L, 50L);
+        final WindowStoreIterator<String> iterator2 = windowStore.fetch(2, 0L, 50L);
 
-    @Test
-    public void testFetchDuplicates() {
-        windowStore = createInMemoryWindowStore(context, true);
+        // This put expires all four previous records, but they should still be returned from already open iterators
+        windowStore.put(1, "four", 2 * RETENTION_PERIOD);
 
-        long currentTime = 0;
-        setCurrentTime(currentTime);
-        windowStore.put(1, "one");
-        windowStore.put(1, "one-2");
+        assertEquals(new KeyValue<>(0L, "one"), iterator1.next());
+        assertEquals(new KeyValue<>(5L, "one"), iterator2.next());
 
-        currentTime += windowSize * 10;
-        setCurrentTime(currentTime);
-        windowStore.put(1, "two");
-        windowStore.put(1, "two-2");
+        assertEquals(new KeyValue<>(15L, "two"), iterator2.next());
+        assertEquals(new KeyValue<>(10L, "two"), iterator1.next());
 
-        currentTime += windowSize * 10;
-        setCurrentTime(currentTime);
-        windowStore.put(1, "three");
-        windowStore.put(1, "three-2");
+        assertFalse(iterator1.hasNext());
+        assertFalse(iterator2.hasNext());
 
-        final WindowStoreIterator<String> iterator = windowStore.fetch(1, 0, windowSize * 10);
+        iterator1.close();
+        iterator2.close();
 
-        assertEquals(new KeyValue<>(0L, "one"), iterator.next());
-        assertEquals(new KeyValue<>(0L, "one-2"), iterator.next());
-        assertEquals(new KeyValue<>(windowSize * 10, "two"), iterator.next());
-        assertEquals(new KeyValue<>(windowSize * 10, "two-2"), iterator.next());
-        assertFalse(iterator.hasNext());
+        // Make sure expired records are removed now that open iterators are closed
+        assertFalse(windowStore.fetch(1, 0L, 50L).hasNext());
     }
 
     @Test
-    public void testSegmentExpiration() {
-        windowStore = createInMemoryWindowStore(context, false);
+    public void testExpiration() {
 
         long currentTime = 0;
         setCurrentTime(currentTime);
         windowStore.put(1, "one");
 
-        currentTime += retentionPeriod / 4;
+        currentTime += RETENTION_PERIOD / 4;
         setCurrentTime(currentTime);
         windowStore.put(1, "two");
 
-        currentTime += retentionPeriod / 4;
+        currentTime += RETENTION_PERIOD / 4;
         setCurrentTime(currentTime);
         windowStore.put(1, "three");
 
-        currentTime += retentionPeriod / 4;
+        currentTime += RETENTION_PERIOD / 4;
         setCurrentTime(currentTime);
         windowStore.put(1, "four");
 
-        // increase current time to the full retentionPeriod to expire first record
-        currentTime = currentTime + retentionPeriod / 4;
+        // increase current time to the full RETENTION_PERIOD to expire first record
+        currentTime = currentTime + RETENTION_PERIOD / 4;
         setCurrentTime(currentTime);
         windowStore.put(1, "five");
 
-        KeyValueIterator<Windowed<Integer>, String> iterator = windowStore.fetchAll(0L, currentTime);
+        KeyValueIterator<Windowed<Integer>, String> iterator = windowStore
+            .fetchAll(0L, currentTime);
 
         // effect of this put (expires next oldest record, adds new one) should not be reflected in the already fetched results
-        currentTime = currentTime + retentionPeriod / 4;
+        currentTime = currentTime + RETENTION_PERIOD / 4;
         setCurrentTime(currentTime);
         windowStore.put(1, "six");
 
         // should only have middle 4 values, as (only) the first record was expired at the time of the fetch
         // and the last was inserted after the fetch
-        assertEquals(windowedPair(1, "two", retentionPeriod / 4), iterator.next());
-        assertEquals(windowedPair(1, "three", retentionPeriod / 2), iterator.next());
-        assertEquals(windowedPair(1, "four", 3 * (retentionPeriod / 4)), iterator.next());
-        assertEquals(windowedPair(1, "five", retentionPeriod), iterator.next());
+        assertEquals(windowedPair(1, "two", RETENTION_PERIOD / 4), iterator.next());
+        assertEquals(windowedPair(1, "three", RETENTION_PERIOD / 2), iterator.next());
+        assertEquals(windowedPair(1, "four", 3 * (RETENTION_PERIOD / 4)), iterator.next());
+        assertEquals(windowedPair(1, "five", RETENTION_PERIOD), iterator.next());
         assertFalse(iterator.hasNext());
 
         iterator = windowStore.fetchAll(0L, currentTime);
 
         // If we fetch again after the last put, the second oldest record should have expired and newest should appear in results
-        assertEquals(windowedPair(1, "three", retentionPeriod / 2), iterator.next());
-        assertEquals(windowedPair(1, "four", 3 * (retentionPeriod / 4)), iterator.next());
-        assertEquals(windowedPair(1, "five", retentionPeriod), iterator.next());
-        assertEquals(windowedPair(1, "six", 5 * (retentionPeriod / 4)), iterator.next());
-        assertFalse(iterator.hasNext());
-    }
-
-    @Test
-    public void testWindowIteratorPeek() {
-        windowStore = createInMemoryWindowStore(context, false);
-
-        final long currentTime = 0;
-        setCurrentTime(currentTime);
-        windowStore.put(1, "one");
-
-        final KeyValueIterator<Windowed<Integer>, String> iterator = windowStore.fetchAll(0L, currentTime);
-
-        assertEquals(iterator.peekNextKey(), iterator.next().key);
-        assertFalse(iterator.hasNext());
-    }
-
-    @Test
-    public void testValueIteratorPeek() {
-        windowStore = createInMemoryWindowStore(context, false);
-
-        final long currentTime = 0;
-        setCurrentTime(currentTime);
-        windowStore.put(1, "one");
-
-        final WindowStoreIterator<String> iterator = windowStore.fetch(1, 0L, currentTime);
-
-        assertEquals(iterator.peekNextKey(), iterator.next().key);
-        assertFalse(iterator.hasNext());
-    }
-
-    @Test
-    public void shouldRestore() {
-        windowStore = createInMemoryWindowStore(context, false);
-
-        // should be empty initially
-        assertFalse(windowStore.all().hasNext());
-
-        final List<KeyValue<byte[], byte[]>> restorableEntries = new LinkedList<>();
-
-        restorableEntries.add(new KeyValue<>(toStoreKeyBinary(1, 0L, 0, serdes).get(), serdes.rawValue("one")));
-        restorableEntries.add(new KeyValue<>(toStoreKeyBinary(2, windowSize, 0, serdes).get(), serdes.rawValue("two")));
-        restorableEntries.add(new KeyValue<>(toStoreKeyBinary(3, 2 * windowSize, 0, serdes).get(), serdes.rawValue("three")));
-
-        context.restore(storeName, restorableEntries);
-        final KeyValueIterator<Windowed<Integer>, String> iterator = windowStore.fetchAll(0L, 2 * windowSize);
-
-        assertEquals(windowedPair(1, "one", 0L), iterator.next());
-        assertEquals(windowedPair(2, "two", windowSize), iterator.next());
-        assertEquals(windowedPair(3, "three", 2 * windowSize), iterator.next());
-        assertFalse(iterator.hasNext());
-    }
-
-    @Test
-    public void shouldLogAndMeasureExpiredRecords() {
-        LogCaptureAppender.setClassLoggerToDebug(InMemoryWindowStore.class);
-        final LogCaptureAppender appender = LogCaptureAppender.createAndRegister();
-
-        windowStore = createInMemoryWindowStore(context, false);
-        setCurrentTime(retentionPeriod);
-
-        // Advance stream time by inserting record with large enough timestamp that records with timestamp 0 are expired
-        windowStore.put(1, "initial record");
-
-        // Try inserting a record with timestamp 0 -- should be dropped
-        windowStore.put(1, "late record", 0L);
-        windowStore.put(1, "another on-time record");
-
-        LogCaptureAppender.unregister(appender);
-
-        final Map<MetricName, ? extends Metric> metrics = context.metrics().metrics();
-
-        final Metric dropTotal = metrics.get(new MetricName(
-            "expired-window-record-drop-total",
-            "stream-in-memory-window-state-metrics",
-            "The total number of occurrence of expired-window-record-drop operations.",
-            mkMap(
-                mkEntry("client-id", "mock"),
-                mkEntry("task-id", "0_0"),
-                mkEntry("in-memory-window-state-id", storeName)
-            )
-        ));
-
-        final Metric dropRate = metrics.get(new MetricName(
-            "expired-window-record-drop-rate",
-            "stream-in-memory-window-state-metrics",
-            "The average number of occurrence of expired-window-record-drop operation per second.",
-            mkMap(
-                mkEntry("client-id", "mock"),
-                mkEntry("task-id", "0_0"),
-                mkEntry("in-memory-window-state-id", storeName)
-            )
-        ));
-
-        assertEquals(1.0, dropTotal.metricValue());
-        assertNotEquals(0.0, dropRate.metricValue());
-        final List<String> messages = appender.getMessages();
-        assertThat(messages, hasItem("Skipping record for expired segment."));
-    }
-
-    @Test
-    public void testIteratorMultiplePeekAndHasNext() {
-        windowStore = createInMemoryWindowStore(context, false);
-
-        long currentTime = 0;
-        setCurrentTime(currentTime);
-        windowStore.put(1, "one");
-
-        currentTime += windowSize * 10;
-        setCurrentTime(currentTime);
-        windowStore.put(2, "two");
-
-        currentTime += windowSize * 10;
-        setCurrentTime(currentTime);
-        windowStore.put(3, "three");
-
-        final KeyValueIterator<Windowed<Integer>, String> iterator = windowStore.fetch(1, 4, 0L, currentTime);
-
-        assertFalse(!iterator.hasNext());
-        assertFalse(!iterator.hasNext());
-        assertEquals(new Windowed<>(1, WindowKeySchema.timeWindowForSize(0L, windowSize)), iterator.peekNextKey());
-        assertEquals(new Windowed<>(1, WindowKeySchema.timeWindowForSize(0L, windowSize)), iterator.peekNextKey());
-
-        assertEquals(windowedPair(1, "one", 0), iterator.next());
-        assertEquals(windowedPair(2, "two", windowSize * 10), iterator.next());
-        assertEquals(windowedPair(3, "three", windowSize * 20), iterator.next());
+        assertEquals(windowedPair(1, "three", RETENTION_PERIOD / 2), iterator.next());
+        assertEquals(windowedPair(1, "four", 3 * (RETENTION_PERIOD / 4)), iterator.next());
+        assertEquals(windowedPair(1, "five", RETENTION_PERIOD), iterator.next());
+        assertEquals(windowedPair(1, "six", 5 * (RETENTION_PERIOD / 4)), iterator.next());
         assertFalse(iterator.hasNext());
     }
-
-    @Test
-    public void shouldNotThrowConcurrentModificationException() {
-        windowStore = createInMemoryWindowStore(context, false);
-
-        long currentTime = 0;
-        setCurrentTime(currentTime);
-        windowStore.put(1, "one");
-
-        currentTime += windowSize * 10;
-        setCurrentTime(currentTime);
-        windowStore.put(1, "two");
-
-        final KeyValueIterator<Windowed<Integer>, String> iterator = windowStore.all();
-
-        currentTime += windowSize * 10;
-        setCurrentTime(currentTime);
-        windowStore.put(1, "three");
-
-        currentTime += windowSize * 10;
-        setCurrentTime(currentTime);
-        windowStore.put(2, "four");
-
-        // Iterator should return all records in store and not throw exception b/c some were added after fetch
-        assertEquals(windowedPair(1, "one", 0), iterator.next());
-        assertEquals(windowedPair(1, "two", windowSize * 10), iterator.next());
-        assertEquals(windowedPair(1, "three", windowSize * 20), iterator.next());
-        assertEquals(windowedPair(2, "four", windowSize * 30), iterator.next());
-        assertFalse(iterator.hasNext());
-    }
-
-    @Test
-    public void shouldNotExpireFromOpenIterator() {
-        windowStore = createInMemoryWindowStore(context, false);
-
-        windowStore.put(1, "one", 0L);
-        windowStore.put(1, "two", 10L);
-
-        windowStore.put(2, "one", 5L);
-        windowStore.put(2, "two", 15L);
-
-        final WindowStoreIterator<String> iterator1 = windowStore.fetch(1, 0L, 50L);
-        final WindowStoreIterator<String> iterator2 = windowStore.fetch(2, 0L, 50L);
-
-        // This put expires all four previous records, but they should still be returned from already open iterators
-        windowStore.put(1, "four", retentionPeriod + 50L);
-
-        assertEquals(new KeyValue<>(0L, "one"), iterator1.next());
-        assertEquals(new KeyValue<>(5L, "one"), iterator2.next());
-
-        assertEquals(new KeyValue<>(15L, "two"), iterator2.next());
-        assertEquals(new KeyValue<>(10L, "two"), iterator1.next());
-
-        assertFalse(iterator1.hasNext());
-        assertFalse(iterator2.hasNext());
-    }
-
-    @Test
-    public void shouldReturnSameResultsForSingleKeyFetchAndEqualKeyRangeFetch() {
-        windowStore = createInMemoryWindowStore(context, false);
-
-        windowStore.put(1, "one", 0L);
-        windowStore.put(2, "two", 1L);
-        windowStore.put(2, "two", 2L);
-        windowStore.put(3, "three", 3L);
-
-        final WindowStoreIterator<String> singleKeyIterator = windowStore.fetch(2, 0L, 5L);
-        final KeyValueIterator<Windowed<Integer>, String> keyRangeIterator = windowStore.fetch(2, 2, 0L, 5L);
-
-        assertEquals(singleKeyIterator.next().value, keyRangeIterator.next().value);
-        assertEquals(singleKeyIterator.next().value, keyRangeIterator.next().value);
-        assertFalse(singleKeyIterator.hasNext());
-        assertFalse(keyRangeIterator.hasNext());
-    }
-
-    @Test
-    public void shouldNotThrowExceptionWhenFetchRangeIsExpired() {
-        windowStore = createInMemoryWindowStore(context, false);
-
-        windowStore.put(1, "one", 0L);
-        windowStore.put(1, "two", retentionPeriod);
-
-        final WindowStoreIterator<String> iterator = windowStore.fetch(1, 0L, 10L);
-
-        assertFalse(iterator.hasNext());
-    }
-
-    @Test
-    public void shouldNotThrowInvalidRangeExceptionWithNegativeFromKey() {
-        windowStore = createInMemoryWindowStore(context, false);
-
-        LogCaptureAppender.setClassLoggerToDebug(InMemoryWindowStore.class);
-        final LogCaptureAppender appender = LogCaptureAppender.createAndRegister();
-
-        final KeyValueIterator<Windowed<Integer>, String> iterator = windowStore.fetch(-1, 1, 0L, 10L);
-        assertFalse(iterator.hasNext());
-
-        final List<String> messages = appender.getMessages();
-        assertThat(messages, hasItem("Returning empty iterator for fetch with invalid key range: from > to. "
-            + "This may be due to serdes that don't preserve ordering when lexicographically comparing the serialized bytes. "
-            + "Note that the built-in numerical serdes do not follow this for negative numbers"));
-    }
+    
 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBSessionStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBSessionStoreTest.java
index 41abdad..091f90f 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBSessionStoreTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBSessionStoreTest.java
@@ -16,337 +16,60 @@
  */
 package org.apache.kafka.streams.state.internals;
 
-import org.apache.kafka.common.metrics.Metrics;
-import org.apache.kafka.common.serialization.Serdes;
-import org.apache.kafka.common.utils.LogContext;
-import org.apache.kafka.streams.KeyValue;
+import java.util.HashSet;
+import org.apache.kafka.common.serialization.Serde;
 import org.apache.kafka.streams.kstream.Windowed;
 import org.apache.kafka.streams.kstream.internals.SessionWindow;
-import org.apache.kafka.streams.processor.internals.MockStreamsMetrics;
 import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender;
 import org.apache.kafka.streams.state.KeyValueIterator;
 import org.apache.kafka.streams.state.SessionStore;
 import org.apache.kafka.streams.state.Stores;
-import org.apache.kafka.test.InternalMockProcessorContext;
-import org.apache.kafka.test.NoOpRecordCollector;
-import org.apache.kafka.test.TestUtils;
-import org.junit.After;
-import org.junit.Before;
 import org.junit.Test;
 
 import java.util.Arrays;
-import java.util.Collections;
-import java.util.List;
 
 import static java.time.Duration.ofMillis;
-import static org.apache.kafka.test.StreamsTestUtils.toList;
-import static org.apache.kafka.test.StreamsTestUtils.valuesToList;
-import static org.hamcrest.CoreMatchers.equalTo;
-import static org.hamcrest.CoreMatchers.hasItem;
-import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
 
-public class RocksDBSessionStoreTest {
+public class RocksDBSessionStoreTest extends SessionBytesStoreTest {
 
-    private SessionStore<String, Long> sessionStore;
-    private InternalMockProcessorContext context;
+    private static final String STORE_NAME = "rocksDB session store";
 
-    @Before
-    public void before() {
-        sessionStore = Stores.sessionStoreBuilder(
+    @Override
+    <K, V> SessionStore<K, V> buildSessionStore(final long retentionPeriod,
+                                                 final Serde<K> keySerde,
+                                                 final Serde<V> valueSerde) {
+        return Stores.sessionStoreBuilder(
             Stores.persistentSessionStore(
-                "session-store",
-                ofMillis(10_000L)),
-            Serdes.String(),
-            Serdes.Long()).build();
-
-        context = new InternalMockProcessorContext(
-            TestUtils.tempDirectory(),
-            Serdes.String(),
-            Serdes.Long(),
-            new NoOpRecordCollector(),
-            new ThreadCache(
-                new LogContext("testCache "),
-                0,
-                new MockStreamsMetrics(new Metrics())));
-
-        sessionStore.init(context, sessionStore);
-    }
-
-    @After
-    public void close() {
-        sessionStore.close();
+                STORE_NAME,
+                ofMillis(retentionPeriod)),
+            keySerde,
+            valueSerde).build();
     }
 
-    @Test
-    public void shouldPutAndFindSessionsInRange() {
-        final String key = "a";
-        final Windowed<String> a1 = new Windowed<>(key, new SessionWindow(10, 10L));
-        final Windowed<String> a2 = new Windowed<>(key, new SessionWindow(500L, 1000L));
-        sessionStore.put(a1, 1L);
-        sessionStore.put(a2, 2L);
-        sessionStore.put(new Windowed<>(key, new SessionWindow(1500L, 2000L)), 1L);
-        sessionStore.put(new Windowed<>(key, new SessionWindow(2500L, 3000L)), 2L);
-
-        final List<KeyValue<Windowed<String>, Long>> expected =
-            Arrays.asList(KeyValue.pair(a1, 1L), KeyValue.pair(a2, 2L));
-
-        try (final KeyValueIterator<Windowed<String>, Long> values =
-                 sessionStore.findSessions(key, 0, 1000L)
-        ) {
-            assertEquals(expected, toList(values));
-        }
-
-        final List<KeyValue<Windowed<String>, Long>> expected2 = Collections.singletonList(KeyValue.pair(a2, 2L));
-
-        try (final KeyValueIterator<Windowed<String>, Long> values2 =
-                 sessionStore.findSessions(key, 400L, 600L)
-        ) {
-            assertEquals(expected2, toList(values2));
-        }
+    @Override
+    String getMetricsScope() {
+        return new RocksDbSessionBytesStoreSupplier(null, 0).metricsScope();
     }
 
-    @Test
-    public void shouldFetchAllSessionsWithSameRecordKey() {
-        final List<KeyValue<Windowed<String>, Long>> expected = Arrays.asList(
-            KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 1L),
-            KeyValue.pair(new Windowed<>("a", new SessionWindow(10, 10)), 2L),
-            KeyValue.pair(new Windowed<>("a", new SessionWindow(100, 100)), 3L),
-            KeyValue.pair(new Windowed<>("a", new SessionWindow(1000, 1000)), 4L));
-
-        for (final KeyValue<Windowed<String>, Long> kv : expected) {
-            sessionStore.put(kv.key, kv.value);
-        }
-
-        // add one that shouldn't appear in the results
-        sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 0)), 5L);
-
-        try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.fetch("a")) {
-            assertEquals(expected, toList(values));
-        }
+    @Override
+    void setClassLoggerToDebug() {
+        LogCaptureAppender.setClassLoggerToDebug(AbstractRocksDBSegmentedBytesStore.class);
     }
 
     @Test
-    public void shouldFetchAllSessionsWithinKeyRange() {
-        final List<KeyValue<Windowed<String>, Long>> expected = Arrays.asList(
-            KeyValue.pair(new Windowed<>("aa", new SessionWindow(10, 10)), 2L),
-            KeyValue.pair(new Windowed<>("aaa", new SessionWindow(100, 100)), 3L),
-            KeyValue.pair(new Windowed<>("b", new SessionWindow(1000, 1000)), 4L),
-            KeyValue.pair(new Windowed<>("bb", new SessionWindow(1500, 2000)), 5L));
-
-        for (final KeyValue<Windowed<String>, Long> kv : expected) {
-            sessionStore.put(kv.key, kv.value);
-        }
-
-        // add some that shouldn't appear in the results
+    public void shouldRemoveExpired() {
         sessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L);
-        sessionStore.put(new Windowed<>("bbb", new SessionWindow(2500, 3000)), 6L);
-
-        try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.fetch("aa", "bb")) {
-            assertEquals(expected, toList(values));
-        }
-    }
+        sessionStore.put(new Windowed<>("aa", new SessionWindow(0, SEGMENT_INTERVAL)), 2L);
+        sessionStore.put(new Windowed<>("a", new SessionWindow(10, SEGMENT_INTERVAL)), 3L);
 
-    @Test
-    public void shouldFetchExactSession() {
-        sessionStore.put(new Windowed<>("a", new SessionWindow(0, 4)), 1L);
-        sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 3)), 2L);
-        sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 4)), 3L);
-        sessionStore.put(new Windowed<>("aa", new SessionWindow(1, 4)), 4L);
-        sessionStore.put(new Windowed<>("aaa", new SessionWindow(0, 4)), 5L);
-
-        final long result = sessionStore.fetchSession("aa", 0, 4);
-        assertEquals(3L, result);
-    }
-
-    @Test
-    public void shouldFindValuesWithinMergingSessionWindowRange() {
-        final String key = "a";
-        sessionStore.put(new Windowed<>(key, new SessionWindow(0L, 0L)), 1L);
-        sessionStore.put(new Windowed<>(key, new SessionWindow(1000L, 1000L)), 2L);
-
-        final List<KeyValue<Windowed<String>, Long>> expected = Arrays.asList(
-            KeyValue.pair(new Windowed<>(key, new SessionWindow(0L, 0L)), 1L),
-            KeyValue.pair(new Windowed<>(key, new SessionWindow(1000L, 1000L)), 2L));
-
-        try (final KeyValueIterator<Windowed<String>, Long> results =
-                 sessionStore.findSessions(key, -1, 1000L)) {
-            assertEquals(expected, toList(results));
-        }
-    }
-
-    @Test
-    public void shouldRemove() {
-        sessionStore.put(new Windowed<>("a", new SessionWindow(0, 1000)), 1L);
-        sessionStore.put(new Windowed<>("a", new SessionWindow(1500, 2500)), 2L);
-
-        sessionStore.remove(new Windowed<>("a", new SessionWindow(0, 1000)));
-
-        try (final KeyValueIterator<Windowed<String>, Long> results =
-                 sessionStore.findSessions("a", 0L, 1000L)) {
-            assertFalse(results.hasNext());
-        }
-
-        try (final KeyValueIterator<Windowed<String>, Long> results =
-                 sessionStore.findSessions("a", 1500L, 2500L)) {
-            assertTrue(results.hasNext());
-        }
-    }
-
-    @Test
-    public void shouldRemoveOnNullAggValue() {
-        sessionStore.put(new Windowed<>("a", new SessionWindow(0, 1000)), 1L);
-        sessionStore.put(new Windowed<>("a", new SessionWindow(1500, 2500)), 2L);
-
-        sessionStore.put(new Windowed<>("a", new SessionWindow(0, 1000)), null);
-
-        try (final KeyValueIterator<Windowed<String>, Long> results =
-            sessionStore.findSessions("a", 0L, 1000L)) {
-            assertFalse(results.hasNext());
-        }
-
-        try (final KeyValueIterator<Windowed<String>, Long> results =
-            sessionStore.findSessions("a", 1500L, 2500L)) {
-            assertTrue(results.hasNext());
-        }
-    }
-
-    @Test
-    public void shouldFindSessionsToMerge() {
-        final Windowed<String> session1 = new Windowed<>("a", new SessionWindow(0, 100));
-        final Windowed<String> session2 = new Windowed<>("a", new SessionWindow(101, 200));
-        final Windowed<String> session3 = new Windowed<>("a", new SessionWindow(201, 300));
-        final Windowed<String> session4 = new Windowed<>("a", new SessionWindow(301, 400));
-        final Windowed<String> session5 = new Windowed<>("a", new SessionWindow(401, 500));
-        sessionStore.put(session1, 1L);
-        sessionStore.put(session2, 2L);
-        sessionStore.put(session3, 3L);
-        sessionStore.put(session4, 4L);
-        sessionStore.put(session5, 5L);
-
-        try (final KeyValueIterator<Windowed<String>, Long> results =
-                 sessionStore.findSessions("a", 150, 300)
-        ) {
-            assertEquals(session2, results.next().key);
-            assertEquals(session3, results.next().key);
-            assertFalse(results.hasNext());
-        }
-    }
-
-    @Test
-    public void shouldFetchExactKeys() {
-        sessionStore = Stores.sessionStoreBuilder(
-            Stores.persistentSessionStore(
-                "session-store",
-                ofMillis(0x7a00000000000000L)),
-            Serdes.String(),
-            Serdes.Long()).build();
-
-        sessionStore.init(context, sessionStore);
-
-        sessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L);
-        sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 10)), 2L);
-        sessionStore.put(new Windowed<>("a", new SessionWindow(10, 20)), 3L);
-        sessionStore.put(new Windowed<>("aa", new SessionWindow(10, 20)), 4L);
-        sessionStore.put(new Windowed<>("a", new SessionWindow(0x7a00000000000000L - 2, 0x7a00000000000000L - 1)), 5L);
+        // Advance stream time to expire the first record
+        sessionStore.put(new Windowed<>("aa", new SessionWindow(10, 2 * SEGMENT_INTERVAL)), 4L);
 
         try (final KeyValueIterator<Windowed<String>, Long> iterator =
-                 sessionStore.findSessions("a", 0, Long.MAX_VALUE)
+            sessionStore.findSessions("a", "b", 0L, Long.MAX_VALUE)
         ) {
-            assertThat(valuesToList(iterator), equalTo(Arrays.asList(1L, 3L, 5L)));
+            assertEquals(valuesToSet(iterator), new HashSet<>(Arrays.asList(2L, 3L, 4L)));
         }
-
-        try (final KeyValueIterator<Windowed<String>, Long> iterator =
-                 sessionStore.findSessions("aa", 0, Long.MAX_VALUE)
-        ) {
-            assertThat(valuesToList(iterator), equalTo(Arrays.asList(2L, 4L)));
-        }
-
-        try (final KeyValueIterator<Windowed<String>, Long> iterator =
-                 sessionStore.findSessions("a", "aa", 0, Long.MAX_VALUE)
-        ) {
-            assertThat(valuesToList(iterator), equalTo(Arrays.asList(1L, 3L, 2L, 4L, 5L)));
-        }
-
-        try (final KeyValueIterator<Windowed<String>, Long> iterator =
-                 sessionStore.findSessions("a", "aa", 10, 0)
-        ) {
-            assertThat(valuesToList(iterator), equalTo(Collections.singletonList(2L)));
-        }
-    }
-
-    @Test
-    public void shouldReturnSameResultsForSingleKeyFindSessionsAndEqualKeyRangeFindSessions() {
-        sessionStore.put(new Windowed<>("a", new SessionWindow(0, 1)), 0L);
-        sessionStore.put(new Windowed<>("aa", new SessionWindow(2, 3)), 1L);
-        sessionStore.put(new Windowed<>("aa", new SessionWindow(4, 5)), 2L);
-        sessionStore.put(new Windowed<>("aaa", new SessionWindow(6, 7)), 3L);
-
-        final KeyValueIterator<Windowed<String>, Long> singleKeyIterator = sessionStore.findSessions("aa", 0L, 10L);
-        final KeyValueIterator<Windowed<String>, Long> keyRangeIterator = sessionStore.findSessions("aa", "aa", 0L, 10L);
-
-        assertEquals(singleKeyIterator.next(), keyRangeIterator.next());
-        assertEquals(singleKeyIterator.next(), keyRangeIterator.next());
-        assertFalse(singleKeyIterator.hasNext());
-        assertFalse(keyRangeIterator.hasNext());
-    }
-
-    @Test(expected = NullPointerException.class)
-    public void shouldThrowNullPointerExceptionOnFindSessionsNullKey() {
-        sessionStore.findSessions(null, 1L, 2L);
-    }
-
-    @Test(expected = NullPointerException.class)
-    public void shouldThrowNullPointerExceptionOnFindSessionsNullFromKey() {
-        sessionStore.findSessions(null, "anyKeyTo", 1L, 2L);
-    }
-
-    @Test(expected = NullPointerException.class)
-    public void shouldThrowNullPointerExceptionOnFindSessionsNullToKey() {
-        sessionStore.findSessions("anyKeyFrom", null, 1L, 2L);
-    }
-
-    @Test(expected = NullPointerException.class)
-    public void shouldThrowNullPointerExceptionOnFetchNullFromKey() {
-        sessionStore.fetch(null, "anyToKey");
-    }
-
-    @Test(expected = NullPointerException.class)
-    public void shouldThrowNullPointerExceptionOnFetchNullToKey() {
-        sessionStore.fetch("anyFromKey", null);
-    }
-
-    @Test(expected = NullPointerException.class)
-    public void shouldThrowNullPointerExceptionOnFetchNullKey() {
-        sessionStore.fetch(null);
-    }
-
-    @Test(expected = NullPointerException.class)
-    public void shouldThrowNullPointerExceptionOnRemoveNullKey() {
-        sessionStore.remove(null);
-    }
-
-    @Test(expected = NullPointerException.class)
-    public void shouldThrowNullPointerExceptionOnPutNullKey() {
-        sessionStore.put(null, 1L);
-    }
-
-    @Test
-    public void shouldNotThrowInvalidRangeExceptionWithNegativeFromKey() {
-        LogCaptureAppender.setClassLoggerToDebug(InMemoryWindowStore.class);
-        final LogCaptureAppender appender = LogCaptureAppender.createAndRegister();
-
-        final String keyFrom = Serdes.String().deserializer().deserialize("", Serdes.Integer().serializer().serialize("", -1));
-        final String keyTo = Serdes.String().deserializer().deserialize("", Serdes.Integer().serializer().serialize("", 1));
-
-        final KeyValueIterator<Windowed<String>, Long> iterator = sessionStore.findSessions(keyFrom, keyTo, 0L, 10L);
-        assertFalse(iterator.hasNext());
-
-        final List<String> messages = appender.getMessages();
-        assertThat(messages, hasItem("Returning empty iterator for fetch with invalid key range: from > to. "
-            + "This may be due to serdes that don't preserve ordering when lexicographically comparing the serialized bytes. "
-            + "Note that the built-in numerical serdes do not follow this for negative numbers"));
     }
-}
+}
\ No newline at end of file
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBWindowStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBWindowStoreTest.java
index 3342207..daf309a 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBWindowStoreTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBWindowStoreTest.java
@@ -16,730 +16,97 @@
  */
 package org.apache.kafka.streams.state.internals;
 
-import org.apache.kafka.clients.producer.MockProducer;
-import org.apache.kafka.clients.producer.Producer;
-import org.apache.kafka.common.header.Headers;
-import org.apache.kafka.common.metrics.Metrics;
+import java.io.File;
+import java.util.Collections;
+import java.util.List;
+import java.util.HashSet;
+import java.util.Set;
+
+import org.apache.kafka.common.serialization.Serde;
 import org.apache.kafka.common.serialization.Serdes;
-import org.apache.kafka.common.serialization.Serializer;
-import org.apache.kafka.common.utils.Bytes;
-import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.KeyValue;
-import org.apache.kafka.streams.errors.DefaultProductionExceptionHandler;
-import org.apache.kafka.streams.kstream.Windowed;
-import org.apache.kafka.streams.processor.ProcessorContext;
-import org.apache.kafka.streams.processor.internals.MockStreamsMetrics;
-import org.apache.kafka.streams.processor.internals.ProcessorRecordContext;
-import org.apache.kafka.streams.processor.internals.RecordCollector;
-import org.apache.kafka.streams.processor.internals.RecordCollectorImpl;
 import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender;
-import org.apache.kafka.streams.state.KeyValueIterator;
-import org.apache.kafka.streams.state.StateSerdes;
 import org.apache.kafka.streams.state.Stores;
 import org.apache.kafka.streams.state.WindowStore;
 import org.apache.kafka.streams.state.WindowStoreIterator;
-import org.apache.kafka.test.InternalMockProcessorContext;
-import org.apache.kafka.test.StreamsTestUtils;
-import org.apache.kafka.test.TestUtils;
-import org.junit.After;
-import org.junit.Before;
 import org.junit.Test;
 
-import java.io.File;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-
 import static java.time.Duration.ofMillis;
 import static java.time.Instant.ofEpochMilli;
 import static java.util.Arrays.asList;
 import static java.util.Objects.requireNonNull;
-import static org.hamcrest.CoreMatchers.equalTo;
-import static org.hamcrest.CoreMatchers.hasItem;
-import static org.hamcrest.MatcherAssert.assertThat;
+
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNull;
-import static org.junit.Assert.assertTrue;
 
 @SuppressWarnings("PointlessArithmeticExpression")
-public class RocksDBWindowStoreTest {
-
-    private static final long DEFAULT_CACHE_SIZE_BYTES = 1024 * 1024L;
-
-    private final int numSegments = 3;
-    private final long windowSize = 3L;
-    private final long segmentInterval = 60_000L;
-    private final long retentionPeriod = segmentInterval * (numSegments - 1);
-    private final String windowName = "window";
-    private final KeyValueSegments segments = new KeyValueSegments(windowName, retentionPeriod, segmentInterval);
-    private final StateSerdes<Integer, String> serdes = new StateSerdes<>("", Serdes.Integer(), Serdes.String());
-
-    private final List<KeyValue<byte[], byte[]>> changeLog = new ArrayList<>();
-    private final ThreadCache cache = new ThreadCache(
-        new LogContext("TestCache "),
-        DEFAULT_CACHE_SIZE_BYTES,
-        new MockStreamsMetrics(new Metrics()));
-
-    private final Producer<byte[], byte[]> producer =
-        new MockProducer<>(true, Serdes.ByteArray().serializer(), Serdes.ByteArray().serializer());
-    private final RecordCollector recordCollector = new RecordCollectorImpl(
-        "RocksDBWindowStoreTestTask",
-        new LogContext("RocksDBWindowStoreTestTask "),
-        new DefaultProductionExceptionHandler(),
-        new Metrics().sensor("skipped-records")
-    ) {
-        @Override
-        public <K1, V1> void send(final String topic,
-                                  final K1 key,
-                                  final V1 value,
-                                  final Headers headers,
-                                  final Integer partition,
-                                  final Long timestamp,
-                                  final Serializer<K1> keySerializer,
-                                  final Serializer<V1> valueSerializer) {
-            changeLog.add(new KeyValue<>(
-                keySerializer.serialize(topic, headers, key),
-                valueSerializer.serialize(topic, headers, value))
-            );
-        }
-    };
+public class RocksDBWindowStoreTest extends WindowBytesStoreTest {
 
-    private final File baseDir = TestUtils.tempDirectory("test");
-    private final InternalMockProcessorContext context =
-        new InternalMockProcessorContext(baseDir, Serdes.ByteArray(), Serdes.ByteArray(), recordCollector, cache);
-    private WindowStore<Integer, String> windowStore;
+    private static final String STORE_NAME = "rocksDB window store";
 
-    private WindowStore<Integer, String> createWindowStore(final ProcessorContext context, final boolean retainDuplicates) {
-        final WindowStore<Integer, String> store = Stores.windowStoreBuilder(
+    private final KeyValueSegments segments = new KeyValueSegments(STORE_NAME, RETENTION_PERIOD, SEGMENT_INTERVAL);
+
+    @Override
+    <K, V> WindowStore<K, V> buildWindowStore(final long retentionPeriod,
+                                              final long windowSize,
+                                              final boolean retainDuplicates,
+                                              final Serde<K> keySerde,
+                                              final Serde<V> valueSerde) {
+        return Stores.windowStoreBuilder(
             Stores.persistentWindowStore(
-                windowName,
+                STORE_NAME,
                 ofMillis(retentionPeriod),
                 ofMillis(windowSize),
                 retainDuplicates),
-            Serdes.Integer(),
-            Serdes.String()).build();
-
-        store.init(context, store);
-        return store;
+            keySerde,
+            valueSerde)
+            .build();
     }
 
-    @Before
-    public void initRecordCollector() {
-        recordCollector.init(producer);
+    @Override
+    String getMetricsScope() {
+        return new RocksDbWindowBytesStoreSupplier(null, 0, 0, 0, false, false).metricsScope();
     }
 
-    @After
-    public void closeStore() {
-        if (windowStore != null) {
-            windowStore.close();
-        }
+    @Override
+    void setClassLoggerToDebug() {
+        LogCaptureAppender.setClassLoggerToDebug(AbstractRocksDBSegmentedBytesStore.class);
     }
 
     @Test
     public void shouldOnlyIterateOpenSegments() {
-        windowStore = createWindowStore(context, false);
         long currentTime = 0;
         setCurrentTime(currentTime);
         windowStore.put(1, "one");
 
-        currentTime = currentTime + segmentInterval;
+        currentTime = currentTime + SEGMENT_INTERVAL;
         setCurrentTime(currentTime);
         windowStore.put(1, "two");
-        currentTime = currentTime + segmentInterval;
+        currentTime = currentTime + SEGMENT_INTERVAL;
 
         setCurrentTime(currentTime);
         windowStore.put(1, "three");
 
-        final WindowStoreIterator<String> iterator = windowStore.fetch(1, ofEpochMilli(0), ofEpochMilli(currentTime));
+        final WindowStoreIterator<String> iterator = windowStore.fetch(1, 0L, currentTime);
 
         // roll to the next segment that will close the first
-        currentTime = currentTime + segmentInterval;
+        currentTime = currentTime + SEGMENT_INTERVAL;
         setCurrentTime(currentTime);
         windowStore.put(1, "four");
 
         // should only have 2 values as the first segment is no longer open
-        assertEquals(new KeyValue<>(segmentInterval, "two"), iterator.next());
-        assertEquals(new KeyValue<>(2 * segmentInterval, "three"), iterator.next());
+        assertEquals(new KeyValue<>(SEGMENT_INTERVAL, "two"), iterator.next());
+        assertEquals(new KeyValue<>(2 * SEGMENT_INTERVAL, "three"), iterator.next());
         assertFalse(iterator.hasNext());
     }
 
-    private void setCurrentTime(final long currentTime) {
-        context.setRecordContext(createRecordContext(currentTime));
-    }
-
-    private ProcessorRecordContext createRecordContext(final long time) {
-        return new ProcessorRecordContext(time, 0, 0, "topic", null);
-    }
-
-    @Test
-    public void testRangeAndSinglePointFetch() {
-        windowStore = createWindowStore(context, false);
-        final long startTime = segmentInterval - 4L;
-
-        putFirstBatch(windowStore, startTime, context);
-
-        assertEquals("zero", windowStore.fetch(0, startTime));
-        assertEquals("one", windowStore.fetch(1, startTime + 1L));
-        assertEquals("two", windowStore.fetch(2, startTime + 2L));
-        assertEquals("four", windowStore.fetch(4, startTime + 4L));
-        assertEquals("five", windowStore.fetch(5, startTime + 5L));
-
-        assertEquals(
-            Collections.singletonList("zero"),
-            toList(windowStore.fetch(
-                0,
-                ofEpochMilli(startTime + 0 - windowSize),
-                ofEpochMilli(startTime + 0 + windowSize))));
-
-        putSecondBatch(windowStore, startTime, context);
-
-        assertEquals("two+1", windowStore.fetch(2, startTime + 3L));
-        assertEquals("two+2", windowStore.fetch(2, startTime + 4L));
-        assertEquals("two+3", windowStore.fetch(2, startTime + 5L));
-        assertEquals("two+4", windowStore.fetch(2, startTime + 6L));
-        assertEquals("two+5", windowStore.fetch(2, startTime + 7L));
-        assertEquals("two+6", windowStore.fetch(2, startTime + 8L));
-
-        assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime - 2L - windowSize),
-                ofEpochMilli(startTime - 2L + windowSize))));
-        assertEquals(
-            Collections.singletonList("two"),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime - 1L - windowSize),
-                ofEpochMilli(startTime - 1L + windowSize))));
-        assertEquals(
-            asList("two", "two+1"),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime - windowSize),
-                ofEpochMilli(startTime + windowSize))));
-        assertEquals(
-            asList("two", "two+1", "two+2"),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime + 1L - windowSize),
-                ofEpochMilli(startTime + 1L + windowSize))));
-        assertEquals(
-            asList("two", "two+1", "two+2", "two+3"),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime + 2L - windowSize),
-                ofEpochMilli(startTime + 2L + windowSize))));
-        assertEquals(
-            asList("two", "two+1", "two+2", "two+3", "two+4"),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime + 3L - windowSize),
-                ofEpochMilli(startTime + 3L + windowSize))));
-        assertEquals(
-            asList("two", "two+1", "two+2", "two+3", "two+4", "two+5"),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime + 4L - windowSize),
-                ofEpochMilli(startTime + 4L + windowSize))));
-        assertEquals(
-            asList("two", "two+1", "two+2", "two+3", "two+4", "two+5", "two+6"),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime + 5L - windowSize),
-                ofEpochMilli(startTime + 5L + windowSize))));
-        assertEquals(
-            asList("two+1", "two+2", "two+3", "two+4", "two+5", "two+6"),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime + 6L - windowSize),
-                ofEpochMilli(startTime + 6L + windowSize))));
-        assertEquals(
-            asList("two+2", "two+3", "two+4", "two+5", "two+6"),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime + 7L - windowSize),
-                ofEpochMilli(startTime + 7L + windowSize))));
-        assertEquals(
-            asList("two+3", "two+4", "two+5", "two+6"),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime + 8L - windowSize),
-                ofEpochMilli(startTime + 8L + windowSize))));
-        assertEquals(
-            asList("two+4", "two+5", "two+6"),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime + 9L - windowSize),
-                ofEpochMilli(startTime + 9L + windowSize))));
-        assertEquals(
-            asList("two+5", "two+6"),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime + 10L - windowSize),
-                ofEpochMilli(startTime + 10L + windowSize))));
-        assertEquals(
-            Collections.singletonList("two+6"),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime + 11L - windowSize),
-                ofEpochMilli(startTime + 11L + windowSize))));
-        assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime + 12L - windowSize),
-                ofEpochMilli(startTime + 12L + windowSize))));
-
-        // Flush the store and verify all current entries were properly flushed ...
-        windowStore.flush();
-
-        final Map<Integer, Set<String>> entriesByKey = entriesByKey(changeLog, startTime);
-
-        assertEquals(Utils.mkSet("zero@0"), entriesByKey.get(0));
-        assertEquals(Utils.mkSet("one@1"), entriesByKey.get(1));
-        assertEquals(Utils.mkSet("two@2", "two+1@3", "two+2@4", "two+3@5", "two+4@6", "two+5@7", "two+6@8"), entriesByKey.get(2));
-        assertNull(entriesByKey.get(3));
-        assertEquals(Utils.mkSet("four@4"), entriesByKey.get(4));
-        assertEquals(Utils.mkSet("five@5"), entriesByKey.get(5));
-        assertNull(entriesByKey.get(6));
-    }
-
-    @Test
-    public void shouldGetAll() {
-        windowStore = createWindowStore(context, false);
-        final long startTime = segmentInterval - 4L;
-
-        putFirstBatch(windowStore, startTime, context);
-
-        final KeyValue<Windowed<Integer>, String> zero = windowedPair(0, "zero", startTime + 0);
-        final KeyValue<Windowed<Integer>, String> one = windowedPair(1, "one", startTime + 1);
-        final KeyValue<Windowed<Integer>, String> two = windowedPair(2, "two", startTime + 2);
-        final KeyValue<Windowed<Integer>, String> four = windowedPair(4, "four", startTime + 4);
-        final KeyValue<Windowed<Integer>, String> five = windowedPair(5, "five", startTime + 5);
-
-        assertEquals(
-            asList(zero, one, two, four, five),
-            StreamsTestUtils.toList(windowStore.all())
-        );
-    }
-
-    @Test
-    public void shouldFetchAllInTimeRange() {
-        windowStore = createWindowStore(context, false);
-        final long startTime = segmentInterval - 4L;
-
-        putFirstBatch(windowStore, startTime, context);
-
-        final KeyValue<Windowed<Integer>, String> zero = windowedPair(0, "zero", startTime + 0);
-        final KeyValue<Windowed<Integer>, String> one = windowedPair(1, "one", startTime + 1);
-        final KeyValue<Windowed<Integer>, String> two = windowedPair(2, "two", startTime + 2);
-        final KeyValue<Windowed<Integer>, String> four = windowedPair(4, "four", startTime + 4);
-        final KeyValue<Windowed<Integer>, String> five = windowedPair(5, "five", startTime + 5);
-
-        assertEquals(
-            asList(one, two, four),
-            StreamsTestUtils.toList(windowStore.fetchAll(ofEpochMilli(startTime + 1), ofEpochMilli(startTime + 4)))
-        );
-        assertEquals(
-            asList(zero, one, two),
-            StreamsTestUtils.toList(windowStore.fetchAll(ofEpochMilli(startTime + 0), ofEpochMilli(startTime + 3)))
-        );
-        assertEquals(
-            asList(one, two, four, five),
-            StreamsTestUtils.toList(windowStore.fetchAll(ofEpochMilli(startTime + 1), ofEpochMilli(startTime + 5)))
-        );
-    }
-
-    @Test
-    public void testFetchRange() {
-        windowStore = createWindowStore(context, false);
-        final long startTime = segmentInterval - 4L;
-
-        putFirstBatch(windowStore, startTime, context);
-
-        final KeyValue<Windowed<Integer>, String> zero = windowedPair(0, "zero", startTime + 0);
-        final KeyValue<Windowed<Integer>, String> one = windowedPair(1, "one", startTime + 1);
-        final KeyValue<Windowed<Integer>, String> two = windowedPair(2, "two", startTime + 2);
-        final KeyValue<Windowed<Integer>, String> four = windowedPair(4, "four", startTime + 4);
-        final KeyValue<Windowed<Integer>, String> five = windowedPair(5, "five", startTime + 5);
-
-        assertEquals(
-            asList(zero, one),
-            StreamsTestUtils.toList(windowStore.fetch(
-                0,
-                1,
-                ofEpochMilli(startTime + 0L - windowSize),
-                ofEpochMilli(startTime + 0L + windowSize)))
-        );
-        assertEquals(
-            Collections.singletonList(one),
-            StreamsTestUtils.toList(windowStore.fetch(
-                1,
-                1,
-                ofEpochMilli(startTime + 0L - windowSize),
-                ofEpochMilli(startTime + 0L + windowSize)))
-        );
-        assertEquals(
-            asList(one, two),
-            StreamsTestUtils.toList(windowStore.fetch(
-                1,
-                3,
-                ofEpochMilli(startTime + 0L - windowSize),
-                ofEpochMilli(startTime + 0L + windowSize)))
-        );
-        assertEquals(
-            asList(zero, one, two),
-            StreamsTestUtils.toList(windowStore.fetch(
-                0,
-                5,
-                ofEpochMilli(startTime + 0L - windowSize),
-                ofEpochMilli(startTime + 0L + windowSize)))
-        );
-        assertEquals(
-            asList(zero, one, two, four, five),
-            StreamsTestUtils.toList(windowStore.fetch(
-                0,
-                5,
-                ofEpochMilli(startTime + 0L - windowSize),
-                ofEpochMilli(startTime + 0L + windowSize + 5L)))
-        );
-        assertEquals(
-            asList(two, four, five),
-            StreamsTestUtils.toList(windowStore.fetch(
-                0,
-                5,
-                ofEpochMilli(startTime + 2L),
-                ofEpochMilli(startTime + 0L + windowSize + 5L)))
-        );
-        assertEquals(
-            Collections.emptyList(),
-            StreamsTestUtils.toList(windowStore.fetch(
-                4,
-                5,
-                ofEpochMilli(startTime + 2L),
-                ofEpochMilli(startTime + windowSize)))
-        );
-        assertEquals(
-            Collections.emptyList(),
-            StreamsTestUtils.toList(windowStore.fetch(
-                0,
-                3,
-                ofEpochMilli(startTime + 3L),
-                ofEpochMilli(startTime + windowSize + 5)))
-        );
-    }
-
-    @Test
-    public void testPutAndFetchBefore() {
-        windowStore = createWindowStore(context, false);
-        final long startTime = segmentInterval - 4L;
-
-        putFirstBatch(windowStore, startTime, context);
-
-        assertEquals(
-            Collections.singletonList("zero"),
-            toList(windowStore.fetch(
-                0,
-                ofEpochMilli(startTime + 0L - windowSize),
-                ofEpochMilli(startTime + 0L))));
-        assertEquals(
-            Collections.singletonList("one"),
-            toList(windowStore.fetch(
-                1,
-                ofEpochMilli(startTime + 1L - windowSize),
-                ofEpochMilli(startTime + 1L))));
-        assertEquals(
-            Collections.singletonList("two"),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime + 2L - windowSize),
-                ofEpochMilli(startTime + 2L))));
-        assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
-                3,
-                ofEpochMilli(startTime + 3L - windowSize),
-                ofEpochMilli(startTime + 3L))));
-        assertEquals(
-            Collections.singletonList("four"),
-            toList(windowStore.fetch(
-                4,
-                ofEpochMilli(startTime + 4L - windowSize),
-                ofEpochMilli(startTime + 4L))));
-        assertEquals(
-            Collections.singletonList("five"),
-            toList(windowStore.fetch(
-                5,
-                ofEpochMilli(startTime + 5L - windowSize),
-                ofEpochMilli(startTime + 5L))));
-
-        putSecondBatch(windowStore, startTime, context);
-
-        assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime - 1L - windowSize),
-                ofEpochMilli(startTime - 1L))));
-        assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime + 0L - windowSize),
-                ofEpochMilli(startTime + 0L))));
-        assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime + 1L - windowSize),
-                ofEpochMilli(startTime + 1L))));
-        assertEquals(
-            Collections.singletonList("two"),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime + 2L - windowSize),
-                ofEpochMilli(startTime + 2L))));
-        assertEquals(
-            asList("two", "two+1"),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime + 3L - windowSize),
-                ofEpochMilli(startTime + 3L))));
-        assertEquals(
-            asList("two", "two+1", "two+2"),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime + 4L - windowSize),
-                ofEpochMilli(startTime + 4L))));
-        assertEquals(
-            asList("two", "two+1", "two+2", "two+3"),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime + 5L - windowSize),
-                ofEpochMilli(startTime + 5L))));
-        assertEquals(
-            asList("two+1", "two+2", "two+3", "two+4"),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime + 6L - windowSize),
-                ofEpochMilli(startTime + 6L))));
-        assertEquals(
-            asList("two+2", "two+3", "two+4", "two+5"),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime + 7L - windowSize),
-                ofEpochMilli(startTime + 7L))));
-        assertEquals(
-            asList("two+3", "two+4", "two+5", "two+6"),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime + 8L - windowSize),
-                ofEpochMilli(startTime + 8L))));
-        assertEquals(
-            asList("two+4", "two+5", "two+6"),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime + 9L - windowSize),
-                ofEpochMilli(startTime + 9L))));
-        assertEquals(
-            asList("two+5", "two+6"),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime + 10L - windowSize),
-                ofEpochMilli(startTime + 10L))));
-        assertEquals(
-            Collections.singletonList("two+6"),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime + 11L - windowSize),
-                ofEpochMilli(startTime + 11L))));
-        assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime + 12L - windowSize),
-                ofEpochMilli(startTime + 12L))));
-        assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime + 13L - windowSize),
-                ofEpochMilli(startTime + 13L))));
-
-        // Flush the store and verify all current entries were properly flushed ...
-        windowStore.flush();
-
-        final Map<Integer, Set<String>> entriesByKey = entriesByKey(changeLog, startTime);
-        assertEquals(Utils.mkSet("zero@0"), entriesByKey.get(0));
-        assertEquals(Utils.mkSet("one@1"), entriesByKey.get(1));
-        assertEquals(
-            Utils.mkSet("two@2", "two+1@3", "two+2@4", "two+3@5", "two+4@6", "two+5@7", "two+6@8"),
-            entriesByKey.get(2));
-        assertNull(entriesByKey.get(3));
-        assertEquals(Utils.mkSet("four@4"), entriesByKey.get(4));
-        assertEquals(Utils.mkSet("five@5"), entriesByKey.get(5));
-        assertNull(entriesByKey.get(6));
-    }
-
-    @Test
-    public void testPutAndFetchAfter() {
-        windowStore = createWindowStore(context, false);
-        final long startTime = segmentInterval - 4L;
-
-        putFirstBatch(windowStore, startTime, context);
-
-        assertEquals(
-            Collections.singletonList("zero"),
-            toList(windowStore.fetch(0, ofEpochMilli(startTime + 0L), ofEpochMilli(startTime + 0L + windowSize))));
-        assertEquals(
-            Collections.singletonList("one"),
-            toList(windowStore.fetch(1, ofEpochMilli(startTime + 1L), ofEpochMilli(startTime + 1L + windowSize))));
-        assertEquals(
-            Collections.singletonList("two"),
-            toList(windowStore.fetch(2, ofEpochMilli(startTime + 2L), ofEpochMilli(startTime + 2L + windowSize))));
-        assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(3, ofEpochMilli(startTime + 3L), ofEpochMilli(startTime + 3L + windowSize))));
-        assertEquals(
-            Collections.singletonList("four"),
-            toList(windowStore.fetch(4, ofEpochMilli(startTime + 4L), ofEpochMilli(startTime + 4L + windowSize))));
-        assertEquals(
-            Collections.singletonList("five"),
-            toList(windowStore.fetch(5, ofEpochMilli(startTime + 5L), ofEpochMilli(startTime + 5L + windowSize))));
-
-        putSecondBatch(windowStore, startTime, context);
-
-        assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(2, ofEpochMilli(startTime - 2L), ofEpochMilli(startTime - 2L + windowSize))));
-        assertEquals(
-            Collections.singletonList("two"),
-            toList(windowStore.fetch(2, ofEpochMilli(startTime - 1L), ofEpochMilli(startTime - 1L + windowSize))));
-        assertEquals(
-            asList("two", "two+1"),
-            toList(windowStore.fetch(2, ofEpochMilli(startTime), ofEpochMilli(startTime + windowSize))));
-        assertEquals(
-            asList("two", "two+1", "two+2"),
-            toList(windowStore.fetch(2, ofEpochMilli(startTime + 1L), ofEpochMilli(startTime + 1L + windowSize))));
-        assertEquals(
-            asList("two", "two+1", "two+2", "two+3"),
-            toList(windowStore.fetch(2, ofEpochMilli(startTime + 2L), ofEpochMilli(startTime + 2L + windowSize))));
-        assertEquals(
-            asList("two+1", "two+2", "two+3", "two+4"),
-            toList(windowStore.fetch(2, ofEpochMilli(startTime + 3L), ofEpochMilli(startTime + 3L + windowSize))));
-        assertEquals(
-            asList("two+2", "two+3", "two+4", "two+5"),
-            toList(windowStore.fetch(2, ofEpochMilli(startTime + 4L), ofEpochMilli(startTime + 4L + windowSize))));
-        assertEquals(
-            asList("two+3", "two+4", "two+5", "two+6"),
-            toList(windowStore.fetch(2, ofEpochMilli(startTime + 5L), ofEpochMilli(startTime + 5L + windowSize))));
-        assertEquals(
-            asList("two+4", "two+5", "two+6"),
-            toList(windowStore.fetch(2, ofEpochMilli(startTime + 6L), ofEpochMilli(startTime + 6L + windowSize))));
-        assertEquals(
-            asList("two+5", "two+6"),
-            toList(windowStore.fetch(2, ofEpochMilli(startTime + 7L), ofEpochMilli(startTime + 7L + windowSize))));
-        assertEquals(
-            Collections.singletonList("two+6"),
-            toList(windowStore.fetch(2, ofEpochMilli(startTime + 8L), ofEpochMilli(startTime + 8L + windowSize))));
-        assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(2, ofEpochMilli(startTime + 9L), ofEpochMilli(startTime + 9L + windowSize))));
-        assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(2, ofEpochMilli(startTime + 10L), ofEpochMilli(startTime + 10L + windowSize))));
-        assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(2, ofEpochMilli(startTime + 11L), ofEpochMilli(startTime + 11L + windowSize))));
-        assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(2, ofEpochMilli(startTime + 12L), ofEpochMilli(startTime + 12L + windowSize))));
-
-        // Flush the store and verify all current entries were properly flushed ...
-        windowStore.flush();
-
-        final Map<Integer, Set<String>> entriesByKey = entriesByKey(changeLog, startTime);
-
-        assertEquals(Utils.mkSet("zero@0"), entriesByKey.get(0));
-        assertEquals(Utils.mkSet("one@1"), entriesByKey.get(1));
-        assertEquals(
-            Utils.mkSet("two@2", "two+1@3", "two+2@4", "two+3@5", "two+4@6", "two+5@7", "two+6@8"),
-            entriesByKey.get(2));
-        assertNull(entriesByKey.get(3));
-        assertEquals(Utils.mkSet("four@4"), entriesByKey.get(4));
-        assertEquals(Utils.mkSet("five@5"), entriesByKey.get(5));
-        assertNull(entriesByKey.get(6));
-    }
-
-    @Test
-    public void testPutSameKeyTimestamp() {
-        windowStore = createWindowStore(context, true);
-        final long startTime = segmentInterval - 4L;
-
-        setCurrentTime(startTime);
-        windowStore.put(0, "zero");
-
-        assertEquals(
-            Collections.singletonList("zero"),
-            toList(windowStore.fetch(0, ofEpochMilli(startTime - windowSize), ofEpochMilli(startTime + windowSize))));
-
-        windowStore.put(0, "zero");
-        windowStore.put(0, "zero+");
-        windowStore.put(0, "zero++");
-
-        assertEquals(
-            asList("zero", "zero", "zero+", "zero++"),
-            toList(windowStore.fetch(
-                0,
-                ofEpochMilli(startTime - windowSize),
-                ofEpochMilli(startTime + windowSize))));
-        assertEquals(
-            asList("zero", "zero", "zero+", "zero++"),
-            toList(windowStore.fetch(
-                0,
-                ofEpochMilli(startTime + 1L - windowSize),
-                ofEpochMilli(startTime + 1L + windowSize))));
-        assertEquals(
-            asList("zero", "zero", "zero+", "zero++"),
-            toList(windowStore.fetch(
-                0,
-                ofEpochMilli(startTime + 2L - windowSize),
-                ofEpochMilli(startTime + 2L + windowSize))));
-        assertEquals(
-            asList("zero", "zero", "zero+", "zero++"),
-            toList(windowStore.fetch(
-                0,
-                ofEpochMilli(startTime + 3L - windowSize),
-                ofEpochMilli(startTime + 3L + windowSize))));
-        assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
-                0,
-                ofEpochMilli(startTime + 4L - windowSize),
-                ofEpochMilli(startTime + 4L + windowSize))));
-
-        // Flush the store and verify all current entries were properly flushed ...
-        windowStore.flush();
-
-        final Map<Integer, Set<String>> entriesByKey = entriesByKey(changeLog, startTime);
-
-        assertEquals(Utils.mkSet("zero@0", "zero@0", "zero+@0", "zero++@0"), entriesByKey.get(0));
-    }
-
     @Test
     public void testRolling() {
-        windowStore = createWindowStore(context, false);
 
         // to validate segments
-        final long startTime = segmentInterval * 2;
-        final long increment = segmentInterval / 2;
+        final long startTime = SEGMENT_INTERVAL * 2;
+        final long increment = SEGMENT_INTERVAL / 2;
         setCurrentTime(startTime);
         windowStore.put(0, "zero");
         assertEquals(Utils.mkSet(segments.segmentName(2)), segmentDirs(baseDir));
@@ -781,41 +148,41 @@ public class RocksDBWindowStoreTest {
         );
 
         assertEquals(
-            Collections.singletonList("zero"),
-            toList(windowStore.fetch(
+            new HashSet<>(Collections.singletonList("zero")),
+            toSet(windowStore.fetch(
                 0,
-                ofEpochMilli(startTime - windowSize),
-                ofEpochMilli(startTime + windowSize))));
+                ofEpochMilli(startTime - WINDOW_SIZE),
+                ofEpochMilli(startTime + WINDOW_SIZE))));
         assertEquals(
-            Collections.singletonList("one"),
-            toList(windowStore.fetch(
+            new HashSet<>(Collections.singletonList("one")),
+            toSet(windowStore.fetch(
                 1,
-                ofEpochMilli(startTime + increment - windowSize),
-                ofEpochMilli(startTime + increment + windowSize))));
+                ofEpochMilli(startTime + increment - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment + WINDOW_SIZE))));
         assertEquals(
-            Collections.singletonList("two"),
-            toList(windowStore.fetch(
+            new HashSet<>(Collections.singletonList("two")),
+            toSet(windowStore.fetch(
                 2,
-                ofEpochMilli(startTime + increment * 2 - windowSize),
-                ofEpochMilli(startTime + increment * 2 + windowSize))));
+                ofEpochMilli(startTime + increment * 2 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 2 + WINDOW_SIZE))));
         assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
                 3,
-                ofEpochMilli(startTime + increment * 3 - windowSize),
-                ofEpochMilli(startTime + increment * 3 + windowSize))));
+                ofEpochMilli(startTime + increment * 3 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 3 + WINDOW_SIZE))));
         assertEquals(
-            Collections.singletonList("four"),
-            toList(windowStore.fetch(
+            new HashSet<>(Collections.singletonList("four")),
+            toSet(windowStore.fetch(
                 4,
-                ofEpochMilli(startTime + increment * 4 - windowSize),
-                ofEpochMilli(startTime + increment * 4 + windowSize))));
+                ofEpochMilli(startTime + increment * 4 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 4 + WINDOW_SIZE))));
         assertEquals(
-            Collections.singletonList("five"),
-            toList(windowStore.fetch(
+            new HashSet<>(Collections.singletonList("five")),
+            toSet(windowStore.fetch(
                 5,
-                ofEpochMilli(startTime + increment * 5 - windowSize),
-                ofEpochMilli(startTime + increment * 5 + windowSize))));
+                ofEpochMilli(startTime + increment * 5 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 5 + WINDOW_SIZE))));
 
         setCurrentTime(startTime + increment * 6);
         windowStore.put(6, "six");
@@ -829,46 +196,47 @@ public class RocksDBWindowStoreTest {
         );
 
         assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
                 0,
-                ofEpochMilli(startTime - windowSize),
-                ofEpochMilli(startTime + windowSize))));
+                ofEpochMilli(startTime - WINDOW_SIZE),
+                ofEpochMilli(startTime + WINDOW_SIZE))));
         assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
                 1,
-                ofEpochMilli(startTime + increment - windowSize),
-                ofEpochMilli(startTime + increment + windowSize))));
+                ofEpochMilli(startTime + increment - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment + WINDOW_SIZE))));
         assertEquals(
-            Collections.singletonList("two"),
-            toList(windowStore.fetch(
+            new HashSet<>(Collections.singletonList("two")),
+            toSet(windowStore.fetch(
                 2,
-                ofEpochMilli(startTime + increment * 2 - windowSize),
-                ofEpochMilli(startTime + increment * 2 + windowSize))));
+                ofEpochMilli(startTime + increment * 2 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 2 + WINDOW_SIZE))));
         assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
                 3,
-                ofEpochMilli(startTime + increment * 3 - windowSize),
-                ofEpochMilli(startTime + increment * 3 + windowSize))));
-        assertEquals(Collections.singletonList("four"),
-            toList(windowStore.fetch(
+                ofEpochMilli(startTime + increment * 3 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 3 + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.singletonList("four")),
+            toSet(windowStore.fetch(
                 4,
-                ofEpochMilli(startTime + increment * 4 - windowSize),
-                ofEpochMilli(startTime + increment * 4 + windowSize))));
+                ofEpochMilli(startTime + increment * 4 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 4 + WINDOW_SIZE))));
         assertEquals(
-            Collections.singletonList("five"),
-            toList(windowStore.fetch(
+            new HashSet<>(Collections.singletonList("five")),
+            toSet(windowStore.fetch(
                 5,
-                ofEpochMilli(startTime + increment * 5 - windowSize),
-                ofEpochMilli(startTime + increment * 5 + windowSize))));
+                ofEpochMilli(startTime + increment * 5 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 5 + WINDOW_SIZE))));
         assertEquals(
-            Collections.singletonList("six"),
-            toList(windowStore.fetch(
+            new HashSet<>(Collections.singletonList("six")),
+            toSet(windowStore.fetch(
                 6,
-                ofEpochMilli(startTime + increment * 6 - windowSize),
-                ofEpochMilli(startTime + increment * 6 + windowSize))));
+                ofEpochMilli(startTime + increment * 6 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 6 + WINDOW_SIZE))));
 
         setCurrentTime(startTime + increment * 7);
         windowStore.put(7, "seven");
@@ -882,53 +250,53 @@ public class RocksDBWindowStoreTest {
         );
 
         assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
                 0,
-                ofEpochMilli(startTime - windowSize),
-                ofEpochMilli(startTime + windowSize))));
+                ofEpochMilli(startTime - WINDOW_SIZE),
+                ofEpochMilli(startTime + WINDOW_SIZE))));
         assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
                 1,
-                ofEpochMilli(startTime + increment - windowSize),
-                ofEpochMilli(startTime + increment + windowSize))));
+                ofEpochMilli(startTime + increment - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment + WINDOW_SIZE))));
         assertEquals(
-            Collections.singletonList("two"),
-            toList(windowStore.fetch(
+            new HashSet<>(Collections.singletonList("two")),
+            toSet(windowStore.fetch(
                 2,
-                ofEpochMilli(startTime + increment * 2 - windowSize),
-                ofEpochMilli(startTime + increment * 2 + windowSize))));
+                ofEpochMilli(startTime + increment * 2 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 2 + WINDOW_SIZE))));
         assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
                 3,
-                ofEpochMilli(startTime + increment * 3 - windowSize),
-                ofEpochMilli(startTime + increment * 3 + windowSize))));
+                ofEpochMilli(startTime + increment * 3 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 3 + WINDOW_SIZE))));
         assertEquals(
-            Collections.singletonList("four"),
-            toList(windowStore.fetch(
+            new HashSet<>(Collections.singletonList("four")),
+            toSet(windowStore.fetch(
                 4,
-                ofEpochMilli(startTime + increment * 4 - windowSize),
-                ofEpochMilli(startTime + increment * 4 + windowSize))));
+                ofEpochMilli(startTime + increment * 4 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 4 + WINDOW_SIZE))));
         assertEquals(
-            Collections.singletonList("five"),
-            toList(windowStore.fetch(
+            new HashSet<>(Collections.singletonList("five")),
+            toSet(windowStore.fetch(
                 5,
-                ofEpochMilli(startTime + increment * 5 - windowSize),
-                ofEpochMilli(startTime + increment * 5 + windowSize))));
+                ofEpochMilli(startTime + increment * 5 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 5 + WINDOW_SIZE))));
         assertEquals(
-            Collections.singletonList("six"),
-            toList(windowStore.fetch(
+            new HashSet<>(Collections.singletonList("six")),
+            toSet(windowStore.fetch(
                 6,
-                ofEpochMilli(startTime + increment * 6 - windowSize),
-                ofEpochMilli(startTime + increment * 6 + windowSize))));
+                ofEpochMilli(startTime + increment * 6 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 6 + WINDOW_SIZE))));
         assertEquals(
-            Collections.singletonList("seven"),
-            toList(windowStore.fetch(
+            new HashSet<>(Collections.singletonList("seven")),
+            toSet(windowStore.fetch(
                 7,
-                ofEpochMilli(startTime + increment * 7 - windowSize),
-                ofEpochMilli(startTime + increment * 7 + windowSize))));
+                ofEpochMilli(startTime + increment * 7 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 7 + WINDOW_SIZE))));
 
         setCurrentTime(startTime + increment * 8);
         windowStore.put(8, "eight");
@@ -942,59 +310,59 @@ public class RocksDBWindowStoreTest {
         );
 
         assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
                 0,
-                ofEpochMilli(startTime - windowSize),
-                ofEpochMilli(startTime + windowSize))));
+                ofEpochMilli(startTime - WINDOW_SIZE),
+                ofEpochMilli(startTime + WINDOW_SIZE))));
         assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
                 1,
-                ofEpochMilli(startTime + increment - windowSize),
-                ofEpochMilli(startTime + increment + windowSize))));
+                ofEpochMilli(startTime + increment - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment + WINDOW_SIZE))));
         assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
                 2,
-                ofEpochMilli(startTime + increment * 2 - windowSize),
-                ofEpochMilli(startTime + increment * 2 + windowSize))));
+                ofEpochMilli(startTime + increment * 2 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 2 + WINDOW_SIZE))));
         assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
                 3,
-                ofEpochMilli(startTime + increment * 3 - windowSize),
-                ofEpochMilli(startTime + increment * 3 + windowSize))));
+                ofEpochMilli(startTime + increment * 3 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 3 + WINDOW_SIZE))));
         assertEquals(
-            Collections.singletonList("four"),
-            toList(windowStore.fetch(
+            new HashSet<>(Collections.singletonList("four")),
+            toSet(windowStore.fetch(
                 4,
-                ofEpochMilli(startTime + increment * 4 - windowSize),
-                ofEpochMilli(startTime + increment * 4 + windowSize))));
+                ofEpochMilli(startTime + increment * 4 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 4 + WINDOW_SIZE))));
         assertEquals(
-            Collections.singletonList("five"),
-            toList(windowStore.fetch(
+            new HashSet<>(Collections.singletonList("five")),
+            toSet(windowStore.fetch(
                 5,
-                ofEpochMilli(startTime + increment * 5 - windowSize),
-                ofEpochMilli(startTime + increment * 5 + windowSize))));
+                ofEpochMilli(startTime + increment * 5 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 5 + WINDOW_SIZE))));
         assertEquals(
-            Collections.singletonList("six"),
-            toList(windowStore.fetch(
+            new HashSet<>(Collections.singletonList("six")),
+            toSet(windowStore.fetch(
                 6,
-                ofEpochMilli(startTime + increment * 6 - windowSize),
-                ofEpochMilli(startTime + increment * 6 + windowSize))));
+                ofEpochMilli(startTime + increment * 6 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 6 + WINDOW_SIZE))));
         assertEquals(
-            Collections.singletonList("seven"),
-            toList(windowStore.fetch(
+            new HashSet<>(Collections.singletonList("seven")),
+            toSet(windowStore.fetch(
                 7,
-                ofEpochMilli(startTime + increment * 7 - windowSize),
-                ofEpochMilli(startTime + increment * 7 + windowSize))));
+                ofEpochMilli(startTime + increment * 7 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 7 + WINDOW_SIZE))));
         assertEquals(
-            Collections.singletonList("eight"),
-            toList(windowStore.fetch(
+            new HashSet<>(Collections.singletonList("eight")),
+            toSet(windowStore.fetch(
                 8,
-                ofEpochMilli(startTime + increment * 8 - windowSize),
-                ofEpochMilli(startTime + increment * 8 + windowSize))));
+                ofEpochMilli(startTime + increment * 8 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 8 + WINDOW_SIZE))));
 
         // check segment directories
         windowStore.flush();
@@ -1006,168 +374,15 @@ public class RocksDBWindowStoreTest {
             ),
             segmentDirs(baseDir)
         );
-
-
     }
 
     @Test
-    public void testRestore() throws Exception {
-        final long startTime = segmentInterval * 2;
-        final long increment = segmentInterval / 2;
-
-        windowStore = createWindowStore(context, false);
-        setCurrentTime(startTime);
-        windowStore.put(0, "zero");
-        setCurrentTime(startTime + increment);
-        windowStore.put(1, "one");
-        setCurrentTime(startTime + increment * 2);
-        windowStore.put(2, "two");
-        setCurrentTime(startTime + increment * 3);
-        windowStore.put(3, "three");
-        setCurrentTime(startTime + increment * 4);
-        windowStore.put(4, "four");
-        setCurrentTime(startTime + increment * 5);
-        windowStore.put(5, "five");
-        setCurrentTime(startTime + increment * 6);
-        windowStore.put(6, "six");
-        setCurrentTime(startTime + increment * 7);
-        windowStore.put(7, "seven");
-        setCurrentTime(startTime + increment * 8);
-        windowStore.put(8, "eight");
-        windowStore.flush();
-
-        windowStore.close();
-
-        // remove local store image
-        Utils.delete(baseDir);
-
-        windowStore = createWindowStore(context, false);
-        assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
-                0,
-                ofEpochMilli(startTime - windowSize),
-                ofEpochMilli(startTime + windowSize))));
-        assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
-                1,
-                ofEpochMilli(startTime + increment - windowSize),
-                ofEpochMilli(startTime + increment + windowSize))));
-        assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime + increment * 2 - windowSize),
-                ofEpochMilli(startTime + increment * 2 + windowSize))));
-        assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
-                3,
-                ofEpochMilli(startTime + increment * 3 - windowSize),
-                ofEpochMilli(startTime + increment * 3 + windowSize))));
-        assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
-                4,
-                ofEpochMilli(startTime + increment * 4 - windowSize),
-                ofEpochMilli(startTime + increment * 4 + windowSize))));
-        assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
-                5,
-                ofEpochMilli(startTime + increment * 5 - windowSize),
-                ofEpochMilli(startTime + increment * 5 + windowSize))));
-        assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
-                6,
-                ofEpochMilli(startTime + increment * 6 - windowSize),
-                ofEpochMilli(startTime + increment * 6 + windowSize))));
-        assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
-                7,
-                ofEpochMilli(startTime + increment * 7 - windowSize),
-                ofEpochMilli(startTime + increment * 7 + windowSize))));
-        assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
-                8,
-                ofEpochMilli(startTime + increment * 8 - windowSize),
-                ofEpochMilli(startTime + increment * 8 + windowSize))));
-
-        context.restore(windowName, changeLog);
-
-        assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
-                0,
-                ofEpochMilli(startTime - windowSize),
-                ofEpochMilli(startTime + windowSize))));
-        assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
-                1,
-                ofEpochMilli(startTime + increment - windowSize),
-                ofEpochMilli(startTime + increment + windowSize))));
-        assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
-                2,
-                ofEpochMilli(startTime + increment * 2 - windowSize),
-                ofEpochMilli(startTime + increment * 2 + windowSize))));
-        assertEquals(
-            Collections.emptyList(),
-            toList(windowStore.fetch(
-                3,
-                ofEpochMilli(startTime + increment * 3 - windowSize),
-                ofEpochMilli(startTime + increment * 3 + windowSize))));
-        assertEquals(
-            Collections.singletonList("four"),
-            toList(windowStore.fetch(
-                4,
-                ofEpochMilli(startTime + increment * 4 - windowSize),
-                ofEpochMilli(startTime + increment * 4 + windowSize))));
-        assertEquals(
-            Collections.singletonList("five"),
-            toList(windowStore.fetch(
-                5,
-                ofEpochMilli(startTime + increment * 5 - windowSize),
-                ofEpochMilli(startTime + increment * 5 + windowSize))));
-        assertEquals(
-            Collections.singletonList("six"),
-            toList(windowStore.fetch(
-                6,
-                ofEpochMilli(startTime + increment * 6 - windowSize),
-                ofEpochMilli(startTime + increment * 6 + windowSize))));
-        assertEquals(
-            Collections.singletonList("seven"),
-            toList(windowStore.fetch(
-                7,
-                ofEpochMilli(startTime + increment * 7 - windowSize),
-                ofEpochMilli(startTime + increment * 7 + windowSize))));
-        assertEquals(
-            Collections.singletonList("eight"),
-            toList(windowStore.fetch(
-                8,
-                ofEpochMilli(startTime + increment * 8 - windowSize),
-                ofEpochMilli(startTime + increment * 8 + windowSize))));
+    public void testSegmentMaintenance() {
 
-        // check segment directories
-        windowStore.flush();
-        assertEquals(
-            Utils.mkSet(
-                segments.segmentName(4L),
-                segments.segmentName(5L),
-                segments.segmentName(6L)),
-            segmentDirs(baseDir)
-        );
-    }
+        windowStore = buildWindowStore(RETENTION_PERIOD, WINDOW_SIZE, true, Serdes.Integer(),
+            Serdes.String());
+        windowStore.init(context, windowStore);
 
-    @Test
-    public void testSegmentMaintenance() {
-        windowStore = createWindowStore(context, true);
         context.setTime(0L);
         setCurrentTime(0);
         windowStore.put(0, "v");
@@ -1176,7 +391,7 @@ public class RocksDBWindowStoreTest {
             segmentDirs(baseDir)
         );
 
-        setCurrentTime(segmentInterval - 1);
+        setCurrentTime(SEGMENT_INTERVAL - 1);
         windowStore.put(0, "v");
         windowStore.put(0, "v");
         assertEquals(
@@ -1184,7 +399,7 @@ public class RocksDBWindowStoreTest {
             segmentDirs(baseDir)
         );
 
-        setCurrentTime(segmentInterval);
+        setCurrentTime(SEGMENT_INTERVAL);
         windowStore.put(0, "v");
         assertEquals(
             Utils.mkSet(segments.segmentName(0L), segments.segmentName(1L)),
@@ -1194,7 +409,7 @@ public class RocksDBWindowStoreTest {
         WindowStoreIterator iter;
         int fetchedCount;
 
-        iter = windowStore.fetch(0, ofEpochMilli(0L), ofEpochMilli(segmentInterval * 4));
+        iter = windowStore.fetch(0, ofEpochMilli(0L), ofEpochMilli(SEGMENT_INTERVAL * 4));
         fetchedCount = 0;
         while (iter.hasNext()) {
             iter.next();
@@ -1207,10 +422,10 @@ public class RocksDBWindowStoreTest {
             segmentDirs(baseDir)
         );
 
-        setCurrentTime(segmentInterval * 3);
+        setCurrentTime(SEGMENT_INTERVAL * 3);
         windowStore.put(0, "v");
 
-        iter = windowStore.fetch(0, ofEpochMilli(0L), ofEpochMilli(segmentInterval * 4));
+        iter = windowStore.fetch(0, ofEpochMilli(0L), ofEpochMilli(SEGMENT_INTERVAL * 4));
         fetchedCount = 0;
         while (iter.hasNext()) {
             iter.next();
@@ -1223,10 +438,10 @@ public class RocksDBWindowStoreTest {
             segmentDirs(baseDir)
         );
 
-        setCurrentTime(segmentInterval * 5);
+        setCurrentTime(SEGMENT_INTERVAL * 5);
         windowStore.put(0, "v");
 
-        iter = windowStore.fetch(0, ofEpochMilli(segmentInterval * 4), ofEpochMilli(segmentInterval * 10));
+        iter = windowStore.fetch(0, ofEpochMilli(SEGMENT_INTERVAL * 4), ofEpochMilli(SEGMENT_INTERVAL * 10));
         fetchedCount = 0;
         while (iter.hasNext()) {
             iter.next();
@@ -1244,9 +459,7 @@ public class RocksDBWindowStoreTest {
     @SuppressWarnings("ResultOfMethodCallIgnored")
     @Test
     public void testInitialLoading() {
-        final File storeDir = new File(baseDir, windowName);
-
-        windowStore = createWindowStore(context, false);
+        final File storeDir = new File(baseDir, STORE_NAME);
 
         new File(storeDir, segments.segmentName(0L)).mkdir();
         new File(storeDir, segments.segmentName(1L)).mkdir();
@@ -1257,10 +470,11 @@ public class RocksDBWindowStoreTest {
         new File(storeDir, segments.segmentName(6L)).mkdir();
         windowStore.close();
 
-        windowStore = createWindowStore(context, false);
+        windowStore = buildWindowStore(RETENTION_PERIOD, WINDOW_SIZE, false, Serdes.Integer(), Serdes.String());
+        windowStore.init(context, windowStore);
 
         // put something in the store to advance its stream time and expire the old segments
-        windowStore.put(1, "v", 6L * segmentInterval);
+        windowStore.put(1, "v", 6L * SEGMENT_INTERVAL);
 
         final List<String> expected = asList(
             segments.segmentName(4L),
@@ -1289,220 +503,169 @@ public class RocksDBWindowStoreTest {
     }
 
     @Test
-    public void shouldCloseOpenIteratorsWhenStoreIsClosedAndNotThrowInvalidStateStoreExceptionOnHasNext() {
-        windowStore = createWindowStore(context, false);
-        setCurrentTime(0);
-        windowStore.put(1, "one", 1L);
-        windowStore.put(1, "two", 2L);
-        windowStore.put(1, "three", 3L);
-
-        final WindowStoreIterator<String> iterator = windowStore.fetch(1, ofEpochMilli(1L), ofEpochMilli(3L));
-        assertTrue(iterator.hasNext());
-        windowStore.close();
-
-        assertFalse(iterator.hasNext());
-
-    }
-
-    @Test
-    public void shouldFetchAndIterateOverExactKeys() {
-        final long windowSize = 0x7a00000000000000L;
-        final long retentionPeriod = 0x7a00000000000000L;
-
-        final WindowStore<String, String> windowStore = Stores.windowStoreBuilder(
-            Stores.persistentWindowStore(windowName, ofMillis(retentionPeriod), ofMillis(windowSize), true),
-            Serdes.String(),
-            Serdes.String()).build();
-
-        windowStore.init(context, windowStore);
-
-        windowStore.put("a", "0001", 0);
-        windowStore.put("aa", "0002", 0);
-        windowStore.put("a", "0003", 1);
-        windowStore.put("aa", "0004", 1);
-        windowStore.put("a", "0005", 0x7a00000000000000L - 1);
-
-
-        final List expected = asList("0001", "0003", "0005");
-        assertThat(toList(windowStore.fetch("a", ofEpochMilli(0), ofEpochMilli(Long.MAX_VALUE))), equalTo(expected));
-
-        List<KeyValue<Windowed<String>, String>> list =
-            StreamsTestUtils.toList(windowStore.fetch("a", "a", ofEpochMilli(0), ofEpochMilli(Long.MAX_VALUE)));
-        assertThat(list, equalTo(asList(
-            windowedPair("a", "0001", 0, windowSize),
-            windowedPair("a", "0003", 1, windowSize),
-            windowedPair("a", "0005", 0x7a00000000000000L - 1, windowSize)
-        )));
-
-        list = StreamsTestUtils.toList(windowStore.fetch("aa", "aa", ofEpochMilli(0), ofEpochMilli(Long.MAX_VALUE)));
-        assertThat(list, equalTo(asList(
-            windowedPair("aa", "0002", 0, windowSize),
-            windowedPair("aa", "0004", 1, windowSize)
-        )));
-    }
-
-    @Test(expected = NullPointerException.class)
-    public void shouldThrowNullPointerExceptionOnPutNullKey() {
-        windowStore = createWindowStore(context, false);
-        windowStore.put(null, "anyValue");
-    }
-
-    @Test
-    public void shouldNotThrowNullPointerExceptionOnPutNullValue() {
-        windowStore = createWindowStore(context, false);
-        windowStore.put(1, null);
-    }
-
-    @Test(expected = NullPointerException.class)
-    public void shouldThrowNullPointerExceptionOnGetNullKey() {
-        windowStore = createWindowStore(context, false);
-        windowStore.fetch(null, ofEpochMilli(1L), ofEpochMilli(2L));
-    }
+    public void testRestore() throws Exception {
+        final long startTime = SEGMENT_INTERVAL * 2;
+        final long increment = SEGMENT_INTERVAL / 2;
 
-    @Test(expected = NullPointerException.class)
-    public void shouldThrowNullPointerExceptionOnRangeNullFromKey() {
-        windowStore = createWindowStore(context, false);
-        windowStore.fetch(null, 2, ofEpochMilli(1L), ofEpochMilli(2L));
-    }
+        setCurrentTime(startTime);
+        windowStore.put(0, "zero");
+        setCurrentTime(startTime + increment);
+        windowStore.put(1, "one");
+        setCurrentTime(startTime + increment * 2);
+        windowStore.put(2, "two");
+        setCurrentTime(startTime + increment * 3);
+        windowStore.put(3, "three");
+        setCurrentTime(startTime + increment * 4);
+        windowStore.put(4, "four");
+        setCurrentTime(startTime + increment * 5);
+        windowStore.put(5, "five");
+        setCurrentTime(startTime + increment * 6);
+        windowStore.put(6, "six");
+        setCurrentTime(startTime + increment * 7);
+        windowStore.put(7, "seven");
+        setCurrentTime(startTime + increment * 8);
+        windowStore.put(8, "eight");
+        windowStore.flush();
 
-    @Test(expected = NullPointerException.class)
-    public void shouldThrowNullPointerExceptionOnRangeNullToKey() {
-        windowStore = createWindowStore(context, false);
-        windowStore.fetch(1, null, ofEpochMilli(1L), ofEpochMilli(2L));
-    }
+        windowStore.close();
 
-    @Test
-    public void shouldFetchAndIterateOverExactBinaryKeys() {
-        final WindowStore<Bytes, String> windowStore = Stores.windowStoreBuilder(
-            Stores.persistentWindowStore(windowName, ofMillis(60_000L), ofMillis(60_000L), true),
-            Serdes.Bytes(),
-            Serdes.String()).build();
+        // remove local store image
+        Utils.delete(baseDir);
 
+        windowStore = buildWindowStore(RETENTION_PERIOD,
+                                       WINDOW_SIZE,
+                                       false,
+                                       Serdes.Integer(),
+                                       Serdes.String());
         windowStore.init(context, windowStore);
 
-        final Bytes key1 = Bytes.wrap(new byte[] {0});
-        final Bytes key2 = Bytes.wrap(new byte[] {0, 0});
-        final Bytes key3 = Bytes.wrap(new byte[] {0, 0, 0});
-        windowStore.put(key1, "1", 0);
-        windowStore.put(key2, "2", 0);
-        windowStore.put(key3, "3", 0);
-        windowStore.put(key1, "4", 1);
-        windowStore.put(key2, "5", 1);
-        windowStore.put(key3, "6", 59999);
-        windowStore.put(key1, "7", 59999);
-        windowStore.put(key2, "8", 59999);
-        windowStore.put(key3, "9", 59999);
-
-        final List expectedKey1 = asList("1", "4", "7");
-        assertThat(toList(windowStore.fetch(key1, ofEpochMilli(0), ofEpochMilli(Long.MAX_VALUE))), equalTo(expectedKey1));
-        final List expectedKey2 = asList("2", "5", "8");
-        assertThat(toList(windowStore.fetch(key2, ofEpochMilli(0), ofEpochMilli(Long.MAX_VALUE))), equalTo(expectedKey2));
-        final List expectedKey3 = asList("3", "6", "9");
-        assertThat(toList(windowStore.fetch(key3, ofEpochMilli(0), ofEpochMilli(Long.MAX_VALUE))), equalTo(expectedKey3));
-    }
-
-    @Test
-    public void shouldReturnSameResultsForSingleKeyFetchAndEqualKeyRangeFetch() {
-        windowStore = createWindowStore(context, false);
-
-        windowStore.put(1, "one", 0L);
-        windowStore.put(2, "two", 1L);
-        windowStore.put(2, "two", 2L);
-        windowStore.put(3, "three", 3L);
-
-        final WindowStoreIterator<String> singleKeyIterator = windowStore.fetch(2, 0L, 5L);
-        final KeyValueIterator<Windowed<Integer>, String> keyRangeIterator = windowStore.fetch(2, 2, 0L, 5L);
-
-        assertEquals(singleKeyIterator.next().value, keyRangeIterator.next().value);
-        assertEquals(singleKeyIterator.next().value, keyRangeIterator.next().value);
-        assertFalse(singleKeyIterator.hasNext());
-        assertFalse(keyRangeIterator.hasNext());
-    }
-
-    @Test
-    public void shouldNotThrowInvalidRangeExceptionWithNegativeFromKey() {
-        windowStore = createWindowStore(context, false);
-
-        LogCaptureAppender.setClassLoggerToDebug(InMemoryWindowStore.class);
-        final LogCaptureAppender appender = LogCaptureAppender.createAndRegister();
-
-        final KeyValueIterator<Windowed<Integer>, String> iterator = windowStore.fetch(-1, 1, 0L, 10L);
-        assertFalse(iterator.hasNext());
-
-        final List<String> messages = appender.getMessages();
-        assertThat(messages, hasItem("Returning empty iterator for fetch with invalid key range: from > to. "
-            + "This may be due to serdes that don't preserve ordering when lexicographically comparing the serialized bytes. "
-            + "Note that the built-in numerical serdes do not follow this for negative numbers"));
-    }
+        assertEquals(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
+                0,
+                ofEpochMilli(startTime - WINDOW_SIZE),
+                ofEpochMilli(startTime + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
+                1,
+                ofEpochMilli(startTime + increment - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime + increment * 2 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 2 + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
+                3,
+                ofEpochMilli(startTime + increment * 3 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 3 + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
+                4,
+                ofEpochMilli(startTime + increment * 4 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 4 + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
+                5,
+                ofEpochMilli(startTime + increment * 5 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 5 + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
+                6,
+                ofEpochMilli(startTime + increment * 6 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 6 + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
+                7,
+                ofEpochMilli(startTime + increment * 7 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 7 + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
+                8,
+                ofEpochMilli(startTime + increment * 8 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 8 + WINDOW_SIZE))));
 
-    private void putFirstBatch(final WindowStore<Integer, String> store,
-                               @SuppressWarnings("SameParameterValue") final long startTime,
-                               final InternalMockProcessorContext context) {
-        context.setRecordContext(createRecordContext(startTime));
-        store.put(0, "zero");
-        context.setRecordContext(createRecordContext(startTime + 1L));
-        store.put(1, "one");
-        context.setRecordContext(createRecordContext(startTime + 2L));
-        store.put(2, "two");
-        context.setRecordContext(createRecordContext(startTime + 4L));
-        store.put(4, "four");
-        context.setRecordContext(createRecordContext(startTime + 5L));
-        store.put(5, "five");
-    }
+        context.restore(STORE_NAME, changeLog);
 
-    private void putSecondBatch(final WindowStore<Integer, String> store,
-                                @SuppressWarnings("SameParameterValue") final long startTime,
-                                final InternalMockProcessorContext context) {
-        context.setRecordContext(createRecordContext(startTime + 3L));
-        store.put(2, "two+1");
-        context.setRecordContext(createRecordContext(startTime + 4L));
-        store.put(2, "two+2");
-        context.setRecordContext(createRecordContext(startTime + 5L));
-        store.put(2, "two+3");
-        context.setRecordContext(createRecordContext(startTime + 6L));
-        store.put(2, "two+4");
-        context.setRecordContext(createRecordContext(startTime + 7L));
-        store.put(2, "two+5");
-        context.setRecordContext(createRecordContext(startTime + 8L));
-        store.put(2, "two+6");
-    }
+        assertEquals(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
+                0,
+                ofEpochMilli(startTime - WINDOW_SIZE),
+                ofEpochMilli(startTime + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
+                1,
+                ofEpochMilli(startTime + increment - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime + increment * 2 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 2 + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
+                3,
+                ofEpochMilli(startTime + increment * 3 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 3 + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.singletonList("four")),
+            toSet(windowStore.fetch(
+                4,
+                ofEpochMilli(startTime + increment * 4 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 4 + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.singletonList("five")),
+            toSet(windowStore.fetch(
+                5,
+                ofEpochMilli(startTime + increment * 5 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 5 + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.singletonList("six")),
+            toSet(windowStore.fetch(
+                6,
+                ofEpochMilli(startTime + increment * 6 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 6 + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.singletonList("seven")),
+            toSet(windowStore.fetch(
+                7,
+                ofEpochMilli(startTime + increment * 7 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 7 + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.singletonList("eight")),
+            toSet(windowStore.fetch(
+                8,
+                ofEpochMilli(startTime + increment * 8 - WINDOW_SIZE),
+                ofEpochMilli(startTime + increment * 8 + WINDOW_SIZE))));
 
-    private <E> List<E> toList(final WindowStoreIterator<E> iterator) {
-        final ArrayList<E> list = new ArrayList<>();
-        while (iterator.hasNext()) {
-            list.add(iterator.next().value);
-        }
-        return list;
+        // check segment directories
+        windowStore.flush();
+        assertEquals(
+            Utils.mkSet(
+                segments.segmentName(4L),
+                segments.segmentName(5L),
+                segments.segmentName(6L)),
+            segmentDirs(baseDir)
+        );
     }
 
     private Set<String> segmentDirs(final File baseDir) {
-        final File windowDir = new File(baseDir, windowName);
+        final File windowDir = new File(baseDir, windowStore.name());
 
         return new HashSet<>(asList(requireNonNull(windowDir.list())));
     }
 
-    private Map<Integer, Set<String>> entriesByKey(final List<KeyValue<byte[], byte[]>> changeLog,
-                                                   @SuppressWarnings("SameParameterValue") final long startTime) {
-        final HashMap<Integer, Set<String>> entriesByKey = new HashMap<>();
-
-        for (final KeyValue<byte[], byte[]> entry : changeLog) {
-            final long timestamp = WindowKeySchema.extractStoreTimestamp(entry.key);
-
-            final Integer key = WindowKeySchema.extractStoreKey(entry.key, serdes);
-            final String value = entry.value == null ? null : serdes.valueFrom(entry.value);
-
-            final Set<String> entries = entriesByKey.computeIfAbsent(key, k -> new HashSet<>());
-            entries.add(value + "@" + (timestamp - startTime));
-        }
-
-        return entriesByKey;
-    }
-
-    private <K, V> KeyValue<Windowed<K>, V> windowedPair(final K key, final V value, final long timestamp) {
-        return windowedPair(key, value, timestamp, windowSize);
-    }
-
-    private static <K, V> KeyValue<Windowed<K>, V> windowedPair(final K key, final V value, final long timestamp, final long windowSize) {
-        return KeyValue.pair(new Windowed<>(key, WindowKeySchema.timeWindowForSize(timestamp, windowSize)), value);
-    }
 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/InMemorySessionStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/SessionBytesStoreTest.java
similarity index 66%
copy from streams/src/test/java/org/apache/kafka/streams/state/internals/InMemorySessionStoreTest.java
copy to streams/src/test/java/org/apache/kafka/streams/state/internals/SessionBytesStoreTest.java
index bbe8d21..a6cef81 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/InMemorySessionStoreTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/SessionBytesStoreTest.java
@@ -16,34 +16,36 @@
  */
 package org.apache.kafka.streams.state.internals;
 
-import static java.time.Duration.ofMillis;
-
+import static java.util.Arrays.asList;
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
-import static org.apache.kafka.test.StreamsTestUtils.toList;
-import static org.apache.kafka.test.StreamsTestUtils.valuesToList;
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.CoreMatchers.hasItem;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.HashSet;
+import java.util.Iterator;
 import java.util.List;
-
 import java.util.Map;
+import java.util.Set;
 import org.apache.kafka.clients.producer.MockProducer;
 import org.apache.kafka.clients.producer.Producer;
 import org.apache.kafka.common.Metric;
 import org.apache.kafka.common.MetricName;
 import org.apache.kafka.common.header.Headers;
 import org.apache.kafka.common.metrics.Metrics;
+import org.apache.kafka.common.serialization.Serde;
 import org.apache.kafka.common.serialization.Serdes;
 import org.apache.kafka.common.serialization.Serializer;
+import org.apache.kafka.common.utils.Bytes;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.streams.KeyValue;
 import org.apache.kafka.streams.errors.DefaultProductionExceptionHandler;
@@ -55,21 +57,19 @@ import org.apache.kafka.streams.processor.internals.RecordCollectorImpl;
 import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender;
 import org.apache.kafka.streams.state.KeyValueIterator;
 import org.apache.kafka.streams.state.SessionStore;
-import org.apache.kafka.streams.state.Stores;
 import org.apache.kafka.test.InternalMockProcessorContext;
 import org.apache.kafka.test.TestUtils;
-
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 
-public class InMemorySessionStoreTest {
+public abstract class SessionBytesStoreTest {
 
-    private static final String STORE_NAME = "InMemorySessionStore";
-    private static final long RETENTION_PERIOD = 10_000L;
+    protected static final long SEGMENT_INTERVAL = 60_000L;
+    protected static final long RETENTION_PERIOD = 10_000L;
 
-    private SessionStore<String, Long> sessionStore;
-    private InternalMockProcessorContext context;
+    protected SessionStore<String, Long> sessionStore;
+    protected InternalMockProcessorContext context;
 
     private final List<KeyValue<byte[], byte[]>> changeLog = new ArrayList<>();
 
@@ -77,39 +77,43 @@ public class InMemorySessionStoreTest {
         Serdes.ByteArray().serializer(),
         Serdes.ByteArray().serializer());
 
-    private final RecordCollector recordCollector = new RecordCollectorImpl(
-        STORE_NAME,
-        new LogContext(STORE_NAME),
-        new DefaultProductionExceptionHandler(),
-        new Metrics().sensor("skipped-records")) {
-
-        @Override
-        public <K1, V1> void send(final String topic,
-            final K1 key,
-            final V1 value,
-            final Headers headers,
-            final Integer partition,
-            final Long timestamp,
-            final Serializer<K1> keySerializer,
-            final Serializer<V1> valueSerializer) {
-            changeLog.add(new KeyValue<>(
-                keySerializer.serialize(topic, headers, key),
-                valueSerializer.serialize(topic, headers, value))
-            );
-        }
-    };
-
-    private SessionStore<String, Long> buildSessionStore(final long retentionPeriod) {
-        return Stores.sessionStoreBuilder(
-            Stores.inMemorySessionStore(
-                STORE_NAME,
-                ofMillis(retentionPeriod)),
-            Serdes.String(),
-            Serdes.Long()).build();
+    abstract <K, V> SessionStore<K, V> buildSessionStore(final long retentionPeriod,
+                                                          final Serde<K> keySerde,
+                                                          final Serde<V> valueSerde);
+
+    abstract String getMetricsScope();
+
+    abstract void setClassLoggerToDebug();
+
+    private RecordCollectorImpl createRecordCollector(final String name) {
+        return new RecordCollectorImpl(name,
+            new LogContext(name),
+            new DefaultProductionExceptionHandler(),
+            new Metrics().sensor("skipped-records")) {
+            @Override
+            public <K1, V1> void send(final String topic,
+                final K1 key,
+                final V1 value,
+                final Headers headers,
+                final Integer partition,
+                final Long timestamp,
+                final Serializer<K1> keySerializer,
+                final Serializer<V1> valueSerializer) {
+                changeLog.add(new KeyValue<>(
+                    keySerializer.serialize(topic, headers, key),
+                    valueSerializer.serialize(topic, headers, value))
+                );
+            }
+        };
     }
 
     @Before
-    public void before() {
+    public void setUp() {
+        sessionStore = buildSessionStore(RETENTION_PERIOD, Serdes.String(), Serdes.Long());
+
+        final RecordCollector recordCollector = createRecordCollector(sessionStore.name());
+        recordCollector.init(producer);
+
         context = new InternalMockProcessorContext(
             TestUtils.tempDirectory(),
             Serdes.String(),
@@ -120,10 +124,7 @@ public class InMemorySessionStoreTest {
                 0,
                 new MockStreamsMetrics(new Metrics())));
 
-        sessionStore = buildSessionStore(RETENTION_PERIOD);
-
         sessionStore.init(context, sessionStore);
-        recordCollector.init(producer);
     }
 
     @After
@@ -144,18 +145,17 @@ public class InMemorySessionStoreTest {
         final List<KeyValue<Windowed<String>, Long>> expected =
             Arrays.asList(KeyValue.pair(a1, 1L), KeyValue.pair(a2, 2L));
 
-        try (final KeyValueIterator<Windowed<String>, Long> values =
-            sessionStore.findSessions(key, 0, 1000L)
+        try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.findSessions(key, 0, 1000L)
         ) {
-            assertEquals(expected, toList(values));
+            assertEquals(new HashSet<>(expected), toSet(values));
         }
 
-        final List<KeyValue<Windowed<String>, Long>> expected2 = Collections.singletonList(KeyValue.pair(a2, 2L));
+        final List<KeyValue<Windowed<String>, Long>> expected2 =
+            Collections.singletonList(KeyValue.pair(a2, 2L));
 
-        try (final KeyValueIterator<Windowed<String>, Long> values2 =
-            sessionStore.findSessions(key, 400L, 600L)
+        try (final KeyValueIterator<Windowed<String>, Long> values2 = sessionStore.findSessions(key, 400L, 600L)
         ) {
-            assertEquals(expected2, toList(values2));
+            assertEquals(new HashSet<>(expected2), toSet(values2));
         }
     }
 
@@ -175,7 +175,7 @@ public class InMemorySessionStoreTest {
         sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 0)), 5L);
 
         try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.fetch("a")) {
-            assertEquals(expected, toList(values));
+            assertEquals(new HashSet<>(expected), toSet(values));
         }
     }
 
@@ -183,8 +183,9 @@ public class InMemorySessionStoreTest {
     public void shouldFetchAllSessionsWithinKeyRange() {
         final List<KeyValue<Windowed<String>, Long>> expected = Arrays.asList(
             KeyValue.pair(new Windowed<>("aa", new SessionWindow(10, 10)), 2L),
-            KeyValue.pair(new Windowed<>("aaa", new SessionWindow(100, 100)), 3L),
             KeyValue.pair(new Windowed<>("b", new SessionWindow(1000, 1000)), 4L),
+
+            KeyValue.pair(new Windowed<>("aaa", new SessionWindow(100, 100)), 3L),
             KeyValue.pair(new Windowed<>("bb", new SessionWindow(1500, 2000)), 5L));
 
         for (final KeyValue<Windowed<String>, Long> kv : expected) {
@@ -196,7 +197,7 @@ public class InMemorySessionStoreTest {
         sessionStore.put(new Windowed<>("bbb", new SessionWindow(2500, 3000)), 6L);
 
         try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.fetch("aa", "bb")) {
-            assertEquals(expected, toList(values));
+            assertEquals(new HashSet<>(expected), toSet(values));
         }
     }
 
@@ -213,6 +214,11 @@ public class InMemorySessionStoreTest {
     }
 
     @Test
+    public void shouldReturnNullOnSessionNotFound() {
+        assertNull(sessionStore.fetchSession("any key", 0L, 5L));
+    }
+
+    @Test
     public void shouldFindValuesWithinMergingSessionWindowRange() {
         final String key = "a";
         sessionStore.put(new Windowed<>(key, new SessionWindow(0L, 0L)), 1L);
@@ -222,9 +228,8 @@ public class InMemorySessionStoreTest {
             KeyValue.pair(new Windowed<>(key, new SessionWindow(0L, 0L)), 1L),
             KeyValue.pair(new Windowed<>(key, new SessionWindow(1000L, 1000L)), 2L));
 
-        try (final KeyValueIterator<Windowed<String>, Long> results =
-            sessionStore.findSessions(key, -1, 1000L)) {
-            assertEquals(expected, toList(results));
+        try (final KeyValueIterator<Windowed<String>, Long> results = sessionStore.findSessions(key, -1, 1000L)) {
+            assertEquals(new HashSet<>(expected), toSet(results));
         }
     }
 
@@ -235,13 +240,11 @@ public class InMemorySessionStoreTest {
 
         sessionStore.remove(new Windowed<>("a", new SessionWindow(0, 1000)));
 
-        try (final KeyValueIterator<Windowed<String>, Long> results =
-            sessionStore.findSessions("a", 0L, 1000L)) {
+        try (final KeyValueIterator<Windowed<String>, Long> results = sessionStore.findSessions("a", 0L, 1000L)) {
             assertFalse(results.hasNext());
         }
 
-        try (final KeyValueIterator<Windowed<String>, Long> results =
-            sessionStore.findSessions("a", 1500L, 2500L)) {
+        try (final KeyValueIterator<Windowed<String>, Long> results = sessionStore.findSessions("a", 1500L, 2500L)) {
             assertTrue(results.hasNext());
         }
     }
@@ -253,13 +256,11 @@ public class InMemorySessionStoreTest {
 
         sessionStore.put(new Windowed<>("a", new SessionWindow(0, 1000)), null);
 
-        try (final KeyValueIterator<Windowed<String>, Long> results =
-            sessionStore.findSessions("a", 0L, 1000L)) {
+        try (final KeyValueIterator<Windowed<String>, Long> results = sessionStore.findSessions("a", 0L, 1000L)) {
             assertFalse(results.hasNext());
         }
 
-        try (final KeyValueIterator<Windowed<String>, Long> results =
-            sessionStore.findSessions("a", 1500L, 2500L)) {
+        try (final KeyValueIterator<Windowed<String>, Long> results = sessionStore.findSessions("a", 1500L, 2500L)) {
             assertTrue(results.hasNext());
         }
     }
@@ -277,52 +278,81 @@ public class InMemorySessionStoreTest {
         sessionStore.put(session4, 4L);
         sessionStore.put(session5, 5L);
 
-        try (final KeyValueIterator<Windowed<String>, Long> results =
-            sessionStore.findSessions("a", 150, 300)
-        ) {
-            assertEquals(session2, results.next().key);
-            assertEquals(session3, results.next().key);
-            assertFalse(results.hasNext());
+        final List<KeyValue<Windowed<String>, Long>> expected =
+            Arrays.asList(KeyValue.pair(session2, 2L), KeyValue.pair(session3, 3L));
+
+        try (final KeyValueIterator<Windowed<String>, Long> results = sessionStore.findSessions("a", 150, 300)) {
+            assertEquals(new HashSet<>(expected), toSet(results));
         }
     }
 
     @Test
     public void shouldFetchExactKeys() {
-        sessionStore = buildSessionStore(0x7a00000000000000L);
+        sessionStore = buildSessionStore(0x7a00000000000000L, Serdes.String(), Serdes.Long());
         sessionStore.init(context, sessionStore);
 
         sessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L);
         sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 10)), 2L);
         sessionStore.put(new Windowed<>("a", new SessionWindow(10, 20)), 3L);
         sessionStore.put(new Windowed<>("aa", new SessionWindow(10, 20)), 4L);
-        sessionStore.put(new Windowed<>("a", new SessionWindow(0x7a00000000000000L - 2, 0x7a00000000000000L - 1)), 5L);
+        sessionStore.put(new Windowed<>("a",
+            new SessionWindow(0x7a00000000000000L - 2, 0x7a00000000000000L - 1)), 5L);
 
         try (final KeyValueIterator<Windowed<String>, Long> iterator =
             sessionStore.findSessions("a", 0, Long.MAX_VALUE)
         ) {
-            assertThat(valuesToList(iterator), equalTo(Arrays.asList(1L, 3L, 5L)));
+            assertThat(valuesToSet(iterator), equalTo(new HashSet<>(asList(1L, 3L, 5L))));
         }
 
         try (final KeyValueIterator<Windowed<String>, Long> iterator =
             sessionStore.findSessions("aa", 0, Long.MAX_VALUE)
         ) {
-            assertThat(valuesToList(iterator), equalTo(Arrays.asList(2L, 4L)));
+            assertThat(valuesToSet(iterator), equalTo(new HashSet<>(asList(2L, 4L))));
         }
 
         try (final KeyValueIterator<Windowed<String>, Long> iterator =
             sessionStore.findSessions("a", "aa", 0, Long.MAX_VALUE)
         ) {
-            assertThat(valuesToList(iterator), equalTo(Arrays.asList(1L, 2L, 3L, 4L, 5L)));
+            assertThat(valuesToSet(iterator), equalTo(new HashSet<>(asList(1L, 2L, 3L, 4L, 5L))));
         }
 
         try (final KeyValueIterator<Windowed<String>, Long> iterator =
             sessionStore.findSessions("a", "aa", 10, 0)
         ) {
-            assertThat(valuesToList(iterator), equalTo(Collections.singletonList(2L)));
+            assertThat(valuesToSet(iterator), equalTo(new HashSet<>(Collections.singletonList(2L))));
         }
     }
 
     @Test
+    public void shouldFetchAndIterateOverExactBinaryKeys() {
+        final SessionStore<Bytes, String> sessionStore =
+            buildSessionStore(RETENTION_PERIOD, Serdes.Bytes(), Serdes.String());
+
+        sessionStore.init(context, sessionStore);
+
+        final Bytes key1 = Bytes.wrap(new byte[]{0});
+        final Bytes key2 = Bytes.wrap(new byte[]{0, 0});
+        final Bytes key3 = Bytes.wrap(new byte[]{0, 0, 0});
+
+        sessionStore.put(new Windowed<>(key1, new SessionWindow(1, 100)), "1");
+        sessionStore.put(new Windowed<>(key2, new SessionWindow(2, 100)), "2");
+        sessionStore.put(new Windowed<>(key3, new SessionWindow(3, 100)), "3");
+        sessionStore.put(new Windowed<>(key1, new SessionWindow(4, 100)), "4");
+        sessionStore.put(new Windowed<>(key2, new SessionWindow(5, 100)), "5");
+        sessionStore.put(new Windowed<>(key3, new SessionWindow(6, 100)), "6");
+        sessionStore.put(new Windowed<>(key1, new SessionWindow(7, 100)), "7");
+        sessionStore.put(new Windowed<>(key2, new SessionWindow(8, 100)), "8");
+        sessionStore.put(new Windowed<>(key3, new SessionWindow(9, 100)), "9");
+
+        final Set<String> expectedKey1 = new HashSet<>(asList("1", "4", "7"));
+        assertThat(valuesToSet(sessionStore.findSessions(key1, 0L, Long.MAX_VALUE)), equalTo(expectedKey1));
+        final Set<String> expectedKey2 = new HashSet<>(asList("2", "5", "8"));
+        assertThat(valuesToSet(sessionStore.findSessions(key2, 0L, Long.MAX_VALUE)), equalTo(expectedKey2));
+        final Set<String> expectedKey3 = new HashSet<>(asList("3", "6", "9"));
+        assertThat(valuesToSet(sessionStore.findSessions(key3, 0L, Long.MAX_VALUE)), equalTo(expectedKey3));
+    }
+
+    @Test
     public void testIteratorPeek() {
         sessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L);
         sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 10)), 2L);
@@ -338,22 +368,6 @@ public class InMemorySessionStoreTest {
     }
 
     @Test
-    public void shouldRemoveExpired() {
-        sessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L);
-        sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 10)), 2L);
-        sessionStore.put(new Windowed<>("a", new SessionWindow(10, 20)), 3L);
-
-        // Advance stream time to expire the first record
-        sessionStore.put(new Windowed<>("aa", new SessionWindow(10, RETENTION_PERIOD)), 4L);
-
-        try (final KeyValueIterator<Windowed<String>, Long> iterator =
-            sessionStore.findSessions("a", "b", 0L, Long.MAX_VALUE)
-        ) {
-            assertThat(valuesToList(iterator), equalTo(Arrays.asList(2L, 3L, 4L)));
-        }
-    }
-
-    @Test
     public void shouldRestore() {
         final List<KeyValue<Windowed<String>, Long>> expected = Arrays.asList(
             KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 1L),
@@ -366,23 +380,36 @@ public class InMemorySessionStoreTest {
         }
 
         try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.fetch("a")) {
-            assertEquals(expected, toList(values));
+            assertEquals(new HashSet<>(expected), toSet(values));
         }
 
         sessionStore.close();
 
         try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.fetch("a")) {
-            assertEquals(Collections.emptyList(), toList(values));
+            assertEquals(Collections.emptySet(), toSet(values));
         }
 
-        context.restore(STORE_NAME, changeLog);
+        context.restore(sessionStore.name(), changeLog);
 
         try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.fetch("a")) {
-            assertEquals(expected, toList(values));
+            assertEquals(new HashSet<>(expected), toSet(values));
         }
     }
 
     @Test
+    public void shouldCloseOpenIteratorsWhenStoreIsClosedAndNotThrowInvalidStateStoreExceptionOnHasNext() {
+        sessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L);
+        sessionStore.put(new Windowed<>("b", new SessionWindow(10, 50)), 2L);
+        sessionStore.put(new Windowed<>("c", new SessionWindow(100, 500)), 3L);
+
+        final KeyValueIterator<Windowed<String>, Long> iterator = sessionStore.fetch("a");
+        assertTrue(iterator.hasNext());
+        sessionStore.close();
+
+        assertFalse(iterator.hasNext());
+    }
+
+    @Test
     public void shouldReturnSameResultsForSingleKeyFindSessionsAndEqualKeyRangeFindSessions() {
         sessionStore.put(new Windowed<>("a", new SessionWindow(0, 1)), 0L);
         sessionStore.put(new Windowed<>("aa", new SessionWindow(2, 3)), 1L);
@@ -390,50 +417,52 @@ public class InMemorySessionStoreTest {
         sessionStore.put(new Windowed<>("aaa", new SessionWindow(6, 7)), 3L);
 
         final KeyValueIterator<Windowed<String>, Long> singleKeyIterator = sessionStore.findSessions("aa", 0L, 10L);
-        final KeyValueIterator<Windowed<String>, Long> keyRangeIterator = sessionStore.findSessions("aa", "aa", 0L, 10L);
+        final KeyValueIterator<Windowed<String>, Long> rangeIterator = sessionStore.findSessions("aa", "aa", 0L, 10L);
 
-        assertEquals(singleKeyIterator.next(), keyRangeIterator.next());
-        assertEquals(singleKeyIterator.next(), keyRangeIterator.next());
+        assertEquals(singleKeyIterator.next(), rangeIterator.next());
+        assertEquals(singleKeyIterator.next(), rangeIterator.next());
         assertFalse(singleKeyIterator.hasNext());
-        assertFalse(keyRangeIterator.hasNext());
+        assertFalse(rangeIterator.hasNext());
     }
 
     @Test
     public void shouldLogAndMeasureExpiredRecords() {
-        LogCaptureAppender.setClassLoggerToDebug(InMemorySessionStore.class);
+        setClassLoggerToDebug();
         final LogCaptureAppender appender = LogCaptureAppender.createAndRegister();
 
-
         // Advance stream time by inserting record with large enough timestamp that records with timestamp 0 are expired
-        sessionStore.put(new Windowed<>("initial record", new SessionWindow(0, RETENTION_PERIOD)), 0L);
+        // Note that rocksdb will only expire segments at a time (where segment interval = 60,000 for this retention period)
+        sessionStore.put(new Windowed<>("initial record", new SessionWindow(0, 2 * SEGMENT_INTERVAL)), 0L);
 
         // Try inserting a record with timestamp 0 -- should be dropped
         sessionStore.put(new Windowed<>("late record", new SessionWindow(0, 0)), 0L);
-        sessionStore.put(new Windowed<>("another on-time record", new SessionWindow(0, RETENTION_PERIOD)), 0L);
+        sessionStore.put(new Windowed<>("another on-time record", new SessionWindow(0, 2 * SEGMENT_INTERVAL)), 0L);
 
         LogCaptureAppender.unregister(appender);
 
         final Map<MetricName, ? extends Metric> metrics = context.metrics().metrics();
 
+        final String metricScope = getMetricsScope();
+
         final Metric dropTotal = metrics.get(new MetricName(
             "expired-window-record-drop-total",
-            "stream-in-memory-session-state-metrics",
+            "stream-" + metricScope + "-metrics",
             "The total number of occurrence of expired-window-record-drop operations.",
             mkMap(
                 mkEntry("client-id", "mock"),
                 mkEntry("task-id", "0_0"),
-                mkEntry("in-memory-session-state-id", STORE_NAME)
+                mkEntry(metricScope + "-id", sessionStore.name())
             )
         ));
 
         final Metric dropRate = metrics.get(new MetricName(
             "expired-window-record-drop-rate",
-            "stream-in-memory-session-state-metrics",
+            "stream-" + metricScope + "-metrics",
             "The average number of occurrence of expired-window-record-drop operation per second.",
             mkMap(
                 mkEntry("client-id", "mock"),
                 mkEntry("task-id", "0_0"),
-                mkEntry("in-memory-session-state-id", STORE_NAME)
+                mkEntry(metricScope + "-id", sessionStore.name())
             )
         ));
 
@@ -485,18 +514,40 @@ public class InMemorySessionStoreTest {
 
     @Test
     public void shouldNotThrowInvalidRangeExceptionWithNegativeFromKey() {
-        LogCaptureAppender.setClassLoggerToDebug(InMemorySessionStore.class);
+        setClassLoggerToDebug();
         final LogCaptureAppender appender = LogCaptureAppender.createAndRegister();
 
-        final String keyFrom = Serdes.String().deserializer().deserialize("", Serdes.Integer().serializer().serialize("", -1));
-        final String keyTo = Serdes.String().deserializer().deserialize("", Serdes.Integer().serializer().serialize("", 1));
+        final String keyFrom = Serdes.String().deserializer()
+            .deserialize("", Serdes.Integer().serializer().serialize("", -1));
+        final String keyTo = Serdes.String().deserializer()
+            .deserialize("", Serdes.Integer().serializer().serialize("", 1));
 
         final KeyValueIterator<Windowed<String>, Long> iterator = sessionStore.findSessions(keyFrom, keyTo, 0L, 10L);
         assertFalse(iterator.hasNext());
 
         final List<String> messages = appender.getMessages();
-        assertThat(messages, hasItem("Returning empty iterator for fetch with invalid key range: from > to. "
-            + "This may be due to serdes that don't preserve ordering when lexicographically comparing the serialized bytes. "
-            + "Note that the built-in numerical serdes do not follow this for negative numbers"));
+        assertThat(messages,
+            hasItem("Returning empty iterator for fetch with invalid key range: from > to. "
+                + "This may be due to serdes that don't preserve ordering when lexicographically comparing the serialized bytes. "
+                + "Note that the built-in numerical serdes do not follow this for negative numbers"));
+    }
+
+    protected static <K, V> Set<V> valuesToSet(final Iterator<KeyValue<K, V>> iterator) {
+        final Set<V> results = new HashSet<>();
+
+        while (iterator.hasNext()) {
+            results.add(iterator.next().value);
+        }
+        return results;
     }
-}
\ No newline at end of file
+
+    protected static <K, V> Set<KeyValue<K, V>> toSet(final Iterator<KeyValue<K, V>> iterator) {
+        final Set<KeyValue<K, V>> results = new HashSet<>();
+
+        while (iterator.hasNext()) {
+            results.add(iterator.next());
+        }
+        return results;
+    }
+
+}
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/WindowBytesStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/WindowBytesStoreTest.java
new file mode 100644
index 0000000..5177079
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/WindowBytesStoreTest.java
@@ -0,0 +1,1104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.state.internals;
+
+import static java.time.Instant.ofEpochMilli;
+import static java.util.Arrays.asList;
+import static org.apache.kafka.common.utils.Utils.mkEntry;
+import static org.apache.kafka.common.utils.Utils.mkMap;
+import static org.hamcrest.CoreMatchers.equalTo;
+import static org.hamcrest.CoreMatchers.hasItem;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+import java.io.File;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import org.apache.kafka.clients.producer.MockProducer;
+import org.apache.kafka.clients.producer.Producer;
+import org.apache.kafka.common.Metric;
+import org.apache.kafka.common.MetricName;
+import org.apache.kafka.common.header.Headers;
+import org.apache.kafka.common.metrics.Metrics;
+import org.apache.kafka.common.serialization.Serde;
+import org.apache.kafka.common.serialization.Serdes;
+import org.apache.kafka.common.serialization.Serializer;
+import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.Utils;
+import org.apache.kafka.streams.KeyValue;
+import org.apache.kafka.streams.errors.DefaultProductionExceptionHandler;
+import org.apache.kafka.streams.kstream.Windowed;
+import org.apache.kafka.streams.processor.internals.MockStreamsMetrics;
+import org.apache.kafka.streams.processor.internals.ProcessorRecordContext;
+import org.apache.kafka.streams.processor.internals.RecordCollector;
+import org.apache.kafka.streams.processor.internals.RecordCollectorImpl;
+import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender;
+import org.apache.kafka.streams.state.KeyValueIterator;
+import org.apache.kafka.streams.state.StateSerdes;
+import org.apache.kafka.streams.state.WindowStore;
+import org.apache.kafka.streams.state.WindowStoreIterator;
+import org.apache.kafka.test.InternalMockProcessorContext;
+import org.apache.kafka.test.TestUtils;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+public abstract class WindowBytesStoreTest {
+
+    static final long WINDOW_SIZE = 3L;
+    static final long SEGMENT_INTERVAL = 60_000L;
+    static final long RETENTION_PERIOD = 2 * SEGMENT_INTERVAL;
+
+    WindowStore<Integer, String> windowStore;
+    InternalMockProcessorContext context;
+    final File baseDir = TestUtils.tempDirectory("test");
+
+    private final StateSerdes<Integer, String> serdes = new StateSerdes<>("", Serdes.Integer(), Serdes.String());
+
+    final List<KeyValue<byte[], byte[]>> changeLog = new ArrayList<>();
+
+    private final Producer<byte[], byte[]> producer = new MockProducer<>(true,
+        Serdes.ByteArray().serializer(),
+        Serdes.ByteArray().serializer());
+
+    abstract <K, V> WindowStore<K, V> buildWindowStore(final long retentionPeriod,
+                                                               final long windowSize,
+                                                               final boolean retainDuplicates,
+                                                               final Serde<K> keySerde,
+                                                               final Serde<V> valueSerde);
+
+    abstract String getMetricsScope();
+
+    abstract void setClassLoggerToDebug();
+
+    private RecordCollectorImpl createRecordCollector(final String name) {
+        return new RecordCollectorImpl(name,
+            new LogContext(name),
+            new DefaultProductionExceptionHandler(),
+            new Metrics().sensor("skipped-records")) {
+            @Override
+            public <K1, V1> void send(final String topic,
+                final K1 key,
+                final V1 value,
+                final Headers headers,
+                final Integer partition,
+                final Long timestamp,
+                final Serializer<K1> keySerializer,
+                final Serializer<V1> valueSerializer) {
+                changeLog.add(new KeyValue<>(
+                    keySerializer.serialize(topic, headers, key),
+                    valueSerializer.serialize(topic, headers, value))
+                );
+            }
+        };
+    }
+
+    @Before
+    public void setup() {
+        windowStore = buildWindowStore(RETENTION_PERIOD, WINDOW_SIZE, false, Serdes.Integer(), Serdes.String());
+
+        final RecordCollector recordCollector = createRecordCollector(windowStore.name());
+        recordCollector.init(producer);
+
+        context = new InternalMockProcessorContext(
+            baseDir,
+            Serdes.String(),
+            Serdes.Integer(),
+            recordCollector,
+            new ThreadCache(
+                new LogContext("testCache"),
+                0,
+                new MockStreamsMetrics(new Metrics())));
+
+        windowStore.init(context, windowStore);
+    }
+
+    @After
+    public void after() {
+        windowStore.close();
+    }
+
+    @Test
+    public void testRangeAndSinglePointFetch() {
+        final long startTime = SEGMENT_INTERVAL - 4L;
+
+        putFirstBatch(windowStore, startTime, context);
+
+        assertEquals("zero", windowStore.fetch(0, startTime));
+        assertEquals("one", windowStore.fetch(1, startTime + 1L));
+        assertEquals("two", windowStore.fetch(2, startTime + 2L));
+        assertEquals("four", windowStore.fetch(4, startTime + 4L));
+        assertEquals("five", windowStore.fetch(5, startTime + 5L));
+
+        assertEquals(
+            new HashSet<>(Collections.singletonList("zero")),
+            toSet(windowStore.fetch(
+                0,
+                ofEpochMilli(startTime + 0 - WINDOW_SIZE),
+                ofEpochMilli(startTime + 0 + WINDOW_SIZE))));
+
+        putSecondBatch(windowStore, startTime, context);
+
+        assertEquals("two+1", windowStore.fetch(2, startTime + 3L));
+        assertEquals("two+2", windowStore.fetch(2, startTime + 4L));
+        assertEquals("two+3", windowStore.fetch(2, startTime + 5L));
+        assertEquals("two+4", windowStore.fetch(2, startTime + 6L));
+        assertEquals("two+5", windowStore.fetch(2, startTime + 7L));
+        assertEquals("two+6", windowStore.fetch(2, startTime + 8L));
+
+        assertEquals(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime - 2L - WINDOW_SIZE),
+                ofEpochMilli(startTime - 2L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.singletonList("two")),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime - 1L - WINDOW_SIZE),
+                ofEpochMilli(startTime - 1L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(asList("two", "two+1")),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime - WINDOW_SIZE),
+                ofEpochMilli(startTime + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(asList("two", "two+1", "two+2")),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime + 1L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 1L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(asList("two", "two+1", "two+2", "two+3")),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime + 2L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 2L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(asList("two", "two+1", "two+2", "two+3", "two+4")),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime + 3L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 3L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(asList("two", "two+1", "two+2", "two+3", "two+4", "two+5")),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime + 4L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 4L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(asList("two", "two+1", "two+2", "two+3", "two+4", "two+5", "two+6")),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime + 5L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 5L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(asList("two+1", "two+2", "two+3", "two+4", "two+5", "two+6")),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime + 6L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 6L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(asList("two+2", "two+3", "two+4", "two+5", "two+6")),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime + 7L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 7L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(asList("two+3", "two+4", "two+5", "two+6")),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime + 8L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 8L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(asList("two+4", "two+5", "two+6")),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime + 9L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 9L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(asList("two+5", "two+6")),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime + 10L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 10L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.singletonList("two+6")),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime + 11L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 11L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime + 12L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 12L + WINDOW_SIZE))));
+
+        // Flush the store and verify all current entries were properly flushed ...
+        windowStore.flush();
+
+        final Map<Integer, Set<String>> entriesByKey = entriesByKey(changeLog, startTime);
+
+        assertEquals(Utils.mkSet("zero@0"), entriesByKey.get(0));
+        assertEquals(Utils.mkSet("one@1"), entriesByKey.get(1));
+        assertEquals(
+            Utils.mkSet("two@2", "two+1@3", "two+2@4", "two+3@5", "two+4@6", "two+5@7", "two+6@8"),
+            entriesByKey.get(2));
+        assertNull(entriesByKey.get(3));
+        assertEquals(Utils.mkSet("four@4"), entriesByKey.get(4));
+        assertEquals(Utils.mkSet("five@5"), entriesByKey.get(5));
+        assertNull(entriesByKey.get(6));
+    }
+
+    @Test
+    public void shouldGetAll() {
+        final long startTime = SEGMENT_INTERVAL - 4L;
+
+        putFirstBatch(windowStore, startTime, context);
+
+        final KeyValue<Windowed<Integer>, String> zero = windowedPair(0, "zero", startTime + 0);
+        final KeyValue<Windowed<Integer>, String> one = windowedPair(1, "one", startTime + 1);
+        final KeyValue<Windowed<Integer>, String> two = windowedPair(2, "two", startTime + 2);
+        final KeyValue<Windowed<Integer>, String> four = windowedPair(4, "four", startTime + 4);
+        final KeyValue<Windowed<Integer>, String> five = windowedPair(5, "five", startTime + 5);
+
+        assertEquals(
+            new HashSet<>(asList(zero, one, two, four, five)),
+            toSet(windowStore.all())
+        );
+    }
+
+    @Test
+    public void shouldFetchAllInTimeRange() {
+        final long startTime = SEGMENT_INTERVAL - 4L;
+
+        putFirstBatch(windowStore, startTime, context);
+
+        final KeyValue<Windowed<Integer>, String> zero = windowedPair(0, "zero", startTime + 0);
+        final KeyValue<Windowed<Integer>, String> one = windowedPair(1, "one", startTime + 1);
+        final KeyValue<Windowed<Integer>, String> two = windowedPair(2, "two", startTime + 2);
+        final KeyValue<Windowed<Integer>, String> four = windowedPair(4, "four", startTime + 4);
+        final KeyValue<Windowed<Integer>, String> five = windowedPair(5, "five", startTime + 5);
+
+        assertEquals(
+            new HashSet<>(asList(one, two, four)),
+            toSet(windowStore.fetchAll(ofEpochMilli(startTime + 1), ofEpochMilli(startTime + 4)))
+        );
+        assertEquals(
+            new HashSet<>(asList(zero, one, two)),
+            toSet(windowStore.fetchAll(ofEpochMilli(startTime + 0), ofEpochMilli(startTime + 3)))
+        );
+        assertEquals(
+            new HashSet<>(asList(one, two, four, five)),
+            toSet(windowStore.fetchAll(ofEpochMilli(startTime + 1), ofEpochMilli(startTime + 5)))
+        );
+    }
+
+    @Test
+    public void testFetchRange() {
+        final long startTime = SEGMENT_INTERVAL - 4L;
+
+        putFirstBatch(windowStore, startTime, context);
+
+        final KeyValue<Windowed<Integer>, String> zero = windowedPair(0, "zero", startTime + 0);
+        final KeyValue<Windowed<Integer>, String> one = windowedPair(1, "one", startTime + 1);
+        final KeyValue<Windowed<Integer>, String> two = windowedPair(2, "two", startTime + 2);
+        final KeyValue<Windowed<Integer>, String> four = windowedPair(4, "four", startTime + 4);
+        final KeyValue<Windowed<Integer>, String> five = windowedPair(5, "five", startTime + 5);
+
+        assertEquals(
+            new HashSet<>(asList(zero, one)),
+            toSet(windowStore.fetch(
+                0,
+                1,
+                ofEpochMilli(startTime + 0L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 0L + WINDOW_SIZE)))
+        );
+        assertEquals(
+            new HashSet<>(Collections.singletonList(one)),
+            toSet(windowStore.fetch(
+                1,
+                1,
+                ofEpochMilli(startTime + 0L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 0L + WINDOW_SIZE)))
+        );
+        assertEquals(
+            new HashSet<>(asList(one, two)),
+            toSet(windowStore.fetch(
+                1,
+                3,
+                ofEpochMilli(startTime + 0L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 0L + WINDOW_SIZE)))
+        );
+        assertEquals(
+            new HashSet<>(asList(zero, one, two)),
+            toSet(windowStore.fetch(
+                0,
+                5,
+                ofEpochMilli(startTime + 0L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 0L + WINDOW_SIZE)))
+        );
+        assertEquals(
+            new HashSet<>(asList(zero, one, two, four, five)),
+            toSet(windowStore.fetch(
+                0,
+                5,
+                ofEpochMilli(startTime + 0L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 0L + WINDOW_SIZE + 5L)))
+        );
+        assertEquals(
+            new HashSet<>(asList(two, four, five)),
+            toSet(windowStore.fetch(
+                0,
+                5,
+                ofEpochMilli(startTime + 2L),
+                ofEpochMilli(startTime + 0L + WINDOW_SIZE + 5L)))
+        );
+        assertEquals(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
+                4,
+                5,
+                ofEpochMilli(startTime + 2L),
+                ofEpochMilli(startTime + WINDOW_SIZE)))
+        );
+        assertEquals(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
+                0,
+                3,
+                ofEpochMilli(startTime + 3L),
+                ofEpochMilli(startTime + WINDOW_SIZE + 5)))
+        );
+    }
+
+    @Test
+    public void testPutAndFetchBefore() {
+        final long startTime = SEGMENT_INTERVAL - 4L;
+
+        putFirstBatch(windowStore, startTime, context);
+
+        assertEquals(
+            new HashSet<>(Collections.singletonList("zero")),
+            toSet(windowStore.fetch(
+                0,
+                ofEpochMilli(startTime + 0L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 0L))));
+        assertEquals(
+            new HashSet<>(Collections.singletonList("one")),
+            toSet(windowStore.fetch(
+                1,
+                ofEpochMilli(startTime + 1L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 1L))));
+        assertEquals(
+            new HashSet<>(Collections.singletonList("two")),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime + 2L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 2L))));
+        assertEquals(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
+                3,
+                ofEpochMilli(startTime + 3L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 3L))));
+        assertEquals(
+            new HashSet<>(Collections.singletonList("four")),
+            toSet(windowStore.fetch(
+                4,
+                ofEpochMilli(startTime + 4L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 4L))));
+        assertEquals(
+            new HashSet<>(Collections.singletonList("five")),
+            toSet(windowStore.fetch(
+                5,
+                ofEpochMilli(startTime + 5L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 5L))));
+
+        putSecondBatch(windowStore, startTime, context);
+
+        assertEquals(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime - 1L - WINDOW_SIZE),
+                ofEpochMilli(startTime - 1L))));
+        assertEquals(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime + 0L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 0L))));
+        assertEquals(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime + 1L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 1L))));
+        assertEquals(
+            new HashSet<>(Collections.singletonList("two")),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime + 2L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 2L))));
+        assertEquals(
+            new HashSet<>(asList("two", "two+1")),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime + 3L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 3L))));
+        assertEquals(
+            new HashSet<>(asList("two", "two+1", "two+2")),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime + 4L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 4L))));
+        assertEquals(
+            new HashSet<>(asList("two", "two+1", "two+2", "two+3")),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime + 5L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 5L))));
+        assertEquals(
+            new HashSet<>(asList("two+1", "two+2", "two+3", "two+4")),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime + 6L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 6L))));
+        assertEquals(
+            new HashSet<>(asList("two+2", "two+3", "two+4", "two+5")),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime + 7L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 7L))));
+        assertEquals(
+            new HashSet<>(asList("two+3", "two+4", "two+5", "two+6")),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime + 8L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 8L))));
+        assertEquals(
+            new HashSet<>(asList("two+4", "two+5", "two+6")),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime + 9L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 9L))));
+        assertEquals(
+            new HashSet<>(asList("two+5", "two+6")),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime + 10L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 10L))));
+        assertEquals(
+            new HashSet<>(Collections.singletonList("two+6")),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime + 11L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 11L))));
+        assertEquals(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime + 12L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 12L))));
+        assertEquals(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
+                2,
+                ofEpochMilli(startTime + 13L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 13L))));
+
+        // Flush the store and verify all current entries were properly flushed ...
+        windowStore.flush();
+
+        final Map<Integer, Set<String>> entriesByKey = entriesByKey(changeLog, startTime);
+        assertEquals(Utils.mkSet("zero@0"), entriesByKey.get(0));
+        assertEquals(Utils.mkSet("one@1"), entriesByKey.get(1));
+        assertEquals(
+            Utils.mkSet("two@2", "two+1@3", "two+2@4", "two+3@5", "two+4@6", "two+5@7", "two+6@8"),
+            entriesByKey.get(2));
+        assertNull(entriesByKey.get(3));
+        assertEquals(Utils.mkSet("four@4"), entriesByKey.get(4));
+        assertEquals(Utils.mkSet("five@5"), entriesByKey.get(5));
+        assertNull(entriesByKey.get(6));
+    }
+
+    @Test
+    public void testPutAndFetchAfter() {
+        final long startTime = SEGMENT_INTERVAL - 4L;
+
+        putFirstBatch(windowStore, startTime, context);
+
+        assertEquals(
+            new HashSet<>(Collections.singletonList("zero")),
+            toSet(windowStore.fetch(0, ofEpochMilli(startTime + 0L),
+                ofEpochMilli(startTime + 0L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.singletonList("one")),
+            toSet(windowStore.fetch(1, ofEpochMilli(startTime + 1L),
+                ofEpochMilli(startTime + 1L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.singletonList("two")),
+            toSet(windowStore.fetch(2, ofEpochMilli(startTime + 2L),
+                ofEpochMilli(startTime + 2L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(3, ofEpochMilli(startTime + 3L),
+                ofEpochMilli(startTime + 3L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.singletonList("four")),
+            toSet(windowStore.fetch(4, ofEpochMilli(startTime + 4L),
+                ofEpochMilli(startTime + 4L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.singletonList("five")),
+            toSet(windowStore.fetch(5, ofEpochMilli(startTime + 5L),
+                ofEpochMilli(startTime + 5L + WINDOW_SIZE))));
+
+        putSecondBatch(windowStore, startTime, context);
+
+        assertEquals(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(2, ofEpochMilli(startTime - 2L),
+                ofEpochMilli(startTime - 2L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.singletonList("two")),
+            toSet(windowStore.fetch(2, ofEpochMilli(startTime - 1L),
+                ofEpochMilli(startTime - 1L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(asList("two", "two+1")),
+            toSet(windowStore
+                .fetch(2, ofEpochMilli(startTime), ofEpochMilli(startTime + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(asList("two", "two+1", "two+2")),
+            toSet(windowStore.fetch(2, ofEpochMilli(startTime + 1L),
+                ofEpochMilli(startTime + 1L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(asList("two", "two+1", "two+2", "two+3")),
+            toSet(windowStore.fetch(2, ofEpochMilli(startTime + 2L),
+                ofEpochMilli(startTime + 2L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(asList("two+1", "two+2", "two+3", "two+4")),
+            toSet(windowStore.fetch(2, ofEpochMilli(startTime + 3L),
+                ofEpochMilli(startTime + 3L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(asList("two+2", "two+3", "two+4", "two+5")),
+            toSet(windowStore.fetch(2, ofEpochMilli(startTime + 4L),
+                ofEpochMilli(startTime + 4L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(asList("two+3", "two+4", "two+5", "two+6")),
+            toSet(windowStore.fetch(2, ofEpochMilli(startTime + 5L),
+                ofEpochMilli(startTime + 5L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(asList("two+4", "two+5", "two+6")),
+            toSet(windowStore.fetch(2, ofEpochMilli(startTime + 6L),
+                ofEpochMilli(startTime + 6L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(asList("two+5", "two+6")),
+            toSet(windowStore.fetch(2, ofEpochMilli(startTime + 7L),
+                ofEpochMilli(startTime + 7L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.singletonList("two+6")),
+            toSet(windowStore.fetch(2, ofEpochMilli(startTime + 8L),
+                ofEpochMilli(startTime + 8L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(2, ofEpochMilli(startTime + 9L),
+                ofEpochMilli(startTime + 9L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(2, ofEpochMilli(startTime + 10L),
+                ofEpochMilli(startTime + 10L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(2, ofEpochMilli(startTime + 11L),
+                ofEpochMilli(startTime + 11L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(2, ofEpochMilli(startTime + 12L),
+                ofEpochMilli(startTime + 12L + WINDOW_SIZE))));
+
+        // Flush the store and verify all current entries were properly flushed ...
+        windowStore.flush();
+
+        final Map<Integer, Set<String>> entriesByKey = entriesByKey(changeLog, startTime);
+
+        assertEquals(Utils.mkSet("zero@0"), entriesByKey.get(0));
+        assertEquals(Utils.mkSet("one@1"), entriesByKey.get(1));
+        assertEquals(
+            Utils.mkSet("two@2", "two+1@3", "two+2@4", "two+3@5", "two+4@6", "two+5@7", "two+6@8"),
+            entriesByKey.get(2));
+        assertNull(entriesByKey.get(3));
+        assertEquals(Utils.mkSet("four@4"), entriesByKey.get(4));
+        assertEquals(Utils.mkSet("five@5"), entriesByKey.get(5));
+        assertNull(entriesByKey.get(6));
+    }
+
+    @Test
+    public void testPutSameKeyTimestamp() {
+        windowStore = buildWindowStore(RETENTION_PERIOD, WINDOW_SIZE, true, Serdes.Integer(), Serdes.String());
+        windowStore.init(context, windowStore);
+
+        final long startTime = SEGMENT_INTERVAL - 4L;
+
+        setCurrentTime(startTime);
+        windowStore.put(0, "zero");
+
+        assertEquals(
+            new HashSet<>(Collections.singletonList("zero")),
+            toSet(windowStore.fetch(0, ofEpochMilli(startTime - WINDOW_SIZE),
+                ofEpochMilli(startTime + WINDOW_SIZE))));
+
+        windowStore.put(0, "zero");
+        windowStore.put(0, "zero+");
+        windowStore.put(0, "zero++");
+
+        assertEquals(
+            new HashSet<>(asList("zero", "zero", "zero+", "zero++")),
+            toSet(windowStore.fetch(
+                0,
+                ofEpochMilli(startTime - WINDOW_SIZE),
+                ofEpochMilli(startTime + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(asList("zero", "zero", "zero+", "zero++")),
+            toSet(windowStore.fetch(
+                0,
+                ofEpochMilli(startTime + 1L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 1L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(asList("zero", "zero", "zero+", "zero++")),
+            toSet(windowStore.fetch(
+                0,
+                ofEpochMilli(startTime + 2L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 2L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(asList("zero", "zero", "zero+", "zero++")),
+            toSet(windowStore.fetch(
+                0,
+                ofEpochMilli(startTime + 3L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 3L + WINDOW_SIZE))));
+        assertEquals(
+            new HashSet<>(Collections.emptyList()),
+            toSet(windowStore.fetch(
+                0,
+                ofEpochMilli(startTime + 4L - WINDOW_SIZE),
+                ofEpochMilli(startTime + 4L + WINDOW_SIZE))));
+
+        // Flush the store and verify all current entries were properly flushed ...
+        windowStore.flush();
+
+        final Map<Integer, Set<String>> entriesByKey = entriesByKey(changeLog, startTime);
+
+        assertEquals(Utils.mkSet("zero@0", "zero@0", "zero+@0", "zero++@0"), entriesByKey.get(0));
+    }
+
+    @Test
+    public void shouldCloseOpenIteratorsWhenStoreIsClosedAndNotThrowInvalidStateStoreExceptionOnHasNext() {
+        setCurrentTime(0);
+        windowStore.put(1, "one", 1L);
+        windowStore.put(1, "two", 2L);
+        windowStore.put(1, "three", 3L);
+
+        final WindowStoreIterator<String> iterator = windowStore.fetch(1, ofEpochMilli(1L), ofEpochMilli(3L));
+        assertTrue(iterator.hasNext());
+        windowStore.close();
+
+        assertFalse(iterator.hasNext());
+    }
+
+    @Test
+    public void shouldFetchAndIterateOverExactKeys() {
+        final long windowSize = 0x7a00000000000000L;
+        final long retentionPeriod = 0x7a00000000000000L;
+        final WindowStore<String, String> windowStore = buildWindowStore(retentionPeriod,
+                                                                         windowSize,
+                                                                         false,
+                                                                         Serdes.String(),
+                                                                         Serdes.String());
+
+        windowStore.init(context, windowStore);
+
+        windowStore.put("a", "0001", 0);
+        windowStore.put("aa", "0002", 0);
+        windowStore.put("a", "0003", 1);
+        windowStore.put("aa", "0004", 1);
+        windowStore.put("a", "0005", 0x7a00000000000000L - 1);
+
+        final Set expected = new HashSet<>(asList("0001", "0003", "0005"));
+        assertThat(toSet(windowStore.fetch("a", ofEpochMilli(0), ofEpochMilli(Long.MAX_VALUE))), equalTo(expected));
+
+        Set<KeyValue<Windowed<String>, String>> set =
+            toSet(windowStore.fetch("a", "a", ofEpochMilli(0), ofEpochMilli(Long.MAX_VALUE)));
+        assertThat(set, equalTo(new HashSet<>(asList(
+            windowedPair("a", "0001", 0, windowSize),
+            windowedPair("a", "0003", 1, windowSize),
+            windowedPair("a", "0005", 0x7a00000000000000L - 1, windowSize)
+        ))));
+
+        set = toSet(windowStore.fetch("aa", "aa", ofEpochMilli(0), ofEpochMilli(Long.MAX_VALUE)));
+        assertThat(set, equalTo(new HashSet<>(asList(
+            windowedPair("aa", "0002", 0, windowSize),
+            windowedPair("aa", "0004", 1, windowSize)
+        ))));
+    }
+
+    @Test
+    public void testDeleteAndUpdate() {
+
+        final long currentTime = 0;
+        setCurrentTime(currentTime);
+        windowStore.put(1, "one");
+        windowStore.put(1, "one v2");
+
+        WindowStoreIterator<String> iterator = windowStore.fetch(1, 0, currentTime);
+        assertEquals(new KeyValue<>(currentTime, "one v2"), iterator.next());
+
+        windowStore.put(1, null);
+        iterator = windowStore.fetch(1, 0, currentTime);
+        assertFalse(iterator.hasNext());
+    }
+
+    @Test
+    public void shouldReturnNullOnWindowNotFound() {
+        assertNull(windowStore.fetch(1, 0L));
+    }
+
+    @Test(expected = NullPointerException.class)
+    public void shouldThrowNullPointerExceptionOnPutNullKey() {
+        windowStore.put(null, "anyValue");
+    }
+
+    @Test(expected = NullPointerException.class)
+    public void shouldThrowNullPointerExceptionOnGetNullKey() {
+        windowStore.fetch(null, ofEpochMilli(1L), ofEpochMilli(2L));
+    }
+
+    @Test(expected = NullPointerException.class)
+    public void shouldThrowNullPointerExceptionOnRangeNullFromKey() {
+        windowStore.fetch(null, 2, ofEpochMilli(1L), ofEpochMilli(2L));
+    }
+
+    @Test(expected = NullPointerException.class)
+    public void shouldThrowNullPointerExceptionOnRangeNullToKey() {
+        windowStore.fetch(1, null, ofEpochMilli(1L), ofEpochMilli(2L));
+    }
+
+    @Test
+    public void shouldFetchAndIterateOverExactBinaryKeys() {
+        final WindowStore<Bytes, String> windowStore = buildWindowStore(RETENTION_PERIOD,
+                                                                        WINDOW_SIZE,
+                                                                        true,
+                                                                        Serdes.Bytes(),
+                                                                        Serdes.String());
+        windowStore.init(context, windowStore);
+
+        final Bytes key1 = Bytes.wrap(new byte[]{0});
+        final Bytes key2 = Bytes.wrap(new byte[]{0, 0});
+        final Bytes key3 = Bytes.wrap(new byte[]{0, 0, 0});
+        windowStore.put(key1, "1", 0);
+        windowStore.put(key2, "2", 0);
+        windowStore.put(key3, "3", 0);
+        windowStore.put(key1, "4", 1);
+        windowStore.put(key2, "5", 1);
+        windowStore.put(key3, "6", 59999);
+        windowStore.put(key1, "7", 59999);
+        windowStore.put(key2, "8", 59999);
+        windowStore.put(key3, "9", 59999);
+
+        final Set expectedKey1 = new HashSet<>(asList("1", "4", "7"));
+        assertThat(toSet(windowStore.fetch(key1, ofEpochMilli(0), ofEpochMilli(Long.MAX_VALUE))),
+            equalTo(expectedKey1));
+        final Set expectedKey2 = new HashSet<>(asList("2", "5", "8"));
+        assertThat(toSet(windowStore.fetch(key2, ofEpochMilli(0), ofEpochMilli(Long.MAX_VALUE))),
+            equalTo(expectedKey2));
+        final Set expectedKey3 = new HashSet<>(asList("3", "6", "9"));
+        assertThat(toSet(windowStore.fetch(key3, ofEpochMilli(0), ofEpochMilli(Long.MAX_VALUE))),
+            equalTo(expectedKey3));
+    }
+
+    @Test
+    public void shouldReturnSameResultsForSingleKeyFetchAndEqualKeyRangeFetch() {
+        windowStore.put(1, "one", 0L);
+        windowStore.put(2, "two", 1L);
+        windowStore.put(2, "two", 2L);
+        windowStore.put(3, "three", 3L);
+
+        final WindowStoreIterator<String> singleKeyIterator = windowStore.fetch(2, 0L, 5L);
+        final KeyValueIterator<Windowed<Integer>, String> keyRangeIterator = windowStore.fetch(2, 2, 0L, 5L);
+
+        assertEquals(singleKeyIterator.next().value, keyRangeIterator.next().value);
+        assertEquals(singleKeyIterator.next().value, keyRangeIterator.next().value);
+        assertFalse(singleKeyIterator.hasNext());
+        assertFalse(keyRangeIterator.hasNext());
+    }
+
+    @Test
+    public void shouldNotThrowInvalidRangeExceptionWithNegativeFromKey() {
+        setClassLoggerToDebug();
+        final LogCaptureAppender appender = LogCaptureAppender.createAndRegister();
+
+        final KeyValueIterator iterator = windowStore.fetch(-1, 1, 0L, 10L);
+        assertFalse(iterator.hasNext());
+
+        final List<String> messages = appender.getMessages();
+        assertThat(messages,
+            hasItem("Returning empty iterator for fetch with invalid key range: from > to. "
+                + "This may be due to serdes that don't preserve ordering when lexicographically comparing the serialized bytes. "
+                + "Note that the built-in numerical serdes do not follow this for negative numbers"));
+    }
+
+    @Test
+    public void shouldLogAndMeasureExpiredRecords() {
+        setClassLoggerToDebug();
+        final LogCaptureAppender appender = LogCaptureAppender.createAndRegister();
+
+        // Advance stream time by inserting record with large enough timestamp that records with timestamp 0 are expired
+        windowStore.put(1, "initial record", 2 * RETENTION_PERIOD);
+
+        // Try inserting a record with timestamp 0 -- should be dropped
+        windowStore.put(1, "late record", 0L);
+        windowStore.put(1, "another on-time record", RETENTION_PERIOD + 1);
+
+        LogCaptureAppender.unregister(appender);
+
+        final Map<MetricName, ? extends Metric> metrics = context.metrics().metrics();
+
+        final String metricScope = getMetricsScope();
+
+        final Metric dropTotal = metrics.get(new MetricName(
+            "expired-window-record-drop-total",
+            "stream-" + metricScope + "-metrics",
+            "The total number of occurrence of expired-window-record-drop operations.",
+            mkMap(
+                mkEntry("client-id", "mock"),
+                mkEntry("task-id", "0_0"),
+                mkEntry(metricScope + "-id", windowStore.name())
+            )
+        ));
+
+        final Metric dropRate = metrics.get(new MetricName(
+            "expired-window-record-drop-rate",
+            "stream-" + metricScope + "-metrics",
+            "The average number of occurrence of expired-window-record-drop operation per second.",
+            mkMap(
+                mkEntry("client-id", "mock"),
+                mkEntry("task-id", "0_0"),
+                mkEntry(metricScope + "-id", windowStore.name())
+            )
+        ));
+
+        assertEquals(1.0, dropTotal.metricValue());
+        assertNotEquals(0.0, dropRate.metricValue());
+        final List<String> messages = appender.getMessages();
+        assertThat(messages, hasItem("Skipping record for expired segment."));
+    }
+
+    @Test
+    public void shouldNotThrowExceptionWhenFetchRangeIsExpired() {
+        windowStore.put(1, "one", 0L);
+        windowStore.put(1, "two", 4 * RETENTION_PERIOD);
+
+        final WindowStoreIterator<String> iterator = windowStore.fetch(1, 0L, 10L);
+
+        assertFalse(iterator.hasNext());
+    }
+
+    @Test
+    public void testWindowIteratorPeek() {
+        final long currentTime = 0;
+        setCurrentTime(currentTime);
+        windowStore.put(1, "one");
+
+        final KeyValueIterator<Windowed<Integer>, String> iterator = windowStore.fetchAll(0L, currentTime);
+
+        assertTrue(iterator.hasNext());
+        final Windowed<Integer> nextKey = iterator.peekNextKey();
+
+        assertEquals(iterator.peekNextKey(), nextKey);
+        assertEquals(iterator.peekNextKey(), iterator.next().key);
+        assertFalse(iterator.hasNext());
+    }
+
+    @Test
+    public void testValueIteratorPeek() {
+        windowStore.put(1, "one", 0L);
+
+        final WindowStoreIterator<String> iterator = windowStore.fetch(1, 0L, 10L);
+
+        assertTrue(iterator.hasNext());
+        final Long nextKey = iterator.peekNextKey();
+
+        assertEquals(iterator.peekNextKey(), nextKey);
+        assertEquals(iterator.peekNextKey(), iterator.next().key);
+        assertFalse(iterator.hasNext());
+    }
+
+    @Test
+    public void shouldNotThrowConcurrentModificationException() {
+        long currentTime = 0;
+        setCurrentTime(currentTime);
+        windowStore.put(1, "one");
+
+        currentTime += WINDOW_SIZE * 10;
+        setCurrentTime(currentTime);
+        windowStore.put(1, "two");
+
+        final KeyValueIterator<Windowed<Integer>, String> iterator = windowStore.all();
+
+        currentTime += WINDOW_SIZE * 10;
+        setCurrentTime(currentTime);
+        windowStore.put(1, "three");
+
+        currentTime += WINDOW_SIZE * 10;
+        setCurrentTime(currentTime);
+        windowStore.put(2, "four");
+
+        // Iterator should return all records in store and not throw exception b/c some were added after fetch
+        assertEquals(windowedPair(1, "one", 0), iterator.next());
+        assertEquals(windowedPair(1, "two", WINDOW_SIZE * 10), iterator.next());
+        assertEquals(windowedPair(1, "three", WINDOW_SIZE * 20), iterator.next());
+        assertEquals(windowedPair(2, "four", WINDOW_SIZE * 30), iterator.next());
+        assertFalse(iterator.hasNext());
+    }
+
+    @Test
+    public void testFetchDuplicates() {
+        windowStore = buildWindowStore(RETENTION_PERIOD, WINDOW_SIZE, true, Serdes.Integer(), Serdes.String());
+        windowStore.init(context, windowStore);
+
+        long currentTime = 0;
+        setCurrentTime(currentTime);
+        windowStore.put(1, "one");
+        windowStore.put(1, "one-2");
+
+        currentTime += WINDOW_SIZE * 10;
+        setCurrentTime(currentTime);
+        windowStore.put(1, "two");
+        windowStore.put(1, "two-2");
+
+        currentTime += WINDOW_SIZE * 10;
+        setCurrentTime(currentTime);
+        windowStore.put(1, "three");
+        windowStore.put(1, "three-2");
+
+        final WindowStoreIterator<String> iterator = windowStore.fetch(1, 0, WINDOW_SIZE * 10);
+
+        assertEquals(new KeyValue<>(0L, "one"), iterator.next());
+        assertEquals(new KeyValue<>(0L, "one-2"), iterator.next());
+        assertEquals(new KeyValue<>(WINDOW_SIZE * 10, "two"), iterator.next());
+        assertEquals(new KeyValue<>(WINDOW_SIZE * 10, "two-2"), iterator.next());
+        assertFalse(iterator.hasNext());
+    }
+
+
+    private void putFirstBatch(final WindowStore<Integer, String> store,
+        @SuppressWarnings("SameParameterValue") final long startTime,
+        final InternalMockProcessorContext context) {
+        context.setRecordContext(createRecordContext(startTime));
+        store.put(0, "zero");
+        context.setRecordContext(createRecordContext(startTime + 1L));
+        store.put(1, "one");
+        context.setRecordContext(createRecordContext(startTime + 2L));
+        store.put(2, "two");
+        context.setRecordContext(createRecordContext(startTime + 4L));
+        store.put(4, "four");
+        context.setRecordContext(createRecordContext(startTime + 5L));
+        store.put(5, "five");
+    }
+
+    private void putSecondBatch(final WindowStore<Integer, String> store,
+        @SuppressWarnings("SameParameterValue") final long startTime,
+        final InternalMockProcessorContext context) {
+        context.setRecordContext(createRecordContext(startTime + 3L));
+        store.put(2, "two+1");
+        context.setRecordContext(createRecordContext(startTime + 4L));
+        store.put(2, "two+2");
+        context.setRecordContext(createRecordContext(startTime + 5L));
+        store.put(2, "two+3");
+        context.setRecordContext(createRecordContext(startTime + 6L));
+        store.put(2, "two+4");
+        context.setRecordContext(createRecordContext(startTime + 7L));
+        store.put(2, "two+5");
+        context.setRecordContext(createRecordContext(startTime + 8L));
+        store.put(2, "two+6");
+    }
+
+    protected static <E> Set<E> toSet(final WindowStoreIterator<E> iterator) {
+        final Set<E> set = new HashSet<>();
+        while (iterator.hasNext()) {
+            set.add(iterator.next().value);
+        }
+        return set;
+    }
+
+    protected static <K, V> Set<KeyValue<K, V>> toSet(final Iterator<KeyValue<K, V>> iterator) {
+        final Set<KeyValue<K, V>> results = new HashSet<>();
+
+        while (iterator.hasNext()) {
+            results.add(iterator.next());
+        }
+        return results;
+    }
+
+    private Map<Integer, Set<String>> entriesByKey(final List<KeyValue<byte[], byte[]>> changeLog,
+        @SuppressWarnings("SameParameterValue") final long startTime) {
+        final HashMap<Integer, Set<String>> entriesByKey = new HashMap<>();
+
+        for (final KeyValue<byte[], byte[]> entry : changeLog) {
+            final long timestamp = WindowKeySchema.extractStoreTimestamp(entry.key);
+
+            final Integer key = WindowKeySchema.extractStoreKey(entry.key, serdes);
+            final String value = entry.value == null ? null : serdes.valueFrom(entry.value);
+
+            final Set<String> entries = entriesByKey.computeIfAbsent(key, k -> new HashSet<>());
+            entries.add(value + "@" + (timestamp - startTime));
+        }
+
+        return entriesByKey;
+    }
+
+    protected static <K, V> KeyValue<Windowed<K>, V> windowedPair(final K key, final V value, final long timestamp) {
+        return windowedPair(key, value, timestamp, WINDOW_SIZE);
+    }
+
+    private static <K, V> KeyValue<Windowed<K>, V> windowedPair(final K key, final V value, final long timestamp, final long windowSize) {
+        return KeyValue.pair(new Windowed<>(key, WindowKeySchema.timeWindowForSize(timestamp, windowSize)), value);
+    }
+
+    protected void setCurrentTime(final long currentTime) {
+        context.setRecordContext(createRecordContext(currentTime));
+    }
+
+    private ProcessorRecordContext createRecordContext(final long time) {
+        return new ProcessorRecordContext(time, 0, 0, "topic", null);
+    }
+
+}