You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by da...@apache.org on 2023/08/11 22:49:25 UTC

[beam] branch master updated: Add ability to run per key inference (#27857)

This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new a5f1347699a Add ability to run per key inference (#27857)
a5f1347699a is described below

commit a5f1347699a3ed142d7d066922c3d1002f0b0f31
Author: Danny McCormick <da...@google.com>
AuthorDate: Fri Aug 11 18:49:17 2023 -0400

    Add ability to run per key inference (#27857)
    
    * Add ability to run per key inference
    
    * lint
    
    * lint
    
    * address feedback
    
    * lint
    
    * Small feedback updates
---
 sdks/python/apache_beam/ml/inference/base.py      | 348 +++++++++++++++-------
 sdks/python/apache_beam/ml/inference/base_test.py |  65 ++++
 2 files changed, 308 insertions(+), 105 deletions(-)

diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py
index 8acdbaa5da1..5f2b4dc465f 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -35,11 +35,13 @@ import threading
 import time
 import uuid
 from collections import OrderedDict
+from collections import defaultdict
 from typing import Any
 from typing import Callable
 from typing import Dict
 from typing import Generic
 from typing import Iterable
+from typing import List
 from typing import Mapping
 from typing import NamedTuple
 from typing import Optional
@@ -243,70 +245,278 @@ class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
     return False
 
 
+class _ModelManager:
+  """
+  A class for efficiently managing copies of multiple models. Will load a
+  single copy of each model into a multi_process_shared object and then
+  return a lookup key for that object. Optionally takes in a max_models
+  parameter, if that is set it will only hold that many models in memory at
+  once before evicting one (using LRU logic).
+  """
+  def __init__(
+      self, mh_map: Dict[str, ModelHandler], max_models: Optional[int] = None):
+    """
+    Args:
+      mh_map: A map from keys to model handlers which can be used to load a
+        model.
+      max_models: The maximum number of models to load at any given time
+        before evicting 1 from memory (using LRU logic). Leave as None to
+        allow unlimited models.
+    """
+    self._max_models = max_models
+    self._mh_map: Dict[str, ModelHandler] = mh_map
+    self._proxy_map: Dict[str, str] = {}
+    self._tag_map: Dict[
+        str, multi_process_shared.MultiProcessShared] = OrderedDict()
+
+  def load(self, key: str) -> str:
+    """
+    Loads the appropriate model for the given key into memory.
+    Args:
+      key: the key associated with the model we'd like to load.
+    Returns:
+      the tag we can use to access the model using multi_process_shared.py.
+    """
+    # Map the key for a model to a unique tag that will persist until the model
+    # is released. This needs to be unique between releasing/reacquiring th
+    # model because otherwise the ProxyManager will try to reuse the model that
+    # has been released and deleted.
+    if key in self._tag_map:
+      self._tag_map.move_to_end(key)
+    else:
+      self._tag_map[key] = uuid.uuid4().hex
+
+    tag = self._tag_map[key]
+    mh = self._mh_map[key]
+
+    if self._max_models is not None and self._max_models < len(self._tag_map):
+      # If we're about to exceed our LRU size, release the last used model.
+      tag_to_remove = self._tag_map.popitem(last=False)[1]
+      shared_handle, model_to_remove = self._proxy_map[tag_to_remove]
+      shared_handle.release(model_to_remove)
+
+    # Load the new model
+    shared_handle = multi_process_shared.MultiProcessShared(
+        mh.load_model, tag=tag)
+    model_reference = shared_handle.acquire()
+    self._proxy_map[tag] = (shared_handle, model_reference)
+
+    return tag
+
+  def increment_max_models(self, increment: int):
+    """
+    Increments the number of models that this instance of a _ModelManager is
+    able to hold.
+    Args:
+      increment: the amount by which we are incrementing the number of models.
+    """
+    if self._max_models is None:
+      raise ValueError(
+          "Cannot increment max_models if self._max_models is None (unlimited" +
+          " models mode).")
+    self._max_models += increment
+
+
+# Use a dataclass instead of named tuple because NamedTuples and generics don't
+# mix well across the board for all versions:
+# https://github.com/python/typing/issues/653
+class KeyMhMapping(Generic[KeyT, ExampleT, PredictionT, ModelT]):
+  """
+  Dataclass for mapping 1 or more keys to 1 model handler.
+  Given `KeyMhMapping(['key1', 'key2'], myMh)`, all examples with keys `key1`
+  or `key2` will be run against the model defined by the `myMh` ModelHandler.
+  """
+  def __init__(
+      self, keys: List[KeyT], mh: ModelHandler[ExampleT, PredictionT, ModelT]):
+    self.keys = keys
+    self.mh = mh
+
+
 class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
                         ModelHandler[Tuple[KeyT, ExampleT],
                                      Tuple[KeyT, PredictionT],
-                                     ModelT]):
-  def __init__(self, unkeyed: ModelHandler[ExampleT, PredictionT, ModelT]):
+                                     Union[ModelT, _ModelManager]]):
+  def __init__(
+      self,
+      unkeyed: Union[ModelHandler[ExampleT, PredictionT, ModelT],
+                     List[KeyMhMapping[KeyT, ExampleT, PredictionT, ModelT]]]):
     """A ModelHandler that takes keyed examples and returns keyed predictions.
 
     For example, if the original model is used with RunInference to take a
     PCollection[E] to a PCollection[P], this ModelHandler would take a
     PCollection[Tuple[K, E]] to a PCollection[Tuple[K, P]], making it possible
-    to use the key to associate the outputs with the inputs.
+    to use the key to associate the outputs with the inputs. KeyedModelHandler
+    is able to accept either a single unkeyed ModelHandler or many different
+    model handlers corresponding to the keys for which that ModelHandler should
+    be used. For example, the following configuration could be used to map keys
+    1-3 to ModelHandler1 and keys 4-5 to ModelHandler2:
+
+        k1 = ['k1', 'k2', 'k3']
+        k2 = ['k4', 'k5']
+        KeyedModelHandler([KeyMhMapping(k1, mh1), KeyMhMapping(k2, mh2)])
+
+    Note that a single copy of each of these models may all be held in memory
+    at the same time; be careful not to load too many large models or your
+    pipeline may cause Out of Memory exceptions.
 
     Args:
-      unkeyed: An implementation of ModelHandler that does not require keys.
+      unkeyed: Either (a) an implementation of ModelHandler that does not
+        require keys or (b) a list of KeyMhMappings mapping lists of keys to
+        unkeyed ModelHandlers.
     """
