You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by mj...@apache.org on 2019/03/26 21:23:14 UTC

[kafka] branch trunk updated: KAFKA-3522: Add RocksDBTimestampedSegmentedBytesStore (#6186)

This is an automated email from the ASF dual-hosted git repository.

mjsax pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new dc0601a  KAFKA-3522: Add RocksDBTimestampedSegmentedBytesStore (#6186)
dc0601a is described below

commit dc0601a1c604bea3f426ed25b6c20176ff444079
Author: Matthias J. Sax <mj...@apache.org>
AuthorDate: Tue Mar 26 14:23:01 2019 -0700

    KAFKA-3522: Add RocksDBTimestampedSegmentedBytesStore (#6186)
    
    Reviewers: Bill Bejeck <bi...@confluent.io>, John Roesler <jo...@confluent.io>, Guozhang Wang <gu...@confluent.io>
---
 .../org/apache/kafka/streams/state/Stores.java     |   8 +-
 ...ava => AbstractRocksDBSegmentedBytesStore.java} | 108 ++---
 .../{Segment.java => BulkLoadingStore.java}        |  20 +-
 .../streams/state/internals/KeyValueSegment.java   |   4 +-
 .../streams/state/internals/KeyValueSegments.java  |   7 +-
 .../internals/RocksDBSegmentedBytesStore.java      | 267 +----------
 .../streams/state/internals/RocksDBStore.java      |  41 +-
 ... => RocksDBTimestampedSegmentedBytesStore.java} |  20 +-
 .../state/internals/RocksDBTimestampedStore.java   |  34 +-
 .../internals/RocksDbWindowBytesStoreSupplier.java |  20 +-
 .../kafka/streams/state/internals/Segment.java     |   4 +-
 ...eyValueSegment.java => TimestampedSegment.java} |  12 +-
 ...ValueSegments.java => TimestampedSegments.java} |  15 +-
 .../org/apache/kafka/streams/state/StoresTest.java |  13 +-
 ...=> AbstractRocksDBSegmentedBytesStoreTest.java} |  87 ++--
 .../state/internals/KeyValueSegmentTest.java       |  99 +++++
 .../internals/RocksDBSegmentedBytesStoreTest.java  | 487 +--------------------
 ...RocksDBTimestampedSegmentedBytesStoreTest.java} |  29 +-
 ...tIteratorTest.java => SegmentIteratorTest.java} |   2 +-
 .../state/internals/TimestampedSegmentTest.java    |  99 +++++
 .../state/internals/TimetampedSegmentsTest.java    | 315 +++++++++++++
 21 files changed, 765 insertions(+), 926 deletions(-)

diff --git a/streams/src/main/java/org/apache/kafka/streams/state/Stores.java b/streams/src/main/java/org/apache/kafka/streams/state/Stores.java
index ac2a023..70bc15a 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/Stores.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/Stores.java
@@ -265,7 +265,13 @@ public class Stores {
                                                    + windowSize + "], retention=[" + retentionPeriod + "]");
         }
 
-        return new RocksDbWindowBytesStoreSupplier(name, retentionPeriod, segmentInterval, windowSize, retainDuplicates);
+        return new RocksDbWindowBytesStoreSupplier(
+            name,
+            retentionPeriod,
+            segmentInterval,
+            windowSize,
+            retainDuplicates,
+            false);
     }
 
     /**
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBSegmentedBytesStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractRocksDBSegmentedBytesStore.java
similarity index 72%
copy from streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBSegmentedBytesStore.java
copy to streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractRocksDBSegmentedBytesStore.java
index f733c80..34639e3 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBSegmentedBytesStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractRocksDBSegmentedBytesStore.java
@@ -43,76 +43,89 @@ import java.util.Set;
 import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.EXPIRED_WINDOW_RECORD_DROP;
 import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addInvocationRateAndCount;
 
-public class RocksDBSegmentedBytesStore implements SegmentedBytesStore {
-    private static final Logger LOG = LoggerFactory.getLogger(RocksDBSegmentedBytesStore.class);
+public class AbstractRocksDBSegmentedBytesStore<S extends Segment> implements SegmentedBytesStore {
+    private static final Logger LOG = LoggerFactory.getLogger(AbstractRocksDBSegmentedBytesStore.class);
     private final String name;
-    private final KeyValueSegments segments;
+    private final AbstractSegments<S> segments;
     private final String metricScope;
     private final KeySchema keySchema;
     private InternalProcessorContext context;
     private volatile boolean open;
-    private Set<KeyValueSegment> bulkLoadSegments;
+    private Set<S> bulkLoadSegments;
     private Sensor expiredRecordSensor;
     private long observedStreamTime = ConsumerRecord.NO_TIMESTAMP;
 
-    RocksDBSegmentedBytesStore(final String name,
-                               final String metricScope,
-                               final long retention,
-                               final long segmentInterval,
-                               final KeySchema keySchema) {
+    AbstractRocksDBSegmentedBytesStore(final String name,
+                                       final String metricScope,
+                                       final KeySchema keySchema,
+                                       final AbstractSegments<S> segments) {
         this.name = name;
         this.metricScope = metricScope;
         this.keySchema = keySchema;
-        this.segments = new KeyValueSegments(name, retention, segmentInterval);
+        this.segments = segments;
     }
 
     @Override
-    public KeyValueIterator<Bytes, byte[]> fetch(final Bytes key, final long from, final long to) {
-        final List<KeyValueSegment> searchSpace = keySchema.segmentsToSearch(segments, from, to);
+    public KeyValueIterator<Bytes, byte[]> fetch(final Bytes key,
+                                                 final long from,
+                                                 final long to) {
+        final List<S> searchSpace = keySchema.segmentsToSearch(segments, from, to);
 
         final Bytes binaryFrom = keySchema.lowerRangeFixedSize(key, from);
         final Bytes binaryTo = keySchema.upperRangeFixedSize(key, to);
 
-        return new SegmentIterator<>(searchSpace.iterator(),
-                                     keySchema.hasNextCondition(key, key, from, to),
-                                     binaryFrom, binaryTo);
+        return new SegmentIterator<>(
+            searchSpace.iterator(),
+            keySchema.hasNextCondition(key, key, from, to),
+            binaryFrom,
+            binaryTo);
     }
 
     @Override
-    public KeyValueIterator<Bytes, byte[]> fetch(final Bytes keyFrom, final Bytes keyTo, final long from, final long to) {
-        final List<KeyValueSegment> searchSpace = keySchema.segmentsToSearch(segments, from, to);
+    public KeyValueIterator<Bytes, byte[]> fetch(final Bytes keyFrom,
+                                                 final Bytes keyTo,
+                                                 final long from,
+                                                 final long to) {
+        final List<S> searchSpace = keySchema.segmentsToSearch(segments, from, to);
 
         final Bytes binaryFrom = keySchema.lowerRange(keyFrom, from);
         final Bytes binaryTo = keySchema.upperRange(keyTo, to);
 
-        return new SegmentIterator<>(searchSpace.iterator(),
-                                     keySchema.hasNextCondition(keyFrom, keyTo, from, to),
-                                     binaryFrom, binaryTo);
+        return new SegmentIterator<>(
+            searchSpace.iterator(),
+            keySchema.hasNextCondition(keyFrom, keyTo, from, to),
+            binaryFrom,
+            binaryTo);
     }
 
     @Override
     public KeyValueIterator<Bytes, byte[]> all() {
-        final List<KeyValueSegment> searchSpace = segments.allSegments();
+        final List<S> searchSpace = segments.allSegments();
 
-        return new SegmentIterator<>(searchSpace.iterator(),
-                                     keySchema.hasNextCondition(null, null, 0, Long.MAX_VALUE),
-                                     null, null);
+        return new SegmentIterator<>(
+            searchSpace.iterator(),
+            keySchema.hasNextCondition(null, null, 0, Long.MAX_VALUE),
+            null,
+            null);
     }
 
     @Override
-    public KeyValueIterator<Bytes, byte[]> fetchAll(final long timeFrom, final long timeTo) {
-        final List<KeyValueSegment> searchSpace = segments.segments(timeFrom, timeTo);
-
-        return new SegmentIterator<>(searchSpace.iterator(),
-                                     keySchema.hasNextCondition(null, null, timeFrom, timeTo),
-                                     null, null);
+    public KeyValueIterator<Bytes, byte[]> fetchAll(final long timeFrom,
+                                                    final long timeTo) {
+        final List<S> searchSpace = segments.segments(timeFrom, timeTo);
+
+        return new SegmentIterator<>(
+            searchSpace.iterator(),
+            keySchema.hasNextCondition(null, null, timeFrom, timeTo),
+            null,
+            null);
     }
 
     @Override
     public void remove(final Bytes key) {
         final long timestamp = keySchema.segmentTimestamp(key);
         observedStreamTime = Math.max(observedStreamTime, timestamp);
-        final KeyValueSegment segment = segments.getSegmentForTimestamp(timestamp);
+        final S segment = segments.getSegmentForTimestamp(timestamp);
         if (segment == null) {
             return;
         }
@@ -120,11 +133,12 @@ public class RocksDBSegmentedBytesStore implements SegmentedBytesStore {
     }
 
     @Override
-    public void put(final Bytes key, final byte[] value) {
+    public void put(final Bytes key,
+                    final byte[] value) {
         final long timestamp = keySchema.segmentTimestamp(key);
         observedStreamTime = Math.max(observedStreamTime, timestamp);
         final long segmentId = segments.segmentId(timestamp);
-        final KeyValueSegment segment = segments.getOrCreateSegmentIfLive(segmentId, context, observedStreamTime);
+        final S segment = segments.getOrCreateSegmentIfLive(segmentId, context, observedStreamTime);
         if (segment == null) {
             expiredRecordSensor.record();
             LOG.debug("Skipping record for expired segment.");
@@ -135,7 +149,7 @@ public class RocksDBSegmentedBytesStore implements SegmentedBytesStore {
 
     @Override
     public byte[] get(final Bytes key) {
-        final KeyValueSegment segment = segments.getSegmentForTimestamp(keySchema.segmentTimestamp(key));
+        final S segment = segments.getSegmentForTimestamp(keySchema.segmentTimestamp(key));
         if (segment == null) {
             return null;
         }
@@ -148,11 +162,11 @@ public class RocksDBSegmentedBytesStore implements SegmentedBytesStore {
     }
 
     @Override
-    public void init(final ProcessorContext context, final StateStore root) {
+    public void init(final ProcessorContext context,
+                     final StateStore root) {
         this.context = (InternalProcessorContext) context;
 
         final StreamsMetricsImpl metrics = this.context.metrics();
-
         final String taskName = context.taskId().toString();
 
         expiredRecordSensor = metrics.storeLevelSensor(
@@ -200,16 +214,16 @@ public class RocksDBSegmentedBytesStore implements SegmentedBytesStore {
     }
 
     // Visible for testing
-    List<KeyValueSegment> getSegments() {
+    List<S> getSegments() {
         return segments.allSegments();
     }
 
     // Visible for testing
     void restoreAllInternal(final Collection<KeyValue<byte[], byte[]>> records) {
         try {
-            final Map<KeyValueSegment, WriteBatch> writeBatchMap = getWriteBatches(records);
-            for (final Map.Entry<KeyValueSegment, WriteBatch> entry : writeBatchMap.entrySet()) {
-                final KeyValueSegment segment = entry.getKey();
+            final Map<S, WriteBatch> writeBatchMap = getWriteBatches(records);
+            for (final Map.Entry<S, WriteBatch> entry : writeBatchMap.entrySet()) {
+                final S segment = entry.getKey();
                 final WriteBatch batch = entry.getValue();
                 segment.write(batch);
             }
@@ -219,18 +233,18 @@ public class RocksDBSegmentedBytesStore implements SegmentedBytesStore {
     }
 
     // Visible for testing
-    Map<KeyValueSegment, WriteBatch> getWriteBatches(final Collection<KeyValue<byte[], byte[]>> records) {
+    Map<S, WriteBatch> getWriteBatches(final Collection<KeyValue<byte[], byte[]>> records) {
         // advance stream time to the max timestamp in the batch
         for (final KeyValue<byte[], byte[]> record : records) {
             final long timestamp = keySchema.segmentTimestamp(Bytes.wrap(record.key));
             observedStreamTime = Math.max(observedStreamTime, timestamp);
         }
 
-        final Map<KeyValueSegment, WriteBatch> writeBatchMap = new HashMap<>();
+        final Map<S, WriteBatch> writeBatchMap = new HashMap<>();
         for (final KeyValue<byte[], byte[]> record : records) {
             final long timestamp = keySchema.segmentTimestamp(Bytes.wrap(record.key));
             final long segmentId = segments.segmentId(timestamp);
-            final KeyValueSegment segment = segments.getOrCreateSegmentIfLive(segmentId, context, observedStreamTime);
+            final S segment = segments.getOrCreateSegmentIfLive(segmentId, context, observedStreamTime);
             if (segment != null) {
                 // This handles the case that state store is moved to a new client and does not
                 // have the local RocksDB instance for the segment. In this case, toggleDBForBulkLoading
@@ -245,11 +259,7 @@ public class RocksDBSegmentedBytesStore implements SegmentedBytesStore {
                 }
                 try {
                     final WriteBatch batch = writeBatchMap.computeIfAbsent(segment, s -> new WriteBatch());
-                    if (record.value == null) {
-                        batch.delete(record.key);
-                    } else {
-                        batch.put(record.key, record.value);
-                    }
+                    segment.addToBatch(record, batch);
                 } catch (final RocksDBException e) {
                     throw new ProcessorStateException("Error restoring batch to store " + this.name, e);
                 }
@@ -259,7 +269,7 @@ public class RocksDBSegmentedBytesStore implements SegmentedBytesStore {
     }
 
     private void toggleForBulkLoading(final boolean prepareForBulkload) {
-        for (final KeyValueSegment segment : segments.allSegments()) {
+        for (final S segment : segments.allSegments()) {
             segment.toggleDbForBulkLoading(prepareForBulkload);
         }
     }
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/Segment.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/BulkLoadingStore.java
similarity index 67%
copy from streams/src/main/java/org/apache/kafka/streams/state/internals/Segment.java
copy to streams/src/main/java/org/apache/kafka/streams/state/internals/BulkLoadingStore.java
index 8687ffc..1e27cc2 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/Segment.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/BulkLoadingStore.java
@@ -16,17 +16,13 @@
  */
 package org.apache.kafka.streams.state.internals;
 
