You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@samza.apache.org by pm...@apache.org on 2019/01/12 00:23:06 UTC

[2/2] samza git commit: SAMZA-2018: State restore improvements using RocksDB bulk load

SAMZA-2018: State restore improvements using RocksDB bulk load

This PR makes the following changes:
* Moves all the state-restore code from TaskStorageManager.scala to ContainerStorageManager (and its internal private java classes).
* Introduces a StoreMode in StorageEngineFactory.getStorageEngine to add a StoreMode enum.
* Changes RocksDB store creation to use that enum and use Rocksdb's bulk load option when creating store in bulk-load mode.
* Changes the ContainerStorageManager to create stores in BulkLoad mode when restoring, then closes such persistent and changelogged stores, and re-opens them in Read-Write mode.
* Adds tests for ContainerStorageManager and changes tests for TaskStorageManager accordingly.

Author: Ray Matharu <rm...@linkedin.com>
Author: rmatharu <40...@users.noreply.github.com>

Reviewers: Jagadish Venkatraman <vj...@gmail.com>

Closes #843 from rmatharu/refactoringCSM


Project: http://git-wip-us.apache.org/repos/asf/samza/repo
Commit: http://git-wip-us.apache.org/repos/asf/samza/commit/61255666
Tree: http://git-wip-us.apache.org/repos/asf/samza/tree/61255666
Diff: http://git-wip-us.apache.org/repos/asf/samza/diff/61255666

Branch: refs/heads/master
Commit: 612556660f3b80e6b4110dee3ad2392f9786379e
Parents: a60da19
Author: Ray Matharu <rm...@linkedin.com>
Authored: Fri Jan 11 16:23:01 2019 -0800
Committer: Prateek Maheshwari <pm...@apache.org>
Committed: Fri Jan 11 16:23:01 2019 -0800

----------------------------------------------------------------------
 .../samza/storage/StorageEngineFactory.java     |  20 +-
 .../samza/storage/StorageManagerUtil.java       |  13 +
 .../apache/samza/storage/StorageRecovery.java   | 136 ++--
 .../apache/samza/container/SamzaContainer.scala | 127 ++--
 .../samza/storage/ContainerStorageManager.java  | 635 ++++++++++++++++++-
 .../samza/storage/TaskStorageManager.scala      | 195 +-----
 .../org/apache/samza/util/ScalaJavaUtil.scala   |  37 +-
 .../samza/storage/MockStorageEngineFactory.java |   2 +-
 .../samza/storage/TestStorageRecovery.java      |   9 +-
 .../storage/TestContainerStorageManager.java    | 209 ++++--
 .../samza/storage/TestTaskStorageManager.scala  | 303 ++++++---
 .../InMemoryKeyValueStorageEngineFactory.scala  |   3 +-
 .../samza/storage/kv/RocksDbKeyValueReader.java |   3 +-
 .../samza/storage/kv/RocksDbOptionsHelper.java  |   7 +-
 .../RocksDbKeyValueStorageEngineFactory.scala   |   5 +-
 .../samza/storage/kv/RocksDbKeyValueStore.scala |   8 +
 .../kv/BaseKeyValueStorageEngineFactory.scala   |   7 +-
 .../apache/samza/monitor/LocalStoreMonitor.java |   4 +-
 .../performance/TestKeyValuePerformance.scala   |   3 +-
 .../table/TestLocalTableWithSideInputs.java     |   3 +-
 20 files changed, 1190 insertions(+), 539 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/samza/blob/61255666/samza-api/src/main/java/org/apache/samza/storage/StorageEngineFactory.java
