You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by da...@apache.org on 2023/08/11 22:49:25 UTC
[beam] branch master updated: Add ability to run per key inference (#27857)
This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new a5f1347699a Add ability to run per key inference (#27857)
a5f1347699a is described below
commit a5f1347699a3ed142d7d066922c3d1002f0b0f31
Author: Danny McCormick <da...@google.com>
AuthorDate: Fri Aug 11 18:49:17 2023 -0400
Add ability to run per key inference (#27857)
* Add ability to run per key inference
* lint
* lint
* address feedback
* lint
* Small feedback updates
---
sdks/python/apache_beam/ml/inference/base.py | 348 +++++++++++++++-------
sdks/python/apache_beam/ml/inference/base_test.py | 65 ++++
2 files changed, 308 insertions(+), 105 deletions(-)
diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py
index 8acdbaa5da1..5f2b4dc465f 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -35,11 +35,13 @@ import threading
import time
import uuid
from collections import OrderedDict
+from collections import defaultdict
from typing import Any
from typing import Callable
from typing import Dict
from typing import Generic
from typing import Iterable
+from typing import List
from typing import Mapping
from typing import NamedTuple
from typing import Optional
@@ -243,70 +245,278 @@ class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
return False
+class _ModelManager:
+ """
+ A class for efficiently managing copies of multiple models. Will load a
+ single copy of each model into a multi_process_shared object and then
+ return a lookup key for that object. Optionally takes in a max_models
+ parameter, if that is set it will only hold that many models in memory at
+ once before evicting one (using LRU logic).
+ """
+ def __init__(
+ self, mh_map: Dict[str, ModelHandler], max_models: Optional[int] = None):
+ """
+ Args:
+ mh_map: A map from keys to model handlers which can be used to load a
+ model.
+ max_models: The maximum number of models to load at any given time
+ before evicting 1 from memory (using LRU logic). Leave as None to
+ allow unlimited models.
+ """
+ self._max_models = max_models
+ self._mh_map: Dict[str, ModelHandler] = mh_map
+ self._proxy_map: Dict[str, str] = {}
+ self._tag_map: Dict[
+ str, multi_process_shared.MultiProcessShared] = OrderedDict()
+
+ def load(self, key: str) -> str:
+ """
+ Loads the appropriate model for the given key into memory.
+ Args:
+ key: the key associated with the model we'd like to load.
+ Returns:
+ the tag we can use to access the model using multi_process_shared.py.
+ """
+ # Map the key for a model to a unique tag that will persist until the model
+ # is released. This needs to be unique between releasing/reacquiring th
+ # model because otherwise the ProxyManager will try to reuse the model that
+ # has been released and deleted.
+ if key in self._tag_map:
+ self._tag_map.move_to_end(key)
+ else:
+ self._tag_map[key] = uuid.uuid4().hex
+
+ tag = self._tag_map[key]
+ mh = self._mh_map[key]
+
+ if self._max_models is not None and self._max_models < len(self._tag_map):
+ # If we're about to exceed our LRU size, release the last used model.
+ tag_to_remove = self._tag_map.popitem(last=False)[1]
+ shared_handle, model_to_remove = self._proxy_map[tag_to_remove]
+ shared_handle.release(model_to_remove)
+
+ # Load the new model
+ shared_handle = multi_process_shared.MultiProcessShared(
+ mh.load_model, tag=tag)
+ model_reference = shared_handle.acquire()
+ self._proxy_map[tag] = (shared_handle, model_reference)
+
+ return tag
+
+ def increment_max_models(self, increment: int):
+ """
+ Increments the number of models that this instance of a _ModelManager is
+ able to hold.
+ Args:
+ increment: the amount by which we are incrementing the number of models.
+ """
+ if self._max_models is None:
+ raise ValueError(
+ "Cannot increment max_models if self._max_models is None (unlimited" +
+ " models mode).")
+ self._max_models += increment
+
+
+# Use a dataclass instead of named tuple because NamedTuples and generics don't
+# mix well across the board for all versions:
+# https://github.com/python/typing/issues/653
+class KeyMhMapping(Generic[KeyT, ExampleT, PredictionT, ModelT]):
+ """
+ Dataclass for mapping 1 or more keys to 1 model handler.
+ Given `KeyMhMapping(['key1', 'key2'], myMh)`, all examples with keys `key1`
+ or `key2` will be run against the model defined by the `myMh` ModelHandler.
+ """
+ def __init__(
+ self, keys: List[KeyT], mh: ModelHandler[ExampleT, PredictionT, ModelT]):
+ self.keys = keys
+ self.mh = mh
+
+
class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
ModelHandler[Tuple[KeyT, ExampleT],
Tuple[KeyT, PredictionT],
- ModelT]):
- def __init__(self, unkeyed: ModelHandler[ExampleT, PredictionT, ModelT]):
+ Union[ModelT, _ModelManager]]):
+ def __init__(
+ self,
+ unkeyed: Union[ModelHandler[ExampleT, PredictionT, ModelT],
+ List[KeyMhMapping[KeyT, ExampleT, PredictionT, ModelT]]]):
"""A ModelHandler that takes keyed examples and returns keyed predictions.
For example, if the original model is used with RunInference to take a
PCollection[E] to a PCollection[P], this ModelHandler would take a
PCollection[Tuple[K, E]] to a PCollection[Tuple[K, P]], making it possible
- to use the key to associate the outputs with the inputs.
+ to use the key to associate the outputs with the inputs. KeyedModelHandler
+ is able to accept either a single unkeyed ModelHandler or many different
+ model handlers corresponding to the keys for which that ModelHandler should
+ be used. For example, the following configuration could be used to map keys
+ 1-3 to ModelHandler1 and keys 4-5 to ModelHandler2:
+
+ k1 = ['k1', 'k2', 'k3']
+ k2 = ['k4', 'k5']
+ KeyedModelHandler([KeyMhMapping(k1, mh1), KeyMhMapping(k2, mh2)])
+
+ Note that a single copy of each of these models may all be held in memory
+ at the same time; be careful not to load too many large models or your
+ pipeline may cause Out of Memory exceptions.
Args:
- unkeyed: An implementation of ModelHandler that does not require keys.
+ unkeyed: Either (a) an implementation of ModelHandler that does not
+ require keys or (b) a list of KeyMhMappings mapping lists of keys to
+ unkeyed ModelHandlers.
"""
- if len(unkeyed.get_preprocess_fns()) or len(unkeyed.get_postprocess_fns()):
- raise Exception(
- 'Cannot make make an unkeyed model handler with pre or '
- 'postprocessing functions defined into a keyed model handler. All '
- 'pre/postprocessing functions must be defined on the outer model'
- 'handler.')
- self._unkeyed = unkeyed
- self._env_vars = unkeyed._env_vars
-
- def load_model(self) -> ModelT:
- return self._unkeyed.load_model()
+ self._single_model = not isinstance(unkeyed, list)
+ if self._single_model:
+ if len(unkeyed.get_preprocess_fns()) or len(
+ unkeyed.get_postprocess_fns()):
+ raise Exception(
+ 'Cannot make make an unkeyed model handler with pre or '
+ 'postprocessing functions defined into a keyed model handler. All '
+ 'pre/postprocessing functions must be defined on the outer model'
+ 'handler.')
+ self._env_vars = unkeyed._env_vars
+ self._unkeyed = unkeyed
+ return
+
+ # To maintain an efficient representation, we will map all keys in a given
+ # KeyMhMapping to a single id (the first key in the KeyMhMapping list).
+ # We will then map that key to a ModelHandler. This will allow us to
+ # quickly look up the appropriate ModelHandler for any given key.
+ self._id_to_mh_map: Dict[str, ModelHandler[ExampleT, PredictionT,
+ ModelT]] = {}
+ self._key_to_id_map: Dict[str, str] = {}
+ for mh_tuple in unkeyed:
+ mh = mh_tuple.mh
+ keys = mh_tuple.keys
+ if len(mh.get_preprocess_fns()) or len(mh.get_postprocess_fns()):
+ raise ValueError(
+ 'Cannot use an unkeyed model handler with pre or '
+ 'postprocessing functions defined in a keyed model handler. All '
+ 'pre/postprocessing functions must be defined on the outer model'
+ 'handler.')
+ hints = mh.get_resource_hints()
+ if len(hints) > 0:
+ logging.warning(
+ 'mh %s defines the following resource hints, which will be'
+ 'ignored: %s. Resource hints are not respected when more than one '
+ 'model handler is used in a KeyedModelHandler. If you would like '
+ 'to specify resource hints, you can do so by overriding the '
+ 'KeyedModelHandler.get_resource_hints() method.',
+ mh,
+ hints)
+ batch_kwargs = mh.batch_elements_kwargs()
+ if len(hints) > 0:
+ logging.warning(
+ 'mh %s defines the following batching kwargs which will be '
+ 'ignored %s. Batching kwargs are not respected when '
+ 'more than one model handler is used in a KeyedModelHandler. If '
+ 'you would like to specify resource hints, you can do so by '
+ 'overriding the KeyedModelHandler.batch_elements_kwargs() method.',
+ hints,
+ batch_kwargs)
+ env_vars = mh._env_vars
+ if len(hints) > 0:
+ logging.warning(
+ 'mh %s defines the following _env_vars which will be ignored %s. '
+ '_env_vars are not respected when more than one model handler is '
+ 'used in a KeyedModelHandler. If you need env vars set at '
+ 'inference time, you can do so with '
+ 'a custom inference function.',
+ mh,
+ env_vars)
+
+ if len(keys) == 0:
+ raise ValueError(
+ f'Empty list maps to model handler {mh}. All model handlers must '
+ 'have one or more associated keys.')
+ self._id_to_mh_map[keys[0]] = mh
+ for key in keys:
+ if key in self._key_to_id_map:
+ raise ValueError(
+ f'key {key} maps to multiple model handlers. All keys must map '
+ 'to exactly one model handler.')
+ self._key_to_id_map[key] = keys[0]
+
+ def load_model(self) -> Union[ModelT, _ModelManager]:
+ if self._single_model:
+ return self._unkeyed.load_model()
+ return _ModelManager(self._id_to_mh_map)
def run_inference(
self,
batch: Sequence[Tuple[KeyT, ExampleT]],
- model: ModelT,
+ model: Union[ModelT, _ModelManager],
inference_args: Optional[Dict[str, Any]] = None
) -> Iterable[Tuple[KeyT, PredictionT]]:
- keys, unkeyed_batch = zip(*batch)
- return zip(
- keys, self._unkeyed.run_inference(unkeyed_batch, model, inference_args))
+ if self._single_model:
+ keys, unkeyed_batch = zip(*batch)
+ return zip(
+ keys,
+ self._unkeyed.run_inference(unkeyed_batch, model, inference_args))
+
+ batch_by_key = defaultdict(list)
+ key_by_id = defaultdict(set)
+ for key, example in batch:
+ batch_by_key[key].append(example)
+ key_by_id[self._key_to_id_map[key]].add(key)
+
+ predictions = []
+ for id, keys in key_by_id.items():
+ mh = self._id_to_mh_map[id]
+ keyed_model_tag = model.load(id)
+ keyed_model_shared_handle = multi_process_shared.MultiProcessShared(
+ mh.load_model, tag=keyed_model_tag)
+ keyed_model = keyed_model_shared_handle.acquire()
+ for key in keys:
+ unkeyed_batches = batch_by_key[key]
+ for inf in mh.run_inference(unkeyed_batches,
+ keyed_model,
+ inference_args):
+ predictions.append((key, inf))
+ keyed_model_shared_handle.release(keyed_model)
+
+ return predictions
def get_num_bytes(self, batch: Sequence[Tuple[KeyT, ExampleT]]) -> int:
- keys, unkeyed_batch = zip(*batch)
- return len(pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch)
+ if self._single_model:
+ keys, unkeyed_batch = zip(*batch)
+ return len(
+ pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch)
+ return len(pickle.dumps(batch))
def get_metrics_namespace(self) -> str:
- return self._unkeyed.get_metrics_namespace()
+ if self._single_model:
+ return self._unkeyed.get_metrics_namespace()
+ return 'BeamML_KeyedModels'
def get_resource_hints(self):
- return self._unkeyed.get_resource_hints()
+ if self._single_model:
+ return self._unkeyed.get_resource_hints()
+ return {}
def batch_elements_kwargs(self):
- return self._unkeyed.batch_elements_kwargs()
+ if self._single_model:
+ return self._unkeyed.batch_elements_kwargs()
+ return {}
def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
- return self._unkeyed.validate_inference_args(inference_args)
+ if self._single_model:
+ return self._unkeyed.validate_inference_args(inference_args)
+ for mh in self._id_to_mh_map.values():
+ mh.validate_inference_args(inference_args)
def update_model_path(self, model_path: Optional[str] = None):
- return self._unkeyed.update_model_path(model_path=model_path)
-
- def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
- return self._unkeyed.get_preprocess_fns()
-
- def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
- return self._unkeyed.get_postprocess_fns()
+ if self._single_model:
+ return self._unkeyed.update_model_path(model_path=model_path)
+ if model_path is not None:
+ raise RuntimeError(
+ 'Model updates are currently not supported for ' +
+ 'KeyedModelHandlers with multiple different per-key ' +
+ 'ModelHandlers.')
def share_model_across_processes(self) -> bool:
- return self._unkeyed.share_model_across_processes()
+ if self._single_model:
+ return self._unkeyed.share_model_across_processes()
+ return True
class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
@@ -740,78 +950,6 @@ class RunInference(beam.PTransform[beam.PCollection[ExampleT],
return self
-class _ModelManager:
- """
- A class for efficiently managing copies of multiple models. Will load a
- single copy of each model into a multi_process_shared object and then
- return a lookup key for that object. Optionally takes in a max_models
- parameter, if that is set it will only hold that many models in memory at
- once before evicting one (using LRU logic).
- """
- def __init__(
- self, mh_map: Dict[str, ModelHandler], max_models: Optional[int] = None):
- """
- Args:
- mh_map: A map from keys to model handlers which can be used to load a
- model.
- max_models: The maximum number of models to load at any given time
- before evicting 1 from memory (using LRU logic). Leave as None to
- allow unlimited models.
- """
- self._max_models = max_models
- self._mh_map: Dict[str, ModelHandler] = mh_map
- self._proxy_map: Dict[str, str] = {}
- self._tag_map: Dict[
- str, multi_process_shared.MultiProcessShared] = OrderedDict()
-
- def load(self, key: str) -> str:
- """
- Loads the appropriate model for the given key into memory.
- Args:
- key: the key associated with the model we'd like to load.
- Returns:
- the tag we can use to access the model using multi_process_shared.py.
- """
- # Map the key for a model to a unique tag that will persist until the model
- # is released. This needs to be unique between releasing/reacquiring th
- # model because otherwise the ProxyManager will try to reuse the model that
- # has been released and deleted.
- if key in self._tag_map:
- self._tag_map.move_to_end(key)
- else:
- self._tag_map[key] = uuid.uuid4().hex
-
- tag = self._tag_map[key]
- mh = self._mh_map[key]
-
- if self._max_models is not None and self._max_models < len(self._tag_map):
- # If we're about to exceed our LRU size, release the last used model.
- tag_to_remove = self._tag_map.popitem(last=False)[1]
- shared_handle, model_to_remove = self._proxy_map[tag_to_remove]
- shared_handle.release(model_to_remove)
-
- # Load the new model
- shared_handle = multi_process_shared.MultiProcessShared(
- mh.load_model, tag=tag)
- model_reference = shared_handle.acquire()
- self._proxy_map[tag] = (shared_handle, model_reference)
-
- return tag
-
- def increment_max_models(self, increment: int):
- """
- Increments the number of models that this instance of a _ModelManager is
- able to hold.
- Args:
- increment: the amount by which we are incrementing the number of models.
- """
- if self._max_models is None:
- raise ValueError(
- "Cannot increment max_models if self._max_models is None (unlimited" +
- " models mode).")
- self._max_models += increment
-
-
class _MetricsCollector:
"""
A metrics collector that tracks ML related performance and memory usage.
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py
index be91efb9479..c79189718a9 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -254,6 +254,71 @@ class RunInferenceBaseTest(unittest.TestCase):
base.KeyedModelHandler(FakeModelHandler()))
assert_that(actual, equal_to(expected), label='assert:inferences')
+ def test_run_inference_impl_with_keyed_examples_many_model_handlers(self):
+ with TestPipeline() as pipeline:
+ examples = [1, 5, 3, 10]
+ keyed_examples = [(i, example) for i, example in enumerate(examples)]
+ expected = [(i, example + 1) for i, example in enumerate(examples)]
+ expected[0] = (0, 200)
+ pcoll = pipeline | 'start' >> beam.Create(keyed_examples)
+ mhs = [
+ base.KeyMhMapping([0],
+ FakeModelHandler(
+ state=200, multi_process_shared=True)),
+ base.KeyMhMapping([1, 2, 3],
+ FakeModelHandler(multi_process_shared=True))
+ ]
+ actual = pcoll | base.RunInference(base.KeyedModelHandler(mhs))
+ assert_that(actual, equal_to(expected), label='assert:inferences')
+
+ def test_keyed_many_model_handlers_validation(self):
+ def mult_two(example: str) -> int:
+ return int(example) * 2
+
+ mhs = [
+ base.KeyMhMapping(
+ [0],
+ FakeModelHandler(
+ state=200,
+ multi_process_shared=True).with_preprocess_fn(mult_two)),
+ base.KeyMhMapping([1, 2, 3],
+ FakeModelHandler(multi_process_shared=True))
+ ]
+ with self.assertRaises(ValueError):
+ base.KeyedModelHandler(mhs)
+
+ mhs = [
+ base.KeyMhMapping(
+ [0],
+ FakeModelHandler(
+ state=200,
+ multi_process_shared=True).with_postprocess_fn(mult_two)),
+ base.KeyMhMapping([1, 2, 3],
+ FakeModelHandler(multi_process_shared=True))
+ ]
+ with self.assertRaises(ValueError):
+ base.KeyedModelHandler(mhs)
+
+ mhs = [
+ base.KeyMhMapping([0],
+ FakeModelHandler(
+ state=200, multi_process_shared=True)),
+ base.KeyMhMapping([0, 1, 2, 3],
+ FakeModelHandler(multi_process_shared=True))
+ ]
+ with self.assertRaises(ValueError):
+ base.KeyedModelHandler(mhs)
+
+ mhs = [
+ base.KeyMhMapping([],
+ FakeModelHandler(
+ state=200, multi_process_shared=True)),
+ base.KeyMhMapping([0, 1, 2, 3],
+ FakeModelHandler(multi_process_shared=True))
+ ]
+ with self.assertRaises(ValueError):
+ base.KeyedModelHandler(mhs)
+
def test_run_inference_impl_with_maybe_keyed_examples(self):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]