You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by gu...@apache.org on 2019/04/22 22:01:45 UTC

[kafka] branch 2.2 updated: KAFKA-8204: fix Streams store flush order (#6555)

This is an automated email from the ASF dual-hosted git repository.

guozhang pushed a commit to branch 2.2
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/2.2 by this push:
     new 3dbe849  KAFKA-8204: fix Streams store flush order (#6555)
3dbe849 is described below

commit 3dbe849f478fbbcf3208bb0da153f99249ad0475
Author: John Roesler <vv...@users.noreply.github.com>
AuthorDate: Mon Apr 22 16:53:51 2019 -0500

    KAFKA-8204: fix Streams store flush order (#6555)
    
    Streams previously flushed stores in the order of their registration, which is arbitrary. Because stores may forward values upon flush (as in cached state stores), we must flush stores in topological order.
    
    Reviewers: Bill Bejeck <bb...@gmail.com>, Guozhang Wang <wa...@gmail.com>
---
 .../apache/kafka/common/utils/FixedOrderMap.java   |  63 +++++++
 .../kafka/common/utils/FixedOrderMapTest.java      |  86 +++++++++
 .../processor/internals/AbstractStateManager.java  |  79 +++++----
 .../internals/GlobalStateManagerImpl.java          |  53 +++---
 .../processor/internals/ProcessorStateManager.java | 162 ++++++++++-------
 .../processor/internals/ProcessorTopology.java     |  62 +------
 .../streams/processor/internals/StreamTask.java    |   2 +
 .../processor/internals/AbstractTaskTest.java      |  11 +-
 .../internals/GlobalStateManagerImplTest.java      |  14 +-
 .../processor/internals/GlobalStateTaskTest.java   |   2 +-
 .../internals/ProcessorStateManagerTest.java       | 193 ++++++++++++---------
 .../internals/ProcessorTopologyFactories.java      |  53 ++++++
 .../processor/internals/StandbyTaskTest.java       |  14 +-
 .../processor/internals/StreamTaskTest.java        |  39 ++++-
 .../org/apache/kafka/test/MockKeyValueStore.java   |  10 ++
 15 files changed, 564 insertions(+), 279 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/common/utils/FixedOrderMap.java b/clients/src/main/java/org/apache/kafka/common/utils/FixedOrderMap.java
new file mode 100644
index 0000000..6518878
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/common/utils/FixedOrderMap.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.common.utils;
+
+import java.util.LinkedHashMap;
+import java.util.Map;
+
+/**
+ * An ordered map (LinkedHashMap) implementation for which the order is immutable.
+ * To accomplish this, all methods of removing mappings are disabled (they are marked
+ * deprecated and throw an exception).
+ *
+ * This class is final to prevent subclasses from violating the desired property.
+ *
+ * @param <K> The key type
+ * @param <V> The value type
+ */
+public final class FixedOrderMap<K, V> extends LinkedHashMap<K, V> {
+    private static final long serialVersionUID = -6504110858733236170L;
+
+    @Deprecated
+    @Override
+    protected boolean removeEldestEntry(final Map.Entry<K, V> eldest) {
+        return false;
+    }
+
+    @Deprecated
+    @Override
+    public V remove(final Object key) {
+        throw new UnsupportedOperationException("Removing from registeredStores is not allowed");
+    }
+
+    @Deprecated
+    @Override
+    public boolean remove(final Object key, final Object value) {
+        throw new UnsupportedOperationException("Removing from registeredStores is not allowed");
+    }
+
+    @Deprecated
+    @Override
+    public void clear() {
+        throw new UnsupportedOperationException("Removing from registeredStores is not allowed");
+    }
+
+    @Override
+    public FixedOrderMap<K, V> clone() {
+        throw new UnsupportedOperationException();
+    }
+}
diff --git a/clients/src/test/java/org/apache/kafka/common/utils/FixedOrderMapTest.java b/clients/src/test/java/org/apache/kafka/common/utils/FixedOrderMapTest.java
new file mode 100644
index 0000000..7d3f3f7
--- /dev/null
+++ b/clients/src/test/java/org/apache/kafka/common/utils/FixedOrderMapTest.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.common.utils;
+
+import org.hamcrest.CoreMatchers;
+import org.junit.Test;
+
+import java.util.Iterator;
+import java.util.Map;
+
+import static org.apache.kafka.common.utils.Utils.mkEntry;
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.fail;
+
+public class FixedOrderMapTest {
+    @Test
+    public void shouldMaintainOrderWhenAdding() {
+        final FixedOrderMap<String, Integer> map = new FixedOrderMap<>();
+        map.put("a", 0);
+        map.put("b", 1);
+        map.put("c", 2);
+        map.put("b", 3);
+        final Iterator<Map.Entry<String, Integer>> iterator = map.entrySet().iterator();
+        assertThat(iterator.next(), is(mkEntry("a", 0)));
+        assertThat(iterator.next(), is(mkEntry("b", 3)));
+        assertThat(iterator.next(), is(mkEntry("c", 2)));
+        assertThat(iterator.hasNext(), is(false));
+    }
+
+    @SuppressWarnings("deprecation")
+    @Test
+    public void shouldForbidRemove() {
+        final FixedOrderMap<String, Integer> map = new FixedOrderMap<>();
+        map.put("a", 0);
+        try {
+            map.remove("a");
+            fail("expected exception");
+        } catch (final RuntimeException e) {
+            assertThat(e, CoreMatchers.instanceOf(UnsupportedOperationException.class));
+        }
+        assertThat(map.get("a"), is(0));
+    }
+
+    @SuppressWarnings("deprecation")
+    @Test
+    public void shouldForbidConditionalRemove() {
+        final FixedOrderMap<String, Integer> map = new FixedOrderMap<>();
+        map.put("a", 0);
+        try {
+            map.remove("a", 0);
+            fail("expected exception");
+        } catch (final RuntimeException e) {
+            assertThat(e, CoreMatchers.instanceOf(UnsupportedOperationException.class));
+        }
+        assertThat(map.get("a"), is(0));
+    }
+
+    @SuppressWarnings("deprecation")
+    @Test
+    public void shouldForbidConditionalClear() {
+        final FixedOrderMap<String, Integer> map = new FixedOrderMap<>();
+        map.put("a", 0);
+        try {
+            map.clear();
+            fail("expected exception");
+        } catch (final RuntimeException e) {
+            assertThat(e, CoreMatchers.instanceOf(UnsupportedOperationException.class));
+        }
+        assertThat(map.get("a"), is(0));
+    }
+}
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractStateManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractStateManager.java
index c306468..2190c43 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractStateManager.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractStateManager.java
@@ -17,6 +17,7 @@
 package org.apache.kafka.streams.processor.internals;
 
 import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.utils.FixedOrderMap;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.processor.StateStore;
@@ -29,8 +30,8 @@ import java.io.IOException;
 import java.util.Collection;
 import java.util.HashMap;
 import java.util.HashSet;
-import java.util.LinkedHashMap;
 import java.util.Map;
+import java.util.Optional;
 import java.util.Set;
 
 import static org.apache.kafka.streams.state.internals.RecordConverters.identity;
@@ -45,8 +46,7 @@ abstract class AbstractStateManager implements StateManager {
     OffsetCheckpoint checkpoint;
 
     final Map<TopicPartition, Long> checkpointableOffsets = new HashMap<>();
-    final Map<String, StateStore> stores = new LinkedHashMap<>();
-    final Map<String, StateStore> globalStores = new LinkedHashMap<>();
+    final FixedOrderMap<String, Optional<StateStore>> globalStores = new FixedOrderMap<>();
 
     AbstractStateManager(final File baseDir,
                          final boolean eosEnabled) {
@@ -60,17 +60,16 @@ abstract class AbstractStateManager implements StateManager {
     }
 
     public void reinitializeStateStoresForPartitions(final Logger log,
-                                                     final Map<String, StateStore> stateStores,
+                                                     final FixedOrderMap<String, Optional<StateStore>> stateStores,
                                                      final Map<String, String> storeToChangelogTopic,
                                                      final Collection<TopicPartition> partitions,
                                                      final InternalProcessorContext processorContext) {
         final Map<String, String> changelogTopicToStore = inverseOneToOneMap(storeToChangelogTopic);
-        final Set<String> storeToBeReinitialized = new HashSet<>();
-        final Map<String, StateStore> storesCopy = new HashMap<>(stateStores);
+        final Set<String> storesToBeReinitialized = new HashSet<>();
 
         for (final TopicPartition topicPartition : partitions) {
             checkpointableOffsets.remove(topicPartition);
-            storeToBeReinitialized.add(changelogTopicToStore.get(topicPartition.topic()));
+            storesToBeReinitialized.add(changelogTopicToStore.get(topicPartition.topic()));
         }
 
         if (!eosEnabled) {
@@ -82,36 +81,44 @@ abstract class AbstractStateManager implements StateManager {
             }
         }
 
-        for (final Map.Entry<String, StateStore> entry : storesCopy.entrySet()) {
-            final StateStore stateStore = entry.getValue();
-            final String storeName = stateStore.name();
-            if (storeToBeReinitialized.contains(storeName)) {
-                try {
-                    stateStore.close();
-                } catch (final RuntimeException ignoreAndSwallow) { /* ignore */ }
-                processorContext.uninitialize();
-                stateStores.remove(entry.getKey());
-
-                // TODO remove this eventually
-                // -> (only after we are sure, we don't need it for backward compatibility reasons anymore; maybe 2.0 release?)
-                // this is an ugly "hack" that is required because RocksDBStore does not follow the pattern to put the
-                // store directory as <taskDir>/<storeName> but nests it with an intermediate <taskDir>/rocksdb/<storeName>
-                try {
-                    Utils.delete(new File(baseDir + File.separator + "rocksdb" + File.separator + storeName));
-                } catch (final IOException fatalException) {
-                    log.error("Failed to reinitialize store {}.", storeName, fatalException);
-                    throw new StreamsException(String.format("Failed to reinitialize store %s.", storeName), fatalException);
-                }
-
-                try {
-                    Utils.delete(new File(baseDir + File.separator + storeName));
-                } catch (final IOException fatalException) {
-                    log.error("Failed to reinitialize store {}.", storeName, fatalException);
-                    throw new StreamsException(String.format("Failed to reinitialize store %s.", storeName), fatalException);
-                }
-
-                stateStore.init(processorContext, stateStore);
+        for (final String storeName : storesToBeReinitialized) {
+            if (!stateStores.containsKey(storeName)) {
+                // the store has never been registered; carry on...
+                continue;
             }
+            final StateStore stateStore = stateStores
+                .get(storeName)
+                .orElseThrow(
+                    () -> new IllegalStateException(
+                        "Re-initializing store that has not been initialized. This is a bug in Kafka Streams."
+                    )
+                );
+
+            try {
+                stateStore.close();
+            } catch (final RuntimeException ignoreAndSwallow) { /* ignore */ }
+            processorContext.uninitialize();
+            stateStores.put(storeName, Optional.empty());
+
+            // TODO remove this eventually
+            // -> (only after we are sure, we don't need it for backward compatibility reasons anymore; maybe 2.0 release?)
+            // this is an ugly "hack" that is required because RocksDBStore does not follow the pattern to put the
+            // store directory as <taskDir>/<storeName> but nests it with an intermediate <taskDir>/rocksdb/<storeName>
+            try {
+                Utils.delete(new File(baseDir + File.separator + "rocksdb" + File.separator + storeName));
+            } catch (final IOException fatalException) {
+                log.error("Failed to reinitialize store {}.", storeName, fatalException);
+                throw new StreamsException(String.format("Failed to reinitialize store %s.", storeName), fatalException);
+            }
+
+            try {
+                Utils.delete(new File(baseDir + File.separator + storeName));
+            } catch (final IOException fatalException) {
+                log.error("Failed to reinitialize store {}.", storeName, fatalException);
+                throw new StreamsException(String.format("Failed to reinitialize store %s.", storeName), fatalException);
+            }
+
+            stateStore.init(processorContext, stateStore);
         }
     }
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java
index 7682814..f224e1e 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java
@@ -45,6 +45,7 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.Set;
 
 /**
@@ -140,7 +141,7 @@ public class GlobalStateManagerImpl extends AbstractStateManager implements Glob
 
     @Override
     public StateStore getGlobalStore(final String name) {
-        return globalStores.get(name);
+        return globalStores.getOrDefault(name, Optional.empty()).orElse(null);
     }
 
     @Override
@@ -203,7 +204,7 @@ public class GlobalStateManagerImpl extends AbstractStateManager implements Glob
                 store.name(),
                 converterForStore(store)
             );
-            globalStores.put(store.name(), store);
+            globalStores.put(store.name(), Optional.of(store));
         } finally {
             globalConsumer.unsubscribe();
         }
@@ -306,12 +307,20 @@ public class GlobalStateManagerImpl extends AbstractStateManager implements Glob
     @Override
     public void flush() {
         log.debug("Flushing all global globalStores registered in the state manager");
-        for (final StateStore store : this.globalStores.values()) {
-            try {
-                log.trace("Flushing global store={}", store.name());
-                store.flush();
-            } catch (final Exception e) {
-                throw new ProcessorStateException(String.format("Failed to flush global state store %s", store.name()), e);
+        for (final Map.Entry<String, Optional<StateStore>> entry : globalStores.entrySet()) {
+            if (entry.getValue().isPresent()) {
+                final StateStore store = entry.getValue().get();
+                try {
+                    log.trace("Flushing global store={}", store.name());
+                    store.flush();
+                } catch (final Exception e) {
+                    throw new ProcessorStateException(
+                        String.format("Failed to flush global state store %s", store.name()),
+                        e
+                    );
+                }
+            } else {
+                throw new IllegalStateException("Expected " + entry.getKey() + " to have been initialized");
             }
         }
     }
@@ -324,20 +333,24 @@ public class GlobalStateManagerImpl extends AbstractStateManager implements Glob
                 return;
             }
             final StringBuilder closeFailed = new StringBuilder();
-            for (final Map.Entry<String, StateStore> entry : globalStores.entrySet()) {
-                log.debug("Closing global storage engine {}", entry.getKey());
-                try {
-                    entry.getValue().close();
-                } catch (final Exception e) {
-                    log.error("Failed to close global state store {}", entry.getKey(), e);
-                    closeFailed.append("Failed to close global state store:")
-                            .append(entry.getKey())
-                            .append(". Reason: ")
-                            .append(e.toString())
-                            .append("\n");
+            for (final Map.Entry<String, Optional<StateStore>> entry : globalStores.entrySet()) {
+                if (entry.getValue().isPresent()) {
+                    log.debug("Closing global storage engine {}", entry.getKey());
+                    try {
+                        entry.getValue().get().close();
+                    } catch (final Exception e) {
+                        log.error("Failed to close global state store {}", entry.getKey(), e);
+                        closeFailed.append("Failed to close global state store:")
+                                   .append(entry.getKey())
+                                   .append(". Reason: ")
+                                   .append(e.toString())
+                                   .append("\n");
+                    }
+                    globalStores.put(entry.getKey(), Optional.empty());
+                } else {
+                    log.info("Skipping to close non-initialized store {}", entry.getKey());
                 }
             }
-            globalStores.clear();
             if (closeFailed.length() > 0) {
                 throw new ProcessorStateException("Exceptions caught during close of 1 or more global state globalStores\n" + closeFailed);
             }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java
index 5c54ee7..f060d29 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java
@@ -18,6 +18,7 @@ package org.apache.kafka.streams.processor.internals;
 
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.utils.FixedOrderMap;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.streams.errors.ProcessorStateException;
 import org.apache.kafka.streams.processor.StateRestoreCallback;
@@ -34,7 +35,9 @@ import java.util.Collection;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 
+import static java.util.Collections.unmodifiableList;
 import static org.apache.kafka.streams.processor.internals.StateRestoreCallbackAdapter.adapt;
 
 
@@ -51,6 +54,10 @@ public class ProcessorStateManager extends AbstractStateManager {
     private final Map<String, StateRestoreCallback> restoreCallbacks; // used for standby tasks, keyed by state topic name
     private final Map<String, RecordConverter> recordConverters; // used for standby tasks, keyed by state topic name
     private final Map<String, String> storeToChangelogTopic;
+
+    // must be maintained in topological order
+    private final FixedOrderMap<String, Optional<StateStore>> registeredStores = new FixedOrderMap<>();
+
     private final List<TopicPartition> changelogPartitions = new ArrayList<>();
 
     // TODO: this map does not work with customized grouper where multiple partitions
@@ -85,7 +92,7 @@ public class ProcessorStateManager extends AbstractStateManager {
         this.isStandby = isStandby;
         restoreCallbacks = isStandby ? new HashMap<>() : null;
         recordConverters = isStandby ? new HashMap<>() : null;
-        this.storeToChangelogTopic = storeToChangelogTopic;
+        this.storeToChangelogTopic = new HashMap<>(storeToChangelogTopic);
 
         // load the checkpoint information
         checkpointableOffsets.putAll(checkpoint.read());
@@ -122,52 +129,52 @@ public class ProcessorStateManager extends AbstractStateManager {
             throw new IllegalArgumentException(String.format("%sIllegal store name: %s", logPrefix, CHECKPOINT_FILE_NAME));
         }
 
-        if (stores.containsKey(storeName)) {
+        if (registeredStores.containsKey(storeName) && registeredStores.get(storeName).isPresent()) {
             throw new IllegalArgumentException(String.format("%sStore %s has already been registered.", logPrefix, storeName));
         }
 
         // check that the underlying change log topic exist or not
         final String topic = storeToChangelogTopic.get(storeName);
         if (topic == null) {
-            stores.put(storeName, store);
-            return;
-        }
+            registeredStores.put(storeName, Optional.of(store));
+        } else {
 
-        final TopicPartition storePartition = new TopicPartition(topic, getPartition(topic));
+            final TopicPartition storePartition = new TopicPartition(topic, getPartition(topic));
 
-        final RecordConverter recordConverter = converterForStore(store);
+            final RecordConverter recordConverter = converterForStore(store);
 
-        if (isStandby) {
-            log.trace("Preparing standby replica of persistent state store {} with changelog topic {}", storeName, topic);
+            if (isStandby) {
+                log.trace("Preparing standby replica of persistent state store {} with changelog topic {}", storeName, topic);
 
-            restoreCallbacks.put(topic, stateRestoreCallback);
-            recordConverters.put(topic, recordConverter);
-        } else {
-            log.trace("Restoring state store {} from changelog topic {} at checkpoint {}", storeName, topic, checkpointableOffsets.get(storePartition));
-
-            final StateRestorer restorer = new StateRestorer(
-                storePartition,
-                new CompositeRestoreListener(stateRestoreCallback),
-                checkpointableOffsets.get(storePartition),
-                offsetLimit(storePartition),
-                store.persistent(),
-                storeName,
-                recordConverter
-            );
-
-            changelogReader.register(restorer);
-        }
-        changelogPartitions.add(storePartition);
+                restoreCallbacks.put(topic, stateRestoreCallback);
+                recordConverters.put(topic, recordConverter);
+            } else {
+                log.trace("Restoring state store {} from changelog topic {} at checkpoint {}", storeName, topic, checkpointableOffsets.get(storePartition));
 
-        stores.put(storeName, store);
+                final StateRestorer restorer = new StateRestorer(
+                    storePartition,
+                    new CompositeRestoreListener(stateRestoreCallback),
+                    checkpointableOffsets.get(storePartition),
+                    offsetLimit(storePartition),
+                    store.persistent(),
+                    storeName,
+                    recordConverter
+                );
+
+                changelogReader.register(restorer);
+            }
+            changelogPartitions.add(storePartition);
+
+            registeredStores.put(storeName, Optional.of(store));
+        }
     }
 
     @Override
     public void reinitializeStateStoresForPartitions(final Collection<TopicPartition> partitions,
                                                      final InternalProcessorContext processorContext) {
-        super.reinitializeStateStoresForPartitions(
+        reinitializeStateStoresForPartitions(
             log,
-            stores,
+            registeredStores,
             storeToChangelogTopic,
             partitions,
             processorContext);
@@ -224,24 +231,29 @@ public class ProcessorStateManager extends AbstractStateManager {
 
     @Override
     public StateStore getStore(final String name) {
-        return stores.get(name);
+        return registeredStores.getOrDefault(name, Optional.empty()).orElse(null);
     }
 
     @Override
     public void flush() {
         ProcessorStateException firstException = null;
         // attempting to flush the stores
-        if (!stores.isEmpty()) {
+        if (!registeredStores.isEmpty()) {
             log.debug("Flushing all stores registered in the state manager");
-            for (final StateStore store : stores.values()) {
-                log.trace("Flushing store {}", store.name());
-                try {
-                    store.flush();
-                } catch (final Exception e) {
-                    if (firstException == null) {
-                        firstException = new ProcessorStateException(String.format("%sFailed to flush state store %s", logPrefix, store.name()), e);
+            for (final Map.Entry<String, Optional<StateStore>> entry : registeredStores.entrySet()) {
+                if (entry.getValue().isPresent()) {
+                    final StateStore store = entry.getValue().get();
+                    log.trace("Flushing store {}", store.name());
+                    try {
+                        store.flush();
+                    } catch (final Exception e) {
+                        if (firstException == null) {
+                            firstException = new ProcessorStateException(String.format("%sFailed to flush state store %s", logPrefix, store.name()), e);
+                        }
+                        log.error("Failed to flush state store {}: ", store.name(), e);
                     }
-                    log.error("Failed to flush state store {}: ", store.name(), e);
+                } else {
+                    throw new IllegalStateException("Expected " + entry.getKey() + " to have been initialized");
                 }
             }
         }
@@ -261,20 +273,25 @@ public class ProcessorStateManager extends AbstractStateManager {
         ProcessorStateException firstException = null;
         // attempting to close the stores, just in case they
         // are not closed by a ProcessorNode yet
-        if (!stores.isEmpty()) {
+        if (!registeredStores.isEmpty()) {
             log.debug("Closing its state manager and all the registered state stores");
-            for (final StateStore store : stores.values()) {
-                log.debug("Closing storage engine {}", store.name());
-                try {
-                    store.close();
-                } catch (final Exception e) {
-                    if (firstException == null) {
-                        firstException = new ProcessorStateException(String.format("%sFailed to close state store %s", logPrefix, store.name()), e);
+            for (final Map.Entry<String, Optional<StateStore>> entry : registeredStores.entrySet()) {
+                if (entry.getValue().isPresent()) {
+                    final StateStore store = entry.getValue().get();
+                    log.debug("Closing storage engine {}", store.name());
+                    try {
+                        store.close();
+                        registeredStores.put(store.name(), Optional.empty());
+                    } catch (final Exception e) {
+                        if (firstException == null) {
+                            firstException = new ProcessorStateException(String.format("%sFailed to close state store %s", logPrefix, store.name()), e);
+                        }
+                        log.error("Failed to close state store {}: ", store.name(), e);
                     }
-                    log.error("Failed to close state store {}: ", store.name(), e);
+                } else {
+                    log.info("Skipping to close non-initialized store {}", entry.getKey());
                 }
             }
-            stores.clear();
         }
 
         if (!clean && eosEnabled && checkpoint != null) {
@@ -297,19 +314,24 @@ public class ProcessorStateManager extends AbstractStateManager {
     public void checkpoint(final Map<TopicPartition, Long> checkpointableOffsets) {
         this.checkpointableOffsets.putAll(changelogReader.restoredOffsets());
         log.trace("Checkpointable offsets updated with restored offsets: {}", this.checkpointableOffsets);
-        for (final StateStore store : stores.values()) {
-            final String storeName = store.name();
-            // only checkpoint the offset to the offsets file if
-            // it is persistent AND changelog enabled
-            if (store.persistent() && storeToChangelogTopic.containsKey(storeName)) {
-                final String changelogTopic = storeToChangelogTopic.get(storeName);
-                final TopicPartition topicPartition = new TopicPartition(changelogTopic, getPartition(storeName));
-                if (checkpointableOffsets.containsKey(topicPartition)) {
-                    // store the last offset + 1 (the log position after restoration)
-                    this.checkpointableOffsets.put(topicPartition, checkpointableOffsets.get(topicPartition) + 1);
-                } else if (standbyRestoredOffsets.containsKey(topicPartition)) {
-                    this.checkpointableOffsets.put(topicPartition, standbyRestoredOffsets.get(topicPartition));
+        for (final Map.Entry<String, Optional<StateStore>> entry : registeredStores.entrySet()) {
+            if (entry.getValue().isPresent()) {
+                final StateStore store = entry.getValue().get();
+                final String storeName = store.name();
+                // only checkpoint the offset to the offsets file if
+                // it is persistent AND changelog enabled
+                if (store.persistent() && storeToChangelogTopic.containsKey(storeName)) {
+                    final String changelogTopic = storeToChangelogTopic.get(storeName);
+                    final TopicPartition topicPartition = new TopicPartition(changelogTopic, getPartition(storeName));
+                    if (checkpointableOffsets.containsKey(topicPartition)) {
+                        // store the last offset + 1 (the log position after restoration)
+                        this.checkpointableOffsets.put(topicPartition, checkpointableOffsets.get(topicPartition) + 1);
+                    } else if (standbyRestoredOffsets.containsKey(topicPartition)) {
+                        this.checkpointableOffsets.put(topicPartition, standbyRestoredOffsets.get(topicPartition));
+                    }
                 }
+            } else {
+                throw new IllegalStateException("Expected " + entry.getKey() + " to have been initialized");
             }
         }
 
@@ -336,16 +358,26 @@ public class ProcessorStateManager extends AbstractStateManager {
     void registerGlobalStateStores(final List<StateStore> stateStores) {
         log.debug("Register global stores {}", stateStores);
         for (final StateStore stateStore : stateStores) {
-            globalStores.put(stateStore.name(), stateStore);
+            globalStores.put(stateStore.name(), Optional.of(stateStore));
         }
     }
 
     @Override
     public StateStore getGlobalStore(final String name) {
-        return globalStores.get(name);
+        return globalStores.getOrDefault(name, Optional.empty()).orElse(null);
     }
 
     Collection<TopicPartition> changelogPartitions() {
-        return changelogPartitions;
+        return unmodifiableList(changelogPartitions);
+    }
+
+    void ensureStoresRegistered() {
+        for (final Map.Entry<String, Optional<StateStore>> entry : registeredStores.entrySet()) {
+            if (!entry.getValue().isPresent()) {
+                throw new IllegalStateException(
+                    "store [" + entry.getKey() + "] has not been correctly registered. This is a bug in Kafka Streams."
+                );
+            }
+        }
     }
 }
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorTopology.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorTopology.java
index 57af1f3..6776aa3 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorTopology.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorTopology.java
@@ -34,77 +34,19 @@ public class ProcessorTopology {
     private final Map<String, String> storeToChangelogTopic;
     private final Set<String> repartitionTopics;
 
-    public static ProcessorTopology with(final List<ProcessorNode> processorNodes,
-                                         final Map<String, SourceNode> sourcesByTopic,
-                                         final List<StateStore> stateStoresByName,
-                                         final Map<String, String> storeToChangelogTopic) {
-        return new ProcessorTopology(processorNodes,
-                sourcesByTopic,
-                Collections.<String, SinkNode>emptyMap(),
-                stateStoresByName,
-                Collections.<StateStore>emptyList(),
-                storeToChangelogTopic,
-                Collections.<String>emptySet());
-    }
-
-    static ProcessorTopology withSources(final List<ProcessorNode> processorNodes,
-                                         final Map<String, SourceNode> sourcesByTopic) {
-        return new ProcessorTopology(processorNodes,
-                sourcesByTopic,
-                Collections.<String, SinkNode>emptyMap(),
-                Collections.<StateStore>emptyList(),
-                Collections.<StateStore>emptyList(),
-                Collections.<String, String>emptyMap(),
-                Collections.<String>emptySet());
-    }
-
-    static ProcessorTopology withLocalStores(final List<StateStore> stateStores,
-                                             final Map<String, String> storeToChangelogTopic) {
-        return new ProcessorTopology(Collections.<ProcessorNode>emptyList(),
-                Collections.<String, SourceNode>emptyMap(),
-                Collections.<String, SinkNode>emptyMap(),
-                stateStores,
-                Collections.<StateStore>emptyList(),
-                storeToChangelogTopic,
-                Collections.<String>emptySet());
-    }
-
-    static ProcessorTopology withGlobalStores(final List<StateStore> stateStores,
-                                              final Map<String, String> storeToChangelogTopic) {
-        return new ProcessorTopology(Collections.<ProcessorNode>emptyList(),
-                Collections.<String, SourceNode>emptyMap(),
-                Collections.<String, SinkNode>emptyMap(),
-                Collections.<StateStore>emptyList(),
-                stateStores,
-                storeToChangelogTopic,
-                Collections.<String>emptySet());
-    }
-
-    static ProcessorTopology withRepartitionTopics(final List<ProcessorNode> processorNodes,
-                                                   final Map<String, SourceNode> sourcesByTopic,
-                                                   final Set<String> repartitionTopics) {
-        return new ProcessorTopology(processorNodes,
-                sourcesByTopic,
-                Collections.<String, SinkNode>emptyMap(),
-                Collections.<StateStore>emptyList(),
-                Collections.<StateStore>emptyList(),
-                Collections.<String, String>emptyMap(),
-                repartitionTopics);
-    }
-
     public ProcessorTopology(final List<ProcessorNode> processorNodes,
                              final Map<String, SourceNode> sourcesByTopic,
                              final Map<String, SinkNode> sinksByTopic,
                              final List<StateStore> stateStores,
                              final List<StateStore> globalStateStores,
-                             final Map<String, String> stateStoreToChangelogTopic,
+                             final Map<String, String> storeToChangelogTopic,
                              final Set<String> repartitionTopics) {
         this.processorNodes = Collections.unmodifiableList(processorNodes);
         this.sourcesByTopic = Collections.unmodifiableMap(sourcesByTopic);
         this.sinksByTopic = Collections.unmodifiableMap(sinksByTopic);
         this.stateStores = Collections.unmodifiableList(stateStores);
         this.globalStateStores = Collections.unmodifiableList(globalStateStores);
-        this.storeToChangelogTopic = Collections.unmodifiableMap(stateStoreToChangelogTopic);
+        this.storeToChangelogTopic = Collections.unmodifiableMap(storeToChangelogTopic);
         this.repartitionTopics = Collections.unmodifiableSet(repartitionTopics);
     }
 
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
index 65761e7..3d974a5 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java
@@ -284,6 +284,8 @@ public class StreamTask extends AbstractTask implements ProcessorNodePunctuator
         taskInitialized = true;
 
         idleStartTime = RecordQueue.UNKNOWN;
+
+        stateMgr.ensureStoresRegistered();
     }
 
     /**
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractTaskTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractTaskTest.java
index a84216e..c5080d7 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractTaskTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractTaskTest.java
@@ -50,6 +50,7 @@ import java.util.HashMap;
 import java.util.Map;
 import java.util.Properties;
 
+import static org.apache.kafka.streams.processor.internals.ProcessorTopologyFactories.withLocalStores;
 import static org.easymock.EasyMock.expect;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
@@ -96,6 +97,8 @@ public class AbstractTaskTest {
     public void shouldThrowLockExceptionIfFailedToLockStateDirectoryWhenTopologyHasStores() throws IOException {
         final Consumer consumer = EasyMock.createNiceMock(Consumer.class);
         final StateStore store = EasyMock.createNiceMock(StateStore.class);
+        expect(store.name()).andReturn("dummy-store-name").anyTimes();
+        EasyMock.replay(store);
         expect(stateDirectory.lock(id)).andReturn(false);
         EasyMock.replay(stateDirectory);
 
@@ -232,9 +235,13 @@ public class AbstractTaskTest {
 
         return new AbstractTask(id,
                                 storeTopicPartitions,
-                                ProcessorTopology.withLocalStores(new ArrayList<>(stateStoresToChangelogTopics.keySet()), storeNamesToChangelogTopics),
+                                withLocalStores(new ArrayList<>(stateStoresToChangelogTopics.keySet()),
+                                                storeNamesToChangelogTopics),
                                 consumer,
-                                new StoreChangelogReader(consumer, Duration.ZERO, new MockStateRestoreListener(), new LogContext("stream-task-test ")),
+                                new StoreChangelogReader(consumer,
+                                                         Duration.ZERO,
+                                                         new MockStateRestoreListener(),
+                                                         new LogContext("stream-task-test ")),
                                 false,
                                 stateDirectory,
                                 config) {
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java
index 30e7953..90cff85 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java
@@ -33,6 +33,7 @@ import org.apache.kafka.streams.errors.ProcessorStateException;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.processor.StateRestoreCallback;
 import org.apache.kafka.streams.state.internals.TimestampedBytesStore;
+import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.state.internals.OffsetCheckpoint;
 import org.apache.kafka.streams.state.internals.WrappedStateStore;
 import org.apache.kafka.test.InternalMockProcessorContext;
@@ -91,6 +92,17 @@ public class GlobalStateManagerImplTest {
     private ProcessorTopology topology;
     private InternalMockProcessorContext processorContext;
 
+    static ProcessorTopology withGlobalStores(final List<StateStore> stateStores,
+                                              final Map<String, String> storeToChangelogTopic) {
+        return new ProcessorTopology(Collections.emptyList(),
+                                     Collections.emptyMap(),
+                                     Collections.emptyMap(),
+                                     Collections.emptyList(),
+                                     stateStores,
+                                     storeToChangelogTopic,
+                                     Collections.emptySet());
+    }
+
     @Before
     public void before() {
         final Map<String, String> storeToTopic = new HashMap<>();
@@ -105,7 +117,7 @@ public class GlobalStateManagerImplTest {
         store3 = new NoOpReadOnlyStore<>(storeName3);
         store4 = new NoOpReadOnlyStore<>(storeName4);
 
-        topology = ProcessorTopology.withGlobalStores(asList(store1, store2, store3, store4), storeToTopic);
+        topology = withGlobalStores(asList(store1, store2, store3, store4), storeToTopic);
 
         streamsConfig = new StreamsConfig(new Properties() {
             {
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateTaskTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateTaskTest.java
index 9001255..b16fc54 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateTaskTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateTaskTest.java
@@ -84,7 +84,7 @@ public class GlobalStateTaskTest {
         final Map<String, String> storeToTopic = new HashMap<>();
         storeToTopic.put("t1-store", topic1);
         storeToTopic.put("t2-store", topic2);
-        topology = ProcessorTopology.with(
+        topology = ProcessorTopologyFactories.with(
             asList(sourceOne, sourceTwo, processorOne, processorTwo),
             sourceByTopics,
             Collections.<StateStore>emptyList(),
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java
index b7cc4b1..0a6ada3 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java
@@ -32,9 +32,11 @@ import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender;
 import org.apache.kafka.streams.state.internals.TimestampedBytesStore;
 import org.apache.kafka.streams.state.internals.OffsetCheckpoint;
 import org.apache.kafka.test.MockBatchingStateRestoreListener;
+import org.apache.kafka.test.MockInternalProcessorContext;
 import org.apache.kafka.test.MockKeyValueStore;
 import org.apache.kafka.test.NoOpProcessorContext;
 import org.apache.kafka.test.TestUtils;
+import org.hamcrest.Matchers;
 import org.junit.After;
 import org.junit.Assert;
 import org.junit.Before;
@@ -43,7 +45,6 @@ import org.junit.Test;
 import java.io.File;
 import java.io.IOException;
 import java.nio.charset.StandardCharsets;
-import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
@@ -52,7 +53,12 @@ import java.util.Properties;
 import java.util.Set;
 import java.util.concurrent.atomic.AtomicBoolean;
 
+import static java.util.Arrays.asList;
+import static java.util.Collections.emptyMap;
 import static java.util.Collections.singletonList;
+import static java.util.Collections.singletonMap;
+import static org.apache.kafka.common.utils.Utils.mkEntry;
+import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.core.Is.is;
@@ -80,7 +86,7 @@ public class ProcessorStateManagerTest {
     private final TaskId taskId = new TaskId(0, 1);
     private final MockChangelogReader changelogReader = new MockChangelogReader();
     private final MockKeyValueStore mockKeyValueStore = new MockKeyValueStore(storeName, true);
-    private final byte[] key = new byte[]{0x0, 0x0, 0x0, 0x1};
+    private final byte[] key = new byte[] {0x0, 0x0, 0x0, 0x1};
     private final byte[] value = "the-value".getBytes(StandardCharsets.UTF_8);
     private final ConsumerRecord<byte[], byte[]> consumerRecord = new ConsumerRecord<>(changelogTopic, 0, 0, key, value);
     private final LogContext logContext = new LogContext("process-state-manager-test ");
@@ -190,12 +196,10 @@ public class ProcessorStateManagerTest {
             noPartitions,
             false,
             stateDirectory,
-            new HashMap<String, String>() {
-                {
-                    put(persistentStoreName, persistentStoreTopicName);
-                    put(nonPersistentStoreName, nonPersistentStoreName);
-                }
-            },
+            mkMap(
+                mkEntry(persistentStoreName, persistentStoreTopicName),
+                mkEntry(nonPersistentStoreName, nonPersistentStoreName)
+            ),
             changelogReader,
             false,
             logContext);
@@ -217,12 +221,10 @@ public class ProcessorStateManagerTest {
             noPartitions,
             false,
             stateDirectory,
-            new HashMap<String, String>() {
-                {
-                    put(persistentStoreName, persistentStoreTopicName);
-                    put(nonPersistentStoreName, nonPersistentStoreTopicName);
-                }
-            },
+            mkMap(
+                mkEntry(persistentStoreName, persistentStoreTopicName),
+                mkEntry(nonPersistentStoreName, nonPersistentStoreTopicName)
+            ),
             changelogReader,
             false,
             logContext);
@@ -253,7 +255,7 @@ public class ProcessorStateManagerTest {
         storeToChangelogTopic.put(storeName3, storeTopicName3);
 
         final OffsetCheckpoint checkpoint = new OffsetCheckpoint(new File(stateDirectory.directoryForTask(taskId), ProcessorStateManager.CHECKPOINT_FILE_NAME));
-        checkpoint.write(Collections.singletonMap(new TopicPartition(storeTopicName1, 0), lastCheckpointedOffset));
+        checkpoint.write(singletonMap(new TopicPartition(storeTopicName1, 0), lastCheckpointedOffset));
 
         final TopicPartition partition1 = new TopicPartition(storeTopicName1, 0);
         final TopicPartition partition2 = new TopicPartition(storeTopicName2, 0);
@@ -304,7 +306,7 @@ public class ProcessorStateManagerTest {
             noPartitions,
             false,
             stateDirectory,
-            Collections.emptyMap(),
+            emptyMap(),
             changelogReader,
             false,
             logContext);
@@ -321,7 +323,7 @@ public class ProcessorStateManagerTest {
 
     @Test
     public void testFlushAndClose() throws IOException {
-        checkpoint.write(Collections.emptyMap());
+        checkpoint.write(emptyMap());
 
         // set up ack'ed offsets
         final HashMap<TopicPartition, Long> ackedOffsets = new HashMap<>();
@@ -334,12 +336,8 @@ public class ProcessorStateManagerTest {
             noPartitions,
             false,
             stateDirectory,
-            new HashMap<String, String>() {
-                {
-                    put(persistentStoreName, persistentStoreTopicName);
-                    put(nonPersistentStoreName, nonPersistentStoreTopicName);
-                }
-            },
+            mkMap(mkEntry(persistentStoreName, persistentStoreTopicName),
+                  mkEntry(nonPersistentStoreName, nonPersistentStoreTopicName)),
             changelogReader,
             false,
             logContext);
@@ -362,6 +360,9 @@ public class ProcessorStateManagerTest {
         assertTrue(nonPersistentStore.closed);
         assertTrue(checkpointFile.exists());
 
+        // make sure that flush is called in the proper order
+        assertThat(persistentStore.getLastFlushCount(), Matchers.lessThan(nonPersistentStore.getLastFlushCount()));
+
         // the checkpoint file should contain an offset from the persistent store only.
         final Map<TopicPartition, Long> checkpointedOffsets = checkpoint.read();
         assertEquals(1, checkpointedOffsets.size());
@@ -369,13 +370,49 @@ public class ProcessorStateManagerTest {
     }
 
     @Test
+    public void shouldMaintainRegistrationOrderWhenReregistered() throws IOException {
+        checkpoint.write(emptyMap());
+
+        // set up ack'ed offsets
+        final TopicPartition persistentTopicPartition = new TopicPartition(persistentStoreTopicName, 1);
+        final TopicPartition nonPersistentTopicPartition = new TopicPartition(nonPersistentStoreTopicName, 1);
+
+        final ProcessorStateManager stateMgr = new ProcessorStateManager(
+            taskId,
+            noPartitions,
+            false,
+            stateDirectory,
+            mkMap(mkEntry(persistentStoreName, persistentStoreTopicName),
+                  mkEntry(nonPersistentStoreName, nonPersistentStoreTopicName)),
+            changelogReader,
+            false,
+            logContext);
+        stateMgr.register(persistentStore, persistentStore.stateRestoreCallback);
+        stateMgr.register(nonPersistentStore, nonPersistentStore.stateRestoreCallback);
+        // de-registers the stores, but doesn't re-register them because
+        // the context isn't connected to our state manager
+        stateMgr.reinitializeStateStoresForPartitions(asList(nonPersistentTopicPartition, persistentTopicPartition),
+                                                      new MockInternalProcessorContext());
+        // register them in backward order
+        stateMgr.register(nonPersistentStore, nonPersistentStore.stateRestoreCallback);
+        stateMgr.register(persistentStore, persistentStore.stateRestoreCallback);
+
+        stateMgr.flush();
+
+        // make sure that flush is called in the proper order
+        assertTrue(persistentStore.flushed);
+        assertTrue(nonPersistentStore.flushed);
+        assertThat(persistentStore.getLastFlushCount(), Matchers.lessThan(nonPersistentStore.getLastFlushCount()));
+    }
+
+    @Test
     public void shouldRegisterStoreWithoutLoggingEnabledAndNotBackedByATopic() throws IOException {
         final ProcessorStateManager stateMgr = new ProcessorStateManager(
             new TaskId(0, 1),
             noPartitions,
             false,
             stateDirectory,
-            Collections.emptyMap(),
+            emptyMap(),
             changelogReader,
             false,
             logContext);
@@ -385,7 +422,7 @@ public class ProcessorStateManagerTest {
 
     @Test
     public void shouldNotChangeOffsetsIfAckedOffsetsIsNull() throws IOException {
-        final Map<TopicPartition, Long> offsets = Collections.singletonMap(persistentStorePartition, 99L);
+        final Map<TopicPartition, Long> offsets = singletonMap(persistentStorePartition, 99L);
         checkpoint.write(offsets);
 
         final MockKeyValueStore persistentStore = new MockKeyValueStore(persistentStoreName, true);
@@ -394,7 +431,7 @@ public class ProcessorStateManagerTest {
             noPartitions,
             false,
             stateDirectory,
-            Collections.emptyMap(),
+            emptyMap(),
             changelogReader,
             false,
             logContext);
@@ -411,15 +448,15 @@ public class ProcessorStateManagerTest {
             noPartitions,
             false,
             stateDirectory,
-            Collections.singletonMap(persistentStore.name(), persistentStoreTopicName),
+            singletonMap(persistentStore.name(), persistentStoreTopicName),
             changelogReader,
             false,
             logContext);
         stateMgr.register(persistentStore, persistentStore.stateRestoreCallback);
 
-        stateMgr.checkpoint(Collections.singletonMap(persistentStorePartition, 10L));
+        stateMgr.checkpoint(singletonMap(persistentStorePartition, 10L));
         final Map<TopicPartition, Long> read = checkpoint.read();
-        assertThat(read, equalTo(Collections.singletonMap(persistentStorePartition, 11L)));
+        assertThat(read, equalTo(singletonMap(persistentStorePartition, 11L)));
     }
 
     @Test
@@ -429,7 +466,7 @@ public class ProcessorStateManagerTest {
             noPartitions,
             true, // standby
             stateDirectory,
-            Collections.singletonMap(persistentStore.name(), persistentStoreTopicName),
+            singletonMap(persistentStore.name(), persistentStoreTopicName),
             changelogReader,
             false,
             logContext);
@@ -442,10 +479,10 @@ public class ProcessorStateManagerTest {
             888L
         );
 
-        stateMgr.checkpoint(Collections.emptyMap());
+        stateMgr.checkpoint(emptyMap());
 
         final Map<TopicPartition, Long> read = checkpoint.read();
-        assertThat(read, equalTo(Collections.singletonMap(persistentStorePartition, 889L)));
+        assertThat(read, equalTo(singletonMap(persistentStorePartition, 889L)));
 
     }
 
@@ -458,16 +495,16 @@ public class ProcessorStateManagerTest {
             noPartitions,
             true, // standby
             stateDirectory,
-            Collections.singletonMap(nonPersistentStoreName, nonPersistentStoreTopicName),
+            singletonMap(nonPersistentStoreName, nonPersistentStoreTopicName),
             changelogReader,
             false,
             logContext);
 
         stateMgr.register(nonPersistentStore, nonPersistentStore.stateRestoreCallback);
-        stateMgr.checkpoint(Collections.singletonMap(topicPartition, 876L));
+        stateMgr.checkpoint(singletonMap(topicPartition, 876L));
 
         final Map<TopicPartition, Long> read = checkpoint.read();
-        assertThat(read, equalTo(Collections.emptyMap()));
+        assertThat(read, equalTo(emptyMap()));
     }
 
     @Test
@@ -477,17 +514,17 @@ public class ProcessorStateManagerTest {
             noPartitions,
             true, // standby
             stateDirectory,
-            Collections.emptyMap(),
+            emptyMap(),
             changelogReader,
             false,
             logContext);
 
         stateMgr.register(persistentStore, persistentStore.stateRestoreCallback);
 
-        stateMgr.checkpoint(Collections.singletonMap(persistentStorePartition, 987L));
+        stateMgr.checkpoint(singletonMap(persistentStorePartition, 987L));
 
         final Map<TopicPartition, Long> read = checkpoint.read();
-        assertThat(read, equalTo(Collections.emptyMap()));
+        assertThat(read, equalTo(emptyMap()));
     }
 
     @Test
@@ -497,7 +534,7 @@ public class ProcessorStateManagerTest {
             noPartitions,
             false,
             stateDirectory,
-            Collections.emptyMap(),
+            emptyMap(),
             changelogReader,
             false,
             logContext);
@@ -517,7 +554,7 @@ public class ProcessorStateManagerTest {
             noPartitions,
             false,
             stateDirectory,
-            Collections.emptyMap(),
+            emptyMap(),
             changelogReader,
             false,
             logContext);
@@ -541,7 +578,7 @@ public class ProcessorStateManagerTest {
             Collections.singleton(changelogTopicPartition),
             false,
             stateDirectory,
-            Collections.singletonMap(storeName, changelogTopic),
+            singletonMap(storeName, changelogTopic),
             changelogReader,
             false,
             logContext);
@@ -570,7 +607,7 @@ public class ProcessorStateManagerTest {
             Collections.singleton(changelogTopicPartition),
             false,
             stateDirectory,
-            Collections.singletonMap(storeName, changelogTopic),
+            singletonMap(storeName, changelogTopic),
             changelogReader,
             false,
             logContext);
@@ -604,7 +641,7 @@ public class ProcessorStateManagerTest {
                 noPartitions,
                 false,
                 stateDirectory,
-                Collections.singletonMap(persistentStore.name(), persistentStoreTopicName),
+                singletonMap(persistentStore.name(), persistentStoreTopicName),
                 changelogReader,
                 false,
                 logContext);
@@ -615,7 +652,7 @@ public class ProcessorStateManagerTest {
         stateMgr.register(persistentStore, persistentStore.stateRestoreCallback);
 
         stateDirectory.clean();
-        stateMgr.checkpoint(Collections.singletonMap(persistentStorePartition, 10L));
+        stateMgr.checkpoint(singletonMap(persistentStorePartition, 10L));
         LogCaptureAppender.unregister(appender);
 
         boolean foundExpectedLogMessage = false;
@@ -634,16 +671,6 @@ public class ProcessorStateManagerTest {
 
     @Test
     public void shouldFlushAllStoresEvenIfStoreThrowsException() throws IOException {
-        final ProcessorStateManager stateManager = new ProcessorStateManager(
-            taskId,
-            Collections.singleton(changelogTopicPartition),
-            false,
-            stateDirectory,
-            Collections.singletonMap(storeName, changelogTopic),
-            changelogReader,
-            false,
-            logContext);
-
         final AtomicBoolean flushedStore = new AtomicBoolean(false);
 
         final MockKeyValueStore stateStore1 = new MockKeyValueStore(storeName, true) {
@@ -658,6 +685,16 @@ public class ProcessorStateManagerTest {
                 flushedStore.set(true);
             }
         };
+        final ProcessorStateManager stateManager = new ProcessorStateManager(
+            taskId,
+            Collections.singleton(changelogTopicPartition),
+            false,
+            stateDirectory,
+            singletonMap(storeName, changelogTopic),
+            changelogReader,
+            false,
+            logContext);
+
         stateManager.register(stateStore1, stateStore1.stateRestoreCallback);
         stateManager.register(stateStore2, stateStore2.stateRestoreCallback);
 
@@ -669,15 +706,6 @@ public class ProcessorStateManagerTest {
 
     @Test
     public void shouldCloseAllStoresEvenIfStoreThrowsExcepiton() throws IOException {
-        final ProcessorStateManager stateManager = new ProcessorStateManager(
-            taskId,
-            Collections.singleton(changelogTopicPartition),
-            false,
-            stateDirectory,
-            Collections.singletonMap(storeName, changelogTopic),
-            changelogReader,
-            false,
-            logContext);
 
         final AtomicBoolean closedStore = new AtomicBoolean(false);
 
@@ -693,6 +721,16 @@ public class ProcessorStateManagerTest {
                 closedStore.set(true);
             }
         };
+        final ProcessorStateManager stateManager = new ProcessorStateManager(
+            taskId,
+            Collections.singleton(changelogTopicPartition),
+            false,
+            stateDirectory,
+            singletonMap(storeName, changelogTopic),
+            changelogReader,
+            false,
+            logContext);
+
         stateManager.register(stateStore1, stateStore1.stateRestoreCallback);
         stateManager.register(stateStore2, stateStore2.stateRestoreCallback);
 
@@ -704,7 +742,7 @@ public class ProcessorStateManagerTest {
 
     @Test
     public void shouldDeleteCheckpointFileOnCreationIfEosEnabled() throws IOException {
-        checkpoint.write(Collections.singletonMap(new TopicPartition(persistentStoreTopicName, 1), 123L));
+        checkpoint.write(singletonMap(new TopicPartition(persistentStoreTopicName, 1), 123L));
         assertTrue(checkpointFile.exists());
 
         ProcessorStateManager stateManager = null;
@@ -714,7 +752,7 @@ public class ProcessorStateManagerTest {
                 noPartitions,
                 false,
                 stateDirectory,
-                Collections.emptyMap(),
+                emptyMap(),
                 changelogReader,
                 true,
                 logContext);
@@ -741,13 +779,15 @@ public class ProcessorStateManagerTest {
         final String store2Name = "store2";
         final String store2Changelog = "store2-changelog";
         final TopicPartition store2Partition = new TopicPartition(store2Changelog, 0);
-        final List<TopicPartition> changelogPartitions = Arrays.asList(changelogTopicPartition, store2Partition);
-        final Map<String, String> storeToChangelog = new HashMap<String, String>() {
-            {
-                put(storeName, changelogTopic);
-                put(store2Name, store2Changelog);
-            }
-        };
+        final List<TopicPartition> changelogPartitions = asList(changelogTopicPartition, store2Partition);
+        final Map<String, String> storeToChangelog = mkMap(
+                mkEntry(storeName, changelogTopic),
+                mkEntry(store2Name, store2Changelog)
+        );
+
+        final MockKeyValueStore stateStore = new MockKeyValueStore(storeName, true);
+        final MockKeyValueStore stateStore2 = new MockKeyValueStore(store2Name, true);
+
         final ProcessorStateManager stateManager = new ProcessorStateManager(
             taskId,
             changelogPartitions,
@@ -758,9 +798,6 @@ public class ProcessorStateManagerTest {
             eosEnabled,
             logContext);
 
-        final MockKeyValueStore stateStore = new MockKeyValueStore(storeName, true);
-        final MockKeyValueStore stateStore2 = new MockKeyValueStore(store2Name, true);
-
         stateManager.register(stateStore, stateStore.stateRestoreCallback);
         stateManager.register(stateStore2, stateStore2.stateRestoreCallback);
 
@@ -784,11 +821,7 @@ public class ProcessorStateManagerTest {
             noPartitions,
             true,
             stateDirectory,
-            new HashMap<String, String>() {
-                {
-                    put(persistentStoreName, persistentStoreTopicName);
-                }
-            },
+            singletonMap(persistentStoreName, persistentStoreTopicName),
             changelogReader,
             false,
             logContext);
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorTopologyFactories.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorTopologyFactories.java
new file mode 100644
index 0000000..5fe1505
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorTopologyFactories.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.streams.processor.StateStore;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+
+public final class ProcessorTopologyFactories {
+    private ProcessorTopologyFactories() {}
+
+
+    public static ProcessorTopology with(final List<ProcessorNode> processorNodes,
+                                         final Map<String, SourceNode> sourcesByTopic,
+                                         final List<StateStore> stateStoresByName,
+                                         final Map<String, String> storeToChangelogTopic) {
+        return new ProcessorTopology(processorNodes,
+                                     sourcesByTopic,
+                                     Collections.emptyMap(),
+                                     stateStoresByName,
+                                     Collections.emptyList(),
+                                     storeToChangelogTopic,
+                                     Collections.emptySet());
+    }
+
+    static ProcessorTopology withLocalStores(final List<StateStore> stateStores,
+                                             final Map<String, String> storeToChangelogTopic) {
+        return new ProcessorTopology(Collections.emptyList(),
+                                     Collections.emptyMap(),
+                                     Collections.emptyMap(),
+                                     stateStores,
+                                     Collections.emptyList(),
+                                     storeToChangelogTopic,
+                                     Collections.emptySet());
+    }
+
+}
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
index a57c407..6e7655a 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java
@@ -104,8 +104,9 @@ public class StandbyTaskTest {
     private final MockStateRestoreListener stateRestoreListener = new MockStateRestoreListener();
 
     private final Set<TopicPartition> topicPartitions = Collections.emptySet();
-    private final ProcessorTopology topology = ProcessorTopology.withLocalStores(
-        asList(new MockKeyValueStoreBuilder(storeName1, false).build(), new MockKeyValueStoreBuilder(storeName2, true).build()),
+    private final ProcessorTopology topology = ProcessorTopologyFactories.withLocalStores(
+        asList(new MockKeyValueStoreBuilder(storeName1, false).build(),
+               new MockKeyValueStoreBuilder(storeName2, true).build()),
         mkMap(
             mkEntry(storeName1, storeChangelogTopicName1),
             mkEntry(storeName2, storeChangelogTopicName2)
@@ -113,8 +114,9 @@ public class StandbyTaskTest {
     );
     private final TopicPartition globalTopicPartition = new TopicPartition(globalStoreName, 0);
     private final Set<TopicPartition> ktablePartitions = Utils.mkSet(globalTopicPartition);
-    private final ProcessorTopology ktableTopology = ProcessorTopology.withLocalStores(
-        singletonList(new MockKeyValueStoreBuilder(globalTopicPartition.topic(), true).withLoggingDisabled().build()),
+    private final ProcessorTopology ktableTopology = ProcessorTopologyFactories.withLocalStores(
+        singletonList(new MockKeyValueStoreBuilder(globalTopicPartition.topic(), true)
+                          .withLoggingDisabled().build()),
         mkMap(
             mkEntry(globalStoreName, globalTopicPartition.topic())
         )
@@ -165,7 +167,7 @@ public class StandbyTaskTest {
 
     @After
     public void cleanup() throws IOException {
-        if (task != null) {
+        if (task != null && !task.isClosed()) {
             task.close(true, false);
             task = null;
         }
@@ -367,7 +369,7 @@ public class StandbyTaskTest {
         );
 
         task.suspend();
-        task.closeStateManager(true);
+        task.close(true, false);
 
         final File taskDir = stateDirectory.directoryForTask(taskId);
         final OffsetCheckpoint checkpoint = new OffsetCheckpoint(new File(taskDir, ProcessorStateManager.CHECKPOINT_FILE_NAME));
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
index 829106a..d22bb1b 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java
@@ -119,7 +119,7 @@ public class StreamTaskTest {
     private final TopicPartition changelogPartition = new TopicPartition("store-changelog", 0);
     private final Long offset = 543L;
 
-    private final ProcessorTopology topology = ProcessorTopology.withSources(
+    private final ProcessorTopology topology = withSources(
         asList(source1, source2, processorStreamTime, processorSystemTime),
         mkMap(mkEntry(topic1, source1), mkEntry(topic2, source2))
     );
@@ -152,6 +152,29 @@ public class StreamTaskTest {
         }
     };
 
+    static ProcessorTopology withRepartitionTopics(final List<ProcessorNode> processorNodes,
+                                                   final Map<String, SourceNode> sourcesByTopic,
+                                                   final Set<String> repartitionTopics) {
+        return new ProcessorTopology(processorNodes,
+                                     sourcesByTopic,
+                                     Collections.emptyMap(),
+                                     Collections.emptyList(),
+                                     Collections.emptyList(),
+                                     Collections.emptyMap(),
+                                     repartitionTopics);
+    }
+
+    static ProcessorTopology withSources(final List<ProcessorNode> processorNodes,
+                                         final Map<String, SourceNode> sourcesByTopic) {
+        return new ProcessorTopology(processorNodes,
+                                     sourcesByTopic,
+                                     Collections.emptyMap(),
+                                     Collections.emptyList(),
+                                     Collections.emptyList(),
+                                     Collections.emptyMap(),
+                                     Collections.emptySet());
+    }
+
     private StreamsConfig createConfig(final boolean enableEoS) {
         final String canonicalPath;
         try {
@@ -195,7 +218,7 @@ public class StreamTaskTest {
     public void shouldHandleInitTransactionsTimeoutExceptionOnCreation() {
         final LogCaptureAppender appender = LogCaptureAppender.createAndRegister();
 
-        final ProcessorTopology topology = ProcessorTopology.withSources(
+        final ProcessorTopology topology = withSources(
             asList(source1, source2, processorStreamTime, processorSystemTime),
             mkMap(mkEntry(topic1, (SourceNode) source1), mkEntry(topic2, (SourceNode) source2))
         );
@@ -245,7 +268,7 @@ public class StreamTaskTest {
     public void shouldHandleInitTransactionsTimeoutExceptionOnResume() {
         final LogCaptureAppender appender = LogCaptureAppender.createAndRegister();
 
-        final ProcessorTopology topology = ProcessorTopology.withSources(
+        final ProcessorTopology topology = withSources(
             asList(source1, source2, processorStreamTime, processorSystemTime),
             mkMap(mkEntry(topic1, (SourceNode) source1), mkEntry(topic2, (SourceNode) source2))
         );
@@ -1353,7 +1376,7 @@ public class StreamTaskTest {
     public void shouldReturnOffsetsForRepartitionTopicsForPurging() {
         final TopicPartition repartition = new TopicPartition("repartition", 1);
 
-        final ProcessorTopology topology = ProcessorTopology.withRepartitionTopics(
+        final ProcessorTopology topology = withRepartitionTopics(
             asList(source1, source2),
             mkMap(mkEntry(topic1, source1), mkEntry(repartition.topic(), source2)),
             Collections.singleton(repartition.topic())
@@ -1425,7 +1448,7 @@ public class StreamTaskTest {
     }
 
     private StreamTask createStatefulTask(final StreamsConfig config, final boolean logged) {
-        final ProcessorTopology topology = ProcessorTopology.with(
+        final ProcessorTopology topology = ProcessorTopologyFactories.with(
             asList(source1, source2),
             mkMap(mkEntry(topic1, source1), mkEntry(topic2, source2)),
             singletonList(stateStore),
@@ -1447,7 +1470,7 @@ public class StreamTaskTest {
     }
 
     private StreamTask createStatefulTaskThatThrowsExceptionOnClose() {
-        final ProcessorTopology topology = ProcessorTopology.with(
+        final ProcessorTopology topology = ProcessorTopologyFactories.with(
             asList(source1, source3),
             mkMap(mkEntry(topic1, source1), mkEntry(topic2, source3)),
             singletonList(stateStore),
@@ -1469,7 +1492,7 @@ public class StreamTaskTest {
     }
 
     private StreamTask createStatelessTask(final StreamsConfig streamsConfig) {
-        final ProcessorTopology topology = ProcessorTopology.withSources(
+        final ProcessorTopology topology = withSources(
             asList(source1, source2, processorStreamTime, processorSystemTime),
             mkMap(mkEntry(topic1, source1), mkEntry(topic2, source2))
         );
@@ -1496,7 +1519,7 @@ public class StreamTaskTest {
 
     // this task will throw exception when processing (on partition2), flushing, suspending and closing
     private StreamTask createTaskThatThrowsException(final boolean enableEos) {
-        final ProcessorTopology topology = ProcessorTopology.withSources(
+        final ProcessorTopology topology = withSources(
             asList(source1, source3, processorStreamTime, processorSystemTime),
             mkMap(mkEntry(topic1, source1), mkEntry(topic2, source3))
         );
diff --git a/streams/src/test/java/org/apache/kafka/test/MockKeyValueStore.java b/streams/src/test/java/org/apache/kafka/test/MockKeyValueStore.java
index 1729b24..11b73d9 100644
--- a/streams/src/test/java/org/apache/kafka/test/MockKeyValueStore.java
+++ b/streams/src/test/java/org/apache/kafka/test/MockKeyValueStore.java
@@ -26,8 +26,13 @@ import org.apache.kafka.streams.state.KeyValueStore;
 
 import java.util.ArrayList;
 import java.util.List;
+import java.util.concurrent.atomic.AtomicInteger;
 
 public class MockKeyValueStore implements KeyValueStore {
+    // keep a global counter of flushes and a local reference to which store had which
+    // flush, so we can reason about the order in which stores get flushed.
+    private static final AtomicInteger GLOBAL_FLUSH_COUNTER = new AtomicInteger(0);
+    private final AtomicInteger instanceLastFlushCount = new AtomicInteger(-1);
     private final String name;
     private final boolean persistent;
 
@@ -58,9 +63,14 @@ public class MockKeyValueStore implements KeyValueStore {
 
     @Override
     public void flush() {
+        instanceLastFlushCount.set(GLOBAL_FLUSH_COUNTER.getAndIncrement());
         flushed = true;
     }
 
+    public int getLastFlushCount() {
+        return instanceLastFlushCount.get();
+    }
+
     @Override
     public void close() {
         closed = true;