You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lc...@apache.org on 2022/09/03 23:34:50 UTC

[beam] branch master updated: [#19857] Migrate to using a memory aware cache within the Python SDK harness (#22924)

This is an automated email from the ASF dual-hosted git repository.

lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 25c6ed74c98 [#19857] Migrate to using a memory aware cache within the Python SDK harness (#22924)
25c6ed74c98 is described below

commit 25c6ed74c9846c89a92655c1e8d313ef87d6adb1
Author: Luke Cwik <lc...@google.com>
AuthorDate: Sat Sep 3 16:34:43 2022 -0700

    [#19857] Migrate to using a memory aware cache within the Python SDK harness (#22924)
    
    * [#19857] Migrate to using a memory aware cache within the Python SDK harness
    
    This relies on getting the deep object size. objsize seemed like an appropriate library to consume over the larger and more complex pympler library.
    
    We get rid of the state cache metrics in favor of plugging in with the status client output since the state cache metrics are not defined within the monitoring.proto as a well defined type. In the future we may re-introduce a similar set of metrics.
    
    * Address PR comments
---
 .../runners/portability/flink_runner_test.py       |  93 -------
 .../portability/fn_api_runner/worker_handlers.py   |  12 +-
 .../apache_beam/runners/worker/sdk_worker.py       |  24 +-
 .../apache_beam/runners/worker/sdk_worker_main.py  |   6 +-
 .../apache_beam/runners/worker/sdk_worker_test.py  |   6 +-
 .../apache_beam/runners/worker/statecache.py       | 286 ++++++++-------------
 .../apache_beam/runners/worker/statecache_test.py  | 224 +++++++---------
 .../apache_beam/runners/worker/worker_status.py    |  27 +-
 .../container/py37/base_image_requirements.txt     |   1 +
 .../container/py38/base_image_requirements.txt     |   1 +
 .../container/py39/base_image_requirements.txt     |   1 +
 sdks/python/setup.py                               |   1 +
 12 files changed, 250 insertions(+), 432 deletions(-)

diff --git a/sdks/python/apache_beam/runners/portability/flink_runner_test.py b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
index 48e5df54d2d..27e4ca4973e 100644
--- a/sdks/python/apache_beam/runners/portability/flink_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
@@ -32,12 +32,9 @@ import pytest
 import apache_beam as beam
 from apache_beam import Impulse
 from apache_beam import Map
-from apache_beam import Pipeline
-from apache_beam.coders import VarIntCoder
 from apache_beam.io.external.generate_sequence import GenerateSequence
 from apache_beam.io.kafka import ReadFromKafka
 from apache_beam.io.kafka import WriteToKafka
-from apache_beam.metrics import Metrics
 from apache_beam.options.pipeline_options import DebugOptions
 from apache_beam.options.pipeline_options import FlinkRunnerOptions
 from apache_beam.options.pipeline_options import PortableOptions
@@ -47,7 +44,6 @@ from apache_beam.runners.portability import portable_runner
 from apache_beam.runners.portability import portable_runner_test
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
-from apache_beam.transforms import userstate
 from apache_beam.transforms.sql import SqlTransform
 
 # Run as
@@ -296,95 +292,6 @@ class FlinkRunnerTest(portable_runner_test.PortableRunnerTest):
   def test_metrics(self):
     super().test_metrics(check_gauge=False)
 
-  def test_flink_metrics(self):
-    """Run a simple DoFn that increments a counter and verifies state
-    caching metrics. Verifies that its expected value is written to a
-    temporary file by the FileReporter"""
-
-    counter_name = 'elem_counter'
-    state_spec = userstate.BagStateSpec('state', VarIntCoder())
-
-    class DoFn(beam.DoFn):
-      def __init__(self):
-        self.counter = Metrics.counter(self.__class__, counter_name)
-        _LOGGER.info('counter: %s' % self.counter.metric_name)
-
-      def process(self, kv, state=beam.DoFn.StateParam(state_spec)):
-        # Trigger materialization
-        list(state.read())
-        state.add(1)
-        self.counter.inc()
-
-    options = self.create_options()
-    # Test only supports parallelism of 1
-    options._all_options['parallelism'] = 1
-    # Create multiple bundles to test cache metrics
-    options._all_options['max_bundle_size'] = 10
-    options._all_options['max_bundle_time_millis'] = 95130590130
-    experiments = options.view_as(DebugOptions).experiments or []
-    experiments.append('state_cache_size=123')
-    options.view_as(DebugOptions).experiments = experiments
-    with Pipeline(self.get_runner(), options) as p:
-      # pylint: disable=expression-not-assigned
-      (
-          p
-          | "create" >> beam.Create(list(range(0, 110)))
-          | "mapper" >> beam.Map(lambda x: (x % 10, 'val'))
-          | "stateful" >> beam.ParDo(DoFn()))
-
-    lines_expected = {'counter: 110'}
-    if options.view_as(StandardOptions).streaming:
-      lines_expected.update([
-          # Gauges for the last finished bundle
-          'stateful.beam_metric:statecache:capacity: 123',
-          'stateful.beam_metric:statecache:size: 10',
-          'stateful.beam_metric:statecache:get: 20',
-          'stateful.beam_metric:statecache:miss: 0',
-          'stateful.beam_metric:statecache:hit: 20',
-          'stateful.beam_metric:statecache:put: 0',
-          'stateful.beam_metric:statecache:evict: 0',
-          # Counters
-          'stateful.beam_metric:statecache:get_total: 220',
-          'stateful.beam_metric:statecache:miss_total: 10',
-          'stateful.beam_metric:statecache:hit_total: 210',
-          'stateful.beam_metric:statecache:put_total: 10',
-          'stateful.beam_metric:statecache:evict_total: 0',
-      ])
-    else:
-      # Batch has a different processing model. All values for
-      # a key are processed at once.
-      lines_expected.update([
-          # Gauges
-          'stateful).beam_metric:statecache:capacity: 123',
-          # For the first key, the cache token will not be set yet.
-          # It's lazily initialized after first access in StateRequestHandlers
-          'stateful).beam_metric:statecache:size: 10',
-          # We have 11 here because there are 110 / 10 elements per key
-          'stateful).beam_metric:statecache:get: 12',
-          'stateful).beam_metric:statecache:miss: 1',
-          'stateful).beam_metric:statecache:hit: 11',
-          # State is flushed back once per key
-          'stateful).beam_metric:statecache:put: 1',
-          'stateful).beam_metric:statecache:evict: 0',
-          # Counters
-          'stateful).beam_metric:statecache:get_total: 120',
-          'stateful).beam_metric:statecache:miss_total: 10',
-          'stateful).beam_metric:statecache:hit_total: 110',
-          'stateful).beam_metric:statecache:put_total: 10',
-          'stateful).beam_metric:statecache:evict_total: 0',
-      ])
-    lines_actual = set()
-    with open(self.test_metrics_path, 'r') as f:
-      for line in f:
-        print(line, end='')
-        for metric_str in lines_expected:
-          metric_name = metric_str.split()[0]
-          if metric_str in line:
-            lines_actual.add(metric_str)
-          elif metric_name in line:
-            lines_actual.add(line)
-    self.assertSetEqual(lines_actual, lines_expected)
-
   def test_sdf_with_watermark_tracking(self):
     raise unittest.SkipTest("BEAM-2939")
 
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py
index 5aaadbd4387..abb356d5ff4 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py
@@ -81,7 +81,8 @@ if TYPE_CHECKING:
 # State caching is enabled in the fn_api_runner for testing, except for one
 # test which runs without state caching (FnApiRunnerTestWithDisabledCaching).
 # The cache is disabled in production for other runners.
-STATE_CACHE_SIZE = 100
+STATE_CACHE_SIZE_MB = 100
+MB_TO_BYTES = 1 << 20
 
 # Time-based flush is enabled in the fn_api_runner by default.
 DATA_BUFFER_TIME_LIMIT_MS = 1000
@@ -360,16 +361,14 @@ class EmbeddedWorkerHandler(WorkerHandler):
         self, data_plane.InMemoryDataChannel(), state, provision_info)
     self.control_conn = self  # type: ignore  # need Protocol to describe this
     self.data_conn = self.data_plane_handler
-    state_cache = StateCache(STATE_CACHE_SIZE)
+    state_cache = StateCache(STATE_CACHE_SIZE_MB * MB_TO_BYTES)
     self.bundle_processor_cache = sdk_worker.BundleProcessorCache(
         SingletonStateHandlerFactory(
             sdk_worker.GlobalCachingStateHandler(state_cache, state)),
         data_plane.InMemoryDataChannelFactory(
             self.data_plane_handler.inverse()),
         worker_manager._process_bundle_descriptors)
-    self.worker = sdk_worker.SdkWorker(
-        self.bundle_processor_cache,
-        state_cache_metrics_fn=state_cache.get_monitoring_infos)
+    self.worker = sdk_worker.SdkWorker(self.bundle_processor_cache)
     self._uid_counter = 0
 
   def push(self, request):
@@ -653,7 +652,8 @@ class EmbeddedGrpcWorkerHandler(GrpcWorkerHandler):
 
     from apache_beam.transforms.environments import EmbeddedPythonGrpcEnvironment
     config = EmbeddedPythonGrpcEnvironment.parse_config(payload.decode('utf-8'))
-    self._state_cache_size = config.get('state_cache_size') or STATE_CACHE_SIZE
+    self._state_cache_size = (
+        config.get('state_cache_size') or STATE_CACHE_SIZE_MB) << 20
     self._data_buffer_time_limit_ms = \
         config.get('data_buffer_time_limit_ms') or DATA_BUFFER_TIME_LIMIT_MS
 
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index 562c3139739..968f213d6b2 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -61,6 +61,7 @@ from apache_beam.runners.worker import data_plane
 from apache_beam.runners.worker import statesampler
 from apache_beam.runners.worker.channel_factory import GRPCChannelFactory
 from apache_beam.runners.worker.data_plane import PeriodicThread
+from apache_beam.runners.worker.statecache import CacheAware
 from apache_beam.runners.worker.statecache import StateCache
 from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor
 from apache_beam.runners.worker.worker_status import FnApiWorkerStatusHandler
@@ -212,7 +213,9 @@ class SdkHarness(object):
     if status_address:
       try:
         self._status_handler = FnApiWorkerStatusHandler(
-            status_address, self._bundle_processor_cache,
+            status_address,
+            self._bundle_processor_cache,
+            self._state_cache,
             enable_heap_dump)  # type: Optional[FnApiWorkerStatusHandler]
       except Exception:
         traceback_string = traceback.format_exc()
@@ -363,9 +366,7 @@ class SdkHarness(object):
   def create_worker(self):
     # type: () -> SdkWorker
     return SdkWorker(
-        self._bundle_processor_cache,
-        state_cache_metrics_fn=self._state_cache.get_monitoring_infos,
-        profiler_factory=self._profiler_factory)
+        self._bundle_processor_cache, profiler_factory=self._profiler_factory)
 
 
 class BundleProcessorCache(object):
@@ -581,12 +582,10 @@ class SdkWorker(object):
   def __init__(
       self,
       bundle_processor_cache,  # type: BundleProcessorCache
-      state_cache_metrics_fn=list,  # type: Callable[[], Iterable[metrics_pb2.MonitoringInfo]]
       profiler_factory=None,  # type: Optional[Callable[..., Profile]]
   ):
     # type: (...) -> None
     self.bundle_processor_cache = bundle_processor_cache
-    self.state_cache_metrics_fn = state_cache_metrics_fn
     self.profiler_factory = profiler_factory
 
   def do_instruction(self, request):
@@ -634,7 +633,6 @@ class SdkWorker(object):
           delayed_applications, requests_finalization = (
               bundle_processor.process_bundle(instruction_id))
           monitoring_infos = bundle_processor.monitoring_infos()
-          monitoring_infos.extend(self.state_cache_metrics_fn())
           response = beam_fn_api_pb2.InstructionResponse(
               instruction_id=instruction_id,
               process_bundle=beam_fn_api_pb2.ProcessBundleResponse(
@@ -878,7 +876,7 @@ class GrpcStateHandlerFactory(StateHandlerFactory):
     for _, state_handler in self._state_handler_cache.items():
       state_handler.done()
     self._state_handler_cache.clear()
-    self._state_cache.evict_all()
+    self._state_cache.invalidate_all()
 
 
 class CachingStateHandler(metaclass=abc.ABCMeta):
@@ -1130,7 +1128,6 @@ class GlobalCachingStateHandler(CachingStateHandler):
     # for items cached at the bundle level.
     self._context.bundle_cache_token = bundle_id
     try:
-      self._state_cache.initialize_metrics()
       self._context.user_state_cache_token = user_state_cache_token
       with self._underlying.process_instruction_id(bundle_id):
         yield
@@ -1290,7 +1287,7 @@ class GlobalCachingStateHandler(CachingStateHandler):
           functools.partial(
               self._lazy_iterator, state_key, coder, continuation_token))
 
-  class ContinuationIterable(Generic[T]):
+  class ContinuationIterable(Generic[T], CacheAware):
     def __init__(self, head, continue_iterator_fn):
       # type: (Iterable[T], Callable[[], Iterable[T]]) -> None
       self.head = head
@@ -1303,6 +1300,13 @@ class GlobalCachingStateHandler(CachingStateHandler):
       for item in self.continue_iterator_fn():
         yield item
 
+    def get_referents_for_cache(self):
+      # type: () -> List[Any]
+      # Only capture the size of the elements and not the
+      # continuation iterator since it references objects
+      # we don't want to include in the cache measurement.
+      return [self.head]
+
   @staticmethod
   def _convert_to_cache_key(state_key):
     # type: (beam_fn_api_pb2.StateKey) -> bytes
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
index 53cdbad5d71..ec90a30ceca 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
@@ -219,8 +219,8 @@ def _get_state_cache_size(experiments):
   future releases.
 
   Returns:
-    an int indicating the maximum number of items to cache.
-      Default is 0 (disabled)
+    an int indicating the maximum number of megabytes to cache.
+      Default is 0 MB
   """
 
   for experiment in experiments:
@@ -228,7 +228,7 @@ def _get_state_cache_size(experiments):
     if re.match(r'state_cache_size=', experiment):
       return int(
           re.match(r'state_cache_size=(?P<state_cache_size>.*)',
-                   experiment).group('state_cache_size'))
+                   experiment).group('state_cache_size')) << 20
   return 0
 
 
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
index d7309c149e4..05263aee96c 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
@@ -294,7 +294,7 @@ class CachingStateHandlerTest(unittest.TestCase):
         yield
 
     underlying_state = FakeUnderlyingState()
-    state_cache = statecache.StateCache(100)
+    state_cache = statecache.StateCache(100 << 20)
     caching_state_hander = GlobalCachingStateHandler(
         state_cache, underlying_state)
 
@@ -430,7 +430,7 @@ class CachingStateHandlerTest(unittest.TestCase):
     coder = VarIntCoder()
 
     underlying_state_handler = self.UnderlyingStateHandler()
-    state_cache = statecache.StateCache(100)
+    state_cache = statecache.StateCache(100 << 20)
     handler = GlobalCachingStateHandler(state_cache, underlying_state_handler)
 
     def get():
@@ -460,7 +460,7 @@ class CachingStateHandlerTest(unittest.TestCase):
 
   def test_continuation_token(self):
     underlying_state_handler = self.UnderlyingStateHandler()
-    state_cache = statecache.StateCache(100)
+    state_cache = statecache.StateCache(100 << 20)
     handler = GlobalCachingStateHandler(state_cache, underlying_state_handler)
 
     coder = VarIntCoder()
diff --git a/sdks/python/apache_beam/runners/worker/statecache.py b/sdks/python/apache_beam/runners/worker/statecache.py
index cebfbc7a096..733bcbbf235 100644
--- a/sdks/python/apache_beam/runners/worker/statecache.py
+++ b/sdks/python/apache_beam/runners/worker/statecache.py
@@ -20,245 +20,179 @@
 # mypy: disallow-untyped-defs
 
 import collections
+import gc
 import logging
 import threading
-from typing import TYPE_CHECKING
 from typing import Any
-from typing import Callable
-from typing import Generic
-from typing import Hashable
 from typing import List
 from typing import Optional
-from typing import Set
 from typing import Tuple
-from typing import TypeVar
 
-from apache_beam.metrics import monitoring_infos
-
-if TYPE_CHECKING:
-  from apache_beam.portability.api import metrics_pb2
+import objsize
 
 _LOGGER = logging.getLogger(__name__)
 
-CallableT = TypeVar('CallableT', bound='Callable')
-KT = TypeVar('KT')
-VT = TypeVar('VT')
 
+class WeightedValue(object):
+  """Value type that stores corresponding weight.
 
-class Metrics(object):
-  """Metrics container for state cache metrics."""
+  :arg value The value to be stored.
+  :arg weight The associated weight of the value. If unspecified, the objects
+  size will be used.
+  """
+  def __init__(self, value, weight):
+    # type: (Any, int) -> None
+    self._value = value
+    if weight <= 0:
+      raise ValueError(
+          'Expected weight to be > 0 for %s but received %d' % (value, weight))
+    self._weight = weight
+
+  def weight(self):
+    # type: () -> int
+    return self._weight
 
-  # A set of all registered metrics
-  ALL_METRICS = set()  # type: Set[Hashable]
-  PREFIX = "beam:metric:statecache:"
+  def value(self):
+    # type: () -> Any
+    return self._value
 
-  def __init__(self):
-    # type: () -> None
-    self._context = threading.local()
 
-  def initialize(self):
+class CacheAware(object):
+  def __init__(self):
     # type: () -> None
+    pass
 
-    """Needs to be called once per thread to initialize the local metrics cache.
-    """
-    if hasattr(self._context, 'metrics'):
-      return  # Already initialized
-    self._context.metrics = collections.defaultdict(int)
-
-  def count(self, name):
-    # type: (str) -> None
-    self._context.metrics[name] += 1
-
-  def hit_miss(self, total_name, hit_miss_name):
-    # type: (str, str) -> None
-    self._context.metrics[total_name] += 1
-    self._context.metrics[hit_miss_name] += 1
+  def get_referents_for_cache(self):
+    # type: () -> List[Any]
 
-  def get_monitoring_infos(self, cache_size, cache_capacity):
-    # type: (int, int) -> List[metrics_pb2.MonitoringInfo]
+    """Returns the list of objects accounted during cache measurement."""
+    raise NotImplementedError()
 
-    """Returns the metrics scoped to the current bundle."""
-    metrics = self._context.metrics
-    if len(metrics) == 0:
-      # No metrics collected, do not report
-      return []
-    # Add all missing metrics which were not reported
-    for key in Metrics.ALL_METRICS:
-      if key not in metrics:
-        metrics[key] = 0
-    # Gauges which reflect the state since last queried
-    gauges = [
-        monitoring_infos.int64_gauge(self.PREFIX + name, val) for name,
-        val in metrics.items()
-    ]
-    gauges.append(
-        monitoring_infos.int64_gauge(self.PREFIX + 'size', cache_size))
-    gauges.append(
-        monitoring_infos.int64_gauge(self.PREFIX + 'capacity', cache_capacity))
-    # Counters for the summary across all metrics
-    counters = [
-        monitoring_infos.int64_counter(self.PREFIX + name + '_total', val)
-        for name,
-        val in metrics.items()
-    ]
-    # Reinitialize metrics for this thread/bundle
-    metrics.clear()
-    return gauges + counters
 
-  @staticmethod
-  def counter_hit_miss(total_name, hit_name, miss_name):
-    # type: (str, str, str) -> Callable[[CallableT], CallableT]
+def get_referents_for_cache(*objs):
+  # type: (List[Any]) -> List[Any]
 
-    """Decorator for counting function calls and whether
-       the return value equals None (=miss) or not (=hit)."""
-    Metrics.ALL_METRICS.update([total_name, hit_name, miss_name])
+  """Returns the list of objects accounted during cache measurement.
 
-    def decorator(function):
-      # type: (CallableT) -> CallableT
-      def reporter(self, *args, **kwargs):
-        # type: (StateCache, Any, Any) -> Any
-        value = function(self, *args, **kwargs)
-        if value is None:
-          self._metrics.hit_miss(total_name, miss_name)
-        else:
-          self._metrics.hit_miss(total_name, hit_name)
-        return value
-
-      return reporter  # type: ignore[return-value]
-
-    return decorator
-
-  @staticmethod
-  def counter(metric_name):
-    # type: (str) -> Callable[[CallableT], CallableT]
-
-    """Decorator for counting function calls."""
-    Metrics.ALL_METRICS.add(metric_name)
-
-    def decorator(function):
-      # type: (CallableT) -> CallableT
-      def reporter(self, *args, **kwargs):
-        # type: (StateCache, Any, Any) -> Any
-        self._metrics.count(metric_name)
-        return function(self, *args, **kwargs)
-
-      return reporter  # type: ignore[return-value]
-
-    return decorator
+  Users can inherit CacheAware to override which referrents should be
+  used when measuring the deep size of the object. The default is to
+  use gc.get_referents(*objs).
+  """
+  rval = []
+  for obj in objs:
+    if isinstance(obj, CacheAware):
+      rval.extend(obj.get_referents_for_cache())
+    else:
+      rval.extend(gc.get_referents(obj))
+  return rval
 
 
 class StateCache(object):
-  """ Cache for Beam state access, scoped by state key and cache_token.
-      Assumes a bag state implementation.
+  """Cache for Beam state access, scoped by state key and cache_token.
+     Assumes a bag state implementation.
 
-  For a given state_key, caches a (cache_token, value) tuple and allows to
+  For a given state_key and cache_token, caches a value and allows to
     a) read from the cache (get),
            if the currently stored cache_token matches the provided
-    a) write to the cache (put),
+    b) write to the cache (put),
            storing the new value alongside with a cache token
-    c) append to the currently cache item (extend),
-           if the currently stored cache_token matches the provided
     c) empty a cached element (clear),
            if the currently stored cache_token matches the provided
-    d) evict a cached element (evict)
+    d) invalidate a cached element (invalidate)
+    e) invalidate all cached elements (invalidate_all)
 
   The operations on the cache are thread-safe for use by multiple workers.
 
-  :arg max_entries The maximum number of entries to store in the cache.
-  TODO Memory-based caching: https://github.com/apache/beam/issues/19857
+  :arg max_weight The maximum weight of entries to store in the cache in bytes.
   """
-  def __init__(self, max_entries):
+  def __init__(self, max_weight):
     # type: (int) -> None
-    _LOGGER.info('Creating state cache with size %s', max_entries)
-    self._missing = None
-    self._cache = self.LRUCache[Tuple[bytes, Optional[bytes]],
-                                Any](max_entries, self._missing)
+    _LOGGER.info('Creating state cache with size %s', max_weight)
+    self._max_weight = max_weight
+    self._current_weight = 0
+    self._cache = collections.OrderedDict(
+    )  # type: collections.OrderedDict[Tuple[bytes, Optional[bytes]], WeightedValue]
+    self._hit_count = 0
+    self._miss_count = 0
+    self._evict_count = 0
     self._lock = threading.RLock()
-    self._metrics = Metrics()
 
-  @Metrics.counter_hit_miss("get", "hit", "miss")
   def get(self, state_key, cache_token):
     # type: (bytes, Optional[bytes]) -> Any
     assert cache_token and self.is_cache_enabled()
+    key = (state_key, cache_token)
     with self._lock:
-      return self._cache.get((state_key, cache_token))
+      value = self._cache.get(key, None)
+      if value is None:
+        self._miss_count += 1
+        return None
+      self._cache.move_to_end(key)
+      self._hit_count += 1
+      return value.value()
 
-  @Metrics.counter("put")
   def put(self, state_key, cache_token, value):
     # type: (bytes, Optional[bytes], Any) -> None
     assert cache_token and self.is_cache_enabled()
+    if not isinstance(value, WeightedValue):
+      weight = objsize.get_deep_size(
+          value, get_referents_func=get_referents_for_cache)
+      if weight <= 0:
+        _LOGGER.warning(
+            'Expected object size to be >= 0 for %s but received %d.',
+            value,
+            weight)
+        weight = 8
+      value = WeightedValue(value, weight)
+    key = (state_key, cache_token)
     with self._lock:
-      return self._cache.put((state_key, cache_token), value)
+      old_value = self._cache.pop(key, None)
+      if old_value is not None:
+        self._current_weight -= old_value.weight()
+      self._cache[(state_key, cache_token)] = value
+      self._current_weight += value.weight()
+      while self._current_weight > self._max_weight:
+        (_, weighted_value) = self._cache.popitem(last=False)
+        self._current_weight -= weighted_value.weight()
+        self._evict_count += 1
 
-  @Metrics.counter("clear")
   def clear(self, state_key, cache_token):
     # type: (bytes, Optional[bytes]) -> None
-    assert cache_token and self.is_cache_enabled()
-    with self._lock:
-      self._cache.put((state_key, cache_token), [])
+    self.put(state_key, cache_token, [])
 
-  @Metrics.counter("evict")
-  def evict(self, state_key, cache_token):
+  def invalidate(self, state_key, cache_token):
     # type: (bytes, Optional[bytes]) -> None
     assert self.is_cache_enabled()
     with self._lock:
-      self._cache.evict((state_key, cache_token))
+      weighted_value = self._cache.pop((state_key, cache_token), None)
+      if weighted_value is not None:
+        self._current_weight -= weighted_value.weight()
 
-  def evict_all(self):
+  def invalidate_all(self):
     # type: () -> None
     with self._lock:
-      self._cache.evict_all()
+      self._cache.clear()
+      self._current_weight = 0
 
-  def initialize_metrics(self):
-    # type: () -> None
-    self._metrics.initialize()
+  def describe_stats(self):
+    # type: () -> str
+    with self._lock:
+      request_count = self._hit_count + self._miss_count
+      if request_count > 0:
+        hit_ratio = 100.0 * self._hit_count / request_count
+      else:
+        hit_ratio = 100.0
+      return 'used/max %d/%d MB, hit %.2f%%, lookups %d, evictions %d' % (
+          self._current_weight >> 20,
+          self._max_weight >> 20,
+          hit_ratio,
+          request_count,
+          self._evict_count)
 
   def is_cache_enabled(self):
     # type: () -> bool
-    return self._cache._max_entries > 0
+    return self._max_weight > 0
 
   def size(self):
     # type: () -> int
-    return len(self._cache)
-
-  def get_monitoring_infos(self):
-    # type: () -> List[metrics_pb2.MonitoringInfo]
-
-    """Retrieves the monitoring infos and resets the counters."""
     with self._lock:
-      size = len(self._cache)
-    capacity = self._cache._max_entries
-    return self._metrics.get_monitoring_infos(size, capacity)
-
-  class LRUCache(Generic[KT, VT]):
-    def __init__(self, max_entries, default_entry):
-      # type: (int, VT) -> None
-      self._max_entries = max_entries
-      self._default_entry = default_entry
-      self._cache = collections.OrderedDict(
-      )  # type: collections.OrderedDict[KT, VT]
-
-    def get(self, key):
-      # type: (KT) -> VT
-      value = self._cache.pop(key, self._default_entry)
-      if value != self._default_entry:
-        self._cache[key] = value
-      return value
-
-    def put(self, key, value):
-      # type: (KT, VT) -> None
-      self._cache[key] = value
-      while len(self._cache) > self._max_entries:
-        self._cache.popitem(last=False)
-
-    def evict(self, key):
-      # type: (KT) -> None
-      self._cache.pop(key, self._default_entry)
-
-    def evict_all(self):
-      # type: () -> None
-      self._cache.clear()
-
-    def __len__(self):
-      # type: () -> int
       return len(self._cache)
diff --git a/sdks/python/apache_beam/runners/worker/statecache_test.py b/sdks/python/apache_beam/runners/worker/statecache_test.py
index a1a175ed347..9bd952721ea 100644
--- a/sdks/python/apache_beam/runners/worker/statecache_test.py
+++ b/sdks/python/apache_beam/runners/worker/statecache_test.py
@@ -21,55 +21,38 @@
 import logging
 import unittest
 
-from apache_beam.metrics import monitoring_infos
+from apache_beam.runners.worker.statecache import CacheAware
 from apache_beam.runners.worker.statecache import StateCache
+from apache_beam.runners.worker.statecache import WeightedValue
 
 
 class StateCacheTest(unittest.TestCase):
   def test_empty_cache_get(self):
-    cache = self.get_cache(5)
+    cache = StateCache(5 << 20)
     self.assertEqual(cache.get("key", 'cache_token'), None)
     with self.assertRaises(Exception):
       # Invalid cache token provided
       self.assertEqual(cache.get("key", None), None)
-    self.verify_metrics(
-        cache,
-        {
-            'get': 1,
-            'put': 0,
-            'miss': 1,
-            'hit': 0,
-            'clear': 0,
-            'evict': 0,
-            'size': 0,
-            'capacity': 5
-        })
+    self.assertEqual(
+        cache.describe_stats(),
+        'used/max 0/5 MB, hit 0.00%, lookups 1, evictions 0')
 
   def test_put_get(self):
-    cache = self.get_cache(5)
-    cache.put("key", "cache_token", "value")
+    cache = StateCache(5 << 20)
+    cache.put("key", "cache_token", WeightedValue("value", 1 << 20))
     self.assertEqual(cache.size(), 1)
     self.assertEqual(cache.get("key", "cache_token"), "value")
     self.assertEqual(cache.get("key", "cache_token2"), None)
     with self.assertRaises(Exception):
       self.assertEqual(cache.get("key", None), None)
-    self.verify_metrics(
-        cache,
-        {
-            'get': 2,
-            'put': 1,
-            'miss': 1,
-            'hit': 1,
-            'clear': 0,
-            'evict': 0,
-            'size': 1,
-            'capacity': 5
-        })
+    self.assertEqual(
+        cache.describe_stats(),
+        'used/max 1/5 MB, hit 50.00%, lookups 2, evictions 0')
 
   def test_clear(self):
-    cache = self.get_cache(5)
+    cache = StateCache(5 << 20)
     cache.clear("new-key", "cache_token")
-    cache.put("key", "cache_token", ["value"])
+    cache.put("key", "cache_token", WeightedValue(["value"], 1 << 20))
     self.assertEqual(cache.size(), 2)
     self.assertEqual(cache.get("new-key", "new_token"), None)
     self.assertEqual(cache.get("key", "cache_token"), ['value'])
@@ -77,72 +60,58 @@ class StateCacheTest(unittest.TestCase):
     cache.clear("non-existing", "token")
     self.assertEqual(cache.size(), 3)
     self.assertEqual(cache.get("non-existing", "token"), [])
-    self.verify_metrics(
-        cache,
-        {
-            'get': 3,
-            'put': 1,
-            'miss': 1,
-            'hit': 2,
-            'clear': 2,
-            'evict': 0,
-            'size': 3,
-            'capacity': 5
-        })
+    self.assertEqual(
+        cache.describe_stats(),
+        'used/max 1/5 MB, hit 66.67%, lookups 3, evictions 0')
+
+  def test_default_sized_put(self):
+    cache = StateCache(5 << 20)
+    cache.put("key", "cache_token", bytearray(1 << 20))
+    cache.put("key2", "cache_token", bytearray(1 << 20))
+    cache.put("key3", "cache_token", bytearray(1 << 20))
+    self.assertEqual(cache.get("key3", "cache_token"), bytearray(1 << 20))
+    cache.put("key4", "cache_token", bytearray(1 << 20))
+    cache.put("key5", "cache_token", bytearray(1 << 20))
+    # note that each byte array instance takes slightly over 1 MB which is why
+    # these 5 byte arrays can't all be stored in the cache causing a single
+    # eviction
+    self.assertEqual(
+        cache.describe_stats(),
+        'used/max 4/5 MB, hit 100.00%, lookups 1, evictions 1')
 
   def test_max_size(self):
-    cache = self.get_cache(2)
-    cache.put("key", "cache_token", "value")
-    cache.put("key2", "cache_token", "value")
-    self.assertEqual(cache.size(), 2)
-    cache.put("key2", "cache_token", "value")
+    cache = StateCache(2 << 20)
+    cache.put("key", "cache_token", WeightedValue("value", 1 << 20))
+    cache.put("key2", "cache_token", WeightedValue("value2", 1 << 20))
     self.assertEqual(cache.size(), 2)
-    cache.put("key", "cache_token", "value")
+    cache.put("key3", "cache_token", WeightedValue("value3", 1 << 20))
     self.assertEqual(cache.size(), 2)
-    self.verify_metrics(
-        cache,
-        {
-            'get': 0,
-            'put': 4,
-            'miss': 0,
-            'hit': 0,
-            'clear': 0,
-            'evict': 0,
-            'size': 2,
-            'capacity': 2
-        })
-
-  def test_evict_all(self):
-    cache = self.get_cache(5)
-    cache.put("key", "cache_token", "value")
-    cache.put("key2", "cache_token", "value2")
+    self.assertEqual(
+        cache.describe_stats(),
+        'used/max 2/2 MB, hit 100.00%, lookups 0, evictions 1')
+
+  def test_invalidate_all(self):
+    cache = StateCache(5 << 20)
+    cache.put("key", "cache_token", WeightedValue("value", 1 << 20))
+    cache.put("key2", "cache_token", WeightedValue("value2", 1 << 20))
     self.assertEqual(cache.size(), 2)
-    cache.evict_all()
+    cache.invalidate_all()
     self.assertEqual(cache.size(), 0)
     self.assertEqual(cache.get("key", "cache_token"), None)
     self.assertEqual(cache.get("key2", "cache_token"), None)
-    self.verify_metrics(
-        cache,
-        {
-            'get': 2,
-            'put': 2,
-            'miss': 2,
-            'hit': 0,
-            'clear': 0,
-            'evict': 0,
-            'size': 0,
-            'capacity': 5
-        })
+    self.assertEqual(
+        cache.describe_stats(),
+        'used/max 0/5 MB, hit 0.00%, lookups 2, evictions 0')
 
   def test_lru(self):
-    cache = self.get_cache(5)
-    cache.put("key", "cache_token", "value")
-    cache.put("key2", "cache_token2", "value2")
-    cache.put("key3", "cache_token", "value0")
-    cache.put("key3", "cache_token", "value3")
-    cache.put("key4", "cache_token4", "value4")
-    cache.put("key5", "cache_token", "value0")
-    cache.put("key5", "cache_token", ["value5"])
+    cache = StateCache(5 << 20)
+    cache.put("key", "cache_token", WeightedValue("value", 1 << 20))
+    cache.put("key2", "cache_token2", WeightedValue("value2", 1 << 20))
+    cache.put("key3", "cache_token", WeightedValue("value0", 1 << 20))
+    cache.put("key3", "cache_token", WeightedValue("value3", 1 << 20))
+    cache.put("key4", "cache_token4", WeightedValue("value4", 1 << 20))
+    cache.put("key5", "cache_token", WeightedValue("value0", 1 << 20))
+    cache.put("key5", "cache_token", WeightedValue(["value5"], 1 << 20))
     self.assertEqual(cache.size(), 5)
     self.assertEqual(cache.get("key", "cache_token"), "value")
     self.assertEqual(cache.get("key2", "cache_token2"), "value2")
@@ -150,80 +119,59 @@ class StateCacheTest(unittest.TestCase):
     self.assertEqual(cache.get("key4", "cache_token4"), "value4")
     self.assertEqual(cache.get("key5", "cache_token"), ["value5"])
     # insert another key to trigger cache eviction
-    cache.put("key6", "cache_token2", "value7")
+    cache.put("key6", "cache_token2", WeightedValue("value6", 1 << 20))
     self.assertEqual(cache.size(), 5)
     # least recently used key should be gone ("key")
     self.assertEqual(cache.get("key", "cache_token"), None)
     # trigger a read on "key2"
     cache.get("key2", "cache_token2")
     # insert another key to trigger cache eviction
-    cache.put("key7", "cache_token", "value7")
+    cache.put("key7", "cache_token", WeightedValue("value7", 1 << 20))
     self.assertEqual(cache.size(), 5)
     # least recently used key should be gone ("key3")
     self.assertEqual(cache.get("key3", "cache_token"), None)
     # trigger a put on "key2"
-    cache.put("key2", "cache_token", "put")
+    cache.put("key2", "cache_token", WeightedValue("put", 1 << 20))
     self.assertEqual(cache.size(), 5)
     # insert another key to trigger cache eviction
-    cache.put("key8", "cache_token", "value8")
+    cache.put("key8", "cache_token", WeightedValue("value8", 1 << 20))
     self.assertEqual(cache.size(), 5)
     # least recently used key should be gone ("key4")
     self.assertEqual(cache.get("key4", "cache_token"), None)
     # make "key5" used by writing to it
-    cache.put("key5", "cache_token", "val")
+    cache.put("key5", "cache_token", WeightedValue("val", 1 << 20))
     # least recently used key should be gone ("key6")
     self.assertEqual(cache.get("key6", "cache_token"), None)
-    self.verify_metrics(
-        cache,
-        {
-            'get': 10,
-            'put': 12,
-            'miss': 4,
-            'hit': 6,
-            'clear': 0,
-            'evict': 0,
-            'size': 5,
-            'capacity': 5
-        })
+    self.assertEqual(
+        cache.describe_stats(),
+        'used/max 5/5 MB, hit 60.00%, lookups 10, evictions 5')
 
   def test_is_cached_enabled(self):
-    cache = self.get_cache(1)
+    cache = StateCache(1 << 20)
     self.assertEqual(cache.is_cache_enabled(), True)
-    self.verify_metrics(cache, {})
-    cache = self.get_cache(0)
+    self.assertEqual(
+        cache.describe_stats(),
+        'used/max 0/1 MB, hit 100.00%, lookups 0, evictions 0')
+    cache = StateCache(0)
     self.assertEqual(cache.is_cache_enabled(), False)
-    self.verify_metrics(cache, {})
-
-  def verify_metrics(self, cache, expected_metrics):
-    infos = cache.get_monitoring_infos()
-    # Reconstruct metrics dictionary from monitoring infos
-    metrics = {
-        info.urn.rsplit(':',
-                        1)[1]: monitoring_infos.extract_gauge_value(info)[1]
-        for info in infos if "_total" not in info.urn and
-        info.type == monitoring_infos.LATEST_INT64_TYPE
-    }
-    self.assertDictEqual(metrics, expected_metrics)
-    # Metrics and total metrics should be identical for a single bundle.
-    # The following two gauges are not part of the total metrics:
-    try:
-      del metrics['capacity']
-      del metrics['size']
-    except KeyError:
-      pass
-    total_metrics = {
-        info.urn.rsplit(':', 1)[1].rsplit("_total")[0]:
-        monitoring_infos.extract_counter_value(info)
-        for info in infos
-        if "_total" in info.urn and info.type == monitoring_infos.SUM_INT64_TYPE
-    }
-    self.assertDictEqual(metrics, total_metrics)
-
-  @staticmethod
-  def get_cache(size):
-    cache = StateCache(size)
-    cache.initialize_metrics()
-    return cache
+    self.assertEqual(
+        cache.describe_stats(),
+        'used/max 0/0 MB, hit 100.00%, lookups 0, evictions 0')
+
+  def test_get_referents_for_cache(self):
+    class GetReferentsForCache(CacheAware):
+      def __init__(self):
+        self.measure_me = bytearray(1 << 20)
+        self.ignore_me = bytearray(2 << 20)
+
+      def get_referents_for_cache(self):
+        return [self.measure_me]
+
+    cache = StateCache(5 << 20)
+    cache.put("key", "cache_token", GetReferentsForCache())
+    self.assertEqual(
+        cache.describe_stats(),
+        'used/max 1/5 MB, hit 100.00%, lookups 0, evictions 0')
 
 
 if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/runners/worker/worker_status.py b/sdks/python/apache_beam/runners/worker/worker_status.py
index 652b01d1e4e..7604bd0867a 100644
--- a/sdks/python/apache_beam/runners/worker/worker_status.py
+++ b/sdks/python/apache_beam/runners/worker/worker_status.py
@@ -30,6 +30,7 @@ import grpc
 from apache_beam.portability.api import beam_fn_api_pb2
 from apache_beam.portability.api import beam_fn_api_pb2_grpc
 from apache_beam.runners.worker.channel_factory import GRPCChannelFactory
+from apache_beam.runners.worker.statecache import StateCache
 from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor
 from apache_beam.utils.sentinel import Sentinel
 
@@ -95,6 +96,18 @@ def heap_dump():
   return banner + heap + ending
 
 
+def _state_cache_stats(state_cache):
+  #type: (StateCache) -> str
+
+  """Gather state cache statistics."""
+  cache_stats = ['=' * 10 + ' CACHE STATS ' + '=' * 10]
+  if not state_cache.is_cache_enabled():
+    cache_stats.append("Cache disabled")
+  else:
+    cache_stats.append(state_cache.describe_stats())
+  return '\n'.join(cache_stats)
+
+
 def _active_processing_bundles_state(bundle_process_cache):
   """Gather information about the currently in-processing active bundles.
 
@@ -138,6 +151,7 @@ class FnApiWorkerStatusHandler(object):
       self,
       status_address,
       bundle_process_cache=None,
+      state_cache=None,
       enable_heap_dump=False,
       log_lull_timeout_ns=DEFAULT_LOG_LULL_TIMEOUT_NS):
     """Initialize FnApiWorkerStatusHandler.
@@ -145,9 +159,11 @@ class FnApiWorkerStatusHandler(object):
     Args:
       status_address: The URL Runner uses to host the WorkerStatus server.
       bundle_process_cache: The BundleProcessor cache dict from sdk worker.
+      state_cache: The StateCache form sdk worker.
     """
     self._alive = True
     self._bundle_process_cache = bundle_process_cache
+    self._state_cache = state_cache
     ch = GRPCChannelFactory.insecure_channel(status_address)
     grpc.channel_ready_future(ch).result(timeout=60)
     self._status_channel = grpc.intercept_channel(ch, WorkerIdInterceptor())
@@ -193,9 +209,14 @@ class FnApiWorkerStatusHandler(object):
                   "status page: %s" % traceback_string))
 
   def generate_status_response(self):
-    all_status_sections = [
-        _active_processing_bundles_state(self._bundle_process_cache)
-    ] if self._bundle_process_cache else []
+    all_status_sections = []
+
+    if self._state_cache:
+      all_status_sections.append(_state_cache_stats(self._state_cache))
+
+    if self._bundle_process_cache:
+      all_status_sections.append(
+          _active_processing_bundles_state(self._bundle_process_cache))
 
     all_status_sections.append(thread_dump())
     if self._enable_heap_dump:
diff --git a/sdks/python/container/py37/base_image_requirements.txt b/sdks/python/container/py37/base_image_requirements.txt
index 392a06f39d0..898d4d8dab3 100644
--- a/sdks/python/container/py37/base_image_requirements.txt
+++ b/sdks/python/container/py37/base_image_requirements.txt
@@ -96,6 +96,7 @@ nose==1.3.7
 numpy==1.21.6
 oauth2client==4.1.3
 oauthlib==3.2.0
+objsize==0.5.1
 opt-einsum==3.3.0
 orjson==3.7.11
 overrides==6.1.0
diff --git a/sdks/python/container/py38/base_image_requirements.txt b/sdks/python/container/py38/base_image_requirements.txt
index 24a81d1dcaf..1539b561487 100644
--- a/sdks/python/container/py38/base_image_requirements.txt
+++ b/sdks/python/container/py38/base_image_requirements.txt
@@ -96,6 +96,7 @@ nose==1.3.7
 numpy==1.22.4
 oauth2client==4.1.3
 oauthlib==3.2.0
+objsize==0.5.1
 opt-einsum==3.3.0
 orjson==3.7.11
 overrides==6.1.0
diff --git a/sdks/python/container/py39/base_image_requirements.txt b/sdks/python/container/py39/base_image_requirements.txt
index 67de8d2b356..2861b6ddb96 100644
--- a/sdks/python/container/py39/base_image_requirements.txt
+++ b/sdks/python/container/py39/base_image_requirements.txt
@@ -96,6 +96,7 @@ nose==1.3.7
 numpy==1.22.4
 oauth2client==4.1.3
 oauthlib==3.2.0
+objsize==0.5.1
 opt-einsum==3.3.0
 orjson==3.7.11
 overrides==6.1.0
diff --git a/sdks/python/setup.py b/sdks/python/setup.py
index f0863bf49cd..fa02a658612 100644
--- a/sdks/python/setup.py
+++ b/sdks/python/setup.py
@@ -223,6 +223,7 @@ if __name__ == '__main__':
         'hdfs>=2.1.0,<3.0.0',
         'httplib2>=0.8,<0.21.0',
         'numpy>=1.14.3,<1.23.0',
+        'objsize>=0.5.1,<1',
         'pymongo>=3.8.0,<4.0.0',
         'protobuf>=3.12.2,<4',
         'proto-plus>=1.7.1,<2',