You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lc...@apache.org on 2022/10/10 18:43:55 UTC

[beam] branch master updated: [fixes #23000] Update the Python SDK harness state cache to be a loading cache (#23046)

This is an automated email from the ASF dual-hosted git repository.

lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 0e61b026ea7 [fixes #23000] Update the Python SDK harness state cache to be a loading cache (#23046)
0e61b026ea7 is described below

commit 0e61b026ea7accd666fc443f3aeec7f93147a3b6
Author: Luke Cwik <lc...@google.com>
AuthorDate: Mon Oct 10 11:43:46 2022 -0700

    [fixes #23000] Update the Python SDK harness state cache to be a loading cache (#23046)
    
    * [#23000] Update the Python SDK harness state cache to be a loading cache
    
    Also record statistics for the cost of performing a load.
    
    * Use get/peek within sdk_worker
    
    * Address PR comments
    
    * Address PR comments
---
 .../apache_beam/runners/worker/sdk_worker.py       |  47 +---
 .../apache_beam/runners/worker/statecache.py       | 174 ++++++++++---
 .../apache_beam/runners/worker/statecache_test.py  | 282 +++++++++++++++------
 3 files changed, 352 insertions(+), 151 deletions(-)

diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index 968f213d6b2..93c9c609aa0 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -1149,22 +1149,9 @@ class GlobalCachingStateHandler(CachingStateHandler):
       return self._lazy_iterator(state_key, coder)
     # Cache lookup
     cache_state_key = self._convert_to_cache_key(state_key)
-    cached_value = self._state_cache.get(cache_state_key, cache_token)
-    if cached_value is None:
-      # Cache miss, need to retrieve from the Runner
-      # Further size estimation or the use of the continuation token on the
-      # runner side could fall back to materializing one item at a time.
-      # https://jira.apache.org/jira/browse/BEAM-8297
-      materialized = cached_value = (
-          self._partially_cached_iterable(state_key, coder))
-      if isinstance(materialized, (list, self.ContinuationIterable)):
-        self._state_cache.put(cache_state_key, cache_token, materialized)
-      else:
-        _LOGGER.error(
-            "Uncacheable type %s for key %s. Not caching.",
-            materialized,
-            state_key)
-    return cached_value
+    return self._state_cache.get(
+        (cache_state_key, cache_token),
+        lambda key: self._partially_cached_iterable(state_key, coder))
 
   def extend(
       self,
@@ -1175,29 +1162,21 @@ class GlobalCachingStateHandler(CachingStateHandler):
     # type: (...) -> _Future
     cache_token = self._get_cache_token(state_key)
     if cache_token:
-      # Update the cache
+      # Update the cache if the value is already present and
+      # can be updated.
       cache_key = self._convert_to_cache_key(state_key)
-      cached_value = self._state_cache.get(cache_key, cache_token)
-      # Keep in mind that the state for this key can be evicted
-      # while executing this function. Either read or write to the cache
-      # but never do both here!
-      if cached_value is None:
-        # We have never cached this key before, first retrieve state
-        cached_value = self.blocking_get(state_key, coder)
-      # Just extend the already cached value
+      cached_value = self._state_cache.peek((cache_key, cache_token))
       if isinstance(cached_value, list):
+        # The state is fully cached and can be extended
+
         # Materialize provided iterable to ensure reproducible iterations,
         # here and when writing to the state handler below.
         elements = list(elements)
-        # The state is fully cached and can be extended
         cached_value.extend(elements)
-      elif isinstance(cached_value, self.ContinuationIterable):
-        # The state is too large to be fully cached (continuation token used),
-        # only the first part is cached, the rest if enumerated via the runner.
-        pass
-      else:
-        # When a corrupt value made it into the cache, we have to fail.
-        raise Exception("Unexpected cached value: %s" % cached_value)
+        # Re-insert into the cache the updated value so the updated size is
+        # reflected.
+        self._state_cache.put((cache_key, cache_token), cached_value)
+
     # Write to state handler
     futures = []
     out = coder_impl.create_OutputStream()
@@ -1220,7 +1199,7 @@ class GlobalCachingStateHandler(CachingStateHandler):
     cache_token = self._get_cache_token(state_key)
     if cache_token:
       cache_key = self._convert_to_cache_key(state_key)
-      self._state_cache.clear(cache_key, cache_token)
+      self._state_cache.put((cache_key, cache_token), [])
     return self._underlying.clear(state_key)
 
   def done(self):
diff --git a/sdks/python/apache_beam/runners/worker/statecache.py b/sdks/python/apache_beam/runners/worker/statecache.py
index 839b2b0568a..e3f37fec114 100644
--- a/sdks/python/apache_beam/runners/worker/statecache.py
+++ b/sdks/python/apache_beam/runners/worker/statecache.py
@@ -28,8 +28,8 @@ import time
 import types
 import weakref
 from typing import Any
+from typing import Callable
 from typing import List
-from typing import Optional
 from typing import Tuple
 from typing import Union
 
@@ -76,6 +76,7 @@ class WeightedValue(object):
 
 
 class CacheAware(object):
+  """Allows cache users to override what objects are measured."""
   def __init__(self):
     # type: () -> None
     pass
@@ -181,18 +182,47 @@ def get_deep_size(*objs):
       filter_func=_filter_func)
 
 
+class _LoadingValue(WeightedValue):
+  """Allows concurrent users of the cache to wait for a value to be loaded."""
+  def __init__(self):
+    # type: () -> None
+    super().__init__(None, 1)
+    self._wait_event = threading.Event()
+
+  def load(self, key, loading_fn):
+    # type: (Any, Callable[[Any], Any]) -> None
+    try:
+      self._value = loading_fn(key)
+    except Exception as err:
+      self._error = err
+    finally:
+      self._wait_event.set()
+
+  def value(self):
+    # type: () -> Any
+    self._wait_event.wait()
+    err = getattr(self, "_error", None)
+    if err:
+      raise err
+    return self._value
+
+
 class StateCache(object):
-  """Cache for Beam state access, scoped by state key and cache_token.
+  """LRU cache for Beam state access, scoped by state key and cache_token.
      Assumes a bag state implementation.
 
-  For a given state_key and cache_token, caches a value and allows to
-    a) read from the cache (get),
-           if the currently stored cache_token matches the provided
-    b) write to the cache (put),
-           storing the new value alongside with a cache token
-    c) empty a cached element (clear),
-           if the currently stored cache_token matches the provided
+  For a given key, caches a value and allows to
+    a) peek at the cache (peek),
+           returns the value for the provided key or None if it doesn't exist.
+           Will never block.
+    b) read from the cache (get),
+           returns the value for the provided key or loads it using the
+           supplied function. Multiple calls for the same key will block
+           until the value is loaded.
+    c) write to the cache (put),
+           store the provided value overwriting any previous result
     d) invalidate a cached element (invalidate)
