You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by ab...@apache.org on 2021/09/29 00:53:45 UTC

[kafka] branch 3.0 updated: KAFKA-13309: fix InMemorySessionStore#fetch/backwardFetch order issue (#11337)

This is an automated email from the ASF dual-hosted git repository.

ableegoldman pushed a commit to branch 3.0
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/3.0 by this push:
     new cc89c9f  KAFKA-13309: fix InMemorySessionStore#fetch/backwardFetch order issue (#11337)
cc89c9f is described below

commit cc89c9fe5f8fea22e72331130a92218296f39323
Author: Luke Chen <sh...@gmail.com>
AuthorDate: Wed Sep 29 08:31:29 2021 +0800

    KAFKA-13309: fix InMemorySessionStore#fetch/backwardFetch order issue (#11337)
    
    In #9139, we added backward iterator on SessionStore. But there is a bug that when fetch/backwardFetch the key range, if there are multiple records in the same session window, we can't return the data in the correct order.
    
    Reviewers: Anna Sophie Blee-Goldman <ableegoldman<apache.org>
---
 .../state/internals/InMemorySessionStore.java      |   4 +-
 .../internals/AbstractSessionBytesStoreTest.java   | 147 +++++++------
 .../internals/CachingInMemorySessionStoreTest.java |  11 +-
 .../state/internals/SessionStoreFetchTest.java     | 232 +++++++++++++++++++++
 4 files changed, 323 insertions(+), 71 deletions(-)

diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemorySessionStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemorySessionStore.java
index 722ed43..a14d3e1 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemorySessionStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemorySessionStore.java
@@ -281,7 +281,7 @@ public class InMemorySessionStore implements SessionStore<Bytes, byte[]> {
         removeExpiredSegments();
 
 
-        return registerNewIterator(keyFrom, keyTo, Long.MAX_VALUE, endTimeMap.entrySet().iterator(), false);
+        return registerNewIterator(keyFrom, keyTo, Long.MAX_VALUE, endTimeMap.entrySet().iterator(), true);
     }
 
     @Override
@@ -292,7 +292,7 @@ public class InMemorySessionStore implements SessionStore<Bytes, byte[]> {
         removeExpiredSegments();
 
         return registerNewIterator(
-            keyFrom, keyTo, Long.MAX_VALUE, endTimeMap.descendingMap().entrySet().iterator(), true);
+            keyFrom, keyTo, Long.MAX_VALUE, endTimeMap.descendingMap().entrySet().iterator(), false);
     }
 
     @Override
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractSessionBytesStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractSessionBytesStoreTest.java
index 6a4fb71..e240cce 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractSessionBytesStoreTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractSessionBytesStoreTest.java
@@ -47,15 +47,15 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashSet;
+import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
 import java.util.Properties;
-import java.util.Set;
 
 import static java.util.Arrays.asList;
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
-import static org.apache.kafka.test.StreamsTestUtils.toSet;
+import static org.apache.kafka.common.utils.Utils.toList;
 import static org.apache.kafka.test.StreamsTestUtils.valuesToSet;
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.CoreMatchers.hasItem;
@@ -121,7 +121,7 @@ public abstract class AbstractSessionBytesStoreTest {
 
         try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.findSessions(key, 0, 1000L)
         ) {
-            assertEquals(new HashSet<>(expected), toSet(values));
+            assertEquals(expected, toList(values));
         }
 
         final List<KeyValue<Windowed<String>, Long>> expected2 =
@@ -129,7 +129,7 @@ public abstract class AbstractSessionBytesStoreTest {
 
         try (final KeyValueIterator<Windowed<String>, Long> values2 = sessionStore.findSessions(key, 400L, 600L)
         ) {
-            assertEquals(new HashSet<>(expected2), toSet(values2));
+            assertEquals(expected2, toList(values2));
         }
     }
 
