You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by da...@apache.org on 2023/05/13 01:09:18 UTC

[beam] branch users/damccorm/mpsRi created (now dbdc95a6552)

This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a change to branch users/damccorm/mpsRi
in repository https://gitbox.apache.org/repos/asf/beam.git


      at dbdc95a6552 Allow model handlers to request multi_process_shared model

This branch includes the following new commits:

     new dbdc95a6552 Allow model handlers to request multi_process_shared model

The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.



[beam] 01/01: Allow model handlers to request multi_process_shared model

Posted by da...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch users/damccorm/mpsRi
in repository https://gitbox.apache.org/repos/asf/beam.git

commit dbdc95a6552b242563e3d7c9a08d9f95449ca869
Author: Danny McCormick <da...@google.com>
AuthorDate: Fri May 12 21:08:29 2023 -0400

    Allow model handlers to request multi_process_shared model
---
 sdks/python/apache_beam/ml/inference/base.py      |  36 ++++-
 sdks/python/apache_beam/ml/inference/base_test.py | 174 +++++++++++++++++++++-
 2 files changed, 205 insertions(+), 5 deletions(-)

diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py
index 0a62c26887b..a60f7365b42 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -33,6 +33,7 @@ import pickle
 import sys
 import threading
 import time
+import uuid
 from typing import Any
 from typing import Callable
 from typing import Dict
@@ -47,6 +48,7 @@ from typing import TypeVar
 from typing import Union
 
 import apache_beam as beam
+from apache_beam.utils import multi_process_shared
 from apache_beam.utils import shared
 
 try:
@@ -227,6 +229,15 @@ class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
     inference result in order from first applied to last applied."""
     return _PostProcessingModelHandler(self, fn)
 
+  def share_model_across_processes(self) -> bool:
+    """Returns a boolean representing whether or not a model should
+    be shared across multiple processes instead of being loaded per process.
+    This is primary useful for large models that  can't fit multiple copies in
+    memory. Multi-process support may vary by runner, but this will fallback to
+    loading per process as necessary. See
+    https://beam.apache.org/releases/pydoc/current/apache_beam.utils.multi_process_shared.html"""
+    return False
+
 
 class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
                         ModelHandler[Tuple[KeyT, ExampleT],
@@ -290,6 +301,9 @@ class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
   def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
     return self._unkeyed.get_postprocess_fns()
 
+  def share_model_across_processes(self) -> bool:
+    return self._unkeyed.share_model_across_processes()
+
 
 class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
                              ModelHandler[Union[ExampleT, Tuple[KeyT,
@@ -379,6 +393,9 @@ class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
   def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
     return self._unkeyed.get_postprocess_fns()
 
+  def share_model_across_processes(self) -> bool:
+    return self._unkeyed.share_model_across_processes()
+
 
 class _PreProcessingModelHandler(Generic[ExampleT,
                                          PredictionT,
@@ -538,6 +555,9 @@ class RunInference(beam.PTransform[beam.PCollection[ExampleT],
     self._with_exception_handling = False
     self._watch_model_pattern = watch_model_pattern
     self._kwargs = kwargs
+    # Generate a random tag to use for shared.py and multi_process_shared.py to
+    # allow us to effectively disambiguate in multi-model settings.
+    self._model_tag = uuid.uuid4().hex
 
   def _get_model_metadata_pcoll(self, pipeline):
     # avoid circular imports.
@@ -623,7 +643,8 @@ class RunInference(beam.PTransform[beam.PCollection[ExampleT],
             self._model_handler,
             self._clock,
             self._metrics_namespace,
-            self._enable_side_input_loading),
+            self._enable_side_input_loading,
+            self._model_tag),
         self._inference_args,
         beam.pvalue.AsSingleton(
             self._model_metadata_pcoll,
@@ -780,7 +801,8 @@ class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]):
       model_handler: ModelHandler[ExampleT, PredictionT, Any],
       clock,
       metrics_namespace,
-      enable_side_input_loading: bool = False):
+      enable_side_input_loading: bool = False,
+      model_tag: str = "RunInference"):
     """A DoFn implementation generic to frameworks.
 
       Args:
@@ -789,6 +811,7 @@ class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]):
         metrics_namespace: Namespace of the transform to collect metrics.
         enable_side_input_loading: Bool to indicate if model updates
             with side inputs.
+        model_tag: Tag to use to disambiguate models in multi-model settings.
     """
     self._model_handler = model_handler
     self._shared_model_handle = shared.Shared()
@@ -797,6 +820,7 @@ class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]):
     self._metrics_namespace = metrics_namespace
     self._enable_side_input_loading = enable_side_input_loading
     self._side_input_path = None
+    self._model_tag = model_tag
 
   def _load_model(self, side_input_model_path: Optional[str] = None):
     def load():
@@ -815,7 +839,13 @@ class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]):
 
     # TODO(https://github.com/apache/beam/issues/21443): Investigate releasing
     # model.
-    model = self._shared_model_handle.acquire(load, tag=side_input_model_path)
+    if self._model_handler.share_model_across_processes():
+      # TODO - make this a more robust tag than 'RunInference'
+      model = multi_process_shared.MultiProcessShared(
+          load, tag=side_input_model_path or self._model_tag).acquire()
+    else:
+      model = self._shared_model_handle.acquire(
+          load, tag=side_input_model_path or self._model_tag)
     # since shared_model_handle is shared across threads, the model path
     # might not get updated in the model handler
     # because we directly get cached weak ref model from shared cache, instead
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py
index 455dfe208b1..afd336f9a47 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -50,11 +50,17 @@ class FakeModel:
 
 class FakeModelHandler(base.ModelHandler[int, int, FakeModel]):
   def __init__(
-      self, clock=None, min_batch_size=1, max_batch_size=9999, **kwargs):
+      self,
+      clock=None,
+      min_batch_size=1,
+      max_batch_size=9999,
+      multi_process_shared=False,
+      **kwargs):
     self._fake_clock = clock
     self._min_batch_size = min_batch_size
     self._max_batch_size = max_batch_size
     self._env_vars = kwargs.get('env_vars', {})