+           removes the value from the cache for the provided key
     e) invalidate all cached elements (invalidate_all)
 
   The operations on the cache are thread-safe for use by multiple workers.
@@ -205,28 +235,106 @@ class StateCache(object):
     self._max_weight = max_weight
     self._current_weight = 0
     self._cache = collections.OrderedDict(
-    )  # type: collections.OrderedDict[Tuple[bytes, Optional[bytes]], WeightedValue]
+    )  # type: collections.OrderedDict[Any, WeightedValue]
     self._hit_count = 0
     self._miss_count = 0
     self._evict_count = 0
+    self._load_time_ns = 0
+    self._load_count = 0
     self._lock = threading.RLock()
 
-  def get(self, state_key, cache_token):
-    # type: (bytes, Optional[bytes]) -> Any
-    assert cache_token and self.is_cache_enabled()
-    key = (state_key, cache_token)
+  def peek(self, key):
+    # type: (Any) -> Any
+    assert self.is_cache_enabled()
     with self._lock:
       value = self._cache.get(key, None)
-      if value is None:
+      if value is None or _safe_isinstance(value, _LoadingValue):
         self._miss_count += 1
         return None
+
       self._cache.move_to_end(key)
       self._hit_count += 1
+    return value.value()
+
+  def get(self, key, loading_fn):
+    # type: (Any, Callable[[Any], Any]) -> Any
+    assert self.is_cache_enabled() and callable(loading_fn)
+
+    self._lock.acquire()
+    value = self._cache.get(key, None)
+
+    # Return the already cached value
+    if value is not None:
+      self._cache.move_to_end(key)
+      self._hit_count += 1
+      self._lock.release()
       return value.value()
 