@@ -143,28 +143,29 @@ public abstract class AbstractSessionBytesStoreTest {
         sessionStore.put(new Windowed<>(key, new SessionWindow(1500L, 2000L)), 1L);
         sessionStore.put(new Windowed<>(key, new SessionWindow(2500L, 3000L)), 2L);
 
-        final List<KeyValue<Windowed<String>, Long>> expected =
-            asList(KeyValue.pair(a1, 1L), KeyValue.pair(a2, 2L));
+        final LinkedList<KeyValue<Windowed<String>, Long>> expected = new LinkedList<>();
+        expected.add(KeyValue.pair(a1, 1L));
+        expected.add(KeyValue.pair(a2, 2L));
 
         try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.backwardFindSessions(key, 0, 1000L)) {
-            assertEquals(new HashSet<>(expected), toSet(values));
+            assertEquals(toList(expected.descendingIterator()), toList(values));
         }
 
         final List<KeyValue<Windowed<String>, Long>> expected2 =
             Collections.singletonList(KeyValue.pair(a2, 2L));
 
         try (final KeyValueIterator<Windowed<String>, Long> values2 = sessionStore.backwardFindSessions(key, 400L, 600L)) {
-            assertEquals(new HashSet<>(expected2), toSet(values2));
+            assertEquals(expected2, toList(values2));
         }
     }
 
     @Test
     public void shouldFetchAllSessionsWithSameRecordKey() {
-        final List<KeyValue<Windowed<String>, Long>> expected = Arrays.asList(
-            KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 1L),
-            KeyValue.pair(new Windowed<>("a", new SessionWindow(10, 10)), 2L),
-            KeyValue.pair(new Windowed<>("a", new SessionWindow(100, 100)), 3L),
-            KeyValue.pair(new Windowed<>("a", new SessionWindow(1000, 1000)), 4L));
+        final LinkedList<KeyValue<Windowed<String>, Long>> expected = new LinkedList<>();
+        expected.add(KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 1L));
+        expected.add(KeyValue.pair(new Windowed<>("a", new SessionWindow(10, 10)), 2L));
+        expected.add(KeyValue.pair(new Windowed<>("a", new SessionWindow(100, 100)), 3L));
+        expected.add(KeyValue.pair(new Windowed<>("a", new SessionWindow(1000, 1000)), 4L));
 
         for (final KeyValue<Windowed<String>, Long> kv : expected) {
             sessionStore.put(kv.key, kv.value);
@@ -174,18 +175,17 @@ public abstract class AbstractSessionBytesStoreTest {
         sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 0)), 5L);
 
         try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.fetch("a")) {
-            assertEquals(new HashSet<>(expected), toSet(values));
+            assertEquals(expected, toList(values));
         }
     }
 
     @Test
     public void shouldBackwardFetchAllSessionsWithSameRecordKey() {
-        final List<KeyValue<Windowed<String>, Long>> expected = asList(
-            KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 1L),
-            KeyValue.pair(new Windowed<>("a", new SessionWindow(10, 10)), 2L),
-            KeyValue.pair(new Windowed<>("a", new SessionWindow(100, 100)), 3L),
-            KeyValue.pair(new Windowed<>("a", new SessionWindow(1000, 1000)), 4L)
-        );
+        final LinkedList<KeyValue<Windowed<String>, Long>> expected = new LinkedList<>();
+        expected.add(KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 1L));
+        expected.add(KeyValue.pair(new Windowed<>("a", new SessionWindow(10, 10)), 2L));
+        expected.add(KeyValue.pair(new Windowed<>("a", new SessionWindow(100, 100)), 3L));
+        expected.add(KeyValue.pair(new Windowed<>("a", new SessionWindow(1000, 1000)), 4L));
 
         for (final KeyValue<Windowed<String>, Long> kv : expected) {
             sessionStore.put(kv.key, kv.value);
@@ -195,18 +195,18 @@ public abstract class AbstractSessionBytesStoreTest {
         sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 0)), 5L);
 
         try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.backwardFetch("a")) {
-            assertEquals(new HashSet<>(expected), toSet(values));
+            assertEquals(toList(expected.descendingIterator()), toList(values));
         }
     }
 
     @Test
     public void shouldFetchAllSessionsWithinKeyRange() {
-        final List<KeyValue<Windowed<String>, Long>> expected = Arrays.asList(
-            KeyValue.pair(new Windowed<>("aa", new SessionWindow(10, 10)), 2L),
-            KeyValue.pair(new Windowed<>("b", new SessionWindow(1000, 1000)), 4L),
-
-            KeyValue.pair(new Windowed<>("aaa", new SessionWindow(100, 100)), 3L),
-            KeyValue.pair(new Windowed<>("bb", new SessionWindow(1500, 2000)), 5L));
+        final List<KeyValue<Windowed<String>, Long>> expected = new LinkedList<>();
+        expected.add(KeyValue.pair(new Windowed<>("aa", new SessionWindow(10, 10)), 2L));
+        expected.add(KeyValue.pair(new Windowed<>("aaa", new SessionWindow(100, 100)), 3L));
+        expected.add(KeyValue.pair(new Windowed<>("aaaa", new SessionWindow(100, 100)), 6L));
+        expected.add(KeyValue.pair(new Windowed<>("b", new SessionWindow(1000, 1000)), 4L));
+        expected.add(KeyValue.pair(new Windowed<>("bb", new SessionWindow(1500, 2000)), 5L));
 
         for (final KeyValue<Windowed<String>, Long> kv : expected) {
             sessionStore.put(kv.key, kv.value);
@@ -217,19 +217,22 @@ public abstract class AbstractSessionBytesStoreTest {
         sessionStore.put(new Windowed<>("bbb", new SessionWindow(2500, 3000)), 6L);
 
         try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.fetch("aa", "bb")) {
-            assertEquals(new HashSet<>(expected), toSet(values));
+            assertEquals(expected, toList(values));
+        }
+
+        try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.findSessions("aa", "bb", 0L, Long.MAX_VALUE)) {
+            assertEquals(expected, toList(values));
         }
     }
 
     @Test
     public void shouldBackwardFetchAllSessionsWithinKeyRange() {
-        final List<KeyValue<Windowed<String>, Long>> expected = asList(
-            KeyValue.pair(new Windowed<>("aa", new SessionWindow(10, 10)), 2L),
-            KeyValue.pair(new Windowed<>("b", new SessionWindow(1000, 1000)), 4L),
-
-            KeyValue.pair(new Windowed<>("aaa", new SessionWindow(100, 100)), 3L),
-            KeyValue.pair(new Windowed<>("bb", new SessionWindow(1500, 2000)), 5L)
-        );
+        final LinkedList<KeyValue<Windowed<String>, Long>> expected = new LinkedList<>();
+        expected.add(KeyValue.pair(new Windowed<>("aa", new SessionWindow(10, 10)), 2L));
+        expected.add(KeyValue.pair(new Windowed<>("aaa", new SessionWindow(100, 100)), 3L));
+        expected.add(KeyValue.pair(new Windowed<>("aaaa", new SessionWindow(100, 100)), 6L));
+        expected.add(KeyValue.pair(new Windowed<>("b", new SessionWindow(1000, 1000)), 4L));
+        expected.add(KeyValue.pair(new Windowed<>("bb", new SessionWindow(1500, 2000)), 5L));
 
         for (final KeyValue<Windowed<String>, Long> kv : expected) {
             sessionStore.put(kv.key, kv.value);
@@ -240,7 +243,11 @@ public abstract class AbstractSessionBytesStoreTest {
         sessionStore.put(new Windowed<>("bbb", new SessionWindow(2500, 3000)), 6L);
 
         try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.backwardFetch("aa", "bb")) {
-            assertEquals(new HashSet<>(expected), toSet(values));
+            assertEquals(toList(expected.descendingIterator()), toList(values));
+        }
+
+        try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.backwardFindSessions("aa", "bb", 0L, Long.MAX_VALUE)) {
+            assertEquals(toList(expected.descendingIterator()), toList(values));
         }
     }
 