----------------------------------------------------------------------
diff --git a/samza-api/src/main/java/org/apache/samza/storage/StorageEngineFactory.java b/samza-api/src/main/java/org/apache/samza/storage/StorageEngineFactory.java
index 2425cf3..26d6e75 100644
--- a/samza-api/src/main/java/org/apache/samza/storage/StorageEngineFactory.java
+++ b/samza-api/src/main/java/org/apache/samza/storage/StorageEngineFactory.java
@@ -34,6 +34,23 @@ import org.apache.samza.task.MessageCollector;
 public interface StorageEngineFactory<K, V> {
 
   /**
+   * Enum to describe different modes a {@link StorageEngine} can be created in.
+   * The BulkLoad mode is used when bulk loading of data onto the store, e.g., store-restoration at Samza
+   * startup. In this mode, the underlying store will tailor itself for write-intensive ops -- tune its params,
+   * adapt its compaction behaviour, etc.
+   *
+   * The ReadWrite mode is used for normal read-write ops by the application.
+   */
+  enum StoreMode {
+    BulkLoad("bulk"), ReadWrite("rw");
+    public final String mode;
+
+    StoreMode(String mode) {
+      this.mode = mode;
+    }
+  }
+
+  /**
    * Create an instance of the given storage engine.
    *
    * @param storeName The name of the storage engine.
@@ -45,6 +62,7 @@ public interface StorageEngineFactory<K, V> {
    * @param changeLogSystemStreamPartition Samza stream partition from which to receive the changelog.
    * @param jobContext Information about the job in which the task is executing
    * @param containerContext Information about the container in which the task is executing.
+   * @param storeMode The mode in which the instance should be created in.
    * @return The storage engine instance.
    */
   StorageEngine getStorageEngine(
@@ -56,5 +74,5 @@ public interface StorageEngineFactory<K, V> {
     MetricsRegistry registry,
     SystemStreamPartition changeLogSystemStreamPartition,
     JobContext jobContext,
-    ContainerContext containerContext);
+    ContainerContext containerContext, StoreMode storeMode);
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/61255666/samza-core/src/main/java/org/apache/samza/storage/StorageManagerUtil.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/storage/StorageManagerUtil.java b/samza-core/src/main/java/org/apache/samza/storage/StorageManagerUtil.java
index e7301ea..2086aa4 100644
--- a/samza-core/src/main/java/org/apache/samza/storage/StorageManagerUtil.java
+++ b/samza-core/src/main/java/org/apache/samza/storage/StorageManagerUtil.java
@@ -21,6 +21,7 @@ package org.apache.samza.storage;
 
 import com.google.common.collect.ImmutableMap;
 import java.io.File;
+import org.apache.samza.container.TaskName;
 import org.apache.samza.system.SystemAdmin;
 import org.apache.samza.system.SystemStreamPartition;
 import org.apache.samza.util.FileUtil;
@@ -139,4 +140,16 @@ public class StorageManagerUtil {
 
     return offset;
   }
+
+  /**
+   * Creates and returns a File pointing to the directory for the given store and task, given a particular base directory.
+   *
+   * @param storeBaseDir the base directory to use
+   * @param storeName the store name to use
+   * @param taskName the task name which is referencing the store
+   * @return the partition directory for the store
+   */
+  public static File getStorePartitionDir(File storeBaseDir, String storeName, TaskName taskName) {
+    return new File(storeBaseDir, (storeName + File.separator + taskName.toString()).replace(' ', '_'));
+  }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/61255666/samza-core/src/main/java/org/apache/samza/storage/StorageRecovery.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/java/org/apache/samza/storage/StorageRecovery.java b/samza-core/src/main/java/org/apache/samza/storage/StorageRecovery.java
index 5442d6e..cf8338a 100644
--- a/samza-core/src/main/java/org/apache/samza/storage/StorageRecovery.java
+++ b/samza-core/src/main/java/org/apache/samza/storage/StorageRecovery.java
@@ -20,18 +20,15 @@
 package org.apache.samza.storage;
 
 import java.io.File;
-import java.time.Duration;
-import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
-import java.util.Map.Entry;
 import org.apache.samza.SamzaException;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JavaStorageConfig;
 import org.apache.samza.config.JavaSystemConfig;
-import org.apache.samza.config.StorageConfig;
-import org.apache.samza.container.TaskName;
+import org.apache.samza.config.SerializerConfig;
+import org.apache.samza.container.SamzaContainerMetrics;
 import org.apache.samza.context.ContainerContext;
 import org.apache.samza.context.ContainerContextImpl;
 import org.apache.samza.context.JobContextImpl;
@@ -41,15 +38,12 @@ import org.apache.samza.job.model.ContainerModel;
 import org.apache.samza.job.model.JobModel;
 import org.apache.samza.job.model.TaskModel;
 import org.apache.samza.metrics.MetricsRegistryMap;
-import org.apache.samza.serializers.ByteSerde;
 import org.apache.samza.serializers.Serde;
-import org.apache.samza.system.SSPMetadataCache;
+import org.apache.samza.serializers.SerdeFactory;
 import org.apache.samza.system.StreamMetadataCache;
 import org.apache.samza.system.SystemAdmins;
-import org.apache.samza.system.SystemConsumer;
 import org.apache.samza.system.SystemFactory;
 import org.apache.samza.system.SystemStream;
-import org.apache.samza.system.SystemStreamPartition;
 import org.apache.samza.util.Clock;
 import org.apache.samza.util.CommandLine;
 import org.apache.samza.util.ScalaJavaUtil;
@@ -58,6 +52,8 @@ import org.apache.samza.util.SystemClock;
 import org.apache.samza.util.Util;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
+import scala.Option;
+
 
 /**
  * Recovers the state storages from the changelog streams and store the storages
@@ -70,13 +66,13 @@ public class StorageRecovery extends CommandLine {
   private int maxPartitionNumber = 0;
   private File storeBaseDir = null;
   private HashMap<String, SystemStream> changeLogSystemStreams = new HashMap<>();
-  private HashMap<String, StorageEngineFactory<?, ?>> storageEngineFactories = new HashMap<>();
+  private HashMap<String, StorageEngineFactory<Object, Object>> storageEngineFactories = new HashMap<>();
   private Map<String, ContainerModel> containers = new HashMap<>();
-  private ContainerStorageManager containerStorageManager;
+  private Map<String, ContainerStorageManager> containerStorageManagers = new HashMap<>();
+
   private Logger log = LoggerFactory.getLogger(StorageRecovery.class);
   private SystemAdmins systemAdmins = null;
 
-
   /**
    * Construct the StorageRecovery
    *
@@ -101,7 +97,7 @@ public class StorageRecovery extends CommandLine {
     getContainerModels();
     getChangeLogSystemStreamsAndStorageFactories();
     getChangeLogMaxPartitionNumber();
-    getContainerStorageManager();
+    getContainerStorageManagers();
   }
 
   /**
@@ -113,16 +109,19 @@ public class StorageRecovery extends CommandLine {
     log.info("start recovering...");
 
     systemAdmins.start();
-    this.containerStorageManager.start();
-    this.containerStorageManager.shutdown();
+    this.containerStorageManagers.forEach((containerName, containerStorageManager) -> {
+        containerStorageManager.start();
+      });
+    this.containerStorageManagers.forEach((containerName, containerStorageManager) -> {
+        containerStorageManager.shutdown();
+      });
     systemAdmins.stop();
 
     log.info("successfully recovered in " + storeBaseDir.toString());
   }
 
   /**
-   * build the ContainerModels from job config file and put the results in the
-   * map
+   * Build ContainerModels from job config file and put the results in the containerModels map.
    */
   private void getContainerModels() {
     MetricsRegistryMap metricsRegistryMap = new MetricsRegistryMap();
@@ -137,7 +136,7 @@ public class StorageRecovery extends CommandLine {
   }
 
   /**
-   * get the changelog streams and the storage factories from the config file
+   * Get the changelog streams and the storage factories from the config file
    * and put them into the maps
    */
   private void getChangeLogSystemStreamsAndStorageFactories() {
@@ -165,24 +164,6 @@ public class StorageRecovery extends CommandLine {
   }
 
   /**
-   * get the SystemConsumers for the stores
-   */
-  private HashMap<String, SystemConsumer> getStoreConsumers() {
-    HashMap<String, SystemConsumer> storeConsumers = new HashMap<>();
-    Map<String, SystemFactory> systemFactories = new JavaSystemConfig(jobConfig).getSystemFactories();
-
-    for (Entry<String, SystemStream> entry : changeLogSystemStreams.entrySet()) {
-      String storeSystem = entry.getValue().getSystem();
-      if (!systemFactories.containsKey(storeSystem)) {
-        throw new SamzaException("Changelog system " + storeSystem + " for store " + entry.getKey() + " does not exist in the config.");
-      }
-      storeConsumers.put(entry.getKey(), systemFactories.get(storeSystem).getConsumer(storeSystem, jobConfig, new MetricsRegistryMap()));
-    }
-
-    return storeConsumers;
-  }
-
-  /**
    * get the max partition number of the changelog stream
    */
   private void getChangeLogMaxPartitionNumber() {
@@ -195,68 +176,49 @@ public class StorageRecovery extends CommandLine {
     maxPartitionNumber = maxPartitionId + 1;
   }
 
+  private Map<String, Serde<Object>> getSerdes() {
+    Map<String, Serde<Object>> serdeMap = new HashMap<>();
+    SerializerConfig serializerConfig = new SerializerConfig(jobConfig);
+
+    // Adding all serdes from factories
+    ScalaJavaUtil.toJavaCollection(serializerConfig.getSerdeNames())
+        .stream()
+        .forEach(serdeName -> {
+            Option<String> serdeClassName = serializerConfig.getSerdeClass(serdeName);
+
+            if (serdeClassName.isEmpty()) {
+              serdeClassName = Option.apply(SerializerConfig.getSerdeFactoryName(serdeName));
+            }
+
+            Serde serde = Util.getObj(serdeClassName.get(), SerdeFactory.class).getSerde(serdeName, serializerConfig);
+            serdeMap.put(serdeName, serde);
+          });
+
+    return serdeMap;
+  }
+
   /**
    * create one TaskStorageManager for each task. Add all of them to the
    * List<TaskStorageManager>
    */
-  @SuppressWarnings({ "unchecked", "rawtypes" })
-  private void getContainerStorageManager() {
+  @SuppressWarnings({"unchecked", "rawtypes"})
+  private void getContainerStorageManagers() {
     Clock clock = SystemClock.instance();
-    Map<TaskName, TaskStorageManager> taskStorageManagers = new HashMap<>();
-    HashMap<String, SystemConsumer> storeConsumers = getStoreConsumers();
     StreamMetadataCache streamMetadataCache = new StreamMetadataCache(systemAdmins, 5000, clock);
     // don't worry about prefetching for this; looks like the tool doesn't flush to offset files anyways
-    SSPMetadataCache sspMetadataCache =
-        new SSPMetadataCache(systemAdmins, Duration.ofSeconds(5), clock, Collections.emptySet());
+
+    Map<String, SystemFactory> systemFactories = new JavaSystemConfig(jobConfig).getSystemFactories();
 
     for (ContainerModel containerModel : containers.values()) {
-      HashMap<String, StorageEngine> taskStores = new HashMap<String, StorageEngine>();
       ContainerContext containerContext = new ContainerContextImpl(containerModel, new MetricsRegistryMap());
 
-      for (TaskModel taskModel : containerModel.getTasks().values()) {
-
-        for (Entry<String, StorageEngineFactory<?, ?>> entry : storageEngineFactories.entrySet()) {
-          String storeName = entry.getKey();
-
-          if (changeLogSystemStreams.containsKey(storeName)) {
-            SystemStreamPartition changeLogSystemStreamPartition = new SystemStreamPartition(changeLogSystemStreams.get(storeName),
-                taskModel.getChangelogPartition());
-            File storePartitionDir = TaskStorageManager.getStorePartitionDir(storeBaseDir, storeName, taskModel.getTaskName());
-
-            log.info("Got storage engine directory: " + storePartitionDir);
-
-            StorageEngine storageEngine = (entry.getValue()).getStorageEngine(
-                storeName,
-                storePartitionDir,
-                (Serde) new ByteSerde(),
-                (Serde) new ByteSerde(),
-                null,
-                new MetricsRegistryMap(),
-                changeLogSystemStreamPartition,
-                JobContextImpl.fromConfigWithDefaults(jobConfig),
-                containerContext);
-            taskStores.put(storeName, storageEngine);
-          }
-        }
-        TaskStorageManager taskStorageManager = new TaskStorageManager(
-            taskModel.getTaskName(),
-            ScalaJavaUtil.toScalaMap(taskStores),
-            ScalaJavaUtil.toScalaMap(storeConsumers),
-            ScalaJavaUtil.toScalaMap(changeLogSystemStreams),
-            maxPartitionNumber,
-            streamMetadataCache,
-            sspMetadataCache,
-            storeBaseDir,
-            storeBaseDir,
-            taskModel.getChangelogPartition(),
-            systemAdmins,
-            new StorageConfig(jobConfig).getChangeLogDeleteRetentionsInMs(),
-            new SystemClock());
-
-        taskStorageManagers.put(taskModel.getTaskName(), taskStorageManager);
-      }
+      ContainerStorageManager containerStorageManager =
+          new ContainerStorageManager(containerModel, streamMetadataCache, systemAdmins, changeLogSystemStreams,
+              storageEngineFactories, systemFactories, this.getSerdes(), jobConfig, new HashMap<>(),
+              new SamzaContainerMetrics(containerModel.getId(), new MetricsRegistryMap()),
+              JobContextImpl.fromConfigWithDefaults(jobConfig), containerContext, new HashMap<>(),
+              storeBaseDir, storeBaseDir, maxPartitionNumber, new SystemClock());
+      this.containerStorageManagers.put(containerModel.getId(), containerStorageManager);
     }
-
-    this.containerStorageManager = new ContainerStorageManager(taskStorageManagers, storeConsumers, null);
   }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/61255666/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
index e38a451..ec7360a 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
@@ -26,7 +26,7 @@ import java.net.{URL, UnknownHostException}
 import java.nio.file.Path
 import java.time.Duration
 import java.util
-import java.util.Base64
+import java.util.{Base64}
 import java.util.concurrent.{ExecutorService, Executors, ScheduledExecutorService, TimeUnit}
 
 import com.google.common.annotations.VisibleForTesting
@@ -50,6 +50,7 @@ import org.apache.samza.metrics.{JmxServer, JvmMetrics, MetricsRegistryMap, Metr
 import org.apache.samza.serializers._
 import org.apache.samza.serializers.model.SamzaObjectMapper
 import org.apache.samza.startpoint.StartpointManager
+import org.apache.samza.storage.StorageEngineFactory.StoreMode
 import org.apache.samza.storage._
 import org.apache.samza.system._
 import org.apache.samza.system.chooser.{DefaultChooser, MessageChooserFactory, RoundRobinChooserFactory}
@@ -81,7 +82,9 @@ object SamzaContainer extends Logging {
         classOf[JobModel])
   }
 
-  // TODO: SAMZA-1701 SamzaContainer should not contain any logic related to store directories
+  /**
+    * If a base-directory was NOT explicitly provided in config, a default base directory is returned.
+    */
   def getNonLoggedStorageBaseDir(config: Config, defaultStoreBaseDir: File) = {
     config.getNonLoggedStorePath match {
       case Some(nonLoggedStorePath) =>
@@ -91,7 +94,10 @@ object SamzaContainer extends Logging {
     }
   }
 
-  // TODO: SAMZA-1701 SamzaContainer should not contain any logic related to store directories
+  /**
+    * If a base-directory was NOT explicitly provided in config or via an environment variable
+    * (see ShellCommandConfig.ENV_LOGGED_STORE_BASE_DIR), then a default base directory is returned.
+    */
   def getLoggedStorageBaseDir(config: Config, defaultStoreBaseDir: File) = {
     val defaultLoggedStorageBaseDir = config.getLoggedStorePath match {
       case Some(durableStorePath) =>
@@ -504,6 +510,7 @@ object SamzaContainer extends Logging {
       .map(_.getTaskName)
       .toSet
 
+    val taskModels = containerModel.getTasks.values.asScala
     val containerContext = new ContainerContextImpl(containerModel, samzaContainerMetrics.registry)
     val applicationContainerContextOption = applicationContainerContextFactoryOption
       .map(_.create(externalContextOption.orNull, jobContext, containerContext))
@@ -512,23 +519,37 @@ object SamzaContainer extends Logging {
 
     val timerExecutor = Executors.newSingleThreadScheduledExecutor
 
-    // We create a map of store SystemName to its respective SystemConsumer
-    val storeSystemConsumers: Map[String, SystemConsumer] = changeLogSystemStreams.mapValues {
-      case (changeLogSystemStream) => (changeLogSystemStream.getSystem)
-    }.values.toSet.map {
-      systemName: String =>
-        (systemName, systemFactories
-          .getOrElse(systemName,
-            throw new SamzaException("Changelog system %s exist in the config." format (systemName)))
-          .getConsumer(systemName, config, samzaContainerMetrics.registry))
-    }.toMap
+    var taskStorageManagers : Map[TaskName, TaskStorageManager] = Map()
 
-    info("Created store system consumers: %s" format storeSystemConsumers)
+    val taskInstanceMetrics: Map[TaskName, TaskInstanceMetrics] = taskModels.map(taskModel => {
+      (taskModel.getTaskName, new TaskInstanceMetrics("TaskName-%s" format taskModel.getTaskName))
+    }).toMap
 
-    var taskStorageManagers : Map[TaskName, TaskStorageManager] = Map()
+    val taskCollectors : Map[TaskName, TaskInstanceCollector] = taskModels.map(taskModel => {
+      (taskModel.getTaskName, new TaskInstanceCollector(producerMultiplexer, taskInstanceMetrics.get(taskModel.getTaskName).get))
+    }).toMap
+
+    val defaultStoreBaseDir = new File(System.getProperty("user.dir"), "state")
+    info("Got default storage engine base directory: %s" format defaultStoreBaseDir)
+
+    val nonLoggedStorageBaseDir = getNonLoggedStorageBaseDir(config, defaultStoreBaseDir)
+    info("Got base directory for non logged data stores: %s" format nonLoggedStorageBaseDir)
+
+    val loggedStorageBaseDir = getLoggedStorageBaseDir(config, defaultStoreBaseDir)
+    info("Got base directory for logged data stores: %s" format loggedStorageBaseDir)
+
+    val sideInputStorageEngineFactories = storageEngineFactories.filterKeys(storeName => sideInputStoresToSystemStreams.contains(storeName))
+    val nonSideInputStorageEngineFactories = (storageEngineFactories.toSet diff sideInputStorageEngineFactories.toSet).toMap
+
+    val containerStorageManager = new ContainerStorageManager(containerModel, streamMetadataCache, systemAdmins,
+      changeLogSystemStreams.asJava, nonSideInputStorageEngineFactories.asJava, systemFactories.asJava, serdes.asJava, config,
+      taskInstanceMetrics.asJava, samzaContainerMetrics, jobContext, containerContext, taskCollectors.asJava,
+      loggedStorageBaseDir, nonLoggedStorageBaseDir, maxChangeLogStreamPartitions, new SystemClock)
+
+    storeWatchPaths.addAll(containerStorageManager.getStoreDirectoryPaths)
 
     // Create taskInstances
-    val taskInstances: Map[TaskName, TaskInstance] = containerModel.getTasks.values.asScala.map(taskModel => {
+    val taskInstances: Map[TaskName, TaskInstance] = taskModels.map(taskModel => {
       debug("Setting up task instance: %s" format taskModel)
 
       val taskName = taskModel.getTaskName
@@ -538,32 +559,7 @@ object SamzaContainer extends Logging {
         case tf: StreamTaskFactory => tf.asInstanceOf[StreamTaskFactory].createInstance()
       }
 
-      val taskInstanceMetrics = new TaskInstanceMetrics("TaskName-%s" format taskName)
-
-      val collector = new TaskInstanceCollector(producerMultiplexer, taskInstanceMetrics)
-
-      // Re-use the storeConsumers, stored in storeSystemConsumers
-      val storeConsumers : Map[String, SystemConsumer] = changeLogSystemStreams
-        .map {
-          case (storeName, changeLogSystemStream) =>
-            val systemConsumer = storeSystemConsumers.get(changeLogSystemStream.getSystem).get
-            samzaContainerMetrics.addStoreRestorationGauge(taskName, storeName)
-            (storeName, systemConsumer)
-        }
-
-      info("Got store consumers: %s" format storeConsumers)
-
-      val defaultStoreBaseDir = new File(System.getProperty("user.dir"), "state")
-      info("Got default storage engine base directory: %s" format defaultStoreBaseDir)
-
-      val nonLoggedStorageBaseDir = getNonLoggedStorageBaseDir(config, defaultStoreBaseDir)
-      info("Got base directory for non logged data stores: %s" format nonLoggedStorageBaseDir)
-
-      val loggedStorageBaseDir = getLoggedStorageBaseDir(config, defaultStoreBaseDir)
-      info("Got base directory for logged data stores: %s" format loggedStorageBaseDir)
-
-      val taskStores = storageEngineFactories
-        .map {
+      val sideInputStores = sideInputStorageEngineFactories.map {
           case (storeName, storageEngineFactory) =>
             val changeLogSystemStreamPartition = if (changeLogSystemStreams.contains(storeName)) {
               new SystemStreamPartition(changeLogSystemStreams(storeName), taskModel.getChangelogPartition)
@@ -583,37 +579,29 @@ object SamzaContainer extends Logging {
               case _ => null
             }
 
-            // We use the logged storage base directory for change logged and side input stores since side input stores
+            // We use the logged storage base directory for side input stores since side input stores
             // dont have changelog configured.
-            val storeDir = if (changeLogSystemStreamPartition != null || sideInputStoresToSystemStreams.contains(storeName)) {
-              TaskStorageManager.getStorePartitionDir(loggedStorageBaseDir, storeName, taskName)
-            } else {
-              TaskStorageManager.getStorePartitionDir(nonLoggedStorageBaseDir, storeName, taskName)
-            }
-
+            val storeDir = StorageManagerUtil.getStorePartitionDir(loggedStorageBaseDir, storeName, taskName)
             storeWatchPaths.add(storeDir.toPath)
 
-            val storageEngine = storageEngineFactory.getStorageEngine(
+            val sideInputStorageEngine = storageEngineFactory.getStorageEngine(
               storeName,
               storeDir,
               keySerde,
               msgSerde,
-              collector,
-              taskInstanceMetrics.registry,
+              taskCollectors.get(taskName).get,
+              taskInstanceMetrics.get(taskName).get.registry,
               changeLogSystemStreamPartition,
               jobContext,
-              containerContext)
-            (storeName, storageEngine)
+              containerContext, StoreMode.ReadWrite)
+            (storeName, sideInputStorageEngine)
         }
 
-      info("Got task stores: %s" format taskStores)
+      info("Got side input stores: %s" format sideInputStores)
 
       val taskSSPs = taskModel.getSystemStreamPartitions.asScala.toSet
       info("Got task SSPs: %s" format taskSSPs)
 
-      val (sideInputStores, nonSideInputStores) =
-        taskStores.partition { case (storeName, _) => sideInputStoresToSystemStreams.contains(storeName)}
-
       val sideInputStoresToSSPs = sideInputStoresToSystemStreams.mapValues(sideInputSystemStreams =>
         taskSSPs.filter(ssp => sideInputSystemStreams.contains(ssp.getSystemStream)).asJava)
 
@@ -627,24 +615,17 @@ object SamzaContainer extends Logging {
               (storeName, SerdeUtils.deserialize("Side Inputs Processor", serializedInstance)))
             .orElse(config.getSideInputsProcessorFactory(storeName).map(factoryClassName =>
               (storeName, Util.getObj(factoryClassName, classOf[SideInputsProcessorFactory])
-                .getSideInputsProcessor(config, taskInstanceMetrics.registry))))
+                .getSideInputsProcessor(config, taskInstanceMetrics.get(taskName).get.registry))))
             .get
         }).toMap
 
       val storageManager = new TaskStorageManager(
         taskName = taskName,
-        taskStores = nonSideInputStores,
-        storeConsumers = storeConsumers,
+        containerStorageManager = containerStorageManager,
         changeLogSystemStreams = changeLogSystemStreams,
-        maxChangeLogStreamPartitions,
-        streamMetadataCache = streamMetadataCache,
         sspMetadataCache = changelogSSPMetadataCache,
-        nonLoggedStoreBaseDir = nonLoggedStorageBaseDir,
         loggedStoreBaseDir = loggedStorageBaseDir,
-        partition = taskModel.getChangelogPartition,
-        systemAdmins = systemAdmins,
-        new StorageConfig(config).getChangeLogDeleteRetentionsInMs,
-        new SystemClock)
+        partition = taskModel.getChangelogPartition)
 
       var sideInputStorageManager: TaskSideInputStorageManager = null
       if (sideInputStores.nonEmpty) {
@@ -667,16 +648,16 @@ object SamzaContainer extends Logging {
       def createTaskInstance(task: Any): TaskInstance = new TaskInstance(
           task = task,
           taskModel = taskModel,
-          metrics = taskInstanceMetrics,
+          metrics = taskInstanceMetrics.get(taskName).get,
           systemAdmins = systemAdmins,
           consumerMultiplexer = consumerMultiplexer,
-          collector = collector,
+          collector = taskCollectors.get(taskName).get,
           offsetManager = offsetManager,
           storageManager = storageManager,
           tableManager = tableManager,
           reporters = reporters,
           systemStreamPartitions = taskSSPs,
-          exceptionHandler = TaskInstanceExceptionHandler(taskInstanceMetrics, config),
+          exceptionHandler = TaskInstanceExceptionHandler(taskInstanceMetrics.get(taskName).get, config),
           jobModel = jobModel,
           streamMetadataCache = streamMetadataCache,
           timerExecutor = timerExecutor,
@@ -694,10 +675,6 @@ object SamzaContainer extends Logging {
       (taskName, taskInstance)
     }).toMap
 
-
-    val containerStorageManager = new ContainerStorageManager(taskStorageManagers.asJava, storeSystemConsumers.asJava,
-      samzaContainerMetrics)
-
     val maxThrottlingDelayMs = config.getLong("container.disk.quota.delay.max.ms", TimeUnit.SECONDS.toMillis(1))
 
     val runLoop = RunLoopFactory.createRunLoop(

http://git-wip-us.apache.org/repos/asf/samza/blob/61255666/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java b/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java
index c39d6e7..b896267 100644
--- a/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java
+++ b/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java
@@ -18,21 +18,53 @@
  */
 package org.apache.samza.storage;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import java.io.File;
+import java.nio.file.Path;
 import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
 import java.util.concurrent.Callable;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
+import java.util.stream.Collectors;
+import org.apache.samza.Partition;
 import org.apache.samza.SamzaException;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.StorageConfig;
 import org.apache.samza.container.SamzaContainerMetrics;
+import org.apache.samza.container.TaskInstanceMetrics;
 import org.apache.samza.container.TaskName;
+import org.apache.samza.context.ContainerContext;
+import org.apache.samza.context.JobContext;
+import org.apache.samza.job.model.ContainerModel;
+import org.apache.samza.job.model.TaskModel;
 import org.apache.samza.metrics.Gauge;
+import org.apache.samza.metrics.MetricsRegistry;
+import org.apache.samza.metrics.MetricsRegistryMap;
+import org.apache.samza.serializers.Serde;
+import org.apache.samza.system.StreamMetadataCache;
+import org.apache.samza.system.StreamSpec;
+import org.apache.samza.system.SystemAdmin;
+import org.apache.samza.system.SystemAdmins;
 import org.apache.samza.system.SystemConsumer;
+import org.apache.samza.system.SystemFactory;
+import org.apache.samza.system.SystemStream;
+import org.apache.samza.system.SystemStreamMetadata;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.system.SystemStreamPartitionIterator;
+import org.apache.samza.task.TaskInstanceCollector;
+import org.apache.samza.util.Clock;
+import org.apache.samza.util.FileUtil;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
+import scala.collection.JavaConverters;
 
 
 /**
@@ -43,53 +75,279 @@ import org.slf4j.LoggerFactory;
  *  a) performing all container-level actions for restore such as, initializing and shutting down
  *  taskStorage managers, starting, registering and stopping consumers, etc.
  *
- *  b) performing individual taskStorageManager restores in parallel.
+ *  b) performing individual task stores' restores in parallel.
  *
  */
 public class ContainerStorageManager {
-
   private static final Logger LOG = LoggerFactory.getLogger(ContainerStorageManager.class);
-  private final Map<TaskName, TaskStorageManager> taskStorageManagers;
+  private static final String RESTORE_THREAD_NAME = "Samza Restore Thread-%d";
+
+  /** Maps containing relevant per-task objects */
+  private final Map<TaskName, Map<String, StorageEngine>> taskStores;
+  private final Map<TaskName, TaskRestoreManager> taskRestoreManagers;
+  private final Map<TaskName, TaskInstanceMetrics> taskInstanceMetrics;
+  private final Map<TaskName, TaskInstanceCollector> taskInstanceCollectors;
+
+  private final Map<String, SystemConsumer> systemConsumers; // Mapping from storeSystemNames to SystemConsumers
+  private final Map<String, StorageEngineFactory<Object, Object>> storageEngineFactories; // Map of storageEngineFactories indexed by store name
+  private final Map<String, SystemStream> changelogSystemStreams; // Map of changelog system-streams indexed by store name
+  private final Map<String, Serde<Object>> serdes; // Map of Serde objects indexed by serde name (specified in config)
+
+  private final StreamMetadataCache streamMetadataCache;
   private final SamzaContainerMetrics samzaContainerMetrics;
 
-  // Mapping of from storeSystemNames to SystemConsumers
-  private final Map<String, SystemConsumer> systemConsumers;
+  /* Parameters required to re-create taskStores post-restoration */
+  private final ContainerModel containerModel;
+  private final JobContext jobContext;
+  private final ContainerContext containerContext;
+
+  private final File loggedStoreBaseDirectory;
+  private final File nonLoggedStoreBaseDirectory;
+  private final Set<Path> storeDirectoryPaths; // the set of store directory paths, used by SamzaContainer to initialize its disk-space-monitor
 
-  // Size of thread-pool to be used for parallel restores
   private final int parallelRestoreThreadPoolSize;
+  private final int maxChangeLogStreamPartitions; // The partition count of each changelog-stream topic. This is used for validating changelog streams before restoring.
 
-  // Naming convention to be used for restore threads
-  private static final String RESTORE_THREAD_NAME = "Samza Restore Thread-%d";
+  private final Config config;
+
+  public ContainerStorageManager(ContainerModel containerModel, StreamMetadataCache streamMetadataCache,
+      SystemAdmins systemAdmins, Map<String, SystemStream> changelogSystemStreams,
+      Map<String, StorageEngineFactory<Object, Object>> storageEngineFactories,
+      Map<String, SystemFactory> systemFactories, Map<String, Serde<Object>> serdes, Config config,
+      Map<TaskName, TaskInstanceMetrics> taskInstanceMetrics, SamzaContainerMetrics samzaContainerMetrics,
+      JobContext jobContext, ContainerContext containerContext,
+      Map<TaskName, TaskInstanceCollector> taskInstanceCollectors, File loggedStoreBaseDirectory,
+      File nonLoggedStoreBaseDirectory, int maxChangeLogStreamPartitions, Clock clock) {
+
+    this.containerModel = containerModel;
+    this.changelogSystemStreams = changelogSystemStreams;
+    this.storageEngineFactories = storageEngineFactories;
+    this.serdes = serdes;
+    this.loggedStoreBaseDirectory = loggedStoreBaseDirectory;
+    this.nonLoggedStoreBaseDirectory = nonLoggedStoreBaseDirectory;
+
+    // set the config
+    this.config = config;
 
-  public ContainerStorageManager(Map<TaskName, TaskStorageManager> taskStorageManagers,
-      Map<String, SystemConsumer> systemConsumers, SamzaContainerMetrics samzaContainerMetrics) {
-    this.taskStorageManagers = taskStorageManagers;
-    this.systemConsumers = systemConsumers;
+    this.taskInstanceMetrics = taskInstanceMetrics;
+
+    // Setting the metrics registry
     this.samzaContainerMetrics = samzaContainerMetrics;
 
-    // Setting thread pool size equal to the number of tasks
-    this.parallelRestoreThreadPoolSize = taskStorageManagers.size();
+    this.jobContext = jobContext;
+    this.containerContext = containerContext;
+
+    this.taskInstanceCollectors = taskInstanceCollectors;
+
+    // initializing the set of store directory paths
+    this.storeDirectoryPaths = new HashSet<>();
+
+    // Setting the restore thread pool size equal to the number of taskInstances
+    this.parallelRestoreThreadPoolSize = containerModel.getTasks().size();
+
+    this.maxChangeLogStreamPartitions = maxChangeLogStreamPartitions;
+    this.streamMetadataCache = streamMetadataCache;
+
+    // create taskStores for all tasks in the containerModel and each store in storageEngineFactories
+    this.taskStores = createTaskStores(containerModel, jobContext, containerContext, storageEngineFactories, changelogSystemStreams,
+        serdes, taskInstanceMetrics, taskInstanceCollectors, StorageEngineFactory.StoreMode.BulkLoad);
+
+    // create system consumers (1 per store system)
+    this.systemConsumers = createStoreConsumers(changelogSystemStreams, systemFactories, config, this.samzaContainerMetrics.registry());
+
+    // creating task restore managers
+    this.taskRestoreManagers = createTaskRestoreManagers(systemAdmins, clock);
+  }
+
+  /**
+   *  Creates SystemConsumer objects for store restoration, creating one consumer per system.
+   */
+  private static Map<String, SystemConsumer> createStoreConsumers(Map<String, SystemStream> changelogSystemStreams,
+      Map<String, SystemFactory> systemFactories, Config config, MetricsRegistry registry) {
+    // Determine the set of systems being used across all stores
+    Set<String> storeSystems =
+        changelogSystemStreams.values().stream().map(SystemStream::getSystem).collect(Collectors.toSet());
+
+    // Create one consumer for each system in use, map with one entry for each such system
+    Map<String, SystemConsumer> storeSystemConsumers = new HashMap<>();
+
+    // Map of each storeName to its respective systemConsumer
+    Map<String, SystemConsumer> storeConsumers = new HashMap<>();
+
+    // Iterate over the list of storeSystems and create one sysConsumer per system
+    for (String storeSystemName : storeSystems) {
+      SystemFactory systemFactory = systemFactories.get(storeSystemName);
+      if (systemFactory == null) {
+        throw new SamzaException("Changelog system " + storeSystemName + " does not exist in config");
+      }
+      storeSystemConsumers.put(storeSystemName,
+          systemFactory.getConsumer(storeSystemName, config, registry));
+    }
+
+    // Populate the map of storeName to its relevant systemConsumer
+    for (String storeName : changelogSystemStreams.keySet()) {
+      storeConsumers.put(storeName, storeSystemConsumers.get(changelogSystemStreams.get(storeName).getSystem()));
+    }
+
+    return storeConsumers;
+  }
+
+  private Map<TaskName, TaskRestoreManager> createTaskRestoreManagers(SystemAdmins systemAdmins, Clock clock) {
+    Map<TaskName, TaskRestoreManager> taskRestoreManagers = new HashMap<>();
+    containerModel.getTasks().forEach((taskName, taskModel) ->
+      taskRestoreManagers.put(taskName, new TaskRestoreManager(taskModel, changelogSystemStreams, taskStores.get(taskName), systemAdmins, clock)));
+    return taskRestoreManagers;
+  }
+
+  /**
+   * Create taskStores with the given store mode for all stores in storageEngineFactories.
+   */
+  private Map<TaskName, Map<String, StorageEngine>> createTaskStores(ContainerModel containerModel, JobContext jobContext, ContainerContext containerContext,
+      Map<String, StorageEngineFactory<Object, Object>> storageEngineFactories,
+      Map<String, SystemStream> changelogSystemStreams, Map<String, Serde<Object>> serdes,
+      Map<TaskName, TaskInstanceMetrics> taskInstanceMetrics,
+      Map<TaskName, TaskInstanceCollector> taskInstanceCollectors, StorageEngineFactory.StoreMode storeMode) {
+
+    Map<TaskName, Map<String, StorageEngine>> taskStores = new HashMap<>();
+
+    // iterate over each task in the containerModel, and each store in storageEngineFactories
+    for (Map.Entry<TaskName, TaskModel> task : containerModel.getTasks().entrySet()) {
+      TaskName taskName = task.getKey();
+      TaskModel taskModel = task.getValue();
+
+      if (!taskStores.containsKey(taskName)) {
+        taskStores.put(taskName, new HashMap<>());
+      }
+
+      for (String storeName : storageEngineFactories.keySet()) {
+
+        StorageEngine storageEngine =
+            createStore(storeName, taskName, taskModel, jobContext, containerContext, storageEngineFactories,
+                changelogSystemStreams, serdes, taskInstanceMetrics, taskInstanceCollectors, storeMode);
+
+        // add created store to map
+        taskStores.get(taskName).put(storeName, storageEngine);
+
+        LOG.info("Created store {} for task {}", storeName, taskName);
+      }
+    }
+
+    return taskStores;
+  }
+
+  /**
+   * Recreate all persistent stores in ReadWrite mode.
+   *
+   */
+  private void recreatePersistentTaskStoresInReadWriteMode(ContainerModel containerModel, JobContext jobContext,
+      ContainerContext containerContext, Map<String, StorageEngineFactory<Object, Object>> storageEngineFactories,
+      Map<String, SystemStream> changelogSystemStreams, Map<String, Serde<Object>> serdes,
+      Map<TaskName, TaskInstanceMetrics> taskInstanceMetrics,
+      Map<TaskName, TaskInstanceCollector> taskInstanceCollectors) {
+
+    // iterate over each task and each storeName
+    for (Map.Entry<TaskName, TaskModel> task : containerModel.getTasks().entrySet()) {
+      TaskName taskName = task.getKey();
+      TaskModel taskModel = task.getValue();
+
+      for (String storeName : storageEngineFactories.keySet()) {
+
+        // if this store has been already created in the taskStores, then re-create and overwrite it only if it is a persistentStore
+        if (this.taskStores.get(taskName).containsKey(storeName) && this.taskStores.get(taskName)
+            .get(storeName)
+            .getStoreProperties()
+            .isPersistedToDisk()) {
+
+          StorageEngine storageEngine =
+              createStore(storeName, taskName, taskModel, jobContext, containerContext, storageEngineFactories,
+                  changelogSystemStreams, serdes, taskInstanceMetrics, taskInstanceCollectors,
+                  StorageEngineFactory.StoreMode.ReadWrite);
+
+          // add created store to map
+          this.taskStores.get(taskName).put(storeName, storageEngine);
+
+          LOG.info("Re-created store {} in read-write mode for task {} because it a persistent store", storeName, taskName);
+        } else {
+
+          LOG.info("Skipping re-creation of store {} for task {} because it a non-persistent store", storeName, taskName);
+        }
+      }
+    }
+  }
+
+  /**
+   * Method to instantiate a StorageEngine with the given parameters, and populate the storeDirectory paths (used to monitor
+   * disk space).
+   */
+  private StorageEngine createStore(String storeName, TaskName taskName, TaskModel taskModel, JobContext jobContext,
+      ContainerContext containerContext, Map<String, StorageEngineFactory<Object, Object>> storageEngineFactories,
+      Map<String, SystemStream> changelogSystemStreams, Map<String, Serde<Object>> serdes,
+      Map<TaskName, TaskInstanceMetrics> taskInstanceMetrics,
+      Map<TaskName, TaskInstanceCollector> taskInstanceCollectors, StorageEngineFactory.StoreMode storeMode) {
+
+    StorageConfig storageConfig = new StorageConfig(config);
+
+    SystemStreamPartition changeLogSystemStreamPartition =
+        (changelogSystemStreams.containsKey(storeName)) ? new SystemStreamPartition(
+            changelogSystemStreams.get(storeName), taskModel.getChangelogPartition()) : null;
+
+    // Use the logged-store-base-directory for change logged stores, and non-logged-store-base-dir for non logged stores
+    File storeDirectory =
+        (changeLogSystemStreamPartition != null) ? StorageManagerUtil.getStorePartitionDir(this.loggedStoreBaseDirectory,
+            storeName, taskName)
+            : StorageManagerUtil.getStorePartitionDir(this.nonLoggedStoreBaseDirectory, storeName, taskName);
+    this.storeDirectoryPaths.add(storeDirectory.toPath());
+
+    if (storageConfig.getStorageKeySerde(storeName).isEmpty()) {
+      throw new SamzaException("No key serde defined for store: " + storeName);
+    }
+
+    Serde keySerde = serdes.get(storageConfig.getStorageKeySerde(storeName).get());
+    if (keySerde == null) {
+      throw new SamzaException(
+          "StorageKeySerde: No class defined for serde: " + storageConfig.getStorageKeySerde(storeName));
+    }
+
+    if (storageConfig.getStorageMsgSerde(storeName).isEmpty()) {
+      throw new SamzaException("No msg serde defined for store: " + storeName);
+    }
+
+    Serde messageSerde = serdes.get(storageConfig.getStorageMsgSerde(storeName).get());
+    if (messageSerde == null) {
+      throw new SamzaException(
+          "StorageMsgSerde: No class defined for serde: " + storageConfig.getStorageMsgSerde(storeName));
+    }
+
+    // if taskInstanceMetrics are specified use those for store metrics,
+    // otherwise (in case of StorageRecovery) use a blank MetricsRegistryMap
+    MetricsRegistry storeMetricsRegistry =
+        taskInstanceMetrics.get(taskName) != null ? taskInstanceMetrics.get(taskName).registry()
+            : new MetricsRegistryMap();
+
+    return storageEngineFactories.get(storeName)
+        .getStorageEngine(storeName, storeDirectory, keySerde, messageSerde, taskInstanceCollectors.get(taskName),
+            storeMetricsRegistry, changeLogSystemStreamPartition, jobContext, containerContext, storeMode);
   }
 
   public void start() throws SamzaException {
     LOG.info("Restore started");
 
     // initialize each TaskStorageManager
-    this.taskStorageManagers.values().forEach(taskStorageManager -> taskStorageManager.init());
+    this.taskRestoreManagers.values().forEach(taskStorageManager -> taskStorageManager.initialize());
 
     // Start consumers
     this.systemConsumers.values().forEach(systemConsumer -> systemConsumer.start());
 
-    // Create a thread pool for parallel restores
+    // Create a thread pool for parallel restores (and stopping of persistent stores)
     ExecutorService executorService = Executors.newFixedThreadPool(this.parallelRestoreThreadPoolSize,
         new ThreadFactoryBuilder().setNameFormat(RESTORE_THREAD_NAME).build());
 
-    List<Future> taskRestoreFutures = new ArrayList<>(this.taskStorageManagers.entrySet().size());
+    List<Future> taskRestoreFutures = new ArrayList<>(this.taskRestoreManagers.entrySet().size());
 
     // Submit restore callable for each taskInstance
-    this.taskStorageManagers.forEach((taskInstance, taskStorageManager) -> {
-        taskRestoreFutures.add(
-            executorService.submit(new TaskRestoreCallable(this.samzaContainerMetrics, taskInstance, taskStorageManager)));
+    this.taskRestoreManagers.forEach((taskInstance, taskRestoreManager) -> {
+        taskRestoreFutures.add(executorService.submit(
+            new TaskRestoreCallable(this.samzaContainerMetrics, taskInstance, taskRestoreManager)));
       });
 
     // loop-over the future list to wait for each thread to finish, catch any exceptions during restore and throw
@@ -108,14 +366,50 @@ public class ContainerStorageManager {
     // Stop consumers
     this.systemConsumers.values().forEach(systemConsumer -> systemConsumer.stop());
 
+    // Now re-create persistent stores in read-write mode, leave non-persistent stores as-is
+    recreatePersistentTaskStoresInReadWriteMode(this.containerModel, jobContext, containerContext,
+        storageEngineFactories, changelogSystemStreams, serdes, taskInstanceMetrics, taskInstanceCollectors);
+
     LOG.info("Restore complete");
   }
 
+  /**
+   * Get the {@link StorageEngine} instance with a given name for a given task.
+   * @param taskName the task name for which the storage engine is desired.
+   * @param storeName the desired store's name.
+   * @return the task store.
+   */
+  public Optional<StorageEngine> getStore(TaskName taskName, String storeName) {
+    return Optional.ofNullable(this.taskStores.get(taskName).get(storeName));
+  }
+
+  /**
+   *  Get all {@link StorageEngine} instance used by a given task.
+   * @param taskName  the task name, all stores for which are desired.
+   * @return map of stores used by the given task, indexed by storename
+   */
+  public Map<String, StorageEngine> getAllStores(TaskName taskName) {
+    return this.taskStores.get(taskName);
+  }
+
+  /**
+   * Set of directory paths for all stores restored by this {@link ContainerStorageManager}.
+   * @return the set of all store directory paths
+   */
+  public Set<Path> getStoreDirectoryPaths() {
+    return this.storeDirectoryPaths;
+  }
+
+  @VisibleForTesting
+  public void stopStores() {
+    this.taskStores.forEach((taskName, storeMap) -> storeMap.forEach((storeName, store) -> store.stop()));
+  }
+
   public void shutdown() {
-    this.taskStorageManagers.forEach((taskInstance, taskStorageManager) -> {
-        if (taskStorageManager != null) {
+    this.taskRestoreManagers.forEach((taskInstance, taskRestoreManager) -> {
+        if (taskRestoreManager != null) {
           LOG.debug("Shutting down task storage manager for taskName: {} ", taskInstance);
-          taskStorageManager.stop();
+          taskRestoreManager.stop();
         } else {
           LOG.debug("Skipping task storage manager shutdown for taskName: {}", taskInstance);
         }
@@ -124,27 +418,36 @@ public class ContainerStorageManager {
     LOG.info("Shutdown complete");
   }
 
-  /** Callable for performing the restoreStores on a taskStorage manager and emitting task-restoration metric.
+  /**
+   * Callable for performing the restoreStores on a task restore manager and emitting the task-restoration metric.
+   * After restoration, all persistent stores are stopped (which will invoke compaction in case of certain persistent
+   * stores that were opened in bulk-load mode).
+   * Performing stop here parallelizes this compaction, which is a time-intensive operation.
    *
    */
   private class TaskRestoreCallable implements Callable<Void> {
 
     private TaskName taskName;
-    private TaskStorageManager taskStorageManager;
+    private TaskRestoreManager taskRestoreManager;
     private SamzaContainerMetrics samzaContainerMetrics;
 
     public TaskRestoreCallable(SamzaContainerMetrics samzaContainerMetrics, TaskName taskName,
-        TaskStorageManager taskStorageManager) {
+        TaskRestoreManager taskRestoreManager) {
       this.samzaContainerMetrics = samzaContainerMetrics;
       this.taskName = taskName;
-      this.taskStorageManager = taskStorageManager;
+      this.taskRestoreManager = taskRestoreManager;
     }
 
     @Override
     public Void call() {
       long startTime = System.currentTimeMillis();
       LOG.info("Starting stores in task instance {}", this.taskName.getTaskName());
-      taskStorageManager.restoreStores();
+      taskRestoreManager.restoreStores();
+
+      // Stop all persistent stores after restoring. Certain persistent stores opened in BulkLoad mode are compacted
+      // on stop, so paralleling stop() also parallelizes their compaction (a time-intensive operation).
+      taskRestoreManager.stopPersistentStores();
+
       long timeToRestore = System.currentTimeMillis() - startTime;
 
       if (this.samzaContainerMetrics != null) {
@@ -157,4 +460,280 @@ public class ContainerStorageManager {
       return null;
     }
   }
+
+  /**
+   * Restore logic for all stores of a task including directory cleanup, setup, changelogSSP validation, registering
+   * with the respective consumer, restoring stores, and stopping stores.
+   */
+  private class TaskRestoreManager {
+
+    private final static String OFFSET_FILE_NAME = "OFFSET";
+    private final Map<String, StorageEngine> taskStores; // Map of all StorageEngines for this task indexed by store name
+    private final Set<String> taskStoresToRestore;
+    // Set of store names which need to be restored by consuming using system-consumers (see registerStartingOffsets)
+
+    private final TaskModel taskModel;
+    private final Clock clock; // Clock value used to validate base-directories for staleness. See isLoggedStoreValid.
+    private Map<SystemStream, String> changeLogOldestOffsets; // Map of changelog oldest known offsets
+    private final Map<SystemStreamPartition, String> fileOffsets; // Map of offsets read from offset file indexed by changelog SSP
+    private final Map<String, SystemStream> changelogSystemStreams; // Map of change log system-streams indexed by store name
+    private final SystemAdmins systemAdmins;
+
+    public TaskRestoreManager(TaskModel taskModel, Map<String, SystemStream> changelogSystemStreams,
+        Map<String, StorageEngine> taskStores, SystemAdmins systemAdmins, Clock clock) {
+      this.taskStores = taskStores;
+      this.taskModel = taskModel;
+      this.clock = clock;
+      this.changelogSystemStreams = changelogSystemStreams;
+      this.systemAdmins = systemAdmins;
+      this.fileOffsets = new HashMap<>();
+      this.taskStoresToRestore = this.taskStores.entrySet().stream()
+          .filter(x -> x.getValue().getStoreProperties().isLoggedStore())
+          .map(x -> x.getKey()).collect(Collectors.toSet());
+    }
+
+    /**
+     * Cleans up and sets up store directories, validates changeLog SSPs for all stores of this task,
+     * and registers SSPs with the respective consumers.
+     */
+    public void initialize() {
+      cleanBaseDirsAndReadOffsetFiles();
+      setupBaseDirs();
+      validateChangelogStreams();
+      getOldestChangeLogOffsets();
+      registerStartingOffsets();
+    }
+
+    /**
+     * For each store for this task,
+     * a. Deletes the corresponding non-logged-store base dir.
+     * b. Deletes the logged-store-base-dir if it not valid. See {@link #isLoggedStoreValid} for validation semantics.
+     * c. If the logged-store-base-dir is valid, this method reads the offset file and stores each offset.
+     */
+    private void cleanBaseDirsAndReadOffsetFiles() {
+      LOG.debug("Cleaning base directories for stores.");
+
+      taskStores.keySet().forEach(storeName -> {
+          File nonLoggedStorePartitionDir =
+              StorageManagerUtil.getStorePartitionDir(nonLoggedStoreBaseDirectory, storeName, taskModel.getTaskName());
+          LOG.info("Got non logged storage partition directory as " + nonLoggedStorePartitionDir.toPath().toString());
+
+          if (nonLoggedStorePartitionDir.exists()) {
+            LOG.info("Deleting non logged storage partition directory " + nonLoggedStorePartitionDir.toPath().toString());
+            FileUtil.rm(nonLoggedStorePartitionDir);
+          }
+
+          File loggedStorePartitionDir =
+              StorageManagerUtil.getStorePartitionDir(loggedStoreBaseDirectory, storeName, taskModel.getTaskName());
+          LOG.info("Got logged storage partition directory as " + loggedStorePartitionDir.toPath().toString());
+
+          // Delete the logged store if it is not valid.
+          if (!isLoggedStoreValid(storeName, loggedStorePartitionDir)) {
+            LOG.info("Deleting logged storage partition directory " + loggedStorePartitionDir.toPath().toString());
+            FileUtil.rm(loggedStorePartitionDir);
+          } else {
+            String offset = StorageManagerUtil.readOffsetFile(loggedStorePartitionDir, OFFSET_FILE_NAME);
+            LOG.info("Read offset " + offset + " for the store " + storeName + " from logged storage partition directory "
+                + loggedStorePartitionDir);
+
+            if (offset != null) {
+              fileOffsets.put(
+                  new SystemStreamPartition(changelogSystemStreams.get(storeName), taskModel.getChangelogPartition()),
+                  offset);
+            }
+          }
+        });
+    }
+
+    /**
+     * Directory loggedStoreDir associated with the logged store storeName is determined to be valid
+     * if all of the following conditions are true.
+     * a) If the store has to be persisted to disk.
+     * b) If there is a valid offset file associated with the logged store.
+     * c) If the logged store has not gone stale.
+     *
+     * @return true if the logged store is valid, false otherwise.
+     */
+    private boolean isLoggedStoreValid(String storeName, File loggedStoreDir) {
+      long changeLogDeleteRetentionInMs = StorageConfig.DEFAULT_CHANGELOG_DELETE_RETENTION_MS();
+
+      if (new StorageConfig(config).getChangeLogDeleteRetentionsInMs().get(storeName).isDefined()) {
+        changeLogDeleteRetentionInMs =
+            (long) new StorageConfig(config).getChangeLogDeleteRetentionsInMs().get(storeName).get();
+      }
+
+      return this.taskStores.get(storeName).getStoreProperties().isPersistedToDisk()
+          && StorageManagerUtil.isOffsetFileValid(loggedStoreDir, OFFSET_FILE_NAME) && !StorageManagerUtil.isStaleStore(
+          loggedStoreDir, OFFSET_FILE_NAME, changeLogDeleteRetentionInMs, clock.currentTimeMillis());
+    }
+
+    /**
+     * Create stores' base directories for logged-stores if they dont exist.
+     */
+    private void setupBaseDirs() {
+      LOG.debug("Setting up base directories for stores.");
+      taskStores.forEach((storeName, storageEngine) -> {
+          if (storageEngine.getStoreProperties().isLoggedStore()) {
+
+            File loggedStorePartitionDir =
+                StorageManagerUtil.getStorePartitionDir(loggedStoreBaseDirectory, storeName, taskModel.getTaskName());
+
+            LOG.info("Using logged storage partition directory: " + loggedStorePartitionDir.toPath().toString()
+                + " for store: " + storeName);
+
+            if (!loggedStorePartitionDir.exists()) {
+              loggedStorePartitionDir.mkdirs();
+            }
+          } else {
+            File nonLoggedStorePartitionDir =
+                StorageManagerUtil.getStorePartitionDir(nonLoggedStoreBaseDirectory, storeName, taskModel.getTaskName());
+            LOG.info("Using non logged storage partition directory: " + nonLoggedStorePartitionDir.toPath().toString()
+                + " for store: " + storeName);
+            nonLoggedStorePartitionDir.mkdirs();
+          }
+        });
+    }
+
+    /**
+     *  Validates each changelog system-stream with its respective SystemAdmin.
+     */
+    private void validateChangelogStreams() {
+      LOG.info("Validating change log streams: " + changelogSystemStreams);
+
+      for (SystemStream changelogSystemStream : changelogSystemStreams.values()) {
+        SystemAdmin systemAdmin = systemAdmins.getSystemAdmin(changelogSystemStream.getSystem());
+        StreamSpec changelogSpec =
+            StreamSpec.createChangeLogStreamSpec(changelogSystemStream.getStream(), changelogSystemStream.getSystem(),
+                maxChangeLogStreamPartitions);
+
+        systemAdmin.validateStream(changelogSpec);
+      }
+    }
+
+    /**
+     * Get the oldest offset for each changelog SSP based on the stream's metadata (obtained from streamMetadataCache).
+     */
+    private void getOldestChangeLogOffsets() {
+
+      Map<SystemStream, SystemStreamMetadata> changeLogMetadata = JavaConverters.mapAsJavaMapConverter(
+          streamMetadataCache.getStreamMetadata(
+              JavaConverters.asScalaSetConverter(new HashSet<>(changelogSystemStreams.values())).asScala().toSet(),
+              false)).asJava();
+
+      LOG.info("Got change log stream metadata: {}", changeLogMetadata);
+
+      changeLogOldestOffsets =
+          getChangeLogOldestOffsetsForPartition(taskModel.getChangelogPartition(), changeLogMetadata);
+      LOG.info("Assigning oldest change log offsets for taskName {} : {}", taskModel.getTaskName(),
+          changeLogOldestOffsets);
+    }
+
+    /**
+     * Builds a map from SystemStreamPartition to oldest offset for changelogs.
+     */
+    private Map<SystemStream, String> getChangeLogOldestOffsetsForPartition(Partition partition,
+        Map<SystemStream, SystemStreamMetadata> inputStreamMetadata) {
+
+      Map<SystemStream, String> retVal = new HashMap<>();
+
+      // NOTE: do not use Collectors.Map because of https://bugs.openjdk.java.net/browse/JDK-8148463
+      inputStreamMetadata.entrySet()
+          .stream()
+          .filter(x -> x.getValue().getSystemStreamPartitionMetadata().get(partition) != null)
+          .forEach(e -> retVal.put(e.getKey(),
+              e.getValue().getSystemStreamPartitionMetadata().get(partition).getOldestOffset()));
+
+      return retVal;
+    }
+
+    /**
+     * Determines the starting offset for each store SSP (based on {@link #getStartingOffset(SystemStreamPartition, SystemAdmin)}) and
+     * registers it with the respective SystemConsumer for starting consumption.
+     */
+    private void registerStartingOffsets() {
+
+      for (Map.Entry<String, SystemStream> changelogSystemStreamEntry : changelogSystemStreams.entrySet()) {
+        SystemStreamPartition systemStreamPartition =
+            new SystemStreamPartition(changelogSystemStreamEntry.getValue(), taskModel.getChangelogPartition());
+        SystemAdmin systemAdmin = systemAdmins.getSystemAdmin(changelogSystemStreamEntry.getValue().getSystem());
+        SystemConsumer systemConsumer = systemConsumers.get(changelogSystemStreamEntry.getKey());
+
+        String offset = getStartingOffset(systemStreamPartition, systemAdmin);
+
+        if (offset != null) {
+          LOG.info("Registering change log consumer with offset " + offset + " for %" + systemStreamPartition);
+          systemConsumer.register(systemStreamPartition, offset);
+        } else {
+          LOG.info("Skipping change log restoration for {} because stream appears to be empty (offset was null).",
+              systemStreamPartition);
+          taskStoresToRestore.remove(changelogSystemStreamEntry.getKey());
+        }
+      }
+    }
+
+    /**
+     * Returns the offset with which the changelog consumer should be initialized for the given SystemStreamPartition.
+     *
+     * If a file offset exists, it represents the last changelog offset which is also reflected in the on-disk state.
+     * In that case, we use the next offset after the file offset, as long as it is newer than the oldest offset
+     * currently available in the stream.
+     *
+     * If there isn't a file offset or it's older than the oldest available offset, we simply start with the oldest.
+     *
+     * @param systemStreamPartition  the changelog partition for which the offset is needed.
+     * @param systemAdmin                  the [[SystemAdmin]] for the changelog.
+     * @return the offset to from which the changelog consumer should be initialized.
+     */
+    private String getStartingOffset(SystemStreamPartition systemStreamPartition, SystemAdmin systemAdmin) {
+      String fileOffset = fileOffsets.get(systemStreamPartition);
+
+      // NOTE: changeLogOldestOffsets may contain a null-offset for the given SSP (signifying an empty stream)
+      // therefore, we need to differentiate that from the case where the offset is simply missing
+      if (!changeLogOldestOffsets.containsKey(systemStreamPartition.getSystemStream())) {
+        throw new SamzaException("Missing a change log offset for " + systemStreamPartition);
+      }
+
+      String oldestOffset = changeLogOldestOffsets.get(systemStreamPartition.getSystemStream());
+      return StorageManagerUtil.getStartingOffset(systemStreamPartition, systemAdmin, fileOffset, oldestOffset);
+    }
+
+
+    /**
+     * Restore each store in taskStoresToRestore sequentially
+     */
+    public void restoreStores() {
+      LOG.debug("Restoring stores for task: {}", taskModel.getTaskName());
+
+      for (String storeName : taskStoresToRestore) {
+        SystemConsumer systemConsumer = systemConsumers.get(storeName);
+        SystemStream systemStream = changelogSystemStreams.get(storeName);
+
+        SystemStreamPartitionIterator systemStreamPartitionIterator = new SystemStreamPartitionIterator(systemConsumer,
+            new SystemStreamPartition(systemStream, taskModel.getChangelogPartition()));
+
+        taskStores.get(storeName).restore(systemStreamPartitionIterator);
+      }
+    }
+
+    /**
+     * Stop all stores.
+     */
+    public void stop() {
+      this.taskStores.values().forEach(storageEngine -> {
+          storageEngine.stop();
+        });
+    }
+
+    /**
+     * Stop only persistent stores. In case of certain stores and store mode (such as RocksDB), this
+     * can invoke compaction.
+     */
+    public void stopPersistentStores() {
+      this.taskStores.values().stream().filter(storageEngine -> {
+          return storageEngine.getStoreProperties().isPersistedToDisk();
+        }).forEach(storageEngine -> {
+            storageEngine.stop();
+          });
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/61255666/samza-core/src/main/scala/org/apache/samza/storage/TaskStorageManager.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/storage/TaskStorageManager.scala b/samza-core/src/main/scala/org/apache/samza/storage/TaskStorageManager.scala
index 4bcf2d3..f2c4679 100644
--- a/samza-core/src/main/scala/org/apache/samza/storage/TaskStorageManager.scala
+++ b/samza-core/src/main/scala/org/apache/samza/storage/TaskStorageManager.scala
@@ -20,210 +20,51 @@
 package org.apache.samza.storage
 
 import java.io._
-import java.util
 
-import org.apache.samza.config.StorageConfig
-import org.apache.samza.{Partition, SamzaException}
+import com.google.common.annotations.VisibleForTesting
+import org.apache.samza.Partition
 import org.apache.samza.container.TaskName
 import org.apache.samza.system._
-import org.apache.samza.util.{Clock, FileUtil, Logging}
+import org.apache.samza.util.ScalaJavaUtil.JavaOptionals
+import org.apache.samza.util.{FileUtil, Logging}
 
-object TaskStorageManager {
-  def getStoreDir(storeBaseDir: File, storeName: String) = {
-    new File(storeBaseDir, storeName)
-  }
-
-  def getStorePartitionDir(storeBaseDir: File, storeName: String, taskName: TaskName) = {
-    // TODO: Sanitize, check and clean taskName string as a valid value for a file
-    new File(storeBaseDir, (storeName + File.separator + taskName.toString).replace(' ', '_'))
-  }
-}
+import scala.collection.JavaConverters._
 
 /**
  * Manage all the storage engines for a given task
  */
 class TaskStorageManager(
   taskName: TaskName,
-  taskStores: Map[String, StorageEngine] = Map(),
-  storeConsumers: Map[String, SystemConsumer] = Map(),
+  containerStorageManager: ContainerStorageManager,
   changeLogSystemStreams: Map[String, SystemStream] = Map(),
-  changeLogStreamPartitions: Int,
-  streamMetadataCache: StreamMetadataCache,
   sspMetadataCache: SSPMetadataCache,
-  nonLoggedStoreBaseDir: File = new File(System.getProperty("user.dir"), "state"),
   loggedStoreBaseDir: File = new File(System.getProperty("user.dir"), "state"),
-  partition: Partition,
-  systemAdmins: SystemAdmins,
-  changeLogDeleteRetentionsInMs: Map[String, Long],
-  clock: Clock) extends Logging {
+  partition: Partition) extends Logging {
 
-  var taskStoresToRestore = taskStores.filter{
-    case (storeName, storageEngine) => storageEngine.getStoreProperties.isLoggedStore
-  }
-  val persistedStores = taskStores.filter{
+  val persistedStores = containerStorageManager.getAllStores(taskName).asScala.filter{
     case (storeName, storageEngine) => storageEngine.getStoreProperties.isPersistedToDisk
   }
 
-  var changeLogOldestOffsets: Map[SystemStream, String] = Map()
-  val fileOffsets: util.Map[SystemStreamPartition, String] = new util.HashMap[SystemStreamPartition, String]()
   val offsetFileName = "OFFSET"
 
-  def getStore(storeName: String): Option[StorageEngine] = taskStores.get(storeName)
+  def getStore(storeName: String): Option[StorageEngine] =  JavaOptionals.toRichOptional(containerStorageManager.getStore(taskName, storeName)).toOption
 
   def init {
-    cleanBaseDirs()
-    setupBaseDirs()
-    validateChangelogStreams()
-    registerSSPs()
-  }
-
-  private def cleanBaseDirs() {
-    debug("Cleaning base directories for stores.")
-
-    taskStores.keys.foreach(storeName => {
-      val nonLoggedStorePartitionDir = TaskStorageManager.getStorePartitionDir(nonLoggedStoreBaseDir, storeName, taskName)
-      info("Got non logged storage partition directory as %s" format nonLoggedStorePartitionDir.toPath.toString)
-
-      if(nonLoggedStorePartitionDir.exists()) {
-        info("Deleting non logged storage partition directory %s" format nonLoggedStorePartitionDir.toPath.toString)
-        FileUtil.rm(nonLoggedStorePartitionDir)
-      }
-
-      val loggedStorePartitionDir = TaskStorageManager.getStorePartitionDir(loggedStoreBaseDir, storeName, taskName)
-      info("Got logged storage partition directory as %s" format loggedStorePartitionDir.toPath.toString)
-
-      // Delete the logged store if it is not valid.
-      if (!isLoggedStoreValid(storeName, loggedStorePartitionDir)) {
-        info("Deleting logged storage partition directory %s." format loggedStorePartitionDir.toPath.toString)
-        FileUtil.rm(loggedStorePartitionDir)
-      } else {
-        val offset = StorageManagerUtil.readOffsetFile(loggedStorePartitionDir, offsetFileName)
-        info("Read offset %s for the store %s from logged storage partition directory %s." format(offset, storeName, loggedStorePartitionDir))
-        if (offset != null) {
-          fileOffsets.put(new SystemStreamPartition(changeLogSystemStreams(storeName), partition), offset)
-        }
-      }
-    })
-  }
-
-  /**
-    * Directory loggedStoreDir associated with the logged store storeName is valid
-    * if all of the following conditions are true.
-    * a) If the store has to be persisted to disk.
-    * b) If there is a valid offset file associated with the logged store.
-    * c) If the logged store has not gone stale.
-    *
-    * @return true if the logged store is valid, false otherwise.
-    */
-  private def isLoggedStoreValid(storeName: String, loggedStoreDir: File): Boolean = {
-    val changeLogDeleteRetentionInMs = changeLogDeleteRetentionsInMs
-      .getOrElse(storeName, StorageConfig.DEFAULT_CHANGELOG_DELETE_RETENTION_MS)
-
-    persistedStores.contains(storeName) &&
-      StorageManagerUtil.isOffsetFileValid(loggedStoreDir, offsetFileName) &&
-      !StorageManagerUtil.isStaleStore(loggedStoreDir, offsetFileName, changeLogDeleteRetentionInMs, clock.currentTimeMillis())
-  }
-
-  private def setupBaseDirs() {
-    debug("Setting up base directories for stores.")
-    taskStores.foreach {
-      case (storeName, storageEngine) =>
-        if (storageEngine.getStoreProperties.isLoggedStore) {
-          val loggedStorePartitionDir = TaskStorageManager.getStorePartitionDir(loggedStoreBaseDir, storeName, taskName)
-          info("Using logged storage partition directory: %s for store: %s." format(loggedStorePartitionDir.toPath.toString, storeName))
-          if (!loggedStorePartitionDir.exists()) loggedStorePartitionDir.mkdirs()
-        } else {
-          val nonLoggedStorePartitionDir = TaskStorageManager.getStorePartitionDir(nonLoggedStoreBaseDir, storeName, taskName)
-          info("Using non logged storage partition directory: %s for store: %s." format(nonLoggedStorePartitionDir.toPath.toString, storeName))
-          nonLoggedStorePartitionDir.mkdirs()
-        }
-    }
-  }
-
-  private def validateChangelogStreams() = {
-    info("Validating change log streams: " + changeLogSystemStreams)
-
-    for ((storeName, systemStream) <- changeLogSystemStreams) {
-      val systemAdmin = systemAdmins.getSystemAdmin(systemStream.getSystem)
-      val changelogSpec = StreamSpec.createChangeLogStreamSpec(systemStream.getStream, systemStream.getSystem, changeLogStreamPartitions)
-
-      systemAdmin.validateStream(changelogSpec)
-    }
-
-    val changeLogMetadata = streamMetadataCache.getStreamMetadata(changeLogSystemStreams.values.toSet)
-    info("Got change log stream metadata: %s" format changeLogMetadata)
-
-    changeLogOldestOffsets = getChangeLogOldestOffsetsForPartition(partition, changeLogMetadata)
-    info("Assigning oldest change log offsets for taskName %s: %s" format (taskName, changeLogOldestOffsets))
-  }
-
-  private def registerSSPs() {
-    debug("Starting consumers for stores.")
-
-    for ((storeName, systemStream) <- changeLogSystemStreams) {
-      val systemStreamPartition = new SystemStreamPartition(systemStream, partition)
-      val admin = systemAdmins.getSystemAdmin(systemStream.getSystem)
-      val consumer = storeConsumers(storeName)
-
-      val offset = getStartingOffset(systemStreamPartition, admin)
-      if (offset != null) {
-        info("Registering change log consumer with offset %s for %s." format (offset, systemStreamPartition))
-        consumer.register(systemStreamPartition, offset)
-      } else {
-        info("Skipping change log restoration for %s because stream appears to be empty (offset was null)." format systemStreamPartition)
-        taskStoresToRestore -= storeName
-      }
-    }
-  }
-
-  /**
-    * Returns the offset with which the changelog consumer should be initialized for the given SystemStreamPartition.
-    *
-    * If a file offset exists, it represents the last changelog offset which is also reflected in the on-disk state.
-    * In that case, we use the next offset after the file offset, as long as it is newer than the oldest offset
-    * currently available in the stream.
-    *
-    * If there isn't a file offset or it's older than the oldest available offset, we simply start with the oldest.
-    *
-    * @param systemStreamPartition  the changelog partition for which the offset is needed.
-    * @param admin                  the [[SystemAdmin]] for the changelog.
-    * @return                       the offset to from which the changelog consumer should be initialized.
-    */
-  private def getStartingOffset(systemStreamPartition: SystemStreamPartition, admin: SystemAdmin) = {
-    val fileOffset = fileOffsets.get(systemStreamPartition)
-    val oldestOffset = changeLogOldestOffsets
-      .getOrElse(systemStreamPartition.getSystemStream,
-        throw new SamzaException("Missing a change log offset for %s." format systemStreamPartition))
-
-    StorageManagerUtil.getStartingOffset(systemStreamPartition, admin, fileOffset, oldestOffset)
-  }
-
-  def restoreStores() {
-    debug("Restoring stores for task: %s." format taskName.getTaskName)
-
-    for ((storeName, store) <- taskStoresToRestore) {
-      if (changeLogSystemStreams.contains(storeName)) {
-        val systemStream = changeLogSystemStreams(storeName)
-        val systemStreamPartition = new SystemStreamPartition(systemStream, partition)
-        val systemConsumer = storeConsumers(storeName)
-        val systemConsumerIterator = new SystemStreamPartitionIterator(systemConsumer, systemStreamPartition)
-        store.restore(systemConsumerIterator)
-      }
-    }
   }
 
   def flush() {
     debug("Flushing stores.")
 
-    taskStores.values.foreach(_.flush)
+    containerStorageManager.getAllStores(taskName).asScala.values.foreach(_.flush)
     flushChangelogOffsetFiles()
   }
 
   def stopStores() {
     debug("Stopping stores.")
-    taskStores.values.foreach(_.stop)
+    containerStorageManager.stopStores();
   }
 
+  @VisibleForTesting
   def stop() {
     stopStores()
 
@@ -249,7 +90,7 @@ class TaskStorageManager(
         val newestOffset = if (sspMetadata == null) null else sspMetadata.getNewestOffset
         debug("Got offset %s for store %s" format(newestOffset, storeName))
 
-        val loggedStorePartitionDir = TaskStorageManager.getStorePartitionDir(loggedStoreBaseDir, storeName, taskName)
+        val loggedStorePartitionDir = StorageManagerUtil.getStorePartitionDir(loggedStoreBaseDir, storeName, taskName)
         val offsetFile = new File(loggedStorePartitionDir, offsetFileName)
         if (newestOffset != null) {
           debug("Storing offset for store in OFFSET file ")
@@ -270,14 +111,4 @@ class TaskStorageManager(
 
     debug("Done persisting logged key value stores")
   }
-
-  /**
-   * Builds a map from SystemStreamPartition to oldest offset for changelogs.
-   */
-  private def getChangeLogOldestOffsetsForPartition(partition: Partition, inputStreamMetadata: Map[SystemStream, SystemStreamMetadata]): Map[SystemStream, String] = {
-    inputStreamMetadata
-      .mapValues(_.getSystemStreamPartitionMetadata.get(partition))
-      .filter(_._2 != null)
-      .mapValues(_.getOldestOffset)
-  }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/61255666/samza-core/src/main/scala/org/apache/samza/util/ScalaJavaUtil.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/util/ScalaJavaUtil.scala b/samza-core/src/main/scala/org/apache/samza/util/ScalaJavaUtil.scala
index a359cd5..049feba 100644
--- a/samza-core/src/main/scala/org/apache/samza/util/ScalaJavaUtil.scala
+++ b/samza-core/src/main/scala/org/apache/samza/util/ScalaJavaUtil.scala
@@ -21,8 +21,10 @@
 
 package org.apache.samza.util
 
-import java.util.function
+import java.util
+import java.util.{Optional, function}
 
+import scala.collection.JavaConverters
 import scala.collection.immutable.Map
 import scala.collection.JavaConverters._
 import scala.runtime.AbstractFunction0
@@ -37,6 +39,13 @@ object ScalaJavaUtil {
   }
 
   /**
+    * Convert a scala iterable to a Java collection.
+    */
+  def toJavaCollection[T](iterable : Iterable[T]): util.Collection[T] = {
+    JavaConverters.asJavaCollectionConverter(iterable).asJavaCollection
+  }
+
+  /**
     * Wraps the provided value in an Scala Function, e.g. for use in [[Option#getOrDefault]]
     *
     * @param value the value to be wrapped
@@ -71,4 +80,30 @@ object ScalaJavaUtil {
   def toScalaFunction[T, R](javaFunction: java.util.function.Function[T, R]): Function1[T, R] = {
     t => javaFunction.apply(t)
   }
+
+
+  /**
+    * Conversions between Scala Option and Java 8 Optional.
+    */
+  object JavaOptionals {
+    implicit def toRichOption[T](opt: Option[T]): RichOption[T] = new RichOption[T](opt)
+    implicit def toRichOptional[T](optional: Optional[T]): RichOptional[T] = new RichOptional[T](optional)
+  }
+
+  class RichOption[T] (opt: Option[T]) {
+
+    /**
+      * Transform this Option to an equivalent Java Optional
+      */
+    def toOptional: Optional[T] = Optional.ofNullable(opt.getOrElse(null).asInstanceOf[T])
+  }
+
+  class RichOptional[T] (opt: Optional[T]) {
+
+    /**
+      * Transform this Optional to an equivalent Scala Option
+      */
+    def toOption: Option[T] = if (opt.isPresent) Some(opt.get()) else None
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/61255666/samza-core/src/test/java/org/apache/samza/storage/MockStorageEngineFactory.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/storage/MockStorageEngineFactory.java b/samza-core/src/test/java/org/apache/samza/storage/MockStorageEngineFactory.java
index 8eff4ad..9a47705 100644
--- a/samza-core/src/test/java/org/apache/samza/storage/MockStorageEngineFactory.java
+++ b/samza-core/src/test/java/org/apache/samza/storage/MockStorageEngineFactory.java
@@ -37,7 +37,7 @@ public class MockStorageEngineFactory implements StorageEngineFactory<Object, Ob
       MetricsRegistry registry,
       SystemStreamPartition changeLogSystemStreamPartition,
       JobContext jobContext,
-      ContainerContext containerContext) {
+      ContainerContext containerContext, StoreMode storeMode) {
     StoreProperties storeProperties = new StoreProperties.StorePropertiesBuilder().setLoggedStore(true).build();
     return new MockStorageEngine(storeName, storeDir, changeLogSystemStreamPartition, storeProperties);
   }

http://git-wip-us.apache.org/repos/asf/samza/blob/61255666/samza-core/src/test/java/org/apache/samza/storage/TestStorageRecovery.java
----------------------------------------------------------------------
diff --git a/samza-core/src/test/java/org/apache/samza/storage/TestStorageRecovery.java b/samza-core/src/test/java/org/apache/samza/storage/TestStorageRecovery.java
index 7c1647e..0bd33fa 100644
--- a/samza-core/src/test/java/org/apache/samza/storage/TestStorageRecovery.java
+++ b/samza-core/src/test/java/org/apache/samza/storage/TestStorageRecovery.java
@@ -26,6 +26,7 @@ import org.apache.samza.Partition;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.MapConfig;
 import org.apache.samza.coordinator.stream.MockCoordinatorStreamSystemFactory;
+import org.apache.samza.serializers.ByteSerdeFactory;
 import org.apache.samza.system.IncomingMessageEnvelope;
 import org.apache.samza.system.MockSystemFactory;
 import org.apache.samza.system.SystemStreamPartition;
@@ -33,11 +34,12 @@ import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 
-import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.*;
 
 public class TestStorageRecovery {
 
   public Config config = null;
+  String path = "/tmp/testing";
   private static final String SYSTEM_STREAM_NAME = "changelog";
   private static final String INPUT_STREAM = "input";
   private static final String STORE_NAME = "testStore";
@@ -59,7 +61,7 @@ public class TestStorageRecovery {
   public void testStorageEngineReceivedAllValues() {
     MockCoordinatorStreamSystemFactory.enableMockConsumerCache();
 
-    String path = "/tmp/testing";
+
     StorageRecovery storageRecovery = new StorageRecovery(config, path);
     storageRecovery.run();
 
@@ -79,6 +81,9 @@ public class TestStorageRecovery {
     map.put("systems.mockSystem.samza.factory", MockSystemFactory.class.getCanonicalName());
     map.put(String.format("stores.%s.factory", STORE_NAME), MockStorageEngineFactory.class.getCanonicalName());
     map.put(String.format("stores.%s.changelog", STORE_NAME), "mockSystem." + SYSTEM_STREAM_NAME);
+    map.put(String.format("stores.%s.key.serde", STORE_NAME), "byteserde");
+    map.put(String.format("stores.%s.msg.serde", STORE_NAME), "byteserde");
+    map.put("serializers.registry.byteserde.class", ByteSerdeFactory.class.getName());
     map.put("task.inputs", "mockSystem.input");
     map.put("job.coordinator.system", "coordinator");
     map.put("systems.coordinator.samza.factory", MockCoordinatorStreamSystemFactory.class.getCanonicalName());