You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lo...@apache.org on 2022/09/09 18:25:33 UTC
[beam] branch release-2.42.0 updated: Revert "[#19857] Migrate to using a memory aware cache within the Python SDK harness (#22924)" (#23107)
This is an automated email from the ASF dual-hosted git repository.
lostluck pushed a commit to branch release-2.42.0
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/release-2.42.0 by this push:
new 4d4fd066203 Revert "[#19857] Migrate to using a memory aware cache within the Python SDK harness (#22924)" (#23107)
4d4fd066203 is described below
commit 4d4fd066203831005d706b9c9a699cc072ca1863
Author: Luke Cwik <lc...@google.com>
AuthorDate: Fri Sep 9 11:25:27 2022 -0700
Revert "[#19857] Migrate to using a memory aware cache within the Python SDK harness (#22924)" (#23107)
This reverts commit 25c6ed74c9846c89a92655c1e8d313ef87d6adb1.
---
.../runners/portability/flink_runner_test.py | 93 +++++++
.../portability/fn_api_runner/worker_handlers.py | 12 +-
.../apache_beam/runners/worker/sdk_worker.py | 24 +-
.../apache_beam/runners/worker/sdk_worker_main.py | 6 +-
.../apache_beam/runners/worker/sdk_worker_test.py | 6 +-
.../apache_beam/runners/worker/statecache.py | 286 +++++++++++++--------
.../apache_beam/runners/worker/statecache_test.py | 224 +++++++++-------
.../apache_beam/runners/worker/worker_status.py | 27 +-
.../container/py37/base_image_requirements.txt | 1 -
.../container/py38/base_image_requirements.txt | 1 -
.../container/py39/base_image_requirements.txt | 1 -
sdks/python/setup.py | 1 -
12 files changed, 432 insertions(+), 250 deletions(-)
diff --git a/sdks/python/apache_beam/runners/portability/flink_runner_test.py b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
index 27e4ca4973e..48e5df54d2d 100644
--- a/sdks/python/apache_beam/runners/portability/flink_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
@@ -32,9 +32,12 @@ import pytest
import apache_beam as beam
from apache_beam import Impulse
from apache_beam import Map
+from apache_beam import Pipeline
+from apache_beam.coders import VarIntCoder
from apache_beam.io.external.generate_sequence import GenerateSequence
from apache_beam.io.kafka import ReadFromKafka
from apache_beam.io.kafka import WriteToKafka
+from apache_beam.metrics import Metrics
from apache_beam.options.pipeline_options import DebugOptions
from apache_beam.options.pipeline_options import FlinkRunnerOptions
from apache_beam.options.pipeline_options import PortableOptions
@@ -44,6 +47,7 @@ from apache_beam.runners.portability import portable_runner
from apache_beam.runners.portability import portable_runner_test
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
+from apache_beam.transforms import userstate
from apache_beam.transforms.sql import SqlTransform
# Run as
@@ -292,6 +296,95 @@ class FlinkRunnerTest(portable_runner_test.PortableRunnerTest):
def test_metrics(self):
super().test_metrics(check_gauge=False)
+ def test_flink_metrics(self):
+ """Run a simple DoFn that increments a counter and verifies state
+ caching metrics. Verifies that its expected value is written to a
+ temporary file by the FileReporter"""
+
+ counter_name = 'elem_counter'
+ state_spec = userstate.BagStateSpec('state', VarIntCoder())
+
+ class DoFn(beam.DoFn):
+ def __init__(self):
+ self.counter = Metrics.counter(self.__class__, counter_name)
+ _LOGGER.info('counter: %s' % self.counter.metric_name)
+
+ def process(self, kv, state=beam.DoFn.StateParam(state_spec)):
+ # Trigger materialization
+ list(state.read())
+ state.add(1)
+ self.counter.inc()
+
+ options = self.create_options()
+ # Test only supports parallelism of 1
+ options._all_options['parallelism'] = 1
+ # Create multiple bundles to test cache metrics
+ options._all_options['max_bundle_size'] = 10
+ options._all_options['max_bundle_time_millis'] = 95130590130
+ experiments = options.view_as(DebugOptions).experiments or []
+ experiments.append('state_cache_size=123')
+ options.view_as(DebugOptions).experiments = experiments
+ with Pipeline(self.get_runner(), options) as p:
+ # pylint: disable=expression-not-assigned
+ (
+ p
+ | "create" >> beam.Create(list(range(0, 110)))
+ | "mapper" >> beam.Map(lambda x: (x % 10, 'val'))
+ | "stateful" >> beam.ParDo(DoFn()))
+
+ lines_expected = {'counter: 110'}
+ if options.view_as(StandardOptions).streaming:
+ lines_expected.update([
+ # Gauges for the last finished bundle
+ 'stateful.beam_metric:statecache:capacity: 123',
+ 'stateful.beam_metric:statecache:size: 10',
+ 'stateful.beam_metric:statecache:get: 20',
+ 'stateful.beam_metric:statecache:miss: 0',
+ 'stateful.beam_metric:statecache:hit: 20',
+ 'stateful.beam_metric:statecache:put: 0',
+ 'stateful.beam_metric:statecache:evict: 0',
+ # Counters
+ 'stateful.beam_metric:statecache:get_total: 220',
+ 'stateful.beam_metric:statecache:miss_total: 10',
+ 'stateful.beam_metric:statecache:hit_total: 210',
+ 'stateful.beam_metric:statecache:put_total: 10',
+ 'stateful.beam_metric:statecache:evict_total: 0',
+ ])
+ else:
+ # Batch has a different processing model. All values for
+ # a key are processed at once.
+ lines_expected.update([
+ # Gauges
+ 'stateful).beam_metric:statecache:capacity: 123',
+ # For the first key, the cache token will not be set yet.
+ # It's lazily initialized after first access in StateRequestHandlers
+ 'stateful).beam_metric:statecache:size: 10',
+ # We have 11 here because there are 110 / 10 elements per key
+ 'stateful).beam_metric:statecache:get: 12',
+ 'stateful).beam_metric:statecache:miss: 1',
+ 'stateful).beam_metric:statecache:hit: 11',
+ # State is flushed back once per key
+ 'stateful).beam_metric:statecache:put: 1',
+ 'stateful).beam_metric:statecache:evict: 0',
+ # Counters
+ 'stateful).beam_metric:statecache:get_total: 120',
+ 'stateful).beam_metric:statecache:miss_total: 10',
+ 'stateful).beam_metric:statecache:hit_total: 110',
+ 'stateful).beam_metric:statecache:put_total: 10',
+ 'stateful).beam_metric:statecache:evict_total: 0',
+ ])
+ lines_actual = set()
+ with open(self.test_metrics_path, 'r') as f:
+ for line in f:
+ print(line, end='')
+ for metric_str in lines_expected:
+ metric_name = metric_str.split()[0]
+ if metric_str in line:
+ lines_actual.add(metric_str)
+ elif metric_name in line:
+ lines_actual.add(line)
+ self.assertSetEqual(lines_actual, lines_expected)
+
def test_sdf_with_watermark_tracking(self):
raise unittest.SkipTest("BEAM-2939")
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py
index abb356d5ff4..5aaadbd4387 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py
@@ -81,8 +81,7 @@ if TYPE_CHECKING:
# State caching is enabled in the fn_api_runner for testing, except for one
# test which runs without state caching (FnApiRunnerTestWithDisabledCaching).
# The cache is disabled in production for other runners.
-STATE_CACHE_SIZE_MB = 100
-MB_TO_BYTES = 1 << 20
+STATE_CACHE_SIZE = 100
# Time-based flush is enabled in the fn_api_runner by default.
DATA_BUFFER_TIME_LIMIT_MS = 1000
@@ -361,14 +360,16 @@ class EmbeddedWorkerHandler(WorkerHandler):
self, data_plane.InMemoryDataChannel(), state, provision_info)
self.control_conn = self # type: ignore # need Protocol to describe this
self.data_conn = self.data_plane_handler
- state_cache = StateCache(STATE_CACHE_SIZE_MB * MB_TO_BYTES)
+ state_cache = StateCache(STATE_CACHE_SIZE)
self.bundle_processor_cache = sdk_worker.BundleProcessorCache(
SingletonStateHandlerFactory(
sdk_worker.GlobalCachingStateHandler(state_cache, state)),
data_plane.InMemoryDataChannelFactory(
self.data_plane_handler.inverse()),
worker_manager._process_bundle_descriptors)
- self.worker = sdk_worker.SdkWorker(self.bundle_processor_cache)
+ self.worker = sdk_worker.SdkWorker(
+ self.bundle_processor_cache,
+ state_cache_metrics_fn=state_cache.get_monitoring_infos)
self._uid_counter = 0
def push(self, request):
@@ -652,8 +653,7 @@ class EmbeddedGrpcWorkerHandler(GrpcWorkerHandler):
from apache_beam.transforms.environments import EmbeddedPythonGrpcEnvironment
config = EmbeddedPythonGrpcEnvironment.parse_config(payload.decode('utf-8'))
- self._state_cache_size = (
- config.get('state_cache_size') or STATE_CACHE_SIZE_MB) << 20
+ self._state_cache_size = config.get('state_cache_size') or STATE_CACHE_SIZE
self._data_buffer_time_limit_ms = \
config.get('data_buffer_time_limit_ms') or DATA_BUFFER_TIME_LIMIT_MS
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index 968f213d6b2..562c3139739 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -61,7 +61,6 @@ from apache_beam.runners.worker import data_plane
from apache_beam.runners.worker import statesampler
from apache_beam.runners.worker.channel_factory import GRPCChannelFactory
from apache_beam.runners.worker.data_plane import PeriodicThread
-from apache_beam.runners.worker.statecache import CacheAware
from apache_beam.runners.worker.statecache import StateCache
from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor
from apache_beam.runners.worker.worker_status import FnApiWorkerStatusHandler
@@ -213,9 +212,7 @@ class SdkHarness(object):
if status_address:
try:
self._status_handler = FnApiWorkerStatusHandler(
- status_address,
- self._bundle_processor_cache,
- self._state_cache,
+ status_address, self._bundle_processor_cache,
enable_heap_dump) # type: Optional[FnApiWorkerStatusHandler]
except Exception:
traceback_string = traceback.format_exc()
@@ -366,7 +363,9 @@ class SdkHarness(object):
def create_worker(self):
# type: () -> SdkWorker
return SdkWorker(
- self._bundle_processor_cache, profiler_factory=self._profiler_factory)
+ self._bundle_processor_cache,
+ state_cache_metrics_fn=self._state_cache.get_monitoring_infos,
+ profiler_factory=self._profiler_factory)
class BundleProcessorCache(object):
@@ -582,10 +581,12 @@ class SdkWorker(object):
def __init__(
self,
bundle_processor_cache, # type: BundleProcessorCache
+ state_cache_metrics_fn=list, # type: Callable[[], Iterable[metrics_pb2.MonitoringInfo]]
profiler_factory=None, # type: Optional[Callable[..., Profile]]
):
# type: (...) -> None
self.bundle_processor_cache = bundle_processor_cache
+ self.state_cache_metrics_fn = state_cache_metrics_fn
self.profiler_factory = profiler_factory
def do_instruction(self, request):
@@ -633,6 +634,7 @@ class SdkWorker(object):
delayed_applications, requests_finalization = (
bundle_processor.process_bundle(instruction_id))
monitoring_infos = bundle_processor.monitoring_infos()
+ monitoring_infos.extend(self.state_cache_metrics_fn())
response = beam_fn_api_pb2.InstructionResponse(
instruction_id=instruction_id,
process_bundle=beam_fn_api_pb2.ProcessBundleResponse(
@@ -876,7 +878,7 @@ class GrpcStateHandlerFactory(StateHandlerFactory):
for _, state_handler in self._state_handler_cache.items():
state_handler.done()
self._state_handler_cache.clear()
- self._state_cache.invalidate_all()
+ self._state_cache.evict_all()
class CachingStateHandler(metaclass=abc.ABCMeta):
@@ -1128,6 +1130,7 @@ class GlobalCachingStateHandler(CachingStateHandler):
# for items cached at the bundle level.
self._context.bundle_cache_token = bundle_id
try:
+ self._state_cache.initialize_metrics()
self._context.user_state_cache_token = user_state_cache_token
with self._underlying.process_instruction_id(bundle_id):
yield
@@ -1287,7 +1290,7 @@ class GlobalCachingStateHandler(CachingStateHandler):
functools.partial(
self._lazy_iterator, state_key, coder, continuation_token))
- class ContinuationIterable(Generic[T], CacheAware):
+ class ContinuationIterable(Generic[T]):
def __init__(self, head, continue_iterator_fn):
# type: (Iterable[T], Callable[[], Iterable[T]]) -> None
self.head = head
@@ -1300,13 +1303,6 @@ class GlobalCachingStateHandler(CachingStateHandler):
for item in self.continue_iterator_fn():
yield item
- def get_referents_for_cache(self):
- # type: () -> List[Any]
- # Only capture the size of the elements and not the
- # continuation iterator since it references objects
- # we don't want to include in the cache measurement.
- return [self.head]
-
@staticmethod
def _convert_to_cache_key(state_key):
# type: (beam_fn_api_pb2.StateKey) -> bytes
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
index ec90a30ceca..53cdbad5d71 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
@@ -219,8 +219,8 @@ def _get_state_cache_size(experiments):
future releases.
Returns:
- an int indicating the maximum number of megabytes to cache.
- Default is 0 MB
+ an int indicating the maximum number of items to cache.
+ Default is 0 (disabled)
"""
for experiment in experiments:
@@ -228,7 +228,7 @@ def _get_state_cache_size(experiments):
if re.match(r'state_cache_size=', experiment):
return int(
re.match(r'state_cache_size=(?P<state_cache_size>.*)',
- experiment).group('state_cache_size')) << 20
+ experiment).group('state_cache_size'))
return 0
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
index 05263aee96c..d7309c149e4 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
@@ -294,7 +294,7 @@ class CachingStateHandlerTest(unittest.TestCase):
yield
underlying_state = FakeUnderlyingState()
- state_cache = statecache.StateCache(100 << 20)
+ state_cache = statecache.StateCache(100)
caching_state_hander = GlobalCachingStateHandler(
state_cache, underlying_state)
@@ -430,7 +430,7 @@ class CachingStateHandlerTest(unittest.TestCase):
coder = VarIntCoder()
underlying_state_handler = self.UnderlyingStateHandler()
- state_cache = statecache.StateCache(100 << 20)
+ state_cache = statecache.StateCache(100)
handler = GlobalCachingStateHandler(state_cache, underlying_state_handler)
def get():
@@ -460,7 +460,7 @@ class CachingStateHandlerTest(unittest.TestCase):
def test_continuation_token(self):
underlying_state_handler = self.UnderlyingStateHandler()
- state_cache = statecache.StateCache(100 << 20)
+ state_cache = statecache.StateCache(100)
handler = GlobalCachingStateHandler(state_cache, underlying_state_handler)
coder = VarIntCoder()
diff --git a/sdks/python/apache_beam/runners/worker/statecache.py b/sdks/python/apache_beam/runners/worker/statecache.py
index 733bcbbf235..cebfbc7a096 100644
--- a/sdks/python/apache_beam/runners/worker/statecache.py
+++ b/sdks/python/apache_beam/runners/worker/statecache.py
@@ -20,179 +20,245 @@
# mypy: disallow-untyped-defs
import collections
-import gc
import logging
import threading
+from typing import TYPE_CHECKING
from typing import Any
+from typing import Callable
+from typing import Generic
+from typing import Hashable
from typing import List
from typing import Optional
+from typing import Set
from typing import Tuple
+from typing import TypeVar
-import objsize
+from apache_beam.metrics import monitoring_infos
-_LOGGER = logging.getLogger(__name__)
+if TYPE_CHECKING:
+ from apache_beam.portability.api import metrics_pb2
+_LOGGER = logging.getLogger(__name__)
-class WeightedValue(object):
- """Value type that stores corresponding weight.
+CallableT = TypeVar('CallableT', bound='Callable')
+KT = TypeVar('KT')
+VT = TypeVar('VT')
- :arg value The value to be stored.
- :arg weight The associated weight of the value. If unspecified, the objects
- size will be used.
- """
- def __init__(self, value, weight):
- # type: (Any, int) -> None
- self._value = value
- if weight <= 0:
- raise ValueError(
- 'Expected weight to be > 0 for %s but received %d' % (value, weight))
- self._weight = weight
-
- def weight(self):
- # type: () -> int
- return self._weight
- def value(self):
- # type: () -> Any
- return self._value
+class Metrics(object):
+ """Metrics container for state cache metrics."""
+ # A set of all registered metrics
+ ALL_METRICS = set() # type: Set[Hashable]
+ PREFIX = "beam:metric:statecache:"
-class CacheAware(object):
def __init__(self):
# type: () -> None
- pass
+ self._context = threading.local()
- def get_referents_for_cache(self):
- # type: () -> List[Any]
+ def initialize(self):
+ # type: () -> None
- """Returns the list of objects accounted during cache measurement."""
- raise NotImplementedError()
+ """Needs to be called once per thread to initialize the local metrics cache.
+ """
+ if hasattr(self._context, 'metrics'):
+ return # Already initialized
+ self._context.metrics = collections.defaultdict(int)
+ def count(self, name):
+ # type: (str) -> None
+ self._context.metrics[name] += 1
-def get_referents_for_cache(*objs):
- # type: (List[Any]) -> List[Any]
+ def hit_miss(self, total_name, hit_miss_name):
+ # type: (str, str) -> None
+ self._context.metrics[total_name] += 1
+ self._context.metrics[hit_miss_name] += 1
- """Returns the list of objects accounted during cache measurement.
+ def get_monitoring_infos(self, cache_size, cache_capacity):
+ # type: (int, int) -> List[metrics_pb2.MonitoringInfo]
- Users can inherit CacheAware to override which referrents should be
- used when measuring the deep size of the object. The default is to
- use gc.get_referents(*objs).
- """
- rval = []
- for obj in objs:
- if isinstance(obj, CacheAware):
- rval.extend(obj.get_referents_for_cache())
- else:
- rval.extend(gc.get_referents(obj))
- return rval
+ """Returns the metrics scoped to the current bundle."""
+ metrics = self._context.metrics
+ if len(metrics) == 0:
+ # No metrics collected, do not report
+ return []
+ # Add all missing metrics which were not reported
+ for key in Metrics.ALL_METRICS:
+ if key not in metrics:
+ metrics[key] = 0
+ # Gauges which reflect the state since last queried
+ gauges = [
+ monitoring_infos.int64_gauge(self.PREFIX + name, val) for name,
+ val in metrics.items()
+ ]
+ gauges.append(
+ monitoring_infos.int64_gauge(self.PREFIX + 'size', cache_size))
+ gauges.append(
+ monitoring_infos.int64_gauge(self.PREFIX + 'capacity', cache_capacity))
+ # Counters for the summary across all metrics
+ counters = [
+ monitoring_infos.int64_counter(self.PREFIX + name + '_total', val)
+ for name,
+ val in metrics.items()
+ ]
+ # Reinitialize metrics for this thread/bundle
+ metrics.clear()
+ return gauges + counters
+
+ @staticmethod
+ def counter_hit_miss(total_name, hit_name, miss_name):
+ # type: (str, str, str) -> Callable[[CallableT], CallableT]
+
+ """Decorator for counting function calls and whether
+ the return value equals None (=miss) or not (=hit)."""
+ Metrics.ALL_METRICS.update([total_name, hit_name, miss_name])
+
+ def decorator(function):
+ # type: (CallableT) -> CallableT
+ def reporter(self, *args, **kwargs):
+ # type: (StateCache, Any, Any) -> Any
+ value = function(self, *args, **kwargs)
+ if value is None:
+ self._metrics.hit_miss(total_name, miss_name)
+ else:
+ self._metrics.hit_miss(total_name, hit_name)
+ return value
+
+ return reporter # type: ignore[return-value]
+
+ return decorator
+
+ @staticmethod
+ def counter(metric_name):
+ # type: (str) -> Callable[[CallableT], CallableT]
+
+ """Decorator for counting function calls."""
+ Metrics.ALL_METRICS.add(metric_name)
+
+ def decorator(function):
+ # type: (CallableT) -> CallableT
+ def reporter(self, *args, **kwargs):
+ # type: (StateCache, Any, Any) -> Any
+ self._metrics.count(metric_name)
+ return function(self, *args, **kwargs)
+
+ return reporter # type: ignore[return-value]
+
+ return decorator
class StateCache(object):
- """Cache for Beam state access, scoped by state key and cache_token.
- Assumes a bag state implementation.
+ """ Cache for Beam state access, scoped by state key and cache_token.
+ Assumes a bag state implementation.
- For a given state_key and cache_token, caches a value and allows to
+ For a given state_key, caches a (cache_token, value) tuple and allows to
a) read from the cache (get),
if the currently stored cache_token matches the provided
- b) write to the cache (put),
+ a) write to the cache (put),
storing the new value alongside with a cache token
+ c) append to the currently cache item (extend),
+ if the currently stored cache_token matches the provided
c) empty a cached element (clear),
if the currently stored cache_token matches the provided
- d) invalidate a cached element (invalidate)
- e) invalidate all cached elements (invalidate_all)
+ d) evict a cached element (evict)
The operations on the cache are thread-safe for use by multiple workers.
- :arg max_weight The maximum weight of entries to store in the cache in bytes.
+ :arg max_entries The maximum number of entries to store in the cache.
+ TODO Memory-based caching: https://github.com/apache/beam/issues/19857
"""
- def __init__(self, max_weight):
+ def __init__(self, max_entries):
# type: (int) -> None
- _LOGGER.info('Creating state cache with size %s', max_weight)
- self._max_weight = max_weight
- self._current_weight = 0
- self._cache = collections.OrderedDict(
- ) # type: collections.OrderedDict[Tuple[bytes, Optional[bytes]], WeightedValue]
- self._hit_count = 0
- self._miss_count = 0
- self._evict_count = 0
+ _LOGGER.info('Creating state cache with size %s', max_entries)
+ self._missing = None
+ self._cache = self.LRUCache[Tuple[bytes, Optional[bytes]],
+ Any](max_entries, self._missing)
self._lock = threading.RLock()
+ self._metrics = Metrics()
+ @Metrics.counter_hit_miss("get", "hit", "miss")
def get(self, state_key, cache_token):
# type: (bytes, Optional[bytes]) -> Any
assert cache_token and self.is_cache_enabled()
- key = (state_key, cache_token)
with self._lock:
- value = self._cache.get(key, None)
- if value is None:
- self._miss_count += 1
- return None
- self._cache.move_to_end(key)
- self._hit_count += 1
- return value.value()
+ return self._cache.get((state_key, cache_token))
+ @Metrics.counter("put")
def put(self, state_key, cache_token, value):
# type: (bytes, Optional[bytes], Any) -> None
assert cache_token and self.is_cache_enabled()
- if not isinstance(value, WeightedValue):
- weight = objsize.get_deep_size(
- value, get_referents_func=get_referents_for_cache)
- if weight <= 0:
- _LOGGER.warning(
- 'Expected object size to be >= 0 for %s but received %d.',
- value,
- weight)
- weight = 8
- value = WeightedValue(value, weight)
- key = (state_key, cache_token)
with self._lock:
- old_value = self._cache.pop(key, None)
- if old_value is not None:
- self._current_weight -= old_value.weight()
- self._cache[(state_key, cache_token)] = value
- self._current_weight += value.weight()
- while self._current_weight > self._max_weight:
- (_, weighted_value) = self._cache.popitem(last=False)
- self._current_weight -= weighted_value.weight()
- self._evict_count += 1
+ return self._cache.put((state_key, cache_token), value)
+ @Metrics.counter("clear")
def clear(self, state_key, cache_token):
# type: (bytes, Optional[bytes]) -> None
- self.put(state_key, cache_token, [])
+ assert cache_token and self.is_cache_enabled()
+ with self._lock:
+ self._cache.put((state_key, cache_token), [])
- def invalidate(self, state_key, cache_token):
+ @Metrics.counter("evict")
+ def evict(self, state_key, cache_token):
# type: (bytes, Optional[bytes]) -> None
assert self.is_cache_enabled()
with self._lock:
- weighted_value = self._cache.pop((state_key, cache_token), None)
- if weighted_value is not None:
- self._current_weight -= weighted_value.weight()
+ self._cache.evict((state_key, cache_token))
- def invalidate_all(self):
+ def evict_all(self):
# type: () -> None
with self._lock:
- self._cache.clear()
- self._current_weight = 0
+ self._cache.evict_all()
- def describe_stats(self):
- # type: () -> str
- with self._lock:
- request_count = self._hit_count + self._miss_count
- if request_count > 0:
- hit_ratio = 100.0 * self._hit_count / request_count
- else:
- hit_ratio = 100.0
- return 'used/max %d/%d MB, hit %.2f%%, lookups %d, evictions %d' % (
- self._current_weight >> 20,
- self._max_weight >> 20,
- hit_ratio,
- request_count,
- self._evict_count)
+ def initialize_metrics(self):
+ # type: () -> None
+ self._metrics.initialize()
def is_cache_enabled(self):
# type: () -> bool
- return self._max_weight > 0
+ return self._cache._max_entries > 0
def size(self):
# type: () -> int
+ return len(self._cache)
+
+ def get_monitoring_infos(self):
+ # type: () -> List[metrics_pb2.MonitoringInfo]
+
+ """Retrieves the monitoring infos and resets the counters."""
with self._lock:
+ size = len(self._cache)
+ capacity = self._cache._max_entries
+ return self._metrics.get_monitoring_infos(size, capacity)
+
+ class LRUCache(Generic[KT, VT]):
+ def __init__(self, max_entries, default_entry):
+ # type: (int, VT) -> None
+ self._max_entries = max_entries
+ self._default_entry = default_entry
+ self._cache = collections.OrderedDict(
+ ) # type: collections.OrderedDict[KT, VT]
+
+ def get(self, key):
+ # type: (KT) -> VT
+ value = self._cache.pop(key, self._default_entry)
+ if value != self._default_entry:
+ self._cache[key] = value
+ return value
+
+ def put(self, key, value):
+ # type: (KT, VT) -> None
+ self._cache[key] = value
+ while len(self._cache) > self._max_entries:
+ self._cache.popitem(last=False)
+
+ def evict(self, key):
+ # type: (KT) -> None
+ self._cache.pop(key, self._default_entry)
+
+ def evict_all(self):
+ # type: () -> None
+ self._cache.clear()
+
+ def __len__(self):
+ # type: () -> int
return len(self._cache)
diff --git a/sdks/python/apache_beam/runners/worker/statecache_test.py b/sdks/python/apache_beam/runners/worker/statecache_test.py
index 9bd952721ea..a1a175ed347 100644
--- a/sdks/python/apache_beam/runners/worker/statecache_test.py
+++ b/sdks/python/apache_beam/runners/worker/statecache_test.py
@@ -21,38 +21,55 @@
import logging
import unittest
-from apache_beam.runners.worker.statecache import CacheAware
+from apache_beam.metrics import monitoring_infos
from apache_beam.runners.worker.statecache import StateCache
-from apache_beam.runners.worker.statecache import WeightedValue
class StateCacheTest(unittest.TestCase):
def test_empty_cache_get(self):
- cache = StateCache(5 << 20)
+ cache = self.get_cache(5)
self.assertEqual(cache.get("key", 'cache_token'), None)
with self.assertRaises(Exception):
# Invalid cache token provided
self.assertEqual(cache.get("key", None), None)
- self.assertEqual(
- cache.describe_stats(),
- 'used/max 0/5 MB, hit 0.00%, lookups 1, evictions 0')
+ self.verify_metrics(
+ cache,
+ {
+ 'get': 1,
+ 'put': 0,
+ 'miss': 1,
+ 'hit': 0,
+ 'clear': 0,
+ 'evict': 0,
+ 'size': 0,
+ 'capacity': 5
+ })
def test_put_get(self):
- cache = StateCache(5 << 20)
- cache.put("key", "cache_token", WeightedValue("value", 1 << 20))
+ cache = self.get_cache(5)
+ cache.put("key", "cache_token", "value")
self.assertEqual(cache.size(), 1)
self.assertEqual(cache.get("key", "cache_token"), "value")
self.assertEqual(cache.get("key", "cache_token2"), None)
with self.assertRaises(Exception):
self.assertEqual(cache.get("key", None), None)
- self.assertEqual(
- cache.describe_stats(),
- 'used/max 1/5 MB, hit 50.00%, lookups 2, evictions 0')
+ self.verify_metrics(
+ cache,
+ {
+ 'get': 2,
+ 'put': 1,
+ 'miss': 1,
+ 'hit': 1,
+ 'clear': 0,
+ 'evict': 0,
+ 'size': 1,
+ 'capacity': 5
+ })
def test_clear(self):
- cache = StateCache(5 << 20)
+ cache = self.get_cache(5)
cache.clear("new-key", "cache_token")
- cache.put("key", "cache_token", WeightedValue(["value"], 1 << 20))
+ cache.put("key", "cache_token", ["value"])
self.assertEqual(cache.size(), 2)
self.assertEqual(cache.get("new-key", "new_token"), None)
self.assertEqual(cache.get("key", "cache_token"), ['value'])
@@ -60,58 +77,72 @@ class StateCacheTest(unittest.TestCase):
cache.clear("non-existing", "token")
self.assertEqual(cache.size(), 3)
self.assertEqual(cache.get("non-existing", "token"), [])
- self.assertEqual(
- cache.describe_stats(),
- 'used/max 1/5 MB, hit 66.67%, lookups 3, evictions 0')
-
- def test_default_sized_put(self):
- cache = StateCache(5 << 20)
- cache.put("key", "cache_token", bytearray(1 << 20))
- cache.put("key2", "cache_token", bytearray(1 << 20))
- cache.put("key3", "cache_token", bytearray(1 << 20))
- self.assertEqual(cache.get("key3", "cache_token"), bytearray(1 << 20))
- cache.put("key4", "cache_token", bytearray(1 << 20))
- cache.put("key5", "cache_token", bytearray(1 << 20))
- # note that each byte array instance takes slightly over 1 MB which is why
- # these 5 byte arrays can't all be stored in the cache causing a single
- # eviction
- self.assertEqual(
- cache.describe_stats(),
- 'used/max 4/5 MB, hit 100.00%, lookups 1, evictions 1')
+ self.verify_metrics(
+ cache,
+ {
+ 'get': 3,
+ 'put': 1,
+ 'miss': 1,
+ 'hit': 2,
+ 'clear': 2,
+ 'evict': 0,
+ 'size': 3,
+ 'capacity': 5
+ })
def test_max_size(self):
- cache = StateCache(2 << 20)
- cache.put("key", "cache_token", WeightedValue("value", 1 << 20))
- cache.put("key2", "cache_token", WeightedValue("value2", 1 << 20))
+ cache = self.get_cache(2)
+ cache.put("key", "cache_token", "value")
+ cache.put("key2", "cache_token", "value")
self.assertEqual(cache.size(), 2)
- cache.put("key3", "cache_token", WeightedValue("value3", 1 << 20))
+ cache.put("key2", "cache_token", "value")
self.assertEqual(cache.size(), 2)
- self.assertEqual(
- cache.describe_stats(),
- 'used/max 2/2 MB, hit 100.00%, lookups 0, evictions 1')
-
- def test_invalidate_all(self):
- cache = StateCache(5 << 20)
- cache.put("key", "cache_token", WeightedValue("value", 1 << 20))
- cache.put("key2", "cache_token", WeightedValue("value2", 1 << 20))
+ cache.put("key", "cache_token", "value")
+ self.assertEqual(cache.size(), 2)
+ self.verify_metrics(
+ cache,
+ {
+ 'get': 0,
+ 'put': 4,
+ 'miss': 0,
+ 'hit': 0,
+ 'clear': 0,
+ 'evict': 0,
+ 'size': 2,
+ 'capacity': 2
+ })
+
+ def test_evict_all(self):
+ cache = self.get_cache(5)
+ cache.put("key", "cache_token", "value")
+ cache.put("key2", "cache_token", "value2")
self.assertEqual(cache.size(), 2)
- cache.invalidate_all()
+ cache.evict_all()
self.assertEqual(cache.size(), 0)
self.assertEqual(cache.get("key", "cache_token"), None)
self.assertEqual(cache.get("key2", "cache_token"), None)
- self.assertEqual(
- cache.describe_stats(),
- 'used/max 0/5 MB, hit 0.00%, lookups 2, evictions 0')
+ self.verify_metrics(
+ cache,
+ {
+ 'get': 2,
+ 'put': 2,
+ 'miss': 2,
+ 'hit': 0,
+ 'clear': 0,
+ 'evict': 0,
+ 'size': 0,
+ 'capacity': 5
+ })
def test_lru(self):
- cache = StateCache(5 << 20)
- cache.put("key", "cache_token", WeightedValue("value", 1 << 20))
- cache.put("key2", "cache_token2", WeightedValue("value2", 1 << 20))
- cache.put("key3", "cache_token", WeightedValue("value0", 1 << 20))
- cache.put("key3", "cache_token", WeightedValue("value3", 1 << 20))
- cache.put("key4", "cache_token4", WeightedValue("value4", 1 << 20))
- cache.put("key5", "cache_token", WeightedValue("value0", 1 << 20))
- cache.put("key5", "cache_token", WeightedValue(["value5"], 1 << 20))
+ cache = self.get_cache(5)
+ cache.put("key", "cache_token", "value")
+ cache.put("key2", "cache_token2", "value2")
+ cache.put("key3", "cache_token", "value0")
+ cache.put("key3", "cache_token", "value3")
+ cache.put("key4", "cache_token4", "value4")
+ cache.put("key5", "cache_token", "value0")
+ cache.put("key5", "cache_token", ["value5"])
self.assertEqual(cache.size(), 5)
self.assertEqual(cache.get("key", "cache_token"), "value")
self.assertEqual(cache.get("key2", "cache_token2"), "value2")
@@ -119,59 +150,80 @@ class StateCacheTest(unittest.TestCase):
self.assertEqual(cache.get("key4", "cache_token4"), "value4")
self.assertEqual(cache.get("key5", "cache_token"), ["value5"])
# insert another key to trigger cache eviction
- cache.put("key6", "cache_token2", WeightedValue("value6", 1 << 20))
+ cache.put("key6", "cache_token2", "value7")
self.assertEqual(cache.size(), 5)
# least recently used key should be gone ("key")
self.assertEqual(cache.get("key", "cache_token"), None)
# trigger a read on "key2"
cache.get("key2", "cache_token2")
# insert another key to trigger cache eviction
- cache.put("key7", "cache_token", WeightedValue("value7", 1 << 20))
+ cache.put("key7", "cache_token", "value7")
self.assertEqual(cache.size(), 5)
# least recently used key should be gone ("key3")
self.assertEqual(cache.get("key3", "cache_token"), None)
# trigger a put on "key2"
- cache.put("key2", "cache_token", WeightedValue("put", 1 << 20))
+ cache.put("key2", "cache_token", "put")
self.assertEqual(cache.size(), 5)
# insert another key to trigger cache eviction
- cache.put("key8", "cache_token", WeightedValue("value8", 1 << 20))
+ cache.put("key8", "cache_token", "value8")
self.assertEqual(cache.size(), 5)
# least recently used key should be gone ("key4")
self.assertEqual(cache.get("key4", "cache_token"), None)
# make "key5" used by writing to it
- cache.put("key5", "cache_token", WeightedValue("val", 1 << 20))
+ cache.put("key5", "cache_token", "val")
# least recently used key should be gone ("key6")
self.assertEqual(cache.get("key6", "cache_token"), None)
- self.assertEqual(
- cache.describe_stats(),
- 'used/max 5/5 MB, hit 60.00%, lookups 10, evictions 5')
+ self.verify_metrics(
+ cache,
+ {
+ 'get': 10,
+ 'put': 12,
+ 'miss': 4,
+ 'hit': 6,
+ 'clear': 0,
+ 'evict': 0,
+ 'size': 5,
+ 'capacity': 5
+ })
def test_is_cached_enabled(self):
- cache = StateCache(1 << 20)
+ cache = self.get_cache(1)
self.assertEqual(cache.is_cache_enabled(), True)
- self.assertEqual(
- cache.describe_stats(),
- 'used/max 0/1 MB, hit 100.00%, lookups 0, evictions 0')
- cache = StateCache(0)
+ self.verify_metrics(cache, {})
+ cache = self.get_cache(0)
self.assertEqual(cache.is_cache_enabled(), False)
- self.assertEqual(
- cache.describe_stats(),
- 'used/max 0/0 MB, hit 100.00%, lookups 0, evictions 0')
-
- def test_get_referents_for_cache(self):
- class GetReferentsForCache(CacheAware):
- def __init__(self):
- self.measure_me = bytearray(1 << 20)
- self.ignore_me = bytearray(2 << 20)
-
- def get_referents_for_cache(self):
- return [self.measure_me]
-
- cache = StateCache(5 << 20)
- cache.put("key", "cache_token", GetReferentsForCache())
- self.assertEqual(
- cache.describe_stats(),
- 'used/max 1/5 MB, hit 100.00%, lookups 0, evictions 0')
+ self.verify_metrics(cache, {})
+
+ def verify_metrics(self, cache, expected_metrics):
+ infos = cache.get_monitoring_infos()
+ # Reconstruct metrics dictionary from monitoring infos
+ metrics = {
+ info.urn.rsplit(':',
+ 1)[1]: monitoring_infos.extract_gauge_value(info)[1]
+ for info in infos if "_total" not in info.urn and
+ info.type == monitoring_infos.LATEST_INT64_TYPE
+ }
+ self.assertDictEqual(metrics, expected_metrics)
+ # Metrics and total metrics should be identical for a single bundle.
+ # The following two gauges are not part of the total metrics:
+ try:
+ del metrics['capacity']
+ del metrics['size']
+ except KeyError:
+ pass
+ total_metrics = {
+ info.urn.rsplit(':', 1)[1].rsplit("_total")[0]:
+ monitoring_infos.extract_counter_value(info)
+ for info in infos
+ if "_total" in info.urn and info.type == monitoring_infos.SUM_INT64_TYPE
+ }
+ self.assertDictEqual(metrics, total_metrics)
+
+ @staticmethod
+ def get_cache(size):
+ cache = StateCache(size)
+ cache.initialize_metrics()
+ return cache
if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/runners/worker/worker_status.py b/sdks/python/apache_beam/runners/worker/worker_status.py
index 7604bd0867a..652b01d1e4e 100644
--- a/sdks/python/apache_beam/runners/worker/worker_status.py
+++ b/sdks/python/apache_beam/runners/worker/worker_status.py
@@ -30,7 +30,6 @@ import grpc
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import beam_fn_api_pb2_grpc
from apache_beam.runners.worker.channel_factory import GRPCChannelFactory
-from apache_beam.runners.worker.statecache import StateCache
from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor
from apache_beam.utils.sentinel import Sentinel
@@ -96,18 +95,6 @@ def heap_dump():
return banner + heap + ending
-def _state_cache_stats(state_cache):
- #type: (StateCache) -> str
-
- """Gather state cache statistics."""
- cache_stats = ['=' * 10 + ' CACHE STATS ' + '=' * 10]
- if not state_cache.is_cache_enabled():
- cache_stats.append("Cache disabled")
- else:
- cache_stats.append(state_cache.describe_stats())
- return '\n'.join(cache_stats)
-
-
def _active_processing_bundles_state(bundle_process_cache):
"""Gather information about the currently in-processing active bundles.
@@ -151,7 +138,6 @@ class FnApiWorkerStatusHandler(object):
self,
status_address,
bundle_process_cache=None,
- state_cache=None,
enable_heap_dump=False,
log_lull_timeout_ns=DEFAULT_LOG_LULL_TIMEOUT_NS):
"""Initialize FnApiWorkerStatusHandler.
@@ -159,11 +145,9 @@ class FnApiWorkerStatusHandler(object):
Args:
status_address: The URL Runner uses to host the WorkerStatus server.
bundle_process_cache: The BundleProcessor cache dict from sdk worker.
- state_cache: The StateCache form sdk worker.
"""
self._alive = True
self._bundle_process_cache = bundle_process_cache
- self._state_cache = state_cache
ch = GRPCChannelFactory.insecure_channel(status_address)
grpc.channel_ready_future(ch).result(timeout=60)
self._status_channel = grpc.intercept_channel(ch, WorkerIdInterceptor())
@@ -209,14 +193,9 @@ class FnApiWorkerStatusHandler(object):
"status page: %s" % traceback_string))
def generate_status_response(self):
- all_status_sections = []
-
- if self._state_cache:
- all_status_sections.append(_state_cache_stats(self._state_cache))
-
- if self._bundle_process_cache:
- all_status_sections.append(
- _active_processing_bundles_state(self._bundle_process_cache))
+ all_status_sections = [
+ _active_processing_bundles_state(self._bundle_process_cache)
+ ] if self._bundle_process_cache else []
all_status_sections.append(thread_dump())
if self._enable_heap_dump:
diff --git a/sdks/python/container/py37/base_image_requirements.txt b/sdks/python/container/py37/base_image_requirements.txt
index 898d4d8dab3..392a06f39d0 100644
--- a/sdks/python/container/py37/base_image_requirements.txt
+++ b/sdks/python/container/py37/base_image_requirements.txt
@@ -96,7 +96,6 @@ nose==1.3.7
numpy==1.21.6
oauth2client==4.1.3
oauthlib==3.2.0
-objsize==0.5.1
opt-einsum==3.3.0
orjson==3.7.11
overrides==6.1.0
diff --git a/sdks/python/container/py38/base_image_requirements.txt b/sdks/python/container/py38/base_image_requirements.txt
index 1539b561487..24a81d1dcaf 100644
--- a/sdks/python/container/py38/base_image_requirements.txt
+++ b/sdks/python/container/py38/base_image_requirements.txt
@@ -96,7 +96,6 @@ nose==1.3.7
numpy==1.22.4
oauth2client==4.1.3
oauthlib==3.2.0
-objsize==0.5.1
opt-einsum==3.3.0
orjson==3.7.11
overrides==6.1.0
diff --git a/sdks/python/container/py39/base_image_requirements.txt b/sdks/python/container/py39/base_image_requirements.txt
index 2861b6ddb96..67de8d2b356 100644
--- a/sdks/python/container/py39/base_image_requirements.txt
+++ b/sdks/python/container/py39/base_image_requirements.txt
@@ -96,7 +96,6 @@ nose==1.3.7
numpy==1.22.4
oauth2client==4.1.3
oauthlib==3.2.0
-objsize==0.5.1
opt-einsum==3.3.0
orjson==3.7.11
overrides==6.1.0
diff --git a/sdks/python/setup.py b/sdks/python/setup.py
index fa02a658612..f0863bf49cd 100644
--- a/sdks/python/setup.py
+++ b/sdks/python/setup.py
@@ -223,7 +223,6 @@ if __name__ == '__main__':
'hdfs>=2.1.0,<3.0.0',
'httplib2>=0.8,<0.21.0',
'numpy>=1.14.3,<1.23.0',
- 'objsize>=0.5.1,<1',
'pymongo>=3.8.0,<4.0.0',
'protobuf>=3.12.2,<4',
'proto-plus>=1.7.1,<2',