@@ -272,7 +279,7 @@ public abstract class AbstractSessionBytesStoreTest {
             KeyValue.pair(new Windowed<>(key, new SessionWindow(1000L, 1000L)), 2L));
 
         try (final KeyValueIterator<Windowed<String>, Long> results = sessionStore.findSessions(key, -1, 1000L)) {
-            assertEquals(new HashSet<>(expected), toSet(results));
+            assertEquals(expected, toList(results));
         }
     }
 
@@ -282,13 +289,12 @@ public abstract class AbstractSessionBytesStoreTest {
         sessionStore.put(new Windowed<>(key, new SessionWindow(0L, 0L)), 1L);
         sessionStore.put(new Windowed<>(key, new SessionWindow(1000L, 1000L)), 2L);
 
-        final List<KeyValue<Windowed<String>, Long>> expected = asList(
-            KeyValue.pair(new Windowed<>(key, new SessionWindow(0L, 0L)), 1L),
-            KeyValue.pair(new Windowed<>(key, new SessionWindow(1000L, 1000L)), 2L)
-        );
+        final LinkedList<KeyValue<Windowed<String>, Long>> expected = new LinkedList<>();
+        expected.add(KeyValue.pair(new Windowed<>(key, new SessionWindow(0L, 0L)), 1L));
+        expected.add(KeyValue.pair(new Windowed<>(key, new SessionWindow(1000L, 1000L)), 2L));
 
         try (final KeyValueIterator<Windowed<String>, Long> results = sessionStore.backwardFindSessions(key, -1, 1000L)) {
-            assertEquals(new HashSet<>(expected), toSet(results));
+            assertEquals(toList(expected.descendingIterator()), toList(results));
         }
     }
 
