You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by vv...@apache.org on 2022/01/27 15:27:16 UTC

[kafka] branch trunk updated: KAFKA-13605: checkpoint position in state stores (#11676)

This is an automated email from the ASF dual-hosted git repository.

vvcephei pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 1a21892  KAFKA-13605: checkpoint position in state stores (#11676)
1a21892 is described below

commit 1a218926636d2d0f31b4ad402f8d26c2ad8d9838
Author: Patrick Stuedi <ps...@confluent.io>
AuthorDate: Thu Jan 27 16:25:04 2022 +0100

    KAFKA-13605: checkpoint position in state stores (#11676)
    
    There are cases in which a state store neither has an in-memory position built up nor has it gone through the state restoration process. If a store is persistent (i.e., RocksDB), and we stop and restart Streams, we will have neither of those continuity mechanisms available.
    
    This patch:
    * adds a test to verify that all stores correctly recover their position after a restart
    * implements storage and recovery of the position for persistent stores alongside on-disk state
    
    Reviewers: Vicky Papavasileiou <vp...@confluent.io>, Matthias J. Sax <mj...@apache.org>, Guozhang Wang <gu...@apache.org>, John Roesler <vv...@apache.org>
---
 .../kafka/streams/processor/CommitCallback.java    |  31 +
 .../streams/processor/StateRestoreListener.java    |   3 +-
 .../apache/kafka/streams/processor/StateStore.java |  20 +-
 .../kafka/streams/processor/StateStoreContext.java |  21 +
 .../internals/AbstractProcessorContext.java        |  10 +-
 .../ChangelogRecordDeserializationHelper.java      |  10 +-
 .../internals/GlobalStateManagerImpl.java          |   5 +-
 .../processor/internals/ProcessorStateManager.java |  28 +-
 .../streams/processor/internals/StateManager.java  |   5 +-
 .../AbstractRocksDBSegmentedBytesStore.java        |  23 +-
 .../state/internals/CachingKeyValueStore.java      |   3 +-
 .../state/internals/InMemoryKeyValueStore.java     |  26 +-
 .../state/internals/InMemorySessionStore.java      |  26 +-
 .../state/internals/InMemoryWindowStore.java       |  30 +-
 .../streams/state/internals/MemoryLRUCache.java    |  31 +-
 .../state/internals/RocksDBSessionStore.java       |  13 +-
 .../streams/state/internals/RocksDBStore.java      |  82 +--
 .../state/internals/RocksDBWindowStore.java        |  13 +-
 .../streams/state/internals/StoreQueryUtils.java   |  46 ++
 .../PositionRestartIntegrationTest.java            | 692 +++++++++++++++++++++
 .../internals/GlobalStateManagerImplTest.java      |  56 +-
 .../internals/ProcessorStateManagerTest.java       | 236 +++++--
 .../processor/internals/StateManagerStub.java      |   4 +-
 .../processor/internals/StreamTaskTest.java        |   8 +-
 .../apache/kafka/test/GlobalStateManagerStub.java  |   5 +-
 .../kafka/test/InternalMockProcessorContext.java   |   6 +-
 .../test/MockInternalNewProcessorContext.java      |  11 +-
 .../kafka/test/MockInternalProcessorContext.java   |  12 +-
 .../apache/kafka/test/NoOpProcessorContext.java    |   4 +-
 .../processor/api/MockProcessorContext.java        |  11 +-
 30 files changed, 1281 insertions(+), 190 deletions(-)

diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/CommitCallback.java b/streams/src/main/java/org/apache/kafka/streams/processor/CommitCallback.java
new file mode 100644
index 0000000..581173d
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/CommitCallback.java
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.streams.processor;
+
+import org.apache.kafka.common.annotation.InterfaceStability.Evolving;
+
+import java.io.IOException;
+
+/**
+ * Stores can register this callback to be notified upon successful commit.
+ */
+@Evolving
+@FunctionalInterface
+public interface CommitCallback {
+    void onCommit() throws IOException;
+}
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/StateRestoreListener.java b/streams/src/main/java/org/apache/kafka/streams/processor/StateRestoreListener.java
index 210a5de..6ba794f 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/StateRestoreListener.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/StateRestoreListener.java
@@ -31,7 +31,8 @@ import org.apache.kafka.common.TopicPartition;
  *
  * Note that this listener is only registered at the per-client level and users can base on the {@code storeName}
  * parameter to define specific monitoring for different {@link StateStore}s. There is another
- * {@link StateRestoreCallback} interface which is registered via the {@link ProcessorContext#register(StateStore, StateRestoreCallback)}
+ * {@link StateRestoreCallback} interface which is registered via the
+ * {@link StateStoreContext#register(StateStore, StateRestoreCallback, CommitCallback)}
  * function per-store, and it is used to apply the fetched changelog records into the local state store during restoration.
  * These two interfaces serve different restoration purposes and users should not try to implement both of them in a single
  * class during state store registration.
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/StateStore.java b/streams/src/main/java/org/apache/kafka/streams/processor/StateStore.java
index 647fa8e..8504e88 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/StateStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/StateStore.java
@@ -42,7 +42,7 @@ import org.apache.kafka.streams.query.QueryResult;
  * Furthermore, Kafka Streams relies on using the store name as store directory name to perform internal cleanup tasks.
  * <p>
  * This interface does not specify any query capabilities, which, of course,
- * would be query engine specific. Instead it just specifies the minimum
+ * would be query engine specific. Instead, it just specifies the minimum
  * functionality required to reload a storage engine from its changelog as well
  * as basic lifecycle management.
  */
@@ -83,7 +83,7 @@ public interface StateStore {
      * Initializes this state store.
      * <p>
      * The implementation of this function must register the root store in the context via the
-     * {@link StateStoreContext#register(StateStore, StateRestoreCallback)} function, where the
+     * {@link StateStoreContext#register(StateStore, StateRestoreCallback, CommitCallback)} function, where the
      * first {@link StateStore} parameter should always be the passed-in {@code root} object, and
      * the second parameter should be an object of user's implementation
      * of the {@link StateRestoreCallback} interface used for restoring the state store from the changelog.
@@ -149,16 +149,20 @@ public interface StateStore {
      */
     @Evolving
     default <R> QueryResult<R> query(
-        Query<R> query,
-        PositionBound positionBound,
-        QueryConfig config) {
+        final Query<R> query,
+        final PositionBound positionBound,
+        final QueryConfig config) {
         // If a store doesn't implement a query handler, then all queries are unknown.
         return QueryResult.forUnknownQueryType(query, this);
     }
 
     /**
-     * Returns the position the state store is at
-     * @return
+     * Returns the position the state store is at with respect to the input topic/partitions
      */
-    Position getPosition();
+    @Evolving
+    default Position getPosition() {
+        throw new UnsupportedOperationException(
+            "getPosition is not implemented by this StateStore (" + getClass() + ")"
+        );
+    }
 }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/StateStoreContext.java b/streams/src/main/java/org/apache/kafka/streams/processor/StateStoreContext.java
index f6f1446..35e3e58 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/StateStoreContext.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/StateStoreContext.java
@@ -16,6 +16,7 @@
  */
 package org.apache.kafka.streams.processor;
 
+import org.apache.kafka.common.annotation.InterfaceStability.Evolving;
 import org.apache.kafka.common.serialization.Serde;
 import org.apache.kafka.streams.StreamsMetrics;
 import org.apache.kafka.streams.errors.StreamsException;
@@ -97,6 +98,26 @@ public interface StateStoreContext {
                   final StateRestoreCallback stateRestoreCallback);
 
     /**
+     * Registers and possibly restores the specified storage engine.
+     *
+     * @param store the storage engine
+     * @param stateRestoreCallback the restoration callback logic for log-backed state stores upon restart
+     * @param commitCallback a callback to be invoked upon successful task commit, in case the store
+     *                           needs to perform any state tracking when the task is known to be in
+     *                           a consistent state. If the store has no such state to track, it may
+     *                           use {@link StateStoreContext#register(StateStore, StateRestoreCallback)} instead.
+     *                           Persistent stores provided by Kafka Streams use this method to save
+     *                           their Position information to local disk, for example.
+     *
+     * @throws IllegalStateException If store gets registered after initialized is already finished
+     * @throws StreamsException if the store's change log does not contain the partition
+     */
+    @Evolving
+    void register(final StateStore store,
+                  final StateRestoreCallback stateRestoreCallback,
+                  final CommitCallback commitCallback);
+
+    /**
      * Returns all the application config properties as key/value pairs.
      *
      * <p> The config properties are defined in the {@link org.apache.kafka.streams.StreamsConfig}
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractProcessorContext.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractProcessorContext.java
index 1bc8034..4140f12 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractProcessorContext.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractProcessorContext.java
@@ -20,6 +20,7 @@ import org.apache.kafka.common.header.Headers;
 import org.apache.kafka.common.header.internals.RecordHeaders;
 import org.apache.kafka.common.serialization.Serde;
 import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.processor.CommitCallback;
 import org.apache.kafka.streams.processor.StateRestoreCallback;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.TaskId;
@@ -112,11 +113,18 @@ public abstract class AbstractProcessorContext<KOut, VOut> implements InternalPr
     @Override
     public void register(final StateStore store,
                          final StateRestoreCallback stateRestoreCallback) {
+        register(store, stateRestoreCallback, () -> { });
+    }
+
+    @Override
+    public void register(final StateStore store,
+                         final StateRestoreCallback stateRestoreCallback,
+                         final CommitCallback checkpoint) {
         if (initialized) {
             throw new IllegalStateException("Can only create state stores during initialization.");
         }
         Objects.requireNonNull(store, "store must not be null");
-        stateManager().registerStore(store, stateRestoreCallback);
+        stateManager().registerStore(store, stateRestoreCallback, checkpoint);
     }
 
     @Override
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogRecordDeserializationHelper.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogRecordDeserializationHelper.java
index 32c9485..306d77b 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogRecordDeserializationHelper.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogRecordDeserializationHelper.java
@@ -40,19 +40,18 @@ public class ChangelogRecordDeserializationHelper {
     public static final RecordHeader CHANGELOG_VERSION_HEADER_RECORD_CONSISTENCY = new RecordHeader(
             CHANGELOG_VERSION_HEADER_KEY, V_0_CHANGELOG_VERSION_HEADER_VALUE);
 
-    public static Position applyChecksAndUpdatePosition(
+    public static void applyChecksAndUpdatePosition(
             final ConsumerRecord<byte[], byte[]> record,
             final boolean consistencyEnabled,
             final Position position
     ) {
         if (!consistencyEnabled) {
-            return position;
+            return;
         }
-        Position restoredPosition = Position.emptyPosition();
         final Header versionHeader = record.headers().lastHeader(
                 ChangelogRecordDeserializationHelper.CHANGELOG_VERSION_HEADER_KEY);
         if (versionHeader == null) {
-            return position;
+            return;
         } else {
             switch (versionHeader.value()[0]) {
                 case 0:
@@ -61,14 +60,13 @@ public class ChangelogRecordDeserializationHelper {
                         throw new StreamsException("This should not happen. Consistency is enabled but the changelog "
                                 + "contains records without consistency information.");
                     }
-                    restoredPosition = position.merge(PositionSerde.deserialize(ByteBuffer.wrap(vectorHeader.value())));
+                    position.merge(PositionSerde.deserialize(ByteBuffer.wrap(vectorHeader.value())));
                     break;
                 default:
                     log.warn("Changelog records have been encoded using a larger version than this server understands." +
                             "Please upgrade your server.");
             }
         }
-        return restoredPosition;
     }
 
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java
index 090621d..69a94d0 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java
@@ -30,6 +30,7 @@ import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.errors.ProcessorStateException;
 import org.apache.kafka.streams.errors.StreamsException;
+import org.apache.kafka.streams.processor.CommitCallback;
 import org.apache.kafka.streams.processor.StateRestoreCallback;
 import org.apache.kafka.streams.processor.StateRestoreListener;
 import org.apache.kafka.streams.processor.StateStore;
@@ -167,7 +168,9 @@ public class GlobalStateManagerImpl implements GlobalStateManager {
     }
 
     @Override
-    public void registerStore(final StateStore store, final StateRestoreCallback stateRestoreCallback) {
+    public void registerStore(final StateStore store,
+                              final StateRestoreCallback stateRestoreCallback,
+                              final CommitCallback ignored) {
         log.info("Restoring state for global store {}", store.name());
 
         // TODO (KAFKA-12887): we should not trigger user's exception handler for illegal-argument but always
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java
index c2916da..3c8c40f 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java
@@ -24,6 +24,7 @@ import org.apache.kafka.streams.errors.ProcessorStateException;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.errors.TaskCorruptedException;
 import org.apache.kafka.streams.errors.TaskMigratedException;
+import org.apache.kafka.streams.processor.CommitCallback;
 import org.apache.kafka.streams.processor.StateRestoreCallback;
 import org.apache.kafka.streams.processor.StateRestoreListener;
 import org.apache.kafka.streams.processor.StateStore;
@@ -76,6 +77,8 @@ public class ProcessorStateManager implements StateManager {
         // could be used for both active restoration and standby
         private final StateRestoreCallback restoreCallback;
 
+        private final CommitCallback commitCallback;
+
         // record converters used for restoration and standby
         private final RecordConverter recordConverter;
 
@@ -95,8 +98,10 @@ public class ProcessorStateManager implements StateManager {
         // corrupted state store should not be included in checkpointing
         private boolean corrupted;
 
-        private StateStoreMetadata(final StateStore stateStore) {
+        private StateStoreMetadata(final StateStore stateStore,
+                                   final CommitCallback commitCallback) {
             this.stateStore = stateStore;
+            this.commitCallback = commitCallback;
             this.restoreCallback = null;
             this.recordConverter = null;
             this.changelogPartition = null;
@@ -107,6 +112,7 @@ public class ProcessorStateManager implements StateManager {
         private StateStoreMetadata(final StateStore stateStore,
                                    final TopicPartition changelogPartition,
                                    final StateRestoreCallback restoreCallback,
+                                   final CommitCallback commitCallback,
                                    final RecordConverter recordConverter) {
             if (restoreCallback == null) {
                 throw new IllegalStateException("Log enabled store should always provide a restore callback upon registration");
@@ -115,6 +121,7 @@ public class ProcessorStateManager implements StateManager {
             this.stateStore = stateStore;
             this.changelogPartition = changelogPartition;
             this.restoreCallback = restoreCallback;
+            this.commitCallback = commitCallback;
             this.recordConverter = recordConverter;
             this.offset = null;
         }
@@ -307,7 +314,9 @@ public class ProcessorStateManager implements StateManager {
     }
 
     @Override
-    public void registerStore(final StateStore store, final StateRestoreCallback stateRestoreCallback) {
+    public void registerStore(final StateStore store,
+                              final StateRestoreCallback stateRestoreCallback,
+                              final CommitCallback commitCallback) {
         final String storeName = store.name();
 
         // TODO (KAFKA-12887): we should not trigger user's exception handler for illegal-argument but always
@@ -333,8 +342,9 @@ public class ProcessorStateManager implements StateManager {
                 store,
                 getStorePartition(storeName),
                 stateRestoreCallback,
+                commitCallback,
                 converterForStore(store)) :
-            new StateStoreMetadata(store);
+            new StateStoreMetadata(store, commitCallback);
 
         // register the store first, so that if later an exception is thrown then eventually while we call `close`
         // on the state manager this state store would be closed as well
@@ -598,6 +608,18 @@ public class ProcessorStateManager implements StateManager {
         // checkpoint those stores that are only logged and persistent to the checkpoint file
         final Map<TopicPartition, Long> checkpointingOffsets = new HashMap<>();
         for (final StateStoreMetadata storeMetadata : stores.values()) {
+            if (storeMetadata.commitCallback != null && !storeMetadata.corrupted) {
+                try {
+                    storeMetadata.commitCallback.onCommit();
+                } catch (final IOException e) {
+                    throw new ProcessorStateException(
+                            format("%sException caught while trying to checkpoint store, " +
+                                    "changelog partition %s", logPrefix, storeMetadata.changelogPartition),
+                            e
+                    );
+                }
+            }
+
             // store is logged, persistent, not corrupted, and has a valid current offset
             if (storeMetadata.changelogPartition != null &&
                 storeMetadata.stateStore.persistent() &&
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateManager.java
index ad5c3cb..e08f42b 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateManager.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateManager.java
@@ -18,6 +18,7 @@ package org.apache.kafka.streams.processor.internals;
 
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.streams.errors.StreamsException;
+import org.apache.kafka.streams.processor.CommitCallback;
 import org.apache.kafka.streams.processor.StateRestoreCallback;
 import org.apache.kafka.streams.processor.StateStore;
 
@@ -34,7 +35,9 @@ public interface StateManager {
      * (e.g., when it conflicts with the names of internal topics, like the checkpoint file name)
      * @throws StreamsException if the store's change log does not contain the partition
      */
-    void registerStore(final StateStore store, final StateRestoreCallback stateRestoreCallback);
+    void registerStore(final StateStore store,
+                       final StateRestoreCallback stateRestoreCallback,
+                       final CommitCallback checkpoint);
 
     StateStore getStore(final String name);
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractRocksDBSegmentedBytesStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractRocksDBSegmentedBytesStore.java
index 40fa5a8..bbe8c54 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractRocksDBSegmentedBytesStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractRocksDBSegmentedBytesStore.java
@@ -38,6 +38,7 @@ import org.rocksdb.WriteBatch;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.io.File;
 import java.util.Collection;
 import java.util.HashMap;
 import java.util.List;
@@ -59,7 +60,7 @@ public class AbstractRocksDBSegmentedBytesStore<S extends Segment> implements Se
     private long observedStreamTime = ConsumerRecord.NO_TIMESTAMP;
     private boolean consistencyEnabled = false;
     private Position position;
-
+    protected OffsetCheckpoint positionCheckpoint;
     private volatile boolean open;
 
     AbstractRocksDBSegmentedBytesStore(final String name,
@@ -70,7 +71,6 @@ public class AbstractRocksDBSegmentedBytesStore<S extends Segment> implements Se
         this.metricScope = metricScope;
         this.keySchema = keySchema;
         this.segments = segments;
-        this.position = Position.emptyPosition();
     }
 
     @Override
@@ -264,8 +264,16 @@ public class AbstractRocksDBSegmentedBytesStore<S extends Segment> implements Se
 
         segments.openExisting(this.context, observedStreamTime);
 
+        final File positionCheckpointFile = new File(context.stateDir(), name() + ".position");
+        this.positionCheckpoint = new OffsetCheckpoint(positionCheckpointFile);
+        this.position = StoreQueryUtils.readPositionFromCheckpoint(positionCheckpoint);
+
         // register and possibly restore the state from the logs
-        context.register(root, (RecordBatchingStateRestoreCallback) this::restoreAllInternal);
+        stateStoreContext.register(
+            root,
+            (RecordBatchingStateRestoreCallback) this::restoreAllInternal,
+            () -> StoreQueryUtils.checkpointPosition(positionCheckpoint, position)
+        );
 
         open = true;
 
@@ -277,8 +285,8 @@ public class AbstractRocksDBSegmentedBytesStore<S extends Segment> implements Se
 
     @Override
     public void init(final StateStoreContext context, final StateStore root) {
-        init(StoreToProcessorContextAdapter.adapt(context), root);
         this.stateStoreContext = context;
+        init(StoreToProcessorContextAdapter.adapt(context), root);
     }
 
     @Override
@@ -336,8 +344,11 @@ public class AbstractRocksDBSegmentedBytesStore<S extends Segment> implements Se
             final long segmentId = segments.segmentId(timestamp);
             final S segment = segments.getOrCreateSegmentIfLive(segmentId, context, observedStreamTime);
             if (segment != null) {
-                position = ChangelogRecordDeserializationHelper.applyChecksAndUpdatePosition(
-                        record, consistencyEnabled, position);
+                ChangelogRecordDeserializationHelper.applyChecksAndUpdatePosition(
+                    record,
+                    consistencyEnabled,
+                    position
+                );
                 try {
                     final WriteBatch batch = writeBatchMap.computeIfAbsent(segment, s -> new WriteBatch());
                     segment.addToBatch(new KeyValue<>(record.key(), record.value()), batch);
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingKeyValueStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingKeyValueStore.java
index c01feac..1d08d20 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingKeyValueStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingKeyValueStore.java
@@ -63,7 +63,7 @@ public class CachingKeyValueStore
     private InternalProcessorContext<?, ?> context;
     private Thread streamThread;
     private final ReadWriteLock lock = new ReentrantReadWriteLock();
-    private Position position;
+    private final Position position;
     private final boolean timestampedSchema;
 
     @FunctionalInterface
@@ -94,6 +94,7 @@ public class CachingKeyValueStore
         this.timestampedSchema = timestampedSchema;
     }
 
+    @SuppressWarnings("deprecation") // This can be removed when it's removed from the interface.
     @Deprecated
     @Override
     public void init(final ProcessorContext context,
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryKeyValueStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryKeyValueStore.java
index 3c86fbe..98f377d 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryKeyValueStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryKeyValueStore.java
@@ -16,12 +16,16 @@
  */
 package org.apache.kafka.streams.state.internals;
 
+import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.common.serialization.Serializer;
 import org.apache.kafka.common.utils.Bytes;
 import org.apache.kafka.streams.KeyValue;
+import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.processor.ProcessorContext;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.StateStoreContext;
+import org.apache.kafka.streams.processor.internals.ChangelogRecordDeserializationHelper;
+import org.apache.kafka.streams.processor.internals.RecordBatchingStateRestoreCallback;
 import org.apache.kafka.streams.processor.internals.StoreToProcessorContextAdapter;
 import org.apache.kafka.streams.query.Position;
 import org.apache.kafka.streams.query.PositionBound;
@@ -40,6 +44,8 @@ import java.util.Set;
 import java.util.TreeMap;
 import java.util.TreeSet;
 
+import static org.apache.kafka.streams.StreamsConfig.InternalConfig.IQ_CONSISTENCY_OFFSET_VECTOR_ENABLED;
+
 public class InMemoryKeyValueStore implements KeyValueStore<Bytes, byte[]> {
 
     private static final Logger LOG = LoggerFactory.getLogger(InMemoryKeyValueStore.class);
@@ -64,8 +70,25 @@ public class InMemoryKeyValueStore implements KeyValueStore<Bytes, byte[]> {
     public void init(final ProcessorContext context,
                      final StateStore root) {
         if (root != null) {
+            final boolean consistencyEnabled = StreamsConfig.InternalConfig.getBoolean(
+                context.appConfigs(),
+                IQ_CONSISTENCY_OFFSET_VECTOR_ENABLED,
+                false
+            );
             // register the store
-            context.register(root, (key, value) -> put(Bytes.wrap(key), value));
+            context.register(
+                root,
+                (RecordBatchingStateRestoreCallback) records -> {
+                    for (final ConsumerRecord<byte[], byte[]> record : records) {
+                        put(Bytes.wrap(record.key()), record.value());
+                        ChangelogRecordDeserializationHelper.applyChecksAndUpdatePosition(
+                            record,
+                            consistencyEnabled,
+                            position
+                        );
+                    }
+                }
+            );
         }
 
         open = true;
@@ -88,6 +111,7 @@ public class InMemoryKeyValueStore implements KeyValueStore<Bytes, byte[]> {
         return open;
     }
 
+    @Override
     public Position getPosition() {
         return position;
     }
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemorySessionStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemorySessionStore.java
index 84dc1e3..97984dd 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemorySessionStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemorySessionStore.java
@@ -20,12 +20,15 @@ import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.common.metrics.Sensor;
 import org.apache.kafka.common.utils.Bytes;
 import org.apache.kafka.streams.KeyValue;
+import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.kstream.Windowed;
 import org.apache.kafka.streams.kstream.internals.SessionWindow;
 import org.apache.kafka.streams.processor.ProcessorContext;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.StateStoreContext;
+import org.apache.kafka.streams.processor.internals.ChangelogRecordDeserializationHelper;
 import org.apache.kafka.streams.processor.internals.InternalProcessorContext;
+import org.apache.kafka.streams.processor.internals.RecordBatchingStateRestoreCallback;
 import org.apache.kafka.streams.processor.internals.StoreToProcessorContextAdapter;
 import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
 import org.apache.kafka.streams.processor.internals.metrics.TaskMetrics;
@@ -49,6 +52,8 @@ import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentNavigableMap;
 import java.util.concurrent.ConcurrentSkipListMap;
 
+import static org.apache.kafka.streams.StreamsConfig.InternalConfig.IQ_CONSISTENCY_OFFSET_VECTOR_ENABLED;
+
 public class InMemorySessionStore implements SessionStore<Bytes, byte[]> {
 
     private static final Logger LOG = LoggerFactory.getLogger(InMemorySessionStore.class);
@@ -111,7 +116,24 @@ public class InMemorySessionStore implements SessionStore<Bytes, byte[]> {
         }
 
         if (root != null) {
-            context.register(root, (key, value) -> put(SessionKeySchema.from(Bytes.wrap(key)), value));
+            final boolean consistencyEnabled = StreamsConfig.InternalConfig.getBoolean(
+                context.appConfigs(),
+                IQ_CONSISTENCY_OFFSET_VECTOR_ENABLED,
+                false
+            );
+            context.register(
+                root,
+                (RecordBatchingStateRestoreCallback) records -> {
+                    for (final ConsumerRecord<byte[], byte[]> record : records) {
+                        put(SessionKeySchema.from(Bytes.wrap(record.key())), record.value());
+                        ChangelogRecordDeserializationHelper.applyChecksAndUpdatePosition(
+                            record,
+                            consistencyEnabled,
+                            position
+                        );
+                    }
+                }
+            );
         }
         open = true;
     }
@@ -119,8 +141,8 @@ public class InMemorySessionStore implements SessionStore<Bytes, byte[]> {
     @Override
     public void init(final StateStoreContext context,
                      final StateStore root) {
-        init(StoreToProcessorContextAdapter.adapt(context), root);
         this.stateStoreContext = context;
+        init(StoreToProcessorContextAdapter.adapt(context), root);
     }
 
     @Override
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryWindowStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryWindowStore.java
index 57db087..5122789 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryWindowStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryWindowStore.java
@@ -20,12 +20,15 @@ import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.common.metrics.Sensor;
 import org.apache.kafka.common.utils.Bytes;
 import org.apache.kafka.streams.KeyValue;
+import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.kstream.Windowed;
 import org.apache.kafka.streams.kstream.internals.TimeWindow;
 import org.apache.kafka.streams.processor.ProcessorContext;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.StateStoreContext;
+import org.apache.kafka.streams.processor.internals.ChangelogRecordDeserializationHelper;
 import org.apache.kafka.streams.processor.internals.ProcessorContextUtils;
+import org.apache.kafka.streams.processor.internals.RecordBatchingStateRestoreCallback;
 import org.apache.kafka.streams.processor.internals.StoreToProcessorContextAdapter;
 import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
 import org.apache.kafka.streams.processor.internals.metrics.TaskMetrics;
@@ -50,6 +53,7 @@ import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentNavigableMap;
 import java.util.concurrent.ConcurrentSkipListMap;
 
+import static org.apache.kafka.streams.StreamsConfig.InternalConfig.IQ_CONSISTENCY_OFFSET_VECTOR_ENABLED;
 import static org.apache.kafka.streams.state.internals.WindowKeySchema.extractStoreKeyBytes;
 import static org.apache.kafka.streams.state.internals.WindowKeySchema.extractStoreTimestamp;
 
@@ -111,8 +115,28 @@ public class InMemoryWindowStore implements WindowStore<Bytes, byte[]> {
         );
 
         if (root != null) {
-            context.register(root, (key, value) ->
-                put(Bytes.wrap(extractStoreKeyBytes(key)), value, extractStoreTimestamp(key)));
+            final boolean consistencyEnabled = StreamsConfig.InternalConfig.getBoolean(
+                context.appConfigs(),
+                IQ_CONSISTENCY_OFFSET_VECTOR_ENABLED,
+                false
+            );
+            context.register(
+                root,
+                (RecordBatchingStateRestoreCallback) records -> {
+                    for (final ConsumerRecord<byte[], byte[]> record : records) {
+                        put(
+                            Bytes.wrap(extractStoreKeyBytes(record.key())),
+                            record.value(),
+                            extractStoreTimestamp(record.key())
+                        );
+                        ChangelogRecordDeserializationHelper.applyChecksAndUpdatePosition(
+                            record,
+                            consistencyEnabled,
+                            position
+                        );
+                    }
+                }
+            );
         }
         open = true;
     }
@@ -120,8 +144,8 @@ public class InMemoryWindowStore implements WindowStore<Bytes, byte[]> {
     @Override
     public void init(final StateStoreContext context,
                      final StateStore root) {
-        init(StoreToProcessorContextAdapter.adapt(context), root);
         this.stateStoreContext = context;
+        init(StoreToProcessorContextAdapter.adapt(context), root);
     }
 
     @Override
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/MemoryLRUCache.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/MemoryLRUCache.java
index af23140..4e6c1e8 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/MemoryLRUCache.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/MemoryLRUCache.java
@@ -16,12 +16,16 @@
  */
 package org.apache.kafka.streams.state.internals;
 
+import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.common.serialization.Serializer;
 import org.apache.kafka.common.utils.Bytes;
 import org.apache.kafka.streams.KeyValue;
+import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.processor.ProcessorContext;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.StateStoreContext;
+import org.apache.kafka.streams.processor.internals.ChangelogRecordDeserializationHelper;
+import org.apache.kafka.streams.processor.internals.RecordBatchingStateRestoreCallback;
 import org.apache.kafka.streams.query.Position;
 import org.apache.kafka.streams.state.KeyValueIterator;
 import org.apache.kafka.streams.state.KeyValueStore;
@@ -31,6 +35,8 @@ import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 
+import static org.apache.kafka.streams.StreamsConfig.InternalConfig.IQ_CONSISTENCY_OFFSET_VECTOR_ENABLED;
+
 /**
  * An in-memory LRU cache store based on HashSet and HashMap.
  */
@@ -93,12 +99,27 @@ public class MemoryLRUCache implements KeyValueStore<Bytes, byte[]> {
 
     @Override
     public void init(final StateStoreContext context, final StateStore root) {
+        final boolean consistencyEnabled = StreamsConfig.InternalConfig.getBoolean(
+            context.appConfigs(),
+            IQ_CONSISTENCY_OFFSET_VECTOR_ENABLED,
+            false
+        );
         // register the store
-        context.register(root, (key, value) -> {
-            restoring = true;
-            put(Bytes.wrap(key), value);
-            restoring = false;
-        });
+        context.register(
+            root,
+            (RecordBatchingStateRestoreCallback) records -> {
+                restoring = true;
+                for (final ConsumerRecord<byte[], byte[]> record : records) {
+                    put(Bytes.wrap(record.key()), record.value());
+                    ChangelogRecordDeserializationHelper.applyChecksAndUpdatePosition(
+                        record,
+                        consistencyEnabled,
+                        position
+                    );
+                }
+                restoring = false;
+            }
+        );
         this.context = context;
     }
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBSessionStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBSessionStore.java
index aad425d..7c72b26 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBSessionStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBSessionStore.java
@@ -20,7 +20,6 @@ import org.apache.kafka.common.utils.Bytes;
 import org.apache.kafka.streams.kstream.Windowed;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.StateStoreContext;
-import org.apache.kafka.streams.query.Position;
 import org.apache.kafka.streams.query.PositionBound;
 import org.apache.kafka.streams.query.Query;
 import org.apache.kafka.streams.query.QueryConfig;
@@ -33,26 +32,19 @@ public class RocksDBSessionStore
     extends WrappedStateStore<SegmentedBytesStore, Object, Object>
     implements SessionStore<Bytes, byte[]> {
 
-    private final Position position;
     private StateStoreContext stateStoreContext;
 
     RocksDBSessionStore(final SegmentedBytesStore bytesStore) {
         super(bytesStore);
-        this.position = Position.emptyPosition();
     }
 
     @Override
     public void init(final StateStoreContext context, final StateStore root) {
-        super.init(context, root);
+        wrapped().init(context, root);
         this.stateStoreContext = context;
     }
 
     @Override
-    public Position getPosition() {
-        return position;
-    }
-
-    @Override
     public <R> QueryResult<R> query(final Query<R> query,
                                     final PositionBound positionBound,
                                     final QueryConfig config) {
@@ -62,7 +54,7 @@ public class RocksDBSessionStore
             positionBound,
             config,
             this,
-            position,
+            getPosition(),
             stateStoreContext
         );
     }
@@ -158,6 +150,5 @@ public class RocksDBSessionStore
     @Override
     public void put(final Windowed<Bytes> sessionKey, final byte[] aggregate) {
         wrapped().put(SessionKeySchema.toBinary(sessionKey), aggregate);
-        StoreQueryUtils.updatePosition(position, stateStoreContext);
     }
 }
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBStore.java
index 93ac778..919c440 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBStore.java
@@ -116,8 +116,8 @@ public class RocksDBStore implements KeyValueStore<Bytes, byte[]>, BatchWritingS
 
     protected volatile boolean open = false;
     protected StateStoreContext context;
-    // VisibleForTesting
     protected Position position;
+    private OffsetCheckpoint positionCheckpoint;
 
     // VisibleForTesting
     public RocksDBStore(final String name,
@@ -131,7 +131,44 @@ public class RocksDBStore implements KeyValueStore<Bytes, byte[]>, BatchWritingS
         this.name = name;
         this.parentDir = parentDir;
         this.metricsRecorder = metricsRecorder;
-        this.position = Position.emptyPosition();
+    }
+
+    @Deprecated
+    @Override
+    public void init(final ProcessorContext context,
+                     final StateStore root) {
+        if (context instanceof StateStoreContext) {
+            init((StateStoreContext) context, root);
+        } else {
+            throw new UnsupportedOperationException(
+                "Use RocksDBStore#init(StateStoreContext, StateStore) instead."
+            );
+        }
+    }
+
+    @Override
+    public void init(final StateStoreContext context,
+                     final StateStore root) {
+        // open the DB dir
+        metricsRecorder.init(getMetricsImpl(context), context.taskId());
+        openDB(context.appConfigs(), context.stateDir());
+
+        final File positionCheckpointFile = new File(context.stateDir(), name() + ".position");
+        this.positionCheckpoint = new OffsetCheckpoint(positionCheckpointFile);
+        this.position = StoreQueryUtils.readPositionFromCheckpoint(positionCheckpoint);
+
+        // value getter should always read directly from rocksDB
+        // since it is only for values that are already flushed
+        this.context = context;
+        context.register(
+            root,
+            (RecordBatchingStateRestoreCallback) this::restoreBatch,
+            () -> StoreQueryUtils.checkpointPosition(positionCheckpoint, position)
+        );
+        consistencyEnabled = StreamsConfig.InternalConfig.getBoolean(
+            context.appConfigs(),
+            IQ_CONSISTENCY_OFFSET_VECTOR_ENABLED,
+            false);
     }
 
     @SuppressWarnings("unchecked")
@@ -161,7 +198,7 @@ public class RocksDBStore implements KeyValueStore<Bytes, byte[]>, BatchWritingS
         userSpecifiedOptions.setInfoLogLevel(InfoLogLevel.ERROR_LEVEL);
         // this is the recommended way to increase parallelism in RocksDb
         // note that the current implementation of setIncreaseParallelism affects the number
-        // of compaction threads but not flush threads (the latter remains one). Also
+        // of compaction threads but not flush threads (the latter remains one). Also,
         // the parallelism value needs to be at least two because of the code in
         // https://github.com/facebook/rocksdb/blob/62ad0a9b19f0be4cefa70b6b32876e764b7f3c11/util/options.cc#L580
         // subtracts one from the value passed to determine the number of compaction threads
@@ -243,36 +280,6 @@ public class RocksDBStore implements KeyValueStore<Bytes, byte[]>, BatchWritingS
         }
     }
 
-    @Deprecated
-    @Override
-    public void init(final ProcessorContext context,
-                     final StateStore root) {
-        // open the DB dir
-        metricsRecorder.init(getMetricsImpl(context), context.taskId());
-        openDB(context.appConfigs(), context.stateDir());
-
-        // value getter should always read directly from rocksDB
-        // since it is only for values that are already flushed
-        context.register(root, (RecordBatchingStateRestoreCallback) this::restoreBatch);
-    }
-
-    @Override
-    public void init(final StateStoreContext context,
-                     final StateStore root) {
-        // open the DB dir
-        metricsRecorder.init(getMetricsImpl(context), context.taskId());
-        openDB(context.appConfigs(), context.stateDir());
-
-        // value getter should always read directly from rocksDB
-        // since it is only for values that are already flushed
-        this.context = context;
-        context.register(root, (RecordBatchingStateRestoreCallback) this::restoreBatch);
-        consistencyEnabled = StreamsConfig.InternalConfig.getBoolean(
-                context.appConfigs(),
-                IQ_CONSISTENCY_OFFSET_VECTOR_ENABLED,
-                false);
-    }
-
     @Override
     public String name() {
         return name;
@@ -327,7 +334,6 @@ public class RocksDBStore implements KeyValueStore<Bytes, byte[]>, BatchWritingS
     }
 
     @Override
-    @SuppressWarnings("unchecked")
     public <R> QueryResult<R> query(
         final Query<R> query,
         final PositionBound positionBound,
@@ -472,7 +478,6 @@ public class RocksDBStore implements KeyValueStore<Bytes, byte[]>, BatchWritingS
         return value < 0;
     }
 
-    @SuppressWarnings("unchecked")
     @Override
     public synchronized void flush() {
         if (db == null) {
@@ -724,8 +729,11 @@ public class RocksDBStore implements KeyValueStore<Bytes, byte[]>, BatchWritingS
         try (final WriteBatch batch = new WriteBatch()) {
             final List<KeyValue<byte[], byte[]>> keyValues = new ArrayList<>();
             for (final ConsumerRecord<byte[], byte[]> record : records) {
-                position = ChangelogRecordDeserializationHelper.applyChecksAndUpdatePosition(
-                        record, consistencyEnabled, position);
+                ChangelogRecordDeserializationHelper.applyChecksAndUpdatePosition(
+                    record,
+                    consistencyEnabled,
+                    position
+                );
                 // If version headers are not present or version is V0
                 keyValues.add(new KeyValue<>(record.key(), record.value()));
             }
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBWindowStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBWindowStore.java
index 2f50eed..61212048 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBWindowStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBWindowStore.java
@@ -20,7 +20,6 @@ import org.apache.kafka.common.utils.Bytes;
 import org.apache.kafka.streams.kstream.Windowed;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.StateStoreContext;
-import org.apache.kafka.streams.query.Position;
 import org.apache.kafka.streams.query.PositionBound;
 import org.apache.kafka.streams.query.Query;
 import org.apache.kafka.streams.query.QueryConfig;
@@ -38,7 +37,6 @@ public class RocksDBWindowStore
 
     private int seqnum = 0;
 
-    private final Position position;
     private StateStoreContext stateStoreContext;
 
     RocksDBWindowStore(final SegmentedBytesStore bytesStore,
@@ -47,27 +45,20 @@ public class RocksDBWindowStore
         super(bytesStore);
         this.retainDuplicates = retainDuplicates;
         this.windowSize = windowSize;
-        this.position = Position.emptyPosition();
     }
 
     @Override
     public void init(final StateStoreContext context, final StateStore root) {
-        super.init(context, root);
+        wrapped().init(context, root);
         this.stateStoreContext = context;
     }
 
     @Override
-    public Position getPosition() {
-        return position;
-    }
-
-    @Override
     public void put(final Bytes key, final byte[] value, final long windowStartTimestamp) {
         // Skip if value is null and duplicates are allowed since this delete is a no-op
         if (!(value == null && retainDuplicates)) {
             maybeUpdateSeqnumForDups();
             wrapped().put(WindowKeySchema.toStoreKeyBinary(key, windowStartTimestamp, seqnum), value);
-            StoreQueryUtils.updatePosition(position, stateStoreContext);
         }
     }
 
@@ -140,7 +131,7 @@ public class RocksDBWindowStore
             positionBound,
             config,
             this,
-            position,
+            getPosition(),
             stateStoreContext
         );
     }
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/StoreQueryUtils.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/StoreQueryUtils.java
index 0a69eb8..06b3713 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/StoreQueryUtils.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/StoreQueryUtils.java
@@ -16,9 +16,11 @@
  */
 package org.apache.kafka.streams.state.internals;
 
+import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.serialization.Deserializer;
 import org.apache.kafka.common.serialization.Serde;
 import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.streams.errors.ProcessorStateException;
 import org.apache.kafka.streams.kstream.Windowed;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.StateStoreContext;
@@ -40,10 +42,14 @@ import org.apache.kafka.streams.state.StateSerdes;
 import org.apache.kafka.streams.state.WindowStore;
 import org.apache.kafka.streams.state.WindowStoreIterator;
 
+import java.io.IOException;
 import java.io.PrintWriter;
 import java.io.StringWriter;
+import java.util.HashMap;
 import java.util.Map;
+import java.util.Map.Entry;
 import java.util.Optional;
+import java.util.Set;
 import java.util.function.Function;
 
 import static org.apache.kafka.common.utils.Utils.mkEntry;
@@ -338,6 +344,46 @@ public final class StoreQueryUtils {
         return byteArray -> deserializer.deserialize(serdes.topic(), byteArray);
     }
 
+    public static void checkpointPosition(final OffsetCheckpoint checkpointFile,
+                                          final Position position) {
+        try {
+            checkpointFile.write(positionToTopicPartitionMap(position));
+        } catch (final IOException e) {
+            throw new ProcessorStateException("Error writing checkpoint file", e);
+        }
+    }
+
+    public static Position readPositionFromCheckpoint(final OffsetCheckpoint checkpointFile) {
+        try {
+            return topicPartitionMapToPosition(checkpointFile.read());
+        } catch (final IOException e) {
+            throw new ProcessorStateException("Error reading checkpoint file", e);
+        }
+    }
+
+    private static Map<TopicPartition, Long> positionToTopicPartitionMap(final Position position) {
+        final Map<TopicPartition, Long> topicPartitions = new HashMap<>();
+        final Set<String> topics = position.getTopics();
+        for (final String t : topics) {
+            final Map<Integer, Long> partitions = position.getPartitionPositions(t);
+            for (final Entry<Integer, Long> e : partitions.entrySet()) {
+                final TopicPartition tp = new TopicPartition(t, e.getKey());
+                topicPartitions.put(tp, e.getValue());
+            }
+        }
+        return topicPartitions;
+    }
+
+    private static Position topicPartitionMapToPosition(final Map<TopicPartition, Long> topicPartitions) {
+        final Map<String, Map<Integer, Long>> pos = new HashMap<>();
+        for (final Entry<TopicPartition, Long> e : topicPartitions.entrySet()) {
+            pos
+                .computeIfAbsent(e.getKey().topic(), t -> new HashMap<>())
+                .put(e.getKey().partition(), e.getValue());
+        }
+        return Position.fromMap(pos);
+    }
+
     private static <R> String parseStoreException(final Exception e, final StateStore store, final Query<R> query) {
         final StringWriter stringWriter = new StringWriter();
         final PrintWriter printWriter = new PrintWriter(stringWriter);
diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/PositionRestartIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/PositionRestartIntegrationTest.java
new file mode 100644
index 0000000..db1f86e
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/integration/PositionRestartIntegrationTest.java
@@ -0,0 +1,692 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.integration;
+
+import org.apache.kafka.clients.consumer.ConsumerConfig;
+import org.apache.kafka.clients.producer.KafkaProducer;
+import org.apache.kafka.clients.producer.Producer;
+import org.apache.kafka.clients.producer.ProducerConfig;
+import org.apache.kafka.clients.producer.ProducerRecord;
+import org.apache.kafka.clients.producer.RecordMetadata;
+import org.apache.kafka.common.serialization.IntegerSerializer;
+import org.apache.kafka.common.serialization.Serdes;
+import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.streams.KafkaStreams;
+import org.apache.kafka.streams.StreamsBuilder;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.StreamsConfig.InternalConfig;
+import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
+import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
+import org.apache.kafka.streams.kstream.Consumed;
+import org.apache.kafka.streams.kstream.Materialized;
+import org.apache.kafka.streams.kstream.SessionWindows;
+import org.apache.kafka.streams.kstream.TimeWindows;
+import org.apache.kafka.streams.kstream.Windowed;
+import org.apache.kafka.streams.kstream.internals.SessionWindow;
+import org.apache.kafka.streams.processor.api.ContextualProcessor;
+import org.apache.kafka.streams.processor.api.ProcessorSupplier;
+import org.apache.kafka.streams.processor.api.Record;
+import org.apache.kafka.streams.query.Position;
+import org.apache.kafka.streams.query.PositionBound;
+import org.apache.kafka.streams.query.Query;
+import org.apache.kafka.streams.query.RangeQuery;
+import org.apache.kafka.streams.query.StateQueryRequest;
+import org.apache.kafka.streams.query.StateQueryResult;
+import org.apache.kafka.streams.query.WindowKeyQuery;
+import org.apache.kafka.streams.query.WindowRangeQuery;
+import org.apache.kafka.streams.state.KeyValueBytesStoreSupplier;
+import org.apache.kafka.streams.state.KeyValueStore;
+import org.apache.kafka.streams.state.SessionBytesStoreSupplier;
+import org.apache.kafka.streams.state.SessionStore;
+import org.apache.kafka.streams.state.StoreBuilder;
+import org.apache.kafka.streams.state.StoreSupplier;
+import org.apache.kafka.streams.state.Stores;
+import org.apache.kafka.streams.state.TimestampedKeyValueStore;
+import org.apache.kafka.streams.state.TimestampedWindowStore;
+import org.apache.kafka.streams.state.ValueAndTimestamp;
+import org.apache.kafka.streams.state.WindowBytesStoreSupplier;
+import org.apache.kafka.streams.state.WindowStore;
+import org.apache.kafka.test.IntegrationTest;
+import org.apache.kafka.test.TestUtils;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Objects;
+import java.util.Properties;
+import java.util.Random;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+import static org.apache.kafka.common.utils.Utils.mkSet;
+import static org.apache.kafka.streams.query.StateQueryRequest.inStore;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.is;
+
+@Category({IntegrationTest.class})
+@RunWith(value = Parameterized.class)
+public class PositionRestartIntegrationTest {
+
+    private static final Logger LOG = LoggerFactory.getLogger(PositionRestartIntegrationTest.class);
+    private static final long SEED = new Random().nextLong();
+    private static final int NUM_BROKERS = 1;
+    public static final Duration WINDOW_SIZE = Duration.ofMinutes(5);
+    private static int port = 0;
+    private static final String INPUT_TOPIC_NAME = "input-topic";
+    private static final Position INPUT_POSITION = Position.emptyPosition();
+    private static final String STORE_NAME = "kv-store";
+    private static final long RECORD_TIME = System.currentTimeMillis();
+    private static final long WINDOW_START =
+        (RECORD_TIME / WINDOW_SIZE.toMillis()) * WINDOW_SIZE.toMillis();
+    public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(NUM_BROKERS);
+    private final StoresToTest storeToTest;
+    private final String kind;
+    private final boolean cache;
+    private final boolean log;
+    private KafkaStreams kafkaStreams;
+    private final Properties streamsConfig;
+
+    public enum StoresToTest {
+        IN_MEMORY_KV {
+            @Override
+            public StoreSupplier<?> supplier() {
+                return Stores.inMemoryKeyValueStore(STORE_NAME);
+            }
+
+            @Override
+            public boolean keyValue() {
+                return true;
+            }
+        },
+        IN_MEMORY_LRU {
+            @Override
+            public StoreSupplier<?> supplier() {
+                return Stores.lruMap(STORE_NAME, 100);
+            }
+
+            @Override
+            public boolean keyValue() {
+                return true;
+            }
+        },
+        ROCKS_KV {
+            @Override
+            public StoreSupplier<?> supplier() {
+                return Stores.persistentKeyValueStore(STORE_NAME);
+            }
+
+            @Override
+            public boolean timestamped() {
+                return false;
+            }
+
+            @Override
+            public boolean keyValue() {
+                return true;
+            }
+        },
+        TIME_ROCKS_KV {
+            @Override
+            public StoreSupplier<?> supplier() {
+                return Stores.persistentTimestampedKeyValueStore(STORE_NAME);
+            }
+
+            @Override
+            public boolean keyValue() {
+                return true;
+            }
+        },
+        IN_MEMORY_WINDOW {
+            @Override
+            public StoreSupplier<?> supplier() {
+                return Stores.inMemoryWindowStore(STORE_NAME, Duration.ofDays(1), WINDOW_SIZE,
+                                                  false
+                );
+            }
+
+            @Override
+            public boolean isWindowed() {
+                return true;
+            }
+        },
+        ROCKS_WINDOW {
+            @Override
+            public StoreSupplier<?> supplier() {
+                return Stores.persistentWindowStore(STORE_NAME, Duration.ofDays(1), WINDOW_SIZE,
+                                                    false
+                );
+            }
+
+            @Override
+            public boolean isWindowed() {
+                return true;
+            }
+
+            @Override
+            public boolean timestamped() {
+                return false;
+            }
+        },
+        TIME_ROCKS_WINDOW {
+            @Override
+            public StoreSupplier<?> supplier() {
+                return Stores.persistentTimestampedWindowStore(STORE_NAME, Duration.ofDays(1),
+                                                               WINDOW_SIZE, false
+                );
+            }
+
+            @Override
+            public boolean isWindowed() {
+                return true;
+            }
+        },
+        IN_MEMORY_SESSION {
+            @Override
+            public StoreSupplier<?> supplier() {
+                return Stores.inMemorySessionStore(STORE_NAME, Duration.ofDays(1));
+            }
+
+            @Override
+            public boolean isSession() {
+                return true;
+            }
+        },
+        ROCKS_SESSION {
+            @Override
+            public StoreSupplier<?> supplier() {
+                return Stores.persistentSessionStore(STORE_NAME, Duration.ofDays(1));
+            }
+
+            @Override
+            public boolean isSession() {
+                return true;
+            }
+        };
+
+        public abstract StoreSupplier<?> supplier();
+
+        public boolean timestamped() {
+            return true; // most stores are timestamped
+        }
+
+        public boolean keyValue() {
+            return false;
+        }
+
+        public boolean isWindowed() {
+            return false;
+        }
+
+        public boolean isSession() {
+            return false;
+        }
+    }
+
+    @Parameterized.Parameters(name = "cache={0}, log={1}, supplier={2}, kind={3}")
+    public static Collection<Object[]> data() {
+        LOG.info("Generating test cases according to random seed: {}", SEED);
+        final List<Object[]> values = new ArrayList<>();
+        for (final boolean cacheEnabled : Arrays.asList(true, false)) {
+            for (final boolean logEnabled : Arrays.asList(true, false)) {
+                for (final StoresToTest toTest : StoresToTest.values()) {
+                    // We don't need to test if non-persistent stores without logging
+                    // survive restarts, since those are by definition not durable.
+                    if (logEnabled || toTest.supplier().get().persistent()) {
+                        for (final String kind : Arrays.asList("DSL", "PAPI")) {
+                            values.add(new Object[]{cacheEnabled, logEnabled, toTest.name(), kind});
+                        }
+                    }
+                }
+            }
+        }
+        return values;
+    }
+
+    public PositionRestartIntegrationTest(
+        final boolean cache,
+        final boolean log,
+        final String storeToTest,
+        final String kind) {
+        this.cache = cache;
+        this.log = log;
+        this.storeToTest = StoresToTest.valueOf(storeToTest);
+        this.kind = kind;
+        this.streamsConfig = streamsConfiguration(
+            cache,
+            log,
+            storeToTest,
+            kind
+        );
+    }
+
+    @BeforeClass
+    public static void before()
+        throws InterruptedException, IOException, ExecutionException, TimeoutException {
+
+        CLUSTER.start();
+        CLUSTER.deleteAllTopicsAndWait(60 * 1000L);
+        final int partitions = 2;
+        CLUSTER.createTopic(INPUT_TOPIC_NAME, partitions, 1);
+
+        final Properties producerProps = new Properties();
+        producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers());
+        producerProps.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, IntegerSerializer.class);
+        producerProps.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, IntegerSerializer.class);
+
+        final List<Future<RecordMetadata>> futures = new LinkedList<>();
+        try (final Producer<Integer, Integer> producer = new KafkaProducer<>(producerProps)) {
+            for (int i = 0; i < 4; i++) {
+                final Future<RecordMetadata> send = producer.send(
+                    new ProducerRecord<>(
+                        INPUT_TOPIC_NAME,
+                        i % partitions,
+                        RECORD_TIME,
+                        i,
+                        i,
+                        null
+                    )
+                );
+                futures.add(send);
+                Time.SYSTEM.sleep(1L);
+            }
+            producer.flush();
+
+            for (final Future<RecordMetadata> future : futures) {
+                final RecordMetadata recordMetadata = future.get(1, TimeUnit.MINUTES);
+                assertThat(recordMetadata.hasOffset(), is(true));
+                INPUT_POSITION.withComponent(
+                    recordMetadata.topic(),
+                    recordMetadata.partition(),
+                    recordMetadata.offset()
+                );
+            }
+        }
+
+        assertThat(INPUT_POSITION, equalTo(
+            Position
+                .emptyPosition()
+                .withComponent(INPUT_TOPIC_NAME, 0, 1L)
+                .withComponent(INPUT_TOPIC_NAME, 1, 1L)
+        ));
+    }
+
+    @Before
+    public void beforeTest() {
+        beforeTest(true);
+    }
+
+    public void beforeTest(final boolean cleanup) {
+        final StoreSupplier<?> supplier = storeToTest.supplier();
+
+        final StreamsBuilder builder = new StreamsBuilder();
+        if (Objects.equals(kind, "DSL") && supplier instanceof KeyValueBytesStoreSupplier) {
+            setUpKeyValueDSLTopology((KeyValueBytesStoreSupplier) supplier, builder);
+        } else if (Objects.equals(kind, "PAPI") && supplier instanceof KeyValueBytesStoreSupplier) {
+            setUpKeyValuePAPITopology((KeyValueBytesStoreSupplier) supplier, builder);
+        } else if (Objects.equals(kind, "DSL") && supplier instanceof WindowBytesStoreSupplier) {
+            setUpWindowDSLTopology((WindowBytesStoreSupplier) supplier, builder);
+        } else if (Objects.equals(kind, "PAPI") && supplier instanceof WindowBytesStoreSupplier) {
+            setUpWindowPAPITopology((WindowBytesStoreSupplier) supplier, builder);
+        } else if (Objects.equals(kind, "DSL") && supplier instanceof SessionBytesStoreSupplier) {
+            setUpSessionDSLTopology((SessionBytesStoreSupplier) supplier, builder);
+        } else if (Objects.equals(kind, "PAPI") && supplier instanceof SessionBytesStoreSupplier) {
+            setUpSessionPAPITopology((SessionBytesStoreSupplier) supplier, builder);
+        } else {
+            throw new AssertionError("Store supplier is an unrecognized type.");
+        }
+
+        kafkaStreams =
+            IntegrationTestUtils.getStartedStreams(
+                streamsConfig,
+                builder,
+                cleanup
+            );
+    }
+
+    @After
+    public void afterTest() {
+        afterTest(true);
+    }
+
+    public void afterTest(final boolean cleanup) {
+        if (kafkaStreams != null) {
+            kafkaStreams.close();
+            if (cleanup) {
+                kafkaStreams.cleanUp();
+            }
+        }
+    }
+
+    @AfterClass
+    public static void after() {
+        CLUSTER.stop();
+    }
+
+    public void reboot() {
+        afterTest(false);
+        beforeTest(false);
+    }
+
+    @Test
+    public void verifyStore() {
+        final Query<?> query;
+        if (storeToTest.keyValue()) {
+            query = RangeQuery.withNoBounds();
+        } else if (storeToTest.isWindowed()) {
+            query = WindowKeyQuery.withKeyAndWindowStartRange(
+                2,
+                Instant.ofEpochMilli(WINDOW_START),
+                Instant.ofEpochMilli(WINDOW_START)
+            );
+        } else if (storeToTest.isSession()) {
+            query = WindowRangeQuery.withKey(2);
+        } else {
+            throw new AssertionError("Unhandled store type: " + storeToTest);
+        }
+        shouldReachExpectedPosition(query);
+
+        reboot();
+
+        shouldReachExpectedPosition(query);
+    }
+
+    private void shouldReachExpectedPosition(final Query<?> query) {
+        final StateQueryRequest<?> request =
+            inStore(STORE_NAME)
+                .withQuery(query)
+                .withPartitions(mkSet(0, 1))
+                .withPositionBound(PositionBound.at(INPUT_POSITION));
+
+        final StateQueryResult<?> result =
+            IntegrationTestUtils.iqv2WaitForResult(kafkaStreams, request);
+
+        assertThat(result.getPosition(), is(INPUT_POSITION));
+    }
+
+    private void setUpSessionDSLTopology(final SessionBytesStoreSupplier supplier,
+                                         final StreamsBuilder builder) {
+        final Materialized<Integer, Integer, SessionStore<Bytes, byte[]>> materialized =
+            Materialized.as(supplier);
+
+        if (cache) {
+            materialized.withCachingEnabled();
+        } else {
+            materialized.withCachingDisabled();
+        }
+
+        if (log) {
+            materialized.withLoggingEnabled(Collections.emptyMap());
+        } else {
+            materialized.withLoggingDisabled();
+        }
+
+        builder
+            .stream(INPUT_TOPIC_NAME, Consumed.with(Serdes.Integer(), Serdes.Integer()))
+            .groupByKey()
+            .windowedBy(SessionWindows.ofInactivityGapWithNoGrace(WINDOW_SIZE))
+            .aggregate(
+                () -> 0,
+                (key, value, aggregate) -> aggregate + value,
+                (aggKey, aggOne, aggTwo) -> aggOne + aggTwo,
+                materialized
+            );
+    }
+
+    private void setUpWindowDSLTopology(final WindowBytesStoreSupplier supplier,
+                                        final StreamsBuilder builder) {
+        final Materialized<Integer, Integer, WindowStore<Bytes, byte[]>> materialized =
+            Materialized.as(supplier);
+
+        if (cache) {
+            materialized.withCachingEnabled();
+        } else {
+            materialized.withCachingDisabled();
+        }
+
+        if (log) {
+            materialized.withLoggingEnabled(Collections.emptyMap());
+        } else {
+            materialized.withLoggingDisabled();
+        }
+
+        builder
+            .stream(INPUT_TOPIC_NAME, Consumed.with(Serdes.Integer(), Serdes.Integer()))
+            .groupByKey()
+            .windowedBy(TimeWindows.ofSizeWithNoGrace(WINDOW_SIZE))
+            .aggregate(
+                () -> 0,
+                (key, value, aggregate) -> aggregate + value,
+                materialized
+            );
+    }
+
+    private void setUpKeyValueDSLTopology(final KeyValueBytesStoreSupplier supplier,
+                                          final StreamsBuilder builder) {
+        final Materialized<Integer, Integer, KeyValueStore<Bytes, byte[]>> materialized =
+            Materialized.as(supplier);
+
+        if (cache) {
+            materialized.withCachingEnabled();
+        } else {
+            materialized.withCachingDisabled();
+        }
+
+        if (log) {
+            materialized.withLoggingEnabled(Collections.emptyMap());
+        } else {
+            materialized.withLoggingDisabled();
+        }
+
+        builder.table(
+            INPUT_TOPIC_NAME,
+            Consumed.with(Serdes.Integer(), Serdes.Integer()),
+            materialized
+        );
+    }
+
+    private void setUpKeyValuePAPITopology(final KeyValueBytesStoreSupplier supplier,
+                                           final StreamsBuilder builder) {
+        final StoreBuilder<?> keyValueStoreStoreBuilder;
+        final ProcessorSupplier<Integer, Integer, Void, Void> processorSupplier;
+        if (storeToTest.timestamped()) {
+            keyValueStoreStoreBuilder = Stores.timestampedKeyValueStoreBuilder(
+                supplier,
+                Serdes.Integer(),
+                Serdes.Integer()
+            );
+            processorSupplier = () -> new ContextualProcessor<Integer, Integer, Void, Void>() {
+                @Override
+                public void process(final Record<Integer, Integer> record) {
+                    final TimestampedKeyValueStore<Integer, Integer> stateStore =
+                        context().getStateStore(keyValueStoreStoreBuilder.name());
+                    stateStore.put(
+                        record.key(),
+                        ValueAndTimestamp.make(
+                            record.value(), record.timestamp()
+                        )
+                    );
+                }
+            };
+        } else {
+            keyValueStoreStoreBuilder = Stores.keyValueStoreBuilder(
+                supplier,
+                Serdes.Integer(),
+                Serdes.Integer()
+            );
+            processorSupplier =
+                () -> new ContextualProcessor<Integer, Integer, Void, Void>() {
+                    @Override
+                    public void process(final Record<Integer, Integer> record) {
+                        final KeyValueStore<Integer, Integer> stateStore =
+                            context().getStateStore(keyValueStoreStoreBuilder.name());
+                        stateStore.put(record.key(), record.value());
+                    }
+                };
+        }
+        if (cache) {
+            keyValueStoreStoreBuilder.withCachingEnabled();
+        } else {
+            keyValueStoreStoreBuilder.withCachingDisabled();
+        }
+        if (log) {
+            keyValueStoreStoreBuilder.withLoggingEnabled(Collections.emptyMap());
+        } else {
+            keyValueStoreStoreBuilder.withLoggingDisabled();
+        }
+        builder.addStateStore(keyValueStoreStoreBuilder);
+        builder
+            .stream(INPUT_TOPIC_NAME, Consumed.with(Serdes.Integer(), Serdes.Integer()))
+            .process(processorSupplier, keyValueStoreStoreBuilder.name());
+
+    }
+
+    private void setUpWindowPAPITopology(final WindowBytesStoreSupplier supplier,
+                                         final StreamsBuilder builder) {
+        final StoreBuilder<?> windowStoreStoreBuilder;
+        final ProcessorSupplier<Integer, Integer, Void, Void> processorSupplier;
+        if (storeToTest.timestamped()) {
+            windowStoreStoreBuilder = Stores.timestampedWindowStoreBuilder(
+                supplier,
+                Serdes.Integer(),
+                Serdes.Integer()
+            );
+            processorSupplier = () -> new ContextualProcessor<Integer, Integer, Void, Void>() {
+                @Override
+                public void process(final Record<Integer, Integer> record) {
+                    final TimestampedWindowStore<Integer, Integer> stateStore =
+                        context().getStateStore(windowStoreStoreBuilder.name());
+                    stateStore.put(
+                        record.key(),
+                        ValueAndTimestamp.make(
+                            record.value(), record.timestamp()
+                        ),
+                        WINDOW_START
+                    );
+                }
+            };
+        } else {
+            windowStoreStoreBuilder = Stores.windowStoreBuilder(
+                supplier,
+                Serdes.Integer(),
+                Serdes.Integer()
+            );
+            processorSupplier =
+                () -> new ContextualProcessor<Integer, Integer, Void, Void>() {
+                    @Override
+                    public void process(final Record<Integer, Integer> record) {
+                        final WindowStore<Integer, Integer> stateStore =
+                            context().getStateStore(windowStoreStoreBuilder.name());
+                        stateStore.put(record.key(), record.value(), WINDOW_START);
+                    }
+                };
+        }
+        if (cache) {
+            windowStoreStoreBuilder.withCachingEnabled();
+        } else {
+            windowStoreStoreBuilder.withCachingDisabled();
+        }
+        if (log) {
+            windowStoreStoreBuilder.withLoggingEnabled(Collections.emptyMap());
+        } else {
+            windowStoreStoreBuilder.withLoggingDisabled();
+        }
+        builder.addStateStore(windowStoreStoreBuilder);
+        builder
+            .stream(INPUT_TOPIC_NAME, Consumed.with(Serdes.Integer(), Serdes.Integer()))
+            .process(processorSupplier, windowStoreStoreBuilder.name());
+
+    }
+
+    private void setUpSessionPAPITopology(final SessionBytesStoreSupplier supplier,
+                                          final StreamsBuilder builder) {
+        final StoreBuilder<?> sessionStoreStoreBuilder;
+        final ProcessorSupplier<Integer, Integer, Void, Void> processorSupplier;
+        sessionStoreStoreBuilder = Stores.sessionStoreBuilder(
+            supplier,
+            Serdes.Integer(),
+            Serdes.Integer()
+        );
+        processorSupplier = () -> new ContextualProcessor<Integer, Integer, Void, Void>() {
+            @Override
+            public void process(final Record<Integer, Integer> record) {
+                final SessionStore<Integer, Integer> stateStore =
+                    context().getStateStore(sessionStoreStoreBuilder.name());
+                stateStore.put(
+                    new Windowed<>(record.key(), new SessionWindow(WINDOW_START, WINDOW_START)),
+                    record.value()
+                );
+            }
+        };
+        if (cache) {
+            sessionStoreStoreBuilder.withCachingEnabled();
+        } else {
+            sessionStoreStoreBuilder.withCachingDisabled();
+        }
+        if (log) {
+            sessionStoreStoreBuilder.withLoggingEnabled(Collections.emptyMap());
+        } else {
+            sessionStoreStoreBuilder.withLoggingDisabled();
+        }
+        builder.addStateStore(sessionStoreStoreBuilder);
+        builder
+            .stream(INPUT_TOPIC_NAME, Consumed.with(Serdes.Integer(), Serdes.Integer()))
+            .process(processorSupplier, sessionStoreStoreBuilder.name());
+    }
+
+    private static Properties streamsConfiguration(final boolean cache, final boolean log,
+                                                   final String supplier, final String kind) {
+        final String safeTestName =
+            PositionRestartIntegrationTest.class.getName() + "-" + cache + "-" + log + "-"
+                + supplier + "-" + kind;
+        final Properties config = new Properties();
+        config.put(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, StreamsConfig.OPTIMIZE);
+        config.put(StreamsConfig.APPLICATION_ID_CONFIG, "app-" + safeTestName);
+        config.put(StreamsConfig.APPLICATION_SERVER_CONFIG, "localhost:" + (++port));
+        config.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers());
+        config.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath());
+        config.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.Integer().getClass());
+        config.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.Integer().getClass());
+        config.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1);
+        config.put(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, 100);
+        config.put(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, 200);
+        config.put(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, 1000);
+        config.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100L);
+        config.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, 1);
+        config.put(InternalConfig.IQ_CONSISTENCY_OFFSET_VECTOR_ENABLED, true);
+        return config;
+    }
+}
\ No newline at end of file
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java
index 9b50ead..0670fed 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java
@@ -231,7 +231,7 @@ public class GlobalStateManagerImplTest {
         stateManager.initialize();
 
         try {
-            stateManager.registerStore(new NoOpReadOnlyStore<>("not-in-topology"), stateRestoreCallback);
+            stateManager.registerStore(new NoOpReadOnlyStore<>("not-in-topology"), stateRestoreCallback, null);
             fail("should have raised an illegal argument exception as store is not in the topology");
         } catch (final IllegalArgumentException e) {
             // pass
@@ -242,9 +242,9 @@ public class GlobalStateManagerImplTest {
     public void shouldThrowIllegalArgumentExceptionIfAttemptingToRegisterStoreTwice() {
         stateManager.initialize();
         initializeConsumer(2, 0, t1);
-        stateManager.registerStore(store1, stateRestoreCallback);
+        stateManager.registerStore(store1, stateRestoreCallback, null);
         try {
-            stateManager.registerStore(store1, stateRestoreCallback);
+            stateManager.registerStore(store1, stateRestoreCallback, null);
             fail("should have raised an illegal argument exception as store has already been registered");
         } catch (final IllegalArgumentException e) {
             // pass
@@ -255,7 +255,7 @@ public class GlobalStateManagerImplTest {
     public void shouldThrowStreamsExceptionIfNoPartitionsFoundForStore() {
         stateManager.initialize();
         try {
-            stateManager.registerStore(store1, stateRestoreCallback);
+            stateManager.registerStore(store1, stateRestoreCallback, null);
             fail("Should have raised a StreamsException as there are no partition for the store");
         } catch (final StreamsException e) {
             // pass
@@ -267,7 +267,7 @@ public class GlobalStateManagerImplTest {
         initializeConsumer(1, 0, t1);
 
         stateManager.initialize();
-        stateManager.registerStore(store1, stateRestoreCallback);
+        stateManager.registerStore(store1, stateRestoreCallback, null);
 
         final KeyValue<byte[], byte[]> restoredRecord = stateRestoreCallback.restored.get(0);
         assertEquals(3, restoredRecord.key.length);
@@ -282,8 +282,8 @@ public class GlobalStateManagerImplTest {
         stateManager.registerStore(
             new WrappedStateStore<NoOpReadOnlyStore<Object, Object>, Object, Object>(store1) {
             },
-            stateRestoreCallback
-        );
+            stateRestoreCallback,
+                null);
 
         final KeyValue<byte[], byte[]> restoredRecord = stateRestoreCallback.restored.get(0);
         assertEquals(3, restoredRecord.key.length);
@@ -295,7 +295,7 @@ public class GlobalStateManagerImplTest {
         initializeConsumer(1, 0, t2);
 
         stateManager.initialize();
-        stateManager.registerStore(store2, stateRestoreCallback);
+        stateManager.registerStore(store2, stateRestoreCallback, null);
 
         final KeyValue<byte[], byte[]> restoredRecord = stateRestoreCallback.restored.get(0);
         assertEquals(3, restoredRecord.key.length);
@@ -310,8 +310,8 @@ public class GlobalStateManagerImplTest {
         stateManager.registerStore(
             new WrappedStateStore<NoOpReadOnlyStore<Object, Object>, Object, Object>(store2) {
             },
-            stateRestoreCallback
-        );
+            stateRestoreCallback,
+            null);
 
         final KeyValue<byte[], byte[]> restoredRecord = stateRestoreCallback.restored.get(0);
         assertEquals(3, restoredRecord.key.length);
@@ -324,7 +324,7 @@ public class GlobalStateManagerImplTest {
 
         stateManager.initialize();
 
-        stateManager.registerStore(store1, stateRestoreCallback);
+        stateManager.registerStore(store1, stateRestoreCallback, null);
         assertEquals(2, stateRestoreCallback.restored.size());
     }
 
@@ -333,7 +333,7 @@ public class GlobalStateManagerImplTest {
         initializeConsumer(5, 1, t1);
         stateManager.initialize();
 
-        stateManager.registerStore(store1, stateRestoreCallback);
+        stateManager.registerStore(store1, stateRestoreCallback, null);
 
         assertThat(stateRestoreListener.restoreStartOffset, equalTo(1L));
         assertThat(stateRestoreListener.restoreEndOffset, equalTo(6L));
@@ -354,7 +354,7 @@ public class GlobalStateManagerImplTest {
         offsetCheckpoint.write(Collections.singletonMap(t1, 5L));
 
         stateManager.initialize();
-        stateManager.registerStore(store1, stateRestoreCallback);
+        stateManager.registerStore(store1, stateRestoreCallback, null);
         assertEquals(5, stateRestoreCallback.restored.size());
     }
 
@@ -364,9 +364,9 @@ public class GlobalStateManagerImplTest {
         stateManager.initialize();
         // register the stores
         initializeConsumer(1, 0, t1);
-        stateManager.registerStore(store1, stateRestoreCallback);
+        stateManager.registerStore(store1, stateRestoreCallback, null);
         initializeConsumer(1, 0, t2);
-        stateManager.registerStore(store2, stateRestoreCallback);
+        stateManager.registerStore(store2, stateRestoreCallback, null);
 
         stateManager.flush();
         assertTrue(store1.flushed);
@@ -383,7 +383,7 @@ public class GlobalStateManagerImplTest {
             public void flush() {
                 throw new RuntimeException("KABOOM!");
             }
-        }, stateRestoreCallback);
+        }, stateRestoreCallback, null);
         assertThrows(StreamsException.class, stateManager::flush);
     }
 
@@ -392,9 +392,9 @@ public class GlobalStateManagerImplTest {
         stateManager.initialize();
         // register the stores
         initializeConsumer(1, 0, t1);
-        stateManager.registerStore(store1, stateRestoreCallback);
+        stateManager.registerStore(store1, stateRestoreCallback, null);
         initializeConsumer(1, 0, t2);
-        stateManager.registerStore(store2, stateRestoreCallback);
+        stateManager.registerStore(store2, stateRestoreCallback, null);
 
         stateManager.close();
         assertFalse(store1.isOpen());
@@ -410,7 +410,7 @@ public class GlobalStateManagerImplTest {
             public void close() {
                 throw new RuntimeException("KABOOM!");
             }
-        }, stateRestoreCallback);
+        }, stateRestoreCallback, null);
 
         assertThrows(ProcessorStateException.class, stateManager::close);
     }
@@ -419,7 +419,7 @@ public class GlobalStateManagerImplTest {
     public void shouldThrowIllegalArgumentExceptionIfCallbackIsNull() {
         stateManager.initialize();
         try {
-            stateManager.registerStore(store1, null);
+            stateManager.registerStore(store1, null, null);
             fail("should have thrown due to null callback");
         } catch (final IllegalArgumentException e) {
             //pass
@@ -438,7 +438,7 @@ public class GlobalStateManagerImplTest {
                 }
                 super.close();
             }
-        }, stateRestoreCallback);
+        }, stateRestoreCallback, null);
         stateManager.close();
 
         stateManager.close();
@@ -455,10 +455,10 @@ public class GlobalStateManagerImplTest {
                 throw new RuntimeException("KABOOM!");
             }
         };
-        stateManager.registerStore(store, stateRestoreCallback);
+        stateManager.registerStore(store, stateRestoreCallback, null);
 
         initializeConsumer(1, 0, t2);
-        stateManager.registerStore(store2, stateRestoreCallback);
+        stateManager.registerStore(store2, stateRestoreCallback, null);
 
         try {
             stateManager.close();
@@ -486,9 +486,9 @@ public class GlobalStateManagerImplTest {
     public void shouldNotRemoveOffsetsOfUnUpdatedTablesDuringCheckpoint() {
         stateManager.initialize();
         initializeConsumer(10, 0, t1);
-        stateManager.registerStore(store1, stateRestoreCallback);
+        stateManager.registerStore(store1, stateRestoreCallback, null);
         initializeConsumer(20, 0, t2);
-        stateManager.registerStore(store2, stateRestoreCallback);
+        stateManager.registerStore(store2, stateRestoreCallback, null);
 
         final Map<TopicPartition, Long> initialCheckpoint = stateManager.changelogOffsets();
         stateManager.updateChangelogOffsets(Collections.singletonMap(t1, 101L));
@@ -515,7 +515,7 @@ public class GlobalStateManagerImplTest {
         consumer.addRecord(new ConsumerRecord<>(t1.topic(), t1.partition(), 2, expectedKey, expectedValue));
 
         stateManager.initialize();
-        stateManager.registerStore(store1, stateRestoreCallback);
+        stateManager.registerStore(store1, stateRestoreCallback, null);
         final KeyValue<byte[], byte[]> restoredKv = stateRestoreCallback.restored.get(0);
         assertThat(stateRestoreCallback.restored, equalTo(Collections.singletonList(KeyValue.pair(restoredKv.key, restoredKv.value))));
     }
@@ -524,7 +524,7 @@ public class GlobalStateManagerImplTest {
     public void shouldCheckpointRestoredOffsetsToFile() throws IOException {
         stateManager.initialize();
         initializeConsumer(10, 0, t1);
-        stateManager.registerStore(store1, stateRestoreCallback);
+        stateManager.registerStore(store1, stateRestoreCallback, null);
         stateManager.checkpoint();
         stateManager.close();
 
@@ -537,7 +537,7 @@ public class GlobalStateManagerImplTest {
     public void shouldSkipGlobalInMemoryStoreOffsetsToFile() throws IOException {
         stateManager.initialize();
         initializeConsumer(10, 0, t3);
-        stateManager.registerStore(store3, stateRestoreCallback);
+        stateManager.registerStore(store3, stateRestoreCallback, null);
         stateManager.close();
 
         assertThat(readOffsetsCheckpoint(), equalTo(Collections.emptyMap()));
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java
index 3e88ced..5947842 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java
@@ -26,14 +26,17 @@ import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.errors.ProcessorStateException;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.errors.TaskCorruptedException;
+import org.apache.kafka.streams.processor.CommitCallback;
 import org.apache.kafka.streams.processor.StateRestoreCallback;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.StateStoreContext;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender;
 import org.apache.kafka.streams.processor.internals.ProcessorStateManager.StateStoreMetadata;
+import org.apache.kafka.streams.query.Position;
 import org.apache.kafka.streams.state.TimestampedBytesStore;
 import org.apache.kafka.streams.state.internals.OffsetCheckpoint;
+import org.apache.kafka.streams.state.internals.StoreQueryUtils;
 import org.apache.kafka.test.MockKeyValueStore;
 import org.apache.kafka.test.MockRestoreCallback;
 import org.apache.kafka.test.TestUtils;
@@ -67,9 +70,12 @@ import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.common.utils.Utils.mkSet;
 import static org.apache.kafka.streams.processor.internals.StateManagerUtil.CHECKPOINT_FILE_NAME;
 import static org.easymock.EasyMock.expect;
+import static org.easymock.EasyMock.expectLastCall;
+import static org.easymock.EasyMock.mock;
 import static org.easymock.EasyMock.replay;
 import static org.easymock.EasyMock.reset;
 import static org.easymock.EasyMock.verify;
+import static org.hamcrest.CoreMatchers.containsString;
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.CoreMatchers.notNullValue;
 import static org.hamcrest.MatcherAssert.assertThat;
@@ -224,8 +230,8 @@ public class ProcessorStateManagerTest {
             ),
             Collections.emptySet());
 
-        stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback);
-        stateMgr.registerStore(persistentStoreTwo, persistentStore.stateRestoreCallback);
+        stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback, null);
+        stateMgr.registerStore(persistentStoreTwo, persistentStore.stateRestoreCallback, null);
 
         assertThrows(
             IllegalStateException.class,
@@ -242,7 +248,7 @@ public class ProcessorStateManagerTest {
         final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE);
 
         try {
-            stateMgr.registerStore(persistentStore, restoreCallback);
+            stateMgr.registerStore(persistentStore, restoreCallback, null);
             final StateStoreMetadata storeMetadata = stateMgr.storeMetadata(persistentStorePartition);
             assertThat(storeMetadata, notNullValue());
 
@@ -262,7 +268,7 @@ public class ProcessorStateManagerTest {
         final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE);
 
         try {
-            stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback);
+            stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback, null);
             final StateStoreMetadata storeMetadata = stateMgr.storeMetadata(persistentStorePartition);
             assertThat(storeMetadata, notNullValue());
 
@@ -283,7 +289,7 @@ public class ProcessorStateManagerTest {
         final MockKeyValueStore store = getConverterStore();
 
         try {
-            stateMgr.registerStore(store, store.stateRestoreCallback);
+            stateMgr.registerStore(store, store.stateRestoreCallback, null);
             final StateStoreMetadata storeMetadata = stateMgr.storeMetadata(persistentStorePartition);
             assertThat(storeMetadata, notNullValue());
 
@@ -314,7 +320,7 @@ public class ProcessorStateManagerTest {
         stateMgr.registerStateStores(singletonList(store), context);
         verify(context, store);
 
-        stateMgr.registerStore(store, noopStateRestoreCallback);
+        stateMgr.registerStore(store, noopStateRestoreCallback, null);
         assertTrue(changelogReader.isPartitionRegistered(persistentStorePartition));
 
         reset(store);
@@ -344,7 +350,7 @@ public class ProcessorStateManagerTest {
         stateMgr.registerStateStores(singletonList(store), context);
         verify(context, store);
 
-        stateMgr.registerStore(store, noopStateRestoreCallback);
+        stateMgr.registerStore(store, noopStateRestoreCallback, null);
         assertTrue(changelogReader.isPartitionRegistered(persistentStorePartition));
 
         stateMgr.recycle();
@@ -367,7 +373,7 @@ public class ProcessorStateManagerTest {
         final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE);
 
         try {
-            stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback);
+            stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback, null);
             assertTrue(changelogReader.isPartitionRegistered(persistentStorePartition));
         } finally {
             stateMgr.close();
@@ -379,7 +385,7 @@ public class ProcessorStateManagerTest {
         final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE);
 
         try {
-            stateMgr.registerStore(nonPersistentStore, nonPersistentStore.stateRestoreCallback);
+            stateMgr.registerStore(nonPersistentStore, nonPersistentStore.stateRestoreCallback, null);
             assertTrue(changelogReader.isPartitionRegistered(nonPersistentStorePartition));
         } finally {
             stateMgr.close();
@@ -399,7 +405,7 @@ public class ProcessorStateManagerTest {
             emptySet());
 
         try {
-            stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback);
+            stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback, null);
             assertFalse(changelogReader.isPartitionRegistered(persistentStorePartition));
         } finally {
             stateMgr.close();
@@ -420,9 +426,9 @@ public class ProcessorStateManagerTest {
         final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE);
 
         try {
-            stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback);
-            stateMgr.registerStore(persistentStoreTwo, persistentStoreTwo.stateRestoreCallback);
-            stateMgr.registerStore(nonPersistentStore, nonPersistentStore.stateRestoreCallback);
+            stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback, null);
+            stateMgr.registerStore(persistentStoreTwo, persistentStoreTwo.stateRestoreCallback, null);
+            stateMgr.registerStore(nonPersistentStore, nonPersistentStore.stateRestoreCallback, null);
             stateMgr.initializeStoreOffsetsFromCheckpoint(true);
 
             assertTrue(checkpointFile.exists());
@@ -461,9 +467,9 @@ public class ProcessorStateManagerTest {
         final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE, true);
 
         try {
-            stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback);
-            stateMgr.registerStore(persistentStoreTwo, persistentStoreTwo.stateRestoreCallback);
-            stateMgr.registerStore(nonPersistentStore, nonPersistentStore.stateRestoreCallback);
+            stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback, null);
+            stateMgr.registerStore(persistentStoreTwo, persistentStoreTwo.stateRestoreCallback, null);
+            stateMgr.registerStore(nonPersistentStore, nonPersistentStore.stateRestoreCallback, null);
             stateMgr.initializeStoreOffsetsFromCheckpoint(true);
 
             assertFalse(checkpointFile.exists());
@@ -492,8 +498,8 @@ public class ProcessorStateManagerTest {
     public void shouldGetRegisteredStore() {
         final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE);
         try {
-            stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback);
-            stateMgr.registerStore(nonPersistentStore, nonPersistentStore.stateRestoreCallback);
+            stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback, null);
+            stateMgr.registerStore(nonPersistentStore, nonPersistentStore.stateRestoreCallback, null);
 
             assertNull(stateMgr.getStore("noSuchStore"));
             assertEquals(persistentStore, stateMgr.getStore(persistentStoreName));
@@ -506,7 +512,7 @@ public class ProcessorStateManagerTest {
     @Test
     public void shouldGetChangelogPartitionForRegisteredStore() {
         final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE);
-        stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback);
+        stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback, null);
 
         final TopicPartition changelogPartition = stateMgr.registeredChangelogPartitionFor(persistentStoreName);
 
@@ -531,7 +537,7 @@ public class ProcessorStateManagerTest {
         final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE);
         final String storeName = "store-with-logging-disabled";
         final MockKeyValueStore storeWithLoggingDisabled = new MockKeyValueStore(storeName, true);
-        stateMgr.registerStore(storeWithLoggingDisabled, null);
+        stateMgr.registerStore(storeWithLoggingDisabled, null, null);
 
         assertThrows("Registered state store " + storeName
                 + " does not have a registered changelog partition."
@@ -556,8 +562,8 @@ public class ProcessorStateManagerTest {
             // make sure the checkpoint file is not written yet
             assertFalse(checkpointFile.exists());
 
-            stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback);
-            stateMgr.registerStore(nonPersistentStore, nonPersistentStore.stateRestoreCallback);
+            stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback, null);
+            stateMgr.registerStore(nonPersistentStore, nonPersistentStore.stateRestoreCallback, null);
         } finally {
             stateMgr.flush();
 
@@ -590,7 +596,7 @@ public class ProcessorStateManagerTest {
 
         final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE);
         try {
-            stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback);
+            stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback, null);
             stateMgr.initializeStoreOffsetsFromCheckpoint(true);
 
             final StateStoreMetadata storeMetadata = stateMgr.storeMetadata(persistentStorePartition);
@@ -620,7 +626,7 @@ public class ProcessorStateManagerTest {
         final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE);
 
         try {
-            stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback);
+            stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback, null);
             stateMgr.initializeStoreOffsetsFromCheckpoint(true);
 
             final StateStoreMetadata storeMetadata = stateMgr.storeMetadata(persistentStorePartition);
@@ -642,7 +648,7 @@ public class ProcessorStateManagerTest {
         final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE);
 
         try {
-            stateMgr.registerStore(nonPersistentStore, nonPersistentStore.stateRestoreCallback);
+            stateMgr.registerStore(nonPersistentStore, nonPersistentStore.stateRestoreCallback, null);
             stateMgr.initializeStoreOffsetsFromCheckpoint(true);
 
             final StateStoreMetadata storeMetadata = stateMgr.storeMetadata(nonPersistentStorePartition);
@@ -671,7 +677,7 @@ public class ProcessorStateManagerTest {
             emptySet());
 
         try {
-            stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback);
+            stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback, null);
 
             stateMgr.updateChangelogOffsets(singletonMap(persistentStorePartition, 987L));
             stateMgr.checkpoint();
@@ -688,17 +694,17 @@ public class ProcessorStateManagerTest {
         final ProcessorStateManager stateManager = getStateManager(Task.TaskType.ACTIVE);
 
         assertThrows(IllegalArgumentException.class, () ->
-            stateManager.registerStore(new MockKeyValueStore(CHECKPOINT_FILE_NAME, true), null));
+            stateManager.registerStore(new MockKeyValueStore(CHECKPOINT_FILE_NAME, true), null, null));
     }
 
     @Test
     public void shouldThrowIllegalArgumentExceptionOnRegisterWhenStoreHasAlreadyBeenRegistered() {
         final ProcessorStateManager stateManager = getStateManager(Task.TaskType.ACTIVE);
 
-        stateManager.registerStore(persistentStore, persistentStore.stateRestoreCallback);
+        stateManager.registerStore(persistentStore, persistentStore.stateRestoreCallback, null);
 
         assertThrows(IllegalArgumentException.class, () ->
-            stateManager.registerStore(persistentStore, persistentStore.stateRestoreCallback));
+            stateManager.registerStore(persistentStore, persistentStore.stateRestoreCallback, null));
     }
 
     @Test
@@ -711,7 +717,7 @@ public class ProcessorStateManagerTest {
                 throw exception;
             }
         };
-        stateManager.registerStore(stateStore, stateStore.stateRestoreCallback);
+        stateManager.registerStore(stateStore, stateStore.stateRestoreCallback, null);
 
         final ProcessorStateException thrown = assertThrows(ProcessorStateException.class, stateManager::flush);
         assertEquals(exception, thrown.getCause());
@@ -727,7 +733,7 @@ public class ProcessorStateManagerTest {
                 throw exception;
             }
         };
-        stateManager.registerStore(stateStore, stateStore.stateRestoreCallback);
+        stateManager.registerStore(stateStore, stateStore.stateRestoreCallback, null);
 
         final StreamsException thrown = assertThrows(StreamsException.class, stateManager::flush);
         assertEquals(exception, thrown);
@@ -743,7 +749,7 @@ public class ProcessorStateManagerTest {
                 throw exception;
             }
         };
-        stateManager.registerStore(stateStore, stateStore.stateRestoreCallback);
+        stateManager.registerStore(stateStore, stateStore.stateRestoreCallback, null);
 
         final ProcessorStateException thrown = assertThrows(ProcessorStateException.class, stateManager::close);
         assertEquals(exception, thrown.getCause());
@@ -759,7 +765,7 @@ public class ProcessorStateManagerTest {
                 throw exception;
             }
         };
-        stateManager.registerStore(stateStore, stateStore.stateRestoreCallback);
+        stateManager.registerStore(stateStore, stateStore.stateRestoreCallback, null);
 
         final StreamsException thrown = assertThrows(StreamsException.class, stateManager::close);
         assertEquals(exception, thrown);
@@ -776,7 +782,7 @@ public class ProcessorStateManagerTest {
     @Test
     public void shouldLogAWarningIfCheckpointThrowsAnIOException() {
         final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE);
-        stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback);
+        stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback, null);
         stateDirectory.clean();
 
         try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(ProcessorStateManager.class)) {
@@ -805,7 +811,7 @@ public class ProcessorStateManagerTest {
     public void shouldThrowIfLoadCheckpointThrows() throws Exception {
         final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE);
 
-        stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback);
+        stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback, null);
         final File file = new File(stateMgr.baseDir(), CHECKPOINT_FILE_NAME);
         file.createNewFile();
         final FileWriter writer = new FileWriter(file);
@@ -824,9 +830,13 @@ public class ProcessorStateManagerTest {
     public void shouldThrowIfRestoreCallbackThrows() {
         final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE);
 
-        stateMgr.registerStore(persistentStore, (key, value) -> {
-            throw new RuntimeException("KABOOM!");
-        });
+        stateMgr.registerStore(
+            persistentStore,
+            (key, value) -> {
+                throw new RuntimeException("KABOOM!");
+            },
+            null
+        );
 
         final StateStoreMetadata storeMetadata = stateMgr.storeMetadata(persistentStorePartition);
 
@@ -856,8 +866,8 @@ public class ProcessorStateManagerTest {
         };
         final ProcessorStateManager stateManager = getStateManager(Task.TaskType.ACTIVE);
 
-        stateManager.registerStore(stateStore1, stateStore1.stateRestoreCallback);
-        stateManager.registerStore(stateStore2, stateStore2.stateRestoreCallback);
+        stateManager.registerStore(stateStore1, stateStore1.stateRestoreCallback, null);
+        stateManager.registerStore(stateStore2, stateStore2.stateRestoreCallback, null);
 
         try {
             stateManager.flush();
@@ -884,8 +894,8 @@ public class ProcessorStateManagerTest {
         };
         final ProcessorStateManager stateManager = getStateManager(Task.TaskType.ACTIVE);
 
-        stateManager.registerStore(stateStore1, stateStore1.stateRestoreCallback);
-        stateManager.registerStore(stateStore2, stateStore2.stateRestoreCallback);
+        stateManager.registerStore(stateStore1, stateStore1.stateRestoreCallback, null);
+        stateManager.registerStore(stateStore2, stateStore2.stateRestoreCallback, null);
 
         try {
             stateManager.close();
@@ -908,9 +918,9 @@ public class ProcessorStateManagerTest {
         final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE, true);
 
         try {
-            stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback);
-            stateMgr.registerStore(persistentStoreTwo, persistentStoreTwo.stateRestoreCallback);
-            stateMgr.registerStore(nonPersistentStore, nonPersistentStore.stateRestoreCallback);
+            stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback, null);
+            stateMgr.registerStore(persistentStoreTwo, persistentStoreTwo.stateRestoreCallback, null);
+            stateMgr.registerStore(nonPersistentStore, nonPersistentStore.stateRestoreCallback, null);
 
             final TaskCorruptedException exception = assertThrows(TaskCorruptedException.class,
                 () -> stateMgr.initializeStoreOffsetsFromCheckpoint(false));
@@ -937,8 +947,8 @@ public class ProcessorStateManagerTest {
         final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE, true);
 
         try {
-            stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback);
-            stateMgr.registerStore(nonPersistentStore, nonPersistentStore.stateRestoreCallback);
+            stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback, null);
+            stateMgr.registerStore(nonPersistentStore, nonPersistentStore.stateRestoreCallback, null);
 
             stateMgr.initializeStoreOffsetsFromCheckpoint(false);
         } finally {
@@ -951,8 +961,8 @@ public class ProcessorStateManagerTest {
         final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE, true);
 
         try {
-            stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback);
-            stateMgr.registerStore(nonPersistentStore, nonPersistentStore.stateRestoreCallback);
+            stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback, null);
+            stateMgr.registerStore(nonPersistentStore, nonPersistentStore.stateRestoreCallback, null);
             stateMgr.initializeStoreOffsetsFromCheckpoint(true);
 
             assertThat(stateMgr.storeMetadata(nonPersistentStorePartition), notNullValue());
@@ -969,8 +979,8 @@ public class ProcessorStateManagerTest {
             assertNull(stateMgr.storeMetadata(nonPersistentStorePartition));
             assertNull(stateMgr.storeMetadata(persistentStorePartition));
 
-            stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback);
-            stateMgr.registerStore(nonPersistentStore, nonPersistentStore.stateRestoreCallback);
+            stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback, null);
+            stateMgr.registerStore(nonPersistentStore, nonPersistentStore.stateRestoreCallback, null);
 
             // This should not throw a TaskCorruptedException!
             stateMgr.initializeStoreOffsetsFromCheckpoint(false);
