You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by bh...@apache.org on 2022/05/25 16:37:50 UTC

[beam] branch master updated: [BEAM-14044] Allow ModelLoader to forward BatchElements args (#17527)

This is an automated email from the ASF dual-hosted git repository.

bhulette pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 0a6fa95a04d [BEAM-14044] Allow ModelLoader to forward BatchElements args (#17527)
0a6fa95a04d is described below

commit 0a6fa95a04dc9485510e3bd034513d566155f61a
Author: zwestrick <10...@users.noreply.github.com>
AuthorDate: Wed May 25 09:37:43 2022 -0700

    [BEAM-14044] Allow ModelLoader to forward BatchElements args (#17527)
    
    * Updates ModelLoader to allow defining arguments to BatchElements
    
    * Update base.py
    
    * Adds unit test for batch arg forwarding
    
    * Fixes run_inference_base -> base
    
    * lint changes
    
    * lint changes
    
    * fmt fix
    
    * lint changes
---
 sdks/python/apache_beam/ml/inference/base.py      |  7 ++++++-
 sdks/python/apache_beam/ml/inference/base_test.py | 22 ++++++++++++++++++++++
 2 files changed, 28 insertions(+), 1 deletion(-)

diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py
index 3c7e6fe0e3d..49753c4e7a3 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -38,6 +38,7 @@ from typing import Any
 from typing import Generic
 from typing import Iterable
 from typing import List
+from typing import Mapping
 from typing import TypeVar
 
 import apache_beam as beam
@@ -82,6 +83,10 @@ class ModelLoader(Generic[T]):
     """Returns an implementation of InferenceRunner for this model."""
     raise NotImplementedError(type(self))
 
+  def batch_elements_kwargs(self) -> Mapping[str, Any]:
+    """Returns kwargs suitable for beam.BatchElements."""
+    return {}
+
 
 class RunInference(beam.PTransform):
   """An extensible transform for running inferences."""
@@ -95,7 +100,7 @@ class RunInference(beam.PTransform):
     return (
         pcoll
         # TODO(BEAM-14044): Hook into the batching DoFn APIs.
-        | beam.BatchElements()
+        | beam.BatchElements(**self._model_loader.batch_elements_kwargs())
         | beam.ParDo(_RunInferenceDoFn(self._model_loader, self._clock)))
 
 
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py
index 55936f63ed4..384ee9426d0 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -72,6 +72,21 @@ class ExtractInferences(beam.DoFn):
     yield prediction_result.inference
 
 
+class FakeInferenceRunnerNeedsBigBatch(FakeInferenceRunner):
+  def run_inference(self, batch, unused_model):
+    if len(batch) < 100:
+      raise ValueError('Unexpectedly small batch')
+    return batch
+
+
+class FakeLoaderWithBatchArgForwarding(FakeModelLoader):
+  def get_inference_runner(self):
+    return FakeInferenceRunnerNeedsBigBatch()
+
+  def batch_elements_kwargs(self):
+    return {'min_batch_size': 9999}
+
+
 class RunInferenceBaseTest(unittest.TestCase):
   def test_run_inference_impl_simple_examples(self):
     with TestPipeline() as pipeline:
@@ -142,6 +157,13 @@ class RunInferenceBaseTest(unittest.TestCase):
     self.assertEqual(load_model_latency.result.count, 1)
     self.assertEqual(load_model_latency.result.mean, 50)
 
+  def test_forwards_batch_args(self):
+    examples = list(range(100))
+    with TestPipeline() as pipeline:
+      pcoll = pipeline | 'start' >> beam.Create(examples)
+      actual = pcoll | base.RunInference(FakeLoaderWithBatchArgForwarding())
+      assert_that(actual, equal_to(examples), label='assert:inferences')
+
 
 if __name__ == '__main__':
   unittest.main()