You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by gu...@apache.org on 2022/03/12 01:53:55 UTC

[kafka] branch trunk updated: KIP-825: Part 1, add new RocksDBTimeOrderedWindowStore (#11802)

This is an automated email from the ASF dual-hosted git repository.

guozhang pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 63ea5db  KIP-825: Part 1, add new RocksDBTimeOrderedWindowStore (#11802)
63ea5db is described below

commit 63ea5db9ec2e93ded483025649f48565f056242d
Author: Hao Li <11...@users.noreply.github.com>
AuthorDate: Fri Mar 11 17:51:10 2022 -0800

    KIP-825: Part 1, add new RocksDBTimeOrderedWindowStore (#11802)
    
    Initial State store implementation for TimedWindow and SlidingWindow.
    
    RocksDBTimeOrderedWindowStore.java contains one RocksDBTimeOrderedSegmentedBytesStore which contains index and base schema.
    
    PrefixedWindowKeySchemas.java implements keyschema for time ordered base store and key ordered index store.
    
    Reviewers: James Hughes, Guozhang Wang <wa...@gmail.com>
---
 ...stractDualSchemaRocksDBSegmentedBytesStore.java |  300 ++++++
 .../AbstractRocksDBSegmentedBytesStore.java        |   12 +-
 .../state/internals/CachingSessionStore.java       |   12 +-
 .../state/internals/CachingWindowStore.java        |   12 +-
 .../state/internals/PrefixedWindowKeySchemas.java  |  385 +++++++
 .../RocksDBTimeOrderedSegmentedBytesStore.java     |  335 ++++++
 .../internals/RocksDBTimeOrderedWindowStore.java   |  171 +++
 ...IndexedTimeOrderedWindowBytesStoreSupplier.java |  131 +++
 .../state/internals/SegmentedBytesStore.java       |    3 +-
 .../streams/state/internals/SessionKeySchema.java  |    8 +-
 .../streams/state/internals/WindowKeySchema.java   |    9 +-
 .../internals/WindowStoreIteratorWrapper.java      |   35 +-
 ...ctDualSchemaRocksDBSegmentedBytesStoreTest.java | 1128 ++++++++++++++++++++
 .../RocksDBTimeOrderedSegmentedBytesStoreTest.java |   74 ++
 .../state/internals/RocksDBWindowStoreTest.java    |   64 +-
 .../state/internals/SessionKeySchemaTest.java      |    6 +-
 .../state/internals/WindowKeySchemaTest.java       |  357 ++++++-
 17 files changed, 2950 insertions(+), 92 deletions(-)

diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractDualSchemaRocksDBSegmentedBytesStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractDualSchemaRocksDBSegmentedBytesStore.java
new file mode 100644
index 0000000..39bfa6a2
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractDualSchemaRocksDBSegmentedBytesStore.java
@@ -0,0 +1,300 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.state.internals;
+
+import java.util.Optional;
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.common.metrics.Sensor;
+import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.streams.KeyValue;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.errors.ProcessorStateException;
+import org.apache.kafka.streams.processor.ProcessorContext;
+import org.apache.kafka.streams.processor.StateStore;
+import org.apache.kafka.streams.processor.StateStoreContext;
+import org.apache.kafka.streams.processor.internals.ProcessorContextUtils;
+import org.apache.kafka.streams.processor.internals.RecordBatchingStateRestoreCallback;
+import org.apache.kafka.streams.processor.internals.StoreToProcessorContextAdapter;
+import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
+import org.apache.kafka.streams.processor.internals.metrics.TaskMetrics;
+import org.apache.kafka.streams.query.Position;
+import org.apache.kafka.streams.state.KeyValueIterator;
+import org.rocksdb.RocksDBException;
+import org.rocksdb.WriteBatch;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.File;
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+
+import static org.apache.kafka.streams.StreamsConfig.InternalConfig.IQ_CONSISTENCY_OFFSET_VECTOR_ENABLED;
+
+public abstract class AbstractDualSchemaRocksDBSegmentedBytesStore<S extends Segment> implements SegmentedBytesStore {
+    private static final Logger LOG = LoggerFactory.getLogger(AbstractDualSchemaRocksDBSegmentedBytesStore.class);
+
+    private final String name;
+    protected final AbstractSegments<S> segments;
+    private final String metricScope;
+    protected final KeySchema baseKeySchema;
+    protected final Optional<KeySchema> indexKeySchema;
+
+
+    protected ProcessorContext context;
+    private StateStoreContext stateStoreContext;
+    private Sensor expiredRecordSensor;
+    protected long observedStreamTime = ConsumerRecord.NO_TIMESTAMP;
+    protected boolean consistencyEnabled = false;
+    protected Position position;
+    protected OffsetCheckpoint positionCheckpoint;
+    private volatile boolean open;
+
+    AbstractDualSchemaRocksDBSegmentedBytesStore(final String name,
+                                                 final String metricScope,
+                                                 final KeySchema baseKeySchema,
+                                                 final Optional<KeySchema> indexKeySchema,
+                                                 final AbstractSegments<S> segments) {
+        this.name = name;
+        this.metricScope = metricScope;
+        this.baseKeySchema = baseKeySchema;
+        this.indexKeySchema = indexKeySchema;
+        this.segments = segments;
+    }
+
+    @Override
+    public KeyValueIterator<Bytes, byte[]> all() {
+        final List<S> searchSpace = segments.allSegments(true);
+        final Bytes from = baseKeySchema.lowerRange(null, 0);
+        final Bytes to = baseKeySchema.upperRange(null, Long.MAX_VALUE);
+
+        return new SegmentIterator<>(
+                searchSpace.iterator(),
+                baseKeySchema.hasNextCondition(null, null, 0, Long.MAX_VALUE, true),
+                from,
+                to,
+                true);
+    }
+
+    @Override
+    public KeyValueIterator<Bytes, byte[]> backwardAll() {
+        final List<S> searchSpace = segments.allSegments(false);
+        final Bytes from = baseKeySchema.lowerRange(null, 0);
+        final Bytes to = baseKeySchema.upperRange(null, Long.MAX_VALUE);
+
+        return new SegmentIterator<>(
+                searchSpace.iterator(),
+                baseKeySchema.hasNextCondition(null, null, 0, Long.MAX_VALUE, false),
+                from,
+                to,
+                false);
+    }
+
+    @Override
+    public void remove(final Bytes rawBaseKey) {
+        final long timestamp = baseKeySchema.segmentTimestamp(rawBaseKey);
+        observedStreamTime = Math.max(observedStreamTime, timestamp);
+        final S segment = segments.getSegmentForTimestamp(timestamp);
+        if (segment == null) {
+            return;
+        }
+        segment.delete(rawBaseKey);
+
+        if (hasIndex()) {
+            final KeyValue<Bytes, byte[]> kv = getIndexKeyValue(rawBaseKey, null);
+            segment.delete(kv.key);
+        }
+    }
+
+    abstract protected KeyValue<Bytes, byte[]> getIndexKeyValue(final Bytes baseKey, final byte[] baseValue);
+
+    // For testing
+    void putIndex(final Bytes indexKey, final byte[] value) {
+        if (!hasIndex()) {
+            throw new IllegalStateException("Index store doesn't exist");
+        }
+
+        final long timestamp = indexKeySchema.get().segmentTimestamp(indexKey);
+        final long segmentId = segments.segmentId(timestamp);
+        final S segment = segments.getOrCreateSegmentIfLive(segmentId, context, observedStreamTime);
+
+        if (segment != null) {
+            segment.put(indexKey, value);
+        }
+    }
+
+    byte[] getIndex(final Bytes indexKey) {
+        if (!hasIndex()) {
+            throw new IllegalStateException("Index store doesn't exist");
+        }
+
+        final long timestamp = indexKeySchema.get().segmentTimestamp(indexKey);
+        final long segmentId = segments.segmentId(timestamp);
+        final S segment = segments.getOrCreateSegmentIfLive(segmentId, context, observedStreamTime);
+
+        if (segment != null) {
+            return segment.get(indexKey);
+        }
+        return null;
+    }
+
+    void removeIndex(final Bytes indexKey) {
+        if (!hasIndex()) {
+            throw new IllegalStateException("Index store doesn't exist");
+        }
+
+        final long timestamp = indexKeySchema.get().segmentTimestamp(indexKey);
+        final long segmentId = segments.segmentId(timestamp);
+        final S segment = segments.getOrCreateSegmentIfLive(segmentId, context, observedStreamTime);
+
+        if (segment != null) {
+            segment.delete(indexKey);
+        }
+    }
+
+    @Override
+    public void put(final Bytes rawBaseKey,
+                    final byte[] value) {
+        final long timestamp = baseKeySchema.segmentTimestamp(rawBaseKey);
+        observedStreamTime = Math.max(observedStreamTime, timestamp);
+        final long segmentId = segments.segmentId(timestamp);
+        final S segment = segments.getOrCreateSegmentIfLive(segmentId, context, observedStreamTime);
+
+        if (segment == null) {
+            expiredRecordSensor.record(1.0d, ProcessorContextUtils.currentSystemTime(context));
+            LOG.warn("Skipping record for expired segment.");
+        } else {
+            StoreQueryUtils.updatePosition(position, stateStoreContext);
+            segment.put(rawBaseKey, value);
+
+            if (hasIndex()) {
+                final KeyValue<Bytes, byte[]> indexKeyValue = getIndexKeyValue(rawBaseKey, value);
+                segment.put(indexKeyValue.key, indexKeyValue.value);
+            }
+        }
+    }
+
+    @Override
+    public byte[] get(final Bytes rawKey) {
+        final S segment = segments.getSegmentForTimestamp(baseKeySchema.segmentTimestamp(rawKey));
+        if (segment == null) {
+            return null;
+        }
+        return segment.get(rawKey);
+    }
+
+    @Override
+    public String name() {
+        return name;
+    }
+
+    @Deprecated
+    @Override
+    public void init(final ProcessorContext context,
+                     final StateStore root) {
+        this.context = context;
+
+        final StreamsMetricsImpl metrics = ProcessorContextUtils.getMetricsImpl(context);
+        final String threadId = Thread.currentThread().getName();
+        final String taskName = context.taskId().toString();
+
+        expiredRecordSensor = TaskMetrics.droppedRecordsSensor(
+            threadId,
+            taskName,
+            metrics
+        );
+
+        segments.openExisting(context, observedStreamTime);
+
+        final File positionCheckpointFile = new File(context.stateDir(), name() + ".position");
+        this.positionCheckpoint = new OffsetCheckpoint(positionCheckpointFile);
+        this.position = StoreQueryUtils.readPositionFromCheckpoint(positionCheckpoint);
+
+        // register and possibly restore the state from the logs
+        stateStoreContext.register(
+            root,
+            (RecordBatchingStateRestoreCallback) this::restoreAllInternal,
+            () -> StoreQueryUtils.checkpointPosition(positionCheckpoint, position)
+        );
+
+        open = true;
+
+        consistencyEnabled = StreamsConfig.InternalConfig.getBoolean(
+            context.appConfigs(),
+            IQ_CONSISTENCY_OFFSET_VECTOR_ENABLED,
+            false
+        );
+    }
+
+    @Override
+    public void init(final StateStoreContext context, final StateStore root) {
+        this.stateStoreContext = context;
+        init(StoreToProcessorContextAdapter.adapt(context), root);
+    }
+
+    @Override
+    public void flush() {
+        segments.flush();
+    }
+
+    @Override
+    public void close() {
+        open = false;
+        segments.close();
+    }
+
+    @Override
+    public boolean persistent() {
+        return true;
+    }
+
+    @Override
+    public boolean isOpen() {
+        return open;
+    }
+
+    // Visible for testing
+    List<S> getSegments() {
+        return segments.allSegments(false);
+    }
+
+    // Visible for testing
+    void restoreAllInternal(final Collection<ConsumerRecord<byte[], byte[]>> records) {
+        try {
+            final Map<S, WriteBatch> writeBatchMap = getWriteBatches(records);
+            for (final Map.Entry<S, WriteBatch> entry : writeBatchMap.entrySet()) {
+                final S segment = entry.getKey();
+                final WriteBatch batch = entry.getValue();
+                segment.write(batch);
+                batch.close();
+            }
+        } catch (final RocksDBException e) {
+            throw new ProcessorStateException("Error restoring batch to store " + this.name, e);
+        }
+    }
+
+    abstract Map<S, WriteBatch> getWriteBatches(final Collection<ConsumerRecord<byte[], byte[]>> records);
+
+    @Override
+    public Position getPosition() {
+        return position;
+    }
+
+    public boolean hasIndex() {
+        return indexKeySchema.isPresent();
+    }
+}
\ No newline at end of file
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractRocksDBSegmentedBytesStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractRocksDBSegmentedBytesStore.java
index bbe8c54..13f914d 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractRocksDBSegmentedBytesStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractRocksDBSegmentedBytesStore.java
@@ -98,7 +98,7 @@ public class AbstractRocksDBSegmentedBytesStore<S extends Segment> implements Se
 
         return new SegmentIterator<>(
                 searchSpace.iterator(),
-                keySchema.hasNextCondition(key, key, from, to),
+                keySchema.hasNextCondition(key, key, from, to, forward),
                 binaryFrom,
                 binaryTo,
                 forward);
@@ -140,7 +140,7 @@ public class AbstractRocksDBSegmentedBytesStore<S extends Segment> implements Se
 
         return new SegmentIterator<>(
                 searchSpace.iterator(),
-                keySchema.hasNextCondition(keyFrom, keyTo, from, to),
+                keySchema.hasNextCondition(keyFrom, keyTo, from, to, forward),
                 binaryFrom,
                 binaryTo,
                 forward);
@@ -152,7 +152,7 @@ public class AbstractRocksDBSegmentedBytesStore<S extends Segment> implements Se
 
         return new SegmentIterator<>(
                 searchSpace.iterator(),
-                keySchema.hasNextCondition(null, null, 0, Long.MAX_VALUE),
+                keySchema.hasNextCondition(null, null, 0, Long.MAX_VALUE, true),
                 null,
                 null,
                 true);
@@ -164,7 +164,7 @@ public class AbstractRocksDBSegmentedBytesStore<S extends Segment> implements Se
 
         return new SegmentIterator<>(
                 searchSpace.iterator(),
-                keySchema.hasNextCondition(null, null, 0, Long.MAX_VALUE),
+                keySchema.hasNextCondition(null, null, 0, Long.MAX_VALUE, false),
                 null,
                 null,
                 false);
@@ -177,7 +177,7 @@ public class AbstractRocksDBSegmentedBytesStore<S extends Segment> implements Se
 
         return new SegmentIterator<>(
                 searchSpace.iterator(),
-                keySchema.hasNextCondition(null, null, timeFrom, timeTo),
+                keySchema.hasNextCondition(null, null, timeFrom, timeTo, true),
                 null,
                 null,
                 true);
@@ -190,7 +190,7 @@ public class AbstractRocksDBSegmentedBytesStore<S extends Segment> implements Se
 
         return new SegmentIterator<>(
                 searchSpace.iterator(),
-                keySchema.hasNextCondition(null, null, timeFrom, timeTo),
+                keySchema.hasNextCondition(null, null, timeFrom, timeTo, false),
                 null,
                 null,
                 false);
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingSessionStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingSessionStore.java
index a701e6d..59d2a0e 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingSessionStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingSessionStore.java
@@ -177,7 +177,8 @@ class CachingSessionStore
         final HasNextCondition hasNextCondition = keySchema.hasNextCondition(key,
                                                                              key,
                                                                              earliestSessionEndTime,
-                                                                             latestSessionStartTime);
+                                                                             latestSessionStartTime,
+                                                                             true);
         final PeekingKeyValueIterator<Bytes, LRUCacheEntry> filteredCacheIterator =
             new FilteredCacheIterator(cacheIterator, hasNextCondition, cacheFunction);
         return new MergedSortedCacheSessionStoreIterator(filteredCacheIterator, storeIterator, cacheFunction, true);
@@ -207,7 +208,8 @@ class CachingSessionStore
             key,
             key,
             earliestSessionEndTime,
-            latestSessionStartTime
+            latestSessionStartTime,
+            false
         );
         final PeekingKeyValueIterator<Bytes, LRUCacheEntry> filteredCacheIterator =
             new FilteredCacheIterator(cacheIterator, hasNextCondition, cacheFunction);
@@ -236,7 +238,8 @@ class CachingSessionStore
         final HasNextCondition hasNextCondition = keySchema.hasNextCondition(keyFrom,
                                                                              keyTo,
                                                                              earliestSessionEndTime,
-                                                                             latestSessionStartTime);
+                                                                             latestSessionStartTime,
+                                                                     true);
         final PeekingKeyValueIterator<Bytes, LRUCacheEntry> filteredCacheIterator =
             new FilteredCacheIterator(cacheIterator, hasNextCondition, cacheFunction);
         return new MergedSortedCacheSessionStoreIterator(filteredCacheIterator, storeIterator, cacheFunction, true);
@@ -264,7 +267,8 @@ class CachingSessionStore
             keyFrom,
             keyTo,
             earliestSessionEndTime,
-            latestSessionStartTime
+            latestSessionStartTime,
+            false
         );
         final PeekingKeyValueIterator<Bytes, LRUCacheEntry> filteredCacheIterator =
             new FilteredCacheIterator(cacheIterator, hasNextCondition, cacheFunction);
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingWindowStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingWindowStore.java
index ee9dbf9..8a1f886 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingWindowStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingWindowStore.java
@@ -213,7 +213,7 @@ class CachingWindowStore
                 cacheFunction.cacheKey(keySchema.upperRangeFixedSize(key, timeTo))
             );
 
-        final HasNextCondition hasNextCondition = keySchema.hasNextCondition(key, key, timeFrom, timeTo);
+        final HasNextCondition hasNextCondition = keySchema.hasNextCondition(key, key, timeFrom, timeTo, true);
         final PeekingKeyValueIterator<Bytes, LRUCacheEntry> filteredCacheIterator =
             new FilteredCacheIterator(cacheIterator, hasNextCondition, cacheFunction);
 
@@ -241,7 +241,7 @@ class CachingWindowStore
                 cacheFunction.cacheKey(keySchema.upperRangeFixedSize(key, timeTo))
             );
 
-        final HasNextCondition hasNextCondition = keySchema.hasNextCondition(key, key, timeFrom, timeTo);
+        final HasNextCondition hasNextCondition = keySchema.hasNextCondition(key, key, timeFrom, timeTo, false);
         final PeekingKeyValueIterator<Bytes, LRUCacheEntry> filteredCacheIterator =
             new FilteredCacheIterator(cacheIterator, hasNextCondition, cacheFunction);
 
@@ -279,7 +279,7 @@ class CachingWindowStore
                 keyTo == null ? null : cacheFunction.cacheKey(keySchema.upperRange(keyTo, timeTo))
             );
 
-        final HasNextCondition hasNextCondition = keySchema.hasNextCondition(keyFrom, keyTo, timeFrom, timeTo);
+        final HasNextCondition hasNextCondition = keySchema.hasNextCondition(keyFrom, keyTo, timeFrom, timeTo, true);
         final PeekingKeyValueIterator<Bytes, LRUCacheEntry> filteredCacheIterator =
             new FilteredCacheIterator(cacheIterator, hasNextCondition, cacheFunction);
 
@@ -323,7 +323,7 @@ class CachingWindowStore
                 keyTo == null ? null : cacheFunction.cacheKey(keySchema.upperRange(keyTo, timeTo))
             );
 
-        final HasNextCondition hasNextCondition = keySchema.hasNextCondition(keyFrom, keyTo, timeFrom, timeTo);
+        final HasNextCondition hasNextCondition = keySchema.hasNextCondition(keyFrom, keyTo, timeFrom, timeTo, false);
         final PeekingKeyValueIterator<Bytes, LRUCacheEntry> filteredCacheIterator =
             new FilteredCacheIterator(cacheIterator, hasNextCondition, cacheFunction);
 