-    if len(unkeyed.get_preprocess_fns()) or len(unkeyed.get_postprocess_fns()):
-      raise Exception(
-          'Cannot make make an unkeyed model handler with pre or '
-          'postprocessing functions defined into a keyed model handler. All '
-          'pre/postprocessing functions must be defined on the outer model'
-          'handler.')
-    self._unkeyed = unkeyed
-    self._env_vars = unkeyed._env_vars
-
-  def load_model(self) -> ModelT:
-    return self._unkeyed.load_model()
+    self._single_model = not isinstance(unkeyed, list)
+    if self._single_model:
+      if len(unkeyed.get_preprocess_fns()) or len(
+          unkeyed.get_postprocess_fns()):
+        raise Exception(
+            'Cannot make make an unkeyed model handler with pre or '
+            'postprocessing functions defined into a keyed model handler. All '
+            'pre/postprocessing functions must be defined on the outer model'
+            'handler.')
+      self._env_vars = unkeyed._env_vars
+      self._unkeyed = unkeyed
+      return
+
+    # To maintain an efficient representation, we will map all keys in a given
+    # KeyMhMapping to a single id (the first key in the KeyMhMapping list).
+    # We will then map that key to a ModelHandler. This will allow us to
+    # quickly look up the appropriate ModelHandler for any given key.
+    self._id_to_mh_map: Dict[str, ModelHandler[ExampleT, PredictionT,
+                                               ModelT]] = {}
+    self._key_to_id_map: Dict[str, str] = {}
+    for mh_tuple in unkeyed:
+      mh = mh_tuple.mh
+      keys = mh_tuple.keys
+      if len(mh.get_preprocess_fns()) or len(mh.get_postprocess_fns()):
+        raise ValueError(
+            'Cannot use an unkeyed model handler with pre or '
+            'postprocessing functions defined in a keyed model handler. All '
+            'pre/postprocessing functions must be defined on the outer model'
+            'handler.')
+      hints = mh.get_resource_hints()
+      if len(hints) > 0:
+        logging.warning(
+            'mh %s defines the following resource hints, which will be'
+            'ignored: %s. Resource hints are not respected when more than one '
+            'model handler is used in a KeyedModelHandler. If you would like '
+            'to specify resource hints, you can do so by overriding the '
+            'KeyedModelHandler.get_resource_hints() method.',
+            mh,
+            hints)
+      batch_kwargs = mh.batch_elements_kwargs()
+      if len(hints) > 0:
+        logging.warning(
+            'mh %s defines the following batching kwargs which will be '
+            'ignored %s. Batching kwargs are not respected when '
+            'more than one model handler is used in a KeyedModelHandler. If '
+            'you would like to specify resource hints, you can do so by '
+            'overriding the KeyedModelHandler.batch_elements_kwargs() method.',
+            hints,
+            batch_kwargs)
+      env_vars = mh._env_vars
+      if len(hints) > 0:
+        logging.warning(
+            'mh %s defines the following _env_vars which will be ignored %s. '
+            '_env_vars are not respected when more than one model handler is '
+            'used in a KeyedModelHandler. If you need env vars set at '
+            'inference time, you can do so with '
+            'a custom inference function.',
+            mh,
+            env_vars)
+
+      if len(keys) == 0:
+        raise ValueError(
+            f'Empty list maps to model handler {mh}. All model handlers must '
+            'have one or more associated keys.')
+      self._id_to_mh_map[keys[0]] = mh
+      for key in keys:
+        if key in self._key_to_id_map:
+          raise ValueError(
+              f'key {key} maps to multiple model handlers. All keys must map '
+              'to exactly one model handler.')
+        self._key_to_id_map[key] = keys[0]
+
+  def load_model(self) -> Union[ModelT, _ModelManager]:
+    if self._single_model:
+      return self._unkeyed.load_model()
+    return _ModelManager(self._id_to_mh_map)
 
   def run_inference(
       self,
       batch: Sequence[Tuple[KeyT, ExampleT]],
-      model: ModelT,
+      model: Union[ModelT, _ModelManager],
       inference_args: Optional[Dict[str, Any]] = None
   ) -> Iterable[Tuple[KeyT, PredictionT]]:
