You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@samza.apache.org by pm...@apache.org on 2019/10/08 23:17:57 UTC

[samza] branch master updated: Transactional State [4/5]: Added new interfaces for TaskRestoreManager and TaskStorageManager

This is an automated email from the ASF dual-hosted git repository.

pmaheshwari pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/samza.git


The following commit(s) were added to refs/heads/master by this push:
     new dc55bd3  Transactional State [4/5]: Added new interfaces for TaskRestoreManager and TaskStorageManager
dc55bd3 is described below

commit dc55bd3fe6d9177cc9f57a63d02cee902b2bddab
Author: Prateek Maheshwari <pr...@utexas.edu>
AuthorDate: Tue Oct 8 16:17:52 2019 -0700

    Transactional State [4/5]: Added new interfaces for TaskRestoreManager and TaskStorageManager
    
    This PR adds new internal interfaces for switching the per Task store commit and restore implementations, and introduces new configs to enable / disable transactional state.
---
 .../java/org/apache/samza/config/TaskConfig.java   |  13 +
 .../NonTransactionalStateTaskRestoreManager.java   | 349 +++++++++++++++++++++
 .../org/apache/samza/storage/StorageRecovery.java  |  21 +-
 .../apache/samza/storage/TaskRestoreManager.java   |  46 +++
 .../samza/storage/TaskRestoreManagerFactory.java   |  70 +++++
 .../apache/samza/container/SamzaContainer.scala    |  35 ++-
 .../samza/storage/ContainerStorageManager.java     | 336 +++-----------------
 .../NonTransactionalStateTaskStorageManager.scala  | 144 +++++++++
 .../apache/samza/storage/TaskStorageManager.scala  |  93 +-----
 .../samza/storage/TaskStorageManagerFactory.java   |  45 +++
 .../samza/container/TestSamzaContainer.scala       |  14 +-
 .../samza/storage/TestContainerStorageManager.java |  87 +++--
 .../samza/storage/TestTaskStorageManager.scala     | 115 ++++---
 .../samza/storage/kv/KeyValueStorageEngine.scala   |   2 +-
 14 files changed, 868 insertions(+), 502 deletions(-)

diff --git a/samza-core/src/main/java/org/apache/samza/config/TaskConfig.java b/samza-core/src/main/java/org/apache/samza/config/TaskConfig.java
index 2b26020..54a0827 100644
--- a/samza-core/src/main/java/org/apache/samza/config/TaskConfig.java
+++ b/samza-core/src/main/java/org/apache/samza/config/TaskConfig.java
@@ -106,6 +106,11 @@ public class TaskConfig extends MapConfig {
   private static final String BROADCAST_STREAM_RANGE_PATTERN = "^\\[[\\d]+\\-[\\d]+\\]$";
   public static final String CHECKPOINT_MANAGER_FACTORY = "task.checkpoint.factory";
 
+  public static final String TRANSACTIONAL_STATE_ENABLED = "samza.transactional.state.enabled";
+  private static final boolean DEFAULT_TRANSACTIONAL_STATE_ENABLED = false;
+  public static final String TRANSACTIONAL_STATE_RETAIN_EXISTING_STATE = "samza.transactional.state.retain.existing.state";
+  private static final boolean DEFAULT_TRANSACTIONAL_STATE_RETAIN_EXISTING_STATE = true;
+
   public TaskConfig(Config config) {
     super(config);
   }
@@ -296,4 +301,12 @@ public class TaskConfig extends MapConfig {
       return DEFAULT_TASK_SHUTDOWN_MS;
     }
   }
+
+  public boolean getTransactionalStateEnabled() {
+    return getBoolean(TRANSACTIONAL_STATE_ENABLED, DEFAULT_TRANSACTIONAL_STATE_ENABLED);
+  }
+
+  public boolean getTransactionalStateRetainExistingState() {
+    return getBoolean(TRANSACTIONAL_STATE_RETAIN_EXISTING_STATE, DEFAULT_TRANSACTIONAL_STATE_RETAIN_EXISTING_STATE);
+  }
 }
