You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lc...@apache.org on 2022/10/10 18:43:55 UTC
[beam] branch master updated: [fixes #23000] Update the Python SDK harness state cache to be a loading cache (#23046)
This is an automated email from the ASF dual-hosted git repository.
lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 0e61b026ea7 [fixes #23000] Update the Python SDK harness state cache to be a loading cache (#23046)
0e61b026ea7 is described below
commit 0e61b026ea7accd666fc443f3aeec7f93147a3b6
Author: Luke Cwik <lc...@google.com>
AuthorDate: Mon Oct 10 11:43:46 2022 -0700
[fixes #23000] Update the Python SDK harness state cache to be a loading cache (#23046)
* [#23000] Update the Python SDK harness state cache to be a loading cache
Also record statistics for the cost of performing a load.
* Use get/peek within sdk_worker
* Address PR comments
* Address PR comments
---
.../apache_beam/runners/worker/sdk_worker.py | 47 +---
.../apache_beam/runners/worker/statecache.py | 174 ++++++++++---
.../apache_beam/runners/worker/statecache_test.py | 282 +++++++++++++++------
3 files changed, 352 insertions(+), 151 deletions(-)
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index 968f213d6b2..93c9c609aa0 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -1149,22 +1149,9 @@ class GlobalCachingStateHandler(CachingStateHandler):
return self._lazy_iterator(state_key, coder)
# Cache lookup
cache_state_key = self._convert_to_cache_key(state_key)
- cached_value = self._state_cache.get(cache_state_key, cache_token)
- if cached_value is None:
- # Cache miss, need to retrieve from the Runner
- # Further size estimation or the use of the continuation token on the
- # runner side could fall back to materializing one item at a time.
- # https://jira.apache.org/jira/browse/BEAM-8297
- materialized = cached_value = (
- self._partially_cached_iterable(state_key, coder))
- if isinstance(materialized, (list, self.ContinuationIterable)):
- self._state_cache.put(cache_state_key, cache_token, materialized)
- else:
- _LOGGER.error(
- "Uncacheable type %s for key %s. Not caching.",
- materialized,
- state_key)
- return cached_value
+ return self._state_cache.get(
+ (cache_state_key, cache_token),
+ lambda key: self._partially_cached_iterable(state_key, coder))
def extend(
self,
@@ -1175,29 +1162,21 @@ class GlobalCachingStateHandler(CachingStateHandler):
# type: (...) -> _Future
cache_token = self._get_cache_token(state_key)
if cache_token:
- # Update the cache
+ # Update the cache if the value is already present and
+ # can be updated.
cache_key = self._convert_to_cache_key(state_key)
- cached_value = self._state_cache.get(cache_key, cache_token)
- # Keep in mind that the state for this key can be evicted
- # while executing this function. Either read or write to the cache
- # but never do both here!
- if cached_value is None:
- # We have never cached this key before, first retrieve state
- cached_value = self.blocking_get(state_key, coder)
- # Just extend the already cached value
+ cached_value = self._state_cache.peek((cache_key, cache_token))
if isinstance(cached_value, list):
+ # The state is fully cached and can be extended
+
# Materialize provided iterable to ensure reproducible iterations,
# here and when writing to the state handler below.
elements = list(elements)
- # The state is fully cached and can be extended
cached_value.extend(elements)
- elif isinstance(cached_value, self.ContinuationIterable):
- # The state is too large to be fully cached (continuation token used),
- # only the first part is cached, the rest if enumerated via the runner.
- pass
- else:
- # When a corrupt value made it into the cache, we have to fail.
- raise Exception("Unexpected cached value: %s" % cached_value)
+ # Re-insert into the cache the updated value so the updated size is
+ # reflected.
+ self._state_cache.put((cache_key, cache_token), cached_value)
+
# Write to state handler
futures = []
out = coder_impl.create_OutputStream()
@@ -1220,7 +1199,7 @@ class GlobalCachingStateHandler(CachingStateHandler):
cache_token = self._get_cache_token(state_key)
if cache_token:
cache_key = self._convert_to_cache_key(state_key)
- self._state_cache.clear(cache_key, cache_token)
+ self._state_cache.put((cache_key, cache_token), [])
return self._underlying.clear(state_key)
def done(self):
diff --git a/sdks/python/apache_beam/runners/worker/statecache.py b/sdks/python/apache_beam/runners/worker/statecache.py
index 839b2b0568a..e3f37fec114 100644
--- a/sdks/python/apache_beam/runners/worker/statecache.py
+++ b/sdks/python/apache_beam/runners/worker/statecache.py
@@ -28,8 +28,8 @@ import time
import types
import weakref
from typing import Any
+from typing import Callable
from typing import List
-from typing import Optional
from typing import Tuple
from typing import Union
@@ -76,6 +76,7 @@ class WeightedValue(object):
class CacheAware(object):
+ """Allows cache users to override what objects are measured."""
def __init__(self):
# type: () -> None
pass
@@ -181,18 +182,47 @@ def get_deep_size(*objs):
filter_func=_filter_func)
+class _LoadingValue(WeightedValue):
+ """Allows concurrent users of the cache to wait for a value to be loaded."""
+ def __init__(self):
+ # type: () -> None
+ super().__init__(None, 1)
+ self._wait_event = threading.Event()
+
+ def load(self, key, loading_fn):
+ # type: (Any, Callable[[Any], Any]) -> None
+ try:
+ self._value = loading_fn(key)
+ except Exception as err:
+ self._error = err
+ finally:
+ self._wait_event.set()
+
+ def value(self):
+ # type: () -> Any
+ self._wait_event.wait()
+ err = getattr(self, "_error", None)
+ if err:
+ raise err
+ return self._value
+
+
class StateCache(object):
- """Cache for Beam state access, scoped by state key and cache_token.
+ """LRU cache for Beam state access, scoped by state key and cache_token.
Assumes a bag state implementation.
- For a given state_key and cache_token, caches a value and allows to
- a) read from the cache (get),
- if the currently stored cache_token matches the provided
- b) write to the cache (put),
- storing the new value alongside with a cache token
- c) empty a cached element (clear),
- if the currently stored cache_token matches the provided
+ For a given key, caches a value and allows to
+ a) peek at the cache (peek),
+ returns the value for the provided key or None if it doesn't exist.
+ Will never block.
+ b) read from the cache (get),
+ returns the value for the provided key or loads it using the
+ supplied function. Multiple calls for the same key will block
+ until the value is loaded.
+ c) write to the cache (put),
+ store the provided value overwriting any previous result
d) invalidate a cached element (invalidate)
+ removes the value from the cache for the provided key
e) invalidate all cached elements (invalidate_all)
The operations on the cache are thread-safe for use by multiple workers.
@@ -205,28 +235,106 @@ class StateCache(object):
self._max_weight = max_weight
self._current_weight = 0
self._cache = collections.OrderedDict(
- ) # type: collections.OrderedDict[Tuple[bytes, Optional[bytes]], WeightedValue]
+ ) # type: collections.OrderedDict[Any, WeightedValue]
self._hit_count = 0
self._miss_count = 0
self._evict_count = 0
+ self._load_time_ns = 0
+ self._load_count = 0
self._lock = threading.RLock()
- def get(self, state_key, cache_token):
- # type: (bytes, Optional[bytes]) -> Any
- assert cache_token and self.is_cache_enabled()
- key = (state_key, cache_token)
+ def peek(self, key):
+ # type: (Any) -> Any
+ assert self.is_cache_enabled()
with self._lock:
value = self._cache.get(key, None)
- if value is None:
+ if value is None or _safe_isinstance(value, _LoadingValue):
self._miss_count += 1
return None
+
self._cache.move_to_end(key)
self._hit_count += 1
+ return value.value()
+
+ def get(self, key, loading_fn):
+ # type: (Any, Callable[[Any], Any]) -> Any
+ assert self.is_cache_enabled() and callable(loading_fn)
+
+ self._lock.acquire()
+ value = self._cache.get(key, None)
+
+ # Return the already cached value
+ if value is not None:
+ self._cache.move_to_end(key)
+ self._hit_count += 1
+ self._lock.release()
return value.value()
- def put(self, state_key, cache_token, value):
- # type: (bytes, Optional[bytes], Any) -> None
- assert cache_token and self.is_cache_enabled()
+ # Load the value since it isn't in the cache.
+ self._miss_count += 1
+ loading_value = _LoadingValue()
+ self._cache[key] = loading_value
+
+ # Ensure that we unlock the lock while loading to allow for parallel gets
+ self._lock.release()
+
+ start_time_ns = time.time_ns()
+ loading_value.load(key, loading_fn)
+ elapsed_time_ns = time.time_ns() - start_time_ns
+
+ try:
+ value = loading_value.value()
+ except Exception as err:
+ # If loading failed then delete the value from the cache allowing for
+ # the next lookup to possibly succeed.
+ with self._lock:
+ self._load_count += 1
+ self._load_time_ns += elapsed_time_ns
+ # Don't remove values that have already been replaced with a different
+ # value by a put/invalidate that occurred concurrently with the load.
+ # The put/invalidate will have been responsible for updating the
+ # cache weight appropriately already.
+ old_value = self._cache.get(key, None)
+ if old_value is not loading_value:
+ raise err
+ self._current_weight -= loading_value.weight()
+ del self._cache[key]
+ raise err
+
+ # Replace the value in the cache with a weighted value now that the
+ # loading has completed successfully.
+ weight = get_deep_size(value)
+ if weight <= 0:
+ _LOGGER.warning(
+ 'Expected object size to be >= 0 for %s but received %d.',
+ value,
+ weight)
+ weight = 8
+ value = WeightedValue(value, weight)
+ with self._lock:
+ self._load_count += 1
+ self._load_time_ns += elapsed_time_ns
+ # Don't replace values that have already been replaced with a different
+ # value by a put/invalidate that occurred concurrently with the load.
+ # The put/invalidate will have been responsible for updating the
+ # cache weight appropriately already.
+ old_value = self._cache.get(key, None)
+ if old_value is not loading_value:
+ return value.value()
+
+ self._current_weight -= loading_value.weight()
+ self._cache[key] = value
+ self._current_weight += value.weight()
+ while self._current_weight > self._max_weight:
+ (_, weighted_value) = self._cache.popitem(last=False)
+ self._current_weight -= weighted_value.weight()
+ self._evict_count += 1
+
+ return value.value()
+
+ def put(self, key, value):
+ # type: (Any, Any) -> None
+ assert self.is_cache_enabled()
if not _safe_isinstance(value, WeightedValue):
weight = get_deep_size(value)
if weight <= 0:
@@ -236,27 +344,22 @@ class StateCache(object):
weight)
weight = _DEFAULT_WEIGHT
value = WeightedValue(value, weight)
- key = (state_key, cache_token)
with self._lock:
old_value = self._cache.pop(key, None)
if old_value is not None:
self._current_weight -= old_value.weight()
- self._cache[(state_key, cache_token)] = value
+ self._cache[key] = value
self._current_weight += value.weight()
while self._current_weight > self._max_weight:
(_, weighted_value) = self._cache.popitem(last=False)
self._current_weight -= weighted_value.weight()
self._evict_count += 1
- def clear(self, state_key, cache_token):
- # type: (bytes, Optional[bytes]) -> None
- self.put(state_key, cache_token, [])
-
- def invalidate(self, state_key, cache_token):
- # type: (bytes, Optional[bytes]) -> None
+ def invalidate(self, key):
+ # type: (Any) -> None
assert self.is_cache_enabled()
with self._lock:
- weighted_value = self._cache.pop((state_key, cache_token), None)
+ weighted_value = self._cache.pop(key, None)
if weighted_value is not None:
self._current_weight -= weighted_value.weight()
@@ -274,12 +377,17 @@ class StateCache(object):
hit_ratio = 100.0 * self._hit_count / request_count
else:
hit_ratio = 100.0
- return 'used/max %d/%d MB, hit %.2f%%, lookups %d, evictions %d' % (
- self._current_weight >> 20,
- self._max_weight >> 20,
- hit_ratio,
- request_count,
- self._evict_count)
+ return (
+ 'used/max %d/%d MB, hit %.2f%%, lookups %d, '
+ 'avg load time %.0f ns, loads %d, evictions %d') % (
+ self._current_weight >> 20,
+ self._max_weight >> 20,
+ hit_ratio,
+ request_count,
+ self._load_time_ns /
+ self._load_count if self._load_count > 0 else 0,
+ self._load_count,
+ self._evict_count)
def is_cache_enabled(self):
# type: () -> bool
diff --git a/sdks/python/apache_beam/runners/worker/statecache_test.py b/sdks/python/apache_beam/runners/worker/statecache_test.py
index 155d2857ce4..6850cb21284 100644
--- a/sdks/python/apache_beam/runners/worker/statecache_test.py
+++ b/sdks/python/apache_beam/runners/worker/statecache_test.py
@@ -19,13 +19,19 @@
# pytype: skip-file
import logging
+import re
import threading
+import time
import unittest
import weakref
+from hamcrest import assert_that
+from hamcrest import contains_string
+
from apache_beam.runners.worker.statecache import CacheAware
from apache_beam.runners.worker.statecache import StateCache
from apache_beam.runners.worker.statecache import WeightedValue
+from apache_beam.runners.worker.statecache import _LoadingValue
class StateCacheTest(unittest.TestCase):
@@ -39,21 +45,22 @@ class StateCacheTest(unittest.TestCase):
cache = StateCache(5 << 20)
wait_event = threading.Event()
o = WeightedValueRef()
- cache.put('deep ref', 'a', o)
+ cache.put('deep ref', o)
# Ensure that the contents of the internal weak ref isn't sized
- self.assertIsNotNone(cache.get('deep ref', 'a'))
+ self.assertIsNotNone(cache.peek('deep ref'))
self.assertEqual(
cache.describe_stats(),
- 'used/max 0/5 MB, hit 100.00%, lookups 1, evictions 0')
+ 'used/max 0/5 MB, hit 100.00%, lookups 1, avg load time 0 ns, loads 0, '
+ 'evictions 0')
cache.invalidate_all()
# Ensure that putting in a weakref doesn't fail regardless of whether
# it is alive or not
o_ref = weakref.ref(o, lambda value: wait_event.set())
- cache.put('not deleted ref', 'a', o_ref)
+ cache.put('not deleted ref', o_ref)
del o
wait_event.wait()
- cache.put('deleted', 'a', o_ref)
+ cache.put('deleted', o_ref)
def test_weakref_proxy(self):
test_value = WeightedValue('test', 10 << 20)
@@ -65,21 +72,22 @@ class StateCacheTest(unittest.TestCase):
cache = StateCache(5 << 20)
wait_event = threading.Event()
o = WeightedValueRef()
- cache.put('deep ref', 'a', o)
+ cache.put('deep ref', o)
# Ensure that the contents of the internal weak ref isn't sized
- self.assertIsNotNone(cache.get('deep ref', 'a'))
+ self.assertIsNotNone(cache.peek('deep ref'))
self.assertEqual(
cache.describe_stats(),
- 'used/max 0/5 MB, hit 100.00%, lookups 1, evictions 0')
+ 'used/max 0/5 MB, hit 100.00%, lookups 1, avg load time 0 ns, loads 0, '
+ 'evictions 0')
cache.invalidate_all()
# Ensure that putting in a weakref doesn't fail regardless of whether
# it is alive or not
o_ref = weakref.proxy(o, lambda value: wait_event.set())
- cache.put('not deleted', 'a', o_ref)
+ cache.put('not deleted', o_ref)
del o
wait_event.wait()
- cache.put('deleted', 'a', o_ref)
+ cache.put('deleted', o_ref)
def test_size_of_fails(self):
class BadSizeOf(object):
@@ -89,143 +97,247 @@ class StateCacheTest(unittest.TestCase):
cache = StateCache(5 << 20)
with self.assertLogs('apache_beam.runners.worker.statecache',
level='WARNING') as context:
- cache.put('key', 'a', BadSizeOf())
+ cache.put('key', BadSizeOf())
self.assertEqual(1, len(context.output))
self.assertTrue('Failed to size' in context.output[0])
# Test that we don't spam the logs
- cache.put('key', 'a', BadSizeOf())
+ cache.put('key', BadSizeOf())
self.assertEqual(1, len(context.output))
- def test_empty_cache_get(self):
+ def test_empty_cache_peek(self):
cache = StateCache(5 << 20)
- self.assertEqual(cache.get("key", 'cache_token'), None)
- with self.assertRaises(Exception):
- # Invalid cache token provided
- self.assertEqual(cache.get("key", None), None)
+ self.assertEqual(cache.peek("key"), None)
self.assertEqual(
cache.describe_stats(),
- 'used/max 0/5 MB, hit 0.00%, lookups 1, evictions 0')
+ (
+ 'used/max 0/5 MB, hit 0.00%, lookups 1, '
+ 'avg load time 0 ns, loads 0, evictions 0'))
- def test_put_get(self):
+ def test_put_peek(self):
cache = StateCache(5 << 20)
- cache.put("key", "cache_token", WeightedValue("value", 1 << 20))
+ cache.put("key", WeightedValue("value", 1 << 20))
self.assertEqual(cache.size(), 1)
- self.assertEqual(cache.get("key", "cache_token"), "value")
- self.assertEqual(cache.get("key", "cache_token2"), None)
- with self.assertRaises(Exception):
- self.assertEqual(cache.get("key", None), None)
- self.assertEqual(
- cache.describe_stats(),
- 'used/max 1/5 MB, hit 50.00%, lookups 2, evictions 0')
-
- def test_clear(self):
- cache = StateCache(5 << 20)
- cache.clear("new-key", "cache_token")
- cache.put("key", "cache_token", WeightedValue(["value"], 1 << 20))
- self.assertEqual(cache.size(), 2)
- self.assertEqual(cache.get("new-key", "new_token"), None)
- self.assertEqual(cache.get("key", "cache_token"), ['value'])
- # test clear without existing key/token
- cache.clear("non-existing", "token")
- self.assertEqual(cache.size(), 3)
- self.assertEqual(cache.get("non-existing", "token"), [])
+ self.assertEqual(cache.peek("key"), "value")
+ self.assertEqual(cache.peek("key2"), None)
self.assertEqual(
cache.describe_stats(),
- 'used/max 1/5 MB, hit 66.67%, lookups 3, evictions 0')
+ (
+ 'used/max 1/5 MB, hit 50.00%, lookups 2, '
+ 'avg load time 0 ns, loads 0, evictions 0'))
def test_default_sized_put(self):
cache = StateCache(5 << 20)
- cache.put("key", "cache_token", bytearray(1 << 20))
- cache.put("key2", "cache_token", bytearray(1 << 20))
- cache.put("key3", "cache_token", bytearray(1 << 20))
- self.assertEqual(cache.get("key3", "cache_token"), bytearray(1 << 20))
- cache.put("key4", "cache_token", bytearray(1 << 20))
- cache.put("key5", "cache_token", bytearray(1 << 20))
+ cache.put("key", bytearray(1 << 20))
+ cache.put("key2", bytearray(1 << 20))
+ cache.put("key3", bytearray(1 << 20))
+ self.assertEqual(cache.peek("key3"), bytearray(1 << 20))
+ cache.put("key4", bytearray(1 << 20))
+ cache.put("key5", bytearray(1 << 20))
# note that each byte array instance takes slightly over 1 MB which is why
# these 5 byte arrays can't all be stored in the cache causing a single
# eviction
self.assertEqual(
cache.describe_stats(),
- 'used/max 4/5 MB, hit 100.00%, lookups 1, evictions 1')
+ (
+ 'used/max 4/5 MB, hit 100.00%, lookups 1, '
+ 'avg load time 0 ns, loads 0, evictions 1'))
def test_max_size(self):
cache = StateCache(2 << 20)
- cache.put("key", "cache_token", WeightedValue("value", 1 << 20))
- cache.put("key2", "cache_token", WeightedValue("value2", 1 << 20))
+ cache.put("key", WeightedValue("value", 1 << 20))
+ cache.put("key2", WeightedValue("value2", 1 << 20))
self.assertEqual(cache.size(), 2)
- cache.put("key3", "cache_token", WeightedValue("value3", 1 << 20))
+ cache.put("key3", WeightedValue("value3", 1 << 20))
self.assertEqual(cache.size(), 2)
self.assertEqual(
cache.describe_stats(),
- 'used/max 2/2 MB, hit 100.00%, lookups 0, evictions 1')
+ (
+ 'used/max 2/2 MB, hit 100.00%, lookups 0, '
+ 'avg load time 0 ns, loads 0, evictions 1'))
def test_invalidate_all(self):
cache = StateCache(5 << 20)
- cache.put("key", "cache_token", WeightedValue("value", 1 << 20))
- cache.put("key2", "cache_token", WeightedValue("value2", 1 << 20))
+ cache.put("key", WeightedValue("value", 1 << 20))
+ cache.put("key2", WeightedValue("value2", 1 << 20))
self.assertEqual(cache.size(), 2)
cache.invalidate_all()
self.assertEqual(cache.size(), 0)
- self.assertEqual(cache.get("key", "cache_token"), None)
- self.assertEqual(cache.get("key2", "cache_token"), None)
+ self.assertEqual(cache.peek("key"), None)
+ self.assertEqual(cache.peek("key2"), None)
self.assertEqual(
cache.describe_stats(),
- 'used/max 0/5 MB, hit 0.00%, lookups 2, evictions 0')
+ (
+ 'used/max 0/5 MB, hit 0.00%, lookups 2, '
+ 'avg load time 0 ns, loads 0, evictions 0'))
def test_lru(self):
cache = StateCache(5 << 20)
- cache.put("key", "cache_token", WeightedValue("value", 1 << 20))
- cache.put("key2", "cache_token2", WeightedValue("value2", 1 << 20))
- cache.put("key3", "cache_token", WeightedValue("value0", 1 << 20))
- cache.put("key3", "cache_token", WeightedValue("value3", 1 << 20))
- cache.put("key4", "cache_token4", WeightedValue("value4", 1 << 20))
- cache.put("key5", "cache_token", WeightedValue("value0", 1 << 20))
- cache.put("key5", "cache_token", WeightedValue(["value5"], 1 << 20))
+ cache.put("key", WeightedValue("value", 1 << 20))
+ cache.put("key2", WeightedValue("value2", 1 << 20))
+ cache.put("key3", WeightedValue("value0", 1 << 20))
+ cache.put("key3", WeightedValue("value3", 1 << 20))
+ cache.put("key4", WeightedValue("value4", 1 << 20))
+ cache.put("key5", WeightedValue("value0", 1 << 20))
+ cache.put("key5", WeightedValue(["value5"], 1 << 20))
self.assertEqual(cache.size(), 5)
- self.assertEqual(cache.get("key", "cache_token"), "value")
- self.assertEqual(cache.get("key2", "cache_token2"), "value2")
- self.assertEqual(cache.get("key3", "cache_token"), "value3")
- self.assertEqual(cache.get("key4", "cache_token4"), "value4")
- self.assertEqual(cache.get("key5", "cache_token"), ["value5"])
+ self.assertEqual(cache.peek("key"), "value")
+ self.assertEqual(cache.peek("key2"), "value2")
+ self.assertEqual(cache.peek("key3"), "value3")
+ self.assertEqual(cache.peek("key4"), "value4")
+ self.assertEqual(cache.peek("key5"), ["value5"])
# insert another key to trigger cache eviction
- cache.put("key6", "cache_token2", WeightedValue("value6", 1 << 20))
+ cache.put("key6", WeightedValue("value6", 1 << 20))
self.assertEqual(cache.size(), 5)
# least recently used key should be gone ("key")
- self.assertEqual(cache.get("key", "cache_token"), None)
+ self.assertEqual(cache.peek("key"), None)
# trigger a read on "key2"
- cache.get("key2", "cache_token2")
+ cache.peek("key2")
# insert another key to trigger cache eviction
- cache.put("key7", "cache_token", WeightedValue("value7", 1 << 20))
+ cache.put("key7", WeightedValue("value7", 1 << 20))
self.assertEqual(cache.size(), 5)
# least recently used key should be gone ("key3")
- self.assertEqual(cache.get("key3", "cache_token"), None)
- # trigger a put on "key2"
- cache.put("key2", "cache_token", WeightedValue("put", 1 << 20))
+ self.assertEqual(cache.peek("key3"), None)
+ # insert another key to trigger cache eviction
+ cache.put("key8", WeightedValue("put", 1 << 20))
self.assertEqual(cache.size(), 5)
# insert another key to trigger cache eviction
- cache.put("key8", "cache_token", WeightedValue("value8", 1 << 20))
+ cache.put("key9", WeightedValue("value8", 1 << 20))
self.assertEqual(cache.size(), 5)
# least recently used key should be gone ("key4")
- self.assertEqual(cache.get("key4", "cache_token"), None)
+ self.assertEqual(cache.peek("key4"), None)
# make "key5" used by writing to it
- cache.put("key5", "cache_token", WeightedValue("val", 1 << 20))
+ cache.put("key5", WeightedValue("val", 1 << 20))
# least recently used key should be gone ("key6")
- self.assertEqual(cache.get("key6", "cache_token"), None)
+ self.assertEqual(cache.peek("key6"), None)
self.assertEqual(
cache.describe_stats(),
- 'used/max 5/5 MB, hit 60.00%, lookups 10, evictions 5')
+ (
+ 'used/max 5/5 MB, hit 60.00%, lookups 10, '
+ 'avg load time 0 ns, loads 0, evictions 5'))
+
+ def test_get(self):
+ def check_key(key):
+ self.assertEqual(key, "key")
+ time.sleep(0.5)
+ return "value"
+
+ def raise_exception(key):
+ time.sleep(0.5)
+ raise Exception("TestException")
+
+ cache = StateCache(5 << 20)
+ self.assertEqual("value", cache.get("key", check_key))
+ with cache._lock:
+ self.assertFalse(isinstance(cache._cache["key"], _LoadingValue))
+ self.assertEqual("value", cache.peek("key"))
+ cache.invalidate_all()
+
+ with self.assertRaisesRegex(Exception, "TestException"):
+ cache.get("key", raise_exception)
+ # The cache should not have the value after the failing load causing
+ # check_key to load the value.
+ self.assertEqual("value", cache.get("key", check_key))
+ with cache._lock:
+ self.assertFalse(isinstance(cache._cache["key"], _LoadingValue))
+ self.assertEqual("value", cache.peek("key"))
+
+ assert_that(cache.describe_stats(), contains_string(", loads 3,"))
+ load_time_ns = re.search(
+ ", avg load time (.+) ns,", cache.describe_stats()).group(1)
+ # Load time should be larger then the sleep time and less than 2x sleep time
+ self.assertGreater(int(load_time_ns), 0.5 * 1_000_000_000)
+ self.assertLess(int(load_time_ns), 1_000_000_000)
+
+ def test_concurrent_get_waits(self):
+ event = threading.Semaphore(0)
+ threads_running = threading.Barrier(3)
+
+ def wait_for_event(key):
+ with cache._lock:
+ self.assertTrue(isinstance(cache._cache["key"], _LoadingValue))
+ event.release()
+ return "value"
+
+ cache = StateCache(5 << 20)
+
+ def load_key(output):
+ threads_running.wait()
+ output["value"] = cache.get("key", wait_for_event)
+ output["time"] = time.time_ns()
+
+ t1_output = {}
+ t1 = threading.Thread(target=load_key, args=(t1_output, ))
+ t1.start()
+
+ t2_output = {}
+ t2 = threading.Thread(target=load_key, args=(t2_output, ))
+ t2.start()
+
+ # Wait for both threads to start
+ threads_running.wait()
+ # Record the time and wait for the load to start
+ current_time_ns = time.time_ns()
+ event.acquire()
+ t1.join()
+ t2.join()
+
+ # Ensure that only one thread did the loading and not both by checking that
+ # the semaphore was only released once
+ self.assertFalse(event.acquire(blocking=False))
+
+ # Ensure that the load time is greater than the set time ensuring that
+ # both loads had to wait for the event
+ self.assertLessEqual(current_time_ns, t1_output["time"])
+ self.assertLessEqual(current_time_ns, t2_output["time"])
+ self.assertEqual("value", t1_output["value"])
+ self.assertEqual("value", t2_output["value"])
+ self.assertEqual("value", cache.peek("key"))
+
+ def test_concurrent_get_superseded_by_put(self):
+ load_happening = threading.Event()
+ finish_loading = threading.Event()
+
+ def wait_for_event(key):
+ load_happening.set()
+ finish_loading.wait()
+ return "value"
+
+ cache = StateCache(5 << 20)
+
+ def load_key(output):
+ output["value"] = cache.get("key", wait_for_event)
+
+ t1_output = {}
+ t1 = threading.Thread(target=load_key, args=(t1_output, ))
+ t1.start()
+
+ # Wait for the load to start, update the key, and then let the load finish
+ load_happening.wait()
+ cache.put("key", "value2")
+ finish_loading.set()
+ t1.join()
+
+ # Ensure that the original value is loaded and returned and not the
+ # updated value
+ self.assertEqual("value", t1_output["value"])
+ # Ensure that the updated value supersedes the loaded value.
+ self.assertEqual("value2", cache.peek("key"))
def test_is_cached_enabled(self):
cache = StateCache(1 << 20)
self.assertEqual(cache.is_cache_enabled(), True)
self.assertEqual(
cache.describe_stats(),
- 'used/max 0/1 MB, hit 100.00%, lookups 0, evictions 0')
+ (
+ 'used/max 0/1 MB, hit 100.00%, lookups 0, '
+ 'avg load time 0 ns, loads 0, evictions 0'))
cache = StateCache(0)
self.assertEqual(cache.is_cache_enabled(), False)
self.assertEqual(
cache.describe_stats(),
- 'used/max 0/0 MB, hit 100.00%, lookups 0, evictions 0')
+ (
+ 'used/max 0/0 MB, hit 100.00%, lookups 0, '
+ 'avg load time 0 ns, loads 0, evictions 0'))
def test_get_referents_for_cache(self):
class GetReferentsForCache(CacheAware):
@@ -237,10 +349,12 @@ class StateCacheTest(unittest.TestCase):
return [self.measure_me]
cache = StateCache(5 << 20)
- cache.put("key", "cache_token", GetReferentsForCache())
+ cache.put("key", GetReferentsForCache())
self.assertEqual(
cache.describe_stats(),
- 'used/max 1/5 MB, hit 100.00%, lookups 0, evictions 0')
+ (
+ 'used/max 1/5 MB, hit 100.00%, lookups 0, '
+ 'avg load time 0 ns, loads 0, evictions 0'))
if __name__ == '__main__':