+    self._multi_process_shared = multi_process_shared
 
   def load_model(self):
     if self._fake_clock:
@@ -66,6 +72,12 @@ class FakeModelHandler(base.ModelHandler[int, int, FakeModel]):
       batch: Sequence[int],
       model: FakeModel,
       inference_args=None) -> Iterable[int]:
+    multi_process_shared_loaded = "multi_process_shared" in str(type(model))
+    if self._multi_process_shared != multi_process_shared_loaded:
+      raise Exception(
+          f'Loaded model of type {type(model)}, was' +
+          f'{"" if self._multi_process_shared else " not"} ' +
+          'expecting multi_process_shared_model')
     if self._fake_clock:
       self._fake_clock.current_time_ns += 3_000_000  # 3 milliseconds
     for example in batch:
@@ -80,13 +92,21 @@ class FakeModelHandler(base.ModelHandler[int, int, FakeModel]):
         'max_batch_size': self._max_batch_size
     }
 
+  def share_model_across_processes(self):
+    return self._multi_process_shared
+
 
 class FakeModelHandlerReturnsPredictionResult(
     base.ModelHandler[int, base.PredictionResult, FakeModel]):
-  def __init__(self, clock=None, model_id='fake_model_id_default'):
+  def __init__(
+      self,
+      clock=None,
+      model_id='fake_model_id_default',
+      multi_process_shared=False):
     self.model_id = model_id
     self._fake_clock = clock
     self._env_vars = {}
+    self._multi_process_shared = multi_process_shared
 
   def load_model(self):
     return FakeModel()
@@ -96,6 +116,12 @@ class FakeModelHandlerReturnsPredictionResult(
       batch: Sequence[int],
       model: FakeModel,
       inference_args=None) -> Iterable[base.PredictionResult]:
+    multi_process_shared_loaded = "multi_process_shared" in str(type(model))
+    if self._multi_process_shared != multi_process_shared_loaded:
+      raise Exception(
+          f'Loaded model of type {type(model)}, was' +
+          f'{"" if self._multi_process_shared else " not"} ' +
+          'expecting multi_process_shared_model')
     for example in batch:
       yield base.PredictionResult(
           model_id=self.model_id,
@@ -105,6 +131,9 @@ class FakeModelHandlerReturnsPredictionResult(
   def update_model_path(self, model_path: Optional[str] = None):
     self.model_id = model_path if model_path else self.model_id
 
+  def share_model_across_processes(self):
+    return self._multi_process_shared
+
 
 class FakeClock:
   def __init__(self):
@@ -156,6 +185,15 @@ class RunInferenceBaseTest(unittest.TestCase):
       actual = pcoll | base.RunInference(FakeModelHandler())
       assert_that(actual, equal_to(expected), label='assert:inferences')
 
+  def test_run_inference_impl_simple_examples_multi_process_shared(self):
+    with TestPipeline() as pipeline:
+      examples = [1, 5, 3, 10]
+      expected = [example + 1 for example in examples]
+      pcoll = pipeline | 'start' >> beam.Create(examples)
+      actual = pcoll | base.RunInference(
+          FakeModelHandler(multi_process_shared=True))
+      assert_that(actual, equal_to(expected), label='assert:inferences')
+
   def test_run_inference_impl_with_keyed_examples(self):
     with TestPipeline() as pipeline:
       examples = [1, 5, 3, 10]
@@ -183,6 +221,35 @@ class RunInferenceBaseTest(unittest.TestCase):
           model_handler)
       assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed')
 
+  def test_run_inference_impl_with_keyed_examples_multi_process_shared(self):
+    with TestPipeline() as pipeline:
+      examples = [1, 5, 3, 10]
+      keyed_examples = [(i, example) for i, example in enumerate(examples)]
+      expected = [(i, example + 1) for i, example in enumerate(examples)]
+      pcoll = pipeline | 'start' >> beam.Create(keyed_examples)
+      actual = pcoll | base.RunInference(
+          base.KeyedModelHandler(FakeModelHandler(multi_process_shared=True)))
+      assert_that(actual, equal_to(expected), label='assert:inferences')
+
+  def test_run_inference_impl_with_maybe_keyed_examples_multi_process_shared(
+      self):
+    with TestPipeline() as pipeline:
+      examples = [1, 5, 3, 10]
+      keyed_examples = [(i, example) for i, example in enumerate(examples)]
+      expected = [example + 1 for example in examples]
+      keyed_expected = [(i, example + 1) for i, example in enumerate(examples)]
+      model_handler = base.MaybeKeyedModelHandler(
+          FakeModelHandler(multi_process_shared=True))
+
+      pcoll = pipeline | 'Unkeyed' >> beam.Create(examples)
+      actual = pcoll | 'RunUnkeyed' >> base.RunInference(model_handler)
+      assert_that(actual, equal_to(expected), label='CheckUnkeyed')
+
+      keyed_pcoll = pipeline | 'Keyed' >> beam.Create(keyed_examples)
+      keyed_actual = keyed_pcoll | 'RunKeyed' >> base.RunInference(
+          model_handler)
+      assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed')
+
   def test_run_inference_preprocessing(self):
     def mult_two(example: str) -> int:
       return int(example) * 2
@@ -634,6 +701,31 @@ class RunInferenceBaseTest(unittest.TestCase):
         'singleton view. First two elements encountered are' in str(
             e.exception))
 