-    keys, unkeyed_batch = zip(*batch)
-    return zip(
-        keys, self._unkeyed.run_inference(unkeyed_batch, model, inference_args))
+    if self._single_model:
+      keys, unkeyed_batch = zip(*batch)
+      return zip(
+          keys,
+          self._unkeyed.run_inference(unkeyed_batch, model, inference_args))
+
+    batch_by_key = defaultdict(list)
+    key_by_id = defaultdict(set)
+    for key, example in batch:
+      batch_by_key[key].append(example)
+      key_by_id[self._key_to_id_map[key]].add(key)
+
+    predictions = []
+    for id, keys in key_by_id.items():
+      mh = self._id_to_mh_map[id]
+      keyed_model_tag = model.load(id)
+      keyed_model_shared_handle = multi_process_shared.MultiProcessShared(
+          mh.load_model, tag=keyed_model_tag)
+      keyed_model = keyed_model_shared_handle.acquire()
+      for key in keys:
+        unkeyed_batches = batch_by_key[key]
+        for inf in mh.run_inference(unkeyed_batches,
+                                    keyed_model,
+                                    inference_args):
+          predictions.append((key, inf))
+      keyed_model_shared_handle.release(keyed_model)
+
+    return predictions
 
   def get_num_bytes(self, batch: Sequence[Tuple[KeyT, ExampleT]]) -> int:
