You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by bh...@apache.org on 2022/05/25 16:37:50 UTC
[beam] branch master updated: [BEAM-14044] Allow ModelLoader to forward BatchElements args (#17527)
This is an automated email from the ASF dual-hosted git repository.
bhulette pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 0a6fa95a04d [BEAM-14044] Allow ModelLoader to forward BatchElements args (#17527)
0a6fa95a04d is described below
commit 0a6fa95a04dc9485510e3bd034513d566155f61a
Author: zwestrick <10...@users.noreply.github.com>
AuthorDate: Wed May 25 09:37:43 2022 -0700
[BEAM-14044] Allow ModelLoader to forward BatchElements args (#17527)
* Updates ModelLoader to allow defining arguments to BatchElements
* Update base.py
* Adds unit test for batch arg forwarding
* Fixes run_inference_base -> base
* lint changes
* lint changes
* fmt fix
* lint changes
---
sdks/python/apache_beam/ml/inference/base.py | 7 ++++++-
sdks/python/apache_beam/ml/inference/base_test.py | 22 ++++++++++++++++++++++
2 files changed, 28 insertions(+), 1 deletion(-)
diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py
index 3c7e6fe0e3d..49753c4e7a3 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -38,6 +38,7 @@ from typing import Any
from typing import Generic
from typing import Iterable
from typing import List
+from typing import Mapping
from typing import TypeVar
import apache_beam as beam
@@ -82,6 +83,10 @@ class ModelLoader(Generic[T]):
"""Returns an implementation of InferenceRunner for this model."""
raise NotImplementedError(type(self))
+ def batch_elements_kwargs(self) -> Mapping[str, Any]:
+ """Returns kwargs suitable for beam.BatchElements."""
+ return {}
+
class RunInference(beam.PTransform):
"""An extensible transform for running inferences."""
@@ -95,7 +100,7 @@ class RunInference(beam.PTransform):
return (
pcoll
# TODO(BEAM-14044): Hook into the batching DoFn APIs.
- | beam.BatchElements()
+ | beam.BatchElements(**self._model_loader.batch_elements_kwargs())
| beam.ParDo(_RunInferenceDoFn(self._model_loader, self._clock)))
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py
index 55936f63ed4..384ee9426d0 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -72,6 +72,21 @@ class ExtractInferences(beam.DoFn):
yield prediction_result.inference
+class FakeInferenceRunnerNeedsBigBatch(FakeInferenceRunner):
+ def run_inference(self, batch, unused_model):
+ if len(batch) < 100:
+ raise ValueError('Unexpectedly small batch')
+ return batch
+
+
+class FakeLoaderWithBatchArgForwarding(FakeModelLoader):
+ def get_inference_runner(self):
+ return FakeInferenceRunnerNeedsBigBatch()
+
+ def batch_elements_kwargs(self):
+ return {'min_batch_size': 9999}
+
+
class RunInferenceBaseTest(unittest.TestCase):
def test_run_inference_impl_simple_examples(self):
with TestPipeline() as pipeline:
@@ -142,6 +157,13 @@ class RunInferenceBaseTest(unittest.TestCase):
self.assertEqual(load_model_latency.result.count, 1)
self.assertEqual(load_model_latency.result.mean, 50)
+ def test_forwards_batch_args(self):
+ examples = list(range(100))
+ with TestPipeline() as pipeline:
+ pcoll = pipeline | 'start' >> beam.Create(examples)
+ actual = pcoll | base.RunInference(FakeLoaderWithBatchArgForwarding())
+ assert_that(actual, equal_to(examples), label='assert:inferences')
+
if __name__ == '__main__':
unittest.main()