+  def test_run_inference_with_iterable_side_input_multi_process_shared(self):
+    test_pipeline = TestPipeline()
+    side_input = (
+        test_pipeline | "CreateDummySideInput" >> beam.Create(
+            [base.ModelMetadata(1, 1), base.ModelMetadata(2, 2)])
+        | "ApplySideInputWindow" >> beam.WindowInto(
+            window.GlobalWindows(),
+            trigger=trigger.Repeatedly(trigger.AfterProcessingTime(1)),
+            accumulation_mode=trigger.AccumulationMode.DISCARDING))
+
+    test_pipeline.options.view_as(StandardOptions).streaming = True
+    with self.assertRaises(ValueError) as e:
+      _ = (
+          test_pipeline
+          | beam.Create([1, 2, 3, 4])
+          | base.RunInference(
+              FakeModelHandler(multi_process_shared=True),
+              model_metadata_pcoll=side_input))
+      test_pipeline.run()
+
+    self.assertTrue(
+        'PCollection of size 2 with more than one element accessed as a '
+        'singleton view. First two elements encountered are' in str(
+            e.exception))
+
   def test_run_inference_empty_side_input(self):
     model_handler = FakeModelHandlerReturnsPredictionResult()
     main_input_elements = [1, 2]
@@ -727,6 +819,84 @@ class RunInferenceBaseTest(unittest.TestCase):
 
       assert_that(result_pcoll, equal_to(expected_result))
 
+  def test_run_inference_side_input_in_batch_multi_process_shared(self):
+    first_ts = math.floor(time.time()) - 30
+    interval = 7
+
+    sample_main_input_elements = ([
+        first_ts - 2,
+        first_ts + 1,
+        first_ts + 8,
+        first_ts + 15,
+        first_ts + 22,
+    ])
+
+    sample_side_input_elements = [
+        (first_ts + 1, base.ModelMetadata(model_id='', model_name='')),
+        # if model_id is empty string, we use the default model
+        # handler model URI.
+        (
+            first_ts + 8,
+            base.ModelMetadata(
+                model_id='fake_model_id_1', model_name='fake_model_id_1')),
+        (
+            first_ts + 15,
+            base.ModelMetadata(
+                model_id='fake_model_id_2', model_name='fake_model_id_2'))
+    ]
+
+    model_handler = FakeModelHandlerReturnsPredictionResult(
+        multi_process_shared=True)
+
+    # applying GroupByKey to utilize windowing according to
+    # https://beam.apache.org/documentation/programming-guide/#windowing-bounded-collections
+    class _EmitElement(beam.DoFn):
+      def process(self, element):
+        for e in element:
+          yield e
+
+    with TestPipeline() as pipeline:
+      side_input = (
+          pipeline
+          |
+          "CreateSideInputElements" >> beam.Create(sample_side_input_elements)
+          | beam.Map(lambda x: TimestampedValue(x[1], x[0]))
+          | beam.WindowInto(
+              window.FixedWindows(interval),
+              accumulation_mode=trigger.AccumulationMode.DISCARDING)
+          | beam.Map(lambda x: ('key', x))
+          | beam.GroupByKey()
+          | beam.Map(lambda x: x[1])
+          | "EmitSideInput" >> beam.ParDo(_EmitElement()))
+
+      result_pcoll = (
+          pipeline
+          | beam.Create(sample_main_input_elements)
+          | "MapTimeStamp" >> beam.Map(lambda x: TimestampedValue(x, x))
+          | "ApplyWindow" >> beam.WindowInto(window.FixedWindows(interval))
+          | beam.Map(lambda x: ('key', x))
+          | "MainInputGBK" >> beam.GroupByKey()
+          | beam.Map(lambda x: x[1])
+          | beam.ParDo(_EmitElement())
+          | "RunInference" >> base.RunInference(
+              model_handler, model_metadata_pcoll=side_input))
+
+      expected_model_id_order = [
+          'fake_model_id_default',
+          'fake_model_id_default',
+          'fake_model_id_1',
+          'fake_model_id_2',
+          'fake_model_id_2'
+      ]
+      expected_result = [
+          base.PredictionResult(
+              example=sample_main_input_elements[i],
+              inference=sample_main_input_elements[i] + 1,
+              model_id=expected_model_id_order[i]) for i in range(5)
+      ]
+
+      assert_that(result_pcoll, equal_to(expected_result))
+
   @unittest.skipIf(
       not TestPipeline().get_pipeline_options().view_as(
           StandardOptions).streaming,