You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by vv...@apache.org on 2022/03/24 03:36:40 UTC

[kafka] branch 3.2 updated: KAFKA-13714: Fix cache flush position (#11926)

This is an automated email from the ASF dual-hosted git repository.

vvcephei pushed a commit to branch 3.2
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/3.2 by this push:
     new bde8cf3  KAFKA-13714: Fix cache flush position (#11926)
bde8cf3 is described below

commit bde8cf30ae75ae0758a554336ce71841208c050f
Author: John Roesler <vv...@users.noreply.github.com>
AuthorDate: Wed Mar 23 22:09:05 2022 -0500

    KAFKA-13714: Fix cache flush position (#11926)
    
    The caching store layers were passing down writes into lower store layers upon eviction, but not setting the context to the evicted records' context. Instead, the context was from whatever unrelated record was being processed at the time.
    
    Reviewers: Matthias J. Sax <mj...@apache.org>
---
 README.md                                          |  9 ++-
 build.gradle                                       | 12 +++-
 .../state/internals/CachingKeyValueStore.java      | 13 ++--
 .../state/internals/CachingSessionStore.java       | 12 +++-
 .../state/internals/CachingWindowStore.java        | 12 +++-
 .../streams/state/internals/RocksDBStore.java      |  2 +-
 .../streams/state/internals/StoreQueryUtils.java   |  4 +-
 .../integration/IQv2StoreIntegrationTest.java      | 73 ++++++++++++----------
 .../CachingInMemoryKeyValueStoreTest.java          | 39 ++++++++++--
 .../internals/CachingInMemorySessionStoreTest.java | 43 +++++++++++++
 .../CachingPersistentSessionStoreTest.java         | 45 ++++++++++++-
 .../CachingPersistentWindowStoreTest.java          | 43 +++++++++++++
 12 files changed, 250 insertions(+), 57 deletions(-)

diff --git a/README.md b/README.md
index 5e409f8..6dafd44 100644
--- a/README.md
+++ b/README.md
@@ -37,13 +37,16 @@ Follow instructions in https://kafka.apache.org/quickstart
     ./gradlew integrationTest
     
 ### Force re-running tests without code change ###
-    ./gradlew cleanTest test
-    ./gradlew cleanTest unitTest
-    ./gradlew cleanTest integrationTest
+    ./gradlew -Prerun-tests test
+    ./gradlew -Prerun-tests unitTest
+    ./gradlew -Prerun-tests integrationTest
 
 ### Running a particular unit/integration test ###
     ./gradlew clients:test --tests RequestResponseTest
 
+### Repeatedly running a particular unit/integration test ###
+    I=0; while ./gradlew clients:test -Prerun-tests --tests RequestResponseTest --fail-fast; do (( I=$I+1 )); echo "Completed run: $I"; sleep 1; done
+
 ### Running a particular test method within a unit/integration test ###
     ./gradlew core:test --tests kafka.api.ProducerFailureHandlingTest.testCannotSendToInternalTopic
     ./gradlew clients:test --tests org.apache.kafka.clients.MetadataTest.testTimeToNextUpdate
diff --git a/build.gradle b/build.gradle
index ff4bab9..3363a43 100644
--- a/build.gradle
+++ b/build.gradle
@@ -207,7 +207,7 @@ if (file('.git').exists()) {
 } else {
   rat.enabled = false
 }
-println("Starting build with version $version (commit id ${commitId.take(8)}) using Gradle $gradleVersion, Java ${JavaVersion.current()} and Scala ${versions.scala}")
+println("Starting build with version $version (commit id ${commitId == null ? "null" : commitId.take(8)}) using Gradle $gradleVersion, Java ${JavaVersion.current()} and Scala ${versions.scala}")
 println("Build properties: maxParallelForks=$maxTestForks, maxScalacThreads=$maxScalacThreads, maxTestRetries=$userMaxTestRetries")
 
 subprojects {
@@ -435,6 +435,11 @@ subprojects {
       maxRetries = userMaxTestRetries
       maxFailures = userMaxTestRetryFailures
     }
+
+    // Allows devs to run tests in a loop to debug flaky tests. See README.
+    if (project.hasProperty("rerun-tests")) {
+      outputs.upToDateWhen { false }
+    }
   }
 
   task integrationTest(type: Test, dependsOn: compileJava) {
@@ -468,6 +473,11 @@ subprojects {
       maxRetries = userMaxTestRetries
       maxFailures = userMaxTestRetryFailures
     }
+
+    // Allows devs to run tests in a loop to debug flaky tests. See README.
+    if (project.hasProperty("rerun-tests")) {
+      outputs.upToDateWhen { false }
+    }
   }
 
   task unitTest(type: Test, dependsOn: compileJava) {
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingKeyValueStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingKeyValueStore.java
index 1d08d20..04f2a0c 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingKeyValueStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingKeyValueStore.java
@@ -226,10 +226,9 @@ public class CachingKeyValueStore
             if (rawNewValue != null || rawOldValue != null) {
                 // we need to get the old values if needed, and then put to store, and then flush
                 final ProcessorRecordContext current = context.recordContext();
-                context.setRecordContext(entry.entry().context());
-                wrapped().put(entry.key(), entry.newValue());
-
                 try {
+                    context.setRecordContext(entry.entry().context());
+                    wrapped().put(entry.key(), entry.newValue());
                     flushListener.apply(
                         new Record<>(
                             entry.key().get(),
@@ -241,7 +240,13 @@ public class CachingKeyValueStore
                 }
             }
         } else {
-            wrapped().put(entry.key(), entry.newValue());
+            final ProcessorRecordContext current = context.recordContext();
+            try {
+                context.setRecordContext(entry.entry().context());
+                wrapped().put(entry.key(), entry.newValue());
+            } finally {
+                context.setRecordContext(current);
+            }
         }
     }
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingSessionStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingSessionStore.java
index 59d2a0e..cff10da 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingSessionStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingSessionStore.java
@@ -105,11 +105,11 @@ class CachingSessionStore
             // we can skip flushing to downstream as well as writing to underlying store
             if (newValueBytes != null || oldValueBytes != null) {
                 // we need to get the old values if needed, and then put to store, and then flush
-                wrapped().put(bytesKey, entry.newValue());
 
                 final ProcessorRecordContext current = context.recordContext();
-                context.setRecordContext(entry.entry().context());
                 try {
+                    context.setRecordContext(entry.entry().context());
+                    wrapped().put(bytesKey, entry.newValue());
                     flushListener.apply(
                         new Record<>(
                             binaryKey.get(),
@@ -121,7 +121,13 @@ class CachingSessionStore
                 }
             }
         } else {
-            wrapped().put(bytesKey, entry.newValue());
+            final ProcessorRecordContext current = context.recordContext();
+            try {
+                context.setRecordContext(entry.entry().context());
+                wrapped().put(bytesKey, entry.newValue());
+            } finally {
+                context.setRecordContext(current);
+            }
         }
     }
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingWindowStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingWindowStore.java
index 8a1f886..50ede9c 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingWindowStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingWindowStore.java
@@ -122,11 +122,11 @@ class CachingWindowStore
             // we can skip flushing to downstream as well as writing to underlying store
             if (rawNewValue != null || rawOldValue != null) {
                 // we need to get the old values if needed, and then put to store, and then flush
-                wrapped().put(binaryKey, entry.newValue(), windowStartTimestamp);
 
                 final ProcessorRecordContext current = context.recordContext();
-                context.setRecordContext(entry.entry().context());
                 try {
+                    context.setRecordContext(entry.entry().context());
+                    wrapped().put(binaryKey, entry.newValue(), windowStartTimestamp);
                     flushListener.apply(
                         new Record<>(
                             binaryWindowKey,
@@ -138,7 +138,13 @@ class CachingWindowStore
                 }
             }
         } else {
-            wrapped().put(binaryKey, entry.newValue(), windowStartTimestamp);
+            final ProcessorRecordContext current = context.recordContext();
+            try {
+                context.setRecordContext(entry.entry().context());
+                wrapped().put(binaryKey, entry.newValue(), windowStartTimestamp);
+            } finally {
+                context.setRecordContext(current);
+            }
         }
     }
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBStore.java
index 919c440..1eb9a70 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBStore.java
@@ -326,8 +326,8 @@ public class RocksDBStore implements KeyValueStore<Bytes, byte[]>, BatchWritingS
     public void putAll(final List<KeyValue<Bytes, byte[]>> entries) {
         try (final WriteBatch batch = new WriteBatch()) {
             dbAccessor.prepareBatch(entries, batch);
-            StoreQueryUtils.updatePosition(position, context);
             write(batch);
+            StoreQueryUtils.updatePosition(position, context);
         } catch (final RocksDBException e) {
             throw new ProcessorStateException("Error while batch writing to store " + name, e);
         }
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/StoreQueryUtils.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/StoreQueryUtils.java
index 06b3713..4630195 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/StoreQueryUtils.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/StoreQueryUtils.java
@@ -142,7 +142,9 @@ public final class StoreQueryUtils {
 
         if (stateStoreContext != null && stateStoreContext.recordMetadata().isPresent()) {
             final RecordMetadata meta = stateStoreContext.recordMetadata().get();
-            position.withComponent(meta.topic(), meta.partition(), meta.offset());
+            if (meta.topic() != null) {
+                position.withComponent(meta.topic(), meta.partition(), meta.offset());
+            }
         }
     }
 
diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/IQv2StoreIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/IQv2StoreIntegrationTest.java
index f534d6d..1c828c7 100644
--- a/streams/src/test/java/org/apache/kafka/streams/integration/IQv2StoreIntegrationTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/integration/IQv2StoreIntegrationTest.java
@@ -760,45 +760,50 @@ public class IQv2StoreIntegrationTest {
 
     @Test
     public void verifyStore() {
-        if (storeToTest.global()) {
-            // See KAFKA-13523
-            globalShouldRejectAllQueries();
-        } else {
-            shouldRejectUnknownQuery();
-            shouldCollectExecutionInfo();
-            shouldCollectExecutionInfoUnderFailure();
-
-            if (storeToTest.keyValue()) {
-                if (storeToTest.timestamped()) {
-                    final Function<ValueAndTimestamp<Integer>, Integer> valueExtractor =
-                        ValueAndTimestamp::value;
-                    shouldHandleKeyQuery(2, valueExtractor, 2);
-                    shouldHandleRangeQueries(valueExtractor);
-                } else {
-                    final Function<Integer, Integer> valueExtractor = Function.identity();
-                    shouldHandleKeyQuery(2, valueExtractor, 2);
-                    shouldHandleRangeQueries(valueExtractor);
+        try {
+            if (storeToTest.global()) {
+                // See KAFKA-13523
+                globalShouldRejectAllQueries();
+            } else {
+                shouldRejectUnknownQuery();
+                shouldCollectExecutionInfo();
+                shouldCollectExecutionInfoUnderFailure();
+
+                if (storeToTest.keyValue()) {
+                    if (storeToTest.timestamped()) {
+                        final Function<ValueAndTimestamp<Integer>, Integer> valueExtractor =
+                            ValueAndTimestamp::value;
+                        shouldHandleKeyQuery(2, valueExtractor, 2);
+                        shouldHandleRangeQueries(valueExtractor);
+                    } else {
+                        final Function<Integer, Integer> valueExtractor = Function.identity();
+                        shouldHandleKeyQuery(2, valueExtractor, 2);
+                        shouldHandleRangeQueries(valueExtractor);
+                    }
                 }
-            }
 
-            if (storeToTest.isWindowed()) {
-                if (storeToTest.timestamped()) {
-                    final Function<ValueAndTimestamp<Integer>, Integer> valueExtractor =
+                if (storeToTest.isWindowed()) {
+                    if (storeToTest.timestamped()) {
+                        final Function<ValueAndTimestamp<Integer>, Integer> valueExtractor =
                             ValueAndTimestamp::value;
-                    shouldHandleWindowKeyQueries(valueExtractor);
-                    shouldHandleWindowRangeQueries(valueExtractor);
-                } else {
-                    final Function<Integer, Integer> valueExtractor = Function.identity();
-                    shouldHandleWindowKeyQueries(valueExtractor);
-                    shouldHandleWindowRangeQueries(valueExtractor);
+                        shouldHandleWindowKeyQueries(valueExtractor);
+                        shouldHandleWindowRangeQueries(valueExtractor);
+                    } else {
+                        final Function<Integer, Integer> valueExtractor = Function.identity();
+                        shouldHandleWindowKeyQueries(valueExtractor);
+                        shouldHandleWindowRangeQueries(valueExtractor);
+                    }
                 }
-            }
 
-            if (storeToTest.isSession()) {
-                // Note there's no "timestamped" differentiation here.
-                // Idiosyncratically, SessionStores are _never_ timestamped.
-                shouldHandleSessionKeyQueries();
+                if (storeToTest.isSession()) {
+                    // Note there's no "timestamped" differentiation here.
+                    // Idiosyncratically, SessionStores are _never_ timestamped.
+                    shouldHandleSessionKeyQueries();
+                }
             }
+        } catch (final AssertionError e) {
+            LOG.error("Failed assertion", e);
+            throw e;
         }
     }
 
@@ -1350,7 +1355,7 @@ public class IQv2StoreIntegrationTest {
                                                    final String supplier, final String kind) {
         final String safeTestName =
             IQv2StoreIntegrationTest.class.getName() + "-" + cache + "-" + log + "-" + supplier
-                + "-" + kind;
+                + "-" + kind + "-" + RANDOM.nextInt();
         final Properties config = new Properties();
         config.put(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, StreamsConfig.OPTIMIZE);
         config.put(StreamsConfig.APPLICATION_ID_CONFIG, "app-" + safeTestName);
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingInMemoryKeyValueStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingInMemoryKeyValueStoreTest.java
index 13d78ec..f11f854 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingInMemoryKeyValueStoreTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingInMemoryKeyValueStoreTest.java
@@ -221,17 +221,44 @@ public class CachingInMemoryKeyValueStoreTest extends AbstractKeyValueStoreTest
     }
 
     @Test
-    public void shouldMatchPositionAfterPut() {
+    public void shouldMatchPositionAfterPutWithFlushListener() {
+        store.setFlushListener(record -> { }, false);
+        shouldMatchPositionAfterPut();
+    }
+
+    @Test
+    public void shouldMatchPositionAfterPutWithoutFlushListener() {
+        store.setFlushListener(null, false);
+        shouldMatchPositionAfterPut();
+    }
+
+    private void shouldMatchPositionAfterPut() {
         context.setRecordContext(new ProcessorRecordContext(0, 1, 0, "", new RecordHeaders()));
         store.put(bytesKey("key1"), bytesValue("value1"));
         context.setRecordContext(new ProcessorRecordContext(0, 2, 0, "", new RecordHeaders()));
         store.put(bytesKey("key2"), bytesValue("value2"));
-        context.setRecordContext(new ProcessorRecordContext(0, 3, 0, "", new RecordHeaders()));
-        store.put(bytesKey("key3"), bytesValue("value3"));
 
-        final Position expected = Position.fromMap(mkMap(mkEntry("", mkMap(mkEntry(0, 3L)))));
-        final Position actual = store.getPosition();
-        assertEquals(expected, actual);
+        // Position should correspond to the last record's context, not the current context.
+        context.setRecordContext(
+            new ProcessorRecordContext(0, 3, 0, "", new RecordHeaders())
+        );
+
+        assertEquals(
+            Position.fromMap(mkMap(mkEntry("", mkMap(mkEntry(0, 2L))))),
+            store.getPosition()
+        );
+        assertEquals(Position.emptyPosition(), underlyingStore.getPosition());
+
+        store.flush();
+
+        assertEquals(
+            Position.fromMap(mkMap(mkEntry("", mkMap(mkEntry(0, 2L))))),
+            store.getPosition()
+        );
+        assertEquals(
+            Position.fromMap(mkMap(mkEntry("", mkMap(mkEntry(0, 2L))))),
+            underlyingStore.getPosition()
+        );
     }
 
     private byte[] bytesValue(final String value) {
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingInMemorySessionStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingInMemorySessionStoreTest.java
index 0de2321..d5aa667 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingInMemorySessionStoreTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingInMemorySessionStoreTest.java
@@ -36,6 +36,7 @@ import org.apache.kafka.streams.processor.api.Record;
 import org.apache.kafka.streams.processor.internals.MockStreamsMetrics;
 import org.apache.kafka.streams.processor.internals.ProcessorRecordContext;
 import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender;
+import org.apache.kafka.streams.query.Position;
 import org.apache.kafka.streams.state.KeyValueIterator;
 import org.apache.kafka.streams.state.SessionStore;
 import org.apache.kafka.test.InternalMockProcessorContext;
@@ -53,6 +54,8 @@ import java.util.List;
 import java.util.Random;
 
 import static java.util.Arrays.asList;
+import static org.apache.kafka.common.utils.Utils.mkEntry;
+import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.test.StreamsTestUtils.toList;
 import static org.apache.kafka.test.StreamsTestUtils.verifyKeyValueList;
 import static org.apache.kafka.test.StreamsTestUtils.verifyWindowedKeyValue;
@@ -143,6 +146,46 @@ public class CachingInMemorySessionStoreTest {
     }
 
     @Test
+    public void shouldMatchPositionAfterPutWithFlushListener() {
+        cachingStore.setFlushListener(record -> { }, false);
+        shouldMatchPositionAfterPut();
+    }
+
+    @Test
+    public void shouldMatchPositionAfterPutWithoutFlushListener() {
+        cachingStore.setFlushListener(null, false);
+        shouldMatchPositionAfterPut();
+    }
+
+    private void shouldMatchPositionAfterPut() {
+        context.setRecordContext(new ProcessorRecordContext(0, 1, 0, "", new RecordHeaders()));
+        cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes());
+        context.setRecordContext(new ProcessorRecordContext(0, 2, 0, "", new RecordHeaders()));
+        cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes());
+
+        // Position should correspond to the last record's context, not the current context.
+        context.setRecordContext(
+            new ProcessorRecordContext(0, 3, 0, "", new RecordHeaders())
+        );
+
+        // the caching session store doesn't maintain a separate
+        // position because it never serves queries from the cache
+        assertEquals(Position.emptyPosition(), cachingStore.getPosition());
+        assertEquals(Position.emptyPosition(), underlyingStore.getPosition());
+
+        cachingStore.flush();
+
+        assertEquals(
+            Position.fromMap(mkMap(mkEntry("", mkMap(mkEntry(0, 2L))))),
+            cachingStore.getPosition()
+        );
+        assertEquals(
+            Position.fromMap(mkMap(mkEntry("", mkMap(mkEntry(0, 2L))))),
+            underlyingStore.getPosition()
+        );
+    }
+
+    @Test
     public void shouldPutFetchAllKeysFromCache() {
         cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes());
         cachingStore.put(new Windowed<>(keyAA, new SessionWindow(0, 0)), "1".getBytes());
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingPersistentSessionStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingPersistentSessionStoreTest.java
index 6a622dc..50fd88a 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingPersistentSessionStoreTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingPersistentSessionStoreTest.java
@@ -35,6 +35,7 @@ import org.apache.kafka.streams.processor.api.Record;
 import org.apache.kafka.streams.processor.internals.MockStreamsMetrics;
 import org.apache.kafka.streams.processor.internals.ProcessorRecordContext;
 import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender;
+import org.apache.kafka.streams.query.Position;
 import org.apache.kafka.streams.state.KeyValueIterator;
 import org.apache.kafka.streams.state.SessionStore;
 import org.apache.kafka.test.InternalMockProcessorContext;
@@ -52,6 +53,8 @@ import java.util.List;
 import java.util.Random;
 
 import static java.util.Arrays.asList;
+import static org.apache.kafka.common.utils.Utils.mkEntry;
+import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.test.StreamsTestUtils.toList;
 import static org.apache.kafka.test.StreamsTestUtils.verifyKeyValueList;
 import static org.apache.kafka.test.StreamsTestUtils.verifyWindowedKeyValue;
@@ -80,6 +83,7 @@ public class CachingPersistentSessionStoreTest {
     private SessionStore<Bytes, byte[]> underlyingStore;
     private CachingSessionStore cachingStore;
     private ThreadCache cache;
+    private InternalMockProcessorContext<Object, Object> context;
 
     @Before
     public void before() {
@@ -93,7 +97,7 @@ public class CachingPersistentSessionStoreTest {
         underlyingStore = new RocksDBSessionStore(segmented);
         cachingStore = new CachingSessionStore(underlyingStore, SEGMENT_INTERVAL);
         cache = new ThreadCache(new LogContext("testCache "), MAX_CACHE_SIZE_BYTES, new MockStreamsMetrics(new Metrics()));
-        final InternalMockProcessorContext context =
+        this.context =
             new InternalMockProcessorContext<>(TestUtils.tempDirectory(), null, null, null, cache);
         context.setRecordContext(new ProcessorRecordContext(DEFAULT_TIMESTAMP, 0, 0, TOPIC, new RecordHeaders()));
         cachingStore.init((StateStoreContext) context, cachingStore);
@@ -123,6 +127,45 @@ public class CachingPersistentSessionStoreTest {
             assertFalse(b.hasNext());
         }
     }
+    @Test
+    public void shouldMatchPositionAfterPutWithFlushListener() {
+        cachingStore.setFlushListener(record -> { }, false);
+        shouldMatchPositionAfterPut();
+    }
+
+    @Test
+    public void shouldMatchPositionAfterPutWithoutFlushListener() {
+        cachingStore.setFlushListener(null, false);
+        shouldMatchPositionAfterPut();
+    }
+
+    private void shouldMatchPositionAfterPut() {
+        context.setRecordContext(new ProcessorRecordContext(0, 1, 0, "", new RecordHeaders()));
+        cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes());
+        context.setRecordContext(new ProcessorRecordContext(0, 2, 0, "", new RecordHeaders()));
+        cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes());
+
+        // Position should correspond to the last record's context, not the current context.
+        context.setRecordContext(
+            new ProcessorRecordContext(0, 3, 0, "", new RecordHeaders())
+        );
+
+        // the caching session store doesn't maintain a separate
+        // position because it never serves queries from the cache
+        assertEquals(Position.emptyPosition(), cachingStore.getPosition());
+        assertEquals(Position.emptyPosition(), underlyingStore.getPosition());
+
+        cachingStore.flush();
+
+        assertEquals(
+            Position.fromMap(mkMap(mkEntry("", mkMap(mkEntry(0, 2L))))),
+            cachingStore.getPosition()
+        );
+        assertEquals(
+            Position.fromMap(mkMap(mkEntry("", mkMap(mkEntry(0, 2L))))),
+            underlyingStore.getPosition()
+        );
+    }
 
     @Test
     public void shouldPutFetchAllKeysFromCache() {
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingPersistentWindowStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingPersistentWindowStoreTest.java
index 2d64a44..3426c3e 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingPersistentWindowStoreTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingPersistentWindowStoreTest.java
@@ -39,6 +39,7 @@ import org.apache.kafka.streams.processor.StateStoreContext;
 import org.apache.kafka.streams.processor.internals.MockStreamsMetrics;
 import org.apache.kafka.streams.processor.internals.ProcessorRecordContext;
 import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender;
+import org.apache.kafka.streams.query.Position;
 import org.apache.kafka.streams.state.KeyValueIterator;
 import org.apache.kafka.streams.state.StoreBuilder;
 import org.apache.kafka.streams.state.Stores;
@@ -63,6 +64,8 @@ import static java.time.Duration.ofHours;
 import static java.time.Duration.ofMinutes;
 import static java.time.Instant.ofEpochMilli;
 import static java.util.Arrays.asList;
+import static org.apache.kafka.common.utils.Utils.mkEntry;
+import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.streams.state.internals.ThreadCacheTest.memoryCacheEntrySize;
 import static org.apache.kafka.test.StreamsTestUtils.toList;
 import static org.apache.kafka.test.StreamsTestUtils.verifyAllWindowedKeyValues;
@@ -260,6 +263,46 @@ public class CachingPersistentWindowStoreTest {
         }
     }
 
+    @Test
+    public void shouldMatchPositionAfterPutWithFlushListener() {
+        cachingStore.setFlushListener(record -> { }, false);
+        shouldMatchPositionAfterPut();
+    }
+
+    @Test
+    public void shouldMatchPositionAfterPutWithoutFlushListener() {
+        cachingStore.setFlushListener(null, false);
+        shouldMatchPositionAfterPut();
+    }
+
+    private void shouldMatchPositionAfterPut() {
+        context.setRecordContext(new ProcessorRecordContext(0, 1, 0, "", new RecordHeaders()));
+        cachingStore.put(bytesKey("key1"), bytesValue("value1"), DEFAULT_TIMESTAMP);
+        context.setRecordContext(new ProcessorRecordContext(0, 2, 0, "", new RecordHeaders()));
+        cachingStore.put(bytesKey("key2"), bytesValue("value2"), DEFAULT_TIMESTAMP);
+
+        // Position should correspond to the last record's context, not the current context.
+        context.setRecordContext(
+            new ProcessorRecordContext(0, 3, 0, "", new RecordHeaders())
+        );
+
+        // the caching window store doesn't maintain a separate
+        // position because it never serves queries from the cache
+        assertEquals(Position.emptyPosition(), cachingStore.getPosition());
+        assertEquals(Position.emptyPosition(), underlyingStore.getPosition());
+
+        cachingStore.flush();
+
+        assertEquals(
+            Position.fromMap(mkMap(mkEntry("", mkMap(mkEntry(0, 2L))))),
+            cachingStore.getPosition()
+        );
+        assertEquals(
+            Position.fromMap(mkMap(mkEntry("", mkMap(mkEntry(0, 2L))))),
+            underlyingStore.getPosition()
+        );
+    }
+
     private void verifyKeyValue(final KeyValue<Long, byte[]> next,
                                 final long expectedKey,
                                 final String expectedValue) {