You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by mj...@apache.org on 2018/10/08 22:14:27 UTC

[kafka] branch 1.0 updated: KAFKA-7192: Wipe out state store if EOS is turned on and checkpoint file does not exist (#5657)

This is an automated email from the ASF dual-hosted git repository.

mjsax pushed a commit to branch 1.0
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/1.0 by this push:
     new 88bca08  KAFKA-7192: Wipe out state store if EOS is turned on and checkpoint file does not exist (#5657)
88bca08 is described below

commit 88bca08e1455d950958cc2583c24a8b8f46f66e5
Author: Matthias J. Sax <mj...@apache.org>
AuthorDate: Mon Oct 8 15:14:19 2018 -0700

    KAFKA-7192: Wipe out state store if EOS is turned on and checkpoint file does not exist (#5657)
    
    Reviewers: Guozhang Wang <gu...@confluent.io>, John Roesler <jo...@confluent.io>, Bill Bejeck <bi...@confluent.io>
---
 .../internals/AbstractProcessorContext.java        |  6 ++
 .../streams/processor/internals/AbstractTask.java  |  5 +-
 .../streams/processor/internals/AssignedTasks.java |  9 +--
 .../internals/InternalProcessorContext.java        |  5 ++
 .../processor/internals/ProcessorStateManager.java | 61 ++++++++++++++++++++
 .../processor/internals/RestoringTasks.java        |  2 +-
 .../streams/processor/internals/StateRestorer.java | 19 ++++--
 .../processor/internals/StoreChangelogReader.java  | 58 +++++++++++++------
 .../streams/integration/EosIntegrationTest.java    |  4 +-
 .../processor/internals/AssignedTasksTest.java     |  4 +-
 .../internals/StoreChangelogReaderTest.java        | 67 +++++++++++++---------
 11 files changed, 180 insertions(+), 60 deletions(-)

diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractProcessorContext.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractProcessorContext.java
index 410212e..eeaed89 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractProcessorContext.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractProcessorContext.java
@@ -198,4 +198,10 @@ public abstract class AbstractProcessorContext implements InternalProcessorConte
     public void initialized() {
         initialized = true;
     }
+
+    @Override
+    public void uninitialize() {
+        initialized = false;
+    }
+
 }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractTask.java
index baea4af..fcf2f6b 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractTask.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractTask.java
@@ -107,7 +107,7 @@ public abstract class AbstractTask implements Task {
     }
 
     @Override
-    public final Set<TopicPartition> partitions() {
+    public Set<TopicPartition> partitions() {
         return partitions;
     }
 
@@ -226,6 +226,9 @@ public abstract class AbstractTask implements Task {
         }
     }
 
+    void reinitializeStateStoresForPartitions(final TopicPartition partitions) {
+        stateMgr.reinitializeStateStoresForPartitions(partitions, processorContext);
+    }
 
     /**
      * @throws ProcessorStateException if there is an error while closing the state manager
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AssignedTasks.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AssignedTasks.java
index 0d9d04d..cfce575 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AssignedTasks.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AssignedTasks.java
@@ -50,7 +50,7 @@ class AssignedTasks implements RestoringTasks {
     // IQ may access this map.
     private Map<TaskId, Task> running = new ConcurrentHashMap<>();
     private Map<TopicPartition, Task> runningByPartition = new HashMap<>();
-    private Map<TopicPartition, Task> restoringByPartition = new HashMap<>();
+    private Map<TopicPartition, StreamTask> restoringByPartition = new HashMap<>();
     private int committed = 0;
 
 
@@ -122,7 +122,8 @@ class AssignedTasks implements RestoringTasks {
             try {
                 if (!entry.getValue().initializeStateStores()) {
                     log.debug("Transitioning {} {} to restoring", taskTypeName, entry.getKey());
-                    addToRestoring(entry.getValue());
+                    // cast is safe, because StandbyTasks always returns `true` in `initializeStateStores()` above
+                    addToRestoring((StreamTask) entry.getValue());
                 } else {
                     transitionToRunning(entry.getValue(), readyPartitions);
                 }
@@ -278,7 +279,7 @@ class AssignedTasks implements RestoringTasks {
         return false;
     }
 
-    private void addToRestoring(final Task task) {
+    private void addToRestoring(final StreamTask task) {
         restoring.put(task.id(), task);
         for (TopicPartition topicPartition : task.partitions()) {
             restoringByPartition.put(topicPartition, task);
@@ -307,7 +308,7 @@ class AssignedTasks implements RestoringTasks {
     }
 
     @Override
-    public Task restoringTaskFor(final TopicPartition partition) {
+    public StreamTask restoringTaskFor(final TopicPartition partition) {
         return restoringByPartition.get(partition);
     }
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalProcessorContext.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalProcessorContext.java
index 57bb3ac..b5719b1 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalProcessorContext.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalProcessorContext.java
@@ -53,4 +53,9 @@ public interface InternalProcessorContext extends ProcessorContext {
      * Mark this contex as being initialized
      */
     void initialized();
+
+    /**
+     * Mark this context as being uninitialized
+     */
+    void uninitialize();
 }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java
index 5536ac1..9ccb458 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java
@@ -19,8 +19,10 @@ package org.apache.kafka.streams.processor.internals;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.KeyValue;
 import org.apache.kafka.streams.errors.ProcessorStateException;
+import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.processor.BatchingStateRestoreCallback;
 import org.apache.kafka.streams.processor.StateRestoreCallback;
 import org.apache.kafka.streams.processor.StateStore;
@@ -33,9 +35,11 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 
 
 public class ProcessorStateManager implements StateManager {
@@ -49,6 +53,7 @@ public class ProcessorStateManager implements StateManager {
     private final String logPrefix;
     private final boolean isStandby;
     private final ChangelogReader changelogReader;
+    private final boolean eosEnabled;
     private final Map<String, StateStore> stores;
     private final Map<String, StateStore> globalStores;
     private final Map<TopicPartition, Long> offsetLimits;
@@ -98,6 +103,7 @@ public class ProcessorStateManager implements StateManager {
         checkpoint = new OffsetCheckpoint(new File(baseDir, CHECKPOINT_FILE_NAME));
         checkpointedOffsets = new HashMap<>(checkpoint.read());
 
+        this.eosEnabled = eosEnabled;
         if (eosEnabled) {
             // delete the checkpoint file after finish loading its stored offsets
             checkpoint.delete();
@@ -176,6 +182,61 @@ public class ProcessorStateManager implements StateManager {
         return partitionsAndOffsets;
     }
 
+    void reinitializeStateStoresForPartitions(final TopicPartition topicPartition,
+                                              final InternalProcessorContext processorContext) {
+        final Map<String, String> changelogTopicToStore = inverseOneToOneMap(storeToChangelogTopic);
+        final Set<String> storeToBeReinitialized = new HashSet<>();
+        final Map<String, StateStore> storesCopy = new HashMap<>(stores);
+
+        checkpointedOffsets.remove(topicPartition);
+        storeToBeReinitialized.add(changelogTopicToStore.get(topicPartition.topic()));
+
+        if (!eosEnabled) {
+            try {
+                checkpoint.write(checkpointedOffsets);
+            } catch (final IOException fatalException) {
+                log.error("Failed to write offset checkpoint file to {} while re-initializing {}: {}", checkpoint, stores, fatalException);
+                throw new StreamsException("Failed to reinitialize stores.", fatalException);
+            }
+        }
+
+        for (final Map.Entry<String, StateStore> entry : storesCopy.entrySet()) {
+            final StateStore stateStore = entry.getValue();
+            final String storeName = stateStore.name();
+            if (storeToBeReinitialized.contains(storeName)) {
+                try {
+                    stateStore.close();
+                } catch (final RuntimeException ignoreAndSwallow) { /* ignore */ }
+                processorContext.uninitialize();
+                stores.remove(entry.getKey());
+
+                try {
+                    Utils.delete(new File(baseDir + File.separator + "rocksdb" + File.separator + storeName));
+                } catch (final IOException fatalException) {
+                    log.error("Failed to reinitialize store {}.", storeName, fatalException);
+                    throw new StreamsException(String.format("Failed to reinitialize store %s.", storeName), fatalException);
+                }
+
+                try {
+                    Utils.delete(new File(baseDir + File.separator + storeName));
+                } catch (final IOException fatalException) {
+                    log.error("Failed to reinitialize store {}.", storeName, fatalException);
+                    throw new StreamsException(String.format("Failed to reinitialize store %s.", storeName), fatalException);
+                }
+
+                stateStore.init(processorContext, stateStore);
+            }
+        }
+    }
+
+    private Map<String, String> inverseOneToOneMap(final Map<String, String> origin) {
+        final Map<String, String> reversedMap = new HashMap<>();
+        for (final Map.Entry<String, String> entry : origin.entrySet()) {
+            reversedMap.put(entry.getValue(), entry.getKey());
+        }
+        return reversedMap;
+    }
+
     List<ConsumerRecord<byte[], byte[]>> updateStandbyStates(final TopicPartition storePartition,
                                                              final List<ConsumerRecord<byte[], byte[]>> records) {
         final long limit = offsetLimit(storePartition);
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RestoringTasks.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RestoringTasks.java
index 6ed28fd..3671b49 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RestoringTasks.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RestoringTasks.java
@@ -19,5 +19,5 @@ package org.apache.kafka.streams.processor.internals;
 import org.apache.kafka.common.TopicPartition;
 
 public interface RestoringTasks {
-    Task restoringTaskFor(final TopicPartition partition);
+    StreamTask restoringTaskFor(final TopicPartition partition);
 }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateRestorer.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateRestorer.java
index 33dce9e..e0bac93 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateRestorer.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateRestorer.java
@@ -26,13 +26,13 @@ public class StateRestorer {
 
     static final int NO_CHECKPOINT = -1;
 
-    private final Long checkpoint;
     private final long offsetLimit;
     private final boolean persistent;
     private final String storeName;
     private final TopicPartition partition;
     private final CompositeRestoreListener compositeRestoreListener;
 
+    private long checkpointOffset;
     private long restoredOffset;
     private long startingOffset;
     private long endingOffset;
@@ -45,7 +45,7 @@ public class StateRestorer {
                   final String storeName) {
         this.partition = partition;
         this.compositeRestoreListener = compositeRestoreListener;
-        this.checkpoint = checkpoint;
+        this.checkpointOffset = checkpoint == null ? NO_CHECKPOINT : checkpoint;
         this.offsetLimit = offsetLimit;
         this.persistent = persistent;
         this.storeName = storeName;
@@ -56,7 +56,15 @@ public class StateRestorer {
     }
 
     long checkpoint() {
-        return checkpoint == null ? NO_CHECKPOINT : checkpoint;
+        return checkpointOffset;
+    }
+
+    void setCheckpointOffset(final long checkpointOffset) {
+        this.checkpointOffset = checkpointOffset;
+    }
+
+    public String storeName() {
+        return storeName;
     }
 
     void restoreStarted() {
@@ -67,7 +75,8 @@ public class StateRestorer {
         compositeRestoreListener.onRestoreEnd(partition, storeName, restoredNumRecords());
     }
 
-    void restoreBatchCompleted(long currentRestoredOffset, int numRestored) {
+    void restoreBatchCompleted(final long currentRestoredOffset,
+                               final int numRestored) {
         compositeRestoreListener.onBatchRestored(partition, storeName, currentRestoredOffset, numRestored);
     }
 
@@ -79,7 +88,7 @@ public class StateRestorer {
         return persistent;
     }
 
-    void setUserRestoreListener(StateRestoreListener userRestoreListener) {
+    void setUserRestoreListener(final StateRestoreListener userRestoreListener) {
         this.compositeRestoreListener.setUserRestoreListener(userRestoreListener);
     }
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
index 8d85b1d..34350c1 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
@@ -69,7 +69,7 @@ public class StoreChangelogReader implements ChangelogReader {
      */
     public Collection<TopicPartition> restore(final RestoringTasks active) {
         if (!needsInitializing.isEmpty()) {
-            initialize();
+            initialize(active);
         }
 
         if (needsRestoring.isEmpty()) {
@@ -90,7 +90,7 @@ public class StoreChangelogReader implements ChangelogReader {
         return completed();
     }
 
-    private void initialize() {
+    private void initialize(final RestoringTasks active) {
         if (!consumer.subscription().isEmpty()) {
             throw new IllegalStateException("Restore consumer should not be subscribed to any topics (" + consumer.subscription() + ")");
         }
@@ -99,8 +99,8 @@ public class StoreChangelogReader implements ChangelogReader {
         // the needsInitializing map is not empty, meaning we do not know the metadata for some of them yet
         refreshChangelogInfo();
 
-        Map<TopicPartition, StateRestorer> initializable = new HashMap<>();
-        for (Map.Entry<TopicPartition, StateRestorer> entry : needsInitializing.entrySet()) {
+        final Map<TopicPartition, StateRestorer> initializable = new HashMap<>();
+        for (final Map.Entry<TopicPartition, StateRestorer> entry : needsInitializing.entrySet()) {
             final TopicPartition topicPartition = entry.getKey();
             if (hasPartition(topicPartition)) {
                 initializable.put(entry.getKey(), entry.getValue());
@@ -144,11 +144,12 @@ public class StoreChangelogReader implements ChangelogReader {
 
         // set up restorer for those initializable
         if (!initializable.isEmpty()) {
-            startRestoration(initializable);
+            startRestoration(initializable, active);
         }
     }
 
-    private void startRestoration(final Map<TopicPartition, StateRestorer> initialized) {
+    private void startRestoration(final Map<TopicPartition, StateRestorer> initialized,
+                                  final RestoringTasks active) {
         log.debug("Start restoring state stores from changelog topics {}", initialized.keySet());
 
         final Set<TopicPartition> assignment = new HashSet<>(consumer.assignment());
@@ -157,26 +158,47 @@ public class StoreChangelogReader implements ChangelogReader {
 
         final List<StateRestorer> needsPositionUpdate = new ArrayList<>();
         for (final StateRestorer restorer : initialized.values()) {
+            final TopicPartition restoringPartition = restorer.partition();
             if (restorer.checkpoint() != StateRestorer.NO_CHECKPOINT) {
-                consumer.seek(restorer.partition(), restorer.checkpoint());
-                logRestoreOffsets(restorer.partition(),
-                        restorer.checkpoint(),
-                        endOffsets.get(restorer.partition()));
-                restorer.setStartingOffset(consumer.position(restorer.partition()));
+                consumer.seek(restoringPartition, restorer.checkpoint());
+                logRestoreOffsets(restoringPartition,
+                                  restorer.checkpoint(),
+                                  endOffsets.get(restoringPartition));
+                restorer.setStartingOffset(consumer.position(restoringPartition));
                 restorer.restoreStarted();
             } else {
-                consumer.seekToBeginning(Collections.singletonList(restorer.partition()));
+                consumer.seekToBeginning(Collections.singletonList(restoringPartition));
                 needsPositionUpdate.add(restorer);
             }
         }
 
         for (final StateRestorer restorer : needsPositionUpdate) {
-            final long position = consumer.position(restorer.partition());
-            logRestoreOffsets(restorer.partition(),
-                              position,
-                              endOffsets.get(restorer.partition()));
-            restorer.setStartingOffset(position);
-            restorer.restoreStarted();
+            final TopicPartition restoringPartition = restorer.partition();
+            final StreamTask task = active.restoringTaskFor(restoringPartition);
+
+            // If checkpoint does not exist it means the task was not shutdown gracefully before;
+            // and in this case if EOS is turned on we should wipe out the state and re-initialize the task
+            if (task.eosEnabled) {
+                log.info("No checkpoint found for task {} state store {} changelog {} with EOS turned on. " +
+                    "Reinitializing the task and restore its state from the beginning.", task.id, restorer.storeName(), restoringPartition);
+
+                // we move the partitions here, because they will be added back within
+                // `task.reinitializeStateStoresForPartitions()` that calls `register()` internally again
+                needsInitializing.remove(restoringPartition);
+                restorer.setCheckpointOffset(consumer.position(restoringPartition));
+
+                task.reinitializeStateStoresForPartitions(restoringPartition);
+                stateRestorers.get(restoringPartition).restoreStarted();
+            } else {
+                log.info("Restoring task {}'s state store {} from beginning of the changelog {} ", task.id, restorer.storeName(), restoringPartition);
+
+                final long position = consumer.position(restoringPartition);
+                logRestoreOffsets(restoringPartition,
+                                  position,
+                                  endOffsets.get(restoringPartition));
+                restorer.setStartingOffset(position);
+                restorer.restoreStarted();
+            }
         }
 
         needsRestoring.putAll(initialized);
diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/EosIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/EosIntegrationTest.java
index 6c7b2b4..d07781f 100644
--- a/streams/src/test/java/org/apache/kafka/streams/integration/EosIntegrationTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/integration/EosIntegrationTest.java
@@ -392,7 +392,7 @@ public class EosIntegrationTest {
         // the app commits after each 10 records per partition, and thus will have 2*5 uncommitted writes
         // and store updates (ie, another 5 uncommitted writes to a changelog topic per partition)
         //
-        // the failure gets inject after 20 committed and 30 uncommitted records got received
+        // the failure gets inject after 20 committed and 10 uncommitted records got received
         // -> the failure only kills one thread
         // after fail over, we should read 40 committed records and the state stores should contain the correct sums
         // per key (even if some records got processed twice)
@@ -402,7 +402,7 @@ public class EosIntegrationTest {
             streams.start();
 
             final List<KeyValue<Long, Long>> committedDataBeforeFailure = prepareData(0L, 10L, 0L, 1L);
-            final List<KeyValue<Long, Long>> uncommittedDataBeforeFailure = prepareData(10L, 15L, 0L, 1L);
+            final List<KeyValue<Long, Long>> uncommittedDataBeforeFailure = prepareData(10L, 15L, 0L, 1L, 2L, 3L);
 
             final List<KeyValue<Long, Long>> dataBeforeFailure = new ArrayList<>();
             dataBeforeFailure.addAll(committedDataBeforeFailure);
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/AssignedTasksTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/AssignedTasksTest.java
index 5c8b7c4..01439ad 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/AssignedTasksTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/AssignedTasksTest.java
@@ -40,8 +40,8 @@ import static org.junit.Assert.fail;
 
 public class AssignedTasksTest {
 
-    private final Task t1 = EasyMock.createMock(Task.class);
-    private final Task t2 = EasyMock.createMock(Task.class);
+    private final StreamTask t1 = EasyMock.createMock(StreamTask.class);
+    private final StreamTask t2 = EasyMock.createMock(StreamTask.class);
     private final TopicPartition tp1 = new TopicPartition("t1", 0);
     private final TopicPartition tp2 = new TopicPartition("t2", 0);
     private final TopicPartition changeLog1 = new TopicPartition("cl1", 0);
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StoreChangelogReaderTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StoreChangelogReaderTest.java
index bb0c51e..a0e2140 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StoreChangelogReaderTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StoreChangelogReaderTest.java
@@ -59,7 +59,7 @@ public class StoreChangelogReaderTest {
     @Mock(type = MockType.NICE)
     private RestoringTasks active;
     @Mock(type = MockType.NICE)
-    private Task task;
+    private StreamTask task;
 
     private final MockStateRestoreListener callback = new MockStateRestoreListener();
     private final CompositeRestoreListener restoreListener = new CompositeRestoreListener(callback);
@@ -107,6 +107,10 @@ public class StoreChangelogReaderTest {
         final int messages = 10;
         setupConsumer(messages, topicPartition);
         changelogReader.register(new StateRestorer(topicPartition, restoreListener, null, Long.MAX_VALUE, true, "storeName"));
+
+        expect(active.restoringTaskFor(topicPartition)).andStubReturn(task);
+        replay(active, task);
+
         changelogReader.restore(active);
         assertThat(callback.restored.size(), equalTo(messages));
     }
@@ -115,8 +119,7 @@ public class StoreChangelogReaderTest {
     public void shouldRestoreMessagesFromCheckpoint() {
         final int messages = 10;
         setupConsumer(messages, topicPartition);
-        changelogReader.register(new StateRestorer(topicPartition, restoreListener, 5L, Long.MAX_VALUE, true,
-                                                   "storeName"));
+        changelogReader.register(new StateRestorer(topicPartition, restoreListener, 5L, Long.MAX_VALUE, true, "storeName"));
 
         changelogReader.restore(active);
         assertThat(callback.restored.size(), equalTo(5));
@@ -126,9 +129,9 @@ public class StoreChangelogReaderTest {
     public void shouldClearAssignmentAtEndOfRestore() {
         final int messages = 1;
         setupConsumer(messages, topicPartition);
-        changelogReader.register(new StateRestorer(topicPartition, restoreListener, null, Long.MAX_VALUE, true,
-                                                   "storeName"));
-
+        changelogReader.register(new StateRestorer(topicPartition, restoreListener, null, Long.MAX_VALUE, true, "storeName"));
+        expect(active.restoringTaskFor(topicPartition)).andStubReturn(task);
+        replay(active, task);
         changelogReader.restore(active);
         assertThat(consumer.assignment(), equalTo(Collections.<TopicPartition>emptySet()));
     }
@@ -136,9 +139,10 @@ public class StoreChangelogReaderTest {
     @Test
     public void shouldRestoreToLimitWhenSupplied() {
         setupConsumer(10, topicPartition);
-        final StateRestorer restorer = new StateRestorer(topicPartition, restoreListener, null, 3, true,
-                                                         "storeName");
+        final StateRestorer restorer = new StateRestorer(topicPartition, restoreListener, null, 3, true, "storeName");
         changelogReader.register(restorer);
+        expect(active.restoringTaskFor(topicPartition)).andStubReturn(task);
+        replay(active, task);
         changelogReader.restore(active);
         assertThat(callback.restored.size(), equalTo(3));
         assertThat(restorer.restoredOffset(), equalTo(3L));
@@ -156,14 +160,14 @@ public class StoreChangelogReaderTest {
         setupConsumer(5, one);
         setupConsumer(3, two);
 
-        changelogReader
-            .register(new StateRestorer(topicPartition, restoreListener, null, Long.MAX_VALUE, true, "storeName1"));
+        changelogReader.register(new StateRestorer(topicPartition, restoreListener, null, Long.MAX_VALUE, true, "storeName1"));
         changelogReader.register(new StateRestorer(one, restoreListener1, null, Long.MAX_VALUE, true, "storeName2"));
         changelogReader.register(new StateRestorer(two, restoreListener2, null, Long.MAX_VALUE, true, "storeName3"));
 
-        expect(active.restoringTaskFor(one)).andReturn(null);
-        expect(active.restoringTaskFor(two)).andReturn(null);
-        replay(active);
+        expect(active.restoringTaskFor(one)).andStubReturn(task);
+        expect(active.restoringTaskFor(two)).andStubReturn(task);
+        expect(active.restoringTaskFor(topicPartition)).andStubReturn(task);
+        replay(active, task);
         changelogReader.restore(active);
 
         assertThat(callback.restored.size(), equalTo(10));
@@ -188,9 +192,10 @@ public class StoreChangelogReaderTest {
         changelogReader.register(new StateRestorer(one, restoreListener1, null, Long.MAX_VALUE, true, "storeName2"));
         changelogReader.register(new StateRestorer(two, restoreListener2, null, Long.MAX_VALUE, true, "storeName3"));
 
-        expect(active.restoringTaskFor(one)).andReturn(null);
-        expect(active.restoringTaskFor(two)).andReturn(null);
-        replay(active);
+        expect(active.restoringTaskFor(one)).andStubReturn(task);
+        expect(active.restoringTaskFor(two)).andStubReturn(task);
+        expect(active.restoringTaskFor(topicPartition)).andStubReturn(task);
+        replay(active, task);
         changelogReader.restore(active);
 
         assertThat(callback.restored.size(), equalTo(10));
@@ -210,8 +215,10 @@ public class StoreChangelogReaderTest {
     @Test
     public void shouldOnlyReportTheLastRestoredOffset() {
         setupConsumer(10, topicPartition);
-        changelogReader
-            .register(new StateRestorer(topicPartition, restoreListener, null, 5, true, "storeName1"));
+        changelogReader.register(new StateRestorer(topicPartition, restoreListener, null, 5, true, "storeName1"));
+        expect(active.restoringTaskFor(topicPartition)).andStubReturn(task);
+        replay(active, task);
+
         changelogReader.restore(active);
 
         assertThat(callback.restored.size(), equalTo(5));
@@ -270,6 +277,8 @@ public class StoreChangelogReaderTest {
     public void shouldReturnRestoredOffsetsForPersistentStores() {
         setupConsumer(10, topicPartition);
         changelogReader.register(new StateRestorer(topicPartition, restoreListener, null, Long.MAX_VALUE, true, "storeName"));
+        expect(active.restoringTaskFor(topicPartition)).andStubReturn(task);
+        replay(active, task);
         changelogReader.restore(active);
         final Map<TopicPartition, Long> restoredOffsets = changelogReader.restoredOffsets();
         assertThat(restoredOffsets, equalTo(Collections.singletonMap(topicPartition, 10L)));
@@ -279,6 +288,8 @@ public class StoreChangelogReaderTest {
     public void shouldNotReturnRestoredOffsetsForNonPersistentStore() {
         setupConsumer(10, topicPartition);
         changelogReader.register(new StateRestorer(topicPartition, restoreListener, null, Long.MAX_VALUE, false, "storeName"));
+        expect(active.restoringTaskFor(topicPartition)).andStubReturn(task);
+        replay(active, task);
         changelogReader.restore(active);
         final Map<TopicPartition, Long> restoredOffsets = changelogReader.restoredOffsets();
         assertThat(restoredOffsets, equalTo(Collections.<TopicPartition, Long>emptyMap()));
@@ -292,8 +303,9 @@ public class StoreChangelogReaderTest {
         consumer.addRecord(new ConsumerRecord<>(topicPartition.topic(), topicPartition.partition(), 1, (byte[]) null, bytes));
         consumer.addRecord(new ConsumerRecord<>(topicPartition.topic(), topicPartition.partition(), 2, bytes, bytes));
         consumer.assign(Collections.singletonList(topicPartition));
-        changelogReader.register(new StateRestorer(topicPartition, restoreListener, null, Long.MAX_VALUE, false,
-                                                   "storeName"));
+        changelogReader.register(new StateRestorer(topicPartition, restoreListener, null, Long.MAX_VALUE, false, "storeName"));
+        expect(active.restoringTaskFor(topicPartition)).andStubReturn(task);
+        replay(active, task);
         changelogReader.restore(active);
 
         assertThat(callback.restored, CoreMatchers.equalTo(Utils.mkList(KeyValue.pair(bytes, bytes), KeyValue.pair(bytes, bytes))));
@@ -318,10 +330,11 @@ public class StoreChangelogReaderTest {
         changelogReader.register(new StateRestorer(topicPartition, restoreListener, null, Long.MAX_VALUE, false, "storeName"));
 
         final TopicPartition postInitialization = new TopicPartition("other", 0);
-        expect(active.restoringTaskFor(topicPartition)).andReturn(null);
-        expect(active.restoringTaskFor(topicPartition)).andReturn(null);
-        expect(active.restoringTaskFor(postInitialization)).andReturn(null);
-        replay(active);
+        expect(active.restoringTaskFor(topicPartition)).andStubReturn(task);
+        expect(active.restoringTaskFor(topicPartition)).andStubReturn(task);
+        expect(active.restoringTaskFor(postInitialization)).andStubReturn(task);
+        expect(active.restoringTaskFor(topicPartition)).andStubReturn(task);
+        replay(active, task);
 
         assertTrue(changelogReader.restore(active).isEmpty());
 
@@ -348,7 +361,7 @@ public class StoreChangelogReaderTest {
         consumer.updateEndOffsets(Collections.singletonMap(topicPartition, 5L));
         changelogReader.register(new StateRestorer(topicPartition, restoreListener, null, Long.MAX_VALUE, true, "storeName"));
 
-        expect(active.restoringTaskFor(topicPartition)).andReturn(task);
+        expect(active.restoringTaskFor(topicPartition)).andStubReturn(task);
         replay(active);
 
         try {
@@ -371,8 +384,8 @@ public class StoreChangelogReaderTest {
         consumer.updateEndOffsets(Collections.singletonMap(topicPartition, 6L));
         changelogReader.register(new StateRestorer(topicPartition, restoreListener, null, Long.MAX_VALUE, true, "storeName"));
 
-        expect(active.restoringTaskFor(topicPartition)).andReturn(task);
-        replay(active);
+        expect(active.restoringTaskFor(topicPartition)).andStubReturn(task);
+        replay(active, task);
         try {
             changelogReader.restore(active);
             fail("Should have thrown task migrated exception");