diff --git a/samza-core/src/main/java/org/apache/samza/storage/NonTransactionalStateTaskRestoreManager.java b/samza-core/src/main/java/org/apache/samza/storage/NonTransactionalStateTaskRestoreManager.java
new file mode 100644
index 0000000..5952647
--- /dev/null
+++ b/samza-core/src/main/java/org/apache/samza/storage/NonTransactionalStateTaskRestoreManager.java
@@ -0,0 +1,349 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.storage;
+
+import java.io.File;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+import org.apache.samza.Partition;
+import org.apache.samza.SamzaException;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.StorageConfig;
+import org.apache.samza.job.model.TaskModel;
+import org.apache.samza.system.ChangelogSSPIterator;
+import org.apache.samza.system.StreamMetadataCache;
+import org.apache.samza.system.StreamSpec;
+import org.apache.samza.system.SystemAdmin;
+import org.apache.samza.system.SystemAdmins;
+import org.apache.samza.system.SystemConsumer;
+import org.apache.samza.system.SystemStream;
+import org.apache.samza.system.SystemStreamMetadata;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.util.Clock;
+import org.apache.samza.util.FileUtil;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import scala.collection.JavaConverters;
+
+
+/**
+ * This is the legacy state restoration, based on changelog and offset files.
+ *
+ * Restore logic for all stores of a task including directory cleanup, setup, changelogSSP validation, registering
+ * with the respective consumer, restoring stores, and stopping stores.
+ */
+class NonTransactionalStateTaskRestoreManager implements TaskRestoreManager {
+  private static final Logger LOG = LoggerFactory.getLogger(NonTransactionalStateTaskRestoreManager.class);
+
+  private final Map<String, StorageEngine> taskStores; // Map of all StorageEngines for this task indexed by store name
+  private final Set<String> taskStoresToRestore;
+  // Set of store names which need to be restored by consuming using system-consumers (see registerStartingOffsets)
+
+  private final TaskModel taskModel;
+  private final Clock clock; // Clock value used to validate base-directories for staleness. See isLoggedStoreValid.
+  private Map<SystemStream, String> changeLogOldestOffsets; // Map of changelog oldest known offsets
+  private final Map<SystemStreamPartition, String> fileOffsets; // Map of offsets read from offset file indexed by changelog SSP
+  private final Map<String, SystemStream> changelogSystemStreams; // Map of change log system-streams indexed by store name
+  private final SystemAdmins systemAdmins;
+  private final File loggedStoreBaseDirectory;
+  private final File nonLoggedStoreBaseDirectory;
+  private final StreamMetadataCache streamMetadataCache;
+  private final Map<String, SystemConsumer> storeConsumers;
+  private final int maxChangeLogStreamPartitions;
+  private final Config config;
+  private final StorageManagerUtil storageManagerUtil;
+
+  NonTransactionalStateTaskRestoreManager(
+      TaskModel taskModel,
+      Map<String, SystemStream> changelogSystemStreams,
+      Map<String, StorageEngine> taskStores,
+      SystemAdmins systemAdmins,
+      StreamMetadataCache streamMetadataCache,
+      Map<String, SystemConsumer> storeConsumers,
+      int maxChangeLogStreamPartitions,
+      File loggedStoreBaseDirectory,
+      File nonLoggedStoreBaseDirectory,
+      Config config,
+      Clock clock) {
+    this.taskStores = taskStores;
+    this.taskModel = taskModel;
+    this.clock = clock;
+    this.changelogSystemStreams = changelogSystemStreams;
+    this.systemAdmins = systemAdmins;
+    this.fileOffsets = new HashMap<>();
+    this.taskStoresToRestore = this.taskStores.entrySet().stream()
+        .filter(x -> x.getValue().getStoreProperties().isLoggedStore())
+        .map(x -> x.getKey()).collect(Collectors.toSet());
+    this.loggedStoreBaseDirectory = loggedStoreBaseDirectory;
+    this.nonLoggedStoreBaseDirectory = nonLoggedStoreBaseDirectory;
+    this.streamMetadataCache = streamMetadataCache;
+    this.storeConsumers = storeConsumers;
+    this.maxChangeLogStreamPartitions = maxChangeLogStreamPartitions;
+    this.config = config;
+    this.storageManagerUtil = new StorageManagerUtil();
+  }
+
+  /**
+   * Cleans up and sets up store directories, validates changeLog SSPs for all stores of this task,
+   * and registers SSPs with the respective consumers.
+   */
+  @Override
+  public void init(Map<SystemStreamPartition, String> checkpointedChangelogSSPOffsets) {
+    cleanBaseDirsAndReadOffsetFiles();
+    setupBaseDirs();
+    validateChangelogStreams();
+    getOldestChangeLogOffsets();
+    registerStartingOffsets();
+  }
+
+  /**
+   * For each store for this task,
+   * a. Deletes the corresponding non-logged-store base dir.
+   * b. Deletes the logged-store-base-dir if it not valid. See {@link #isLoggedStoreValid} for validation semantics.
+   * c. If the logged-store-base-dir is valid, this method reads the offset file and stores each offset.
+   */
+  private void cleanBaseDirsAndReadOffsetFiles() {
+    LOG.debug("Cleaning base directories for stores.");
+
+    FileUtil fileUtil = new FileUtil();
+    taskStores.forEach((storeName, storageEngine) -> {
+        if (!storageEngine.getStoreProperties().isLoggedStore()) {
+          File nonLoggedStorePartitionDir =
+              storageManagerUtil.getTaskStoreDir(nonLoggedStoreBaseDirectory, storeName, taskModel.getTaskName(), taskModel.getTaskMode());
+          LOG.info("Got non logged storage partition directory as " + nonLoggedStorePartitionDir.toPath().toString());
+
+          if (nonLoggedStorePartitionDir.exists()) {
+            LOG.info("Deleting non logged storage partition directory " + nonLoggedStorePartitionDir.toPath().toString());
+            fileUtil.rm(nonLoggedStorePartitionDir);
+          }
+        } else {
+          File loggedStorePartitionDir =
+              storageManagerUtil.getTaskStoreDir(loggedStoreBaseDirectory, storeName, taskModel.getTaskName(), taskModel.getTaskMode());
+          LOG.info("Got logged storage partition directory as " + loggedStorePartitionDir.toPath().toString());
+
+          // Delete the logged store if it is not valid.
+          if (!isLoggedStoreValid(storeName, loggedStorePartitionDir)) {
+            LOG.info("Deleting logged storage partition directory " + loggedStorePartitionDir.toPath().toString());
+            fileUtil.rm(loggedStorePartitionDir);
+          } else {
+
+            SystemStreamPartition changelogSSP = new SystemStreamPartition(changelogSystemStreams.get(storeName), taskModel.getChangelogPartition());
+            Map<SystemStreamPartition, String> offset =
+                storageManagerUtil.readOffsetFile(loggedStorePartitionDir, Collections.singleton(changelogSSP), false);
+            LOG.info("Read offset {} for the store {} from logged storage partition directory {}", offset, storeName, loggedStorePartitionDir);
+
+            if (offset.containsKey(changelogSSP)) {
+              fileOffsets.put(changelogSSP, offset.get(changelogSSP));
+            }
+          }
+        }
+      });
+  }
+
+  /**
+   * Directory loggedStoreDir associated with the logged store storeName is determined to be valid
+   * if all of the following conditions are true.
+   * a) If the store has to be persisted to disk.
+   * b) If there is a valid offset file associated with the logged store.
+   * c) If the logged store has not gone stale.
+   *
+   * @return true if the logged store is valid, false otherwise.
+   */
+  private boolean isLoggedStoreValid(String storeName, File loggedStoreDir) {
+    long changeLogDeleteRetentionInMs = new StorageConfig(config).getChangeLogDeleteRetentionInMs(storeName);
+
+    if (changelogSystemStreams.containsKey(storeName)) {
+      SystemStreamPartition changelogSSP = new SystemStreamPartition(changelogSystemStreams.get(storeName), taskModel.getChangelogPartition());
+      return this.taskStores.get(storeName).getStoreProperties().isPersistedToDisk()
+          && storageManagerUtil.isOffsetFileValid(loggedStoreDir, Collections.singleton(changelogSSP), false)
+          && !storageManagerUtil.isStaleStore(loggedStoreDir, changeLogDeleteRetentionInMs, clock.currentTimeMillis(), false);
+    }
+
+    return false;
+  }
+
+  /**
+   * Create stores' base directories for logged-stores if they dont exist.
+   */
+  private void setupBaseDirs() {
+    LOG.debug("Setting up base directories for stores.");
+    taskStores.forEach((storeName, storageEngine) -> {
+        if (storageEngine.getStoreProperties().isLoggedStore()) {
+
+          File loggedStorePartitionDir =
+              storageManagerUtil.getTaskStoreDir(loggedStoreBaseDirectory, storeName, taskModel.getTaskName(), taskModel.getTaskMode());
+
+          LOG.info("Using logged storage partition directory: " + loggedStorePartitionDir.toPath().toString()
+              + " for store: " + storeName);
+
+          if (!loggedStorePartitionDir.exists()) {
+            loggedStorePartitionDir.mkdirs();
+          }
+        } else {
+          File nonLoggedStorePartitionDir =
+              storageManagerUtil.getTaskStoreDir(nonLoggedStoreBaseDirectory, storeName, taskModel.getTaskName(), taskModel.getTaskMode());
+          LOG.info("Using non logged storage partition directory: " + nonLoggedStorePartitionDir.toPath().toString()
+              + " for store: " + storeName);
+          nonLoggedStorePartitionDir.mkdirs();
+        }
+      });
+  }
+
+  /**
+   *  Validates each changelog system-stream with its respective SystemAdmin.
+   */
+  private void validateChangelogStreams() {
+    LOG.info("Validating change log streams: " + changelogSystemStreams);
+
+    for (SystemStream changelogSystemStream : changelogSystemStreams.values()) {
+      SystemAdmin systemAdmin = systemAdmins.getSystemAdmin(changelogSystemStream.getSystem());
+      StreamSpec changelogSpec =
+          StreamSpec.createChangeLogStreamSpec(changelogSystemStream.getStream(), changelogSystemStream.getSystem(),
+              maxChangeLogStreamPartitions);
+
+      systemAdmin.validateStream(changelogSpec);
+    }
+  }
+
+  /**
+   * Get the oldest offset for each changelog SSP based on the stream's metadata (obtained from streamMetadataCache).
+   */
+  private void getOldestChangeLogOffsets() {
+
+    Map<SystemStream, SystemStreamMetadata> changeLogMetadata = JavaConverters.mapAsJavaMapConverter(
+        streamMetadataCache.getStreamMetadata(
+            JavaConverters.asScalaSetConverter(new HashSet<>(changelogSystemStreams.values())).asScala().toSet(),
+            false)).asJava();
+
+    LOG.info("Got change log stream metadata: {}", changeLogMetadata);
+
+    changeLogOldestOffsets =
+        getChangeLogOldestOffsetsForPartition(taskModel.getChangelogPartition(), changeLogMetadata);
+    LOG.info("Assigning oldest change log offsets for taskName {} : {}", taskModel.getTaskName(),
+        changeLogOldestOffsets);
+  }
+
+  /**
+   * Builds a map from SystemStreamPartition to oldest offset for changelogs.
+   */
+  private Map<SystemStream, String> getChangeLogOldestOffsetsForPartition(Partition partition,
+      Map<SystemStream, SystemStreamMetadata> inputStreamMetadata) {
+
+    Map<SystemStream, String> retVal = new HashMap<>();
+
+    // NOTE: do not use Collectors.Map because of https://bugs.openjdk.java.net/browse/JDK-8148463
+    inputStreamMetadata.entrySet()
+        .stream()
+        .filter(x -> x.getValue().getSystemStreamPartitionMetadata().get(partition) != null)
+        .forEach(e -> retVal.put(e.getKey(),
+            e.getValue().getSystemStreamPartitionMetadata().get(partition).getOldestOffset()));
+
+    return retVal;
+  }
+
+  /**
+   * Determines the starting offset for each store SSP (based on {@link #getStartingOffset(SystemStreamPartition, SystemAdmin)}) and
+   * registers it with the respective SystemConsumer for starting consumption.
+   */
+  private void registerStartingOffsets() {
+
+    for (Map.Entry<String, SystemStream> changelogSystemStreamEntry : changelogSystemStreams.entrySet()) {
+      SystemStreamPartition systemStreamPartition =
+          new SystemStreamPartition(changelogSystemStreamEntry.getValue(), taskModel.getChangelogPartition());
+      SystemAdmin systemAdmin = systemAdmins.getSystemAdmin(changelogSystemStreamEntry.getValue().getSystem());
+      SystemConsumer systemConsumer = storeConsumers.get(changelogSystemStreamEntry.getKey());
+
+      String offset = getStartingOffset(systemStreamPartition, systemAdmin);
+
+      if (offset != null) {
+        LOG.info("Registering change log consumer with offset " + offset + " for %" + systemStreamPartition);
+        systemConsumer.register(systemStreamPartition, offset);
+      } else {
+        LOG.info("Skipping change log restoration for {} because stream appears to be empty (offset was null).",
+            systemStreamPartition);
+        taskStoresToRestore.remove(changelogSystemStreamEntry.getKey());
+      }
+    }
+  }
+
+  /**
+   * Returns the offset with which the changelog consumer should be initialized for the given SystemStreamPartition.
+   *
+   * If a file offset exists, it represents the last changelog offset which is also reflected in the on-disk state.
+   * In that case, we use the next offset after the file offset, as long as it is newer than the oldest offset
+   * currently available in the stream.
+   *
+   * If there isn't a file offset or it's older than the oldest available offset, we simply start with the oldest.
+   *
+   * @param systemStreamPartition  the changelog partition for which the offset is needed.
+   * @param systemAdmin                  the [[SystemAdmin]] for the changelog.
+   * @return the offset to from which the changelog consumer should be initialized.
+   */
+  private String getStartingOffset(SystemStreamPartition systemStreamPartition, SystemAdmin systemAdmin) {
+    String fileOffset = fileOffsets.get(systemStreamPartition);
+
+    // NOTE: changeLogOldestOffsets may contain a null-offset for the given SSP (signifying an empty stream)
+    // therefore, we need to differentiate that from the case where the offset is simply missing
+    if (!changeLogOldestOffsets.containsKey(systemStreamPartition.getSystemStream())) {
+      throw new SamzaException("Missing a change log offset for " + systemStreamPartition);
+    }
+
+    String oldestOffset = changeLogOldestOffsets.get(systemStreamPartition.getSystemStream());
+    return storageManagerUtil.getStartingOffset(systemStreamPartition, systemAdmin, fileOffset, oldestOffset);
+  }
+
+  /**
+   * Restore each store in taskStoresToRestore sequentially
+   */
+  @Override
+  public void restore() {
+    for (String storeName : taskStoresToRestore) {
+      LOG.info("Restoring store: {} for task: {}", storeName, taskModel.getTaskName());
+      SystemConsumer systemConsumer = storeConsumers.get(storeName);
+      SystemStream systemStream = changelogSystemStreams.get(storeName);
+      SystemAdmin systemAdmin = systemAdmins.getSystemAdmin(systemStream.getSystem());
+      ChangelogSSPIterator changelogSSPIterator = new ChangelogSSPIterator(systemConsumer,
+          new SystemStreamPartition(systemStream, taskModel.getChangelogPartition()), null, systemAdmin, false);
+
+      taskStores.get(storeName).restore(changelogSSPIterator);
+    }
+  }
+
+  /**
+   * Stop only persistent stores. In case of certain stores and store mode (such as RocksDB), this
+   * can invoke compaction.
+   */
+  public void stopPersistentStores() {
+
+    Map<String, StorageEngine> persistentStores = this.taskStores.entrySet().stream().filter(e -> {
+        return e.getValue().getStoreProperties().isPersistedToDisk();
+      }).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
+
+    persistentStores.forEach((storeName, storageEngine) -> {
+        storageEngine.stop();
+        this.taskStores.remove(storeName);
+      });
+    LOG.info("Stopped persistent stores {}", persistentStores);
+  }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/storage/StorageRecovery.java b/samza-core/src/main/java/org/apache/samza/storage/StorageRecovery.java
index a8a0942..2cbeddc 100644
--- a/samza-core/src/main/java/org/apache/samza/storage/StorageRecovery.java
+++ b/samza-core/src/main/java/org/apache/samza/storage/StorageRecovery.java
@@ -20,15 +20,20 @@
 package org.apache.samza.storage;
 
 import java.io.File;
+import java.time.Duration;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
 import org.apache.samza.SamzaException;
+import org.apache.samza.checkpoint.CheckpointManager;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.SerializerConfig;
 import org.apache.samza.config.StorageConfig;
 import org.apache.samza.config.SystemConfig;
+import org.apache.samza.config.TaskConfig;
 import org.apache.samza.container.SamzaContainerMetrics;
 import org.apache.samza.context.ContainerContext;
 import org.apache.samza.context.ContainerContextImpl;
@@ -41,10 +46,12 @@ import org.apache.samza.job.model.TaskModel;
 import org.apache.samza.metrics.MetricsRegistryMap;
 import org.apache.samza.serializers.Serde;
 import org.apache.samza.serializers.SerdeFactory;
+import org.apache.samza.system.SSPMetadataCache;
 import org.apache.samza.system.StreamMetadataCache;
 import org.apache.samza.system.SystemAdmins;
 import org.apache.samza.system.SystemFactory;
 import org.apache.samza.system.SystemStream;
+import org.apache.samza.system.SystemStreamPartition;
 import org.apache.samza.util.Clock;
 import org.apache.samza.util.CommandLine;
 import org.apache.samza.util.CoordinatorStreamUtil;
@@ -207,15 +214,25 @@ public class StorageRecovery extends CommandLine {
     Clock clock = SystemClock.instance();
     StreamMetadataCache streamMetadataCache = new StreamMetadataCache(systemAdmins, 5000, clock);
     // don't worry about prefetching for this; looks like the tool doesn't flush to offset files anyways
-
     Map<String, SystemFactory> systemFactories = new SystemConfig(jobConfig).getSystemFactories();
+    CheckpointManager checkpointManager = new TaskConfig(jobConfig)
+        .getCheckpointManager(new MetricsRegistryMap()).orElse(null);
 
     for (ContainerModel containerModel : containers.values()) {
       ContainerContext containerContext = new ContainerContextImpl(containerModel, new MetricsRegistryMap());
 
+      Set<SystemStreamPartition> changelogSSPs = changeLogSystemStreams.values().stream()
+          .flatMap(ss -> containerModel.getTasks().values().stream()
+              .map(tm -> new SystemStreamPartition(ss, tm.getChangelogPartition())))
+          .collect(Collectors.toSet());
+      SSPMetadataCache sspMetadataCache = new SSPMetadataCache(systemAdmins, Duration.ofMillis(5000), clock, changelogSSPs);
+
       ContainerStorageManager containerStorageManager =
-          new ContainerStorageManager(containerModel,
+          new ContainerStorageManager(
+              checkpointManager,
+              containerModel,
               streamMetadataCache,
+              sspMetadataCache,
               systemAdmins,
               changeLogSystemStreams,
               new HashMap<>(),
diff --git a/samza-core/src/main/java/org/apache/samza/storage/TaskRestoreManager.java b/samza-core/src/main/java/org/apache/samza/storage/TaskRestoreManager.java
new file mode 100644
index 0000000..2bdeeea
--- /dev/null
+++ b/samza-core/src/main/java/org/apache/samza/storage/TaskRestoreManager.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.storage;
+
+import java.util.Map;
+import org.apache.samza.system.SystemStreamPartition;
+
+
+/**
+ * The helper interface restores task state.
+ */
+public interface TaskRestoreManager {
+
+  /**
+   * Init state resources such as file directories.
+   */
+  void init(Map<SystemStreamPartition, String> checkpointedChangelogSSPOffsets);
+
+  /**
+   * Restore state from checkpoints, state snapshots and changelog.
+   */
+  void restore();
+
+  /**
+   * Stop all persistent stores after restoring.
+   */
+  void stopPersistentStores();
+
+}
diff --git a/samza-core/src/main/java/org/apache/samza/storage/TaskRestoreManagerFactory.java b/samza-core/src/main/java/org/apache/samza/storage/TaskRestoreManagerFactory.java
new file mode 100644
index 0000000..8fa9f1b
--- /dev/null
+++ b/samza-core/src/main/java/org/apache/samza/storage/TaskRestoreManagerFactory.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.storage;
+
+import java.io.File;
+import java.util.Map;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.TaskConfig;
+import org.apache.samza.job.model.TaskModel;
+import org.apache.samza.system.SSPMetadataCache;
+import org.apache.samza.system.StreamMetadataCache;
+import org.apache.samza.system.SystemAdmins;
+import org.apache.samza.system.SystemConsumer;
+import org.apache.samza.system.SystemStream;
+import org.apache.samza.util.Clock;
+
+/**
+ * Factory class to create {@link TaskRestoreManager}.
+ */
+class TaskRestoreManagerFactory {
+
+  public static TaskRestoreManager create(
+      TaskModel taskModel,
+      Map<String, SystemStream> changelogSystemStreams,
+      Map<String, StorageEngine> taskStores,
+      SystemAdmins systemAdmins,
+      StreamMetadataCache streamMetadataCache,
+      SSPMetadataCache sspMetadataCache,
+      Map<String, SystemConsumer> storeConsumers,
+      int maxChangeLogStreamPartitions,
+      File loggedStoreBaseDirectory,
+      File nonLoggedStoreBaseDirectory,
+      Config config,
+      Clock clock) {
+
+    if (new TaskConfig(config).getTransactionalStateEnabled()) {
+      throw new UnsupportedOperationException();
+    } else {
+      // Create legacy offset-file based state restoration which is NOT transactional.
+      return new NonTransactionalStateTaskRestoreManager(
+          taskModel,
+          changelogSystemStreams,
+          taskStores,
+          systemAdmins,
+          streamMetadataCache,
+          storeConsumers,
+          maxChangeLogStreamPartitions,
+          loggedStoreBaseDirectory,
+          nonLoggedStoreBaseDirectory,
+          config,
+          clock);
+    }
+  }
+}
diff --git a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
index 922c13e..d15c109 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
@@ -335,13 +335,13 @@ object SamzaContainer extends Logging {
 
     debug("Got system stream message serdes: %s" format systemStreamMessageSerdes)
 
-    val changeLogSystemStreams = storageConfig
+    val storeChangelogs = storageConfig
       .getStoreNames.asScala
       .filter(storageConfig.getChangelogStream(_).isPresent)
       .map(name => (name, storageConfig.getChangelogStream(name).get)).toMap
       .mapValues(StreamUtil.getSystemStreamFromNames(_))
 
-    info("Got change log system streams: %s" format changeLogSystemStreams)
+    info("Got change log system streams: %s" format storeChangelogs)
 
     /*
      * This keeps track of the changelog SSPs that are associated with the whole container. This is used so that we can
@@ -358,7 +358,7 @@ object SamzaContainer extends Logging {
     val changelogSSPMetadataCache = new SSPMetadataCache(systemAdmins,
       Duration.ofSeconds(5),
       SystemClock.instance,
-      getChangelogSSPsForContainer(containerModel, changeLogSystemStreams).asJava)
+      getChangelogSSPsForContainer(containerModel, storeChangelogs).asJava)
 
     val intermediateStreams = streamConfig
       .getStreamIds()
@@ -390,7 +390,7 @@ object SamzaContainer extends Logging {
       systemMessageSerdes = systemMessageSerdes,
       systemStreamKeySerdes = systemStreamKeySerdes,
       systemStreamMessageSerdes = systemStreamMessageSerdes,
-      changeLogSystemStreams = changeLogSystemStreams.values.toSet,
+      changeLogSystemStreams = storeChangelogs.values.toSet,
       controlMessageKeySerdes = controlMessageKeySerdes,
       intermediateMessageSerdes = intermediateStreamMessageSerdes)
 
@@ -509,10 +509,13 @@ object SamzaContainer extends Logging {
     val loggedStorageBaseDir = getLoggedStorageBaseDir(jobConfig, defaultStoreBaseDir)
     info("Got base directory for logged data stores: %s" format loggedStorageBaseDir)
 
-    val containerStorageManager = new ContainerStorageManager(containerModel,
+    val containerStorageManager = new ContainerStorageManager(
+      checkpointManager,
+      containerModel,
       streamMetadataCache,
+      changelogSSPMetadataCache,
       systemAdmins,
-      changeLogSystemStreams.asJava,
+      storeChangelogs.asJava,
       sideInputStoresToSystemStreams.mapValues(systemStreamSet => systemStreamSet.toSet.asJava).asJava,
       storageEngineFactories.asJava,
       systemFactories.asJava,
@@ -552,13 +555,15 @@ object SamzaContainer extends Logging {
       val taskSideInputSSPs = sideInputStoresToSSPs.values.flatMap(_.asScala).toSet
       info ("Got task side input SSPs: %s" format taskSideInputSSPs)
 
-      val storageManager = new TaskStorageManager(
-        taskName = taskName,
-        containerStorageManager = containerStorageManager,
-        changeLogSystemStreams = changeLogSystemStreams,
-        sspMetadataCache = changelogSSPMetadataCache,
-        loggedStoreBaseDir = loggedStorageBaseDir,
-        partition = taskModel.getChangelogPartition)
+      val storageManager = TaskStorageManagerFactory.create(
+        taskName,
+        containerStorageManager,
+        storeChangelogs,
+        systemAdmins,
+        loggedStorageBaseDir,
+        taskModel.getChangelogPartition,
+        config,
+        taskModel.getTaskMode)
 
       val tableManager = new TableManager(config)
 
@@ -742,11 +747,13 @@ class SamzaContainer(
       startAdmins
       startOffsetManager
       storeContainerLocality
+      // TODO HIGH pmaheshw SAMZA-2338: since store restore needs to trim changelog messages,
+      // need to start changelog producers before the stores, but stop them after stores.
+      startProducers
       startStores
       startTableManager
       startDiskSpaceMonitor
       startHostStatisticsMonitor
-      startProducers
       startTask
       startConsumers
       startSecurityManger
diff --git a/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java b/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java
index 58ec045..c55253f 100644
--- a/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java
+++ b/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java
@@ -44,8 +44,9 @@ import java.util.concurrent.ScheduledFuture;
 import java.util.concurrent.TimeUnit;
 import java.util.stream.Collectors;
 import org.apache.commons.collections4.MapUtils;
-import org.apache.samza.Partition;
 import org.apache.samza.SamzaException;
+import org.apache.samza.checkpoint.Checkpoint;
+import org.apache.samza.checkpoint.CheckpointManager;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.StorageConfig;
 import org.apache.samza.config.TaskConfig;
@@ -64,10 +65,9 @@ import org.apache.samza.serializers.Serde;
 import org.apache.samza.serializers.SerdeManager;
 import org.apache.samza.storage.kv.Entry;
 import org.apache.samza.storage.kv.KeyValueStore;
-import org.apache.samza.system.ChangelogSSPIterator;
 import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.SSPMetadataCache;
 import org.apache.samza.system.StreamMetadataCache;
-import org.apache.samza.system.StreamSpec;
 import org.apache.samza.system.SystemAdmin;
 import org.apache.samza.system.SystemAdmins;
 import org.apache.samza.system.SystemConsumer;
@@ -83,13 +83,11 @@ import org.apache.samza.system.chooser.RoundRobinChooserFactory;
 import org.apache.samza.table.utils.SerdeUtils;
 import org.apache.samza.task.TaskInstanceCollector;
 import org.apache.samza.util.Clock;
-import org.apache.samza.util.FileUtil;
 import org.apache.samza.util.ReflectionUtil;
 import org.apache.samza.util.ScalaJavaUtil;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import scala.collection.JavaConversions;
-import scala.collection.JavaConverters;
 
 
 /**
@@ -134,8 +132,10 @@ public class ContainerStorageManager {
   private final SystemAdmins systemAdmins;
 
   private final StreamMetadataCache streamMetadataCache;
+  private final SSPMetadataCache sspMetadataCache;
   private final SamzaContainerMetrics samzaContainerMetrics;
 
+  private final CheckpointManager checkpointManager;
   /* Parameters required to re-create taskStores post-restoration */
   private final ContainerModel containerModel;
   private final JobContext jobContext;
@@ -170,8 +170,11 @@ public class ContainerStorageManager {
   private final Config config;
   private final StorageManagerUtil storageManagerUtil = new StorageManagerUtil();
 
-  public ContainerStorageManager(ContainerModel containerModel,
+  public ContainerStorageManager(
+      CheckpointManager checkpointManager,
+      ContainerModel containerModel,
       StreamMetadataCache streamMetadataCache,
+      SSPMetadataCache sspMetadataCache,
       SystemAdmins systemAdmins,
       Map<String, SystemStream> changelogSystemStreams,
       Map<String, Set<SystemStream>> sideInputSystemStreams,
@@ -189,11 +192,11 @@ public class ContainerStorageManager {
       int maxChangeLogStreamPartitions,
       SerdeManager serdeManager,
       Clock clock) {
-
+    this.checkpointManager = checkpointManager;
     this.containerModel = containerModel;
     this.sideInputSystemStreams = new HashMap<>(sideInputSystemStreams);
     this.taskSideInputSSPs = getTaskSideInputSSPs(containerModel, sideInputSystemStreams);
-
+    this.sspMetadataCache = sspMetadataCache;
     this.changelogSystemStreams = getChangelogSystemStreams(containerModel, changelogSystemStreams); // handling standby tasks
 
     LOG.info("Starting with changelogSystemStreams = {} sideInputSystemStreams = {}", this.changelogSystemStreams, this.sideInputSystemStreams);
@@ -366,7 +369,10 @@ public class ContainerStorageManager {
     Map<TaskName, TaskRestoreManager> taskRestoreManagers = new HashMap<>();
     containerModel.getTasks().forEach((taskName, taskModel) -> {
         taskRestoreManagers.put(taskName,
-            new TaskRestoreManager(taskModel, changelogSystemStreams, getNonSideInputStores(taskName), systemAdmins, clock));
+            TaskRestoreManagerFactory.create(
+                taskModel, changelogSystemStreams, getNonSideInputStores(taskName), systemAdmins,
+                streamMetadataCache, sspMetadataCache, storeConsumers, maxChangeLogStreamPartitions,
+                loggedStoreBaseDirectory, nonLoggedStoreBaseDirectory, config, clock));
         samzaContainerMetrics.addStoresRestorationGauge(taskName);
       });
     return taskRestoreManagers;
@@ -622,18 +628,36 @@ public class ContainerStorageManager {
   }
 
   public void start() throws SamzaException {
-    restoreStores();
+    Map<SystemStreamPartition, String> checkpointedChangelogSSPOffsets = new HashMap<>();
+    if (new TaskConfig(config).getTransactionalStateEnabled()) {
+      getTasks(containerModel, TaskMode.Active).forEach((taskName, taskModel) -> {
+          if (checkpointManager != null) {
+            Set<SystemStream> changelogSystemStreams = new HashSet<>(this.changelogSystemStreams.values());
+            Checkpoint checkpoint = checkpointManager.readLastCheckpoint(taskName);
+            if (checkpoint != null) {
+              checkpoint.getOffsets().forEach((ssp, offset) -> {
+                  if (changelogSystemStreams.contains(new SystemStream(ssp.getSystem(), ssp.getStream()))) {
+                    checkpointedChangelogSSPOffsets.put(ssp, offset);
+                  }
+                });
+            }
+          }
+        });
+    }
+    LOG.info("Checkpointed changelog ssp offsets: {}", checkpointedChangelogSSPOffsets);
+    restoreStores(checkpointedChangelogSSPOffsets);
     if (sideInputsPresent()) {
       startSideInputs();
     }
   }
 
   // Restoration of all stores, in parallel across tasks
-  private void restoreStores() {
+  private void restoreStores(Map<SystemStreamPartition, String> checkpointedChangelogSSPOffsets) {
     LOG.info("Store Restore started");
 
     // initialize each TaskStorageManager
-    this.taskRestoreManagers.values().forEach(taskStorageManager -> taskStorageManager.initialize());
+    this.taskRestoreManagers.values().forEach(taskStorageManager ->
+       taskStorageManager.init(checkpointedChangelogSSPOffsets));
 
     // Start each store consumer once
     this.storeConsumers.values().stream().distinct().forEach(systemConsumer -> systemConsumer.start());
@@ -869,7 +893,7 @@ public class ContainerStorageManager {
   }
 
   /**
-   * Callable for performing the restoreStores on a task restore manager and emitting the task-restoration metric.
+   * Callable for performing the restore on a task restore manager and emitting the task-restoration metric.
    * After restoration, all persistent stores are stopped (which will invoke compaction in case of certain persistent
    * stores that were opened in bulk-load mode).
    * Performing stop here parallelizes this compaction, which is a time-intensive operation.
@@ -892,7 +916,7 @@ public class ContainerStorageManager {
     public Void call() {
       long startTime = System.currentTimeMillis();
       LOG.info("Starting stores in task instance {}", this.taskName.getTaskName());
-      taskRestoreManager.restoreStores();
+      taskRestoreManager.restore();
 
       // Stop all persistent stores after restoring. Certain persistent stores opened in BulkLoad mode are compacted
       // on stop, so paralleling stop() also parallelizes their compaction (a time-intensive operation).
@@ -910,288 +934,4 @@ public class ContainerStorageManager {
       return null;
     }
   }
-
-  /**
-   * Restore logic for all stores of a task including directory cleanup, setup, changelogSSP validation, registering
-   * with the respective consumer, restoring stores, and stopping stores.
-   */
-  private class TaskRestoreManager {
-    private final Map<String, StorageEngine> taskStores; // Map of all StorageEngines for this task indexed by store name
-    private final Set<String> taskStoresToRestore;
-    // Set of store names which need to be restored by consuming using system-consumers (see registerStartingOffsets)
-
-    private final TaskModel taskModel;
-    private final Clock clock; // Clock value used to validate base-directories for staleness. See isLoggedStoreValid.
-    private Map<SystemStream, String> changeLogOldestOffsets; // Map of changelog oldest known offsets
-    private final Map<SystemStreamPartition, String> fileOffsets; // Map of offsets read from offset file indexed by changelog SSP
-    private final Map<String, SystemStream> changelogSystemStreams; // Map of change log system-streams indexed by store name
-    private final SystemAdmins systemAdmins;
-
-    public TaskRestoreManager(TaskModel taskModel, Map<String, SystemStream> changelogSystemStreams,
-        Map<String, StorageEngine> taskStores, SystemAdmins systemAdmins, Clock clock) {
-      this.taskStores = taskStores;
-      this.taskModel = taskModel;
-      this.clock = clock;
-      this.changelogSystemStreams = changelogSystemStreams;
-      this.systemAdmins = systemAdmins;
-      this.fileOffsets = new HashMap<>();
-      this.taskStoresToRestore = this.taskStores.entrySet().stream()
-          .filter(x -> x.getValue().getStoreProperties().isLoggedStore())
-          .map(x -> x.getKey()).collect(Collectors.toSet());
-    }
-
-    /**
-     * Cleans up and sets up store directories, validates changeLog SSPs for all stores of this task,
-     * and registers SSPs with the respective consumers.
-     */
-    public void initialize() {
-      cleanBaseDirsAndReadOffsetFiles();
-      setupBaseDirs();
-      validateChangelogStreams();
-      getOldestChangeLogOffsets();
-      registerStartingOffsets();
-    }
-
-    /**
-     * For each store for this task,
-     * a. Deletes the corresponding non-logged-store base dir.
-     * b. Deletes the logged-store-base-dir if it not valid. See {@link #isLoggedStoreValid} for validation semantics.
-     * c. If the logged-store-base-dir is valid, this method reads the offset file and stores each offset.
-     */
-    private void cleanBaseDirsAndReadOffsetFiles() {
-      LOG.debug("Cleaning base directories for stores.");
-      FileUtil fileUtil = new FileUtil();
-      taskStores.forEach((storeName, storageEngine) -> {
-          if (!storageEngine.getStoreProperties().isLoggedStore()) {
-            File nonLoggedStorePartitionDir =
-                storageManagerUtil.getTaskStoreDir(nonLoggedStoreBaseDirectory, storeName, taskModel.getTaskName(),
-                    taskModel.getTaskMode());
-            LOG.info("Got non logged storage partition directory as " + nonLoggedStorePartitionDir.toPath().toString());
-
-            if (nonLoggedStorePartitionDir.exists()) {
-              LOG.info("Deleting non logged storage partition directory " + nonLoggedStorePartitionDir.toPath().toString());
-              fileUtil.rm(nonLoggedStorePartitionDir);
-            }
-          } else {
-            File loggedStorePartitionDir =
-                storageManagerUtil.getTaskStoreDir(loggedStoreBaseDirectory, storeName, taskModel.getTaskName(),
-                    taskModel.getTaskMode());
-            LOG.info("Got logged storage partition directory as " + loggedStorePartitionDir.toPath().toString());
-
-            // Delete the logged store if it is not valid.
-            if (!isLoggedStoreValid(storeName, loggedStorePartitionDir)) {
-              LOG.info("Deleting logged storage partition directory " + loggedStorePartitionDir.toPath().toString());
-              fileUtil.rm(loggedStorePartitionDir);
-            } else {
-
-              SystemStreamPartition changelogSSP = new SystemStreamPartition(changelogSystemStreams.get(storeName), taskModel.getChangelogPartition());
-              Map<SystemStreamPartition, String> offset =
-                  storageManagerUtil.readOffsetFile(loggedStorePartitionDir, Collections.singleton(changelogSSP), false);
-              LOG.info("Read offset {} for the store {} from logged storage partition directory {}", offset, storeName, loggedStorePartitionDir);
-
-              if (offset.containsKey(changelogSSP)) {
-                fileOffsets.put(changelogSSP, offset.get(changelogSSP));
-              }
-            }
-          }
-        });
-    }
-
-    /**
-     * Directory loggedStoreDir associated with the logged store storeName is determined to be valid
-     * if all of the following conditions are true.
-     * a) If the store has to be persisted to disk.
-     * b) If there is a valid offset file associated with the logged store.
-     * c) If the logged store has not gone stale.
-     *
-     * @return true if the logged store is valid, false otherwise.
-     */
-    private boolean isLoggedStoreValid(String storeName, File loggedStoreDir) {
-      long changeLogDeleteRetentionInMs = new StorageConfig(config).getChangeLogDeleteRetentionInMs(storeName);
-      if (changelogSystemStreams.containsKey(storeName)) {
-        SystemStreamPartition changelogSSP = new SystemStreamPartition(changelogSystemStreams.get(storeName), taskModel.getChangelogPartition());
-        return this.taskStores.get(storeName).getStoreProperties().isPersistedToDisk()
-            && storageManagerUtil.isOffsetFileValid(loggedStoreDir, Collections.singleton(changelogSSP), false)
-            && !storageManagerUtil.isStaleStore(loggedStoreDir, changeLogDeleteRetentionInMs, clock.currentTimeMillis(), false);
-      }
-
-      return false;
-    }
-
-    /**
-     * Create stores' base directories for logged-stores if they dont exist.
-     */
-    private void setupBaseDirs() {
-      LOG.debug("Setting up base directories for stores.");
-      taskStores.forEach((storeName, storageEngine) -> {
-          if (storageEngine.getStoreProperties().isLoggedStore()) {
-
-            File loggedStorePartitionDir =
-                storageManagerUtil.getTaskStoreDir(loggedStoreBaseDirectory, storeName, taskModel.getTaskName(), taskModel.getTaskMode());
-
-            LOG.info("Using logged storage partition directory: " + loggedStorePartitionDir.toPath().toString()
-                + " for store: " + storeName);
-
-            if (!loggedStorePartitionDir.exists()) {
-              loggedStorePartitionDir.mkdirs();
-            }
-          } else {
-            File nonLoggedStorePartitionDir =
-                storageManagerUtil.getTaskStoreDir(nonLoggedStoreBaseDirectory, storeName, taskModel.getTaskName(), taskModel.getTaskMode());
-            LOG.info("Using non logged storage partition directory: " + nonLoggedStorePartitionDir.toPath().toString()
-                + " for store: " + storeName);
-            nonLoggedStorePartitionDir.mkdirs();
-          }
-        });
-    }
-
-    /**
-     *  Validates each changelog system-stream with its respective SystemAdmin.
-     */
-    private void validateChangelogStreams() {
-      LOG.info("Validating change log streams: " + changelogSystemStreams);
-
-      for (SystemStream changelogSystemStream : changelogSystemStreams.values()) {
-        SystemAdmin systemAdmin = systemAdmins.getSystemAdmin(changelogSystemStream.getSystem());
-        StreamSpec changelogSpec =
-            StreamSpec.createChangeLogStreamSpec(changelogSystemStream.getStream(), changelogSystemStream.getSystem(),
-                maxChangeLogStreamPartitions);
-
-        systemAdmin.validateStream(changelogSpec);
-      }
-    }
-
-    /**
-     * Get the oldest offset for each changelog SSP based on the stream's metadata (obtained from streamMetadataCache).
-     */
-    private void getOldestChangeLogOffsets() {
-
-      Map<SystemStream, SystemStreamMetadata> changeLogMetadata = JavaConverters.mapAsJavaMapConverter(
-          streamMetadataCache.getStreamMetadata(
-              JavaConverters.asScalaSetConverter(new HashSet<>(changelogSystemStreams.values())).asScala().toSet(),
-              false)).asJava();
-
-      LOG.info("Got change log stream metadata: {}", changeLogMetadata);
-
-      changeLogOldestOffsets =
-          getChangeLogOldestOffsetsForPartition(taskModel.getChangelogPartition(), changeLogMetadata);
-      LOG.info("Assigning oldest change log offsets for taskName {} : {}", taskModel.getTaskName(),
-          changeLogOldestOffsets);
-    }
-
-    /**
-     * Builds a map from SystemStreamPartition to oldest offset for changelogs.
-     */
-    private Map<SystemStream, String> getChangeLogOldestOffsetsForPartition(Partition partition,
-        Map<SystemStream, SystemStreamMetadata> inputStreamMetadata) {
-
-      Map<SystemStream, String> retVal = new HashMap<>();
-
-      // NOTE: do not use Collectors.Map because of https://bugs.openjdk.java.net/browse/JDK-8148463
-      inputStreamMetadata.entrySet()
-          .stream()
-          .filter(x -> x.getValue().getSystemStreamPartitionMetadata().get(partition) != null)
-          .forEach(e -> retVal.put(e.getKey(),
-              e.getValue().getSystemStreamPartitionMetadata().get(partition).getOldestOffset()));
-
-      return retVal;
-    }
-
-    /**
-     * Determines the starting offset for each store SSP (based on {@link #getStartingOffset(SystemStreamPartition, SystemAdmin)}) and
-     * registers it with the respective SystemConsumer for starting consumption.
-     */
-    private void registerStartingOffsets() {
-
-      for (Map.Entry<String, SystemStream> changelogSystemStreamEntry : changelogSystemStreams.entrySet()) {
-        SystemStreamPartition systemStreamPartition =
-            new SystemStreamPartition(changelogSystemStreamEntry.getValue(), taskModel.getChangelogPartition());
-        SystemAdmin systemAdmin = systemAdmins.getSystemAdmin(changelogSystemStreamEntry.getValue().getSystem());
-        SystemConsumer systemConsumer = storeConsumers.get(changelogSystemStreamEntry.getKey());
-
-        String offset = getStartingOffset(systemStreamPartition, systemAdmin);
-
-        if (offset != null) {
-          LOG.info("Registering change log consumer with offset " + offset + " for %" + systemStreamPartition);
-          systemConsumer.register(systemStreamPartition, offset);
-        } else {
-          LOG.info("Skipping change log restoration for {} because stream appears to be empty (offset was null).",
-              systemStreamPartition);
-          taskStoresToRestore.remove(changelogSystemStreamEntry.getKey());
-        }
-      }
-    }
-
-    /**
-     * Returns the offset with which the changelog consumer should be initialized for the given SystemStreamPartition.
-     *
-     * If a file offset exists, it represents the last changelog offset which is also reflected in the on-disk state.
-     * In that case, we use the next offset after the file offset, as long as it is newer than the oldest offset
-     * currently available in the stream.
-     *
-     * If there isn't a file offset or it's older than the oldest available offset, we simply start with the oldest.
-     *
-     * @param systemStreamPartition  the changelog partition for which the offset is needed.
-     * @param systemAdmin                  the [[SystemAdmin]] for the changelog.
-     * @return the offset to from which the changelog consumer should be initialized.
-     */
-    private String getStartingOffset(SystemStreamPartition systemStreamPartition, SystemAdmin systemAdmin) {
-      String fileOffset = fileOffsets.get(systemStreamPartition);
-
-      // NOTE: changeLogOldestOffsets may contain a null-offset for the given SSP (signifying an empty stream)
-      // therefore, we need to differentiate that from the case where the offset is simply missing
-      if (!changeLogOldestOffsets.containsKey(systemStreamPartition.getSystemStream())) {
-        throw new SamzaException("Missing a change log offset for " + systemStreamPartition);
-      }
-
-      String oldestOffset = changeLogOldestOffsets.get(systemStreamPartition.getSystemStream());
-      return storageManagerUtil.getStartingOffset(systemStreamPartition, systemAdmin, fileOffset, oldestOffset);
-    }
-
-
-    /**
-     * Restore each store in taskStoresToRestore sequentially
-     */
-    public void restoreStores() {
-      LOG.debug("Restoring stores for task: {}", taskModel.getTaskName());
-
-      for (String storeName : taskStoresToRestore) {
-        SystemConsumer systemConsumer = storeConsumers.get(storeName);
-        SystemStream systemStream = changelogSystemStreams.get(storeName);
-        SystemAdmin systemAdmin = systemAdmins.getSystemAdmin(systemStream.getSystem());
-
-        // TODO HIGH pmaheshw: use actual changelog topic newest offset instead of trimEnabled flag
-        ChangelogSSPIterator changelogSSPIterator = new ChangelogSSPIterator(systemConsumer,
-            new SystemStreamPartition(systemStream, taskModel.getChangelogPartition()), null, systemAdmin, false);
-
-        taskStores.get(storeName).restore(changelogSSPIterator);
-      }
-    }
-
-    /**
-     * Stop all stores.
-     */
-    public void stop() {
-      this.taskStores.values().forEach(storageEngine -> {
-          storageEngine.stop();
-        });
-    }
-
-    /**
-     * Stop only persistent stores. In case of certain stores and store mode (such as RocksDB), this
-     * can invoke compaction.
-     */
-    public void stopPersistentStores() {
-
-      Map<String, StorageEngine> persistentStores = this.taskStores.entrySet().stream().filter(e -> {
-          return e.getValue().getStoreProperties().isPersistedToDisk();
-        }).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
-
-      persistentStores.forEach((storeName, storageEngine) -> {
-          storageEngine.stop();
-          this.taskStores.remove(storeName);
-        });
-      LOG.info("Stopped persistent stores {}", persistentStores);
-    }
-  }
 }
diff --git a/samza-core/src/main/scala/org/apache/samza/storage/NonTransactionalStateTaskStorageManager.scala b/samza-core/src/main/scala/org/apache/samza/storage/NonTransactionalStateTaskStorageManager.scala
new file mode 100644
index 0000000..9a81cb6
--- /dev/null
+++ b/samza-core/src/main/scala/org/apache/samza/storage/NonTransactionalStateTaskStorageManager.scala
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.storage
+
+import java.io._
+
+import com.google.common.annotations.VisibleForTesting
+import com.google.common.collect.ImmutableSet
+import org.apache.samza.container.TaskName
+import org.apache.samza.job.model.TaskMode
+import org.apache.samza.system._
+import org.apache.samza.util.Logging
+import org.apache.samza.util.ScalaJavaUtil.JavaOptionals
+import org.apache.samza.{Partition, SamzaException}
+
+import scala.collection.JavaConverters._
+
+/**
+ * Manage all the storage engines for a given task
+ */
+class NonTransactionalStateTaskStorageManager(
+  taskName: TaskName,
+  containerStorageManager: ContainerStorageManager,
+  storeChangelogs: Map[String, SystemStream] = Map(),
+  systemAdmins: SystemAdmins,
+  loggedStoreBaseDir: File = new File(System.getProperty("user.dir"), "state"),
+  partition: Partition) extends Logging with TaskStorageManager {
+
+  private val storageManagerUtil = new StorageManagerUtil
+  private val persistedStores = containerStorageManager.getAllStores(taskName).asScala
+    .filter { case (storeName, storageEngine) => storageEngine.getStoreProperties.isPersistedToDisk }
+
+  def getStore(storeName: String): Option[StorageEngine] =  JavaOptionals.toRichOptional(containerStorageManager.getStore(taskName, storeName)).toOption
+
+  def flush(): Map[SystemStreamPartition, Option[String]] = {
+    debug("Flushing stores.")
+    containerStorageManager.getAllStores(taskName).asScala.values.foreach(_.flush)
+    val newestChangelogSSPOffsets = getNewestChangelogSSPOffsets()
+    writeChangelogOffsetFiles(newestChangelogSSPOffsets)
+    newestChangelogSSPOffsets
+  }
+
+  def checkpoint(newestChangelogOffsets: Map[SystemStreamPartition, Option[String]]): String = {
+    null
+  }
+
+  override def removeOldCheckpoints(checkpointId: String): Unit =  {}
+
+  @VisibleForTesting
+  def stop() {
+    debug("Stopping stores.")
+    containerStorageManager.stopStores()
+  }
+
+  /**
+   * Returns the newest offset for each store changelog SSP for this task.
+   * @return A map of changelog SSPs for this task to their newest offset (or None if ssp is empty)
+   * @throws SamzaException if there was an error fetching newest offset for any SSP
+   */
+  private def getNewestChangelogSSPOffsets(): Map[SystemStreamPartition, Option[String]] = {
+    storeChangelogs
+      .map { case (storeName, systemStream) => {
+        debug("Fetching newest offset for taskName %s store %s changelog %s" format (taskName, storeName, systemStream))
+        val ssp = new SystemStreamPartition(systemStream.getSystem, systemStream.getStream, partition)
+        val systemAdmin = systemAdmins.getSystemAdmin(systemStream.getSystem)
+
+        try {
+          val sspMetadataOption = Option(systemAdmin.getSSPMetadata(ImmutableSet.of(ssp)).get(ssp))
+
+          // newest offset == null implies topic is empty
+          val newestOffsetOption = sspMetadataOption.flatMap(sspMetadata => Option(sspMetadata.getNewestOffset))
+          newestOffsetOption.foreach(newestOffset =>
+            debug("Got newest offset %s for taskName %s store %s changelog %s" format(newestOffset, taskName, storeName, systemStream)))
+
+          (ssp, newestOffsetOption)
+        } catch {
+          case e: Exception =>
+            throw new SamzaException("Error getting newest changelog offset for taskName %s store %s changelog %s."
+              format(taskName, storeName, systemStream), e)
+        }
+      }}
+  }
+
+  /**
+   * Writes the newest changelog ssp offset for each persistent store to the OFFSET file on disk.
+   * These files are used during container startup to determine whether there is any new information in the
+   * changelog that is not reflected in the on-disk copy of the store. If there is any delta, it is replayed
+   * from the changelog e.g. This can happen if the job was run on this host, then another
+   * host and back to this host.
+   */
+  private def writeChangelogOffsetFiles(newestChangelogOffsets: Map[SystemStreamPartition, Option[String]]) {
+    debug("Writing OFFSET files for logged persistent key value stores for task %s." format(taskName))
+
+    storeChangelogs
+      .filterKeys(storeName => persistedStores.contains(storeName))
+      .foreach { case (storeName, systemStream) => {
+        debug("Writing changelog offset for taskName %s store %s changelog %s." format(taskName, storeName, systemStream))
+        try {
+          val ssp = new SystemStreamPartition(systemStream.getSystem, systemStream.getStream, partition)
+          newestChangelogOffsets(ssp) match {
+            case Some(newestOffset) => {
+              debug("Storing newest offset %s for taskName %s store %s changelog %s in OFFSET file."
+                format(newestOffset, taskName, storeName, systemStream))
+              // TaskStorageManagers are only created for active tasks
+              val currentStoreDir = storageManagerUtil.getTaskStoreDir(loggedStoreBaseDir, storeName, taskName, TaskMode.Active)
+              storageManagerUtil.writeOffsetFile(currentStoreDir, Map(ssp -> newestOffset).asJava, false)
+              debug("Successfully stored offset %s for taskName %s store %s changelog %s in OFFSET file."
+                format(newestOffset, taskName, storeName, systemStream))
+            }
+            case None => {
+              // if newestOffset is null, then it means the changelog ssp is (or has become) empty. This could be
+              // either because the changelog topic was newly added, repartitioned, or manually deleted and recreated.
+              // No need to persist the offset file.
+              storageManagerUtil.deleteOffsetFile(loggedStoreBaseDir, storeName, taskName)
+              debug("Deleting OFFSET file for taskName %s store %s changelog ssp %s since the newestOffset is null."
+                format (taskName, storeName, ssp))
+            }
+          }
+        } catch {
+          case e: Exception =>
+            throw new SamzaException("Error storing offset for taskName %s store %s changelog %s."
+              format(taskName, storeName, systemStream), e)
+        }
+      }}
+    debug("Done writing OFFSET files for logged persistent key value stores for task %s" format(taskName))
+  }
+}
diff --git a/samza-core/src/main/scala/org/apache/samza/storage/TaskStorageManager.scala b/samza-core/src/main/scala/org/apache/samza/storage/TaskStorageManager.scala
index be089c6..c98461b 100644
--- a/samza-core/src/main/scala/org/apache/samza/storage/TaskStorageManager.scala
+++ b/samza-core/src/main/scala/org/apache/samza/storage/TaskStorageManager.scala
@@ -19,95 +19,18 @@
 
 package org.apache.samza.storage
 
-import java.io._
+import org.apache.samza.system.SystemStreamPartition
 
-import com.google.common.annotations.VisibleForTesting
-import org.apache.samza.Partition
-import org.apache.samza.container.TaskName
-import org.apache.samza.job.model.TaskMode
-import org.apache.samza.system._
-import org.apache.samza.util.ScalaJavaUtil.JavaOptionals
-import org.apache.samza.util.{FileUtil, Logging}
+trait TaskStorageManager {
 
-import scala.collection.JavaConverters._
+  def getStore(storeName: String): Option[StorageEngine]
 
-/**
- * Manage all the storage engines for a given task
- */
-class TaskStorageManager(
-  taskName: TaskName,
-  containerStorageManager: ContainerStorageManager,
-  changeLogSystemStreams: Map[String, SystemStream] = Map(),
-  sspMetadataCache: SSPMetadataCache,
-  loggedStoreBaseDir: File = new File(System.getProperty("user.dir"), "state"),
-  partition: Partition) extends Logging {
-
-  val persistedStores = containerStorageManager.getAllStores(taskName).asScala.filter{
-    case (storeName, storageEngine) => storageEngine.getStoreProperties.isPersistedToDisk
-  }
-
-  def getStore(storeName: String): Option[StorageEngine] =  JavaOptionals.toRichOptional(containerStorageManager.getStore(taskName, storeName)).toOption
-
-  def init {
-  }
-
-  def flush() {
-    debug("Flushing stores.")
-
-    containerStorageManager.getAllStores(taskName).asScala.values.foreach(_.flush)
-    flushChangelogOffsetFiles()
-  }
-
-  def stopStores() {
-    debug("Stopping stores.")
-    containerStorageManager.stopStores();
-  }
-
-  @VisibleForTesting
-  def stop() {
-    stopStores()
-
-    flushChangelogOffsetFiles()
-  }
-
-  /**
-    * Writes the offset files for each changelog to disk.
-    * These files are used when stores are restored from disk to determine whether
-    * there is any new information in the changelog that is not reflected in the disk
-    * copy of the store. If there is any delta, it is replayed from the changelog
-    * e.g. This can happen if the job was run on this host, then another
-    * host and back to this host.
-    */
-  private def flushChangelogOffsetFiles() {
-    debug("Persisting logged key value stores")
-
-    for ((storeName, systemStream) <- changeLogSystemStreams.filterKeys(storeName => persistedStores.contains(storeName))) {
-      debug("Fetching newest offset for store %s" format(storeName))
-      try {
-        val ssp = new SystemStreamPartition(systemStream.getSystem, systemStream.getStream, partition)
-        val sspMetadata = sspMetadataCache.getMetadata(ssp)
-        val newestOffset = if (sspMetadata == null) null else sspMetadata.getNewestOffset
-        debug("Got offset %s for store %s" format(newestOffset, storeName))
+  def flush(): Map[SystemStreamPartition, Option[String]]
 
-        val storageManagerUtil = new StorageManagerUtil()
-        if (newestOffset != null) {
-          debug("Storing offset for store in OFFSET file ")
+  def checkpoint(newestChangelogOffsets: Map[SystemStreamPartition, Option[String]]): String
 
-          // TaskStorageManagers are only spun-up for active tasks
-          val currentStoreDir = storageManagerUtil.getTaskStoreDir(loggedStoreBaseDir, storeName, taskName, TaskMode.Active)
-          storageManagerUtil.writeOffsetFile(currentStoreDir, Map(ssp -> newestOffset).asJava, false)
-          debug("Successfully stored offset %s for store %s in OFFSET file " format(newestOffset, storeName))
-        } else {
-          //if newestOffset is null, then it means the store is (or has become) empty. No need to persist the offset file
-          storageManagerUtil.deleteOffsetFile(loggedStoreBaseDir, storeName, taskName);
-          debug("Not storing OFFSET file for taskName %s. Store %s backed by changelog topic: %s, partition: %s is empty. " format (taskName, storeName, systemStream.getStream, partition.getPartitionId))
-        }
-      } catch {
-        case e: Exception => error("Exception storing offset for store %s. Skipping." format(storeName), e)
-      }
+  def removeOldCheckpoints(checkpointId: String): Unit
 
-    }
+  def stop(): Unit
 
-    debug("Done persisting logged key value stores")
-  }
-}
+}
\ No newline at end of file
diff --git a/samza-core/src/main/scala/org/apache/samza/storage/TaskStorageManagerFactory.java b/samza-core/src/main/scala/org/apache/samza/storage/TaskStorageManagerFactory.java
new file mode 100644
index 0000000..571780b
--- /dev/null
+++ b/samza-core/src/main/scala/org/apache/samza/storage/TaskStorageManagerFactory.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.storage;
+
+import scala.collection.immutable.Map;
+
+import java.io.File;
+import org.apache.samza.Partition;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.TaskConfig;
+import org.apache.samza.container.TaskName;
+import org.apache.samza.job.model.TaskMode;
+import org.apache.samza.system.SystemAdmins;
+import org.apache.samza.system.SystemStream;
+
+public class TaskStorageManagerFactory {
+  public static TaskStorageManager create(TaskName taskName, ContainerStorageManager containerStorageManager,
+      Map<String, SystemStream> storeChangelogs, SystemAdmins systemAdmins,
+      File loggedStoreBaseDir, Partition changelogPartition,
+      Config config, TaskMode taskMode) {
+    if (new TaskConfig(config).getTransactionalStateEnabled()) {
+      throw new UnsupportedOperationException();
+    } else {
+      return new NonTransactionalStateTaskStorageManager(taskName, containerStorageManager, storeChangelogs, systemAdmins,
+          loggedStoreBaseDir, changelogPartition);
+    }
+  }
+}
diff --git a/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala b/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
index 47eb2a4..78136bf 100644
--- a/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
+++ b/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
@@ -22,29 +22,25 @@ package org.apache.samza.container
 import java.util
 import java.util.concurrent.atomic.AtomicReference
 
+import org.apache.samza.Partition
 import org.apache.samza.config.{ClusterManagerConfig, Config, MapConfig}
-import org.apache.samza.context.{ApplicationContainerContext, ContainerContext, JobContext}
+import org.apache.samza.context.{ApplicationContainerContext, ContainerContext}
 import org.apache.samza.coordinator.JobModelManager
 import org.apache.samza.coordinator.server.{HttpServer, JobServlet}
 import org.apache.samza.job.model.{ContainerModel, JobModel, TaskModel}
-import org.apache.samza.metrics.{Gauge, MetricsReporter, Timer}
-import org.apache.samza.storage.{ContainerStorageManager, TaskStorageManager}
+import org.apache.samza.metrics.Gauge
+import org.apache.samza.storage.ContainerStorageManager
 import org.apache.samza.system._
-import org.apache.samza.task.{StreamTaskFactory, TaskFactory}
-import org.apache.samza.Partition
 import org.junit.Assert._
 import org.junit.{Before, Test}
 import org.mockito.Matchers.{any, notNull}
 import org.mockito.Mockito._
-import org.mockito.invocation.InvocationOnMock
-import org.mockito.stubbing.Answer
-import org.mockito.{ArgumentCaptor, Mock, Mockito, MockitoAnnotations}
+import org.mockito.{Mock, Mockito, MockitoAnnotations}
 import org.scalatest.junit.AssertionsForJUnit
 import org.scalatest.mockito.MockitoSugar
 
 import scala.collection.JavaConversions._
 import scala.collection.JavaConverters._
-import scala.collection.mutable
 
 class TestSamzaContainer extends AssertionsForJUnit with MockitoSugar {
   private val TASK_NAME = new TaskName("taskName")
diff --git a/samza-core/src/test/scala/org/apache/samza/storage/TestContainerStorageManager.java b/samza-core/src/test/scala/org/apache/samza/storage/TestContainerStorageManager.java
index 794fd89..43872e9 100644
--- a/samza-core/src/test/scala/org/apache/samza/storage/TestContainerStorageManager.java
+++ b/samza-core/src/test/scala/org/apache/samza/storage/TestContainerStorageManager.java
@@ -24,8 +24,11 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Map;
 import org.apache.samza.Partition;
+import org.apache.samza.checkpoint.Checkpoint;
+import org.apache.samza.checkpoint.CheckpointManager;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.MapConfig;
+import org.apache.samza.config.TaskConfig;
 import org.apache.samza.container.SamzaContainerMetrics;
 import org.apache.samza.container.TaskInstance;
 import org.apache.samza.container.TaskInstanceMetrics;
@@ -37,6 +40,7 @@ import org.apache.samza.job.model.TaskModel;
 import org.apache.samza.metrics.Gauge;
 import org.apache.samza.serializers.Serde;
 import org.apache.samza.serializers.StringSerdeFactory;
+import org.apache.samza.system.SSPMetadataCache;
 import org.apache.samza.system.StreamMetadataCache;
 import org.apache.samza.system.SystemAdmin;
 import org.apache.samza.system.SystemAdmins;
@@ -44,15 +48,17 @@ import org.apache.samza.system.SystemConsumer;
 import org.apache.samza.system.SystemFactory;
 import org.apache.samza.system.SystemStream;
 import org.apache.samza.system.SystemStreamMetadata;
+import org.apache.samza.system.SystemStreamPartition;
 import org.apache.samza.util.SystemClock;
 import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
-import org.mockito.Mockito;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
 import scala.collection.JavaConverters;
 
+import static org.mockito.Mockito.*;
+
 
 public class TestContainerStorageManager {
 
@@ -79,16 +85,16 @@ public class TestContainerStorageManager {
    * @param taskname the desired taskname.
    */
   private void addMockedTask(String taskname, int changelogPartition) {
-    TaskInstance mockTaskInstance = Mockito.mock(TaskInstance.class);
-    Mockito.doAnswer(invocation -> {
+    TaskInstance mockTaskInstance = mock(TaskInstance.class);
+    doAnswer(invocation -> {
         return new TaskName(taskname);
       }).when(mockTaskInstance).taskName();
 
-    Gauge testGauge = Mockito.mock(Gauge.class);
+    Gauge testGauge = mock(Gauge.class);
     this.tasks.put(new TaskName(taskname),
         new TaskModel(new TaskName(taskname), new HashSet<>(), new Partition(changelogPartition)));
     this.taskRestoreMetricGauges.put(new TaskName(taskname), testGauge);
-    this.taskInstanceMetrics.put(new TaskName(taskname), Mockito.mock(TaskInstanceMetrics.class));
+    this.taskInstanceMetrics.put(new TaskName(taskname), mock(TaskInstanceMetrics.class));
   }
 
   /**
@@ -105,8 +111,8 @@ public class TestContainerStorageManager {
     addMockedTask("task 1", 1);
 
     // Mock container metrics
-    samzaContainerMetrics = Mockito.mock(SamzaContainerMetrics.class);
-    Mockito.when(samzaContainerMetrics.taskStoreRestorationMetrics()).thenReturn(taskRestoreMetricGauges);
+    samzaContainerMetrics = mock(SamzaContainerMetrics.class);
+    when(samzaContainerMetrics.taskStoreRestorationMetrics()).thenReturn(taskRestoreMetricGauges);
 
     // Create a map of test changeLogSSPs
     Map<String, SystemStream> changelogSystemStreams = new HashMap<>();
@@ -115,33 +121,35 @@ public class TestContainerStorageManager {
     // Create mocked storage engine factories
     Map<String, StorageEngineFactory<Object, Object>> storageEngineFactories = new HashMap<>();
     StorageEngineFactory mockStorageEngineFactory =
-        (StorageEngineFactory<Object, Object>) Mockito.mock(StorageEngineFactory.class);
-    StorageEngine mockStorageEngine = Mockito.mock(StorageEngine.class);
-    Mockito.doAnswer(invocation -> {
+        (StorageEngineFactory<Object, Object>) mock(StorageEngineFactory.class);
+    StorageEngine mockStorageEngine = mock(StorageEngine.class);
+    when(mockStorageEngine.getStoreProperties())
+        .thenReturn(new StoreProperties.StorePropertiesBuilder().setLoggedStore(true).setPersistedToDisk(true).build());
+    doAnswer(invocation -> {
         return mockStorageEngine;
-      }).when(mockStorageEngineFactory).getStorageEngine(Mockito.anyString(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(),
-            Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any());
+      }).when(mockStorageEngineFactory).getStorageEngine(anyString(), any(), any(), any(), any(),
+            any(), any(), any(), any(), any());
 
     storageEngineFactories.put(STORE_NAME, mockStorageEngineFactory);
 
     // Add instrumentation to mocked storage engine, to record the number of store.restore() calls
-    Mockito.doAnswer(invocation -> {
+    doAnswer(invocation -> {
         storeRestoreCallCount++;
         return null;
-      }).when(mockStorageEngine).restore(Mockito.any());
+      }).when(mockStorageEngine).restore(any());
 
     // Set the mocked stores' properties to be persistent
-    Mockito.doAnswer(invocation -> {
+    doAnswer(invocation -> {
         return new StoreProperties.StorePropertiesBuilder().setLoggedStore(true).build();
       }).when(mockStorageEngine).getStoreProperties();
 
     // Mock and setup sysconsumers
-    SystemConsumer mockSystemConsumer = Mockito.mock(SystemConsumer.class);
-    Mockito.doAnswer(invocation -> {
+    SystemConsumer mockSystemConsumer = mock(SystemConsumer.class);
+    doAnswer(invocation -> {
         systemConsumerStartCount++;
         return null;
       }).when(mockSystemConsumer).start();
-    Mockito.doAnswer(invocation -> {
+    doAnswer(invocation -> {
         systemConsumerStopCount++;
         return null;
       }).when(mockSystemConsumer).stop();
@@ -150,11 +158,11 @@ public class TestContainerStorageManager {
     Map<String, SystemFactory> systemFactories = new HashMap<>();
 
     // Count the number of sysConsumers created
-    SystemFactory mockSystemFactory = Mockito.mock(SystemFactory.class);
-    Mockito.doAnswer(invocation -> {
+    SystemFactory mockSystemFactory = mock(SystemFactory.class);
+    doAnswer(invocation -> {
         this.systemConsumerCreationCount++;
         return mockSystemConsumer;
-      }).when(mockSystemFactory).getConsumer(Mockito.anyString(), Mockito.any(), Mockito.any());
+      }).when(mockSystemFactory).getConsumer(anyString(), any(), any());
 
     systemFactories.put(SYSTEM_NAME, mockSystemFactory);
 
@@ -163,22 +171,23 @@ public class TestContainerStorageManager {
     configMap.put("stores." + STORE_NAME + ".key.serde", "stringserde");
     configMap.put("stores." + STORE_NAME + ".msg.serde", "stringserde");
     configMap.put("serializers.registry.stringserde.class", StringSerdeFactory.class.getName());
+    configMap.put(TaskConfig.TRANSACTIONAL_STATE_RETAIN_EXISTING_STATE, "true");
     Config config = new MapConfig(configMap);
 
     Map<String, Serde<Object>> serdes = new HashMap<>();
-    serdes.put("stringserde", Mockito.mock(Serde.class));
+    serdes.put("stringserde", mock(Serde.class));
 
     // Create mocked system admins
-    SystemAdmin mockSystemAdmin = Mockito.mock(SystemAdmin.class);
-    Mockito.doAnswer(new Answer<Void>() {
+    SystemAdmin mockSystemAdmin = mock(SystemAdmin.class);
+    doAnswer(new Answer<Void>() {
         public Void answer(InvocationOnMock invocation) {
           Object[] args = invocation.getArguments();
           System.out.println("called with arguments: " + Arrays.toString(args));
           return null;
         }
-      }).when(mockSystemAdmin).validateStream(Mockito.any());
-    SystemAdmins mockSystemAdmins = Mockito.mock(SystemAdmins.class);
-    Mockito.when(mockSystemAdmins.getSystemAdmin("kafka")).thenReturn(mockSystemAdmin);
+      }).when(mockSystemAdmin).validateStream(any());
+    SystemAdmins mockSystemAdmins = mock(SystemAdmins.class);
+    when(mockSystemAdmins.getSystemAdmin("kafka")).thenReturn(mockSystemAdmin);
 
     // Create a mocked mockStreamMetadataCache
     SystemStreamMetadata.SystemStreamPartitionMetadata sspMetadata =
@@ -187,14 +196,21 @@ public class TestContainerStorageManager {
     partitionMetadata.put(new Partition(0), sspMetadata);
     partitionMetadata.put(new Partition(1), sspMetadata);
     SystemStreamMetadata systemStreamMetadata = new SystemStreamMetadata(STREAM_NAME, partitionMetadata);
-    StreamMetadataCache mockStreamMetadataCache = Mockito.mock(StreamMetadataCache.class);
+    StreamMetadataCache mockStreamMetadataCache = mock(StreamMetadataCache.class);
 
-    Mockito.when(mockStreamMetadataCache.
+    when(mockStreamMetadataCache.
         getStreamMetadata(JavaConverters.
             asScalaSetConverter(new HashSet<SystemStream>(changelogSystemStreams.values())).asScala().toSet(), false))
         .thenReturn(
             new scala.collection.immutable.Map.Map1(new SystemStream(SYSTEM_NAME, STREAM_NAME), systemStreamMetadata));
 
+    CheckpointManager checkpointManager = mock(CheckpointManager.class);
+    when(checkpointManager.readLastCheckpoint(any(TaskName.class))).thenReturn(new Checkpoint(new HashMap<>()));
+
+    SSPMetadataCache mockSSPMetadataCache = mock(SSPMetadataCache.class);
+    when(mockSSPMetadataCache.getMetadata(any(SystemStreamPartition.class)))
+        .thenReturn(new SystemStreamMetadata.SystemStreamPartitionMetadata("0", "10", "11"));
+
     // Reset the  expected number of sysConsumer create, start and stop calls, and store.restore() calls
     this.systemConsumerCreationCount = 0;
     this.systemConsumerStartCount = 0;
@@ -202,8 +218,11 @@ public class TestContainerStorageManager {
     this.storeRestoreCallCount = 0;
 
     // Create the container storage manager
-    this.containerStorageManager = new ContainerStorageManager(new ContainerModel("samza-container-test", tasks),
+    this.containerStorageManager = new ContainerStorageManager(
+        checkpointManager,
+        new ContainerModel("samza-container-test", tasks),
         mockStreamMetadataCache,
+        mockSSPMetadataCache,
         mockSystemAdmins,
         changelogSystemStreams,
         new HashMap<>(),
@@ -213,9 +232,9 @@ public class TestContainerStorageManager {
         config,
         taskInstanceMetrics,
         samzaContainerMetrics,
-        Mockito.mock(JobContext.class),
-        Mockito.mock(ContainerContext.class),
-        Mockito.mock(Map.class),
+        mock(JobContext.class),
+        mock(ContainerContext.class),
+        mock(Map.class),
         DEFAULT_LOGGED_STORE_BASE_DIR,
         DEFAULT_STORE_BASE_DIR,
         2,
@@ -230,7 +249,7 @@ public class TestContainerStorageManager {
 
     for (Gauge gauge : taskRestoreMetricGauges.values()) {
       Assert.assertTrue("Restoration time gauge value should be invoked atleast once",
-          Mockito.mockingDetails(gauge).getInvocations().size() >= 1);
+          mockingDetails(gauge).getInvocations().size() >= 1);
     }
 
     Assert.assertTrue("Store restore count should be 2 because there are 2 tasks", this.storeRestoreCallCount == 2);
diff --git a/samza-core/src/test/scala/org/apache/samza/storage/TestTaskStorageManager.scala b/samza-core/src/test/scala/org/apache/samza/storage/TestTaskStorageManager.scala
index 5451854..b5f70fc 100644
--- a/samza-core/src/test/scala/org/apache/samza/storage/TestTaskStorageManager.scala
+++ b/samza-core/src/test/scala/org/apache/samza/storage/TestTaskStorageManager.scala
@@ -48,6 +48,8 @@ import org.scalatest.mockito.MockitoSugar
 import scala.collection.JavaConverters._
 import scala.collection.immutable.HashMap
 import scala.collection.mutable
+import com.google.common.collect.{ImmutableMap, ImmutableSet}
+import org.apache.samza.checkpoint.{Checkpoint, CheckpointManager}
 
 /**
   * This test is parameterized on the offsetFileName and is run for both
@@ -56,7 +58,7 @@ import scala.collection.mutable
   * @param offsetFileName the name of the offset file.
   */
 @RunWith(value = classOf[Parameterized])
-class TestTaskStorageManager(offsetFileName: String) extends MockitoSugar {
+class TestNonTransactionalStateTaskStorageManager(offsetFileName: String) extends MockitoSugar {
 
   val store = "store1"
   val loggedStore = "loggedStore1"
@@ -93,7 +95,7 @@ class TestTaskStorageManager(offsetFileName: String) extends MockitoSugar {
     val ss = new SystemStream("kafka", getStreamName(loggedStore))
     val partition = new Partition(0)
     val ssp = new SystemStreamPartition(ss, partition)
-    val storeDirectory = storageManagerUtil.getTaskStoreDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, 
+    val storeDirectory = storageManagerUtil.getTaskStoreDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir,
       loggedStore, taskName, TaskMode.Active)
     val storeFile = new File(storeDirectory, "store.sst")
     val offsetFile = new File(storeDirectory, offsetFileName)
@@ -102,7 +104,6 @@ class TestTaskStorageManager(offsetFileName: String) extends MockitoSugar {
 
     // Mock for StreamMetadataCache, SystemConsumer, SystemAdmin
     val mockStreamMetadataCache = mock[StreamMetadataCache]
-    val mockSSPMetadataCache = mock[SSPMetadataCache]
     val mockSystemConsumer = mock[SystemConsumer]
     val mockSystemAdmin = mock[SystemAdmin]
     val changelogSpec = StreamSpec.createChangeLogStreamSpec(getStreamName(loggedStore), "kafka", 1)
@@ -118,18 +119,15 @@ class TestTaskStorageManager(offsetFileName: String) extends MockitoSugar {
       }
     })
     when(mockStreamMetadataCache.getStreamMetadata(any(), any())).thenReturn(Map(ss -> metadata))
-    when(mockSSPMetadataCache.getMetadata(ssp)).thenReturn(sspMetadata)
+    when(mockSystemAdmin.getSSPMetadata(ImmutableSet.of(ssp))).thenReturn(ImmutableMap.of(ssp, sspMetadata))
 
     var taskManager = new TaskStorageManagerBuilder()
       .addStore(loggedStore, mockStorageEngine, mockSystemConsumer)
       .setStreamMetadataCache(mockStreamMetadataCache)
-      .setSSPMetadataCache(mockSSPMetadataCache)
       .setSystemAdmin("kafka", mockSystemAdmin)
       .initializeContainerStorageManager()
       .build
 
-    taskManager.init
-
     assertTrue(storeFile.exists())
     assertFalse(offsetFile.exists())
     verify(mockSystemConsumer).register(ssp, "0")
@@ -139,12 +137,14 @@ class TestTaskStorageManager(offsetFileName: String) extends MockitoSugar {
     assertTrue(offsetFile.exists())
     validateOffsetFileContents(offsetFile, "kafka.testStream-loggedStore1.0", "50")
 
-    // Test 3: Update sspMetadata before shutdown and verify that offset file is updated correctly
-    when(mockSSPMetadataCache.getMetadata(ssp)).thenReturn(new SystemStreamPartitionMetadata("0", "100", "101"))
+    // Test 3: Update sspMetadata before shutdown and verify that offset file is not updated
+    when(mockSystemAdmin.getSSPMetadata(ImmutableSet.of(ssp)))
+      .thenReturn(ImmutableMap.of(ssp, new SystemStreamPartitionMetadata("0", "100", "101")))
     taskManager.stop()
+    verify(mockStorageEngine, times(1)).flush() // only called once during Test 2.
     assertTrue(storeFile.exists())
     assertTrue(offsetFile.exists())
-    validateOffsetFileContents(offsetFile, "kafka.testStream-loggedStore1.0", "100")
+    validateOffsetFileContents(offsetFile, "kafka.testStream-loggedStore1.0", "50")
 
     // Test 4: Initialize again with an updated sspMetadata; Verify that it restores from the correct offset
     sspMetadata = new SystemStreamPartitionMetadata("0", "150", "151")
@@ -154,23 +154,21 @@ class TestTaskStorageManager(offsetFileName: String) extends MockitoSugar {
       }
     })
     when(mockStreamMetadataCache.getStreamMetadata(any(), any())).thenReturn(Map(ss -> metadata))
-    when(mockSSPMetadataCache.getMetadata(ssp)).thenReturn(sspMetadata)
-    when(mockSystemAdmin.getOffsetsAfter(Map(ssp -> "100").asJava)).thenReturn(Map(ssp -> "101").asJava)
+    when(mockSystemAdmin.getSSPMetadata(ImmutableSet.of(ssp)))
+      .thenReturn(ImmutableMap.of(ssp, sspMetadata))
+    when(mockSystemAdmin.getOffsetsAfter(Map(ssp -> "50").asJava)).thenReturn(Map(ssp -> "51").asJava)
     Mockito.reset(mockSystemConsumer)
 
     taskManager = new TaskStorageManagerBuilder()
       .addStore(loggedStore, mockStorageEngine, mockSystemConsumer)
       .setStreamMetadataCache(mockStreamMetadataCache)
-      .setSSPMetadataCache(mockSSPMetadataCache)
       .setSystemAdmin("kafka", mockSystemAdmin)
       .initializeContainerStorageManager()
       .build
 
-    taskManager.init
-
     assertTrue(storeFile.exists())
     assertTrue(offsetFile.exists())
-    verify(mockSystemConsumer).register(ssp, "101")
+    verify(mockSystemConsumer).register(ssp, "51")
   }
 
   /**
@@ -206,6 +204,7 @@ class TestTaskStorageManager(offsetFileName: String) extends MockitoSugar {
       }
     })
     when(mockStreamMetadataCache.getStreamMetadata(any(), any())).thenReturn(Map(ss -> metadata))
+    when(mockSystemAdmin.getSSPMetadata(ImmutableSet.of(ssp))).thenReturn(ImmutableMap.of(ssp, sspMetadata))
     var taskManager = new TaskStorageManagerBuilder()
       .addStore(store, mockStorageEngine, mockSystemConsumer)
       .setStreamMetadataCache(mockStreamMetadataCache)
@@ -213,8 +212,6 @@ class TestTaskStorageManager(offsetFileName: String) extends MockitoSugar {
       .initializeContainerStorageManager()
       .build
 
-    taskManager.init
-
     // Verify that the store directory doesn't have ANY files
     assertTrue(storeDirectory.list().isEmpty)
     verify(mockSystemConsumer).register(ssp, "0")
@@ -230,6 +227,7 @@ class TestTaskStorageManager(offsetFileName: String) extends MockitoSugar {
       }
     })
     when(mockStreamMetadataCache.getStreamMetadata(any(), any())).thenReturn(Map(ss -> metadata))
+    when(mockSystemAdmin.getSSPMetadata(ImmutableSet.of(ssp))).thenReturn(ImmutableMap.of(ssp, sspMetadata))
     taskManager.stop()
     assertTrue(storeDirectory.list().isEmpty)
 
@@ -248,8 +246,6 @@ class TestTaskStorageManager(offsetFileName: String) extends MockitoSugar {
       .initializeContainerStorageManager()
       .build
 
-    taskManager.init
-
     assertTrue(storeDirectory.list().isEmpty)
     // second time to register; make sure it starts from beginning
     verify(mockSystemConsumer, times(2)).register(ssp, "0")
@@ -323,16 +319,16 @@ class TestTaskStorageManager(offsetFileName: String) extends MockitoSugar {
   }
 
   @Test
-  def testStopCreatesOffsetFileForLoggedStore() {
+  def testStopDoesNotCreatesOffsetFileForLoggedStore() {
     val partition = new Partition(0)
 
     val storeDirectory = storageManagerUtil.getTaskStoreDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName, TaskMode.Active)
     val offsetFile = new File(storeDirectory, offsetFileName)
 
-    val sspMetadataCache = mock[SSPMetadataCache]
+    val ssp = new SystemStreamPartition("kafka", getStreamName(loggedStore), partition)
+    val mockSystemAdmin = mock[SystemAdmin]
     val sspMetadata = new SystemStreamPartitionMetadata("20", "100", "101")
-    when(sspMetadataCache.getMetadata(new SystemStreamPartition("kafka", getStreamName(loggedStore), partition)))
-      .thenReturn(sspMetadata)
+    when(mockSystemAdmin.getSSPMetadata(ImmutableSet.of(ssp))).thenReturn(ImmutableMap.of(ssp, sspMetadata))
 
     var metadata = new SystemStreamMetadata(getStreamName(loggedStore), new java.util.HashMap[Partition, SystemStreamPartitionMetadata]() {
       {
@@ -347,7 +343,7 @@ class TestTaskStorageManager(offsetFileName: String) extends MockitoSugar {
     val taskStorageManager = new TaskStorageManagerBuilder()
       .addLoggedStore(loggedStore, true)
       .setStreamMetadataCache(mockStreamMetadataCache)
-      .setSSPMetadataCache(sspMetadataCache)
+      .setSystemAdmin("kafka", mockSystemAdmin)
       .setPartition(partition)
       .initializeContainerStorageManager()
       .build
@@ -356,8 +352,7 @@ class TestTaskStorageManager(offsetFileName: String) extends MockitoSugar {
     taskStorageManager.stop()
 
     //Check conditions
-    assertTrue("Offset file doesn't exist!", offsetFile.exists())
-    validateOffsetFileContents(offsetFile, "kafka.testStream-loggedStore1.0", "100")
+    assertFalse("Offset file doesn't exist!", offsetFile.exists())
   }
 
   /**
@@ -371,18 +366,19 @@ class TestTaskStorageManager(offsetFileName: String) extends MockitoSugar {
     val anotherOffsetPath = new File(
       storageManagerUtil.getTaskStoreDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, store, taskName, TaskMode.Active) + File.separator + offsetFileName)
 
-    val sspMetadataCache = mock[SSPMetadataCache]
+    val ssp1 = new SystemStreamPartition("kafka", getStreamName(loggedStore), partition)
+    val ssp2 = new SystemStreamPartition("kafka", getStreamName(store), partition)
     val sspMetadata = new SystemStreamPartitionMetadata("20", "100", "101")
-    when(sspMetadataCache.getMetadata(new SystemStreamPartition("kafka", getStreamName(loggedStore), partition)))
-      .thenReturn(sspMetadata)
-    when(sspMetadataCache.getMetadata(new SystemStreamPartition("kafka", getStreamName(store), partition)))
-      .thenReturn(sspMetadata)
+
+    val mockSystemAdmin = mock[SystemAdmin]
+    when(mockSystemAdmin.getSSPMetadata(ImmutableSet.of(ssp1))).thenReturn(ImmutableMap.of(ssp1, sspMetadata))
+    when(mockSystemAdmin.getSSPMetadata(ImmutableSet.of(ssp2))).thenReturn(ImmutableMap.of(ssp2, sspMetadata))
 
     //Build TaskStorageManager
     val taskStorageManager = new TaskStorageManagerBuilder()
       .addLoggedStore(loggedStore, true)
       .addStore(store, false)
-      .setSSPMetadataCache(sspMetadataCache)
+      .setSystemAdmin("kafka", mockSystemAdmin)
       .setStreamMetadataCache(createMockStreamMetadataCache("20", "100", "101"))
       .setPartition(partition)
       .initializeContainerStorageManager()
@@ -407,13 +403,13 @@ class TestTaskStorageManager(offsetFileName: String) extends MockitoSugar {
 
     val offsetFilePath = new File(storageManagerUtil.getTaskStoreDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName, TaskMode.Active) + File.separator + offsetFileName)
 
-    val sspMetadataCache = mock[SSPMetadataCache]
+    val ssp = new SystemStreamPartition("kafka", getStreamName(loggedStore), partition)
     val sspMetadata = new SystemStreamPartitionMetadata("0", "100", "101")
-    when(sspMetadataCache.getMetadata(new SystemStreamPartition("kafka", getStreamName(loggedStore), partition)))
-      // first return some metadata
-      .thenReturn(sspMetadata)
-      // then return no metadata to trigger the delete
-      .thenReturn(null)
+    val nullSspMetadata = new SystemStreamPartitionMetadata(null, null, null)
+    val mockSystemAdmin = mock[SystemAdmin]
+    when(mockSystemAdmin.getSSPMetadata(ImmutableSet.of(ssp)))
+      .thenReturn(ImmutableMap.of(ssp, sspMetadata))
+      .thenReturn(ImmutableMap.of(ssp, nullSspMetadata))
 
     var metadata = new SystemStreamMetadata(getStreamName(loggedStore), new java.util.HashMap[Partition, SystemStreamPartitionMetadata]() {
       {
@@ -427,7 +423,7 @@ class TestTaskStorageManager(offsetFileName: String) extends MockitoSugar {
     //Build TaskStorageManager
     val taskStorageManager = new TaskStorageManagerBuilder()
       .addLoggedStore(loggedStore, true)
-      .setSSPMetadataCache(sspMetadataCache)
+      .setSystemAdmin("kafka", mockSystemAdmin)
       .setStreamMetadataCache(mockStreamMetadataCache)
       .setPartition(partition)
       .initializeContainerStorageManager()
@@ -455,9 +451,10 @@ class TestTaskStorageManager(offsetFileName: String) extends MockitoSugar {
     val offsetFilePath = new File(storageManagerUtil.getTaskStoreDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName, TaskMode.Active) + File.separator + offsetFileName)
     fileUtil.writeWithChecksum(offsetFilePath, "100")
 
-    val sspMetadataCache = mock[SSPMetadataCache]
     val sspMetadata = new SystemStreamPartitionMetadata("20", "139", "140")
-    when(sspMetadataCache.getMetadata(ssp)).thenReturn(sspMetadata)
+    val mockSystemAdmin = mock[SystemAdmin]
+    when(mockSystemAdmin.getSSPMetadata(ImmutableSet.of(ssp))).thenReturn(ImmutableMap.of(ssp, sspMetadata))
+
 
     var metadata = new SystemStreamMetadata(getStreamName(loggedStore), new java.util.HashMap[Partition, SystemStreamPartitionMetadata]() {
       {
@@ -471,7 +468,7 @@ class TestTaskStorageManager(offsetFileName: String) extends MockitoSugar {
     //Build TaskStorageManager
     val taskStorageManager = new TaskStorageManagerBuilder()
       .addLoggedStore(loggedStore, true)
-      .setSSPMetadataCache(sspMetadataCache)
+      .setSystemAdmin("kafka", mockSystemAdmin)
       .setPartition(partition)
       .setStreamMetadataCache(mockStreamMetadataCache)
       .initializeContainerStorageManager()
@@ -485,7 +482,8 @@ class TestTaskStorageManager(offsetFileName: String) extends MockitoSugar {
     validateOffsetFileContents(offsetFilePath, "kafka.testStream-loggedStore1.0", "139")
 
     // Flush again
-    when(sspMetadataCache.getMetadata(ssp)).thenReturn(new SystemStreamPartitionMetadata("20", "193", "194"))
+    when(mockSystemAdmin.getSSPMetadata(ImmutableSet.of(ssp)))
+      .thenReturn(ImmutableMap.of(ssp, new SystemStreamPartitionMetadata("20", "193", "194")))
 
     //Invoke test method
     taskStorageManager.flush()
@@ -526,7 +524,6 @@ class TestTaskStorageManager(offsetFileName: String) extends MockitoSugar {
     //Build TaskStorageManager
     val taskStorageManager = new TaskStorageManagerBuilder()
       .addLoggedStore(loggedStore, true)
-      .setSSPMetadataCache(sspMetadataCache)
       .setPartition(partition)
       .setStreamMetadataCache(createMockStreamMetadataCache(null, null, null)) // null offsets for empty store
       .initializeContainerStorageManager()
@@ -713,8 +710,6 @@ class TestTaskStorageManager(offsetFileName: String) extends MockitoSugar {
       .initializeContainerStorageManager()
       .build
 
-    taskManager.init
-
     verify(mockSystemConsumer).register(any(classOf[SystemStreamPartition]), anyString())
   }
 
@@ -763,7 +758,7 @@ class TestTaskStorageManager(offsetFileName: String) extends MockitoSugar {
   }
 }
 
-object TestTaskStorageManager {
+object TestNonTransactionalStateTaskStorageManager {
 
   @Parameters def parameters: util.Collection[Array[String]] = {
     val offsetFileNames = new util.ArrayList[Array[String]]()
@@ -784,7 +779,6 @@ class TaskStorageManagerBuilder extends MockitoSugar {
   var storeConsumers: Map[String, SystemConsumer] = Map()
   var changeLogSystemStreams: Map[String, SystemStream] = Map()
   var streamMetadataCache = mock[StreamMetadataCache]
-  var sspMetadataCache = mock[SSPMetadataCache]
   var partition: Partition = new Partition(0)
   var systemAdminsMap: Map[String, SystemAdmin] = Map("kafka" -> mock[SystemAdmin])
   var taskName: TaskName = new TaskName("testTask")
@@ -843,11 +837,6 @@ class TaskStorageManagerBuilder extends MockitoSugar {
     this
   }
 
-  def setSSPMetadataCache(cache: SSPMetadataCache) = {
-    sspMetadataCache = cache
-    this
-  }
-
   /**
     * This method creates and starts a {@link ContainerStorageManager}
     */
@@ -883,14 +872,22 @@ class TaskStorageManagerBuilder extends MockitoSugar {
       "stores.store1.key.serde" -> classOf[StringSerdeFactory].getCanonicalName,
       "stores.store1.msg.serde" -> classOf[StringSerdeFactory].getCanonicalName,
       "stores.loggedStore1.key.serde" -> classOf[StringSerdeFactory].getCanonicalName,
-      "stores.loggedStore1.msg.serde" -> classOf[StringSerdeFactory].getCanonicalName).asJava)
+      "stores.loggedStore1.msg.serde" -> classOf[StringSerdeFactory].getCanonicalName,
+      TaskConfig.TRANSACTIONAL_STATE_ENABLED -> "false").asJava)
 
     var mockSerdes: Map[String, Serde[AnyRef]] = HashMap[String, Serde[AnyRef]]((classOf[StringSerdeFactory].getCanonicalName, Mockito.mock(classOf[Serde[AnyRef]])))
 
+    val mockCheckpointManager = Mockito.mock(classOf[CheckpointManager])
+    when(mockCheckpointManager.readLastCheckpoint(any(classOf[TaskName])))
+      .thenReturn(new Checkpoint(new util.HashMap[SystemStreamPartition, String]()))
+
+    val mockSSPMetadataCache = Mockito.mock(classOf[SSPMetadataCache])
 
     containerStorageManager = new ContainerStorageManager(
+      mockCheckpointManager,
       containerModel,
       streamMetadataCache,
+      mockSSPMetadataCache,
       mockSystemAdmins,
       changeLogSystemStreams.asJava,
       Map[String, util.Set[SystemStream]]().asJava,
@@ -913,17 +910,17 @@ class TaskStorageManagerBuilder extends MockitoSugar {
 
 
 
-  def build: TaskStorageManager = {
+  def build: NonTransactionalStateTaskStorageManager = {
 
     if (containerStorageManager != null) {
       containerStorageManager.start()
     }
 
-    new TaskStorageManager(
+    new NonTransactionalStateTaskStorageManager(
       taskName = taskName,
       containerStorageManager = containerStorageManager,
-      changeLogSystemStreams = changeLogSystemStreams,
-      sspMetadataCache = sspMetadataCache,
+      storeChangelogs = changeLogSystemStreams,
+      systemAdmins = buildSystemAdmins(systemAdminsMap),
       loggedStoreBaseDir = loggedStoreBaseDir,
       partition = partition
     )
diff --git a/samza-kv/src/main/scala/org/apache/samza/storage/kv/KeyValueStorageEngine.scala b/samza-kv/src/main/scala/org/apache/samza/storage/kv/KeyValueStorageEngine.scala
index 7ab5268..b1e4554 100644
--- a/samza-kv/src/main/scala/org/apache/samza/storage/kv/KeyValueStorageEngine.scala
+++ b/samza-kv/src/main/scala/org/apache/samza/storage/kv/KeyValueStorageEngine.scala
@@ -187,7 +187,7 @@ class KeyValueStorageEngine[K, V](
     info(restoredMessages + " entries trimmed for store: " + storeName + " in directory: " + storeDir.toString + ".")
 
     // flush the store and the changelog producer
-    flush() // TODO HIGH pmaheshw: Need a way to flush changelog producers. This only flushes the stores.
+    flush() // TODO HIGH pmaheshw SAMZA-2338: Need a way to flush changelog producers. This only flushes the stores.
   }
 
   def flush() = {