You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by vv...@apache.org on 2020/10/08 13:23:22 UTC

[kafka] branch 2.7 updated: KAFKA-9929: Support backward iterator on SessionStore (#9139)

This is an automated email from the ASF dual-hosted git repository.

vvcephei pushed a commit to branch 2.7
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/2.7 by this push:
     new 9390181  KAFKA-9929: Support backward iterator on SessionStore (#9139)
9390181 is described below

commit 939018174d79742de3e01cda251faffd740177c4
Author: Jorge Esteban Quilcate Otoya <qu...@gmail.com>
AuthorDate: Thu Oct 8 14:08:24 2020 +0100

    KAFKA-9929: Support backward iterator on SessionStore (#9139)
    
    Implements KIP-617 for `SessionStore`
    
    Reviewers: A. Sophie Blee-Goldman <so...@confluent.io>, John Roesler <vv...@apache.org>
---
 .../kafka/streams/state/ReadOnlySessionStore.java  | 130 +++++++-
 .../apache/kafka/streams/state/SessionStore.java   |  40 ---
 .../state/internals/CachingSessionStore.java       | 148 +++++++--
 .../internals/ChangeLoggingSessionBytesStore.java  |  24 ++
 .../internals/CompositeReadOnlySessionStore.java   | 169 +++++++++++
 .../state/internals/InMemorySessionStore.java      | 123 +++++++-
 .../MergedSortedCacheSessionStoreIterator.java     |   5 +-
 .../state/internals/MeteredSessionStore.java       |  68 +++++
 .../state/internals/RocksDBSessionStore.java       |  36 +++
 .../internals/AbstractSessionBytesStoreTest.java   | 187 ++++++++++++
 .../state/internals/CacheFlushListenerStub.java    |  49 +++
 ....java => CachingInMemoryKeyValueStoreTest.java} |  29 +-
 ...t.java => CachingInMemorySessionStoreTest.java} | 212 ++++++++++---
 ...java => CachingPersistentSessionStoreTest.java} | 330 +++++++++++++++------
 ....java => CachingPersistentWindowStoreTest.java} |   6 +-
 .../ChangeLoggingSessionBytesStoreTest.java        |  40 +++
 ...SortedCacheWrappedSessionStoreIteratorTest.java |  63 +++-
 .../state/internals/MeteredSessionStoreTest.java   |  80 +++++
 .../kafka/test/ReadOnlySessionStoreStub.java       |  72 ++++-
 19 files changed, 1571 insertions(+), 240 deletions(-)

