You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@beam.apache.org by GitBox <gi...@apache.org> on 2022/09/08 19:17:09 UTC

[GitHub] [beam] ryanthompson591 commented on a diff in pull request #23046: [fixes #23000] Update the Python SDK harness state cache to be a loading cache

ryanthompson591 commented on code in PR #23046:
URL: https://github.com/apache/beam/pull/23046#discussion_r966300851


##########
sdks/python/apache_beam/runners/worker/sdk_worker.py:
##########
@@ -1149,22 +1149,9 @@ def blocking_get(
       return self._lazy_iterator(state_key, coder)
     # Cache lookup
     cache_state_key = self._convert_to_cache_key(state_key)
-    cached_value = self._state_cache.get(cache_state_key, cache_token)
-    if cached_value is None:
-      # Cache miss, need to retrieve from the Runner
-      # Further size estimation or the use of the continuation token on the
-      # runner side could fall back to materializing one item at a time.
-      # https://jira.apache.org/jira/browse/BEAM-8297

Review Comment:
   Is this now resolved with this PR?



##########
sdks/python/apache_beam/runners/worker/sdk_worker.py:
##########
@@ -1149,22 +1149,9 @@ def blocking_get(
       return self._lazy_iterator(state_key, coder)
     # Cache lookup
     cache_state_key = self._convert_to_cache_key(state_key)
-    cached_value = self._state_cache.get(cache_state_key, cache_token)
-    if cached_value is None:
-      # Cache miss, need to retrieve from the Runner
-      # Further size estimation or the use of the continuation token on the
-      # runner side could fall back to materializing one item at a time.
-      # https://jira.apache.org/jira/browse/BEAM-8297
-      materialized = cached_value = (
-          self._partially_cached_iterable(state_key, coder))
-      if isinstance(materialized, (list, self.ContinuationIterable)):
-        self._state_cache.put(cache_state_key, cache_token, materialized)
-      else:
-        _LOGGER.error(
-            "Uncacheable type %s for key %s. Not caching.",
-            materialized,
-            state_key)
-    return cached_value
+    return self._state_cache.get(
+        (cache_state_key, cache_token),

Review Comment:
   I'm just curious about why we need a state key and a token to retrieve data from the cache. Isn't a token adquate?



##########
sdks/python/apache_beam/runners/worker/sdk_worker.py:
##########
@@ -1149,7 +1149,7 @@ def blocking_get(
       return self._lazy_iterator(state_key, coder)
     # Cache lookup
     cache_state_key = self._convert_to_cache_key(state_key)
-    cached_value = self._state_cache.get(cache_state_key, cache_token)
+    cached_value = self._state_cache.peek((cache_state_key, cache_token))

Review Comment:
   are the extra brackets because peek wants a tuple?



##########
sdks/python/apache_beam/runners/worker/statecache.py:
##########
@@ -87,18 +88,46 @@ def get_referents_for_cache(*objs):
   return rval
 
 
+class _LoadingValue(WeightedValue):
+  """Allows concurrent users of the cache to wait for a value to be loaded."""
+  def __init__(self):
+    # type: () -> None
+    super().__init__(None, 1)
+    self._wait_event = threading.Event()
+
+  def load(self, key, loading_fn):
+    # type: (Any, Callable[[Any], Any]) -> None
+    try:
+      self._value = loading_fn(key)
+    except Exception as err:
+      self._error = err
+    finally:
+      self._wait_event.set()
+
+  def value(self):
+    # type: () -> Any
+    self._wait_event.wait()
+    err = getattr(self, "_error", None)
+    if err:
+      raise err
+    return self._value
+
+
 class StateCache(object):
   """Cache for Beam state access, scoped by state key and cache_token.
      Assumes a bag state implementation.
 
-  For a given state_key and cache_token, caches a value and allows to
-    a) read from the cache (get),
-           if the currently stored cache_token matches the provided
-    b) write to the cache (put),
-           storing the new value alongside with a cache token
-    c) empty a cached element (clear),
-           if the currently stored cache_token matches the provided
+  For a given key, caches a value and allows to
+    a) peek at the cache (peek),

Review Comment:
   is it worth documenting that peek doesn't just look if the value is there but also moves the value the head of the cache?
   
   That's the behavior we want with peek right?



##########
sdks/python/apache_beam/runners/worker/sdk_worker.py:
##########
@@ -1149,22 +1149,9 @@ def blocking_get(
       return self._lazy_iterator(state_key, coder)
     # Cache lookup
     cache_state_key = self._convert_to_cache_key(state_key)
-    cached_value = self._state_cache.get(cache_state_key, cache_token)
-    if cached_value is None:
-      # Cache miss, need to retrieve from the Runner
-      # Further size estimation or the use of the continuation token on the
-      # runner side could fall back to materializing one item at a time.
-      # https://jira.apache.org/jira/browse/BEAM-8297
-      materialized = cached_value = (
-          self._partially_cached_iterable(state_key, coder))
-      if isinstance(materialized, (list, self.ContinuationIterable)):
-        self._state_cache.put(cache_state_key, cache_token, materialized)
-      else:
-        _LOGGER.error(

Review Comment:
   we are removing this are there no longer cache misses we worry about? How are those handled?



##########
sdks/python/apache_beam/runners/worker/statecache.py:
##########
@@ -181,12 +263,17 @@ def describe_stats(self):
         hit_ratio = 100.0 * self._hit_count / request_count
       else:
         hit_ratio = 100.0
-      return 'used/max %d/%d MB, hit %.2f%%, lookups %d, evictions %d' % (
-          self._current_weight >> 20,
-          self._max_weight >> 20,
-          hit_ratio,
-          request_count,
-          self._evict_count)
+      return (
+          'used/max %d/%d MB, hit %.2f%%, lookups %d, '
+          'avg load time %.0f ns, loads %d, evictions %d') % (
+              self._current_weight >> 20,
+              self._max_weight >> 20,
+              hit_ratio,
+              request_count,
+              self._load_time_ns /
+              self._load_count if self._load_count > 0 else 0,

Review Comment:
   This is just to prevent a divide by 0 right?



##########
sdks/python/apache_beam/runners/worker/statecache.py:
##########
@@ -111,28 +140,86 @@ def __init__(self, max_weight):
     self._max_weight = max_weight
     self._current_weight = 0
     self._cache = collections.OrderedDict(
-    )  # type: collections.OrderedDict[Tuple[bytes, Optional[bytes]], WeightedValue]
+    )  # type: collections.OrderedDict[Any, WeightedValue]
     self._hit_count = 0
     self._miss_count = 0
     self._evict_count = 0
+    self._load_time_ns = 0
+    self._load_count = 0
     self._lock = threading.RLock()
 
-  def get(self, state_key, cache_token):
-    # type: (bytes, Optional[bytes]) -> Any
-    assert cache_token and self.is_cache_enabled()
-    key = (state_key, cache_token)
+  def peek(self, key):
+    # type: (Any) -> Any
+    assert self.is_cache_enabled()
     with self._lock:

Review Comment:
   If the lock is set because we are waiting on a load, peek will block?  Is that what we want?



##########
sdks/python/apache_beam/runners/worker/sdk_worker.py:
##########
@@ -1175,29 +1162,21 @@ def extend(
     # type: (...) -> _Future
     cache_token = self._get_cache_token(state_key)
     if cache_token:
-      # Update the cache
+      # Update the cache if the value is already present and
+      # can be updated.
       cache_key = self._convert_to_cache_key(state_key)
-      cached_value = self._state_cache.get(cache_key, cache_token)
-      # Keep in mind that the state for this key can be evicted
-      # while executing this function. Either read or write to the cache
-      # but never do both here!
-      if cached_value is None:
-        # We have never cached this key before, first retrieve state
-        cached_value = self.blocking_get(state_key, coder)
-      # Just extend the already cached value
+      cached_value = self._state_cache.peek((cache_key, cache_token))

Review Comment:
   So is the purpose of a peek vs a get, that get has some sort of effect on the cache and peek has none?
   
   I guess I'm not sure why we are peeking, shouldn't just the fact that we peek mean we are interested in updating the state of the data we looked at in the cache.



##########
sdks/python/apache_beam/runners/worker/statecache.py:
##########
@@ -181,12 +263,17 @@ def describe_stats(self):
         hit_ratio = 100.0 * self._hit_count / request_count
       else:
         hit_ratio = 100.0
-      return 'used/max %d/%d MB, hit %.2f%%, lookups %d, evictions %d' % (
-          self._current_weight >> 20,
-          self._max_weight >> 20,
-          hit_ratio,
-          request_count,
-          self._evict_count)
+      return (
+          'used/max %d/%d MB, hit %.2f%%, lookups %d, '
+          'avg load time %.0f ns, loads %d, evictions %d') % (

Review Comment:
   At first when I read loads, I though load time seconds num_loads, total_loads?



##########
sdks/python/apache_beam/runners/worker/statecache.py:
##########
@@ -111,28 +140,86 @@ def __init__(self, max_weight):
     self._max_weight = max_weight
     self._current_weight = 0
     self._cache = collections.OrderedDict(
-    )  # type: collections.OrderedDict[Tuple[bytes, Optional[bytes]], WeightedValue]
+    )  # type: collections.OrderedDict[Any, WeightedValue]
     self._hit_count = 0
     self._miss_count = 0
     self._evict_count = 0
+    self._load_time_ns = 0
+    self._load_count = 0
     self._lock = threading.RLock()
 
-  def get(self, state_key, cache_token):
-    # type: (bytes, Optional[bytes]) -> Any
-    assert cache_token and self.is_cache_enabled()
-    key = (state_key, cache_token)
+  def peek(self, key):
+    # type: (Any) -> Any
+    assert self.is_cache_enabled()
     with self._lock:
       value = self._cache.get(key, None)
-      if value is None:
+      if value is None or isinstance(value, _LoadingValue):
         self._miss_count += 1
         return None
+
       self._cache.move_to_end(key)
       self._hit_count += 1
-      return value.value()
+    return value.value()
+
+  def get(self, key, loading_fn):
+    # type: (Any, Callable[[Any], Any]) -> Any
+    assert self.is_cache_enabled() and callable(loading_fn)
+
+    self._lock.acquire()
+    value = self._cache.get(key, None)
+    if value is None:
+      self._miss_count += 1
+      loading_value = _LoadingValue()
+      self._cache[key] = loading_value
 
-  def put(self, state_key, cache_token, value):
-    # type: (bytes, Optional[bytes], Any) -> None
-    assert cache_token and self.is_cache_enabled()
+      # Ensure that we unlock the lock while loading to allow for parallel gets
+      self._lock.release()
+
+      start_time_ns = time.time_ns()
+      loading_value.load(key, loading_fn)
+      elapsed_time_ns = time.time_ns() - start_time_ns
+
+      # Replace the value in the cache with a weighted value now that the
+      # loading has completed successfully.
+      value = loading_value.value()
+      weight = objsize.get_deep_size(
+          value, get_referents_func=get_referents_for_cache)
+      if weight <= 0:
+        _LOGGER.warning(
+            'Expected object size to be >= 0 for %s but received %d.',
+            value,
+            weight)
+        weight = 8
+      value = WeightedValue(value, weight)
+      with self._lock:
+        self._load_count += 1
+        self._load_time_ns += elapsed_time_ns
+        # Don't replace values that have already been replaced with a different
+        # value by a put/invalidate that occurred concurrently with the load.
+        # The put/invalidate will have been responsible for updating the
+        # cache weight appropriately already.
+        old_value = self._cache.get(key, None)
+        if old_value is not loading_value:
+          return value.value()
+
+        self._current_weight -= loading_value.weight()
+        self._cache[key] = value
+        self._current_weight += value.weight()
+        while self._current_weight > self._max_weight:
+          (_, weighted_value) = self._cache.popitem(last=False)
+          self._current_weight -= weighted_value.weight()
+          self._evict_count += 1
+
+    else:
+      self._cache.move_to_end(key)
+      self._hit_count += 1
+      self._lock.release()
+
+    return value.value()

Review Comment:
   This assumes the value is either in the cache or the load function will return something we can use. Is there are case where the load function could fail and we should do something different?
   
   If I'm correct the value method has a raise method if there was an exception. Is there a case where the load function could no op and then we would have strange values here?



##########
sdks/python/apache_beam/runners/worker/statecache.py:
##########
@@ -87,18 +88,46 @@ def get_referents_for_cache(*objs):
   return rval
 
 
+class _LoadingValue(WeightedValue):
+  """Allows concurrent users of the cache to wait for a value to be loaded."""
+  def __init__(self):
+    # type: () -> None
+    super().__init__(None, 1)
+    self._wait_event = threading.Event()
+
+  def load(self, key, loading_fn):
+    # type: (Any, Callable[[Any], Any]) -> None
+    try:
+      self._value = loading_fn(key)
+    except Exception as err:

Review Comment:
   what kind of exceptions should we expect? Can we be more specific about which exceptions we catch?
   
   Oh I see we are raising this exception below in the value method.



##########
sdks/python/apache_beam/runners/worker/statecache.py:
##########
@@ -111,28 +140,86 @@ def __init__(self, max_weight):
     self._max_weight = max_weight
     self._current_weight = 0
     self._cache = collections.OrderedDict(
-    )  # type: collections.OrderedDict[Tuple[bytes, Optional[bytes]], WeightedValue]
+    )  # type: collections.OrderedDict[Any, WeightedValue]
     self._hit_count = 0
     self._miss_count = 0
     self._evict_count = 0
+    self._load_time_ns = 0
+    self._load_count = 0
     self._lock = threading.RLock()
 
-  def get(self, state_key, cache_token):
-    # type: (bytes, Optional[bytes]) -> Any
-    assert cache_token and self.is_cache_enabled()
-    key = (state_key, cache_token)
+  def peek(self, key):
+    # type: (Any) -> Any
+    assert self.is_cache_enabled()
     with self._lock:
       value = self._cache.get(key, None)
-      if value is None:
+      if value is None or isinstance(value, _LoadingValue):
         self._miss_count += 1
         return None
+
       self._cache.move_to_end(key)
       self._hit_count += 1
-      return value.value()
+    return value.value()
+
+  def get(self, key, loading_fn):
+    # type: (Any, Callable[[Any], Any]) -> Any
+    assert self.is_cache_enabled() and callable(loading_fn)
+
+    self._lock.acquire()
+    value = self._cache.get(key, None)
+    if value is None:
+      self._miss_count += 1
+      loading_value = _LoadingValue()

Review Comment:
   I'm just trying to go through the possible calls to the method in my head and trying to determine if the lock could be released and the value at cache[key] could be a _LoadingValue still.
   
   Do we need any logic to handle that. I think if we have strong tests to prevent this kind of race case we should be fine.



##########
sdks/python/apache_beam/runners/worker/statecache_test.py:
##########
@@ -19,144 +19,239 @@
 # pytype: skip-file
 
 import logging
+import re
+import threading
+import time
 import unittest
 
+from hamcrest import assert_that
+from hamcrest import contains_string
+
 from apache_beam.runners.worker.statecache import CacheAware
 from apache_beam.runners.worker.statecache import StateCache
 from apache_beam.runners.worker.statecache import WeightedValue
 
 
 class StateCacheTest(unittest.TestCase):
-  def test_empty_cache_get(self):
+  def test_empty_cache_peek(self):
     cache = StateCache(5 << 20)
-    self.assertEqual(cache.get("key", 'cache_token'), None)
-    with self.assertRaises(Exception):
-      # Invalid cache token provided
-      self.assertEqual(cache.get("key", None), None)
+    self.assertEqual(cache.peek("key"), None)
     self.assertEqual(
         cache.describe_stats(),
-        'used/max 0/5 MB, hit 0.00%, lookups 1, evictions 0')
+        (
+            'used/max 0/5 MB, hit 0.00%, lookups 1, '
+            'avg load time 0 ns, loads 0, evictions 0'))
 
-  def test_put_get(self):
+  def test_put_peek(self):
     cache = StateCache(5 << 20)
-    cache.put("key", "cache_token", WeightedValue("value", 1 << 20))
+    cache.put("key", WeightedValue("value", 1 << 20))
     self.assertEqual(cache.size(), 1)
-    self.assertEqual(cache.get("key", "cache_token"), "value")
-    self.assertEqual(cache.get("key", "cache_token2"), None)
-    with self.assertRaises(Exception):
-      self.assertEqual(cache.get("key", None), None)
-    self.assertEqual(
-        cache.describe_stats(),
-        'used/max 1/5 MB, hit 50.00%, lookups 2, evictions 0')
-
-  def test_clear(self):
-    cache = StateCache(5 << 20)
-    cache.clear("new-key", "cache_token")
-    cache.put("key", "cache_token", WeightedValue(["value"], 1 << 20))
-    self.assertEqual(cache.size(), 2)
-    self.assertEqual(cache.get("new-key", "new_token"), None)
-    self.assertEqual(cache.get("key", "cache_token"), ['value'])
-    # test clear without existing key/token
-    cache.clear("non-existing", "token")
-    self.assertEqual(cache.size(), 3)
-    self.assertEqual(cache.get("non-existing", "token"), [])
+    self.assertEqual(cache.peek("key"), "value")
+    self.assertEqual(cache.peek("key2"), None)
     self.assertEqual(
         cache.describe_stats(),
-        'used/max 1/5 MB, hit 66.67%, lookups 3, evictions 0')
+        (
+            'used/max 1/5 MB, hit 50.00%, lookups 2, '
+            'avg load time 0 ns, loads 0, evictions 0'))
 
   def test_default_sized_put(self):
     cache = StateCache(5 << 20)
-    cache.put("key", "cache_token", bytearray(1 << 20))
-    cache.put("key2", "cache_token", bytearray(1 << 20))
-    cache.put("key3", "cache_token", bytearray(1 << 20))
-    self.assertEqual(cache.get("key3", "cache_token"), bytearray(1 << 20))
-    cache.put("key4", "cache_token", bytearray(1 << 20))
-    cache.put("key5", "cache_token", bytearray(1 << 20))
+    cache.put("key", bytearray(1 << 20))
+    cache.put("key2", bytearray(1 << 20))
+    cache.put("key3", bytearray(1 << 20))
+    self.assertEqual(cache.peek("key3"), bytearray(1 << 20))
+    cache.put("key4", bytearray(1 << 20))
+    cache.put("key5", bytearray(1 << 20))
     # note that each byte array instance takes slightly over 1 MB which is why
     # these 5 byte arrays can't all be stored in the cache causing a single
     # eviction
     self.assertEqual(
         cache.describe_stats(),
-        'used/max 4/5 MB, hit 100.00%, lookups 1, evictions 1')
+        (
+            'used/max 4/5 MB, hit 100.00%, lookups 1, '
+            'avg load time 0 ns, loads 0, evictions 1'))
 
   def test_max_size(self):
     cache = StateCache(2 << 20)
-    cache.put("key", "cache_token", WeightedValue("value", 1 << 20))
-    cache.put("key2", "cache_token", WeightedValue("value2", 1 << 20))
+    cache.put("key", WeightedValue("value", 1 << 20))
+    cache.put("key2", WeightedValue("value2", 1 << 20))
     self.assertEqual(cache.size(), 2)
-    cache.put("key3", "cache_token", WeightedValue("value3", 1 << 20))
+    cache.put("key3", WeightedValue("value3", 1 << 20))
     self.assertEqual(cache.size(), 2)
     self.assertEqual(
         cache.describe_stats(),
-        'used/max 2/2 MB, hit 100.00%, lookups 0, evictions 1')
+        (
+            'used/max 2/2 MB, hit 100.00%, lookups 0, '
+            'avg load time 0 ns, loads 0, evictions 1'))
 
   def test_invalidate_all(self):
     cache = StateCache(5 << 20)
-    cache.put("key", "cache_token", WeightedValue("value", 1 << 20))
-    cache.put("key2", "cache_token", WeightedValue("value2", 1 << 20))
+    cache.put("key", WeightedValue("value", 1 << 20))
+    cache.put("key2", WeightedValue("value2", 1 << 20))
     self.assertEqual(cache.size(), 2)
     cache.invalidate_all()
     self.assertEqual(cache.size(), 0)
-    self.assertEqual(cache.get("key", "cache_token"), None)
-    self.assertEqual(cache.get("key2", "cache_token"), None)
+    self.assertEqual(cache.peek("key"), None)
+    self.assertEqual(cache.peek("key2"), None)
     self.assertEqual(
         cache.describe_stats(),
-        'used/max 0/5 MB, hit 0.00%, lookups 2, evictions 0')
+        (
+            'used/max 0/5 MB, hit 0.00%, lookups 2, '
+            'avg load time 0 ns, loads 0, evictions 0'))
 
   def test_lru(self):
     cache = StateCache(5 << 20)
-    cache.put("key", "cache_token", WeightedValue("value", 1 << 20))
-    cache.put("key2", "cache_token2", WeightedValue("value2", 1 << 20))
-    cache.put("key3", "cache_token", WeightedValue("value0", 1 << 20))
-    cache.put("key3", "cache_token", WeightedValue("value3", 1 << 20))
-    cache.put("key4", "cache_token4", WeightedValue("value4", 1 << 20))
-    cache.put("key5", "cache_token", WeightedValue("value0", 1 << 20))
-    cache.put("key5", "cache_token", WeightedValue(["value5"], 1 << 20))
+    cache.put("key", WeightedValue("value", 1 << 20))
+    cache.put("key2", WeightedValue("value2", 1 << 20))
+    cache.put("key3", WeightedValue("value0", 1 << 20))
+    cache.put("key3", WeightedValue("value3", 1 << 20))
+    cache.put("key4", WeightedValue("value4", 1 << 20))
+    cache.put("key5", WeightedValue("value0", 1 << 20))
+    cache.put("key5", WeightedValue(["value5"], 1 << 20))
     self.assertEqual(cache.size(), 5)
-    self.assertEqual(cache.get("key", "cache_token"), "value")
-    self.assertEqual(cache.get("key2", "cache_token2"), "value2")
-    self.assertEqual(cache.get("key3", "cache_token"), "value3")
-    self.assertEqual(cache.get("key4", "cache_token4"), "value4")
-    self.assertEqual(cache.get("key5", "cache_token"), ["value5"])
+    self.assertEqual(cache.peek("key"), "value")
+    self.assertEqual(cache.peek("key2"), "value2")
+    self.assertEqual(cache.peek("key3"), "value3")
+    self.assertEqual(cache.peek("key4"), "value4")
+    self.assertEqual(cache.peek("key5"), ["value5"])
     # insert another key to trigger cache eviction
-    cache.put("key6", "cache_token2", WeightedValue("value6", 1 << 20))
+    cache.put("key6", WeightedValue("value6", 1 << 20))
     self.assertEqual(cache.size(), 5)
     # least recently used key should be gone ("key")
-    self.assertEqual(cache.get("key", "cache_token"), None)
+    self.assertEqual(cache.peek("key"), None)
     # trigger a read on "key2"
-    cache.get("key2", "cache_token2")
+    cache.peek("key2")
     # insert another key to trigger cache eviction
-    cache.put("key7", "cache_token", WeightedValue("value7", 1 << 20))
+    cache.put("key7", WeightedValue("value7", 1 << 20))
     self.assertEqual(cache.size(), 5)
     # least recently used key should be gone ("key3")
-    self.assertEqual(cache.get("key3", "cache_token"), None)
-    # trigger a put on "key2"
-    cache.put("key2", "cache_token", WeightedValue("put", 1 << 20))
+    self.assertEqual(cache.peek("key3"), None)
+    # insert another key to trigger cache eviction
+    cache.put("key8", WeightedValue("put", 1 << 20))
     self.assertEqual(cache.size(), 5)
     # insert another key to trigger cache eviction
-    cache.put("key8", "cache_token", WeightedValue("value8", 1 << 20))
+    cache.put("key9", WeightedValue("value8", 1 << 20))
     self.assertEqual(cache.size(), 5)
     # least recently used key should be gone ("key4")
-    self.assertEqual(cache.get("key4", "cache_token"), None)
+    self.assertEqual(cache.peek("key4"), None)
     # make "key5" used by writing to it
-    cache.put("key5", "cache_token", WeightedValue("val", 1 << 20))
+    cache.put("key5", WeightedValue("val", 1 << 20))
     # least recently used key should be gone ("key6")
-    self.assertEqual(cache.get("key6", "cache_token"), None)
+    self.assertEqual(cache.peek("key6"), None)
     self.assertEqual(
         cache.describe_stats(),
-        'used/max 5/5 MB, hit 60.00%, lookups 10, evictions 5')
+        (
+            'used/max 5/5 MB, hit 60.00%, lookups 10, '
+            'avg load time 0 ns, loads 0, evictions 5'))
+
+  def test_get(self):
+    def check_key(key):
+      self.assertEqual(key, "key")
+      time.sleep(0.5)
+      return "value"
+
+    cache = StateCache(5 << 20)
+    self.assertEqual("value", cache.get("key", check_key))
+    self.assertEqual("value", cache.peek("key"))
+    cache.invalidate_all()
+    self.assertEqual("value", cache.get("key", check_key))
+    self.assertEqual("value", cache.peek("key"))
+
+    assert_that(cache.describe_stats(), contains_string(", loads 2,"))
+    load_time_ns = re.search(
+        ", avg load time (.+) ns,", cache.describe_stats()).group(1)
+    # Load time should be larger then the sleep time and less than 2x sleep time
+    self.assertGreater(int(load_time_ns), 0.5 * 1_000_000_000)
+    self.assertLess(int(load_time_ns), 1_000_000_000)
+
+  def test_concurrent_get_waits(self):

Review Comment:
   cool test.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org