@@ -345,7 +345,7 @@ class CachingWindowStore
         final KeyValueIterator<Windowed<Bytes>, byte[]> underlyingIterator = wrapped().fetchAll(timeFrom, timeTo);
         final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = context.cache().all(cacheName);
 
-        final HasNextCondition hasNextCondition = keySchema.hasNextCondition(null, null, timeFrom, timeTo);
+        final HasNextCondition hasNextCondition = keySchema.hasNextCondition(null, null, timeFrom, timeTo, true);
         final PeekingKeyValueIterator<Bytes, LRUCacheEntry> filteredCacheIterator =
             new FilteredCacheIterator(cacheIterator, hasNextCondition, cacheFunction);
         return new MergedSortedCacheWindowStoreKeyValueIterator(
@@ -366,7 +366,7 @@ class CachingWindowStore
         final KeyValueIterator<Windowed<Bytes>, byte[]> underlyingIterator = wrapped().backwardFetchAll(timeFrom, timeTo);
         final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = context.cache().reverseAll(cacheName);
 
-        final HasNextCondition hasNextCondition = keySchema.hasNextCondition(null, null, timeFrom, timeTo);
+        final HasNextCondition hasNextCondition = keySchema.hasNextCondition(null, null, timeFrom, timeTo, false);
         final PeekingKeyValueIterator<Bytes, LRUCacheEntry> filteredCacheIterator =
             new FilteredCacheIterator(cacheIterator, hasNextCondition, cacheFunction);
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/PrefixedWindowKeySchemas.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/PrefixedWindowKeySchemas.java
new file mode 100644
index 0000000..4f94ca9
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/PrefixedWindowKeySchemas.java
@@ -0,0 +1,385 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.state.internals;
+
+import java.util.Arrays;
+import org.apache.kafka.common.serialization.Deserializer;
+import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.streams.kstream.Window;
+import org.apache.kafka.streams.kstream.Windowed;
+
+import java.nio.ByteBuffer;
+import java.util.List;
+import org.apache.kafka.streams.state.StateSerdes;
+import org.apache.kafka.streams.state.internals.SegmentedBytesStore.KeySchema;
+
+import static org.apache.kafka.streams.state.StateSerdes.TIMESTAMP_SIZE;
+import static org.apache.kafka.streams.state.internals.WindowKeySchema.timeWindowForSize;
+
+public class PrefixedWindowKeySchemas {
+
+    private static final int PREFIX_SIZE = 1;
+    private static final byte TIME_FIRST_PREFIX = 0;
+    private static final byte KEY_FIRST_PREFIX = 1;
+    private static final int SEQNUM_SIZE = 4;
+
+    private static byte extractPrefix(final byte[] binaryBytes) {
+        return binaryBytes[0];
+    }
+
+    public static class TimeFirstWindowKeySchema implements RocksDBSegmentedBytesStore.KeySchema {
+
+        @Override
+        public Bytes upperRange(final Bytes key, final long to) {
+            if (key == null) {
+                // Put next prefix instead of null so that we can start from right prefix
+                // when scanning backwards
+                final byte nextPrefix = TIME_FIRST_PREFIX + 1;
+                return Bytes.wrap(ByteBuffer.allocate(PREFIX_SIZE).put(nextPrefix).array());
+            }
+            final byte[] maxKey = new byte[key.get().length];
+            Arrays.fill(maxKey, (byte) 0xFF);
+            return Bytes.wrap(ByteBuffer.allocate(PREFIX_SIZE + TIMESTAMP_SIZE + maxKey.length + SEQNUM_SIZE)
+                .put(TIME_FIRST_PREFIX)
+                .putLong(to)
+                .put(maxKey).putInt(Integer.MAX_VALUE)
+                .array());
+        }
+
+        @Override
+        public Bytes lowerRange(final Bytes key, final long from) {
+            if (key == null) {
+                return Bytes.wrap(ByteBuffer.allocate(PREFIX_SIZE + TIMESTAMP_SIZE)
+                    .put(TIME_FIRST_PREFIX)
+                    .putLong(from)
+                    .array());
+            }
+
+            /*
+             * Larger timestamp or key's byte order can't be smaller than this lower range. Reason:
+             *     1. Timestamp is fixed length (with big endian byte order). Since we put timestamp
+             *        first, larger timestamp will have larger byte order.
+             *     2. If timestamp is the same but key (k1) is larger than this lower range key (k2):
+             *         a. If k2 is not a prefix of k1, then k1 will always have larger byte order no
+             *            matter what seqnum k2 has
+             *         b. If k2 is a prefix of k1, since k2's seqnum is 0, after k1 appends seqnum,
+             *            it will always be larger than (k1 + seqnum).
+             */
+            return Bytes.wrap(ByteBuffer.allocate(PREFIX_SIZE + TIMESTAMP_SIZE + key.get().length)
+                .put(TIME_FIRST_PREFIX)
+                .putLong(from)
+                .put(key.get())
+                .array());
+        }
+
+        @Override
+        public Bytes lowerRangeFixedSize(final Bytes key, final long from) {
+            return toStoreKeyBinary(key, Math.max(0, from), 0);
+        }
+
+        @Override
+        public Bytes upperRangeFixedSize(final Bytes key, final long to) {
+            return toStoreKeyBinary(key, to, Integer.MAX_VALUE);
+        }
+
+        @Override
+        public long segmentTimestamp(final Bytes key) {
+            return extractStoreTimestamp(key.get());
+        }
+
+        @Override
+        public HasNextCondition hasNextCondition(final Bytes binaryKeyFrom,
+            final Bytes binaryKeyTo, final long from, final long to, final boolean forward) {
+            return iterator -> {
+                while (iterator.hasNext()) {
+                    final Bytes bytes = iterator.peekNextKey();
+                    final byte prefix = extractPrefix(bytes.get());
+
+                    if (prefix != TIME_FIRST_PREFIX) {
+                        return false;
+                    }
+
+                    final long time = TimeFirstWindowKeySchema.extractStoreTimestamp(bytes.get());
+
+                    // We can return false directly here since keys are sorted by time and if
+                    // we get time larger than `to`, there won't be time within range.
+                    if (forward && time > to) {
+                        return false;
+                    }
+                    if (!forward && time < from) {
+                        return false;
+                    }
+
+                    final Bytes keyBytes = Bytes.wrap(
+                        TimeFirstWindowKeySchema.extractStoreKeyBytes(bytes.get()));
+                    if ((binaryKeyFrom == null || keyBytes.compareTo(binaryKeyFrom) >= 0)
+                        && (binaryKeyTo == null || keyBytes.compareTo(binaryKeyTo) <= 0)
+                        && time >= from && time <= to) {
+                        return true;
+                    }
+                    iterator.next();
+                }
+                return false;
+            };
+        }
+
+        @Override
+        public <S extends Segment> List<S> segmentsToSearch(final Segments<S> segments,
+            final long from,
+            final long to,
+            final boolean forward) {
+            return segments.segments(from, to, forward);
+        }
+
+        static byte[] extractStoreKeyBytes(final byte[] binaryKey) {
+            final byte[] bytes = new byte[binaryKey.length - TIMESTAMP_SIZE - SEQNUM_SIZE - PREFIX_SIZE];
+            System.arraycopy(binaryKey, TIMESTAMP_SIZE + PREFIX_SIZE, bytes, 0, bytes.length);
+            return bytes;
+        }
+
+        static long extractStoreTimestamp(final byte[] binaryKey) {
+            return ByteBuffer.wrap(binaryKey).getLong(PREFIX_SIZE);
+        }
+
+        public static Bytes toStoreKeyBinary(final Windowed<Bytes> timeKey,
+                                             final int seqnum) {
+            return toStoreKeyBinary(timeKey.key().get(), timeKey.window().start(), seqnum);
+        }
+
+        public static <K> Windowed<K> fromStoreKey(final byte[] binaryKey,
+                                                   final long windowSize,
+                                                   final Deserializer<K> deserializer,
+                                                   final String topic) {
+            final K key = deserializer.deserialize(topic, extractStoreKeyBytes(binaryKey));
+            final Window window = extractStoreWindow(binaryKey, windowSize);
+            return new Windowed<>(key, window);
+        }
+
+        public static <K> Bytes toStoreKeyBinary(final Windowed<K> timeKey,
+                                                 final int seqnum,
+                                                 final StateSerdes<K, ?> serdes) {
+            final byte[] serializedKey = serdes.rawKey(timeKey.key());
+            return toStoreKeyBinary(serializedKey, timeKey.window().start(), seqnum);
+        }
+
+        // for store serdes
+        public static Bytes toStoreKeyBinary(final Bytes key,
+                                             final long timestamp,
+                                             final int seqnum) {
+            return toStoreKeyBinary(key.get(), timestamp, seqnum);
+        }
+
+        static Bytes toStoreKeyBinary(final byte[] serializedKey,
+                                      final long timestamp,
+                                      final int seqnum) {
+            final ByteBuffer buf = ByteBuffer.allocate(
+                PREFIX_SIZE + TIMESTAMP_SIZE + serializedKey.length + SEQNUM_SIZE);
+            buf.put(TIME_FIRST_PREFIX);
+            buf.putLong(timestamp);
+            buf.put(serializedKey);
+            buf.putInt(seqnum);
+
+            return Bytes.wrap(buf.array());
+        }
+
+        public static Windowed<Bytes> fromStoreBytesKey(final byte[] binaryKey,
+                                                        final long windowSize) {
+            final Bytes key = Bytes.wrap(extractStoreKeyBytes(binaryKey));
+            final Window window = extractStoreWindow(binaryKey, windowSize);
+            return new Windowed<>(key, window);
+        }
+
+        static Window extractStoreWindow(final byte[] binaryKey,
+                                         final long windowSize) {
+            final long start = extractStoreTimestamp(binaryKey);
+            return timeWindowForSize(start, windowSize);
+        }
+
+        static int extractStoreSequence(final byte[] binaryKey) {
+            return ByteBuffer.wrap(binaryKey).getInt(binaryKey.length - SEQNUM_SIZE);
+        }
+
+        public static byte[] fromNonPrefixWindowKey(final byte[] binaryKey) {
+            final ByteBuffer buffer = ByteBuffer.allocate(PREFIX_SIZE + binaryKey.length).put(TIME_FIRST_PREFIX);
+            // Put timestamp
+            buffer.put(binaryKey, binaryKey.length - SEQNUM_SIZE - TIMESTAMP_SIZE, TIMESTAMP_SIZE);
+            buffer.put(binaryKey, 0, binaryKey.length - SEQNUM_SIZE - TIMESTAMP_SIZE);
+            buffer.put(binaryKey, binaryKey.length - SEQNUM_SIZE, SEQNUM_SIZE);
+
+            return buffer.array();
+        }
+    }
+
+    public static class KeyFirstWindowKeySchema implements KeySchema {
+
+
+
+        @Override
+        public Bytes upperRange(final Bytes key, final long to) {
+            final Bytes noPrefixBytes = new WindowKeySchema().upperRange(key, to);
+            return wrapPrefix(noPrefixBytes, true);
+        }
+
+        @Override
+        public Bytes lowerRange(final Bytes key, final long from) {
+            final Bytes noPrefixBytes = new WindowKeySchema().lowerRange(key, from);
+            // Wrap at least prefix even key is null
+            return wrapPrefix(noPrefixBytes, false);
+        }
+
+        @Override
+        public Bytes lowerRangeFixedSize(final Bytes key, final long from) {
+            final Bytes noPrefixBytes = WindowKeySchema.toStoreKeyBinary(key, Math.max(0, from), 0);
+            return wrapPrefix(noPrefixBytes, false);
+        }
+
+        @Override
+        public Bytes upperRangeFixedSize(final Bytes key, final long to) {
+            final Bytes noPrefixBytes = WindowKeySchema.toStoreKeyBinary(key, to, Integer.MAX_VALUE);
+            return wrapPrefix(noPrefixBytes, true);
+        }
+
+        @Override
+        public long segmentTimestamp(final Bytes key) {
+            return KeyFirstWindowKeySchema.extractStoreTimestamp(key.get());
+        }
+
+        @Override
+        public HasNextCondition hasNextCondition(final Bytes binaryKeyFrom,
+                                                 final Bytes binaryKeyTo,
+                                                 final long from,
+                                                 final long to,
+                                                 final boolean forward) {
+            return iterator -> {
+                while (iterator.hasNext()) {
+                    final Bytes bytes = iterator.peekNextKey();
+                    final byte prefix = extractPrefix(bytes.get());
+
+                    if (prefix != KEY_FIRST_PREFIX) {
+                        return false;
+                    }
+
+                    final Bytes keyBytes = Bytes.wrap(KeyFirstWindowKeySchema.extractStoreKeyBytes(bytes.get()));
+                    final long time = KeyFirstWindowKeySchema.extractStoreTimestamp(bytes.get());
+                    if ((binaryKeyFrom == null || keyBytes.compareTo(binaryKeyFrom) >= 0)
+                        && (binaryKeyTo == null || keyBytes.compareTo(binaryKeyTo) <= 0)
+                        && time >= from
+                        && time <= to) {
+                        return true;
+                    }
+                    iterator.next();
+                }
+                return false;
+            };
+        }
+
+        @Override
+        public <S extends Segment> List<S> segmentsToSearch(final Segments<S> segments,
+                                                            final long from,
+                                                            final long to,
+                                                            final boolean forward) {
+            return segments.segments(from, to, forward);
+        }
+
+        public static Bytes toStoreKeyBinary(final Windowed<Bytes> timeKey,
+                                             final int seqnum) {
+            return toStoreKeyBinary(timeKey.key().get(), timeKey.window().start(), seqnum);
+        }
+
+        public static <K> Bytes toStoreKeyBinary(final Windowed<K> timeKey,
+                                                 final int seqnum,
+                                                 final StateSerdes<K, ?> serdes) {
+            final byte[] serializedKey = serdes.rawKey(timeKey.key());
+            return toStoreKeyBinary(serializedKey, timeKey.window().start(), seqnum);
+        }
+
+        public static Bytes toStoreKeyBinary(final Bytes key,
+                                             final long timestamp,
+                                             final int seqnum) {
+            return toStoreKeyBinary(key.get(), timestamp, seqnum);
+        }
+
+        // package private for testing
+        public static Bytes toStoreKeyBinary(final byte[] serializedKey,
+                                      final long timestamp,
+                                      final int seqnum) {
+            final ByteBuffer buf = ByteBuffer.allocate(PREFIX_SIZE + serializedKey.length + TIMESTAMP_SIZE + SEQNUM_SIZE);
+            buf.put(KEY_FIRST_PREFIX);
+            buf.put(serializedKey);
+            buf.putLong(timestamp);
+            buf.putInt(seqnum);
+
+            return Bytes.wrap(buf.array());
+        }
+
+        static byte[] extractStoreKeyBytes(final byte[] binaryKey) {
+            final byte[] bytes = new byte[binaryKey.length - TIMESTAMP_SIZE - SEQNUM_SIZE - PREFIX_SIZE];
+            System.arraycopy(binaryKey, PREFIX_SIZE, bytes, 0, bytes.length);
+            return bytes;
+        }
+
+        public static Windowed<Bytes> fromStoreBytesKey(final byte[] binaryKey,
+                                                        final long windowSize) {
+            final Bytes key = Bytes.wrap(extractStoreKeyBytes(binaryKey));
+            final Window window = extractStoreWindow(binaryKey, windowSize);
+            return new Windowed<>(key, window);
+        }
+
+        static long extractStoreTimestamp(final byte[] binaryKey) {
+            return ByteBuffer.wrap(binaryKey).getLong(binaryKey.length - TIMESTAMP_SIZE - SEQNUM_SIZE);
+        }
+
+        static int extractStoreSequence(final byte[] binaryKey) {
+            return ByteBuffer.wrap(binaryKey).getInt(binaryKey.length - SEQNUM_SIZE);
+        }
+
+        static Window extractStoreWindow(final byte[] binaryKey,
+                                     final long windowSize) {
+            final long start = KeyFirstWindowKeySchema.extractStoreTimestamp(binaryKey);
+            return timeWindowForSize(start, windowSize);
+        }
+
+        public static <K> Windowed<K> fromStoreKey(final byte[] binaryKey,
+                                                   final long windowSize,
+                                                   final Deserializer<K> deserializer,
+                                                   final String topic) {
+            final K key = deserializer.deserialize(topic, extractStoreKeyBytes(binaryKey));
+            final Window window = extractStoreWindow(binaryKey, windowSize);
+            return new Windowed<>(key, window);
+        }
+
+        private static Bytes wrapPrefix(final Bytes noPrefixKey, final boolean upperRange) {
+            // Need to scan from prefix even key is null
+            if (noPrefixKey == null) {
+                final byte prefix = upperRange ? KEY_FIRST_PREFIX + 1 : KEY_FIRST_PREFIX;
+                final byte[] ret = ByteBuffer.allocate(PREFIX_SIZE)
+                    .put(prefix)
+                    .array();
+                return Bytes.wrap(ret);
+            }
+            final byte[] ret = ByteBuffer.allocate(PREFIX_SIZE + noPrefixKey.get().length)
+                .put(KEY_FIRST_PREFIX)
+                .put(noPrefixKey.get())
+                .array();
+            return Bytes.wrap(ret);
+        }
+
+        public static byte[] fromNonPrefixWindowKey(final byte[] binaryKey) {
+            return wrapPrefix(Bytes.wrap(binaryKey), false).get();
+        }
+    }
+}
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedSegmentedBytesStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedSegmentedBytesStore.java
new file mode 100644
index 0000000..e87af87
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedSegmentedBytesStore.java
@@ -0,0 +1,335 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.state.internals;
+
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.NoSuchElementException;
+import java.util.Optional;
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.streams.KeyValue;
+import org.apache.kafka.streams.errors.ProcessorStateException;
+import org.apache.kafka.streams.processor.internals.ChangelogRecordDeserializationHelper;
+import org.apache.kafka.streams.state.KeyValueIterator;
+import org.apache.kafka.streams.state.internals.PrefixedWindowKeySchemas.KeyFirstWindowKeySchema;
+import org.apache.kafka.streams.state.internals.PrefixedWindowKeySchemas.TimeFirstWindowKeySchema;
+import org.rocksdb.RocksDBException;
+import org.rocksdb.WriteBatch;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * RocksDB store backed by two SegmentedBytesStores which can optimize scan by time as well as window
+ * lookup for a specific key.
+ *
+ * Schema for first SegmentedBytesStore (base store) is as below:
+ *     Key schema: | timestamp + recordkey |
+ *     Value schema: | value |. Value here is determined by caller.
+ *
+ * Schema for second SegmentedBytesStore (index store) is as below:
+ *     Key schema: | record + timestamp |
+ *     Value schema: ||
+ *
+ * Operations:
+ *     Put: 1. Put to index store. 2. Put to base store.
+ *     Delete: 1. Delete from base store. 2. Delete from index store.
+ * Since we need to update two stores, failure can happen in the middle. We put in index store first
+ * to make sure if a failure happens in second step and the view is inconsistent, we can't get the
+ * value for the key. We delete from base store first to make sure if a failure happens in second step
+ * and the view is inconsistent, we can't get the value for the key.
+ *
+ * Note:
+ *     Index store can be optional if we can construct the timestamp in base store instead of looking
+ *     them up from index store.
+ *
+ */
+public class RocksDBTimeOrderedSegmentedBytesStore extends AbstractDualSchemaRocksDBSegmentedBytesStore<KeyValueSegment> {
+    private static final Logger LOG = LoggerFactory.getLogger(AbstractDualSchemaRocksDBSegmentedBytesStore.class);
+
+    private class IndexToBaseStoreIterator implements KeyValueIterator<Bytes, byte[]> {
+        private final KeyValueIterator<Bytes, byte[]> indexIterator;
+        private byte[] cachedValue;
+
+
+        IndexToBaseStoreIterator(final KeyValueIterator<Bytes, byte[]> indexIterator) {
+            this.indexIterator = indexIterator;
+        }
+
+        @Override
+        public void close() {
+            indexIterator.close();
+        }
+
+        @Override
+        public Bytes peekNextKey() {
+            if (!hasNext()) {
+                throw new NoSuchElementException();
+            }
+            return getBaseKey(indexIterator.peekNextKey());
+        }
+
+        @Override
+        public boolean hasNext() {
+            while (indexIterator.hasNext()) {
+                final Bytes key = indexIterator.peekNextKey();
+                final Bytes baseKey = getBaseKey(key);
+
+                cachedValue = get(baseKey);
+                if (cachedValue == null) {
+                    // Key not in base store, inconsistency happened and remove from index.
+                    indexIterator.next();
+                    RocksDBTimeOrderedSegmentedBytesStore.this.removeIndex(key);
+                } else {
+                    return true;
+                }
+            }
+            return false;
+        }
+
+        @Override
+        public KeyValue<Bytes, byte[]> next() {
+            if (cachedValue == null && !hasNext()) {
+                throw new NoSuchElementException();
+            }
+            final KeyValue<Bytes, byte[]> ret = indexIterator.next();
+            final byte[] value = cachedValue;
+            cachedValue = null;
+            return KeyValue.pair(getBaseKey(ret.key), value);
+        }
+
+        private Bytes getBaseKey(final Bytes indexKey) {
+            final byte[] keyBytes = KeyFirstWindowKeySchema.extractStoreKeyBytes(indexKey.get());
+            final long timestamp = KeyFirstWindowKeySchema.extractStoreTimestamp(indexKey.get());
+            final int seqnum = KeyFirstWindowKeySchema.extractStoreSequence(indexKey.get());
+            return TimeFirstWindowKeySchema.toStoreKeyBinary(keyBytes, timestamp, seqnum);
+        }
+    }
+
+    RocksDBTimeOrderedSegmentedBytesStore(final String name,
+                                          final String metricsScope,
+                                          final long retention,
+                                          final long segmentInterval,
+                                          final boolean withIndex) {
+        super(name, metricsScope, new TimeFirstWindowKeySchema(),
+            Optional.ofNullable(withIndex ? new KeyFirstWindowKeySchema() : null),
+            new KeyValueSegments(name, metricsScope, retention, segmentInterval));
+    }
+
+    public void put(final Bytes key, final long timestamp, final int seqnum, final byte[] value) {
+        final Bytes baseKey = TimeFirstWindowKeySchema.toStoreKeyBinary(key, timestamp, seqnum);
+        put(baseKey, value);
+    }
+
+    byte[] fetch(final Bytes key, final long timestamp, final int seqnum) {
+        return get(TimeFirstWindowKeySchema.toStoreKeyBinary(key, timestamp, seqnum));
+    }
+
+    @Override
+    protected KeyValue<Bytes, byte[]> getIndexKeyValue(final Bytes baseKey, final byte[] baseValue) {
+        final byte[] key = TimeFirstWindowKeySchema.extractStoreKeyBytes(baseKey.get());
+        final long timestamp = TimeFirstWindowKeySchema.extractStoreTimestamp(baseKey.get());
+        final int seqnum = TimeFirstWindowKeySchema.extractStoreSequence(baseKey.get());
+
+        return KeyValue.pair(KeyFirstWindowKeySchema.toStoreKeyBinary(key, timestamp, seqnum), new byte[0]);
+    }
+
+    @Override
+    Map<KeyValueSegment, WriteBatch> getWriteBatches(
+        final Collection<ConsumerRecord<byte[], byte[]>> records) {
+        // advance stream time to the max timestamp in the batch
+        for (final ConsumerRecord<byte[], byte[]> record : records) {
+            final long timestamp = WindowKeySchema.extractStoreTimestamp(record.key());
+            observedStreamTime = Math.max(observedStreamTime, timestamp);
+        }
+
+        final Map<KeyValueSegment, WriteBatch> writeBatchMap = new HashMap<>();
+        for (final ConsumerRecord<byte[], byte[]> record : records) {
+            final long timestamp = WindowKeySchema.extractStoreTimestamp(record.key());
+            final long segmentId = segments.segmentId(timestamp);
+            final KeyValueSegment segment = segments.getOrCreateSegmentIfLive(segmentId, context, observedStreamTime);
+            if (segment != null) {
+                ChangelogRecordDeserializationHelper.applyChecksAndUpdatePosition(
+                    record,
+                    consistencyEnabled,
+                    position
+                );
+                try {
+                    final WriteBatch batch = writeBatchMap.computeIfAbsent(segment, s -> new WriteBatch());
+
+                    // Assuming changelog record is serialized using WindowKeySchema
+                    // from ChangeLoggingTimestampedWindowBytesStore. Reconstruct key/value to restore
+                    if (hasIndex()) {
+                        final byte[] indexKey = KeyFirstWindowKeySchema.fromNonPrefixWindowKey(record.key());
+                        // Take care of tombstone
+                        final byte[] value = record.value() == null ? null : new byte[0];
+                        segment.addToBatch(new KeyValue<>(indexKey, value), batch);
+                    }
+
+                    final byte[] baseKey = TimeFirstWindowKeySchema.fromNonPrefixWindowKey(record.key());
+                    segment.addToBatch(new KeyValue<>(baseKey, record.value()), batch);
+                } catch (final RocksDBException e) {
+                    throw new ProcessorStateException("Error restoring batch to store " + name(), e);
+                }
+            }
+        }
+        return writeBatchMap;
+    }
+
+    @Override
+    public KeyValueIterator<Bytes, byte[]> fetch(final Bytes key,
+                                                 final long from,
+                                                 final long to) {
+        return fetch(key, from, to, true);
+    }
+
+    @Override
+    public KeyValueIterator<Bytes, byte[]> backwardFetch(final Bytes key,
+                                                         final long from,
+                                                         final long to) {
+        return fetch(key, from, to, false);
+    }
+
+    KeyValueIterator<Bytes, byte[]> fetch(final Bytes key,
+                                          final long from,
+                                          final long to,
+                                          final boolean forward) {
+        if (indexKeySchema.isPresent()) {
+            final List<KeyValueSegment> searchSpace = indexKeySchema.get().segmentsToSearch(segments, from, to,
+                forward);
+
+            final Bytes binaryFrom = indexKeySchema.get().lowerRangeFixedSize(key, from);
+            final Bytes binaryTo = indexKeySchema.get().upperRangeFixedSize(key, to);
+
+            return new IndexToBaseStoreIterator(new SegmentIterator<>(
+                searchSpace.iterator(),
+                indexKeySchema.get().hasNextCondition(key, key, from, to, forward),
+                binaryFrom,
+                binaryTo,
+                forward));
+        }
+
+        final List<KeyValueSegment> searchSpace = baseKeySchema.segmentsToSearch(segments, from, to,
+            forward);
+
+        final Bytes binaryFrom = baseKeySchema.lowerRangeFixedSize(key, from);
+        final Bytes binaryTo = baseKeySchema.upperRangeFixedSize(key, to);
+
+        return new SegmentIterator<>(
+            searchSpace.iterator(),
+            baseKeySchema.hasNextCondition(key, key, from, to, forward),
+            binaryFrom,
+            binaryTo,
+            forward);
+    }
+
+    @Override
+    public KeyValueIterator<Bytes, byte[]> fetch(final Bytes keyFrom,
+                                                 final Bytes keyTo,
+                                                 final long from,
+                                                 final long to) {
+        return fetch(keyFrom, keyTo, from, to, true);
+    }
+
+    @Override
+    public KeyValueIterator<Bytes, byte[]> backwardFetch(final Bytes keyFrom,
+                                                         final Bytes keyTo,
+                                                         final long from,
+                                                         final long to) {
+        return fetch(keyFrom, keyTo, from, to, false);
+    }
+
+    KeyValueIterator<Bytes, byte[]> fetch(final Bytes keyFrom,
+                                          final Bytes keyTo,
+                                          final long from,
+                                          final long to,
+                                          final boolean forward) {
+        if (keyFrom != null && keyTo != null && keyFrom.compareTo(keyTo) > 0) {
+            LOG.warn("Returning empty iterator for fetch with invalid key range: from > to. " +
+                    "This may be due to range arguments set in the wrong order, " +
+                    "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes. " +
+                    "Note that the built-in numerical serdes do not follow this for negative numbers");
+            return KeyValueIterators.emptyIterator();
+        }
+
+        if (indexKeySchema.isPresent()) {
+            final List<KeyValueSegment> searchSpace = indexKeySchema.get().segmentsToSearch(segments, from, to,
+                forward);
+
+            final Bytes binaryFrom = indexKeySchema.get().lowerRange(keyFrom, from);
+            final Bytes binaryTo = indexKeySchema.get().upperRange(keyTo, to);
+
+            return new IndexToBaseStoreIterator(new SegmentIterator<>(
+                searchSpace.iterator(),
+                indexKeySchema.get().hasNextCondition(keyFrom, keyTo, from, to, forward),
+                binaryFrom,
+                binaryTo,
+                forward));
+        }
+
+        final List<KeyValueSegment> searchSpace = baseKeySchema.segmentsToSearch(segments, from, to,
+            forward);
+
+        final Bytes binaryFrom = baseKeySchema.lowerRange(keyFrom, from);
+        final Bytes binaryTo = baseKeySchema.upperRange(keyTo, to);
+
+        return new SegmentIterator<>(
+            searchSpace.iterator(),
+            baseKeySchema.hasNextCondition(keyFrom, keyTo, from, to, forward),
+            binaryFrom,
+            binaryTo,
+            forward);
+    }
+
+
+    @Override
+    public void remove(final Bytes key, final long timestamp) {
+        throw new UnsupportedOperationException("Not supported operation");
+    }
+
+    @Override
+    public KeyValueIterator<Bytes, byte[]> fetchAll(final long timeFrom,
+                                                    final long timeTo) {
+        final List<KeyValueSegment> searchSpace = segments.segments(timeFrom, timeTo, true);
+        final Bytes binaryFrom = baseKeySchema.lowerRange(null, timeFrom);
+        final Bytes binaryTo = baseKeySchema.upperRange(null, timeTo);
+
+        return new SegmentIterator<>(
+                searchSpace.iterator(),
+                baseKeySchema.hasNextCondition(null, null, timeFrom, timeTo, true),
+                binaryFrom,
+                binaryTo,
+                true);
+    }
+
+    @Override
+    public KeyValueIterator<Bytes, byte[]> backwardFetchAll(final long timeFrom,
+                                                            final long timeTo) {
+        final List<KeyValueSegment> searchSpace = segments.segments(timeFrom, timeTo, false);
+        final Bytes binaryFrom = baseKeySchema.lowerRange(null, timeFrom);
+        final Bytes binaryTo = baseKeySchema.upperRange(null, timeTo);
+
+        return new SegmentIterator<>(
+                searchSpace.iterator(),
+                baseKeySchema.hasNextCondition(null, null, timeFrom, timeTo, false),
+                binaryFrom,
+                binaryTo,
+                false);
+    }
+}
\ No newline at end of file
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedWindowStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedWindowStore.java
new file mode 100644
index 0000000..a174e51
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedWindowStore.java
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.state.internals;
+
+import java.util.Objects;
+import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.streams.kstream.Windowed;
+import org.apache.kafka.streams.processor.StateStore;
+import org.apache.kafka.streams.processor.StateStoreContext;
+import org.apache.kafka.streams.state.KeyValueIterator;
+import org.apache.kafka.streams.state.WindowStore;
+import org.apache.kafka.streams.state.WindowStoreIterator;
+import org.apache.kafka.streams.state.internals.PrefixedWindowKeySchemas.TimeFirstWindowKeySchema;
+
+
+public class RocksDBTimeOrderedWindowStore
+    extends WrappedStateStore<RocksDBTimeOrderedSegmentedBytesStore, Object, Object>
+    implements WindowStore<Bytes, byte[]> {
+
+    private final boolean retainDuplicates;
+    private final long windowSize;
+    private int seqnum = 0;
+
+    RocksDBTimeOrderedWindowStore(
+        final RocksDBTimeOrderedSegmentedBytesStore store,
+        final boolean retainDuplicates,
+        final long windowSize
+    ) {
+        super(store);
+        Objects.requireNonNull(store, "store is null");
+        this.retainDuplicates = retainDuplicates;
+        this.windowSize = windowSize;
+    }
+
+    @Override
+    public void init(final StateStoreContext context, final StateStore root) {
+        wrapped().init(context, root);
+    }
+
+    @Override
+    public void flush() {
+        wrapped().flush();
+    }
+
+    @Override
+    public void close() {
+        wrapped().close();
+    }
+
+    @Override
+    public boolean persistent() {
+        return wrapped().persistent();
+    }
+
+    @Override
+    public boolean isOpen() {
+        return wrapped().isOpen();
+    }
+
+    @Override
+    public void put(final Bytes key, final byte[] value, final long windowStartTimestamp) {
+        // Skip if value is null and duplicates are allowed since this delete is a no-op
+        if (!(value == null && retainDuplicates)) {
+            maybeUpdateSeqnumForDups();
+            wrapped().put(key, windowStartTimestamp, seqnum, value);
+        }
+    }
+
+    @Override
+    public byte[] fetch(final Bytes key, final long timestamp) {
+        return wrapped().fetch(key, timestamp, seqnum);
+    }
+
+    @Override
+    public WindowStoreIterator<byte[]> fetch(final Bytes key, final long timeFrom, final long timeTo) {
+        final KeyValueIterator<Bytes, byte[]> bytesIterator = wrapped().fetch(key, timeFrom, timeTo);
+        return new WindowStoreIteratorWrapper(bytesIterator,
+            windowSize,
+            TimeFirstWindowKeySchema::extractStoreTimestamp,
+            TimeFirstWindowKeySchema::fromStoreBytesKey).valuesIterator();
+    }
+
+    @Override
+    public WindowStoreIterator<byte[]> backwardFetch(final Bytes key, final long timeFrom, final long timeTo) {
+        final KeyValueIterator<Bytes, byte[]> bytesIterator = wrapped().backwardFetch(key, timeFrom, timeTo);
+        return new WindowStoreIteratorWrapper(bytesIterator,
+            windowSize,
+            TimeFirstWindowKeySchema::extractStoreTimestamp,
+            TimeFirstWindowKeySchema::fromStoreBytesKey).valuesIterator();
+    }
+
+    @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> fetch(final Bytes keyFrom,
+                                                           final Bytes keyTo,
+                                                           final long timeFrom,
+                                                           final long timeTo) {
+        final KeyValueIterator<Bytes, byte[]> bytesIterator = wrapped().fetch(keyFrom, keyTo, timeFrom, timeTo);
+        return new WindowStoreIteratorWrapper(bytesIterator,
+            windowSize,
+            TimeFirstWindowKeySchema::extractStoreTimestamp,
+            TimeFirstWindowKeySchema::fromStoreBytesKey).keyValueIterator();
+    }
+
+    @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> backwardFetch(final Bytes keyFrom,
+                                                                   final Bytes keyTo,
+                                                                   final long timeFrom,
+                                                                   final long timeTo) {
+        final KeyValueIterator<Bytes, byte[]> bytesIterator = wrapped().backwardFetch(keyFrom, keyTo, timeFrom, timeTo);
+        return new WindowStoreIteratorWrapper(bytesIterator,
+            windowSize,
+            TimeFirstWindowKeySchema::extractStoreTimestamp,
+            TimeFirstWindowKeySchema::fromStoreBytesKey).keyValueIterator();
+    }
+
+    @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> all() {
+        final KeyValueIterator<Bytes, byte[]> bytesIterator = wrapped().all();
+        return new WindowStoreIteratorWrapper(bytesIterator,
+            windowSize,
+            TimeFirstWindowKeySchema::extractStoreTimestamp,
+            TimeFirstWindowKeySchema::fromStoreBytesKey).keyValueIterator();
+    }
+
+    @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> backwardAll() {
+        final KeyValueIterator<Bytes, byte[]> bytesIterator = wrapped().backwardAll();
+        return new WindowStoreIteratorWrapper(bytesIterator,
+            windowSize,
+            TimeFirstWindowKeySchema::extractStoreTimestamp,
+            TimeFirstWindowKeySchema::fromStoreBytesKey).keyValueIterator();
+    }
+
+    @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> fetchAll(final long timeFrom, final long timeTo) {
+        final KeyValueIterator<Bytes, byte[]> bytesIterator = wrapped().fetchAll(timeFrom, timeTo);
+        return new WindowStoreIteratorWrapper(bytesIterator,
+            windowSize,
+            TimeFirstWindowKeySchema::extractStoreTimestamp,
+            TimeFirstWindowKeySchema::fromStoreBytesKey).keyValueIterator();
+    }
+
+    @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> backwardFetchAll(final long timeFrom, final long timeTo) {
+        final KeyValueIterator<Bytes, byte[]> bytesIterator = wrapped().backwardFetchAll(timeFrom, timeTo);
+        return new WindowStoreIteratorWrapper(bytesIterator,
+            windowSize,
+            TimeFirstWindowKeySchema::extractStoreTimestamp,
+            TimeFirstWindowKeySchema::fromStoreBytesKey).keyValueIterator();
+    }
+
+    private void maybeUpdateSeqnumForDups() {
+        if (retainDuplicates) {
+            seqnum = (seqnum + 1) & 0x7FFFFFFF;
+        }
+    }
+}
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDbIndexedTimeOrderedWindowBytesStoreSupplier.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDbIndexedTimeOrderedWindowBytesStoreSupplier.java
new file mode 100644
index 0000000..84d8a80
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDbIndexedTimeOrderedWindowBytesStoreSupplier.java
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.state.internals;
+
+import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.streams.state.WindowBytesStoreSupplier;
+import org.apache.kafka.streams.state.WindowStore;
+
+public class RocksDbIndexedTimeOrderedWindowBytesStoreSupplier implements WindowBytesStoreSupplier {
+    public enum WindowStoreTypes {
+        DEFAULT_WINDOW_STORE,
+        INDEXED_WINDOW_STORE
+    }
+
+    private final String name;
+    private final long retentionPeriod;
+    private final long segmentInterval;
+    private final long windowSize;
+    private final boolean retainDuplicates;
+    private final WindowStoreTypes windowStoreType;
+
+    public RocksDbIndexedTimeOrderedWindowBytesStoreSupplier(final String name,
+                                           final long retentionPeriod,
+                                           final long segmentInterval,
+                                           final long windowSize,
+                                           final boolean retainDuplicates,
+                                           final boolean withIndex) {
+        this(name, retentionPeriod, segmentInterval, windowSize, retainDuplicates,
+            withIndex
+                ? WindowStoreTypes.INDEXED_WINDOW_STORE
+                : WindowStoreTypes.DEFAULT_WINDOW_STORE);
+    }
+
+    public RocksDbIndexedTimeOrderedWindowBytesStoreSupplier(final String name,
+                                           final long retentionPeriod,
+                                           final long segmentInterval,
+                                           final long windowSize,
+                                           final boolean retainDuplicates,
+                                           final WindowStoreTypes windowStoreType) {
+        this.name = name;
+        this.retentionPeriod = retentionPeriod;
+        this.segmentInterval = segmentInterval;
+        this.windowSize = windowSize;
+        this.retainDuplicates = retainDuplicates;
+        this.windowStoreType = windowStoreType;
+    }
+
+    @Override
+    public String name() {
+        return name;
+    }
+
+    @Override
+    public WindowStore<Bytes, byte[]> get() {
+        switch (windowStoreType) {
+            case DEFAULT_WINDOW_STORE:
+                return new RocksDBTimeOrderedWindowStore(
+                    new RocksDBTimeOrderedSegmentedBytesStore(
+                        name,
+                        metricsScope(),
+                        retentionPeriod,
+                        segmentInterval,
+                        false),
+                    retainDuplicates,
+                    windowSize);
+            case INDEXED_WINDOW_STORE:
+                return new RocksDBTimeOrderedWindowStore(
+                    new RocksDBTimeOrderedSegmentedBytesStore(
+                        name,
+                        metricsScope(),
+                        retentionPeriod,
+                        segmentInterval,
+                        true),
+                    retainDuplicates,
+                    windowSize);
+            default:
+                throw new IllegalArgumentException("invalid window store type: " + windowStoreType);
+        }
+    }
+
+    @Override
+    public String metricsScope() {
+        return "rocksdb-window";
+    }
+
+    @Override
+    public long segmentIntervalMs() {
+        return segmentInterval;
+    }
+
+    @Override
+    public long windowSize() {
+        return windowSize;
+    }
+
+    @Override
+    public boolean retainDuplicates() {
+        return retainDuplicates;
+    }
+
+    @Override
+    public long retentionPeriod() {
+        return retentionPeriod;
+    }
+
+    @Override
+    public String toString() {
+        return "RocksDbIndexedTimeOrderedWindowBytesStoreSupplier{" +
+                   "name='" + name + '\'' +
+                   ", retentionPeriod=" + retentionPeriod +
+                   ", segmentInterval=" + segmentInterval +
+                   ", windowSize=" + windowSize +
+                   ", retainDuplicates=" + retainDuplicates +
+                   ", windowStoreType=" + windowStoreType +
+                   '}';
+    }
+}
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/SegmentedBytesStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/SegmentedBytesStore.java
index 4519929..80b5a91 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/SegmentedBytesStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/SegmentedBytesStore.java
@@ -210,9 +210,10 @@ public interface SegmentedBytesStore extends StateStore {
          * @param binaryKeyTo   the last key in the range
          * @param from          starting time range
          * @param to            ending time range
+         * @param forward       forward or backward
          * @return
          */
-        HasNextCondition hasNextCondition(final Bytes binaryKeyFrom, final Bytes binaryKeyTo, final long from, final long to);
+        HasNextCondition hasNextCondition(final Bytes binaryKeyFrom, final Bytes binaryKeyTo, final long from, final long to, final boolean forward);
 
         /**
          * Used during {@link SegmentedBytesStore#fetch(Bytes, long, long)} operations to determine
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/SessionKeySchema.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/SessionKeySchema.java
index 8bb50e5..d4196a9 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/SessionKeySchema.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/SessionKeySchema.java
@@ -48,6 +48,9 @@ public class SessionKeySchema implements SegmentedBytesStore.KeySchema {
 
     @Override
     public Bytes upperRange(final Bytes key, final long to) {
+        if (key == null) {
+            return null;
+        }
         final byte[] maxSuffix = ByteBuffer.allocate(SUFFIX_SIZE)
             // the end timestamp can be as large as possible as long as it's larger than start time
             .putLong(Long.MAX_VALUE)
@@ -59,6 +62,9 @@ public class SessionKeySchema implements SegmentedBytesStore.KeySchema {
 
     @Override
     public Bytes lowerRange(final Bytes key, final long from) {
+        if (key == null) {
+            return null;
+        }
         return OrderedBytes.lowerRange(key, MIN_SUFFIX);
     }
 
@@ -68,7 +74,7 @@ public class SessionKeySchema implements SegmentedBytesStore.KeySchema {
     }
 
     @Override
-    public HasNextCondition hasNextCondition(final Bytes binaryKeyFrom, final Bytes binaryKeyTo, final long from, final long to) {
+    public HasNextCondition hasNextCondition(final Bytes binaryKeyFrom, final Bytes binaryKeyTo, final long from, final long to, final boolean forward) {
         return iterator -> {
             while (iterator.hasNext()) {
                 final Bytes bytes = iterator.peekNextKey();
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/WindowKeySchema.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/WindowKeySchema.java
index 5834f94..f263eff 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/WindowKeySchema.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/WindowKeySchema.java
@@ -41,6 +41,9 @@ public class WindowKeySchema implements RocksDBSegmentedBytesStore.KeySchema {
 
     @Override
     public Bytes upperRange(final Bytes key, final long to) {
+        if (key == null) {
+            return null;
+        }
         final byte[] maxSuffix = ByteBuffer.allocate(SUFFIX_SIZE)
             .putLong(to)
             .putInt(Integer.MAX_VALUE)
@@ -51,6 +54,9 @@ public class WindowKeySchema implements RocksDBSegmentedBytesStore.KeySchema {
 
     @Override
     public Bytes lowerRange(final Bytes key, final long from) {
+        if (key == null) {
+            return null;
+        }
         return OrderedBytes.lowerRange(key, MIN_SUFFIX);
     }
 
@@ -73,7 +79,8 @@ public class WindowKeySchema implements RocksDBSegmentedBytesStore.KeySchema {
     public HasNextCondition hasNextCondition(final Bytes binaryKeyFrom,
                                              final Bytes binaryKeyTo,
                                              final long from,
-                                             final long to) {
+                                             final long to,
+                                             final boolean forward) {
         return iterator -> {
             while (iterator.hasNext()) {
                 final Bytes bytes = iterator.peekNextKey();
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/WindowStoreIteratorWrapper.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/WindowStoreIteratorWrapper.java
index 14acb13..99a68a2 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/WindowStoreIteratorWrapper.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/WindowStoreIteratorWrapper.java
@@ -16,6 +16,8 @@
  */
 package org.apache.kafka.streams.state.internals;
 
+import java.util.function.BiFunction;
+import java.util.function.Function;
 import org.apache.kafka.common.utils.Bytes;
 import org.apache.kafka.streams.KeyValue;
 import org.apache.kafka.streams.kstream.Windowed;
@@ -26,32 +28,46 @@ class WindowStoreIteratorWrapper {
 
     private final KeyValueIterator<Bytes, byte[]> bytesIterator;
     private final long windowSize;
+    private final Function<byte[], Long> timestampExtractor;
+    private final BiFunction<byte[], Long, Windowed<Bytes>> windowConstructor;
 
     WindowStoreIteratorWrapper(final KeyValueIterator<Bytes, byte[]> bytesIterator,
                                final long windowSize) {
+        this(bytesIterator, windowSize, WindowKeySchema::extractStoreTimestamp, WindowKeySchema::fromStoreBytesKey);
+    }
+
+    WindowStoreIteratorWrapper(final KeyValueIterator<Bytes, byte[]> bytesIterator,
+                               final long windowSize,
+                               final Function<byte[], Long> timestampExtractor,
+                               final BiFunction<byte[], Long, Windowed<Bytes>> windowConstructor) {
         this.bytesIterator = bytesIterator;
         this.windowSize = windowSize;
+        this.timestampExtractor = timestampExtractor;
+        this.windowConstructor = windowConstructor;
     }
 
     public WindowStoreIterator<byte[]> valuesIterator() {
-        return new WrappedWindowStoreIterator(bytesIterator);
+        return new WrappedWindowStoreIterator(bytesIterator, timestampExtractor);
     }
 
     public KeyValueIterator<Windowed<Bytes>, byte[]> keyValueIterator() {
-        return new WrappedKeyValueIterator(bytesIterator, windowSize);
+        return new WrappedKeyValueIterator(bytesIterator, windowSize, windowConstructor);
     }
 
     private static class WrappedWindowStoreIterator implements WindowStoreIterator<byte[]> {
         final KeyValueIterator<Bytes, byte[]> bytesIterator;
+        final Function<byte[], Long> timestampExtractor;
 
         WrappedWindowStoreIterator(
-            final KeyValueIterator<Bytes, byte[]> bytesIterator) {
+            final KeyValueIterator<Bytes, byte[]> bytesIterator,
+            final Function<byte[], Long> timestampExtractor) {
             this.bytesIterator = bytesIterator;
+            this.timestampExtractor = timestampExtractor;
         }
 
         @Override
         public Long peekNextKey() {
-            return WindowKeySchema.extractStoreTimestamp(bytesIterator.peekNextKey().get());
+            return timestampExtractor.apply(bytesIterator.peekNextKey().get());
         }
 
         @Override
@@ -62,7 +78,7 @@ class WindowStoreIteratorWrapper {
         @Override
         public KeyValue<Long, byte[]> next() {
             final KeyValue<Bytes, byte[]> next = bytesIterator.next();
-            final long timestamp = WindowKeySchema.extractStoreTimestamp(next.key.get());
+            final long timestamp = timestampExtractor.apply(next.key.get());
             return KeyValue.pair(timestamp, next.value);
         }
 
@@ -75,17 +91,20 @@ class WindowStoreIteratorWrapper {
     private static class WrappedKeyValueIterator implements KeyValueIterator<Windowed<Bytes>, byte[]> {
         final KeyValueIterator<Bytes, byte[]> bytesIterator;
         final long windowSize;
+        final BiFunction<byte[], Long, Windowed<Bytes>> windowConstructor;
 
         WrappedKeyValueIterator(final KeyValueIterator<Bytes, byte[]> bytesIterator,
-                                final long windowSize) {
+                                final long windowSize,
+                                final BiFunction<byte[], Long, Windowed<Bytes>> windowConstructor) {
             this.bytesIterator = bytesIterator;
             this.windowSize = windowSize;
+            this.windowConstructor = windowConstructor;
         }
 
         @Override
         public Windowed<Bytes> peekNextKey() {
             final byte[] nextKey = bytesIterator.peekNextKey().get();
-            return WindowKeySchema.fromStoreBytesKey(nextKey, windowSize);
+            return windowConstructor.apply(nextKey, windowSize);
         }
 
         @Override
@@ -96,7 +115,7 @@ class WindowStoreIteratorWrapper {
         @Override
         public KeyValue<Windowed<Bytes>, byte[]> next() {
             final KeyValue<Bytes, byte[]> next = bytesIterator.next();
-            return KeyValue.pair(WindowKeySchema.fromStoreBytesKey(next.key.get(), windowSize), next.value);
+            return KeyValue.pair(windowConstructor.apply(next.key.get(), windowSize), next.value);
         }
 
         @Override
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractDualSchemaRocksDBSegmentedBytesStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractDualSchemaRocksDBSegmentedBytesStoreTest.java
new file mode 100644
index 0000000..e8d578d
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractDualSchemaRocksDBSegmentedBytesStoreTest.java
@@ -0,0 +1,1128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.state.internals;
+
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.common.Metric;
+import org.apache.kafka.common.MetricName;
+import org.apache.kafka.common.header.Headers;
+import org.apache.kafka.common.header.internals.RecordHeader;
+import org.apache.kafka.common.header.internals.RecordHeaders;
+import org.apache.kafka.common.metrics.Metrics;
+import org.apache.kafka.common.record.RecordBatch;
+import org.apache.kafka.common.record.TimestampType;
+import org.apache.kafka.common.serialization.Serdes;
+import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.MockTime;
+import org.apache.kafka.common.utils.SystemTime;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.common.utils.Utils;
+import org.apache.kafka.streams.KeyValue;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.StreamsConfig.InternalConfig;
+import org.apache.kafka.streams.kstream.Window;
+import org.apache.kafka.streams.kstream.Windowed;
+import org.apache.kafka.streams.kstream.internals.TimeWindow;
+import org.apache.kafka.streams.processor.StateStoreContext;
+import org.apache.kafka.streams.processor.internals.ChangelogRecordDeserializationHelper;
+import org.apache.kafka.streams.processor.internals.MockStreamsMetrics;
+import org.apache.kafka.streams.processor.internals.ProcessorRecordContext;
+import org.apache.kafka.streams.processor.internals.Task.TaskType;
+import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
+import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender;
+import org.apache.kafka.streams.query.Position;
+import org.apache.kafka.streams.state.KeyValueIterator;
+import org.apache.kafka.streams.state.StateSerdes;
+import org.apache.kafka.streams.state.internals.PrefixedWindowKeySchemas.KeyFirstWindowKeySchema;
+import org.apache.kafka.streams.state.internals.PrefixedWindowKeySchemas.TimeFirstWindowKeySchema;
+import org.apache.kafka.streams.state.internals.SegmentedBytesStore.KeySchema;
+import org.apache.kafka.test.InternalMockProcessorContext;
+import org.apache.kafka.test.MockRecordCollector;
+import org.apache.kafka.test.StreamsTestUtils;
+import org.apache.kafka.test.TestUtils;
+import org.hamcrest.Matchers;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.rocksdb.WriteBatch;
+
+import java.io.File;
+import java.text.SimpleDateFormat;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Date;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.Properties;
+import java.util.Set;
+import java.util.SimpleTimeZone;
+
+import static java.util.Arrays.asList;
+import static org.apache.kafka.common.utils.Utils.mkEntry;
+import static org.apache.kafka.common.utils.Utils.mkMap;
+import static org.apache.kafka.streams.state.internals.WindowKeySchema.timeWindowForSize;
+import static org.hamcrest.CoreMatchers.equalTo;
+import static org.hamcrest.CoreMatchers.hasItem;
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.CoreMatchers.nullValue;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.hasEntry;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+public abstract class AbstractDualSchemaRocksDBSegmentedBytesStoreTest<S extends Segment> {
+
+    private final long windowSizeForTimeWindow = 500;
+    private InternalMockProcessorContext context;
+    private AbstractDualSchemaRocksDBSegmentedBytesStore<S> bytesStore;
+    private File stateDir;
+    private final Window[] windows = new Window[4];
+    private Window nextSegmentWindow;
+
+    final long retention = 1000;
+    final long segmentInterval = 60_000L;
+    final String storeName = "bytes-store";
+
+    @Before
+    public void before() {
+        if (getBaseSchema() instanceof TimeFirstWindowKeySchema) {
+            windows[0] = timeWindowForSize(10L, windowSizeForTimeWindow);
+            windows[1] = timeWindowForSize(500L, windowSizeForTimeWindow);
+            windows[2] = timeWindowForSize(1_000L, windowSizeForTimeWindow);
+            windows[3] = timeWindowForSize(60_000L, windowSizeForTimeWindow);
+            // All four of the previous windows will go into segment 1.
+            // The nextSegmentWindow is computed be a high enough time that when it gets written
+            // to the segment store, it will advance stream time past the first segment's retention time and
+            // expire it.
+            nextSegmentWindow = timeWindowForSize(segmentInterval + retention, windowSizeForTimeWindow);
+        }
+
+        bytesStore = getBytesStore();
+
+        stateDir = TestUtils.tempDirectory();
+        context = new InternalMockProcessorContext<>(
+            stateDir,
+            Serdes.String(),
+            Serdes.Long(),
+            new MockRecordCollector(),
+            new ThreadCache(new LogContext("testCache "), 0, new MockStreamsMetrics(new Metrics()))
+        );
+        bytesStore.init((StateStoreContext) context, bytesStore);
+    }
+
+    @After
+    public void close() {
+        bytesStore.close();
+    }
+
+    abstract AbstractDualSchemaRocksDBSegmentedBytesStore<S> getBytesStore();
+
+    abstract AbstractSegments<S> newSegments();
+
+    abstract KeySchema getBaseSchema();
+
+    abstract KeySchema getIndexSchema();
+
+    @Test
+    public void shouldPutAndFetch() {
+        final String keyA = "a";
+        final String keyB = "b";
+        final String keyC = "c";
+        bytesStore.put(serializeKey(new Windowed<>(keyA, windows[0])), serializeValue(10));
+        bytesStore.put(serializeKey(new Windowed<>(keyA, windows[1])), serializeValue(50));
+        bytesStore.put(serializeKey(new Windowed<>(keyB, windows[2])), serializeValue(100));
+        bytesStore.put(serializeKey(new Windowed<>(keyC, windows[3])), serializeValue(200));
+
+        try (final KeyValueIterator<Bytes, byte[]> values = bytesStore.fetch(
+            Bytes.wrap(keyA.getBytes()), 0, windows[2].start())) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyA, windows[0]), 10L),
+                KeyValue.pair(new Windowed<>(keyA, windows[1]), 50L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+
+        try (final KeyValueIterator<Bytes, byte[]> values = bytesStore.fetch(
+            Bytes.wrap(keyA.getBytes()), Bytes.wrap(keyB.getBytes()), 0, windows[2].start())) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyA, windows[0]), 10L),
+                KeyValue.pair(new Windowed<>(keyA, windows[1]), 50L),
+                KeyValue.pair(new Windowed<>(keyB, windows[2]), 100L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+
+        try (final KeyValueIterator<Bytes, byte[]> values = bytesStore.fetch(
+            null, Bytes.wrap(keyB.getBytes()), 0, windows[2].start())) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyA, windows[0]), 10L),
+                KeyValue.pair(new Windowed<>(keyA, windows[1]), 50L),
+                KeyValue.pair(new Windowed<>(keyB, windows[2]), 100L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+
+        try (final KeyValueIterator<Bytes, byte[]> values = bytesStore.fetch(
+            Bytes.wrap(keyB.getBytes()), null, 0, windows[3].start())) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyB, windows[2]), 100L),
+                KeyValue.pair(new Windowed<>(keyC, windows[3]), 200L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+
+        try (final KeyValueIterator<Bytes, byte[]> values = bytesStore.fetch(
+            null, null, 0, windows[3].start())) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyA, windows[0]), 10L),
+                KeyValue.pair(new Windowed<>(keyA, windows[1]), 50L),
+                KeyValue.pair(new Windowed<>(keyB, windows[2]), 100L),
+                KeyValue.pair(new Windowed<>(keyC, windows[3]), 200L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+    }
+
+    @Test
+    public void shouldPutAndBackwardFetch() {
+        final String keyA = "a";
+        final String keyB = "b";
+        final String keyC = "c";
+
+        bytesStore.put(serializeKey(new Windowed<>(keyA, windows[0])), serializeValue(10));
+        bytesStore.put(serializeKey(new Windowed<>(keyA, windows[1])), serializeValue(50));
+        bytesStore.put(serializeKey(new Windowed<>(keyB, windows[2])), serializeValue(100));
+        bytesStore.put(serializeKey(new Windowed<>(keyC, windows[3])), serializeValue(200));
+
+        try (final KeyValueIterator<Bytes, byte[]> values = bytesStore.backwardFetch(
+            Bytes.wrap(keyA.getBytes()), 0, windows[2].start())) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyA, windows[1]), 50L),
+                KeyValue.pair(new Windowed<>(keyA, windows[0]), 10L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+
+        try (final KeyValueIterator<Bytes, byte[]> values = bytesStore.backwardFetch(
+            Bytes.wrap(keyA.getBytes()), Bytes.wrap(keyB.getBytes()), 0, windows[2].start())) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyB, windows[2]), 100L),
+                KeyValue.pair(new Windowed<>(keyA, windows[1]), 50L),
+                KeyValue.pair(new Windowed<>(keyA, windows[0]), 10L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+
+        try (final KeyValueIterator<Bytes, byte[]> values = bytesStore.backwardFetch(
+            null, Bytes.wrap(keyB.getBytes()), 0, windows[2].start())) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyB, windows[2]), 100L),
+                KeyValue.pair(new Windowed<>(keyA, windows[1]), 50L),
+                KeyValue.pair(new Windowed<>(keyA, windows[0]), 10L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+
+        try (final KeyValueIterator<Bytes, byte[]> values = bytesStore.backwardFetch(
+            Bytes.wrap(keyB.getBytes()), null, 0, windows[3].start())) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyC, windows[3]), 200L),
+                KeyValue.pair(new Windowed<>(keyB, windows[2]), 100L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+
+        try (final KeyValueIterator<Bytes, byte[]> values = bytesStore.backwardFetch(
+            null, null, 0, windows[3].start())) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyC, windows[3]), 200L),
+                KeyValue.pair(new Windowed<>(keyB, windows[2]), 100L),
+                KeyValue.pair(new Windowed<>(keyA, windows[1]), 50L),
+                KeyValue.pair(new Windowed<>(keyA, windows[0]), 10L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+    }
+
+    @Test
+    public void shouldPutAndFetchWithPrefixKey() {
+        final String keyA = "a";
+        final String keyB = "aa";
+        final String keyC = "aaa";
+
+        final Window maxWindow = new TimeWindow(Long.MAX_VALUE - 1, Long.MAX_VALUE);
+        final Bytes serializedKeyA = serializeKey(new Windowed<>(keyA, maxWindow), false, Integer.MAX_VALUE);
+        final Bytes serializedKeyB = serializeKey(new Windowed<>(keyB, maxWindow), false, Integer.MAX_VALUE);
+        final Bytes serializedKeyC = serializeKey(new Windowed<>(keyC, maxWindow), false, Integer.MAX_VALUE);
+
+        // Key are in decrease order but base storage binary key are in increase order
+        assertTrue(serializedKeyA.compareTo(serializedKeyB) > 0);
+        assertTrue(serializedKeyB.compareTo(serializedKeyC) > 0);
+
+        bytesStore.put(serializedKeyA, serializeValue(10));
+        bytesStore.put(serializedKeyB, serializeValue(50));
+        bytesStore.put(serializedKeyC, serializeValue(100));
+
+        try (final KeyValueIterator<Bytes, byte[]> values = bytesStore.fetch(
+            Bytes.wrap(keyA.getBytes()), 0, Long.MAX_VALUE)) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyA, maxWindow), 10L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+
+        try (final KeyValueIterator<Bytes, byte[]> values = bytesStore.fetch(
+            Bytes.wrap(keyA.getBytes()), Bytes.wrap(keyB.getBytes()), 0, Long.MAX_VALUE)) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyB, maxWindow), 50L),
+                KeyValue.pair(new Windowed<>(keyA, maxWindow), 10L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+
+        // KeyC should be ignored and KeyA should be included even in storage, KeyC is before KeyB
+        // and KeyA is after KeyB
+        try (final KeyValueIterator<Bytes, byte[]> values = bytesStore.fetch(
+            null, Bytes.wrap(keyB.getBytes()), 0, Long.MAX_VALUE)) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyB, maxWindow), 50L),
+                KeyValue.pair(new Windowed<>(keyA, maxWindow), 10L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+
+        // KeyC should be included even in storage KeyC is before KeyB
+        try (final KeyValueIterator<Bytes, byte[]> values = bytesStore.fetch(
+            Bytes.wrap(keyB.getBytes()), null, 0, Long.MAX_VALUE)) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyC, maxWindow), 100L),
+                KeyValue.pair(new Windowed<>(keyB, maxWindow), 50L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+
+        try (final KeyValueIterator<Bytes, byte[]> values = bytesStore.fetch(
+            null, null, 0, Long.MAX_VALUE)) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyC, maxWindow), 100L),
+                KeyValue.pair(new Windowed<>(keyB, maxWindow), 50L),
+                KeyValue.pair(new Windowed<>(keyA, maxWindow), 10L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+    }
+
+    @Test
+    public void shouldPutAndBackwardFetchWithPrefix() {
+        final String keyA = "a";
+        final String keyB = "aa";
+        final String keyC = "aaa";
+
+        final Window maxWindow = new TimeWindow(Long.MAX_VALUE - 1, Long.MAX_VALUE);
+        final Bytes serializedKeyA = serializeKey(new Windowed<>(keyA, maxWindow), false, Integer.MAX_VALUE);
+        final Bytes serializedKeyB = serializeKey(new Windowed<>(keyB, maxWindow), false, Integer.MAX_VALUE);
+        final Bytes serializedKeyC = serializeKey(new Windowed<>(keyC, maxWindow), false, Integer.MAX_VALUE);
+
+        // Key are in decrease order but base storage binary key are in increase order
+        assertTrue(serializedKeyA.compareTo(serializedKeyB) > 0);
+        assertTrue(serializedKeyB.compareTo(serializedKeyC) > 0);
+
+        bytesStore.put(serializedKeyA, serializeValue(10));
+        bytesStore.put(serializedKeyB, serializeValue(50));
+        bytesStore.put(serializedKeyC, serializeValue(100));
+
+        try (final KeyValueIterator<Bytes, byte[]> values = bytesStore.backwardFetch(
+            Bytes.wrap(keyA.getBytes()), 0, Long.MAX_VALUE)) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyA, maxWindow), 10L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+
+        try (final KeyValueIterator<Bytes, byte[]> values = bytesStore.backwardFetch(
+            Bytes.wrap(keyA.getBytes()), Bytes.wrap(keyB.getBytes()), 0, Long.MAX_VALUE)) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyA, maxWindow), 10L),
+                KeyValue.pair(new Windowed<>(keyB, maxWindow), 50L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+
+        try (final KeyValueIterator<Bytes, byte[]> values = bytesStore.backwardFetch(
+            null, Bytes.wrap(keyB.getBytes()), 0, Long.MAX_VALUE)) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyA, maxWindow), 10L),
+                KeyValue.pair(new Windowed<>(keyB, maxWindow), 50L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+
+        try (final KeyValueIterator<Bytes, byte[]> values = bytesStore.backwardFetch(
+            Bytes.wrap(keyB.getBytes()), null, 0, Long.MAX_VALUE)) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyB, maxWindow), 50L),
+                KeyValue.pair(new Windowed<>(keyC, maxWindow), 100L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+
+        try (final KeyValueIterator<Bytes, byte[]> values = bytesStore.backwardFetch(
+            null, null, 0, Long.MAX_VALUE)) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyA, maxWindow), 10L),
+                KeyValue.pair(new Windowed<>(keyB, maxWindow), 50L),
+                KeyValue.pair(new Windowed<>(keyC, maxWindow), 100L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+    }
+
+    @Test
+    public void shouldSkipAndRemoveDanglingIndex() {
+        final String keyA = "a";
+        final String keyB = "b";
+        if (getIndexSchema() == null) {
+            assertThrows(
+                IllegalStateException.class,
+                () -> bytesStore.putIndex(Bytes.wrap(keyA.getBytes()), new byte[0])
+            );
+        } else {
+            // Only put to index
+            final Bytes serializedKey1 = serializeKeyForIndex(new Windowed<>(keyA, windows[1]));
+            bytesStore.putIndex(serializedKey1, new byte[0]);
+
+            byte[] value = bytesStore.getIndex(serializedKey1);
+            assertThat(Bytes.wrap(value), is(Bytes.wrap(new byte[0])));
+
+            final Bytes serializedKey0 = serializeKey(new Windowed<>(keyA, windows[0]));
+            bytesStore.put(serializedKey0, serializeValue(10L));
+
+            final Bytes serializedKey2 = serializeKey(new Windowed<>(keyB, windows[2]));
+            bytesStore.put(serializedKey2, serializeValue(20L));
+
+            try (final KeyValueIterator<Bytes, byte[]> results = bytesStore.fetch(
+                Bytes.wrap(keyA.getBytes()), Bytes.wrap(keyB.getBytes()), 1, 2000)) {
+
+                final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                    KeyValue.pair(new Windowed<>(keyA, windows[0]), 10L),
+                    KeyValue.pair(new Windowed<>(keyB, windows[2]), 20L)
+                );
+                assertEquals(expected, toList(results));
+            }
+
+            // Dangling index should be deleted.
+            value = bytesStore.getIndex(serializedKey1);
+            assertThat(value, is(nullValue()));
+        }
+    }
+
+    @Test
+    public void shouldFindValuesWithinRange() {
+        final String key = "a";
+        bytesStore.put(serializeKey(new Windowed<>(key, windows[0])), serializeValue(10));
+        bytesStore.put(serializeKey(new Windowed<>(key, windows[1])), serializeValue(50));
+        bytesStore.put(serializeKey(new Windowed<>(key, windows[2])), serializeValue(100));
+        try (final KeyValueIterator<Bytes, byte[]> results = bytesStore.fetch(Bytes.wrap(key.getBytes()), 1, 999)) {
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(key, windows[0]), 10L),
+                KeyValue.pair(new Windowed<>(key, windows[1]), 50L)
+            );
+
+            assertEquals(expected, toList(results));
+        }
+    }
+
+    @Test
+    public void shouldRemove() {
+        bytesStore.put(serializeKey(new Windowed<>("a", windows[0])), serializeValue(30));
+        bytesStore.put(serializeKey(new Windowed<>("a", windows[1])), serializeValue(50));
+
+        bytesStore.remove(serializeKey(new Windowed<>("a", windows[0])));
+        try (final KeyValueIterator<Bytes, byte[]> value = bytesStore.fetch(Bytes.wrap("a".getBytes()), 0, 100)) {
+            assertFalse(value.hasNext());
+        }
+
+        if (getIndexSchema() != null) {
+            // Index should also be removed.
+            final Bytes indexKey = serializeKeyForIndex(new Windowed<>("a", windows[0]));
+            final byte[] value = bytesStore.getIndex(indexKey);
+            assertThat(value, is(nullValue()));
+        }
+    }
+
+    @Test
+    public void shouldRollSegments() {
+        // just to validate directories
+        final AbstractSegments<S> segments = newSegments();
+        final String key = "a";
+
+        bytesStore.put(serializeKey(new Windowed<>(key, windows[0])), serializeValue(50));
+        bytesStore.put(serializeKey(new Windowed<>(key, windows[1])), serializeValue(100));
+        bytesStore.put(serializeKey(new Windowed<>(key, windows[2])), serializeValue(500));
+        assertEquals(Collections.singleton(segments.segmentName(0)), segmentDirs());
+
+        bytesStore.put(serializeKey(new Windowed<>(key, windows[3])), serializeValue(1000));
+        assertEquals(Utils.mkSet(segments.segmentName(0), segments.segmentName(1)), segmentDirs());
+
+        final List<KeyValue<Windowed<String>, Long>> results = toList(bytesStore.fetch(Bytes.wrap(key.getBytes()), 0, 1500));
+
+        assertEquals(
+            asList(
+                KeyValue.pair(new Windowed<>(key, windows[0]), 50L),
+                KeyValue.pair(new Windowed<>(key, windows[1]), 100L),
+                KeyValue.pair(new Windowed<>(key, windows[2]), 500L)
+            ),
+            results
+        );
+
+        segments.close();
+    }
+
+    @Test
+    public void shouldGetAllSegments() {
+        // just to validate directories
+        final AbstractSegments<S> segments = newSegments();
+        final String keyA = "a";
+        final String keyB = "b";
+
+        bytesStore.put(serializeKey(new Windowed<>(keyA, windows[0])), serializeValue(50L));
+        assertEquals(Collections.singleton(segments.segmentName(0)), segmentDirs());
+
+        bytesStore.put(serializeKey(new Windowed<>(keyB, windows[3])), serializeValue(100L));
+        assertEquals(
+            Utils.mkSet(
+                segments.segmentName(0),
+                segments.segmentName(1)
+            ),
+            segmentDirs()
+        );
+
+        final List<KeyValue<Windowed<String>, Long>> results = toList(bytesStore.all());
+        assertEquals(
+            asList(
+                KeyValue.pair(new Windowed<>(keyA, windows[0]), 50L),
+                KeyValue.pair(new Windowed<>(keyB, windows[3]), 100L)
+            ),
+            results
+        );
+
+        segments.close();
+    }
+
+    @Test
+    public void shouldGetAllBackwards() {
+        // just to validate directories
+        final AbstractSegments<S> segments = newSegments();
+        final String keyA = "a";
+        final String keyB = "b";
+
+        bytesStore.put(serializeKey(new Windowed<>(keyA, windows[0])), serializeValue(50L));
+        assertEquals(Collections.singleton(segments.segmentName(0)), segmentDirs());
+
+        bytesStore.put(serializeKey(new Windowed<>(keyB, windows[3])), serializeValue(100L));
+        assertEquals(
+            Utils.mkSet(
+                segments.segmentName(0),
+                segments.segmentName(1)
+            ),
+            segmentDirs()
+        );
+
+        final List<KeyValue<Windowed<String>, Long>> results = toList(bytesStore.backwardAll());
+        assertEquals(
+            asList(
+                KeyValue.pair(new Windowed<>(keyB, windows[3]), 100L),
+                KeyValue.pair(new Windowed<>(keyA, windows[0]), 50L)
+            ),
+            results
+        );
+
+        segments.close();
+    }
+
+    @Test
+    public void shouldFetchAllSegments() {
+        // just to validate directories
+        final AbstractSegments<S> segments = newSegments();
+        final String key = "a";
+
+        bytesStore.put(serializeKey(new Windowed<>(key, windows[0])), serializeValue(50L));
+        assertEquals(Collections.singleton(segments.segmentName(0)), segmentDirs());
+
+        bytesStore.put(serializeKey(new Windowed<>(key, windows[3])), serializeValue(100L));
+        assertEquals(
+            Utils.mkSet(
+                segments.segmentName(0),
+                segments.segmentName(1)
+            ),
+            segmentDirs()
+        );
+
+        final List<KeyValue<Windowed<String>, Long>> results = toList(bytesStore.fetchAll(0L, 60_000L));
+        assertEquals(
+            asList(
+                KeyValue.pair(new Windowed<>(key, windows[0]), 50L),
+                KeyValue.pair(new Windowed<>(key, windows[3]), 100L)
+            ),
+            results
+        );
+
+        segments.close();
+    }
+
+    @Test
+    public void shouldLoadSegmentsWithOldStyleDateFormattedName() {
+        final AbstractSegments<S> segments = newSegments();
+        final String key = "a";
+
+        bytesStore.put(serializeKey(new Windowed<>(key, windows[0])), serializeValue(50L));
+        bytesStore.put(serializeKey(new Windowed<>(key, windows[3])), serializeValue(100L));
+        bytesStore.close();
+
+        final String firstSegmentName = segments.segmentName(0);
+        final String[] nameParts = firstSegmentName.split("\\.");
+        final long segmentId = Long.parseLong(nameParts[1]);
+        final SimpleDateFormat formatter = new SimpleDateFormat("yyyyMMddHHmm");
+        formatter.setTimeZone(new SimpleTimeZone(0, "UTC"));
+        final String formatted = formatter.format(new Date(segmentId * segmentInterval));
+        final File parent = new File(stateDir, storeName);
+        final File oldStyleName = new File(parent, nameParts[0] + "-" + formatted);
+        assertTrue(new File(parent, firstSegmentName).renameTo(oldStyleName));
+
+        bytesStore = getBytesStore();
+
+        bytesStore.init((StateStoreContext) context, bytesStore);
+        final List<KeyValue<Windowed<String>, Long>> results = toList(bytesStore.fetch(Bytes.wrap(key.getBytes()), 0L, 60_000L));
+        assertThat(
+            results,
+            equalTo(
+                asList(
+                    KeyValue.pair(new Windowed<>(key, windows[0]), 50L),
+                    KeyValue.pair(new Windowed<>(key, windows[3]), 100L)
+                )
+            )
+        );
+
+        segments.close();
+    }
+
+    @Test
+    public void shouldLoadSegmentsWithOldStyleColonFormattedName() {
+        final AbstractSegments<S> segments = newSegments();
+        final String key = "a";
+
+        bytesStore.put(serializeKey(new Windowed<>(key, windows[0])), serializeValue(50L));
+        bytesStore.put(serializeKey(new Windowed<>(key, windows[3])), serializeValue(100L));
+        bytesStore.close();
+
+        final String firstSegmentName = segments.segmentName(0);
+        final String[] nameParts = firstSegmentName.split("\\.");
+        final File parent = new File(stateDir, storeName);
+        final File oldStyleName = new File(parent, nameParts[0] + ":" + Long.parseLong(nameParts[1]));
+        assertTrue(new File(parent, firstSegmentName).renameTo(oldStyleName));
+
+        bytesStore = getBytesStore();
+
+        bytesStore.init((StateStoreContext) context, bytesStore);
+        final List<KeyValue<Windowed<String>, Long>> results = toList(bytesStore.fetch(Bytes.wrap(key.getBytes()), 0L, 60_000L));
+        assertThat(
+            results,
+            equalTo(
+                asList(
+                    KeyValue.pair(new Windowed<>(key, windows[0]), 50L),
+                    KeyValue.pair(new Windowed<>(key, windows[3]), 100L)
+                )
+            )
+        );
+
+        segments.close();
+    }
+
+    @Test
+    public void shouldBeAbleToWriteToReInitializedStore() {
+        final String key = "a";
+        // need to create a segment so we can attempt to write to it again.
+        bytesStore.put(serializeKey(new Windowed<>(key, windows[0])), serializeValue(50));
+        bytesStore.close();
+        bytesStore.init((StateStoreContext) context, bytesStore);
+        bytesStore.put(serializeKey(new Windowed<>(key, windows[1])), serializeValue(100));
+    }
+
+    @Test
+    public void shouldCreateWriteBatches() {
+        final String key = "a";
+        final Collection<ConsumerRecord<byte[], byte[]>> records = new ArrayList<>();
+        records.add(new ConsumerRecord<>("", 0, 0L, serializeKey(new Windowed<>(key, windows[0]), true).get(), serializeValue(50L)));
+        records.add(new ConsumerRecord<>("", 0, 0L, serializeKey(new Windowed<>(key, windows[3]), true).get(), serializeValue(100L)));
+        final Map<S, WriteBatch> writeBatchMap = bytesStore.getWriteBatches(records);
+        assertEquals(2, writeBatchMap.size());
+
+        final int expectedCount = getIndexSchema() == null ? 1 : 2;
+        for (final WriteBatch batch : writeBatchMap.values()) {
+            // 2 includes base and index record
+            assertEquals(expectedCount, batch.count());
+        }
+    }
+
+    @Test
+    public void shouldRestoreToByteStoreForActiveTask() {
+        shouldRestoreToByteStore(TaskType.ACTIVE);
+    }
+
+    @Test
+    public void shouldRestoreToByteStoreForStandbyTask() {
+        context.transitionToStandby(null);
+        shouldRestoreToByteStore(TaskType.STANDBY);
+    }
+
+    private void shouldRestoreToByteStore(final TaskType taskType) {
+        bytesStore.init((StateStoreContext) context, bytesStore);
+        // 0 segments initially.
+        assertEquals(0, bytesStore.getSegments().size());
+        final String key = "a";
+        final Collection<ConsumerRecord<byte[], byte[]>> records = new ArrayList<>();
+        records.add(new ConsumerRecord<>("", 0, 0L, serializeKey(new Windowed<>(key, windows[0]), true).get(), serializeValue(50L)));
+        records.add(new ConsumerRecord<>("", 0, 0L, serializeKey(new Windowed<>(key, windows[3]), true).get(), serializeValue(100L)));
+        bytesStore.restoreAllInternal(records);
+
+        // 2 segments are created during restoration.
+        assertEquals(2, bytesStore.getSegments().size());
+
+        final List<KeyValue<Windowed<String>, Long>> expected = new ArrayList<>();
+        expected.add(new KeyValue<>(new Windowed<>(key, windows[0]), 50L));
+        expected.add(new KeyValue<>(new Windowed<>(key, windows[3]), 100L));
+
+        final List<KeyValue<Windowed<String>, Long>> results = toList(bytesStore.all());
+        assertEquals(expected, results);
+    }
+
+    @Test
+    public void shouldMatchPositionAfterPut() {
+        bytesStore.init((StateStoreContext) context, bytesStore);
+
+        final String keyA = "a";
+        final String keyB = "b";
+        final String keyC = "c";
+
+        context.setRecordContext(new ProcessorRecordContext(0, 1, 0, "", new RecordHeaders()));
+        bytesStore.put(serializeKey(new Windowed<>(keyA, windows[0])), serializeValue(10));
+        context.setRecordContext(new ProcessorRecordContext(0, 2, 0, "", new RecordHeaders()));
+        bytesStore.put(serializeKey(new Windowed<>(keyA, windows[1])), serializeValue(50));
+        context.setRecordContext(new ProcessorRecordContext(0, 3, 0, "", new RecordHeaders()));
+        bytesStore.put(serializeKey(new Windowed<>(keyB, windows[2])), serializeValue(100));
+        context.setRecordContext(new ProcessorRecordContext(0, 4, 0, "", new RecordHeaders()));
+        bytesStore.put(serializeKey(new Windowed<>(keyC, windows[3])), serializeValue(200));
+
+        final Position expected = Position.fromMap(mkMap(mkEntry("", mkMap(mkEntry(0, 4L)))));
+        final Position actual = bytesStore.getPosition();
+        assertEquals(expected, actual);
+    }
+
+    @Test
+    public void shouldRestoreRecordsAndConsistencyVectorSingleTopic() {
+        final Properties props = StreamsTestUtils.getStreamsConfig();
+        props.put(InternalConfig.IQ_CONSISTENCY_OFFSET_VECTOR_ENABLED, true);
+        final File dir = TestUtils.tempDirectory();
+        context = new InternalMockProcessorContext<>(
+                dir,
+                Serdes.String(),
+                Serdes.String(),
+                new StreamsMetricsImpl(new Metrics(), "mock", StreamsConfig.METRICS_LATEST, new MockTime()),
+                new StreamsConfig(props),
+                MockRecordCollector::new,
+                new ThreadCache(new LogContext("testCache "), 0, new MockStreamsMetrics(new Metrics())),
+                Time.SYSTEM
+        );
+        bytesStore = getBytesStore();
+        bytesStore.init((StateStoreContext) context, bytesStore);
+        // 0 segments initially.
+        assertEquals(0, bytesStore.getSegments().size());
+
+        bytesStore.restoreAllInternal(getChangelogRecords());
+        // 2 segments are created during restoration.
+        assertEquals(2, bytesStore.getSegments().size());
+
+        final String key = "a";
+        final List<KeyValue<Windowed<String>, Long>> expected = new ArrayList<>();
+        expected.add(new KeyValue<>(new Windowed<>(key, windows[0]), 50L));
+        expected.add(new KeyValue<>(new Windowed<>(key, windows[2]), 100L));
+        expected.add(new KeyValue<>(new Windowed<>(key, windows[3]), 200L));
+
+        final List<KeyValue<Windowed<String>, Long>> results = toList(bytesStore.all());
+        assertEquals(expected, results);
+        assertThat(bytesStore.getPosition(), Matchers.notNullValue());
+        assertThat(bytesStore.getPosition().getPartitionPositions(""), Matchers.notNullValue());
+        assertThat(bytesStore.getPosition().getPartitionPositions(""), hasEntry(0, 3L));
+    }
+
+    @Test
+    public void shouldRestoreRecordsAndConsistencyVectorMultipleTopics() {
+        final Properties props = StreamsTestUtils.getStreamsConfig();
+        props.put(InternalConfig.IQ_CONSISTENCY_OFFSET_VECTOR_ENABLED, true);
+        final File dir = TestUtils.tempDirectory();
+        context = new InternalMockProcessorContext<>(
+                dir,
+                Serdes.String(),
+                Serdes.String(),
+                new StreamsMetricsImpl(new Metrics(), "mock", StreamsConfig.METRICS_LATEST, new MockTime()),
+                new StreamsConfig(props),
+                MockRecordCollector::new,
+                new ThreadCache(new LogContext("testCache "), 0, new MockStreamsMetrics(new Metrics())),
+                Time.SYSTEM
+        );
+        bytesStore = getBytesStore();
+        bytesStore.init((StateStoreContext) context, bytesStore);
+        // 0 segments initially.
+        assertEquals(0, bytesStore.getSegments().size());
+
+        bytesStore.restoreAllInternal(getChangelogRecordsMultipleTopics());
+        // 2 segments are created during restoration.
+        assertEquals(2, bytesStore.getSegments().size());
+
+        final String key = "a";
+        final List<KeyValue<Windowed<String>, Long>> expected = new ArrayList<>();
+        expected.add(new KeyValue<>(new Windowed<>(key, windows[0]), 50L));
+        expected.add(new KeyValue<>(new Windowed<>(key, windows[2]), 100L));
+        expected.add(new KeyValue<>(new Windowed<>(key, windows[3]), 200L));
+
+        final List<KeyValue<Windowed<String>, Long>> results = toList(bytesStore.all());
+        assertEquals(expected, results);
+        assertThat(bytesStore.getPosition(), Matchers.notNullValue());
+        assertThat(bytesStore.getPosition().getPartitionPositions("A"), Matchers.notNullValue());
+        assertThat(bytesStore.getPosition().getPartitionPositions("A"), hasEntry(0, 3L));
+        assertThat(bytesStore.getPosition().getPartitionPositions("B"), Matchers.notNullValue());
+        assertThat(bytesStore.getPosition().getPartitionPositions("B"), hasEntry(0, 2L));
+    }
+
+    @Test
+    public void shouldHandleTombstoneRecords() {
+        final Properties props = StreamsTestUtils.getStreamsConfig();
+        props.put(InternalConfig.IQ_CONSISTENCY_OFFSET_VECTOR_ENABLED, true);
+        final File dir = TestUtils.tempDirectory();
+        context = new InternalMockProcessorContext<>(
+                dir,
+                Serdes.String(),
+                Serdes.String(),
+                new StreamsMetricsImpl(new Metrics(), "mock", StreamsConfig.METRICS_LATEST, new MockTime()),
+                new StreamsConfig(props),
+                MockRecordCollector::new,
+                new ThreadCache(new LogContext("testCache "), 0, new MockStreamsMetrics(new Metrics())),
+                Time.SYSTEM
+        );
+        bytesStore = getBytesStore();
+        bytesStore.init((StateStoreContext) context, bytesStore);
+        // 0 segments initially.
+        assertEquals(0, bytesStore.getSegments().size());
+
+        bytesStore.restoreAllInternal(getChangelogRecordsWithTombstones());
+        // 1 segments are created during restoration.
+        assertEquals(1, bytesStore.getSegments().size());
+        final String key = "a";
+        final List<KeyValue<Windowed<String>, Long>> expected = new ArrayList<>();
+        expected.add(new KeyValue<>(new Windowed<>(key, windows[0]), 50L));
+
+        final List<KeyValue<Windowed<String>, Long>> results = toList(bytesStore.all());
+        assertEquals(expected, results);
+        assertThat(bytesStore.getPosition(), Matchers.notNullValue());
+        assertThat(bytesStore.getPosition().getPartitionPositions("A"), hasEntry(0, 2L));
+    }
+
+    @Test
+    public void shouldNotThrowWhenRestoringOnMissingHeaders() {
+        final Properties props = StreamsTestUtils.getStreamsConfig();
+        props.put(InternalConfig.IQ_CONSISTENCY_OFFSET_VECTOR_ENABLED, true);
+        final File dir = TestUtils.tempDirectory();
+        context = new InternalMockProcessorContext<>(
+                dir,
+                Serdes.String(),
+                Serdes.String(),
+                new StreamsMetricsImpl(new Metrics(), "mock", StreamsConfig.METRICS_LATEST, new MockTime()),
+                new StreamsConfig(props),
+                MockRecordCollector::new,
+                new ThreadCache(new LogContext("testCache "), 0, new MockStreamsMetrics(new Metrics())),
+                Time.SYSTEM
+        );
+        bytesStore = getBytesStore();
+        bytesStore.init((StateStoreContext) context, bytesStore);
+        bytesStore.restoreAllInternal(getChangelogRecordsWithoutHeaders());
+        assertThat(bytesStore.getPosition(), is(Position.emptyPosition()));
+    }
+
+    private List<ConsumerRecord<byte[], byte[]>> getChangelogRecords() {
+        final List<ConsumerRecord<byte[], byte[]>> records = new ArrayList<>();
+        final Headers headers = new RecordHeaders();
+
+        Position position1 = Position.emptyPosition();
+        position1 = position1.withComponent("", 0, 1);
+        headers.add(ChangelogRecordDeserializationHelper.CHANGELOG_VERSION_HEADER_RECORD_CONSISTENCY);
+        headers.add(new RecordHeader(
+                ChangelogRecordDeserializationHelper.CHANGELOG_POSITION_HEADER_KEY,
+                PositionSerde.serialize(position1).array())
+        );
+        records.add(new ConsumerRecord<>("", 0, 0L,  RecordBatch.NO_TIMESTAMP, TimestampType.NO_TIMESTAMP_TYPE, -1, -1,
+                serializeKey(new Windowed<>("a", windows[0]), true).get(), serializeValue(50L), headers, Optional.empty()));
+
+        headers.remove(ChangelogRecordDeserializationHelper.CHANGELOG_POSITION_HEADER_KEY);
+        position1 = position1.withComponent("", 0, 2);
+        headers.add(new RecordHeader(
+                ChangelogRecordDeserializationHelper.CHANGELOG_POSITION_HEADER_KEY,
+                PositionSerde.serialize(position1).array())
+        );
+        records.add(new ConsumerRecord<>("", 0, 0L,  RecordBatch.NO_TIMESTAMP, TimestampType.NO_TIMESTAMP_TYPE, -1, -1,
+                serializeKey(new Windowed<>("a", windows[2]), true).get(), serializeValue(100L), headers, Optional.empty()));
+
+        headers.remove(ChangelogRecordDeserializationHelper.CHANGELOG_POSITION_HEADER_KEY);
+        position1 = position1.withComponent("", 0, 3);
+        headers.add(new RecordHeader(
+                ChangelogRecordDeserializationHelper.CHANGELOG_POSITION_HEADER_KEY,
+                PositionSerde.serialize(position1).array())
+        );
+        records.add(new ConsumerRecord<>("", 0, 0L,  RecordBatch.NO_TIMESTAMP, TimestampType.NO_TIMESTAMP_TYPE, -1, -1,
+                serializeKey(new Windowed<>("a", windows[3]), true).get(), serializeValue(200L), headers, Optional.empty()));
+
+        return records;
+    }
+
+    private List<ConsumerRecord<byte[], byte[]>> getChangelogRecordsMultipleTopics() {
+        final List<ConsumerRecord<byte[], byte[]>> records = new ArrayList<>();
+        final Headers headers = new RecordHeaders();
+        Position position1 = Position.emptyPosition();
+
+        position1 = position1.withComponent("A", 0, 1);
+        headers.add(ChangelogRecordDeserializationHelper.CHANGELOG_VERSION_HEADER_RECORD_CONSISTENCY);
+        headers.add(new RecordHeader(
+                ChangelogRecordDeserializationHelper.CHANGELOG_POSITION_HEADER_KEY,
+                PositionSerde.serialize(position1).array())
+        );
+        records.add(new ConsumerRecord<>("", 0, 0L,  RecordBatch.NO_TIMESTAMP, TimestampType.NO_TIMESTAMP_TYPE, -1, -1,
+                serializeKey(new Windowed<>("a", windows[0]), true).get(), serializeValue(50L), headers, Optional.empty()));
+
+        headers.remove(ChangelogRecordDeserializationHelper.CHANGELOG_POSITION_HEADER_KEY);
+        position1 = position1.withComponent("B", 0, 2);
+        headers.add(new RecordHeader(
+                ChangelogRecordDeserializationHelper.CHANGELOG_POSITION_HEADER_KEY,
+                PositionSerde.serialize(position1).array())
+        );
+        records.add(new ConsumerRecord<>("", 0, 0L,  RecordBatch.NO_TIMESTAMP, TimestampType.NO_TIMESTAMP_TYPE, -1, -1,
+                serializeKey(new Windowed<>("a", windows[2]), true).get(), serializeValue(100L), headers, Optional.empty()));
+
+        headers.remove(ChangelogRecordDeserializationHelper.CHANGELOG_POSITION_HEADER_KEY);
+        position1 = position1.withComponent("A", 0, 3);
+        headers.add(new RecordHeader(
+                ChangelogRecordDeserializationHelper.CHANGELOG_POSITION_HEADER_KEY,
+                PositionSerde.serialize(position1).array())
+        );
+        records.add(new ConsumerRecord<>("", 0, 0L,  RecordBatch.NO_TIMESTAMP, TimestampType.NO_TIMESTAMP_TYPE, -1, -1,
+                serializeKey(new Windowed<>("a", windows[3]), true).get(), serializeValue(200L), headers, Optional.empty()));
+
+        return records;
+    }
+
+    private List<ConsumerRecord<byte[], byte[]>> getChangelogRecordsWithTombstones() {
+        final List<ConsumerRecord<byte[], byte[]>> records = new ArrayList<>();
+        final Headers headers = new RecordHeaders();
+        Position position = Position.emptyPosition();
+
+        position = position.withComponent("A", 0, 1);
+        headers.add(ChangelogRecordDeserializationHelper.CHANGELOG_VERSION_HEADER_RECORD_CONSISTENCY);
+        headers.add(new RecordHeader(
+                ChangelogRecordDeserializationHelper.CHANGELOG_POSITION_HEADER_KEY,
+                PositionSerde.serialize(position).array()));
+        records.add(new ConsumerRecord<>("", 0, 0L,  RecordBatch.NO_TIMESTAMP, TimestampType.NO_TIMESTAMP_TYPE, -1, -1,
+                serializeKey(new Windowed<>("a", windows[0]), true).get(), serializeValue(50L), headers, Optional.empty()));
+
+        position = position.withComponent("A", 0, 2);
+        headers.add(ChangelogRecordDeserializationHelper.CHANGELOG_VERSION_HEADER_RECORD_CONSISTENCY);
+        headers.add(new RecordHeader(
+                ChangelogRecordDeserializationHelper.CHANGELOG_POSITION_HEADER_KEY,
+                PositionSerde.serialize(position).array()));
+        records.add(new ConsumerRecord<>("", 0, 0L,  RecordBatch.NO_TIMESTAMP, TimestampType.NO_TIMESTAMP_TYPE, -1, -1,
+                serializeKey(new Windowed<>("a", windows[2]), true).get(), null, headers, Optional.empty()));
+
+        return records;
+    }
+
+    private List<ConsumerRecord<byte[], byte[]>> getChangelogRecordsWithoutHeaders() {
+        final List<ConsumerRecord<byte[], byte[]>> records = new ArrayList<>();
+        records.add(new ConsumerRecord<>("", 0, 0L, serializeKey(new Windowed<>("a", windows[2])).get(), serializeValue(50L)));
+        return records;
+    }
+
+
+
+    @Test
+    public void shouldLogAndMeasureExpiredRecords() {
+        final Properties streamsConfig = StreamsTestUtils.getStreamsConfig();
+        final AbstractDualSchemaRocksDBSegmentedBytesStore<S> bytesStore = getBytesStore();
+        final InternalMockProcessorContext context = new InternalMockProcessorContext(
+            TestUtils.tempDirectory(),
+            new StreamsConfig(streamsConfig)
+        );
+        final Time time = new SystemTime();
+        context.setSystemTimeMs(time.milliseconds());
+        bytesStore.init((StateStoreContext) context, bytesStore);
+
+        try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister()) {
+            // write a record to advance stream time, with a high enough timestamp
+            // that the subsequent record in windows[0] will already be expired.
+            bytesStore.put(serializeKey(new Windowed<>("dummy", nextSegmentWindow)), serializeValue(0));
+
+            final Bytes key = serializeKey(new Windowed<>("a", windows[0]));
+            final byte[] value = serializeValue(5);
+            bytesStore.put(key, value);
+
+            final List<String> messages = appender.getMessages();
+            assertThat(messages, hasItem("Skipping record for expired segment."));
+        }
+
+        final Map<MetricName, ? extends Metric> metrics = context.metrics().metrics();
+        final String threadId = Thread.currentThread().getName();
+        final Metric dropTotal;
+        final Metric dropRate;
+        dropTotal = metrics.get(new MetricName(
+            "dropped-records-total",
+            "stream-task-metrics",
+            "",
+            mkMap(
+                mkEntry("thread-id", threadId),
+                mkEntry("task-id", "0_0")
+            )
+        ));
+
+        dropRate = metrics.get(new MetricName(
+            "dropped-records-rate",
+            "stream-task-metrics",
+            "",
+            mkMap(
+                mkEntry("thread-id", threadId),
+                mkEntry("task-id", "0_0")
+            )
+        ));
+        assertEquals(1.0, dropTotal.metricValue());
+        assertNotEquals(0.0, dropRate.metricValue());
+
+        bytesStore.close();
+    }
+
+    private Set<String> segmentDirs() {
+        final File windowDir = new File(stateDir, storeName);
+
+        return Utils.mkSet(Objects.requireNonNull(windowDir.list()));
+    }
+
+    private Bytes serializeKey(final Windowed<String> key) {
+        return serializeKey(key, false);
+    }
+
+    private Bytes serializeKey(final Windowed<String> key, final boolean changeLog) {
+        return serializeKey(key, changeLog, 0);
+    }
+
+    private Bytes serializeKey(final Windowed<String> key, final boolean changeLog, final int seq) {
+        final StateSerdes<String, Long> stateSerdes = StateSerdes.withBuiltinTypes("dummy", String.class, Long.class);
+        if (changeLog) {
+            return WindowKeySchema.toStoreKeyBinary(key, seq, stateSerdes);
+        } else if (getBaseSchema() instanceof TimeFirstWindowKeySchema) {
+            return TimeFirstWindowKeySchema.toStoreKeyBinary(key, seq, stateSerdes);
+        } else {
+            throw new IllegalStateException("Unrecognized serde schema");
+        }
+    }
+
+    private Bytes serializeKeyForIndex(final Windowed<String> key) {
+        final StateSerdes<String, Long> stateSerdes = StateSerdes.withBuiltinTypes("dummy", String.class, Long.class);
+        if (getIndexSchema() instanceof KeyFirstWindowKeySchema) {
+            return KeyFirstWindowKeySchema.toStoreKeyBinary(key, 0, stateSerdes);
+        } else {
+            throw new IllegalStateException("Unrecognized serde schema");
+        }
+    }
+
+    private byte[] serializeValue(final long value) {
+        return Serdes.Long().serializer().serialize("", value);
+    }
+
+    private List<KeyValue<Windowed<String>, Long>> toList(final KeyValueIterator<Bytes, byte[]> iterator) {
+        final List<KeyValue<Windowed<String>, Long>> results = new ArrayList<>();
+        final StateSerdes<String, Long> stateSerdes = StateSerdes.withBuiltinTypes("dummy", String.class, Long.class);
+        while (iterator.hasNext()) {
+            final KeyValue<Bytes, byte[]> next = iterator.next();
+            if (getBaseSchema() instanceof TimeFirstWindowKeySchema) {
+                final KeyValue<Windowed<String>, Long> deserialized = KeyValue.pair(
+                    TimeFirstWindowKeySchema.fromStoreKey(
+                        next.key.get(),
+                        windowSizeForTimeWindow,
+                        stateSerdes.keyDeserializer(),
+                        stateSerdes.topic()
+                    ),
+                    stateSerdes.valueDeserializer().deserialize("dummy", next.value)
+                );
+                results.add(deserialized);
+            } else {
+                throw new IllegalStateException("Unrecognized serde schema");
+            }
+        }
+        return results;
+    }
+}
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedSegmentedBytesStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedSegmentedBytesStoreTest.java
new file mode 100644
index 0000000..0d5b016
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedSegmentedBytesStoreTest.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.state.internals;
+
+import static java.util.Arrays.asList;
+
+import java.util.Collection;
+import org.apache.kafka.streams.state.internals.PrefixedWindowKeySchemas.KeyFirstWindowKeySchema;
+import org.apache.kafka.streams.state.internals.PrefixedWindowKeySchemas.TimeFirstWindowKeySchema;
+import org.apache.kafka.streams.state.internals.SegmentedBytesStore.KeySchema;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+
+@RunWith(Parameterized.class)
+public class RocksDBTimeOrderedSegmentedBytesStoreTest
+    extends AbstractDualSchemaRocksDBSegmentedBytesStoreTest<KeyValueSegment> {
+
+    private final static String METRICS_SCOPE = "metrics-scope";
+
+    @Parameter
+    public String name;
+
+    @Parameter(1)
+    public boolean hasIndex;
+
+    @Parameterized.Parameters(name = "{0}")
+    public static Collection<Object[]> getKeySchema() {
+        return asList(new Object[][] {
+            {"WindowSchemaWithIndex", true},
+            {"WindowSchemaWithoutIndex", false}
+        });
+    }
+
+    AbstractDualSchemaRocksDBSegmentedBytesStore<KeyValueSegment> getBytesStore() {
+        return new RocksDBTimeOrderedSegmentedBytesStore(
+            storeName,
+            METRICS_SCOPE,
+            retention,
+            segmentInterval,
+            hasIndex
+        );
+    }
+
+    @Override
+    KeyValueSegments newSegments() {
+        return new KeyValueSegments(storeName, METRICS_SCOPE, retention, segmentInterval);
+    }
+
+    @Override
+    KeySchema getBaseSchema() {
+        return new TimeFirstWindowKeySchema();
+    }
+
+    @Override
+    KeySchema getIndexSchema() {
+        return hasIndex ? new KeyFirstWindowKeySchema() : null;
+    }
+
+}
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBWindowStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBWindowStoreTest.java
index 9390174..5abfd06 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBWindowStoreTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBWindowStoreTest.java
@@ -18,6 +18,7 @@ package org.apache.kafka.streams.state.internals;
 
 import java.io.File;
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.List;
 import java.util.HashSet;