-  def put(self, state_key, cache_token, value):
-    # type: (bytes, Optional[bytes], Any) -> None
-    assert cache_token and self.is_cache_enabled()
+    # Load the value since it isn't in the cache.
+    self._miss_count += 1
+    loading_value = _LoadingValue()
+    self._cache[key] = loading_value
+
+    # Ensure that we unlock the lock while loading to allow for parallel gets
+    self._lock.release()
+
+    start_time_ns = time.time_ns()
+    loading_value.load(key, loading_fn)
+    elapsed_time_ns = time.time_ns() - start_time_ns
+
+    try:
+      value = loading_value.value()
+    except Exception as err:
+      # If loading failed then delete the value from the cache allowing for
+      # the next lookup to possibly succeed.
+      with self._lock:
+        self._load_count += 1
+        self._load_time_ns += elapsed_time_ns
+        # Don't remove values that have already been replaced with a different
+        # value by a put/invalidate that occurred concurrently with the load.
+        # The put/invalidate will have been responsible for updating the
+        # cache weight appropriately already.
+        old_value = self._cache.get(key, None)
+        if old_value is not loading_value:
+          raise err
+        self._current_weight -= loading_value.weight()
+        del self._cache[key]
+      raise err
+
+    # Replace the value in the cache with a weighted value now that the
+    # loading has completed successfully.
+    weight = get_deep_size(value)
+    if weight <= 0:
+      _LOGGER.warning(
+          'Expected object size to be >= 0 for %s but received %d.',
+          value,
+          weight)
+      weight = 8
+    value = WeightedValue(value, weight)
+    with self._lock:
+      self._load_count += 1
+      self._load_time_ns += elapsed_time_ns
+      # Don't replace values that have already been replaced with a different
+      # value by a put/invalidate that occurred concurrently with the load.
+      # The put/invalidate will have been responsible for updating the
+      # cache weight appropriately already.
+      old_value = self._cache.get(key, None)
+      if old_value is not loading_value:
+        return value.value()
+
+      self._current_weight -= loading_value.weight()
+      self._cache[key] = value
+      self._current_weight += value.weight()
+      while self._current_weight > self._max_weight:
+        (_, weighted_value) = self._cache.popitem(last=False)
+        self._current_weight -= weighted_value.weight()
+        self._evict_count += 1
+
+    return value.value()
+
+  def put(self, key, value):
+    # type: (Any, Any) -> None
+    assert self.is_cache_enabled()
     if not _safe_isinstance(value, WeightedValue):
       weight = get_deep_size(value)
       if weight <= 0:
@@ -236,27 +344,22 @@ class StateCache(object):
             weight)
         weight = _DEFAULT_WEIGHT
       value = WeightedValue(value, weight)
-    key = (state_key, cache_token)
     with self._lock:
       old_value = self._cache.pop(key, None)
       if old_value is not None:
         self._current_weight -= old_value.weight()
-      self._cache[(state_key, cache_token)] = value
+      self._cache[key] = value
       self._current_weight += value.weight()
       while self._current_weight > self._max_weight:
         (_, weighted_value) = self._cache.popitem(last=False)
         self._current_weight -= weighted_value.weight()
         self._evict_count += 1
 
-  def clear(self, state_key, cache_token):
-    # type: (bytes, Optional[bytes]) -> None
-    self.put(state_key, cache_token, [])
-
-  def invalidate(self, state_key, cache_token):
-    # type: (bytes, Optional[bytes]) -> None
+  def invalidate(self, key):
+    # type: (Any) -> None
     assert self.is_cache_enabled()
     with self._lock:
