You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by da...@apache.org on 2022/09/18 02:17:08 UTC
[beam] branch master updated: Add drop_example flag to the RunInference and Model Handler (#23266)
This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new f477b85f230 Add drop_example flag to the RunInference and Model Handler (#23266)
f477b85f230 is described below
commit f477b85f230ebb5dbd6b62540da078a33e3318ce
Author: Anand Inguva <34...@users.noreply.github.com>
AuthorDate: Sat Sep 17 22:16:57 2022 -0400
Add drop_example flag to the RunInference and Model Handler (#23266)
* Add drop_example flag to the RunInference and Model Handler
* Pass drop_example to the _convert_to_result
* Refactor _convert_to_result
* Return _convert_to_result
* Fixup lint
* Code update based on Suggestions
* Refactor class name
* Refactor TensorRT to add drop_example
* Add typing.Optional to the type hint for drop_example
* Add _convert_to_result for tensorRT run_inference
* fixup lint
---
sdks/python/apache_beam/ml/inference/base.py | 52 +++++++++++++++++-----
sdks/python/apache_beam/ml/inference/base_test.py | 49 ++++++++++++++++++--
.../apache_beam/ml/inference/pytorch_inference.py | 35 +++++----------
.../apache_beam/ml/inference/sklearn_inference.py | 36 +++++----------
.../apache_beam/ml/inference/tensorrt_inference.py | 16 ++++---
5 files changed, 118 insertions(+), 70 deletions(-)
diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py
index 8b88809d329..4f59bc43ad3 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -82,6 +82,24 @@ def _to_microseconds(time_ns: int) -> int:
return int(time_ns / _NANOSECOND_TO_MICROSECOND)
+def _convert_to_result(
+ batch: Iterable,
+ predictions: Union[Iterable, Dict[Any, Iterable]],
+ drop_example: Optional[bool] = False) -> Iterable[PredictionResult]:
+ if isinstance(predictions, dict):
+ # Go from one dictionary of type: {key_type1: Iterable<val_type1>,
+ # key_type2: Iterable<val_type2>, ...} where each Iterable is of
+ # length batch_size, to a list of dictionaries:
+ # [{key_type1: value_type1, key_type2: value_type2}]
+ predictions = [
+ dict(zip(predictions.keys(), v)) for v in zip(*predictions.values())
+ ]
+ if drop_example:
+ return [PredictionResult(None, y) for x, y in zip(batch, predictions)]
+
+ return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
+
+
class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
"""Has the ability to load and apply an ML model."""
def load_model(self) -> ModelT:
@@ -92,7 +110,8 @@ class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
self,
batch: Sequence[ExampleT],
model: ModelT,
- inference_args: Optional[Dict[str, Any]] = None) -> Iterable[PredictionT]:
+ inference_args: Optional[Dict[str, Any]] = None,
+ drop_example: Optional[bool] = False) -> Iterable[PredictionT]:
"""Runs inferences on a batch of examples.
Args:
@@ -100,7 +119,8 @@ class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
model: The model used to make inferences.
inference_args: Extra arguments for models whose inference call requires
extra parameters.
-
+ drop_example: Boolean flag indicating whether to
+ drop the example from PredictionResult
Returns:
An Iterable of Predictions.
"""
@@ -170,7 +190,8 @@ class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
self,
batch: Sequence[Tuple[KeyT, ExampleT]],
model: ModelT,
- inference_args: Optional[Dict[str, Any]] = None
+ inference_args: Optional[Dict[str, Any]] = None,
+ drop_example: Optional[bool] = False
) -> Iterable[Tuple[KeyT, PredictionT]]:
keys, unkeyed_batch = zip(*batch)
return zip(
@@ -225,7 +246,8 @@ class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
self,
batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]],
model: ModelT,
- inference_args: Optional[Dict[str, Any]] = None
+ inference_args: Optional[Dict[str, Any]] = None,
+ drop_example: Optional[bool] = False
) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]:
# Really the input should be
# Union[Sequence[ExampleT], Sequence[Tuple[KeyT, ExampleT]]]
@@ -273,7 +295,9 @@ class RunInference(beam.PTransform[beam.PCollection[ExampleT],
model_handler: ModelHandler[ExampleT, PredictionT, Any],
clock=time,
inference_args: Optional[Dict[str, Any]] = None,
- metrics_namespace: Optional[str] = None):
+ metrics_namespace: Optional[str] = None,
+ drop_example: Optional[bool] = False,
+ ):
"""A transform that takes a PCollection of examples (or features) to be used
on an ML model. It will then output inferences (or predictions) for those
examples in a PCollection of PredictionResults, containing the input
@@ -291,11 +315,13 @@ class RunInference(beam.PTransform[beam.PCollection[ExampleT],
inference_args: Extra arguments for models whose inference call requires
extra parameters.
metrics_namespace: Namespace of the transform to collect metrics.
- """
+ drop_example: Boolean flag indicating whether to
+ drop the example from PredictionResult """
self._model_handler = model_handler
self._inference_args = inference_args
self._clock = clock
self._metrics_namespace = metrics_namespace
+ self._drop_example = drop_example
# TODO(BEAM-14046): Add and link to help documentation.
@classmethod
@@ -327,7 +353,10 @@ class RunInference(beam.PTransform[beam.PCollection[ExampleT],
| 'BeamML_RunInference' >> (
beam.ParDo(
_RunInferenceDoFn(
- self._model_handler, self._clock, self._metrics_namespace),
+ model_handler=self._model_handler,
+ clock=self._clock,
+ metrics_namespace=self._metrics_namespace,
+ drop_example=self._drop_example),
self._inference_args).with_resource_hints(**resource_hints)))
@@ -385,19 +414,22 @@ class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]):
self,
model_handler: ModelHandler[ExampleT, PredictionT, Any],
clock,
- metrics_namespace):
+ metrics_namespace: Optional[str],
+ drop_example: Optional[bool] = False):
"""A DoFn implementation generic to frameworks.
Args:
model_handler: An implementation of ModelHandler.
clock: A clock implementing time_ns. *Used for unit testing.*
metrics_namespace: Namespace of the transform to collect metrics.
- """
+ drop_example: Boolean flag indicating whether to
+ drop the example from PredictionResult """
self._model_handler = model_handler
self._shared_model_handle = shared.Shared()
self._clock = clock
self._model = None
self._metrics_namespace = metrics_namespace
+ self._drop_example = drop_example
def _load_model(self):
def load():
@@ -427,7 +459,7 @@ class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]):
def process(self, batch, inference_args):
start_time = _to_microseconds(self._clock.time_ns())
result_generator = self._model_handler.run_inference(
- batch, self._model, inference_args)
+ batch, self._model, inference_args, self._drop_example)
predictions = list(result_generator)
end_time = _to_microseconds(self._clock.time_ns())
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py
index 278485666a0..1f74f1868c0 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -48,13 +48,38 @@ class FakeModelHandler(base.ModelHandler[int, int, FakeModel]):
self,
batch: Sequence[int],
model: FakeModel,
- inference_args=None) -> Iterable[int]:
+ inference_args=None,
+ drop_example=False) -> Iterable[int]:
if self._fake_clock:
self._fake_clock.current_time_ns += 3_000_000 # 3 milliseconds
for example in batch:
yield model.predict(example)
+class FakeModelHandlerReturnsPredictionResult(
+ base.ModelHandler[int, base.PredictionResult, FakeModel]):
+ def __init__(self, clock=None):
+ self._fake_clock = clock
+
+ def load_model(self):
+ if self._fake_clock:
+ self._fake_clock.current_time_ns += 500_000_000 # 500ms
+ return FakeModel()
+
+ def run_inference(
+ self,
+ batch: Sequence[int],
+ model: FakeModel,
+ inference_args=None,
+ drop_example=False) -> Iterable[base.PredictionResult]:
+ if self._fake_clock:
+ self._fake_clock.current_time_ns += 3_000_000 # 3 milliseconds
+
+ predictions = [model.predict(example) for example in batch]
+ return base._convert_to_result(
+ batch=batch, predictions=predictions, drop_example=drop_example)
+
+
class FakeClock:
def __init__(self):
# Start at 10 seconds.
@@ -70,7 +95,8 @@ class ExtractInferences(beam.DoFn):
class FakeModelHandlerNeedsBigBatch(FakeModelHandler):
- def run_inference(self, batch, unused_model, inference_args=None):
+ def run_inference(
+ self, batch, model, inference_args=None, drop_example=False):
if len(batch) < 100:
raise ValueError('Unexpectedly small batch')
return batch
@@ -80,14 +106,16 @@ class FakeModelHandlerNeedsBigBatch(FakeModelHandler):
class FakeModelHandlerFailsOnInferenceArgs(FakeModelHandler):
- def run_inference(self, batch, unused_model, inference_args=None):
+ def run_inference(
+ self, batch, model, inference_args=None, drop_example=False):
raise ValueError(
'run_inference should not be called because error should already be '
'thrown from the validate_inference_args check.')
class FakeModelHandlerExpectedInferenceArgs(FakeModelHandler):
- def run_inference(self, batch, unused_model, inference_args=None):
+ def run_inference(
+ self, batch, model, inference_args=None, drop_example=False):
if not inference_args:
raise ValueError('inference_args should exist')
return batch
@@ -251,6 +279,19 @@ class RunInferenceBaseTest(unittest.TestCase):
| 'RunKeyed' >> base.RunInference(model_handler))
pipeline.run()
+ def test_drop_example_prediction_result(self):
+ def assert_drop_example(prediction_result):
+ assert prediction_result.example is None
+
+ pipeline = TestPipeline()
+ examples = [1, 3, 5]
+ model_handler = FakeModelHandlerReturnsPredictionResult()
+ _ = (
+ pipeline | 'keyed' >> beam.Create(examples)
+ | 'RunKeyed' >> base.RunInference(model_handler, drop_example=True)
+ | beam.Map(assert_drop_example))
+ pipeline.run()
+
if __name__ == '__main__':
unittest.main()
diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
index d97205937e2..9d80e2fc1ef 100644
--- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py
+++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
@@ -25,12 +25,12 @@ from typing import Dict
from typing import Iterable
from typing import Optional
from typing import Sequence
-from typing import Union
import torch
from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.base import _convert_to_result
from apache_beam.utils.annotations import experimental
__all__ = [
@@ -83,23 +83,6 @@ def _convert_to_device(examples: torch.Tensor, device) -> torch.Tensor:
return examples
-def _convert_to_result(
- batch: Iterable, predictions: Union[Iterable, Dict[Any, Iterable]]
-) -> Iterable[PredictionResult]:
- if isinstance(predictions, dict):
- # Go from one dictionary of type: {key_type1: Iterable<val_type1>,
- # key_type2: Iterable<val_type2>, ...} where each Iterable is of
- # length batch_size, to a list of dictionaries:
- # [{key_type1: value_type1, key_type2: value_type2}]
- predictions_per_tensor = [
- dict(zip(predictions.keys(), v)) for v in zip(*predictions.values())
- ]
- return [
- PredictionResult(x, y) for x, y in zip(batch, predictions_per_tensor)
- ]
- return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
-
-
class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
PredictionResult,
torch.nn.Module]):
@@ -152,8 +135,8 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
self,
batch: Sequence[torch.Tensor],
model: torch.nn.Module,
- inference_args: Optional[Dict[str, Any]] = None
- ) -> Iterable[PredictionResult]:
+ inference_args: Optional[Dict[str, Any]] = None,
+ drop_example: Optional[bool] = False) -> Iterable[PredictionResult]:
"""
Runs inferences on a batch of Tensors and returns an Iterable of
Tensor Predictions.
@@ -170,7 +153,8 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
inference_args: Non-batchable arguments required as inputs to the model's
forward() function. Unlike Tensors in `batch`, these parameters will
not be dynamically batched
-
+ drop_example: Boolean flag indicating whether to
+ drop the example from PredictionResult
Returns:
An Iterable of type PredictionResult.
"""
@@ -182,7 +166,7 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
batched_tensors = torch.stack(batch)
batched_tensors = _convert_to_device(batched_tensors, self._device)
predictions = model(batched_tensors, **inference_args)
- return _convert_to_result(batch, predictions)
+ return _convert_to_result(batch, predictions, drop_example=drop_example)
def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int:
"""
@@ -259,7 +243,8 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor],
self,
batch: Sequence[Dict[str, torch.Tensor]],
model: torch.nn.Module,
- inference_args: Optional[Dict[str, Any]] = None
+ inference_args: Optional[Dict[str, Any]] = None,
+ drop_example: Optional[bool] = False,
) -> Iterable[PredictionResult]:
"""
Runs inferences on a batch of Keyed Tensors and returns an Iterable of
@@ -277,6 +262,8 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor],
inference_args: Non-batchable arguments required as inputs to the model's
forward() function. Unlike Tensors in `batch`, these parameters will
not be dynamically batched
+ drop_example: Boolean flag indicating whether to
+ drop the example from PredictionResult
Returns:
An Iterable of type PredictionResult.
@@ -300,7 +287,7 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor],
key_to_batched_tensors[key] = batched_tensors
predictions = model(**key_to_batched_tensors, **inference_args)
- return _convert_to_result(batch, predictions)
+ return _convert_to_result(batch, predictions, drop_example=drop_example)
def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int:
"""
diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
index 8fd1899ef96..62e22f4fd5a 100644
--- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py
+++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
@@ -23,7 +23,6 @@ from typing import Dict
from typing import Iterable
from typing import Optional
from typing import Sequence
-from typing import Union
import numpy
import pandas
@@ -32,6 +31,7 @@ from sklearn.base import BaseEstimator
from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.base import _convert_to_result
from apache_beam.utils.annotations import experimental
try:
@@ -67,23 +67,6 @@ def _load_model(model_uri, file_type):
raise AssertionError('Unsupported serialization type.')
-def _convert_to_result(
- batch: Iterable, predictions: Union[Iterable, Dict[Any, Iterable]]
-) -> Iterable[PredictionResult]:
- if isinstance(predictions, dict):
- # Go from one dictionary of type: {key_type1: Iterable<val_type1>,
- # key_type2: Iterable<val_type2>, ...} where each Iterable is of
- # length batch_size, to a list of dictionaries:
- # [{key_type1: value_type1, key_type2: value_type2}]
- predictions_per_tensor = [
- dict(zip(predictions.keys(), v)) for v in zip(*predictions.values())
- ]
- return [
- PredictionResult(x, y) for x, y in zip(batch, predictions_per_tensor)
- ]
- return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
-
-
class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,
PredictionResult,
BaseEstimator]):
@@ -114,8 +97,8 @@ class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,
self,
batch: Sequence[numpy.ndarray],
model: BaseEstimator,
- inference_args: Optional[Dict[str, Any]] = None
- ) -> Iterable[PredictionResult]:
+ inference_args: Optional[Dict[str, Any]] = None,
+ drop_example: Optional[bool] = False) -> Iterable[PredictionResult]:
"""Runs inferences on a batch of numpy arrays.
Args:
@@ -124,6 +107,8 @@ class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,
model: A numpy model or pipeline. Must implement predict(X).
Where the parameter X is a numpy array.
inference_args: Any additional arguments for an inference.
+ drop_example: Boolean flag indicating whether to
+ drop the example from PredictionResult
Returns:
An Iterable of type PredictionResult.
@@ -132,7 +117,7 @@ class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,
vectorized_batch = numpy.stack(batch, axis=0)
predictions = model.predict(vectorized_batch)
- return _convert_to_result(batch, predictions)
+ return _convert_to_result(batch, predictions, drop_example=drop_example)
def get_num_bytes(self, batch: Sequence[pandas.DataFrame]) -> int:
"""
@@ -183,8 +168,8 @@ class SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame,
self,
batch: Sequence[pandas.DataFrame],
model: BaseEstimator,
- inference_args: Optional[Dict[str, Any]] = None
- ) -> Iterable[PredictionResult]:
+ inference_args: Optional[Dict[str, Any]] = None,
+ drop_example: Optional[bool] = False) -> Iterable[PredictionResult]:
"""
Runs inferences on a batch of pandas dataframes.
@@ -194,7 +179,8 @@ class SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame,
model: A dataframe model or pipeline. Must implement predict(X).
Where the parameter X is a pandas dataframe.
inference_args: Any additional arguments for an inference.
-
+ drop_example: Boolean flag indicating whether to
+ drop the example from PredictionResult
Returns:
An Iterable of type PredictionResult.
"""
@@ -210,7 +196,7 @@ class SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame,
vectorized_batch.iloc[[i]] for i in range(vectorized_batch.shape[0])
]
- return _convert_to_result(splits, predictions)
+ return _convert_to_result(splits, predictions, drop_example=drop_example)
def get_num_bytes(self, batch: Sequence[pandas.DataFrame]) -> int:
"""
diff --git a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py
index 8ff65658c6b..64e2e6dcaa6 100644
--- a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py
+++ b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py
@@ -33,6 +33,7 @@ import numpy as np
from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.base import _convert_to_result
from apache_beam.utils.annotations import experimental
LOGGER = logging.getLogger("TensorRTEngineHandlerNumPy")
@@ -225,7 +226,8 @@ class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray,
self,
batch: Sequence[np.ndarray],
engine: TensorRTEngine,
- inference_args: Optional[Dict[str, Any]] = None
+ inference_args: Optional[Dict[str, Any]] = None,
+ drop_example: Optional[bool] = False,
) -> Iterable[PredictionResult]:
"""
Runs inferences on a batch of Tensors and returns an Iterable of
@@ -270,11 +272,11 @@ class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray,
stream))
_assign_or_fail(cuda.cuStreamSynchronize(stream))
- return [
- PredictionResult(
- x, [prediction[idx] for prediction in cpu_allocations]) for idx,
- x in enumerate(batch)
- ]
+ predictions = []
+ for idx in range(len(batch)):
+ predictions.append([prediction[idx] for prediction in cpu_allocations])
+
+ return _convert_to_result(batch, predictions, drop_example=drop_example)
def get_num_bytes(self, batch: Sequence[np.ndarray]) -> int:
"""
@@ -287,4 +289,4 @@ class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray,
"""
Returns a namespace for metrics collected by the RunInference transform.
"""
- return 'RunInferenceTensorRT'
+ return 'BeamML_TensorRT'