-    keys, unkeyed_batch = zip(*batch)
-    return len(pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch)
+    if self._single_model:
+      keys, unkeyed_batch = zip(*batch)
+      return len(
+          pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch)
+    return len(pickle.dumps(batch))
 
   def get_metrics_namespace(self) -> str:
-    return self._unkeyed.get_metrics_namespace()
+    if self._single_model:
+      return self._unkeyed.get_metrics_namespace()
+    return 'BeamML_KeyedModels'
 
   def get_resource_hints(self):
-    return self._unkeyed.get_resource_hints()
+    if self._single_model:
+      return self._unkeyed.get_resource_hints()
+    return {}
 
   def batch_elements_kwargs(self):
-    return self._unkeyed.batch_elements_kwargs()
+    if self._single_model:
+      return self._unkeyed.batch_elements_kwargs()
+    return {}
 
   def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
-    return self._unkeyed.validate_inference_args(inference_args)
+    if self._single_model:
+      return self._unkeyed.validate_inference_args(inference_args)
+    for mh in self._id_to_mh_map.values():
+      mh.validate_inference_args(inference_args)
 
   def update_model_path(self, model_path: Optional[str] = None):
-    return self._unkeyed.update_model_path(model_path=model_path)
-
-  def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
-    return self._unkeyed.get_preprocess_fns()
-
-  def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
-    return self._unkeyed.get_postprocess_fns()
+    if self._single_model:
+      return self._unkeyed.update_model_path(model_path=model_path)
+    if model_path is not None:
+      raise RuntimeError(
+          'Model updates are currently not supported for ' +
+          'KeyedModelHandlers with multiple different per-key ' +
+          'ModelHandlers.')
 
   def share_model_across_processes(self) -> bool:
-    return self._unkeyed.share_model_across_processes()
+    if self._single_model:
+      return self._unkeyed.share_model_across_processes()
+    return True
 
 
 class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
