You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by gu...@apache.org on 2021/10/12 23:16:03 UTC

[kafka] branch trunk updated: KAFKA-13212: add support infinite query for session store (#11234)

This is an automated email from the ASF dual-hosted git repository.

guozhang pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 65b01a0  KAFKA-13212: add support infinite query for session store (#11234)
65b01a0 is described below

commit 65b01a0464d9bc4cc185cfb99ea80dc66cfeb2ec
Author: Luke Chen <sh...@gmail.com>
AuthorDate: Wed Oct 13 07:14:38 2021 +0800

    KAFKA-13212: add support infinite query for session store (#11234)
    
    Add support for infinite range query for SessionStore.
    
    Reviewers: Patrick Stuedi <ps...@apache.org>, Guozhang Wang <wa...@gmail.com>
---
 .../kafka/streams/state/ReadOnlySessionStore.java  |  18 ++-
 .../state/internals/CachingSessionStore.java       |  30 ++---
 .../internals/CompositeReadOnlySessionStore.java   |   8 --
 .../state/internals/InMemorySessionStore.java      |  57 ++++-----
 .../state/internals/MeteredSessionStore.java       |  10 +-
 .../state/internals/MeteredWindowStore.java        |   8 +-
 .../internals/AbstractSessionBytesStoreTest.java   | 106 ++++++++++++----
 .../internals/CachingInMemorySessionStoreTest.java | 128 +++++++++++++++----
 .../CachingPersistentSessionStoreTest.java         | 128 +++++++++++++++----
 .../CompositeReadOnlySessionStoreTest.java         |  45 +++++--
 .../state/internals/SessionStoreFetchTest.java     | 137 ++++++++++++++++++---
 .../kafka/test/ReadOnlySessionStoreStub.java       |  31 ++++-
 12 files changed, 529 insertions(+), 177 deletions(-)

diff --git a/streams/src/main/java/org/apache/kafka/streams/state/ReadOnlySessionStore.java b/streams/src/main/java/org/apache/kafka/streams/state/ReadOnlySessionStore.java
index 0ade242..049c1a4 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/ReadOnlySessionStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/ReadOnlySessionStore.java
@@ -126,14 +126,15 @@ public interface ReadOnlySessionStore<K, AGG> {
      * This iterator must be closed after use.
      *
      * @param keyFrom                The first key that could be in the range
+     * A null value indicates a starting position from the first element in the store.
      * @param keyTo                  The last key that could be in the range
+     * A null value indicates that the range ends with the last element in the store.
      * @param earliestSessionEndTime the end timestamp of the earliest session to search for, where
      *                               iteration starts.
      * @param latestSessionStartTime the end timestamp of the latest session to search for, where
      *                               iteration ends.
      * @return iterator of sessions with the matching keys and aggregated values, from earliest to
      * latest session time.
-     * @throws NullPointerException If null is used for any key.
      */
     default KeyValueIterator<Windowed<K>, AGG> findSessions(final K keyFrom,
                                                             final K keyTo,
@@ -151,14 +152,15 @@ public interface ReadOnlySessionStore<K, AGG> {
      * This iterator must be closed after use.
      *
      * @param keyFrom                The first key that could be in the range
+     * A null value indicates a starting position from the first element in the store.
      * @param keyTo                  The last key that could be in the range
+     * A null value indicates that the range ends with the last element in the store.
      * @param earliestSessionEndTime the end timestamp of the earliest session to search for, where
      *                               iteration starts.
      * @param latestSessionStartTime the end timestamp of the latest session to search for, where
      *                               iteration ends.
      * @return iterator of sessions with the matching keys and aggregated values, from earliest to
      * latest session time.
-     * @throws NullPointerException If null is used for any key.
      */
     default KeyValueIterator<Windowed<K>, AGG> findSessions(final K keyFrom,
                                                             final K keyTo,
@@ -176,14 +178,15 @@ public interface ReadOnlySessionStore<K, AGG> {
      * This iterator must be closed after use.
      *
      * @param keyFrom                The first key that could be in the range
+     * A null value indicates a starting position from the first element in the store.
      * @param keyTo                  The last key that could be in the range
+     * A null value indicates that the range ends with the last element in the store.
      * @param earliestSessionEndTime the end timestamp of the earliest session to search for, where
      *                               iteration ends.
      * @param latestSessionStartTime the end timestamp of the latest session to search for, where
      *                               iteration starts.
      * @return backward iterator of sessions with the matching keys and aggregated values, from
      * latest to earliest session time.
-     * @throws NullPointerException If null is used for any key.
      */
     default KeyValueIterator<Windowed<K>, AGG> backwardFindSessions(final K keyFrom,
                                                                     final K keyTo,
@@ -201,14 +204,15 @@ public interface ReadOnlySessionStore<K, AGG> {
      * This iterator must be closed after use.
      *
      * @param keyFrom                The first key that could be in the range
+     * A null value indicates a starting position from the first element in the store.
      * @param keyTo                  The last key that could be in the range
+     * A null value indicates that the range ends with the last element in the store.
      * @param earliestSessionEndTime the end timestamp of the earliest session to search for, where
      *                               iteration ends.
      * @param latestSessionStartTime the end timestamp of the latest session to search for, where
      *                               iteration starts.
      * @return backward iterator of sessions with the matching keys and aggregated values, from
      * latest to earliest session time.
-     * @throws NullPointerException If null is used for any key.
      */
     default KeyValueIterator<Windowed<K>, AGG> backwardFindSessions(final K keyFrom,
                                                                     final K keyTo,
@@ -289,10 +293,11 @@ public interface ReadOnlySessionStore<K, AGG> {
      * available session to the newest/latest session.
      *
      * @param keyFrom first key in the range to find aggregated session values for
+     * A null value indicates a starting position from the first element in the store.
      * @param keyTo   last key in the range to find aggregated session values for
+     * A null value indicates that the range ends with the last element in the store.
      * @return KeyValueIterator containing all sessions for the provided key, from oldest to newest
      * session.
-     * @throws NullPointerException If null is used for any of the keys.
      */
     KeyValueIterator<Windowed<K>, AGG> fetch(final K keyFrom, final K keyTo);
 
@@ -304,10 +309,11 @@ public interface ReadOnlySessionStore<K, AGG> {
      * available session to the oldest/earliest session.
      *
      * @param keyFrom first key in the range to find aggregated session values for
+     * A null value indicates a starting position from the first element in the store.
      * @param keyTo   last key in the range to find aggregated session values for
+     * A null value indicates that the range ends with the last element in the store.
      * @return backward KeyValueIterator containing all sessions for the provided key, from newest
      * to oldest session.
-     * @throws NullPointerException If null is used for any of the keys.
      */
     default KeyValueIterator<Windowed<K>, AGG> backwardFetch(final K keyFrom, final K keyTo) {
         throw new UnsupportedOperationException(
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingSessionStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingSessionStore.java
index 1cfb8ce..9b07fe8 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingSessionStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingSessionStore.java
@@ -47,6 +47,10 @@ class CachingSessionStore
 
     private final SessionKeySchema keySchema;
     private final SegmentedCacheFunction cacheFunction;
+    private static final String INVALID_RANGE_WARN_MSG = "Returning empty iterator for fetch with invalid key range: from > to. " +
+        "This may be due to range arguments set in the wrong order, " +
+        "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes. " +
+        "Note that the built-in numerical serdes do not follow this for negative numbers";
 
     private String cacheName;
     private InternalProcessorContext context;
@@ -212,18 +216,15 @@ class CachingSessionStore
                                                                   final Bytes keyTo,
                                                                   final long earliestSessionEndTime,
                                                                   final long latestSessionStartTime) {
-        if (keyFrom.compareTo(keyTo) > 0) {
-            LOG.warn("Returning empty iterator for fetch with invalid key range: from > to. " +
-                "This may be due to range arguments set in the wrong order, " +
-                "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes. " +
-                "Note that the built-in numerical serdes do not follow this for negative numbers");
+        if (keyFrom != null && keyTo != null && keyFrom.compareTo(keyTo) > 0) {
+            LOG.warn(INVALID_RANGE_WARN_MSG);
             return KeyValueIterators.emptyIterator();
         }
 
         validateStoreOpen();
 
-        final Bytes cacheKeyFrom = cacheFunction.cacheKey(keySchema.lowerRange(keyFrom, earliestSessionEndTime));
-        final Bytes cacheKeyTo = cacheFunction.cacheKey(keySchema.upperRange(keyTo, latestSessionStartTime));
+        final Bytes cacheKeyFrom = keyFrom == null ? null : cacheFunction.cacheKey(keySchema.lowerRange(keyFrom, earliestSessionEndTime));
+        final Bytes cacheKeyTo = keyTo == null ? null : cacheFunction.cacheKey(keySchema.upperRange(keyTo, latestSessionStartTime));
         final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = context.cache().range(cacheName, cacheKeyFrom, cacheKeyTo);
 
         final KeyValueIterator<Windowed<Bytes>, byte[]> storeIterator = wrapped().findSessions(
@@ -243,18 +244,15 @@ class CachingSessionStore
                                                                           final Bytes keyTo,
                                                                           final long earliestSessionEndTime,
                                                                           final long latestSessionStartTime) {
-        if (keyFrom.compareTo(keyTo) > 0) {
-            LOG.warn("Returning empty iterator for fetch with invalid key range: from > to. " +
-                         "This may be due to range arguments set in the wrong order, " +
-                         "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes. " +
-                         "Note that the built-in numerical serdes do not follow this for negative numbers");
+        if (keyFrom != null && keyTo != null && keyFrom.compareTo(keyTo) > 0) {
+            LOG.warn(INVALID_RANGE_WARN_MSG);
             return KeyValueIterators.emptyIterator();
         }
 
         validateStoreOpen();
 
-        final Bytes cacheKeyFrom = cacheFunction.cacheKey(keySchema.lowerRange(keyFrom, earliestSessionEndTime));
-        final Bytes cacheKeyTo = cacheFunction.cacheKey(keySchema.upperRange(keyTo, latestSessionStartTime));
+        final Bytes cacheKeyFrom = keyFrom == null ? null : cacheFunction.cacheKey(keySchema.lowerRange(keyFrom, earliestSessionEndTime));
+        final Bytes cacheKeyTo = keyTo == null ? null : cacheFunction.cacheKey(keySchema.upperRange(keyTo, latestSessionStartTime));
         final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = context.cache().reverseRange(cacheName, cacheKeyFrom, cacheKeyTo);
 
         final KeyValueIterator<Windowed<Bytes>, byte[]> storeIterator =
@@ -304,16 +302,12 @@ class CachingSessionStore
     @Override
     public KeyValueIterator<Windowed<Bytes>, byte[]> fetch(final Bytes keyFrom,
                                                            final Bytes keyTo) {
-        Objects.requireNonNull(keyFrom, "keyFrom cannot be null");
-        Objects.requireNonNull(keyTo, "keyTo cannot be null");
         return findSessions(keyFrom, keyTo, 0, Long.MAX_VALUE);
     }
 
     @Override
     public KeyValueIterator<Windowed<Bytes>, byte[]> backwardFetch(final Bytes keyFrom,
                                                                    final Bytes keyTo) {
-        Objects.requireNonNull(keyFrom, "keyFrom cannot be null");
-        Objects.requireNonNull(keyTo, "keyTo cannot be null");
         return backwardFindSessions(keyFrom, keyTo, 0, Long.MAX_VALUE);
     }
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/CompositeReadOnlySessionStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/CompositeReadOnlySessionStore.java
index fb5fb61..0f153cc 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/CompositeReadOnlySessionStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/CompositeReadOnlySessionStore.java
@@ -101,8 +101,6 @@ public class CompositeReadOnlySessionStore<K, V> implements ReadOnlySessionStore
                                                          final K keyTo,
                                                          final long earliestSessionEndTime,
                                                          final long latestSessionStartTime) {
-        Objects.requireNonNull(keyFrom, "from can't be null");
-        Objects.requireNonNull(keyTo, "to can't be null");
         final List<ReadOnlySessionStore<K, V>> stores = storeProvider.stores(storeName, queryableStoreType);
         for (final ReadOnlySessionStore<K, V> store : stores) {
             try {
@@ -130,8 +128,6 @@ public class CompositeReadOnlySessionStore<K, V> implements ReadOnlySessionStore
                                                                  final K keyTo,
                                                                  final long earliestSessionEndTime,
                                                                  final long latestSessionStartTime) {
-        Objects.requireNonNull(keyFrom, "from can't be null");
-        Objects.requireNonNull(keyTo, "to can't be null");
         final List<ReadOnlySessionStore<K, V>> stores = storeProvider.stores(storeName, queryableStoreType);
         for (final ReadOnlySessionStore<K, V> store : stores) {
             try {
@@ -221,8 +217,6 @@ public class CompositeReadOnlySessionStore<K, V> implements ReadOnlySessionStore
 
     @Override
     public KeyValueIterator<Windowed<K>, V> fetch(final K keyFrom, final K keyTo) {
-        Objects.requireNonNull(keyFrom, "keyFrom can't be null");
-        Objects.requireNonNull(keyTo, "keyTo can't be null");
         final NextIteratorFunction<Windowed<K>, V, ReadOnlySessionStore<K, V>> nextIteratorFunction =
             store -> store.fetch(keyFrom, keyTo);
         return new DelegatingPeekingKeyValueIterator<>(storeName,
@@ -233,8 +227,6 @@ public class CompositeReadOnlySessionStore<K, V> implements ReadOnlySessionStore
 
     @Override
     public KeyValueIterator<Windowed<K>, V> backwardFetch(final K keyFrom, final K keyTo) {
-        Objects.requireNonNull(keyFrom, "keyFrom can't be null");
-        Objects.requireNonNull(keyTo, "keyTo can't be null");
         final NextIteratorFunction<Windowed<K>, V, ReadOnlySessionStore<K, V>> nextIteratorFunction =
             store -> store.backwardFetch(keyFrom, keyTo);
         return new DelegatingPeekingKeyValueIterator<>(
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemorySessionStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemorySessionStore.java
index a14d3e1..bc8cda6 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemorySessionStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemorySessionStore.java
@@ -54,6 +54,12 @@ public class InMemorySessionStore implements SessionStore<Bytes, byte[]> {
 
     private final long retentionPeriod;
 
+    private final static String INVALID_RANGE_WARN_MSG =
+        "Returning empty iterator for fetch with invalid key range: from > to. " +
+        "This may be due to range arguments set in the wrong order, " +
+        "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes. " +
+        "Note that the built-in numerical serdes do not follow this for negative numbers";
+
     private final ConcurrentNavigableMap<Long, ConcurrentNavigableMap<Bytes, ConcurrentNavigableMap<Long, byte[]>>> endTimeMap = new ConcurrentSkipListMap<>();
     private final Set<InMemorySessionStoreIterator> openIterators  = ConcurrentHashMap.newKeySet();
 
@@ -205,16 +211,10 @@ public class InMemorySessionStore implements SessionStore<Bytes, byte[]> {
                                                                   final Bytes keyTo,
                                                                   final long earliestSessionEndTime,
                                                                   final long latestSessionStartTime) {
-        Objects.requireNonNull(keyFrom, "from key cannot be null");
-        Objects.requireNonNull(keyTo, "to key cannot be null");
-
         removeExpiredSegments();
 
-        if (keyFrom.compareTo(keyTo) > 0) {
-            LOG.warn("Returning empty iterator for fetch with invalid key range: from > to. " +
-                "This may be due to range arguments set in the wrong order, " +
-                "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes. " +
-                "Note that the built-in numerical serdes do not follow this for negative numbers");
+        if (keyFrom != null && keyTo != null && keyFrom.compareTo(keyTo) > 0) {
+            LOG.warn(INVALID_RANGE_WARN_MSG);
             return KeyValueIterators.emptyIterator();
         }
 
@@ -230,16 +230,10 @@ public class InMemorySessionStore implements SessionStore<Bytes, byte[]> {
                                                                           final Bytes keyTo,
                                                                           final long earliestSessionEndTime,
                                                                           final long latestSessionStartTime) {
-        Objects.requireNonNull(keyFrom, "from key cannot be null");
-        Objects.requireNonNull(keyTo, "to key cannot be null");
-
         removeExpiredSegments();
 
-        if (keyFrom.compareTo(keyTo) > 0) {
-            LOG.warn("Returning empty iterator for fetch with invalid key range: from > to. " +
-                "This may be due to range arguments set in the wrong order, " +
-                "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes. " +
-                "Note that the built-in numerical serdes do not follow this for negative numbers");
+        if (keyFrom != null && keyTo != null && keyFrom.compareTo(keyTo) > 0) {
+            LOG.warn(INVALID_RANGE_WARN_MSG);
             return KeyValueIterators.emptyIterator();
         }
 
@@ -274,21 +268,13 @@ public class InMemorySessionStore implements SessionStore<Bytes, byte[]> {
 
     @Override
     public KeyValueIterator<Windowed<Bytes>, byte[]> fetch(final Bytes keyFrom, final Bytes keyTo) {
-
-        Objects.requireNonNull(keyFrom, "from key cannot be null");
-        Objects.requireNonNull(keyTo, "to key cannot be null");
-
         removeExpiredSegments();
 
-
         return registerNewIterator(keyFrom, keyTo, Long.MAX_VALUE, endTimeMap.entrySet().iterator(), true);
     }
 
     @Override
     public KeyValueIterator<Windowed<Bytes>, byte[]> backwardFetch(final Bytes keyFrom, final Bytes keyTo) {
-        Objects.requireNonNull(keyFrom, "from key cannot be null");
-        Objects.requireNonNull(keyTo, "to key cannot be null");
-
         removeExpiredSegments();
 
         return registerNewIterator(
@@ -457,17 +443,22 @@ public class InMemorySessionStore implements SessionStore<Bytes, byte[]> {
             while (endTimeIterator.hasNext()) {
                 final Entry<Long, ConcurrentNavigableMap<Bytes, ConcurrentNavigableMap<Long, byte[]>>> nextEndTimeEntry = endTimeIterator.next();
                 currentEndTime = nextEndTimeEntry.getKey();
+
+                final ConcurrentNavigableMap<Bytes, ConcurrentNavigableMap<Long, byte[]>> subKVMap;
+                if (keyFrom == null && keyTo == null) {
+                    subKVMap = nextEndTimeEntry.getValue();
+                } else if (keyFrom == null) {
+                    subKVMap = nextEndTimeEntry.getValue().headMap(keyTo, true);
+                } else if (keyTo == null) {
+                    subKVMap = nextEndTimeEntry.getValue().tailMap(keyFrom, true);
+                } else {
+                    subKVMap = nextEndTimeEntry.getValue().subMap(keyFrom, true, keyTo, true);
+                }
+
                 if (forward) {
-                    keyIterator = nextEndTimeEntry.getValue()
-                                                  .subMap(keyFrom, true, keyTo, true)
-                                                  .entrySet()
-                                                  .iterator();
+                    keyIterator = subKVMap.entrySet().iterator();
                 } else {
-                    keyIterator = nextEndTimeEntry.getValue()
-                                                  .subMap(keyFrom, true, keyTo, true)
-                                                  .descendingMap()
-                                                  .entrySet()
-                                                  .iterator();
+                    keyIterator = subKVMap.descendingMap().entrySet().iterator();
                 }
 
                 if (setInnerIterators()) {
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStore.java
index f717899..041f391 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStore.java
@@ -265,8 +265,6 @@ public class MeteredSessionStore<K, V>
     @Override
     public KeyValueIterator<Windowed<K>, V> fetch(final K keyFrom,
                                                   final K keyTo) {
-        Objects.requireNonNull(keyFrom, "keyFrom cannot be null");
-        Objects.requireNonNull(keyTo, "keyTo cannot be null");
         return new MeteredWindowedKeyValueIterator<>(
             wrapped().fetch(keyBytes(keyFrom), keyBytes(keyTo)),
             fetchSensor,
@@ -278,8 +276,6 @@ public class MeteredSessionStore<K, V>
     @Override
     public KeyValueIterator<Windowed<K>, V> backwardFetch(final K keyFrom,
                                                           final K keyTo) {
-        Objects.requireNonNull(keyFrom, "keyFrom cannot be null");
-        Objects.requireNonNull(keyTo, "keyTo cannot be null");
         return new MeteredWindowedKeyValueIterator<>(
             wrapped().backwardFetch(keyBytes(keyFrom), keyBytes(keyTo)),
             fetchSensor,
@@ -330,8 +326,6 @@ public class MeteredSessionStore<K, V>
                                                          final K keyTo,
                                                          final long earliestSessionEndTime,
                                                          final long latestSessionStartTime) {
-        Objects.requireNonNull(keyFrom, "keyFrom cannot be null");
-        Objects.requireNonNull(keyTo, "keyTo cannot be null");
         final Bytes bytesKeyFrom = keyBytes(keyFrom);
         final Bytes bytesKeyTo = keyBytes(keyTo);
         return new MeteredWindowedKeyValueIterator<>(
@@ -351,8 +345,6 @@ public class MeteredSessionStore<K, V>
                                                                  final K keyTo,
                                                                  final long earliestSessionEndTime,
                                                                  final long latestSessionStartTime) {
-        Objects.requireNonNull(keyFrom, "keyFrom cannot be null");
-        Objects.requireNonNull(keyTo, "keyTo cannot be null");
         final Bytes bytesKeyFrom = keyBytes(keyFrom);
         final Bytes bytesKeyTo = keyBytes(keyTo);
         return new MeteredWindowedKeyValueIterator<>(
@@ -384,7 +376,7 @@ public class MeteredSessionStore<K, V>
     }
 
     private Bytes keyBytes(final K key) {
-        return Bytes.wrap(serdes.rawKey(key));
+        return key == null ? null : Bytes.wrap(serdes.rawKey(key));
     }
 
     private void maybeRecordE2ELatency() {
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredWindowStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredWindowStore.java
index 48b42fb..1a45b55 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredWindowStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredWindowStore.java
@@ -247,8 +247,8 @@ public class MeteredWindowStore<K, V>
                                                   final long timeTo) {
         return new MeteredWindowedKeyValueIterator<>(
             wrapped().fetch(
-                keyFrom == null ? null : keyBytes(keyFrom),
-                keyTo == null ? null : keyBytes(keyTo),
+                keyBytes(keyFrom),
+                keyBytes(keyTo),
                 timeFrom,
                 timeTo),
             fetchSensor,
@@ -264,8 +264,8 @@ public class MeteredWindowStore<K, V>
                                                           final long timeTo) {
         return new MeteredWindowedKeyValueIterator<>(
             wrapped().backwardFetch(
-                keyFrom == null ? null : keyBytes(keyFrom),
-                keyTo == null ? null : keyBytes(keyTo),
+                keyBytes(keyFrom),
+                keyBytes(keyTo),
                 timeFrom,
                 timeTo),
             fetchSensor,
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractSessionBytesStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractSessionBytesStoreTest.java
index b2a1022..78b9d63 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractSessionBytesStoreTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractSessionBytesStoreTest.java
@@ -212,7 +212,7 @@ public abstract class AbstractSessionBytesStoreTest {
             sessionStore.put(kv.key, kv.value);
         }
 
-        // add some that shouldn't appear in the results
+        // add some that should only be fetched in infinite fetch
         sessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L);
         sessionStore.put(new Windowed<>("bbb", new SessionWindow(2500, 3000)), 6L);
 
@@ -223,6 +223,29 @@ public abstract class AbstractSessionBytesStoreTest {
         try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.findSessions("aa", "bb", 0L, Long.MAX_VALUE)) {
             assertEquals(expected, toList(values));
         }
+
+        // infinite keyFrom fetch case
+        expected.add(0, KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 1L));
+
+        try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.fetch(null, "bb")) {
+            assertEquals(expected, toList(values));
+        }
+
+        // remove the one added for unlimited start fetch case
+        expected.remove(0);
+        // infinite keyTo fetch case
+        expected.add(KeyValue.pair(new Windowed<>("bbb", new SessionWindow(2500, 3000)), 6L));
+
+        try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.fetch("aa", null)) {
+            assertEquals(expected, toList(values));
+        }
+
+        // fetch all case
+        expected.add(0, KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 1L));
+
+        try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.fetch(null, null)) {
+            assertEquals(expected, toList(values));
+        }
     }
 
     @Test
@@ -238,7 +261,7 @@ public abstract class AbstractSessionBytesStoreTest {
             sessionStore.put(kv.key, kv.value);
         }
 
-        // add some that shouldn't appear in the results
+        // add some that should only be fetched in infinite fetch
         sessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L);
         sessionStore.put(new Windowed<>("bbb", new SessionWindow(2500, 3000)), 6L);
 
@@ -249,6 +272,29 @@ public abstract class AbstractSessionBytesStoreTest {
         try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.backwardFindSessions("aa", "bb", 0L, Long.MAX_VALUE)) {
             assertEquals(toList(expected.descendingIterator()), toList(values));
         }
+
+        // infinite keyFrom fetch case
+        expected.add(0, KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 1L));
+
+        try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.backwardFetch(null, "bb")) {
+            assertEquals(toList(expected.descendingIterator()), toList(values));
+        }
+
+        // remove the one added for unlimited start fetch case
+        expected.remove(0);
+        // infinite keyTo fetch case
+        expected.add(KeyValue.pair(new Windowed<>("bbb", new SessionWindow(2500, 3000)), 6L));
+
+        try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.backwardFetch("aa", null)) {
+            assertEquals(toList(expected.descendingIterator()), toList(values));
+        }
+
+        // fetch all case
+        expected.add(0, KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 1L));
+
+        try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.backwardFetch(null, null)) {
+            assertEquals(toList(expected.descendingIterator()), toList(values));
+        }
     }
 
     @Test
@@ -408,6 +454,24 @@ public abstract class AbstractSessionBytesStoreTest {
         ) {
             assertThat(valuesToSet(iterator), equalTo(new HashSet<>(asList(2L))));
         }
+
+        try (final KeyValueIterator<Windowed<String>, Long> iterator =
+                 sessionStore.findSessions(null, "aa", 0, Long.MAX_VALUE)
+        ) {
+            assertThat(valuesToSet(iterator), equalTo(new HashSet<>(asList(1L, 2L, 3L, 4L, 5L))));
+        }
+
+        try (final KeyValueIterator<Windowed<String>, Long> iterator =
+                 sessionStore.findSessions("a", null, 0, Long.MAX_VALUE)
+        ) {
+            assertThat(valuesToSet(iterator), equalTo(new HashSet<>(asList(1L, 2L, 3L, 4L, 5L))));
+        }
+
+        try (final KeyValueIterator<Windowed<String>, Long> iterator =
+                 sessionStore.findSessions(null, null, 0, Long.MAX_VALUE)
+        ) {
+            assertThat(valuesToSet(iterator), equalTo(new HashSet<>(asList(1L, 2L, 3L, 4L, 5L))));
+        }
     }
 
     @Test
@@ -446,6 +510,24 @@ public abstract class AbstractSessionBytesStoreTest {
         ) {
             assertThat(valuesToSet(iterator), equalTo(new HashSet<>(asList(2L))));
         }
+
+        try (final KeyValueIterator<Windowed<String>, Long> iterator =
+                 sessionStore.backwardFindSessions(null, "aa", 0, Long.MAX_VALUE)
+        ) {
+            assertThat(valuesToSet(iterator), equalTo(new HashSet<>(asList(1L, 2L, 3L, 4L, 5L))));
+        }
+
+        try (final KeyValueIterator<Windowed<String>, Long> iterator =
+                 sessionStore.backwardFindSessions("a", null, 0, Long.MAX_VALUE)
+        ) {
+            assertThat(valuesToSet(iterator), equalTo(new HashSet<>(asList(1L, 2L, 3L, 4L, 5L))));
+        }
+
+        try (final KeyValueIterator<Windowed<String>, Long> iterator =
+                 sessionStore.backwardFindSessions(null, null, 0, Long.MAX_VALUE)
+        ) {
+            assertThat(valuesToSet(iterator), equalTo(new HashSet<>(asList(1L, 2L, 3L, 4L, 5L))));
+        }
     }
 
     @Test
@@ -693,26 +775,6 @@ public abstract class AbstractSessionBytesStoreTest {
     }
 
     @Test
-    public void shouldThrowNullPointerExceptionOnFindSessionsNullFromKey() {
-        assertThrows(NullPointerException.class, () -> sessionStore.findSessions(null, "anyKeyTo", 1L, 2L));
-    }
-
-    @Test
-    public void shouldThrowNullPointerExceptionOnFindSessionsNullToKey() {
-        assertThrows(NullPointerException.class, () -> sessionStore.findSessions("anyKeyFrom", null, 1L, 2L));
-    }
-
-    @Test
-    public void shouldThrowNullPointerExceptionOnFetchNullFromKey() {
-        assertThrows(NullPointerException.class, () -> sessionStore.fetch(null, "anyToKey"));
-    }
-
-    @Test
-    public void shouldThrowNullPointerExceptionOnFetchNullToKey() {
-        assertThrows(NullPointerException.class, () -> sessionStore.fetch("anyFromKey", null));
-    }
-
-    @Test
     public void shouldThrowNullPointerExceptionOnFetchNullKey() {
         assertThrows(NullPointerException.class, () -> sessionStore.fetch(null));
     }
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingInMemorySessionStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingInMemorySessionStoreTest.java
index d105976..4116df5 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingInMemorySessionStoreTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingInMemorySessionStoreTest.java
@@ -150,7 +150,31 @@ public class CachingInMemorySessionStoreTest {
 
         assertEquals(3, cache.size());
 
-        try (final KeyValueIterator<Windowed<Bytes>, byte[]> all = cachingStore.findSessions(keyA, keyB, 0, 0)) {
+        try (final KeyValueIterator<Windowed<Bytes>, byte[]> all = cachingStore.fetch(keyA, keyB)) {
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
+            assertFalse(all.hasNext());
+        }
+
+        // infinite keyFrom fetch
+        try (final KeyValueIterator<Windowed<Bytes>, byte[]> all = cachingStore.fetch(null, keyB)) {
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
+            assertFalse(all.hasNext());
+        }
+
+        // infinite keyTo fetch
+        try (final KeyValueIterator<Windowed<Bytes>, byte[]> all = cachingStore.fetch(null, keyB)) {
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
+            assertFalse(all.hasNext());
+        }
+
+        // infinite keyFrom and keyTo fetch
+        try (final KeyValueIterator<Windowed<Bytes>, byte[]> all = cachingStore.fetch(null, keyB)) {
             verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1");
             verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
             verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
@@ -166,7 +190,31 @@ public class CachingInMemorySessionStoreTest {
 
         assertEquals(3, cache.size());
 
-        try (final KeyValueIterator<Windowed<Bytes>, byte[]> all = cachingStore.backwardFindSessions(keyA, keyB, 0, 0)) {
+        try (final KeyValueIterator<Windowed<Bytes>, byte[]> all = cachingStore.backwardFetch(keyA, keyB)) {
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1");
+            assertFalse(all.hasNext());
+        }
+
+        // infinite keyFrom fetch
+        try (final KeyValueIterator<Windowed<Bytes>, byte[]> all = cachingStore.backwardFetch(null, keyB)) {
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1");
+            assertFalse(all.hasNext());
+        }
+
+        // infinite keyTo fetch
+        try (final KeyValueIterator<Windowed<Bytes>, byte[]> all = cachingStore.backwardFetch(null, keyB)) {
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1");
+            assertFalse(all.hasNext());
+        }
+
+        // infinite keyFrom and keyTo fetch
+        try (final KeyValueIterator<Windowed<Bytes>, byte[]> all = cachingStore.backwardFetch(null, null)) {
             verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
             verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
             verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1");
@@ -241,7 +289,33 @@ public class CachingInMemorySessionStoreTest {
 
         assertEquals(3, cache.size());
 
-        try (final KeyValueIterator<Windowed<Bytes>, byte[]> some = cachingStore.findSessions(keyAA, keyB, 0, 0)) {
+        try (final KeyValueIterator<Windowed<Bytes>, byte[]> some =
+                 cachingStore.findSessions(keyAA, keyB, 0, 0)) {
+            verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(some.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
+            assertFalse(some.hasNext());
+        }
+
+        // infinite keyFrom case
+        try (final KeyValueIterator<Windowed<Bytes>, byte[]> some =
+                 cachingStore.findSessions(null, keyAA, 0, 0)) {
+            verifyWindowedKeyValue(some.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
+            assertFalse(some.hasNext());
+        }
+
+        // infinite keyTo case
+        try (final KeyValueIterator<Windowed<Bytes>, byte[]> some =
+                 cachingStore.findSessions(keyAA, keyB, 0, 0)) {
+            verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(some.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
+            assertFalse(some.hasNext());
+        }
+
+        // infinite keyFrom and keyTo case
+        try (final KeyValueIterator<Windowed<Bytes>, byte[]> some =
+                 cachingStore.findSessions(null, null, 0, 0)) {
+            verifyWindowedKeyValue(some.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1");
             verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
             verifyWindowedKeyValue(some.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
             assertFalse(some.hasNext());
@@ -256,9 +330,35 @@ public class CachingInMemorySessionStoreTest {
 
         assertEquals(3, cache.size());
 
-        try (final KeyValueIterator<Windowed<Bytes>, byte[]> some = cachingStore.backwardFindSessions(keyAA, keyB, 0, 0)) {
+        try (final KeyValueIterator<Windowed<Bytes>, byte[]> some =
+                 cachingStore.backwardFindSessions(keyAA, keyB, 0, 0)) {
+            verifyWindowedKeyValue(some.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
+            assertFalse(some.hasNext());
+        }
+
+        // infinite keyFrom case
+        try (final KeyValueIterator<Windowed<Bytes>, byte[]> some =
+                 cachingStore.backwardFindSessions(null, keyAA, 0, 0)) {
+            verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(some.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1");
+            assertFalse(some.hasNext());
+        }
+
+        // infinite keyTo case
+        try (final KeyValueIterator<Windowed<Bytes>, byte[]> some =
+                 cachingStore.backwardFindSessions(keyAA, keyB, 0, 0)) {
+            verifyWindowedKeyValue(some.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
+            assertFalse(some.hasNext());
+        }
+
+        // infinite keyFrom and keyTo case
+        try (final KeyValueIterator<Windowed<Bytes>, byte[]> some =
+                 cachingStore.backwardFindSessions(null, null, 0, 0)) {
             verifyWindowedKeyValue(some.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
             verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(some.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1");
             assertFalse(some.hasNext());
         }
     }
@@ -645,26 +745,6 @@ public class CachingInMemorySessionStoreTest {
     }
 
     @Test
-    public void shouldThrowNullPointerExceptionOnFindSessionsNullFromKey() {
-        assertThrows(NullPointerException.class, () -> cachingStore.findSessions(null, keyA, 1L, 2L));
-    }
-
-    @Test
-    public void shouldThrowNullPointerExceptionOnFindSessionsNullToKey() {
-        assertThrows(NullPointerException.class, () -> cachingStore.findSessions(keyA, null, 1L, 2L));
-    }
-
-    @Test
-    public void shouldThrowNullPointerExceptionOnFetchNullFromKey() {
-        assertThrows(NullPointerException.class, () -> cachingStore.fetch(null, keyA));
-    }
-
-    @Test
-    public void shouldThrowNullPointerExceptionOnFetchNullToKey() {
-        assertThrows(NullPointerException.class, () -> cachingStore.fetch(keyA, null));
-    }
-
-    @Test
     public void shouldThrowNullPointerExceptionOnFetchNullKey() {
         assertThrows(NullPointerException.class, () -> cachingStore.fetch(null));
     }
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingPersistentSessionStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingPersistentSessionStoreTest.java
index 95ebf77..83a514b 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingPersistentSessionStoreTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingPersistentSessionStoreTest.java
@@ -133,7 +133,34 @@ public class CachingPersistentSessionStoreTest {
         assertEquals(3, cache.size());
 
         try (final KeyValueIterator<Windowed<Bytes>, byte[]> all =
-                 cachingStore.findSessions(keyA, keyB, 0, 0)) {
+                 cachingStore.fetch(keyA, keyB)) {
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
+            assertFalse(all.hasNext());
+        }
+
+        // infinite keyFrom fetch
+        try (final KeyValueIterator<Windowed<Bytes>, byte[]> all =
+                 cachingStore.fetch(null, keyB)) {
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
+            assertFalse(all.hasNext());
+        }
+
+        // infinite keyTo fetch
+        try (final KeyValueIterator<Windowed<Bytes>, byte[]> all =
+                 cachingStore.fetch(keyA, null)) {
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
+            assertFalse(all.hasNext());
+        }
+
+        // infinite keyFrom and keyTo fetch
+        try (final KeyValueIterator<Windowed<Bytes>, byte[]> all =
+                 cachingStore.fetch(null, null)) {
             verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1");
             verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
             verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
@@ -150,7 +177,34 @@ public class CachingPersistentSessionStoreTest {
         assertEquals(3, cache.size());
 
         try (final KeyValueIterator<Windowed<Bytes>, byte[]> all =
-                 cachingStore.backwardFindSessions(keyA, keyB, 0, 0)) {
+                 cachingStore.backwardFetch(keyA, keyB)) {
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1");
+            assertFalse(all.hasNext());
+        }
+
+        // infinite keyFrom fetch
+        try (final KeyValueIterator<Windowed<Bytes>, byte[]> all =
+                 cachingStore.backwardFetch(null, keyB)) {
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1");
+            assertFalse(all.hasNext());
+        }
+
+        // infinite keyTo fetch
+        try (final KeyValueIterator<Windowed<Bytes>, byte[]> all =
+                 cachingStore.backwardFetch(keyA, null)) {
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1");
+            assertFalse(all.hasNext());
+        }
+
+        // infinite keyFrom and keyTo fetch
+        try (final KeyValueIterator<Windowed<Bytes>, byte[]> all =
+                 cachingStore.backwardFetch(null, null)) {
             verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
             verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
             verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1");
@@ -233,6 +287,31 @@ public class CachingPersistentSessionStoreTest {
             verifyWindowedKeyValue(some.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
             assertFalse(some.hasNext());
         }
+
+        // infinite keyFrom case
+        try (final KeyValueIterator<Windowed<Bytes>, byte[]> some =
+                 cachingStore.findSessions(null, keyAA, 0, 0)) {
+            verifyWindowedKeyValue(some.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
+            assertFalse(some.hasNext());
+        }
+
+        // infinite keyTo case
+        try (final KeyValueIterator<Windowed<Bytes>, byte[]> some =
+                 cachingStore.findSessions(keyAA, keyB, 0, 0)) {
+            verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(some.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
+            assertFalse(some.hasNext());
+        }
+
+        // infinite keyFrom and keyTo case
+        try (final KeyValueIterator<Windowed<Bytes>, byte[]> some =
+                 cachingStore.findSessions(null, null, 0, 0)) {
+            verifyWindowedKeyValue(some.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(some.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
+            assertFalse(some.hasNext());
+        }
     }
 
     @Test
@@ -249,6 +328,31 @@ public class CachingPersistentSessionStoreTest {
             verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
             assertFalse(some.hasNext());
         }
+
+        // infinite keyFrom case
+        try (final KeyValueIterator<Windowed<Bytes>, byte[]> some =
+                 cachingStore.backwardFindSessions(null, keyAA, 0, 0)) {
+            verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(some.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1");
+            assertFalse(some.hasNext());
+        }
+
+        // infinite keyTo case
+        try (final KeyValueIterator<Windowed<Bytes>, byte[]> some =
+                 cachingStore.backwardFindSessions(keyAA, keyB, 0, 0)) {
+            verifyWindowedKeyValue(some.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
+            assertFalse(some.hasNext());
+        }
+
+        // infinite keyFrom and keyTo case
+        try (final KeyValueIterator<Windowed<Bytes>, byte[]> some =
+                 cachingStore.backwardFindSessions(null, null, 0, 0)) {
+            verifyWindowedKeyValue(some.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1");
+            verifyWindowedKeyValue(some.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1");
+            assertFalse(some.hasNext());
+        }
     }
 
     @Test
@@ -652,26 +756,6 @@ public class CachingPersistentSessionStoreTest {
     }
 
     @Test
-    public void shouldThrowNullPointerExceptionOnFindSessionsNullFromKey() {
-        assertThrows(NullPointerException.class, () -> cachingStore.findSessions(null, keyA, 1L, 2L));
-    }
-
-    @Test
-    public void shouldThrowNullPointerExceptionOnFindSessionsNullToKey() {
-        assertThrows(NullPointerException.class, () -> cachingStore.findSessions(keyA, null, 1L, 2L));
-    }
-
-    @Test
-    public void shouldThrowNullPointerExceptionOnFetchNullFromKey() {
-        assertThrows(NullPointerException.class, () -> cachingStore.fetch(null, keyA));
-    }
-
-    @Test
-    public void shouldThrowNullPointerExceptionOnFetchNullToKey() {
-        assertThrows(NullPointerException.class, () -> cachingStore.fetch(keyA, null));
-    }
-
-    @Test
     public void shouldThrowNullPointerExceptionOnFetchNullKey() {
         assertThrows(NullPointerException.class, () -> cachingStore.fetch(null));
     }
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/CompositeReadOnlySessionStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/CompositeReadOnlySessionStoreTest.java
index 66ae243..c2d38de 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/CompositeReadOnlySessionStoreTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/CompositeReadOnlySessionStoreTest.java
@@ -146,21 +146,52 @@ public class CompositeReadOnlySessionStoreTest {
         underlyingSessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 0L);
         secondUnderlying.put(new Windowed<>("b", new SessionWindow(0, 0)), 10L);
         final List<KeyValue<Windowed<String>, Long>> results = StreamsTestUtils.toList(sessionStore.fetch("a", "b"));
-        assertThat(results.size(), equalTo(2));
+        assertThat(results, equalTo(Arrays.asList(
+            KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 0L),
+            KeyValue.pair(new Windowed<>("b", new SessionWindow(0, 0)), 10L))));
     }
 
     @Test
-    public void shouldThrowNPEIfKeyIsNull() {
-        assertThrows(NullPointerException.class, () -> underlyingSessionStore.fetch(null));
+    public void shouldFetchKeyRangeAcrossStoresWithNullKeyFrom() {
+        final ReadOnlySessionStoreStub<String, Long> secondUnderlying = new
+            ReadOnlySessionStoreStub<>();
+        stubProviderTwo.addStore(storeName, secondUnderlying);
+        underlyingSessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 0L);
+        secondUnderlying.put(new Windowed<>("b", new SessionWindow(0, 0)), 10L);
+        final List<KeyValue<Windowed<String>, Long>> results = StreamsTestUtils.toList(sessionStore.fetch(null, "b"));
+        assertThat(results, equalTo(Arrays.asList(
+            KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 0L),
+            KeyValue.pair(new Windowed<>("b", new SessionWindow(0, 0)), 10L))));
     }
 
     @Test
-    public void shouldThrowNPEIfFromKeyIsNull() {
-        assertThrows(NullPointerException.class, () -> underlyingSessionStore.fetch(null, "a"));
+    public void shouldFetchKeyRangeAcrossStoresWithNullKeyTo() {
+        final ReadOnlySessionStoreStub<String, Long> secondUnderlying = new
+            ReadOnlySessionStoreStub<>();
+        stubProviderTwo.addStore(storeName, secondUnderlying);
+        underlyingSessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 0L);
+        secondUnderlying.put(new Windowed<>("b", new SessionWindow(0, 0)), 10L);
+        final List<KeyValue<Windowed<String>, Long>> results = StreamsTestUtils.toList(sessionStore.fetch("a", null));
+        assertThat(results, equalTo(Arrays.asList(
+            KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 0L),
+            KeyValue.pair(new Windowed<>("b", new SessionWindow(0, 0)), 10L))));
     }
 
     @Test
-    public void shouldThrowNPEIfToKeyIsNull() {
-        assertThrows(NullPointerException.class, () -> underlyingSessionStore.fetch("a", null));
+    public void shouldFetchKeyRangeAcrossStoresWithNullKeyFromKeyTo() {
+        final ReadOnlySessionStoreStub<String, Long> secondUnderlying = new
+            ReadOnlySessionStoreStub<>();
+        stubProviderTwo.addStore(storeName, secondUnderlying);
+        underlyingSessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 0L);
+        secondUnderlying.put(new Windowed<>("b", new SessionWindow(0, 0)), 10L);
+        final List<KeyValue<Windowed<String>, Long>> results = StreamsTestUtils.toList(sessionStore.fetch(null, null));
+        assertThat(results, equalTo(Arrays.asList(
+            KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 0L),
+            KeyValue.pair(new Windowed<>("b", new SessionWindow(0, 0)), 10L))));
+    }
+
+    @Test
+    public void shouldThrowNPEIfKeyIsNull() {
+        assertThrows(NullPointerException.class, () -> underlyingSessionStore.fetch(null));
     }
 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/SessionStoreFetchTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/SessionStoreFetchTest.java
index 1e274a6..fbb6e00 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/SessionStoreFetchTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/SessionStoreFetchTest.java
@@ -20,6 +20,7 @@ package org.apache.kafka.streams.state.internals;
 import org.apache.kafka.common.serialization.Serdes;
 import org.apache.kafka.common.serialization.StringSerializer;
 import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.KeyValue;
 import org.apache.kafka.streams.StreamsBuilder;
 import org.apache.kafka.streams.StreamsConfig;
@@ -38,6 +39,7 @@ import org.apache.kafka.streams.state.SessionBytesStoreSupplier;
 import org.apache.kafka.streams.state.SessionStore;
 import org.apache.kafka.streams.state.Stores;
 import org.apache.kafka.test.TestUtils;
+import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
@@ -53,12 +55,15 @@ import java.util.Iterator;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Properties;
+import java.util.function.Predicate;
 import java.util.function.Supplier;
 
 import static java.time.Duration.ofMillis;
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.common.utils.Utils.mkProperties;
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.MatcherAssert.assertThat;
 
 @RunWith(Parameterized.class)
 public class SessionStoreFetchTest {
@@ -76,6 +81,13 @@ public class SessionStoreFetchTest {
     private LinkedList<KeyValue<Windowed<String>, Long>> expectedRecords;
     private LinkedList<KeyValue<String, String>> records;
     private Properties streamsConfig;
+    private String low;
+    private String high;
+    private String middle;
+    private String innerLow;
+    private String innerHigh;
+    private String innerLowBetween;
+    private String innerHighBetween;
 
     public SessionStoreFetchTest(final StoreType storeType, final boolean enableLogging, final boolean enableCaching, final boolean forward) {
         this.storeType = storeType;
@@ -95,7 +107,32 @@ public class SessionStoreFetchTest {
             final KeyValue<String, String> r2 = new KeyValue<>(key2, value);
             records.add(r);
             records.add(r2);
+            high = key;
+            if (low == null) {
+                low = key;
+            }
+            if (i == m) {
+                middle = key;
+            }
+            if (i == 1) {
+                innerLow = key;
+                final int index = i * 2 - 1;
+                innerLowBetween = "key-" + index;
+            }
+            if (i == DATA_SIZE - 2) {
+                innerHigh = key;
+                final int index = i * 2 + 1;
+                innerHighBetween = "key-" + index;
+            }
         }
+        Assert.assertNotNull(low);
+        Assert.assertNotNull(high);
+        Assert.assertNotNull(middle);
+        Assert.assertNotNull(innerLow);
+        Assert.assertNotNull(innerHigh);
+        Assert.assertNotNull(innerLowBetween);
+        Assert.assertNotNull(innerHighBetween);
+
         expectedRecords.add(new KeyValue<>(new Windowed<>("key-a", new SessionWindow(0, 500)), 4L));
         expectedRecords.add(new KeyValue<>(new Windowed<>("key-aa", new SessionWindow(0, 500)), 4L));
         expectedRecords.add(new KeyValue<>(new Windowed<>("key-b", new SessionWindow(1500, 2000)), 6L));
@@ -121,10 +158,66 @@ public class SessionStoreFetchTest {
         ));
     }
 
+    private void verifyNormalQuery(final SessionStore<String, Long> stateStore) {
+        try (final KeyValueIterator<Windowed<String>, Long> scanIterator = forward ?
+            stateStore.fetch("key-a", "key-bb") :
+            stateStore.backwardFetch("key-a", "key-bb")) {
+
+            final Iterator<KeyValue<Windowed<String>, Long>> dataIterator = forward ?
+                expectedRecords.iterator() :
+                expectedRecords.descendingIterator();
+
+            TestUtils.checkEquals(scanIterator, dataIterator);
+        }
+
+        try (final KeyValueIterator<Windowed<String>, Long> scanIterator = forward ?
+            stateStore.findSessions("key-a", "key-bb", 0L, Long.MAX_VALUE) :
+            stateStore.backwardFindSessions("key-a", "key-bb", 0L, Long.MAX_VALUE)) {
+
+            final Iterator<KeyValue<Windowed<String>, Long>> dataIterator = forward ?
+                expectedRecords.iterator() :
+                expectedRecords.descendingIterator();
+
+            TestUtils.checkEquals(scanIterator, dataIterator);
+        }
+    }
+
+    private void verifyInfiniteQuery(final SessionStore<String, Long> stateStore) {
+        try (final KeyValueIterator<Windowed<String>, Long> scanIterator = forward ?
+            stateStore.fetch(null, null) :
+            stateStore.backwardFetch(null, null)) {
+
+            final Iterator<KeyValue<Windowed<String>, Long>> dataIterator = forward ?
+                expectedRecords.iterator() :
+                expectedRecords.descendingIterator();
+
+            TestUtils.checkEquals(scanIterator, dataIterator);
+        }
+
+        try (final KeyValueIterator<Windowed<String>, Long> scanIterator = forward ?
+            stateStore.findSessions(null, null, 0L, Long.MAX_VALUE) :
+            stateStore.backwardFindSessions(null, null, 0L, Long.MAX_VALUE)) {
+
+            final Iterator<KeyValue<Windowed<String>, Long>> dataIterator = forward ?
+                expectedRecords.iterator() :
+                expectedRecords.descendingIterator();
+
+            TestUtils.checkEquals(scanIterator, dataIterator);
+        }
+    }
+
+    private void verifyRangeQuery(final SessionStore<String, Long> stateStore) {
+        testRange("range", stateStore, innerLow, innerHigh, forward);
+        testRange("until", stateStore, null, middle, forward);
+        testRange("from", stateStore, middle, null, forward);
+
+        testRange("untilBetween", stateStore, null, innerHighBetween, forward);
+        testRange("fromBetween", stateStore, innerLowBetween, null, forward);
+    }
+
     @Test
     public void testStoreConfig() {
         final Materialized<String, Long, SessionStore<Bytes, byte[]>> stateStoreConfig = getStoreConfig(storeType, STORE_NAME, enableLogging, enableCaching);
-        //Create topology: table from input topic
         final StreamsBuilder builder = new StreamsBuilder();
 
         final KStream<String, String> stream = builder.stream("input", Consumed.with(Serdes.String(), Serdes.String()));
@@ -152,28 +245,36 @@ public class SessionStoreFetchTest {
                 input.pipeInput(kv.key, kv.value, windowStartTime + WINDOW_SIZE);
             }
 
-            // query the state store
-            try (final KeyValueIterator<Windowed<String>, Long> scanIterator = forward ?
-                stateStore.fetch("key-a", "key-bb") :
-                stateStore.backwardFetch("key-a", "key-bb")) {
+            verifyNormalQuery(stateStore);
+            verifyInfiniteQuery(stateStore);
+            verifyRangeQuery(stateStore);
+        }
+    }
 
-                final Iterator<KeyValue<Windowed<String>, Long>> dataIterator = forward ?
-                    expectedRecords.iterator() :
-                    expectedRecords.descendingIterator();
 
-                TestUtils.checkEquals(scanIterator, dataIterator);
+    private List<KeyValue<Windowed<String>, Long>> filterList(final KeyValueIterator<Windowed<String>, Long> iterator, final String from, final String to) {
+        final Predicate<KeyValue<Windowed<String>, Long>> pred = new Predicate<KeyValue<Windowed<String>, Long>>() {
+            @Override
+            public boolean test(final KeyValue<Windowed<String>, Long> elem) {
+                if (from != null && elem.key.key().compareTo(from) < 0) {
+                    return false;
+                }
+                if (to != null && elem.key.key().compareTo(to) > 0) {
+                    return false;
+                }
+                return elem != null;
             }
+        };
 
-            try (final KeyValueIterator<Windowed<String>, Long> scanIterator = forward ?
-                stateStore.findSessions("key-a", "key-bb", 0L, Long.MAX_VALUE) :
-                stateStore.backwardFindSessions("key-a", "key-bb", 0L, Long.MAX_VALUE)) {
-
-                final Iterator<KeyValue<Windowed<String>, Long>> dataIterator = forward ?
-                    expectedRecords.iterator() :
-                    expectedRecords.descendingIterator();
+        return Utils.toList(iterator, pred);
+    }
 
-                TestUtils.checkEquals(scanIterator, dataIterator);
-            }
+    private void testRange(final String name, final SessionStore<String, Long> store, final String from, final String to, final boolean forward) {
+        try (final KeyValueIterator<Windowed<String>, Long> resultIterator = forward ? store.fetch(from, to) : store.backwardFetch(from, to);
+             final KeyValueIterator<Windowed<String>, Long> expectedIterator = forward ? store.fetch(null, null) : store.backwardFetch(null, null)) {
+            final List<KeyValue<Windowed<String>, Long>> result = Utils.toList(resultIterator);
+            final List<KeyValue<Windowed<String>, Long>> expected = filterList(expectedIterator, from, to);
+            assertThat(result, is(expected));
         }
     }
 
diff --git a/streams/src/test/java/org/apache/kafka/test/ReadOnlySessionStoreStub.java b/streams/src/test/java/org/apache/kafka/test/ReadOnlySessionStoreStub.java
index 0664055..61ea1ff 100644
--- a/streams/src/test/java/org/apache/kafka/test/ReadOnlySessionStoreStub.java
+++ b/streams/src/test/java/org/apache/kafka/test/ReadOnlySessionStoreStub.java
@@ -94,11 +94,13 @@ public class ReadOnlySessionStoreStub<K, V> implements ReadOnlySessionStore<K, V
         if (!open) {
             throw new InvalidStateStoreException("not open");
         }
-        if (sessions.subMap(keyFrom, true, keyTo, true).isEmpty()) {
+
+        NavigableMap<K, List<KeyValue<Windowed<K>, V>>> subSessionsMap = getSubSessionsMap(keyFrom, keyTo);
+
+        if (subSessionsMap.isEmpty()) {
             return new KeyValueIteratorStub<>(Collections.<KeyValue<Windowed<K>, V>>emptyIterator());
         }
-        final Iterator<List<KeyValue<Windowed<K>, V>>> keysIterator = sessions.subMap(keyFrom, true,
-            keyTo, true).values().iterator();
+        final Iterator<List<KeyValue<Windowed<K>, V>>> keysIterator = subSessionsMap.values().iterator();
         return new KeyValueIteratorStub<>(
             new Iterator<KeyValue<Windowed<K>, V>>() {
 
@@ -124,16 +126,33 @@ public class ReadOnlySessionStoreStub<K, V> implements ReadOnlySessionStore<K, V
         );
     }
 
+    private NavigableMap<K, List<KeyValue<Windowed<K>, V>>> getSubSessionsMap(final K keyFrom, final K keyTo) {
+        final NavigableMap<K, List<KeyValue<Windowed<K>, V>>> subSessionsMap;
+        if (keyFrom == null && keyTo == null) { // fetch all
+            subSessionsMap = sessions;
+        } else if (keyFrom == null) {
+            subSessionsMap = sessions.headMap(keyTo, true);
+        } else if (keyTo == null) {
+            subSessionsMap = sessions.tailMap(keyFrom, true);
+        } else {
+            subSessionsMap = sessions.subMap(keyFrom, true, keyTo, true);
+        }
+        return subSessionsMap;
+    }
+
     @Override
     public KeyValueIterator<Windowed<K>, V> backwardFetch(K keyFrom, K keyTo) {
         if (!open) {
             throw new InvalidStateStoreException("not open");
         }
-        if (sessions.subMap(keyFrom, true, keyTo, true).isEmpty()) {
+
+        NavigableMap<K, List<KeyValue<Windowed<K>, V>>> subSessionsMap = getSubSessionsMap(keyFrom, keyTo);
+
+        if (subSessionsMap.isEmpty()) {
             return new KeyValueIteratorStub<>(Collections.emptyIterator());
         }
-        final Iterator<List<KeyValue<Windowed<K>, V>>> keysIterator =
-            sessions.subMap(keyFrom, true, keyTo, true).descendingMap().values().iterator();
+
+        final Iterator<List<KeyValue<Windowed<K>, V>>> keysIterator = subSessionsMap.descendingMap().values().iterator();
         return new KeyValueIteratorStub<>(
             new Iterator<KeyValue<Windowed<K>, V>>() {