You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@samza.apache.org by ni...@apache.org on 2015/05/22 17:46:53 UTC

samza git commit: Backport SAMZA-608, SAMZA-616, and SAMZA-658 fixes to 0.9.1

Repository: samza
Updated Branches:
  refs/heads/0.9.1 0d140058c -> 93faa06cb


Backport SAMZA-608, SAMZA-616, and SAMZA-658 fixes to 0.9.1


Project: http://git-wip-us.apache.org/repos/asf/samza/repo
Commit: http://git-wip-us.apache.org/repos/asf/samza/commit/93faa06c
Tree: http://git-wip-us.apache.org/repos/asf/samza/tree/93faa06c
Diff: http://git-wip-us.apache.org/repos/asf/samza/diff/93faa06c

Branch: refs/heads/0.9.1
Commit: 93faa06cb25ea3a6b9050de15326391ce99cd44e
Parents: 0d14005
Author: Yi Pan (Data Infrastructure) <yi...@linkedin.com>
Authored: Fri May 22 08:46:18 2015 -0700
Committer: Yi Pan (Data Infrastructure) <yi...@linkedin.com>
Committed: Fri May 22 08:46:18 2015 -0700

----------------------------------------------------------------------
 .../versioned/jobs/configuration-table.html     | 30 +++++---
 .../org/apache/samza/config/TaskConfig.scala    |  6 ++
 .../org/apache/samza/container/RunLoop.scala    | 48 ++++++-------
 .../apache/samza/container/SamzaContainer.scala |  7 +-
 .../apache/samza/system/SystemConsumers.scala   | 18 +++--
 .../apache/samza/container/TestRunLoop.scala    | 34 ---------
 .../samza/system/TestSystemConsumers.scala      | 25 ++++++-
 .../kv/inmemory/InMemoryKeyValueStore.scala     | 46 ++++++------
 .../kv/BaseKeyValueStorageEngineFactory.scala   | 24 ++++---
 .../apache/samza/storage/kv/CachedStore.scala   | 62 +++++++++++------
 .../storage/kv/KeyValueStorageEngine.scala      | 34 ++++-----
 .../samza/storage/kv/MockKeyValueStore.scala    | 73 ++++++++++++++++++++
 .../samza/storage/kv/TestCachedStore.scala      | 58 +++++++++++++++-
 13 files changed, 318 insertions(+), 147 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/samza/blob/93faa06c/docs/learn/documentation/versioned/jobs/configuration-table.html
----------------------------------------------------------------------
diff --git a/docs/learn/documentation/versioned/jobs/configuration-table.html b/docs/learn/documentation/versioned/jobs/configuration-table.html
index e091460..5ebe8a7 100644
--- a/docs/learn/documentation/versioned/jobs/configuration-table.html
+++ b/docs/learn/documentation/versioned/jobs/configuration-table.html
@@ -358,7 +358,7 @@
                     <td class="default"></td>
                     <td class="description">
                         This property is to define how the system deals with deserialization failure situation. If set to true, the system will
-                        skip the error messages and keep running. If set to false, the system with throw exceptions and fail the container. Default 
+                        skip the error messages and keep running. If set to false, the system with throw exceptions and fail the container. Default
                         is false.
                     </td>
                 </tr>
@@ -368,7 +368,7 @@
                     <td class="default"></td>
                     <td class="description">
                         This property is to define how the system deals with serialization failure situation. If set to true, the system will
-                        drop the error messages and keep running. If set to false, the system with throw exceptions and fail the container. Default 
+                        drop the error messages and keep running. If set to false, the system with throw exceptions and fail the container. Default
                         is false.
                     </td>
                 </tr>
@@ -392,7 +392,7 @@
                     <td class="default">false</td>
                     <td class="description">
                         Defines whether or not to include log4j's LocationInfo data in Log4j StreamAppender messages. LocationInfo includes
-                        information such as the file, class, and line that wrote a log message. This setting is only active if the Log4j 
+                        information such as the file, class, and line that wrote a log message. This setting is only active if the Log4j
                         stream appender is being used. (See <a href="logging.html#stream-log4j-appender">Stream Log4j Appender</a>)
                         <dl>
                             <dt>Example: <code>task.log4j.location.info.enabled=true</code></dt>
@@ -404,13 +404,13 @@
                     <td class="property" id="task-poll-interval-ms">task.poll.interval.ms</td>
                     <td class="default"></td>
                     <td class="description">
-                      Samza's container polls for more messages under two conditions. The first condition arises when there are simply no remaining 
-                      buffered messages to process for any input SystemStreamPartition. The second condition arises when some input 
+                      Samza's container polls for more messages under two conditions. The first condition arises when there are simply no remaining
+                      buffered messages to process for any input SystemStreamPartition. The second condition arises when some input
                       SystemStreamPartitions have empty buffers, but some do not. In the latter case, a polling interval is defined to determine how