-      weighted_value = self._cache.pop((state_key, cache_token), None)
+      weighted_value = self._cache.pop(key, None)
       if weighted_value is not None:
         self._current_weight -= weighted_value.weight()
 
@@ -274,12 +377,17 @@ class StateCache(object):
         hit_ratio = 100.0 * self._hit_count / request_count
       else:
         hit_ratio = 100.0
-      return 'used/max %d/%d MB, hit %.2f%%, lookups %d, evictions %d' % (
-          self._current_weight >> 20,
-          self._max_weight >> 20,
-          hit_ratio,
-          request_count,
-          self._evict_count)
+      return (
+          'used/max %d/%d MB, hit %.2f%%, lookups %d, '
+          'avg load time %.0f ns, loads %d, evictions %d') % (
+              self._current_weight >> 20,
+              self._max_weight >> 20,
+              hit_ratio,
+              request_count,
+              self._load_time_ns /
+              self._load_count if self._load_count > 0 else 0,
+              self._load_count,
+              self._evict_count)
 
   def is_cache_enabled(self):
     # type: () -> bool
diff --git a/sdks/python/apache_beam/runners/worker/statecache_test.py b/sdks/python/apache_beam/runners/worker/statecache_test.py
index 155d2857ce4..6850cb21284 100644
--- a/sdks/python/apache_beam/runners/worker/statecache_test.py
+++ b/sdks/python/apache_beam/runners/worker/statecache_test.py
@@ -19,13 +19,19 @@
 # pytype: skip-file
 
 import logging
+import re
 import threading
+import time
 import unittest
 import weakref
 
+from hamcrest import assert_that
+from hamcrest import contains_string
+
 from apache_beam.runners.worker.statecache import CacheAware
 from apache_beam.runners.worker.statecache import StateCache
 from apache_beam.runners.worker.statecache import WeightedValue
+from apache_beam.runners.worker.statecache import _LoadingValue
 
 
 class StateCacheTest(unittest.TestCase):
@@ -39,21 +45,22 @@ class StateCacheTest(unittest.TestCase):
     cache = StateCache(5 << 20)
     wait_event = threading.Event()
     o = WeightedValueRef()
-    cache.put('deep ref', 'a', o)
+    cache.put('deep ref', o)
     # Ensure that the contents of the internal weak ref isn't sized
-    self.assertIsNotNone(cache.get('deep ref', 'a'))
+    self.assertIsNotNone(cache.peek('deep ref'))
     self.assertEqual(
         cache.describe_stats(),
-        'used/max 0/5 MB, hit 100.00%, lookups 1, evictions 0')
+        'used/max 0/5 MB, hit 100.00%, lookups 1, avg load time 0 ns, loads 0, '
+        'evictions 0')
     cache.invalidate_all()
 
     # Ensure that putting in a weakref doesn't fail regardless of whether
     # it is alive or not
     o_ref = weakref.ref(o, lambda value: wait_event.set())
-    cache.put('not deleted ref', 'a', o_ref)
+    cache.put('not deleted ref', o_ref)
     del o
     wait_event.wait()
-    cache.put('deleted', 'a', o_ref)
+    cache.put('deleted', o_ref)
 
   def test_weakref_proxy(self):
     test_value = WeightedValue('test', 10 << 20)
@@ -65,21 +72,22 @@ class StateCacheTest(unittest.TestCase):
     cache = StateCache(5 << 20)
     wait_event = threading.Event()
     o = WeightedValueRef()
-    cache.put('deep ref', 'a', o)
+    cache.put('deep ref', o)
     # Ensure that the contents of the internal weak ref isn't sized
-    self.assertIsNotNone(cache.get('deep ref', 'a'))
+    self.assertIsNotNone(cache.peek('deep ref'))
     self.assertEqual(
         cache.describe_stats(),
-        'used/max 0/5 MB, hit 100.00%, lookups 1, evictions 0')
+        'used/max 0/5 MB, hit 100.00%, lookups 1, avg load time 0 ns, loads 0, '
+        'evictions 0')
     cache.invalidate_all()
 
     # Ensure that putting in a weakref doesn't fail regardless of whether
     # it is alive or not
     o_ref = weakref.proxy(o, lambda value: wait_event.set())