-import org.apache.kafka.common.utils.Bytes;
-import org.apache.kafka.streams.processor.StateStore;
-import org.apache.kafka.streams.state.KeyValueIterator;
+import org.apache.kafka.streams.KeyValue;
+import org.rocksdb.RocksDBException;
+import org.rocksdb.WriteBatch;
 
-import java.io.IOException;
-
-public interface Segment extends StateStore {
-
-    void destroy() throws IOException;
-
-    KeyValueIterator<Bytes, byte[]> all();
-
-    KeyValueIterator<Bytes, byte[]> range(final Bytes from, final Bytes to);
+public interface BulkLoadingStore {
+    void toggleDbForBulkLoading(final boolean prepareForBulkload);
+    void addToBatch(final KeyValue<byte[], byte[]> record,
+                    final WriteBatch batch) throws RocksDBException;
+    void write(final WriteBatch batch) throws RocksDBException;
 }
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueSegment.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueSegment.java
index 697b67a..79e6110 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueSegment.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueSegment.java
@@ -25,7 +25,9 @@ import java.util.Objects;
 class KeyValueSegment extends RocksDBStore implements Comparable<KeyValueSegment>, Segment {
     public final long id;
 
-    KeyValueSegment(final String segmentName, final String windowName, final long id) {
+    KeyValueSegment(final String segmentName,
+                    final String windowName,
+                    final long id) {
         super(segmentName, windowName);
         this.id = id;
     }
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueSegments.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueSegments.java
index 0664551..9145d44 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueSegments.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueSegments.java
@@ -23,12 +23,15 @@ import org.apache.kafka.streams.processor.internals.InternalProcessorContext;
  */
 class KeyValueSegments extends AbstractSegments<KeyValueSegment> {
 
-    KeyValueSegments(final String name, final long retentionPeriod, final long segmentInterval) {
+    KeyValueSegments(final String name,
+                     final long retentionPeriod,
+                     final long segmentInterval) {
         super(name, retentionPeriod, segmentInterval);
     }
 
     @Override
-    public KeyValueSegment getOrCreateSegment(final long segmentId, final InternalProcessorContext context) {
+    public KeyValueSegment getOrCreateSegment(final long segmentId,
+                                              final InternalProcessorContext context) {
         if (segments.containsKey(segmentId)) {
             return segments.get(segmentId);
         } else {
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBSegmentedBytesStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBSegmentedBytesStore.java
index f733c80..b3de6e8 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBSegmentedBytesStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBSegmentedBytesStore.java
@@ -16,274 +16,13 @@
  */
 package org.apache.kafka.streams.state.internals;
 
-import org.apache.kafka.clients.consumer.ConsumerRecord;
-import org.apache.kafka.common.TopicPartition;
-import org.apache.kafka.common.metrics.Sensor;
-import org.apache.kafka.common.utils.Bytes;
-import org.apache.kafka.streams.KeyValue;
-import org.apache.kafka.streams.errors.ProcessorStateException;
-import org.apache.kafka.streams.processor.AbstractNotifyingBatchingRestoreCallback;
-import org.apache.kafka.streams.processor.ProcessorContext;
-import org.apache.kafka.streams.processor.StateStore;
-import org.apache.kafka.streams.processor.internals.InternalProcessorContext;
-import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
-import org.apache.kafka.streams.state.KeyValueIterator;
-import org.rocksdb.RocksDBException;
-import org.rocksdb.WriteBatch;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.util.Collection;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-
-import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.EXPIRED_WINDOW_RECORD_DROP;
-import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addInvocationRateAndCount;
-
-public class RocksDBSegmentedBytesStore implements SegmentedBytesStore {
-    private static final Logger LOG = LoggerFactory.getLogger(RocksDBSegmentedBytesStore.class);
-    private final String name;
-    private final KeyValueSegments segments;
-    private final String metricScope;
-    private final KeySchema keySchema;
-    private InternalProcessorContext context;
-    private volatile boolean open;
-    private Set<KeyValueSegment> bulkLoadSegments;
-    private Sensor expiredRecordSensor;
-    private long observedStreamTime = ConsumerRecord.NO_TIMESTAMP;
+public class RocksDBSegmentedBytesStore extends AbstractRocksDBSegmentedBytesStore<KeyValueSegment> {
 
     RocksDBSegmentedBytesStore(final String name,
                                final String metricScope,
                                final long retention,
                                final long segmentInterval,
                                final KeySchema keySchema) {
-        this.name = name;
-        this.metricScope = metricScope;
-        this.keySchema = keySchema;
-        this.segments = new KeyValueSegments(name, retention, segmentInterval);
-    }
-
-    @Override
-    public KeyValueIterator<Bytes, byte[]> fetch(final Bytes key, final long from, final long to) {
-        final List<KeyValueSegment> searchSpace = keySchema.segmentsToSearch(segments, from, to);
-
-        final Bytes binaryFrom = keySchema.lowerRangeFixedSize(key, from);
-        final Bytes binaryTo = keySchema.upperRangeFixedSize(key, to);
-
-        return new SegmentIterator<>(searchSpace.iterator(),
-                                     keySchema.hasNextCondition(key, key, from, to),
-                                     binaryFrom, binaryTo);
-    }
-
-    @Override
-    public KeyValueIterator<Bytes, byte[]> fetch(final Bytes keyFrom, final Bytes keyTo, final long from, final long to) {
-        final List<KeyValueSegment> searchSpace = keySchema.segmentsToSearch(segments, from, to);
-
-        final Bytes binaryFrom = keySchema.lowerRange(keyFrom, from);
-        final Bytes binaryTo = keySchema.upperRange(keyTo, to);
-
-        return new SegmentIterator<>(searchSpace.iterator(),
-                                     keySchema.hasNextCondition(keyFrom, keyTo, from, to),
-                                     binaryFrom, binaryTo);
-    }
-
-    @Override
-    public KeyValueIterator<Bytes, byte[]> all() {
-        final List<KeyValueSegment> searchSpace = segments.allSegments();
-
-        return new SegmentIterator<>(searchSpace.iterator(),
-                                     keySchema.hasNextCondition(null, null, 0, Long.MAX_VALUE),
-                                     null, null);
-    }
-
-    @Override
-    public KeyValueIterator<Bytes, byte[]> fetchAll(final long timeFrom, final long timeTo) {
-        final List<KeyValueSegment> searchSpace = segments.segments(timeFrom, timeTo);
-
-        return new SegmentIterator<>(searchSpace.iterator(),
-                                     keySchema.hasNextCondition(null, null, timeFrom, timeTo),
-                                     null, null);
-    }
-
-    @Override
-    public void remove(final Bytes key) {
-        final long timestamp = keySchema.segmentTimestamp(key);
-        observedStreamTime = Math.max(observedStreamTime, timestamp);
-        final KeyValueSegment segment = segments.getSegmentForTimestamp(timestamp);
-        if (segment == null) {
-            return;
-        }
-        segment.delete(key);
-    }
-
-    @Override
-    public void put(final Bytes key, final byte[] value) {
-        final long timestamp = keySchema.segmentTimestamp(key);
-        observedStreamTime = Math.max(observedStreamTime, timestamp);
-        final long segmentId = segments.segmentId(timestamp);
-        final KeyValueSegment segment = segments.getOrCreateSegmentIfLive(segmentId, context, observedStreamTime);
-        if (segment == null) {
-            expiredRecordSensor.record();
-            LOG.debug("Skipping record for expired segment.");
-        } else {
-            segment.put(key, value);
-        }
-    }
-
-    @Override
-    public byte[] get(final Bytes key) {
-        final KeyValueSegment segment = segments.getSegmentForTimestamp(keySchema.segmentTimestamp(key));
-        if (segment == null) {
-            return null;
-        }
-        return segment.get(key);
-    }
-
-    @Override
-    public String name() {
-        return name;
-    }
-
-    @Override
-    public void init(final ProcessorContext context, final StateStore root) {
-        this.context = (InternalProcessorContext) context;
-
-        final StreamsMetricsImpl metrics = this.context.metrics();
-
-        final String taskName = context.taskId().toString();
-
-        expiredRecordSensor = metrics.storeLevelSensor(
-            taskName,
-            name(),
-            EXPIRED_WINDOW_RECORD_DROP,
-            Sensor.RecordingLevel.INFO
-        );
-        addInvocationRateAndCount(
-            expiredRecordSensor,
-            "stream-" + metricScope + "-metrics",
-            metrics.tagMap("task-id", taskName, metricScope + "-id", name()),
-            EXPIRED_WINDOW_RECORD_DROP
-        );
-
-        segments.openExisting(this.context, observedStreamTime);
-
-        bulkLoadSegments = new HashSet<>(segments.allSegments());
-
-        // register and possibly restore the state from the logs
-        context.register(root, new RocksDBSegmentsBatchingRestoreCallback());
-
-        open = true;
-    }
-
-    @Override
-    public void flush() {
-        segments.flush();
-    }
-
-    @Override
-    public void close() {
-        open = false;
-        segments.close();
-    }
-
-    @Override
-    public boolean persistent() {
-        return true;
-    }
-
-    @Override
-    public boolean isOpen() {
-        return open;
-    }
-
-    // Visible for testing
-    List<KeyValueSegment> getSegments() {
-        return segments.allSegments();
-    }
-
-    // Visible for testing
-    void restoreAllInternal(final Collection<KeyValue<byte[], byte[]>> records) {
-        try {
-            final Map<KeyValueSegment, WriteBatch> writeBatchMap = getWriteBatches(records);
-            for (final Map.Entry<KeyValueSegment, WriteBatch> entry : writeBatchMap.entrySet()) {
-                final KeyValueSegment segment = entry.getKey();
-                final WriteBatch batch = entry.getValue();
-                segment.write(batch);
-            }
-        } catch (final RocksDBException e) {
-            throw new ProcessorStateException("Error restoring batch to store " + this.name, e);
-        }
-    }
-
-    // Visible for testing
-    Map<KeyValueSegment, WriteBatch> getWriteBatches(final Collection<KeyValue<byte[], byte[]>> records) {
-        // advance stream time to the max timestamp in the batch
-        for (final KeyValue<byte[], byte[]> record : records) {
-            final long timestamp = keySchema.segmentTimestamp(Bytes.wrap(record.key));
-            observedStreamTime = Math.max(observedStreamTime, timestamp);
-        }
-
-        final Map<KeyValueSegment, WriteBatch> writeBatchMap = new HashMap<>();
-        for (final KeyValue<byte[], byte[]> record : records) {
-            final long timestamp = keySchema.segmentTimestamp(Bytes.wrap(record.key));
-            final long segmentId = segments.segmentId(timestamp);
-            final KeyValueSegment segment = segments.getOrCreateSegmentIfLive(segmentId, context, observedStreamTime);
-            if (segment != null) {
-                // This handles the case that state store is moved to a new client and does not
-                // have the local RocksDB instance for the segment. In this case, toggleDBForBulkLoading
-                // will only close the database and open it again with bulk loading enabled.
-                if (!bulkLoadSegments.contains(segment)) {
-                    segment.toggleDbForBulkLoading(true);
-                    // If the store does not exist yet, the getOrCreateSegmentIfLive will call openDB that
-                    // makes the open flag for the newly created store.
-                    // if the store does exist already, then toggleDbForBulkLoading will make sure that
-                    // the store is already open here.
-                    bulkLoadSegments = new HashSet<>(segments.allSegments());
-                }
-                try {
-                    final WriteBatch batch = writeBatchMap.computeIfAbsent(segment, s -> new WriteBatch());
-                    if (record.value == null) {
-                        batch.delete(record.key);
-                    } else {
-                        batch.put(record.key, record.value);
-                    }
-                } catch (final RocksDBException e) {
-                    throw new ProcessorStateException("Error restoring batch to store " + this.name, e);
-                }
-            }
-        }
-        return writeBatchMap;
-    }
-
-    private void toggleForBulkLoading(final boolean prepareForBulkload) {
-        for (final KeyValueSegment segment : segments.allSegments()) {
-            segment.toggleDbForBulkLoading(prepareForBulkload);
-        }
-    }
-
-    private class RocksDBSegmentsBatchingRestoreCallback extends AbstractNotifyingBatchingRestoreCallback {
-
-        @Override
-        public void restoreAll(final Collection<KeyValue<byte[], byte[]>> records) {
-            restoreAllInternal(records);
-        }
-
-        @Override
-        public void onRestoreStart(final TopicPartition topicPartition,
-                                   final String storeName,
-                                   final long startingOffset,
-                                   final long endingOffset) {
-            toggleForBulkLoading(true);
-        }
-
-        @Override
-        public void onRestoreEnd(final TopicPartition topicPartition,
-                                 final String storeName,
-                                 final long totalRestored) {
-            toggleForBulkLoading(false);
-        }
+        super(name, metricScope, keySchema, new KeyValueSegments(name, retention, segmentInterval));
     }
-}
+}
\ No newline at end of file
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBStore.java
index 2ca3ad3..3e3e478 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBStore.java
@@ -65,7 +65,7 @@ import java.util.regex.Pattern;
 /**
  * A persistent key-value store based on RocksDB.
  */
-public class RocksDBStore implements KeyValueStore<Bytes, byte[]> {
+public class RocksDBStore implements KeyValueStore<Bytes, byte[]>, BulkLoadingStore {
     private static final Logger log = LoggerFactory.getLogger(RocksDBStore.class);
 
     private static final Pattern SST_FILE_EXTENSION = Pattern.compile(".*\\.sst");
@@ -344,7 +344,8 @@ public class RocksDBStore implements KeyValueStore<Bytes, byte[]> {
         }
     }
 
-    void toggleDbForBulkLoading(final boolean prepareForBulkload) {
+    @Override
+    public void toggleDbForBulkLoading(final boolean prepareForBulkload) {
         if (prepareForBulkload) {
             // if the store is not empty, we need to compact to get around the num.levels check for bulk loading
             final String[] sstFileNames = dbDir.list((dir, name) -> SST_FILE_EXTENSION.matcher(name).matches());
@@ -359,7 +360,14 @@ public class RocksDBStore implements KeyValueStore<Bytes, byte[]> {
         openDB(internalProcessorContext);
     }
 
-    void write(final WriteBatch batch) throws RocksDBException {
+    @Override
+    public void addToBatch(final KeyValue<byte[], byte[]> record,
+                           final WriteBatch batch) throws RocksDBException {
+        dbAccessor.addToBatch(record.key, record.value, batch);
+    }
+
+    @Override
+    public void write(final WriteBatch batch) throws RocksDBException {
         db.write(wOptions, batch);
     }
 
@@ -428,6 +436,10 @@ public class RocksDBStore implements KeyValueStore<Bytes, byte[]> {
         void prepareBatchForRestore(final Collection<KeyValue<byte[], byte[]>> records,
                                     final WriteBatch batch) throws RocksDBException;
 
+        void addToBatch(final byte[] key,
+                        final byte[] value,
+                        final WriteBatch batch) throws RocksDBException;
+
         void close();
 
         void toggleDbForBulkLoading();
@@ -465,11 +477,7 @@ public class RocksDBStore implements KeyValueStore<Bytes, byte[]> {
                                  final WriteBatch batch) throws RocksDBException {
             for (final KeyValue<Bytes, byte[]> entry : entries) {
                 Objects.requireNonNull(entry.key, "key cannot be null");
-                if (entry.value == null) {
-                    batch.delete(columnFamily, entry.key.get());
-                } else {
-                    batch.put(columnFamily, entry.key.get(), entry.value);
-                }
+                addToBatch(entry.key.get(), entry.value, batch);
             }
         }
 
@@ -515,11 +523,18 @@ public class RocksDBStore implements KeyValueStore<Bytes, byte[]> {
         public void prepareBatchForRestore(final Collection<KeyValue<byte[], byte[]>> records,
                                            final WriteBatch batch) throws RocksDBException {
             for (final KeyValue<byte[], byte[]> record : records) {
-                if (record.value == null) {
-                    batch.delete(columnFamily, record.key);
-                } else {
-                    batch.put(columnFamily, record.key, record.value);
-                }
+                addToBatch(record.key, record.value, batch);
+            }
+        }
+
+        @Override
+        public void addToBatch(final byte[] key,
+                               final byte[] value,
+                               final WriteBatch batch) throws RocksDBException {
+            if (value == null) {
+                batch.delete(columnFamily, key);
+            } else {
+                batch.put(columnFamily, key, value);
             }
         }
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/Segment.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimestampedSegmentedBytesStore.java
similarity index 60%
copy from streams/src/main/java/org/apache/kafka/streams/state/internals/Segment.java
copy to streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimestampedSegmentedBytesStore.java
index 8687ffc..630124c 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/Segment.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimestampedSegmentedBytesStore.java
@@ -16,17 +16,13 @@
  */
 package org.apache.kafka.streams.state.internals;
 
-import org.apache.kafka.common.utils.Bytes;
-import org.apache.kafka.streams.processor.StateStore;
-import org.apache.kafka.streams.state.KeyValueIterator;
+public class RocksDBTimestampedSegmentedBytesStore extends AbstractRocksDBSegmentedBytesStore<TimestampedSegment> {
 
-import java.io.IOException;
-
-public interface Segment extends StateStore {
-
-    void destroy() throws IOException;
-
-    KeyValueIterator<Bytes, byte[]> all();
-
-    KeyValueIterator<Bytes, byte[]> range(final Bytes from, final Bytes to);
+    RocksDBTimestampedSegmentedBytesStore(final String name,
+                                          final String metricScope,
+                                          final long retention,
+                                          final long segmentInterval,
+                                          final KeySchema keySchema) {
+        super(name, metricScope, keySchema, new TimestampedSegments(name, retention, segmentInterval));
+    }
 }
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimestampedStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimestampedStore.java
index 6d477bd..f52033b 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimestampedStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimestampedStore.java
@@ -54,6 +54,11 @@ public class RocksDBTimestampedStore extends RocksDBStore {
         super(name);
     }
 
+    RocksDBTimestampedStore(final String name,
+                            final String parentDir) {
+        super(name, parentDir);
+    }
+
     @Override
     void openRocksDB(final DBOptions dbOptions,
                      final ColumnFamilyOptions columnFamilyOptions) {
@@ -142,13 +147,7 @@ public class RocksDBTimestampedStore extends RocksDBStore {
                                  final WriteBatch batch) throws RocksDBException {
             for (final KeyValue<Bytes, byte[]> entry : entries) {
                 Objects.requireNonNull(entry.key, "key cannot be null");
-                if (entry.value == null) {
-                    batch.delete(oldColumnFamily, entry.key.get());
-                    batch.delete(newColumnFamily, entry.key.get());
-                } else {
-                    batch.delete(oldColumnFamily, entry.key.get());
-                    batch.put(newColumnFamily, entry.key.get(), entry.value);
-                }
+                addToBatch(entry.key.get(), entry.value, batch);
             }
         }
 
@@ -223,13 +222,20 @@ public class RocksDBTimestampedStore extends RocksDBStore {
         public void prepareBatchForRestore(final Collection<KeyValue<byte[], byte[]>> records,
                                            final WriteBatch batch) throws RocksDBException {
             for (final KeyValue<byte[], byte[]> record : records) {
-                if (record.value == null) {
-                    batch.delete(oldColumnFamily, record.key);
-                    batch.delete(newColumnFamily, record.key);
-                } else {
-                    batch.delete(oldColumnFamily, record.key);
-                    batch.put(newColumnFamily, record.key, record.value);
-                }
+                addToBatch(record.key, record.value, batch);
+            }
+        }
+
+        @Override
+        public void addToBatch(final byte[] key,
+                               final byte[] value,
+                               final WriteBatch batch) throws RocksDBException {
+            if (value == null) {
+                batch.delete(oldColumnFamily, key);
+                batch.delete(newColumnFamily, key);
+            } else {
+                batch.delete(oldColumnFamily, key);
+                batch.put(newColumnFamily, key, value);
             }
         }
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDbWindowBytesStoreSupplier.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDbWindowBytesStoreSupplier.java
index ecdfad2..b2e8c11 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDbWindowBytesStoreSupplier.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDbWindowBytesStoreSupplier.java
@@ -26,17 +26,20 @@ public class RocksDbWindowBytesStoreSupplier implements WindowBytesStoreSupplier
     private final long segmentInterval;
     private final long windowSize;
     private final boolean retainDuplicates;
+    private final boolean returnTimestampedStore;
 
     public RocksDbWindowBytesStoreSupplier(final String name,
                                            final long retentionPeriod,
                                            final long segmentInterval,
                                            final long windowSize,
-                                           final boolean retainDuplicates) {
+                                           final boolean retainDuplicates,
+                                           final boolean returnTimestampedStore) {
         this.name = name;
         this.retentionPeriod = retentionPeriod;
         this.segmentInterval = segmentInterval;
         this.windowSize = windowSize;
         this.retainDuplicates = retainDuplicates;
+        this.returnTimestampedStore = returnTimestampedStore;
     }
 
     @Override
@@ -46,13 +49,24 @@ public class RocksDbWindowBytesStoreSupplier implements WindowBytesStoreSupplier
 
     @Override
     public WindowStore<Bytes, byte[]> get() {
-        final RocksDBSegmentedBytesStore segmentedBytesStore = new RocksDBSegmentedBytesStore(
+        final SegmentedBytesStore segmentedBytesStore;
+        if (!returnTimestampedStore) {
+            segmentedBytesStore = new RocksDBSegmentedBytesStore(
                 name,
                 metricsScope(),
                 retentionPeriod,
                 segmentInterval,
                 new WindowKeySchema()
-        );
+            );
+        } else {
+            segmentedBytesStore = new RocksDBTimestampedSegmentedBytesStore(
+                name,
+                metricsScope(),
+                retentionPeriod,
+                segmentInterval,
+                new WindowKeySchema()
+            );
+        }
         return new RocksDBWindowStore(
             segmentedBytesStore,
             retainDuplicates,
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/Segment.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/Segment.java
index 8687ffc..fe1fc33 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/Segment.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/Segment.java
@@ -17,12 +17,12 @@
 package org.apache.kafka.streams.state.internals;
 
 import org.apache.kafka.common.utils.Bytes;
-import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.state.KeyValueIterator;
+import org.apache.kafka.streams.state.KeyValueStore;
 
 import java.io.IOException;
 
-public interface Segment extends StateStore {
+public interface Segment extends KeyValueStore<Bytes, byte[]>, BulkLoadingStore {
 
     void destroy() throws IOException;
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueSegment.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedSegment.java
similarity index 79%
copy from streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueSegment.java
copy to streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedSegment.java
index 697b67a..ba7b64d 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueSegment.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedSegment.java
@@ -22,10 +22,12 @@ import org.apache.kafka.streams.processor.ProcessorContext;
 import java.io.IOException;
 import java.util.Objects;
 
-class KeyValueSegment extends RocksDBStore implements Comparable<KeyValueSegment>, Segment {
+class TimestampedSegment extends RocksDBTimestampedStore implements Comparable<TimestampedSegment>, Segment {
     public final long id;
 
-    KeyValueSegment(final String segmentName, final String windowName, final long id) {
+    TimestampedSegment(final String segmentName,
+                       final String windowName,
+                       final long id) {
         super(segmentName, windowName);
         this.id = id;
     }
@@ -36,7 +38,7 @@ class KeyValueSegment extends RocksDBStore implements Comparable<KeyValueSegment
     }
 
     @Override
-    public int compareTo(final KeyValueSegment segment) {
+    public int compareTo(final TimestampedSegment segment) {
         return Long.compare(id, segment.id);
     }
 
@@ -49,7 +51,7 @@ class KeyValueSegment extends RocksDBStore implements Comparable<KeyValueSegment
 
     @Override
     public String toString() {
-        return "KeyValueSegment(id=" + id + ", name=" + name() + ")";
+        return "TimestampedSegment(id=" + id + ", name=" + name() + ")";
     }
 
     @Override
@@ -57,7 +59,7 @@ class KeyValueSegment extends RocksDBStore implements Comparable<KeyValueSegment
         if (obj == null || getClass() != obj.getClass()) {
             return false;
         }
-        final KeyValueSegment segment = (KeyValueSegment) obj;
+        final TimestampedSegment segment = (TimestampedSegment) obj;
         return id == segment.id;
     }
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueSegments.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedSegments.java
similarity index 64%
copy from streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueSegments.java
copy to streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedSegments.java
index 0664551..3e45017 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueSegments.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedSegments.java
@@ -19,23 +19,26 @@ package org.apache.kafka.streams.state.internals;
 import org.apache.kafka.streams.processor.internals.InternalProcessorContext;
 
 /**
- * Manages the {@link KeyValueSegment}s that are used by the {@link RocksDBSegmentedBytesStore}
+ * Manages the {@link TimestampedSegment}s that are used by the {@link RocksDBTimestampedSegmentedBytesStore}
  */
-class KeyValueSegments extends AbstractSegments<KeyValueSegment> {
+class TimestampedSegments extends AbstractSegments<TimestampedSegment> {
 
-    KeyValueSegments(final String name, final long retentionPeriod, final long segmentInterval) {
+    TimestampedSegments(final String name,
+                        final long retentionPeriod,
+                        final long segmentInterval) {
         super(name, retentionPeriod, segmentInterval);
     }
 
     @Override
-    public KeyValueSegment getOrCreateSegment(final long segmentId, final InternalProcessorContext context) {
+    public TimestampedSegment getOrCreateSegment(final long segmentId,
+                                                 final InternalProcessorContext context) {
         if (segments.containsKey(segmentId)) {
             return segments.get(segmentId);
         } else {
-            final KeyValueSegment newSegment = new KeyValueSegment(segmentName(segmentId), name, segmentId);
+            final TimestampedSegment newSegment = new TimestampedSegment(segmentName(segmentId), name, segmentId);
 
             if (segments.put(segmentId, newSegment) != null) {
-                throw new IllegalStateException("KeyValueSegment already exists. Possible concurrent access.");
+                throw new IllegalStateException("TimestampedSegment already exists. Possible concurrent access.");
             }
 
             newSegment.openDB(context);
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/StoresTest.java b/streams/src/test/java/org/apache/kafka/streams/state/StoresTest.java
index 9cc1280..4819ac1 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/StoresTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/StoresTest.java
@@ -17,12 +17,16 @@
 package org.apache.kafka.streams.state;
 
 import org.apache.kafka.common.serialization.Serdes;
+import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.state.internals.InMemoryKeyValueStore;
 import org.apache.kafka.streams.state.internals.MemoryNavigableLRUCache;
+import org.apache.kafka.streams.state.internals.RocksDBSegmentedBytesStore;
 import org.apache.kafka.streams.state.internals.RocksDBSessionStore;
 import org.apache.kafka.streams.state.internals.RocksDBStore;
+import org.apache.kafka.streams.state.internals.RocksDBTimestampedSegmentedBytesStore;
 import org.apache.kafka.streams.state.internals.RocksDBTimestampedStore;
 import org.apache.kafka.streams.state.internals.RocksDBWindowStore;
+import org.apache.kafka.streams.state.internals.WrappedStateStore;
 import org.junit.Test;
 
 import static java.time.Duration.ZERO;
@@ -115,12 +119,17 @@ public class StoresTest {
 
     @Test
     public void shouldCreateRocksDbStore() {
-        assertThat(Stores.persistentKeyValueStore("store").get(), allOf(not(instanceOf(RocksDBTimestampedStore.class)), instanceOf(RocksDBStore.class)));
+        assertThat(
+            Stores.persistentKeyValueStore("store").get(),
+            allOf(not(instanceOf(RocksDBTimestampedStore.class)), instanceOf(RocksDBStore.class)));
     }
 
     @Test
     public void shouldCreateRocksDbWindowStore() {
-        assertThat(Stores.persistentWindowStore("store", ofMillis(1L), ofMillis(1L), false).get(), instanceOf(RocksDBWindowStore.class));
+        final WindowStore store = Stores.persistentWindowStore("store", ofMillis(1L), ofMillis(1L), false).get();
+        final StateStore wrapped = ((WrappedStateStore) store).wrapped();
+        assertThat(store, instanceOf(RocksDBWindowStore.class));
+        assertThat(wrapped, allOf(not(instanceOf(RocksDBTimestampedSegmentedBytesStore.class)), instanceOf(RocksDBSegmentedBytesStore.class)));
     }
 
     @Test
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBSegmentedBytesStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractRocksDBSegmentedBytesStoreTest.java
similarity index 90%
copy from streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBSegmentedBytesStoreTest.java
copy to streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractRocksDBSegmentedBytesStoreTest.java
index d0dd133..004f181 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBSegmentedBytesStoreTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractRocksDBSegmentedBytesStoreTest.java
@@ -42,6 +42,7 @@ import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 import org.junit.runners.Parameterized.Parameter;
 import org.junit.runners.Parameterized.Parameters;
+import org.rocksdb.Options;
 import org.rocksdb.WriteBatch;
 
 import java.io.File;
@@ -68,20 +69,20 @@ import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotEquals;
 import static org.junit.Assert.assertTrue;
 
-
 @RunWith(Parameterized.class)
-public class RocksDBSegmentedBytesStoreTest {
+public abstract class AbstractRocksDBSegmentedBytesStoreTest<S extends Segment> {
 
     private final long windowSizeForTimeWindow = 500;
-    private final long retention = 1000;
-    private final long segmentInterval = 60_000L;
     private InternalMockProcessorContext context;
-    private final String storeName = "bytes-store";
-    private RocksDBSegmentedBytesStore bytesStore;
+    private AbstractRocksDBSegmentedBytesStore<S> bytesStore;
     private File stateDir;
     private final Window[] windows = new Window[4];
     private Window nextSegmentWindow;
 
+    final long retention = 1000;
+    final long segmentInterval = 60_000L;
+    final String storeName = "bytes-store";
+
     @Parameter
     public SegmentedBytesStore.KeySchema schema;
 
@@ -92,7 +93,6 @@ public class RocksDBSegmentedBytesStoreTest {
 
     @Before
     public void before() {
-
         if (schema instanceof SessionKeySchema) {
             windows[0] = new SessionWindow(10L, 10L);
             windows[1] = new SessionWindow(500L, 1000L);
@@ -116,14 +116,7 @@ public class RocksDBSegmentedBytesStoreTest {
             nextSegmentWindow = timeWindowForSize(segmentInterval + retention, windowSizeForTimeWindow);
         }
 
-
-        bytesStore = new RocksDBSegmentedBytesStore(
-            storeName,
-            "metrics-scope",
-            retention,
-            segmentInterval,
-            schema
-        );
+        bytesStore = getBytesStore();
 
         stateDir = TestUtils.tempDirectory();
         context = new InternalMockProcessorContext(
@@ -141,6 +134,12 @@ public class RocksDBSegmentedBytesStoreTest {
         bytesStore.close();
     }
 
+    abstract AbstractRocksDBSegmentedBytesStore<S> getBytesStore();
+
+    abstract AbstractSegments<S> newSegments();
+
+    abstract Options getOptions(S segment);
+
     @Test
     public void shouldPutAndFetch() {
         final String key = "a";
@@ -173,7 +172,6 @@ public class RocksDBSegmentedBytesStoreTest {
         assertEquals(expected, toList(results));
     }
 
-
     @Test
     public void shouldRemove() {
         bytesStore.put(serializeKey(new Windowed<>("a", windows[0])), serializeValue(30));
@@ -184,11 +182,10 @@ public class RocksDBSegmentedBytesStoreTest {
         assertFalse(value.hasNext());
     }
 
-
     @Test
     public void shouldRollSegments() {
         // just to validate directories
-        final KeyValueSegments segments = new KeyValueSegments(storeName, retention, segmentInterval);
+        final AbstractSegments<S> segments = newSegments();
         final String key = "a";
 
         bytesStore.put(serializeKey(new Windowed<>(key, windows[0])), serializeValue(50));
@@ -209,14 +206,12 @@ public class RocksDBSegmentedBytesStoreTest {
             ),
             results
         );
-
     }
 
-
     @Test
     public void shouldGetAllSegments() {
         // just to validate directories
-        final KeyValueSegments segments = new KeyValueSegments(storeName, retention, segmentInterval);
+        final AbstractSegments<S> segments = newSegments();
         final String key = "a";
 
         bytesStore.put(serializeKey(new Windowed<>(key, windows[0])), serializeValue(50L));
@@ -239,13 +234,12 @@ public class RocksDBSegmentedBytesStoreTest {
             ),
             results
         );
-
     }
 
     @Test
     public void shouldFetchAllSegments() {
         // just to validate directories
-        final KeyValueSegments segments = new KeyValueSegments(storeName, retention, segmentInterval);
+        final AbstractSegments<S> segments = newSegments();
         final String key = "a";
 
         bytesStore.put(serializeKey(new Windowed<>(key, windows[0])), serializeValue(50L));
@@ -268,12 +262,11 @@ public class RocksDBSegmentedBytesStoreTest {
             ),
             results
         );
-
     }
 
     @Test
     public void shouldLoadSegmentsWithOldStyleDateFormattedName() {
-        final KeyValueSegments segments = new KeyValueSegments(storeName, retention, segmentInterval);
+        final AbstractSegments<S> segments = newSegments();
         final String key = "a";
 
         bytesStore.put(serializeKey(new Windowed<>(key, windows[0])), serializeValue(50L));
@@ -290,13 +283,7 @@ public class RocksDBSegmentedBytesStoreTest {
         final File oldStyleName = new File(parent, nameParts[0] + "-" + formatted);
         assertTrue(new File(parent, firstSegmentName).renameTo(oldStyleName));
 
-        bytesStore = new RocksDBSegmentedBytesStore(
-            storeName,
-            "metrics-scope",
-            retention,
-            segmentInterval,
-            schema
-        );
+        bytesStore = getBytesStore();
 
         bytesStore.init(context, bytesStore);
         final List<KeyValue<Windowed<String>, Long>> results = toList(bytesStore.fetch(Bytes.wrap(key.getBytes()), 0L, 60_000L));
@@ -311,10 +298,9 @@ public class RocksDBSegmentedBytesStoreTest {
         );
     }
 
-
     @Test
     public void shouldLoadSegmentsWithOldStyleColonFormattedName() {
-        final KeyValueSegments segments = new KeyValueSegments(storeName, retention, segmentInterval);
+        final AbstractSegments<S> segments = newSegments();
         final String key = "a";
 
         bytesStore.put(serializeKey(new Windowed<>(key, windows[0])), serializeValue(50L));
@@ -327,13 +313,7 @@ public class RocksDBSegmentedBytesStoreTest {
         final File oldStyleName = new File(parent, nameParts[0] + ":" + Long.parseLong(nameParts[1]));
         assertTrue(new File(parent, firstSegmentName).renameTo(oldStyleName));
 
-        bytesStore = new RocksDBSegmentedBytesStore(
-            storeName,
-            "metrics-scope",
-            retention,
-            segmentInterval,
-            schema
-        );
+        bytesStore = getBytesStore();
 
         bytesStore.init(context, bytesStore);
         final List<KeyValue<Windowed<String>, Long>> results = toList(bytesStore.fetch(Bytes.wrap(key.getBytes()), 0L, 60_000L));
@@ -348,7 +328,6 @@ public class RocksDBSegmentedBytesStoreTest {
         );
     }
 
-
     @Test
     public void shouldBeAbleToWriteToReInitializedStore() {
         final String key = "a";
@@ -365,7 +344,7 @@ public class RocksDBSegmentedBytesStoreTest {
         final Collection<KeyValue<byte[], byte[]>> records = new ArrayList<>();
         records.add(new KeyValue<>(serializeKey(new Windowed<>(key, windows[0])).get(), serializeValue(50L)));
         records.add(new KeyValue<>(serializeKey(new Windowed<>(key, windows[3])).get(), serializeValue(100L)));
-        final Map<KeyValueSegment, WriteBatch> writeBatchMap = bytesStore.getWriteBatches(records);
+        final Map<S, WriteBatch> writeBatchMap = bytesStore.getWriteBatches(records);
         assertEquals(2, writeBatchMap.size());
         for (final WriteBatch batch : writeBatchMap.values()) {
             assertEquals(1, batch.count());
@@ -386,8 +365,8 @@ public class RocksDBSegmentedBytesStoreTest {
         assertEquals(2, bytesStore.getSegments().size());
 
         // Bulk loading is enabled during recovery.
-        for (final KeyValueSegment segment : bytesStore.getSegments()) {
-            assertThat(segment.getOptions().level0FileNumCompactionTrigger(), equalTo(1 << 30));
+        for (final S segment : bytesStore.getSegments()) {
+            assertThat(getOptions(segment).level0FileNumCompactionTrigger(), equalTo(1 << 30));
         }
 
         final List<KeyValue<Windowed<String>, Long>> expected = new ArrayList<>();
@@ -410,19 +389,19 @@ public class RocksDBSegmentedBytesStoreTest {
 
         restoreListener.onRestoreStart(null, bytesStore.name(), 0L, 0L);
 
-        for (final KeyValueSegment segment : bytesStore.getSegments()) {
-            assertThat(segment.getOptions().level0FileNumCompactionTrigger(), equalTo(1 << 30));
+        for (final S segment : bytesStore.getSegments()) {
+            assertThat(getOptions(segment).level0FileNumCompactionTrigger(), equalTo(1 << 30));
         }
 
         restoreListener.onRestoreEnd(null, bytesStore.name(), 0L);
-        for (final KeyValueSegment segment : bytesStore.getSegments()) {
-            assertThat(segment.getOptions().level0FileNumCompactionTrigger(), equalTo(4));
+        for (final S segment : bytesStore.getSegments()) {
+            assertThat(getOptions(segment).level0FileNumCompactionTrigger(), equalTo(4));
         }
     }
 
     @Test
     public void shouldLogAndMeasureExpiredRecords() {
-        LogCaptureAppender.setClassLoggerToDebug(RocksDBSegmentedBytesStore.class);
+        LogCaptureAppender.setClassLoggerToDebug(AbstractRocksDBSegmentedBytesStore.class);
         final LogCaptureAppender appender = LogCaptureAppender.createAndRegister();
 
         // write a record to advance stream time, with a high enough timestamp
@@ -471,10 +450,6 @@ public class RocksDBSegmentedBytesStoreTest {
         return Utils.mkSet(Objects.requireNonNull(windowDir.list()));
     }
 
-    private byte[] serializeValue(final long value) {
-        return Serdes.Long().serializer().serialize("", value);
-    }
-
     private Bytes serializeKey(final Windowed<String> key) {
         final StateSerdes<String, Long> stateSerdes = StateSerdes.withBuiltinTypes("dummy", String.class, Long.class);
         if (schema instanceof SessionKeySchema) {
@@ -484,6 +459,10 @@ public class RocksDBSegmentedBytesStoreTest {
         }
     }
 
+    private byte[] serializeValue(final long value) {
+        return Serdes.Long().serializer().serialize("", value);
+    }
+
     private List<KeyValue<Windowed<String>, Long>> toList(final KeyValueIterator<Bytes, byte[]> iterator) {
         final List<KeyValue<Windowed<String>, Long>> results = new ArrayList<>();
         final StateSerdes<String, Long> stateSerdes = StateSerdes.withBuiltinTypes("dummy", String.class, Long.class);
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/KeyValueSegmentTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/KeyValueSegmentTest.java
new file mode 100644
index 0000000..55654b9
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/KeyValueSegmentTest.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.state.internals;
+
+import org.apache.kafka.streams.processor.ProcessorContext;
+import org.apache.kafka.test.TestUtils;
+import org.junit.Test;
+
+import java.io.File;
+import java.util.HashSet;
+import java.util.Set;
+
+import static java.util.Collections.emptyMap;
+import static org.easymock.EasyMock.expect;
+import static org.easymock.EasyMock.mock;
+import static org.easymock.EasyMock.replay;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.not;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+public class KeyValueSegmentTest {
+
+    @Test
+    public void shouldDeleteStateDirectoryOnDestroy() throws Exception {
+        final KeyValueSegment segment = new KeyValueSegment("segment", "window", 0L);
+        final String directoryPath = TestUtils.tempDirectory().getAbsolutePath();
+        final File directory = new File(directoryPath);
+
+        final ProcessorContext mockContext = mock(ProcessorContext.class);
+        expect(mockContext.appConfigs()).andReturn(emptyMap());
+        expect(mockContext.stateDir()).andReturn(directory);
+        replay(mockContext);
+
+        segment.openDB(mockContext);
+
+        assertTrue(new File(directoryPath, "window").exists());
+        assertTrue(new File(directoryPath + File.separator + "window", "segment").exists());
+        assertTrue(new File(directoryPath + File.separator + "window", "segment").list().length > 0);
+        segment.destroy();
+        assertFalse(new File(directoryPath + File.separator + "window", "segment").exists());
+        assertTrue(new File(directoryPath, "window").exists());
+    }
+
+    @Test
+    public void shouldBeEqualIfIdIsEqual() {
+        final KeyValueSegment segment = new KeyValueSegment("anyName", "anyName", 0L);
+        final KeyValueSegment segmentSameId = new KeyValueSegment("someOtherName", "someOtherName", 0L);
+        final KeyValueSegment segmentDifferentId = new KeyValueSegment("anyName", "anyName", 1L);
+
+        assertThat(segment, equalTo(segment));
+        assertThat(segment, equalTo(segmentSameId));
+        assertThat(segment, not(equalTo(segmentDifferentId)));
+        assertThat(segment, not(equalTo(null)));
+        assertThat(segment, not(equalTo("anyName")));
+    }
+
+    @Test
+    public void shouldHashOnSegmentIdOnly() {
+        final KeyValueSegment segment = new KeyValueSegment("anyName", "anyName", 0L);
+        final KeyValueSegment segmentSameId = new KeyValueSegment("someOtherName", "someOtherName", 0L);
+        final KeyValueSegment segmentDifferentId = new KeyValueSegment("anyName", "anyName", 1L);
+
+        final Set<KeyValueSegment> set = new HashSet<>();
+        assertTrue(set.add(segment));
+        assertFalse(set.add(segmentSameId));
+        assertTrue(set.add(segmentDifferentId));
+    }
+
+    @Test
+    public void shouldCompareSegmentIdOnly() {
+        final KeyValueSegment segment1 = new KeyValueSegment("a", "C", 50L);
+        final KeyValueSegment segment2 = new KeyValueSegment("b", "B", 100L);
+        final KeyValueSegment segment3 = new KeyValueSegment("c", "A", 0L);
+
+        assertThat(segment1.compareTo(segment1), equalTo(0));
+        assertThat(segment1.compareTo(segment2), equalTo(-1));
+        assertThat(segment2.compareTo(segment1), equalTo(1));
+        assertThat(segment1.compareTo(segment3), equalTo(1));
+        assertThat(segment3.compareTo(segment1), equalTo(-1));
+        assertThat(segment2.compareTo(segment3), equalTo(1));
+        assertThat(segment3.compareTo(segment2), equalTo(-1));
+    }
+}
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBSegmentedBytesStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBSegmentedBytesStoreTest.java
index d0dd133..7511677 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBSegmentedBytesStoreTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBSegmentedBytesStoreTest.java
@@ -16,493 +16,28 @@
  */
 package org.apache.kafka.streams.state.internals;
 
-import org.apache.kafka.common.Metric;
-import org.apache.kafka.common.MetricName;
-import org.apache.kafka.common.metrics.Metrics;
-import org.apache.kafka.common.serialization.Serdes;
-import org.apache.kafka.common.utils.Bytes;
-import org.apache.kafka.common.utils.LogContext;
-import org.apache.kafka.common.utils.Utils;
-import org.apache.kafka.streams.KeyValue;
-import org.apache.kafka.streams.kstream.Window;
-import org.apache.kafka.streams.kstream.Windowed;
-import org.apache.kafka.streams.kstream.internals.SessionWindow;
-import org.apache.kafka.streams.processor.StateRestoreListener;
-import org.apache.kafka.streams.processor.internals.MockStreamsMetrics;
-import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender;
-import org.apache.kafka.streams.state.KeyValueIterator;
-import org.apache.kafka.streams.state.StateSerdes;
-import org.apache.kafka.test.InternalMockProcessorContext;
-import org.apache.kafka.test.NoOpRecordCollector;
-import org.apache.kafka.test.TestUtils;
-import org.junit.After;
-import org.junit.Before;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.Parameterized;
-import org.junit.runners.Parameterized.Parameter;
-import org.junit.runners.Parameterized.Parameters;
-import org.rocksdb.WriteBatch;
+import org.rocksdb.Options;
 
-import java.io.File;
-import java.text.SimpleDateFormat;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.Date;
-import java.util.List;
-import java.util.Map;
-import java.util.Objects;
-import java.util.Set;
-import java.util.SimpleTimeZone;
+public class RocksDBSegmentedBytesStoreTest extends AbstractRocksDBSegmentedBytesStoreTest<KeyValueSegment> {
 
-import static org.apache.kafka.common.utils.Utils.mkEntry;
-import static org.apache.kafka.common.utils.Utils.mkMap;
-import static org.apache.kafka.streams.state.internals.WindowKeySchema.timeWindowForSize;
-import static org.hamcrest.CoreMatchers.equalTo;
-import static org.hamcrest.CoreMatchers.hasItem;
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNotEquals;
-import static org.junit.Assert.assertTrue;
-
-
-@RunWith(Parameterized.class)
-public class RocksDBSegmentedBytesStoreTest {
-
-    private final long windowSizeForTimeWindow = 500;
-    private final long retention = 1000;
-    private final long segmentInterval = 60_000L;
-    private InternalMockProcessorContext context;
-    private final String storeName = "bytes-store";
-    private RocksDBSegmentedBytesStore bytesStore;
-    private File stateDir;
-    private final Window[] windows = new Window[4];
-    private Window nextSegmentWindow;
-
-    @Parameter
-    public SegmentedBytesStore.KeySchema schema;
-
-    @Parameters(name = "{0}")
-    public static Object[] getKeySchemas() {
-        return new Object[] {new SessionKeySchema(), new WindowKeySchema()};
-    }
-
-    @Before
-    public void before() {
-
-        if (schema instanceof SessionKeySchema) {
-            windows[0] = new SessionWindow(10L, 10L);
-            windows[1] = new SessionWindow(500L, 1000L);
-            windows[2] = new SessionWindow(1_000L, 1_500L);
-            windows[3] = new SessionWindow(30_000L, 60_000L);
-            // All four of the previous windows will go into segment 1.
-            // The nextSegmentWindow is computed be a high enough time that when it gets written
-            // to the segment store, it will advance stream time past the first segment's retention time and
-            // expire it.
-            nextSegmentWindow = new SessionWindow(segmentInterval + retention, segmentInterval + retention);
-        }
-        if (schema instanceof WindowKeySchema) {
-            windows[0] = timeWindowForSize(10L, windowSizeForTimeWindow);
-            windows[1] = timeWindowForSize(500L, windowSizeForTimeWindow);
-            windows[2] = timeWindowForSize(1_000L, windowSizeForTimeWindow);
-            windows[3] = timeWindowForSize(60_000L, windowSizeForTimeWindow);
-            // All four of the previous windows will go into segment 1.
-            // The nextSegmentWindow is computed be a high enough time that when it gets written
-            // to the segment store, it will advance stream time past the first segment's retention time and
-            // expire it.
-            nextSegmentWindow = timeWindowForSize(segmentInterval + retention, windowSizeForTimeWindow);
-        }
-
-
-        bytesStore = new RocksDBSegmentedBytesStore(
-            storeName,
-            "metrics-scope",
-            retention,
-            segmentInterval,
-            schema
-        );
-
-        stateDir = TestUtils.tempDirectory();
-        context = new InternalMockProcessorContext(
-            stateDir,
-            Serdes.String(),
-            Serdes.Long(),
-            new NoOpRecordCollector(),
-            new ThreadCache(new LogContext("testCache "), 0, new MockStreamsMetrics(new Metrics()))
-        );
-        bytesStore.init(context, bytesStore);
-    }
-
-    @After
-    public void close() {
-        bytesStore.close();
-    }
-
-    @Test
-    public void shouldPutAndFetch() {
-        final String key = "a";
-        bytesStore.put(serializeKey(new Windowed<>(key, windows[0])), serializeValue(10));
-        bytesStore.put(serializeKey(new Windowed<>(key, windows[1])), serializeValue(50));
-        bytesStore.put(serializeKey(new Windowed<>(key, windows[2])), serializeValue(100));
-
-        final KeyValueIterator<Bytes, byte[]> values = bytesStore.fetch(Bytes.wrap(key.getBytes()), 0, 500);
-
-        final List<KeyValue<Windowed<String>, Long>> expected = Arrays.asList(
-            KeyValue.pair(new Windowed<>(key, windows[0]), 10L),
-            KeyValue.pair(new Windowed<>(key, windows[1]), 50L)
-        );
-
-        assertEquals(expected, toList(values));
-    }
-
-    @Test
-    public void shouldFindValuesWithinRange() {
-        final String key = "a";
-        bytesStore.put(serializeKey(new Windowed<>(key, windows[0])), serializeValue(10));
-        bytesStore.put(serializeKey(new Windowed<>(key, windows[1])), serializeValue(50));
-        bytesStore.put(serializeKey(new Windowed<>(key, windows[2])), serializeValue(100));
-        final KeyValueIterator<Bytes, byte[]> results = bytesStore.fetch(Bytes.wrap(key.getBytes()), 1, 999);
-        final List<KeyValue<Windowed<String>, Long>> expected = Arrays.asList(
-            KeyValue.pair(new Windowed<>(key, windows[0]), 10L),
-            KeyValue.pair(new Windowed<>(key, windows[1]), 50L)
-        );
-
-        assertEquals(expected, toList(results));
-    }
-
-
-    @Test
-    public void shouldRemove() {
-        bytesStore.put(serializeKey(new Windowed<>("a", windows[0])), serializeValue(30));
-        bytesStore.put(serializeKey(new Windowed<>("a", windows[1])), serializeValue(50));
-
-        bytesStore.remove(serializeKey(new Windowed<>("a", windows[0])));
-        final KeyValueIterator<Bytes, byte[]> value = bytesStore.fetch(Bytes.wrap("a".getBytes()), 0, 100);
-        assertFalse(value.hasNext());
-    }
-
-
-    @Test
-    public void shouldRollSegments() {
-        // just to validate directories
-        final KeyValueSegments segments = new KeyValueSegments(storeName, retention, segmentInterval);
-        final String key = "a";
-
-        bytesStore.put(serializeKey(new Windowed<>(key, windows[0])), serializeValue(50));
-        bytesStore.put(serializeKey(new Windowed<>(key, windows[1])), serializeValue(100));
-        bytesStore.put(serializeKey(new Windowed<>(key, windows[2])), serializeValue(500));
-        assertEquals(Collections.singleton(segments.segmentName(0)), segmentDirs());
-
-        bytesStore.put(serializeKey(new Windowed<>(key, windows[3])), serializeValue(1000));
-        assertEquals(Utils.mkSet(segments.segmentName(0), segments.segmentName(1)), segmentDirs());
-
-        final List<KeyValue<Windowed<String>, Long>> results = toList(bytesStore.fetch(Bytes.wrap(key.getBytes()), 0, 1500));
-
-        assertEquals(
-            Arrays.asList(
-                KeyValue.pair(new Windowed<>(key, windows[0]), 50L),
-                KeyValue.pair(new Windowed<>(key, windows[1]), 100L),
-                KeyValue.pair(new Windowed<>(key, windows[2]), 500L)
-            ),
-            results
-        );
-
-    }
-
-
-    @Test
-    public void shouldGetAllSegments() {
-        // just to validate directories
-        final KeyValueSegments segments = new KeyValueSegments(storeName, retention, segmentInterval);
-        final String key = "a";
-
-        bytesStore.put(serializeKey(new Windowed<>(key, windows[0])), serializeValue(50L));
-        assertEquals(Collections.singleton(segments.segmentName(0)), segmentDirs());
-
-        bytesStore.put(serializeKey(new Windowed<>(key, windows[3])), serializeValue(100L));
-        assertEquals(
-            Utils.mkSet(
-                segments.segmentName(0),
-                segments.segmentName(1)
-            ),
-            segmentDirs()
-        );
-
-        final List<KeyValue<Windowed<String>, Long>> results = toList(bytesStore.all());
-        assertEquals(
-            Arrays.asList(
-                KeyValue.pair(new Windowed<>(key, windows[0]), 50L),
-                KeyValue.pair(new Windowed<>(key, windows[3]), 100L)
-            ),
-            results
-        );
-
-    }
-
-    @Test
-    public void shouldFetchAllSegments() {
-        // just to validate directories
-        final KeyValueSegments segments = new KeyValueSegments(storeName, retention, segmentInterval);
-        final String key = "a";
-
-        bytesStore.put(serializeKey(new Windowed<>(key, windows[0])), serializeValue(50L));
-        assertEquals(Collections.singleton(segments.segmentName(0)), segmentDirs());
-
-        bytesStore.put(serializeKey(new Windowed<>(key, windows[3])), serializeValue(100L));
-        assertEquals(
-            Utils.mkSet(
-                segments.segmentName(0),
-                segments.segmentName(1)
-            ),
-            segmentDirs()
-        );
-
-        final List<KeyValue<Windowed<String>, Long>> results = toList(bytesStore.fetchAll(0L, 60_000L));
-        assertEquals(
-            Arrays.asList(
-                KeyValue.pair(new Windowed<>(key, windows[0]), 50L),
-                KeyValue.pair(new Windowed<>(key, windows[3]), 100L)
-            ),
-            results
-        );
-
-    }
-
-    @Test
-    public void shouldLoadSegmentsWithOldStyleDateFormattedName() {
-        final KeyValueSegments segments = new KeyValueSegments(storeName, retention, segmentInterval);
-        final String key = "a";
-
-        bytesStore.put(serializeKey(new Windowed<>(key, windows[0])), serializeValue(50L));
-        bytesStore.put(serializeKey(new Windowed<>(key, windows[3])), serializeValue(100L));
-        bytesStore.close();
-
-        final String firstSegmentName = segments.segmentName(0);
-        final String[] nameParts = firstSegmentName.split("\\.");
-        final Long segmentId = Long.parseLong(nameParts[1]);
-        final SimpleDateFormat formatter = new SimpleDateFormat("yyyyMMddHHmm");
-        formatter.setTimeZone(new SimpleTimeZone(0, "UTC"));
-        final String formatted = formatter.format(new Date(segmentId * segmentInterval));
-        final File parent = new File(stateDir, storeName);
-        final File oldStyleName = new File(parent, nameParts[0] + "-" + formatted);
-        assertTrue(new File(parent, firstSegmentName).renameTo(oldStyleName));
-
-        bytesStore = new RocksDBSegmentedBytesStore(
-            storeName,
-            "metrics-scope",
-            retention,
-            segmentInterval,
-            schema
-        );
-
-        bytesStore.init(context, bytesStore);
-        final List<KeyValue<Windowed<String>, Long>> results = toList(bytesStore.fetch(Bytes.wrap(key.getBytes()), 0L, 60_000L));
-        assertThat(
-            results,
-            equalTo(
-                Arrays.asList(
-                    KeyValue.pair(new Windowed<>(key, windows[0]), 50L),
-                    KeyValue.pair(new Windowed<>(key, windows[3]), 100L)
-                )
-            )
-        );
-    }
-
-
-    @Test
-    public void shouldLoadSegmentsWithOldStyleColonFormattedName() {
-        final KeyValueSegments segments = new KeyValueSegments(storeName, retention, segmentInterval);
-        final String key = "a";
-
-        bytesStore.put(serializeKey(new Windowed<>(key, windows[0])), serializeValue(50L));
-        bytesStore.put(serializeKey(new Windowed<>(key, windows[3])), serializeValue(100L));
-        bytesStore.close();
-
-        final String firstSegmentName = segments.segmentName(0);
-        final String[] nameParts = firstSegmentName.split("\\.");
-        final File parent = new File(stateDir, storeName);
-        final File oldStyleName = new File(parent, nameParts[0] + ":" + Long.parseLong(nameParts[1]));
-        assertTrue(new File(parent, firstSegmentName).renameTo(oldStyleName));
-
-        bytesStore = new RocksDBSegmentedBytesStore(
+    @Override
+    RocksDBSegmentedBytesStore getBytesStore() {
+        return new RocksDBSegmentedBytesStore(
             storeName,
             "metrics-scope",
             retention,
             segmentInterval,
             schema
         );
-
-        bytesStore.init(context, bytesStore);
-        final List<KeyValue<Windowed<String>, Long>> results = toList(bytesStore.fetch(Bytes.wrap(key.getBytes()), 0L, 60_000L));
-        assertThat(
-            results,
-            equalTo(
-                Arrays.asList(
-                    KeyValue.pair(new Windowed<>(key, windows[0]), 50L),
-                    KeyValue.pair(new Windowed<>(key, windows[3]), 100L)
-                )
-            )
-        );
-    }
-
-
-    @Test
-    public void shouldBeAbleToWriteToReInitializedStore() {
-        final String key = "a";
-        // need to create a segment so we can attempt to write to it again.
-        bytesStore.put(serializeKey(new Windowed<>(key, windows[0])), serializeValue(50));
-        bytesStore.close();
-        bytesStore.init(context, bytesStore);
-        bytesStore.put(serializeKey(new Windowed<>(key, windows[1])), serializeValue(100));
-    }
-
-    @Test
-    public void shouldCreateWriteBatches() {
-        final String key = "a";
-        final Collection<KeyValue<byte[], byte[]>> records = new ArrayList<>();
-        records.add(new KeyValue<>(serializeKey(new Windowed<>(key, windows[0])).get(), serializeValue(50L)));
-        records.add(new KeyValue<>(serializeKey(new Windowed<>(key, windows[3])).get(), serializeValue(100L)));
-        final Map<KeyValueSegment, WriteBatch> writeBatchMap = bytesStore.getWriteBatches(records);
-        assertEquals(2, writeBatchMap.size());
-        for (final WriteBatch batch : writeBatchMap.values()) {
-            assertEquals(1, batch.count());
-        }
-    }
-
-    @Test
-    public void shouldRestoreToByteStore() {
-        // 0 segments initially.
-        assertEquals(0, bytesStore.getSegments().size());
-        final String key = "a";
-        final Collection<KeyValue<byte[], byte[]>> records = new ArrayList<>();
-        records.add(new KeyValue<>(serializeKey(new Windowed<>(key, windows[0])).get(), serializeValue(50L)));
-        records.add(new KeyValue<>(serializeKey(new Windowed<>(key, windows[3])).get(), serializeValue(100L)));
-        bytesStore.restoreAllInternal(records);
-
-        // 2 segments are created during restoration.
-        assertEquals(2, bytesStore.getSegments().size());
-
-        // Bulk loading is enabled during recovery.
-        for (final KeyValueSegment segment : bytesStore.getSegments()) {
-            assertThat(segment.getOptions().level0FileNumCompactionTrigger(), equalTo(1 << 30));
-        }
-
-        final List<KeyValue<Windowed<String>, Long>> expected = new ArrayList<>();
-        expected.add(new KeyValue<>(new Windowed<>(key, windows[0]), 50L));
-        expected.add(new KeyValue<>(new Windowed<>(key, windows[3]), 100L));
-
-        final List<KeyValue<Windowed<String>, Long>> results = toList(bytesStore.all());
-        assertEquals(expected, results);
-    }
-
-    @Test
-    public void shouldRespectBulkLoadOptionsDuringInit() {
-        bytesStore.init(context, bytesStore);
-        final String key = "a";
-        bytesStore.put(serializeKey(new Windowed<>(key, windows[0])), serializeValue(50L));
-        bytesStore.put(serializeKey(new Windowed<>(key, windows[3])), serializeValue(100L));
-        assertEquals(2, bytesStore.getSegments().size());
-
-        final StateRestoreListener restoreListener = context.getRestoreListener(bytesStore.name());
-
-        restoreListener.onRestoreStart(null, bytesStore.name(), 0L, 0L);
-
-        for (final KeyValueSegment segment : bytesStore.getSegments()) {
-            assertThat(segment.getOptions().level0FileNumCompactionTrigger(), equalTo(1 << 30));
-        }
-
-        restoreListener.onRestoreEnd(null, bytesStore.name(), 0L);
-        for (final KeyValueSegment segment : bytesStore.getSegments()) {
-            assertThat(segment.getOptions().level0FileNumCompactionTrigger(), equalTo(4));
-        }
-    }
-
-    @Test
-    public void shouldLogAndMeasureExpiredRecords() {
-        LogCaptureAppender.setClassLoggerToDebug(RocksDBSegmentedBytesStore.class);
-        final LogCaptureAppender appender = LogCaptureAppender.createAndRegister();
-
-        // write a record to advance stream time, with a high enough timestamp
-        // that the subsequent record in windows[0] will already be expired.
-        bytesStore.put(serializeKey(new Windowed<>("dummy", nextSegmentWindow)), serializeValue(0));
-
-        final Bytes key = serializeKey(new Windowed<>("a", windows[0]));
-        final byte[] value = serializeValue(5);
-        bytesStore.put(key, value);
-
-        LogCaptureAppender.unregister(appender);
-
-        final Map<MetricName, ? extends Metric> metrics = context.metrics().metrics();
-
-        final Metric dropTotal = metrics.get(new MetricName(
-            "expired-window-record-drop-total",
-            "stream-metrics-scope-metrics",
-            "The total number of occurrence of expired-window-record-drop operations.",
-            mkMap(
-                mkEntry("client-id", "mock"),
-                mkEntry("task-id", "0_0"),
-                mkEntry("metrics-scope-id", "bytes-store")
-            )
-        ));
-
-        final Metric dropRate = metrics.get(new MetricName(
-            "expired-window-record-drop-rate",
-            "stream-metrics-scope-metrics",
-            "The average number of occurrence of expired-window-record-drop operation per second.",
-            mkMap(
-                mkEntry("client-id", "mock"),
-                mkEntry("task-id", "0_0"),
-                mkEntry("metrics-scope-id", "bytes-store")
-            )
-        ));
-
-        assertEquals(1.0, dropTotal.metricValue());
-        assertNotEquals(0.0, dropRate.metricValue());
-        final List<String> messages = appender.getMessages();
-        assertThat(messages, hasItem("Skipping record for expired segment."));
-    }
-
-    private Set<String> segmentDirs() {
-        final File windowDir = new File(stateDir, storeName);
-
-        return Utils.mkSet(Objects.requireNonNull(windowDir.list()));
-    }
-
-    private byte[] serializeValue(final long value) {
-        return Serdes.Long().serializer().serialize("", value);
     }
 
-    private Bytes serializeKey(final Windowed<String> key) {
-        final StateSerdes<String, Long> stateSerdes = StateSerdes.withBuiltinTypes("dummy", String.class, Long.class);
-        if (schema instanceof SessionKeySchema) {
-            return Bytes.wrap(SessionKeySchema.toBinary(key, stateSerdes.keySerializer(), "dummy"));
-        } else {
-            return WindowKeySchema.toStoreKeyBinary(key, 0, stateSerdes);
-        }
+    @Override
+    KeyValueSegments newSegments() {
+        return new KeyValueSegments(storeName, retention, segmentInterval);
     }
 
-    private List<KeyValue<Windowed<String>, Long>> toList(final KeyValueIterator<Bytes, byte[]> iterator) {
-        final List<KeyValue<Windowed<String>, Long>> results = new ArrayList<>();
-        final StateSerdes<String, Long> stateSerdes = StateSerdes.withBuiltinTypes("dummy", String.class, Long.class);
-        while (iterator.hasNext()) {
-            final KeyValue<Bytes, byte[]> next = iterator.next();
-            if (schema instanceof WindowKeySchema) {
-                final KeyValue<Windowed<String>, Long> deserialized = KeyValue.pair(
-                    WindowKeySchema.fromStoreKey(next.key.get(), windowSizeForTimeWindow, stateSerdes.keyDeserializer(), stateSerdes.topic()),
-                    stateSerdes.valueDeserializer().deserialize("dummy", next.value)
-                );
-                results.add(deserialized);
-            } else {
-                final KeyValue<Windowed<String>, Long> deserialized = KeyValue.pair(
-                    SessionKeySchema.from(next.key.get(), stateSerdes.keyDeserializer(), "dummy"),
-                    stateSerdes.valueDeserializer().deserialize("dummy", next.value)
-                );
-                results.add(deserialized);
-            }
-        }
-        return results;
+    @Override
+    Options getOptions(final KeyValueSegment segment) {
+        return segment.getOptions();
     }
 }
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/Segment.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBTimestampedSegmentedBytesStoreTest.java
similarity index 56%
copy from streams/src/main/java/org/apache/kafka/streams/state/internals/Segment.java
copy to streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBTimestampedSegmentedBytesStoreTest.java
index 8687ffc..5ab0482 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/Segment.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBTimestampedSegmentedBytesStoreTest.java
@@ -16,17 +16,28 @@
  */
 package org.apache.kafka.streams.state.internals;
 
-import org.apache.kafka.common.utils.Bytes;
-import org.apache.kafka.streams.processor.StateStore;
-import org.apache.kafka.streams.state.KeyValueIterator;
+import org.rocksdb.Options;
 
-import java.io.IOException;
+public class RocksDBTimestampedSegmentedBytesStoreTest
+    extends AbstractRocksDBSegmentedBytesStoreTest<TimestampedSegment> {
 
-public interface Segment extends StateStore {
+    RocksDBTimestampedSegmentedBytesStore getBytesStore() {
+        return new RocksDBTimestampedSegmentedBytesStore(
+            storeName,
+            "metrics-scope",
+            retention,
+            segmentInterval,
+            schema
+        );
+    }
 
-    void destroy() throws IOException;
+    @Override
+    TimestampedSegments newSegments() {
+        return new TimestampedSegments(storeName, retention, segmentInterval);
+    }
 
-    KeyValueIterator<Bytes, byte[]> all();
-
-    KeyValueIterator<Bytes, byte[]> range(final Bytes from, final Bytes to);
+    @Override
+    Options getOptions(final TimestampedSegment segment) {
+        return segment.getOptions();
+    }
 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/KeyValueSegmentIteratorTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/SegmentIteratorTest.java
similarity index 99%
rename from streams/src/test/java/org/apache/kafka/streams/state/internals/KeyValueSegmentIteratorTest.java
rename to streams/src/test/java/org/apache/kafka/streams/state/internals/SegmentIteratorTest.java
index 68bd815..3c64bad 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/KeyValueSegmentIteratorTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/SegmentIteratorTest.java
@@ -38,7 +38,7 @@ import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 
-public class KeyValueSegmentIteratorTest {
+public class SegmentIteratorTest {
 
     private final KeyValueSegment segmentOne = new KeyValueSegment("one", "one", 0);
     private final KeyValueSegment segmentTwo = new KeyValueSegment("two", "window", 1);
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/TimestampedSegmentTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/TimestampedSegmentTest.java
new file mode 100644
index 0000000..10ed56e
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/TimestampedSegmentTest.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.state.internals;
+
+import org.apache.kafka.streams.processor.ProcessorContext;
+import org.apache.kafka.test.TestUtils;
+import org.junit.Test;
+
+import java.io.File;
+import java.util.HashSet;
+import java.util.Set;
+
+import static java.util.Collections.emptyMap;
+import static org.easymock.EasyMock.expect;
+import static org.easymock.EasyMock.mock;
+import static org.easymock.EasyMock.replay;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.not;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+public class TimestampedSegmentTest {
+
+    @Test
+    public void shouldDeleteStateDirectoryOnDestroy() throws Exception {
+        final TimestampedSegment segment = new TimestampedSegment("segment", "window", 0L);
+        final String directoryPath = TestUtils.tempDirectory().getAbsolutePath();
+        final File directory = new File(directoryPath);
+
+        final ProcessorContext mockContext = mock(ProcessorContext.class);
+        expect(mockContext.appConfigs()).andReturn(emptyMap());
+        expect(mockContext.stateDir()).andReturn(directory);
+        replay(mockContext);
+
+        segment.openDB(mockContext);
+
+        assertTrue(new File(directoryPath, "window").exists());
+        assertTrue(new File(directoryPath + File.separator + "window", "segment").exists());
+        assertTrue(new File(directoryPath + File.separator + "window", "segment").list().length > 0);
+        segment.destroy();
+        assertFalse(new File(directoryPath + File.separator + "window", "segment").exists());
+        assertTrue(new File(directoryPath, "window").exists());
+    }
+
+    @Test
+    public void shouldBeEqualIfIdIsEqual() {
+        final TimestampedSegment segment = new TimestampedSegment("anyName", "anyName", 0L);
+        final TimestampedSegment segmentSameId = new TimestampedSegment("someOtherName", "someOtherName", 0L);
+        final TimestampedSegment segmentDifferentId = new TimestampedSegment("anyName", "anyName", 1L);
+
+        assertThat(segment, equalTo(segment));
+        assertThat(segment, equalTo(segmentSameId));
+        assertThat(segment, not(equalTo(segmentDifferentId)));
+        assertThat(segment, not(equalTo(null)));
+        assertThat(segment, not(equalTo("anyName")));
+    }
+
+    @Test
+    public void shouldHashOnSegmentIdOnly() {
+        final TimestampedSegment segment = new TimestampedSegment("anyName", "anyName", 0L);
+        final TimestampedSegment segmentSameId = new TimestampedSegment("someOtherName", "someOtherName", 0L);
+        final TimestampedSegment segmentDifferentId = new TimestampedSegment("anyName", "anyName", 1L);
+
+        final Set<TimestampedSegment> set = new HashSet<>();
+        assertTrue(set.add(segment));
+        assertFalse(set.add(segmentSameId));
+        assertTrue(set.add(segmentDifferentId));
+    }
+
+    @Test
+    public void shouldCompareSegmentIdOnly() {
+        final TimestampedSegment segment1 = new TimestampedSegment("a", "C", 50L);
+        final TimestampedSegment segment2 = new TimestampedSegment("b", "B", 100L);
+        final TimestampedSegment segment3 = new TimestampedSegment("c", "A", 0L);
+
+        assertThat(segment1.compareTo(segment1), equalTo(0));
+        assertThat(segment1.compareTo(segment2), equalTo(-1));
+        assertThat(segment2.compareTo(segment1), equalTo(1));
+        assertThat(segment1.compareTo(segment3), equalTo(1));
+        assertThat(segment3.compareTo(segment1), equalTo(-1));
+        assertThat(segment2.compareTo(segment3), equalTo(1));
+        assertThat(segment3.compareTo(segment2), equalTo(-1));
+    }
+}
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/TimetampedSegmentsTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/TimetampedSegmentsTest.java
new file mode 100644
index 0000000..c519887
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/TimetampedSegmentsTest.java
@@ -0,0 +1,315 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.state.internals;
+
+import org.apache.kafka.common.metrics.Metrics;
+import org.apache.kafka.common.serialization.Serdes;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.streams.processor.internals.MockStreamsMetrics;
+import org.apache.kafka.test.InternalMockProcessorContext;
+import org.apache.kafka.test.NoOpRecordCollector;
+import org.apache.kafka.test.TestUtils;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.io.File;
+import java.text.SimpleDateFormat;
+import java.util.Date;
+import java.util.List;
+import java.util.SimpleTimeZone;
+
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.CoreMatchers.nullValue;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+public class TimetampedSegmentsTest {
+
+    private static final int NUM_SEGMENTS = 5;
+    private static final long SEGMENT_INTERVAL = 100L;
+    private static final long RETENTION_PERIOD = 4 * SEGMENT_INTERVAL;
+    private InternalMockProcessorContext context;
+    private TimestampedSegments segments;
+    private File stateDirectory;
+    private final String storeName = "test";
+
+    @Before
+    public void createContext() {
+        stateDirectory = TestUtils.tempDirectory();
+        context = new InternalMockProcessorContext(
+            stateDirectory,
+            Serdes.String(),
+            Serdes.Long(),
+            new NoOpRecordCollector(),
+            new ThreadCache(new LogContext("testCache "), 0, new MockStreamsMetrics(new Metrics()))
+        );
+        segments = new TimestampedSegments(storeName, RETENTION_PERIOD, SEGMENT_INTERVAL);
+    }
+
+    @After
+    public void close() {
+        segments.close();
+    }
+
+    @Test
+    public void shouldGetSegmentIdsFromTimestamp() {
+        assertEquals(0, segments.segmentId(0));
+        assertEquals(1, segments.segmentId(SEGMENT_INTERVAL));
+        assertEquals(2, segments.segmentId(2 * SEGMENT_INTERVAL));
+        assertEquals(3, segments.segmentId(3 * SEGMENT_INTERVAL));
+    }
+
+    @Test
+    public void shouldBaseSegmentIntervalOnRetentionAndNumSegments() {
+        final KeyValueSegments segments = new KeyValueSegments("test", 8 * SEGMENT_INTERVAL, 2 * SEGMENT_INTERVAL);
+        assertEquals(0, segments.segmentId(0));
+        assertEquals(0, segments.segmentId(SEGMENT_INTERVAL));
+        assertEquals(1, segments.segmentId(2 * SEGMENT_INTERVAL));
+    }
+
+    @Test
+    public void shouldGetSegmentNameFromId() {
+        assertEquals("test.0", segments.segmentName(0));
+        assertEquals("test." + SEGMENT_INTERVAL, segments.segmentName(1));
+        assertEquals("test." + 2 * SEGMENT_INTERVAL, segments.segmentName(2));
+    }
+
+    @Test
+    public void shouldCreateSegments() {
+        final TimestampedSegment segment1 = segments.getOrCreateSegmentIfLive(0, context, -1L);
+        final TimestampedSegment segment2 = segments.getOrCreateSegmentIfLive(1, context, -1L);
+        final TimestampedSegment segment3 = segments.getOrCreateSegmentIfLive(2, context, -1L);
+        assertTrue(new File(context.stateDir(), "test/test.0").isDirectory());
+        assertTrue(new File(context.stateDir(), "test/test." + SEGMENT_INTERVAL).isDirectory());
+        assertTrue(new File(context.stateDir(), "test/test." + 2 * SEGMENT_INTERVAL).isDirectory());
+        assertTrue(segment1.isOpen());
+        assertTrue(segment2.isOpen());
+        assertTrue(segment3.isOpen());
+    }
+
+    @Test
+    public void shouldNotCreateSegmentThatIsAlreadyExpired() {
+        final long streamTime = updateStreamTimeAndCreateSegment(7);
+        assertNull(segments.getOrCreateSegmentIfLive(0, context, streamTime));
+        assertFalse(new File(context.stateDir(), "test/test.0").exists());
+    }
+
+    @Test
+    public void shouldCleanupSegmentsThatHaveExpired() {
+        final TimestampedSegment segment1 = segments.getOrCreateSegmentIfLive(0, context, -1L);
+        final TimestampedSegment segment2 = segments.getOrCreateSegmentIfLive(1, context, -1L);
+        final TimestampedSegment segment3 = segments.getOrCreateSegmentIfLive(7, context, SEGMENT_INTERVAL * 7L);
+        assertFalse(segment1.isOpen());
+        assertFalse(segment2.isOpen());
+        assertTrue(segment3.isOpen());
+        assertFalse(new File(context.stateDir(), "test/test.0").exists());
+        assertFalse(new File(context.stateDir(), "test/test." + SEGMENT_INTERVAL).exists());
+        assertTrue(new File(context.stateDir(), "test/test." + 7 * SEGMENT_INTERVAL).exists());
+    }
+
+    @Test
+    public void shouldGetSegmentForTimestamp() {
+        final TimestampedSegment segment = segments.getOrCreateSegmentIfLive(0, context, -1L);
+        segments.getOrCreateSegmentIfLive(1, context, -1L);
+        assertEquals(segment, segments.getSegmentForTimestamp(0L));
+    }
+
+    @Test
+    public void shouldGetCorrectSegmentString() {
+        final TimestampedSegment segment = segments.getOrCreateSegmentIfLive(0, context, -1L);
+        assertEquals("TimestampedSegment(id=0, name=test.0)", segment.toString());
+    }
+
+    @Test
+    public void shouldCloseAllOpenSegments() {
+        final TimestampedSegment first = segments.getOrCreateSegmentIfLive(0, context, -1L);
+        final TimestampedSegment second = segments.getOrCreateSegmentIfLive(1, context, -1L);
+        final TimestampedSegment third = segments.getOrCreateSegmentIfLive(2, context, -1L);
+        segments.close();
+
+        assertFalse(first.isOpen());
+        assertFalse(second.isOpen());
+        assertFalse(third.isOpen());
+    }
+
+    @Test
+    public void shouldOpenExistingSegments() {
+        segments = new TimestampedSegments("test", 4, 1);
+        segments.getOrCreateSegmentIfLive(0, context, -1L);
+        segments.getOrCreateSegmentIfLive(1, context, -1L);
+        segments.getOrCreateSegmentIfLive(2, context, -1L);
+        segments.getOrCreateSegmentIfLive(3, context, -1L);
+        segments.getOrCreateSegmentIfLive(4, context, -1L);
+        // close existing.
+        segments.close();
+
+        segments = new TimestampedSegments("test", 4, 1);
+        segments.openExisting(context, -1L);
+
+        assertTrue(segments.getSegmentForTimestamp(0).isOpen());
+        assertTrue(segments.getSegmentForTimestamp(1).isOpen());
+        assertTrue(segments.getSegmentForTimestamp(2).isOpen());
+        assertTrue(segments.getSegmentForTimestamp(3).isOpen());
+        assertTrue(segments.getSegmentForTimestamp(4).isOpen());
+    }
+
+    @Test
+    public void shouldGetSegmentsWithinTimeRange() {
+        updateStreamTimeAndCreateSegment(0);
+        updateStreamTimeAndCreateSegment(1);
+        updateStreamTimeAndCreateSegment(2);
+        updateStreamTimeAndCreateSegment(3);
+        final long streamTime = updateStreamTimeAndCreateSegment(4);
+        segments.getOrCreateSegmentIfLive(0, context, streamTime);
+        segments.getOrCreateSegmentIfLive(1, context, streamTime);
+        segments.getOrCreateSegmentIfLive(2, context, streamTime);
+        segments.getOrCreateSegmentIfLive(3, context, streamTime);
+        segments.getOrCreateSegmentIfLive(4, context, streamTime);
+
+        final List<TimestampedSegment> segments = this.segments.segments(0, 2 * SEGMENT_INTERVAL);
+        assertEquals(3, segments.size());
+        assertEquals(0, segments.get(0).id);
+        assertEquals(1, segments.get(1).id);
+        assertEquals(2, segments.get(2).id);
+    }
+
+    @Test
+    public void shouldGetSegmentsWithinTimeRangeOutOfOrder() {
+        updateStreamTimeAndCreateSegment(4);
+        updateStreamTimeAndCreateSegment(2);
+        updateStreamTimeAndCreateSegment(0);
+        updateStreamTimeAndCreateSegment(1);
+        updateStreamTimeAndCreateSegment(3);
+
+        final List<TimestampedSegment> segments = this.segments.segments(0, 2 * SEGMENT_INTERVAL);
+        assertEquals(3, segments.size());
+        assertEquals(0, segments.get(0).id);
+        assertEquals(1, segments.get(1).id);
+        assertEquals(2, segments.get(2).id);
+    }
+
+    @Test
+    public void shouldRollSegments() {
+        updateStreamTimeAndCreateSegment(0);
+        verifyCorrectSegments(0, 1);
+        updateStreamTimeAndCreateSegment(1);
+        verifyCorrectSegments(0, 2);
+        updateStreamTimeAndCreateSegment(2);
+        verifyCorrectSegments(0, 3);
+        updateStreamTimeAndCreateSegment(3);
+        verifyCorrectSegments(0, 4);
+        updateStreamTimeAndCreateSegment(4);
+        verifyCorrectSegments(0, 5);
+        updateStreamTimeAndCreateSegment(5);
+        verifyCorrectSegments(1, 5);
+        updateStreamTimeAndCreateSegment(6);
+        verifyCorrectSegments(2, 5);
+    }
+
+    @Test
+    public void futureEventsShouldNotCauseSegmentRoll() {
+        updateStreamTimeAndCreateSegment(0);
+        verifyCorrectSegments(0, 1);
+        updateStreamTimeAndCreateSegment(1);
+        verifyCorrectSegments(0, 2);
+        updateStreamTimeAndCreateSegment(2);
+        verifyCorrectSegments(0, 3);
+        updateStreamTimeAndCreateSegment(3);
+        verifyCorrectSegments(0, 4);
+        final long streamTime = updateStreamTimeAndCreateSegment(4);
+        verifyCorrectSegments(0, 5);
+        segments.getOrCreateSegmentIfLive(5, context, streamTime);
+        verifyCorrectSegments(0, 6);
+        segments.getOrCreateSegmentIfLive(6, context, streamTime);
+        verifyCorrectSegments(0, 7);
+    }
+
+    private long updateStreamTimeAndCreateSegment(final int segment) {
+        final long streamTime = SEGMENT_INTERVAL * segment;
+        segments.getOrCreateSegmentIfLive(segment, context, streamTime);
+        return streamTime;
+    }
+
+    @Test
+    public void shouldUpdateSegmentFileNameFromOldDateFormatToNewFormat() throws Exception {
+        final long segmentInterval = 60_000L; // the old segment file's naming system maxes out at 1 minute granularity.
+
+        segments = new TimestampedSegments(storeName, NUM_SEGMENTS * segmentInterval, segmentInterval);
+
+        final String storeDirectoryPath = stateDirectory.getAbsolutePath() + File.separator + storeName;
+        final File storeDirectory = new File(storeDirectoryPath);
+        //noinspection ResultOfMethodCallIgnored
+        storeDirectory.mkdirs();
+
+        final SimpleDateFormat formatter = new SimpleDateFormat("yyyyMMddHHmm");
+        formatter.setTimeZone(new SimpleTimeZone(0, "UTC"));
+
+        for (int segmentId = 0; segmentId < NUM_SEGMENTS; ++segmentId) {
+            final File oldSegment = new File(storeDirectoryPath + File.separator + storeName + "-" + formatter.format(new Date(segmentId * segmentInterval)));
+            //noinspection ResultOfMethodCallIgnored
+            oldSegment.createNewFile();
+        }
+
+        segments.openExisting(context, -1L);
+
+        for (int segmentId = 0; segmentId < NUM_SEGMENTS; ++segmentId) {
+            final String segmentName = storeName + "." + (long) segmentId * segmentInterval;
+            final File newSegment = new File(storeDirectoryPath + File.separator + segmentName);
+            assertTrue(newSegment.exists());
+        }
+    }
+
+    @Test
+    public void shouldUpdateSegmentFileNameFromOldColonFormatToNewFormat() throws Exception {
+        final String storeDirectoryPath = stateDirectory.getAbsolutePath() + File.separator + storeName;
+        final File storeDirectory = new File(storeDirectoryPath);
+        //noinspection ResultOfMethodCallIgnored
+        storeDirectory.mkdirs();
+
+        for (int segmentId = 0; segmentId < NUM_SEGMENTS; ++segmentId) {
+            final File oldSegment = new File(storeDirectoryPath + File.separator + storeName + ":" + segmentId * (RETENTION_PERIOD / (NUM_SEGMENTS - 1)));
+            //noinspection ResultOfMethodCallIgnored
+            oldSegment.createNewFile();
+        }
+
+        segments.openExisting(context, -1L);
+
+        for (int segmentId = 0; segmentId < NUM_SEGMENTS; ++segmentId) {
+            final File newSegment = new File(storeDirectoryPath + File.separator + storeName + "." + segmentId * (RETENTION_PERIOD / (NUM_SEGMENTS - 1)));
+            assertTrue(newSegment.exists());
+        }
+    }
+
+    @Test
+    public void shouldClearSegmentsOnClose() {
+        segments.getOrCreateSegmentIfLive(0, context, -1L);
+        segments.close();
+        assertThat(segments.getSegmentForTimestamp(0), is(nullValue()));
+    }
+
+    private void verifyCorrectSegments(final long first, final int numSegments) {
+        final List<TimestampedSegment> result = this.segments.segments(0, Long.MAX_VALUE);
+        assertEquals(numSegments, result.size());
+        for (int i = 0; i < numSegments; i++) {
+            assertEquals(i + first, result.get(i).id);
+        }
+    }
+}