@@ -341,7 +347,7 @@ public abstract class AbstractSessionBytesStoreTest {
             Arrays.asList(KeyValue.pair(session2, 2L), KeyValue.pair(session3, 3L));
 
         try (final KeyValueIterator<Windowed<String>, Long> results = sessionStore.findSessions("a", 150, 300)) {
-            assertEquals(new HashSet<>(expected), toSet(results));
+            assertEquals(expected, toList(results));
         }
     }
 
@@ -359,10 +365,10 @@ public abstract class AbstractSessionBytesStoreTest {
         sessionStore.put(session5, 5L);
 
         final List<KeyValue<Windowed<String>, Long>> expected =
-            asList(KeyValue.pair(session2, 2L), KeyValue.pair(session3, 3L));
+            asList(KeyValue.pair(session3, 3L), KeyValue.pair(session2, 2L));
 
         try (final KeyValueIterator<Windowed<String>, Long> results = sessionStore.backwardFindSessions("a", 150, 300)) {
-            assertEquals(new HashSet<>(expected), toSet(results));
+            assertEquals(expected, toList(results));
         }
     }
 
@@ -400,7 +406,7 @@ public abstract class AbstractSessionBytesStoreTest {
         try (final KeyValueIterator<Windowed<String>, Long> iterator =
                  sessionStore.findSessions("a", "aa", 10, 0)
         ) {
-            assertThat(valuesToSet(iterator), equalTo(new HashSet<>(Collections.singletonList(2L))));
+            assertThat(valuesToSet(iterator), equalTo(new HashSet<>(asList(2L))));
         }
     }
 
@@ -438,7 +444,7 @@ public abstract class AbstractSessionBytesStoreTest {
         try (final KeyValueIterator<Windowed<String>, Long> iterator =
                  sessionStore.backwardFindSessions("a", "aa", 10, 0)
         ) {
-            assertThat(valuesToSet(iterator), equalTo(new HashSet<>(Collections.singletonList(2L))));
+            assertThat(valuesToSet(iterator), equalTo(new HashSet<>(asList(2L))));
         }
     }
 