@@ -740,78 +950,6 @@ class RunInference(beam.PTransform[beam.PCollection[ExampleT],
     return self
 
 
-class _ModelManager:
-  """
-  A class for efficiently managing copies of multiple models. Will load a
-  single copy of each model into a multi_process_shared object and then
-  return a lookup key for that object. Optionally takes in a max_models
-  parameter, if that is set it will only hold that many models in memory at
-  once before evicting one (using LRU logic).
-  """
-  def __init__(
-      self, mh_map: Dict[str, ModelHandler], max_models: Optional[int] = None):
-    """
-    Args:
-      mh_map: A map from keys to model handlers which can be used to load a
-        model.
-      max_models: The maximum number of models to load at any given time
-        before evicting 1 from memory (using LRU logic). Leave as None to
-        allow unlimited models.
-    """
-    self._max_models = max_models
-    self._mh_map: Dict[str, ModelHandler] = mh_map
-    self._proxy_map: Dict[str, str] = {}
-    self._tag_map: Dict[
-        str, multi_process_shared.MultiProcessShared] = OrderedDict()
-
-  def load(self, key: str) -> str:
-    """
-    Loads the appropriate model for the given key into memory.
-    Args:
-      key: the key associated with the model we'd like to load.
-    Returns:
-      the tag we can use to access the model using multi_process_shared.py.
-    """
-    # Map the key for a model to a unique tag that will persist until the model
-    # is released. This needs to be unique between releasing/reacquiring th
-    # model because otherwise the ProxyManager will try to reuse the model that
-    # has been released and deleted.
-    if key in self._tag_map:
-      self._tag_map.move_to_end(key)
-    else:
-      self._tag_map[key] = uuid.uuid4().hex
-
-    tag = self._tag_map[key]
-    mh = self._mh_map[key]
-
-    if self._max_models is not None and self._max_models < len(self._tag_map):
-      # If we're about to exceed our LRU size, release the last used model.
-      tag_to_remove = self._tag_map.popitem(last=False)[1]
-      shared_handle, model_to_remove = self._proxy_map[tag_to_remove]
-      shared_handle.release(model_to_remove)
-
-    # Load the new model
-    shared_handle = multi_process_shared.MultiProcessShared(
-        mh.load_model, tag=tag)
-    model_reference = shared_handle.acquire()
-    self._proxy_map[tag] = (shared_handle, model_reference)
-
-    return tag
-
-  def increment_max_models(self, increment: int):
-    """
-    Increments the number of models that this instance of a _ModelManager is
-    able to hold.
-    Args:
-      increment: the amount by which we are incrementing the number of models.
-    """
-    if self._max_models is None:
-      raise ValueError(
-          "Cannot increment max_models if self._max_models is None (unlimited" +
-          " models mode).")
-    self._max_models += increment
-
-
 class _MetricsCollector:
   """
   A metrics collector that tracks ML related performance and memory usage.
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py
index be91efb9479..c79189718a9 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -254,6 +254,71 @@ class RunInferenceBaseTest(unittest.TestCase):
           base.KeyedModelHandler(FakeModelHandler()))
       assert_that(actual, equal_to(expected), label='assert:inferences')
 
+  def test_run_inference_impl_with_keyed_examples_many_model_handlers(self):
+    with TestPipeline() as pipeline:
+      examples = [1, 5, 3, 10]
+      keyed_examples = [(i, example) for i, example in enumerate(examples)]
+      expected = [(i, example + 1) for i, example in enumerate(examples)]
+      expected[0] = (0, 200)
+      pcoll = pipeline | 'start' >> beam.Create(keyed_examples)
+      mhs = [
+          base.KeyMhMapping([0],
+                            FakeModelHandler(
+                                state=200, multi_process_shared=True)),
+          base.KeyMhMapping([1, 2, 3],
+                            FakeModelHandler(multi_process_shared=True))
+      ]
+      actual = pcoll | base.RunInference(base.KeyedModelHandler(mhs))
+      assert_that(actual, equal_to(expected), label='assert:inferences')
+
+  def test_keyed_many_model_handlers_validation(self):
+    def mult_two(example: str) -> int:
+      return int(example) * 2
+
+    mhs = [
+        base.KeyMhMapping(
+            [0],
+            FakeModelHandler(
+                state=200,
+                multi_process_shared=True).with_preprocess_fn(mult_two)),
+        base.KeyMhMapping([1, 2, 3],
+                          FakeModelHandler(multi_process_shared=True))
+    ]
+    with self.assertRaises(ValueError):
+      base.KeyedModelHandler(mhs)
+
+    mhs = [
+        base.KeyMhMapping(
+            [0],
+            FakeModelHandler(
+                state=200,
+                multi_process_shared=True).with_postprocess_fn(mult_two)),
+        base.KeyMhMapping([1, 2, 3],
+                          FakeModelHandler(multi_process_shared=True))
+    ]
+    with self.assertRaises(ValueError):
+      base.KeyedModelHandler(mhs)
+
+    mhs = [
+        base.KeyMhMapping([0],
+                          FakeModelHandler(
+                              state=200, multi_process_shared=True)),
+        base.KeyMhMapping([0, 1, 2, 3],
+                          FakeModelHandler(multi_process_shared=True))
+    ]
+    with self.assertRaises(ValueError):
+      base.KeyedModelHandler(mhs)
+
+    mhs = [
+        base.KeyMhMapping([],
+                          FakeModelHandler(
+                              state=200, multi_process_shared=True)),
+        base.KeyMhMapping([0, 1, 2, 3],
+                          FakeModelHandler(multi_process_shared=True))
+    ]
+    with self.assertRaises(ValueError):
+      base.KeyedModelHandler(mhs)
+
   def test_run_inference_impl_with_maybe_keyed_examples(self):
     with TestPipeline() as pipeline:
       examples = [1, 5, 3, 10]