You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@samza.apache.org by ni...@apache.org on 2015/05/08 03:58:37 UTC

samza git commit: SAMZA-658: fix cached store iterator remove() function

Repository: samza
Updated Branches:
  refs/heads/master f0809a54b -> 4323003dc


SAMZA-658: fix cached store iterator remove() function


Project: http://git-wip-us.apache.org/repos/asf/samza/repo
Commit: http://git-wip-us.apache.org/repos/asf/samza/commit/4323003d
Tree: http://git-wip-us.apache.org/repos/asf/samza/tree/4323003d
Diff: http://git-wip-us.apache.org/repos/asf/samza/diff/4323003d

Branch: refs/heads/master
Commit: 4323003dc4a0749d10b36be590e42697749f7dca
Parents: f0809a5
Author: Guozhang Wang <wa...@gmail.com>
Authored: Thu May 7 18:58:23 2015 -0700
Committer: Yi Pan (Data Infrastructure) <yi...@linkedin.com>
Committed: Thu May 7 18:58:23 2015 -0700

----------------------------------------------------------------------
 .../kv/inmemory/InMemoryKeyValueStore.scala     |  58 +++++-----
 .../kv/BaseKeyValueStorageEngineFactory.scala   |  24 ++--
 .../apache/samza/storage/kv/CachedStore.scala   | 110 +++++++++++--------
 .../storage/kv/KeyValueStorageEngine.scala      |  34 +++---
 .../samza/storage/kv/MockKeyValueStore.scala    |  80 ++++++++++++++
 .../samza/storage/kv/TestCachedStore.scala      |  58 +++++++++-
 6 files changed, 264 insertions(+), 100 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/samza/blob/4323003d/samza-kv-inmemory/src/main/scala/org/apache/samza/storage/kv/inmemory/InMemoryKeyValueStore.scala