@@ -463,12 +469,20 @@ public abstract class AbstractSessionBytesStoreTest {
         sessionStore.put(new Windowed<>(key2, new SessionWindow(8, 100)), "8");
         sessionStore.put(new Windowed<>(key3, new SessionWindow(9, 100)), "9");
 
-        final Set<String> expectedKey1 = new HashSet<>(asList("1", "4", "7"));
-        assertThat(valuesToSet(sessionStore.findSessions(key1, 0L, Long.MAX_VALUE)), equalTo(expectedKey1));
-        final Set<String> expectedKey2 = new HashSet<>(asList("2", "5", "8"));
-        assertThat(valuesToSet(sessionStore.findSessions(key2, 0L, Long.MAX_VALUE)), equalTo(expectedKey2));
-        final Set<String> expectedKey3 = new HashSet<>(asList("3", "6", "9"));
-        assertThat(valuesToSet(sessionStore.findSessions(key3, 0L, Long.MAX_VALUE)), equalTo(expectedKey3));
+        final List<String> expectedKey1 = asList("1", "4", "7");
+        try (KeyValueIterator<Windowed<Bytes>, String> iterator = sessionStore.findSessions(key1, 0L, Long.MAX_VALUE)) {
+            assertThat(valuesToSet(iterator), equalTo(new HashSet<>(expectedKey1)));
+        }
+
+        final List<String> expectedKey2 = asList("2", "5", "8");
+        try (KeyValueIterator<Windowed<Bytes>, String> iterator = sessionStore.findSessions(key2, 0L, Long.MAX_VALUE)) {
+            assertThat(valuesToSet(iterator), equalTo(new HashSet<>(expectedKey2)));
+        }
+
+        final List<String> expectedKey3 = asList("3", "6", "9");
+        try (KeyValueIterator<Windowed<Bytes>, String> iterator = sessionStore.findSessions(key3, 0L, Long.MAX_VALUE)) {
+            assertThat(valuesToSet(iterator), equalTo(new HashSet<>(expectedKey3)));
+        }
 
         sessionStore.close();
     }
@@ -494,12 +508,21 @@ public abstract class AbstractSessionBytesStoreTest {
         sessionStore.put(new Windowed<>(key2, new SessionWindow(8, 100)), "8");
         sessionStore.put(new Windowed<>(key3, new SessionWindow(9, 100)), "9");
 
-        final Set<String> expectedKey1 = new HashSet<>(asList("1", "4", "7"));
-        assertThat(valuesToSet(sessionStore.backwardFindSessions(key1, 0L, Long.MAX_VALUE)), equalTo(expectedKey1));
-        final Set<String> expectedKey2 = new HashSet<>(asList("2", "5", "8"));
-        assertThat(valuesToSet(sessionStore.backwardFindSessions(key2, 0L, Long.MAX_VALUE)), equalTo(expectedKey2));
-        final Set<String> expectedKey3 = new HashSet<>(asList("3", "6", "9"));
-        assertThat(valuesToSet(sessionStore.backwardFindSessions(key3, 0L, Long.MAX_VALUE)), equalTo(expectedKey3));
+
+        final List<String> expectedKey1 = asList("7", "4", "1");
+        try (KeyValueIterator<Windowed<Bytes>, String> iterator = sessionStore.backwardFindSessions(key1, 0L, Long.MAX_VALUE)) {
+            assertThat(valuesToSet(iterator), equalTo(new HashSet<>(expectedKey1)));
+        }
+
+        final List<String> expectedKey2 = asList("8", "5", "2");
+        try (KeyValueIterator<Windowed<Bytes>, String> iterator = sessionStore.backwardFindSessions(key2, 0L, Long.MAX_VALUE)) {
+            assertThat(valuesToSet(iterator), equalTo(new HashSet<>(expectedKey2)));
+        }
+
+        final List<String> expectedKey3 = asList("9", "6", "3");
+        try (KeyValueIterator<Windowed<Bytes>, String> iterator = sessionStore.backwardFindSessions(key3, 0L, Long.MAX_VALUE)) {
+            assertThat(valuesToSet(iterator), equalTo(new HashSet<>(expectedKey3)));
+        }
 
         sessionStore.close();
     }
