You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lo...@apache.org on 2022/09/09 18:25:33 UTC

[beam] branch release-2.42.0 updated: Revert "[#19857] Migrate to using a memory aware cache within the Python SDK harness (#22924)" (#23107)

This is an automated email from the ASF dual-hosted git repository.

lostluck pushed a commit to branch release-2.42.0
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/release-2.42.0 by this push:
     new 4d4fd066203 Revert "[#19857] Migrate to using a memory aware cache within the Python SDK harness (#22924)" (#23107)
4d4fd066203 is described below

commit 4d4fd066203831005d706b9c9a699cc072ca1863
Author: Luke Cwik <lc...@google.com>
AuthorDate: Fri Sep 9 11:25:27 2022 -0700

    Revert "[#19857] Migrate to using a memory aware cache within the Python SDK harness (#22924)" (#23107)
    
    This reverts commit 25c6ed74c9846c89a92655c1e8d313ef87d6adb1.
---
 .../runners/portability/flink_runner_test.py       |  93 +++++++
 .../portability/fn_api_runner/worker_handlers.py   |  12 +-
 .../apache_beam/runners/worker/sdk_worker.py       |  24 +-
 .../apache_beam/runners/worker/sdk_worker_main.py  |   6 +-
 .../apache_beam/runners/worker/sdk_worker_test.py  |   6 +-
 .../apache_beam/runners/worker/statecache.py       | 286 +++++++++++++--------
 .../apache_beam/runners/worker/statecache_test.py  | 224 +++++++++-------
 .../apache_beam/runners/worker/worker_status.py    |  27 +-
 .../container/py37/base_image_requirements.txt     |   1 -
 .../container/py38/base_image_requirements.txt     |   1 -
 .../container/py39/base_image_requirements.txt     |   1 -
 sdks/python/setup.py                               |   1 -
 12 files changed, 432 insertions(+), 250 deletions(-)

diff --git a/sdks/python/apache_beam/runners/portability/flink_runner_test.py b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
index 27e4ca4973e..48e5df54d2d 100644
--- a/sdks/python/apache_beam/runners/portability/flink_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
@@ -32,9 +32,12 @@ import pytest
 import apache_beam as beam
 from apache_beam import Impulse
 from apache_beam import Map
+from apache_beam import Pipeline
+from apache_beam.coders import VarIntCoder
 from apache_beam.io.external.generate_sequence import GenerateSequence
 from apache_beam.io.kafka import ReadFromKafka
 from apache_beam.io.kafka import WriteToKafka
+from apache_beam.metrics import Metrics
 from apache_beam.options.pipeline_options import DebugOptions
 from apache_beam.options.pipeline_options import FlinkRunnerOptions
 from apache_beam.options.pipeline_options import PortableOptions
@@ -44,6 +47,7 @@ from apache_beam.runners.portability import portable_runner
 from apache_beam.runners.portability import portable_runner_test
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
+from apache_beam.transforms import userstate
 from apache_beam.transforms.sql import SqlTransform
 
 # Run as
@@ -292,6 +296,95 @@ class FlinkRunnerTest(portable_runner_test.PortableRunnerTest):
   def test_metrics(self):
     super().test_metrics(check_gauge=False)
 
+  def test_flink_metrics(self):
+    """Run a simple DoFn that increments a counter and verifies state
+    caching metrics. Verifies that its expected value is written to a
+    temporary file by the FileReporter"""
+
+    counter_name = 'elem_counter'
+    state_spec = userstate.BagStateSpec('state', VarIntCoder())
+
+    class DoFn(beam.DoFn):
+      def __init__(self):
+        self.counter = Metrics.counter(self.__class__, counter_name)
+        _LOGGER.info('counter: %s' % self.counter.metric_name)
+
+      def process(self, kv, state=beam.DoFn.StateParam(state_spec)):
+        # Trigger materialization
+        list(state.read())
+        state.add(1)
+        self.counter.inc()
+
+    options = self.create_options()
+    # Test only supports parallelism of 1
+    options._all_options['parallelism'] = 1
+    # Create multiple bundles to test cache metrics
+    options._all_options['max_bundle_size'] = 10
+    options._all_options['max_bundle_time_millis'] = 95130590130
+    experiments = options.view_as(DebugOptions).experiments or []
+    experiments.append('state_cache_size=123')
+    options.view_as(DebugOptions).experiments = experiments
+    with Pipeline(self.get_runner(), options) as p:
+      # pylint: disable=expression-not-assigned
+      (
+          p
+          | "create" >> beam.Create(list(range(0, 110)))
+          | "mapper" >> beam.Map(lambda x: (x % 10, 'val'))
+          | "stateful" >> beam.ParDo(DoFn()))
+
+    lines_expected = {'counter: 110'}
+    if options.view_as(StandardOptions).streaming:
+      lines_expected.update([
+          # Gauges for the last finished bundle
+          'stateful.beam_metric:statecache:capacity: 123',
+          'stateful.beam_metric:statecache:size: 10',
+          'stateful.beam_metric:statecache:get: 20',
+          'stateful.beam_metric:statecache:miss: 0',
+          'stateful.beam_metric:statecache:hit: 20',
+          'stateful.beam_metric:statecache:put: 0',
+          'stateful.beam_metric:statecache:evict: 0',
+          # Counters
+          'stateful.beam_metric:statecache:get_total: 220',
+          'stateful.beam_metric:statecache:miss_total: 10',
+          'stateful.beam_metric:statecache:hit_total: 210',
+          'stateful.beam_metric:statecache:put_total: 10',
+          'stateful.beam_metric:statecache:evict_total: 0',
+      ])
+    else:
+      # Batch has a different processing model. All values for
+      # a key are processed at once.
+      lines_expected.update([
+          # Gauges
+          'stateful).beam_metric:statecache:capacity: 123',
+          # For the first key, the cache token will not be set yet.
+          # It's lazily initialized after first access in StateRequestHandlers
+          'stateful).beam_metric:statecache:size: 10',
+          # We have 11 here because there are 110 / 10 elements per key
+          'stateful).beam_metric:statecache:get: 12',
+          'stateful).beam_metric:statecache:miss: 1',
+          'stateful).beam_metric:statecache:hit: 11',
+          # State is flushed back once per key
+          'stateful).beam_metric:statecache:put: 1',
+          'stateful).beam_metric:statecache:evict: 0',
+          # Counters
+          'stateful).beam_metric:statecache:get_total: 120',
+          'stateful).beam_metric:statecache:miss_total: 10',
+          'stateful).beam_metric:statecache:hit_total: 110',
+          'stateful).beam_metric:statecache:put_total: 10',
+          'stateful).beam_metric:statecache:evict_total: 0',
+      ])
+    lines_actual = set()
+    with open(self.test_metrics_path, 'r') as f:
+      for line in f:
+        print(line, end='')
+        for metric_str in lines_expected:
+          metric_name = metric_str.split()[0]
+          if metric_str in line:
+            lines_actual.add(metric_str)
+          elif metric_name in line:
+            lines_actual.add(line)
+    self.assertSetEqual(lines_actual, lines_expected)
+
   def test_sdf_with_watermark_tracking(self):
     raise unittest.SkipTest("BEAM-2939")
 
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py
index abb356d5ff4..5aaadbd4387 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py
@@ -81,8 +81,7 @@ if TYPE_CHECKING:
 # State caching is enabled in the fn_api_runner for testing, except for one
 # test which runs without state caching (FnApiRunnerTestWithDisabledCaching).
 # The cache is disabled in production for other runners.
-STATE_CACHE_SIZE_MB = 100
-MB_TO_BYTES = 1 << 20
+STATE_CACHE_SIZE = 100
 
 # Time-based flush is enabled in the fn_api_runner by default.
 DATA_BUFFER_TIME_LIMIT_MS = 1000
@@ -361,14 +360,16 @@ class EmbeddedWorkerHandler(WorkerHandler):
         self, data_plane.InMemoryDataChannel(), state, provision_info)
     self.control_conn = self  # type: ignore  # need Protocol to describe this
     self.data_conn = self.data_plane_handler
-    state_cache = StateCache(STATE_CACHE_SIZE_MB * MB_TO_BYTES)
+    state_cache = StateCache(STATE_CACHE_SIZE)
     self.bundle_processor_cache = sdk_worker.BundleProcessorCache(
         SingletonStateHandlerFactory(
             sdk_worker.GlobalCachingStateHandler(state_cache, state)),
         data_plane.InMemoryDataChannelFactory(
             self.data_plane_handler.inverse()),
         worker_manager._process_bundle_descriptors)
-    self.worker = sdk_worker.SdkWorker(self.bundle_processor_cache)
+    self.worker = sdk_worker.SdkWorker(
+        self.bundle_processor_cache,
+        state_cache_metrics_fn=state_cache.get_monitoring_infos)
     self._uid_counter = 0
 
   def push(self, request):
@@ -652,8 +653,7 @@ class EmbeddedGrpcWorkerHandler(GrpcWorkerHandler):
 
     from apache_beam.transforms.environments import EmbeddedPythonGrpcEnvironment
     config = EmbeddedPythonGrpcEnvironment.parse_config(payload.decode('utf-8'))
-    self._state_cache_size = (
-        config.get('state_cache_size') or STATE_CACHE_SIZE_MB) << 20
+    self._state_cache_size = config.get('state_cache_size') or STATE_CACHE_SIZE
     self._data_buffer_time_limit_ms = \
         config.get('data_buffer_time_limit_ms') or DATA_BUFFER_TIME_LIMIT_MS
 
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index 968f213d6b2..562c3139739 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -61,7 +61,6 @@ from apache_beam.runners.worker import data_plane
 from apache_beam.runners.worker import statesampler
 from apache_beam.runners.worker.channel_factory import GRPCChannelFactory
 from apache_beam.runners.worker.data_plane import PeriodicThread
-from apache_beam.runners.worker.statecache import CacheAware
 from apache_beam.runners.worker.statecache import StateCache
 from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor
 from apache_beam.runners.worker.worker_status import FnApiWorkerStatusHandler
@@ -213,9 +212,7 @@ class SdkHarness(object):
     if status_address:
       try:
         self._status_handler = FnApiWorkerStatusHandler(
-            status_address,
-            self._bundle_processor_cache,
-            self._state_cache,
+            status_address, self._bundle_processor_cache,
             enable_heap_dump)  # type: Optional[FnApiWorkerStatusHandler]
       except Exception:
         traceback_string = traceback.format_exc()
@@ -366,7 +363,9 @@ class SdkHarness(object):
   def create_worker(self):
     # type: () -> SdkWorker
     return SdkWorker(
-        self._bundle_processor_cache, profiler_factory=self._profiler_factory)
+        self._bundle_processor_cache,
+        state_cache_metrics_fn=self._state_cache.get_monitoring_infos,
+        profiler_factory=self._profiler_factory)
 
 
 class BundleProcessorCache(object):
@@ -582,10 +581,12 @@ class SdkWorker(object):
   def __init__(
       self,
       bundle_processor_cache,  # type: BundleProcessorCache
+      state_cache_metrics_fn=list,  # type: Callable[[], Iterable[metrics_pb2.MonitoringInfo]]
       profiler_factory=None,  # type: Optional[Callable[..., Profile]]
   ):
     # type: (...) -> None
     self.bundle_processor_cache = bundle_processor_cache
+    self.state_cache_metrics_fn = state_cache_metrics_fn
     self.profiler_factory = profiler_factory
 
   def do_instruction(self, request):
@@ -633,6 +634,7 @@ class SdkWorker(object):
           delayed_applications, requests_finalization = (
               bundle_processor.process_bundle(instruction_id))
           monitoring_infos = bundle_processor.monitoring_infos()
+          monitoring_infos.extend(self.state_cache_metrics_fn())
           response = beam_fn_api_pb2.InstructionResponse(
               instruction_id=instruction_id,
               process_bundle=beam_fn_api_pb2.ProcessBundleResponse(
@@ -876,7 +878,7 @@ class GrpcStateHandlerFactory(StateHandlerFactory):
     for _, state_handler in self._state_handler_cache.items():
       state_handler.done()
     self._state_handler_cache.clear()
-    self._state_cache.invalidate_all()
+    self._state_cache.evict_all()
 
 
 class CachingStateHandler(metaclass=abc.ABCMeta):
@@ -1128,6 +1130,7 @@ class GlobalCachingStateHandler(CachingStateHandler):
     # for items cached at the bundle level.
     self._context.bundle_cache_token = bundle_id
     try:
+      self._state_cache.initialize_metrics()
       self._context.user_state_cache_token = user_state_cache_token
       with self._underlying.process_instruction_id(bundle_id):
         yield
@@ -1287,7 +1290,7 @@ class GlobalCachingStateHandler(CachingStateHandler):
           functools.partial(
               self._lazy_iterator, state_key, coder, continuation_token))
 
-  class ContinuationIterable(Generic[T], CacheAware):
+  class ContinuationIterable(Generic[T]):
     def __init__(self, head, continue_iterator_fn):
       # type: (Iterable[T], Callable[[], Iterable[T]]) -> None
       self.head = head
@@ -1300,13 +1303,6 @@ class GlobalCachingStateHandler(CachingStateHandler):
       for item in self.continue_iterator_fn():
         yield item
 
-    def get_referents_for_cache(self):
-      # type: () -> List[Any]
-      # Only capture the size of the elements and not the
-      # continuation iterator since it references objects
-      # we don't want to include in the cache measurement.
-      return [self.head]
-
   @staticmethod
   def _convert_to_cache_key(state_key):
     # type: (beam_fn_api_pb2.StateKey) -> bytes
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
index ec90a30ceca..53cdbad5d71 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
@@ -219,8 +219,8 @@ def _get_state_cache_size(experiments):
   future releases.
 
   Returns:
-    an int indicating the maximum number of megabytes to cache.
-      Default is 0 MB
+    an int indicating the maximum number of items to cache.
+      Default is 0 (disabled)
   """
 
   for experiment in experiments:
@@ -228,7 +228,7 @@ def _get_state_cache_size(experiments):
     if re.match(r'state_cache_size=', experiment):
       return int(
           re.match(r'state_cache_size=(?P<state_cache_size>.*)',
-                   experiment).group('state_cache_size')) << 20
+                   experiment).group('state_cache_size'))
   return 0
 
 
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
index 05263aee96c..d7309c149e4 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
@@ -294,7 +294,7 @@ class CachingStateHandlerTest(unittest.TestCase):
         yield
 
     underlying_state = FakeUnderlyingState()
-    state_cache = statecache.StateCache(100 << 20)
+    state_cache = statecache.StateCache(100)
     caching_state_hander = GlobalCachingStateHandler(
         state_cache, underlying_state)
 
@@ -430,7 +430,7 @@ class CachingStateHandlerTest(unittest.TestCase):
     coder = VarIntCoder()
 
     underlying_state_handler = self.UnderlyingStateHandler()
-    state_cache = statecache.StateCache(100 << 20)
+    state_cache = statecache.StateCache(100)
     handler = GlobalCachingStateHandler(state_cache, underlying_state_handler)
 
     def get():
@@ -460,7 +460,7 @@ class CachingStateHandlerTest(unittest.TestCase):
 
   def test_continuation_token(self):
     underlying_state_handler = self.UnderlyingStateHandler()
-    state_cache = statecache.StateCache(100 << 20)
+    state_cache = statecache.StateCache(100)
     handler = GlobalCachingStateHandler(state_cache, underlying_state_handler)
 
     coder = VarIntCoder()
diff --git a/sdks/python/apache_beam/runners/worker/statecache.py b/sdks/python/apache_beam/runners/worker/statecache.py
index 733bcbbf235..cebfbc7a096 100644
--- a/sdks/python/apache_beam/runners/worker/statecache.py
+++ b/sdks/python/apache_beam/runners/worker/statecache.py
@@ -20,179 +20,245 @@
 # mypy: disallow-untyped-defs
 
 import collections
-import gc
 import logging
 import threading
+from typing import TYPE_CHECKING
 from typing import Any
+from typing import Callable
+from typing import Generic
+from typing import Hashable
 from typing import List
 from typing import Optional
+from typing import Set
 from typing import Tuple
+from typing import TypeVar
 
-import objsize
+from apache_beam.metrics import monitoring_infos
 
-_LOGGER = logging.getLogger(__name__)
+if TYPE_CHECKING:
+  from apache_beam.portability.api import metrics_pb2
 
+_LOGGER = logging.getLogger(__name__)
 
-class WeightedValue(object):
-  """Value type that stores corresponding weight.
+CallableT = TypeVar('CallableT', bound='Callable')
+KT = TypeVar('KT')
+VT = TypeVar('VT')
 
-  :arg value The value to be stored.
-  :arg weight The associated weight of the value. If unspecified, the objects
-  size will be used.
-  """
-  def __init__(self, value, weight):
-    # type: (Any, int) -> None
-    self._value = value
-    if weight <= 0:
-      raise ValueError(
-          'Expected weight to be > 0 for %s but received %d' % (value, weight))
-    self._weight = weight
-
-  def weight(self):
-    # type: () -> int
-    return self._weight
 
-  def value(self):
-    # type: () -> Any
-    return self._value
+class Metrics(object):
+  """Metrics container for state cache metrics."""
 
+  # A set of all registered metrics
+  ALL_METRICS = set()  # type: Set[Hashable]
+  PREFIX = "beam:metric:statecache:"
 
-class CacheAware(object):
   def __init__(self):
     # type: () -> None
-    pass
+    self._context = threading.local()
 
-  def get_referents_for_cache(self):
-    # type: () -> List[Any]
+  def initialize(self):
+    # type: () -> None
 
-    """Returns the list of objects accounted during cache measurement."""
-    raise NotImplementedError()
+    """Needs to be called once per thread to initialize the local metrics cache.
+    """
+    if hasattr(self._context, 'metrics'):
+      return  # Already initialized
+    self._context.metrics = collections.defaultdict(int)
 
+  def count(self, name):
+    # type: (str) -> None
+    self._context.metrics[name] += 1
 
-def get_referents_for_cache(*objs):
-  # type: (List[Any]) -> List[Any]
+  def hit_miss(self, total_name, hit_miss_name):
+    # type: (str, str) -> None
+    self._context.metrics[total_name] += 1
+    self._context.metrics[hit_miss_name] += 1
 
-  """Returns the list of objects accounted during cache measurement.
+  def get_monitoring_infos(self, cache_size, cache_capacity):
+    # type: (int, int) -> List[metrics_pb2.MonitoringInfo]
 
-  Users can inherit CacheAware to override which referrents should be
-  used when measuring the deep size of the object. The default is to
-  use gc.get_referents(*objs).
-  """
-  rval = []
-  for obj in objs:
-    if isinstance(obj, CacheAware):
-      rval.extend(obj.get_referents_for_cache())
-    else:
-      rval.extend(gc.get_referents(obj))
-  return rval
+    """Returns the metrics scoped to the current bundle."""
+    metrics = self._context.metrics
+    if len(metrics) == 0:
+      # No metrics collected, do not report
+      return []
+    # Add all missing metrics which were not reported
+    for key in Metrics.ALL_METRICS:
+      if key not in metrics:
+        metrics[key] = 0
+    # Gauges which reflect the state since last queried
+    gauges = [
+        monitoring_infos.int64_gauge(self.PREFIX + name, val) for name,
+        val in metrics.items()
+    ]
+    gauges.append(
+        monitoring_infos.int64_gauge(self.PREFIX + 'size', cache_size))
+    gauges.append(
+        monitoring_infos.int64_gauge(self.PREFIX + 'capacity', cache_capacity))
+    # Counters for the summary across all metrics
+    counters = [
+        monitoring_infos.int64_counter(self.PREFIX + name + '_total', val)
+        for name,
+        val in metrics.items()
+    ]
+    # Reinitialize metrics for this thread/bundle
+    metrics.clear()
+    return gauges + counters
+
+  @staticmethod
+  def counter_hit_miss(total_name, hit_name, miss_name):
+    # type: (str, str, str) -> Callable[[CallableT], CallableT]
+
+    """Decorator for counting function calls and whether
+       the return value equals None (=miss) or not (=hit)."""
+    Metrics.ALL_METRICS.update([total_name, hit_name, miss_name])
+
+    def decorator(function):
+      # type: (CallableT) -> CallableT
+      def reporter(self, *args, **kwargs):
+        # type: (StateCache, Any, Any) -> Any
+        value = function(self, *args, **kwargs)
+        if value is None:
+          self._metrics.hit_miss(total_name, miss_name)
+        else:
+          self._metrics.hit_miss(total_name, hit_name)
+        return value
+
+      return reporter  # type: ignore[return-value]
+
+    return decorator
+
+  @staticmethod
+  def counter(metric_name):
+    # type: (str) -> Callable[[CallableT], CallableT]
+
+    """Decorator for counting function calls."""
+    Metrics.ALL_METRICS.add(metric_name)
+
+    def decorator(function):
+      # type: (CallableT) -> CallableT
+      def reporter(self, *args, **kwargs):
+        # type: (StateCache, Any, Any) -> Any
+        self._metrics.count(metric_name)
+        return function(self, *args, **kwargs)
+
+      return reporter  # type: ignore[return-value]
+
+    return decorator
 
 
 class StateCache(object):
-  """Cache for Beam state access, scoped by state key and cache_token.
-     Assumes a bag state implementation.
+  """ Cache for Beam state access, scoped by state key and cache_token.
+      Assumes a bag state implementation.
 
-  For a given state_key and cache_token, caches a value and allows to
+  For a given state_key, caches a (cache_token, value) tuple and allows to
     a) read from the cache (get),
            if the currently stored cache_token matches the provided
-    b) write to the cache (put),
+    a) write to the cache (put),
            storing the new value alongside with a cache token
+    c) append to the currently cache item (extend),
+           if the currently stored cache_token matches the provided
     c) empty a cached element (clear),
            if the currently stored cache_token matches the provided
-    d) invalidate a cached element (invalidate)
-    e) invalidate all cached elements (invalidate_all)
+    d) evict a cached element (evict)
 
   The operations on the cache are thread-safe for use by multiple workers.
 
-  :arg max_weight The maximum weight of entries to store in the cache in bytes.
+  :arg max_entries The maximum number of entries to store in the cache.
+  TODO Memory-based caching: https://github.com/apache/beam/issues/19857
   """
-  def __init__(self, max_weight):
+  def __init__(self, max_entries):
     # type: (int) -> None
-    _LOGGER.info('Creating state cache with size %s', max_weight)
-    self._max_weight = max_weight
-    self._current_weight = 0
-    self._cache = collections.OrderedDict(
-    )  # type: collections.OrderedDict[Tuple[bytes, Optional[bytes]], WeightedValue]
-    self._hit_count = 0
-    self._miss_count = 0
-    self._evict_count = 0
+    _LOGGER.info('Creating state cache with size %s', max_entries)
+    self._missing = None
+    self._cache = self.LRUCache[Tuple[bytes, Optional[bytes]],
+                                Any](max_entries, self._missing)
     self._lock = threading.RLock()
+    self._metrics = Metrics()
 
+  @Metrics.counter_hit_miss("get", "hit", "miss")
   def get(self, state_key, cache_token):
     # type: (bytes, Optional[bytes]) -> Any
     assert cache_token and self.is_cache_enabled()
-    key = (state_key, cache_token)
     with self._lock:
-      value = self._cache.get(key, None)
-      if value is None:
-        self._miss_count += 1
-        return None
-      self._cache.move_to_end(key)
-      self._hit_count += 1
-      return value.value()
+      return self._cache.get((state_key, cache_token))
 
+  @Metrics.counter("put")
   def put(self, state_key, cache_token, value):
     # type: (bytes, Optional[bytes], Any) -> None
     assert cache_token and self.is_cache_enabled()
-    if not isinstance(value, WeightedValue):
-      weight = objsize.get_deep_size(
-          value, get_referents_func=get_referents_for_cache)
-      if weight <= 0:
-        _LOGGER.warning(
-            'Expected object size to be >= 0 for %s but received %d.',
-            value,
-            weight)
-        weight = 8
-      value = WeightedValue(value, weight)
-    key = (state_key, cache_token)
     with self._lock:
-      old_value = self._cache.pop(key, None)
-      if old_value is not None:
-        self._current_weight -= old_value.weight()
-      self._cache[(state_key, cache_token)] = value
-      self._current_weight += value.weight()
-      while self._current_weight > self._max_weight:
-        (_, weighted_value) = self._cache.popitem(last=False)
-        self._current_weight -= weighted_value.weight()
-        self._evict_count += 1
+      return self._cache.put((state_key, cache_token), value)
 
+  @Metrics.counter("clear")
   def clear(self, state_key, cache_token):
     # type: (bytes, Optional[bytes]) -> None
-    self.put(state_key, cache_token, [])
+    assert cache_token and self.is_cache_enabled()
+    with self._lock:
+      self._cache.put((state_key, cache_token), [])
 
-  def invalidate(self, state_key, cache_token):
+  @Metrics.counter("evict")
+  def evict(self, state_key, cache_token):
     # type: (bytes, Optional[bytes]) -> None
     assert self.is_cache_enabled()
     with self._lock:
-      weighted_value = self._cache.pop((state_key, cache_token), None)
-      if weighted_value is not None:
-        self._current_weight -= weighted_value.weight()
+      self._cache.evict((state_key, cache_token))
 
-  def invalidate_all(self):
+  def evict_all(self):
     # type: () -> None
     with self._lock:
-      self._cache.clear()
-      self._current_weight = 0
+      self._cache.evict_all()
 
-  def describe_stats(self):
-    # type: () -> str
-    with self._lock:
-      request_count = self._hit_count + self._miss_count
-      if request_count > 0:
-        hit_ratio = 100.0 * self._hit_count / request_count
-      else:
-        hit_ratio = 100.0
-      return 'used/max %d/%d MB, hit %.2f%%, lookups %d, evictions %d' % (
-          self._current_weight >> 20,
-          self._max_weight >> 20,
-          hit_ratio,
-          request_count,
-          self._evict_count)
+  def initialize_metrics(self):
+    # type: () -> None
+    self._metrics.initialize()
 
   def is_cache_enabled(self):
     # type: () -> bool
-    return self._max_weight > 0
+    return self._cache._max_entries > 0
 
   def size(self):
     # type: () -> int
+    return len(self._cache)
+
+  def get_monitoring_infos(self):
+    # type: () -> List[metrics_pb2.MonitoringInfo]
+
+    """Retrieves the monitoring infos and resets the counters."""
     with self._lock:
+      size = len(self._cache)
+    capacity = self._cache._max_entries
+    return self._metrics.get_monitoring_infos(size, capacity)
+
+  class LRUCache(Generic[KT, VT]):
+    def __init__(self, max_entries, default_entry):
+      # type: (int, VT) -> None
+      self._max_entries = max_entries
+      self._default_entry = default_entry
+      self._cache = collections.OrderedDict(
+      )  # type: collections.OrderedDict[KT, VT]
+
+    def get(self, key):
+      # type: (KT) -> VT
+      value = self._cache.pop(key, self._default_entry)
+      if value != self._default_entry:
+        self._cache[key] = value
+      return value
+
+    def put(self, key, value):
+      # type: (KT, VT) -> None
+      self._cache[key] = value
+      while len(self._cache) > self._max_entries:
+        self._cache.popitem(last=False)
+
+    def evict(self, key):
+      # type: (KT) -> None
+      self._cache.pop(key, self._default_entry)
+
+    def evict_all(self):
+      # type: () -> None
+      self._cache.clear()
+
+    def __len__(self):
+      # type: () -> int
       return len(self._cache)
diff --git a/sdks/python/apache_beam/runners/worker/statecache_test.py b/sdks/python/apache_beam/runners/worker/statecache_test.py
index 9bd952721ea..a1a175ed347 100644
--- a/sdks/python/apache_beam/runners/worker/statecache_test.py
+++ b/sdks/python/apache_beam/runners/worker/statecache_test.py
@@ -21,38 +21,55 @@
 import logging
 import unittest
 
-from apache_beam.runners.worker.statecache import CacheAware
+from apache_beam.metrics import monitoring_infos
 from apache_beam.runners.worker.statecache import StateCache
-from apache_beam.runners.worker.statecache import WeightedValue
 
 
 class StateCacheTest(unittest.TestCase):
   def test_empty_cache_get(self):
-    cache = StateCache(5 << 20)
+    cache = self.get_cache(5)
     self.assertEqual(cache.get("key", 'cache_token'), None)
     with self.assertRaises(Exception):
       # Invalid cache token provided
       self.assertEqual(cache.get("key", None), None)
-    self.assertEqual(
-        cache.describe_stats(),
-        'used/max 0/5 MB, hit 0.00%, lookups 1, evictions 0')
+    self.verify_metrics(
+        cache,
+        {
+            'get': 1,
+            'put': 0,
+            'miss': 1,
+            'hit': 0,
+            'clear': 0,
+            'evict': 0,
+            'size': 0,
+            'capacity': 5
+        })
 
   def test_put_get(self):
-    cache = StateCache(5 << 20)
-    cache.put("key", "cache_token", WeightedValue("value", 1 << 20))
+    cache = self.get_cache(5)
+    cache.put("key", "cache_token", "value")
     self.assertEqual(cache.size(), 1)
     self.assertEqual(cache.get("key", "cache_token"), "value")
     self.assertEqual(cache.get("key", "cache_token2"), None)
     with self.assertRaises(Exception):
       self.assertEqual(cache.get("key", None), None)
-    self.assertEqual(
-        cache.describe_stats(),
-        'used/max 1/5 MB, hit 50.00%, lookups 2, evictions 0')
+    self.verify_metrics(
+        cache,
+        {
+            'get': 2,
+            'put': 1,
+            'miss': 1,
+            'hit': 1,
+            'clear': 0,
+            'evict': 0,
+            'size': 1,
+            'capacity': 5
+        })
 
   def test_clear(self):
-    cache = StateCache(5 << 20)
+    cache = self.get_cache(5)
     cache.clear("new-key", "cache_token")
-    cache.put("key", "cache_token", WeightedValue(["value"], 1 << 20))
+    cache.put("key", "cache_token", ["value"])
     self.assertEqual(cache.size(), 2)
     self.assertEqual(cache.get("new-key", "new_token"), None)
     self.assertEqual(cache.get("key", "cache_token"), ['value'])
@@ -60,58 +77,72 @@ class StateCacheTest(unittest.TestCase):
     cache.clear("non-existing", "token")
     self.assertEqual(cache.size(), 3)
     self.assertEqual(cache.get("non-existing", "token"), [])
-    self.assertEqual(
-        cache.describe_stats(),
-        'used/max 1/5 MB, hit 66.67%, lookups 3, evictions 0')
-
-  def test_default_sized_put(self):
-    cache = StateCache(5 << 20)
-    cache.put("key", "cache_token", bytearray(1 << 20))
-    cache.put("key2", "cache_token", bytearray(1 << 20))
-    cache.put("key3", "cache_token", bytearray(1 << 20))
-    self.assertEqual(cache.get("key3", "cache_token"), bytearray(1 << 20))
-    cache.put("key4", "cache_token", bytearray(1 << 20))
-    cache.put("key5", "cache_token", bytearray(1 << 20))
-    # note that each byte array instance takes slightly over 1 MB which is why
-    # these 5 byte arrays can't all be stored in the cache causing a single
-    # eviction
-    self.assertEqual(
-        cache.describe_stats(),
-        'used/max 4/5 MB, hit 100.00%, lookups 1, evictions 1')
+    self.verify_metrics(
+        cache,
+        {
+            'get': 3,
+            'put': 1,
+            'miss': 1,
+            'hit': 2,
+            'clear': 2,
+            'evict': 0,
+            'size': 3,
+            'capacity': 5
+        })
 
   def test_max_size(self):
-    cache = StateCache(2 << 20)
-    cache.put("key", "cache_token", WeightedValue("value", 1 << 20))
-    cache.put("key2", "cache_token", WeightedValue("value2", 1 << 20))
+    cache = self.get_cache(2)
+    cache.put("key", "cache_token", "value")
+    cache.put("key2", "cache_token", "value")
     self.assertEqual(cache.size(), 2)
-    cache.put("key3", "cache_token", WeightedValue("value3", 1 << 20))
+    cache.put("key2", "cache_token", "value")
     self.assertEqual(cache.size(), 2)
-    self.assertEqual(
-        cache.describe_stats(),
-        'used/max 2/2 MB, hit 100.00%, lookups 0, evictions 1')
-
-  def test_invalidate_all(self):
-    cache = StateCache(5 << 20)
-    cache.put("key", "cache_token", WeightedValue("value", 1 << 20))
-    cache.put("key2", "cache_token", WeightedValue("value2", 1 << 20))
+    cache.put("key", "cache_token", "value")
+    self.assertEqual(cache.size(), 2)
+    self.verify_metrics(
+        cache,
+        {
+            'get': 0,
+            'put': 4,
+            'miss': 0,
+            'hit': 0,
+            'clear': 0,
+            'evict': 0,
+            'size': 2,
+            'capacity': 2
+        })
+
+  def test_evict_all(self):
+    cache = self.get_cache(5)
+    cache.put("key", "cache_token", "value")
+    cache.put("key2", "cache_token", "value2")
     self.assertEqual(cache.size(), 2)
-    cache.invalidate_all()
+    cache.evict_all()
     self.assertEqual(cache.size(), 0)
     self.assertEqual(cache.get("key", "cache_token"), None)
     self.assertEqual(cache.get("key2", "cache_token"), None)
-    self.assertEqual(
-        cache.describe_stats(),
-        'used/max 0/5 MB, hit 0.00%, lookups 2, evictions 0')
+    self.verify_metrics(
+        cache,
+        {
+            'get': 2,
+            'put': 2,
+            'miss': 2,
+            'hit': 0,
+            'clear': 0,
+            'evict': 0,
+            'size': 0,
+            'capacity': 5
+        })
 
   def test_lru(self):
-    cache = StateCache(5 << 20)
-    cache.put("key", "cache_token", WeightedValue("value", 1 << 20))
-    cache.put("key2", "cache_token2", WeightedValue("value2", 1 << 20))
-    cache.put("key3", "cache_token", WeightedValue("value0", 1 << 20))
-    cache.put("key3", "cache_token", WeightedValue("value3", 1 << 20))
-    cache.put("key4", "cache_token4", WeightedValue("value4", 1 << 20))
-    cache.put("key5", "cache_token", WeightedValue("value0", 1 << 20))
-    cache.put("key5", "cache_token", WeightedValue(["value5"], 1 << 20))
+    cache = self.get_cache(5)
+    cache.put("key", "cache_token", "value")
+    cache.put("key2", "cache_token2", "value2")
+    cache.put("key3", "cache_token", "value0")
+    cache.put("key3", "cache_token", "value3")
+    cache.put("key4", "cache_token4", "value4")
+    cache.put("key5", "cache_token", "value0")
+    cache.put("key5", "cache_token", ["value5"])
     self.assertEqual(cache.size(), 5)
     self.assertEqual(cache.get("key", "cache_token"), "value")
     self.assertEqual(cache.get("key2", "cache_token2"), "value2")
@@ -119,59 +150,80 @@ class StateCacheTest(unittest.TestCase):
     self.assertEqual(cache.get("key4", "cache_token4"), "value4")
     self.assertEqual(cache.get("key5", "cache_token"), ["value5"])
     # insert another key to trigger cache eviction
-    cache.put("key6", "cache_token2", WeightedValue("value6", 1 << 20))
+    cache.put("key6", "cache_token2", "value7")
     self.assertEqual(cache.size(), 5)
     # least recently used key should be gone ("key")
     self.assertEqual(cache.get("key", "cache_token"), None)
     # trigger a read on "key2"
     cache.get("key2", "cache_token2")
     # insert another key to trigger cache eviction
-    cache.put("key7", "cache_token", WeightedValue("value7", 1 << 20))
+    cache.put("key7", "cache_token", "value7")
     self.assertEqual(cache.size(), 5)
     # least recently used key should be gone ("key3")
     self.assertEqual(cache.get("key3", "cache_token"), None)
     # trigger a put on "key2"
-    cache.put("key2", "cache_token", WeightedValue("put", 1 << 20))
+    cache.put("key2", "cache_token", "put")
     self.assertEqual(cache.size(), 5)
     # insert another key to trigger cache eviction
-    cache.put("key8", "cache_token", WeightedValue("value8", 1 << 20))
+    cache.put("key8", "cache_token", "value8")
     self.assertEqual(cache.size(), 5)
     # least recently used key should be gone ("key4")
     self.assertEqual(cache.get("key4", "cache_token"), None)
     # make "key5" used by writing to it
-    cache.put("key5", "cache_token", WeightedValue("val", 1 << 20))
+    cache.put("key5", "cache_token", "val")
     # least recently used key should be gone ("key6")
     self.assertEqual(cache.get("key6", "cache_token"), None)
-    self.assertEqual(
-        cache.describe_stats(),
-        'used/max 5/5 MB, hit 60.00%, lookups 10, evictions 5')
+    self.verify_metrics(
+        cache,
+        {
+            'get': 10,
+            'put': 12,
+            'miss': 4,
+            'hit': 6,
+            'clear': 0,
+            'evict': 0,
+            'size': 5,
+            'capacity': 5
+        })
 
   def test_is_cached_enabled(self):
-    cache = StateCache(1 << 20)
+    cache = self.get_cache(1)
     self.assertEqual(cache.is_cache_enabled(), True)
-    self.assertEqual(
-        cache.describe_stats(),
-        'used/max 0/1 MB, hit 100.00%, lookups 0, evictions 0')
-    cache = StateCache(0)
+    self.verify_metrics(cache, {})
+    cache = self.get_cache(0)
     self.assertEqual(cache.is_cache_enabled(), False)
-    self.assertEqual(
-        cache.describe_stats(),
-        'used/max 0/0 MB, hit 100.00%, lookups 0, evictions 0')
-
-  def test_get_referents_for_cache(self):
-    class GetReferentsForCache(CacheAware):
-      def __init__(self):
-        self.measure_me = bytearray(1 << 20)
-        self.ignore_me = bytearray(2 << 20)
-
-      def get_referents_for_cache(self):
-        return [self.measure_me]
-
-    cache = StateCache(5 << 20)
-    cache.put("key", "cache_token", GetReferentsForCache())
-    self.assertEqual(
-        cache.describe_stats(),
-        'used/max 1/5 MB, hit 100.00%, lookups 0, evictions 0')
+    self.verify_metrics(cache, {})
+
+  def verify_metrics(self, cache, expected_metrics):
+    infos = cache.get_monitoring_infos()
+    # Reconstruct metrics dictionary from monitoring infos
+    metrics = {
+        info.urn.rsplit(':',
+                        1)[1]: monitoring_infos.extract_gauge_value(info)[1]
+        for info in infos if "_total" not in info.urn and
+        info.type == monitoring_infos.LATEST_INT64_TYPE
+    }
+    self.assertDictEqual(metrics, expected_metrics)
+    # Metrics and total metrics should be identical for a single bundle.
+    # The following two gauges are not part of the total metrics:
+    try:
+      del metrics['capacity']
+      del metrics['size']
+    except KeyError:
+      pass
+    total_metrics = {
+        info.urn.rsplit(':', 1)[1].rsplit("_total")[0]:
+        monitoring_infos.extract_counter_value(info)
+        for info in infos
+        if "_total" in info.urn and info.type == monitoring_infos.SUM_INT64_TYPE
+    }
+    self.assertDictEqual(metrics, total_metrics)
+
+  @staticmethod
+  def get_cache(size):
+    cache = StateCache(size)
+    cache.initialize_metrics()
+    return cache
 
 
 if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/runners/worker/worker_status.py b/sdks/python/apache_beam/runners/worker/worker_status.py
index 7604bd0867a..652b01d1e4e 100644
--- a/sdks/python/apache_beam/runners/worker/worker_status.py
+++ b/sdks/python/apache_beam/runners/worker/worker_status.py
@@ -30,7 +30,6 @@ import grpc
 from apache_beam.portability.api import beam_fn_api_pb2
 from apache_beam.portability.api import beam_fn_api_pb2_grpc
 from apache_beam.runners.worker.channel_factory import GRPCChannelFactory
-from apache_beam.runners.worker.statecache import StateCache
 from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor
 from apache_beam.utils.sentinel import Sentinel
 
@@ -96,18 +95,6 @@ def heap_dump():
   return banner + heap + ending
 
 
-def _state_cache_stats(state_cache):
-  #type: (StateCache) -> str
-
-  """Gather state cache statistics."""
-  cache_stats = ['=' * 10 + ' CACHE STATS ' + '=' * 10]
-  if not state_cache.is_cache_enabled():
-    cache_stats.append("Cache disabled")
-  else:
-    cache_stats.append(state_cache.describe_stats())
-  return '\n'.join(cache_stats)
-
-
 def _active_processing_bundles_state(bundle_process_cache):
   """Gather information about the currently in-processing active bundles.
 
@@ -151,7 +138,6 @@ class FnApiWorkerStatusHandler(object):
       self,
       status_address,
       bundle_process_cache=None,
-      state_cache=None,
       enable_heap_dump=False,
       log_lull_timeout_ns=DEFAULT_LOG_LULL_TIMEOUT_NS):
     """Initialize FnApiWorkerStatusHandler.
@@ -159,11 +145,9 @@ class FnApiWorkerStatusHandler(object):
     Args:
       status_address: The URL Runner uses to host the WorkerStatus server.
       bundle_process_cache: The BundleProcessor cache dict from sdk worker.
-      state_cache: The StateCache form sdk worker.
     """
     self._alive = True
     self._bundle_process_cache = bundle_process_cache
-    self._state_cache = state_cache
     ch = GRPCChannelFactory.insecure_channel(status_address)
     grpc.channel_ready_future(ch).result(timeout=60)
     self._status_channel = grpc.intercept_channel(ch, WorkerIdInterceptor())
@@ -209,14 +193,9 @@ class FnApiWorkerStatusHandler(object):
                   "status page: %s" % traceback_string))
 
   def generate_status_response(self):
-    all_status_sections = []
-
-    if self._state_cache:
-      all_status_sections.append(_state_cache_stats(self._state_cache))
-
-    if self._bundle_process_cache:
-      all_status_sections.append(
-          _active_processing_bundles_state(self._bundle_process_cache))
+    all_status_sections = [
+        _active_processing_bundles_state(self._bundle_process_cache)
+    ] if self._bundle_process_cache else []
 
     all_status_sections.append(thread_dump())
     if self._enable_heap_dump:
diff --git a/sdks/python/container/py37/base_image_requirements.txt b/sdks/python/container/py37/base_image_requirements.txt
index 898d4d8dab3..392a06f39d0 100644
--- a/sdks/python/container/py37/base_image_requirements.txt
+++ b/sdks/python/container/py37/base_image_requirements.txt
@@ -96,7 +96,6 @@ nose==1.3.7
 numpy==1.21.6
 oauth2client==4.1.3
 oauthlib==3.2.0
-objsize==0.5.1
 opt-einsum==3.3.0
 orjson==3.7.11
 overrides==6.1.0
diff --git a/sdks/python/container/py38/base_image_requirements.txt b/sdks/python/container/py38/base_image_requirements.txt
index 1539b561487..24a81d1dcaf 100644
--- a/sdks/python/container/py38/base_image_requirements.txt
+++ b/sdks/python/container/py38/base_image_requirements.txt
@@ -96,7 +96,6 @@ nose==1.3.7
 numpy==1.22.4
 oauth2client==4.1.3
 oauthlib==3.2.0
-objsize==0.5.1
 opt-einsum==3.3.0
 orjson==3.7.11
 overrides==6.1.0
diff --git a/sdks/python/container/py39/base_image_requirements.txt b/sdks/python/container/py39/base_image_requirements.txt
index 2861b6ddb96..67de8d2b356 100644
--- a/sdks/python/container/py39/base_image_requirements.txt
+++ b/sdks/python/container/py39/base_image_requirements.txt
@@ -96,7 +96,6 @@ nose==1.3.7
 numpy==1.22.4
 oauth2client==4.1.3
 oauthlib==3.2.0
-objsize==0.5.1
 opt-einsum==3.3.0
 orjson==3.7.11
 overrides==6.1.0
diff --git a/sdks/python/setup.py b/sdks/python/setup.py
index fa02a658612..f0863bf49cd 100644
--- a/sdks/python/setup.py
+++ b/sdks/python/setup.py
@@ -223,7 +223,6 @@ if __name__ == '__main__':
         'hdfs>=2.1.0,<3.0.0',
         'httplib2>=0.8,<0.21.0',
         'numpy>=1.14.3,<1.23.0',
-        'objsize>=0.5.1,<1',
         'pymongo>=3.8.0,<4.0.0',
         'protobuf>=3.12.2,<4',
         'proto-plus>=1.7.1,<2',