You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@samza.apache.org by bh...@apache.org on 2021/05/12 23:47:21 UTC

[samza] branch state-backend-async-commit updated: SAMZA-2591: Async Commit [3/3]: Container restore lifecycle (#1491)

This is an automated email from the ASF dual-hosted git repository.

bharathkk pushed a commit to branch state-backend-async-commit
in repository https://gitbox.apache.org/repos/asf/samza.git


The following commit(s) were added to refs/heads/state-backend-async-commit by this push:
     new c85aade  SAMZA-2591: Async Commit [3/3]: Container restore lifecycle (#1491)
c85aade is described below

commit c85aade351a3e7dea2a699a5836968a9fc7ec016
Author: Daniel Chen <dc...@linkedin.com>
AuthorDate: Wed May 12 16:47:09 2021 -0700

    SAMZA-2591: Async Commit [3/3]: Container restore lifecycle (#1491)
    
    - Refactor ContainerStorageManager lifecycle for restores
---
 .../org/apache/samza/checkpoint/CheckpointId.java  |   2 +-
 .../java/org/apache/samza/job/model/JobModel.java  |   4 +
 .../samza/storage/KafkaChangelogRestoreParams.java |  85 ++++++
 .../apache/samza/storage/StateBackendFactory.java  |   2 +-
 .../org/apache/samza/storage/StorageEngine.java    |   8 +
 .../org/apache/samza/storage/StoreProperties.java  |  23 +-
 .../org/apache/samza/storage/kv/KeyValueStore.java |  13 +
 .../apache/samza/system/ChangelogSSPIterator.java  |   2 +-
 .../org/apache/samza/config/StorageConfig.java     |   5 +-
 .../samza/coordinator/MetadataResourceUtil.java    |   2 +-
 .../samza/serializers/CheckpointV2Serde.java       |   6 +-
 .../serializers/model/JsonCheckpointV2Mixin.java   |   3 -
 .../storage/KafkaChangelogStateBackendFactory.java |   2 +-
 .../NonTransactionalStateTaskRestoreManager.java   |  80 +++--
 .../org/apache/samza/storage/StorageRecovery.java  |  15 +-
 .../apache/samza/storage/TaskRestoreManager.java   |  55 ----
 .../samza/storage/TaskRestoreManagerFactory.java   |  82 -----
 .../samza/storage/TaskSideInputStorageManager.java |   1 +
 .../TransactionalStateTaskRestoreManager.java      | 127 ++++++--
 .../java/org/apache/samza/system/SystemAdmins.java |   8 +
 .../samza/storage/ContainerStorageManager.java     | 329 +++++++++++----------
 .../samza/storage/TaskStorageManagerFactory.java   |  46 ---
 .../org/apache/samza/job/model/TestJobModel.java   |   2 +-
 .../TestTransactionalStateTaskRestoreManager.java  | 252 ++++++++--------
 .../samza/storage/TestContainerStorageManager.java |  34 ++-
 .../InMemoryKeyValueStorageEngineFactory.java      |   2 -
 .../kv/RocksDbKeyValueStorageEngineFactory.scala   |   2 -
 .../samza/storage/kv/RocksDbKeyValueStore.scala    |   9 +-
 .../storage/kv/TestRocksDbKeyValueStore.scala      |  20 +-
 .../kv/BaseKeyValueStorageEngineFactory.java       |   9 +-
 .../kv/MockKeyValueStorageEngineFactory.java       |   4 +-
 .../kv/TestBaseKeyValueStorageEngineFactory.java   |  37 ++-
 .../processor/TestZkLocalApplicationRunner.java    |   4 +-
 .../NonTransactionalStateIntegrationTest.scala     |   5 +-
 .../test/integration/StreamTaskTestUtil.scala      |   8 +-
 35 files changed, 717 insertions(+), 571 deletions(-)

diff --git a/samza-api/src/main/java/org/apache/samza/checkpoint/CheckpointId.java b/samza-api/src/main/java/org/apache/samza/checkpoint/CheckpointId.java
index f2aa8e1..a4d4ca3 100644
--- a/samza-api/src/main/java/org/apache/samza/checkpoint/CheckpointId.java
+++ b/samza-api/src/main/java/org/apache/samza/checkpoint/CheckpointId.java
@@ -92,7 +92,7 @@ public class CheckpointId implements Comparable<CheckpointId> {
 
   @Override
   public int compareTo(CheckpointId that) {
-    if(this.millis != that.millis) return Long.compare(this.millis, that.millis);
+    if (this.millis != that.millis) return Long.compare(this.millis, that.millis);
     else return Long.compare(this.nanoId, that.nanoId);
   }
 
diff --git a/samza-api/src/main/java/org/apache/samza/job/model/JobModel.java b/samza-api/src/main/java/org/apache/samza/job/model/JobModel.java
index d1f5e72..4bd995b 100644
--- a/samza-api/src/main/java/org/apache/samza/job/model/JobModel.java
+++ b/samza-api/src/main/java/org/apache/samza/job/model/JobModel.java
@@ -57,6 +57,10 @@ public class JobModel {
     }
   }
 
+  public int getMaxChangeLogStreamPartitions() {
+    return maxChangeLogStreamPartitions;
+  }
+
   public Config getConfig() {
     return config;
   }
diff --git a/samza-api/src/main/java/org/apache/samza/storage/KafkaChangelogRestoreParams.java b/samza-api/src/main/java/org/apache/samza/storage/KafkaChangelogRestoreParams.java
new file mode 100644
index 0000000..961a85b
--- /dev/null
+++ b/samza-api/src/main/java/org/apache/samza/storage/KafkaChangelogRestoreParams.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.storage;
+
+import java.util.Map;
+import java.util.Set;
+import org.apache.samza.serializers.Serde;
+import org.apache.samza.system.SystemAdmin;
+import org.apache.samza.system.SystemConsumer;
+import org.apache.samza.task.MessageCollector;
+
+/**
+ * Provides the required for Kafka Changelog restore managers
+ */
+public class KafkaChangelogRestoreParams {
+  private final Map<String, SystemConsumer> storeConsumers;
+  private final Map<String, StorageEngine> inMemoryStores;
+  private final Map<String, SystemAdmin> systemAdmins;
+  private final Map<String, StorageEngineFactory<Object, Object>> storageEngineFactories;
+  private final Map<String, Serde<Object>> serdes;
+  private final MessageCollector collector;
+  private final Set<String> storeNames;
+
+  public KafkaChangelogRestoreParams(
+      Map<String, SystemConsumer> storeConsumers,
+      Map<String, StorageEngine> inMemoryStores,
+      Map<String, SystemAdmin> systemAdmins,
+      Map<String, StorageEngineFactory<Object, Object>> storageEngineFactories,
+      Map<String, Serde<Object>> serdes,
+      MessageCollector collector,
+      Set<String> storeNames) {
+    this.storeConsumers = storeConsumers;
+    this.inMemoryStores = inMemoryStores;
+    this.systemAdmins = systemAdmins;
+    this.storageEngineFactories = storageEngineFactories;
+    this.serdes = serdes;
+    this.collector = collector;
+    this.storeNames = storeNames;
+  }
+
+  public Map<String, SystemConsumer> getStoreConsumers() {
+    return storeConsumers;
+  }
+
+  public Map<String, StorageEngine> getInMemoryStores() {
+    return inMemoryStores;
+  }
+
+  public Map<String, SystemAdmin> getSystemAdmins() {
+    return systemAdmins;
+  }
+
+  public Map<String, StorageEngineFactory<Object, Object>> getStorageEngineFactories() {
+    return storageEngineFactories;
+  }
+
+  public Map<String, Serde<Object>> getSerdes() {
+    return serdes;
+  }
+
+  public MessageCollector getCollector() {
+    return collector;
+  }
+
+  public Set<String> getStoreNames() {
+    return storeNames;
+  }
+}
diff --git a/samza-api/src/main/java/org/apache/samza/storage/StateBackendFactory.java b/samza-api/src/main/java/org/apache/samza/storage/StateBackendFactory.java
index 9946d2a..54a9a81 100644
--- a/samza-api/src/main/java/org/apache/samza/storage/StateBackendFactory.java
+++ b/samza-api/src/main/java/org/apache/samza/storage/StateBackendFactory.java
@@ -33,7 +33,7 @@ import org.apache.samza.util.Clock;
  * Factory to build the Samza {@link TaskBackupManager}, {@link TaskRestoreManager} and {@link TaskStorageAdmin}
  * for a particular state storage backend, which are used to durably backup the Samza task state.
  */
-  public interface StateBackendFactory {
+public interface StateBackendFactory {
   TaskBackupManager getBackupManager(JobContext jobContext,
       ContainerContext containerContext,
       TaskModel taskModel,
diff --git a/samza-api/src/main/java/org/apache/samza/storage/StorageEngine.java b/samza-api/src/main/java/org/apache/samza/storage/StorageEngine.java
index 7b12c85..0d62ca4 100644
--- a/samza-api/src/main/java/org/apache/samza/storage/StorageEngine.java
+++ b/samza-api/src/main/java/org/apache/samza/storage/StorageEngine.java
@@ -24,6 +24,9 @@ import java.nio.file.Path;
 import java.util.Optional;
 import org.apache.samza.annotation.InterfaceStability;
 import org.apache.samza.checkpoint.CheckpointId;
+import org.apache.samza.context.ContainerContext;
+import org.apache.samza.context.ExternalContext;
+import org.apache.samza.context.JobContext;
 import org.apache.samza.system.ChangelogSSPIterator;
 
 /**
@@ -39,6 +42,11 @@ import org.apache.samza.system.ChangelogSSPIterator;
 public interface StorageEngine {
 
   /**
+   * Initialize the storage engine
+   */
+  default void init(ExternalContext externalContext, JobContext jobContext, ContainerContext containerContext) { };
+
+  /**
    * Restore the content of this StorageEngine from the changelog. Messages are
    * provided in one {@link java.util.Iterator} and not deserialized for
    * efficiency, allowing the implementation to optimize replay, if possible.
diff --git a/samza-api/src/main/java/org/apache/samza/storage/StoreProperties.java b/samza-api/src/main/java/org/apache/samza/storage/StoreProperties.java
index a398271..1244adf 100644
--- a/samza-api/src/main/java/org/apache/samza/storage/StoreProperties.java
+++ b/samza-api/src/main/java/org/apache/samza/storage/StoreProperties.java
@@ -24,12 +24,15 @@ package org.apache.samza.storage;
 public class StoreProperties {
   private final boolean persistedToDisk;
   private final boolean loggedStore;
+  private final boolean durable;
 
   private StoreProperties(
       final boolean persistedToDisk,
-      final boolean loggedStore) {
+      final boolean loggedStore,
+      final boolean durable) {
     this.persistedToDisk = persistedToDisk;
     this.loggedStore = loggedStore;
+    this.durable = durable;
   }
 
   /**
@@ -50,9 +53,20 @@ public class StoreProperties {
     return loggedStore;
   }
 
+  /**
+   * Flag to indicate whether a store is durable, that is, it's contents are available cross container restarts
+   * or host reallocation.
+   *
+   * @return True, if the store is durable. False by default.
+   */
+  public boolean isDurableStore() {
+    return durable;
+  }
+
   public static class StorePropertiesBuilder {
     private boolean persistedToDisk = false;
     private boolean loggedStore = false;
+    private boolean durable = false;
 
     public StorePropertiesBuilder setPersistedToDisk(boolean persistedToDisk) {
       this.persistedToDisk = persistedToDisk;
@@ -64,8 +78,13 @@ public class StoreProperties {
       return this;
     }
 
+    public StorePropertiesBuilder setIsDurable(boolean durable) {
+      this.durable = durable;
+      return this;
+    }
+
     public StoreProperties build() {
-      return new StoreProperties(persistedToDisk, loggedStore);
+      return new StoreProperties(persistedToDisk, loggedStore, durable);
     }
   }
 }
diff --git a/samza-api/src/main/java/org/apache/samza/storage/kv/KeyValueStore.java b/samza-api/src/main/java/org/apache/samza/storage/kv/KeyValueStore.java
index 41faac3..a3552f0 100644
--- a/samza-api/src/main/java/org/apache/samza/storage/kv/KeyValueStore.java
+++ b/samza-api/src/main/java/org/apache/samza/storage/kv/KeyValueStore.java
@@ -26,6 +26,9 @@ import java.util.Map;
 import java.util.Optional;
 import org.apache.samza.annotation.InterfaceStability;
 import org.apache.samza.checkpoint.CheckpointId;
+import org.apache.samza.context.ContainerContext;
+import org.apache.samza.context.ExternalContext;
+import org.apache.samza.context.JobContext;
 
 
 /**
@@ -35,6 +38,16 @@ import org.apache.samza.checkpoint.CheckpointId;
  * @param <V> the type of values maintained by this key-value store.
  */
 public interface KeyValueStore<K, V> {
+
+  /**
+   * Initializes the KeyValueStore
+   *
+   * @param externalContext any external store required for initialization
+   * @param jobContext context of the job the KeyValueStore is in
+   * @param containerContext context of the KeyValueStore's container
+   */
+  default void init(ExternalContext externalContext, JobContext jobContext, ContainerContext containerContext) { }
+
   /**
    * Gets the value associated with the specified {@code key}.
    *
diff --git a/samza-api/src/main/java/org/apache/samza/system/ChangelogSSPIterator.java b/samza-api/src/main/java/org/apache/samza/system/ChangelogSSPIterator.java
index ea44b9d..8e5bc93 100644
--- a/samza-api/src/main/java/org/apache/samza/system/ChangelogSSPIterator.java
+++ b/samza-api/src/main/java/org/apache/samza/system/ChangelogSSPIterator.java
@@ -61,7 +61,7 @@ public class ChangelogSSPIterator extends BoundedSSPIterator {
   public IncomingMessageEnvelope next() {
     IncomingMessageEnvelope envelope = super.next();
 
-    // if trimming changelog is enabled, then switch to trim mode if if we've consumed past the restore offset
+    // if trimming changelog is enabled, then switch to trim mode if we've consumed past the restore offset
     // (i.e., restoreOffset was null or current offset is > restoreOffset)
     if (this.trimEnabled && (restoreOffset == null || admin.offsetComparator(envelope.getOffset(), restoreOffset) > 0)) {
       mode = Mode.TRIM;
diff --git a/samza-core/src/main/java/org/apache/samza/config/StorageConfig.java b/samza-core/src/main/java/org/apache/samza/config/StorageConfig.java
index f5e2055..4dd753e 100644
--- a/samza-core/src/main/java/org/apache/samza/config/StorageConfig.java
+++ b/samza-core/src/main/java/org/apache/samza/config/StorageConfig.java
@@ -22,7 +22,6 @@ package org.apache.samza.config;
 import com.google.common.collect.ImmutableList;
 import java.util.ArrayList;
 import java.util.Collections;
-import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
@@ -72,6 +71,8 @@ public class StorageConfig extends MapConfig {
   public static final List<String> DEFAULT_STATE_BACKEND_BACKUP_FACTORIES = ImmutableList.of(
       DEFAULT_STATE_BACKEND_FACTORY);
   public static final String STATE_BACKEND_RESTORE_FACTORY = STORE_PREFIX + "state.restore.backend";
+  public static final String INMEMORY_KV_STORAGE_ENGINE_FACTORY =
+      "org.apache.samza.storage.kv.inmemory.InMemoryKeyValueStorageEngineFactory";
 
   static final String CHANGELOG_SYSTEM = "job.changelog.system";
   static final String CHANGELOG_DELETE_RETENTION_MS = STORE_PREFIX + "%s.changelog.delete.retention.ms";
@@ -83,8 +84,6 @@ public class StorageConfig extends MapConfig {
   static final String SIDE_INPUTS_PROCESSOR_FACTORY = STORE_PREFIX + "%s" + SIDE_INPUT_PROCESSOR_FACTORY_SUFFIX;
   static final String SIDE_INPUTS_PROCESSOR_SERIALIZED_INSTANCE =
       STORE_PREFIX + "%s.side.inputs.processor.serialized.instance";
-  static final String INMEMORY_KV_STORAGE_ENGINE_FACTORY =
-      "org.apache.samza.storage.kv.inmemory.InMemoryKeyValueStorageEngineFactory";
 
   // Internal config to clean storeDirs of a store on container start. This is used to benchmark bootstrap performance.
   static final String CLEAN_LOGGED_STOREDIRS_ON_START = STORE_PREFIX + "%s.clean.on.container.start";
diff --git a/samza-core/src/main/java/org/apache/samza/coordinator/MetadataResourceUtil.java b/samza-core/src/main/java/org/apache/samza/coordinator/MetadataResourceUtil.java
index 1050662..9e366ea 100644
--- a/samza-core/src/main/java/org/apache/samza/coordinator/MetadataResourceUtil.java
+++ b/samza-core/src/main/java/org/apache/samza/coordinator/MetadataResourceUtil.java
@@ -61,7 +61,7 @@ public class MetadataResourceUtil {
 
   @VisibleForTesting
   void createChangelogStreams() {
-    ChangelogStreamManager.createChangelogStreams(config, jobModel.maxChangeLogStreamPartitions);
+    ChangelogStreamManager.createChangelogStreams(config, jobModel.getMaxChangeLogStreamPartitions());
   }
 
   @VisibleForTesting
diff --git a/samza-core/src/main/java/org/apache/samza/serializers/CheckpointV2Serde.java b/samza-core/src/main/java/org/apache/samza/serializers/CheckpointV2Serde.java
index 10e6b3d..48aa564 100644
--- a/samza-core/src/main/java/org/apache/samza/serializers/CheckpointV2Serde.java
+++ b/samza-core/src/main/java/org/apache/samza/serializers/CheckpointV2Serde.java
@@ -21,15 +21,13 @@ package org.apache.samza.serializers;
 
 import com.fasterxml.jackson.databind.ObjectMapper;
 import org.apache.samza.SamzaException;
-import org.apache.samza.checkpoint.CheckpointId;
 import org.apache.samza.checkpoint.CheckpointV2;
 import org.apache.samza.serializers.model.SamzaObjectMapper;
-import org.apache.samza.system.SystemStreamPartition;
 
 
 /**
- * The {@link Serde} for {@link CheckpointV2} which includes {@link CheckpointId}s, state checkpoint markers
- * and the input {@link SystemStreamPartition} offsets.
+ * The {@link Serde} for {@link CheckpointV2} which includes {@link org.apache.samza.checkpoint.CheckpointId}s,
+ * state checkpoint markers and the input {@link org.apache.samza.system.SystemStreamPartition} offsets.
  *
  * The overall payload is serde'd as JSON using {@link SamzaObjectMapper}. Since the Samza classes cannot be directly
  * serialized by Jackson using {@link org.apache.samza.serializers.model.JsonCheckpointV2Mixin}.
diff --git a/samza-core/src/main/java/org/apache/samza/serializers/model/JsonCheckpointV2Mixin.java b/samza-core/src/main/java/org/apache/samza/serializers/model/JsonCheckpointV2Mixin.java
index 8e26745..9edd234 100644
--- a/samza-core/src/main/java/org/apache/samza/serializers/model/JsonCheckpointV2Mixin.java
+++ b/samza-core/src/main/java/org/apache/samza/serializers/model/JsonCheckpointV2Mixin.java
@@ -23,10 +23,7 @@ import com.fasterxml.jackson.annotation.JsonCreator;
 import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
 import com.fasterxml.jackson.annotation.JsonProperty;
 import java.util.Map;
-import java.util.Set;
-import org.apache.samza.Partition;
 import org.apache.samza.checkpoint.CheckpointId;
-import org.apache.samza.container.TaskName;
 import org.apache.samza.system.SystemStreamPartition;
 
 @JsonIgnoreProperties(ignoreUnknown = true)
diff --git a/samza-core/src/main/java/org/apache/samza/storage/KafkaChangelogStateBackendFactory.java b/samza-core/src/main/java/org/apache/samza/storage/KafkaChangelogStateBackendFactory.java
index e3230f6..36ea5b5 100644
--- a/samza-core/src/main/java/org/apache/samza/storage/KafkaChangelogStateBackendFactory.java
+++ b/samza-core/src/main/java/org/apache/samza/storage/KafkaChangelogStateBackendFactory.java
@@ -202,7 +202,7 @@ public class KafkaChangelogStateBackendFactory implements StateBackendFactory {
     Map<SystemStreamPartition, String> changelogSSPToStore = new HashMap<>();
     changelogSystemStreams.forEach((storeName, systemStream) ->
         containerModel.getTasks().forEach((taskName, taskModel) -> {
-          if (TaskMode.Standby.equals(taskModel.getTaskMode())) {
+          if (!TaskMode.Standby.equals(taskModel.getTaskMode())) {
             changelogSSPToStore.put(new SystemStreamPartition(systemStream, taskModel.getChangelogPartition()),
                 storeName);
           }
diff --git a/samza-core/src/main/java/org/apache/samza/storage/NonTransactionalStateTaskRestoreManager.java b/samza-core/src/main/java/org/apache/samza/storage/NonTransactionalStateTaskRestoreManager.java
index 44dd59a..f648b06 100644
--- a/samza-core/src/main/java/org/apache/samza/storage/NonTransactionalStateTaskRestoreManager.java
+++ b/samza-core/src/main/java/org/apache/samza/storage/NonTransactionalStateTaskRestoreManager.java
@@ -28,9 +28,14 @@ import java.util.Set;
 import java.util.stream.Collectors;
 import org.apache.samza.Partition;
 import org.apache.samza.SamzaException;
+import org.apache.samza.checkpoint.Checkpoint;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.StorageConfig;
+import org.apache.samza.context.ContainerContext;
+import org.apache.samza.context.JobContext;
 import org.apache.samza.job.model.TaskModel;
+import org.apache.samza.metrics.MetricsRegistry;
+import org.apache.samza.serializers.Serde;
 import org.apache.samza.system.ChangelogSSPIterator;
 import org.apache.samza.system.StreamMetadataCache;
 import org.apache.samza.system.StreamSpec;
@@ -40,6 +45,7 @@ import org.apache.samza.system.SystemConsumer;
 import org.apache.samza.system.SystemStream;
 import org.apache.samza.system.SystemStreamMetadata;
 import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.task.MessageCollector;
 import org.apache.samza.util.Clock;
 import org.apache.samza.util.FileUtil;
 import org.slf4j.Logger;
@@ -64,44 +70,52 @@ class NonTransactionalStateTaskRestoreManager implements TaskRestoreManager {
   private final Clock clock; // Clock value used to validate base-directories for staleness. See isLoggedStoreValid.
   private Map<SystemStream, String> changeLogOldestOffsets; // Map of changelog oldest known offsets
   private final Map<SystemStreamPartition, String> fileOffsets; // Map of offsets read from offset file indexed by changelog SSP
-  private final Map<String, SystemStream> changelogSystemStreams; // Map of change log system-streams indexed by store name
+  private final Map<String, SystemStream> storeChangelogs; // Map of change log system-streams indexed by store name
   private final SystemAdmins systemAdmins;
   private final File loggedStoreBaseDirectory;
   private final File nonLoggedStoreBaseDirectory;
   private final StreamMetadataCache streamMetadataCache;
   private final Map<String, SystemConsumer> storeConsumers;
   private final int maxChangeLogStreamPartitions;
-  private final StorageConfig storageConfig;
+  private final Config config;
   private final StorageManagerUtil storageManagerUtil;
 
   NonTransactionalStateTaskRestoreManager(
+      Set<String> storeNames,
+      JobContext jobContext,
+      ContainerContext containerContext,
       TaskModel taskModel,
-      Map<String, SystemStream> changelogSystemStreams,
-      Map<String, StorageEngine> taskStores,
+      Map<String, SystemStream> storeChangelogs,
+      Map<String, StorageEngine> inMemoryStores,
+      Map<String, StorageEngineFactory<Object, Object>> storageEngineFactories,
+      Map<String, Serde<Object>> serdes,
       SystemAdmins systemAdmins,
       StreamMetadataCache streamMetadataCache,
       Map<String, SystemConsumer> storeConsumers,
+      MetricsRegistry metricsRegistry,
+      MessageCollector messageCollector,
       int maxChangeLogStreamPartitions,
       File loggedStoreBaseDirectory,
       File nonLoggedStoreBaseDirectory,
       Config config,
       Clock clock) {
-    this.taskStores = taskStores;
     this.taskModel = taskModel;
     this.clock = clock;
-    this.changelogSystemStreams = changelogSystemStreams;
+    this.storeChangelogs = storeChangelogs;
     this.systemAdmins = systemAdmins;
     this.fileOffsets = new HashMap<>();
-    this.taskStoresToRestore = this.taskStores.entrySet().stream()
-        .filter(x -> x.getValue().getStoreProperties().isLoggedStore())
-        .map(x -> x.getKey()).collect(Collectors.toSet());
     this.loggedStoreBaseDirectory = loggedStoreBaseDirectory;
     this.nonLoggedStoreBaseDirectory = nonLoggedStoreBaseDirectory;
     this.streamMetadataCache = streamMetadataCache;
     this.storeConsumers = storeConsumers;
     this.maxChangeLogStreamPartitions = maxChangeLogStreamPartitions;
-    this.storageConfig = new StorageConfig(config);
+    this.config = config;
     this.storageManagerUtil = new StorageManagerUtil();
+    this.taskStores = createStoreEngines(storeNames, jobContext, containerContext,
+        storageEngineFactories, serdes, metricsRegistry, messageCollector, inMemoryStores);
+    this.taskStoresToRestore = this.taskStores.entrySet().stream()
+        .filter(x -> x.getValue().getStoreProperties().isLoggedStore())
+        .map(x -> x.getKey()).collect(Collectors.toSet());
   }
 
   /**
@@ -109,7 +123,7 @@ class NonTransactionalStateTaskRestoreManager implements TaskRestoreManager {
    * and registers SSPs with the respective consumers.
    */
   @Override
-  public void init(Map<SystemStreamPartition, String> checkpointedChangelogSSPOffsets) {
+  public void init(Checkpoint checkpoint) {
     cleanBaseDirsAndReadOffsetFiles();
     setupBaseDirs();
     validateChangelogStreams();
@@ -124,7 +138,7 @@ class NonTransactionalStateTaskRestoreManager implements TaskRestoreManager {
    */
   private void cleanBaseDirsAndReadOffsetFiles() {
     LOG.debug("Cleaning base directories for stores.");
-
+    StorageConfig storageConfig = new StorageConfig(config);
     FileUtil fileUtil = new FileUtil();
     taskStores.forEach((storeName, storageEngine) -> {
       if (!storageEngine.getStoreProperties().isLoggedStore()) {
@@ -147,7 +161,7 @@ class NonTransactionalStateTaskRestoreManager implements TaskRestoreManager {
           fileUtil.rm(loggedStorePartitionDir);
         } else {
 
-          SystemStreamPartition changelogSSP = new SystemStreamPartition(changelogSystemStreams.get(storeName), taskModel.getChangelogPartition());
+          SystemStreamPartition changelogSSP = new SystemStreamPartition(storeChangelogs.get(storeName), taskModel.getChangelogPartition());
           Map<SystemStreamPartition, String> offset =
               storageManagerUtil.readOffsetFile(loggedStorePartitionDir, Collections.singleton(changelogSSP), false);
           LOG.info("Read offset {} for the store {} from logged storage partition directory {}", offset, storeName, loggedStorePartitionDir);
@@ -170,10 +184,10 @@ class NonTransactionalStateTaskRestoreManager implements TaskRestoreManager {
    * @return true if the logged store is valid, false otherwise.
    */
   private boolean isLoggedStoreValid(String storeName, File loggedStoreDir) {
-    long changeLogDeleteRetentionInMs = storageConfig.getChangeLogDeleteRetentionInMs(storeName);
+    long changeLogDeleteRetentionInMs = new StorageConfig(config).getChangeLogDeleteRetentionInMs(storeName);
 
-    if (changelogSystemStreams.containsKey(storeName)) {
-      SystemStreamPartition changelogSSP = new SystemStreamPartition(changelogSystemStreams.get(storeName), taskModel.getChangelogPartition());
+    if (storeChangelogs.containsKey(storeName)) {
+      SystemStreamPartition changelogSSP = new SystemStreamPartition(storeChangelogs.get(storeName), taskModel.getChangelogPartition());
       return this.taskStores.get(storeName).getStoreProperties().isPersistedToDisk()
           && storageManagerUtil.isOffsetFileValid(loggedStoreDir, Collections.singleton(changelogSSP), false)
           && !storageManagerUtil.isStaleStore(loggedStoreDir, changeLogDeleteRetentionInMs, clock.currentTimeMillis(), false);
@@ -213,9 +227,9 @@ class NonTransactionalStateTaskRestoreManager implements TaskRestoreManager {
    *  Validates each changelog system-stream with its respective SystemAdmin.
    */
   private void validateChangelogStreams() {
-    LOG.info("Validating change log streams: " + changelogSystemStreams);
+    LOG.info("Validating change log streams: " + storeChangelogs);
 
-    for (SystemStream changelogSystemStream : changelogSystemStreams.values()) {
+    for (SystemStream changelogSystemStream : storeChangelogs.values()) {
       SystemAdmin systemAdmin = systemAdmins.getSystemAdmin(changelogSystemStream.getSystem());
       StreamSpec changelogSpec =
           StreamSpec.createChangeLogStreamSpec(changelogSystemStream.getStream(), changelogSystemStream.getSystem(),
@@ -232,7 +246,7 @@ class NonTransactionalStateTaskRestoreManager implements TaskRestoreManager {
 
     Map<SystemStream, SystemStreamMetadata> changeLogMetadata = JavaConverters.mapAsJavaMapConverter(
         streamMetadataCache.getStreamMetadata(
-            JavaConverters.asScalaSetConverter(new HashSet<>(changelogSystemStreams.values())).asScala().toSet(),
+            JavaConverters.asScalaSetConverter(new HashSet<>(storeChangelogs.values())).asScala().toSet(),
             false)).asJava();
 
     LOG.info("Got change log stream metadata: {}", changeLogMetadata);
@@ -267,7 +281,7 @@ class NonTransactionalStateTaskRestoreManager implements TaskRestoreManager {
    */
   private void registerStartingOffsets() {
 
-    for (Map.Entry<String, SystemStream> changelogSystemStreamEntry : changelogSystemStreams.entrySet()) {
+    for (Map.Entry<String, SystemStream> changelogSystemStreamEntry : storeChangelogs.entrySet()) {
       SystemStreamPartition systemStreamPartition =
           new SystemStreamPartition(changelogSystemStreamEntry.getValue(), taskModel.getChangelogPartition());
       SystemAdmin systemAdmin = systemAdmins.getSystemAdmin(changelogSystemStreamEntry.getValue().getSystem());
@@ -312,6 +326,28 @@ class NonTransactionalStateTaskRestoreManager implements TaskRestoreManager {
     return storageManagerUtil.getStartingOffset(systemStreamPartition, systemAdmin, fileOffset, oldestOffset);
   }
 
+  // TODO dchen put this in common code path for transactional and non-transactional
+  private Map<String, StorageEngine> createStoreEngines(Set<String> storeNames, JobContext jobContext,
+      ContainerContext containerContext, Map<String, StorageEngineFactory<Object, Object>> storageEngineFactories,
+      Map<String, Serde<Object>> serdes, MetricsRegistry metricsRegistry,
+      MessageCollector messageCollector, Map<String, StorageEngine> nonPersistedStores) {
+    Map<String, StorageEngine> storageEngines = new HashMap<>();
+    // Put non persisted stores
+    nonPersistedStores.forEach(storageEngines::put);
+    // Create persisted stores
+    storeNames.forEach(storeName -> {
+      boolean isLogged = this.storeChangelogs.containsKey(storeName);
+      File storeBaseDir = isLogged ? this.loggedStoreBaseDirectory : this.nonLoggedStoreBaseDirectory;
+      File storeDirectory = storageManagerUtil.getTaskStoreDir(storeBaseDir, storeName, taskModel.getTaskName(),
+          taskModel.getTaskMode());
+      StorageEngine engine = ContainerStorageManager.createStore(storeName, storeDirectory, taskModel, jobContext, containerContext,
+          storageEngineFactories, serdes, metricsRegistry, messageCollector,
+          StorageEngineFactory.StoreMode.BulkLoad, this.storeChangelogs, this.config);
+      storageEngines.put(storeName, engine);
+    });
+    return storageEngines;
+  }
+
   /**
    * Restore each store in taskStoresToRestore sequentially
    */
@@ -320,7 +356,7 @@ class NonTransactionalStateTaskRestoreManager implements TaskRestoreManager {
     for (String storeName : taskStoresToRestore) {
       LOG.info("Restoring store: {} for task: {}", storeName, taskModel.getTaskName());
       SystemConsumer systemConsumer = storeConsumers.get(storeName);
-      SystemStream systemStream = changelogSystemStreams.get(storeName);
+      SystemStream systemStream = storeChangelogs.get(storeName);
       SystemAdmin systemAdmin = systemAdmins.getSystemAdmin(systemStream.getSystem());
       ChangelogSSPIterator changelogSSPIterator = new ChangelogSSPIterator(systemConsumer,
           new SystemStreamPartition(systemStream, taskModel.getChangelogPartition()), null, systemAdmin, false);
@@ -333,7 +369,7 @@ class NonTransactionalStateTaskRestoreManager implements TaskRestoreManager {
    * Stop only persistent stores. In case of certain stores and store mode (such as RocksDB), this
    * can invoke compaction.
    */
-  public void stopPersistentStores() {
+  public void close() {
 
     Map<String, StorageEngine> persistentStores = this.taskStores.entrySet().stream().filter(e -> {
       return e.getValue().getStoreProperties().isPersistedToDisk();
diff --git a/samza-core/src/main/java/org/apache/samza/storage/StorageRecovery.java b/samza-core/src/main/java/org/apache/samza/storage/StorageRecovery.java
index 9d1896e..0925949 100644
--- a/samza-core/src/main/java/org/apache/samza/storage/StorageRecovery.java
+++ b/samza-core/src/main/java/org/apache/samza/storage/StorageRecovery.java
@@ -20,13 +20,10 @@
 package org.apache.samza.storage;
 
 import java.io.File;
-import java.time.Duration;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
-import java.util.Set;
-import java.util.stream.Collectors;
 import org.apache.samza.SamzaException;
 import org.apache.samza.checkpoint.CheckpointManager;
 import org.apache.samza.config.Config;
@@ -46,12 +43,10 @@ import org.apache.samza.job.model.TaskModel;
 import org.apache.samza.metrics.MetricsRegistryMap;
 import org.apache.samza.serializers.Serde;
 import org.apache.samza.serializers.SerdeFactory;
-import org.apache.samza.system.SSPMetadataCache;
 import org.apache.samza.system.StreamMetadataCache;
 import org.apache.samza.system.SystemAdmins;
 import org.apache.samza.system.SystemFactory;
 import org.apache.samza.system.SystemStream;
-import org.apache.samza.system.SystemStreamPartition;
 import org.apache.samza.util.Clock;
 import org.apache.samza.util.CoordinatorStreamUtil;
 import org.apache.samza.util.ReflectionUtil;
@@ -220,6 +215,7 @@ public class StorageRecovery {
    */
   @SuppressWarnings("rawtypes")
   private void getContainerStorageManagers() {
+    String factoryClass = new StorageConfig(jobConfig).getStateBackendRestoreFactory();
     Clock clock = SystemClock.instance();
     StreamMetadataCache streamMetadataCache = new StreamMetadataCache(systemAdmins, 5000, clock);
     // don't worry about prefetching for this; looks like the tool doesn't flush to offset files anyways
@@ -230,18 +226,11 @@ public class StorageRecovery {
     for (ContainerModel containerModel : containers.values()) {
       ContainerContext containerContext = new ContainerContextImpl(containerModel, new MetricsRegistryMap());
 
-      Set<SystemStreamPartition> changelogSSPs = changeLogSystemStreams.values().stream()
-          .flatMap(ss -> containerModel.getTasks().values().stream()
-              .map(tm -> new SystemStreamPartition(ss, tm.getChangelogPartition())))
-          .collect(Collectors.toSet());
-      SSPMetadataCache sspMetadataCache = new SSPMetadataCache(systemAdmins, Duration.ofMillis(5000), clock, changelogSSPs);
-
       ContainerStorageManager containerStorageManager =
           new ContainerStorageManager(
               checkpointManager,
               containerModel,
               streamMetadataCache,
-              sspMetadataCache,
               systemAdmins,
               changeLogSystemStreams,
               new HashMap<>(),
@@ -253,10 +242,10 @@ public class StorageRecovery {
               new SamzaContainerMetrics(containerModel.getId(), new MetricsRegistryMap(), ""),
               JobContextImpl.fromConfigWithDefaults(jobConfig, jobModel),
               containerContext,
+              ReflectionUtil.getObj(factoryClass, StateBackendFactory.class),
               new HashMap<>(),
               storeBaseDir,
               storeBaseDir,
-              maxPartitionNumber,
               null,
               new SystemClock());
       this.containerStorageManagers.put(containerModel.getId(), containerStorageManager);
diff --git a/samza-core/src/main/java/org/apache/samza/storage/TaskRestoreManager.java b/samza-core/src/main/java/org/apache/samza/storage/TaskRestoreManager.java
deleted file mode 100644
index f60e148..0000000
--- a/samza-core/src/main/java/org/apache/samza/storage/TaskRestoreManager.java
+++ /dev/null
@@ -1,55 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.samza.storage;
-
-import java.util.Map;
-import org.apache.samza.system.SystemStreamPartition;
-
-
-/**
- * The helper interface restores task state.
- */
-public interface TaskRestoreManager {
-
-  /**
-   * Init state resources such as file directories.
-   */
-  void init(Map<SystemStreamPartition, String> checkpointedChangelogSSPOffsets);
-
-  /**
-   * Restore state from checkpoints, state snapshots and changelog.
-   * Currently, store restoration happens on a separate thread pool within {@code ContainerStorageManager}. In case of
-   * interrupt/shutdown signals from {@code SamzaContainer}, {@code ContainerStorageManager} may interrupt the restore
-   * thread.
-   *
-   * Note: Typically, interrupt signals don't bubble up as {@link InterruptedException} unless the restore thread is
-   * waiting on IO/network. In case of busy looping, implementors are expected to check the interrupt status of the
-   * thread periodically and shutdown gracefully before throwing {@link InterruptedException} upstream.
-   * {@code SamzaContainer} will not wait for clean up and the interrupt signal is the best effort by the container
-   * to notify that its shutting down.
-   */
-  void restore() throws InterruptedException;
-
-  /**
-   * Stop all persistent stores after restoring.
-   */
-  void stopPersistentStores();
-
-}
diff --git a/samza-core/src/main/java/org/apache/samza/storage/TaskRestoreManagerFactory.java b/samza-core/src/main/java/org/apache/samza/storage/TaskRestoreManagerFactory.java
deleted file mode 100644
index 9da9bc0..0000000
--- a/samza-core/src/main/java/org/apache/samza/storage/TaskRestoreManagerFactory.java
+++ /dev/null
@@ -1,82 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.samza.storage;
-
-import java.io.File;
-import java.util.Map;
-import org.apache.samza.config.Config;
-import org.apache.samza.config.TaskConfig;
-import org.apache.samza.job.model.TaskModel;
-import org.apache.samza.system.SSPMetadataCache;
-import org.apache.samza.system.StreamMetadataCache;
-import org.apache.samza.system.SystemAdmins;
-import org.apache.samza.system.SystemConsumer;
-import org.apache.samza.system.SystemStream;
-import org.apache.samza.util.Clock;
-
-/**
- * Factory class to create {@link TaskRestoreManager}.
- */
-class TaskRestoreManagerFactory {
-
-  public static TaskRestoreManager create(
-      TaskModel taskModel,
-      Map<String, SystemStream> changelogSystemStreams,
-      Map<String, StorageEngine> taskStores,
-      SystemAdmins systemAdmins,
-      StreamMetadataCache streamMetadataCache,
-      SSPMetadataCache sspMetadataCache,
-      Map<String, SystemConsumer> storeConsumers,
-      int maxChangeLogStreamPartitions,
-      File loggedStoreBaseDirectory,
-      File nonLoggedStoreBaseDirectory,
-      Config config,
-      Clock clock) {
-
-    if (new TaskConfig(config).getTransactionalStateRestoreEnabled()) {
-      // Create checkpoint-snapshot based state restoration which is transactional.
-      return new TransactionalStateTaskRestoreManager(
-          taskModel,
-          taskStores,
-          changelogSystemStreams,
-          systemAdmins,
-          storeConsumers,
-          sspMetadataCache,
-          loggedStoreBaseDirectory,
-          nonLoggedStoreBaseDirectory,
-          config,
-          clock
-      );
-    } else {
-      // Create legacy offset-file based state restoration which is NOT transactional.
-      return new NonTransactionalStateTaskRestoreManager(
-          taskModel,
-          changelogSystemStreams,
-          taskStores,
-          systemAdmins,
-          streamMetadataCache,
-          storeConsumers,
-          maxChangeLogStreamPartitions,
-          loggedStoreBaseDirectory,
-          nonLoggedStoreBaseDirectory,
-          config,
-          clock);
-    }
-  }
-}
diff --git a/samza-core/src/main/java/org/apache/samza/storage/TaskSideInputStorageManager.java b/samza-core/src/main/java/org/apache/samza/storage/TaskSideInputStorageManager.java
index c93e0b3..f407b1a 100644
--- a/samza-core/src/main/java/org/apache/samza/storage/TaskSideInputStorageManager.java
+++ b/samza-core/src/main/java/org/apache/samza/storage/TaskSideInputStorageManager.java
@@ -211,6 +211,7 @@ public class TaskSideInputStorageManager {
 
   private void validateStoreConfiguration(Map<String, StorageEngine> stores) {
     stores.forEach((storeName, storageEngine) -> {
+      // Ensure that the side inputs store is NOT logged (they are durable)
       if (storageEngine.getStoreProperties().isLoggedStore()) {
         throw new SamzaException(
             String.format("Cannot configure both side inputs and a changelog for store: %s.", storeName));
diff --git a/samza-core/src/main/java/org/apache/samza/storage/TransactionalStateTaskRestoreManager.java b/samza-core/src/main/java/org/apache/samza/storage/TransactionalStateTaskRestoreManager.java
index 4b6ac1f..1329a0d 100644
--- a/samza-core/src/main/java/org/apache/samza/storage/TransactionalStateTaskRestoreManager.java
+++ b/samza-core/src/main/java/org/apache/samza/storage/TransactionalStateTaskRestoreManager.java
@@ -31,24 +31,35 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
+import java.util.Set;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.samza.Partition;
 import org.apache.samza.SamzaException;
-import org.apache.samza.checkpoint.CheckpointedChangelogOffset;
+import org.apache.samza.checkpoint.Checkpoint;
+import org.apache.samza.checkpoint.CheckpointId;
+import org.apache.samza.checkpoint.CheckpointV1;
+import org.apache.samza.checkpoint.CheckpointV2;
+import org.apache.samza.checkpoint.kafka.KafkaChangelogSSPOffset;
+import org.apache.samza.checkpoint.kafka.KafkaStateCheckpointMarker;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.StorageConfig;
 import org.apache.samza.config.TaskConfig;
 import org.apache.samza.container.TaskName;
+import org.apache.samza.context.ContainerContext;
+import org.apache.samza.context.JobContext;
 import org.apache.samza.job.model.TaskMode;
 import org.apache.samza.job.model.TaskModel;
+import org.apache.samza.metrics.MetricsRegistry;
+import org.apache.samza.serializers.Serde;
 import org.apache.samza.system.ChangelogSSPIterator;
 import org.apache.samza.system.SSPMetadataCache;
 import org.apache.samza.system.SystemAdmin;
 import org.apache.samza.system.SystemAdmins;
 import org.apache.samza.system.SystemConsumer;
 import org.apache.samza.system.SystemStream;
-import org.apache.samza.system.SystemStreamPartition;
 import org.apache.samza.system.SystemStreamMetadata.SystemStreamPartitionMetadata;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.task.MessageCollector;
 import org.apache.samza.util.Clock;
 import org.apache.samza.util.FileUtil;
 import org.slf4j.Logger;
@@ -64,7 +75,7 @@ public class TransactionalStateTaskRestoreManager implements TaskRestoreManager
   private final Map<String, StorageEngine> storeEngines; // store name to storage engines
   private final Map<String, SystemStream> storeChangelogs; // store name to changelog system stream
   private final SystemAdmins systemAdmins;
-  private final Map<String, SystemConsumer> storeConsumers;
+  private final Map<String, SystemConsumer> storeConsumers; // store name to system consumer
   private final SSPMetadataCache sspMetadataCache;
   private final File loggedStoreBaseDirectory;
   private final File nonLoggedStoreBaseDirectory;
@@ -77,18 +88,24 @@ public class TransactionalStateTaskRestoreManager implements TaskRestoreManager
   private Map<SystemStreamPartition, SystemStreamPartitionMetadata> currentChangelogOffsets;
 
   public TransactionalStateTaskRestoreManager(
+      Set<String> storeNames,  // non-side input stores
+      JobContext jobContext,
+      ContainerContext containerContext,
       TaskModel taskModel,
-      Map<String, StorageEngine> storeEngines,
       Map<String, SystemStream> storeChangelogs,
+      Map<String, StorageEngine> inMemoryStores, // in memory stores to be mutated during restore
+      Map<String, StorageEngineFactory<Object, Object>> storageEngineFactories,
+      Map<String, Serde<Object>> serdes,
       SystemAdmins systemAdmins,
       Map<String, SystemConsumer> storeConsumers,
+      MetricsRegistry metricsRegistry,
+      MessageCollector messageCollector,
       SSPMetadataCache sspMetadataCache,
       File loggedStoreBaseDirectory,
       File nonLoggedStoreBaseDirectory,
       Config config,
       Clock clock) {
     this.taskModel = taskModel;
-    this.storeEngines = storeEngines;
     this.storeChangelogs = storeChangelogs;
     this.systemAdmins = systemAdmins;
     this.storeConsumers = storeConsumers;
@@ -101,14 +118,17 @@ public class TransactionalStateTaskRestoreManager implements TaskRestoreManager
     this.clock = clock;
     this.storageManagerUtil = new StorageManagerUtil();
     this.fileUtil = new FileUtil();
+    this.storeEngines = createStoreEngines(storeNames, jobContext, containerContext,
+        storageEngineFactories, serdes, metricsRegistry, messageCollector, inMemoryStores);
   }
 
   @Override
-  public void init(Map<SystemStreamPartition, String> checkpointedChangelogOffsets) {
+  public void init(Checkpoint checkpoint) {
+    Map<String, KafkaStateCheckpointMarker> storeStateCheckpointMarkers = getCheckpointedChangelogOffsets(checkpoint);
     currentChangelogOffsets = getCurrentChangelogOffsets(taskModel, storeChangelogs, sspMetadataCache);
 
     this.storeActions = getStoreActions(taskModel, storeEngines, storeChangelogs,
-        checkpointedChangelogOffsets, currentChangelogOffsets, systemAdmins, storageManagerUtil,
+        storeStateCheckpointMarkers, getCheckpointId(checkpoint), currentChangelogOffsets, systemAdmins, storageManagerUtil,
         loggedStoreBaseDirectory, nonLoggedStoreBaseDirectory, config, clock);
 
     setupStoreDirs(taskModel, storeEngines, storeActions, storageManagerUtil, fileUtil,
@@ -143,7 +163,7 @@ public class TransactionalStateTaskRestoreManager implements TaskRestoreManager
    * Stop only persistent stores. In case of certain stores and store mode (such as RocksDB), this
    * can invoke compaction. Persisted stores are recreated in read-write mode in {@link ContainerStorageManager}.
    */
-  public void stopPersistentStores() {
+  public void close() {
     TaskName taskName = taskModel.getTaskName();
     storeEngines.forEach((storeName, storeEngine) -> {
       if (storeEngine.getStoreProperties().isPersistedToDisk())
@@ -152,6 +172,27 @@ public class TransactionalStateTaskRestoreManager implements TaskRestoreManager
     });
   }
 
+  private Map<String, StorageEngine> createStoreEngines(Set<String> storeNames, JobContext jobContext,
+      ContainerContext containerContext, Map<String, StorageEngineFactory<Object, Object>> storageEngineFactories,
+      Map<String, Serde<Object>> serdes, MetricsRegistry metricsRegistry,
+      MessageCollector messageCollector, Map<String, StorageEngine> nonPersistedStores) {
+    Map<String, StorageEngine> storageEngines = new HashMap<>();
+    // Put non persisted stores
+    nonPersistedStores.forEach(storageEngines::put);
+    // Create persisted stores
+    storeNames.forEach(storeName -> {
+      boolean isLogged = this.storeChangelogs.containsKey(storeName);
+      File storeBaseDir = isLogged ? this.loggedStoreBaseDirectory : this.nonLoggedStoreBaseDirectory;
+      File storeDirectory = storageManagerUtil.getTaskStoreDir(storeBaseDir, storeName, taskModel.getTaskName(),
+          taskModel.getTaskMode());
+      StorageEngine engine = ContainerStorageManager.createStore(storeName, storeDirectory, taskModel, jobContext, containerContext,
+          storageEngineFactories, serdes, metricsRegistry, messageCollector,
+          StorageEngineFactory.StoreMode.BulkLoad, this.storeChangelogs, this.config);
+      storageEngines.put(storeName, engine);
+    });
+    return storageEngines;
+  }
+
   /**
    * Get offset metadata for each changelog SSP for this task. A task may have multiple changelog streams
    * (e.g., for different stores), but will have the same partition for all of them.
@@ -194,7 +235,8 @@ public class TransactionalStateTaskRestoreManager implements TaskRestoreManager
       TaskModel taskModel,
       Map<String, StorageEngine> storeEngines,
       Map<String, SystemStream> storeChangelogs,
-      Map<SystemStreamPartition, String> checkpointedChangelogOffsets,
+      Map<String, KafkaStateCheckpointMarker> kafkaStateCheckpointMarkers,
+      CheckpointId checkpointId,
       Map<SystemStreamPartition, SystemStreamPartitionMetadata> currentChangelogOffsets,
       SystemAdmins systemAdmins,
       StorageManagerUtil storageManagerUtil,
@@ -236,15 +278,12 @@ public class TransactionalStateTaskRestoreManager implements TaskRestoreManager
       String oldestOffset = changelogSSPMetadata.getOldestOffset();
       String newestOffset = changelogSSPMetadata.getNewestOffset();
 
-      String checkpointMessage = checkpointedChangelogOffsets.get(changelogSSP);
-      String checkpointedOffset = null;  // can be null if no message, or message has null offset
-      long timeSinceLastCheckpointInMs = Long.MAX_VALUE;
-      if (StringUtils.isNotBlank(checkpointMessage)) {
-        CheckpointedChangelogOffset checkpointedChangelogOffset = CheckpointedChangelogOffset.fromString(checkpointMessage);
-        checkpointedOffset = checkpointedChangelogOffset.getOffset();
-        timeSinceLastCheckpointInMs = System.currentTimeMillis() -
-            checkpointedChangelogOffset.getCheckpointId().getMillis();
+      String checkpointedOffset = null; // can be null if no message, or message has null offset
+      if (kafkaStateCheckpointMarkers.containsKey(storeName) &&
+          StringUtils.isNotBlank(kafkaStateCheckpointMarkers.get(storeName).getChangelogOffset())) {
+        checkpointedOffset = kafkaStateCheckpointMarkers.get(storeName).getChangelogOffset();
       }
+      long timeSinceLastCheckpointInMs = checkpointId == null ? Long.MAX_VALUE : System.currentTimeMillis() - checkpointId.getMillis();
 
       // if the clean.store.start config is set, delete current and checkpoint dirs, restore from oldest offset to checkpointed
       if (storageEngine.getStoreProperties().isPersistedToDisk() && new StorageConfig(
@@ -557,6 +596,60 @@ public class TransactionalStateTaskRestoreManager implements TaskRestoreManager
     }
   }
 
+  private Map<String, KafkaStateCheckpointMarker> getCheckpointedChangelogOffsets(Checkpoint checkpoint) {
+    Map<String, KafkaStateCheckpointMarker> checkpointedChangelogOffsets = new HashMap<>();
+    if (checkpoint == null) return checkpointedChangelogOffsets;
+
+    if (checkpoint instanceof CheckpointV2) {
+      Map<String, Map<String, String>> factoryStoreSCMs = ((CheckpointV2) checkpoint).getStateCheckpointMarkers();
+      if (factoryStoreSCMs.containsKey(KafkaStateCheckpointMarker.KAFKA_STATE_BACKEND_FACTORY_NAME)) {
+        factoryStoreSCMs.get(KafkaStateCheckpointMarker.KAFKA_STATE_BACKEND_FACTORY_NAME)
+            .forEach((storeName, scmString) -> {
+              KafkaStateCheckpointMarker kafkaSCM = KafkaStateCheckpointMarker.deserialize(scmString);
+              checkpointedChangelogOffsets.put(storeName, kafkaSCM);
+            });
+      } // skip the non-KafkaStateCheckpointMarkers
+    } else if (checkpoint instanceof CheckpointV1) {
+      // If the checkpoint v1 is used, we need to fetch the changelog SSPs in the inputOffsets in order to get the
+      // store offset.
+      Map<SystemStreamPartition, String> checkpointedOffsets = ((CheckpointV1) checkpoint).getOffsets();
+      storeChangelogs.forEach((storeName, systemStream) -> {
+        Partition changelogPartition = taskModel.getChangelogPartition();
+        SystemStreamPartition storeChangelogSSP = new SystemStreamPartition(systemStream, changelogPartition);
+        String checkpointedOffset = checkpointedOffsets.get(storeChangelogSSP);
+        if (StringUtils.isNotBlank(checkpointedOffset)) {
+          KafkaChangelogSSPOffset kafkaChangelogSSPOffset = KafkaChangelogSSPOffset.fromString(checkpointedOffset);
+          KafkaStateCheckpointMarker marker = new KafkaStateCheckpointMarker(
+              storeChangelogSSP, kafkaChangelogSSPOffset.getChangelogOffset());
+          checkpointedChangelogOffsets.put(storeName, marker);
+        }
+      });
+    } else {
+      throw new SamzaException("Unsupported checkpoint version: " + checkpoint.getVersion());
+    }
+
+    return checkpointedChangelogOffsets;
+  }
+
+  private CheckpointId getCheckpointId(Checkpoint checkpoint) {
+    if (checkpoint == null) return null;
+    if (checkpoint instanceof CheckpointV1) {
+      for (Map.Entry<String, SystemStream> storeNameSystemStream : storeChangelogs.entrySet()) {
+        SystemStreamPartition storeChangelogSSP = new SystemStreamPartition(storeNameSystemStream.getValue(), taskModel.getChangelogPartition());
+        String checkpointMessage = checkpoint.getOffsets().get(storeChangelogSSP);
+        if (StringUtils.isNotBlank(checkpointMessage)) {
+          KafkaChangelogSSPOffset kafkaStateChanglogOffset = KafkaChangelogSSPOffset.fromString(checkpointMessage);
+          return kafkaStateChanglogOffset.getCheckpointId();
+        }
+      }
+    } else if (checkpoint instanceof CheckpointV2) {
+      return ((CheckpointV2) checkpoint).getCheckpointId();
+    } else {
+      throw new SamzaException("Unsupported checkpoint version: " + checkpoint.getVersion());
+    }
+    return null;
+  }
+
   @VisibleForTesting
   static class StoreActions {
     final Map<String, File> storeDirsToRetain;
diff --git a/samza-core/src/main/java/org/apache/samza/system/SystemAdmins.java b/samza-core/src/main/java/org/apache/samza/system/SystemAdmins.java
index 2ca81ca..987e0c6 100644
--- a/samza-core/src/main/java/org/apache/samza/system/SystemAdmins.java
+++ b/samza-core/src/main/java/org/apache/samza/system/SystemAdmins.java
@@ -43,6 +43,10 @@ public class SystemAdmins {
     this.systemAdminMap = systemConfig.getSystemAdmins(adminLabel);
   }
 
+  public SystemAdmins(Map<String, SystemAdmin> systemAdminMap) {
+    this.systemAdminMap = systemAdminMap;
+  }
+
   /**
    * Creates a new instance of {@link SystemAdmins} with an empty admin mapping.
    * @return New empty instance of {@link SystemAdmins}
@@ -70,6 +74,10 @@ public class SystemAdmins {
     return systemAdminMap.get(systemName);
   }
 
+  public Map<String, SystemAdmin> getSystemAdmins() {
+    return systemAdminMap;
+  }
+
   public Set<String> getSystemNames() {
     return systemAdminMap.keySet();
   }
diff --git a/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java b/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java
index c333748..6838ec2 100644
--- a/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java
+++ b/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java
@@ -45,6 +45,7 @@ import org.apache.samza.SamzaException;
 import org.apache.samza.checkpoint.Checkpoint;
 import org.apache.samza.checkpoint.CheckpointManager;
 import org.apache.samza.config.Config;
+import org.apache.samza.config.JobConfig;
 import org.apache.samza.config.StorageConfig;
 import org.apache.samza.config.TaskConfig;
 import org.apache.samza.container.RunLoop;
@@ -65,7 +66,6 @@ import org.apache.samza.serializers.SerdeManager;
 import org.apache.samza.storage.kv.Entry;
 import org.apache.samza.storage.kv.KeyValueStore;
 import org.apache.samza.system.IncomingMessageEnvelope;
-import org.apache.samza.system.SSPMetadataCache;
 import org.apache.samza.system.StreamMetadataCache;
 import org.apache.samza.system.SystemAdmins;
 import org.apache.samza.system.SystemConsumer;
@@ -79,6 +79,7 @@ import org.apache.samza.system.chooser.DefaultChooser;
 import org.apache.samza.system.chooser.MessageChooser;
 import org.apache.samza.system.chooser.RoundRobinChooserFactory;
 import org.apache.samza.table.utils.SerdeUtils;
+import org.apache.samza.task.MessageCollector;
 import org.apache.samza.task.TaskInstanceCollector;
 import org.apache.samza.util.Clock;
 import org.apache.samza.util.ReflectionUtil;
@@ -116,11 +117,14 @@ public class ContainerStorageManager {
   private static final int SIDE_INPUT_CHECK_TIMEOUT_SECONDS = 10;
   private static final int SIDE_INPUT_SHUTDOWN_TIMEOUT_SECONDS = 60;
 
+  private static final int RESTORE_THREAD_POOL_SHUTDOWN_TIMEOUT_SECONDS = 60;
+
   /** Maps containing relevant per-task objects */
-  private final Map<TaskName, Map<String, StorageEngine>> taskStores;
   private final Map<TaskName, TaskRestoreManager> taskRestoreManagers;
   private final Map<TaskName, TaskInstanceMetrics> taskInstanceMetrics;
   private final Map<TaskName, TaskInstanceCollector> taskInstanceCollectors;
+  private final Map<TaskName, Map<String, StorageEngine>> inMemoryStores; // subset of taskStores after #start()
+  private Map<TaskName, Map<String, StorageEngine>> taskStores; // Will be available after #start()
 
   private final Map<String, SystemConsumer> storeConsumers; // Mapping from store name to SystemConsumers
   private final Map<String, StorageEngineFactory<Object, Object>> storageEngineFactories; // Map of storageEngineFactories indexed by store name
@@ -129,7 +133,6 @@ public class ContainerStorageManager {
   private final SystemAdmins systemAdmins;
 
   private final StreamMetadataCache streamMetadataCache;
-  private final SSPMetadataCache sspMetadataCache;
   private final SamzaContainerMetrics samzaContainerMetrics;
 
   private final CheckpointManager checkpointManager;
@@ -142,13 +145,12 @@ public class ContainerStorageManager {
   private final File nonLoggedStoreBaseDirectory;
   private final Set<Path> storeDirectoryPaths; // the set of store directory paths, used by SamzaContainer to initialize its disk-space-monitor
 
-  private final int parallelRestoreThreadPoolSize;
-  private final int maxChangeLogStreamPartitions; // The partition count of each changelog-stream topic. This is used for validating changelog streams before restoring.
-
   /* Sideinput related parameters */
   private final boolean hasSideInputs;
+  private final Map<TaskName, Map<String, StorageEngine>> sideInputStores; // subset of taskStores after #start()
   // side inputs indexed first by task, then store name
   private final Map<TaskName, Map<String, Set<SystemStreamPartition>>> taskSideInputStoreSSPs;
+  private final Set<String> sideInputStoreNames;
   private final Map<SystemStreamPartition, TaskSideInputHandler> sspSideInputHandlers;
   private SystemConsumers sideInputSystemConsumers;
   private volatile Map<TaskName, CountDownLatch> sideInputTaskLatches; // Used by the sideInput-read thread to signal to the main thread
@@ -157,17 +159,19 @@ public class ContainerStorageManager {
 
   private final ExecutorService sideInputsExecutor = Executors.newSingleThreadExecutor(
       new ThreadFactoryBuilder().setDaemon(true).setNameFormat(SIDEINPUTS_THREAD_NAME).build());
+  private final ExecutorService restoreExecutor;
 
   private volatile Throwable sideInputException = null;
 
   private final Config config;
   private final StorageManagerUtil storageManagerUtil = new StorageManagerUtil();
 
+  private boolean isStarted = false;
+
   public ContainerStorageManager(
       CheckpointManager checkpointManager,
       ContainerModel containerModel,
       StreamMetadataCache streamMetadataCache,
-      SSPMetadataCache sspMetadataCache,
       SystemAdmins systemAdmins,
       Map<String, SystemStream> changelogSystemStreams,
       Map<String, Set<SystemStream>> sideInputSystemStreams,
@@ -179,22 +183,22 @@ public class ContainerStorageManager {
       SamzaContainerMetrics samzaContainerMetrics,
       JobContext jobContext,
       ContainerContext containerContext,
+      StateBackendFactory stateBackendFactory,
       Map<TaskName, TaskInstanceCollector> taskInstanceCollectors,
       File loggedStoreBaseDirectory,
       File nonLoggedStoreBaseDirectory,
-      int maxChangeLogStreamPartitions,
       SerdeManager serdeManager,
       Clock clock) {
     this.checkpointManager = checkpointManager;
     this.containerModel = containerModel;
     this.taskSideInputStoreSSPs = getTaskSideInputSSPs(containerModel, sideInputSystemStreams);
+    this.sideInputStoreNames = sideInputSystemStreams.keySet();
     this.sideInputTaskLatches = new HashMap<>();
     this.hasSideInputs = this.taskSideInputStoreSSPs.values().stream()
         .flatMap(m -> m.values().stream())
         .flatMap(Collection::stream)
         .findAny()
         .isPresent();
-    this.sspMetadataCache = sspMetadataCache;
     this.changelogSystemStreams = getChangelogSystemStreams(containerModel, changelogSystemStreams); // handling standby tasks
 
     LOG.info("Starting with changelogSystemStreams = {} taskSideInputStoreSSPs = {}", this.changelogSystemStreams, this.taskSideInputStoreSSPs);
@@ -226,15 +230,22 @@ public class ContainerStorageManager {
     // initializing the set of store directory paths
     this.storeDirectoryPaths = new HashSet<>();
 
-    // Setting the restore thread pool size equal to the number of taskInstances
-    this.parallelRestoreThreadPoolSize = containerModel.getTasks().size();
-
-    this.maxChangeLogStreamPartitions = maxChangeLogStreamPartitions;
     this.streamMetadataCache = streamMetadataCache;
     this.systemAdmins = systemAdmins;
 
-    // create taskStores for all tasks in the containerModel and each store in storageEngineFactories
-    this.taskStores = createTaskStores(containerModel, jobContext, containerContext, storageEngineFactories, serdes, taskInstanceMetrics, taskInstanceCollectors);
+    // create side input taskStores for all tasks in the containerModel and each store in storageEngineFactories
+    this.sideInputStores = createTaskStores(sideInputStoreNames, containerModel, jobContext, containerContext,
+        storageEngineFactories, serdes, taskInstanceMetrics, taskInstanceCollectors);
+    StorageConfig storageConfig = new StorageConfig(config);
+    Set<String> inMemoryStoreNames = storageEngineFactories.keySet().stream()
+        .filter(storeName -> {
+          Optional<String> storeFactory = storageConfig.getStorageFactoryClassName(storeName);
+          return storeFactory.isPresent() && !storeFactory.get()
+              .equals(StorageConfig.INMEMORY_KV_STORAGE_ENGINE_FACTORY);
+        })
+        .collect(Collectors.toSet());
+    this.inMemoryStores = createTaskStores(inMemoryStoreNames,
+        this.containerModel, jobContext, containerContext, storageEngineFactories, serdes, taskInstanceMetrics, taskInstanceCollectors);
 
     Set<String> containerChangelogSystems = this.changelogSystemStreams.values().stream()
         .map(SystemStream::getSystem)
@@ -245,8 +256,19 @@ public class ContainerStorageManager {
         containerChangelogSystems, systemFactories, config, this.samzaContainerMetrics.registry());
     this.storeConsumers = createStoreIndexedMap(this.changelogSystemStreams, storeSystemConsumers);
 
+    // TODO HIGH dchen tune based on observed concurrency
+    JobConfig jobConfig = new JobConfig(config);
+    int restoreThreadPoolSize =
+        Math.min(
+            Math.max(containerModel.getTasks().size() * 2, jobConfig.getRestoreThreadPoolSize()),
+            jobConfig.getRestoreThreadPoolMaxSize()
+        );
+    this.restoreExecutor = Executors.newFixedThreadPool(restoreThreadPoolSize,
+        new ThreadFactoryBuilder().setDaemon(true).setNameFormat(RESTORE_THREAD_NAME).build());
+
     // creating task restore managers
-    this.taskRestoreManagers = createTaskRestoreManagers(systemAdmins, clock, this.samzaContainerMetrics);
+    this.taskRestoreManagers = createTaskRestoreManagers(stateBackendFactory, clock,
+        this.samzaContainerMetrics);
 
     this.sspSideInputHandlers = createSideInputHandlers(clock);
 
@@ -311,7 +333,8 @@ public class ContainerStorageManager {
    * @param changelogSystemStreams the passed in set of changelogSystemStreams
    * @return A map of changeLogSSP to storeName across all tasks, assuming no two stores have the same changelogSSP
    */
-  private Map<String, SystemStream> getChangelogSystemStreams(ContainerModel containerModel, Map<String, SystemStream> changelogSystemStreams) {
+  private Map<String, SystemStream> getChangelogSystemStreams(ContainerModel containerModel,
+      Map<String, SystemStream> changelogSystemStreams) {
 
     if (MapUtils.invertMap(changelogSystemStreams).size() != changelogSystemStreams.size()) {
       throw new SamzaException("Two stores cannot have the same changelog system-stream");
@@ -319,20 +342,22 @@ public class ContainerStorageManager {
 
     Map<SystemStreamPartition, String> changelogSSPToStore = new HashMap<>();
     changelogSystemStreams.forEach((storeName, systemStream) ->
-        containerModel.getTasks().forEach((taskName, taskModel) -> { changelogSSPToStore.put(new SystemStreamPartition(systemStream, taskModel.getChangelogPartition()), storeName); })
+        containerModel.getTasks().forEach((taskName, taskModel) ->
+            changelogSSPToStore.put(new SystemStreamPartition(systemStream, taskModel.getChangelogPartition()), storeName))
     );
 
     getTasks(containerModel, TaskMode.Standby).forEach((taskName, taskModel) -> {
-      this.taskSideInputStoreSSPs.putIfAbsent(taskName, new HashMap<>());
+      taskSideInputStoreSSPs.putIfAbsent(taskName, new HashMap<>());
       changelogSystemStreams.forEach((storeName, systemStream) -> {
         SystemStreamPartition ssp = new SystemStreamPartition(systemStream, taskModel.getChangelogPartition());
         changelogSSPToStore.remove(ssp);
-        this.taskSideInputStoreSSPs.get(taskName).put(storeName, Collections.singleton(ssp));
+        taskSideInputStoreSSPs.get(taskName).put(storeName, Collections.singleton(ssp));
       });
     });
 
     // changelogSystemStreams correspond only to active tasks (since those of standby-tasks moved to sideInputs above)
-    return MapUtils.invertMap(changelogSSPToStore).entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, x -> x.getValue().getSystemStream()));
+    return MapUtils.invertMap(changelogSSPToStore).entrySet().stream()
+        .collect(Collectors.toMap(Map.Entry::getKey, x -> x.getValue().getSystemStream()));
   }
 
 
@@ -357,25 +382,35 @@ public class ContainerStorageManager {
   }
 
   private static Map<String, SystemConsumer> createStoreIndexedMap(Map<String, SystemStream> changelogSystemStreams,
-      Map<String, SystemConsumer> storeSystemConsumers) {
+      Map<String, SystemConsumer> systemNameToSystemConsumers) {
     // Map of each storeName to its respective systemConsumer
     Map<String, SystemConsumer> storeConsumers = new HashMap<>();
 
     // Populate the map of storeName to its relevant systemConsumer
     for (String storeName : changelogSystemStreams.keySet()) {
-      storeConsumers.put(storeName, storeSystemConsumers.get(changelogSystemStreams.get(storeName).getSystem()));
+      storeConsumers.put(storeName, systemNameToSystemConsumers.get(changelogSystemStreams.get(storeName).getSystem()));
     }
     return storeConsumers;
   }
 
-  private Map<TaskName, TaskRestoreManager> createTaskRestoreManagers(SystemAdmins systemAdmins, Clock clock, SamzaContainerMetrics samzaContainerMetrics) {
+  private Map<TaskName, TaskRestoreManager> createTaskRestoreManagers(StateBackendFactory factory, Clock clock,
+      SamzaContainerMetrics samzaContainerMetrics) {
     Map<TaskName, TaskRestoreManager> taskRestoreManagers = new HashMap<>();
+
     containerModel.getTasks().forEach((taskName, taskModel) -> {
+      MetricsRegistry taskMetricsRegistry =
+          taskInstanceMetrics.get(taskName) != null ? taskInstanceMetrics.get(taskName).registry() : new MetricsRegistryMap();
+      Set<String> nonSideInputStoreNames = storageEngineFactories.keySet().stream()
+          .filter(storeName -> !sideInputStoreNames.contains(storeName))
+          .collect(Collectors.toSet());
+      KafkaChangelogRestoreParams kafkaChangelogRestoreParams = new KafkaChangelogRestoreParams(storeConsumers,
+          inMemoryStores.get(taskName), systemAdmins.getSystemAdmins(), storageEngineFactories, serdes,
+          taskInstanceCollectors.get(taskName), nonSideInputStoreNames);
+
       taskRestoreManagers.put(taskName,
-          TaskRestoreManagerFactory.create(
-              taskModel, changelogSystemStreams, getNonSideInputStores(taskName), systemAdmins,
-              streamMetadataCache, sspMetadataCache, storeConsumers, maxChangeLogStreamPartitions,
-              loggedStoreBaseDirectory, nonLoggedStoreBaseDirectory, config, clock));
+          factory.getRestoreManager(jobContext, containerContext, taskModel, restoreExecutor,
+              taskMetricsRegistry, config, clock, loggedStoreBaseDirectory, nonLoggedStoreBaseDirectory,
+              kafkaChangelogRestoreParams));
       samzaContainerMetrics.addStoresRestorationGauge(taskName);
     });
     return taskRestoreManagers;
@@ -388,106 +423,77 @@ public class ContainerStorageManager {
   }
 
   /**
-   * Create taskStores for all stores in storageEngineFactories.
-   * The store mode is chosen as bulk-load if its a non-sideinput store, and readWrite if its a sideInput store
+   * Create taskStores for all stores in storesToCreate.
+   * The store mode is chosen as read-write mode.
    */
-  private Map<TaskName, Map<String, StorageEngine>> createTaskStores(ContainerModel containerModel, JobContext jobContext, ContainerContext containerContext,
+  private Map<TaskName, Map<String, StorageEngine>> createTaskStores(Set<String> storesToCreate,
+      ContainerModel containerModel, JobContext jobContext, ContainerContext containerContext,
       Map<String, StorageEngineFactory<Object, Object>> storageEngineFactories, Map<String, Serde<Object>> serdes,
       Map<TaskName, TaskInstanceMetrics> taskInstanceMetrics,
       Map<TaskName, TaskInstanceCollector> taskInstanceCollectors) {
-
     Map<TaskName, Map<String, StorageEngine>> taskStores = new HashMap<>();
+    StorageConfig storageConfig = new StorageConfig(config);
 
-    // iterate over each task in the containerModel, and each store in storageEngineFactories
+    // iterate over each task and each storeName
     for (Map.Entry<TaskName, TaskModel> task : containerModel.getTasks().entrySet()) {
       TaskName taskName = task.getKey();
       TaskModel taskModel = task.getValue();
-
       if (!taskStores.containsKey(taskName)) {
         taskStores.put(taskName, new HashMap<>());
       }
 
-      for (String storeName : storageEngineFactories.keySet()) {
-
-        StorageEngineFactory.StoreMode storeMode = this.taskSideInputStoreSSPs.get(taskName).containsKey(storeName) ?
-            StorageEngineFactory.StoreMode.ReadWrite : StorageEngineFactory.StoreMode.BulkLoad;
+      for (String storeName : storesToCreate) {
+        // A store is considered durable if it is backed by a changelog or another backupManager factory
+        boolean isDurable = changelogSystemStreams.containsKey(storeName) ||
+            !storageConfig.getStoreBackupManagerClassName(storeName).isEmpty();
+        boolean isSideInput = this.sideInputStoreNames.contains(storeName);
+        // Use the logged-store-base-directory for change logged stores and sideInput stores, and non-logged-store-base-dir
+        // for non logged stores
+        File storeBaseDir = isDurable || isSideInput ? this.loggedStoreBaseDirectory : this.nonLoggedStoreBaseDirectory;
+        File storeDirectory = storageManagerUtil.getTaskStoreDir(storeBaseDir, storeName, taskName,
+            taskModel.getTaskMode());
+        this.storeDirectoryPaths.add(storeDirectory.toPath());
+
+        // if taskInstanceMetrics are specified use those for store metrics,
+        // otherwise (in case of StorageRecovery) use a blank MetricsRegistryMap
+        MetricsRegistry storeMetricsRegistry =
+            taskInstanceMetrics.get(taskName) != null ? taskInstanceMetrics.get(taskName).registry() : new MetricsRegistryMap();
 
         StorageEngine storageEngine =
-            createStore(storeName, taskName, taskModel, jobContext, containerContext, storageEngineFactories, serdes, taskInstanceMetrics, taskInstanceCollectors, storeMode);
+            createStore(storeName, storeDirectory, taskModel, jobContext, containerContext, storageEngineFactories,
+                serdes, storeMetricsRegistry, taskInstanceCollectors.get(taskName),
+                StorageEngineFactory.StoreMode.ReadWrite, this.changelogSystemStreams, this.config);
 
         // add created store to map
         taskStores.get(taskName).put(storeName, storageEngine);
 
-        LOG.info("Created store {} for task {} in mode {}", storeName, taskName, storeMode);
+        LOG.info("Created non side input store store {} in read-write mode for task {}", storeName, taskName);
       }
     }
-
     return taskStores;
   }
 
   /**
-   * Recreate all non-sideInput persistent stores in ReadWrite mode.
-   *
-   */
-  private void recreatePersistentTaskStoresInReadWriteMode(ContainerModel containerModel, JobContext jobContext,
-      ContainerContext containerContext, Map<String, StorageEngineFactory<Object, Object>> storageEngineFactories,
-      Map<String, Serde<Object>> serdes, Map<TaskName, TaskInstanceMetrics> taskInstanceMetrics,
-      Map<TaskName, TaskInstanceCollector> taskInstanceCollectors) {
-
-    // iterate over each task and each storeName
-    for (Map.Entry<TaskName, TaskModel> task : containerModel.getTasks().entrySet()) {
-      TaskName taskName = task.getKey();
-      TaskModel taskModel = task.getValue();
-      Map<String, StorageEngine> nonSideInputStores = getNonSideInputStores(taskName);
-
-      for (String storeName : nonSideInputStores.keySet()) {
-
-        // if this store has been already created then re-create and overwrite it only if it is a
-        // persistentStore and a non-sideInputStore, because sideInputStores are always created in RW mode
-        if (nonSideInputStores.get(storeName).getStoreProperties().isPersistedToDisk()) {
-
-          StorageEngine storageEngine =
-              createStore(storeName, taskName, taskModel, jobContext, containerContext, storageEngineFactories, serdes, taskInstanceMetrics, taskInstanceCollectors,
-                  StorageEngineFactory.StoreMode.ReadWrite);
-
-          // add created store to map
-          this.taskStores.get(taskName).put(storeName, storageEngine);
-
-          LOG.info("Re-created store {} in read-write mode for task {} because it a persistent store", storeName, taskName);
-        } else {
-          LOG.info("Skipping re-creation of store {} for task {}", storeName, taskName);
-        }
-      }
-    }
-  }
-
-  /**
    * Method to instantiate a StorageEngine with the given parameters, and populate the storeDirectory paths (used to monitor
    * disk space).
    */
-  private StorageEngine createStore(String storeName, TaskName taskName, TaskModel taskModel, JobContext jobContext,
-      ContainerContext containerContext, Map<String, StorageEngineFactory<Object, Object>> storageEngineFactories,
-      Map<String, Serde<Object>> serdes, Map<TaskName, TaskInstanceMetrics> taskInstanceMetrics,
-      Map<TaskName, TaskInstanceCollector> taskInstanceCollectors, StorageEngineFactory.StoreMode storeMode) {
+  public static StorageEngine createStore(
+      String storeName,
+      File storeDirectory,
+      TaskModel taskModel,
+      JobContext jobContext,
+      ContainerContext containerContext,
+      Map<String, StorageEngineFactory<Object, Object>> storageEngineFactories,
+      Map<String, Serde<Object>> serdes,
+      MetricsRegistry storeMetricsRegistry,
+      MessageCollector messageCollector,
+      StorageEngineFactory.StoreMode storeMode,
+      Map<String, SystemStream> changelogSystemStreams,
+      Config config) {
 
     StorageConfig storageConfig = new StorageConfig(config);
-
-    SystemStreamPartition changeLogSystemStreamPartition =
-        (changelogSystemStreams.containsKey(storeName)) ? new SystemStreamPartition(
-            changelogSystemStreams.get(storeName), taskModel.getChangelogPartition()) : null;
-
-    // Use the logged-store-base-directory for change logged stores and sideInput stores, and non-logged-store-base-dir
-    // for non logged stores
-    File storeDirectory;
-    if (changeLogSystemStreamPartition != null || this.taskSideInputStoreSSPs.get(taskName).containsKey(storeName)) {
-      storeDirectory = storageManagerUtil.getTaskStoreDir(this.loggedStoreBaseDirectory, storeName, taskName,
-          taskModel.getTaskMode());
-    } else {
-      storeDirectory = storageManagerUtil.getTaskStoreDir(this.nonLoggedStoreBaseDirectory, storeName, taskName,
-          taskModel.getTaskMode());
-    }
-
-    this.storeDirectoryPaths.add(storeDirectory.toPath());
+    SystemStreamPartition changeLogSystemStreamPartition = changelogSystemStreams.containsKey(storeName) ?
+        new SystemStreamPartition(changelogSystemStreams.get(storeName), taskModel.getChangelogPartition()) : null;
 
     Optional<String> storageKeySerde = storageConfig.getStorageKeySerde(storeName);
     Serde keySerde = null;
@@ -500,14 +506,8 @@ public class ContainerStorageManager {
       messageSerde = serdes.get(storageMsgSerde.get());
     }
 
-    // if taskInstanceMetrics are specified use those for store metrics,
-    // otherwise (in case of StorageRecovery) use a blank MetricsRegistryMap
-    MetricsRegistry storeMetricsRegistry =
-        taskInstanceMetrics.get(taskName) != null ? taskInstanceMetrics.get(taskName).registry()
-            : new MetricsRegistryMap();
-
     return storageEngineFactories.get(storeName)
-        .getStorageEngine(storeName, storeDirectory, keySerde, messageSerde, taskInstanceCollectors.get(taskName),
+        .getStorageEngine(storeName, storeDirectory, keySerde, messageSerde, messageCollector,
             storeMetricsRegistry, changeLogSystemStreamPartition, jobContext, containerContext, storeMode);
   }
 
@@ -584,10 +584,10 @@ public class ContainerStorageManager {
     if (this.hasSideInputs) {
       containerModel.getTasks().forEach((taskName, taskModel) -> {
 
-        Map<String, StorageEngine> sideInputStores = getSideInputStores(taskName);
+        Map<String, StorageEngine> taskSideInputStores = sideInputStores.get(taskName);
         Map<String, Set<SystemStreamPartition>> sideInputStoresToSSPs = new HashMap<>();
         boolean taskHasSideInputs = false;
-        for (String storeName : sideInputStores.keySet()) {
+        for (String storeName : taskSideInputStores.keySet()) {
           Set<SystemStreamPartition> storeSSPs = this.taskSideInputStoreSSPs.get(taskName).get(storeName);
           taskHasSideInputs = taskHasSideInputs || !storeSSPs.isEmpty();
           sideInputStoresToSSPs.put(storeName, storeSSPs);
@@ -600,7 +600,7 @@ public class ContainerStorageManager {
           TaskSideInputHandler taskSideInputHandler = new TaskSideInputHandler(taskName,
               taskModel.getTaskMode(),
               loggedStoreBaseDirectory,
-              sideInputStores,
+              taskSideInputStores,
               sideInputStoresToSSPs,
               taskSideInputProcessors.get(taskName),
               this.systemAdmins,
@@ -612,73 +612,60 @@ public class ContainerStorageManager {
             handlers.put(ssp, taskSideInputHandler);
           });
 
-          LOG.info("Created TaskSideInputHandler for task {}, sideInputStores {} and loggedStoreBaseDirectory {}",
-              taskName, sideInputStores, loggedStoreBaseDirectory);
+          LOG.info("Created TaskSideInputHandler for task {}, taskSideInputStores {} and loggedStoreBaseDirectory {}",
+              taskName, taskSideInputStores, loggedStoreBaseDirectory);
         }
       });
     }
     return handlers;
   }
 
-  private Map<String, StorageEngine> getSideInputStores(TaskName taskName) {
-    return taskStores.get(taskName).entrySet().stream().
-        filter(e -> this.taskSideInputStoreSSPs.get(taskName).containsKey(e.getKey())).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
-  }
-
-  private Map<String, StorageEngine> getNonSideInputStores(TaskName taskName) {
-    return taskStores.get(taskName).entrySet().stream().
-        filter(e -> !this.taskSideInputStoreSSPs.get(taskName).containsKey(e.getKey())).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
-  }
-
   private Set<TaskSideInputHandler> getSideInputHandlers() {
     return this.sspSideInputHandlers.values().stream().collect(Collectors.toSet());
   }
 
   public void start() throws SamzaException, InterruptedException {
-    Map<SystemStreamPartition, String> checkpointedChangelogSSPOffsets = new HashMap<>();
-    if (new TaskConfig(config).getTransactionalStateRestoreEnabled()) {
-      getTasks(containerModel, TaskMode.Active).forEach((taskName, taskModel) -> {
-        if (checkpointManager != null) {
-          Set<SystemStream> changelogSystemStreams = new HashSet<>(this.changelogSystemStreams.values());
-          Checkpoint checkpoint = checkpointManager.readLastCheckpoint(taskName);
-          if (checkpoint != null) {
-            checkpoint.getOffsets().forEach((ssp, offset) -> {
-              if (changelogSystemStreams.contains(new SystemStream(ssp.getSystem(), ssp.getStream()))) {
-                checkpointedChangelogSSPOffsets.put(ssp, offset);
-              }
-            });
-          }
-        }
-      });
+    // Restores and recreates
+    restoreStores();
+    // Shutdown restore executor since it will no longer be used
+    try {
+      restoreExecutor.shutdown();
+      if (restoreExecutor.awaitTermination(RESTORE_THREAD_POOL_SHUTDOWN_TIMEOUT_SECONDS, TimeUnit.MILLISECONDS)) {
+        restoreExecutor.shutdownNow();
+      }
+    } catch (Exception e) {
+      LOG.error(e.getMessage());
     }
-    LOG.info("Checkpointed changelog ssp offsets: {}", checkpointedChangelogSSPOffsets);
-    restoreStores(checkpointedChangelogSSPOffsets);
     if (this.hasSideInputs) {
       startSideInputs();
     }
+    isStarted = true;
   }
 
   // Restoration of all stores, in parallel across tasks
-  private void restoreStores(Map<SystemStreamPartition, String> checkpointedChangelogSSPOffsets)
-      throws InterruptedException {
+  private void restoreStores() throws InterruptedException {
     LOG.info("Store Restore started");
+    Set<TaskName> activeTasks = getTasks(containerModel, TaskMode.Active).keySet();
 
     // initialize each TaskStorageManager
-    this.taskRestoreManagers.values().forEach(taskStorageManager ->
-       taskStorageManager.init(checkpointedChangelogSSPOffsets));
+    this.taskRestoreManagers.forEach((taskName, taskRestoreManager) -> {
+      Checkpoint taskCheckpoint = null;
+      if (checkpointManager != null && activeTasks.contains(taskName)) {
+        // only pass in checkpoints for active tasks
+        taskCheckpoint = checkpointManager.readLastCheckpoint(taskName);
+        LOG.info("Obtained checkpoint: {} for state restore for taskName: {}", taskCheckpoint, taskName);
+      }
+      taskRestoreManager.init(taskCheckpoint);
+    });
 
     // Start each store consumer once
     this.storeConsumers.values().stream().distinct().forEach(SystemConsumer::start);
 
-    // Create a thread pool for parallel restores (and stopping of persistent stores)
-    ExecutorService executorService = Executors.newFixedThreadPool(this.parallelRestoreThreadPoolSize,
-        new ThreadFactoryBuilder().setDaemon(true).setNameFormat(RESTORE_THREAD_NAME).build());
-
     List<Future> taskRestoreFutures = new ArrayList<>(this.taskRestoreManagers.entrySet().size());
 
     // Submit restore callable for each taskInstance
     this.taskRestoreManagers.forEach((taskInstance, taskRestoreManager) -> {
-      taskRestoreFutures.add(executorService.submit(
+      taskRestoreFutures.add(restoreExecutor.submit(
           new TaskRestoreCallable(this.samzaContainerMetrics, taskInstance, taskRestoreManager)));
     });
 
@@ -690,7 +677,7 @@ public class ContainerStorageManager {
       } catch (InterruptedException e) {
         LOG.warn("Received an interrupt during store restoration. Issuing interrupts to the store restoration workers to exit "
             + "prematurely without restoring full state.");
-        executorService.shutdownNow();
+        restoreExecutor.shutdownNow();
         throw e;
       } catch (Exception e) {
         LOG.error("Exception when restoring ", e);
@@ -698,14 +685,29 @@ public class ContainerStorageManager {
       }
     }
 
-    executorService.shutdown();
-
     // Stop each store consumer once
     this.storeConsumers.values().stream().distinct().forEach(SystemConsumer::stop);
 
-    // Now re-create persistent stores in read-write mode, leave non-persistent stores as-is
-    recreatePersistentTaskStoresInReadWriteMode(this.containerModel, jobContext, containerContext,
+    // Now create persistent non side input stores in read-write mode, leave non-persistent stores as-is
+    Set<String> nonSideInputStoreNames = storageEngineFactories.keySet().stream()
+        .filter(storeName -> !sideInputStoreNames.contains(storeName))
+        .collect(Collectors.toSet());
+    this.taskStores = createTaskStores(nonSideInputStoreNames, this.containerModel, jobContext, containerContext,
         storageEngineFactories, serdes, taskInstanceMetrics, taskInstanceCollectors);
+    // Add in memory stores
+    this.inMemoryStores.forEach((taskName, stores) -> {
+      if (!this.taskStores.containsKey(taskName)) {
+        taskStores.put(taskName, new HashMap<>());
+      }
+      taskStores.get(taskName).putAll(stores);
+    });
+    // Add side input stores
+    this.sideInputStores.forEach((taskName, stores) -> {
+      if (!this.taskStores.containsKey(taskName)) {
+        taskStores.put(taskName, new HashMap<>());
+      }
+      taskStores.get(taskName).putAll(stores);
+    });
 
     LOG.info("Store Restore complete");
   }
@@ -838,15 +840,24 @@ public class ContainerStorageManager {
    * @return the task store.
    */
   public Optional<StorageEngine> getStore(TaskName taskName, String storeName) {
+    if (!isStarted) {
+      throw new SamzaException(String.format(
+          "Attempting to access store %s for task %s before ContainerStorageManager is started.",
+          storeName, taskName));
+    }
     return Optional.ofNullable(this.taskStores.get(taskName).get(storeName));
   }
 
   /**
-   *  Get all {@link StorageEngine} instance used by a given task.
-   * @param taskName  the task name, all stores for which are desired.
+   * Get all {@link StorageEngine} instance used by a given task.
+   * @param taskName the task name, all stores for which are desired.
    * @return map of stores used by the given task, indexed by storename
    */
   public Map<String, StorageEngine> getAllStores(TaskName taskName) {
+    if (!isStarted) {
+      throw new SamzaException(String.format(
+          "Attempting to access stores for task %s before ContainerStorageManager is started.", taskName));
+    }
     return this.taskStores.get(taskName);
   }
 
@@ -865,9 +876,13 @@ public class ContainerStorageManager {
 
   public void shutdown() {
     // stop all nonsideinputstores including persistent and non-persistent stores
-    this.containerModel.getTasks().forEach((taskName, taskModel) ->
-        getNonSideInputStores(taskName).forEach((storeName, store) -> store.stop())
-    );
+    if (taskStores != null) {
+      this.containerModel.getTasks()
+          .forEach((taskName, taskModel) -> taskStores.get(taskName)
+              .entrySet().stream()
+              .filter(e -> !sideInputStoreNames.contains(e.getKey()))
+              .forEach(e -> e.getValue().stop()));
+    }
 
     this.shouldShutdown = true;
 
@@ -926,7 +941,7 @@ public class ContainerStorageManager {
       } finally {
         // Stop all persistent stores after restoring. Certain persistent stores opened in BulkLoad mode are compacted
         // on stop, so paralleling stop() also parallelizes their compaction (a time-intensive operation).
-        taskRestoreManager.stopPersistentStores();
+        taskRestoreManager.close();
         long timeToRestore = System.currentTimeMillis() - startTime;
 
         if (this.samzaContainerMetrics != null) {
diff --git a/samza-core/src/main/scala/org/apache/samza/storage/TaskStorageManagerFactory.java b/samza-core/src/main/scala/org/apache/samza/storage/TaskStorageManagerFactory.java
deleted file mode 100644
index 97e4504..0000000
--- a/samza-core/src/main/scala/org/apache/samza/storage/TaskStorageManagerFactory.java
+++ /dev/null
@@ -1,46 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.samza.storage;
-
-import scala.collection.immutable.Map;
-
-import java.io.File;
-import org.apache.samza.Partition;
-import org.apache.samza.config.Config;
-import org.apache.samza.config.TaskConfig;
-import org.apache.samza.container.TaskName;
-import org.apache.samza.job.model.TaskMode;
-import org.apache.samza.system.SystemAdmins;
-import org.apache.samza.system.SystemStream;
-
-public class TaskStorageManagerFactory {
-  public static TaskStorageManager create(TaskName taskName, ContainerStorageManager containerStorageManager,
-      Map<String, SystemStream> storeChangelogs, SystemAdmins systemAdmins,
-      File loggedStoreBaseDir, Partition changelogPartition,
-      Config config, TaskMode taskMode) {
-    if (new TaskConfig(config).getTransactionalStateCheckpointEnabled()) {
-      return new TransactionalStateTaskStorageManager(taskName, containerStorageManager, storeChangelogs, systemAdmins,
-          loggedStoreBaseDir, changelogPartition, taskMode, new StorageManagerUtil());
-    } else {
-      return new NonTransactionalStateTaskStorageManager(taskName, containerStorageManager, storeChangelogs, systemAdmins,
-          loggedStoreBaseDir, changelogPartition);
-    }
-  }
-}
diff --git a/samza-core/src/test/java/org/apache/samza/job/model/TestJobModel.java b/samza-core/src/test/java/org/apache/samza/job/model/TestJobModel.java
index 77fe639..e66e213 100644
--- a/samza-core/src/test/java/org/apache/samza/job/model/TestJobModel.java
+++ b/samza-core/src/test/java/org/apache/samza/job/model/TestJobModel.java
@@ -45,6 +45,6 @@ public class TestJobModel {
     ContainerModel containerModel2 = new ContainerModel("1", tasksForContainer2);
     Map<String, ContainerModel> containers = ImmutableMap.of("0", containerModel1, "1", containerModel2);
     JobModel jobModel = new JobModel(config, containers);
-    assertEquals(jobModel.maxChangeLogStreamPartitions, 5);
+    assertEquals(jobModel.getMaxChangeLogStreamPartitions(), 5);
   }
 }
\ No newline at end of file
diff --git a/samza-core/src/test/java/org/apache/samza/storage/TestTransactionalStateTaskRestoreManager.java b/samza-core/src/test/java/org/apache/samza/storage/TestTransactionalStateTaskRestoreManager.java
index 2bdd6c3..6d5b45c 100644
--- a/samza-core/src/test/java/org/apache/samza/storage/TestTransactionalStateTaskRestoreManager.java
+++ b/samza-core/src/test/java/org/apache/samza/storage/TestTransactionalStateTaskRestoreManager.java
@@ -33,7 +33,7 @@ import java.util.Map;
 import java.util.Set;
 import org.apache.samza.Partition;
 import org.apache.samza.checkpoint.CheckpointId;
-import org.apache.samza.checkpoint.CheckpointedChangelogOffset;
+import org.apache.samza.checkpoint.kafka.KafkaStateCheckpointMarker;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.MapConfig;
 import org.apache.samza.config.TaskConfig;
@@ -134,7 +134,7 @@ public class TestTransactionalStateTaskRestoreManager {
 
     Map<String, SystemStream> mockStoreChangelogs = ImmutableMap.of();
 
-    Map<SystemStreamPartition, String> mockCheckpointedChangelogOffset = ImmutableMap.of();
+    Map<String, KafkaStateCheckpointMarker> mockCheckpointedChangelogOffset = ImmutableMap.of();
     Map<SystemStreamPartition, SystemStreamPartitionMetadata> mockCurrentChangelogOffsets = ImmutableMap.of();
 
     SystemAdmins mockSystemAdmins = mock(SystemAdmins.class);
@@ -149,7 +149,7 @@ public class TestTransactionalStateTaskRestoreManager {
         .thenReturn(mockCurrentStoreDir);
 
     StoreActions storeActions = TransactionalStateTaskRestoreManager.getStoreActions(
-        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset,
+        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset, null,
         mockCurrentChangelogOffsets, mockSystemAdmins, mockStorageManagerUtil,
         mockLoggedStoreBaseDir, mockNonLoggedStoreBaseDir, mockConfig, mockClock);
 
@@ -185,10 +185,11 @@ public class TestTransactionalStateTaskRestoreManager {
     Map<String, SystemStream> mockStoreChangelogs = ImmutableMap.of(store1Name, changelog1SystemStream);
 
     String changelog1CheckpointedOffset = "5";
-    CheckpointedChangelogOffset changelog1CheckpointMessage =
-        new CheckpointedChangelogOffset(CheckpointId.create(), changelog1CheckpointedOffset);
-    ImmutableMap<SystemStreamPartition, String> mockCheckpointedChangelogOffset =
-        ImmutableMap.of(changelog1SSP, changelog1CheckpointMessage.toString());
+    CheckpointId checkpointId = CheckpointId.create();
+    KafkaStateCheckpointMarker kafkaStateCheckpointMarker =
+        new KafkaStateCheckpointMarker(changelog1SSP, changelog1CheckpointedOffset);
+    ImmutableMap<String, KafkaStateCheckpointMarker> mockCheckpointedChangelogOffset =
+        ImmutableMap.of(store1Name, kafkaStateCheckpointMarker);
     Map<SystemStreamPartition, SystemStreamPartitionMetadata> mockCurrentChangelogOffsets =
         ImmutableMap.of(changelog1SSP, changelog1SSPMetadata);
 
@@ -218,7 +219,7 @@ public class TestTransactionalStateTaskRestoreManager {
         .thenReturn(ImmutableList.of(dummyCheckpointDir));
 
     StoreActions storeActions = TransactionalStateTaskRestoreManager.getStoreActions(
-        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset,
+        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset, checkpointId,
         mockCurrentChangelogOffsets, mockSystemAdmins, mockStorageManagerUtil,
         mockLoggedStoreBaseDir, mockNonLoggedStoreBaseDir, mockConfig, mockClock);
 
@@ -255,10 +256,11 @@ public class TestTransactionalStateTaskRestoreManager {
     Map<String, SystemStream> mockStoreChangelogs = ImmutableMap.of(store1Name, changelog1SystemStream);
 
     String changelog1CheckpointedOffset = "5";
-    CheckpointedChangelogOffset changelog1CheckpointMessage =
-        new CheckpointedChangelogOffset(CheckpointId.create(), changelog1CheckpointedOffset);
-    ImmutableMap<SystemStreamPartition, String> mockCheckpointedChangelogOffset =
-        ImmutableMap.of(changelog1SSP, changelog1CheckpointMessage.toString());
+    CheckpointId checkpointId = CheckpointId.create();
+    KafkaStateCheckpointMarker kafkaStateCheckpointMarker =
+        new KafkaStateCheckpointMarker(changelog1SSP, changelog1CheckpointedOffset);
+    ImmutableMap<String, KafkaStateCheckpointMarker> mockCheckpointedChangelogOffset =
+        ImmutableMap.of(store1Name, kafkaStateCheckpointMarker);
     Map<SystemStreamPartition, SystemStreamPartitionMetadata> mockCurrentChangelogOffsets =
         ImmutableMap.of(changelog1SSP, changelog1SSPMetadata);
 
@@ -279,7 +281,7 @@ public class TestTransactionalStateTaskRestoreManager {
         });
 
     StoreActions storeActions = TransactionalStateTaskRestoreManager.getStoreActions(
-        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset,
+        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset, checkpointId,
         mockCurrentChangelogOffsets, mockSystemAdmins, mockStorageManagerUtil,
         mockLoggedStoreBaseDir, mockNonLoggedStoreBaseDir, mockConfig, mockClock);
 
@@ -319,11 +321,12 @@ public class TestTransactionalStateTaskRestoreManager {
     Map<String, SystemStream> mockStoreChangelogs = ImmutableMap.of(store1Name, changelog1SystemStream);
 
     String changelog1CheckpointedOffset = "21";
-    CheckpointedChangelogOffset changelog1CheckpointMessage =
-        new CheckpointedChangelogOffset(CheckpointId.create(), changelog1CheckpointedOffset);
-    Map<SystemStreamPartition, String> mockCheckpointedChangelogOffset =
-        new HashMap<SystemStreamPartition, String>() { {
-          put(changelog1SSP, changelog1CheckpointMessage.toString());
+    CheckpointId checkpointId = CheckpointId.create();
+    KafkaStateCheckpointMarker kafkaStateCheckpointMarker =
+        new KafkaStateCheckpointMarker(changelog1SSP, changelog1CheckpointedOffset);
+    Map<String, KafkaStateCheckpointMarker> mockCheckpointedChangelogOffset =
+        new HashMap<String, KafkaStateCheckpointMarker>() { {
+          put(store1Name, kafkaStateCheckpointMarker);
         } };
     Map<SystemStreamPartition, SystemStreamPartitionMetadata> mockCurrentChangelogOffsets =
         ImmutableMap.of(changelog1SSP, changelog1SSPMetadata);
@@ -345,7 +348,7 @@ public class TestTransactionalStateTaskRestoreManager {
         });
 
     StoreActions storeActions = TransactionalStateTaskRestoreManager.getStoreActions(
-        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset,
+        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset, checkpointId,
         mockCurrentChangelogOffsets, mockSystemAdmins, mockStorageManagerUtil,
         mockLoggedStoreBaseDir, mockNonLoggedStoreBaseDir, mockConfig, mockClock);
 
@@ -386,11 +389,12 @@ public class TestTransactionalStateTaskRestoreManager {
     Map<String, SystemStream> mockStoreChangelogs = ImmutableMap.of(store1Name, changelog1SystemStream);
 
     String changelog1CheckpointedOffset = "5";
-    CheckpointedChangelogOffset changelog1CheckpointMessage =
-        new CheckpointedChangelogOffset(CheckpointId.create(), changelog1CheckpointedOffset);
-    Map<SystemStreamPartition, String> mockCheckpointedChangelogOffset =
-        new HashMap<SystemStreamPartition, String>() { {
-          put(changelog1SSP, changelog1CheckpointMessage.toString());
+    CheckpointId checkpointId = CheckpointId.create();
+    KafkaStateCheckpointMarker kafkaStateCheckpointMarker =
+        new KafkaStateCheckpointMarker(changelog1SSP, changelog1CheckpointedOffset);
+    Map<String, KafkaStateCheckpointMarker> mockCheckpointedChangelogOffset =
+        new HashMap<String, KafkaStateCheckpointMarker>() { {
+          put(store1Name, kafkaStateCheckpointMarker);
         } };
     Map<SystemStreamPartition, SystemStreamPartitionMetadata> mockCurrentChangelogOffsets =
         ImmutableMap.of(changelog1SSP, changelog1SSPMetadata);
@@ -412,7 +416,7 @@ public class TestTransactionalStateTaskRestoreManager {
         });
 
     StoreActions storeActions = TransactionalStateTaskRestoreManager.getStoreActions(
-        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset,
+        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset, checkpointId,
         mockCurrentChangelogOffsets, mockSystemAdmins, mockStorageManagerUtil,
         mockLoggedStoreBaseDir, mockNonLoggedStoreBaseDir, mockConfig, mockClock);
 
@@ -453,12 +457,12 @@ public class TestTransactionalStateTaskRestoreManager {
     Map<String, SystemStream> mockStoreChangelogs = ImmutableMap.of(store1Name, changelog1SystemStream);
 
     String changelog1CheckpointedOffset = "5";
-    CheckpointId checkpointId = CheckpointId.fromString("0-0"); // checkpoint id older than default min.compaction.lag.ms
-    CheckpointedChangelogOffset changelog1CheckpointMessage =
-        new CheckpointedChangelogOffset(checkpointId, changelog1CheckpointedOffset);
-    Map<SystemStreamPartition, String> mockCheckpointedChangelogOffset =
-        new HashMap<SystemStreamPartition, String>() { {
-          put(changelog1SSP, changelog1CheckpointMessage.toString());
+    CheckpointId checkpointId = CheckpointId.deserialize("0-0"); // checkpoint id older than default min.compaction.lag.ms
+    KafkaStateCheckpointMarker kafkaStateCheckpointMarker =
+        new KafkaStateCheckpointMarker(changelog1SSP, changelog1CheckpointedOffset);
+    Map<String, KafkaStateCheckpointMarker> mockCheckpointedChangelogOffset =
+        new HashMap<String, KafkaStateCheckpointMarker>() { {
+          put(store1Name, kafkaStateCheckpointMarker);
         } };
     Map<SystemStreamPartition, SystemStreamPartitionMetadata> mockCurrentChangelogOffsets =
         ImmutableMap.of(changelog1SSP, changelog1SSPMetadata);
@@ -480,7 +484,7 @@ public class TestTransactionalStateTaskRestoreManager {
         });
 
     StoreActions storeActions = TransactionalStateTaskRestoreManager.getStoreActions(
-        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset,
+        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset, checkpointId,
         mockCurrentChangelogOffsets, mockSystemAdmins, mockStorageManagerUtil,
         mockLoggedStoreBaseDir, mockNonLoggedStoreBaseDir, mockConfig, mockClock);
 
@@ -524,11 +528,12 @@ public class TestTransactionalStateTaskRestoreManager {
     Map<String, SystemStream> mockStoreChangelogs = ImmutableMap.of(store1Name, changelog1SystemStream);
 
     String changelog1CheckpointedOffset = null;
-    CheckpointedChangelogOffset changelog1CheckpointMessage =
-        new CheckpointedChangelogOffset(CheckpointId.create(), changelog1CheckpointedOffset);
-    Map<SystemStreamPartition, String> mockCheckpointedChangelogOffset =
-        new HashMap<SystemStreamPartition, String>() { {
-          put(changelog1SSP, changelog1CheckpointMessage.toString());
+    CheckpointId checkpointId = CheckpointId.create();
+    KafkaStateCheckpointMarker kafkaStateCheckpointMarker =
+        new KafkaStateCheckpointMarker(changelog1SSP, changelog1CheckpointedOffset);
+    Map<String, KafkaStateCheckpointMarker> mockCheckpointedChangelogOffset =
+        new HashMap<String, KafkaStateCheckpointMarker>() { {
+          put(store1Name, kafkaStateCheckpointMarker);
         } };
     Map<SystemStreamPartition, SystemStreamPartitionMetadata> mockCurrentChangelogOffsets =
         ImmutableMap.of(changelog1SSP, changelog1SSPMetadata);
@@ -552,7 +557,7 @@ public class TestTransactionalStateTaskRestoreManager {
         });
 
     StoreActions storeActions = TransactionalStateTaskRestoreManager.getStoreActions(
-        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset,
+        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset, checkpointId,
         mockCurrentChangelogOffsets, mockSystemAdmins, mockStorageManagerUtil,
         mockLoggedStoreBaseDir, mockNonLoggedStoreBaseDir, mockConfig, mockClock);
 
@@ -588,11 +593,12 @@ public class TestTransactionalStateTaskRestoreManager {
     Map<String, SystemStream> mockStoreChangelogs = ImmutableMap.of(store1Name, changelog1SystemStream);
 
     String changelog1CheckpointedOffset = null;
-    CheckpointedChangelogOffset changelog1CheckpointMessage =
-        new CheckpointedChangelogOffset(CheckpointId.create(), changelog1CheckpointedOffset);
-    Map<SystemStreamPartition, String> mockCheckpointedChangelogOffset =
-        new HashMap<SystemStreamPartition, String>() { {
-          put(changelog1SSP, changelog1CheckpointMessage.toString());
+    CheckpointId checkpointId = CheckpointId.create();
+    KafkaStateCheckpointMarker kafkaStateCheckpointMarker =
+        new KafkaStateCheckpointMarker(changelog1SSP, changelog1CheckpointedOffset);
+    Map<String, KafkaStateCheckpointMarker> mockCheckpointedChangelogOffset =
+        new HashMap<String, KafkaStateCheckpointMarker>() { {
+          put(store1Name, kafkaStateCheckpointMarker);
         } };
     Map<SystemStreamPartition, SystemStreamPartitionMetadata> mockCurrentChangelogOffsets =
         ImmutableMap.of(changelog1SSP, changelog1SSPMetadata);
@@ -616,7 +622,7 @@ public class TestTransactionalStateTaskRestoreManager {
         });
 
     StoreActions storeActions = TransactionalStateTaskRestoreManager.getStoreActions(
-        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset,
+        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset, checkpointId,
         mockCurrentChangelogOffsets, mockSystemAdmins, mockStorageManagerUtil,
         mockLoggedStoreBaseDir, mockNonLoggedStoreBaseDir, mockConfig, mockClock);
 
@@ -652,10 +658,11 @@ public class TestTransactionalStateTaskRestoreManager {
     Map<String, SystemStream> mockStoreChangelogs = ImmutableMap.of(store1Name, changelog1SystemStream);
 
     String changelog1CheckpointedOffset = "5";
-    CheckpointedChangelogOffset changelog1CheckpointMessage =
-        new CheckpointedChangelogOffset(CheckpointId.create(), changelog1CheckpointedOffset);
-    ImmutableMap<SystemStreamPartition, String> mockCheckpointedChangelogOffset =
-        ImmutableMap.of(changelog1SSP, changelog1CheckpointMessage.toString());
+    CheckpointId checkpointId = CheckpointId.create();
+    KafkaStateCheckpointMarker kafkaStateCheckpointMarker =
+        new KafkaStateCheckpointMarker(changelog1SSP, changelog1CheckpointedOffset);
+    ImmutableMap<String, KafkaStateCheckpointMarker> mockCheckpointedChangelogOffset =
+        ImmutableMap.of(store1Name, kafkaStateCheckpointMarker);
     Map<SystemStreamPartition, SystemStreamPartitionMetadata> mockCurrentChangelogOffsets =
         ImmutableMap.of(changelog1SSP, changelog1SSPMetadata);
 
@@ -679,7 +686,7 @@ public class TestTransactionalStateTaskRestoreManager {
         });
 
     StoreActions storeActions = TransactionalStateTaskRestoreManager.getStoreActions(
-        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset,
+        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset, checkpointId,
         mockCurrentChangelogOffsets, mockSystemAdmins, mockStorageManagerUtil,
         mockLoggedStoreBaseDir, mockNonLoggedStoreBaseDir, mockConfig, mockClock);
 
@@ -717,10 +724,11 @@ public class TestTransactionalStateTaskRestoreManager {
     Map<String, SystemStream> mockStoreChangelogs = ImmutableMap.of(store1Name, changelog1SystemStream);
 
     String changelog1CheckpointedOffset = "5";
-    CheckpointedChangelogOffset changelog1CheckpointMessage =
-        new CheckpointedChangelogOffset(CheckpointId.create(), changelog1CheckpointedOffset);
-    ImmutableMap<SystemStreamPartition, String> mockCheckpointedChangelogOffset =
-        ImmutableMap.of(changelog1SSP, changelog1CheckpointMessage.toString());
+    CheckpointId checkpointId = CheckpointId.create();
+    KafkaStateCheckpointMarker kafkaStateCheckpointMarker =
+        new KafkaStateCheckpointMarker(changelog1SSP, changelog1CheckpointedOffset);
+    ImmutableMap<String, KafkaStateCheckpointMarker> mockCheckpointedChangelogOffset =
+        ImmutableMap.of(store1Name, kafkaStateCheckpointMarker);
     Map<SystemStreamPartition, SystemStreamPartitionMetadata> mockCurrentChangelogOffsets =
         ImmutableMap.of(changelog1SSP, changelog1SSPMetadata);
 
@@ -760,7 +768,7 @@ public class TestTransactionalStateTaskRestoreManager {
         });
 
     StoreActions storeActions = TransactionalStateTaskRestoreManager.getStoreActions(
-        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset,
+        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset, checkpointId,
         mockCurrentChangelogOffsets, mockSystemAdmins, mockStorageManagerUtil,
         mockLoggedStoreBaseDir, mockNonLoggedStoreBaseDir, mockConfig, mockClock);
 
@@ -800,10 +808,11 @@ public class TestTransactionalStateTaskRestoreManager {
     Map<String, SystemStream> mockStoreChangelogs = ImmutableMap.of(store1Name, changelog1SystemStream);
 
     String changelog1CheckpointedOffset = "5";
-    CheckpointedChangelogOffset changelog1CheckpointMessage =
-        new CheckpointedChangelogOffset(CheckpointId.create(), changelog1CheckpointedOffset);
-    ImmutableMap<SystemStreamPartition, String> mockCheckpointedChangelogOffset =
-        ImmutableMap.of(changelog1SSP, changelog1CheckpointMessage.toString());
+    CheckpointId checkpointId = CheckpointId.create();
+    KafkaStateCheckpointMarker kafkaStateCheckpointMarker =
+        new KafkaStateCheckpointMarker(changelog1SSP, changelog1CheckpointedOffset);
+    ImmutableMap<String, KafkaStateCheckpointMarker> mockCheckpointedChangelogOffset =
+        ImmutableMap.of(store1Name, kafkaStateCheckpointMarker);
     Map<SystemStreamPartition, SystemStreamPartitionMetadata> mockCurrentChangelogOffsets =
         ImmutableMap.of(changelog1SSP, changelog1SSPMetadata);
 
@@ -843,7 +852,7 @@ public class TestTransactionalStateTaskRestoreManager {
         });
 
     StoreActions storeActions = TransactionalStateTaskRestoreManager.getStoreActions(
-        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset,
+        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset, checkpointId,
         mockCurrentChangelogOffsets, mockSystemAdmins, mockStorageManagerUtil,
         mockLoggedStoreBaseDir, mockNonLoggedStoreBaseDir, mockConfig, mockClock);
 
@@ -886,10 +895,11 @@ public class TestTransactionalStateTaskRestoreManager {
     Map<String, SystemStream> mockStoreChangelogs = ImmutableMap.of(store1Name, changelog1SystemStream);
 
     String changelog1CheckpointedOffset = "5";
-    CheckpointedChangelogOffset changelog1CheckpointMessage =
-        new CheckpointedChangelogOffset(CheckpointId.create(), changelog1CheckpointedOffset);
-    ImmutableMap<SystemStreamPartition, String> mockCheckpointedChangelogOffset =
-        ImmutableMap.of(changelog1SSP, changelog1CheckpointMessage.toString());
+    CheckpointId checkpointId = CheckpointId.create();
+    KafkaStateCheckpointMarker kafkaStateCheckpointMarker =
+        new KafkaStateCheckpointMarker(changelog1SSP, changelog1CheckpointedOffset);
+    ImmutableMap<String, KafkaStateCheckpointMarker> mockCheckpointedChangelogOffset =
+        ImmutableMap.of(store1Name, kafkaStateCheckpointMarker);
     Map<SystemStreamPartition, SystemStreamPartitionMetadata> mockCurrentChangelogOffsets =
         ImmutableMap.of(changelog1SSP, changelog1SSPMetadata);
 
@@ -927,7 +937,7 @@ public class TestTransactionalStateTaskRestoreManager {
         });
 
     StoreActions storeActions = TransactionalStateTaskRestoreManager.getStoreActions(
-        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset,
+        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset, checkpointId,
         mockCurrentChangelogOffsets, mockSystemAdmins, mockStorageManagerUtil,
         mockLoggedStoreBaseDir, mockNonLoggedStoreBaseDir, mockConfig, mockClock);
 
@@ -971,10 +981,11 @@ public class TestTransactionalStateTaskRestoreManager {
     Map<String, SystemStream> mockStoreChangelogs = ImmutableMap.of(store1Name, changelog1SystemStream);
 
     String changelog1CheckpointedOffset = "5";
-    CheckpointedChangelogOffset changelog1CheckpointMessage =
-        new CheckpointedChangelogOffset(CheckpointId.create(), changelog1CheckpointedOffset);
-    ImmutableMap<SystemStreamPartition, String> mockCheckpointedChangelogOffset =
-        ImmutableMap.of(changelog1SSP, changelog1CheckpointMessage.toString());
+    CheckpointId checkpointId = CheckpointId.create();
+    KafkaStateCheckpointMarker kafkaStateCheckpointMarker =
+        new KafkaStateCheckpointMarker(changelog1SSP, changelog1CheckpointedOffset);
+    ImmutableMap<String, KafkaStateCheckpointMarker> mockCheckpointedChangelogOffset =
+        ImmutableMap.of(store1Name, kafkaStateCheckpointMarker);
     Map<SystemStreamPartition, SystemStreamPartitionMetadata> mockCurrentChangelogOffsets =
         ImmutableMap.of(changelog1SSP, changelog1SSPMetadata);
 
@@ -1012,7 +1023,7 @@ public class TestTransactionalStateTaskRestoreManager {
         });
 
     StoreActions storeActions = TransactionalStateTaskRestoreManager.getStoreActions(
-        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset,
+        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset, checkpointId,
         mockCurrentChangelogOffsets, mockSystemAdmins, mockStorageManagerUtil,
         mockLoggedStoreBaseDir, mockNonLoggedStoreBaseDir, mockConfig, mockClock);
 
@@ -1056,10 +1067,11 @@ public class TestTransactionalStateTaskRestoreManager {
     Map<String, SystemStream> mockStoreChangelogs = ImmutableMap.of(store1Name, changelog1SystemStream);
 
     String changelog1CheckpointedOffset = "5";
-    CheckpointedChangelogOffset changelog1CheckpointMessage =
-        new CheckpointedChangelogOffset(CheckpointId.create(), changelog1CheckpointedOffset);
-    ImmutableMap<SystemStreamPartition, String> mockCheckpointedChangelogOffset =
-        ImmutableMap.of(changelog1SSP, changelog1CheckpointMessage.toString());
+    CheckpointId checkpointId = CheckpointId.create();
+    KafkaStateCheckpointMarker kafkaStateCheckpointMarker =
+        new KafkaStateCheckpointMarker(changelog1SSP, changelog1CheckpointedOffset);
+    ImmutableMap<String, KafkaStateCheckpointMarker> mockCheckpointedChangelogOffset =
+        ImmutableMap.of(store1Name, kafkaStateCheckpointMarker);
     Map<SystemStreamPartition, SystemStreamPartitionMetadata> mockCurrentChangelogOffsets =
         ImmutableMap.of(changelog1SSP, changelog1SSPMetadata);
 
@@ -1097,7 +1109,7 @@ public class TestTransactionalStateTaskRestoreManager {
         });
 
     StoreActions storeActions = TransactionalStateTaskRestoreManager.getStoreActions(
-        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset,
+        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset, checkpointId,
         mockCurrentChangelogOffsets, mockSystemAdmins, mockStorageManagerUtil,
         mockLoggedStoreBaseDir, mockNonLoggedStoreBaseDir, mockConfig, mockClock);
 
@@ -1144,11 +1156,12 @@ public class TestTransactionalStateTaskRestoreManager {
     Map<String, SystemStream> mockStoreChangelogs = ImmutableMap.of(store1Name, changelog1SystemStream);
 
     String changelog1CheckpointedOffset = null;
-    CheckpointedChangelogOffset changelog1CheckpointMessage =
-        new CheckpointedChangelogOffset(CheckpointId.create(), changelog1CheckpointedOffset);
-    Map<SystemStreamPartition, String> mockCheckpointedChangelogOffset =
-        new HashMap<SystemStreamPartition, String>() { {
-          put(changelog1SSP, changelog1CheckpointMessage.toString());
+    CheckpointId checkpointId = CheckpointId.create();
+    KafkaStateCheckpointMarker kafkaStateCheckpointMarker =
+        new KafkaStateCheckpointMarker(changelog1SSP, changelog1CheckpointedOffset);
+    Map<String, KafkaStateCheckpointMarker> mockCheckpointedChangelogOffset =
+        new HashMap<String, KafkaStateCheckpointMarker>() { {
+          put(store1Name, kafkaStateCheckpointMarker);
         } };
     Map<SystemStreamPartition, SystemStreamPartitionMetadata> mockCurrentChangelogOffsets =
         ImmutableMap.of(changelog1SSP, changelog1SSPMetadata);
@@ -1191,7 +1204,7 @@ public class TestTransactionalStateTaskRestoreManager {
         });
 
     StoreActions storeActions = TransactionalStateTaskRestoreManager.getStoreActions(
-        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset,
+        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset, checkpointId,
         mockCurrentChangelogOffsets, mockSystemAdmins, mockStorageManagerUtil,
         mockLoggedStoreBaseDir, mockNonLoggedStoreBaseDir, mockConfig, mockClock);
 
@@ -1239,11 +1252,12 @@ public class TestTransactionalStateTaskRestoreManager {
     Map<String, SystemStream> mockStoreChangelogs = ImmutableMap.of(store1Name, changelog1SystemStream);
 
     String changelog1CheckpointedOffset = null;
-    CheckpointedChangelogOffset changelog1CheckpointMessage =
-        new CheckpointedChangelogOffset(CheckpointId.create(), changelog1CheckpointedOffset);
-    Map<SystemStreamPartition, String> mockCheckpointedChangelogOffset =
-        new HashMap<SystemStreamPartition, String>() { {
-          put(changelog1SSP, changelog1CheckpointMessage.toString());
+    CheckpointId checkpointId = CheckpointId.create();
+    KafkaStateCheckpointMarker kafkaStateCheckpointMarker =
+        new KafkaStateCheckpointMarker(changelog1SSP, changelog1CheckpointedOffset);
+    Map<String, KafkaStateCheckpointMarker> mockCheckpointedChangelogOffset =
+        new HashMap<String, KafkaStateCheckpointMarker>() { {
+          put(store1Name, kafkaStateCheckpointMarker);
         } };
     Map<SystemStreamPartition, SystemStreamPartitionMetadata> mockCurrentChangelogOffsets =
         ImmutableMap.of(changelog1SSP, changelog1SSPMetadata);
@@ -1286,7 +1300,7 @@ public class TestTransactionalStateTaskRestoreManager {
         });
 
     StoreActions storeActions = TransactionalStateTaskRestoreManager.getStoreActions(
-        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset,
+        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset, checkpointId,
         mockCurrentChangelogOffsets, mockSystemAdmins, mockStorageManagerUtil,
         mockLoggedStoreBaseDir, mockNonLoggedStoreBaseDir, mockConfig, mockClock);
 
@@ -1332,11 +1346,12 @@ public class TestTransactionalStateTaskRestoreManager {
     Map<String, SystemStream> mockStoreChangelogs = ImmutableMap.of(store1Name, changelog1SystemStream);
 
     String changelog1CheckpointedOffset = "5";
-    CheckpointedChangelogOffset changelog1CheckpointMessage =
-        new CheckpointedChangelogOffset(CheckpointId.create(), changelog1CheckpointedOffset);
-    Map<SystemStreamPartition, String> mockCheckpointedChangelogOffset =
-        new HashMap<SystemStreamPartition, String>() { {
-          put(changelog1SSP, changelog1CheckpointMessage.toString());
+    CheckpointId checkpointId = CheckpointId.create();
+    KafkaStateCheckpointMarker kafkaStateCheckpointMarker =
+        new KafkaStateCheckpointMarker(changelog1SSP, changelog1CheckpointedOffset);
+    Map<String, KafkaStateCheckpointMarker> mockCheckpointedChangelogOffset =
+        new HashMap<String, KafkaStateCheckpointMarker>() { {
+          put(store1Name, kafkaStateCheckpointMarker);
         } };
     Map<SystemStreamPartition, SystemStreamPartitionMetadata> mockCurrentChangelogOffsets =
         ImmutableMap.of(changelog1SSP, changelog1SSPMetadata);
@@ -1376,7 +1391,7 @@ public class TestTransactionalStateTaskRestoreManager {
         });
 
     StoreActions storeActions = TransactionalStateTaskRestoreManager.getStoreActions(
-        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset,
+        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset, checkpointId,
         mockCurrentChangelogOffsets, mockSystemAdmins, mockStorageManagerUtil,
         mockLoggedStoreBaseDir, mockNonLoggedStoreBaseDir, mockConfig, mockClock);
 
@@ -1421,11 +1436,12 @@ public class TestTransactionalStateTaskRestoreManager {
     Map<String, SystemStream> mockStoreChangelogs = ImmutableMap.of(store1Name, changelog1SystemStream);
 
     String changelog1CheckpointedOffset = null;
-    CheckpointedChangelogOffset changelog1CheckpointMessage =
-        new CheckpointedChangelogOffset(CheckpointId.create(), changelog1CheckpointedOffset);
-    Map<SystemStreamPartition, String> mockCheckpointedChangelogOffset =
-        new HashMap<SystemStreamPartition, String>() { {
-          put(changelog1SSP, changelog1CheckpointMessage.toString());
+    CheckpointId checkpointId = CheckpointId.create();
+    KafkaStateCheckpointMarker kafkaStateCheckpointMarker =
+        new KafkaStateCheckpointMarker(changelog1SSP, changelog1CheckpointedOffset);
+    Map<String, KafkaStateCheckpointMarker> mockCheckpointedChangelogOffset =
+        new HashMap<String, KafkaStateCheckpointMarker>() { {
+          put(store1Name, kafkaStateCheckpointMarker);
         } };
     Map<SystemStreamPartition, SystemStreamPartitionMetadata> mockCurrentChangelogOffsets =
         ImmutableMap.of(changelog1SSP, changelog1SSPMetadata);
@@ -1471,7 +1487,7 @@ public class TestTransactionalStateTaskRestoreManager {
         });
 
     StoreActions storeActions = TransactionalStateTaskRestoreManager.getStoreActions(
-        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset,
+        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset, checkpointId,
         mockCurrentChangelogOffsets, mockSystemAdmins, mockStorageManagerUtil,
         mockLoggedStoreBaseDir, mockNonLoggedStoreBaseDir, mockConfig, mockClock);
 
@@ -1516,11 +1532,12 @@ public class TestTransactionalStateTaskRestoreManager {
     Map<String, SystemStream> mockStoreChangelogs = ImmutableMap.of(store1Name, changelog1SystemStream);
 
     String changelog1CheckpointedOffset = "5";
-    CheckpointedChangelogOffset changelog1CheckpointMessage =
-        new CheckpointedChangelogOffset(CheckpointId.create(), changelog1CheckpointedOffset);
-    Map<SystemStreamPartition, String> mockCheckpointedChangelogOffset =
-        new HashMap<SystemStreamPartition, String>() { {
-          put(changelog1SSP, changelog1CheckpointMessage.toString());
+    CheckpointId checkpointId = CheckpointId.create();
+    KafkaStateCheckpointMarker kafkaStateCheckpointMarker =
+        new KafkaStateCheckpointMarker(changelog1SSP, changelog1CheckpointedOffset);
+    Map<String, KafkaStateCheckpointMarker> mockCheckpointedChangelogOffset =
+        new HashMap<String, KafkaStateCheckpointMarker>() { {
+          put(store1Name, kafkaStateCheckpointMarker);
         } };
     Map<SystemStreamPartition, SystemStreamPartitionMetadata> mockCurrentChangelogOffsets =
         ImmutableMap.of(changelog1SSP, changelog1SSPMetadata);
@@ -1561,7 +1578,7 @@ public class TestTransactionalStateTaskRestoreManager {
         });
 
     StoreActions storeActions = TransactionalStateTaskRestoreManager.getStoreActions(
-        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset,
+        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset, checkpointId,
         mockCurrentChangelogOffsets, mockSystemAdmins, mockStorageManagerUtil,
         mockLoggedStoreBaseDir, mockNonLoggedStoreBaseDir, mockConfig, mockClock);
 
@@ -1606,12 +1623,12 @@ public class TestTransactionalStateTaskRestoreManager {
     Map<String, SystemStream> mockStoreChangelogs = ImmutableMap.of(store1Name, changelog1SystemStream);
 
     String changelog1CheckpointedOffset = "5";
-    CheckpointId checkpointId = CheckpointId.fromString("0-0"); // checkpoint timestamp older than default min compaction lag
-    CheckpointedChangelogOffset changelog1CheckpointMessage =
-        new CheckpointedChangelogOffset(checkpointId, changelog1CheckpointedOffset);
-    Map<SystemStreamPartition, String> mockCheckpointedChangelogOffset =
-        new HashMap<SystemStreamPartition, String>() { {
-          put(changelog1SSP, changelog1CheckpointMessage.toString());
+    CheckpointId checkpointId = CheckpointId.deserialize("0-0"); // checkpoint timestamp older than default min compaction lag
+    KafkaStateCheckpointMarker kafkaStateCheckpointMarker =
+        new KafkaStateCheckpointMarker(changelog1SSP, changelog1CheckpointedOffset);
+    Map<String, KafkaStateCheckpointMarker> mockCheckpointedChangelogOffset =
+        new HashMap<String, KafkaStateCheckpointMarker>() { {
+          put(store1Name, kafkaStateCheckpointMarker);
         } };
     Map<SystemStreamPartition, SystemStreamPartitionMetadata> mockCurrentChangelogOffsets =
         ImmutableMap.of(changelog1SSP, changelog1SSPMetadata);
@@ -1652,7 +1669,7 @@ public class TestTransactionalStateTaskRestoreManager {
         });
 
     StoreActions storeActions = TransactionalStateTaskRestoreManager.getStoreActions(
-        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset,
+        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset, checkpointId,
         mockCurrentChangelogOffsets, mockSystemAdmins, mockStorageManagerUtil,
         mockLoggedStoreBaseDir, mockNonLoggedStoreBaseDir, mockConfig, mockClock);
 
@@ -1697,11 +1714,12 @@ public class TestTransactionalStateTaskRestoreManager {
     Map<String, SystemStream> mockStoreChangelogs = ImmutableMap.of(store1Name, changelog1SystemStream);
 
     String changelog1CheckpointedOffset = "21";
-    CheckpointedChangelogOffset changelog1CheckpointMessage =
-        new CheckpointedChangelogOffset(CheckpointId.create(), changelog1CheckpointedOffset);
-    Map<SystemStreamPartition, String> mockCheckpointedChangelogOffset =
-        new HashMap<SystemStreamPartition, String>() { {
-          put(changelog1SSP, changelog1CheckpointMessage.toString());
+    CheckpointId checkpointId = CheckpointId.create();
+    KafkaStateCheckpointMarker kafkaStateCheckpointMarker =
+        new KafkaStateCheckpointMarker(changelog1SSP, changelog1CheckpointedOffset);
+    Map<String, KafkaStateCheckpointMarker> mockCheckpointedChangelogOffset =
+        new HashMap<String, KafkaStateCheckpointMarker>() { {
+          put(store1Name, kafkaStateCheckpointMarker);
         } };
     Map<SystemStreamPartition, SystemStreamPartitionMetadata> mockCurrentChangelogOffsets =
         ImmutableMap.of(changelog1SSP, changelog1SSPMetadata);
@@ -1742,7 +1760,7 @@ public class TestTransactionalStateTaskRestoreManager {
         });
 
     StoreActions storeActions = TransactionalStateTaskRestoreManager.getStoreActions(
-        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset,
+        mockTaskModel, mockStoreEngines, mockStoreChangelogs, mockCheckpointedChangelogOffset, checkpointId,
         mockCurrentChangelogOffsets, mockSystemAdmins, mockStorageManagerUtil,
         mockLoggedStoreBaseDir, mockNonLoggedStoreBaseDir, mockConfig, mockClock);
 
diff --git a/samza-core/src/test/scala/org/apache/samza/storage/TestContainerStorageManager.java b/samza-core/src/test/scala/org/apache/samza/storage/TestContainerStorageManager.java
index c36a3be..f381f3d 100644
--- a/samza-core/src/test/scala/org/apache/samza/storage/TestContainerStorageManager.java
+++ b/samza-core/src/test/scala/org/apache/samza/storage/TestContainerStorageManager.java
@@ -24,8 +24,8 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Map;
 import org.apache.samza.Partition;
-import org.apache.samza.checkpoint.Checkpoint;
 import org.apache.samza.checkpoint.CheckpointManager;
+import org.apache.samza.checkpoint.CheckpointV1;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.MapConfig;
 import org.apache.samza.config.TaskConfig;
@@ -205,24 +205,36 @@ public class TestContainerStorageManager {
             new scala.collection.immutable.Map.Map1(new SystemStream(SYSTEM_NAME, STREAM_NAME), systemStreamMetadata));
 
     CheckpointManager checkpointManager = mock(CheckpointManager.class);
-    when(checkpointManager.readLastCheckpoint(any(TaskName.class))).thenReturn(new Checkpoint(new HashMap<>()));
+    when(checkpointManager.readLastCheckpoint(any(TaskName.class))).thenReturn(new CheckpointV1(new HashMap<>()));
 
     SSPMetadataCache mockSSPMetadataCache = mock(SSPMetadataCache.class);
     when(mockSSPMetadataCache.getMetadata(any(SystemStreamPartition.class)))
         .thenReturn(new SystemStreamMetadata.SystemStreamPartitionMetadata("0", "10", "11"));
 
+    ContainerContext mockContainerContext = mock(ContainerContext.class);
+    ContainerModel mockContainerModel = new ContainerModel("samza-container-test", tasks);
+    when(mockContainerContext.getContainerModel()).thenReturn(mockContainerModel);
+
     // Reset the  expected number of sysConsumer create, start and stop calls, and store.restore() calls
     this.systemConsumerCreationCount = 0;
     this.systemConsumerStartCount = 0;
     this.systemConsumerStopCount = 0;
     this.storeRestoreCallCount = 0;
 
+    StateBackendFactory backendFactory = mock(StateBackendFactory.class);
+    TaskRestoreManager restoreManager = mock(TaskRestoreManager.class);
+    when(backendFactory.getRestoreManager(any(), any(), any(), any(), any(), any(), any(), any(), any(), any()))
+        .thenReturn(restoreManager);
+    doAnswer(invocation -> {
+      storeRestoreCallCount++;
+      return null;
+    }).when(restoreManager).restore();
+
     // Create the container storage manager
     this.containerStorageManager = new ContainerStorageManager(
         checkpointManager,
-        new ContainerModel("samza-container-test", tasks),
+        mockContainerModel,
         mockStreamMetadataCache,
-        mockSSPMetadataCache,
         mockSystemAdmins,
         changelogSystemStreams,
         new HashMap<>(),
@@ -233,11 +245,11 @@ public class TestContainerStorageManager {
         taskInstanceMetrics,
         samzaContainerMetrics,
         mock(JobContext.class),
-        mock(ContainerContext.class),
+        mockContainerContext,
+        backendFactory,
         mock(Map.class),
         DEFAULT_LOGGED_STORE_BASE_DIR,
         DEFAULT_STORE_BASE_DIR,
-        2,
         null,
         new SystemClock());
   }
@@ -252,10 +264,10 @@ public class TestContainerStorageManager {
           mockingDetails(gauge).getInvocations().size() >= 1);
     }
 
-    Assert.assertTrue("Store restore count should be 2 because there are 2 tasks", this.storeRestoreCallCount == 2);
-    Assert.assertTrue("systemConsumerCreation count should be 1 (1 consumer per system)",
-        this.systemConsumerCreationCount == 1);
-    Assert.assertTrue("systemConsumerStopCount count should be 1", this.systemConsumerStopCount == 1);
-    Assert.assertTrue("systemConsumerStartCount count should be 1", this.systemConsumerStartCount == 1);
+    Assert.assertEquals("Store restore count should be 2 because there are 2 tasks", 2, this.storeRestoreCallCount);
+    Assert.assertEquals("systemConsumerCreation count should be 1 (1 consumer per system)", 1,
+        this.systemConsumerCreationCount);
+    Assert.assertEquals("systemConsumerStopCount count should be 1", 1, this.systemConsumerStopCount);
+    Assert.assertEquals("systemConsumerStartCount count should be 1", 1, this.systemConsumerStartCount);
   }
 }
diff --git a/samza-kv-inmemory/src/main/java/org/apache/samza/storage/kv/inmemory/InMemoryKeyValueStorageEngineFactory.java b/samza-kv-inmemory/src/main/java/org/apache/samza/storage/kv/inmemory/InMemoryKeyValueStorageEngineFactory.java
index 8cd4e36..883766b 100644
--- a/samza-kv-inmemory/src/main/java/org/apache/samza/storage/kv/inmemory/InMemoryKeyValueStorageEngineFactory.java
+++ b/samza-kv-inmemory/src/main/java/org/apache/samza/storage/kv/inmemory/InMemoryKeyValueStorageEngineFactory.java
@@ -25,7 +25,6 @@ import org.apache.samza.metrics.MetricsRegistry;
 import org.apache.samza.storage.kv.BaseKeyValueStorageEngineFactory;
 import org.apache.samza.storage.kv.KeyValueStore;
 import org.apache.samza.storage.kv.KeyValueStoreMetrics;
-import org.apache.samza.system.SystemStreamPartition;
 
 
 public class InMemoryKeyValueStorageEngineFactory<K, V> extends BaseKeyValueStorageEngineFactory<K, V> {
@@ -33,7 +32,6 @@ public class InMemoryKeyValueStorageEngineFactory<K, V> extends BaseKeyValueStor
   protected KeyValueStore<byte[], byte[]> getKVStore(String storeName,
       File storeDir,
       MetricsRegistry registry,
-      SystemStreamPartition changeLogSystemStreamPartition,
       JobContext jobContext,
       ContainerContext containerContext,
       StoreMode storeMode) {
diff --git a/samza-kv-rocksdb/src/main/scala/org/apache/samza/storage/kv/RocksDbKeyValueStorageEngineFactory.scala b/samza-kv-rocksdb/src/main/scala/org/apache/samza/storage/kv/RocksDbKeyValueStorageEngineFactory.scala
index d02d623..e1f44d2 100644
--- a/samza-kv-rocksdb/src/main/scala/org/apache/samza/storage/kv/RocksDbKeyValueStorageEngineFactory.scala
+++ b/samza-kv-rocksdb/src/main/scala/org/apache/samza/storage/kv/RocksDbKeyValueStorageEngineFactory.scala
@@ -34,14 +34,12 @@ class RocksDbKeyValueStorageEngineFactory [K, V] extends BaseKeyValueStorageEngi
    * @param storeName Name of the store
    * @param storeDir The directory of the store
    * @param registry MetricsRegistry to which to publish store specific metrics.
-   * @param changeLogSystemStreamPartition Samza stream partition from which to receive the changelog.
    * @param containerContext Information about the container in which the task is executing.
    * @return A valid KeyValueStore instance
    */
   override def getKVStore(storeName: String,
     storeDir: File,
     registry: MetricsRegistry,
-    changeLogSystemStreamPartition: SystemStreamPartition,
     jobContext: JobContext,
     containerContext: ContainerContext, storeMode: StoreMode): KeyValueStore[Array[Byte], Array[Byte]] = {
     val storageConfigSubset = jobContext.getConfig.subset("stores." + storeName + ".", true)
diff --git a/samza-kv-rocksdb/src/main/scala/org/apache/samza/storage/kv/RocksDbKeyValueStore.scala b/samza-kv-rocksdb/src/main/scala/org/apache/samza/storage/kv/RocksDbKeyValueStore.scala
index 300177a..9d5ddfd 100644
--- a/samza-kv-rocksdb/src/main/scala/org/apache/samza/storage/kv/RocksDbKeyValueStore.scala
+++ b/samza-kv-rocksdb/src/main/scala/org/apache/samza/storage/kv/RocksDbKeyValueStore.scala
@@ -24,11 +24,11 @@ import java.nio.file.{Path, Paths}
 import java.util.concurrent.TimeUnit
 import java.util.concurrent.locks.ReentrantReadWriteLock
 import java.util.{Comparator, Optional}
-
 import org.apache.samza.SamzaException
 import org.apache.samza.checkpoint.CheckpointId
 import org.apache.samza.config.Config
-import org.apache.samza.util.Logging
+import org.apache.samza.storage.StorageManagerUtil
+import org.apache.samza.util.{FileUtil, Logging}
 import org.rocksdb.{TtlDB, _}
 
 object RocksDbKeyValueStore extends Logging {
@@ -67,6 +67,9 @@ object RocksDbKeyValueStore extends Logging {
     }
 
     try {
+      // Create the path if it doesn't exist
+      new FileUtil().createDirectories(dir.toPath)
+
       val rocksDb =
         if (useTTL) {
           info("Opening RocksDB store: %s in path: %s with TTL value: %s" format (storeName, dir.toString, ttl))
@@ -239,7 +242,7 @@ class RocksDbKeyValueStore(
 
   override def checkpoint(id: CheckpointId): Optional[Path] = {
     val checkpoint = Checkpoint.create(db)
-    val checkpointPath = dir.getPath + "-" + id.toString
+    val checkpointPath = StorageManagerUtil.getCheckpointDirPath(dir, id)
     checkpoint.createCheckpoint(checkpointPath)
     Optional.of(Paths.get(checkpointPath))
   }
diff --git a/samza-kv-rocksdb/src/test/scala/org/apache/samza/storage/kv/TestRocksDbKeyValueStore.scala b/samza-kv-rocksdb/src/test/scala/org/apache/samza/storage/kv/TestRocksDbKeyValueStore.scala
index 2a4f44e..54dca8f 100644
--- a/samza-kv-rocksdb/src/test/scala/org/apache/samza/storage/kv/TestRocksDbKeyValueStore.scala
+++ b/samza-kv-rocksdb/src/test/scala/org/apache/samza/storage/kv/TestRocksDbKeyValueStore.scala
@@ -22,11 +22,10 @@ package org.apache.samza.storage.kv
 
 import java.io.File
 import java.util
-
 import org.apache.samza.SamzaException
 import org.apache.samza.config.MapConfig
 import org.apache.samza.metrics.{Counter, Gauge, MetricsRegistryMap, MetricsVisitor, Timer}
-import org.apache.samza.util.ExponentialSleepStrategy
+import org.apache.samza.util.{ExponentialSleepStrategy, FileUtil}
 import org.junit.{Assert, Test}
 import org.rocksdb.{FlushOptions, Options, RocksDB, RocksIterator}
 
@@ -124,6 +123,23 @@ class TestRocksDbKeyValueStore
   }
 
   @Test
+  def testRocksDbCreatePathIfNotExist(): Unit = {
+    val map = new util.HashMap[String, String]()
+    val config = new MapConfig(map)
+    val options = new Options()
+    options.setCreateIfMissing(true)
+
+    val dbDir = new File(System.getProperty("java.io.tmpdir") + File.separator + "samza-test2", "rocksDbFiles")
+    val rocksDB = new RocksDbKeyValueStore(dbDir, options, config, false, "dbStore")
+    val key = "key".getBytes("UTF-8")
+    rocksDB.put(key, "val".getBytes("UTF-8"))
+
+    rocksDB.close()
+
+    new FileUtil().rm(dbDir)
+  }
+
+  @Test
   def testIteratorWithRemoval(): Unit = {
     val lock = new Object
 
diff --git a/samza-kv/src/main/java/org/apache/samza/storage/kv/BaseKeyValueStorageEngineFactory.java b/samza-kv/src/main/java/org/apache/samza/storage/kv/BaseKeyValueStorageEngineFactory.java
index d6c1196..704e0cb 100644
--- a/samza-kv/src/main/java/org/apache/samza/storage/kv/BaseKeyValueStorageEngineFactory.java
+++ b/samza-kv/src/main/java/org/apache/samza/storage/kv/BaseKeyValueStorageEngineFactory.java
@@ -19,6 +19,7 @@
 package org.apache.samza.storage.kv;
 
 import java.io.File;
+import java.util.List;
 import java.util.Optional;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.samza.SamzaException;
@@ -60,7 +61,6 @@ public abstract class BaseKeyValueStorageEngineFactory<K, V> implements StorageE
    * @param storeName Name of the store
    * @param storeDir The directory of the store
    * @param registry MetricsRegistry to which to publish store specific metrics.
-   * @param changeLogSystemStreamPartition Samza stream partition from which to receive the changelog.
    * @param jobContext Information about the job in which the task is executing.
    * @param containerContext Information about the container in which the task is executing.
    * @return A raw KeyValueStore instance
@@ -68,7 +68,6 @@ public abstract class BaseKeyValueStorageEngineFactory<K, V> implements StorageE
   protected abstract KeyValueStore<byte[], byte[]> getKVStore(String storeName,
       File storeDir,
       MetricsRegistry registry,
-      SystemStreamPartition changeLogSystemStreamPartition,
       JobContext jobContext,
       ContainerContext containerContext,
       StoreMode storeMode);
@@ -106,6 +105,10 @@ public abstract class BaseKeyValueStorageEngineFactory<K, V> implements StorageE
     if (!storeFactory.get().equals(INMEMORY_KV_STORAGE_ENGINE_FACTORY)) {
       storePropertiesBuilder.setPersistedToDisk(true);
     }
+    // The store is durable iff it is backed by the task backup manager
+    List<String> storeBackupManager = storageConfig.getStoreBackupManagerClassName(storeName);
+    storePropertiesBuilder.setIsDurable(!storeBackupManager.isEmpty());
+
     int batchSize = storageConfigSubset.getInt(WRITE_BATCH_SIZE, DEFAULT_WRITE_BATCH_SIZE);
     int cacheSize = storageConfigSubset.getInt(OBJECT_CACHE_SIZE, Math.max(batchSize, DEFAULT_OBJECT_CACHE_SIZE));
     if (cacheSize > 0 && cacheSize < batchSize) {
@@ -123,7 +126,7 @@ public abstract class BaseKeyValueStorageEngineFactory<K, V> implements StorageE
     }
 
     KeyValueStore<byte[], byte[]> rawStore =
-        getKVStore(storeName, storeDir, registry, changelogSSP, jobContext, containerContext, storeMode);
+        getKVStore(storeName, storeDir, registry, jobContext, containerContext, storeMode);
     KeyValueStore<byte[], byte[]> maybeLoggedStore = buildMaybeLoggedStore(changelogSSP,
         storeName, registry, storePropertiesBuilder, rawStore, changelogCollector);
     // this also applies serialization and caching layers
diff --git a/samza-kv/src/test/java/org/apache/samza/storage/kv/MockKeyValueStorageEngineFactory.java b/samza-kv/src/test/java/org/apache/samza/storage/kv/MockKeyValueStorageEngineFactory.java
index 3430ae9..2c8251e 100644
--- a/samza-kv/src/test/java/org/apache/samza/storage/kv/MockKeyValueStorageEngineFactory.java
+++ b/samza-kv/src/test/java/org/apache/samza/storage/kv/MockKeyValueStorageEngineFactory.java
@@ -22,7 +22,6 @@ import java.io.File;
 import org.apache.samza.context.ContainerContext;
 import org.apache.samza.context.JobContext;
 import org.apache.samza.metrics.MetricsRegistry;
-import org.apache.samza.system.SystemStreamPartition;
 
 
 /**
@@ -38,8 +37,7 @@ public class MockKeyValueStorageEngineFactory extends BaseKeyValueStorageEngineF
 
   @Override
   protected KeyValueStore<byte[], byte[]> getKVStore(String storeName, File storeDir, MetricsRegistry registry,
-      SystemStreamPartition changeLogSystemStreamPartition, JobContext jobContext, ContainerContext containerContext,
-      StoreMode storeMode) {
+      JobContext jobContext, ContainerContext containerContext, StoreMode storeMode) {
     return this.rawKeyValueStore;
   }
 }
diff --git a/samza-kv/src/test/java/org/apache/samza/storage/kv/TestBaseKeyValueStorageEngineFactory.java b/samza-kv/src/test/java/org/apache/samza/storage/kv/TestBaseKeyValueStorageEngineFactory.java
index 22a1b57..05b106d 100644
--- a/samza-kv/src/test/java/org/apache/samza/storage/kv/TestBaseKeyValueStorageEngineFactory.java
+++ b/samza-kv/src/test/java/org/apache/samza/storage/kv/TestBaseKeyValueStorageEngineFactory.java
@@ -127,7 +127,23 @@ public class TestBaseKeyValueStorageEngineFactory {
         "org.apache.samza.storage.kv.inmemory.InMemoryKeyValueStorageEngineFactory"));
     StorageEngine storageEngine = callGetStorageEngine(config, null);
     KeyValueStorageEngine<?, ?> keyValueStorageEngine = baseStorageEngineValidation(storageEngine);
-    assertStoreProperties(keyValueStorageEngine.getStoreProperties(), false, false);
+    assertStoreProperties(keyValueStorageEngine.getStoreProperties(), false, false, false);
+    NullSafeKeyValueStore<?, ?> nullSafeKeyValueStore =
+        assertAndCast(keyValueStorageEngine.getWrapperStore(), NullSafeKeyValueStore.class);
+    SerializedKeyValueStore<?, ?> serializedKeyValueStore =
+        assertAndCast(nullSafeKeyValueStore.getStore(), SerializedKeyValueStore.class);
+    // config has the in-memory key-value factory, but still calling the test factory, so store will be the test store
+    assertEquals(this.rawKeyValueStore, serializedKeyValueStore.getStore());
+  }
+
+  @Test
+  public void testDurableKeyValueStore() {
+    Config config = new MapConfig(BASE_CONFIG, DISABLE_CACHE,
+        ImmutableMap.of(String.format(StorageConfig.STORE_BACKEND_BACKUP_FACTORIES, STORE_NAME),
+        "backendFactory,backendFactory2"));
+    StorageEngine storageEngine = callGetStorageEngine(config, null);
+    KeyValueStorageEngine<?, ?> keyValueStorageEngine = baseStorageEngineValidation(storageEngine);
+    assertStoreProperties(keyValueStorageEngine.getStoreProperties(), true, false, true);
     NullSafeKeyValueStore<?, ?> nullSafeKeyValueStore =
         assertAndCast(keyValueStorageEngine.getWrapperStore(), NullSafeKeyValueStore.class);
     SerializedKeyValueStore<?, ?> serializedKeyValueStore =
@@ -141,7 +157,7 @@ public class TestBaseKeyValueStorageEngineFactory {
     Config config = new MapConfig(BASE_CONFIG, DISABLE_CACHE);
     StorageEngine storageEngine = callGetStorageEngine(config, null);
     KeyValueStorageEngine<?, ?> keyValueStorageEngine = baseStorageEngineValidation(storageEngine);
-    assertStoreProperties(keyValueStorageEngine.getStoreProperties(), true, false);
+    assertStoreProperties(keyValueStorageEngine.getStoreProperties(), true, false, false);
     NullSafeKeyValueStore<?, ?> nullSafeKeyValueStore =
         assertAndCast(keyValueStorageEngine.getWrapperStore(), NullSafeKeyValueStore.class);
     SerializedKeyValueStore<?, ?> serializedKeyValueStore =
@@ -154,7 +170,7 @@ public class TestBaseKeyValueStorageEngineFactory {
     Config config = new MapConfig(BASE_CONFIG, DISABLE_CACHE);
     StorageEngine storageEngine = callGetStorageEngine(config, CHANGELOG_SSP);
     KeyValueStorageEngine<?, ?> keyValueStorageEngine = baseStorageEngineValidation(storageEngine);
-    assertStoreProperties(keyValueStorageEngine.getStoreProperties(), true, true);
+    assertStoreProperties(keyValueStorageEngine.getStoreProperties(), true, true, false);
     NullSafeKeyValueStore<?, ?> nullSafeKeyValueStore =
         assertAndCast(keyValueStorageEngine.getWrapperStore(), NullSafeKeyValueStore.class);
     SerializedKeyValueStore<?, ?> serializedKeyValueStore =
@@ -170,7 +186,7 @@ public class TestBaseKeyValueStorageEngineFactory {
     Config config = new MapConfig(BASE_CONFIG);
     StorageEngine storageEngine = callGetStorageEngine(config, CHANGELOG_SSP);
     KeyValueStorageEngine<?, ?> keyValueStorageEngine = baseStorageEngineValidation(storageEngine);
-    assertStoreProperties(keyValueStorageEngine.getStoreProperties(), true, true);
+    assertStoreProperties(keyValueStorageEngine.getStoreProperties(), true, true, false);
     NullSafeKeyValueStore<?, ?> nullSafeKeyValueStore =
         assertAndCast(keyValueStorageEngine.getWrapperStore(), NullSafeKeyValueStore.class);
     CachedStore<?, ?> cachedStore = assertAndCast(nullSafeKeyValueStore.getStore(), CachedStore.class);
@@ -187,7 +203,7 @@ public class TestBaseKeyValueStorageEngineFactory {
     Config config = new MapConfig(BASE_CONFIG, DISABLE_CACHE, DISALLOW_LARGE_MESSAGES);
     StorageEngine storageEngine = callGetStorageEngine(config, null);
     KeyValueStorageEngine<?, ?> keyValueStorageEngine = baseStorageEngineValidation(storageEngine);
-    assertStoreProperties(keyValueStorageEngine.getStoreProperties(), true, false);
+    assertStoreProperties(keyValueStorageEngine.getStoreProperties(), true, false, false);
     NullSafeKeyValueStore<?, ?> nullSafeKeyValueStore =
         assertAndCast(keyValueStorageEngine.getWrapperStore(), NullSafeKeyValueStore.class);
     SerializedKeyValueStore<?, ?> serializedKeyValueStore =
@@ -202,7 +218,7 @@ public class TestBaseKeyValueStorageEngineFactory {
     Config config = new MapConfig(BASE_CONFIG, DISALLOW_LARGE_MESSAGES);
     StorageEngine storageEngine = callGetStorageEngine(config, null);
     KeyValueStorageEngine<?, ?> keyValueStorageEngine = baseStorageEngineValidation(storageEngine);
-    assertStoreProperties(keyValueStorageEngine.getStoreProperties(), true, false);
+    assertStoreProperties(keyValueStorageEngine.getStoreProperties(), true, false, false);
     NullSafeKeyValueStore<?, ?> nullSafeKeyValueStore =
         assertAndCast(keyValueStorageEngine.getWrapperStore(), NullSafeKeyValueStore.class);
     SerializedKeyValueStore<?, ?> serializedKeyValueStore =
@@ -220,7 +236,7 @@ public class TestBaseKeyValueStorageEngineFactory {
     Config config = new MapConfig(BASE_CONFIG, DISABLE_CACHE, DROP_LARGE_MESSAGES);
     StorageEngine storageEngine = callGetStorageEngine(config, null);
     KeyValueStorageEngine<?, ?> keyValueStorageEngine = baseStorageEngineValidation(storageEngine);
-    assertStoreProperties(keyValueStorageEngine.getStoreProperties(), true, false);
+    assertStoreProperties(keyValueStorageEngine.getStoreProperties(), true, false, false);
     NullSafeKeyValueStore<?, ?> nullSafeKeyValueStore =
         assertAndCast(keyValueStorageEngine.getWrapperStore(), NullSafeKeyValueStore.class);
     SerializedKeyValueStore<?, ?> serializedKeyValueStore =
@@ -235,7 +251,7 @@ public class TestBaseKeyValueStorageEngineFactory {
     Config config = new MapConfig(BASE_CONFIG, DROP_LARGE_MESSAGES);
     StorageEngine storageEngine = callGetStorageEngine(config, null);
     KeyValueStorageEngine<?, ?> keyValueStorageEngine = baseStorageEngineValidation(storageEngine);
-    assertStoreProperties(keyValueStorageEngine.getStoreProperties(), true, false);
+    assertStoreProperties(keyValueStorageEngine.getStoreProperties(), true, false, false);
     NullSafeKeyValueStore<?, ?> nullSafeKeyValueStore =
         assertAndCast(keyValueStorageEngine.getWrapperStore(), NullSafeKeyValueStore.class);
     CachedStore<?, ?> cachedStore = assertAndCast(nullSafeKeyValueStore.getStore(), CachedStore.class);
@@ -252,7 +268,7 @@ public class TestBaseKeyValueStorageEngineFactory {
     // AccessLoggedStore requires a changelog SSP
     StorageEngine storageEngine = callGetStorageEngine(config, CHANGELOG_SSP);
     KeyValueStorageEngine<?, ?> keyValueStorageEngine = baseStorageEngineValidation(storageEngine);
-    assertStoreProperties(keyValueStorageEngine.getStoreProperties(), true, true);
+    assertStoreProperties(keyValueStorageEngine.getStoreProperties(), true, true, false);
     NullSafeKeyValueStore<?, ?> nullSafeKeyValueStore =
         assertAndCast(keyValueStorageEngine.getWrapperStore(), NullSafeKeyValueStore.class);
     AccessLoggedStore<?, ?> accessLoggedStore =
@@ -278,9 +294,10 @@ public class TestBaseKeyValueStorageEngineFactory {
   }
 
   private static void assertStoreProperties(StoreProperties storeProperties, boolean expectedPersistedToDisk,
-      boolean expectedLoggedStore) {
+      boolean expectedLoggedStore, boolean expectedDurable) {
     assertEquals(expectedPersistedToDisk, storeProperties.isPersistedToDisk());
     assertEquals(expectedLoggedStore, storeProperties.isLoggedStore());
+    assertEquals(expectedDurable, storeProperties.isDurableStore());
   }
 
   /**
diff --git a/samza-test/src/test/java/org/apache/samza/test/processor/TestZkLocalApplicationRunner.java b/samza-test/src/test/java/org/apache/samza/test/processor/TestZkLocalApplicationRunner.java
index 4e69372..e2e599d 100644
--- a/samza-test/src/test/java/org/apache/samza/test/processor/TestZkLocalApplicationRunner.java
+++ b/samza-test/src/test/java/org/apache/samza/test/processor/TestZkLocalApplicationRunner.java
@@ -856,7 +856,7 @@ public class TestZkLocalApplicationRunner extends IntegrationTestHarness {
     }
 
     Assert.assertEquals(expectedTaskAssignments, actualTaskAssignments);
-    Assert.assertEquals(32, jobModel.maxChangeLogStreamPartitions);
+    Assert.assertEquals(32, jobModel.getMaxChangeLogStreamPartitions());
   }
 
   /**
@@ -958,7 +958,7 @@ public class TestZkLocalApplicationRunner extends IntegrationTestHarness {
     // Validate that the new JobModel has the expected task assignments.
     actualTaskAssignments = getTaskAssignments(jobModel);
     Assert.assertEquals(expectedTaskAssignments, actualTaskAssignments);
-    Assert.assertEquals(32, jobModel.maxChangeLogStreamPartitions);
+    Assert.assertEquals(32, jobModel.getMaxChangeLogStreamPartitions());
   }
 
   @Test
diff --git a/samza-test/src/test/scala/org/apache/samza/test/integration/NonTransactionalStateIntegrationTest.scala b/samza-test/src/test/scala/org/apache/samza/test/integration/NonTransactionalStateIntegrationTest.scala
index 00a530d..229e026 100644
--- a/samza-test/src/test/scala/org/apache/samza/test/integration/NonTransactionalStateIntegrationTest.scala
+++ b/samza-test/src/test/scala/org/apache/samza/test/integration/NonTransactionalStateIntegrationTest.scala
@@ -16,7 +16,6 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-
 package org.apache.samza.test.integration
 
 import org.apache.samza.config.TaskConfig
@@ -78,7 +77,9 @@ class NonTransactionalStateIntegrationTest extends StreamTaskTestUtil {
     "stores.mystore.changelog.replication.factor" -> "1",
     // However, don't have the inputs use the checkpoint manager
     // since the second part of the test expects to replay the input streams.
-    "systems.kafka.streams.input.samza.reset.offset" -> "true"))
+    "systems.kafka.streams.input.samza.reset.offset" -> "true",
+    TaskConfig.COMMIT_MAX_DELAY_MS -> "0" // Ensure no commits are skipped due to in progress commits
+  ))
 
   @Test
   def testShouldStartAndRestore {
diff --git a/samza-test/src/test/scala/org/apache/samza/test/integration/StreamTaskTestUtil.scala b/samza-test/src/test/scala/org/apache/samza/test/integration/StreamTaskTestUtil.scala
index edcb159..b375511 100644
--- a/samza-test/src/test/scala/org/apache/samza/test/integration/StreamTaskTestUtil.scala
+++ b/samza-test/src/test/scala/org/apache/samza/test/integration/StreamTaskTestUtil.scala
@@ -34,7 +34,7 @@ import org.apache.kafka.clients.consumer.{ConsumerRecord, KafkaConsumer}
 import org.apache.kafka.clients.producer.{KafkaProducer, Producer, ProducerConfig, ProducerRecord}
 import org.apache.kafka.common.security.auth.SecurityProtocol
 import org.apache.samza.Partition
-import org.apache.samza.checkpoint.Checkpoint
+import org.apache.samza.checkpoint.{Checkpoint, CheckpointV1}
 import org.apache.samza.config._
 import org.apache.samza.container.TaskName
 import org.apache.samza.context.Context
@@ -73,8 +73,8 @@ object StreamTaskTestUtil {
 
   var producer: Producer[Array[Byte], Array[Byte]] = null
   var adminClient: AdminClient = null
-  val cp1 = new Checkpoint(Map(new SystemStreamPartition("kafka", "topic", new Partition(0)) -> "123").asJava)
-  val cp2 = new Checkpoint(Map(new SystemStreamPartition("kafka", "topic", new Partition(0)) -> "12345").asJava)
+  val cp1 = new CheckpointV1(Map(new SystemStreamPartition("kafka", "topic", new Partition(0)) -> "123").asJava)
+  val cp2 = new CheckpointV1(Map(new SystemStreamPartition("kafka", "topic", new Partition(0)) -> "12345").asJava)
 
   // use a random store directory for each run. prevents test failures due to left over state from
   // previously aborted test runs
@@ -291,7 +291,7 @@ class StreamTaskTestUtil {
       case _ => throw new ConfigException("No checkpoint manager factory configured")
     }
 
-    ChangelogStreamManager.createChangelogStreams(jobModel.getConfig, jobModel.maxChangeLogStreamPartitions)
+    ChangelogStreamManager.createChangelogStreams(jobModel.getConfig, jobModel.getMaxChangeLogStreamPartitions)
   }
 }