----------------------------------------------------------------------
diff --git a/samza-kv-inmemory/src/main/scala/org/apache/samza/storage/kv/inmemory/InMemoryKeyValueStore.scala b/samza-kv-inmemory/src/main/scala/org/apache/samza/storage/kv/inmemory/InMemoryKeyValueStore.scala
index 23d028b..72f25a3 100644
--- a/samza-kv-inmemory/src/main/scala/org/apache/samza/storage/kv/inmemory/InMemoryKeyValueStore.scala
+++ b/samza-kv-inmemory/src/main/scala/org/apache/samza/storage/kv/inmemory/InMemoryKeyValueStore.scala
@@ -35,57 +35,57 @@ class InMemoryKeyValueStore(val metrics: KeyValueStoreMetrics = new KeyValueStor
 
   val underlying = new util.TreeMap[Array[Byte], Array[Byte]] (UnsignedBytes.lexicographicalComparator())
 
-  def flush(): Unit = {
+  override def flush(): Unit = {
     // No-op for In memory store.
     metrics.flushes.inc
   }
 
-  def close(): Unit = Unit
+  override def close(): Unit = Unit
 
-  private def getIter(tm:util.SortedMap[Array[Byte], Array[Byte]]) = {
-    new KeyValueIterator[Array[Byte], Array[Byte]] {
-      val iter = tm.entrySet().iterator()
+  private class InMemoryIterator (val iter: util.Iterator[util.Map.Entry[Array[Byte], Array[Byte]]])
+    extends KeyValueIterator[Array[Byte], Array[Byte]] {
 
-      override def close(): Unit = Unit
+    override def close(): Unit = Unit
 
-      override def remove(): Unit = iter.remove()
+    override def remove(): Unit = iter.remove()
 
-      override def next(): Entry[Array[Byte], Array[Byte]] = {
-        val n = iter.next()
-        if (n != null && n.getKey != null) {
-          metrics.bytesRead.inc(n.getKey.size)
-        }
-        if (n != null && n.getValue != null) {
-          metrics.bytesRead.inc(n.getValue.size)
-        }
-        new Entry(n.getKey, n.getValue)
+    override def next(): Entry[Array[Byte], Array[Byte]] = {
+      val n = iter.next()
+      if (n != null && n.getKey != null) {
+        metrics.bytesRead.inc(n.getKey.size)
       }
-
-      override def hasNext: Boolean = iter.hasNext
+      if (n != null && n.getValue != null) {
+        metrics.bytesRead.inc(n.getValue.size)
+      }
+      new Entry(n.getKey, n.getValue)
     }
+
+    override def hasNext: Boolean = iter.hasNext
   }
 
-  def all(): KeyValueIterator[Array[Byte], Array[Byte]] = {
+  override def all(): KeyValueIterator[Array[Byte], Array[Byte]] = {
     metrics.alls.inc
-    getIter(underlying)
+
+    new InMemoryIterator(underlying.entrySet().iterator())
   }
 
-  def range(from: Array[Byte], to: Array[Byte]): KeyValueIterator[Array[Byte], Array[Byte]] = {
+  override def range(from: Array[Byte], to: Array[Byte]): KeyValueIterator[Array[Byte], Array[Byte]] = {
     metrics.ranges.inc
     require(from != null && to != null, "Null bound not allowed.")
-    getIter(underlying.subMap(from, to))
+
+    new InMemoryIterator(underlying.subMap(from, to).entrySet().iterator())
   }
 
-  def delete(key: Array[Byte]): Unit = {
+  override def delete(key: Array[Byte]): Unit = {
     metrics.deletes.inc
     put(key, null)
   }
 
-  def deleteAll(keys: java.util.List[Array[Byte]]) = {
-    KeyValueStore.Extension.deleteAll(this, keys);
+  override def deleteAll(keys: java.util.List[Array[Byte]]) = {
+    KeyValueStore.Extension.deleteAll(this, keys)
   }
 
-  def putAll(entries: util.List[Entry[Array[Byte], Array[Byte]]]): Unit = {
+  override def putAll(entries: util.List[Entry[Array[Byte], Array[Byte]]]): Unit = {
     // TreeMap's putAll requires a map, so we'd need to iterate over all the entries anyway
     // to use it, in order to putAll here.  Therefore, just iterate here.
     val iter = entries.iterator()
@@ -95,7 +95,7 @@ class InMemoryKeyValueStore(val metrics: KeyValueStoreMetrics = new KeyValueStor
     }
   }
 
-  def put(key: Array[Byte], value: Array[Byte]): Unit = {
+  override def put(key: Array[Byte], value: Array[Byte]): Unit = {
     metrics.puts.inc
     require(key != null, "Null key not allowed.")
     if (value == null) {
@@ -107,7 +107,7 @@ class InMemoryKeyValueStore(val metrics: KeyValueStoreMetrics = new KeyValueStor
     }
   }
 
-  def get(key: Array[Byte]): Array[Byte] = {
+  override def get(key: Array[Byte]): Array[Byte] = {
     metrics.gets.inc
     require(key != null, "Null key not allowed.")
     val found = underlying.get(key)
@@ -117,7 +117,7 @@ class InMemoryKeyValueStore(val metrics: KeyValueStoreMetrics = new KeyValueStor
     found
   }
 
-  def getAll(keys: java.util.List[Array[Byte]]): java.util.Map[Array[Byte], Array[Byte]] = {
+  override def getAll(keys: java.util.List[Array[Byte]]): java.util.Map[Array[Byte], Array[Byte]] = {
     KeyValueStore.Extension.getAll(this, keys);
   }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/4323003d/samza-kv/src/main/scala/org/apache/samza/storage/kv/BaseKeyValueStorageEngineFactory.scala
----------------------------------------------------------------------
diff --git a/samza-kv/src/main/scala/org/apache/samza/storage/kv/BaseKeyValueStorageEngineFactory.scala b/samza-kv/src/main/scala/org/apache/samza/storage/kv/BaseKeyValueStorageEngineFactory.scala
index b3624e6..391cf89 100644
--- a/samza-kv/src/main/scala/org/apache/samza/storage/kv/BaseKeyValueStorageEngineFactory.scala
+++ b/samza-kv/src/main/scala/org/apache/samza/storage/kv/BaseKeyValueStorageEngineFactory.scala
@@ -38,7 +38,9 @@ import org.apache.samza.task.MessageCollector
 trait BaseKeyValueStorageEngineFactory[K, V] extends StorageEngineFactory[K, V] {
 
   /**
-   * Return a KeyValueStore instance for the given store name
+   * Return a KeyValueStore instance for the given store name,
+   * which will be used as the underlying raw store
+   *
    * @param storeName Name of the store
    * @param storeDir The directory of the store
    * @param registry MetricsRegistry to which to publish store specific metrics.
@@ -90,29 +92,35 @@ trait BaseKeyValueStorageEngineFactory[K, V] extends StorageEngineFactory[K, V]
       throw new SamzaException("Must define a message serde when using key value storage.")
     }
 
-    val kvStore = getKVStore(storeName, storeDir, registry, changeLogSystemStreamPartition, containerContext)
+    val rawStore = getKVStore(storeName, storeDir, registry, changeLogSystemStreamPartition, containerContext)
 
+    // maybe wrap with logging
     val maybeLoggedStore = if (changeLogSystemStreamPartition == null) {
-      kvStore
+      rawStore
     } else {
       val loggedStoreMetrics = new LoggedStoreMetrics(storeName, registry)
-      new LoggedStore(kvStore, changeLogSystemStreamPartition, collector, loggedStoreMetrics)
+      new LoggedStore(rawStore, changeLogSystemStreamPartition, collector, loggedStoreMetrics)
     }
 
+    // wrap with serialization
     val serializedMetrics = new SerializedKeyValueStoreMetrics(storeName, registry)
     val serialized = new SerializedKeyValueStore[K, V](maybeLoggedStore, keySerde, msgSerde, serializedMetrics)
+
+    // maybe wrap with caching
     val maybeCachedStore = if (enableCache) {
       val cachedStoreMetrics = new CachedStoreMetrics(storeName, registry)
       new CachedStore(serialized, cacheSize, batchSize, cachedStoreMetrics)
     } else {
       serialized
     }
-    val db = new NullSafeKeyValueStore(maybeCachedStore)
-    val keyValueStorageEngineMetrics = new KeyValueStorageEngineMetrics(storeName, registry)
 
-    // TODO: Decide if we should use raw bytes when restoring
+    // wrap with null value checking
+    val nullSafeStore = new NullSafeKeyValueStore(maybeCachedStore)
 
-    new KeyValueStorageEngine(db, kvStore, keyValueStorageEngineMetrics, batchSize)
+    // create the storage engine and return
+    // TODO: Decide if we should use raw bytes when restoring
+    val keyValueStorageEngineMetrics = new KeyValueStorageEngineMetrics(storeName, registry)
+    new KeyValueStorageEngine(nullSafeStore, rawStore, keyValueStorageEngineMetrics, batchSize)
   }
 
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/4323003d/samza-kv/src/main/scala/org/apache/samza/storage/kv/CachedStore.scala
----------------------------------------------------------------------
diff --git a/samza-kv/src/main/scala/org/apache/samza/storage/kv/CachedStore.scala b/samza-kv/src/main/scala/org/apache/samza/storage/kv/CachedStore.scala
index 479016d..1112350 100644
--- a/samza-kv/src/main/scala/org/apache/samza/storage/kv/CachedStore.scala
+++ b/samza-kv/src/main/scala/org/apache/samza/storage/kv/CachedStore.scala
@@ -40,8 +40,9 @@ import java.util.Arrays
  * This class is very non-thread safe.
  *
  * @param store The store to cache
- * @param cacheEntries The number of entries to hold in the in memory-cache
+ * @param cacheSize The number of entries to hold in the in memory-cache
  * @param writeBatchSize The number of entries to batch together before forcing a write
+ * @param metrics The metrics recording object for this cached store
  */
 class CachedStore[K, V](
   val store: KeyValueStore[K, V],
@@ -82,7 +83,7 @@ class CachedStore[K, V](
   metrics.setDirtyCount(() => dirtyCount)
   metrics.setCacheSize(() => cacheCount)
 
-  def get(key: K) = {
+  override def get(key: K) = {
     metrics.gets.inc
 
     val c = cache.get(key)
@@ -97,46 +98,41 @@ class CachedStore[K, V](
     }
   }
 
-  def getAll(keys: java.util.List[K]): java.util.Map[K, V] = {
-    metrics.gets.inc(keys.size)
-    val returnValue = new java.util.HashMap[K, V](keys.size)
-    val misses = new java.util.ArrayList[K]
-    val keysIterator = keys.iterator
-    while (keysIterator.hasNext) {
-      val key = keysIterator.next
-      val cached = cache.get(key)
-      if (cached != null) {
-        metrics.cacheHits.inc
-        returnValue.put(key, cached.value)
-      } else {
-        misses.add(key)
-      }
+  private class CachedStoreIterator(val iter: KeyValueIterator[K, V])
+    extends KeyValueIterator[K, V] {
+
+    var last: Entry[K, V] = null
+
+    override def close(): Unit = iter.close()
+
+    override def remove(): Unit = {
+      iter.remove()
+      delete(last.getKey)
     }
-    if (!misses.isEmpty) {
-      val entryIterator = store.getAll(misses).entrySet.iterator
-      while (entryIterator.hasNext) {
-        val entry = entryIterator.next
-        returnValue.put(entry.getKey, entry.getValue)
-        cache.put(entry.getKey, new CacheEntry(entry.getValue, null))
-      }
-      cacheCount = cache.size // update outside the loop since it's used for metrics and not for time-sensitive logic
+
+    override def next() = {
+      last = iter.next()
+      last
     }
-    returnValue
+
+    override def hasNext: Boolean = iter.hasNext
   }
 
-  def range(from: K, to: K) = {
+  override def range(from: K, to: K): KeyValueIterator[K, V] = {
     metrics.ranges.inc
     flush()
-    store.range(from, to)
+
+    new CachedStoreIterator(store.range(from, to))
   }
 
-  def all() = {
+  override def all(): KeyValueIterator[K, V] = {
     metrics.alls.inc
     flush()
-    store.all()
+
+    new CachedStoreIterator(store.all())
   }
 
-  def put(key: K, value: V) {
+  override def put(key: K, value: V) {
     metrics.puts.inc
 
     checkKeyIsArray(key)
@@ -153,7 +149,7 @@ class CachedStore[K, V](
         this.dirty = found.dirty.next
         this.dirty.prev = null
       } else {
-        found.dirty.remove
+        found.dirty.remove()
       }
     }
     this.dirty = new mutable.DoubleLinkedList(key, this.dirty)
@@ -176,14 +172,14 @@ class CachedStore[K, V](
     }
   }
 
-  def flush() {
+  override def flush() {
     trace("Flushing.")
 
     metrics.flushes.inc
 
     // write out the contents of the dirty list oldest first
     val batch = new Array[Entry[K, V]](this.dirtyCount)
-    var pos : Int = this.dirtyCount - 1;
+    var pos : Int = this.dirtyCount - 1
     for (k <- this.dirty) {
       val entry = this.cache.get(k)
       entry.dirty = null // not dirty any more
@@ -191,7 +187,7 @@ class CachedStore[K, V](
       pos -= 1
     }
     store.putAll(Arrays.asList(batch : _*))
-    store.flush
+    store.flush()
     metrics.flushBatchSize.inc(batch.size)
 
     // reset the dirty list
@@ -199,7 +195,7 @@ class CachedStore[K, V](
     this.dirtyCount = 0
   }
 
-  def putAll(entries: java.util.List[Entry[K, V]]) {
+  override def putAll(entries: java.util.List[Entry[K, V]]) {
     val iter = entries.iterator
     while (iter.hasNext) {
       val curr = iter.next
@@ -207,22 +203,19 @@ class CachedStore[K, V](
     }
   }
 
-  def delete(key: K) {
+  override def delete(key: K) {
     metrics.deletes.inc
-
     put(key, null.asInstanceOf[V])
   }
 
-  def deleteAll(keys: java.util.List[K]) = {
-    KeyValueStore.Extension.deleteAll(this, keys);
-  }
-
-  def close() {
+  override def close() {
     trace("Closing.")
+    flush()
+    store.close()
+  }
 
-    flush
-
-    store.close
+  override def deleteAll(keys: java.util.List[K]) = {
+    KeyValueStore.Extension.deleteAll(this, keys)
   }
 
   private def checkKeyIsArray(key: K) {
@@ -233,6 +226,33 @@ class CachedStore[K, V](
     }
   }
 
+  override def getAll(keys: java.util.List[K]): java.util.Map[K, V] = {
+    metrics.gets.inc(keys.size)
+    val returnValue = new java.util.HashMap[K, V](keys.size)
+    val misses = new java.util.ArrayList[K]
+    val keysIterator = keys.iterator
+    while (keysIterator.hasNext) {
+      val key = keysIterator.next
+      val cached = cache.get(key)
+      if (cached != null) {
+        metrics.cacheHits.inc
+        returnValue.put(key, cached.value)
+      } else {
+        misses.add(key)
+      }
+    }
+    if (!misses.isEmpty) {
+      val entryIterator = store.getAll(misses).entrySet.iterator
+      while (entryIterator.hasNext) {
+        val entry = entryIterator.next
+        returnValue.put(entry.getKey, entry.getValue)
+        cache.put(entry.getKey, new CacheEntry(entry.getValue, null))
+      }
+      cacheCount = cache.size // update outside the loop since it's used for metrics and not for time-sensitive logic
+    }
+    returnValue
+  }
+
   def hasArrayKeys = containsArrayKeys
 }
 

http://git-wip-us.apache.org/repos/asf/samza/blob/4323003d/samza-kv/src/main/scala/org/apache/samza/storage/kv/KeyValueStorageEngine.scala
----------------------------------------------------------------------
diff --git a/samza-kv/src/main/scala/org/apache/samza/storage/kv/KeyValueStorageEngine.scala b/samza-kv/src/main/scala/org/apache/samza/storage/kv/KeyValueStorageEngine.scala
index fc677b2..e5a66a4 100644
--- a/samza-kv/src/main/scala/org/apache/samza/storage/kv/KeyValueStorageEngine.scala
+++ b/samza-kv/src/main/scala/org/apache/samza/storage/kv/KeyValueStorageEngine.scala
@@ -31,8 +31,8 @@ import scala.collection.JavaConversions._
  * This implements both the key/value interface and the storage engine interface.
  */
 class KeyValueStorageEngine[K, V](
-  db: KeyValueStore[K, V],
-  rawDb: KeyValueStore[Array[Byte], Array[Byte]],
+  wrapperStore: KeyValueStore[K, V],
+  rawStore: KeyValueStore[Array[Byte], Array[Byte]],
   metrics: KeyValueStorageEngineMetrics = new KeyValueStorageEngineMetrics,
   batchSize: Int = 500) extends StorageEngine with KeyValueStore[K, V] with Logging {
 
@@ -41,47 +41,47 @@ class KeyValueStorageEngine[K, V](
   /* delegate to underlying store */
   def get(key: K): V = {
     metrics.gets.inc
-    db.get(key)
+    wrapperStore.get(key)
   }
 
   def getAll(keys: java.util.List[K]): java.util.Map[K, V] = {
     metrics.gets.inc(keys.size)
-    db.getAll(keys)
+    wrapperStore.getAll(keys)
   }
 
   def put(key: K, value: V) = {
     metrics.puts.inc
-    db.put(key, value)
+    wrapperStore.put(key, value)
   }
 
   def putAll(entries: java.util.List[Entry[K, V]]) = {
     metrics.puts.inc(entries.size)
-    db.putAll(entries)
+    wrapperStore.putAll(entries)
   }
 
   def delete(key: K) = {
     metrics.deletes.inc
-    db.delete(key)
+    wrapperStore.delete(key)
   }
 
   def deleteAll(keys: java.util.List[K]) = {
     metrics.deletes.inc(keys.size)
-    db.deleteAll(keys)
+    wrapperStore.deleteAll(keys)
   }
 
   def range(from: K, to: K) = {
     metrics.ranges.inc
-    db.range(from, to)
+    wrapperStore.range(from, to)
   }
 
   def all() = {
     metrics.alls.inc
-    db.all()
+    wrapperStore.all()
   }
 
   /**
    * Restore the contents of this key/value store from the change log,
-   * batching updates and skipping serialization for efficiency.
+   * batching updates to underlying raw store to skip wrapping functions for efficiency.
    */
   def restore(envelopes: java.util.Iterator[IncomingMessageEnvelope]) {
     val batch = new java.util.ArrayList[Entry[Array[Byte], Array[Byte]]](batchSize)
@@ -93,7 +93,7 @@ class KeyValueStorageEngine[K, V](
       batch.add(new Entry(keyBytes, valBytes))
 
       if (batch.size >= batchSize) {
-        rawDb.putAll(batch)
+        rawStore.putAll(batch)
         batch.clear()
       }
 
@@ -111,7 +111,7 @@ class KeyValueStorageEngine[K, V](
     }
 
     if (batch.size > 0) {
-      rawDb.putAll(batch)
+      rawStore.putAll(batch)
     }
   }
 
@@ -120,19 +120,19 @@ class KeyValueStorageEngine[K, V](
 
     metrics.flushes.inc
 
-    db.flush
+    wrapperStore.flush()
   }
 
   def stop() = {
     trace("Stopping.")
 
-    close
+    close()
   }
 
   def close() = {
     trace("Closing.")
 
-    flush
-    db.close
+    flush()
+    wrapperStore.close()
   }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/4323003d/samza-kv/src/test/scala/org/apache/samza/storage/kv/MockKeyValueStore.scala
----------------------------------------------------------------------
diff --git a/samza-kv/src/test/scala/org/apache/samza/storage/kv/MockKeyValueStore.scala b/samza-kv/src/test/scala/org/apache/samza/storage/kv/MockKeyValueStore.scala
new file mode 100644
index 0000000..595dd0d
--- /dev/null
+++ b/samza-kv/src/test/scala/org/apache/samza/storage/kv/MockKeyValueStore.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.storage.kv
+
+import scala.collection.JavaConversions._
+import java.util
+
+/**
+ * A mock key-value store wrapper that handles serialization
+ */
+class MockKeyValueStore extends KeyValueStore[String, String] {
+
+  val kvMap = new java.util.TreeMap[String, String]()
+
+  override def get(key: String) = kvMap.get(key)
+
+  override def put(key: String, value: String) {
+    kvMap.put(key, value)
+  }
+
+  override def putAll(entries: java.util.List[Entry[String, String]]) {
+    for (entry <- entries) {
+      kvMap.put(entry.getKey, entry.getValue)
+    }
+  }
+
+  override def delete(key: String) {
+    kvMap.remove(key)
+  }
+
+  private class MockIterator(val iter: util.Iterator[util.Map.Entry[String, String]])
+    extends KeyValueIterator[String, String] {
+
+    override def hasNext = iter.hasNext
+
+    override def next() = {
+      val entry = iter.next()
+      new Entry(entry.getKey, entry.getValue)
+    }
+
+    override def remove(): Unit = iter.remove()
+
+    override def close(): Unit = Unit
+  }
+
+  override def range(from: String, to: String): KeyValueIterator[String, String] =
+    new MockIterator(kvMap.subMap(from, to).entrySet().iterator())
+
+  override def all(): KeyValueIterator[String, String] =
+    new MockIterator(kvMap.entrySet().iterator())
+
+  override def flush() {}  // no-op
+
+  override def close() { kvMap.clear() }
+
+  override def deleteAll(keys: java.util.List[String]) {
+    KeyValueStore.Extension.deleteAll(this, keys)
+  }
+
+  override def getAll(keys: java.util.List[String]): java.util.Map[String, String] = {
+    KeyValueStore.Extension.getAll(this, keys)
+  }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/samza/blob/4323003d/samza-kv/src/test/scala/org/apache/samza/storage/kv/TestCachedStore.scala
----------------------------------------------------------------------
diff --git a/samza-kv/src/test/scala/org/apache/samza/storage/kv/TestCachedStore.scala b/samza-kv/src/test/scala/org/apache/samza/storage/kv/TestCachedStore.scala
index d03ec92..cc9c9f3 100644
--- a/samza-kv/src/test/scala/org/apache/samza/storage/kv/TestCachedStore.scala
+++ b/samza-kv/src/test/scala/org/apache/samza/storage/kv/TestCachedStore.scala
@@ -23,13 +23,69 @@ import org.junit.Test
 import org.junit.Assert._
 import org.mockito.Mockito._
 
+import java.util.Arrays
+
 class TestCachedStore {
   @Test
-  def testArrayCheck {
+  def testArrayCheck() {
     val kv = mock(classOf[KeyValueStore[Array[Byte], Array[Byte]]])
     val store = new CachedStore[Array[Byte], Array[Byte]](kv, 100, 100)
+
     assertFalse(store.hasArrayKeys)
     store.put("test1-key".getBytes("UTF-8"), "test1-value".getBytes("UTF-8"))
     assertTrue(store.hasArrayKeys)
   }
+
+  @Test
+  def testIterator() {
+    val kv = new MockKeyValueStore()
+    val store = new CachedStore[String, String](kv, 100, 100)
+
+    val keys = Arrays.asList("test1-key",
+                             "test2-key",
+                             "test3-key")
+    val values = Arrays.asList("test1-value",
+                               "test2-value",
+                               "test3-value")
+
+    for (i <- 0 until 3) {
+      store.put(keys.get(i), values.get(i))
+    }
+
+    // test all iterator
+    var iter = store.all()
+    for (i <- 0 until 3) {
+      assertTrue(iter.hasNext)
+      val entry = iter.next()
+      assertEquals(entry.getKey, keys.get(i))
+      assertEquals(entry.getValue, values.get(i))
+    }
+    assertFalse(iter.hasNext)
+
+    // test range iterator
+    iter = store.range(keys.get(0), keys.get(2))
+    for (i <- 0 until 2) {
+      assertTrue(iter.hasNext)
+      val entry = iter.next()
+      assertEquals(entry.getKey, keys.get(i))
+      assertEquals(entry.getValue, values.get(i))
+    }
+    assertFalse(iter.hasNext)
+
+    // test iterator remove
+    iter = store.all()
+    iter.next()
+    iter.remove()
+
+    assertNull(kv.get(keys.get(0)))
+    assertNull(store.get(keys.get(0)))
+
+    iter = store.range(keys.get(1), keys.get(2))
+    iter.next()
+    iter.remove()
+
+    assertFalse(iter.hasNext)
+    assertNull(kv.get(keys.get(1)))
+    assertNull(store.get(keys.get(1)))
+  }
 }
\ No newline at end of file