-    cache.put('not deleted', 'a', o_ref)
+    cache.put('not deleted', o_ref)
     del o
     wait_event.wait()
-    cache.put('deleted', 'a', o_ref)
+    cache.put('deleted', o_ref)
 
   def test_size_of_fails(self):
     class BadSizeOf(object):
@@ -89,143 +97,247 @@ class StateCacheTest(unittest.TestCase):
     cache = StateCache(5 << 20)
     with self.assertLogs('apache_beam.runners.worker.statecache',
                          level='WARNING') as context:
-      cache.put('key', 'a', BadSizeOf())
+      cache.put('key', BadSizeOf())
       self.assertEqual(1, len(context.output))
       self.assertTrue('Failed to size' in context.output[0])
       # Test that we don't spam the logs
-      cache.put('key', 'a', BadSizeOf())
+      cache.put('key', BadSizeOf())
       self.assertEqual(1, len(context.output))
 
-  def test_empty_cache_get(self):
+  def test_empty_cache_peek(self):
     cache = StateCache(5 << 20)
-    self.assertEqual(cache.get("key", 'cache_token'), None)
-    with self.assertRaises(Exception):
-      # Invalid cache token provided
-      self.assertEqual(cache.get("key", None), None)
+    self.assertEqual(cache.peek("key"), None)
     self.assertEqual(
         cache.describe_stats(),
-        'used/max 0/5 MB, hit 0.00%, lookups 1, evictions 0')
+        (
+            'used/max 0/5 MB, hit 0.00%, lookups 1, '
+            'avg load time 0 ns, loads 0, evictions 0'))
 
-  def test_put_get(self):
+  def test_put_peek(self):
     cache = StateCache(5 << 20)
-    cache.put("key", "cache_token", WeightedValue("value", 1 << 20))
+    cache.put("key", WeightedValue("value", 1 << 20))
     self.assertEqual(cache.size(), 1)
-    self.assertEqual(cache.get("key", "cache_token"), "value")
-    self.assertEqual(cache.get("key", "cache_token2"), None)
-    with self.assertRaises(Exception):
-      self.assertEqual(cache.get("key", None), None)
-    self.assertEqual(
-        cache.describe_stats(),
-        'used/max 1/5 MB, hit 50.00%, lookups 2, evictions 0')
-
-  def test_clear(self):
-    cache = StateCache(5 << 20)
-    cache.clear("new-key", "cache_token")
-    cache.put("key", "cache_token", WeightedValue(["value"], 1 << 20))
-    self.assertEqual(cache.size(), 2)
-    self.assertEqual(cache.get("new-key", "new_token"), None)
-    self.assertEqual(cache.get("key", "cache_token"), ['value'])
-    # test clear without existing key/token
-    cache.clear("non-existing", "token")
-    self.assertEqual(cache.size(), 3)
-    self.assertEqual(cache.get("non-existing", "token"), [])
+    self.assertEqual(cache.peek("key"), "value")
+    self.assertEqual(cache.peek("key2"), None)
     self.assertEqual(
         cache.describe_stats(),
-        'used/max 1/5 MB, hit 66.67%, lookups 3, evictions 0')
+        (
+            'used/max 1/5 MB, hit 50.00%, lookups 2, '
+            'avg load time 0 ns, loads 0, evictions 0'))
 
   def test_default_sized_put(self):
     cache = StateCache(5 << 20)