@@ -986,7 +996,7 @@ public class ProcessorStateManagerTest {
         final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE, true);
 
         try {
-            stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback);
+            stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback, null);
             stateMgr.markChangelogAsCorrupted(mkSet(persistentStorePartition));
 
             final ProcessorStateException thrown = assertThrows(ProcessorStateException.class, () -> stateMgr.initializeStoreOffsetsFromCheckpoint(true));
@@ -1033,6 +1043,128 @@ public class ProcessorStateManagerTest {
         assertTrue(checkpointFile.exists());
     }
 
+    @Test
+    public void shouldWritePositionCheckpointFile() throws IOException {
+        final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE);
+        final Position persistentPosition =
+            Position.emptyPosition().withComponent(persistentStoreTopicName, 1, 123L);
+        final File persistentFile = new File(
+            stateDirectory.getOrCreateDirectoryForTask(taskId),
+            "shouldWritePositionCheckpointFile.position"
+        );
+        final StateStorePositionCommit persistentCheckpoint = new StateStorePositionCommit(persistentFile, persistentPosition);
+        stateMgr.registerStore(
+            persistentStore,
+            persistentStore.stateRestoreCallback,
+            persistentCheckpoint
+        );
+
+        assertFalse(persistentCheckpoint.getFile().exists());
+
+        stateMgr.checkpoint();
+
+        assertTrue(persistentCheckpoint.getFile().exists());
+
+        // the checkpoint file should contain an offset from the persistent store only.
+        final Map<TopicPartition, Long> persistentOffsets = persistentCheckpoint.getOffsetCheckpoint()
+                                                                                .read();
+        assertThat(
+            persistentOffsets,
+            is(singletonMap(new TopicPartition(persistentStoreTopicName, 1), 123L))
+        );
+
+        assertEquals(
+            persistentCheckpoint.getCheckpointedPosition(),
+            persistentCheckpoint.getStateStorePosition()
+        );
+
+        stateMgr.close();
+
+        assertTrue(persistentStore.closed);
+    }
+
+    @Test
+    public void shouldThrowOnFailureToWritePositionCheckpointFile() throws IOException {
+        final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE);
+        final CommitCallback persistentCheckpoint = mock(CommitCallback.class);
+        persistentCheckpoint.onCommit();
+        final IOException ioException = new IOException("asdf");
+        expectLastCall().andThrow(ioException);
+        replay(persistentCheckpoint);
+        stateMgr.registerStore(
+            persistentStore,
+            persistentStore.stateRestoreCallback,
+            persistentCheckpoint
+        );
+
+        final ProcessorStateException processorStateException = assertThrows(
+            ProcessorStateException.class,
+            stateMgr::checkpoint
+        );
+
+        assertThat(
+            processorStateException.getMessage(),
+            containsString(
+                "process-state-manager-test Exception caught while trying to checkpoint store,"
+                    + " changelog partition test-application-My-Topology-persistentStore-changelog-1"
+            )
+        );
+        assertThat(processorStateException.getCause(), is(ioException));
+    }
+
+    @Test
+    public void shouldLoadMissingFileAsEmptyPosition() {
+        final Position persistentPosition =
+            Position.emptyPosition().withComponent(persistentStoreTopicName, 1, 123L);
+        final File persistentFile = new File(
+            stateDirectory.getOrCreateDirectoryForTask(taskId),
+            "shouldFailWritingPositionCheckpointFile.position"
+        );
+        final StateStorePositionCommit persistentCheckpoint = new StateStorePositionCommit(persistentFile, persistentPosition);
+
+        assertFalse(persistentCheckpoint.getFile().exists());
+
+        assertEquals(persistentCheckpoint.getCheckpointedPosition(), Position.emptyPosition());
+    }
+
+    public static class StateStorePositionCommit implements CommitCallback {
+        private File file;
+        private final OffsetCheckpoint checkpointFile;
+        private final Position position;
+
+        public StateStorePositionCommit(final File file, final Position position) {
+            this.file = file;
+            this.checkpointFile = new OffsetCheckpoint(file);
+            this.position = position;
+        }
+
+        public OffsetCheckpoint getOffsetCheckpoint() {
+            return checkpointFile;
+        }
+
+        public File getFile() {
+            return file;
+        }
+
+        public Position getStateStorePosition() {
+            return position;
+        }
+
+        public Position getCheckpointedPosition() {
+            return StoreQueryUtils.readPositionFromCheckpoint(checkpointFile);
+        }
+
+        @Override
+        public void onCommit() throws IOException {
+            StoreQueryUtils.checkpointPosition(checkpointFile, position);
+        }
+    };
+
+
+
+
+
+
     private ProcessorStateManager getStateManager(final Task.TaskType taskType, final boolean eosEnabled) {
         return new ProcessorStateManager(
             taskId,
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateManagerStub.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateManagerStub.java
index 122c992..6f0d18e 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateManagerStub.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateManagerStub.java
@@ -18,6 +18,7 @@ package org.apache.kafka.streams.processor.internals;
 
 
 import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.streams.processor.CommitCallback;
 import org.apache.kafka.streams.processor.StateRestoreCallback;
 import org.apache.kafka.streams.processor.StateStore;
 
@@ -34,7 +35,8 @@ public class StateManagerStub implements StateManager {
 
     @Override
     public void registerStore(final StateStore store,
-                              final StateRestoreCallback stateRestoreCallback) {}
+                              final StateRestoreCallback stateRestoreCallback,
+                              final CommitCallback checkpoint) {}
 
     @Override
     public void flush() {}
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
index 9ef68e3..72dc1bc 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
@@ -285,7 +285,7 @@ public class StreamTaskTest {
         stateDirectory = EasyMock.createNiceMock(StateDirectory.class);
         EasyMock.expect(stateDirectory.lock(taskId)).andReturn(false);
         EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet());
-        stateManager.registerStore(stateStore, stateStore.stateRestoreCallback);
+        stateManager.registerStore(stateStore, stateStore.stateRestoreCallback, null);
         EasyMock.expectLastCall();
         EasyMock.replay(stateDirectory, stateManager);
 
@@ -428,7 +428,7 @@ public class StreamTaskTest {
         EasyMock.expect(stateDirectory.lock(taskId)).andReturn(true);
         EasyMock.expect(stateManager.changelogPartitions()).andReturn(singleton(changelogPartition));
         EasyMock.expect(stateManager.changelogOffsets()).andReturn(singletonMap(changelogPartition, 10L));
-        stateManager.registerStore(stateStore, stateStore.stateRestoreCallback);
+        stateManager.registerStore(stateStore, stateStore.stateRestoreCallback, null);
         EasyMock.expectLastCall();
         EasyMock.expect(recordCollector.offsets()).andReturn(emptyMap()).anyTimes();
         EasyMock.replay(stateDirectory, stateManager, recordCollector);
@@ -1473,7 +1473,7 @@ public class StreamTaskTest {
             .andReturn(singletonMap(changelogPartition, 0L))
             .andReturn(singletonMap(changelogPartition, 10L))
             .andReturn(singletonMap(changelogPartition, 12000L));
-        stateManager.registerStore(stateStore, stateStore.stateRestoreCallback);
+        stateManager.registerStore(stateStore, stateStore.stateRestoreCallback, null);
         EasyMock.expectLastCall();
         EasyMock.replay(stateManager, recordCollector);
 
@@ -1494,7 +1494,7 @@ public class StreamTaskTest {
     @Test
     public void shouldNotCheckpointOffsetsOnCommitIfEosIsEnabled() {
         EasyMock.expect(stateManager.changelogPartitions()).andReturn(singleton(changelogPartition));
-        stateManager.registerStore(stateStore, stateStore.stateRestoreCallback);
+        stateManager.registerStore(stateStore, stateStore.stateRestoreCallback, null);
         EasyMock.expectLastCall();
         EasyMock.expect(recordCollector.offsets()).andReturn(emptyMap()).anyTimes();
         EasyMock.replay(stateManager, recordCollector);
diff --git a/streams/src/test/java/org/apache/kafka/test/GlobalStateManagerStub.java b/streams/src/test/java/org/apache/kafka/test/GlobalStateManagerStub.java
index 9ea836b..d34b3c8 100644
--- a/streams/src/test/java/org/apache/kafka/test/GlobalStateManagerStub.java
+++ b/streams/src/test/java/org/apache/kafka/test/GlobalStateManagerStub.java
@@ -17,6 +17,7 @@
 package org.apache.kafka.test;
 
 import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.streams.processor.CommitCallback;
 import org.apache.kafka.streams.processor.StateRestoreCallback;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.internals.GlobalStateManager;
@@ -58,7 +59,9 @@ public class GlobalStateManagerStub implements GlobalStateManager {
     }
 
     @Override
-    public void registerStore(final StateStore store, final StateRestoreCallback stateRestoreCallback) {}
+    public void registerStore(final StateStore store,
+                              final StateRestoreCallback stateRestoreCallback,
+                              final CommitCallback checkpoint) {}
 
     @Override
     public void flush() {}
diff --git a/streams/src/test/java/org/apache/kafka/test/InternalMockProcessorContext.java b/streams/src/test/java/org/apache/kafka/test/InternalMockProcessorContext.java
index 1d96d88..650c7f7 100644
--- a/streams/src/test/java/org/apache/kafka/test/InternalMockProcessorContext.java
+++ b/streams/src/test/java/org/apache/kafka/test/InternalMockProcessorContext.java
@@ -28,6 +28,7 @@ import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.streams.KeyValue;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.processor.Cancellable;
+import org.apache.kafka.streams.processor.CommitCallback;
 import org.apache.kafka.streams.processor.PunctuationType;
 import org.apache.kafka.streams.processor.Punctuator;
 import org.apache.kafka.streams.processor.StateRestoreCallback;
@@ -289,10 +290,11 @@ public class InternalMockProcessorContext<KOut, VOut>
 
     @Override
     public void register(final StateStore store,
-                         final StateRestoreCallback func) {
+                         final StateRestoreCallback func,
+                         final CommitCallback checkpoint) {
         storeMap.put(store.name(), store);
         restoreFuncs.put(store.name(), func);
-        stateManager().registerStore(store, func);
+        stateManager().registerStore(store, func, checkpoint);
     }
 
     @SuppressWarnings("unchecked")
diff --git a/streams/src/test/java/org/apache/kafka/test/MockInternalNewProcessorContext.java b/streams/src/test/java/org/apache/kafka/test/MockInternalNewProcessorContext.java
index 5fff794..7131e86 100644
--- a/streams/src/test/java/org/apache/kafka/test/MockInternalNewProcessorContext.java
+++ b/streams/src/test/java/org/apache/kafka/test/MockInternalNewProcessorContext.java
@@ -19,6 +19,7 @@ package org.apache.kafka.test;
 import org.apache.kafka.common.header.Headers;
 import org.apache.kafka.common.header.internals.RecordHeaders;
 import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.streams.processor.CommitCallback;
 import org.apache.kafka.streams.processor.StateRestoreCallback;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.TaskId;
@@ -121,7 +122,15 @@ public class MockInternalNewProcessorContext<KOut, VOut> extends MockProcessorCo
     public void uninitialize() {}
 
     @Override
-    public void register(final StateStore store, final StateRestoreCallback stateRestoreCallback) {
+    public void register(final StateStore store,
+                         final StateRestoreCallback stateRestoreCallback) {
+        addStateStore(store);
+    }
+
+    @Override
+    public void register(final StateStore store,
+                         final StateRestoreCallback stateRestoreCallback,
+                         final CommitCallback checkpoint) {
         addStateStore(store);
     }
 
diff --git a/streams/src/test/java/org/apache/kafka/test/MockInternalProcessorContext.java b/streams/src/test/java/org/apache/kafka/test/MockInternalProcessorContext.java
index f982d23..6f8bcd8 100644
--- a/streams/src/test/java/org/apache/kafka/test/MockInternalProcessorContext.java
+++ b/streams/src/test/java/org/apache/kafka/test/MockInternalProcessorContext.java
@@ -17,6 +17,7 @@
 package org.apache.kafka.test;
 
 import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.streams.processor.CommitCallback;
 import org.apache.kafka.streams.processor.MockProcessorContext;
 import org.apache.kafka.streams.processor.StateRestoreCallback;
 import org.apache.kafka.streams.processor.StateStore;
@@ -133,7 +134,16 @@ public class MockInternalProcessorContext extends MockProcessorContext implement
     }
 
     @Override
-    public void register(final StateStore store, final StateRestoreCallback stateRestoreCallback) {
+    public void register(final StateStore store,
+                         final StateRestoreCallback stateRestoreCallback) {
+        restoreCallbacks.put(store.name(), stateRestoreCallback);
+        super.register(store, stateRestoreCallback);
+    }
+
+    @Override
+    public void register(final StateStore store,
+                         final StateRestoreCallback stateRestoreCallback,
+                         final CommitCallback checkpoint) {
         restoreCallbacks.put(store.name(), stateRestoreCallback);
         super.register(store, stateRestoreCallback);
     }
diff --git a/streams/src/test/java/org/apache/kafka/test/NoOpProcessorContext.java b/streams/src/test/java/org/apache/kafka/test/NoOpProcessorContext.java
index 7da926e..53231ed 100644
--- a/streams/src/test/java/org/apache/kafka/test/NoOpProcessorContext.java
+++ b/streams/src/test/java/org/apache/kafka/test/NoOpProcessorContext.java
@@ -20,6 +20,7 @@ import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.utils.Bytes;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.processor.Cancellable;
+import org.apache.kafka.streams.processor.CommitCallback;
 import org.apache.kafka.streams.processor.PunctuationType;
 import org.apache.kafka.streams.processor.Punctuator;
 import org.apache.kafka.streams.processor.StateRestoreCallback;
@@ -117,7 +118,8 @@ public class NoOpProcessorContext extends AbstractProcessorContext<Object, Objec
 
     @Override
     public void register(final StateStore store,
-                         final StateRestoreCallback stateRestoreCallback) {
+                         final StateRestoreCallback stateRestoreCallback,
+                         final CommitCallback checkpoint) {
     }
 
     @Override
diff --git a/streams/test-utils/src/main/java/org/apache/kafka/streams/processor/api/MockProcessorContext.java b/streams/test-utils/src/main/java/org/apache/kafka/streams/processor/api/MockProcessorContext.java
index 52fe3c0..acd946a 100644
--- a/streams/test-utils/src/main/java/org/apache/kafka/streams/processor/api/MockProcessorContext.java
+++ b/streams/test-utils/src/main/java/org/apache/kafka/streams/processor/api/MockProcessorContext.java
@@ -28,6 +28,7 @@ import org.apache.kafka.streams.TopologyTestDriver;
 import org.apache.kafka.streams.kstream.Transformer;
 import org.apache.kafka.streams.kstream.ValueTransformer;
 import org.apache.kafka.streams.processor.Cancellable;
+import org.apache.kafka.streams.processor.CommitCallback;
 import org.apache.kafka.streams.processor.PunctuationType;
 import org.apache.kafka.streams.processor.Punctuator;
 import org.apache.kafka.streams.processor.StateRestoreCallback;
@@ -485,7 +486,15 @@ public class MockProcessorContext<KForward, VForward> implements ProcessorContex
             }
 
             @Override
-            public void register(final StateStore store, final StateRestoreCallback stateRestoreCallback) {
+            public void register(final StateStore store,
+                                 final StateRestoreCallback stateRestoreCallback) {
+                register(store, stateRestoreCallback, () -> { });
+            }
+
+            @Override
+            public void register(final StateStore store,
+                                 final StateRestoreCallback stateRestoreCallback,
+                                 final CommitCallback checkpoint) {
                 stateStores.put(store.name(), store);
             }