@@ -548,13 +571,13 @@ public abstract class AbstractSessionBytesStoreTest {
         }
 
         try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.fetch("a")) {
-            assertEquals(new HashSet<>(expected), toSet(values));
+            assertEquals(expected, toList(values));
         }
 
         sessionStore.close();
 
         try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.fetch("a")) {
-            assertEquals(Collections.emptySet(), toSet(values));
+            assertEquals(Collections.emptyList(), toList(values));
         }
 
 
@@ -566,7 +589,7 @@ public abstract class AbstractSessionBytesStoreTest {
         context.restore(sessionStore.name(), changeLog);
 
         try (final KeyValueIterator<Windowed<String>, Long> values = sessionStore.fetch("a")) {
-            assertEquals(new HashSet<>(expected), toSet(values));
+            assertEquals(expected, toList(values));
         }
     }
 
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingInMemorySessionStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingInMemorySessionStoreTest.java
index 67258e7..1b44b91 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingInMemorySessionStoreTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingInMemorySessionStoreTest.java
@@ -326,13 +326,10 @@ public class CachingInMemorySessionStoreTest {
         cachingStore.put(b, "2".getBytes());
         cachingStore.remove(a);
 
-        final KeyValueIterator<Windowed<Bytes>, byte[]> rangeIter =
-            cachingStore.findSessions(keyA, 0, 0);
-        assertFalse(rangeIter.hasNext());
-
-        assertNull(cachingStore.fetchSession(keyA, 0, 0));
-        assertThat(cachingStore.fetchSession(keyB, 0, 0), equalTo("2".getBytes()));
-
+        try (final KeyValueIterator<Windowed<Bytes>, byte[]> rangeIter =
+                 cachingStore.findSessions(keyA, 0, 0)) {
+            assertFalse(rangeIter.hasNext());
+        }
     }
 
     @Test
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/SessionStoreFetchTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/SessionStoreFetchTest.java
new file mode 100644
index 0000000..1e274a6
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/SessionStoreFetchTest.java
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.streams.state.internals;
+
+import org.apache.kafka.common.serialization.Serdes;
+import org.apache.kafka.common.serialization.StringSerializer;
+import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.streams.KeyValue;
+import org.apache.kafka.streams.StreamsBuilder;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.TestInputTopic;
+import org.apache.kafka.streams.Topology;
+import org.apache.kafka.streams.TopologyTestDriver;
+import org.apache.kafka.streams.kstream.Consumed;
+import org.apache.kafka.streams.kstream.Grouped;
+import org.apache.kafka.streams.kstream.KStream;
+import org.apache.kafka.streams.kstream.Materialized;
+import org.apache.kafka.streams.kstream.SessionWindows;
+import org.apache.kafka.streams.kstream.Windowed;
+import org.apache.kafka.streams.kstream.internals.SessionWindow;
+import org.apache.kafka.streams.state.KeyValueIterator;
+import org.apache.kafka.streams.state.SessionBytesStoreSupplier;
+import org.apache.kafka.streams.state.SessionStore;
+import org.apache.kafka.streams.state.Stores;
+import org.apache.kafka.test.TestUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TestName;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.time.Duration;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Properties;
+import java.util.function.Supplier;
+
+import static java.time.Duration.ofMillis;
+import static org.apache.kafka.common.utils.Utils.mkEntry;
+import static org.apache.kafka.common.utils.Utils.mkMap;
+import static org.apache.kafka.common.utils.Utils.mkProperties;
+
+@RunWith(Parameterized.class)
+public class SessionStoreFetchTest {
+    private enum StoreType { InMemory, RocksDB };
+    private static final String STORE_NAME = "store";
+    private static final int DATA_SIZE = 5;
+    private static final long WINDOW_SIZE = 500L;
+    private static final long RETENTION_MS = 10000L;
+
+    private StoreType storeType;
+    private boolean enableLogging;
+    private boolean enableCaching;
+    private boolean forward;
+
+    private LinkedList<KeyValue<Windowed<String>, Long>> expectedRecords;
+    private LinkedList<KeyValue<String, String>> records;
+    private Properties streamsConfig;
+
+    public SessionStoreFetchTest(final StoreType storeType, final boolean enableLogging, final boolean enableCaching, final boolean forward) {
+        this.storeType = storeType;
+        this.enableLogging = enableLogging;
+        this.enableCaching = enableCaching;
+        this.forward = forward;
+
+        this.records = new LinkedList<>();
+        this.expectedRecords = new LinkedList<>();
+        final int m = DATA_SIZE / 2;
+        for (int i = 0; i < DATA_SIZE; i++) {
+            final String keyStr = i < m ? "a" : "b";
+            final String key = "key-" + keyStr;
+            final String key2 = "key-" + keyStr + keyStr;
+            final String value = "val-" + i;
+            final KeyValue<String, String> r = new KeyValue<>(key, value);
+            final KeyValue<String, String> r2 = new KeyValue<>(key2, value);
+            records.add(r);
+            records.add(r2);
+        }
+        expectedRecords.add(new KeyValue<>(new Windowed<>("key-a", new SessionWindow(0, 500)), 4L));
+        expectedRecords.add(new KeyValue<>(new Windowed<>("key-aa", new SessionWindow(0, 500)), 4L));
+        expectedRecords.add(new KeyValue<>(new Windowed<>("key-b", new SessionWindow(1500, 2000)), 6L));
+        expectedRecords.add(new KeyValue<>(new Windowed<>("key-bb", new SessionWindow(1500, 2000)), 6L));
+    }
+
+    @Rule
+    public TestName testName = new TestName();
+
+    @Parameterized.Parameters(name = "storeType={0}, enableLogging={1}, enableCaching={2}, forward={3}")
+    public static Collection<Object[]> data() {
+        final List<StoreType> types = Arrays.asList(StoreType.InMemory, StoreType.RocksDB);
+        final List<Boolean> logging = Arrays.asList(true, false);
+        final List<Boolean> caching = Arrays.asList(true, false);
+        final List<Boolean> forward = Arrays.asList(true, false);
+        return buildParameters(types, logging, caching, forward);
+    }
+
+    @Before
+    public void setup() {
+        streamsConfig = mkProperties(mkMap(
+                mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath())
+        ));
+    }
+
+    @Test
+    public void testStoreConfig() {
+        final Materialized<String, Long, SessionStore<Bytes, byte[]>> stateStoreConfig = getStoreConfig(storeType, STORE_NAME, enableLogging, enableCaching);
+        //Create topology: table from input topic
+        final StreamsBuilder builder = new StreamsBuilder();
+
+        final KStream<String, String> stream = builder.stream("input", Consumed.with(Serdes.String(), Serdes.String()));
+        stream.
+            groupByKey(Grouped.with(Serdes.String(), Serdes.String()))
+            .windowedBy(SessionWindows.ofInactivityGapWithNoGrace(ofMillis(WINDOW_SIZE)))
+            .count(stateStoreConfig)
+            .toStream()
+            .to("output");
+
+        final Topology topology = builder.build();
+
+        try (final TopologyTestDriver driver = new TopologyTestDriver(topology)) {
+            //get input topic and stateStore
+            final TestInputTopic<String, String> input = driver
+                    .createInputTopic("input", new StringSerializer(), new StringSerializer());
+            final SessionStore<String, Long> stateStore = driver.getSessionStore(STORE_NAME);
+
+            //write some data
+            final int medium = DATA_SIZE / 2 * 2;
+            for (int i = 0; i < records.size(); i++) {
+                final KeyValue<String, String> kv = records.get(i);
+                final long windowStartTime = i < medium ? 0 : 1500;
+                input.pipeInput(kv.key, kv.value, windowStartTime);
+                input.pipeInput(kv.key, kv.value, windowStartTime + WINDOW_SIZE);
+            }
+
+            // query the state store
+            try (final KeyValueIterator<Windowed<String>, Long> scanIterator = forward ?
+                stateStore.fetch("key-a", "key-bb") :
+                stateStore.backwardFetch("key-a", "key-bb")) {
+
+                final Iterator<KeyValue<Windowed<String>, Long>> dataIterator = forward ?
+                    expectedRecords.iterator() :
+                    expectedRecords.descendingIterator();
+
+                TestUtils.checkEquals(scanIterator, dataIterator);
+            }
+
+            try (final KeyValueIterator<Windowed<String>, Long> scanIterator = forward ?
+                stateStore.findSessions("key-a", "key-bb", 0L, Long.MAX_VALUE) :
+                stateStore.backwardFindSessions("key-a", "key-bb", 0L, Long.MAX_VALUE)) {
+
+                final Iterator<KeyValue<Windowed<String>, Long>> dataIterator = forward ?
+                    expectedRecords.iterator() :
+                    expectedRecords.descendingIterator();
+
+                TestUtils.checkEquals(scanIterator, dataIterator);
+            }
+        }
+    }
+
+    private static Collection<Object[]> buildParameters(final List<?>... argOptions) {
+        List<Object[]> result = new LinkedList<>();
+        result.add(new Object[0]);
+
+        for (final List<?> argOption : argOptions) {
+            result = times(result, argOption);
+        }
+
+        return result;
+    }
+
+    private static List<Object[]> times(final List<Object[]> left, final List<?> right) {
+        final List<Object[]> result = new LinkedList<>();
+        for (final Object[] args : left) {
+            for (final Object rightElem : right) {
+                final Object[] resArgs = new Object[args.length + 1];
+                System.arraycopy(args, 0, resArgs, 0, args.length);
+                resArgs[args.length] = rightElem;
+                result.add(resArgs);
+            }
+        }
+        return result;
+    }
+
+    private Materialized<String, Long, SessionStore<Bytes, byte[]>> getStoreConfig(final StoreType type, final String name, final boolean cachingEnabled, final boolean loggingEnabled) {
+        final Supplier<SessionBytesStoreSupplier> createStore = () -> {
+            if (type == StoreType.InMemory) {
+                return Stores.inMemorySessionStore(STORE_NAME, Duration.ofMillis(RETENTION_MS));
+            } else if (type == StoreType.RocksDB) {
+                return Stores.persistentSessionStore(STORE_NAME, Duration.ofMillis(RETENTION_MS));
+            } else {
+                return Stores.inMemorySessionStore(STORE_NAME, Duration.ofMillis(RETENTION_MS));
+            }
+        };
+
+        final SessionBytesStoreSupplier stateStoreSupplier = createStore.get();
+        final Materialized<String, Long, SessionStore<Bytes, byte[]>> stateStoreConfig = Materialized
+                .<String, Long>as(stateStoreSupplier)
+                .withKeySerde(Serdes.String())
+                .withValueSerde(Serdes.Long());
+        if (cachingEnabled) {
+            stateStoreConfig.withCachingEnabled();
+        } else {
+            stateStoreConfig.withCachingDisabled();
+        }
+        if (loggingEnabled) {
+            stateStoreConfig.withLoggingEnabled(new HashMap<String, String>());
+        } else {
+            stateStoreConfig.withLoggingDisabled();
+        }
+        return stateStoreConfig;
+    }
+}