-    cache.put("key", "cache_token", bytearray(1 << 20))
-    cache.put("key2", "cache_token", bytearray(1 << 20))
-    cache.put("key3", "cache_token", bytearray(1 << 20))
-    self.assertEqual(cache.get("key3", "cache_token"), bytearray(1 << 20))
-    cache.put("key4", "cache_token", bytearray(1 << 20))
-    cache.put("key5", "cache_token", bytearray(1 << 20))
+    cache.put("key", bytearray(1 << 20))
+    cache.put("key2", bytearray(1 << 20))
+    cache.put("key3", bytearray(1 << 20))
+    self.assertEqual(cache.peek("key3"), bytearray(1 << 20))
+    cache.put("key4", bytearray(1 << 20))
+    cache.put("key5", bytearray(1 << 20))
     # note that each byte array instance takes slightly over 1 MB which is why
     # these 5 byte arrays can't all be stored in the cache causing a single
     # eviction
     self.assertEqual(
         cache.describe_stats(),
-        'used/max 4/5 MB, hit 100.00%, lookups 1, evictions 1')
+        (
+            'used/max 4/5 MB, hit 100.00%, lookups 1, '
+            'avg load time 0 ns, loads 0, evictions 1'))
 
   def test_max_size(self):
     cache = StateCache(2 << 20)
-    cache.put("key", "cache_token", WeightedValue("value", 1 << 20))
-    cache.put("key2", "cache_token", WeightedValue("value2", 1 << 20))
+    cache.put("key", WeightedValue("value", 1 << 20))
+    cache.put("key2", WeightedValue("value2", 1 << 20))
     self.assertEqual(cache.size(), 2)
-    cache.put("key3", "cache_token", WeightedValue("value3", 1 << 20))
+    cache.put("key3", WeightedValue("value3", 1 << 20))
     self.assertEqual(cache.size(), 2)
     self.assertEqual(
         cache.describe_stats(),
-        'used/max 2/2 MB, hit 100.00%, lookups 0, evictions 1')
+        (
+            'used/max 2/2 MB, hit 100.00%, lookups 0, '
+            'avg load time 0 ns, loads 0, evictions 1'))
 
   def test_invalidate_all(self):
     cache = StateCache(5 << 20)
-    cache.put("key", "cache_token", WeightedValue("value", 1 << 20))
-    cache.put("key2", "cache_token", WeightedValue("value2", 1 << 20))
+    cache.put("key", WeightedValue("value", 1 << 20))
+    cache.put("key2", WeightedValue("value2", 1 << 20))
     self.assertEqual(cache.size(), 2)
     cache.invalidate_all()
     self.assertEqual(cache.size(), 0)
-    self.assertEqual(cache.get("key", "cache_token"), None)
-    self.assertEqual(cache.get("key2", "cache_token"), None)
+    self.assertEqual(cache.peek("key"), None)
+    self.assertEqual(cache.peek("key2"), None)
     self.assertEqual(
         cache.describe_stats(),
-        'used/max 0/5 MB, hit 0.00%, lookups 2, evictions 0')
+        (
+            'used/max 0/5 MB, hit 0.00%, lookups 2, '
+            'avg load time 0 ns, loads 0, evictions 0'))
 
   def test_lru(self):
     cache = StateCache(5 << 20)
-    cache.put("key", "cache_token", WeightedValue("value", 1 << 20))
-    cache.put("key2", "cache_token2", WeightedValue("value2", 1 << 20))
-    cache.put("key3", "cache_token", WeightedValue("value0", 1 << 20))
-    cache.put("key3", "cache_token", WeightedValue("value3", 1 << 20))
-    cache.put("key4", "cache_token4", WeightedValue("value4", 1 << 20))
-    cache.put("key5", "cache_token", WeightedValue("value0", 1 << 20))
-    cache.put("key5", "cache_token", WeightedValue(["value5"], 1 << 20))
+    cache.put("key", WeightedValue("value", 1 << 20))
+    cache.put("key2", WeightedValue("value2", 1 << 20))
+    cache.put("key3", WeightedValue("value0", 1 << 20))
+    cache.put("key3", WeightedValue("value3", 1 << 20))
+    cache.put("key4", WeightedValue("value4", 1 << 20))
+    cache.put("key5", WeightedValue("value0", 1 << 20))
+    cache.put("key5", WeightedValue(["value5"], 1 << 20))
     self.assertEqual(cache.size(), 5)
