You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by gu...@apache.org on 2020/05/29 18:39:06 UTC

[kafka] branch 2.6 updated: KAFKA-9501: convert between active and standby without closing stores (#8248)

This is an automated email from the ASF dual-hosted git repository.

guozhang pushed a commit to branch 2.6
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/2.6 by this push:
     new 2d66972  KAFKA-9501: convert between active and standby without closing stores (#8248)
2d66972 is described below

commit 2d66972345b2c4b7d81870cb68a5fd0063f92aa9
Author: A. Sophie Blee-Goldman <so...@confluent.io>
AuthorDate: Fri May 29 10:48:03 2020 -0700

    KAFKA-9501: convert between active and standby without closing stores (#8248)
    
    This PR has gone through several significant transitions of its own, but here's the latest:
    
    * TaskManager just collects the tasks to transition and refers to the active/standby task creator to handle closing & recycling the old task and creating the new one. If we ever hit an exception during the close, we bail and close all the remaining tasks as dirty.
    
    * The task creators tell the task to "close but recycle state". If this is successful, it tells the recycled processor context and state manager that they should transition to the new type.
    
    * During "close and recycle" the task just does a normal clean close, but instead of closing the state manager it informs it to recycle itself: maintain all of its store information (most importantly the current store offsets) but unregister the changelogs from the changelog reader
    
    * The new task will (re-)register its changelogs during initialization, but skip re-registering any stores. It will still read the checkpoint file, but only use the written offsets if the store offsets are not already initialized from pre-transition
    
    * To ensure we don't end up with manual compaction disabled for standbys, we have to call the state restore listener's onRestoreEnd for any active restoring stores that are switching to standbys
    
    Reviewers: John Roesler <vv...@apache.org>, Guozhang Wang <wa...@gmail.com>
---
 checkstyle/suppressions.xml                        |   9 +-
 .../org/apache/kafka/streams/StreamsConfig.java    |   4 +-
 .../internals/AbstractProcessorContext.java        |   3 +-
 .../processor/internals/ActiveTaskCreator.java     | 136 ++++++++++-----
 .../processor/internals/ChangelogReader.java       |   7 -
 .../processor/internals/ChangelogRegister.java     |   8 +
 .../internals/GlobalProcessorContextImpl.java      |  18 +-
 .../internals/InternalProcessorContext.java        |  16 ++
 .../processor/internals/ProcessorContextImpl.java  |  72 ++++----
 .../processor/internals/ProcessorStateManager.java | 121 ++++++++++---
 .../streams/processor/internals/StandbyTask.java   |  40 ++++-
 .../processor/internals/StandbyTaskCreator.java    |  74 ++++++--
 .../processor/internals/StateManagerUtil.java      |  20 +--
 .../processor/internals/StoreChangelogReader.java  |  33 +++-
 .../streams/processor/internals/StreamTask.java    |  47 +++--
 .../kafka/streams/processor/internals/Task.java    |   5 +
 .../streams/processor/internals/TaskManager.java   |  49 ++++--
 .../state/internals/CachingKeyValueStore.java      |  22 ++-
 .../state/internals/CachingSessionStore.java       |  24 ++-
 .../state/internals/CachingWindowStore.java        |  32 ++--
 .../InMemoryTimeOrderedKeyValueBuffer.java         |  41 +++--
 .../kafka/streams/TopologyTestDriverWrapper.java   |   2 +-
 .../integration/InternalTopicIntegrationTest.java  |   4 +-
 .../OptimizedKTableIntegrationTest.java            |  39 +----
 .../integration/RestoreIntegrationTest.java        | 188 +++++++++++++++-----
 .../integration/utils/IntegrationTestUtils.java    | 189 +++++++++++++++------
 .../internals/AbstractProcessorContextTest.java    |  13 ++
 .../processor/internals/MockChangelogReader.java   |   2 +-
 .../internals/ProcessorContextImplTest.java        |  14 +-
 .../processor/internals/ProcessorContextTest.java  |   3 +-
 .../internals/ProcessorStateManagerTest.java       |  86 +++++++++-
 .../processor/internals/StandbyTaskTest.java       |  39 ++++-
 .../processor/internals/StateManagerUtilTest.java  |  27 +--
 .../internals/StoreChangelogReaderTest.java        |   6 +-
 .../processor/internals/StreamTaskTest.java        | 100 ++++++++---
 .../processor/internals/TaskManagerTest.java       |  22 +--
 .../AbstractRocksDBSegmentedBytesStoreTest.java    |   2 +-
 .../StreamThreadStateStoreProviderTest.java        |  18 +-
 .../kafka/test/InternalMockProcessorContext.java   |  32 +++-
 .../kafka/test/MockInternalProcessorContext.java   |  14 ++
 .../apache/kafka/test/NoOpProcessorContext.java    |  17 ++
 .../apache/kafka/streams/TopologyTestDriver.java   |  16 +-
 42 files changed, 1149 insertions(+), 465 deletions(-)

diff --git a/checkstyle/suppressions.xml b/checkstyle/suppressions.xml
index 0241e6e..fe7f716 100644
--- a/checkstyle/suppressions.xml
+++ b/checkstyle/suppressions.xml
@@ -143,7 +143,7 @@
               files="(TopologyBuilder|KafkaStreams|KStreamImpl|KTableImpl|StreamThread|StreamTask).java"/>
 
     <suppress checks="MethodLength"
-              files="(KTableImpl|StreamsPartitionAssignor|EosBetaUpgradeIntegrationTest).java"/>
+              files="(EosBetaUpgradeIntegrationTest|KTableImpl|TaskManager).java"/>
 
     <suppress checks="ParameterNumber"
               files="StreamTask.java"/>
@@ -166,12 +166,11 @@
     <suppress checks="CyclomaticComplexity"
               files="EosBetaUpgradeIntegrationTest.java"/>
 
-    <suppress checks="JavaNCSS"
-              files="StreamsPartitionAssignor.java"/>
-    <suppress checks="JavaNCSS"
-              files="EosBetaUpgradeIntegrationTest.java"/>
     <suppress checks="StaticVariableName"
               files="StreamsMetricsImpl.java"/>
+    
+    <suppress checks="JavaNCSS"
+              files="(EosBetaUpgradeIntegrationTest|StreamsPartitionAssignor|TaskManager).java"/>
 
     <suppress checks="NPathComplexity"
               files="(AssignorConfiguration|EosBetaUpgradeIntegrationTest|InternalTopologyBuilder|KafkaStreams|ProcessorStateManager|StreamsPartitionAssignor|StreamThread|TaskManager).java"/>
diff --git a/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java b/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java
index 4278cd8..f9e75fe 100644
--- a/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java
+++ b/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java
@@ -200,8 +200,8 @@ public class StreamsConfig extends AbstractConfig {
 
     /**
      * Prefix used to isolate {@link Admin admin} configs from other client configs.
-     * It is recommended to use {@link #adminClientPrefix(String)} to add this prefix to {@link ProducerConfig producer
-     * properties}.
+     * It is recommended to use {@link #adminClientPrefix(String)} to add this prefix to {@link AdminClientConfig admin
+     * client properties}.
      */
     @SuppressWarnings("WeakerAccess")
     public static final String ADMIN_CLIENT_PREFIX = "admin.";
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractProcessorContext.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractProcessorContext.java
index 1132708..cfa91f5 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractProcessorContext.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractProcessorContext.java
@@ -39,13 +39,14 @@ public abstract class AbstractProcessorContext implements InternalProcessorConte
     private final StreamsConfig config;
     private final StreamsMetricsImpl metrics;
     private final Serde<?> keySerde;
-    private final ThreadCache cache;
     private final Serde<?> valueSerde;
     private boolean initialized;
     protected ProcessorRecordContext recordContext;
     protected ProcessorNode<?, ?> currentNode;
     private long currentSystemTimeMs;
+
     final StateManager stateManager;
+    protected ThreadCache cache;
 
     public AbstractProcessorContext(final TaskId taskId,
                                     final StreamsConfig config,
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
index 5473d7a..0a1f47e 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
@@ -27,6 +27,7 @@ import org.apache.kafka.streams.KafkaClientSupplier;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.internals.Task.TaskType;
 import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
 import org.apache.kafka.streams.processor.internals.metrics.ThreadMetrics;
 import org.apache.kafka.streams.state.internals.ThreadCache;
@@ -138,9 +139,7 @@ class ActiveTaskCreator {
             final TaskId taskId = newTaskAndPartitions.getKey();
             final Set<TopicPartition> partitions = newTaskAndPartitions.getValue();
 
-            final String threadIdPrefix = String.format("stream-thread [%s] ", Thread.currentThread().getName());
-            final String logPrefix = threadIdPrefix + String.format("%s [%s] ", "task", taskId);
-            final LogContext logContext = new LogContext(logPrefix);
+            final LogContext logContext = getLogContext(taskId);
 
             final ProcessorTopology topology = builder.buildSubtopology(taskId.topicGroupId);
 
@@ -155,50 +154,100 @@ class ActiveTaskCreator {
                 partitions
             );
 
-            final StreamsProducer streamsProducer;
-            if (processingMode == StreamThread.ProcessingMode.EXACTLY_ONCE_ALPHA) {
-                log.info("Creating producer client for task {}", taskId);
-                streamsProducer = new StreamsProducer(
-                    config,
-                    threadId,
-                    clientSupplier,
-                    taskId,
-                    null,
-                    logContext);
-                taskProducers.put(taskId, streamsProducer);
-            } else {
-                streamsProducer = threadProducer;
-            }
-
-            final RecordCollector recordCollector = new RecordCollectorImpl(
-                logContext,
+            final InternalProcessorContext context = new ProcessorContextImpl(
                 taskId,
-                streamsProducer,
-                config.defaultProductionExceptionHandler(),
-                streamsMetrics
-            );
-
-            final Task task = new StreamTask(
-                taskId,
-                partitions,
-                topology,
-                consumer,
                 config,
-                streamsMetrics,
-                stateDirectory,
-                cache,
-                time,
                 stateManager,
-                recordCollector
+                streamsMetrics,
+                cache
             );
 
-            log.trace("Created task {} with assigned partitions {}", taskId, partitions);
-            createdTasks.add(task);
-            createTaskSensor.record();
+            createdTasks.add(
+                createActiveTask(
+                    taskId,
+                    partitions,
+                    consumer,
+                    logContext,
+                    topology,
+                    stateManager,
+                    context
+                )
+            );
         }
         return createdTasks;
     }
 
+    StreamTask createActiveTaskFromStandby(final StandbyTask standbyTask,
+                                           final Set<TopicPartition> partitions,
+                                           final Consumer<byte[], byte[]> consumer) {
+        final InternalProcessorContext context = standbyTask.processorContext();
+        final ProcessorStateManager stateManager = standbyTask.stateMgr;
+        final LogContext logContext = getLogContext(standbyTask.id);
+
+        standbyTask.closeAndRecycleState();
+        stateManager.transitionTaskType(TaskType.ACTIVE, logContext);
+
+        return createActiveTask(
+            standbyTask.id,
+            partitions,
+            consumer,
+            logContext,
+            builder.buildSubtopology(standbyTask.id.topicGroupId),
+            stateManager,
+            context
+        );
+    }
+
+    private StreamTask createActiveTask(final TaskId taskId,
+                                        final Set<TopicPartition> partitions,
+                                        final Consumer<byte[], byte[]> consumer,
+                                        final LogContext logContext,
+                                        final ProcessorTopology topology,
+                                        final ProcessorStateManager stateManager,
+                                        final InternalProcessorContext context) {
+        final StreamsProducer streamsProducer;
+        if (processingMode == StreamThread.ProcessingMode.EXACTLY_ONCE_ALPHA) {
+            log.info("Creating producer client for task {}", taskId);
+            streamsProducer = new StreamsProducer(
+                config,
+                threadId,
+                clientSupplier,
+                taskId,
+                null,
+                logContext);
+            taskProducers.put(taskId, streamsProducer);
+        } else {
+            streamsProducer = threadProducer;
+        }
+
+        final RecordCollector recordCollector = new RecordCollectorImpl(
+            logContext,
+            taskId,
+            streamsProducer,
+            config.defaultProductionExceptionHandler(),
+            streamsMetrics
+        );
+
+        final StreamTask task = new StreamTask(
+            taskId,
+            partitions,
+            topology,
+            consumer,
+            config,
+            streamsMetrics,
+            stateDirectory,
+            cache,
+            time,
+            stateManager,
+            recordCollector,
+            context
+        );
+
+        log.trace("Created task {} with assigned partitions {}", taskId, partitions);
+        createTaskSensor.record();
+        return task;
+    }
+
     void closeThreadProducerIfNeeded() {
         if (threadProducer != null) {
             try {
@@ -225,8 +274,8 @@ class ActiveTaskCreator {
         // and the producer object passed in here will be null. We would then iterate through
         // all the active tasks and add their metrics to the output metrics map.
         final Collection<StreamsProducer> producers = threadProducer != null ?
-                Collections.singleton(threadProducer) :
-                taskProducers.values();
+            Collections.singleton(threadProducer) :
+            taskProducers.values();
         return ClientUtils.producerMetrics(producers);
     }
 
@@ -240,4 +289,11 @@ class ActiveTaskCreator {
                                 .collect(Collectors.toSet());
         }
     }
+
+    private LogContext getLogContext(final TaskId taskId) {
+        final String threadIdPrefix = String.format("stream-thread [%s] ", Thread.currentThread().getName());
+        final String logPrefix = threadIdPrefix + String.format("%s [%s] ", "task", taskId);
+        return new LogContext(logPrefix);
+    }
+
 }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogReader.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogReader.java
index 815b842..2211591 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogReader.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogReader.java
@@ -18,7 +18,6 @@ package org.apache.kafka.streams.processor.internals;
 
 import org.apache.kafka.common.TopicPartition;
 
-import java.util.Collection;
 import java.util.Set;
 
 /**
@@ -46,12 +45,6 @@ public interface ChangelogReader extends ChangelogRegister {
     Set<TopicPartition> completedChangelogs();
 
     /**
-     * Removes the passed in partitions from the set of changelogs
-     * @param revokedPartitions the set of partitions to remove
-     */
-    void remove(Collection<TopicPartition> revokedPartitions);
-
-    /**
      * Clear all partitions
      */
     void clear();
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogRegister.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogRegister.java
index ecdb265..cdddd20 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogRegister.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogRegister.java
@@ -16,6 +16,7 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
+import java.util.Collection;
 import org.apache.kafka.common.TopicPartition;
 
 /**
@@ -29,4 +30,11 @@ interface ChangelogRegister {
      * @param stateManager the state manager used for restoring (one per task)
      */
     void register(final TopicPartition partition, final ProcessorStateManager stateManager);
+
+    /**
+     * Unregisters and removes the passed in partitions from the set of changelogs
+     * @param removedPartitions the set of partitions to remove
+     * @param triggerOnRestoreEnd whether to trigger the onRestoreEnd callback
+     */
+    void unregister(final Collection<TopicPartition> removedPartitions, final boolean triggerOnRestoreEnd);
 }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalProcessorContextImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalProcessorContextImpl.java
index 81169d3..480e4a0 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalProcessorContextImpl.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalProcessorContextImpl.java
@@ -30,6 +30,7 @@ import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
 import org.apache.kafka.streams.state.internals.ThreadCache;
 
 import java.time.Duration;
+import org.apache.kafka.streams.state.internals.ThreadCache.DirtyEntryFlushListener;
 
 public class GlobalProcessorContextImpl extends AbstractProcessorContext {
 
@@ -51,7 +52,7 @@ public class GlobalProcessorContextImpl extends AbstractProcessorContext {
     public <K, V> void forward(final K key, final V value) {
         final ProcessorNode<?, ?> previousNode = currentNode();
         try {
-            for (final ProcessorNode<?, ?> child :  currentNode().children()) {
+            for (final ProcessorNode<?, ?> child : currentNode().children()) {
                 setCurrentNode(child);
                 ((ProcessorNode<K, V>) child).process(key, value);
             }
@@ -117,4 +118,19 @@ public class GlobalProcessorContextImpl extends AbstractProcessorContext {
                           final long timestamp) {
         throw new UnsupportedOperationException("this should not happen: logChange() not supported in global processor context.");
     }
+
+    @Override
+    public void transitionToActive(final StreamTask streamTask, final RecordCollector recordCollector, final ThreadCache newCache) {
+        throw new UnsupportedOperationException("this should not happen: transitionToActive() not supported in global processor context.");
+    }
+
+    @Override
+    public void transitionToStandby(final ThreadCache newCache) {
+        throw new UnsupportedOperationException("this should not happen: transitionToStandby() not supported in global processor context.");
+    }
+
+    @Override
+    public void registerCacheFlushListener(final String namespace, final DirtyEntryFlushListener listener) {
+        cache.addDirtyEntryFlushListener(namespace, listener);
+    }
 }
\ No newline at end of file
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalProcessorContext.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalProcessorContext.java
index db5cfc9..572f18f 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalProcessorContext.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalProcessorContext.java
@@ -26,6 +26,7 @@ import org.apache.kafka.streams.processor.internals.Task.TaskType;
 import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
 import org.apache.kafka.streams.state.StoreBuilder;
 import org.apache.kafka.streams.state.internals.ThreadCache;
+import org.apache.kafka.streams.state.internals.ThreadCache.DirtyEntryFlushListener;
 
 /**
  * For internal use so we can update the {@link RecordContext} and current
@@ -91,6 +92,21 @@ public interface InternalProcessorContext extends ProcessorContext {
     TaskType taskType();
 
     /**
+     * Transition to active task and register a new task and cache to this processor context
+     */
+    void transitionToActive(final StreamTask streamTask, final RecordCollector recordCollector, final ThreadCache newCache);
+
+    /**
+     * Transition to standby task and register a dummy cache to this processor context
+     */
+    void transitionToStandby(final ThreadCache newCache);
+
+    /**
+     * Register a dirty entry flush listener for a particular namespace
+     */
+    void registerCacheFlushListener(final String namespace, final DirtyEntryFlushListener listener);
+
+    /**
      * Get a correctly typed state store, given a handle on the original builder.
      */
     @SuppressWarnings("unchecked")
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java
index 19bbbdf..a58a862 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java
@@ -19,7 +19,6 @@ package org.apache.kafka.streams.processor.internals;
 import java.util.HashMap;
 import java.util.Map;
 import org.apache.kafka.common.utils.Bytes;
-import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.internals.ApiUtils;
@@ -36,54 +35,65 @@ import org.apache.kafka.streams.state.internals.ThreadCache;
 
 import java.time.Duration;
 import java.util.List;
+import org.apache.kafka.streams.state.internals.ThreadCache.DirtyEntryFlushListener;
 
 import static org.apache.kafka.streams.internals.ApiUtils.prepareMillisCheckFailMsgPrefix;
 import static org.apache.kafka.streams.processor.internals.AbstractReadOnlyDecorator.getReadOnlyStore;
 import static org.apache.kafka.streams.processor.internals.AbstractReadWriteDecorator.getReadWriteStore;
 
 public class ProcessorContextImpl extends AbstractProcessorContext implements RecordCollector.Supplier {
-    // The below are both null for standby tasks
-    private final StreamTask streamTask;
-    private final RecordCollector collector;
+    // the below are null for standby tasks
+    private StreamTask streamTask;
+    private RecordCollector collector;
 
     private final ToInternal toInternal = new ToInternal();
     private final static To SEND_TO_ALL = To.all();
 
     final Map<String, String> storeToChangelogTopic = new HashMap<>();
+    final Map<String, DirtyEntryFlushListener> cacheNameToFlushListener = new HashMap<>();
 
-    ProcessorContextImpl(final TaskId id,
-                         final StreamTask streamTask,
-                         final StreamsConfig config,
-                         final RecordCollector collector,
-                         final ProcessorStateManager stateMgr,
-                         final StreamsMetricsImpl metrics,
-                         final ThreadCache cache) {
+    public ProcessorContextImpl(final TaskId id,
+                                final StreamsConfig config,
+                                final ProcessorStateManager stateMgr,
+                                final StreamsMetricsImpl metrics,
+                                final ThreadCache cache) {
         super(id, config, metrics, stateMgr, cache);
+    }
+
+    @Override
+    public void transitionToActive(final StreamTask streamTask, final RecordCollector recordCollector, final ThreadCache newCache) {
+        if (stateManager.taskType() != TaskType.ACTIVE) {
+            throw new IllegalStateException("Tried to transition processor context to active but the state manager's " +
+                                                "type was " + stateManager.taskType());
+        }
         this.streamTask = streamTask;
-        this.collector = collector;
+        this.collector = recordCollector;
+        this.cache = newCache;
+        addAllFlushListenersToNewCache();
+    }
 
-        if (streamTask == null && taskType() == TaskType.ACTIVE) {
-            throw new IllegalStateException("Tried to create context for active task but the streamtask was null");
+    @Override
+    public void transitionToStandby(final ThreadCache newCache) {
+        if (stateManager.taskType() != TaskType.STANDBY) {
+            throw new IllegalStateException("Tried to transition processor context to standby but the state manager's " +
+                                                "type was " + stateManager.taskType());
         }
+        this.streamTask = null;
+        this.collector = null;
+        this.cache = newCache;
+        addAllFlushListenersToNewCache();
     }
 
-    ProcessorContextImpl(final TaskId id,
-                         final StreamsConfig config,
-                         final ProcessorStateManager stateMgr,
-                         final StreamsMetricsImpl metrics) {
-        this(
-            id,
-            null,
-            config,
-            null,
-            stateMgr,
-            metrics,
-            new ThreadCache(
-                new LogContext(String.format("stream-thread [%s] ", Thread.currentThread().getName())),
-                0,
-                metrics
-            )
-        );
+    @Override
+    public void registerCacheFlushListener(final String namespace, final DirtyEntryFlushListener listener) {
+        cacheNameToFlushListener.put(namespace, listener);
+        cache.addDirtyEntryFlushListener(namespace, listener);
+    }
+
+    private void addAllFlushListenersToNewCache() {
+        for (final Map.Entry<String, DirtyEntryFlushListener> cacheEntry : cacheNameToFlushListener.entrySet()) {
+            cache.addDirtyEntryFlushListener(cacheEntry.getKey(), cacheEntry.getValue());
+        }
     }
 
     public ProcessorStateManager stateManager() {
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java
index dbd2b72..f00284f 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java
@@ -16,6 +16,7 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
+import java.util.ArrayList;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.utils.FixedOrderMap;
@@ -142,10 +143,10 @@ public class ProcessorStateManager implements StateManager {
 
     private static final String STATE_CHANGELOG_TOPIC_SUFFIX = "-changelog";
 
-    private final Logger log;
+    private Logger log;
+    private String logPrefix;
+
     private final TaskId taskId;
-    private final String logPrefix;
-    private final TaskType taskType;
     private final boolean eosEnabled;
     private final ChangelogRegister changelogReader;
     private final Map<String, String> storeToChangelogTopic;
@@ -158,6 +159,8 @@ public class ProcessorStateManager implements StateManager {
     private final File baseDir;
     private final OffsetCheckpoint checkpointFile;
 
+    private TaskType taskType;
+
     public static String storeChangelogTopic(final String applicationId, final String storeName) {
         return applicationId + "-" + storeName + STATE_CHANGELOG_TOPIC_SUFFIX;
     }
@@ -189,6 +192,18 @@ public class ProcessorStateManager implements StateManager {
         log.debug("Created state store manager for task {}", taskId);
     }
 
+    void registerStateStores(final List<StateStore> allStores, final InternalProcessorContext processorContext) {
+        processorContext.uninitialize();
+        for (final StateStore store : allStores) {
+            if (stores.containsKey(store.name())) {
+                maybeRegisterStoreWithChangelogReader(store.name());
+            } else {
+                store.init(processorContext, store);
+            }
+            log.trace("Registered state store {}", store.name());
+        }
+    }
+
     void registerGlobalStateStores(final List<StateStore> stateStores) {
         log.debug("Register global stores {}", stateStores);
         for (final StateStore stateStore : stateStores) {
@@ -211,12 +226,12 @@ public class ProcessorStateManager implements StateManager {
             for (final StateStoreMetadata store : stores.values()) {
                 if (store.changelogPartition == null) {
                     log.info("State store {} is not logged and hence would not be restored", store.stateStore.name());
-                } else {
+                } else if (store.offset() == null) {
                     if (loadedCheckpoints.containsKey(store.changelogPartition)) {
                         store.setOffset(loadedCheckpoints.remove(store.changelogPartition));
 
                         log.debug("State store {} initialized from checkpoint with offset {} at changelog {}",
-                            store.stateStore.name(), store.offset, store.changelogPartition);
+                                  store.stateStore.name(), store.offset, store.changelogPartition);
                     } else {
                         // with EOS, if the previous run did not shutdown gracefully, we may lost the checkpoint file
                         // and hence we are uncertain that the current local state only contains committed data;
@@ -234,6 +249,9 @@ public class ProcessorStateManager implements StateManager {
                                 store.stateStore.name(), store.changelogPartition);
                         }
                     }
+                }  else {
+                    log.debug("Skipping re-initialization of offset from checkpoint for recycled store {}",
+                              store.stateStore.name());
                 }
             }
 
@@ -251,6 +269,22 @@ public class ProcessorStateManager implements StateManager {
         }
     }
 
+    private void maybeRegisterStoreWithChangelogReader(final String storeName) {
+        if (isLoggingEnabled(storeName)) {
+            changelogReader.register(getStorePartition(storeName), this);
+        }
+    }
+
+    private List<TopicPartition> getAllChangelogTopicPartitions() {
+        final List<TopicPartition> allChangelogPartitions = new ArrayList<>();
+        for (final StateStoreMetadata storeMetadata : stores.values()) {
+            if (storeMetadata.changelogPartition != null) {
+                allChangelogPartitions.add(storeMetadata.changelogPartition);
+            }
+        }
+        return allChangelogPartitions;
+    }
+
     @Override
     public File baseDir() {
         return baseDir;
@@ -269,26 +303,18 @@ public class ProcessorStateManager implements StateManager {
             throw new IllegalArgumentException(format("%sStore %s has already been registered.", logPrefix, storeName));
         }
 
-        final String topic = storeToChangelogTopic.get(storeName);
-
-        // if the store name does not exist in the changelog map, it means the underlying store
-        // is not log enabled (including global stores), and hence it does not need to be restored
-        if (topic != null) {
-            // NOTE we assume the partition of the topic can always be inferred from the task id;
-            // if user ever use a custom partition grouper (deprecated in KIP-528) this would break and
-            // it is not a regression (it would always break anyways)
-            final TopicPartition storePartition = new TopicPartition(topic, taskId.partition);
-            final StateStoreMetadata storeMetadata = new StateStoreMetadata(
+        final StateStoreMetadata storeMetadata = isLoggingEnabled(storeName) ?
+            new StateStoreMetadata(
                 store,
-                storePartition,
+                getStorePartition(storeName),
                 stateRestoreCallback,
-                converterForStore(store));
-            stores.put(storeName, storeMetadata);
+                converterForStore(store)) :
+            new StateStoreMetadata(store);
 
-            changelogReader.register(storePartition, this);
-        } else {
-            stores.put(storeName, new StateStoreMetadata(store));
-        }
+
+        stores.put(storeName, storeMetadata);
+
+        maybeRegisterStoreWithChangelogReader(storeName);
 
         log.debug("Registered state store {} to its state manager", storeName);
     }
@@ -399,8 +425,8 @@ public class ProcessorStateManager implements StateManager {
         // attempting to flush the stores
         if (!stores.isEmpty()) {
             log.debug("Flushing all stores registered in the state manager: {}", stores);
-            for (final Map.Entry<String, StateStoreMetadata> entry : stores.entrySet()) {
-                final StateStore store = entry.getValue().stateStore;
+            for (final StateStoreMetadata metadata : stores.values()) {
+                final StateStore store = metadata.stateStore;
                 log.trace("Flushing store {}", store.name());
                 try {
                     store.flush();
@@ -425,17 +451,20 @@ public class ProcessorStateManager implements StateManager {
 
     /**
      * {@link StateStore#close() Close} all stores (even in case of failure).
-     * Log all exception and re-throw the first exception that did occur at the end.
+     * Log all exceptions and re-throw the first exception that occurred at the end.
      *
      * @throws ProcessorStateException if any error happens when closing the state stores
      */
     @Override
     public void close() throws ProcessorStateException {
+        log.debug("Closing its state manager and all the registered state stores: {}", stores);
+
+        changelogReader.unregister(getAllChangelogTopicPartitions(), false);
+
         RuntimeException firstException = null;
         // attempting to close the stores, just in case they
         // are not closed by a ProcessorNode yet
         if (!stores.isEmpty()) {
-            log.debug("Closing its state manager and all the registered state stores: {}", stores);
             for (final Map.Entry<String, StateStoreMetadata> entry : stores.entrySet()) {
                 final StateStore store = entry.getValue().stateStore;
                 log.trace("Closing store {}", store.name());
@@ -462,6 +491,33 @@ public class ProcessorStateManager implements StateManager {
         }
     }
 
+    /**
+     * Alternative to {@link #close()} that just resets the changelogs without closing any of the underlying state
+     * or unregistering the stores themselves
+     */
+    void recycle() {
+        log.debug("Recycling state for {} task {}.", taskType, taskId);
+
+        final List<TopicPartition> allChangelogs = getAllChangelogTopicPartitions();
+        if (taskType.equals(TaskType.ACTIVE)) {
+            changelogReader.unregister(allChangelogs, true);
+        } else {
+            changelogReader.unregister(allChangelogs, false);
+        }
+    }
+
+    void transitionTaskType(final TaskType newType, final LogContext logContext) {
+        if (taskType.equals(newType)) {
+            throw new IllegalStateException("Tried to recycle state for task type conversion but new type was the same.");
+        }
+
+        taskType = newType;
+        log = logContext.logger(ProcessorStateManager.class);
+        logPrefix = logContext.logPrefix();
+
+        log.debug("Transitioning state manager for {} task {} to {}", taskType, taskId, newType);
+    }
+
     @Override
     public void checkpoint(final Map<TopicPartition, Long> writtenOffsets) {
         // first update each state store's current offset, then checkpoint
@@ -497,6 +553,19 @@ public class ProcessorStateManager implements StateManager {
         }
     }
 
+    private  TopicPartition getStorePartition(final String storeName) {
+        // NOTE we assume the partition of the topic can always be inferred from the task id;
+        // if user ever use a custom partition grouper (deprecated in KIP-528) this would break and
+        // it is not a regression (it would always break anyways)
+        return new TopicPartition(storeToChangelogTopic.get(storeName), taskId.partition);
+    }
+
+    private boolean isLoggingEnabled(final String storeName) {
+        // if the store name does not exist in the changelog map, it means the underlying store
+        // is not log enabled (including global stores)
+        return storeToChangelogTopic.containsKey(storeName);
+    }
+
     private StateStoreMetadata findStore(final TopicPartition changelogPartition) {
         final List<StateStoreMetadata> found = stores.values().stream()
             .filter(metadata -> changelogPartition.equals(metadata.changelogPartition))
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java
index b4abd79..cc922ab 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java
@@ -27,6 +27,7 @@ import org.apache.kafka.streams.errors.TaskMigratedException;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
 import org.apache.kafka.streams.processor.internals.metrics.ThreadMetrics;
+import org.apache.kafka.streams.state.internals.ThreadCache;
 import org.slf4j.Logger;
 
 import java.util.Collections;
@@ -62,15 +63,18 @@ public class StandbyTask extends AbstractTask implements Task {
                 final StreamsConfig config,
                 final StreamsMetricsImpl metrics,
                 final ProcessorStateManager stateMgr,
-                final StateDirectory stateDirectory) {
+                final StateDirectory stateDirectory,
+                final ThreadCache cache,
+                final InternalProcessorContext processorContext) {
         super(id, topology, stateDirectory, stateMgr, partitions);
+        this.processorContext = processorContext;
+        processorContext.transitionToStandby(cache);
 
         final String threadIdPrefix = String.format("stream-thread [%s] ", Thread.currentThread().getName());
         logPrefix = threadIdPrefix + String.format("%s [%s] ", "standby-task", id);
         final LogContext logContext = new LogContext(logPrefix);
         log = logContext.logger(getClass());
 
-        processorContext = new ProcessorContextImpl(id, config, stateMgr, metrics);
         closeTaskSensor = ThreadMetrics.closeTaskSensor(Thread.currentThread().getName(), metrics);
         eosEnabled = StreamThread.eosEnabled(config);
     }
@@ -201,6 +205,26 @@ public class StandbyTask extends AbstractTask implements Task {
         log.info("Closed dirty");
     }
 
+    @Override
+    public void closeAndRecycleState() {
+        prepareClose(true);
+
+        if (state() == State.CREATED || state() == State.RUNNING) {
+            // since there's no written offsets we can checkpoint with empty map,
+            // and the state current offset would be used to checkpoint
+            stateMgr.checkpoint(Collections.emptyMap());
+            offsetSnapshotSinceLastCommit = new HashMap<>(stateMgr.changelogOffsets());
+            stateMgr.recycle();
+        } else {
+            throw new IllegalStateException("Illegal state " + state() + " while closing standby task " + id);
+        }
+
+        closeTaskSensor.record();
+        transitionTo(State.CLOSED);
+
+        log.info("Closed clean and recycled state");
+    }
+
     private void close(final boolean clean) {
         if (state() == State.CREATED || state() == State.RUNNING) {
             if (clean) {
@@ -209,15 +233,17 @@ public class StandbyTask extends AbstractTask implements Task {
                 stateMgr.checkpoint(Collections.emptyMap());
                 offsetSnapshotSinceLastCommit = new HashMap<>(stateMgr.changelogOffsets());
             }
-            executeAndMaybeSwallow(clean, () ->
-                StateManagerUtil.closeStateManager(
+            executeAndMaybeSwallow(
+                clean,
+                () -> StateManagerUtil.closeStateManager(
                     log,
                     logPrefix,
                     clean,
                     eosEnabled,
                     stateMgr,
                     stateDirectory,
-                    TaskType.STANDBY),
+                    TaskType.STANDBY
+                ),
                 "state manager close",
                 log
             );
@@ -245,6 +271,10 @@ public class StandbyTask extends AbstractTask implements Task {
         throw new IllegalStateException("Attempted to add records to task " + id() + " for invalid input partition " + partition);
     }
 
+    InternalProcessorContext processorContext() {
+        return processorContext;
+    }
+
     /**
      * Produces a string representation containing useful information about a Task.
      * This is useful in debugging scenarios.
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java
index 2cbb3ea..443db8e 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java
@@ -21,8 +21,10 @@ import org.apache.kafka.common.metrics.Sensor;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.internals.Task.TaskType;
 import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
 import org.apache.kafka.streams.processor.internals.metrics.ThreadMetrics;
+import org.apache.kafka.streams.state.internals.ThreadCache;
 import org.slf4j.Logger;
 
 import java.util.ArrayList;
@@ -37,6 +39,7 @@ class StandbyTaskCreator {
     private final StreamsMetricsImpl streamsMetrics;
     private final StateDirectory stateDirectory;
     private final ChangelogReader storeChangelogReader;
+    private final ThreadCache dummyCache;
     private final Logger log;
     private final Sensor createTaskSensor;
 
@@ -53,7 +56,14 @@ class StandbyTaskCreator {
         this.stateDirectory = stateDirectory;
         this.storeChangelogReader = storeChangelogReader;
         this.log = log;
+
         createTaskSensor = ThreadMetrics.createTaskSensor(threadId, streamsMetrics);
+
+        dummyCache = new ThreadCache(
+            new LogContext(String.format("stream-thread [%s] ", Thread.currentThread().getName())),
+            0,
+            streamsMetrics
+        );
     }
 
     Collection<Task> createTasks(final Map<TaskId, Set<TopicPartition>> tasksToBeCreated) {
@@ -62,10 +72,6 @@ class StandbyTaskCreator {
             final TaskId taskId = newTaskAndPartitions.getKey();
             final Set<TopicPartition> partitions = newTaskAndPartitions.getValue();
 
-            final String threadIdPrefix = String.format("stream-thread [%s] ", Thread.currentThread().getName());
-            final String logPrefix = threadIdPrefix + String.format("%s [%s] ", "standby-task", taskId);
-            final LogContext logContext = new LogContext(logPrefix);
-
             final ProcessorTopology topology = builder.buildSubtopology(taskId.topicGroupId);
 
             if (topology.hasStateWithChangelogs()) {
@@ -73,26 +79,22 @@ class StandbyTaskCreator {
                     taskId,
                     Task.TaskType.STANDBY,
                     StreamThread.eosEnabled(config),
-                    logContext,
+                    getLogContext(taskId),
                     stateDirectory,
                     storeChangelogReader,
                     topology.storeToChangelogTopic(),
                     partitions
                 );
 
-                final StandbyTask task = new StandbyTask(
+                final InternalProcessorContext context = new ProcessorContextImpl(
                     taskId,
-                    partitions,
-                    topology,
                     config,
-                    streamsMetrics,
                     stateManager,
-                    stateDirectory
+                    streamsMetrics,
+                    dummyCache
                 );
 
-                log.trace("Created task {} with assigned partitions {}", taskId, partitions);
-                createdTasks.add(task);
-                createTaskSensor.record();
+                createdTasks.add(createStandbyTask(taskId, partitions, topology, stateManager, context));
             } else {
                 log.trace(
                     "Skipped standby task {} with assigned partitions {} " +
@@ -100,10 +102,50 @@ class StandbyTaskCreator {
                     taskId, partitions
                 );
             }
+
         }
         return createdTasks;
     }
 
+    StandbyTask createStandbyTaskFromActive(final StreamTask streamTask,
+                                            final Set<TopicPartition> partitions) {
+        final InternalProcessorContext context = streamTask.processorContext();
+        final ProcessorStateManager stateManager = streamTask.stateMgr;
+
+        streamTask.closeAndRecycleState();
+        stateManager.transitionTaskType(TaskType.STANDBY, getLogContext(streamTask.id()));
+
+        return createStandbyTask(
+            streamTask.id(),
+            partitions,
+            builder.buildSubtopology(streamTask.id.topicGroupId),
+            stateManager,
+            context
+        );
+    }
+
+    StandbyTask createStandbyTask(final TaskId taskId,
+                                  final Set<TopicPartition> partitions,
+                                  final ProcessorTopology topology,
+                                  final ProcessorStateManager stateManager,
+                                  final InternalProcessorContext context) {
+        final StandbyTask task = new StandbyTask(
+            taskId,
+            partitions,
+            topology,
+            config,
+            streamsMetrics,
+            stateManager,
+            stateDirectory,
+            dummyCache,
+            context
+        );
+
+        log.trace("Created task {} with assigned partitions {}", taskId, partitions);
+        createTaskSensor.record();
+        return task;
+    }
+
     public InternalTopologyBuilder builder() {
         return builder;
     }
@@ -111,4 +153,10 @@ class StandbyTaskCreator {
     public StateDirectory stateDirectory() {
         return stateDirectory;
     }
+
+    private LogContext getLogContext(final TaskId taskId) {
+        final String threadIdPrefix = String.format("stream-thread [%s] ", Thread.currentThread().getName());
+        final String logPrefix = threadIdPrefix + String.format("%s [%s] ", "standby-task", taskId);
+        return new LogContext(logPrefix);
+    }
 }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateManagerUtil.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateManagerUtil.java
index 17fc095..422b124 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateManagerUtil.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateManagerUtil.java
@@ -16,6 +16,7 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
+import java.io.IOException;
 import java.util.concurrent.atomic.AtomicReference;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.errors.LockException;
@@ -27,8 +28,6 @@ import org.apache.kafka.streams.processor.internals.Task.TaskType;
 import org.apache.kafka.streams.state.internals.RecordConverter;
 import org.slf4j.Logger;
 
-import java.io.IOException;
-
 import static org.apache.kafka.streams.state.internals.RecordConverters.identity;
 import static org.apache.kafka.streams.state.internals.RecordConverters.rawValueToTimestampedValue;
 import static org.apache.kafka.streams.state.internals.WrappedStateStore.isTimestamped;
@@ -47,7 +46,7 @@ final class StateManagerUtil {
     }
 
     /**
-     * @throws StreamsException If the store's change log does not contain the partition
+     * @throws StreamsException If the store's changelog does not contain the partition
      */
     static void registerStateStores(final Logger log,
                                     final String logPrefix,
@@ -74,21 +73,12 @@ final class StateManagerUtil {
 
         final boolean storeDirsEmpty = stateDirectory.directoryForTaskIsEmpty(id);
 
+        stateMgr.registerStateStores(topology.stateStores(), processorContext);
+        log.debug("Registered state stores");
+
         // We should only load checkpoint AFTER the corresponding state directory lock has been acquired and
         // the state stores have been registered; we should not try to load at the state manager construction time.
         // See https://issues.apache.org/jira/browse/KAFKA-8574
-
-        for (final StateStore store : topology.stateStores()) {
-            if (stateMgr.getStore(store.name()) != null) {
-                log.warn("Skip the registration of store {} since it is already registered. This could be due " +
-                    "to a half-way registration in the previous round of initialization.", store.name());
-                continue;
-            }
-            processorContext.uninitialize();
-            store.init(processorContext, store);
-            log.trace("Registered state store {}", store.name());
-        }
-
         stateMgr.initializeStoreOffsetsFromCheckpoint(storeDirsEmpty);
         log.debug("Initialized state stores");
     }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
index aed6c54..c712bd3 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
@@ -16,6 +16,7 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
+import java.util.concurrent.atomic.AtomicReference;
 import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.clients.consumer.ConsumerRecords;
@@ -546,7 +547,7 @@ public class StoreChangelogReader implements ChangelogReader {
             pauseChangelogsFromRestoreConsumer(Collections.singleton(partition));
 
             try {
-                // first trigger the store's specific listener if its registered callback is also an lister,
+                // first trigger the store's specific listener if its registered callback is also an listener,
                 // then trigger the user registered global listener
                 final StateRestoreCallback restoreCallback = storeMetadata.restoreCallback();
                 if (restoreCallback instanceof StateRestoreListener) {
@@ -811,17 +812,41 @@ public class StoreChangelogReader implements ChangelogReader {
         }
     }
 
+    private RuntimeException invokeOnRestoreEnd(final TopicPartition partition,
+                                                final ChangelogMetadata changelogMetadata) {
+        // only trigger the store's specific listener to make sure we disable bulk loading before transition to standby
+        final StateStoreMetadata storeMetadata = changelogMetadata.storeMetadata;
+        final StateRestoreCallback restoreCallback = storeMetadata.restoreCallback();
+        final String storeName = storeMetadata.store().name();
+        if (restoreCallback instanceof StateRestoreListener) {
+            try {
+                ((StateRestoreListener) restoreCallback).onRestoreEnd(partition, storeName, changelogMetadata.totalRestored);
+            } catch (final RuntimeException e) {
+                return e;
+            }
+        }
+        return null;
+    }
+
     @Override
-    public void remove(final Collection<TopicPartition> revokedChangelogs) {
-        // Only changelogs that are initialized that been added to the restore consumer's assignment
+    public void unregister(final Collection<TopicPartition> revokedChangelogs,
+                           final boolean triggerOnRestoreEnd) {
+        final AtomicReference<RuntimeException> firstException = new AtomicReference<>(null);
+
+        // Only changelogs that are initialized have been added to the restore consumer's assignment
         final List<TopicPartition> revokedInitializedChangelogs = new ArrayList<>();
 
         for (final TopicPartition partition : revokedChangelogs) {
             final ChangelogMetadata changelogMetadata = changelogs.remove(partition);
             if (changelogMetadata != null) {
-                if (changelogMetadata.state() != ChangelogState.REGISTERED) {
+                if (triggerOnRestoreEnd && changelogMetadata.state().equals(ChangelogState.RESTORING)) {
+                    firstException.compareAndSet(null, invokeOnRestoreEnd(partition, changelogMetadata));
+                }
+
+                if (!changelogMetadata.state().equals(ChangelogState.REGISTERED)) {
                     revokedInitializedChangelogs.add(partition);
                 }
+
                 changelogMetadata.clear();
             } else {
                 log.debug("Changelog partition {} could not be found," +
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
index cf1f443..89be8a5 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
@@ -33,7 +33,6 @@ import org.apache.kafka.streams.errors.LockException;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.errors.TaskMigratedException;
 import org.apache.kafka.streams.processor.Cancellable;
-import org.apache.kafka.streams.processor.ProcessorContext;
 import org.apache.kafka.streams.processor.PunctuationType;
 import org.apache.kafka.streams.processor.Punctuator;
 import org.apache.kafka.streams.processor.TaskId;
@@ -118,10 +117,14 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
                       final ThreadCache cache,
                       final Time time,
                       final ProcessorStateManager stateMgr,
-                      final RecordCollector recordCollector) {
+                      final RecordCollector recordCollector,
+                      final InternalProcessorContext processorContext) {
         super(id, topology, stateDirectory, stateMgr, partitions);
         this.mainConsumer = mainConsumer;
 
+        this.processorContext = processorContext;
+        processorContext.transitionToActive(this, recordCollector, cache);
+
         final String threadIdPrefix = String.format("stream-thread [%s] ", Thread.currentThread().getName());
         logPrefix = threadIdPrefix + String.format("%s [%s] ", "task", id);
         final LogContext logContext = new LogContext(logPrefix);
@@ -169,8 +172,6 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
         consumedOffsets = new HashMap<>();
 
         recordQueueCreator = new RecordQueueCreator(logContext, config.defaultTimestampExtractor(), config.defaultDeserializationExceptionHandler());
-        // initialize the topology with its own context
-        processorContext = new ProcessorContextImpl(id, this, config, this.recordCollector, stateMgr, streamsMetrics, cache);
 
         recordInfo = new PartitionGroup.RecordInfo();
         partitionGroup = new PartitionGroup(createPartitionQueues(),
@@ -450,6 +451,33 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
         }
     }
 
+    @Override
+    public void closeAndRecycleState() {
+        final Map<TopicPartition, Long> checkpoint = prepareClose(true);
+
+        if (checkpoint != null) {
+            stateMgr.checkpoint(checkpoint);
+        }
+        switch (state()) {
+            case CREATED:
+            case RUNNING:
+            case RESTORING:
+            case SUSPENDED:
+                stateMgr.recycle();
+                recordCollector.close();
+                break;
+            default:
+                throw new IllegalStateException("Illegal state " + state() + " while closing active task " + id);
+        }
+
+        partitionGroup.close();
+        closeTaskSensor.record();
+
+        transitionTo(State.CLOSED);
+
+        log.info("Closed clean and recycled state");
+    }
+
     /**
      * <pre>
      * the following order must be followed:
@@ -500,8 +528,7 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
      *  3. finally release the state manager lock
      * </pre>
      */
-    private void close(final boolean clean,
-                       final Map<TopicPartition, Long> checkpoint) {
+    private void close(final boolean clean, final Map<TopicPartition, Long> checkpoint) {
         if (clean && checkpoint != null) {
             executeAndMaybeSwallow(clean, () -> stateMgr.checkpoint(checkpoint), "state manager checkpoint", log);
         }
@@ -970,7 +997,7 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
         }
     }
 
-    public ProcessorContext context() {
+    public InternalProcessorContext processorContext() {
         return processorContext;
     }
 
@@ -1037,15 +1064,11 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator,
         return numBuffered() > 0;
     }
 
-    // below are visible for testing only
     RecordCollector recordCollector() {
         return recordCollector;
     }
 
-    InternalProcessorContext processorContext() {
-        return processorContext;
-    }
-
+    // below are visible for testing only
     int numBuffered() {
         return partitionGroup.numBuffered();
     }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java
index 8ca1b4c..f95e476 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java
@@ -176,6 +176,11 @@ public interface Task {
     void update(final Set<TopicPartition> topicPartitions, final ProcessorTopology processorTopology);
 
     /**
+     * Attempt a clean close but do not close the underlying state
+     */
+    void closeAndRecycleState();
+
+    /**
      * Revive a closed task to a created one; should never throw an exception
      */
     void revive();
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
index 4d074d8..15ffa4a 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
@@ -155,9 +155,6 @@ public class TaskManager {
             final TaskId taskId = entry.getKey();
             final Task task = tasks.get(taskId);
 
-            // this call is idempotent so even if the task is only CREATED we can still call it
-            changelogReader.remove(task.changelogPartitions());
-
             // mark corrupted partitions to not be checkpointed, and then close the task as dirty
             final Collection<TopicPartition> corruptedPartitions = entry.getValue();
             task.markChangelogAsCorrupted(corruptedPartitions);
@@ -183,12 +180,15 @@ public class TaskManager {
                      "\tExisting standby tasks: {}",
                  activeTasks.keySet(), standbyTasks.keySet(), activeTaskIds(), standbyTaskIds());
 
-        final Map<TaskId, Set<TopicPartition>> activeTasksToCreate = new TreeMap<>(activeTasks);
-        final Map<TaskId, Set<TopicPartition>> standbyTasksToCreate = new TreeMap<>(standbyTasks);
+        final Map<TaskId, Set<TopicPartition>> activeTasksToCreate = new HashMap<>(activeTasks);
+        final Map<TaskId, Set<TopicPartition>> standbyTasksToCreate = new HashMap<>(standbyTasks);
+        final Set<Task> tasksToRecycle = new HashSet<>();
+
         builder.addSubscribedTopicsFromAssignment(
-                activeTasks.values().stream().flatMap(Collection::stream).collect(Collectors.toList()),
-                logPrefix
+            activeTasks.values().stream().flatMap(Collection::stream).collect(Collectors.toList()),
+            logPrefix
         );
+
         // first rectify all existing tasks
         final LinkedHashMap<TaskId, RuntimeException> taskCloseExceptions = new LinkedHashMap<>();
 
@@ -207,7 +207,10 @@ public class TaskManager {
             } else if (standbyTasks.containsKey(task.id()) && !task.isActive()) {
                 updateInputPartitionsAndResume(task, standbyTasks.get(task.id()));
                 standbyTasksToCreate.remove(task.id());
-            } else /* we previously owned this task, and we don't have it anymore, or it has changed active/standby state */ {
+                // check for tasks that were owned previously but have changed active/standby status
+            } else if (activeTasks.containsKey(task.id()) || standbyTasks.containsKey(task.id())) {
+                tasksToRecycle.add(task);
+            } else {
                 try {
                     final Map<TopicPartition, Long> checkpoint = task.prepareCloseClean();
                     final Map<TopicPartition, OffsetAndMetadata> committableOffsets = task
@@ -275,6 +278,26 @@ public class TaskManager {
             }
         }
 
+        for (final Task oldTask : tasksToRecycle) {
+            final Task newTask;
+            try {
+                if (oldTask.isActive()) {
+                    final Set<TopicPartition> partitions = standbyTasksToCreate.remove(oldTask.id());
+                    newTask = standbyTaskCreator.createStandbyTaskFromActive((StreamTask) oldTask, partitions);
+                } else {
+                    final Set<TopicPartition> partitions = activeTasksToCreate.remove(oldTask.id());
+                    newTask = activeTaskCreator.createActiveTaskFromStandby((StandbyTask) oldTask, partitions, mainConsumer);
+                }
+                tasks.remove(oldTask.id());
+                addNewTask(newTask);
+            } catch (final RuntimeException e) {
+                final String uncleanMessage = String.format("Failed to recycle task %s cleanly. Attempting to close remaining tasks before re-throwing:", oldTask.id());
+                log.error(uncleanMessage, e);
+                taskCloseExceptions.put(oldTask.id(), e);
+                dirtyTasks.add(oldTask);
+            }
+        }
+
         for (final Task task : dirtyTasks) {
             closeTaskDirty(task);
             cleanUpTaskProducer(task, taskCloseExceptions);
@@ -385,7 +408,9 @@ public class TaskManager {
         }
 
         if (allRunning && !restoringTasks.isEmpty()) {
+
             final Set<TopicPartition> restored = changelogReader.completedChangelogs();
+
             for (final Task task : restoringTasks) {
                 if (restored.containsAll(task.changelogPartitions())) {
                     try {
@@ -611,12 +636,8 @@ public class TaskManager {
 
     // Note: this MUST be called *before* actually closing the task
     private void cleanupTask(final Task task) {
-        // 1. remove the changelog partitions from changelog reader;
-        // 2. remove the input partitions from the materialized map;
-        // 3. remove the task metrics from the metrics registry
-        if (!task.changelogPartitions().isEmpty()) {
-            changelogReader.remove(task.changelogPartitions());
-        }
+        // 1. remove the input partitions from the materialized map;
+        // 2. remove the task metrics from the metrics registry
 
         for (final TopicPartition inputPartition : task.inputPartitions()) {
             partitionToTask.remove(inputPartition);
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingKeyValueStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingKeyValueStore.java
index 8026b04..3502fad 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingKeyValueStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingKeyValueStore.java
@@ -46,7 +46,6 @@ public class CachingKeyValueStore
     private CacheFlushListener<byte[], byte[]> flushListener;
     private boolean sendOldValues;
     private String cacheName;
-    private ThreadCache cache;
     private InternalProcessorContext context;
     private Thread streamThread;
     private final ReadWriteLock lock = new ReentrantReadWriteLock();
@@ -69,9 +68,8 @@ public class CachingKeyValueStore
     private void initInternal(final ProcessorContext context) {
         this.context = (InternalProcessorContext) context;
 
-        this.cache = this.context.cache();
         this.cacheName = ThreadCache.nameSpaceFromTaskIdAndStore(context.taskId().toString(), name());
-        cache.addDirtyEntryFlushListener(cacheName, entries -> {
+        this.context.registerCacheFlushListener(cacheName, entries -> {
             for (final ThreadCache.DirtyEntry entry : entries) {
                 putAndMaybeForward(entry, (InternalProcessorContext) context);
             }
@@ -133,7 +131,7 @@ public class CachingKeyValueStore
 
     private void putInternal(final Bytes key,
                              final byte[] value) {
-        cache.put(
+        context.cache().put(
             cacheName,
             key,
             new LRUCacheEntry(
@@ -219,8 +217,8 @@ public class CachingKeyValueStore
 
     private byte[] getInternal(final Bytes key) {
         LRUCacheEntry entry = null;
-        if (cache != null) {
-            entry = cache.get(cacheName, key);
+        if (context.cache() != null) {
+            entry = context.cache().get(cacheName, key);
         }
         if (entry == null) {
             final byte[] rawValue = wrapped().get(key);
@@ -230,7 +228,7 @@ public class CachingKeyValueStore
             // only update the cache if this call is on the streamThread
             // as we don't want other threads to trigger an eviction/flush
             if (Thread.currentThread().equals(streamThread)) {
-                cache.put(cacheName, key, new LRUCacheEntry(rawValue));
+                context.cache().put(cacheName, key, new LRUCacheEntry(rawValue));
             }
             return rawValue;
         } else {
@@ -250,7 +248,7 @@ public class CachingKeyValueStore
 
         validateStoreOpen();
         final KeyValueIterator<Bytes, byte[]> storeIterator = wrapped().range(from, to);
-        final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = cache.range(cacheName, from, to);
+        final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = context.cache().range(cacheName, from, to);
         return new MergedSortedCacheKeyValueBytesStoreIterator(cacheIterator, storeIterator);
     }
 
@@ -259,7 +257,7 @@ public class CachingKeyValueStore
         validateStoreOpen();
         final KeyValueIterator<Bytes, byte[]> storeIterator =
             new DelegatingPeekingKeyValueIterator<>(this.name(), wrapped().all());
-        final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = cache.all(cacheName);
+        final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = context.cache().all(cacheName);
         return new MergedSortedCacheKeyValueBytesStoreIterator(cacheIterator, storeIterator);
     }
 
@@ -281,7 +279,7 @@ public class CachingKeyValueStore
         lock.writeLock().lock();
         try {
             validateStoreOpen();
-            cache.flush(cacheName);
+            context.cache().flush(cacheName);
             wrapped().flush();
         } finally {
             lock.writeLock().unlock();
@@ -293,8 +291,8 @@ public class CachingKeyValueStore
         lock.writeLock().lock();
         try {
             final LinkedList<RuntimeException> suppressed = executeAll(
-                () -> cache.flush(cacheName),
-                () -> cache.close(cacheName),
+                () -> context.cache().flush(cacheName),
+                () -> context.cache().close(cacheName),
                 wrapped()::close
             );
             if (!suppressed.isEmpty()) {
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingSessionStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingSessionStore.java
index 4976ef1..25068c0 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingSessionStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingSessionStore.java
@@ -46,7 +46,6 @@ class CachingSessionStore
     private final SessionKeySchema keySchema;
     private final SegmentedCacheFunction cacheFunction;
     private String cacheName;
-    private ThreadCache cache;
     private InternalProcessorContext context;
     private CacheFlushListener<byte[], byte[]> flushListener;
     private boolean sendOldValues;
@@ -72,8 +71,7 @@ class CachingSessionStore
         this.context = context;
 
         cacheName = context.taskId() + "-" + name();
-        cache = context.cache();
-        cache.addDirtyEntryFlushListener(cacheName, entries -> {
+        context.registerCacheFlushListener(cacheName, entries -> {
             for (final ThreadCache.DirtyEntry entry : entries) {
                 putAndMaybeForward(entry, context);
             }
@@ -133,7 +131,7 @@ class CachingSessionStore
                 context.timestamp(),
                 context.partition(),
                 context.topic());
-        cache.put(cacheName, cacheFunction.cacheKey(binaryKey), entry);
+        context.cache().put(cacheName, cacheFunction.cacheKey(binaryKey), entry);
 
         maxObservedTimestamp = Math.max(keySchema.segmentTimestamp(binaryKey), maxObservedTimestamp);
     }
@@ -152,7 +150,7 @@ class CachingSessionStore
 
         final PeekingKeyValueIterator<Bytes, LRUCacheEntry> cacheIterator = wrapped().persistent() ?
             new CacheIteratorWrapper(key, earliestSessionEndTime, latestSessionStartTime) :
-            cache.range(cacheName,
+            context.cache().range(cacheName,
                         cacheFunction.cacheKey(keySchema.lowerRangeFixedSize(key, earliestSessionEndTime)),
                         cacheFunction.cacheKey(keySchema.upperRangeFixedSize(key, latestSessionStartTime))
             );
@@ -185,7 +183,7 @@ class CachingSessionStore
 
         final Bytes cacheKeyFrom = cacheFunction.cacheKey(keySchema.lowerRange(keyFrom, earliestSessionEndTime));
         final Bytes cacheKeyTo = cacheFunction.cacheKey(keySchema.upperRange(keyTo, latestSessionStartTime));
-        final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = cache.range(cacheName, cacheKeyFrom, cacheKeyTo);
+        final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = context.cache().range(cacheName, cacheKeyFrom, cacheKeyTo);
 
         final KeyValueIterator<Windowed<Bytes>, byte[]> storeIterator = wrapped().findSessions(
             keyFrom, keyTo, earliestSessionEndTime, latestSessionStartTime
@@ -203,12 +201,12 @@ class CachingSessionStore
     public byte[] fetchSession(final Bytes key, final long startTime, final long endTime) {
         Objects.requireNonNull(key, "key cannot be null");
         validateStoreOpen();
-        if (cache == null) {
+        if (context.cache() == null) {
             return wrapped().fetchSession(key, startTime, endTime);
         } else {
             final Bytes bytesKey = SessionKeySchema.toBinary(key, startTime, endTime);
             final Bytes cacheKey = cacheFunction.cacheKey(bytesKey);
-            final LRUCacheEntry entry = cache.get(cacheName, cacheKey);
+            final LRUCacheEntry entry = context.cache().get(cacheName, cacheKey);
             if (entry == null) {
                 return wrapped().fetchSession(key, startTime, endTime);
             } else {
@@ -232,14 +230,14 @@ class CachingSessionStore
     }
 
     public void flush() {
-        cache.flush(cacheName);
+        context.cache().flush(cacheName);
         wrapped().flush();
     }
 
     public void close() {
         final LinkedList<RuntimeException> suppressed = executeAll(
-            () -> cache.flush(cacheName),
-            () -> cache.close(cacheName),
+            () -> context.cache().flush(cacheName),
+            () -> context.cache().close(cacheName),
             wrapped()::close
         );
         if (!suppressed.isEmpty()) {
@@ -283,7 +281,7 @@ class CachingSessionStore
 
             setCacheKeyRange(earliestSessionEndTime, currentSegmentLastTime());
 
-            this.current = cache.range(cacheName, cacheKeyFrom, cacheKeyTo);
+            this.current = context.cache().range(cacheName, cacheKeyFrom, cacheKeyTo);
         }
 
         @Override
@@ -354,7 +352,7 @@ class CachingSessionStore
             setCacheKeyRange(currentSegmentBeginTime(), currentSegmentLastTime());
 
             current.close();
-            current = cache.range(cacheName, cacheKeyFrom, cacheKeyTo);
+            current = context.cache().range(cacheName, cacheKeyFrom, cacheKeyTo);
         }
 
         private void setCacheKeyRange(final long lowerRangeEndTime, final long upperRangeEndTime) {
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingWindowStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingWindowStore.java
index e71f87e..78e16a9 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingWindowStore.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingWindowStore.java
@@ -49,7 +49,6 @@ class CachingWindowStore
     private final SegmentedBytesStore.KeySchema keySchema = new WindowKeySchema();
 
     private String name;
-    private ThreadCache cache;
     private boolean sendOldValues;
     private InternalProcessorContext context;
     private StateSerdes<Bytes, byte[]> bytesSerdes;
@@ -84,9 +83,8 @@ class CachingWindowStore
             Serdes.Bytes(),
             Serdes.ByteArray());
         name = context.taskId() + "-" + name();
-        cache = this.context.cache();
 
-        cache.addDirtyEntryFlushListener(name, entries -> {
+        context.registerCacheFlushListener(name, entries -> {
             for (final ThreadCache.DirtyEntry entry : entries) {
                 putAndMaybeForward(entry, context);
             }
@@ -161,7 +159,7 @@ class CachingWindowStore
                 context.timestamp(),
                 context.partition(),
                 context.topic());
-        cache.put(name, cacheFunction.cacheKey(keyBytes), entry);
+        context.cache().put(name, cacheFunction.cacheKey(keyBytes), entry);
 
         maxObservedTimestamp = Math.max(keySchema.segmentTimestamp(keyBytes), maxObservedTimestamp);
     }
@@ -172,10 +170,10 @@ class CachingWindowStore
         validateStoreOpen();
         final Bytes bytesKey = WindowKeySchema.toStoreKeyBinary(key, timestamp, 0);
         final Bytes cacheKey = cacheFunction.cacheKey(bytesKey);
-        if (cache == null) {
+        if (context.cache() == null) {
             return wrapped().fetch(key, timestamp);
         }
-        final LRUCacheEntry entry = cache.get(name, cacheKey);
+        final LRUCacheEntry entry = context.cache().get(name, cacheKey);
         if (entry == null) {
             return wrapped().fetch(key, timestamp);
         } else {
@@ -193,13 +191,13 @@ class CachingWindowStore
         validateStoreOpen();
 
         final WindowStoreIterator<byte[]> underlyingIterator = wrapped().fetch(key, timeFrom, timeTo);
-        if (cache == null) {
+        if (context.cache() == null) {
             return underlyingIterator;
         }
 
         final PeekingKeyValueIterator<Bytes, LRUCacheEntry> cacheIterator = wrapped().persistent() ?
             new CacheIteratorWrapper(key, timeFrom, timeTo) :
-            cache.range(name,
+            context.cache().range(name,
                         cacheFunction.cacheKey(keySchema.lowerRangeFixedSize(key, timeFrom)),
                         cacheFunction.cacheKey(keySchema.upperRangeFixedSize(key, timeTo))
             );
@@ -231,13 +229,13 @@ class CachingWindowStore
 
         final KeyValueIterator<Windowed<Bytes>, byte[]> underlyingIterator =
             wrapped().fetch(from, to, timeFrom, timeTo);
-        if (cache == null) {
+        if (context.cache() == null) {
             return underlyingIterator;
         }
 
         final PeekingKeyValueIterator<Bytes, LRUCacheEntry> cacheIterator = wrapped().persistent() ?
             new CacheIteratorWrapper(from, to, timeFrom, timeTo) :
-            cache.range(name,
+            context.cache().range(name,
                         cacheFunction.cacheKey(keySchema.lowerRange(from, timeFrom)),
                         cacheFunction.cacheKey(keySchema.upperRange(to, timeTo))
             );
@@ -261,7 +259,7 @@ class CachingWindowStore
         validateStoreOpen();
 
         final KeyValueIterator<Windowed<Bytes>, byte[]> underlyingIterator = wrapped().fetchAll(timeFrom, timeTo);
-        final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = cache.all(name);
+        final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = context.cache().all(name);
 
         final HasNextCondition hasNextCondition = keySchema.hasNextCondition(null, null, timeFrom, timeTo);
         final PeekingKeyValueIterator<Bytes, LRUCacheEntry> filteredCacheIterator =
@@ -280,7 +278,7 @@ class CachingWindowStore
         validateStoreOpen();
 
         final KeyValueIterator<Windowed<Bytes>, byte[]>  underlyingIterator = wrapped().all();
-        final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = cache.all(name);
+        final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = context.cache().all(name);
 
         return new MergedSortedCacheWindowStoreKeyValueIterator(
             cacheIterator,
@@ -293,15 +291,15 @@ class CachingWindowStore
 
     @Override
     public synchronized void flush() {
-        cache.flush(name);
+        context.cache().flush(name);
         wrapped().flush();
     }
 
     @Override
     public synchronized void close() {
         final LinkedList<RuntimeException> suppressed = executeAll(
-            () -> cache.flush(name),
-            () -> cache.close(name),
+            () -> context.cache().flush(name),
+            () -> context.cache().close(name),
             wrapped()::close
         );
         if (!suppressed.isEmpty()) {
@@ -347,7 +345,7 @@ class CachingWindowStore
 
             setCacheKeyRange(timeFrom, currentSegmentLastTime());
 
-            this.current = cache.range(name, cacheKeyFrom, cacheKeyTo);
+            this.current = context.cache().range(name, cacheKeyFrom, cacheKeyTo);
         }
 
         @Override
@@ -418,7 +416,7 @@ class CachingWindowStore
             setCacheKeyRange(currentSegmentBeginTime(), currentSegmentLastTime());
 
             current.close();
-            current = cache.range(name, cacheKeyFrom, cacheKeyTo);
+            current = context.cache().range(name, cacheKeyFrom, cacheKeyTo);
         }
 
         private void setCacheKeyRange(final long lowerRangeEndTime, final long upperRangeEndTime) {
diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryTimeOrderedKeyValueBuffer.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryTimeOrderedKeyValueBuffer.java
index c813a47..9feccb9 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryTimeOrderedKeyValueBuffer.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryTimeOrderedKeyValueBuffer.java
@@ -77,7 +77,6 @@ public final class InMemoryTimeOrderedKeyValueBuffer<K, V> implements TimeOrdere
     private long memBufferSize = 0L;
     private long minTimestamp = Long.MAX_VALUE;
     private InternalProcessorContext context;
-    private RecordCollector collector;
     private String changelogTopic;
     private Sensor bufferSizeSensor;
     private Sensor bufferCountSensor;
@@ -210,10 +209,7 @@ public final class InMemoryTimeOrderedKeyValueBuffer<K, V> implements TimeOrdere
         );
 
         context.register(root, (RecordBatchingStateRestoreCallback) this::restoreBatch);
-        if (loggingEnabled) {
-            collector = ((RecordCollector.Supplier) context).recordCollector();
-            changelogTopic = ProcessorStateManager.storeChangelogTopic(context.applicationId(), storeName);
-        }
+        changelogTopic = ProcessorStateManager.storeChangelogTopic(context.applicationId(), storeName);
         updateBufferMetrics();
         open = true;
         partition = context.taskId().partition;
@@ -263,27 +259,28 @@ public final class InMemoryTimeOrderedKeyValueBuffer<K, V> implements TimeOrdere
         final ByteBuffer buffer = value.serialize(sizeOfBufferTime);
         buffer.putLong(bufferKey.time());
 
-        collector.send(
-            changelogTopic,
-            key,
-            buffer.array(),
-            V_2_CHANGELOG_HEADERS,
-            partition,
-            null,
-            KEY_SERIALIZER,
-            VALUE_SERIALIZER
+        ((RecordCollector.Supplier) context).recordCollector().send(
+                changelogTopic,
+                key,
+                buffer.array(),
+                V_2_CHANGELOG_HEADERS,
+                partition,
+                null,
+                KEY_SERIALIZER,
+                VALUE_SERIALIZER
         );
     }
 
     private void logTombstone(final Bytes key) {
-        collector.send(changelogTopic,
-                       key,
-                       null,
-                       null,
-                       partition,
-                       null,
-                       KEY_SERIALIZER,
-                       VALUE_SERIALIZER
+        ((RecordCollector.Supplier) context).recordCollector().send(
+                changelogTopic,
+                key,
+                null,
+                null,
+                partition,
+                null,
+                KEY_SERIALIZER,
+                VALUE_SERIALIZER
         );
     }
 
diff --git a/streams/src/test/java/org/apache/kafka/streams/TopologyTestDriverWrapper.java b/streams/src/test/java/org/apache/kafka/streams/TopologyTestDriverWrapper.java
index 5c73bb4..69c3f38 100644
--- a/streams/src/test/java/org/apache/kafka/streams/TopologyTestDriverWrapper.java
+++ b/streams/src/test/java/org/apache/kafka/streams/TopologyTestDriverWrapper.java
@@ -43,7 +43,7 @@ public class TopologyTestDriverWrapper extends TopologyTestDriver {
      * @return the processor context
      */
     public ProcessorContext setCurrentNodeForProcessorContext(final String processorName) {
-        final ProcessorContext context = task.context();
+        final ProcessorContext context = task.processorContext();
         ((ProcessorContextImpl) context).setCurrentNode(getProcessor(processorName));
         return context;
     }
diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/InternalTopicIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/InternalTopicIntegrationTest.java
index 2eafabd..0a71481 100644
--- a/streams/src/test/java/org/apache/kafka/streams/integration/InternalTopicIntegrationTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/integration/InternalTopicIntegrationTest.java
@@ -162,7 +162,7 @@ public class InternalTopicIntegrationTest {
         //
         // Step 3: Verify the state changelog topics are compact
         //
-        waitForCompletion(streams, 2, 30000);
+        waitForCompletion(streams, 2, 30000L);
         streams.close();
 
         final Properties changelogProps = getTopicProperties(ProcessorStateManager.storeChangelogTopic(appID, "Counts"));
@@ -202,7 +202,7 @@ public class InternalTopicIntegrationTest {
         //
         // Step 3: Verify the state changelog topics are compact
         //
-        waitForCompletion(streams, 2, 30000);
+        waitForCompletion(streams, 2, 30000L);
         streams.close();
         final Properties properties = getTopicProperties(ProcessorStateManager.storeChangelogTopic(appID, "CountWindows"));
         final List<String> policies = Arrays.asList(properties.getProperty(LogConfig.CleanupPolicyProp()).split(","));
diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/OptimizedKTableIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/OptimizedKTableIntegrationTest.java
index 092cffc..ea24090 100644
--- a/streams/src/test/java/org/apache/kafka/streams/integration/OptimizedKTableIntegrationTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/integration/OptimizedKTableIntegrationTest.java
@@ -22,6 +22,7 @@ import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.notNullValue;
+import static org.junit.Assert.assertTrue;
 
 import java.time.Duration;
 import java.util.ArrayList;
@@ -34,7 +35,6 @@ import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 import org.apache.kafka.clients.consumer.ConsumerConfig;
 import org.apache.kafka.clients.producer.ProducerConfig;
-import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.serialization.IntegerSerializer;
 import org.apache.kafka.common.serialization.Serdes;
 import org.apache.kafka.common.utils.Bytes;
@@ -46,9 +46,9 @@ import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.KeyQueryMetadata;
 import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
 import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
+import org.apache.kafka.streams.integration.utils.IntegrationTestUtils.TrackingStateRestoreListener;
 import org.apache.kafka.streams.kstream.Consumed;
 import org.apache.kafka.streams.kstream.Materialized;
-import org.apache.kafka.streams.processor.StateRestoreListener;
 import org.apache.kafka.streams.state.KeyValueStore;
 import org.apache.kafka.streams.state.QueryableStoreTypes;
 import org.apache.kafka.streams.state.ReadOnlyKeyValueStore;
@@ -109,9 +109,7 @@ public class OptimizedKTableIntegrationTest {
         final List<KafkaStreams> kafkaStreamsList = Arrays.asList(kafkaStreams1, kafkaStreams2);
         final TrackingStateRestoreListener listener = new TrackingStateRestoreListener();
 
-        kafkaStreamsList.forEach(kafkaStreams -> {
-            kafkaStreams.setGlobalStateRestoreListener(listener);
-        });
+        kafkaStreamsList.forEach(kafkaStreams -> kafkaStreams.setGlobalStateRestoreListener(listener));
         startApplicationAndWaitUntilRunning(kafkaStreamsList, Duration.ofSeconds(60));
 
         produceValueRange(key, 0, batch1NumMessages);
@@ -135,8 +133,8 @@ public class OptimizedKTableIntegrationTest {
 
         // Assert that no restore has occurred, ensures that when we check later that the restore
         // notification actually came from after the rebalance.
-        assertThat(listener.startOffset, is(equalTo(0L)));
-        assertThat(listener.totalNumRestored, is(equalTo(0L)));
+        assertTrue(listener.allStartOffsetsAtZero());
+        assertThat(listener.totalNumRestored(), is(equalTo(0L)));
 
         // Assert that the current value in store reflects all messages being processed
         assertThat(kafkaStreams1WasFirstActive ? store1.get(key) : store2.get(key), is(equalTo(batch1NumMessages - 1)));
@@ -187,33 +185,6 @@ public class OptimizedKTableIntegrationTest {
         return streams;
     }
 
-    private class TrackingStateRestoreListener implements StateRestoreListener {
-        long startOffset = -1L;
-        long endOffset = -1L;
-        long totalNumRestored = 0L;
-
-        @Override
-        public void onRestoreStart(final TopicPartition topicPartition,
-                                   final String storeName,
-                                   final long startingOffset,
-                                   final long endingOffset) {
-            startOffset = startingOffset;
-            endOffset = endingOffset;
-        }
-
-        @Override
-        public void onBatchRestored(final TopicPartition topicPartition,
-                                    final String storeName,
-                                    final long batchEndOffset,
-                                    final long numRestored) {
-            totalNumRestored += numRestored;
-        }
-
-        @Override
-        public void onRestoreEnd(final TopicPartition topicPartition, final String storeName, final long totalRestored) {
-
-        }
-    }
 
     private Properties streamsConfiguration() {
         final String safeTestName = safeUniqueTestName(getClass(), testName);
diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/RestoreIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/RestoreIntegrationTest.java
index 160313a..faac172 100644
--- a/streams/src/test/java/org/apache/kafka/streams/integration/RestoreIntegrationTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/integration/RestoreIntegrationTest.java
@@ -29,12 +29,14 @@ import org.apache.kafka.common.serialization.Serdes;
 import org.apache.kafka.common.utils.Bytes;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.streams.KafkaStreams;
+import org.apache.kafka.streams.KafkaStreams.State;
 import org.apache.kafka.streams.KeyValue;
 import org.apache.kafka.streams.StreamsBuilder;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.Topology;
 import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
 import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
+import org.apache.kafka.streams.integration.utils.IntegrationTestUtils.TrackingStateRestoreListener;
 import org.apache.kafka.streams.kstream.Consumed;
 import org.apache.kafka.streams.kstream.KStream;
 import org.apache.kafka.streams.kstream.Materialized;
@@ -47,28 +49,38 @@ import org.apache.kafka.streams.state.KeyValueBytesStoreSupplier;
 import org.apache.kafka.streams.state.KeyValueStore;
 import org.apache.kafka.streams.state.StoreBuilder;
 import org.apache.kafka.streams.state.Stores;
+import org.apache.kafka.streams.state.internals.InMemoryKeyValueStore;
 import org.apache.kafka.streams.state.internals.KeyValueStoreBuilder;
 import org.apache.kafka.streams.state.internals.OffsetCheckpoint;
 import org.apache.kafka.test.IntegrationTest;
 import org.apache.kafka.test.TestUtils;
+import org.hamcrest.CoreMatchers;
 import org.junit.After;
-import org.junit.BeforeClass;
+import org.junit.Before;
 import org.junit.ClassRule;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.experimental.categories.Category;
 
 import java.io.File;
 import java.time.Duration;
-import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 import java.util.Properties;
 import java.util.concurrent.CountDownLatch;
-import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLong;
-
+import org.junit.rules.TestName;
+
+import static java.util.Arrays.asList;
+import static java.util.Collections.singletonList;
+import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.purgeLocalStreamsState;
+import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName;
+import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.startApplicationAndWaitUntilRunning;
+import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.waitForApplicationState;
+import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.waitForCompletion;
+import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.waitForStandbyCompletion;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.core.IsEqual.equalTo;
 import static org.junit.Assert.assertTrue;
@@ -77,29 +89,30 @@ import static org.junit.Assert.assertTrue;
 public class RestoreIntegrationTest {
     private static final int NUM_BROKERS = 1;
 
-    private static final String APPID = "restore-test";
-
     @ClassRule
-    public static final EmbeddedKafkaCluster CLUSTER =
-            new EmbeddedKafkaCluster(NUM_BROKERS);
-    private static final String INPUT_STREAM = "input-stream";
-    private static final String INPUT_STREAM_2 = "input-stream-2";
+    public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(NUM_BROKERS);
+
+    @Rule
+    public final TestName testName = new TestName();
+    private String appId;
+    private String inputStream;
+
     private final int numberOfKeys = 10000;
     private KafkaStreams kafkaStreams;
 
-    @BeforeClass
-    public static void createTopics() throws InterruptedException {
-        CLUSTER.createTopic(INPUT_STREAM, 2, 1);
-        CLUSTER.createTopic(INPUT_STREAM_2, 2, 1);
-        CLUSTER.createTopic(APPID + "-store-changelog", 2, 1);
+    @Before
+    public void createTopics() throws InterruptedException {
+        appId = safeUniqueTestName(RestoreIntegrationTest.class, testName);
+        inputStream = appId + "-input-stream";
+        CLUSTER.createTopic(inputStream, 2, 1);
     }
 
-    private Properties props(final String applicationId) {
+    private Properties props() {
         final Properties streamsConfiguration = new Properties();
-        streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, applicationId);
+        streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, appId);
         streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers());
         streamsConfiguration.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0);
-        streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory(applicationId).getPath());
+        streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory(appId).getPath());
         streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.Integer().getClass());
         streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.Integer().getClass());
         streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000);
@@ -121,26 +134,26 @@ public class RestoreIntegrationTest {
         final AtomicInteger numReceived = new AtomicInteger(0);
         final StreamsBuilder builder = new StreamsBuilder();
 
-        final Properties props = props(APPID);
+        final Properties props = props();
         props.put(StreamsConfig.TOPOLOGY_OPTIMIZATION, StreamsConfig.OPTIMIZE);
 
         // restoring from 1000 to 4000 (committed), and then process from 4000 to 5000 on each of the two partitions
         final int offsetLimitDelta = 1000;
         final int offsetCheckpointed = 1000;
-        createStateForRestoration(INPUT_STREAM, 0);
-        setCommittedOffset(INPUT_STREAM, offsetLimitDelta);
+        createStateForRestoration(inputStream, 0);
+        setCommittedOffset(inputStream, offsetLimitDelta);
 
         final StateDirectory stateDirectory = new StateDirectory(new StreamsConfig(props), new MockTime(), true);
         // note here the checkpointed offset is the last processed record's offset, so without control message we should write this offset - 1
         new OffsetCheckpoint(new File(stateDirectory.directoryForTask(new TaskId(0, 0)), ".checkpoint"))
-                .write(Collections.singletonMap(new TopicPartition(INPUT_STREAM, 0), (long) offsetCheckpointed - 1));
+                .write(Collections.singletonMap(new TopicPartition(inputStream, 0), (long) offsetCheckpointed - 1));
         new OffsetCheckpoint(new File(stateDirectory.directoryForTask(new TaskId(0, 1)), ".checkpoint"))
-                .write(Collections.singletonMap(new TopicPartition(INPUT_STREAM, 1), (long) offsetCheckpointed - 1));
+                .write(Collections.singletonMap(new TopicPartition(inputStream, 1), (long) offsetCheckpointed - 1));
 
         final CountDownLatch startupLatch = new CountDownLatch(1);
         final CountDownLatch shutdownLatch = new CountDownLatch(1);
 
-        builder.table(INPUT_STREAM, Materialized.<Integer, Integer, KeyValueStore<Bytes, byte[]>>as("store").withKeySerde(Serdes.Integer()).withValueSerde(Serdes.Integer()))
+        builder.table(inputStream, Materialized.<Integer, Integer, KeyValueStore<Bytes, byte[]>>as("store").withKeySerde(Serdes.Integer()).withValueSerde(Serdes.Integer()))
                 .toStream()
                 .foreach((key, value) -> {
                     if (numReceived.incrementAndGet() == offsetLimitDelta * 2) {
@@ -183,27 +196,30 @@ public class RestoreIntegrationTest {
 
     @Test
     public void shouldRestoreStateFromChangelogTopic() throws Exception {
+        final String changelog = appId + "-store-changelog";
+        CLUSTER.createTopic(changelog, 2, 1);
+
         final AtomicInteger numReceived = new AtomicInteger(0);
         final StreamsBuilder builder = new StreamsBuilder();
 
-        final Properties props = props(APPID);
+        final Properties props = props();
 
         // restoring from 1000 to 5000, and then process from 5000 to 10000 on each of the two partitions
         final int offsetCheckpointed = 1000;
-        createStateForRestoration(APPID + "-store-changelog", 0);
-        createStateForRestoration(INPUT_STREAM, 10000);
+        createStateForRestoration(changelog, 0);
+        createStateForRestoration(inputStream, 10000);
 
         final StateDirectory stateDirectory = new StateDirectory(new StreamsConfig(props), new MockTime(), true);
         // note here the checkpointed offset is the last processed record's offset, so without control message we should write this offset - 1
         new OffsetCheckpoint(new File(stateDirectory.directoryForTask(new TaskId(0, 0)), ".checkpoint"))
-                .write(Collections.singletonMap(new TopicPartition(APPID + "-store-changelog", 0), (long) offsetCheckpointed - 1));
+                .write(Collections.singletonMap(new TopicPartition(changelog, 0), (long) offsetCheckpointed - 1));
         new OffsetCheckpoint(new File(stateDirectory.directoryForTask(new TaskId(0, 1)), ".checkpoint"))
-                .write(Collections.singletonMap(new TopicPartition(APPID + "-store-changelog", 1), (long) offsetCheckpointed - 1));
+                .write(Collections.singletonMap(new TopicPartition(changelog, 1), (long) offsetCheckpointed - 1));
 
         final CountDownLatch startupLatch = new CountDownLatch(1);
         final CountDownLatch shutdownLatch = new CountDownLatch(1);
 
-        builder.table(INPUT_STREAM, Consumed.with(Serdes.Integer(), Serdes.Integer()), Materialized.as("store"))
+        builder.table(inputStream, Consumed.with(Serdes.Integer(), Serdes.Integer()), Materialized.as("store"))
                 .toStream()
                 .foreach((key, value) -> {
                     if (numReceived.incrementAndGet() == numberOfKeys) {
@@ -244,19 +260,18 @@ public class RestoreIntegrationTest {
         assertThat(numReceived.get(), equalTo(numberOfKeys));
     }
 
-
     @Test
     public void shouldSuccessfullyStartWhenLoggingDisabled() throws InterruptedException {
         final StreamsBuilder builder = new StreamsBuilder();
 
-        final KStream<Integer, Integer> stream = builder.stream(INPUT_STREAM);
+        final KStream<Integer, Integer> stream = builder.stream(inputStream);
         stream.groupByKey()
                 .reduce(
                     (value1, value2) -> value1 + value2,
                     Materialized.<Integer, Integer, KeyValueStore<Bytes, byte[]>>as("reduce-store").withLoggingDisabled());
 
         final CountDownLatch startupLatch = new CountDownLatch(1);
-        kafkaStreams = new KafkaStreams(builder.build(), props(APPID));
+        kafkaStreams = new KafkaStreams(builder.build(), props());
         kafkaStreams.setStateListener((newState, oldState) -> {
             if (newState == KafkaStreams.State.RUNNING && oldState == KafkaStreams.State.REBALANCING) {
                 startupLatch.countDown();
@@ -269,10 +284,10 @@ public class RestoreIntegrationTest {
     }
 
     @Test
-    public void shouldProcessDataFromStoresWithLoggingDisabled() throws InterruptedException, ExecutionException {
+    public void shouldProcessDataFromStoresWithLoggingDisabled() throws InterruptedException {
 
-        IntegrationTestUtils.produceKeyValuesSynchronously(INPUT_STREAM_2,
-                                                           Arrays.asList(KeyValue.pair(1, 1),
+        IntegrationTestUtils.produceKeyValuesSynchronously(inputStream,
+                                                           asList(KeyValue.pair(1, 1),
                                                                          KeyValue.pair(2, 2),
                                                                          KeyValue.pair(3, 3)),
                                                            TestUtils.producerConfig(CLUSTER.bootstrapServers(),
@@ -280,7 +295,7 @@ public class RestoreIntegrationTest {
                                                                                     IntegerSerializer.class),
                                                            CLUSTER.time);
 
-        final KeyValueBytesStoreSupplier lruMapSupplier = Stores.lruMap(INPUT_STREAM_2, 10);
+        final KeyValueBytesStoreSupplier lruMapSupplier = Stores.lruMap(inputStream, 10);
 
         final StoreBuilder<KeyValueStore<Integer, Integer>> storeBuilder = new KeyValueStoreBuilder<>(lruMapSupplier,
                                                                                                       Serdes.Integer(),
@@ -292,13 +307,13 @@ public class RestoreIntegrationTest {
 
         streamsBuilder.addStateStore(storeBuilder);
 
-        final KStream<Integer, Integer> stream = streamsBuilder.stream(INPUT_STREAM_2);
+        final KStream<Integer, Integer> stream = streamsBuilder.stream(inputStream);
         final CountDownLatch processorLatch = new CountDownLatch(3);
-        stream.process(() -> new KeyValueStoreProcessor(INPUT_STREAM_2, processorLatch), INPUT_STREAM_2);
+        stream.process(() -> new KeyValueStoreProcessor(inputStream, processorLatch), inputStream);
 
         final Topology topology = streamsBuilder.build();
 
-        kafkaStreams = new KafkaStreams(topology, props(APPID + "-logging-disabled"));
+        kafkaStreams = new KafkaStreams(topology, props());
 
         final CountDownLatch latch = new CountDownLatch(1);
         kafkaStreams.setStateListener((newState, oldState) -> {
@@ -311,9 +326,98 @@ public class RestoreIntegrationTest {
         latch.await(30, TimeUnit.SECONDS);
 
         assertTrue(processorLatch.await(30, TimeUnit.SECONDS));
+    }
+
+    @Test
+    public void shouldRecycleStateFromStandbyTaskPromotedToActiveTaskAndNotRestore() throws Exception {
+        final StreamsBuilder builder = new StreamsBuilder();
+        builder.table(
+                inputStream,
+            Consumed.with(Serdes.Integer(), Serdes.Integer()), Materialized.as(getCloseCountingStore("store"))
+        );
+        createStateForRestoration(inputStream, 0);
+
+        final Properties props1 = props();
+        props1.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1);
+        props1.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory(appId + "-1").getPath());
+        purgeLocalStreamsState(props1);
+        final KafkaStreams client1 = new KafkaStreams(builder.build(), props1);
+
+        final Properties props2 = props();
+        props2.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1);
+        props2.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory(appId + "-2").getPath());
+        purgeLocalStreamsState(props2);
+        final KafkaStreams client2 = new KafkaStreams(builder.build(), props2);
+
+        final TrackingStateRestoreListener restoreListener = new TrackingStateRestoreListener();
+        client1.setGlobalStateRestoreListener(restoreListener);
+
+        startApplicationAndWaitUntilRunning(asList(client1, client2), Duration.ofSeconds(60));
+
+        waitForCompletion(client1, 1, 30 * 1000L);
+        waitForCompletion(client2, 1, 30 * 1000L);
+        waitForStandbyCompletion(client1, 1, 30 * 1000L);
+        waitForStandbyCompletion(client2, 1, 30 * 1000L);
+
+        assertThat(CloseCountingInMemoryStore.numStoresClosed(), CoreMatchers.equalTo(0));
+        assertThat(restoreListener.totalNumRestored(), CoreMatchers.equalTo(0L));
+
+        client2.close();
+        waitForApplicationState(singletonList(client2), State.NOT_RUNNING, Duration.ofSeconds(60));
+        waitForApplicationState(singletonList(client1), State.REBALANCING, Duration.ofSeconds(60));
+        waitForApplicationState(singletonList(client1), State.RUNNING, Duration.ofSeconds(60));
+
+        waitForCompletion(client1, 1, 30 * 1000L);
+        waitForStandbyCompletion(client1, 1, 30 * 1000L);
+
+        assertThat(restoreListener.totalNumRestored(), CoreMatchers.equalTo(0L));
+
+        // After stopping instance 2 and letting instance 1 take over its tasks, we should have closed just two stores
+        // total: the active and standby tasks on instance 2
+        assertThat(CloseCountingInMemoryStore.numStoresClosed(), equalTo(2));
 
+        client1.close();
+        waitForApplicationState(singletonList(client2), State.NOT_RUNNING, Duration.ofSeconds(60));
+
+        assertThat(CloseCountingInMemoryStore.numStoresClosed(), CoreMatchers.equalTo(4));
     }
 
+    private static KeyValueBytesStoreSupplier getCloseCountingStore(final String name) {
+        return new KeyValueBytesStoreSupplier() {
+            @Override
+            public String name() {
+                return name;
+            }
+
+            @Override
+            public KeyValueStore<Bytes, byte[]> get() {
+                return new CloseCountingInMemoryStore(name);
+            }
+
+            @Override
+            public String metricsScope() {
+                return "close-counting";
+            }
+        };
+    }
+
+    static class CloseCountingInMemoryStore extends InMemoryKeyValueStore {
+        static AtomicInteger numStoresClosed = new AtomicInteger(0);
+
+        CloseCountingInMemoryStore(final String name) {
+            super(name);
+        }
+
+        @Override
+        public void close() {
+            numStoresClosed.incrementAndGet();
+            super.close();
+        }
+
+        static int numStoresClosed() {
+            return numStoresClosed.get();
+        }
+    }
 
     public static class KeyValueStoreProcessor implements Processor<Integer, Integer> {
 
@@ -362,13 +466,13 @@ public class RestoreIntegrationTest {
     private void setCommittedOffset(final String topic, final int limitDelta) {
         final Properties consumerConfig = new Properties();
         consumerConfig.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers());
-        consumerConfig.put(ConsumerConfig.GROUP_ID_CONFIG, APPID);
+        consumerConfig.put(ConsumerConfig.GROUP_ID_CONFIG, appId);
         consumerConfig.put(ConsumerConfig.CLIENT_ID_CONFIG, "commit-consumer");
         consumerConfig.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, IntegerDeserializer.class);
         consumerConfig.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, IntegerDeserializer.class);
 
         final Consumer<Integer, Integer> consumer = new KafkaConsumer<>(consumerConfig);
-        final List<TopicPartition> partitions = Arrays.asList(
+        final List<TopicPartition> partitions = asList(
             new TopicPartition(topic, 0),
             new TopicPartition(topic, 1));
 
diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/utils/IntegrationTestUtils.java b/streams/src/test/java/org/apache/kafka/streams/integration/utils/IntegrationTestUtils.java
index d8b5816..c4cb121 100644
--- a/streams/src/test/java/org/apache/kafka/streams/integration/utils/IntegrationTestUtils.java
+++ b/streams/src/test/java/org/apache/kafka/streams/integration/utils/IntegrationTestUtils.java
@@ -16,7 +16,10 @@
  */
 package org.apache.kafka.streams.integration.utils;
 
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicLong;
+
 import kafka.api.Request;
 import kafka.server.KafkaServer;
 import kafka.server.MetadataCache;
@@ -46,6 +49,7 @@ import org.apache.kafka.streams.StoreQueryParameters;
 import org.apache.kafka.streams.StreamsBuilder;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.errors.InvalidStateStoreException;
+import org.apache.kafka.streams.processor.StateRestoreListener;
 import org.apache.kafka.streams.processor.internals.StreamThread;
 import org.apache.kafka.streams.processor.internals.ThreadStateTransitionValidator;
 import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentListener;
@@ -369,12 +373,10 @@ public class IntegrationTestUtils {
      * @param <K>                 Key type of the data records
      * @param <V>                 Value type of the data records
      */
-    public static <K, V> void produceAbortedKeyValuesSynchronouslyWithTimestamp(
-        final String topic,
-        final Collection<KeyValue<K, V>> records,
-        final Properties producerConfig,
-        final Long timestamp
-    ) throws Exception {
+    public static <K, V> void produceAbortedKeyValuesSynchronouslyWithTimestamp(final String topic,
+                                                                                final Collection<KeyValue<K, V>> records,
+                                                                                final Properties producerConfig,
+                                                                                final Long timestamp) throws Exception {
         try (final Producer<K, V> producer = new KafkaProducer<>(producerConfig)) {
             producer.initTransactions();
             for (final KeyValue<K, V> record : records) {
@@ -424,7 +426,8 @@ public class IntegrationTestUtils {
     }
 
     /**
-     * Wait for streams to "finish", based on the consumer lag metric.
+     * Wait for streams to "finish", based on the consumer lag metric. Includes only the main consumer, for
+     * completion of standbys as well see {@link #waitForStandbyCompletion}
      *
      * Caveats:
      * - Inputs must be finite, fully loaded, and flushed before this method is called
@@ -434,15 +437,54 @@ public class IntegrationTestUtils {
      */
     public static void waitForCompletion(final KafkaStreams streams,
                                          final int expectedPartitions,
-                                         final int timeoutMilliseconds) {
+                                         final long timeoutMilliseconds) {
         final long start = System.currentTimeMillis();
         while (true) {
             int lagMetrics = 0;
             double totalLag = 0.0;
             for (final Metric metric : streams.metrics().values()) {
                 if (metric.metricName().name().equals("records-lag")) {
-                    lagMetrics++;
-                    totalLag += ((Number) metric.metricValue()).doubleValue();
+                    if (!metric.metricName().tags().get("client-id").endsWith("restore-consumer")) {
+                        lagMetrics++;
+                        totalLag += ((Number) metric.metricValue()).doubleValue();
+                    }
+                }
+            }
+            if (lagMetrics >= expectedPartitions && totalLag == 0.0) {
+                return;
+            }
+            if (System.currentTimeMillis() - start >= timeoutMilliseconds) {
+                throw new RuntimeException(String.format(
+                    "Timed out waiting for completion. lagMetrics=[%s/%s] totalLag=[%s]",
+                    lagMetrics, expectedPartitions, totalLag
+                ));
+            }
+        }
+    }
+
+    /**
+     * Wait for streams to "finish" processing standbys, based on the (restore) consumer lag metric. Includes only the
+     * restore consumer, for completion of active tasks see {@link #waitForCompletion}
+     *
+     * Caveats:
+     * - Inputs must be finite, fully loaded, and flushed before this method is called
+     * - expectedPartitions is the total number of partitions to watch the lag on, including both input and internal.
+     *   It's somewhat ok to get this wrong, as the main failure case would be an immediate return due to the clients
+     *   not being initialized, which you can avoid with any non-zero value. But it's probably better to get it right ;)
+     */
+    public static void waitForStandbyCompletion(final KafkaStreams streams,
+                                                final int expectedPartitions,
+                                                final long timeoutMilliseconds) {
+        final long start = System.currentTimeMillis();
+        while (true) {
+            int lagMetrics = 0;
+            double totalLag = 0.0;
+            for (final Metric metric : streams.metrics().values()) {
+                if (metric.metricName().name().equals("records-lag")) {
+                    if (metric.metricName().tags().get("client-id").endsWith("restore-consumer")) {
+                        lagMetrics++;
+                        totalLag += ((Number) metric.metricValue()).doubleValue();
+                    }
                 }
             }
             if (lagMetrics >= expectedPartitions && totalLag == 0.0) {
@@ -468,11 +510,9 @@ public class IntegrationTestUtils {
      * @return All the records consumed, or null if no records are consumed
      */
     @SuppressWarnings("WeakerAccess")
-    public static <K, V> List<ConsumerRecord<K, V>> waitUntilMinRecordsReceived(
-        final Properties consumerConfig,
-        final String topic,
-        final int expectedNumRecords
-    ) throws Exception {
+    public static <K, V> List<ConsumerRecord<K, V>> waitUntilMinRecordsReceived(final Properties consumerConfig,
+                                                                                final String topic,
+                                                                                final int expectedNumRecords) throws Exception {
         return waitUntilMinRecordsReceived(consumerConfig, topic, expectedNumRecords, DEFAULT_TIMEOUT);
     }
 
@@ -488,12 +528,10 @@ public class IntegrationTestUtils {
      * @return All the records consumed, or null if no records are consumed
      */
     @SuppressWarnings("WeakerAccess")
-    public static <K, V> List<ConsumerRecord<K, V>> waitUntilMinRecordsReceived(
-        final Properties consumerConfig,
-        final String topic,
-        final int expectedNumRecords,
-        final long waitTime
-    ) throws Exception {
+    public static <K, V> List<ConsumerRecord<K, V>> waitUntilMinRecordsReceived(final Properties consumerConfig,
+                                                                                final String topic,
+                                                                                final int expectedNumRecords,
+                                                                                final long waitTime) throws Exception {
         final List<ConsumerRecord<K, V>> accumData = new ArrayList<>();
         final String reason = String.format(
             "Did not receive all %d records from topic %s within %d ms",
@@ -522,11 +560,9 @@ public class IntegrationTestUtils {
      * @param <V>                 Value type of the data records
      * @return All the records consumed, or null if no records are consumed
      */
-    public static <K, V> List<KeyValue<K, V>> waitUntilMinKeyValueRecordsReceived(
-        final Properties consumerConfig,
-        final String topic,
-        final int expectedNumRecords
-    ) throws Exception {
+    public static <K, V> List<KeyValue<K, V>> waitUntilMinKeyValueRecordsReceived(final Properties consumerConfig,
+                                                                                  final String topic,
+                                                                                  final int expectedNumRecords) throws Exception {
         return waitUntilMinKeyValueRecordsReceived(consumerConfig, topic, expectedNumRecords, DEFAULT_TIMEOUT);
     }
 
@@ -542,12 +578,10 @@ public class IntegrationTestUtils {
      * @return All the records consumed, or null if no records are consumed
      * @throws AssertionError    if the given wait time elapses
      */
-    public static <K, V> List<KeyValue<K, V>> waitUntilMinKeyValueRecordsReceived(
-        final Properties consumerConfig,
-        final String topic,
-        final int expectedNumRecords,
-        final long waitTime
-    ) throws Exception {
+    public static <K, V> List<KeyValue<K, V>> waitUntilMinKeyValueRecordsReceived(final Properties consumerConfig,
+                                                                                  final String topic,
+                                                                                  final int expectedNumRecords,
+                                                                                  final long waitTime) throws Exception {
         final List<KeyValue<K, V>> accumData = new ArrayList<>();
         final String reason = String.format(
             "Did not receive all %d records from topic %s within %d ms",
@@ -577,12 +611,10 @@ public class IntegrationTestUtils {
      * @param <K>                Key type of the data records
      * @param <V>                Value type of the data records
      */
-    public static <K, V> List<KeyValueTimestamp<K, V>> waitUntilMinKeyValueWithTimestampRecordsReceived(
-        final Properties consumerConfig,
-        final String topic,
-        final int expectedNumRecords,
-        final long waitTime
-    ) throws Exception {
+    public static <K, V> List<KeyValueTimestamp<K, V>> waitUntilMinKeyValueWithTimestampRecordsReceived(final Properties consumerConfig,
+                                                                                                        final String topic,
+                                                                                                        final int expectedNumRecords,
+                                                                                                        final long waitTime) throws Exception {
         final List<KeyValueTimestamp<K, V>> accumData = new ArrayList<>();
         final String reason = String.format(
             "Did not receive all %d records from topic %s within %d ms",
@@ -611,11 +643,9 @@ public class IntegrationTestUtils {
      * @param <V>                Value type of the data records
      * @return All the mappings consumed, or null if no records are consumed
      */
-    public static <K, V> List<KeyValue<K, V>> waitUntilFinalKeyValueRecordsReceived(
-        final Properties consumerConfig,
-        final String topic,
-        final List<KeyValue<K, V>> expectedRecords
-    ) throws Exception {
+    public static <K, V> List<KeyValue<K, V>> waitUntilFinalKeyValueRecordsReceived(final Properties consumerConfig,
+                                                                                    final String topic,
+                                                                                    final List<KeyValue<K, V>> expectedRecords) throws Exception {
         return waitUntilFinalKeyValueRecordsReceived(consumerConfig, topic, expectedRecords, DEFAULT_TIMEOUT);
     }
 
@@ -629,11 +659,9 @@ public class IntegrationTestUtils {
      * @param <V>                Value type of the data records
      * @return All the mappings consumed, or null if no records are consumed
      */
-    public static <K, V> List<KeyValueTimestamp<K, V>> waitUntilFinalKeyValueTimestampRecordsReceived(
-        final Properties consumerConfig,
-        final String topic,
-        final List<KeyValueTimestamp<K, V>> expectedRecords
-    ) throws Exception {
+    public static <K, V> List<KeyValueTimestamp<K, V>> waitUntilFinalKeyValueTimestampRecordsReceived(final Properties consumerConfig,
+                                                                                                      final String topic,
+                                                                                                      final List<KeyValueTimestamp<K, V>> expectedRecords) throws Exception {
         return waitUntilFinalKeyValueRecordsReceived(consumerConfig, topic, expectedRecords, DEFAULT_TIMEOUT, true);
     }
 
@@ -649,12 +677,10 @@ public class IntegrationTestUtils {
      * @return All the mappings consumed, or null if no records are consumed
      */
     @SuppressWarnings("WeakerAccess")
-    public static <K, V> List<KeyValue<K, V>> waitUntilFinalKeyValueRecordsReceived(
-        final Properties consumerConfig,
-        final String topic,
-        final List<KeyValue<K, V>> expectedRecords,
-        final long waitTime
-    ) throws Exception {
+    public static <K, V> List<KeyValue<K, V>> waitUntilFinalKeyValueRecordsReceived(final Properties consumerConfig,
+                                                                                    final String topic,
+                                                                                    final List<KeyValue<K, V>> expectedRecords,
+                                                                                    final long waitTime) throws Exception {
         return waitUntilFinalKeyValueRecordsReceived(consumerConfig, topic, expectedRecords, waitTime, false);
     }
 
@@ -875,7 +901,7 @@ public class IntegrationTestUtils {
     }
 
     /**
-     * Waits for the given {@link KafkaStreams} instances to all be in a  {@link State#RUNNING}
+     * Waits for the given {@link KafkaStreams} instances to all be in a {@link State#RUNNING}
      * state. Prefer {@link #startApplicationAndWaitUntilRunning(List, Duration)} when possible
      * because this method uses polling, which can be more error prone and slightly slower.
      *
@@ -1252,4 +1278,57 @@ public class IntegrationTestUtils {
             );
         }
     }
+
+    /**
+     * Tracks the offsets and number of restored records on a per-partition basis.
+     * Currently assumes only one store in the topology; you will need to update this
+     * if it's important to track across multiple stores in a topology
+     */
+    public static class TrackingStateRestoreListener implements StateRestoreListener {
+        public final Map<TopicPartition, AtomicLong> changelogToStartOffset = new ConcurrentHashMap<>();
+        public final Map<TopicPartition, AtomicLong> changelogToEndOffset = new ConcurrentHashMap<>();
+        public final Map<TopicPartition, AtomicLong> changelogToTotalNumRestored = new ConcurrentHashMap<>();
+
+        @Override
+        public void onRestoreStart(final TopicPartition topicPartition,
+                                   final String storeName,
+                                   final long startingOffset,
+                                   final long endingOffset) {
+            changelogToStartOffset.put(topicPartition, new AtomicLong(startingOffset));
+            changelogToEndOffset.put(topicPartition, new AtomicLong(endingOffset));
+            changelogToTotalNumRestored.put(topicPartition, new AtomicLong(0L));
+        }
+
+        @Override
+        public void onBatchRestored(final TopicPartition topicPartition,
+                                    final String storeName,
+                                    final long batchEndOffset,
+                                    final long numRestored) {
+            changelogToTotalNumRestored.get(topicPartition).addAndGet(numRestored);
+        }
+
+        @Override
+        public void onRestoreEnd(final TopicPartition topicPartition,
+                                 final String storeName,
+                                 final long totalRestored) {
+        }
+
+        public boolean allStartOffsetsAtZero() {
+            for (final AtomicLong startOffset : changelogToStartOffset.values()) {
+                if (startOffset.get() != 0L) {
+                    return false;
+                }
+            }
+            return true;
+        }
+
+        public long totalNumRestored() {
+            long totalNumRestored = 0L;
+            for (final AtomicLong numRestored : changelogToTotalNumRestored.values()) {
+                totalNumRestored += numRestored.get();
+            }
+            return totalNumRestored;
+        }
+    }
+
 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractProcessorContextTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractProcessorContextTest.java
index 72b415f..26caa46 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractProcessorContextTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractProcessorContextTest.java
@@ -32,6 +32,7 @@ import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.To;
 import org.apache.kafka.streams.state.RocksDBConfigSetter;
 import org.apache.kafka.streams.state.internals.ThreadCache;
+import org.apache.kafka.streams.state.internals.ThreadCache.DirtyEntryFlushListener;
 import org.apache.kafka.test.MockKeyValueStore;
 import org.junit.Before;
 import org.junit.Test;
@@ -241,5 +242,17 @@ public class AbstractProcessorContextTest {
                               final byte[] value,
                               final long timestamp) {
         }
+
+        @Override
+        public void transitionToActive(final StreamTask streamTask, final RecordCollector recordCollector, final ThreadCache newCache) {
+        }
+
+        @Override
+        public void transitionToStandby(final ThreadCache newCache) {
+        }
+
+        @Override
+        public void registerCacheFlushListener(final String namespace, final DirtyEntryFlushListener listener) {
+        }
     }
 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/MockChangelogReader.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/MockChangelogReader.java
index bd705a5..ad7dad6 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/MockChangelogReader.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/MockChangelogReader.java
@@ -64,7 +64,7 @@ public class MockChangelogReader implements ChangelogReader {
     }
 
     @Override
-    public void remove(final Collection<TopicPartition> partitions) {
+    public void unregister(final Collection<TopicPartition> partitions, final boolean triggerOnRestoreEnd) {
         restoringPartitions.removeAll(partitions);
 
         for (final TopicPartition partition : partitions) {
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorContextImplTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorContextImplTest.java
index 41cfdfa..58f4eb4 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorContextImplTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorContextImplTest.java
@@ -132,14 +132,16 @@ public class ProcessorContextImplTest {
 
         context = new ProcessorContextImpl(
             mock(TaskId.class),
-            mock(StreamTask.class),
             streamsConfig,
-            recordCollector,
             stateManager,
             mock(StreamsMetricsImpl.class),
             mock(ThreadCache.class)
         );
 
+        final StreamTask task = mock(StreamTask.class);
+        ((InternalProcessorContext) context).transitionToActive(task, null, null);
+        EasyMock.expect(task.recordCollector()).andStubReturn(recordCollector);
+
         context.setCurrentNode(new ProcessorNode<String, Long>("fake", null,
             new HashSet<>(asList(
                 "LocalKeyValueStore",
@@ -157,7 +159,8 @@ public class ProcessorContextImplTest {
             mock(TaskId.class),
             streamsConfig,
             stateManager,
-            mock(StreamsMetricsImpl.class)
+            mock(StreamsMetricsImpl.class),
+            mock(ThreadCache.class)
         );
     }
 
@@ -375,7 +378,10 @@ public class ProcessorContextImplTest {
 
         recordCollector.send(null, key, value, null, 0, 42L, BYTES_KEY_SERIALIZER, BYTEARRAY_VALUE_SERIALIZER);
 
-        replay(recordCollector);
+        final StreamTask task = EasyMock.createNiceMock(StreamTask.class);
+
+        replay(recordCollector, task);
+        context.transitionToActive(task, recordCollector, null);
         context.logChange("Store", key, value, 42L);
 
         verify(recordCollector);
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorContextTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorContextTest.java
index 44f01b0..46d8f70 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorContextTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorContextTest.java
@@ -52,13 +52,12 @@ public class ProcessorContextTest {
 
         context = new ProcessorContextImpl(
             mock(TaskId.class),
-            mock(StreamTask.class),
             streamsConfig,
-            mock(RecordCollector.class),
             stateManager,
             mock(StreamsMetricsImpl.class),
             mock(ThreadCache.class)
         );
+        ((InternalProcessorContext) context).transitionToActive(mock(StreamTask.class), null, null);
     }
 
     @Test
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java
index 6507349..fdfb3c2 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java
@@ -26,6 +26,7 @@ import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.errors.ProcessorStateException;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.errors.TaskCorruptedException;
+import org.apache.kafka.streams.processor.StateRestoreCallback;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender;
@@ -65,6 +66,10 @@ import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.common.utils.Utils.mkSet;
 import static org.apache.kafka.streams.processor.internals.StateManagerUtil.CHECKPOINT_FILE_NAME;
+import static org.easymock.EasyMock.expect;
+import static org.easymock.EasyMock.replay;
+import static org.easymock.EasyMock.reset;
+import static org.easymock.EasyMock.verify;
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.CoreMatchers.notNullValue;
 import static org.hamcrest.MatcherAssert.assertThat;
@@ -105,6 +110,7 @@ public class ProcessorStateManagerTest {
         new ConsumerRecord<>(persistentStoreTopicName, 1, 100L, keyBytes, valueBytes);
     private final MockChangelogReader changelogReader = new MockChangelogReader();
     private final LogContext logContext = new LogContext("process-state-manager-test ");
+    private final StateRestoreCallback noopStateRestoreCallback = (k, v) -> { };
 
     private File baseDir;
     private File checkpointFile;
@@ -115,6 +121,8 @@ public class ProcessorStateManagerTest {
     private StateStore store;
     @Mock(type = MockType.NICE)
     private StateStoreMetadata storeMetadata;
+    @Mock(type = MockType.NICE)
+    private InternalProcessorContext context;
 
     @Before
     public void setup() {
@@ -130,10 +138,10 @@ public class ProcessorStateManagerTest {
         checkpointFile = new File(stateDirectory.directoryForTask(taskId), CHECKPOINT_FILE_NAME);
         checkpoint = new OffsetCheckpoint(checkpointFile);
 
-        EasyMock.expect(storeMetadata.changelogPartition()).andReturn(persistentStorePartition).anyTimes();
-        EasyMock.expect(storeMetadata.store()).andReturn(store).anyTimes();
-        EasyMock.expect(store.name()).andReturn(persistentStoreName).anyTimes();
-        EasyMock.replay(storeMetadata, store);
+        expect(storeMetadata.changelogPartition()).andReturn(persistentStorePartition).anyTimes();
+        expect(storeMetadata.store()).andReturn(store).anyTimes();
+        expect(store.name()).andReturn(persistentStoreName).anyTimes();
+        replay(storeMetadata, store);
     }
 
     @After
@@ -268,6 +276,70 @@ public class ProcessorStateManagerTest {
     }
 
     @Test
+    public void shouldUnregisterChangelogsDuringClose() {
+        final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE);
+        reset(storeMetadata);
+        final StateStore store = EasyMock.createMock(StateStore.class);
+        expect(storeMetadata.changelogPartition()).andStubReturn(persistentStorePartition);
+        expect(storeMetadata.store()).andStubReturn(store);
+        expect(store.name()).andStubReturn(persistentStoreName);
+
+        context.uninitialize();
+        store.init(context, store);
+        replay(storeMetadata, context, store);
+
+        stateMgr.registerStateStores(singletonList(store), context);
+        verify(context, store);
+
+        stateMgr.registerStore(store, noopStateRestoreCallback);
+        assertTrue(changelogReader.isPartitionRegistered(persistentStorePartition));
+
+        reset(store);
+        expect(store.name()).andStubReturn(persistentStoreName);
+        store.close();
+        replay(store);
+
+        stateMgr.close();
+        verify(store);
+
+        assertFalse(changelogReader.isPartitionRegistered(persistentStorePartition));
+    }
+
+    @Test
+    public void shouldRecycleStoreAndReregisterChangelog() {
+        final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE);
+        reset(storeMetadata);
+        final StateStore store = EasyMock.createMock(StateStore.class);
+        expect(storeMetadata.changelogPartition()).andStubReturn(persistentStorePartition);
+        expect(storeMetadata.store()).andStubReturn(store);
+        expect(store.name()).andStubReturn(persistentStoreName);
+
+        context.uninitialize();
+        store.init(context, store);
+        replay(storeMetadata, context, store);
+
+        stateMgr.registerStateStores(singletonList(store), context);
+        verify(context, store);
+
+        stateMgr.registerStore(store, noopStateRestoreCallback);
+        assertTrue(changelogReader.isPartitionRegistered(persistentStorePartition));
+
+        stateMgr.recycle();
+        assertFalse(changelogReader.isPartitionRegistered(persistentStorePartition));
+        assertThat(stateMgr.getStore(persistentStoreName), equalTo(store));
+
+        reset(context, store);
+        context.uninitialize();
+        expect(store.name()).andStubReturn(persistentStoreName);
+        replay(context, store);
+
+        stateMgr.registerStateStores(singletonList(store), context);
+
+        verify(context, store);
+        assertTrue(changelogReader.isPartitionRegistered(persistentStorePartition));
+    }
+
+    @Test
     public void shouldRegisterPersistentStores() {
         final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE);
 
@@ -682,13 +754,13 @@ public class ProcessorStateManagerTest {
 
         try {
             stateManager.flush();
-        } catch (final ProcessorStateException expected) { /* ignode */ }
+        } catch (final ProcessorStateException expected) { /* ignore */ }
 
         Assert.assertTrue(flushedStore.get());
     }
 
     @Test
-    public void shouldCloseAllStoresEvenIfStoreThrowsExcepiton() {
+    public void shouldCloseAllStoresEvenIfStoreThrowsException() {
         final AtomicBoolean closedStore = new AtomicBoolean(false);
 
         final MockKeyValueStore stateStore1 = new MockKeyValueStore(persistentStoreName, true) {
@@ -710,7 +782,7 @@ public class ProcessorStateManagerTest {
 
         try {
             stateManager.close();
-        } catch (final ProcessorStateException expected) { /* ignode */ }
+        } catch (final ProcessorStateException expected) { /* ignore */ }
 
         Assert.assertTrue(closedStore.get());
     }
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
index 7c14f23..cbba51d 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
@@ -25,6 +25,7 @@ import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.metrics.Sensor;
 import org.apache.kafka.common.metrics.stats.CumulativeSum;
 import org.apache.kafka.common.serialization.IntegerSerializer;
+import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.StreamsConfig;
@@ -33,6 +34,7 @@ import org.apache.kafka.streams.errors.ProcessorStateException;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.internals.Task.TaskType;
 import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
+import org.apache.kafka.streams.state.internals.ThreadCache;
 import org.apache.kafka.test.MockKeyValueStore;
 import org.apache.kafka.test.MockKeyValueStoreBuilder;
 import org.apache.kafka.test.MockRestoreConsumer;
@@ -112,7 +114,8 @@ public class StandbyTaskTest {
 
     @Before
     public void setup() throws Exception {
-        EasyMock.expect(stateManager.taskId()).andReturn(taskId).anyTimes();
+        EasyMock.expect(stateManager.taskId()).andStubReturn(taskId);
+        EasyMock.expect(stateManager.taskType()).andStubReturn(TaskType.STANDBY);
 
         restoreStateConsumer.reset();
         restoreStateConsumer.updatePartitions(storeChangelogTopicName1, asList(
@@ -156,10 +159,7 @@ public class StandbyTaskTest {
 
     @Test
     public void shouldTransitToRunningAfterInitialization() {
-        stateManager.registerStore(store1, store1.stateRestoreCallback);
-        EasyMock.expectLastCall();
-        stateManager.registerStore(store2, store2.stateRestoreCallback);
-        EasyMock.expectLastCall();
+        stateManager.registerStateStores(EasyMock.anyObject(), EasyMock.anyObject());
         EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()).anyTimes();
         EasyMock.replay(stateManager);
 
@@ -181,6 +181,7 @@ public class StandbyTaskTest {
 
     @Test
     public void shouldThrowIfCommittingOnIllegalState() {
+        EasyMock.replay(stateManager);
         task = createStandbyTask();
 
         assertThrows(IllegalStateException.class, task::prepareCommit);
@@ -384,6 +385,8 @@ public class StandbyTaskTest {
 
     @Test
     public void shouldThrowIfClosingOnIllegalState() {
+        EasyMock.replay(stateManager);
+
         task = createStandbyTask();
 
         task.transitionTo(Task.State.RESTORING);
@@ -475,7 +478,31 @@ public class StandbyTaskTest {
     }
 
     private StandbyTask createStandbyTask() {
-        return new StandbyTask(taskId, Collections.singleton(partition), topology, config, streamsMetrics, stateManager, stateDirectory);
+
+        final ThreadCache cache = new ThreadCache(
+            new LogContext(String.format("stream-thread [%s] ", Thread.currentThread().getName())),
+            0,
+            streamsMetrics
+        );
+
+        final InternalProcessorContext context = new ProcessorContextImpl(
+            taskId,
+            config,
+            stateManager,
+            streamsMetrics,
+            cache
+        );
+
+        return new StandbyTask(
+            taskId,
+            Collections.singleton(partition),
+            topology,
+            config,
+            streamsMetrics,
+            stateManager,
+            stateDirectory,
+            cache,
+            context);
     }
 
     private MetricName setupCloseTaskMetric() {
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateManagerUtilTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateManagerUtilTest.java
index 6902227..9f2a5fd 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateManagerUtilTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateManagerUtilTest.java
@@ -16,11 +16,13 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
+import java.util.List;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.errors.LockException;
 import org.apache.kafka.streams.errors.ProcessorStateException;
 import org.apache.kafka.streams.errors.StreamsException;
+import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.internals.Task.TaskType;
 import org.apache.kafka.test.MockKeyValueStore;
@@ -139,30 +141,20 @@ public class StateManagerUtilTest {
 
     @Test
     public void testRegisterStateStores() throws IOException {
-        expect(topology.stateStores())
-            .andReturn(Arrays.asList(new MockKeyValueStore("store1", false),
-                new MockKeyValueStore("store2", false)));
+        final MockKeyValueStore store1 = new MockKeyValueStore("store1", false);
+        final MockKeyValueStore store2 = new MockKeyValueStore("store2", false);
+        final List<StateStore> stateStores = Arrays.asList(store1, store2);
+
+        expect(topology.stateStores()).andReturn(stateStores);
 
         expect(stateManager.taskId()).andReturn(taskId);
 
         expect(stateDirectory.lock(taskId)).andReturn(true);
         expect(stateDirectory.directoryForTaskIsEmpty(taskId)).andReturn(true);
 
-        final MockKeyValueStore store1 = new MockKeyValueStore("store1", false);
-        final MockKeyValueStore store2 = new MockKeyValueStore("store2", false);
-
-        expect(topology.stateStores()).andReturn(Arrays.asList(store1, store2));
+        expect(topology.stateStores()).andReturn(stateStores);
 
-        // Store1 will be registered as it hasn't been registered before.
-        expect(stateManager.getStore(store1.name())).andReturn(null);
-
-        processorContext.uninitialize();
-        expectLastCall();
-        processorContext.register(store1, store1.stateRestoreCallback);
-        expectLastCall();
-
-        // Store2 is already registered, so no more registration will happen.
-        expect(stateManager.getStore(store2.name())).andReturn(store2);
+        stateManager.registerStateStores(stateStores, processorContext);
 
         stateManager.initializeStoreOffsetsFromCheckpoint(true);
         expectLastCall();
@@ -350,7 +342,6 @@ public class StateManagerUtilTest {
     public void shouldNotWipeStateStoresIfUnableToLockTaskDirectory() throws IOException {
         final File unknownFile = new File("/unknown/path");
         expect(stateManager.taskId()).andReturn(taskId);
-
         expect(stateDirectory.lock(taskId)).andReturn(false);
 
         expect(stateManager.baseDir()).andReturn(unknownFile);
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StoreChangelogReaderTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StoreChangelogReaderTest.java
index 54599e0..e873e04 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StoreChangelogReaderTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StoreChangelogReaderTest.java
@@ -865,7 +865,7 @@ public class StoreChangelogReaderTest extends EasyMockSupport {
         assertEquals(StoreChangelogReader.ChangelogState.RESTORING, changelogReader.changelogMetadata(tp2).state());
 
         // should support removing and clearing changelogs
-        changelogReader.remove(Collections.singletonList(tp));
+        changelogReader.unregister(Collections.singletonList(tp), false);
         assertNull(changelogReader.changelogMetadata(tp));
         assertFalse(changelogReader.isEmpty());
         assertEquals(StoreChangelogReader.ChangelogState.RESTORING, changelogReader.changelogMetadata(tp1).state());
@@ -1006,7 +1006,7 @@ public class StoreChangelogReaderTest extends EasyMockSupport {
         // transition to update standby is NOT idempotent
         assertThrows(IllegalStateException.class, changelogReader::transitToUpdateStandby);
 
-        changelogReader.remove(Collections.singletonList(tp));
+        changelogReader.unregister(Collections.singletonList(tp), false);
         changelogReader.register(tp, activeStateManager);
 
         // if a new active is registered, we should immediately transit to standby updating
@@ -1060,7 +1060,7 @@ public class StoreChangelogReaderTest extends EasyMockSupport {
     public void shouldNotThrowOnUnknownRevokedPartition() {
         LogCaptureAppender.setClassLoggerToDebug(StoreChangelogReader.class);
         try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(StoreChangelogReader.class)) {
-            changelogReader.remove(Collections.singletonList(new TopicPartition("unknown", 0)));
+            changelogReader.unregister(Collections.singletonList(new TopicPartition("unknown", 0)), false);
 
             assertThat(
                 appender.getMessages(),
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
index 5df51e8..13d039b 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
@@ -48,7 +48,9 @@ import org.apache.kafka.streams.processor.PunctuationType;
 import org.apache.kafka.streams.processor.Punctuator;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.internals.Task.TaskType;
 import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
+import org.apache.kafka.streams.state.internals.ThreadCache;
 import org.apache.kafka.test.MockKeyValueStore;
 import org.apache.kafka.test.MockProcessorNode;
 import org.apache.kafka.test.MockSourceNode;
@@ -152,6 +154,8 @@ public class StreamTaskTest {
     private ProcessorStateManager stateManager;
     @Mock(type = MockType.NICE)
     private RecordCollector recordCollector;
+    @Mock(type = MockType.NICE)
+    private ThreadCache cache;
 
     private final Punctuator punctuator = new Punctuator() {
         @Override
@@ -203,7 +207,8 @@ public class StreamTaskTest {
 
     @Before
     public void setup() {
-        EasyMock.expect(stateManager.taskId()).andReturn(taskId).anyTimes();
+        EasyMock.expect(stateManager.taskId()).andStubReturn(taskId);
+        EasyMock.expect(stateManager.taskType()).andStubReturn(TaskType.ACTIVE);
 
         consumer.assign(asList(partition1, partition2));
         consumer.updateBeginningOffsets(mkMap(mkEntry(partition1, 0L), mkEntry(partition2, 0L)));
@@ -251,6 +256,7 @@ public class StreamTaskTest {
     public void shouldAttemptToDeleteStateDirectoryWhenCloseDirtyAndEosEnabled() throws IOException {
         final IMocksControl ctrl = EasyMock.createStrictControl();
         final ProcessorStateManager stateManager = ctrl.createMock(ProcessorStateManager.class);
+        EasyMock.expect(stateManager.taskType()).andStubReturn(TaskType.ACTIVE);
         stateDirectory = ctrl.createMock(StateDirectory.class);
 
         stateManager.registerGlobalStateStores(Collections.emptyList());
@@ -259,7 +265,6 @@ public class StreamTaskTest {
         EasyMock.expect(stateManager.taskId()).andReturn(taskId);
 
         EasyMock.expect(stateDirectory.lock(taskId)).andReturn(true);
-        EasyMock.expectLastCall();
 
         stateManager.close();
         EasyMock.expectLastCall();
@@ -1404,7 +1409,7 @@ public class StreamTaskTest {
         task.initializeIfNeeded();
         task.completeRestoration();
         task.punctuate(processorStreamTime, 5, PunctuationType.STREAM_TIME, punctuator);
-        assertThat(((InternalProcessorContext) task.context()).currentNode(), nullValue());
+        assertThat(task.processorContext().currentNode(), nullValue());
     }
 
     @Test(expected = IllegalStateException.class)
@@ -1421,9 +1426,8 @@ public class StreamTaskTest {
     }
 
     @Test
-    public void shouldCloseStateManagerEvenFailureOnUncleanTaskClose() {
+    public void shouldCloseStateManagerEvenDuringFailureOnUncleanTaskClose() {
         EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()).anyTimes();
-        stateManager.registerStore(stateStore, stateStore.stateRestoreCallback);
         EasyMock.expectLastCall();
 
         stateManager.close();
@@ -1457,18 +1461,29 @@ public class StreamTaskTest {
         EasyMock.expect(recordCollector.offsets()).andReturn(Collections.emptyMap()).anyTimes();
         EasyMock.replay(stateManager, recordCollector);
 
+        final StreamsConfig config = createConfig(false, "0");
+        final InternalProcessorContext context = new ProcessorContextImpl(
+            taskId,
+            config,
+            stateManager,
+            streamsMetrics,
+            null
+        );
+
         task = new StreamTask(
             taskId,
             mkSet(partition1, repartition),
             topology,
             consumer,
-            createConfig(false, "0"),
+            config,
             streamsMetrics,
             stateDirectory,
-            null,
+            cache,
             time,
             stateManager,
-            recordCollector);
+            recordCollector,
+            context);
+
         task.initializeIfNeeded();
         task.completeRestoration();
 
@@ -1800,6 +1815,14 @@ public class StreamTaskTest {
             singletonList(stateStore),
             Collections.singletonMap(storeName, topic1));
 
+        final InternalProcessorContext context = new ProcessorContextImpl(
+            taskId,
+            config,
+            stateManager,
+            streamsMetrics,
+            null
+        );
+
         return new StreamTask(
             taskId,
             mkSet(partition1),
@@ -1808,10 +1831,11 @@ public class StreamTaskTest {
             config,
             streamsMetrics,
             stateDirectory,
-            null,
+            cache,
             time,
             stateManager,
-            recordCollector);
+            recordCollector,
+            context);
     }
 
     private StreamTask createDisconnectedTask(final StreamsConfig config) {
@@ -1830,6 +1854,14 @@ public class StreamTaskTest {
             }
         };
 
+        final InternalProcessorContext context = new ProcessorContextImpl(
+            taskId,
+            config,
+            stateManager,
+            streamsMetrics,
+            null
+        );
+
         return new StreamTask(
             taskId,
             partitions,
@@ -1838,10 +1870,11 @@ public class StreamTaskTest {
             config,
             streamsMetrics,
             stateDirectory,
-            null,
+            cache,
             time,
             stateManager,
-            recordCollector);
+            recordCollector,
+            context);
     }
 
     private StreamTask createFaultyStatefulTask(final StreamsConfig config) {
@@ -1851,6 +1884,14 @@ public class StreamTaskTest {
             singletonList(stateStore),
             Collections.emptyMap());
 
+        final InternalProcessorContext context = new ProcessorContextImpl(
+            taskId,
+            config,
+            stateManager,
+            streamsMetrics,
+            null
+        );
+
         return new StreamTask(
             taskId,
             partitions,
@@ -1859,10 +1900,11 @@ public class StreamTaskTest {
             config,
             streamsMetrics,
             stateDirectory,
-            null,
+            cache,
             time,
             stateManager,
-            recordCollector);
+            recordCollector,
+            context);
     }
 
     private StreamTask createStatefulTask(final StreamsConfig config, final boolean logged) {
@@ -1878,6 +1920,14 @@ public class StreamTaskTest {
             singletonList(stateStore),
             logged ? Collections.singletonMap(storeName, storeName + "-changelog") : Collections.emptyMap());
 
+        final InternalProcessorContext context = new ProcessorContextImpl(
+            taskId,
+            config,
+            stateManager,
+            streamsMetrics,
+            null
+        );
+
         return new StreamTask(
             taskId,
             partitions,
@@ -1886,13 +1936,14 @@ public class StreamTaskTest {
             config,
             streamsMetrics,
             stateDirectory,
-            null,
+            cache,
             time,
             stateManager,
-            recordCollector);
+            recordCollector,
+            context);
     }
 
-    private StreamTask createStatelessTask(final StreamsConfig streamsConfig,
+    private StreamTask createStatelessTask(final StreamsConfig config,
                                            final String builtInMetricsVersion) {
         final ProcessorTopology topology = withSources(
             asList(source1, source2, processorStreamTime, processorSystemTime),
@@ -1908,18 +1959,27 @@ public class StreamTaskTest {
         EasyMock.expect(recordCollector.offsets()).andReturn(Collections.emptyMap()).anyTimes();
         EasyMock.replay(stateManager, recordCollector);
 
+        final InternalProcessorContext context = new ProcessorContextImpl(
+            taskId,
+            config,
+            stateManager,
+            streamsMetrics,
+            null
+        );
+
         return new StreamTask(
             taskId,
             partitions,
             topology,
             consumer,
-            streamsConfig,
+            config,
             new StreamsMetricsImpl(metrics, "test", builtInMetricsVersion),
             stateDirectory,
-            null,
+            cache,
             time,
             stateManager,
-            recordCollector);
+            recordCollector,
+            context);
     }
 
     private ConsumerRecord<byte[], byte[]> getConsumerRecord(final TopicPartition topicPartition, final long offset) {
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
index c6a0eab..c7446e2 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
@@ -978,9 +978,6 @@ public class TaskManagerTest {
         consumer.commitSync(expectedCommittedOffsets);
         expectLastCall().andThrow(new RuntimeException("Something went wrong!"));
 
-        changeLogReader.remove(singleton(t1p1));
-        expectLastCall();
-
         replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader);
 
         taskManager.handleAssignment(assignmentActive, emptyMap());
@@ -1165,9 +1162,6 @@ public class TaskManagerTest {
 
         resetToStrict(changeLogReader);
         expect(changeLogReader.completedChangelogs()).andReturn(emptySet());
-        // make sure we also remove the changelog partitions from the changelog reader
-        changeLogReader.remove(eq(singletonList(changelog)));
-        expectLastCall();
         expect(activeTaskCreator.createTasks(anyObject(), eq(assignment)))
             .andReturn(asList(task00, task01, task02, task03)).anyTimes();
         activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId00));
@@ -1249,9 +1243,6 @@ public class TaskManagerTest {
 
         resetToStrict(changeLogReader);
         expect(changeLogReader.completedChangelogs()).andReturn(emptySet());
-        // make sure we also remove the changelog partitions from the changelog reader
-        changeLogReader.remove(eq(singletonList(changelog)));
-        expectLastCall();
         expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andReturn(singletonList(task00)).anyTimes();
         activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId00));
         expectLastCall().andThrow(new RuntimeException("whatever"));
@@ -1303,9 +1294,6 @@ public class TaskManagerTest {
 
         resetToStrict(changeLogReader);
         expect(changeLogReader.completedChangelogs()).andReturn(emptySet());
-        // make sure we also remove the changelog partitions from the changelog reader
-        changeLogReader.remove(eq(singletonList(changelog)));
-        expectLastCall();
         expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andReturn(singletonList(task00)).anyTimes();
         activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId00));
         expectLastCall();
@@ -1468,9 +1456,6 @@ public class TaskManagerTest {
 
         resetToStrict(changeLogReader);
         expect(changeLogReader.completedChangelogs()).andReturn(emptySet());
-        // make sure we also remove the changelog partitions from the changelog reader
-        changeLogReader.remove(eq(singletonList(changelog)));
-        expectLastCall();
         expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andReturn(asList(task00, task01, task02)).anyTimes();
         activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId00));
         expectLastCall().andThrow(new RuntimeException("whatever 0"));
@@ -2630,6 +2615,8 @@ public class TaskManagerTest {
         private Map<TopicPartition, OffsetAndMetadata> committableOffsets = Collections.emptyMap();
         private Map<TopicPartition, Long> purgeableOffsets;
         private Map<TopicPartition, Long> changelogOffsets = Collections.emptyMap();
+        private InternalProcessorContext processorContext = mock(InternalProcessorContext.class);
+
         private final Map<TopicPartition, LinkedList<ConsumerRecord<byte[], byte[]>>> queue = new HashMap<>();
 
         StateMachineTask(final TaskId id,
@@ -2719,6 +2706,11 @@ public class TaskManagerTest {
             transitionTo(State.CLOSED);
         }
 
+        @Override
+        public void closeAndRecycleState() {
+            transitionTo(State.CLOSED);
+        }
+
         void setCommittableOffsetsAndMetadata(final Map<TopicPartition, OffsetAndMetadata> committableOffsets) {
             if (!active) {
                 throw new IllegalStateException("Cannot set CommittableOffsetsAndMetadate for StandbyTasks");
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractRocksDBSegmentedBytesStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractRocksDBSegmentedBytesStoreTest.java
index ce16d91..5b0081c 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractRocksDBSegmentedBytesStoreTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractRocksDBSegmentedBytesStoreTest.java
@@ -364,11 +364,11 @@ public abstract class AbstractRocksDBSegmentedBytesStoreTest<S extends Segment>
 
     @Test
     public void shouldRestoreToByteStoreForStandbyTask() {
+        context.transitionToStandby(null);
         shouldRestoreToByteStore(TaskType.STANDBY);
     }
 
     private void shouldRestoreToByteStore(final TaskType taskType) {
-        context.setTaskType(taskType);
         bytesStore.init(context, bytesStore);
         // 0 segments initially.
         assertEquals(0, bytesStore.getSegments().size());
diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java
index 7aad8d6..b7d2d32 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java
@@ -29,8 +29,10 @@ import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.TopologyWrapper;
 import org.apache.kafka.streams.errors.InvalidStateStoreException;
 import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.internals.InternalProcessorContext;
 import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder;
 import org.apache.kafka.streams.processor.internals.MockStreamsMetrics;
+import org.apache.kafka.streams.processor.internals.ProcessorContextImpl;
 import org.apache.kafka.streams.processor.internals.ProcessorStateManager;
 import org.apache.kafka.streams.processor.internals.ProcessorTopology;
 import org.apache.kafka.streams.processor.internals.RecordCollector;
@@ -41,6 +43,7 @@ import org.apache.kafka.streams.processor.internals.StreamTask;
 import org.apache.kafka.streams.processor.internals.StreamThread;
 import org.apache.kafka.streams.processor.internals.StreamsProducer;
 import org.apache.kafka.streams.processor.internals.Task;
+import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
 import org.apache.kafka.streams.state.QueryableStoreTypes;
 import org.apache.kafka.streams.state.ReadOnlyKeyValueStore;
 import org.apache.kafka.streams.state.ReadOnlyWindowStore;
@@ -383,18 +386,27 @@ public class StreamThreadStateStoreProviderTest {
             ),
             streamsConfig.defaultProductionExceptionHandler(),
             new MockStreamsMetrics(metrics));
+        final StreamsMetricsImpl streamsMetrics = new MockStreamsMetrics(metrics);
+        final InternalProcessorContext context = new ProcessorContextImpl(
+            taskId,
+            streamsConfig,
+            stateManager,
+            streamsMetrics,
+            null
+        );
         return new StreamTask(
             taskId,
             partitions,
             topology,
             clientSupplier.consumer,
             streamsConfig,
-            new MockStreamsMetrics(metrics),
+            streamsMetrics,
             stateDirectory,
-            null,
+            EasyMock.createNiceMock(ThreadCache.class),
             new MockTime(),
             stateManager,
-            recordCollector);
+            recordCollector,
+            context);
     }
 
     private void mockThread(final boolean initialized) {
diff --git a/streams/src/test/java/org/apache/kafka/test/InternalMockProcessorContext.java b/streams/src/test/java/org/apache/kafka/test/InternalMockProcessorContext.java
index ebe9053..ed5d943 100644
--- a/streams/src/test/java/org/apache/kafka/test/InternalMockProcessorContext.java
+++ b/streams/src/test/java/org/apache/kafka/test/InternalMockProcessorContext.java
@@ -39,11 +39,13 @@ import org.apache.kafka.streams.processor.internals.ProcessorNode;
 import org.apache.kafka.streams.processor.internals.ProcessorRecordContext;
 import org.apache.kafka.streams.processor.internals.RecordBatchingStateRestoreCallback;
 import org.apache.kafka.streams.processor.internals.RecordCollector;
+import org.apache.kafka.streams.processor.internals.StreamTask;
 import org.apache.kafka.streams.processor.internals.Task.TaskType;
 import org.apache.kafka.streams.processor.internals.ToInternal;
 import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
 import org.apache.kafka.streams.state.StateSerdes;
 import org.apache.kafka.streams.state.internals.ThreadCache;
+import org.apache.kafka.streams.state.internals.ThreadCache.DirtyEntryFlushListener;
 import org.apache.kafka.streams.state.internals.metrics.RocksDBMetricsRecordingTrigger;
 
 import java.io.File;
@@ -66,10 +68,10 @@ public class InternalMockProcessorContext
     private final Map<String, StateRestoreCallback> restoreFuncs = new HashMap<>();
     private final ToInternal toInternal = new ToInternal();
 
+    private TaskType taskType = TaskType.ACTIVE;
     private Serde<?> keySerde;
     private Serde<?> valSerde;
     private long timestamp = -1L;
-    private TaskType taskType = TaskType.ACTIVE;
 
     public InternalMockProcessorContext() {
         this(null,
@@ -160,7 +162,8 @@ public class InternalMockProcessorContext
                                         final Serde<?> valSerde,
                                         final RecordCollector collector,
                                         final ThreadCache cache) {
-        this(stateDir,
+        this(
+            stateDir,
             keySerde,
             valSerde,
             new StreamsMetricsImpl(new Metrics(), "mock", StreamsConfig.METRICS_LATEST),
@@ -177,11 +180,13 @@ public class InternalMockProcessorContext
                                         final StreamsConfig config,
                                         final RecordCollector.Supplier collectorSupplier,
                                         final ThreadCache cache) {
-        super(new TaskId(0, 0),
+        super(
+            new TaskId(0, 0),
             config,
             metrics,
             null,
-            cache);
+            cache
+        );
         super.setCurrentNode(new ProcessorNode<>("TESTING_NODE"));
         this.stateDir = stateDir;
         this.keySerde = keySerde;
@@ -357,10 +362,6 @@ public class InternalMockProcessorContext
         return taskType;
     }
 
-    public void setTaskType(final TaskType taskType) {
-        this.taskType = taskType;
-    }
-
     @Override
     public void logChange(final String storeName,
                           final Bytes key,
@@ -377,6 +378,21 @@ public class InternalMockProcessorContext
             BYTEARRAY_VALUE_SERIALIZER);
     }
 
+    @Override
+    public void transitionToActive(final StreamTask streamTask, final RecordCollector recordCollector, final ThreadCache newCache) {
+        taskType = TaskType.ACTIVE;
+    }
+
+    @Override
+    public void transitionToStandby(final ThreadCache newCache) {
+        taskType = TaskType.STANDBY;
+    }
+
+    @Override
+    public void registerCacheFlushListener(final String namespace, final DirtyEntryFlushListener listener) {
+        cache().addDirtyEntryFlushListener(namespace, listener);
+    }
+
     public StateRestoreListener getRestoreListener(final String storeName) {
         return getStateRestoreListener(restoreFuncs.get(storeName));
     }
diff --git a/streams/src/test/java/org/apache/kafka/test/MockInternalProcessorContext.java b/streams/src/test/java/org/apache/kafka/test/MockInternalProcessorContext.java
index e375085..c62f69a 100644
--- a/streams/src/test/java/org/apache/kafka/test/MockInternalProcessorContext.java
+++ b/streams/src/test/java/org/apache/kafka/test/MockInternalProcessorContext.java
@@ -25,6 +25,7 @@ import org.apache.kafka.streams.processor.internals.InternalProcessorContext;
 import org.apache.kafka.streams.processor.internals.ProcessorNode;
 import org.apache.kafka.streams.processor.internals.ProcessorRecordContext;
 import org.apache.kafka.streams.processor.internals.RecordCollector;
+import org.apache.kafka.streams.processor.internals.StreamTask;
 import org.apache.kafka.streams.processor.internals.Task.TaskType;
 import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
 import org.apache.kafka.streams.state.internals.ThreadCache;
@@ -33,6 +34,7 @@ import java.io.File;
 import java.util.LinkedHashMap;
 import java.util.Map;
 import java.util.Properties;
+import org.apache.kafka.streams.state.internals.ThreadCache.DirtyEntryFlushListener;
 
 public class MockInternalProcessorContext extends MockProcessorContext implements InternalProcessorContext {
 
@@ -131,4 +133,16 @@ public class MockInternalProcessorContext extends MockProcessorContext implement
                           final byte[] value,
                           final long timestamp) {
     }
+
+    @Override
+    public void transitionToActive(final StreamTask streamTask, final RecordCollector recordCollector, final ThreadCache newCache) {
+    }
+
+    @Override
+    public void transitionToStandby(final ThreadCache newCache) {
+    }
+
+    @Override
+    public void registerCacheFlushListener(final String namespace, final DirtyEntryFlushListener listener) {
+    }
 }
\ No newline at end of file
diff --git a/streams/src/test/java/org/apache/kafka/test/NoOpProcessorContext.java b/streams/src/test/java/org/apache/kafka/test/NoOpProcessorContext.java
index da8b7b4..88d5fca 100644
--- a/streams/src/test/java/org/apache/kafka/test/NoOpProcessorContext.java
+++ b/streams/src/test/java/org/apache/kafka/test/NoOpProcessorContext.java
@@ -33,7 +33,11 @@ import java.time.Duration;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Properties;
+import org.apache.kafka.streams.processor.internals.RecordCollector;
+import org.apache.kafka.streams.processor.internals.StreamTask;
 import org.apache.kafka.streams.processor.internals.Task.TaskType;
+import org.apache.kafka.streams.state.internals.ThreadCache;
+import org.apache.kafka.streams.state.internals.ThreadCache.DirtyEntryFlushListener;
 
 public class NoOpProcessorContext extends AbstractProcessorContext {
     public boolean initialized;
@@ -117,4 +121,17 @@ public class NoOpProcessorContext extends AbstractProcessorContext {
                           final byte[] value,
                           final long timestamp) {
     }
+
+    @Override
+    public void transitionToActive(final StreamTask streamTask, final RecordCollector recordCollector, final ThreadCache newCache) {
+    }
+
+    @Override
+    public void transitionToStandby(final ThreadCache newCache) {
+    }
+
+    @Override
+    public void registerCacheFlushListener(final String namespace, final DirtyEntryFlushListener listener) {
+        cache.addDirtyEntryFlushListener(namespace, listener);
+    }
 }
diff --git a/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
index 8475172..923a980 100644
--- a/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
+++ b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
@@ -490,6 +490,15 @@ public class TopologyTestDriver implements Closeable {
                 streamsConfig.defaultProductionExceptionHandler(),
                 streamsMetrics
             );
+
+            final InternalProcessorContext context = new ProcessorContextImpl(
+                TASK_ID,
+                streamsConfig,
+                stateManager,
+                streamsMetrics,
+                cache
+            );
+
             task = new StreamTask(
                 TASK_ID,
                 new HashSet<>(partitionsByInputTopic.values()),
@@ -501,11 +510,12 @@ public class TopologyTestDriver implements Closeable {
                 cache,
                 mockWallClockTime,
                 stateManager,
-                recordCollector
+                recordCollector,
+                context
             );
             task.initializeIfNeeded();
             task.completeRestoration();
-            ((InternalProcessorContext) task.context()).setRecordContext(new ProcessorRecordContext(
+            task.processorContext().setRecordContext(new ProcessorRecordContext(
                 0L,
                 -1L,
                 -1,
@@ -991,7 +1001,7 @@ public class TopologyTestDriver implements Closeable {
     private StateStore getStateStore(final String name,
                                      final boolean throwForBuiltInStores) {
         if (task != null) {
-            final StateStore stateStore = ((ProcessorContextImpl) task.context()).stateManager().getStore(name);
+            final StateStore stateStore = ((ProcessorContextImpl) task.processorContext()).stateManager().getStore(name);
             if (stateStore != null) {
                 if (throwForBuiltInStores) {
                     throwIfBuiltInStore(stateStore);