-                      often to refresh the empty SystemStreamPartition buffers. By default, this interval is 50ms, which means that any empty 
-                      SystemStreamPartition buffer will be refreshed at least every 50ms. A higher value here means that empty SystemStreamPartitions 
-                      will be refreshed less often, which means more latency is introduced, but less CPU and network will be used. Decreasing this 
-                      value means that empty SystemStreamPartitions are refreshed more frequently, thereby introducing less latency, but increasing 
+                      often to refresh the empty SystemStreamPartition buffers. By default, this interval is 50ms, which means that any empty
+                      SystemStreamPartition buffer will be refreshed at least every 50ms. A higher value here means that empty SystemStreamPartitions
+                      will be refreshed less often, which means more latency is introduced, but less CPU and network will be used. Decreasing this
+                      value means that empty SystemStreamPartitions are refreshed more frequently, thereby introducing less latency, but increasing
                       CPU and network utilization.
                     </td>
                 </tr>
@@ -426,6 +426,14 @@
                 </tr>
 
                 <tr>
+                    <td class="property" id="task-shutdown-ms">task.shutdown.ms</td>
+                    <td class="default">5000</td>
+                    <td class="description">
+                        This property controls how long the Samza container will wait for an orderly shutdown of task instances.
+                    </td>
+                </tr>
+
+                <tr>
                     <th colspan="3" class="section" id="streams"><a href="../container/streams.html">Systems (input and output streams)</a></th>
                 </tr>
 
@@ -502,7 +510,7 @@
                         one of the following:
                         <dl>
                             <dt><code>upcoming</code></dt>
-                            <dd>Start processing messages that are published after the job starts. Any messages published while 
+                            <dd>Start processing messages that are published after the job starts. Any messages published while
                                 the job was not running are not processed.</dd>
                             <dt><code>oldest</code></dt>
                             <dd>Start processing at the oldest available message in the system, and
@@ -766,7 +774,7 @@
                 </tr>
 
                 <tr>
-                    <td class="property" id="store-changelog-partitions">stores.<span class="store">store-name</span>.changelog.<br>replication.factor</td>
+                    <td class="property" id="store-changelog-replication-factor">stores.<span class="store">store-name</span>.changelog.<br>replication.factor</td>
                     <td class="default">2</td>
                     <td class="description">
                         The property defines the number of replicas to use for the change log stream.