-    self.assertEqual(cache.get("key", "cache_token"), "value")
-    self.assertEqual(cache.get("key2", "cache_token2"), "value2")
-    self.assertEqual(cache.get("key3", "cache_token"), "value3")
-    self.assertEqual(cache.get("key4", "cache_token4"), "value4")
-    self.assertEqual(cache.get("key5", "cache_token"), ["value5"])
+    self.assertEqual(cache.peek("key"), "value")
+    self.assertEqual(cache.peek("key2"), "value2")
+    self.assertEqual(cache.peek("key3"), "value3")
+    self.assertEqual(cache.peek("key4"), "value4")
+    self.assertEqual(cache.peek("key5"), ["value5"])
     # insert another key to trigger cache eviction
-    cache.put("key6", "cache_token2", WeightedValue("value6", 1 << 20))
+    cache.put("key6", WeightedValue("value6", 1 << 20))
     self.assertEqual(cache.size(), 5)
     # least recently used key should be gone ("key")
-    self.assertEqual(cache.get("key", "cache_token"), None)
+    self.assertEqual(cache.peek("key"), None)
     # trigger a read on "key2"
-    cache.get("key2", "cache_token2")
+    cache.peek("key2")
     # insert another key to trigger cache eviction
-    cache.put("key7", "cache_token", WeightedValue("value7", 1 << 20))
+    cache.put("key7", WeightedValue("value7", 1 << 20))
     self.assertEqual(cache.size(), 5)
     # least recently used key should be gone ("key3")
-    self.assertEqual(cache.get("key3", "cache_token"), None)
-    # trigger a put on "key2"
-    cache.put("key2", "cache_token", WeightedValue("put", 1 << 20))
+    self.assertEqual(cache.peek("key3"), None)
+    # insert another key to trigger cache eviction
+    cache.put("key8", WeightedValue("put", 1 << 20))
     self.assertEqual(cache.size(), 5)
     # insert another key to trigger cache eviction
-    cache.put("key8", "cache_token", WeightedValue("value8", 1 << 20))
+    cache.put("key9", WeightedValue("value8", 1 << 20))
     self.assertEqual(cache.size(), 5)
     # least recently used key should be gone ("key4")
-    self.assertEqual(cache.get("key4", "cache_token"), None)
+    self.assertEqual(cache.peek("key4"), None)
     # make "key5" used by writing to it
-    cache.put("key5", "cache_token", WeightedValue("val", 1 << 20))
+    cache.put("key5", WeightedValue("val", 1 << 20))
     # least recently used key should be gone ("key6")