@@ -37,6 +38,9 @@ import org.apache.kafka.streams.state.Stores;
 import org.apache.kafka.streams.state.WindowStore;
 import org.apache.kafka.streams.state.WindowStoreIterator;
 import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
 
 import static java.time.Duration.ofMillis;
 import static java.time.Instant.ofEpochMilli;
@@ -48,6 +52,7 @@ import static org.apache.kafka.test.StreamsTestUtils.valuesToSet;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 
+@RunWith(Parameterized.class)
 public class RocksDBWindowStoreTest extends AbstractWindowBytesStoreTest {
 
     private static final String STORE_NAME = "rocksDB window store";
@@ -56,21 +61,60 @@ public class RocksDBWindowStoreTest extends AbstractWindowBytesStoreTest {
     private final KeyValueSegments segments =
         new KeyValueSegments(STORE_NAME, METRICS_SCOPE, RETENTION_PERIOD, SEGMENT_INTERVAL);
 
+    enum StoreType {
+        RocksDBWindowStore,
+        RocksDBTimeOrderedWindowStoreWithIndex,
+        RocksDBTimeOrderedWindowStoreWithoutIndex
+    }
+
+    @Parameter
+    public String name;
+
+    @Parameter(1)
+    public StoreType storeType;
+
+    @Parameterized.Parameters(name = "{0}")
+    public static Collection<Object[]> getKeySchema() {
+        return asList(new Object[][] {
+            {"RocksDBWindowStore", StoreType.RocksDBWindowStore},
+            {"RocksDBTimeOrderedWindowStoreWithIndex", StoreType.RocksDBTimeOrderedWindowStoreWithIndex},
+            {"RocksDBTimeOrderedWindowStoreWithoutIndex", StoreType.RocksDBTimeOrderedWindowStoreWithoutIndex}
+        });
+    }
+
     @Override
     <K, V> WindowStore<K, V> buildWindowStore(final long retentionPeriod,
                                               final long windowSize,
                                               final boolean retainDuplicates,
                                               final Serde<K> keySerde,
                                               final Serde<V> valueSerde) {
-        return Stores.windowStoreBuilder(
-            Stores.persistentWindowStore(
-                STORE_NAME,
-                ofMillis(retentionPeriod),
-                ofMillis(windowSize),
-                retainDuplicates),
-            keySerde,
-            valueSerde)
-            .build();
+        if (storeType == StoreType.RocksDBWindowStore) {
+            return Stores.windowStoreBuilder(
+                    Stores.persistentWindowStore(
+                        STORE_NAME,
+                        ofMillis(retentionPeriod),
+                        ofMillis(windowSize),
+                        retainDuplicates),
+                    keySerde,
+                    valueSerde)
+                .build();
+        } else if (storeType == StoreType.RocksDBTimeOrderedWindowStoreWithIndex) {
+            final long defaultSegmentInterval = Math.max(retentionPeriod / 2, 60_000L);
+            return Stores.windowStoreBuilder(
+                new RocksDbIndexedTimeOrderedWindowBytesStoreSupplier(STORE_NAME,
+                    retentionPeriod, defaultSegmentInterval, windowSize, retainDuplicates, true),
+                keySerde,
+                valueSerde
+            ).build();
+        } else {
+            final long defaultSegmentInterval = Math.max(retentionPeriod / 2, 60_000L);
+            return Stores.windowStoreBuilder(
+                new RocksDbIndexedTimeOrderedWindowBytesStoreSupplier(STORE_NAME,
+                    retentionPeriod, defaultSegmentInterval, windowSize, retainDuplicates, false),
+                keySerde,
+                valueSerde
+            ).build();
+        }
     }
 
     @Test
@@ -646,7 +690,7 @@ public class RocksDBWindowStoreTest extends AbstractWindowBytesStoreTest {
     public void shouldMatchPositionAfterPut() {
         final MeteredWindowStore<Integer, String> meteredSessionStore = (MeteredWindowStore<Integer, String>) windowStore;
         final ChangeLoggingWindowBytesStore changeLoggingSessionBytesStore = (ChangeLoggingWindowBytesStore) meteredSessionStore.wrapped();
-        final RocksDBWindowStore rocksDBWindowStore = (RocksDBWindowStore) changeLoggingSessionBytesStore.wrapped();
+        final WrappedStateStore rocksDBWindowStore = (WrappedStateStore) changeLoggingSessionBytesStore.wrapped();
 
         context.setRecordContext(new ProcessorRecordContext(0, 1, 0, "", new RecordHeaders()));
         windowStore.put(0, "0", SEGMENT_INTERVAL);
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/SessionKeySchemaTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/SessionKeySchemaTest.java
index 40b06c0..0482f01 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/SessionKeySchemaTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/SessionKeySchemaTest.java
@@ -76,21 +76,21 @@ public class SessionKeySchemaTest {
     @Test
     public void shouldFetchExactKeysSkippingLongerKeys() {
         final Bytes key = Bytes.wrap(new byte[]{0});
-        final List<Integer> result = getValues(sessionKeySchema.hasNextCondition(key, key, 0, Long.MAX_VALUE));
+        final List<Integer> result = getValues(sessionKeySchema.hasNextCondition(key, key, 0, Long.MAX_VALUE, true));
         assertThat(result, equalTo(Arrays.asList(2, 4)));
     }
 
     @Test
     public void shouldFetchExactKeySkippingShorterKeys() {
         final Bytes key = Bytes.wrap(new byte[]{0, 0});
-        final HasNextCondition hasNextCondition = sessionKeySchema.hasNextCondition(key, key, 0, Long.MAX_VALUE);
+        final HasNextCondition hasNextCondition = sessionKeySchema.hasNextCondition(key, key, 0, Long.MAX_VALUE, true);
         final List<Integer> results = getValues(hasNextCondition);
         assertThat(results, equalTo(Arrays.asList(1, 5)));
     }
 
     @Test
     public void shouldFetchAllKeysUsingNullKeys() {
-        final HasNextCondition hasNextCondition = sessionKeySchema.hasNextCondition(null, null, 0, Long.MAX_VALUE);
+        final HasNextCondition hasNextCondition = sessionKeySchema.hasNextCondition(null, null, 0, Long.MAX_VALUE, true);
         final List<Integer> results = getValues(hasNextCondition);
         assertThat(results, equalTo(Arrays.asList(1, 2, 3, 4, 5, 6)));
     }
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/WindowKeySchemaTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/WindowKeySchemaTest.java
index dc88410..e936053 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/WindowKeySchemaTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/WindowKeySchemaTest.java
@@ -17,6 +17,11 @@
 
 package org.apache.kafka.streams.state.internals;
 
+import java.nio.ByteBuffer;
+import java.util.Collection;
+import java.util.Map;
+import java.util.function.BiFunction;
+import java.util.function.Function;
 import org.apache.kafka.common.serialization.Serde;
 import org.apache.kafka.common.serialization.Serdes;
 import org.apache.kafka.common.utils.Bytes;
@@ -27,21 +32,93 @@ import org.apache.kafka.streams.kstream.Windowed;
 import org.apache.kafka.streams.kstream.WindowedSerdes;
 import org.apache.kafka.streams.kstream.internals.TimeWindow;
 import org.apache.kafka.streams.state.StateSerdes;
+import org.apache.kafka.streams.state.internals.PrefixedWindowKeySchemas.KeyFirstWindowKeySchema;
+import org.apache.kafka.streams.state.internals.PrefixedWindowKeySchemas.TimeFirstWindowKeySchema;
+import org.apache.kafka.streams.state.internals.SegmentedBytesStore.KeySchema;
 import org.apache.kafka.test.KeyValueIteratorStub;
 import org.junit.Test;
 
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.List;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
 
+import static java.util.Arrays.asList;
+import static org.apache.kafka.common.utils.Utils.mkEntry;
+import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.core.IsEqual.equalTo;
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNull;
 
+@RunWith(Parameterized.class)
 public class WindowKeySchemaTest {
 
+    private static final Map<SchemaType, KeySchema> SCHEMA_TYPE_MAP = mkMap(
+        mkEntry(SchemaType.WindowKeySchema, new WindowKeySchema()),
+        mkEntry(SchemaType.PrefixedKeyFirstSchema, new KeyFirstWindowKeySchema()),
+        mkEntry(SchemaType.PrefixedTimeFirstSchema, new TimeFirstWindowKeySchema())
+    );
+
+    private static final Map<SchemaType, Function<byte[], byte[]>> EXTRACT_STORE_KEY_MAP = mkMap(
+        mkEntry(SchemaType.WindowKeySchema, WindowKeySchema::extractStoreKeyBytes),
+        mkEntry(SchemaType.PrefixedKeyFirstSchema, KeyFirstWindowKeySchema::extractStoreKeyBytes),
+        mkEntry(SchemaType.PrefixedTimeFirstSchema, TimeFirstWindowKeySchema::extractStoreKeyBytes)
+    );
+
+    private static final Map<SchemaType, BiFunction<byte[], Long, Windowed<Bytes>>> FROM_STORAGE_BYTES_KEY = mkMap(
+        mkEntry(SchemaType.WindowKeySchema, WindowKeySchema::fromStoreBytesKey),
+        mkEntry(SchemaType.PrefixedKeyFirstSchema, KeyFirstWindowKeySchema::fromStoreBytesKey),
+        mkEntry(SchemaType.PrefixedTimeFirstSchema, TimeFirstWindowKeySchema::fromStoreBytesKey)
+    );
+
+    private static final Map<SchemaType, BiFunction<Windowed<Bytes>, Integer, Bytes>> WINDOW_TO_STORE_BINARY_MAP = mkMap(
+        mkEntry(SchemaType.WindowKeySchema, WindowKeySchema::toStoreKeyBinary),
+        mkEntry(SchemaType.PrefixedKeyFirstSchema, KeyFirstWindowKeySchema::toStoreKeyBinary),
+        mkEntry(SchemaType.PrefixedTimeFirstSchema, TimeFirstWindowKeySchema::toStoreKeyBinary)
+    );
+
+    private static final Map<SchemaType, BiFunction<byte[], Long, Window>> EXTRACT_STORE_WINDOW_MAP = mkMap(
+        mkEntry(SchemaType.WindowKeySchema, WindowKeySchema::extractStoreWindow),
+        mkEntry(SchemaType.PrefixedKeyFirstSchema, KeyFirstWindowKeySchema::extractStoreWindow),
+        mkEntry(SchemaType.PrefixedTimeFirstSchema, TimeFirstWindowKeySchema::extractStoreWindow)
+    );
+
+    @FunctionalInterface
+    interface TriFunction<A, B, C, R> {
+        R apply(A a, B b, C c);
+    }
+
+    private static final Map<SchemaType, TriFunction<byte[], Long, Integer, Bytes>> BYTES_TO_STORE_BINARY_MAP = mkMap(
+        mkEntry(SchemaType.WindowKeySchema, WindowKeySchema::toStoreKeyBinary),
+        mkEntry(SchemaType.PrefixedKeyFirstSchema, KeyFirstWindowKeySchema::toStoreKeyBinary),
+        mkEntry(SchemaType.PrefixedTimeFirstSchema, TimeFirstWindowKeySchema::toStoreKeyBinary)
+    );
+
+    private static final Map<SchemaType, TriFunction<Windowed<String>, Integer, StateSerdes<String, byte[]>, Bytes>> SERDE_TO_STORE_BINARY_MAP = mkMap(
+        mkEntry(SchemaType.WindowKeySchema, WindowKeySchema::toStoreKeyBinary),
+        mkEntry(SchemaType.PrefixedKeyFirstSchema, KeyFirstWindowKeySchema::toStoreKeyBinary),
+        mkEntry(SchemaType.PrefixedTimeFirstSchema, TimeFirstWindowKeySchema::toStoreKeyBinary)
+    );
+
+    private static final Map<SchemaType, Function<byte[], Long>> EXTRACT_TS_MAP = mkMap(
+        mkEntry(SchemaType.WindowKeySchema, WindowKeySchema::extractStoreTimestamp),
+        mkEntry(SchemaType.PrefixedKeyFirstSchema, KeyFirstWindowKeySchema::extractStoreTimestamp),
+        mkEntry(SchemaType.PrefixedTimeFirstSchema, TimeFirstWindowKeySchema::extractStoreTimestamp)
+    );
+
+    private static final Map<SchemaType, Function<byte[], Integer>> EXTRACT_SEQ_MAP = mkMap(
+        mkEntry(SchemaType.WindowKeySchema, WindowKeySchema::extractStoreSequence),
+        mkEntry(SchemaType.PrefixedKeyFirstSchema, KeyFirstWindowKeySchema::extractStoreSequence),
+        mkEntry(SchemaType.PrefixedTimeFirstSchema, TimeFirstWindowKeySchema::extractStoreSequence)
+    );
+
+    private static final Map<SchemaType, Function<byte[], byte[]>> FROM_WINDOW_KEY_MAP = mkMap(
+        mkEntry(SchemaType.PrefixedKeyFirstSchema, KeyFirstWindowKeySchema::fromNonPrefixWindowKey),
+        mkEntry(SchemaType.PrefixedTimeFirstSchema, TimeFirstWindowKeySchema::fromNonPrefixWindowKey)
+    );
+
     final private String key = "key";
     final private String topic = "topic";
     final private long startTime = 50L;
@@ -50,39 +127,94 @@ public class WindowKeySchemaTest {
 
     final private Window window = new TimeWindow(startTime, endTime);
     final private Windowed<String> windowedKey = new Windowed<>(key, window);
-    final private WindowKeySchema windowKeySchema = new WindowKeySchema();
+    final private KeySchema keySchema;
     final private Serde<Windowed<String>> keySerde = new WindowedSerdes.TimeWindowedSerde<>(serde, Long.MAX_VALUE);
     final private StateSerdes<String, byte[]> stateSerdes = new StateSerdes<>("dummy", serde, Serdes.ByteArray());
+    final private SchemaType schemaType;
+
+    private enum SchemaType {
+        WindowKeySchema,
+        PrefixedTimeFirstSchema,
+        PrefixedKeyFirstSchema
+    }
+
+    @Parameterized.Parameters(name = "{0}")
+    public static Collection<Object[]> data() {
+        return asList(new Object[][] {
+            {"WindowKeySchema", SchemaType.WindowKeySchema},
+            {"PrefixedTimeFirstSchema", SchemaType.PrefixedTimeFirstSchema},
+            {"PrefixedKeyFirstSchema", SchemaType.PrefixedKeyFirstSchema}
+        });
+    }
+
+    public WindowKeySchemaTest(final String name, final SchemaType type) {
+        schemaType = type;
+        keySchema = SCHEMA_TYPE_MAP.get(type);
+    }
+
+    private BiFunction<byte[], Long, Windowed<Bytes>> getFromStorageKey() {
+        return FROM_STORAGE_BYTES_KEY.get(schemaType);
+    }
+
+    private BiFunction<byte[], Long, Window> getExtractStoreWindow() {
+        return EXTRACT_STORE_WINDOW_MAP.get(schemaType);
+    }
+
+    private Function<byte[], byte[]> getExtractStorageKey() {
+        return EXTRACT_STORE_KEY_MAP.get(schemaType);
+    }
+
+    private BiFunction<Windowed<Bytes>, Integer, Bytes> getToStoreKeyBinaryWindowParam() {
+        return WINDOW_TO_STORE_BINARY_MAP.get(schemaType);
+    }
+
+    private TriFunction<byte[], Long, Integer, Bytes> getToStoreKeyBinaryBytesParam() {
+        return BYTES_TO_STORE_BINARY_MAP.get(schemaType);
+    }
+
+    private Function<byte[], Long> getExtractTimestampFunc() {
+        return EXTRACT_TS_MAP.get(schemaType);
+    }
+
+    private Function<byte[], Integer> getExtractSeqFunc() {
+        return EXTRACT_SEQ_MAP.get(schemaType);
+    }
+
+    private TriFunction<Windowed<String>, Integer, StateSerdes<String, byte[]>, Bytes> getSerdeToStoreKey() {
+        return SERDE_TO_STORE_BINARY_MAP.get(schemaType);
+    }
 
     @Test
     public void testHasNextConditionUsingNullKeys() {
-        final List<KeyValue<Bytes, Integer>> keys = Arrays.asList(
-            KeyValue.pair(WindowKeySchema.toStoreKeyBinary(new Windowed<>(Bytes.wrap(new byte[] {0, 0}), new TimeWindow(0, 1)), 0), 1),
-            KeyValue.pair(WindowKeySchema.toStoreKeyBinary(new Windowed<>(Bytes.wrap(new byte[] {0}), new TimeWindow(0, 1)), 0), 2),
-            KeyValue.pair(WindowKeySchema.toStoreKeyBinary(new Windowed<>(Bytes.wrap(new byte[] {0, 0, 0}), new TimeWindow(0, 1)), 0), 3),
-            KeyValue.pair(WindowKeySchema.toStoreKeyBinary(new Windowed<>(Bytes.wrap(new byte[] {0}), new TimeWindow(10, 20)), 4), 4),
-            KeyValue.pair(WindowKeySchema.toStoreKeyBinary(new Windowed<>(Bytes.wrap(new byte[] {0, 0}), new TimeWindow(10, 20)), 5), 5),
-            KeyValue.pair(WindowKeySchema.toStoreKeyBinary(new Windowed<>(Bytes.wrap(new byte[] {0, 0, 0}), new TimeWindow(10, 20)), 6), 6));
+        final BiFunction<Windowed<Bytes>, Integer, Bytes> toStoreKeyBinary = getToStoreKeyBinaryWindowParam();
+        final List<KeyValue<Bytes, Integer>> keys = asList(
+            KeyValue.pair(toStoreKeyBinary.apply(new Windowed<>(Bytes.wrap(new byte[] {0, 0}), new TimeWindow(0, 1)), 0), 1),
+            KeyValue.pair(toStoreKeyBinary.apply(new Windowed<>(Bytes.wrap(new byte[] {0}), new TimeWindow(0, 1)), 0), 2),
+            KeyValue.pair(toStoreKeyBinary.apply(new Windowed<>(Bytes.wrap(new byte[] {0, 0, 0}), new TimeWindow(0, 1)), 0), 3),
+            KeyValue.pair(toStoreKeyBinary.apply(new Windowed<>(Bytes.wrap(new byte[] {0}), new TimeWindow(10, 20)), 4), 4),
+            KeyValue.pair(toStoreKeyBinary.apply(new Windowed<>(Bytes.wrap(new byte[] {0, 0}), new TimeWindow(10, 20)), 5), 5),
+            KeyValue.pair(toStoreKeyBinary.apply(new Windowed<>(Bytes.wrap(new byte[] {0, 0, 0}), new TimeWindow(10, 20)), 6), 6));
         try (final DelegatingPeekingKeyValueIterator<Bytes, Integer> iterator = new DelegatingPeekingKeyValueIterator<>("foo", new KeyValueIteratorStub<>(keys.iterator()))) {
 
-            final HasNextCondition hasNextCondition = windowKeySchema.hasNextCondition(null, null, 0, Long.MAX_VALUE);
+            final HasNextCondition hasNextCondition = keySchema.hasNextCondition(null, null, 0, Long.MAX_VALUE, true);
             final List<Integer> results = new ArrayList<>();
             while (hasNextCondition.hasNext(iterator)) {
                 results.add(iterator.next().value);
             }
 
-            assertThat(results, equalTo(Arrays.asList(1, 2, 3, 4, 5, 6)));
+            assertThat(results, equalTo(asList(1, 2, 3, 4, 5, 6)));
         }
     }
 
     @Test
     public void testUpperBoundWithLargeTimestamps() {
-        final Bytes upper = windowKeySchema.upperRange(Bytes.wrap(new byte[] {0xA, 0xB, 0xC}), Long.MAX_VALUE);
+        final Bytes upper = keySchema.upperRange(Bytes.wrap(new byte[] {0xA, 0xB, 0xC}), Long.MAX_VALUE);
+        final TriFunction<byte[], Long, Integer, Bytes> toStoreKeyBinary = getToStoreKeyBinaryBytesParam();
 
         assertThat(
             "shorter key with max timestamp should be in range",
             upper.compareTo(
-                WindowKeySchema.toStoreKeyBinary(
+                toStoreKeyBinary.apply(
                     new byte[] {0xA},
                     Long.MAX_VALUE,
                     Integer.MAX_VALUE
@@ -93,7 +225,7 @@ public class WindowKeySchemaTest {
         assertThat(
             "shorter key with max timestamp should be in range",
             upper.compareTo(
-                WindowKeySchema.toStoreKeyBinary(
+                toStoreKeyBinary.apply(
                     new byte[] {0xA, 0xB},
                     Long.MAX_VALUE,
                     Integer.MAX_VALUE
@@ -101,17 +233,24 @@ public class WindowKeySchemaTest {
             ) >= 0
         );
 
-        assertThat(upper, equalTo(WindowKeySchema.toStoreKeyBinary(new byte[] {0xA}, Long.MAX_VALUE, Integer.MAX_VALUE)));
+        if (schemaType == SchemaType.PrefixedTimeFirstSchema) {
+            assertThat(upper, equalTo(
+                toStoreKeyBinary.apply(new byte[]{(byte) 0xFF, (byte) 0xFF, (byte) 0xFF}, Long.MAX_VALUE, Integer.MAX_VALUE)));
+        } else {
+            assertThat(upper, equalTo(
+                toStoreKeyBinary.apply(new byte[]{0xA}, Long.MAX_VALUE, Integer.MAX_VALUE)));
+        }
     }
 
     @Test
     public void testUpperBoundWithKeyBytesLargerThanFirstTimestampByte() {
-        final Bytes upper = windowKeySchema.upperRange(Bytes.wrap(new byte[] {0xA, (byte) 0x8F, (byte) 0x9F}), Long.MAX_VALUE);
+        final Bytes upper = keySchema.upperRange(Bytes.wrap(new byte[] {0xA, (byte) 0x8F, (byte) 0x9F}), Long.MAX_VALUE);
+        final TriFunction<byte[], Long, Integer, Bytes> toStoreKeyBinary = getToStoreKeyBinaryBytesParam();
 
         assertThat(
             "shorter key with max timestamp should be in range",
             upper.compareTo(
-                WindowKeySchema.toStoreKeyBinary(
+                toStoreKeyBinary.apply(
                     new byte[] {0xA, (byte) 0x8F},
                     Long.MAX_VALUE,
                     Integer.MAX_VALUE
@@ -119,62 +258,136 @@ public class WindowKeySchemaTest {
             ) >= 0
         );
 
-        assertThat(upper, equalTo(WindowKeySchema.toStoreKeyBinary(new byte[] {0xA, (byte) 0x8F, (byte) 0x9F}, Long.MAX_VALUE, Integer.MAX_VALUE)));
+        if (schemaType == SchemaType.PrefixedTimeFirstSchema) {
+            assertThat(upper, equalTo(
+                toStoreKeyBinary.apply(new byte[]{(byte) 0xFF, (byte) 0xFF, (byte) 0xFF}, Long.MAX_VALUE, Integer.MAX_VALUE)));
+        } else {
+            assertThat(upper, equalTo(
+                toStoreKeyBinary.apply(new byte[]{0xA, (byte) 0x8F, (byte) 0x9F}, Long.MAX_VALUE,
+                    Integer.MAX_VALUE)));
+        }
     }
 
 
     @Test
     public void testUpperBoundWithKeyBytesLargerAndSmallerThanFirstTimestampByte() {
-        final Bytes upper = windowKeySchema.upperRange(Bytes.wrap(new byte[] {0xC, 0xC, 0x9}), 0x0AffffffffffffffL);
+        final Bytes upper = keySchema.upperRange(Bytes.wrap(new byte[] {0xC, 0xC, 0x9}), 0x0AffffffffffffffL);
+        final TriFunction<byte[], Long, Integer, Bytes> toStoreKeyBinary = getToStoreKeyBinaryBytesParam();
 
         assertThat(
             "shorter key with customized timestamp should be in range",
             upper.compareTo(
-                WindowKeySchema.toStoreKeyBinary(
+                toStoreKeyBinary.apply(
                     new byte[] {0xC, 0xC},
                     0x0AffffffffffffffL,
                     Integer.MAX_VALUE
                 )
             ) >= 0
         );
-
-        assertThat(upper, equalTo(WindowKeySchema.toStoreKeyBinary(new byte[] {0xC, 0xC}, 0x0AffffffffffffffL, Integer.MAX_VALUE)));
+        if (schemaType == SchemaType.PrefixedTimeFirstSchema) {
+            assertThat(upper, equalTo(
+                toStoreKeyBinary.apply(new byte[]{(byte) 0xFF, (byte) 0xFF, (byte) 0xFF}, 0x0AffffffffffffffL, Integer.MAX_VALUE)));
+        } else {
+            assertThat(upper, equalTo(
+                toStoreKeyBinary.apply(new byte[]{0xC, 0xC}, 0x0AffffffffffffffL,
+                    Integer.MAX_VALUE)));
+        }
     }
 
     @Test
     public void testUpperBoundWithZeroTimestamp() {
-        final Bytes upper = windowKeySchema.upperRange(Bytes.wrap(new byte[] {0xA, 0xB, 0xC}), 0);
-        assertThat(upper, equalTo(WindowKeySchema.toStoreKeyBinary(new byte[] {0xA, 0xB, 0xC}, 0, Integer.MAX_VALUE)));
+        final Bytes upper = keySchema.upperRange(Bytes.wrap(new byte[] {0xA, 0xB, 0xC}), 0);
+        final TriFunction<byte[], Long, Integer, Bytes> toStoreKeyBinary = getToStoreKeyBinaryBytesParam();
+
+        if (schemaType == SchemaType.PrefixedTimeFirstSchema) {
+            assertThat(upper, equalTo(
+                toStoreKeyBinary.apply(new byte[]{(byte) 0xFF, (byte) 0xFF, (byte) 0xFF}, 0x0L, Integer.MAX_VALUE)));
+        } else {
+            assertThat(upper,
+                equalTo(toStoreKeyBinary.apply(new byte[]{0xA, 0xB, 0xC}, 0L, Integer.MAX_VALUE)));
+        }
     }
 
     @Test
     public void testLowerBoundWithZeroTimestamp() {
-        final Bytes lower = windowKeySchema.lowerRange(Bytes.wrap(new byte[] {0xA, 0xB, 0xC}), 0);
-        assertThat(lower, equalTo(WindowKeySchema.toStoreKeyBinary(new byte[] {0xA, 0xB, 0xC}, 0, 0)));
+        final Bytes lower = keySchema.lowerRange(Bytes.wrap(new byte[] {0xA, 0xB, 0xC}), 0);
+        final TriFunction<byte[], Long, Integer, Bytes> toStoreKeyBinary = getToStoreKeyBinaryBytesParam();
+        assertThat(
+            "Larger key prefix should be in range.",
+            lower.compareTo(
+                toStoreKeyBinary.apply(
+                    new byte[] {0xA, 0xB, 0xC, 0x0},
+                    0L,
+                    0
+                )
+            ) < 0
+        );
+
+        if (schemaType == SchemaType.PrefixedTimeFirstSchema) {
+            final Bytes expected = Bytes.wrap(ByteBuffer.allocate(1 + 8 + 3)
+                .put((byte) 0x0)
+                .putLong(0)
+                .put(new byte[] {0xA, 0xB, 0xC})
+                .array());
+            assertThat(lower, equalTo(expected));
+        } else {
+            assertThat(lower, equalTo(toStoreKeyBinary.apply(new byte[]{0xA, 0xB, 0xC}, 0L, 0)));
+        }
     }
 
     @Test
-    public void testLowerBoundWithMonZeroTimestamp() {
-        final Bytes lower = windowKeySchema.lowerRange(Bytes.wrap(new byte[] {0xA, 0xB, 0xC}), 42);
-        assertThat(lower, equalTo(WindowKeySchema.toStoreKeyBinary(new byte[] {0xA, 0xB, 0xC}, 0, 0)));
+    public void testLowerBoundWithNonZeroTimestamp() {
+        final Bytes lower = keySchema.lowerRange(Bytes.wrap(new byte[] {0xA, 0xB, 0xC}), 42);
+        final TriFunction<byte[], Long, Integer, Bytes> toStoreKeyBinary = getToStoreKeyBinaryBytesParam();
+
+        assertThat(
+            "Larger timestamp should be in range",
+            lower.compareTo(
+                toStoreKeyBinary.apply(
+                    new byte[] {0xA, 0xB, 0xC, 0x0},
+                    43L,
+                    0
+                )
+            ) < 0
+        );
+
+        if (schemaType == SchemaType.PrefixedTimeFirstSchema) {
+            final Bytes expected = Bytes.wrap(ByteBuffer.allocate(1 + 8 + 3)
+                .put((byte) 0x0)
+                .putLong(42)
+                .put(new byte[] {0xA, 0xB, 0xC})
+                .array());
+            assertThat(lower, equalTo(expected));
+        } else {
+            assertThat(lower, equalTo(toStoreKeyBinary.apply(new byte[]{0xA, 0xB, 0xC}, 0L, 0)));
+        }
     }
 
     @Test
     public void testLowerBoundMatchesTrailingZeros() {
-        final Bytes lower = windowKeySchema.lowerRange(Bytes.wrap(new byte[] {0xA, 0xB, 0xC}), Long.MAX_VALUE - 1);
+        final Bytes lower = keySchema.lowerRange(Bytes.wrap(new byte[] {0xA, 0xB, 0xC}), Long.MAX_VALUE - 1);
+        final TriFunction<byte[], Long, Integer, Bytes> toStoreKeyBinary = getToStoreKeyBinaryBytesParam();
 
         assertThat(
             "appending zeros to key should still be in range",
             lower.compareTo(
-                WindowKeySchema.toStoreKeyBinary(
+                toStoreKeyBinary.apply(
                     new byte[] {0xA, 0xB, 0xC, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
                     Long.MAX_VALUE - 1,
                     0
                 )
             ) < 0
         );
-
-        assertThat(lower, equalTo(WindowKeySchema.toStoreKeyBinary(new byte[] {0xA, 0xB, 0xC}, 0, 0)));
+        if (schemaType == SchemaType.PrefixedTimeFirstSchema) {
+            final Bytes expected = Bytes.wrap(ByteBuffer.allocate(1 + 8 + 3)
+                .put((byte) 0x0)
+                .putLong(Long.MAX_VALUE - 1)
+                .put(new byte[] {0xA, 0xB, 0xC})
+                .array());
+            assertThat(lower, equalTo(expected));
+        } else {
+            assertThat(lower, equalTo(toStoreKeyBinary.apply(new byte[]{0xA, 0xB, 0xC}, 0L, 0)));
+        }
     }
 
     @Test
@@ -203,8 +416,12 @@ public class WindowKeySchemaTest {
 
     @Test
     public void shouldSerializeDeserializeExpectedChangelogWindowSize() {
+        if (schemaType != SchemaType.WindowKeySchema) {
+            // Changelog key is serialized using WindowKeySchema
+            return;
+        }
         // Key-value containing serialized store key binary and the key's window size
-        final List<KeyValue<Bytes, Integer>> keys = Arrays.asList(
+        final List<KeyValue<Bytes, Integer>> keys = asList(
             KeyValue.pair(WindowKeySchema.toStoreKeyBinary(new Windowed<>(Bytes.wrap(new byte[] {0}), new TimeWindow(0, 1)), 0), 1),
             KeyValue.pair(WindowKeySchema.toStoreKeyBinary(new Windowed<>(Bytes.wrap(new byte[] {0, 0}), new TimeWindow(0, 10)), 0), 10),
             KeyValue.pair(WindowKeySchema.toStoreKeyBinary(new Windowed<>(Bytes.wrap(new byte[] {0, 0, 0}), new TimeWindow(10, 30)), 6), 20));
@@ -218,7 +435,7 @@ public class WindowKeySchemaTest {
             results.add(resultWindow.end() - resultWindow.start());
         }
 
-        assertThat(results, equalTo(Arrays.asList(1L, 10L, 20L)));
+        assertThat(results, equalTo(asList(1L, 10L, 20L)));
     }
 
     @Test
@@ -238,45 +455,81 @@ public class WindowKeySchemaTest {
 
     @Test
     public void shouldConvertToBinaryAndBack() {
-        final Bytes serialized = WindowKeySchema.toStoreKeyBinary(windowedKey, 0, stateSerdes);
-        final Windowed<String> result = WindowKeySchema.fromStoreKey(serialized.get(), endTime - startTime, stateSerdes.keyDeserializer(), stateSerdes.topic());
+        final TriFunction<Windowed<String>, Integer, StateSerdes<String, byte[]>, Bytes> toStoreKeyBinary = getSerdeToStoreKey();
+        final Bytes serialized = toStoreKeyBinary.apply(windowedKey, 0, stateSerdes);
+        final Windowed<String> result;
+        if (schemaType == SchemaType.WindowKeySchema) {
+            result = WindowKeySchema.fromStoreKey(serialized.get(),
+                endTime - startTime, stateSerdes.keyDeserializer(), stateSerdes.topic());
+        } else if (schemaType == SchemaType.PrefixedTimeFirstSchema) {
+            result = TimeFirstWindowKeySchema.fromStoreKey(serialized.get(),
+                endTime - startTime, stateSerdes.keyDeserializer(), stateSerdes.topic());
+        } else {
+            result = KeyFirstWindowKeySchema.fromStoreKey(serialized.get(),
+                endTime - startTime, stateSerdes.keyDeserializer(), stateSerdes.topic());
+        }
         assertEquals(windowedKey, result);
     }
 
     @Test
     public void shouldExtractSequenceFromBinary() {
-        final Bytes serialized = WindowKeySchema.toStoreKeyBinary(windowedKey, 0, stateSerdes);
-        assertEquals(0, WindowKeySchema.extractStoreSequence(serialized.get()));
+        final TriFunction<Windowed<String>, Integer, StateSerdes<String, byte[]>, Bytes> toStoreKeyBinary = getSerdeToStoreKey();
+        final Bytes serialized = toStoreKeyBinary.apply(windowedKey, 0, stateSerdes);
+        final Function<byte[], Integer> extractStoreSequence = getExtractSeqFunc();
+        assertEquals(0, (int) extractStoreSequence.apply(serialized.get()));
     }
 
     @Test
     public void shouldExtractStartTimeFromBinary() {
-        final Bytes serialized = WindowKeySchema.toStoreKeyBinary(windowedKey, 0, stateSerdes);
-        assertEquals(startTime, WindowKeySchema.extractStoreTimestamp(serialized.get()));
+        final TriFunction<Windowed<String>, Integer, StateSerdes<String, byte[]>, Bytes> toStoreKeyBinary = getSerdeToStoreKey();
+        final Bytes serialized = toStoreKeyBinary.apply(windowedKey, 0, stateSerdes);
+        final Function<byte[], Long> extractStoreTimestamp = getExtractTimestampFunc();
+        assertEquals(startTime, (long) extractStoreTimestamp.apply(serialized.get()));
     }
 
     @Test
     public void shouldExtractWindowFromBinary() {
-        final Bytes serialized = WindowKeySchema.toStoreKeyBinary(windowedKey, 0, stateSerdes);
-        assertEquals(window, WindowKeySchema.extractStoreWindow(serialized.get(), endTime - startTime));
+        final TriFunction<Windowed<String>, Integer, StateSerdes<String, byte[]>, Bytes> toStoreKeyBinary = getSerdeToStoreKey();
+        final Bytes serialized = toStoreKeyBinary.apply(windowedKey, 0, stateSerdes);
+        final BiFunction<byte[], Long, Window> extractStoreWindow = getExtractStoreWindow();
+        assertEquals(window, extractStoreWindow.apply(serialized.get(), endTime - startTime));
     }
 
     @Test
     public void shouldExtractKeyBytesFromBinary() {
-        final Bytes serialized = WindowKeySchema.toStoreKeyBinary(windowedKey, 0, stateSerdes);
-        assertArrayEquals(key.getBytes(), WindowKeySchema.extractStoreKeyBytes(serialized.get()));
+        final TriFunction<Windowed<String>, Integer, StateSerdes<String, byte[]>, Bytes> toStoreKeyBinary = getSerdeToStoreKey();
+        final Bytes serialized = toStoreKeyBinary.apply(windowedKey, 0, stateSerdes);
+        final Function<byte[], byte[]> extractStoreKeyBytes = getExtractStorageKey();
+        assertArrayEquals(key.getBytes(), extractStoreKeyBytes.apply(serialized.get()));
     }
 
     @Test
-    public void shouldExtractKeyFromBinary() {
-        final Bytes serialized = WindowKeySchema.toStoreKeyBinary(windowedKey, 0, stateSerdes);
-        assertEquals(windowedKey, WindowKeySchema.fromStoreKey(serialized.get(), endTime - startTime, stateSerdes.keyDeserializer(), stateSerdes.topic()));
+    public void shouldExtractBytesKeyFromBinary() {
+        final Windowed<Bytes> windowedBytesKey = new Windowed<>(Bytes.wrap(key.getBytes()), window);
+        final BiFunction<Windowed<Bytes>, Integer, Bytes> toStoreKeyBinary = getToStoreKeyBinaryWindowParam();
+        final Bytes serialized = toStoreKeyBinary.apply(windowedBytesKey, 0);
+        final BiFunction<byte[], Long, Windowed<Bytes>> fromStoreBytesKey = getFromStorageKey();
+        assertEquals(windowedBytesKey, fromStoreBytesKey.apply(serialized.get(), endTime - startTime));
     }
 
     @Test
-    public void shouldExtractBytesKeyFromBinary() {
-        final Windowed<Bytes> windowedBytesKey = new Windowed<>(Bytes.wrap(key.getBytes()), window);
-        final Bytes serialized = WindowKeySchema.toStoreKeyBinary(windowedBytesKey, 0);
-        assertEquals(windowedBytesKey, WindowKeySchema.fromStoreBytesKey(serialized.get(), endTime - startTime));
+    public void shouldConvertFromNonPrefixWindowKey() {
+        final Function<byte[], byte[]> fromWindowKey = FROM_WINDOW_KEY_MAP.get(schemaType);
+        final TriFunction<byte[], Long, Integer, Bytes> toStoreKeyBinary = BYTES_TO_STORE_BINARY_MAP.get(SchemaType.WindowKeySchema);
+        if (fromWindowKey != null) {
+            final Bytes windowKeyBytes = toStoreKeyBinary.apply(key.getBytes(), startTime, 0);
+            final byte[] convertedBytes = fromWindowKey.apply(windowKeyBytes.get());
+            final Function<byte[], Long> extractStoreTimestamp = getExtractTimestampFunc();
+            final Function<byte[], Integer> extractStoreSequence = getExtractSeqFunc();
+            final Function<byte[], byte[]> extractStoreKeyBytes = getExtractStorageKey();
+
+            final byte[] rawkey = extractStoreKeyBytes.apply(convertedBytes);
+            final long timestamp = extractStoreTimestamp.apply(convertedBytes);
+            final int seq = extractStoreSequence.apply(convertedBytes);
+
+            assertEquals(0, seq);
+            assertEquals(startTime, timestamp);
+            assertEquals(Bytes.wrap(key.getBytes()), Bytes.wrap(rawkey));
+        }
     }
 }