http://git-wip-us.apache.org/repos/asf/samza/blob/93faa06c/samza-core/src/main/scala/org/apache/samza/config/TaskConfig.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/config/TaskConfig.scala b/samza-core/src/main/scala/org/apache/samza/config/TaskConfig.scala
index 1ca9e2c..cd06c06 100644
--- a/samza-core/src/main/scala/org/apache/samza/config/TaskConfig.scala
+++ b/samza-core/src/main/scala/org/apache/samza/config/TaskConfig.scala
@@ -27,6 +27,7 @@ object TaskConfig {
   val INPUT_STREAMS = "task.inputs" // streaming.input-streams
   val WINDOW_MS = "task.window.ms" // window period in milliseconds
   val COMMIT_MS = "task.commit.ms" // commit period in milliseconds
+  val SHUTDOWN_MS = "task.shutdown.ms" // how long to wait for a clean shutdown
   val TASK_CLASS = "task.class" // streaming.task-factory-class
   val COMMAND_BUILDER = "task.command.class" // streaming.task-factory-class
   val LIFECYCLE_LISTENERS = "task.lifecycle.listeners" // li-generator,foo
@@ -79,6 +80,11 @@ class TaskConfig(config: Config) extends ScalaMapConfig(config) {
     case _ => None
   }
 
+  def getShutdownMs: Option[Long] = getOption(TaskConfig.SHUTDOWN_MS) match {
+    case Some(ms) => Some(ms.toLong)
+    case _ => None
+  }
+
   def getLifecycleListeners(): Option[String] = getOption(TaskConfig.LIFECYCLE_LISTENERS)
 
   def getLifecycleListenerClass(name: String): Option[String] = getOption(TaskConfig.LIFECYCLE_LISTENER format name)

http://git-wip-us.apache.org/repos/asf/samza/blob/93faa06c/samza-core/src/main/scala/org/apache/samza/container/RunLoop.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/container/RunLoop.scala b/samza-core/src/main/scala/org/apache/samza/container/RunLoop.scala
index 4098235..4c0faf6 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/RunLoop.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/RunLoop.scala
@@ -39,7 +39,8 @@ class RunLoop(
   val metrics: SamzaContainerMetrics,
   val windowMs: Long = -1,
   val commitMs: Long = 60000,
-  val clock: () => Long = { System.currentTimeMillis }) extends Runnable with TimerUtils with Logging {
+  val clock: () => Long = { System.currentTimeMillis },
+  val shutdownMs: Long = 5000) extends Runnable with TimerUtils with Logging {
 
   private var lastWindowMs = 0L
   private var lastCommitMs = 0L
@@ -56,39 +57,36 @@ class RunLoop(
     taskInstances.values.map { getSystemStreamPartitionToTaskInstance }.flatten.toMap
   }
 
-  val shutdownHook = new Thread() {
-    override def run() = {
-      info("Triggering shutdown in response to shutdown hook")
-      shutdownNow = true
-    }
-  }
-
-  protected def addShutdownHook() {
-    Runtime.getRuntime().addShutdownHook(shutdownHook)
-  }
-
-  protected def removeShutdownHook() {
-    Runtime.getRuntime().removeShutdownHook(shutdownHook)
-  }
 
   /**
    * Starts the run loop. Blocks until either the tasks request shutdown, or an
    * unhandled exception is thrown.
    */
   def run {
-    try {
-      addShutdownHook()
+    addShutdownHook(Thread.currentThread())
 
-      while (!shutdownNow) {
-        process
-        window
-        commit
-      }
-    } finally {
-      removeShutdownHook()
+    while (!shutdownNow) {
+      process
+      window
+      commit
     }
   }
 
+  private def addShutdownHook(runLoopThread: Thread) {
+    Runtime.getRuntime().addShutdownHook(new Thread() {
+      override def run() = {
+        info("Shutting down, will wait up to %s ms" format shutdownMs)
+        shutdownNow = true
+        runLoopThread.join(shutdownMs)
+        if (runLoopThread.isAlive) {
+          warn("Did not shut down within %s ms, exiting" format shutdownMs)
+        } else {
+          info("Shutdown complete")
+        }
+      }
+    })
+  }
+
   /**
    * Chooses a message from an input stream to process, and calls the
    * process() method on the appropriate StreamTask to handle it.
@@ -189,4 +187,4 @@ class RunLoop(
       shutdownNow = true
     }
   }
-}
\ No newline at end of file
+}

http://git-wip-us.apache.org/repos/asf/samza/blob/93faa06c/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
index 6dbef29..ce4d527 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
@@ -401,6 +401,10 @@ object SamzaContainer extends Logging {
 
     info("Got commit milliseconds: %s" format taskCommitMs)
 
+    val taskShutdownMs = config.getShutdownMs.getOrElse(5000L)
+
+    info("Got shutdown timeout milliseconds: %s" format taskShutdownMs)
+
     // Wire up all task-instance-level (unshared) objects.
 
     val taskNames = containerModel
@@ -509,7 +513,8 @@ object SamzaContainer extends Logging {
       consumerMultiplexer = consumerMultiplexer,
       metrics = samzaContainerMetrics,
       windowMs = taskWindowMs,
-      commitMs = taskCommitMs)
+      commitMs = taskCommitMs,
+      shutdownMs = taskShutdownMs)
 
     info("Samza container setup complete.")
 

http://git-wip-us.apache.org/repos/asf/samza/blob/93faa06c/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala b/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala
index 125d376..44cd140 100644
--- a/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala
+++ b/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala
@@ -204,9 +204,7 @@ class SystemConsumers(
       metrics.choseObject.inc
       metrics.systemStreamMessagesChosen(envelopeFromChooser.getSystemStreamPartition.getSystemStream).inc
 
-      if (!update(systemStreamPartition)) {
-        emptySystemStreamPartitionsBySystem.get(systemStreamPartition.getSystem).add(systemStreamPartition)
-      }
+      tryUpdate(systemStreamPartition)
     }
 
     if (envelopeFromChooser == null || lastPollMs < clock() - pollIntervalMs) {
@@ -257,7 +255,7 @@ class SystemConsumers(
 
           // Update the chooser if it needs a message for this SSP.
           if (emptySystemStreamPartitionsBySystem.get(systemStreamPartition.getSystem).remove(systemStreamPartition)) {
-            update(systemStreamPartition)
+            tryUpdate(systemStreamPartition)
           }
         }
       }
@@ -266,6 +264,18 @@ class SystemConsumers(
     }
   }
 
+  private def tryUpdate(ssp: SystemStreamPartition) {
+    var updated = false
+    try {
+      updated = update(ssp)
+    } finally {
+      if (!updated) {
+        // if failed to update the chooser, add the ssp back into the emptySystemStreamPartitionBySystem map to ensure that we will poll for the next message
+        emptySystemStreamPartitionsBySystem.get(ssp.getSystem).add(ssp)
+      }
+    }
+  }
+
   private def refresh {
     trace("Refreshing chooser with new messages.")
 

http://git-wip-us.apache.org/repos/asf/samza/blob/93faa06c/samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala b/samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala
index 2a0897f..73ec2b5 100644
--- a/samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala
+++ b/samza-core/src/test/scala/org/apache/samza/container/TestRunLoop.scala
@@ -211,38 +211,4 @@ class TestRunLoop extends AssertionsForJUnit with MockitoSugar with ScalaTestMat
     testMetrics.processMs.getSnapshot.getSize should equal(2)
     testMetrics.commitMs.getSnapshot.getSize should equal(2)
   }
-
-  @Test
-  def testShutdownHook: Unit = {
-    // The shutdown hook can't be directly tested so we verify that a) both add and remove
-    // are called and b) invoking the shutdown hook actually kills the run loop.
-    val consumers = mock[SystemConsumers]
-    when(consumers.choose).thenReturn(envelope0)
-    val testMetrics = new SamzaContainerMetrics
-    var addCalled = false
-    var removeCalled = false
-    val runLoop = new RunLoop(
-      taskInstances = getMockTaskInstances,
-      consumerMultiplexer = consumers,
-      metrics = testMetrics) {
-      override def addShutdownHook() {
-        addCalled = true
-      }
-      override def removeShutdownHook() {
-        removeCalled = true
-      }
-    }
-
-    val runThread = new Thread(runLoop)
-    runThread.start()
-
-    runLoop.shutdownHook.start()
-    runLoop.shutdownHook.join(1000)
-    runThread.join(1000)
-
-    assert(addCalled)
-    assert(removeCalled)
-    assert(!runLoop.shutdownHook.isAlive)
-    assert(!runThread.isAlive)
-  }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/93faa06c/samza-core/src/test/scala/org/apache/samza/system/TestSystemConsumers.scala
----------------------------------------------------------------------
diff --git a/samza-core/src/test/scala/org/apache/samza/system/TestSystemConsumers.scala b/samza-core/src/test/scala/org/apache/samza/system/TestSystemConsumers.scala
index 3fdc781..fbaa8ee 100644
--- a/samza-core/src/test/scala/org/apache/samza/system/TestSystemConsumers.scala
+++ b/samza-core/src/test/scala/org/apache/samza/system/TestSystemConsumers.scala
@@ -241,9 +241,10 @@ class TestSystemConsumers {
     // it should not throw exceptions when deserializaion fails if dropDeserializationError is set to true
     val consumers2 = new SystemConsumers(msgChooser, consumer, serdeManager, dropDeserializationError = true)
     consumers2.register(systemStreamPartition, "0")
-    consumers2.start
     consumer(system).putBytesMessage
     consumer(system).putStringMessage
+    consumer(system).putBytesMessage
+    consumers2.start
 
     var notThrowException = true;
     try {
@@ -251,9 +252,29 @@ class TestSystemConsumers {
     } catch {
       case e: Throwable => notThrowException = false
     }
-
     assertTrue("it should not throw any exception", notThrowException)
+
+    var msgEnvelope = Some(consumers2.choose)
+    assertTrue("Consumer did not succeed in receiving the second message after Serde exception in choose", msgEnvelope.get != null)
+    consumers2.stop
+
+    // ensure that the system consumer will continue after poll() method ignored a Serde exception
+    consumer(system).putStringMessage
+    consumer(system).putBytesMessage
+
+    notThrowException = true;
+    try {
+      consumers2.start
+    } catch {
+      case e: Throwable => notThrowException = false
+    }
+    assertTrue("SystemConsumer start should not throw any Serde exception", notThrowException)
+
+    msgEnvelope = null
+    msgEnvelope = Some(consumers2.choose)
+    assertTrue("Consumer did not succeed in receiving the second message after Serde exception in poll", msgEnvelope.get != null)
     consumers2.stop
+
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/samza/blob/93faa06c/samza-kv-inmemory/src/main/scala/org/apache/samza/storage/kv/inmemory/InMemoryKeyValueStore.scala
----------------------------------------------------------------------
diff --git a/samza-kv-inmemory/src/main/scala/org/apache/samza/storage/kv/inmemory/InMemoryKeyValueStore.scala b/samza-kv-inmemory/src/main/scala/org/apache/samza/storage/kv/inmemory/InMemoryKeyValueStore.scala
index 217333c..e93eb1e 100644
--- a/samza-kv-inmemory/src/main/scala/org/apache/samza/storage/kv/inmemory/InMemoryKeyValueStore.scala
+++ b/samza-kv-inmemory/src/main/scala/org/apache/samza/storage/kv/inmemory/InMemoryKeyValueStore.scala
@@ -20,7 +20,7 @@ package org.apache.samza.storage.kv.inmemory
 
 import com.google.common.primitives.UnsignedBytes
 import org.apache.samza.util.Logging
-import org.apache.samza.storage.kv.{KeyValueStoreMetrics, KeyValueIterator, Entry, KeyValueStore}
+import org.apache.samza.storage.kv.{ KeyValueStoreMetrics, KeyValueIterator, Entry, KeyValueStore }
 import java.util
 
 /**
@@ -31,9 +31,9 @@ import java.util
  * @param metrics A metrics instance to publish key-value store related statistics
  */
 class InMemoryKeyValueStore(val metrics: KeyValueStoreMetrics = new KeyValueStoreMetrics)
-  extends KeyValueStore[Array[Byte], Array[Byte]] with Logging {
+    extends KeyValueStore[Array[Byte], Array[Byte]] with Logging {
 
-  val underlying = new util.TreeMap[Array[Byte], Array[Byte]] (UnsignedBytes.lexicographicalComparator())
+  val underlying = new util.TreeMap[Array[Byte], Array[Byte]](UnsignedBytes.lexicographicalComparator())
 
   override def flush(): Unit = {
     // No-op for In memory store.
@@ -42,37 +42,38 @@ class InMemoryKeyValueStore(val metrics: KeyValueStoreMetrics = new KeyValueStor
 
   override def close(): Unit = Unit
 
-  private def getIter(tm:util.SortedMap[Array[Byte], Array[Byte]]) = {
-    new KeyValueIterator[Array[Byte], Array[Byte]] {
-      val iter = tm.entrySet().iterator()
+  private class InMemoryIterator(val iter: util.Iterator[util.Map.Entry[Array[Byte], Array[Byte]]])
+      extends KeyValueIterator[Array[Byte], Array[Byte]] {
 
-      override def close(): Unit = Unit
+    override def close(): Unit = Unit
 
-      override def remove(): Unit = iter.remove()
+    override def remove(): Unit = iter.remove()
 
-      override def next(): Entry[Array[Byte], Array[Byte]] = {
-        val n = iter.next()
-        if (n != null && n.getKey != null) {
-          metrics.bytesRead.inc(n.getKey.size)
-        }
-        if (n != null && n.getValue != null) {
-          metrics.bytesRead.inc(n.getValue.size)
-        }
-        new Entry(n.getKey, n.getValue)
+    override def next(): Entry[Array[Byte], Array[Byte]] = {
+      val n = iter.next()
+      if (n != null && n.getKey != null) {
+        metrics.bytesRead.inc(n.getKey.size)
       }
-
-      override def hasNext: Boolean = iter.hasNext
+      if (n != null && n.getValue != null) {
+        metrics.bytesRead.inc(n.getValue.size)
+      }
+      new Entry(n.getKey, n.getValue)
     }
+
+    override def hasNext: Boolean = iter.hasNext
   }
+
   override def all(): KeyValueIterator[Array[Byte], Array[Byte]] = {
     metrics.alls.inc
-    getIter(underlying)
+
+    new InMemoryIterator(underlying.entrySet().iterator())
   }
 
   override def range(from: Array[Byte], to: Array[Byte]): KeyValueIterator[Array[Byte], Array[Byte]] = {
     metrics.ranges.inc
     require(from != null && to != null, "Null bound not allowed.")
-    getIter(underlying.subMap(from, to))
+
+    new InMemoryIterator(underlying.subMap(from, to).entrySet().iterator())
   }
 
   override def delete(key: Array[Byte]): Unit = {
@@ -84,7 +85,7 @@ class InMemoryKeyValueStore(val metrics: KeyValueStoreMetrics = new KeyValueStor
     // TreeMap's putAll requires a map, so we'd need to iterate over all the entries anyway
     // to use it, in order to putAll here.  Therefore, just iterate here.
     val iter = entries.iterator()
-    while(iter.hasNext) {
+    while (iter.hasNext) {
       val next = iter.next()
       put(next.getKey, next.getValue)
     }
@@ -112,4 +113,3 @@ class InMemoryKeyValueStore(val metrics: KeyValueStoreMetrics = new KeyValueStor
     found
   }
 }
-

http://git-wip-us.apache.org/repos/asf/samza/blob/93faa06c/samza-kv/src/main/scala/org/apache/samza/storage/kv/BaseKeyValueStorageEngineFactory.scala
----------------------------------------------------------------------
diff --git a/samza-kv/src/main/scala/org/apache/samza/storage/kv/BaseKeyValueStorageEngineFactory.scala b/samza-kv/src/main/scala/org/apache/samza/storage/kv/BaseKeyValueStorageEngineFactory.scala
index b3624e6..391cf89 100644
--- a/samza-kv/src/main/scala/org/apache/samza/storage/kv/BaseKeyValueStorageEngineFactory.scala
+++ b/samza-kv/src/main/scala/org/apache/samza/storage/kv/BaseKeyValueStorageEngineFactory.scala
@@ -38,7 +38,9 @@ import org.apache.samza.task.MessageCollector
 trait BaseKeyValueStorageEngineFactory[K, V] extends StorageEngineFactory[K, V] {
 
   /**
-   * Return a KeyValueStore instance for the given store name
+   * Return a KeyValueStore instance for the given store name,
+   * which will be used as the underlying raw store
+   *
    * @param storeName Name of the store
    * @param storeDir The directory of the store
    * @param registry MetricsRegistry to which to publish store specific metrics.
@@ -90,29 +92,35 @@ trait BaseKeyValueStorageEngineFactory[K, V] extends StorageEngineFactory[K, V]
       throw new SamzaException("Must define a message serde when using key value storage.")
     }
 
-    val kvStore = getKVStore(storeName, storeDir, registry, changeLogSystemStreamPartition, containerContext)
+    val rawStore = getKVStore(storeName, storeDir, registry, changeLogSystemStreamPartition, containerContext)
 
+    // maybe wrap with logging
     val maybeLoggedStore = if (changeLogSystemStreamPartition == null) {
-      kvStore
+      rawStore
     } else {
       val loggedStoreMetrics = new LoggedStoreMetrics(storeName, registry)
-      new LoggedStore(kvStore, changeLogSystemStreamPartition, collector, loggedStoreMetrics)
+      new LoggedStore(rawStore, changeLogSystemStreamPartition, collector, loggedStoreMetrics)
     }
 
+    // wrap with serialization
     val serializedMetrics = new SerializedKeyValueStoreMetrics(storeName, registry)
     val serialized = new SerializedKeyValueStore[K, V](maybeLoggedStore, keySerde, msgSerde, serializedMetrics)
+
+    // maybe wrap with caching
     val maybeCachedStore = if (enableCache) {
       val cachedStoreMetrics = new CachedStoreMetrics(storeName, registry)
       new CachedStore(serialized, cacheSize, batchSize, cachedStoreMetrics)
     } else {
       serialized
     }
-    val db = new NullSafeKeyValueStore(maybeCachedStore)
-    val keyValueStorageEngineMetrics = new KeyValueStorageEngineMetrics(storeName, registry)
 
-    // TODO: Decide if we should use raw bytes when restoring
+    // wrap with null value checking
+    val nullSafeStore = new NullSafeKeyValueStore(maybeCachedStore)
 
-    new KeyValueStorageEngine(db, kvStore, keyValueStorageEngineMetrics, batchSize)
+    // create the storage engine and return
+    // TODO: Decide if we should use raw bytes when restoring
+    val keyValueStorageEngineMetrics = new KeyValueStorageEngineMetrics(storeName, registry)
+    new KeyValueStorageEngine(nullSafeStore, rawStore, keyValueStorageEngineMetrics, batchSize)
   }
 
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/93faa06c/samza-kv/src/main/scala/org/apache/samza/storage/kv/CachedStore.scala
----------------------------------------------------------------------
diff --git a/samza-kv/src/main/scala/org/apache/samza/storage/kv/CachedStore.scala b/samza-kv/src/main/scala/org/apache/samza/storage/kv/CachedStore.scala
index 61bb3f6..b94bb27 100644
--- a/samza-kv/src/main/scala/org/apache/samza/storage/kv/CachedStore.scala
+++ b/samza-kv/src/main/scala/org/apache/samza/storage/kv/CachedStore.scala
@@ -40,14 +40,15 @@ import java.util.Arrays
  * This class is very non-thread safe.
  *
  * @param store The store to cache
- * @param cacheEntries The number of entries to hold in the in memory-cache
+ * @param cacheSize The number of entries to hold in the in memory-cache
  * @param writeBatchSize The number of entries to batch together before forcing a write
+ * @param metrics The metrics recording object for this cached store
  */
 class CachedStore[K, V](
-  val store: KeyValueStore[K, V],
-  val cacheSize: Int,
-  val writeBatchSize: Int,
-  val metrics: CachedStoreMetrics = new CachedStoreMetrics) extends KeyValueStore[K, V] with Logging {
+    val store: KeyValueStore[K, V],
+    val cacheSize: Int,
+    val writeBatchSize: Int,
+    val metrics: CachedStoreMetrics = new CachedStoreMetrics) extends KeyValueStore[K, V] with Logging {
 
   /** the number of items in the dirty list */
   @volatile private var dirtyCount = 0
@@ -82,7 +83,7 @@ class CachedStore[K, V](
   metrics.setDirtyCount(() => dirtyCount)
   metrics.setCacheSize(() => cacheCount)
 
-  def get(key: K) = {
+  override def get(key: K) = {
     metrics.gets.inc
 
     val c = cache.get(key)
@@ -97,19 +98,41 @@ class CachedStore[K, V](
     }
   }
 
-  def range(from: K, to: K) = {
+  private class CachedStoreIterator(val iter: KeyValueIterator[K, V])
+      extends KeyValueIterator[K, V] {
+
+    var last: Entry[K, V] = null
+
+    override def close(): Unit = iter.close()
+
+    override def remove(): Unit = {
+      iter.remove()
+      delete(last.getKey)
+    }
+
+    override def next() = {
+      last = iter.next()
+      last
+    }
+
+    override def hasNext: Boolean = iter.hasNext
+  }
+
+  override def range(from: K, to: K): KeyValueIterator[K, V] = {
     metrics.ranges.inc
     flush()
-    store.range(from, to)
+
+    new CachedStoreIterator(store.range(from, to))
   }
 
-  def all() = {
+  override def all(): KeyValueIterator[K, V] = {
     metrics.alls.inc
     flush()
-    store.all()
+
+    new CachedStoreIterator(store.all())
   }
 
-  def put(key: K, value: V) {
+  override def put(key: K, value: V) {
     metrics.puts.inc
 
     checkKeyIsArray(key)
@@ -126,7 +149,7 @@ class CachedStore[K, V](
         this.dirty = found.dirty.next
         this.dirty.prev = null
       } else {
-        found.dirty.remove
+        found.dirty.remove()
       }
     }
     this.dirty = new mutable.DoubleLinkedList(key, this.dirty)
@@ -149,22 +172,22 @@ class CachedStore[K, V](
     }
   }
 
-  def flush() {
+  override def flush() {
     trace("Flushing.")
 
     metrics.flushes.inc
 
     // write out the contents of the dirty list oldest first
     val batch = new Array[Entry[K, V]](this.dirtyCount)
-    var pos : Int = this.dirtyCount - 1;
+    var pos: Int = this.dirtyCount - 1
     for (k <- this.dirty) {
       val entry = this.cache.get(k)
       entry.dirty = null // not dirty any more
       batch(pos) = new Entry(k, entry.value)
       pos -= 1
     }
-    store.putAll(Arrays.asList(batch : _*))
-    store.flush
+    store.putAll(Arrays.asList(batch: _*))
+    store.flush()
     metrics.flushBatchSize.inc(batch.size)
 
     // reset the dirty list
@@ -188,16 +211,13 @@ class CachedStore[K, V](
    */
   def delete(key: K) {
     metrics.deletes.inc
-
     put(key, null.asInstanceOf[V])
   }
 
   def close() {
     trace("Closing.")
-
-    flush
-
-    store.close
+    flush()
+    store.close()
   }
 
   private def checkKeyIsArray(key: K) {

http://git-wip-us.apache.org/repos/asf/samza/blob/93faa06c/samza-kv/src/main/scala/org/apache/samza/storage/kv/KeyValueStorageEngine.scala
----------------------------------------------------------------------
diff --git a/samza-kv/src/main/scala/org/apache/samza/storage/kv/KeyValueStorageEngine.scala b/samza-kv/src/main/scala/org/apache/samza/storage/kv/KeyValueStorageEngine.scala
index 3a23daf..380b60c 100644
--- a/samza-kv/src/main/scala/org/apache/samza/storage/kv/KeyValueStorageEngine.scala
+++ b/samza-kv/src/main/scala/org/apache/samza/storage/kv/KeyValueStorageEngine.scala
@@ -31,47 +31,47 @@ import scala.collection.JavaConversions._
  * This implements both the key/value interface and the storage engine interface.
  */
 class KeyValueStorageEngine[K, V](
-  db: KeyValueStore[K, V],
-  rawDb: KeyValueStore[Array[Byte], Array[Byte]],
-  metrics: KeyValueStorageEngineMetrics = new KeyValueStorageEngineMetrics,
-  batchSize: Int = 500) extends StorageEngine with KeyValueStore[K, V] with Logging {
+    wrapperStore: KeyValueStore[K, V],
+    rawStore: KeyValueStore[Array[Byte], Array[Byte]],
+    metrics: KeyValueStorageEngineMetrics = new KeyValueStorageEngineMetrics,
+    batchSize: Int = 500) extends StorageEngine with KeyValueStore[K, V] with Logging {
 
   var count = 0
 
   /* delegate to underlying store */
   def get(key: K): V = {
     metrics.gets.inc
-    db.get(key)
+    wrapperStore.get(key)
   }
 
   def put(key: K, value: V) = {
     metrics.puts.inc
-    db.put(key, value)
+    wrapperStore.put(key, value)
   }
 
   def putAll(entries: java.util.List[Entry[K, V]]) = {
     metrics.puts.inc(entries.size)
-    db.putAll(entries)
+    wrapperStore.putAll(entries)
   }
 
   def delete(key: K) = {
     metrics.deletes.inc
-    db.delete(key)
+    wrapperStore.delete(key)
   }
 
   def range(from: K, to: K) = {
     metrics.ranges.inc
-    db.range(from, to)
+    wrapperStore.range(from, to)
   }
 
   def all() = {
     metrics.alls.inc
-    db.all()
+    wrapperStore.all()
   }
 
   /**
    * Restore the contents of this key/value store from the change log,
-   * batching updates and skipping serialization for efficiency.
+   * batching updates to underlying raw store to skip wrapping functions for efficiency.
    */
   def restore(envelopes: java.util.Iterator[IncomingMessageEnvelope]) {
     val batch = new java.util.ArrayList[Entry[Array[Byte], Array[Byte]]](batchSize)
@@ -83,7 +83,7 @@ class KeyValueStorageEngine[K, V](
       batch.add(new Entry(keyBytes, valBytes))
 
       if (batch.size >= batchSize) {
-        rawDb.putAll(batch)
+        rawStore.putAll(batch)
         batch.clear()
       }
 
@@ -101,7 +101,7 @@ class KeyValueStorageEngine[K, V](
     }
 
     if (batch.size > 0) {
-      rawDb.putAll(batch)
+      rawStore.putAll(batch)
     }
   }
 
@@ -110,19 +110,19 @@ class KeyValueStorageEngine[K, V](
 
     metrics.flushes.inc
 
-    db.flush
+    wrapperStore.flush()
   }
 
   def stop() = {
     trace("Stopping.")
 
-    close
+    close()
   }
 
   def close() = {
     trace("Closing.")
 
-    flush
-    db.close
+    flush()
+    wrapperStore.close()
   }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/93faa06c/samza-kv/src/test/scala/org/apache/samza/storage/kv/MockKeyValueStore.scala
----------------------------------------------------------------------
diff --git a/samza-kv/src/test/scala/org/apache/samza/storage/kv/MockKeyValueStore.scala b/samza-kv/src/test/scala/org/apache/samza/storage/kv/MockKeyValueStore.scala
new file mode 100644
index 0000000..0822167
--- /dev/null
+++ b/samza-kv/src/test/scala/org/apache/samza/storage/kv/MockKeyValueStore.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.storage.kv
+
+import scala.collection.JavaConversions._
+import java.util
+
+/**
+ * A mock key-value store wrapper that handles serialization
+ */
+class MockKeyValueStore extends KeyValueStore[String, String] {
+
+  val kvMap = new java.util.TreeMap[String, String]()
+
+  override def get(key: String) = kvMap.get(key)
+
+  override def put(key: String, value: String) {
+    kvMap.put(key, value)
+  }
+
+  override def putAll(entries: java.util.List[Entry[String, String]]) {
+    for (entry <- entries) {
+      kvMap.put(entry.getKey, entry.getValue)
+    }
+  }
+
+  override def delete(key: String) {
+    kvMap.remove(key)
+  }
+
+  private class MockIterator(val iter: util.Iterator[util.Map.Entry[String, String]])
+      extends KeyValueIterator[String, String] {
+
+    override def hasNext = iter.hasNext
+
+    override def next() = {
+      val entry = iter.next()
+      new Entry(entry.getKey, entry.getValue)
+    }
+
+    override def remove(): Unit = iter.remove()
+
+    override def close(): Unit = Unit
+  }
+
+  override def range(from: String, to: String): KeyValueIterator[String, String] =
+    new MockIterator(kvMap.subMap(from, to).entrySet().iterator())
+
+  override def all(): KeyValueIterator[String, String] =
+    new MockIterator(kvMap.entrySet().iterator())
+
+  override def flush() {} // no-op
+
+  override def close() { kvMap.clear() }
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/samza/blob/93faa06c/samza-kv/src/test/scala/org/apache/samza/storage/kv/TestCachedStore.scala
----------------------------------------------------------------------
diff --git a/samza-kv/src/test/scala/org/apache/samza/storage/kv/TestCachedStore.scala b/samza-kv/src/test/scala/org/apache/samza/storage/kv/TestCachedStore.scala
index d03ec92..cc9c9f3 100644
--- a/samza-kv/src/test/scala/org/apache/samza/storage/kv/TestCachedStore.scala
+++ b/samza-kv/src/test/scala/org/apache/samza/storage/kv/TestCachedStore.scala
@@ -23,13 +23,69 @@ import org.junit.Test
 import org.junit.Assert._
 import org.mockito.Mockito._
 
+import java.util.Arrays
+
 class TestCachedStore {
   @Test
-  def testArrayCheck {
+  def testArrayCheck() {
     val kv = mock(classOf[KeyValueStore[Array[Byte], Array[Byte]]])
     val store = new CachedStore[Array[Byte], Array[Byte]](kv, 100, 100)
+
     assertFalse(store.hasArrayKeys)
     store.put("test1-key".getBytes("UTF-8"), "test1-value".getBytes("UTF-8"))
     assertTrue(store.hasArrayKeys)
   }
+
+  @Test
+  def testIterator() {
+    val kv = new MockKeyValueStore()
+    val store = new CachedStore[String, String](kv, 100, 100)
+
+    val keys = Arrays.asList("test1-key",
+                             "test2-key",
+                             "test3-key")
+    val values = Arrays.asList("test1-value",
+                               "test2-value",
+                               "test3-value")
+
+    for (i <- 0 until 3) {
+      store.put(keys.get(i), values.get(i))
+    }
+
+    // test all iterator
+    var iter = store.all()
+    for (i <- 0 until 3) {
+      assertTrue(iter.hasNext)
+      val entry = iter.next()
+      assertEquals(entry.getKey, keys.get(i))
+      assertEquals(entry.getValue, values.get(i))
+    }
+    assertFalse(iter.hasNext)
+
+    // test range iterator
+    iter = store.range(keys.get(0), keys.get(2))
+    for (i <- 0 until 2) {
+      assertTrue(iter.hasNext)
+      val entry = iter.next()
+      assertEquals(entry.getKey, keys.get(i))
+      assertEquals(entry.getValue, values.get(i))
+    }
+    assertFalse(iter.hasNext)
+
+    // test iterator remove
+    iter = store.all()
+    iter.next()
+    iter.remove()
+
+    assertNull(kv.get(keys.get(0)))
+    assertNull(store.get(keys.get(0)))
+
+    iter = store.range(keys.get(1), keys.get(2))
+    iter.next()
+    iter.remove()
+
+    assertFalse(iter.hasNext)
+    assertNull(kv.get(keys.get(1)))
+    assertNull(store.get(keys.get(1)))
+  }
 }
\ No newline at end of file