diff --git a/streams/src/main/java/org/apache/kafka/streams/state/ReadOnlySessionStore.java b/streams/src/main/java/org/apache/kafka/streams/state/ReadOnlySessionStore.java
index 230d257..8874908 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/ReadOnlySessionStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/ReadOnlySessionStore.java
@@ -28,31 +28,153 @@ import org.apache.kafka.streams.kstream.Windowed;
  * @param <AGG> the aggregated value type
  */
 public interface ReadOnlySessionStore<K, AGG> {
+
     /**
-     * Retrieve all aggregated sessions for the provided key.
+     * Fetch any sessions with the matching key and the sessions end is &ge; earliestSessionEndTime and the sessions
+     * start is &le; latestSessionStartTime iterating from earliest to latest.
+     * <p>
+     * This iterator must be closed after use.
+     *
+     * @param key                    the key to return sessions for
+     * @param earliestSessionEndTime the end timestamp of the earliest session to search for, where iteration starts.
+     * @param latestSessionStartTime the end timestamp of the latest session to search for, where iteration ends.
+     * @return iterator of sessions with the matching key and aggregated values, from earliest to latest session time.
+     * @throws NullPointerException If null is used for key.
+     */
+    default KeyValueIterator<Windowed<K>, AGG> findSessions(final K key,
+                                                            final long earliestSessionEndTime,
+                                                            final long latestSessionStartTime) {
+        throw new UnsupportedOperationException("This API is not supported by this implementation of ReadOnlySessionStore.");
+    }
+
+    /**
+     * Fetch any sessions with the matching key and the sessions end is &ge; earliestSessionEndTime and the sessions
+     * start is &le; latestSessionStartTime iterating from latest to earliest.
+     * <p>
+     * This iterator must be closed after use.
+     *
+     * @param key                    the key to return sessions for
+     * @param earliestSessionEndTime the end timestamp of the earliest session to search for, where iteration ends.
+     * @param latestSessionStartTime the end timestamp of the latest session to search for, where iteration starts.
+     * @return backward iterator of sessions with the matching key and aggregated values, from latest to earliest session time.
+     * @throws NullPointerException If null is used for key.
+     */
+    default KeyValueIterator<Windowed<K>, AGG> backwardFindSessions(final K key,
+                                                                    final long earliestSessionEndTime,
+                                                                    final long latestSessionStartTime) {
+        throw new UnsupportedOperationException("This API is not supported by this implementation of ReadOnlySessionStore.");
+    }
+
+    /**
+     * Fetch any sessions in the given range of keys and the sessions end is &ge; earliestSessionEndTime and the sessions
+     * start is &le; latestSessionStartTime iterating from earliest to latest.
+     * <p>
+     * This iterator must be closed after use.
+     *
+     * @param keyFrom                The first key that could be in the range
+     * @param keyTo                  The last key that could be in the range
+     * @param earliestSessionEndTime the end timestamp of the earliest session to search for, where iteration starts.
+     * @param latestSessionStartTime the end timestamp of the latest session to search for, where iteration ends.
+     * @return iterator of sessions with the matching keys and aggregated values, from earliest to latest session time.
+     * @throws NullPointerException If null is used for any key.
+     */
+    default KeyValueIterator<Windowed<K>, AGG> findSessions(final K keyFrom,
+                                                            final K keyTo,
+                                                            final long earliestSessionEndTime,
+                                                            final long latestSessionStartTime) {
+        throw new UnsupportedOperationException("This API is not supported by this implementation of ReadOnlySessionStore.");
+    }
+
+
+    /**
+     * Fetch any sessions in the given range of keys and the sessions end is &ge; earliestSessionEndTime and the sessions
+     * start is &le; latestSessionStartTime iterating from latest to earliest.
+     * <p>
      * This iterator must be closed after use.
      *
+     * @param keyFrom                The first key that could be in the range
+     * @param keyTo                  The last key that could be in the range
+     * @param earliestSessionEndTime the end timestamp of the earliest session to search for, where iteration ends.
+     * @param latestSessionStartTime the end timestamp of the latest session to search for, where iteration starts.
+     * @return backward iterator of sessions with the matching keys and aggregated values, from latest to earliest session time.
+     * @throws NullPointerException If null is used for any key.
+     */
+    default KeyValueIterator<Windowed<K>, AGG> backwardFindSessions(final K keyFrom,
+                                                                    final K keyTo,
+                                                                    final long earliestSessionEndTime,
+                                                                    final long latestSessionStartTime) {
+        throw new UnsupportedOperationException("This API is not supported by this implementation of ReadOnlySessionStore.");
+    }
+
+    /**
+     * Get the value of key from a single session.
+     *
+     * @param key       the key to fetch
+     * @param startTime start timestamp of the session
+     * @param endTime   end timestamp of the session
+     * @return The value or {@code null} if no session associated with the key can be found
+     * @throws NullPointerException If {@code null} is used for any key.
+     */
+    default AGG fetchSession(final K key, final long startTime, final long endTime) {
+        throw new UnsupportedOperationException("This API is not supported by this implementation of ReadOnlySessionStore.");
+    }
+
+    /**
+     * Retrieve all aggregated sessions for the provided key.
+     * This iterator must be closed after use.
+     * <p>
      * For each key, the iterator guarantees ordering of sessions, starting from the oldest/earliest
      * available session to the newest/latest session.
      *
      * @param    key record key to find aggregated session values for
-     * @return   KeyValueIterator containing all sessions for the provided key.
+     * @return   KeyValueIterator containing all sessions for the provided key, from oldest to newest session.
      * @throws   NullPointerException If null is used for key.
      *
      */
     KeyValueIterator<Windowed<K>, AGG> fetch(final K key);
 
     /**
-     * Retrieve all aggregated sessions for the given range of keys.
+     * Retrieve all aggregated sessions for the provided key.
      * This iterator must be closed after use.
+     * <p>
+     * For each key, the iterator guarantees ordering of sessions, starting from the newest/latest
+     * available session to the oldest/earliest session.
      *
+     * @param key record key to find aggregated session values for
+     * @return backward KeyValueIterator containing all sessions for the provided key, from newest to oldest session.
+     * @throws NullPointerException If null is used for key.
+     */
+    default KeyValueIterator<Windowed<K>, AGG> backwardFetch(final K key) {
+        throw new UnsupportedOperationException("This API is not supported by this implementation of ReadOnlySessionStore.");
+    }
+
+    /**
+     * Retrieve all aggregated sessions for the given range of keys.
+     * This iterator must be closed after use.
+     * <p>
      * For each key, the iterator guarantees ordering of sessions, starting from the oldest/earliest
      * available session to the newest/latest session.
      *
      * @param    from first key in the range to find aggregated session values for
      * @param    to last key in the range to find aggregated session values for
-     * @return   KeyValueIterator containing all sessions for the provided key.
+     * @return   KeyValueIterator containing all sessions for the provided key, from oldest to newest session.
      * @throws   NullPointerException If null is used for any of the keys.
      */
     KeyValueIterator<Windowed<K>, AGG> fetch(final K from, final K to);
+
+    /**
+     * Retrieve all aggregated sessions for the given range of keys.
+     * This iterator must be closed after use.
+     * <p>
+     * For each key, the iterator guarantees ordering of sessions, starting from the newest/latest
+     * available session to the oldest/earliest session.
+     *
+     * @param from first key in the range to find aggregated session values for
+     * @param to   last key in the range to find aggregated session values for
+     * @return backward KeyValueIterator containing all sessions for the provided key, from newest to oldest session.
+     * @throws NullPointerException If null is used for any of the keys.
+     */
+    default KeyValueIterator<Windowed<K>, AGG> backwardFetch(final K from, final K to) {
+        throw new UnsupportedOperationException("This API is not supported by this implementation of ReadOnlySessionStore.");
+    }
 }
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/SessionStore.java b/streams/src/main/java/org/apache/kafka/streams/state/SessionStore.java
index faaa751..47f48d5 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/SessionStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/SessionStore.java
@@ -35,46 +35,6 @@ import org.apache.kafka.streams.processor.StateStore;
 public interface SessionStore<K, AGG> extends StateStore, ReadOnlySessionStore<K, AGG> {
 
     /**
-     * Fetch any sessions with the matching key and the sessions end is &ge; earliestSessionEndTime and the sessions
-     * start is &le; latestSessionStartTime
-     *
-     * This iterator must be closed after use.
-     *
-     * @param key the key to return sessions for
-     * @param earliestSessionEndTime the end timestamp of the earliest session to search for
-     * @param latestSessionStartTime the end timestamp of the latest session to search for
-     * @return iterator of sessions with the matching key and aggregated values
-     * @throws NullPointerException If null is used for key.
-     */
-    KeyValueIterator<Windowed<K>, AGG> findSessions(final K key, final long earliestSessionEndTime, final long latestSessionStartTime);
-
-    /**
-     * Fetch any sessions in the given range of keys and the sessions end is &ge; earliestSessionEndTime and the sessions
-     * start is &le; latestSessionStartTime
-     *
-     * This iterator must be closed after use.
-     *
-     * @param keyFrom The first key that could be in the range
-     * @param keyTo The last key that could be in the range
-     * @param earliestSessionEndTime the end timestamp of the earliest session to search for
-     * @param latestSessionStartTime the end timestamp of the latest session to search for
-     * @return iterator of sessions with the matching keys and aggregated values
-     * @throws NullPointerException If null is used for any key.
-     */
-    KeyValueIterator<Windowed<K>, AGG> findSessions(final K keyFrom, final K keyTo, final long earliestSessionEndTime, final long latestSessionStartTime);
-
-    /**
-     * Get the value of key from a single session.
-     *
-     * @param key            the key to fetch
-     * @param startTime      start timestamp of the session
-     * @param endTime        end timestamp of the session
-     * @return The value or {@code null} if no session associated with the key can be found
-     * @throws NullPointerException If {@code null} is used for any key.
-     */
-    AGG fetchSession(final K key, final long startTime, final long endTime);
-
-    /**
      * Remove the session aggregated with provided {@link Windowed} key from the store
      * @param sessionKey key of the session to remove
      * @throws NullPointerException If null is used for sessionKey.
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingSessionStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingSessionStore.java
index c92123d..d0fe25a 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingSessionStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingSessionStore.java
@@ -76,7 +76,6 @@ class CachingSessionStore
         super.init(context, root);
     }
 
-    @SuppressWarnings("unchecked")
     private void initInternal(final InternalProcessorContext context) {
         this.context = context;
 
@@ -159,7 +158,7 @@ class CachingSessionStore
         validateStoreOpen();
 
         final PeekingKeyValueIterator<Bytes, LRUCacheEntry> cacheIterator = wrapped().persistent() ?
-            new CacheIteratorWrapper(key, earliestSessionEndTime, latestSessionStartTime) :
+            new CacheIteratorWrapper(key, earliestSessionEndTime, latestSessionStartTime, true) :
             context.cache().range(cacheName,
                         cacheFunction.cacheKey(keySchema.lowerRangeFixedSize(key, earliestSessionEndTime)),
                         cacheFunction.cacheKey(keySchema.upperRangeFixedSize(key, latestSessionStartTime))
@@ -174,7 +173,38 @@ class CachingSessionStore
                                                                              latestSessionStartTime);
         final PeekingKeyValueIterator<Bytes, LRUCacheEntry> filteredCacheIterator =
             new FilteredCacheIterator(cacheIterator, hasNextCondition, cacheFunction);
-        return new MergedSortedCacheSessionStoreIterator(filteredCacheIterator, storeIterator, cacheFunction);
+        return new MergedSortedCacheSessionStoreIterator(filteredCacheIterator, storeIterator, cacheFunction, true);
+    }
+
+    @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> backwardFindSessions(final Bytes key,
+                                                                          final long earliestSessionEndTime,
+                                                                          final long latestSessionStartTime) {
+        validateStoreOpen();
+
+        final PeekingKeyValueIterator<Bytes, LRUCacheEntry> cacheIterator = wrapped().persistent() ?
+            new CacheIteratorWrapper(key, earliestSessionEndTime, latestSessionStartTime, false) :
+            context.cache().reverseRange(
+                cacheName,
+                cacheFunction.cacheKey(keySchema.lowerRangeFixedSize(key, earliestSessionEndTime)),
+                cacheFunction.cacheKey(keySchema.upperRangeFixedSize(key, latestSessionStartTime)
+                )
+            );
+
+        final KeyValueIterator<Windowed<Bytes>, byte[]> storeIterator = wrapped().backwardFindSessions(
+            key,
+            earliestSessionEndTime,
+            latestSessionStartTime
+        );
+        final HasNextCondition hasNextCondition = keySchema.hasNextCondition(
+            key,
+            key,
+            earliestSessionEndTime,
+            latestSessionStartTime
+        );
+        final PeekingKeyValueIterator<Bytes, LRUCacheEntry> filteredCacheIterator =
+            new FilteredCacheIterator(cacheIterator, hasNextCondition, cacheFunction);
+        return new MergedSortedCacheSessionStoreIterator(filteredCacheIterator, storeIterator, cacheFunction, false);
     }
 
     @Override
@@ -205,7 +235,39 @@ class CachingSessionStore
                                                                              latestSessionStartTime);
         final PeekingKeyValueIterator<Bytes, LRUCacheEntry> filteredCacheIterator =
             new FilteredCacheIterator(cacheIterator, hasNextCondition, cacheFunction);
-        return new MergedSortedCacheSessionStoreIterator(filteredCacheIterator, storeIterator, cacheFunction);
+        return new MergedSortedCacheSessionStoreIterator(filteredCacheIterator, storeIterator, cacheFunction, true);
+    }
+
+    @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> backwardFindSessions(final Bytes keyFrom,
+                                                                          final Bytes keyTo,
+                                                                          final long earliestSessionEndTime,
+                                                                          final long latestSessionStartTime) {
+        if (keyFrom.compareTo(keyTo) > 0) {
+            LOG.warn("Returning empty iterator for fetch with invalid key range: from > to. " +
+                         "This may be due to range arguments set in the wrong order, " +
+                         "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes. " +
+                         "Note that the built-in numerical serdes do not follow this for negative numbers");
+            return KeyValueIterators.emptyIterator();
+        }
+
+        validateStoreOpen();
+
+        final Bytes cacheKeyFrom = cacheFunction.cacheKey(keySchema.lowerRange(keyFrom, earliestSessionEndTime));
+        final Bytes cacheKeyTo = cacheFunction.cacheKey(keySchema.upperRange(keyTo, latestSessionStartTime));
+        final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = context.cache().reverseRange(cacheName, cacheKeyFrom, cacheKeyTo);
+
+        final KeyValueIterator<Windowed<Bytes>, byte[]> storeIterator =
+            wrapped().backwardFindSessions(keyFrom, keyTo, earliestSessionEndTime, latestSessionStartTime);
+        final HasNextCondition hasNextCondition = keySchema.hasNextCondition(
+            keyFrom,
+            keyTo,
+            earliestSessionEndTime,
+            latestSessionStartTime
+        );
+        final PeekingKeyValueIterator<Bytes, LRUCacheEntry> filteredCacheIterator =
+            new FilteredCacheIterator(cacheIterator, hasNextCondition, cacheFunction);
+        return new MergedSortedCacheSessionStoreIterator(filteredCacheIterator, storeIterator, cacheFunction, false);
     }
 
     @Override
@@ -233,6 +295,12 @@ class CachingSessionStore
     }
 
     @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> backwardFetch(final Bytes key) {
+        Objects.requireNonNull(key, "key cannot be null");
+        return backwardFindSessions(key, 0, Long.MAX_VALUE);
+    }
+
+    @Override
     public KeyValueIterator<Windowed<Bytes>, byte[]> fetch(final Bytes from,
                                                            final Bytes to) {
         Objects.requireNonNull(from, "from cannot be null");
@@ -240,6 +308,14 @@ class CachingSessionStore
         return findSessions(from, to, 0, Long.MAX_VALUE);
     }
 
+    @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> backwardFetch(final Bytes from,
+                                                                   final Bytes to) {
+        Objects.requireNonNull(from, "from cannot be null");
+        Objects.requireNonNull(to, "to cannot be null");
+        return backwardFindSessions(from, to, 0, Long.MAX_VALUE);
+    }
+
     public void flush() {
         context.cache().flush(cacheName);
         wrapped().flush();
@@ -269,6 +345,8 @@ class CachingSessionStore
         private final Bytes keyFrom;
         private final Bytes keyTo;
         private final long latestSessionStartTime;
+        private final boolean forward;
+
         private long lastSegmentId;
 
         private long currentSegmentId;
@@ -279,25 +357,36 @@ class CachingSessionStore
 
         private CacheIteratorWrapper(final Bytes key,
                                      final long earliestSessionEndTime,
-                                     final long latestSessionStartTime) {
-            this(key, key, earliestSessionEndTime, latestSessionStartTime);
+                                     final long latestSessionStartTime,
+                                     final boolean forward) {
+            this(key, key, earliestSessionEndTime, latestSessionStartTime, forward);
         }
 
         private CacheIteratorWrapper(final Bytes keyFrom,
                                      final Bytes keyTo,
                                      final long earliestSessionEndTime,
-                                     final long latestSessionStartTime) {
+                                     final long latestSessionStartTime,
+                                     final boolean forward) {
             this.keyFrom = keyFrom;
             this.keyTo = keyTo;
             this.latestSessionStartTime = latestSessionStartTime;
-            this.lastSegmentId = cacheFunction.segmentId(maxObservedTimestamp);
             this.segmentInterval = cacheFunction.getSegmentInterval();
+            this.forward = forward;
+
 
-            this.currentSegmentId = cacheFunction.segmentId(earliestSessionEndTime);
+            if (forward) {
+                this.currentSegmentId = cacheFunction.segmentId(earliestSessionEndTime);
+                this.lastSegmentId = cacheFunction.segmentId(maxObservedTimestamp);
 
-            setCacheKeyRange(earliestSessionEndTime, currentSegmentLastTime());
+                setCacheKeyRange(earliestSessionEndTime, currentSegmentLastTime());
+                this.current = context.cache().range(cacheName, cacheKeyFrom, cacheKeyTo);
+            } else {
+                this.lastSegmentId = cacheFunction.segmentId(earliestSessionEndTime);
+                this.currentSegmentId = cacheFunction.segmentId(maxObservedTimestamp);
 
-            this.current = context.cache().range(cacheName, cacheKeyFrom, cacheKeyTo);
+                setCacheKeyRange(currentSegmentBeginTime(), Math.min(latestSessionStartTime, maxObservedTimestamp));
+                this.current = context.cache().reverseRange(cacheName, cacheKeyFrom, cacheKeyTo);
+            }
         }
 
         @Override
@@ -357,18 +446,35 @@ class CachingSessionStore
         }
 
         private void getNextSegmentIterator() {
-            ++currentSegmentId;
-            lastSegmentId = cacheFunction.segmentId(maxObservedTimestamp);
+            if (forward) {
+                ++currentSegmentId;
+                lastSegmentId = cacheFunction.segmentId(maxObservedTimestamp);
 
-            if (currentSegmentId > lastSegmentId) {
-                current = null;
-                return;
-            }
+                if (currentSegmentId > lastSegmentId) {
+                    current = null;
+                    return;
+                }
 
-            setCacheKeyRange(currentSegmentBeginTime(), currentSegmentLastTime());
+                setCacheKeyRange(currentSegmentBeginTime(), currentSegmentLastTime());
+
+                current.close();
+
+                current = context.cache().range(cacheName, cacheKeyFrom, cacheKeyTo);
+            } else {
+                --currentSegmentId;
+
+                if (currentSegmentId < lastSegmentId) {
+                    current = null;
+                    return;
+                }
+
+                setCacheKeyRange(currentSegmentBeginTime(), currentSegmentLastTime());
+
+                current.close();
+
+                current = context.cache().reverseRange(cacheName, cacheKeyFrom, cacheKeyTo);
+            }
 
-            current.close();
-            current = context.cache().range(cacheName, cacheKeyFrom, cacheKeyTo);
         }
 
         private void setCacheKeyRange(final long lowerRangeEndTime, final long upperRangeEndTime) {
@@ -376,7 +482,7 @@ class CachingSessionStore
                 throw new IllegalStateException("Error iterating over segments: segment interval has changed");
             }
 
-            if (keyFrom == keyTo) {
+            if (keyFrom.equals(keyTo)) {
                 cacheKeyFrom = cacheFunction.cacheKey(segmentLowerRangeFixedSize(keyFrom, lowerRangeEndTime));
                 cacheKeyTo = cacheFunction.cacheKey(segmentUpperRangeFixedSize(keyTo, upperRangeEndTime));
             } else {
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/ChangeLoggingSessionBytesStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/ChangeLoggingSessionBytesStore.java
index 0d2133d..556a67e 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/ChangeLoggingSessionBytesStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/ChangeLoggingSessionBytesStore.java
@@ -60,11 +60,25 @@ class ChangeLoggingSessionBytesStore
     }
 
     @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> backwardFindSessions(final Bytes key,
+                                                                          final long earliestSessionEndTime,
+                                                                          final long latestSessionStartTime) {
+        return wrapped().backwardFindSessions(key, earliestSessionEndTime, latestSessionStartTime);
+    }
+
+    @Override
     public KeyValueIterator<Windowed<Bytes>, byte[]> findSessions(final Bytes keyFrom, final Bytes keyTo, final long earliestSessionEndTime, final long latestSessionStartTime) {
         return wrapped().findSessions(keyFrom, keyTo, earliestSessionEndTime, latestSessionStartTime);
     }
 
     @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> backwardFindSessions(final Bytes keyFrom, final Bytes keyTo,
+                                                                          final long earliestSessionEndTime,
+                                                                          final long latestSessionStartTime) {
+        return wrapped().backwardFindSessions(keyFrom, keyTo, earliestSessionEndTime, latestSessionStartTime);
+    }
+
+    @Override
     public void remove(final Windowed<Bytes> sessionKey) {
         wrapped().remove(sessionKey);
         context.logChange(name(), SessionKeySchema.toBinary(sessionKey), null, context.timestamp());
@@ -82,11 +96,21 @@ class ChangeLoggingSessionBytesStore
     }
 
     @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> backwardFetch(final Bytes key) {
+        return wrapped().backwardFetch(key);
+    }
+
+    @Override
     public KeyValueIterator<Windowed<Bytes>, byte[]> fetch(final Bytes key) {
         return wrapped().fetch(key);
     }
 
     @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> backwardFetch(final Bytes from, final Bytes to) {
+        return wrapped().backwardFetch(from, to);
+    }
+
+    @Override
     public KeyValueIterator<Windowed<Bytes>, byte[]> fetch(final Bytes from, final Bytes to) {
         return wrapped().fetch(from, to);
     }
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/CompositeReadOnlySessionStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/CompositeReadOnlySessionStore.java
index 63d551c..7223312 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/CompositeReadOnlySessionStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/CompositeReadOnlySessionStore.java
@@ -43,6 +43,137 @@ public class CompositeReadOnlySessionStore<K, V> implements ReadOnlySessionStore
     }
 
     @Override
+    public KeyValueIterator<Windowed<K>, V> findSessions(final K key,
+                                                         final long earliestSessionEndTime,
+                                                         final long latestSessionStartTime) {
+        Objects.requireNonNull(key, "key can't be null");
+        final List<ReadOnlySessionStore<K, V>> stores = storeProvider.stores(storeName, queryableStoreType);
+        for (final ReadOnlySessionStore<K, V> store : stores) {
+            try {
+                final KeyValueIterator<Windowed<K>, V> result =
+                    store.findSessions(key, earliestSessionEndTime, latestSessionStartTime);
+
+                if (!result.hasNext()) {
+                    result.close();
+                } else {
+                    return result;
+                }
+            } catch (final InvalidStateStoreException ise) {
+                throw new InvalidStateStoreException(
+                    "State store  [" + storeName + "] is not available anymore" +
+                        " and may have been migrated to another instance; " +
+                        "please re-discover its location from the state metadata.",
+                    ise
+                );
+            }
+        }
+        return KeyValueIterators.emptyIterator();
+    }
+
+    @Override
+    public KeyValueIterator<Windowed<K>, V> backwardFindSessions(final K key,
+                                                                 final long earliestSessionEndTime,
+                                                                 final long latestSessionStartTime) {
+        Objects.requireNonNull(key, "key can't be null");
+        final List<ReadOnlySessionStore<K, V>> stores = storeProvider.stores(storeName, queryableStoreType);
+        for (final ReadOnlySessionStore<K, V> store : stores) {
+            try {
+                final KeyValueIterator<Windowed<K>, V> result = store.backwardFindSessions(key, earliestSessionEndTime, latestSessionStartTime);
+                if (!result.hasNext()) {
+                    result.close();
+                } else {
+                    return result;
+                }
+            } catch (final InvalidStateStoreException ise) {
+                throw new InvalidStateStoreException(
+                    "State store  [" + storeName + "] is not available anymore" +
+                        " and may have been migrated to another instance; " +
+                        "please re-discover its location from the state metadata.",
+                    ise
+                );
+            }
+        }
+        return KeyValueIterators.emptyIterator();
+    }
+
+    @Override
+    public KeyValueIterator<Windowed<K>, V> findSessions(final K keyFrom,
+                                                         final K keyTo,
+                                                         final long earliestSessionEndTime,
+                                                         final long latestSessionStartTime) {
+        Objects.requireNonNull(keyFrom, "from can't be null");
+        Objects.requireNonNull(keyTo, "to can't be null");
+        final List<ReadOnlySessionStore<K, V>> stores = storeProvider.stores(storeName, queryableStoreType);
+        for (final ReadOnlySessionStore<K, V> store : stores) {
+            try {
+                final KeyValueIterator<Windowed<K>, V> result =
+                    store.findSessions(keyFrom, keyTo, earliestSessionEndTime, latestSessionStartTime);
+                if (!result.hasNext()) {
+                    result.close();
+                } else {
+                    return result;
+                }
+            } catch (final InvalidStateStoreException ise) {
+                throw new InvalidStateStoreException(
+                    "State store  [" + storeName + "] is not available anymore" +
+                        " and may have been migrated to another instance; " +
+                        "please re-discover its location from the state metadata.",
+                    ise
+                );
+            }
+        }
+        return KeyValueIterators.emptyIterator();
+    }
+
+    @Override
+    public KeyValueIterator<Windowed<K>, V> backwardFindSessions(final K keyFrom,
+                                                                 final K keyTo,
+                                                                 final long earliestSessionEndTime,
+                                                                 final long latestSessionStartTime) {
+        Objects.requireNonNull(keyFrom, "from can't be null");
+        Objects.requireNonNull(keyTo, "to can't be null");
+        final List<ReadOnlySessionStore<K, V>> stores = storeProvider.stores(storeName, queryableStoreType);
+        for (final ReadOnlySessionStore<K, V> store : stores) {
+            try {
+                final KeyValueIterator<Windowed<K>, V> result =
+                    store.backwardFindSessions(keyFrom, keyTo, earliestSessionEndTime, latestSessionStartTime);
+                if (!result.hasNext()) {
+                    result.close();
+                } else {
+                    return result;
+                }
+            } catch (final InvalidStateStoreException ise) {
+                throw new InvalidStateStoreException(
+                    "State store  [" + storeName + "] is not available anymore" +
+                        " and may have been migrated to another instance; " +
+                        "please re-discover its location from the state metadata.",
+                    ise
+                );
+            }
+        }
+        return KeyValueIterators.emptyIterator();
+    }
+
+    @Override
+    public V fetchSession(final K key, final long startTime, final long endTime) {
+        Objects.requireNonNull(key, "key can't be null");
+        final List<ReadOnlySessionStore<K, V>> stores = storeProvider.stores(storeName, queryableStoreType);
+        for (final ReadOnlySessionStore<K, V> store : stores) {
+            try {
+                return store.fetchSession(key, startTime, endTime);
+            } catch (final InvalidStateStoreException ise) {
+                throw new InvalidStateStoreException(
+                    "State store  [" + storeName + "] is not available anymore" +
+                        " and may have been migrated to another instance; " +
+                        "please re-discover its location from the state metadata.",
+                    ise
+                );
+            }
+        }
+        return null;
+    }
+
+    @Override
     public KeyValueIterator<Windowed<K>, V> fetch(final K key) {
         Objects.requireNonNull(key, "key can't be null");
         final List<ReadOnlySessionStore<K, V>> stores = storeProvider.stores(storeName, queryableStoreType);
@@ -65,6 +196,30 @@ public class CompositeReadOnlySessionStore<K, V> implements ReadOnlySessionStore
     }
 
     @Override
+    public KeyValueIterator<Windowed<K>, V> backwardFetch(final K key) {
+        Objects.requireNonNull(key, "key can't be null");
+        final List<ReadOnlySessionStore<K, V>> stores = storeProvider.stores(storeName, queryableStoreType);
+        for (final ReadOnlySessionStore<K, V> store : stores) {
+            try {
+                final KeyValueIterator<Windowed<K>, V> result = store.backwardFetch(key);
+                if (!result.hasNext()) {
+                    result.close();
+                } else {
+                    return result;
+                }
+            } catch (final InvalidStateStoreException ise) {
+                throw new InvalidStateStoreException(
+                    "State store  [" + storeName + "] is not available anymore" +
+                        " and may have been migrated to another instance; " +
+                        "please re-discover its location from the state metadata.",
+                    ise
+                );
+            }
+        }
+        return KeyValueIterators.emptyIterator();
+    }
+
+    @Override
     public KeyValueIterator<Windowed<K>, V> fetch(final K from, final K to) {
         Objects.requireNonNull(from, "from can't be null");
         Objects.requireNonNull(to, "to can't be null");
@@ -74,4 +229,18 @@ public class CompositeReadOnlySessionStore<K, V> implements ReadOnlySessionStore
                                                                storeProvider.stores(storeName, queryableStoreType).iterator(),
                                                                nextIteratorFunction));
     }
+
+    @Override
+    public KeyValueIterator<Windowed<K>, V> backwardFetch(final K from, final K to) {
+        Objects.requireNonNull(from, "from can't be null");
+        Objects.requireNonNull(to, "to can't be null");
+        final NextIteratorFunction<Windowed<K>, V, ReadOnlySessionStore<K, V>> nextIteratorFunction = store -> store.backwardFetch(from, to);
+        return new DelegatingPeekingKeyValueIterator<>(
+            storeName,
+            new CompositeKeyValueIterator<>(
+                storeProvider.stores(storeName, queryableStoreType).iterator(),
+                nextIteratorFunction
+            )
+        );
+    }
 }
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemorySessionStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemorySessionStore.java
index 2e45b48..46c4de2 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemorySessionStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemorySessionStore.java
@@ -168,7 +168,25 @@ public class InMemorySessionStore implements SessionStore<Bytes, byte[]> {
         return registerNewIterator(key,
                                    key,
                                    latestSessionStartTime,
-                                   endTimeMap.tailMap(earliestSessionEndTime, true).entrySet().iterator());
+                                   endTimeMap.tailMap(earliestSessionEndTime, true).entrySet().iterator(),
+                                   true);
+    }
+
+    @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> backwardFindSessions(final Bytes key,
+                                                                          final long earliestSessionEndTime,
+                                                                          final long latestSessionStartTime) {
+        Objects.requireNonNull(key, "key cannot be null");
+
+        removeExpiredSegments();
+
+        return registerNewIterator(
+            key,
+            key,
+            latestSessionStartTime,
+            endTimeMap.tailMap(earliestSessionEndTime, true).descendingMap().entrySet().iterator(),
+            false
+        );
     }
 
     @Override
@@ -192,7 +210,35 @@ public class InMemorySessionStore implements SessionStore<Bytes, byte[]> {
         return registerNewIterator(keyFrom,
                                    keyTo,
                                    latestSessionStartTime,
-                                   endTimeMap.tailMap(earliestSessionEndTime, true).entrySet().iterator());
+                                   endTimeMap.tailMap(earliestSessionEndTime, true).entrySet().iterator(),
+                                   true);
+    }
+
+    @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> backwardFindSessions(final Bytes keyFrom,
+                                                                          final Bytes keyTo,
+                                                                          final long earliestSessionEndTime,
+                                                                          final long latestSessionStartTime) {
+        Objects.requireNonNull(keyFrom, "from key cannot be null");
+        Objects.requireNonNull(keyTo, "to key cannot be null");
+
+        removeExpiredSegments();
+
+        if (keyFrom.compareTo(keyTo) > 0) {
+            LOG.warn("Returning empty iterator for fetch with invalid key range: from > to. " +
+                "This may be due to range arguments set in the wrong order, " +
+                "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes. " +
+                "Note that the built-in numerical serdes do not follow this for negative numbers");
+            return KeyValueIterators.emptyIterator();
+        }
+
+        return registerNewIterator(
+            keyFrom,
+            keyTo,
+            latestSessionStartTime,
+            endTimeMap.tailMap(earliestSessionEndTime, true).descendingMap().entrySet().iterator(),
+            false
+        );
     }
 
     @Override
@@ -202,7 +248,17 @@ public class InMemorySessionStore implements SessionStore<Bytes, byte[]> {
 
         removeExpiredSegments();
 
-        return registerNewIterator(key, key, Long.MAX_VALUE, endTimeMap.entrySet().iterator());
+        return registerNewIterator(key, key, Long.MAX_VALUE, endTimeMap.entrySet().iterator(), true);
+    }
+
+    @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> backwardFetch(final Bytes key) {
+
+        Objects.requireNonNull(key, "key cannot be null");
+
+        removeExpiredSegments();
+
+        return registerNewIterator(key, key, Long.MAX_VALUE, endTimeMap.descendingMap().entrySet().iterator(), false);
     }
 
     @Override
@@ -214,7 +270,17 @@ public class InMemorySessionStore implements SessionStore<Bytes, byte[]> {
         removeExpiredSegments();
 
 
-        return registerNewIterator(from, to, Long.MAX_VALUE, endTimeMap.entrySet().iterator());
+        return registerNewIterator(from, to, Long.MAX_VALUE, endTimeMap.entrySet().iterator(), false);
+    }
+
+    @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> backwardFetch(final Bytes from, final Bytes to) {
+        Objects.requireNonNull(from, "from key cannot be null");
+        Objects.requireNonNull(to, "to key cannot be null");
+
+        removeExpiredSegments();
+
+        return registerNewIterator(from, to, Long.MAX_VALUE, endTimeMap.descendingMap().entrySet().iterator(), true);
     }
 
     @Override
@@ -259,8 +325,17 @@ public class InMemorySessionStore implements SessionStore<Bytes, byte[]> {
     private InMemorySessionStoreIterator registerNewIterator(final Bytes keyFrom,
                                                              final Bytes keyTo,
                                                              final long latestSessionStartTime,
-                                                             final Iterator<Entry<Long, ConcurrentNavigableMap<Bytes, ConcurrentNavigableMap<Long, byte[]>>>> endTimeIterator) {
-        final InMemorySessionStoreIterator iterator = new InMemorySessionStoreIterator(keyFrom, keyTo, latestSessionStartTime, endTimeIterator, openIterators::remove);
+                                                             final Iterator<Entry<Long, ConcurrentNavigableMap<Bytes, ConcurrentNavigableMap<Long, byte[]>>>> endTimeIterator,
+                                                             final boolean forward) {
+        final InMemorySessionStoreIterator iterator =
+            new InMemorySessionStoreIterator(
+                keyFrom,
+                keyTo,
+                latestSessionStartTime,
+                endTimeIterator,
+                openIterators::remove,
+                forward
+            );
         openIterators.add(iterator);
         return iterator;
     }
@@ -285,17 +360,21 @@ public class InMemorySessionStore implements SessionStore<Bytes, byte[]> {
 
         private final ClosingCallback callback;
 
+        private final boolean forward;
+
         InMemorySessionStoreIterator(final Bytes keyFrom,
                                      final Bytes keyTo,
                                      final long latestSessionStartTime,
                                      final Iterator<Entry<Long, ConcurrentNavigableMap<Bytes, ConcurrentNavigableMap<Long, byte[]>>>> endTimeIterator,
-                                     final ClosingCallback callback) {
+                                     final ClosingCallback callback,
+                                     final boolean forward) {
             this.keyFrom = keyFrom;
             this.keyTo = keyTo;
             this.latestSessionStartTime = latestSessionStartTime;
 
             this.endTimeIterator = endTimeIterator;
             this.callback = callback;
+            this.forward = forward;
             setAllIterators();
         }
 
@@ -366,7 +445,18 @@ public class InMemorySessionStore implements SessionStore<Bytes, byte[]> {
             while (endTimeIterator.hasNext()) {
                 final Entry<Long, ConcurrentNavigableMap<Bytes, ConcurrentNavigableMap<Long, byte[]>>> nextEndTimeEntry = endTimeIterator.next();
                 currentEndTime = nextEndTimeEntry.getKey();
-                keyIterator = nextEndTimeEntry.getValue().subMap(keyFrom, true, keyTo, true).entrySet().iterator();
+                if (forward) {
+                    keyIterator = nextEndTimeEntry.getValue()
+                                                  .subMap(keyFrom, true, keyTo, true)
+                                                  .entrySet()
+                                                  .iterator();
+                } else {
+                    keyIterator = nextEndTimeEntry.getValue()
+                                                  .subMap(keyFrom, true, keyTo, true)
+                                                  .descendingMap()
+                                                  .entrySet()
+                                                  .iterator();
+                }
 
                 if (setInnerIterators()) {
                     return;
@@ -383,9 +473,22 @@ public class InMemorySessionStore implements SessionStore<Bytes, byte[]> {
                 currentKey = nextKeyEntry.getKey();
 
                 if (latestSessionStartTime == Long.MAX_VALUE) {
-                    recordIterator = nextKeyEntry.getValue().entrySet().iterator();
+                    if (forward) {
+                        recordIterator = nextKeyEntry.getValue().descendingMap().entrySet().iterator();
+                    } else {
+                        recordIterator = nextKeyEntry.getValue().entrySet().iterator();
+                    }
                 } else {
-                    recordIterator = nextKeyEntry.getValue().headMap(latestSessionStartTime, true).entrySet().iterator();
+                    if (forward) {
+                        recordIterator = nextKeyEntry.getValue()
+                                                     .headMap(latestSessionStartTime, true)
+                                                     .descendingMap()
+                                                     .entrySet().iterator();
+                    } else {
+                        recordIterator = nextKeyEntry.getValue()
+                                                     .headMap(latestSessionStartTime, true)
+                                                     .entrySet().iterator();
+                    }
                 }
 
                 if (recordIterator.hasNext()) {
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/MergedSortedCacheSessionStoreIterator.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/MergedSortedCacheSessionStoreIterator.java
index ff45a41..cd0c0df 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/MergedSortedCacheSessionStoreIterator.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/MergedSortedCacheSessionStoreIterator.java
@@ -32,8 +32,9 @@ class MergedSortedCacheSessionStoreIterator extends AbstractMergedSortedCacheSto
 
     MergedSortedCacheSessionStoreIterator(final PeekingKeyValueIterator<Bytes, LRUCacheEntry> cacheIterator,
                                           final KeyValueIterator<Windowed<Bytes>, byte[]> storeIterator,
-                                          final SegmentedCacheFunction cacheFunction) {
-        super(cacheIterator, storeIterator, true);
+                                          final SegmentedCacheFunction cacheFunction,
+                                          final boolean forward) {
+        super(cacheIterator, storeIterator, forward);
         this.cacheFunction = cacheFunction;
     }
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStore.java
index 8b9256d..f7f25c0 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStore.java
@@ -222,6 +222,18 @@ public class MeteredSessionStore<K, V>
     }
 
     @Override
+    public KeyValueIterator<Windowed<K>, V> backwardFetch(final K key) {
+        Objects.requireNonNull(key, "key cannot be null");
+        return new MeteredWindowedKeyValueIterator<>(
+            wrapped().backwardFetch(keyBytes(key)),
+            fetchSensor,
+            streamsMetrics,
+            serdes,
+            time
+        );
+    }
+
+    @Override
     public KeyValueIterator<Windowed<K>, V> fetch(final K from,
                                                   final K to) {
         Objects.requireNonNull(from, "from cannot be null");
@@ -235,6 +247,20 @@ public class MeteredSessionStore<K, V>
     }
 
     @Override
+    public KeyValueIterator<Windowed<K>, V> backwardFetch(final K from,
+                                                          final K to) {
+        Objects.requireNonNull(from, "from cannot be null");
+        Objects.requireNonNull(to, "to cannot be null");
+        return new MeteredWindowedKeyValueIterator<>(
+            wrapped().backwardFetch(keyBytes(from), keyBytes(to)),
+            fetchSensor,
+            streamsMetrics,
+            serdes,
+            time
+        );
+    }
+
+    @Override
     public KeyValueIterator<Windowed<K>, V> findSessions(final K key,
                                                          final long earliestSessionEndTime,
                                                          final long latestSessionStartTime) {
@@ -252,6 +278,25 @@ public class MeteredSessionStore<K, V>
     }
 
     @Override
+    public KeyValueIterator<Windowed<K>, V> backwardFindSessions(final K key,
+                                                                 final long earliestSessionEndTime,
+                                                                 final long latestSessionStartTime) {
+        Objects.requireNonNull(key, "key cannot be null");
+        final Bytes bytesKey = keyBytes(key);
+        return new MeteredWindowedKeyValueIterator<>(
+            wrapped().backwardFindSessions(
+                bytesKey,
+                earliestSessionEndTime,
+                latestSessionStartTime
+            ),
+            fetchSensor,
+            streamsMetrics,
+            serdes,
+            time
+        );
+    }
+
+    @Override
     public KeyValueIterator<Windowed<K>, V> findSessions(final K keyFrom,
                                                          final K keyTo,
                                                          final long earliestSessionEndTime,
@@ -273,6 +318,29 @@ public class MeteredSessionStore<K, V>
     }
 
     @Override
+    public KeyValueIterator<Windowed<K>, V> backwardFindSessions(final K keyFrom,
+                                                                 final K keyTo,
+                                                                 final long earliestSessionEndTime,
+                                                                 final long latestSessionStartTime) {
+        Objects.requireNonNull(keyFrom, "keyFrom cannot be null");
+        Objects.requireNonNull(keyTo, "keyTo cannot be null");
+        final Bytes bytesKeyFrom = keyBytes(keyFrom);
+        final Bytes bytesKeyTo = keyBytes(keyTo);
+        return new MeteredWindowedKeyValueIterator<>(
+            wrapped().backwardFindSessions(
+                bytesKeyFrom,
+                bytesKeyTo,
+                earliestSessionEndTime,
+                latestSessionStartTime
+            ),
+            fetchSensor,
+            streamsMetrics,
+            serdes,
+            time
+        );
+    }
+
+    @Override
     public void flush() {
         maybeMeasureLatency(super::flush, time, flushSensor);
     }
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBSessionStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBSessionStore.java
index 2f7a211..338769a 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBSessionStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBSessionStore.java
@@ -43,6 +43,18 @@ public class RocksDBSessionStore
     }
 
     @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> backwardFindSessions(final Bytes key,
+                                                                          final long earliestSessionEndTime,
+                                                                          final long latestSessionStartTime) {
+        final KeyValueIterator<Bytes, byte[]> bytesIterator = wrapped().backwardFetch(
+            key,
+            earliestSessionEndTime,
+            latestSessionStartTime
+        );
+        return new WrappedSessionStoreIterator(bytesIterator);
+    }
+
+    @Override
     public KeyValueIterator<Windowed<Bytes>, byte[]> findSessions(final Bytes keyFrom,
                                                                   final Bytes keyTo,
                                                                   final long earliestSessionEndTime,
@@ -57,6 +69,20 @@ public class RocksDBSessionStore
     }
 
     @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> backwardFindSessions(final Bytes keyFrom,
+                                                                          final Bytes keyTo,
+                                                                          final long earliestSessionEndTime,
+                                                                          final long latestSessionStartTime) {
+        final KeyValueIterator<Bytes, byte[]> bytesIterator = wrapped().backwardFetch(
+            keyFrom,
+            keyTo,
+            earliestSessionEndTime,
+            latestSessionStartTime
+        );
+        return new WrappedSessionStoreIterator(bytesIterator);
+    }
+
+    @Override
     public byte[] fetchSession(final Bytes key, final long startTime, final long endTime) {
         return wrapped().get(SessionKeySchema.toBinary(key, startTime, endTime));
     }
@@ -67,11 +93,21 @@ public class RocksDBSessionStore
     }
 
     @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> backwardFetch(final Bytes key) {
+        return backwardFindSessions(key, 0, Long.MAX_VALUE);
+    }
+
+    @Override
     public KeyValueIterator<Windowed<Bytes>, byte[]> fetch(final Bytes from, final Bytes to) {
         return findSessions(from, to, 0, Long.MAX_VALUE);
     }
 
     @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> backwardFetch(final Bytes from, final Bytes to) {
+        return backwardFindSessions(from, to, 0, Long.MAX_VALUE);
+    }
+
+    @Override
     public void remove(final Windowed<Bytes> key) {
         wrapped().remove(SessionKeySchema.toBinary(key));
     }
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractSessionBytesStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractSessionBytesStoreTest.java
index bb425a9..b355f0e 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractSessionBytesStoreTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractSessionBytesStoreTest.java
@@ -135,6 +135,31 @@ public abstract class AbstractSessionBytesStoreTest {
     }
 
     @Test
+    public void shouldPutAndBackwardFindSessionsInRange() {
+        final String key = "a";
+        final Windowed<String> a1 = new Windowed<>(key, new SessionWindow(10, 10L));
+        final Windowed<String> a2 = new Windowed<>(key, new SessionWindow(500L, 1000L));
+        sessionStore.put(a1, 1L);
+        sessionStore.put(a2, 2L);
+        sessionStore.put(new Windowed<>(key, new SessionWindow(1500L, 2000L)), 1L);
+        sessionStore.put(new Windowed<>(key, new SessionWindow(2500L, 3000L)), 2L);
+
+        final List<KeyValue<Windowed<String>, Long>> expected =
+            asList(KeyValue.pair(a1, 1L), KeyValue.pair(a2, 2L));
+
+        try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.backwardFindSessions(key, 0, 1000L)) {
+            assertEquals(new HashSet<>(expected), toSet(values));
+        }
+
+        final List<KeyValue<Windowed<String>, Long>> expected2 =
+            Collections.singletonList(KeyValue.pair(a2, 2L));
+
+        try (final KeyValueIterator<Windowed<String>, Long> values2 = sessionStore.backwardFindSessions(key, 400L, 600L)) {
+            assertEquals(new HashSet<>(expected2), toSet(values2));
+        }
+    }
+
+    @Test
     public void shouldFetchAllSessionsWithSameRecordKey() {
         final List<KeyValue<Windowed<String>, Long>> expected = Arrays.asList(
             KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 1L),
@@ -155,6 +180,27 @@ public abstract class AbstractSessionBytesStoreTest {
     }
 
     @Test
+    public void shouldBackwardFetchAllSessionsWithSameRecordKey() {
+        final List<KeyValue<Windowed<String>, Long>> expected = asList(
+            KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 1L),
+            KeyValue.pair(new Windowed<>("a", new SessionWindow(10, 10)), 2L),
+            KeyValue.pair(new Windowed<>("a", new SessionWindow(100, 100)), 3L),
+            KeyValue.pair(new Windowed<>("a", new SessionWindow(1000, 1000)), 4L)
+        );
+
+        for (final KeyValue<Windowed<String>, Long> kv : expected) {
+            sessionStore.put(kv.key, kv.value);
+        }
+
+        // add one that shouldn't appear in the results
+        sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 0)), 5L);
+
+        try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.backwardFetch("a")) {
+            assertEquals(new HashSet<>(expected), toSet(values));
+        }
+    }
+
+    @Test
     public void shouldFetchAllSessionsWithinKeyRange() {
         final List<KeyValue<Windowed<String>, Long>> expected = Arrays.asList(
             KeyValue.pair(new Windowed<>("aa", new SessionWindow(10, 10)), 2L),
@@ -177,6 +223,29 @@ public abstract class AbstractSessionBytesStoreTest {
     }
 
     @Test
+    public void shouldBackwardFetchAllSessionsWithinKeyRange() {
+        final List<KeyValue<Windowed<String>, Long>> expected = asList(
+            KeyValue.pair(new Windowed<>("aa", new SessionWindow(10, 10)), 2L),
+            KeyValue.pair(new Windowed<>("b", new SessionWindow(1000, 1000)), 4L),
+
+            KeyValue.pair(new Windowed<>("aaa", new SessionWindow(100, 100)), 3L),
+            KeyValue.pair(new Windowed<>("bb", new SessionWindow(1500, 2000)), 5L)
+        );
+
+        for (final KeyValue<Windowed<String>, Long> kv : expected) {
+            sessionStore.put(kv.key, kv.value);
+        }
+
+        // add some that shouldn't appear in the results
+        sessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L);
+        sessionStore.put(new Windowed<>("bbb", new SessionWindow(2500, 3000)), 6L);
+
+        try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.backwardFetch("aa", "bb")) {
+            assertEquals(new HashSet<>(expected), toSet(values));
+        }
+    }
+
+    @Test
     public void shouldFetchExactSession() {
         sessionStore.put(new Windowed<>("a", new SessionWindow(0, 4)), 1L);
         sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 3)), 2L);
@@ -209,6 +278,22 @@ public abstract class AbstractSessionBytesStoreTest {
     }
 
     @Test
+    public void shouldBackwardFindValuesWithinMergingSessionWindowRange() {
+        final String key = "a";
+        sessionStore.put(new Windowed<>(key, new SessionWindow(0L, 0L)), 1L);
+        sessionStore.put(new Windowed<>(key, new SessionWindow(1000L, 1000L)), 2L);
+
+        final List<KeyValue<Windowed<String>, Long>> expected = asList(
+            KeyValue.pair(new Windowed<>(key, new SessionWindow(0L, 0L)), 1L),
+            KeyValue.pair(new Windowed<>(key, new SessionWindow(1000L, 1000L)), 2L)
+        );
+
+        try (final KeyValueIterator<Windowed<String>, Long> results = sessionStore.backwardFindSessions(key, -1, 1000L)) {
+            assertEquals(new HashSet<>(expected), toSet(results));
+        }
+    }
+
+    @Test
     public void shouldRemove() {
         sessionStore.put(new Windowed<>("a", new SessionWindow(0, 1000)), 1L);
         sessionStore.put(new Windowed<>("a", new SessionWindow(1500, 2500)), 2L);
@@ -262,6 +347,27 @@ public abstract class AbstractSessionBytesStoreTest {
     }
 
     @Test
+    public void shouldBackwardFindSessionsToMerge() {
+        final Windowed<String> session1 = new Windowed<>("a", new SessionWindow(0, 100));
+        final Windowed<String> session2 = new Windowed<>("a", new SessionWindow(101, 200));
+        final Windowed<String> session3 = new Windowed<>("a", new SessionWindow(201, 300));
+        final Windowed<String> session4 = new Windowed<>("a", new SessionWindow(301, 400));
+        final Windowed<String> session5 = new Windowed<>("a", new SessionWindow(401, 500));
+        sessionStore.put(session1, 1L);
+        sessionStore.put(session2, 2L);
+        sessionStore.put(session3, 3L);
+        sessionStore.put(session4, 4L);
+        sessionStore.put(session5, 5L);
+
+        final List<KeyValue<Windowed<String>, Long>> expected =
+            asList(KeyValue.pair(session2, 2L), KeyValue.pair(session3, 3L));
+
+        try (final KeyValueIterator<Windowed<String>, Long> results = sessionStore.backwardFindSessions("a", 150, 300)) {
+            assertEquals(new HashSet<>(expected), toSet(results));
+        }
+    }
+
+    @Test
     public void shouldFetchExactKeys() {
         sessionStore = buildSessionStore(0x7a00000000000000L, Serdes.String(), Serdes.Long());
         sessionStore.init((StateStoreContext) context, sessionStore);
@@ -299,6 +405,43 @@ public abstract class AbstractSessionBytesStoreTest {
     }
 
     @Test
+    public void shouldBackwardFetchExactKeys() {
+        sessionStore = buildSessionStore(0x7a00000000000000L, Serdes.String(), Serdes.Long());
+        sessionStore.init((StateStoreContext) context, sessionStore);
+
+        sessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L);
+        sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 10)), 2L);
+        sessionStore.put(new Windowed<>("a", new SessionWindow(10, 20)), 3L);
+        sessionStore.put(new Windowed<>("aa", new SessionWindow(10, 20)), 4L);
+        sessionStore.put(new Windowed<>("a",
+            new SessionWindow(0x7a00000000000000L - 2, 0x7a00000000000000L - 1)), 5L);
+
+        try (final KeyValueIterator<Windowed<String>, Long> iterator =
+                 sessionStore.backwardFindSessions("a", 0, Long.MAX_VALUE)
+        ) {
+            assertThat(valuesToSet(iterator), equalTo(new HashSet<>(asList(1L, 3L, 5L))));
+        }
+
+        try (final KeyValueIterator<Windowed<String>, Long> iterator =
+                 sessionStore.backwardFindSessions("aa", 0, Long.MAX_VALUE)
+        ) {
+            assertThat(valuesToSet(iterator), equalTo(new HashSet<>(asList(2L, 4L))));
+        }
+
+        try (final KeyValueIterator<Windowed<String>, Long> iterator =
+                 sessionStore.backwardFindSessions("a", "aa", 0, Long.MAX_VALUE)
+        ) {
+            assertThat(valuesToSet(iterator), equalTo(new HashSet<>(asList(1L, 2L, 3L, 4L, 5L))));
+        }
+
+        try (final KeyValueIterator<Windowed<String>, Long> iterator =
+                 sessionStore.backwardFindSessions("a", "aa", 10, 0)
+        ) {
+            assertThat(valuesToSet(iterator), equalTo(new HashSet<>(Collections.singletonList(2L))));
+        }
+    }
+
+    @Test
     public void shouldFetchAndIterateOverExactBinaryKeys() {
         final SessionStore<Bytes, String> sessionStore =
             buildSessionStore(RETENTION_PERIOD, Serdes.Bytes(), Serdes.String());
@@ -328,6 +471,35 @@ public abstract class AbstractSessionBytesStoreTest {
     }
 
     @Test
+    public void shouldBackwardFetchAndIterateOverExactBinaryKeys() {
+        final SessionStore<Bytes, String> sessionStore =
+            buildSessionStore(RETENTION_PERIOD, Serdes.Bytes(), Serdes.String());
+
+        sessionStore.init((StateStoreContext) context, sessionStore);
+
+        final Bytes key1 = Bytes.wrap(new byte[] {0});
+        final Bytes key2 = Bytes.wrap(new byte[] {0, 0});
+        final Bytes key3 = Bytes.wrap(new byte[] {0, 0, 0});
+
+        sessionStore.put(new Windowed<>(key1, new SessionWindow(1, 100)), "1");
+        sessionStore.put(new Windowed<>(key2, new SessionWindow(2, 100)), "2");
+        sessionStore.put(new Windowed<>(key3, new SessionWindow(3, 100)), "3");
+        sessionStore.put(new Windowed<>(key1, new SessionWindow(4, 100)), "4");
+        sessionStore.put(new Windowed<>(key2, new SessionWindow(5, 100)), "5");
+        sessionStore.put(new Windowed<>(key3, new SessionWindow(6, 100)), "6");
+        sessionStore.put(new Windowed<>(key1, new SessionWindow(7, 100)), "7");
+        sessionStore.put(new Windowed<>(key2, new SessionWindow(8, 100)), "8");
+        sessionStore.put(new Windowed<>(key3, new SessionWindow(9, 100)), "9");
+
+        final Set<String> expectedKey1 = new HashSet<>(asList("1", "4", "7"));
+        assertThat(valuesToSet(sessionStore.backwardFindSessions(key1, 0L, Long.MAX_VALUE)), equalTo(expectedKey1));
+        final Set<String> expectedKey2 = new HashSet<>(asList("2", "5", "8"));
+        assertThat(valuesToSet(sessionStore.backwardFindSessions(key2, 0L, Long.MAX_VALUE)), equalTo(expectedKey2));
+        final Set<String> expectedKey3 = new HashSet<>(asList("3", "6", "9"));
+        assertThat(valuesToSet(sessionStore.backwardFindSessions(key3, 0L, Long.MAX_VALUE)), equalTo(expectedKey3));
+    }
+
+    @Test
     public void testIteratorPeek() {
         sessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L);
         sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 10)), 2L);
@@ -343,6 +515,21 @@ public abstract class AbstractSessionBytesStoreTest {
     }
 
     @Test
+    public void testIteratorPeekBackward() {
+        sessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L);
+        sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 10)), 2L);
+        sessionStore.put(new Windowed<>("a", new SessionWindow(10, 20)), 3L);
+        sessionStore.put(new Windowed<>("aa", new SessionWindow(10, 20)), 4L);
+
+        final KeyValueIterator<Windowed<String>, Long> iterator = sessionStore.backwardFindSessions("a", 0L, 20);
+
+        assertEquals(iterator.peekNextKey(), new Windowed<>("a", new SessionWindow(10L, 20L)));
+        assertEquals(iterator.peekNextKey(), iterator.next().key);
+        assertEquals(iterator.peekNextKey(), iterator.next().key);
+        assertFalse(iterator.hasNext());
+    }
+
+    @Test
     public void shouldRestore() {
         final List<KeyValue<Windowed<String>, Long>> expected = Arrays.asList(
             KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 1L),
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/CacheFlushListenerStub.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/CacheFlushListenerStub.java
new file mode 100644
index 0000000..ea4b147
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/CacheFlushListenerStub.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.state.internals;
+
+import org.apache.kafka.common.serialization.Deserializer;
+import org.apache.kafka.streams.kstream.internals.Change;
+
+import java.util.HashMap;
+import java.util.Map;
+
+public class CacheFlushListenerStub<K, V> implements CacheFlushListener<byte[], byte[]> {
+    private final Deserializer<K> keyDeserializer;
+    private final Deserializer<V> valueDeserializer;
+    final Map<K, Change<V>> forwarded = new HashMap<>();
+
+    CacheFlushListenerStub(final Deserializer<K> keyDeserializer,
+                           final Deserializer<V> valueDeserializer) {
+        this.keyDeserializer = keyDeserializer;
+        this.valueDeserializer = valueDeserializer;
+    }
+
+    @Override
+    public void apply(final byte[] key,
+                      final byte[] newValue,
+                      final byte[] oldValue,
+                      final long timestamp) {
+        forwarded.put(
+            keyDeserializer.deserialize(null, key),
+            new Change<>(
+                valueDeserializer.deserialize(null, newValue),
+                valueDeserializer.deserialize(null, oldValue)
+            )
+        );
+    }
+}
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingKeyValueStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingInMemoryKeyValueStoreTest.java
similarity index 94%
rename from streams/src/test/java/org/apache/kafka/streams/state/internals/CachingKeyValueStoreTest.java
rename to streams/src/test/java/org/apache/kafka/streams/state/internals/CachingInMemoryKeyValueStoreTest.java
index 89e2b0e..a9085a6 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingKeyValueStoreTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingInMemoryKeyValueStoreTest.java
@@ -17,14 +17,12 @@
 package org.apache.kafka.streams.state.internals;
 
 import org.apache.kafka.common.metrics.Metrics;
-import org.apache.kafka.common.serialization.Deserializer;
 import org.apache.kafka.common.serialization.Serde;
 import org.apache.kafka.common.serialization.StringDeserializer;
 import org.apache.kafka.common.utils.Bytes;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.streams.KeyValue;
 import org.apache.kafka.streams.errors.InvalidStateStoreException;
-import org.apache.kafka.streams.kstream.internals.Change;
 import org.apache.kafka.streams.processor.ProcessorContext;
 import org.apache.kafka.streams.processor.StateStoreContext;
 import org.apache.kafka.streams.processor.internals.MockStreamsMetrics;
@@ -43,9 +41,7 @@ import org.junit.Test;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.List;
-import java.util.Map;
 
 import static org.apache.kafka.streams.state.internals.ThreadCacheTest.memoryCacheEntrySize;
 import static org.hamcrest.CoreMatchers.equalTo;
@@ -57,7 +53,7 @@ import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
 
-public class CachingKeyValueStoreTest extends AbstractKeyValueStoreTest {
+public class CachingInMemoryKeyValueStoreTest extends AbstractKeyValueStoreTest {
 
     private final static String TOPIC = "topic";
     private static final String CACHE_NAMESPACE = "0_0-store-name";
@@ -527,27 +523,4 @@ public class CachingKeyValueStoreTest extends AbstractKeyValueStoreTest {
         return i;
     }
 
-    public static class CacheFlushListenerStub<K, V> implements CacheFlushListener<byte[], byte[]> {
-        final Deserializer<K> keyDeserializer;
-        final Deserializer<V> valueDeserializer;
-        final Map<K, Change<V>> forwarded = new HashMap<>();
-
-        CacheFlushListenerStub(final Deserializer<K> keyDeserializer,
-                               final Deserializer<V> valueDeserializer) {
-            this.keyDeserializer = keyDeserializer;
-            this.valueDeserializer = valueDeserializer;
-        }
-
-        @Override
-        public void apply(final byte[] key,
-                          final byte[] newValue,
-                          final byte[] oldValue,
-                          final long timestamp) {
-            forwarded.put(
-                keyDeserializer.deserialize(null, key),
-                new Change<>(
-                    valueDeserializer.deserialize(null, newValue),
-                    valueDeserializer.deserialize(null, oldValue)));
-        }
-    }
 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingSessionStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingInMemorySessionStoreTest.java
similarity index 71%
copy from streams/src/test/java/org/apache/kafka/streams/state/internals/CachingSessionStoreTest.java
copy to streams/src/test/java/org/apache/kafka/streams/state/internals/CachingInMemorySessionStoreTest.java
index 05e97a2..e584e2c 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingSessionStoreTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingInMemorySessionStoreTest.java
@@ -46,14 +46,11 @@ import org.junit.Test;
 import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
 import java.util.Collections;
-import java.util.HashSet;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Random;
-import java.util.Set;
 
 import static java.util.Arrays.asList;
-import static org.apache.kafka.common.utils.Utils.mkSet;
 import static org.apache.kafka.test.StreamsTestUtils.toList;
 import static org.apache.kafka.test.StreamsTestUtils.verifyKeyValueList;
 import static org.apache.kafka.test.StreamsTestUtils.verifyWindowedKeyValue;
@@ -68,7 +65,7 @@ import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
 
 @SuppressWarnings("PointlessArithmeticExpression")
-public class CachingSessionStoreTest {
+public class CachingInMemorySessionStoreTest {
 
     private static final int MAX_CACHE_SIZE_BYTES = 600;
     private static final Long DEFAULT_TIMESTAMP = 10L;
@@ -80,14 +77,14 @@ public class CachingSessionStoreTest {
     private final Bytes keyAA = Bytes.wrap("aa".getBytes());
     private final Bytes keyB = Bytes.wrap("b".getBytes());
 
-    private SessionStore<Bytes, byte[]> underlyingStore =
-        new InMemorySessionStore("store-name", Long.MAX_VALUE, "metric-scope");
+    private SessionStore<Bytes, byte[]> underlyingStore;
     private InternalMockProcessorContext context;
     private CachingSessionStore cachingStore;
     private ThreadCache cache;
 
     @Before
     public void before() {
+        underlyingStore = new InMemorySessionStore("store-name", Long.MAX_VALUE, "metric-scope");
         cachingStore = new CachingSessionStore(underlyingStore, SEGMENT_INTERVAL);
         cache = new ThreadCache(new LogContext("testCache "), MAX_CACHE_SIZE_BYTES, new MockStreamsMetrics(new Metrics()));
         context = new InternalMockProcessorContext(TestUtils.tempDirectory(), null, null, null, cache);
@@ -158,6 +155,21 @@ public class CachingSessionStoreTest {
     }
 
     @Test
+    public void shouldPutBackwardFetchAllKeysFromCache() {
+        cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes());
+        cachingStore.put(new Windowed<>(keyAA, new SessionWindow(0, 0)), "1".getBytes());
+        cachingStore.put(new Windowed<>(keyB, new SessionWindow(0, 0)), "1".getBytes());
+
+        assertEquals(3, cache.size());
+
+        final KeyValueIterator<Windowed<Bytes>, byte[]> all = cachingStore.backwardFindSessions(keyA, keyB, 0, 0);
+        verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
+        verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
+        verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1");
+        assertFalse(all.hasNext());
+    }
+
+    @Test
     public void shouldCloseWrappedStoreAndCacheAfterErrorDuringCacheFlush() {
         setUpCloseTests();
         EasyMock.reset(cache);
@@ -211,7 +223,7 @@ public class CachingSessionStoreTest {
         EasyMock.replay(underlyingStore);
         cachingStore = new CachingSessionStore(underlyingStore, SEGMENT_INTERVAL);
         cache = EasyMock.niceMock(ThreadCache.class);
-        context = new InternalMockProcessorContext(TestUtils.tempDirectory(), null, null, null, cache);
+        final InternalMockProcessorContext context = new InternalMockProcessorContext(TestUtils.tempDirectory(), null, null, null, cache);
         context.setRecordContext(new ProcessorRecordContext(10, 0, 0, TOPIC, null));
         cachingStore.init((StateStoreContext) context, cachingStore);
     }
@@ -231,6 +243,20 @@ public class CachingSessionStoreTest {
     }
 
     @Test
+    public void shouldPutBackwardFetchRangeFromCache() {
+        cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes());
+        cachingStore.put(new Windowed<>(keyAA, new SessionWindow(0, 0)), "1".getBytes());
+        cachingStore.put(new Windowed<>(keyB, new SessionWindow(0, 0)), "1".getBytes());
+
+        assertEquals(3, cache.size());
+
+        final KeyValueIterator<Windowed<Bytes>, byte[]> some = cachingStore.backwardFindSessions(keyAA, keyB, 0, 0);
+        verifyWindowedKeyValue(some.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
+        verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
+        assertFalse(some.hasNext());
+    }
+
+    @Test
     public void shouldFetchAllSessionsWithSameRecordKey() {
         final List<KeyValue<Windowed<Bytes>, byte[]>> expected = asList(
             KeyValue.pair(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes()),
@@ -250,6 +276,26 @@ public class CachingSessionStoreTest {
     }
 
     @Test
+    public void shouldBackwardFetchAllSessionsWithSameRecordKey() {
+        final List<KeyValue<Windowed<Bytes>, byte[]>> expected = asList(
+            KeyValue.pair(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes()),
+            KeyValue.pair(new Windowed<>(keyA, new SessionWindow(10, 10)), "2".getBytes()),
+            KeyValue.pair(new Windowed<>(keyA, new SessionWindow(100, 100)), "3".getBytes()),
+            KeyValue.pair(new Windowed<>(keyA, new SessionWindow(1000, 1000)), "4".getBytes())
+        );
+        for (final KeyValue<Windowed<Bytes>, byte[]> kv : expected) {
+            cachingStore.put(kv.key, kv.value);
+        }
+
+        // add one that shouldn't appear in the results
+        cachingStore.put(new Windowed<>(keyAA, new SessionWindow(0, 0)), "5".getBytes());
+
+        final List<KeyValue<Windowed<Bytes>, byte[]>> results = toList(cachingStore.backwardFetch(keyA));
+        Collections.reverse(results);
+        verifyKeyValueList(expected, results);
+    }
+
+    @Test
     public void shouldFlushItemsToStoreOnEviction() {
         final List<KeyValue<Windowed<Bytes>, byte[]>> added = addSessionsUntilOverflow("a", "b", "c", "d");
         assertEquals(added.size() - 1, cache.size());
@@ -292,15 +338,50 @@ public class CachingSessionStoreTest {
         final Windowed<Bytes> a1 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 0, SEGMENT_INTERVAL * 0));
         final Windowed<Bytes> a2 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 1, SEGMENT_INTERVAL * 1));
         final Windowed<Bytes> a3 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 2, SEGMENT_INTERVAL * 2));
+        final Windowed<Bytes> a4 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 3, SEGMENT_INTERVAL * 3));
+        final Windowed<Bytes> a5 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 4, SEGMENT_INTERVAL * 4));
+        final Windowed<Bytes> a6 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 5, SEGMENT_INTERVAL * 5));
         cachingStore.put(a1, "1".getBytes());
         cachingStore.put(a2, "2".getBytes());
         cachingStore.put(a3, "3".getBytes());
         cachingStore.flush();
+        cachingStore.put(a4, "4".getBytes());
+        cachingStore.put(a5, "5".getBytes());
+        cachingStore.put(a6, "6".getBytes());
         final KeyValueIterator<Windowed<Bytes>, byte[]> results =
-            cachingStore.findSessions(keyA, 0, SEGMENT_INTERVAL * 2);
+            cachingStore.findSessions(keyA, 0, SEGMENT_INTERVAL * 5);
         assertEquals(a1, results.next().key);
         assertEquals(a2, results.next().key);
         assertEquals(a3, results.next().key);
+        assertEquals(a4, results.next().key);
+        assertEquals(a5, results.next().key);
+        assertEquals(a6, results.next().key);
+        assertFalse(results.hasNext());
+    }
+
+    @Test
+    public void shouldBackwardFetchCorrectlyAcrossSegments() {
+        final Windowed<Bytes> a1 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 0, SEGMENT_INTERVAL * 0));
+        final Windowed<Bytes> a2 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 1, SEGMENT_INTERVAL * 1));
+        final Windowed<Bytes> a3 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 2, SEGMENT_INTERVAL * 2));
+        final Windowed<Bytes> a4 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 3, SEGMENT_INTERVAL * 3));
+        final Windowed<Bytes> a5 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 4, SEGMENT_INTERVAL * 4));
+        final Windowed<Bytes> a6 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 5, SEGMENT_INTERVAL * 5));
+        cachingStore.put(a1, "1".getBytes());
+        cachingStore.put(a2, "2".getBytes());
+        cachingStore.put(a3, "3".getBytes());
+        cachingStore.flush();
+        cachingStore.put(a4, "4".getBytes());
+        cachingStore.put(a5, "5".getBytes());
+        cachingStore.put(a6, "6".getBytes());
+        final KeyValueIterator<Windowed<Bytes>, byte[]> results =
+            cachingStore.backwardFindSessions(keyA, 0, SEGMENT_INTERVAL * 5);
+        assertEquals(a6, results.next().key);
+        assertEquals(a5, results.next().key);
+        assertEquals(a4, results.next().key);
+        assertEquals(a3, results.next().key);
+        assertEquals(a2, results.next().key);
+        assertEquals(a1, results.next().key);
         assertFalse(results.hasNext());
     }
 
@@ -319,12 +400,35 @@ public class CachingSessionStoreTest {
 
         final KeyValueIterator<Windowed<Bytes>, byte[]> rangeResults =
             cachingStore.findSessions(keyA, keyAA, 0, SEGMENT_INTERVAL * 2);
-        final Set<Windowed<Bytes>> keys = new HashSet<>();
+        final List<Windowed<Bytes>> keys = new ArrayList<>();
         while (rangeResults.hasNext()) {
             keys.add(rangeResults.next().key);
         }
         rangeResults.close();
-        assertEquals(mkSet(a1, a2, a3, aa1, aa3), keys);
+        assertEquals(asList(a1, aa1, a2, a3, aa3), keys);
+    }
+
+    @Test
+    public void shouldBackwardFetchRangeCorrectlyAcrossSegments() {
+        final Windowed<Bytes> a1 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 0, SEGMENT_INTERVAL * 0));
+        final Windowed<Bytes> aa1 = new Windowed<>(keyAA, new SessionWindow(SEGMENT_INTERVAL * 0, SEGMENT_INTERVAL * 0));
+        final Windowed<Bytes> a2 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 1, SEGMENT_INTERVAL * 1));
+        final Windowed<Bytes> a3 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 2, SEGMENT_INTERVAL * 2));
+        final Windowed<Bytes> aa3 = new Windowed<>(keyAA, new SessionWindow(SEGMENT_INTERVAL * 2, SEGMENT_INTERVAL * 2));
+        cachingStore.put(a1, "1".getBytes());
+        cachingStore.put(aa1, "1".getBytes());
+        cachingStore.put(a2, "2".getBytes());
+        cachingStore.put(a3, "3".getBytes());
+        cachingStore.put(aa3, "3".getBytes());
+
+        final KeyValueIterator<Windowed<Bytes>, byte[]> rangeResults =
+            cachingStore.backwardFindSessions(keyA, keyAA, 0, SEGMENT_INTERVAL * 2);
+        final List<Windowed<Bytes>> keys = new ArrayList<>();
+        while (rangeResults.hasNext()) {
+            keys.add(rangeResults.next().key);
+        }
+        rangeResults.close();
+        assertEquals(asList(aa3, a3, a2, aa1, a1), keys);
     }
 
     @Test
@@ -474,6 +578,24 @@ public class CachingSessionStoreTest {
     }
 
     @Test
+    public void shouldReturnSameResultsForSingleKeyFindSessionsBackwardsAndEqualKeyRangeFindSessions() {
+        cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 1)), "1".getBytes());
+        cachingStore.put(new Windowed<>(keyAA, new SessionWindow(2, 3)), "2".getBytes());
+        cachingStore.put(new Windowed<>(keyAA, new SessionWindow(4, 5)), "3".getBytes());
+        cachingStore.put(new Windowed<>(keyB, new SessionWindow(6, 7)), "4".getBytes());
+
+        final KeyValueIterator<Windowed<Bytes>, byte[]> singleKeyIterator =
+            cachingStore.backwardFindSessions(keyAA, 0L, 10L);
+        final KeyValueIterator<Windowed<Bytes>, byte[]> keyRangeIterator =
+            cachingStore.backwardFindSessions(keyAA, keyAA, 0L, 10L);
+
+        assertEquals(singleKeyIterator.next(), keyRangeIterator.next());
+        assertEquals(singleKeyIterator.next(), keyRangeIterator.next());
+        assertFalse(singleKeyIterator.hasNext());
+        assertFalse(keyRangeIterator.hasNext());
+    }
+
+    @Test
     public void shouldClearNamespaceCacheOnClose() {
         final Windowed<Bytes> a1 = new Windowed<>(keyA, new SessionWindow(0, 0));
         cachingStore.put(a1, "1".getBytes());
@@ -482,68 +604,90 @@ public class CachingSessionStoreTest {
         assertEquals(0, cache.size());
     }
 
-    @Test(expected = InvalidStateStoreException.class)
+    @Test
     public void shouldThrowIfTryingToFetchFromClosedCachingStore() {
         cachingStore.close();
-        cachingStore.fetch(keyA);
+        assertThrows(InvalidStateStoreException.class, () -> cachingStore.fetch(keyA));
     }
 
-    @Test(expected = InvalidStateStoreException.class)
+    @Test
     public void shouldThrowIfTryingToFindMergeSessionFromClosedCachingStore() {
         cachingStore.close();
-        cachingStore.findSessions(keyA, 0, Long.MAX_VALUE);
+        assertThrows(InvalidStateStoreException.class, () -> cachingStore.findSessions(keyA, 0, Long.MAX_VALUE));
     }
 
-    @Test(expected = InvalidStateStoreException.class)
+    @Test
     public void shouldThrowIfTryingToRemoveFromClosedCachingStore() {
         cachingStore.close();
-        cachingStore.remove(new Windowed<>(keyA, new SessionWindow(0, 0)));
+        assertThrows(InvalidStateStoreException.class, () -> cachingStore.remove(new Windowed<>(keyA, new SessionWindow(0, 0))));
     }
 
-    @Test(expected = InvalidStateStoreException.class)
+    @Test
     public void shouldThrowIfTryingToPutIntoClosedCachingStore() {
         cachingStore.close();
-        cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes());
+        assertThrows(InvalidStateStoreException.class, () -> cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes()));
     }
 
-    @Test(expected = NullPointerException.class)
+    @Test
     public void shouldThrowNullPointerExceptionOnFindSessionsNullKey() {
-        cachingStore.findSessions(null, 1L, 2L);
+        assertThrows(NullPointerException.class, () -> cachingStore.findSessions(null, 1L, 2L));
     }
 
-    @Test(expected = NullPointerException.class)
+    @Test
     public void shouldThrowNullPointerExceptionOnFindSessionsNullFromKey() {
-        cachingStore.findSessions(null, keyA, 1L, 2L);
+        assertThrows(NullPointerException.class, () -> cachingStore.findSessions(null, keyA, 1L, 2L));
     }
 
-    @Test(expected = NullPointerException.class)
+    @Test
     public void shouldThrowNullPointerExceptionOnFindSessionsNullToKey() {
-        cachingStore.findSessions(keyA, null, 1L, 2L);
+        assertThrows(NullPointerException.class, () -> cachingStore.findSessions(keyA, null, 1L, 2L));
     }
 
-    @Test(expected = NullPointerException.class)
+    @Test
     public void shouldThrowNullPointerExceptionOnFetchNullFromKey() {
-        cachingStore.fetch(null, keyA);
+        assertThrows(NullPointerException.class, () -> cachingStore.fetch(null, keyA));
     }
 
-    @Test(expected = NullPointerException.class)
+    @Test
     public void shouldThrowNullPointerExceptionOnFetchNullToKey() {
-        cachingStore.fetch(keyA, null);
+        assertThrows(NullPointerException.class, () -> cachingStore.fetch(keyA, null));
     }
 
-    @Test(expected = NullPointerException.class)
+    @Test
     public void shouldThrowNullPointerExceptionOnFetchNullKey() {
-        cachingStore.fetch(null);
+        assertThrows(NullPointerException.class, () -> cachingStore.fetch(null));
     }
 
-    @Test(expected = NullPointerException.class)
+    @Test
     public void shouldThrowNullPointerExceptionOnRemoveNullKey() {
-        cachingStore.remove(null);
+        assertThrows(NullPointerException.class, () -> cachingStore.remove(null));
     }
 
-    @Test(expected = NullPointerException.class)
+    @Test
     public void shouldThrowNullPointerExceptionOnPutNullKey() {
-        cachingStore.put(null, "1".getBytes());
+        assertThrows(NullPointerException.class, () -> cachingStore.put(null, "1".getBytes()));
+    }
+
+    @Test
+    public void shouldNotThrowInvalidRangeExceptionWhenBackwardWithNegativeFromKey() {
+        final Bytes keyFrom = Bytes.wrap(Serdes.Integer().serializer().serialize("", -1));
+        final Bytes keyTo = Bytes.wrap(Serdes.Integer().serializer().serialize("", 1));
+
+        try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(CachingSessionStore.class)) {
+            final KeyValueIterator<Windowed<Bytes>, byte[]> iterator = cachingStore.backwardFindSessions(keyFrom, keyTo, 0L, 10L);
+            assertFalse(iterator.hasNext());
+
+            final List<String> messages = appender.getMessages();
+            assertThat(
+                messages,
+                hasItem(
+                    "Returning empty iterator for fetch with invalid key range: from > to." +
+                        " This may be due to range arguments set in the wrong order, " +
+                        "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes." +
+                        " Note that the built-in numerical serdes do not follow this for negative numbers"
+                )
+            );
+        }
     }
 
     @Test
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingSessionStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingPersistentSessionStoreTest.java
similarity index 63%
rename from streams/src/test/java/org/apache/kafka/streams/state/internals/CachingSessionStoreTest.java
rename to streams/src/test/java/org/apache/kafka/streams/state/internals/CachingPersistentSessionStoreTest.java
index 05e97a2..d472c7f5 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingSessionStoreTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingPersistentSessionStoreTest.java
@@ -29,7 +29,6 @@ import org.apache.kafka.streams.kstream.SessionWindowedDeserializer;
 import org.apache.kafka.streams.kstream.Windowed;
 import org.apache.kafka.streams.kstream.internals.Change;
 import org.apache.kafka.streams.kstream.internals.SessionWindow;
-import org.apache.kafka.streams.processor.ProcessorContext;
 import org.apache.kafka.streams.processor.StateStoreContext;
 import org.apache.kafka.streams.processor.internals.MockStreamsMetrics;
 import org.apache.kafka.streams.processor.internals.ProcessorRecordContext;
@@ -46,14 +45,11 @@ import org.junit.Test;
 import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
 import java.util.Collections;
-import java.util.HashSet;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Random;
-import java.util.Set;
 
 import static java.util.Arrays.asList;
-import static org.apache.kafka.common.utils.Utils.mkSet;
 import static org.apache.kafka.test.StreamsTestUtils.toList;
 import static org.apache.kafka.test.StreamsTestUtils.verifyKeyValueList;
 import static org.apache.kafka.test.StreamsTestUtils.verifyWindowedKeyValue;
@@ -67,8 +63,7 @@ import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
 
-@SuppressWarnings("PointlessArithmeticExpression")
-public class CachingSessionStoreTest {
+public class CachingPersistentSessionStoreTest {
 
     private static final int MAX_CACHE_SIZE_BYTES = 600;
     private static final Long DEFAULT_TIMESTAMP = 10L;
@@ -80,17 +75,24 @@ public class CachingSessionStoreTest {
     private final Bytes keyAA = Bytes.wrap("aa".getBytes());
     private final Bytes keyB = Bytes.wrap("b".getBytes());
 
-    private SessionStore<Bytes, byte[]> underlyingStore =
-        new InMemorySessionStore("store-name", Long.MAX_VALUE, "metric-scope");
-    private InternalMockProcessorContext context;
+    private SessionStore<Bytes, byte[]> underlyingStore;
     private CachingSessionStore cachingStore;
     private ThreadCache cache;
 
     @Before
     public void before() {
+        final RocksDBSegmentedBytesStore segmented = new RocksDBSegmentedBytesStore(
+            "store-name",
+            "metric-scope",
+            Long.MAX_VALUE,
+            SEGMENT_INTERVAL,
+            new SessionKeySchema()
+        );
+        underlyingStore = new RocksDBSessionStore(segmented);
         cachingStore = new CachingSessionStore(underlyingStore, SEGMENT_INTERVAL);
         cache = new ThreadCache(new LogContext("testCache "), MAX_CACHE_SIZE_BYTES, new MockStreamsMetrics(new Metrics()));
-        context = new InternalMockProcessorContext(TestUtils.tempDirectory(), null, null, null, cache);
+        final InternalMockProcessorContext context =
+            new InternalMockProcessorContext(TestUtils.tempDirectory(), null, null, null, cache);
         context.setRecordContext(new ProcessorRecordContext(DEFAULT_TIMESTAMP, 0, 0, TOPIC, null));
         cachingStore.init((StateStoreContext) context, cachingStore);
     }
@@ -100,31 +102,6 @@ public class CachingSessionStoreTest {
         cachingStore.close();
     }
 
-    @SuppressWarnings("deprecation")
-    @Test
-    public void shouldDelegateDeprecatedInit() {
-        final SessionStore<Bytes, byte[]> inner = EasyMock.mock(InMemorySessionStore.class);
-        final CachingSessionStore outer = new CachingSessionStore(inner, SEGMENT_INTERVAL);
-        EasyMock.expect(inner.name()).andStubReturn("store");
-        inner.init((ProcessorContext) context, outer);
-        EasyMock.expectLastCall();
-        EasyMock.replay(inner);
-        outer.init((ProcessorContext) context, outer);
-        EasyMock.verify(inner);
-    }
-
-    @Test
-    public void shouldDelegateInit() {
-        final SessionStore<Bytes, byte[]> inner = EasyMock.mock(InMemorySessionStore.class);
-        final CachingSessionStore outer = new CachingSessionStore(inner, SEGMENT_INTERVAL);
-        EasyMock.expect(inner.name()).andStubReturn("store");
-        inner.init((StateStoreContext) context, outer);
-        EasyMock.expectLastCall();
-        EasyMock.replay(inner);
-        outer.init((StateStoreContext) context, outer);
-        EasyMock.verify(inner);
-    }
-
     @Test
     public void shouldPutFetchFromCache() {
         cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes());
@@ -133,8 +110,10 @@ public class CachingSessionStoreTest {
 
         assertEquals(3, cache.size());
 
-        final KeyValueIterator<Windowed<Bytes>, byte[]> a = cachingStore.findSessions(keyA, 0, 0);
-        final KeyValueIterator<Windowed<Bytes>, byte[]> b = cachingStore.findSessions(keyB, 0, 0);
+        final KeyValueIterator<Windowed<Bytes>, byte[]> a =
+            cachingStore.findSessions(keyA, 0, 0);
+        final KeyValueIterator<Windowed<Bytes>, byte[]> b =
+            cachingStore.findSessions(keyB, 0, 0);
 
         verifyWindowedKeyValue(a.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1");
         verifyWindowedKeyValue(b.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
@@ -150,7 +129,8 @@ public class CachingSessionStoreTest {
 
         assertEquals(3, cache.size());
 
-        final KeyValueIterator<Windowed<Bytes>, byte[]> all = cachingStore.findSessions(keyA, keyB, 0, 0);
+        final KeyValueIterator<Windowed<Bytes>, byte[]> all =
+            cachingStore.findSessions(keyA, keyB, 0, 0);
         verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1");
         verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
         verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
@@ -158,6 +138,22 @@ public class CachingSessionStoreTest {
     }
 
     @Test
+    public void shouldPutBackwardFetchAllKeysFromCache() {
+        cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes());
+        cachingStore.put(new Windowed<>(keyAA, new SessionWindow(0, 0)), "1".getBytes());
+        cachingStore.put(new Windowed<>(keyB, new SessionWindow(0, 0)), "1".getBytes());
+
+        assertEquals(3, cache.size());
+
+        final KeyValueIterator<Windowed<Bytes>, byte[]> all =
+            cachingStore.backwardFindSessions(keyA, keyB, 0, 0);
+        verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
+        verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
+        verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1");
+        assertFalse(all.hasNext());
+    }
+
+    @Test
     public void shouldCloseWrappedStoreAndCacheAfterErrorDuringCacheFlush() {
         setUpCloseTests();
         EasyMock.reset(cache);
@@ -211,7 +207,8 @@ public class CachingSessionStoreTest {
         EasyMock.replay(underlyingStore);
         cachingStore = new CachingSessionStore(underlyingStore, SEGMENT_INTERVAL);
         cache = EasyMock.niceMock(ThreadCache.class);
-        context = new InternalMockProcessorContext(TestUtils.tempDirectory(), null, null, null, cache);
+        final InternalMockProcessorContext context =
+            new InternalMockProcessorContext(TestUtils.tempDirectory(), null, null, null, cache);
         context.setRecordContext(new ProcessorRecordContext(10, 0, 0, TOPIC, null));
         cachingStore.init((StateStoreContext) context, cachingStore);
     }
@@ -224,13 +221,29 @@ public class CachingSessionStoreTest {
 
         assertEquals(3, cache.size());
 
-        final KeyValueIterator<Windowed<Bytes>, byte[]> some = cachingStore.findSessions(keyAA, keyB, 0, 0);
+        final KeyValueIterator<Windowed<Bytes>, byte[]> some =
+            cachingStore.findSessions(keyAA, keyB, 0, 0);
         verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
         verifyWindowedKeyValue(some.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
         assertFalse(some.hasNext());
     }
 
     @Test
+    public void shouldPutBackwardFetchRangeFromCache() {
+        cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes());
+        cachingStore.put(new Windowed<>(keyAA, new SessionWindow(0, 0)), "1".getBytes());
+        cachingStore.put(new Windowed<>(keyB, new SessionWindow(0, 0)), "1".getBytes());
+
+        assertEquals(3, cache.size());
+
+        final KeyValueIterator<Windowed<Bytes>, byte[]> some =
+            cachingStore.backwardFindSessions(keyAA, keyB, 0, 0);
+        verifyWindowedKeyValue(some.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
+        verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
+        assertFalse(some.hasNext());
+    }
+
+    @Test
     public void shouldFetchAllSessionsWithSameRecordKey() {
         final List<KeyValue<Windowed<Bytes>, byte[]>> expected = asList(
             KeyValue.pair(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes()),
@@ -250,10 +263,31 @@ public class CachingSessionStoreTest {
     }
 
     @Test
+    public void shouldBackwardFetchAllSessionsWithSameRecordKey() {
+        final List<KeyValue<Windowed<Bytes>, byte[]>> expected = asList(
+            KeyValue.pair(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes()),
+            KeyValue.pair(new Windowed<>(keyA, new SessionWindow(10, 10)), "2".getBytes()),
+            KeyValue.pair(new Windowed<>(keyA, new SessionWindow(100, 100)), "3".getBytes()),
+            KeyValue.pair(new Windowed<>(keyA, new SessionWindow(1000, 1000)), "4".getBytes())
+        );
+        for (final KeyValue<Windowed<Bytes>, byte[]> kv : expected) {
+            cachingStore.put(kv.key, kv.value);
+        }
+
+        // add one that shouldn't appear in the results
+        cachingStore.put(new Windowed<>(keyAA, new SessionWindow(0, 0)), "5".getBytes());
+
+        final List<KeyValue<Windowed<Bytes>, byte[]>> results = toList(cachingStore.backwardFetch(keyA));
+        Collections.reverse(results);
+        verifyKeyValueList(expected, results);
+    }
+
+    @Test
     public void shouldFlushItemsToStoreOnEviction() {
         final List<KeyValue<Windowed<Bytes>, byte[]>> added = addSessionsUntilOverflow("a", "b", "c", "d");
         assertEquals(added.size() - 1, cache.size());
-        final KeyValueIterator<Windowed<Bytes>, byte[]> iterator = cachingStore.findSessions(added.get(0).key.key(), 0, 0);
+        final KeyValueIterator<Windowed<Bytes>, byte[]> iterator =
+            cachingStore.findSessions(added.get(0).key.key(), 0, 0);
         final KeyValue<Windowed<Bytes>, byte[]> next = iterator.next();
         assertEquals(added.get(0).key, next.key);
         assertArrayEquals(added.get(0).value, next.value);
@@ -265,7 +299,8 @@ public class CachingSessionStoreTest {
         final KeyValueIterator<Windowed<Bytes>, byte[]> iterator = cachingStore.findSessions(
             Bytes.wrap("a".getBytes(StandardCharsets.UTF_8)),
             0,
-            added.size() * 10);
+            added.size() * 10L
+        );
         final List<KeyValue<Windowed<Bytes>, byte[]>> actual = toList(iterator);
         verifyKeyValueList(added, actual);
     }
@@ -292,15 +327,50 @@ public class CachingSessionStoreTest {
         final Windowed<Bytes> a1 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 0, SEGMENT_INTERVAL * 0));
         final Windowed<Bytes> a2 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 1, SEGMENT_INTERVAL * 1));
         final Windowed<Bytes> a3 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 2, SEGMENT_INTERVAL * 2));
+        final Windowed<Bytes> a4 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 3, SEGMENT_INTERVAL * 3));
+        final Windowed<Bytes> a5 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 4, SEGMENT_INTERVAL * 4));
+        final Windowed<Bytes> a6 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 5, SEGMENT_INTERVAL * 5));
         cachingStore.put(a1, "1".getBytes());
         cachingStore.put(a2, "2".getBytes());
         cachingStore.put(a3, "3".getBytes());
         cachingStore.flush();
+        cachingStore.put(a4, "4".getBytes());
+        cachingStore.put(a5, "5".getBytes());
+        cachingStore.put(a6, "6".getBytes());
         final KeyValueIterator<Windowed<Bytes>, byte[]> results =
-            cachingStore.findSessions(keyA, 0, SEGMENT_INTERVAL * 2);
+            cachingStore.findSessions(keyA, 0, SEGMENT_INTERVAL * 5);
         assertEquals(a1, results.next().key);
         assertEquals(a2, results.next().key);
         assertEquals(a3, results.next().key);
+        assertEquals(a4, results.next().key);
+        assertEquals(a5, results.next().key);
+        assertEquals(a6, results.next().key);
+        assertFalse(results.hasNext());
+    }
+
+    @Test
+    public void shouldBackwardFetchCorrectlyAcrossSegments() {
+        final Windowed<Bytes> a1 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 0, SEGMENT_INTERVAL * 0));
+        final Windowed<Bytes> a2 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 1, SEGMENT_INTERVAL * 1));
+        final Windowed<Bytes> a3 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 2, SEGMENT_INTERVAL * 2));
+        final Windowed<Bytes> a4 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 3, SEGMENT_INTERVAL * 3));
+        final Windowed<Bytes> a5 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 4, SEGMENT_INTERVAL * 4));
+        final Windowed<Bytes> a6 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 5, SEGMENT_INTERVAL * 5));
+        cachingStore.put(a1, "1".getBytes());
+        cachingStore.put(a2, "2".getBytes());
+        cachingStore.put(a3, "3".getBytes());
+        cachingStore.flush();
+        cachingStore.put(a4, "4".getBytes());
+        cachingStore.put(a5, "5".getBytes());
+        cachingStore.put(a6, "6".getBytes());
+        final KeyValueIterator<Windowed<Bytes>, byte[]> results =
+            cachingStore.backwardFindSessions(keyA, 0, SEGMENT_INTERVAL * 5);
+        assertEquals(a6, results.next().key);
+        assertEquals(a5, results.next().key);
+        assertEquals(a4, results.next().key);
+        assertEquals(a3, results.next().key);
+        assertEquals(a2, results.next().key);
+        assertEquals(a1, results.next().key);
         assertFalse(results.hasNext());
     }
 
@@ -319,12 +389,35 @@ public class CachingSessionStoreTest {
 
         final KeyValueIterator<Windowed<Bytes>, byte[]> rangeResults =
             cachingStore.findSessions(keyA, keyAA, 0, SEGMENT_INTERVAL * 2);
-        final Set<Windowed<Bytes>> keys = new HashSet<>();
+        final List<Windowed<Bytes>> keys = new ArrayList<>();
+        while (rangeResults.hasNext()) {
+            keys.add(rangeResults.next().key);
+        }
+        rangeResults.close();
+        assertEquals(asList(a1, aa1, a2, a3, aa3), keys);
+    }
+
+    @Test
+    public void shouldBackwardFetchRangeCorrectlyAcrossSegments() {
+        final Windowed<Bytes> a1 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 0, SEGMENT_INTERVAL * 0));
+        final Windowed<Bytes> aa1 = new Windowed<>(keyAA, new SessionWindow(SEGMENT_INTERVAL * 0, SEGMENT_INTERVAL * 0));
+        final Windowed<Bytes> a2 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 1, SEGMENT_INTERVAL * 1));
+        final Windowed<Bytes> a3 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 2, SEGMENT_INTERVAL * 2));
+        final Windowed<Bytes> aa3 = new Windowed<>(keyAA, new SessionWindow(SEGMENT_INTERVAL * 2, SEGMENT_INTERVAL * 2));
+        cachingStore.put(a1, "1".getBytes());
+        cachingStore.put(aa1, "1".getBytes());
+        cachingStore.put(a2, "2".getBytes());
+        cachingStore.put(a3, "3".getBytes());
+        cachingStore.put(aa3, "3".getBytes());
+
+        final KeyValueIterator<Windowed<Bytes>, byte[]> rangeResults =
+            cachingStore.backwardFindSessions(keyA, keyAA, 0, SEGMENT_INTERVAL * 2);
+        final List<Windowed<Bytes>> keys = new ArrayList<>();
         while (rangeResults.hasNext()) {
             keys.add(rangeResults.next().key);
         }
         rangeResults.close();
-        assertEquals(mkSet(a1, a2, a3, aa1, aa3), keys);
+        assertEquals(asList(aa3, a3, a2, aa1, a1), keys);
     }
 
     @Test
@@ -342,7 +435,8 @@ public class CachingSessionStoreTest {
         final CacheFlushListenerStub<Windowed<String>, String> flushListener =
             new CacheFlushListenerStub<>(
                 new SessionWindowedDeserializer<>(new StringDeserializer()),
-                new StringDeserializer());
+                new StringDeserializer()
+            );
         cachingStore.setFlushListener(flushListener, true);
 
         cachingStore.put(b, "1".getBytes());
@@ -353,7 +447,9 @@ public class CachingSessionStoreTest {
                 new KeyValueTimestamp<>(
                     bDeserialized,
                     new Change<>("1", null),
-                    DEFAULT_TIMESTAMP)),
+                    DEFAULT_TIMESTAMP
+                )
+            ),
             flushListener.forwarded
         );
         flushListener.forwarded.clear();
@@ -366,7 +462,9 @@ public class CachingSessionStoreTest {
                 new KeyValueTimestamp<>(
                     aDeserialized,
                     new Change<>("1", null),
-                    DEFAULT_TIMESTAMP)),
+                    DEFAULT_TIMESTAMP
+                )
+            ),
             flushListener.forwarded
         );
         flushListener.forwarded.clear();
@@ -379,7 +477,9 @@ public class CachingSessionStoreTest {
                 new KeyValueTimestamp<>(
                     aDeserialized,
                     new Change<>("2", "1"),
-                    DEFAULT_TIMESTAMP)),
+                    DEFAULT_TIMESTAMP
+                )
+            ),
             flushListener.forwarded
         );
         flushListener.forwarded.clear();
@@ -392,7 +492,9 @@ public class CachingSessionStoreTest {
                 new KeyValueTimestamp<>(
                     aDeserialized,
                     new Change<>(null, "2"),
-                    DEFAULT_TIMESTAMP)),
+                    DEFAULT_TIMESTAMP
+                )
+            ),
             flushListener.forwarded
         );
         flushListener.forwarded.clear();
@@ -429,18 +531,23 @@ public class CachingSessionStoreTest {
         cachingStore.flush();
 
         assertEquals(
-            asList(new KeyValueTimestamp<>(
+            asList(
+                new KeyValueTimestamp<>(
                     aDeserialized,
                     new Change<>("1", null),
-                    DEFAULT_TIMESTAMP),
+                    DEFAULT_TIMESTAMP
+                ),
                 new KeyValueTimestamp<>(
                     aDeserialized,
                     new Change<>("2", null),
-                    DEFAULT_TIMESTAMP),
+                    DEFAULT_TIMESTAMP
+                ),
                 new KeyValueTimestamp<>(
                     aDeserialized,
                     new Change<>(null, null),
-                    DEFAULT_TIMESTAMP)),
+                    DEFAULT_TIMESTAMP
+                )
+            ),
             flushListener.forwarded
         );
         flushListener.forwarded.clear();
@@ -464,8 +571,28 @@ public class CachingSessionStoreTest {
         cachingStore.put(new Windowed<>(keyAA, new SessionWindow(4, 5)), "3".getBytes());
         cachingStore.put(new Windowed<>(keyB, new SessionWindow(6, 7)), "4".getBytes());
 
-        final KeyValueIterator<Windowed<Bytes>, byte[]> singleKeyIterator = cachingStore.findSessions(keyAA, 0L, 10L);
-        final KeyValueIterator<Windowed<Bytes>, byte[]> keyRangeIterator = cachingStore.findSessions(keyAA, keyAA, 0L, 10L);
+        final KeyValueIterator<Windowed<Bytes>, byte[]> singleKeyIterator =
+            cachingStore.findSessions(keyAA, 0L, 10L);
+        final KeyValueIterator<Windowed<Bytes>, byte[]> keyRangeIterator =
+            cachingStore.findSessions(keyAA, keyAA, 0L, 10L);
+
+        assertEquals(singleKeyIterator.next(), keyRangeIterator.next());
+        assertEquals(singleKeyIterator.next(), keyRangeIterator.next());
+        assertFalse(singleKeyIterator.hasNext());
+        assertFalse(keyRangeIterator.hasNext());
+    }
+
+    @Test
+    public void shouldReturnSameResultsForSingleKeyFindSessionsBackwardsAndEqualKeyRangeFindSessions() {
+        cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 1)), "1".getBytes());
+        cachingStore.put(new Windowed<>(keyAA, new SessionWindow(2, 3)), "2".getBytes());
+        cachingStore.put(new Windowed<>(keyAA, new SessionWindow(4, 5)), "3".getBytes());
+        cachingStore.put(new Windowed<>(keyB, new SessionWindow(6, 7)), "4".getBytes());
+
+        final KeyValueIterator<Windowed<Bytes>, byte[]> singleKeyIterator =
+            cachingStore.backwardFindSessions(keyAA, 0L, 10L);
+        final KeyValueIterator<Windowed<Bytes>, byte[]> keyRangeIterator =
+            cachingStore.backwardFindSessions(keyAA, keyAA, 0L, 10L);
 
         assertEquals(singleKeyIterator.next(), keyRangeIterator.next());
         assertEquals(singleKeyIterator.next(), keyRangeIterator.next());
@@ -482,68 +609,91 @@ public class CachingSessionStoreTest {
         assertEquals(0, cache.size());
     }
 
-    @Test(expected = InvalidStateStoreException.class)
+    @Test
     public void shouldThrowIfTryingToFetchFromClosedCachingStore() {
         cachingStore.close();
-        cachingStore.fetch(keyA);
+        assertThrows(InvalidStateStoreException.class, () -> cachingStore.fetch(keyA));
     }
 
-    @Test(expected = InvalidStateStoreException.class)
+    @Test
     public void shouldThrowIfTryingToFindMergeSessionFromClosedCachingStore() {
         cachingStore.close();
-        cachingStore.findSessions(keyA, 0, Long.MAX_VALUE);
+        assertThrows(InvalidStateStoreException.class, () -> cachingStore.findSessions(keyA, 0, Long.MAX_VALUE));
     }
 
-    @Test(expected = InvalidStateStoreException.class)
+    @Test
     public void shouldThrowIfTryingToRemoveFromClosedCachingStore() {
         cachingStore.close();
-        cachingStore.remove(new Windowed<>(keyA, new SessionWindow(0, 0)));
+        assertThrows(InvalidStateStoreException.class, () -> cachingStore.remove(new Windowed<>(keyA, new SessionWindow(0, 0))));
     }
 
-    @Test(expected = InvalidStateStoreException.class)
+    @Test
     public void shouldThrowIfTryingToPutIntoClosedCachingStore() {
         cachingStore.close();
-        cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes());
+        assertThrows(InvalidStateStoreException.class, () -> cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes()));
     }
 
-    @Test(expected = NullPointerException.class)
+    @Test
     public void shouldThrowNullPointerExceptionOnFindSessionsNullKey() {
-        cachingStore.findSessions(null, 1L, 2L);
+        assertThrows(NullPointerException.class, () -> cachingStore.findSessions(null, 1L, 2L));
     }
 
-    @Test(expected = NullPointerException.class)
+    @Test
     public void shouldThrowNullPointerExceptionOnFindSessionsNullFromKey() {
-        cachingStore.findSessions(null, keyA, 1L, 2L);
+        assertThrows(NullPointerException.class, () -> cachingStore.findSessions(null, keyA, 1L, 2L));
     }
 
-    @Test(expected = NullPointerException.class)
+    @Test
     public void shouldThrowNullPointerExceptionOnFindSessionsNullToKey() {
-        cachingStore.findSessions(keyA, null, 1L, 2L);
+        assertThrows(NullPointerException.class, () -> cachingStore.findSessions(keyA, null, 1L, 2L));
     }
 
-    @Test(expected = NullPointerException.class)
+    @Test
     public void shouldThrowNullPointerExceptionOnFetchNullFromKey() {
-        cachingStore.fetch(null, keyA);
+        assertThrows(NullPointerException.class, () -> cachingStore.fetch(null, keyA));
     }
 
-    @Test(expected = NullPointerException.class)
+    @Test
     public void shouldThrowNullPointerExceptionOnFetchNullToKey() {
-        cachingStore.fetch(keyA, null);
+        assertThrows(NullPointerException.class, () -> cachingStore.fetch(keyA, null));
     }
 
-    @Test(expected = NullPointerException.class)
+    @Test
     public void shouldThrowNullPointerExceptionOnFetchNullKey() {
-        cachingStore.fetch(null);
+        assertThrows(NullPointerException.class, () -> cachingStore.fetch(null));
     }
 
-    @Test(expected = NullPointerException.class)
+    @Test
     public void shouldThrowNullPointerExceptionOnRemoveNullKey() {
-        cachingStore.remove(null);
+        assertThrows(NullPointerException.class, () -> cachingStore.remove(null));
     }
 
-    @Test(expected = NullPointerException.class)
+    @Test
     public void shouldThrowNullPointerExceptionOnPutNullKey() {
-        cachingStore.put(null, "1".getBytes());
+        assertThrows(NullPointerException.class, () -> cachingStore.put(null, "1".getBytes()));
+    }
+
+    @Test
+    public void shouldNotThrowInvalidRangeExceptionWhenBackwardWithNegativeFromKey() {
+        final Bytes keyFrom = Bytes.wrap(Serdes.Integer().serializer().serialize("", -1));
+        final Bytes keyTo = Bytes.wrap(Serdes.Integer().serializer().serialize("", 1));
+
+        try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(CachingSessionStore.class)) {
+            final KeyValueIterator<Windowed<Bytes>, byte[]> iterator =
+                cachingStore.backwardFindSessions(keyFrom, keyTo, 0L, 10L);
+            assertFalse(iterator.hasNext());
+
+            final List<String> messages = appender.getMessages();
+            assertThat(
+                messages,
+                hasItem(
+                    "Returning empty iterator for fetch with invalid key range: from > to." +
+                        " This may be due to range arguments set in the wrong order, " +
+                        "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes." +
+                        " Note that the built-in numerical serdes do not follow this for negative numbers"
+                )
+            );
+        }
     }
 
     @Test
@@ -558,10 +708,12 @@ public class CachingSessionStoreTest {
             final List<String> messages = appender.getMessages();
             assertThat(
                 messages,
-                hasItem("Returning empty iterator for fetch with invalid key range: from > to." +
-                    " This may be due to range arguments set in the wrong order, " +
-                    "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes." +
-                    " Note that the built-in numerical serdes do not follow this for negative numbers")
+                hasItem(
+                    "Returning empty iterator for fetch with invalid key range: from > to." +
+                        " This may be due to range arguments set in the wrong order, " +
+                        "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes." +
+                        " Note that the built-in numerical serdes do not follow this for negative numbers"
+                )
             );
         }
     }
@@ -585,9 +737,9 @@ public class CachingSessionStoreTest {
     }
 
     public static class CacheFlushListenerStub<K, V> implements CacheFlushListener<byte[], byte[]> {
-        final Deserializer<K> keyDeserializer;
-        final Deserializer<V> valueDesializer;
-        final List<KeyValueTimestamp<K, Change<V>>> forwarded = new LinkedList<>();
+        private final Deserializer<K> keyDeserializer;
+        private final Deserializer<V> valueDesializer;
+        private final List<KeyValueTimestamp<K, Change<V>>> forwarded = new LinkedList<>();
 
         CacheFlushListenerStub(final Deserializer<K> keyDeserializer,
                                final Deserializer<V> valueDesializer) {
@@ -606,7 +758,9 @@ public class CachingSessionStoreTest {
                     new Change<>(
                         valueDesializer.deserialize(null, newValue),
                         valueDesializer.deserialize(null, oldValue)),
-                    timestamp));
+                    timestamp
+                )
+            );
         }
     }
 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingWindowStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingPersistentWindowStoreTest.java
similarity index 99%
rename from streams/src/test/java/org/apache/kafka/streams/state/internals/CachingWindowStoreTest.java
rename to streams/src/test/java/org/apache/kafka/streams/state/internals/CachingPersistentWindowStoreTest.java
index 2a04c48..86ee164 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingWindowStoreTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingPersistentWindowStoreTest.java
@@ -75,7 +75,7 @@ import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
 
-public class CachingWindowStoreTest {
+public class CachingPersistentWindowStoreTest {
 
     private static final int MAX_CACHE_SIZE_BYTES = 150;
     private static final long DEFAULT_TIMESTAMP = 10L;
@@ -88,7 +88,7 @@ public class CachingWindowStoreTest {
     private RocksDBSegmentedBytesStore bytesStore;
     private WindowStore<Bytes, byte[]> underlyingStore;
     private CachingWindowStore cachingStore;
-    private CachingKeyValueStoreTest.CacheFlushListenerStub<Windowed<String>, String> cacheListener;
+    private CacheFlushListenerStub<Windowed<String>, String> cacheListener;
     private ThreadCache cache;
     private WindowKeySchema keySchema;
 
@@ -99,7 +99,7 @@ public class CachingWindowStoreTest {
         underlyingStore = new RocksDBWindowStore(bytesStore, false, WINDOW_SIZE);
         final TimeWindowedDeserializer<String> keyDeserializer = new TimeWindowedDeserializer<>(new StringDeserializer(), WINDOW_SIZE);
         keyDeserializer.setIsChangelogTopic(true);
-        cacheListener = new CachingKeyValueStoreTest.CacheFlushListenerStub<>(keyDeserializer, new StringDeserializer());
+        cacheListener = new CacheFlushListenerStub<>(keyDeserializer, new StringDeserializer());
         cachingStore = new CachingWindowStore(underlyingStore, WINDOW_SIZE, SEGMENT_INTERVAL);
         cachingStore.setFlushListener(cacheListener, false);
         cache = new ThreadCache(new LogContext("testCache "), MAX_CACHE_SIZE_BYTES, new MockStreamsMetrics(new Metrics()));
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/ChangeLoggingSessionBytesStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/ChangeLoggingSessionBytesStoreTest.java
index c55c4e15..8fdbd33 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/ChangeLoggingSessionBytesStoreTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/ChangeLoggingSessionBytesStoreTest.java
@@ -131,6 +131,16 @@ public class ChangeLoggingSessionBytesStoreTest {
     }
 
     @Test
+    public void shouldDelegateToUnderlyingStoreWhenBackwardFetching() {
+        EasyMock.expect(inner.backwardFetch(bytesKey)).andReturn(KeyValueIterators.emptyIterator());
+
+        init();
+
+        store.backwardFetch(bytesKey);
+        EasyMock.verify(inner);
+    }
+
+    @Test
     public void shouldDelegateToUnderlyingStoreWhenFetchingRange() {
         EasyMock.expect(inner.fetch(bytesKey, bytesKey)).andReturn(KeyValueIterators.emptyIterator());
 
@@ -141,6 +151,16 @@ public class ChangeLoggingSessionBytesStoreTest {
     }
 
     @Test
+    public void shouldDelegateToUnderlyingStoreWhenBackwardFetchingRange() {
+        EasyMock.expect(inner.backwardFetch(bytesKey, bytesKey)).andReturn(KeyValueIterators.emptyIterator());
+
+        init();
+
+        store.backwardFetch(bytesKey, bytesKey);
+        EasyMock.verify(inner);
+    }
+
+    @Test
     public void shouldDelegateToUnderlyingStoreWhenFindingSessions() {
         EasyMock.expect(inner.findSessions(bytesKey, 0, 1)).andReturn(KeyValueIterators.emptyIterator());
 
@@ -151,6 +171,16 @@ public class ChangeLoggingSessionBytesStoreTest {
     }
 
     @Test
+    public void shouldDelegateToUnderlyingStoreWhenBackwardFindingSessions() {
+        EasyMock.expect(inner.backwardFindSessions(bytesKey, 0, 1)).andReturn(KeyValueIterators.emptyIterator());
+
+        init();
+
+        store.backwardFindSessions(bytesKey, 0, 1);
+        EasyMock.verify(inner);
+    }
+
+    @Test
     public void shouldDelegateToUnderlyingStoreWhenFindingSessionRange() {
         EasyMock.expect(inner.findSessions(bytesKey, bytesKey, 0, 1)).andReturn(KeyValueIterators.emptyIterator());
 
@@ -161,6 +191,16 @@ public class ChangeLoggingSessionBytesStoreTest {
     }
 
     @Test
+    public void shouldDelegateToUnderlyingStoreWhenBackwardFindingSessionRange() {
+        EasyMock.expect(inner.backwardFindSessions(bytesKey, bytesKey, 0, 1)).andReturn(KeyValueIterators.emptyIterator());
+
+        init();
+
+        store.backwardFindSessions(bytesKey, bytesKey, 0, 1);
+        EasyMock.verify(inner);
+    }
+
+    @Test
     public void shouldFlushUnderlyingStore() {
         inner.flush();
         EasyMock.expectLastCall();
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/MergedSortedCacheWrappedSessionStoreIteratorTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/MergedSortedCacheWrappedSessionStoreIteratorTest.java
index 617ff36..4bd125a 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/MergedSortedCacheWrappedSessionStoreIteratorTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/MergedSortedCacheWrappedSessionStoreIteratorTest.java
@@ -55,56 +55,101 @@ public class MergedSortedCacheWrappedSessionStoreIteratorTest {
 
     @Test
     public void shouldHaveNextFromStore() {
-        final MergedSortedCacheSessionStoreIterator mergeIterator = createIterator(storeKvs, Collections.emptyIterator());
+        final MergedSortedCacheSessionStoreIterator mergeIterator = createIterator(storeKvs, Collections.emptyIterator(), false);
+        assertTrue(mergeIterator.hasNext());
+    }
+
+    @Test
+    public void shouldHaveNextFromReverseStore() {
+        final MergedSortedCacheSessionStoreIterator mergeIterator = createIterator(storeKvs, Collections.emptyIterator(), true);
         assertTrue(mergeIterator.hasNext());
     }
 
     @Test
     public void shouldGetNextFromStore() {
-        final MergedSortedCacheSessionStoreIterator mergeIterator = createIterator(storeKvs, Collections.emptyIterator());
+        final MergedSortedCacheSessionStoreIterator mergeIterator = createIterator(storeKvs, Collections.emptyIterator(), false);
+        assertThat(mergeIterator.next(), equalTo(KeyValue.pair(new Windowed<>(storeKey, storeWindow), storeKey.get())));
+    }
+
+    @Test
+    public void shouldGetNextFromReverseStore() {
+        final MergedSortedCacheSessionStoreIterator mergeIterator = createIterator(storeKvs, Collections.emptyIterator(), true);
         assertThat(mergeIterator.next(), equalTo(KeyValue.pair(new Windowed<>(storeKey, storeWindow), storeKey.get())));
     }
 
     @Test
     public void shouldPeekNextKeyFromStore() {
-        final MergedSortedCacheSessionStoreIterator mergeIterator = createIterator(storeKvs, Collections.emptyIterator());
+        final MergedSortedCacheSessionStoreIterator mergeIterator = createIterator(storeKvs, Collections.emptyIterator(), false);
+        assertThat(mergeIterator.peekNextKey(), equalTo(new Windowed<>(storeKey, storeWindow)));
+    }
+
+    @Test
+    public void shouldPeekNextKeyFromReverseStore() {
+        final MergedSortedCacheSessionStoreIterator mergeIterator = createIterator(storeKvs, Collections.emptyIterator(), true);
         assertThat(mergeIterator.peekNextKey(), equalTo(new Windowed<>(storeKey, storeWindow)));
     }
 
     @Test
     public void shouldHaveNextFromCache() {
-        final MergedSortedCacheSessionStoreIterator mergeIterator = createIterator(Collections.emptyIterator(), cacheKvs);
+        final MergedSortedCacheSessionStoreIterator mergeIterator = createIterator(Collections.emptyIterator(), cacheKvs, false);
+        assertTrue(mergeIterator.hasNext());
+    }
+
+    @Test
+    public void shouldHaveNextFromReverseCache() {
+        final MergedSortedCacheSessionStoreIterator mergeIterator = createIterator(Collections.emptyIterator(), cacheKvs, true);
         assertTrue(mergeIterator.hasNext());
     }
 
     @Test
     public void shouldGetNextFromCache() {
-        final MergedSortedCacheSessionStoreIterator mergeIterator = createIterator(Collections.emptyIterator(), cacheKvs);
+        final MergedSortedCacheSessionStoreIterator mergeIterator = createIterator(Collections.emptyIterator(), cacheKvs, false);
+        assertThat(mergeIterator.next(), equalTo(KeyValue.pair(new Windowed<>(cacheKey, cacheWindow), cacheKey.get())));
+    }
+
+    @Test
+    public void shouldGetNextFromReverseCache() {
+        final MergedSortedCacheSessionStoreIterator mergeIterator = createIterator(Collections.emptyIterator(), cacheKvs, true);
         assertThat(mergeIterator.next(), equalTo(KeyValue.pair(new Windowed<>(cacheKey, cacheWindow), cacheKey.get())));
     }
 
     @Test
     public void shouldPeekNextKeyFromCache() {
-        final MergedSortedCacheSessionStoreIterator mergeIterator = createIterator(Collections.emptyIterator(), cacheKvs);
+        final MergedSortedCacheSessionStoreIterator mergeIterator = createIterator(Collections.emptyIterator(), cacheKvs, false);
+        assertThat(mergeIterator.peekNextKey(), equalTo(new Windowed<>(cacheKey, cacheWindow)));
+    }
+
+    @Test
+    public void shouldPeekNextKeyFromReverseCache() {
+        final MergedSortedCacheSessionStoreIterator mergeIterator = createIterator(Collections.emptyIterator(), cacheKvs, true);
         assertThat(mergeIterator.peekNextKey(), equalTo(new Windowed<>(cacheKey, cacheWindow)));
     }
 
     @Test
     public void shouldIterateBothStoreAndCache() {
-        final MergedSortedCacheSessionStoreIterator iterator = createIterator(storeKvs, cacheKvs);
+        final MergedSortedCacheSessionStoreIterator iterator = createIterator(storeKvs, cacheKvs, true);
         assertThat(iterator.next(), equalTo(KeyValue.pair(new Windowed<>(storeKey, storeWindow), storeKey.get())));
         assertThat(iterator.next(), equalTo(KeyValue.pair(new Windowed<>(cacheKey, cacheWindow), cacheKey.get())));
         assertFalse(iterator.hasNext());
     }
 
+    @Test
+    public void shouldReverseIterateBothStoreAndCache() {
+        final MergedSortedCacheSessionStoreIterator iterator = createIterator(storeKvs, cacheKvs, false);
+        assertThat(iterator.next(), equalTo(KeyValue.pair(new Windowed<>(cacheKey, cacheWindow), cacheKey.get())));
+        assertThat(iterator.next(), equalTo(KeyValue.pair(new Windowed<>(storeKey, storeWindow), storeKey.get())));
+        assertFalse(iterator.hasNext());
+    }
+
     private MergedSortedCacheSessionStoreIterator createIterator(final Iterator<KeyValue<Windowed<Bytes>, byte[]>> storeKvs,
-                                                                 final Iterator<KeyValue<Bytes, LRUCacheEntry>> cacheKvs) {
+                                                                 final Iterator<KeyValue<Bytes, LRUCacheEntry>> cacheKvs,
+                                                                 final boolean forward) {
         final DelegatingPeekingKeyValueIterator<Windowed<Bytes>, byte[]> storeIterator =
             new DelegatingPeekingKeyValueIterator<>("store", new KeyValueIteratorStub<>(storeKvs));
 
         final PeekingKeyValueIterator<Bytes, LRUCacheEntry> cacheIterator =
             new DelegatingPeekingKeyValueIterator<>("cache", new KeyValueIteratorStub<>(cacheKvs));
-        return new MergedSortedCacheSessionStoreIterator(cacheIterator, storeIterator, SINGLE_SEGMENT_CACHE_FUNCTION);
+        return new MergedSortedCacheSessionStoreIterator(cacheIterator, storeIterator, SINGLE_SEGMENT_CACHE_FUNCTION, forward);
     }
 
 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreTest.java
index a77dd07..6b1889a 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreTest.java
@@ -303,6 +303,26 @@ public class MeteredSessionStoreTest {
     }
 
     @Test
+    public void shouldBackwardFindSessionsFromStoreAndRecordFetchMetric() {
+        expect(innerStore.backwardFindSessions(KEY_BYTES, 0, 0))
+            .andReturn(
+                new KeyValueIteratorStub<>(
+                    Collections.singleton(KeyValue.pair(WINDOWED_KEY_BYTES, VALUE_BYTES)).iterator()
+                )
+            );
+        init();
+
+        final KeyValueIterator<Windowed<String>, String> iterator = store.backwardFindSessions(KEY, 0, 0);
+        assertThat(iterator.next().value, equalTo(VALUE));
+        assertFalse(iterator.hasNext());
+        iterator.close();
+
+        final KafkaMetric metric = metric("fetch-rate");
+        assertTrue((Double) metric.metricValue() > 0);
+        verify(innerStore);
+    }
+
+    @Test
     public void shouldFindSessionRangeFromStoreAndRecordFetchMetric() {
         expect(innerStore.findSessions(KEY_BYTES, KEY_BYTES, 0, 0))
                 .andReturn(new KeyValueIteratorStub<>(
@@ -320,6 +340,26 @@ public class MeteredSessionStoreTest {
     }
 
     @Test
+    public void shouldBackwardFindSessionRangeFromStoreAndRecordFetchMetric() {
+        expect(innerStore.backwardFindSessions(KEY_BYTES, KEY_BYTES, 0, 0))
+            .andReturn(
+                new KeyValueIteratorStub<>(
+                    Collections.singleton(KeyValue.pair(WINDOWED_KEY_BYTES, VALUE_BYTES)).iterator()
+                )
+            );
+        init();
+
+        final KeyValueIterator<Windowed<String>, String> iterator = store.backwardFindSessions(KEY, KEY, 0, 0);
+        assertThat(iterator.next().value, equalTo(VALUE));
+        assertFalse(iterator.hasNext());
+        iterator.close();
+
+        final KafkaMetric metric = metric("fetch-rate");
+        assertTrue((Double) metric.metricValue() > 0);
+        verify(innerStore);
+    }
+
+    @Test
     public void shouldRemoveFromStoreAndRecordRemoveMetric() {
         innerStore.remove(WINDOWED_KEY_BYTES);
         expectLastCall();
@@ -351,6 +391,26 @@ public class MeteredSessionStoreTest {
     }
 
     @Test
+    public void shouldBackwardFetchForKeyAndRecordFetchMetric() {
+        expect(innerStore.backwardFetch(KEY_BYTES))
+            .andReturn(
+                new KeyValueIteratorStub<>(
+                    Collections.singleton(KeyValue.pair(WINDOWED_KEY_BYTES, VALUE_BYTES)).iterator()
+                )
+            );
+        init();
+
+        final KeyValueIterator<Windowed<String>, String> iterator = store.backwardFetch(KEY);
+        assertThat(iterator.next().value, equalTo(VALUE));
+        assertFalse(iterator.hasNext());
+        iterator.close();
+
+        final KafkaMetric metric = metric("fetch-rate");
+        assertTrue((Double) metric.metricValue() > 0);
+        verify(innerStore);
+    }
+
+    @Test
     public void shouldFetchRangeFromStoreAndRecordFetchMetric() {
         expect(innerStore.fetch(KEY_BYTES, KEY_BYTES))
                 .andReturn(new KeyValueIteratorStub<>(
@@ -368,6 +428,26 @@ public class MeteredSessionStoreTest {
     }
 
     @Test
+    public void shouldBackwardFetchRangeFromStoreAndRecordFetchMetric() {
+        expect(innerStore.backwardFetch(KEY_BYTES, KEY_BYTES))
+            .andReturn(
+                new KeyValueIteratorStub<>(
+                    Collections.singleton(KeyValue.pair(WINDOWED_KEY_BYTES, VALUE_BYTES)).iterator()
+                )
+            );
+        init();
+
+        final KeyValueIterator<Windowed<String>, String> iterator = store.backwardFetch(KEY, KEY);
+        assertThat(iterator.next().value, equalTo(VALUE));
+        assertFalse(iterator.hasNext());
+        iterator.close();
+
+        final KafkaMetric metric = metric("fetch-rate");
+        assertTrue((Double) metric.metricValue() > 0);
+        verify(innerStore);
+    }
+
+    @Test
     public void shouldRecordRestoreTimeOnInit() {
         init();
         final KafkaMetric metric = metric("restore-rate");
diff --git a/streams/src/test/java/org/apache/kafka/test/ReadOnlySessionStoreStub.java b/streams/src/test/java/org/apache/kafka/test/ReadOnlySessionStoreStub.java
index a2924fc..ff37e25 100644
--- a/streams/src/test/java/org/apache/kafka/test/ReadOnlySessionStoreStub.java
+++ b/streams/src/test/java/org/apache/kafka/test/ReadOnlySessionStoreStub.java
@@ -32,7 +32,7 @@ import java.util.NavigableMap;
 import java.util.TreeMap;
 
 public class ReadOnlySessionStoreStub<K, V> implements ReadOnlySessionStore<K, V>, StateStore {
-    private NavigableMap<K, List<KeyValue<Windowed<K>, V>>> sessions = new TreeMap<>();
+    private final NavigableMap<K, List<KeyValue<Windowed<K>, V>>> sessions = new TreeMap<>();
     private boolean open = true;
 
     public void put(final Windowed<K> sessionKey, final V value) {
@@ -43,6 +43,31 @@ public class ReadOnlySessionStoreStub<K, V> implements ReadOnlySessionStore<K, V
     }
 
     @Override
+    public KeyValueIterator<Windowed<K>, V> findSessions(K key, long earliestSessionEndTime, long latestSessionStartTime) {
+        throw new UnsupportedOperationException("Moved from Session Store. Implement if needed");
+    }
+
+    @Override
+    public KeyValueIterator<Windowed<K>, V> backwardFindSessions(K key, long earliestSessionEndTime, long latestSessionStartTime) {
+        throw new UnsupportedOperationException("Moved from Session Store. Implement if needed");
+    }
+
+    @Override
+    public KeyValueIterator<Windowed<K>, V> findSessions(K keyFrom, K keyTo, long earliestSessionEndTime, long latestSessionStartTime) {
+        throw new UnsupportedOperationException("Moved from Session Store. Implement if needed");
+    }
+
+    @Override
+    public KeyValueIterator<Windowed<K>, V> backwardFindSessions(K keyFrom, K keyTo, long earliestSessionEndTime, long latestSessionStartTime) {
+        throw new UnsupportedOperationException("Moved from Session Store. Implement if needed");
+    }
+
+    @Override
+    public V fetchSession(K key, long startTime, long endTime) {
+        throw new UnsupportedOperationException("Moved from Session Store. Implement if needed");
+    }
+
+    @Override
     public KeyValueIterator<Windowed<K>, V> fetch(final K key) {
         if (!open) {
             throw new InvalidStateStoreException("not open");
@@ -54,6 +79,17 @@ public class ReadOnlySessionStoreStub<K, V> implements ReadOnlySessionStore<K, V
     }
 
     @Override
+    public KeyValueIterator<Windowed<K>, V> backwardFetch(K key) {
+        if (!open) {
+            throw new InvalidStateStoreException("not open");
+        }
+        if (!sessions.containsKey(key)) {
+            return new KeyValueIteratorStub<>(Collections.emptyIterator());
+        }
+        return new KeyValueIteratorStub<>(sessions.descendingMap().get(key).iterator());
+    }
+
+    @Override
     public KeyValueIterator<Windowed<K>, V> fetch(final K from, final K to) {
         if (!open) {
             throw new InvalidStateStoreException("not open");
@@ -88,6 +124,40 @@ public class ReadOnlySessionStoreStub<K, V> implements ReadOnlySessionStore<K, V
     }
 
     @Override
+    public KeyValueIterator<Windowed<K>, V> backwardFetch(K from, K to) {
+        if (!open) {
+            throw new InvalidStateStoreException("not open");
+        }
+        if (sessions.subMap(from, true, to, true).isEmpty()) {
+            return new KeyValueIteratorStub<>(Collections.emptyIterator());
+        }
+        final Iterator<List<KeyValue<Windowed<K>, V>>> keysIterator =
+            sessions.subMap(from, true, to, true).descendingMap().values().iterator();
+        return new KeyValueIteratorStub<>(
+            new Iterator<KeyValue<Windowed<K>, V>>() {
+
+                Iterator<KeyValue<Windowed<K>, V>> it;
+
+                @Override
+                public boolean hasNext() {
+                    while (it == null || !it.hasNext()) {
+                        if (!keysIterator.hasNext()) {
+                            return false;
+                        }
+                        it = keysIterator.next().iterator();
+                    }
+                    return true;
+                }
+
+                @Override
+                public KeyValue<Windowed<K>, V> next() {
+                    return it.next();
+                }
+            }
+        );
+    }
+
+    @Override
     public String name() {
         return "";
     }