-    self.assertEqual(cache.get("key6", "cache_token"), None)
+    self.assertEqual(cache.peek("key6"), None)
     self.assertEqual(
         cache.describe_stats(),
-        'used/max 5/5 MB, hit 60.00%, lookups 10, evictions 5')
+        (
+            'used/max 5/5 MB, hit 60.00%, lookups 10, '
+            'avg load time 0 ns, loads 0, evictions 5'))
+
+  def test_get(self):
+    def check_key(key):
+      self.assertEqual(key, "key")
+      time.sleep(0.5)
+      return "value"
+
+    def raise_exception(key):
+      time.sleep(0.5)
+      raise Exception("TestException")
+
+    cache = StateCache(5 << 20)
+    self.assertEqual("value", cache.get("key", check_key))
+    with cache._lock:
+      self.assertFalse(isinstance(cache._cache["key"], _LoadingValue))
+    self.assertEqual("value", cache.peek("key"))
+    cache.invalidate_all()
+
+    with self.assertRaisesRegex(Exception, "TestException"):
+      cache.get("key", raise_exception)
+    # The cache should not have the value after the failing load causing
+    # check_key to load the value.
+    self.assertEqual("value", cache.get("key", check_key))
+    with cache._lock:
+      self.assertFalse(isinstance(cache._cache["key"], _LoadingValue))
+    self.assertEqual("value", cache.peek("key"))
+
+    assert_that(cache.describe_stats(), contains_string(", loads 3,"))
+    load_time_ns = re.search(
+        ", avg load time (.+) ns,", cache.describe_stats()).group(1)
+    # Load time should be larger then the sleep time and less than 2x sleep time
+    self.assertGreater(int(load_time_ns), 0.5 * 1_000_000_000)
+    self.assertLess(int(load_time_ns), 1_000_000_000)
+
+  def test_concurrent_get_waits(self):
+    event = threading.Semaphore(0)
+    threads_running = threading.Barrier(3)
+
+    def wait_for_event(key):
+      with cache._lock:
+        self.assertTrue(isinstance(cache._cache["key"], _LoadingValue))
+      event.release()
+      return "value"
+
+    cache = StateCache(5 << 20)
+
+    def load_key(output):
+      threads_running.wait()
+      output["value"] = cache.get("key", wait_for_event)
+      output["time"] = time.time_ns()
+
+    t1_output = {}
+    t1 = threading.Thread(target=load_key, args=(t1_output, ))
+    t1.start()
+
+    t2_output = {}
+    t2 = threading.Thread(target=load_key, args=(t2_output, ))
+    t2.start()
+
+    # Wait for both threads to start
+    threads_running.wait()
+    # Record the time and wait for the load to start
+    current_time_ns = time.time_ns()
+    event.acquire()
+    t1.join()
+    t2.join()
+
+    # Ensure that only one thread did the loading and not both by checking that
+    # the semaphore was only released once
+    self.assertFalse(event.acquire(blocking=False))
+
+    # Ensure that the load time is greater than the set time ensuring that
+    # both loads had to wait for the event
+    self.assertLessEqual(current_time_ns, t1_output["time"])
+    self.assertLessEqual(current_time_ns, t2_output["time"])
+    self.assertEqual("value", t1_output["value"])
+    self.assertEqual("value", t2_output["value"])
+    self.assertEqual("value", cache.peek("key"))
+
+  def test_concurrent_get_superseded_by_put(self):
+    load_happening = threading.Event()
+    finish_loading = threading.Event()
+
+    def wait_for_event(key):
+      load_happening.set()
+      finish_loading.wait()
+      return "value"
+
+    cache = StateCache(5 << 20)
+
+    def load_key(output):
+      output["value"] = cache.get("key", wait_for_event)
+
+    t1_output = {}
+    t1 = threading.Thread(target=load_key, args=(t1_output, ))
+    t1.start()
+
+    # Wait for the load to start, update the key, and then let the load finish
+    load_happening.wait()
+    cache.put("key", "value2")
+    finish_loading.set()
+    t1.join()
+
+    # Ensure that the original value is loaded and returned and not the
+    # updated value
+    self.assertEqual("value", t1_output["value"])
+    # Ensure that the updated value supersedes the loaded value.
+    self.assertEqual("value2", cache.peek("key"))
 
   def test_is_cached_enabled(self):
     cache = StateCache(1 << 20)
     self.assertEqual(cache.is_cache_enabled(), True)
     self.assertEqual(
         cache.describe_stats(),
-        'used/max 0/1 MB, hit 100.00%, lookups 0, evictions 0')
+        (
+            'used/max 0/1 MB, hit 100.00%, lookups 0, '
+            'avg load time 0 ns, loads 0, evictions 0'))
     cache = StateCache(0)
     self.assertEqual(cache.is_cache_enabled(), False)
     self.assertEqual(
         cache.describe_stats(),
-        'used/max 0/0 MB, hit 100.00%, lookups 0, evictions 0')
+        (
+            'used/max 0/0 MB, hit 100.00%, lookups 0, '
+            'avg load time 0 ns, loads 0, evictions 0'))
 
   def test_get_referents_for_cache(self):
     class GetReferentsForCache(CacheAware):
@@ -237,10 +349,12 @@ class StateCacheTest(unittest.TestCase):
         return [self.measure_me]
 
     cache = StateCache(5 << 20)
-    cache.put("key", "cache_token", GetReferentsForCache())
+    cache.put("key", GetReferentsForCache())
     self.assertEqual(
         cache.describe_stats(),
-        'used/max 1/5 MB, hit 100.00%, lookups 0, evictions 0')
+        (
+            'used/max 1/5 MB, hit 100.00%, lookups 0, '
+            'avg load time 0 ns, loads 0, evictions 0'))
 
 
 if __name__ == '__main__':