You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by da...@apache.org on 2022/09/18 02:17:08 UTC

[beam] branch master updated: Add drop_example flag to the RunInference and Model Handler (#23266)

This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new f477b85f230 Add drop_example flag to the RunInference and Model Handler (#23266)
f477b85f230 is described below

commit f477b85f230ebb5dbd6b62540da078a33e3318ce
Author: Anand Inguva <34...@users.noreply.github.com>
AuthorDate: Sat Sep 17 22:16:57 2022 -0400

    Add drop_example flag to the RunInference and Model Handler (#23266)
    
    * Add drop_example flag to the RunInference and Model Handler
    
    * Pass drop_example to the _convert_to_result
    
    * Refactor _convert_to_result
    
    * Return _convert_to_result
    
    * Fixup lint
    
    * Code update based on Suggestions
    
    * Refactor class name
    
    * Refactor TensorRT to add drop_example
    
    * Add typing.Optional to the type hint for drop_example
    
    * Add _convert_to_result for tensorRT run_inference
    
    * fixup lint
---
 sdks/python/apache_beam/ml/inference/base.py       | 52 +++++++++++++++++-----
 sdks/python/apache_beam/ml/inference/base_test.py  | 49 ++++++++++++++++++--
 .../apache_beam/ml/inference/pytorch_inference.py  | 35 +++++----------
 .../apache_beam/ml/inference/sklearn_inference.py  | 36 +++++----------
 .../apache_beam/ml/inference/tensorrt_inference.py | 16 ++++---
 5 files changed, 118 insertions(+), 70 deletions(-)

diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py
index 8b88809d329..4f59bc43ad3 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -82,6 +82,24 @@ def _to_microseconds(time_ns: int) -> int:
   return int(time_ns / _NANOSECOND_TO_MICROSECOND)
 
 
+def _convert_to_result(
+    batch: Iterable,
+    predictions: Union[Iterable, Dict[Any, Iterable]],
+    drop_example: Optional[bool] = False) -> Iterable[PredictionResult]:
+  if isinstance(predictions, dict):
+    # Go from one dictionary of type: {key_type1: Iterable<val_type1>,
+    # key_type2: Iterable<val_type2>, ...} where each Iterable is of
+    # length batch_size, to a list of dictionaries:
+    # [{key_type1: value_type1, key_type2: value_type2}]
+    predictions = [
+        dict(zip(predictions.keys(), v)) for v in zip(*predictions.values())
+    ]
+  if drop_example:
+    return [PredictionResult(None, y) for x, y in zip(batch, predictions)]
+
+  return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
+
+
 class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
   """Has the ability to load and apply an ML model."""
   def load_model(self) -> ModelT:
@@ -92,7 +110,8 @@ class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
       self,
       batch: Sequence[ExampleT],
       model: ModelT,
-      inference_args: Optional[Dict[str, Any]] = None) -> Iterable[PredictionT]:
+      inference_args: Optional[Dict[str, Any]] = None,
+      drop_example: Optional[bool] = False) -> Iterable[PredictionT]:
     """Runs inferences on a batch of examples.
 
     Args:
@@ -100,7 +119,8 @@ class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
       model: The model used to make inferences.
       inference_args: Extra arguments for models whose inference call requires
         extra parameters.
-
+      drop_example: Boolean flag indicating whether to
+        drop the example from PredictionResult
     Returns:
       An Iterable of Predictions.
     """
@@ -170,7 +190,8 @@ class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
       self,
       batch: Sequence[Tuple[KeyT, ExampleT]],
       model: ModelT,
-      inference_args: Optional[Dict[str, Any]] = None
+      inference_args: Optional[Dict[str, Any]] = None,
+      drop_example: Optional[bool] = False
   ) -> Iterable[Tuple[KeyT, PredictionT]]:
     keys, unkeyed_batch = zip(*batch)
     return zip(
@@ -225,7 +246,8 @@ class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
       self,
       batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]],
       model: ModelT,
-      inference_args: Optional[Dict[str, Any]] = None
+      inference_args: Optional[Dict[str, Any]] = None,
+      drop_example: Optional[bool] = False
   ) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]:
     # Really the input should be
     #    Union[Sequence[ExampleT], Sequence[Tuple[KeyT, ExampleT]]]
@@ -273,7 +295,9 @@ class RunInference(beam.PTransform[beam.PCollection[ExampleT],
       model_handler: ModelHandler[ExampleT, PredictionT, Any],
       clock=time,
       inference_args: Optional[Dict[str, Any]] = None,
-      metrics_namespace: Optional[str] = None):
+      metrics_namespace: Optional[str] = None,
+      drop_example: Optional[bool] = False,
+  ):
     """A transform that takes a PCollection of examples (or features) to be used
     on an ML model. It will then output inferences (or predictions) for those
     examples in a PCollection of PredictionResults, containing the input
@@ -291,11 +315,13 @@ class RunInference(beam.PTransform[beam.PCollection[ExampleT],
         inference_args: Extra arguments for models whose inference call requires
           extra parameters.
         metrics_namespace: Namespace of the transform to collect metrics.
-    """
+        drop_example: Boolean flag indicating whether to
+          drop the example from PredictionResult    """
     self._model_handler = model_handler
     self._inference_args = inference_args
     self._clock = clock
     self._metrics_namespace = metrics_namespace
+    self._drop_example = drop_example
 
   # TODO(BEAM-14046): Add and link to help documentation.
   @classmethod
@@ -327,7 +353,10 @@ class RunInference(beam.PTransform[beam.PCollection[ExampleT],
         | 'BeamML_RunInference' >> (
             beam.ParDo(
                 _RunInferenceDoFn(
-                    self._model_handler, self._clock, self._metrics_namespace),
+                    model_handler=self._model_handler,
+                    clock=self._clock,
+                    metrics_namespace=self._metrics_namespace,
+                    drop_example=self._drop_example),
                 self._inference_args).with_resource_hints(**resource_hints)))
 
 
@@ -385,19 +414,22 @@ class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]):
       self,
       model_handler: ModelHandler[ExampleT, PredictionT, Any],
       clock,
-      metrics_namespace):
+      metrics_namespace: Optional[str],
+      drop_example: Optional[bool] = False):
     """A DoFn implementation generic to frameworks.
 
       Args:
         model_handler: An implementation of ModelHandler.
         clock: A clock implementing time_ns. *Used for unit testing.*
         metrics_namespace: Namespace of the transform to collect metrics.
-    """
+        drop_example: Boolean flag indicating whether to
+          drop the example from PredictionResult    """
     self._model_handler = model_handler
     self._shared_model_handle = shared.Shared()
     self._clock = clock
     self._model = None
     self._metrics_namespace = metrics_namespace
+    self._drop_example = drop_example
 
   def _load_model(self):
     def load():
@@ -427,7 +459,7 @@ class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]):
   def process(self, batch, inference_args):
     start_time = _to_microseconds(self._clock.time_ns())
     result_generator = self._model_handler.run_inference(
-        batch, self._model, inference_args)
+        batch, self._model, inference_args, self._drop_example)
     predictions = list(result_generator)
 
     end_time = _to_microseconds(self._clock.time_ns())
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py
index 278485666a0..1f74f1868c0 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -48,13 +48,38 @@ class FakeModelHandler(base.ModelHandler[int, int, FakeModel]):
       self,
       batch: Sequence[int],
       model: FakeModel,
-      inference_args=None) -> Iterable[int]:
+      inference_args=None,
+      drop_example=False) -> Iterable[int]:
     if self._fake_clock:
       self._fake_clock.current_time_ns += 3_000_000  # 3 milliseconds
     for example in batch:
       yield model.predict(example)
 
 
+class FakeModelHandlerReturnsPredictionResult(
+    base.ModelHandler[int, base.PredictionResult, FakeModel]):
+  def __init__(self, clock=None):
+    self._fake_clock = clock
+
+  def load_model(self):
+    if self._fake_clock:
+      self._fake_clock.current_time_ns += 500_000_000  # 500ms
+    return FakeModel()
+
+  def run_inference(
+      self,
+      batch: Sequence[int],
+      model: FakeModel,
+      inference_args=None,
+      drop_example=False) -> Iterable[base.PredictionResult]:
+    if self._fake_clock:
+      self._fake_clock.current_time_ns += 3_000_000  # 3 milliseconds
+
+    predictions = [model.predict(example) for example in batch]
+    return base._convert_to_result(
+        batch=batch, predictions=predictions, drop_example=drop_example)
+
+
 class FakeClock:
   def __init__(self):
     # Start at 10 seconds.
@@ -70,7 +95,8 @@ class ExtractInferences(beam.DoFn):
 
 
 class FakeModelHandlerNeedsBigBatch(FakeModelHandler):
-  def run_inference(self, batch, unused_model, inference_args=None):
+  def run_inference(
+      self, batch, model, inference_args=None, drop_example=False):
     if len(batch) < 100:
       raise ValueError('Unexpectedly small batch')
     return batch
@@ -80,14 +106,16 @@ class FakeModelHandlerNeedsBigBatch(FakeModelHandler):
 
 
 class FakeModelHandlerFailsOnInferenceArgs(FakeModelHandler):
-  def run_inference(self, batch, unused_model, inference_args=None):
+  def run_inference(
+      self, batch, model, inference_args=None, drop_example=False):
     raise ValueError(
         'run_inference should not be called because error should already be '
         'thrown from the validate_inference_args check.')
 
 
 class FakeModelHandlerExpectedInferenceArgs(FakeModelHandler):
-  def run_inference(self, batch, unused_model, inference_args=None):
+  def run_inference(
+      self, batch, model, inference_args=None, drop_example=False):
     if not inference_args:
       raise ValueError('inference_args should exist')
     return batch
@@ -251,6 +279,19 @@ class RunInferenceBaseTest(unittest.TestCase):
           | 'RunKeyed' >> base.RunInference(model_handler))
       pipeline.run()
 
+  def test_drop_example_prediction_result(self):
+    def assert_drop_example(prediction_result):
+      assert prediction_result.example is None
+
+    pipeline = TestPipeline()
+    examples = [1, 3, 5]
+    model_handler = FakeModelHandlerReturnsPredictionResult()
+    _ = (
+        pipeline | 'keyed' >> beam.Create(examples)
+        | 'RunKeyed' >> base.RunInference(model_handler, drop_example=True)
+        | beam.Map(assert_drop_example))
+    pipeline.run()
+
 
 if __name__ == '__main__':
   unittest.main()
diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
index d97205937e2..9d80e2fc1ef 100644
--- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py
+++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
@@ -25,12 +25,12 @@ from typing import Dict
 from typing import Iterable
 from typing import Optional
 from typing import Sequence
-from typing import Union
 
 import torch
 from apache_beam.io.filesystems import FileSystems
 from apache_beam.ml.inference.base import ModelHandler
 from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.base import _convert_to_result
 from apache_beam.utils.annotations import experimental
 
 __all__ = [
@@ -83,23 +83,6 @@ def _convert_to_device(examples: torch.Tensor, device) -> torch.Tensor:
   return examples
 
 
-def _convert_to_result(
-    batch: Iterable, predictions: Union[Iterable, Dict[Any, Iterable]]
-) -> Iterable[PredictionResult]:
-  if isinstance(predictions, dict):
-    # Go from one dictionary of type: {key_type1: Iterable<val_type1>,
-    # key_type2: Iterable<val_type2>, ...} where each Iterable is of
-    # length batch_size, to a list of dictionaries:
-    # [{key_type1: value_type1, key_type2: value_type2}]
-    predictions_per_tensor = [
-        dict(zip(predictions.keys(), v)) for v in zip(*predictions.values())
-    ]
-    return [
-        PredictionResult(x, y) for x, y in zip(batch, predictions_per_tensor)
-    ]
-  return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
-
-
 class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
                                              PredictionResult,
                                              torch.nn.Module]):
@@ -152,8 +135,8 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
       self,
       batch: Sequence[torch.Tensor],
       model: torch.nn.Module,
-      inference_args: Optional[Dict[str, Any]] = None
-  ) -> Iterable[PredictionResult]:
+      inference_args: Optional[Dict[str, Any]] = None,
+      drop_example: Optional[bool] = False) -> Iterable[PredictionResult]:
     """
     Runs inferences on a batch of Tensors and returns an Iterable of
     Tensor Predictions.
@@ -170,7 +153,8 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
       inference_args: Non-batchable arguments required as inputs to the model's
         forward() function. Unlike Tensors in `batch`, these parameters will
         not be dynamically batched
-
+      drop_example: Boolean flag indicating whether to
+        drop the example from PredictionResult
     Returns:
       An Iterable of type PredictionResult.
     """
@@ -182,7 +166,7 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
       batched_tensors = torch.stack(batch)
       batched_tensors = _convert_to_device(batched_tensors, self._device)
       predictions = model(batched_tensors, **inference_args)
-      return _convert_to_result(batch, predictions)
+      return _convert_to_result(batch, predictions, drop_example=drop_example)
 
   def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int:
     """
@@ -259,7 +243,8 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor],
       self,
       batch: Sequence[Dict[str, torch.Tensor]],
       model: torch.nn.Module,
-      inference_args: Optional[Dict[str, Any]] = None
+      inference_args: Optional[Dict[str, Any]] = None,
+      drop_example: Optional[bool] = False,
   ) -> Iterable[PredictionResult]:
     """
     Runs inferences on a batch of Keyed Tensors and returns an Iterable of
@@ -277,6 +262,8 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor],
       inference_args: Non-batchable arguments required as inputs to the model's
         forward() function. Unlike Tensors in `batch`, these parameters will
         not be dynamically batched
+      drop_example: Boolean flag indicating whether to
+        drop the example from PredictionResult
 
     Returns:
       An Iterable of type PredictionResult.
@@ -300,7 +287,7 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor],
         key_to_batched_tensors[key] = batched_tensors
       predictions = model(**key_to_batched_tensors, **inference_args)
 
-      return _convert_to_result(batch, predictions)
+      return _convert_to_result(batch, predictions, drop_example=drop_example)
 
   def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int:
     """
diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
index 8fd1899ef96..62e22f4fd5a 100644
--- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py
+++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
@@ -23,7 +23,6 @@ from typing import Dict
 from typing import Iterable
 from typing import Optional
 from typing import Sequence
-from typing import Union
 
 import numpy
 import pandas
@@ -32,6 +31,7 @@ from sklearn.base import BaseEstimator
 from apache_beam.io.filesystems import FileSystems
 from apache_beam.ml.inference.base import ModelHandler
 from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.base import _convert_to_result
 from apache_beam.utils.annotations import experimental
 
 try:
@@ -67,23 +67,6 @@ def _load_model(model_uri, file_type):
   raise AssertionError('Unsupported serialization type.')
 
 
-def _convert_to_result(
-    batch: Iterable, predictions: Union[Iterable, Dict[Any, Iterable]]
-) -> Iterable[PredictionResult]:
-  if isinstance(predictions, dict):
-    # Go from one dictionary of type: {key_type1: Iterable<val_type1>,
-    # key_type2: Iterable<val_type2>, ...} where each Iterable is of
-    # length batch_size, to a list of dictionaries:
-    # [{key_type1: value_type1, key_type2: value_type2}]
-    predictions_per_tensor = [
-        dict(zip(predictions.keys(), v)) for v in zip(*predictions.values())
-    ]
-    return [
-        PredictionResult(x, y) for x, y in zip(batch, predictions_per_tensor)
-    ]
-  return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
-
-
 class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,
                                             PredictionResult,
                                             BaseEstimator]):
@@ -114,8 +97,8 @@ class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,
       self,
       batch: Sequence[numpy.ndarray],
       model: BaseEstimator,
-      inference_args: Optional[Dict[str, Any]] = None
-  ) -> Iterable[PredictionResult]:
+      inference_args: Optional[Dict[str, Any]] = None,
+      drop_example: Optional[bool] = False) -> Iterable[PredictionResult]:
     """Runs inferences on a batch of numpy arrays.
 
     Args:
@@ -124,6 +107,8 @@ class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,
       model: A numpy model or pipeline. Must implement predict(X).
         Where the parameter X is a numpy array.
       inference_args: Any additional arguments for an inference.
+      drop_example: Boolean flag indicating whether to
+        drop the example from PredictionResult
 
     Returns:
       An Iterable of type PredictionResult.
@@ -132,7 +117,7 @@ class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,
     vectorized_batch = numpy.stack(batch, axis=0)
     predictions = model.predict(vectorized_batch)
 
-    return _convert_to_result(batch, predictions)
+    return _convert_to_result(batch, predictions, drop_example=drop_example)
 
   def get_num_bytes(self, batch: Sequence[pandas.DataFrame]) -> int:
     """
@@ -183,8 +168,8 @@ class SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame,
       self,
       batch: Sequence[pandas.DataFrame],
       model: BaseEstimator,
-      inference_args: Optional[Dict[str, Any]] = None
-  ) -> Iterable[PredictionResult]:
+      inference_args: Optional[Dict[str, Any]] = None,
+      drop_example: Optional[bool] = False) -> Iterable[PredictionResult]:
     """
     Runs inferences on a batch of pandas dataframes.
 
@@ -194,7 +179,8 @@ class SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame,
       model: A dataframe model or pipeline. Must implement predict(X).
         Where the parameter X is a pandas dataframe.
       inference_args: Any additional arguments for an inference.
-
+      drop_example: Boolean flag indicating whether to
+        drop the example from PredictionResult
     Returns:
       An Iterable of type PredictionResult.
     """
@@ -210,7 +196,7 @@ class SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame,
         vectorized_batch.iloc[[i]] for i in range(vectorized_batch.shape[0])
     ]
 
-    return _convert_to_result(splits, predictions)
+    return _convert_to_result(splits, predictions, drop_example=drop_example)
 
   def get_num_bytes(self, batch: Sequence[pandas.DataFrame]) -> int:
     """
diff --git a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py
index 8ff65658c6b..64e2e6dcaa6 100644
--- a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py
+++ b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py
@@ -33,6 +33,7 @@ import numpy as np
 from apache_beam.io.filesystems import FileSystems
 from apache_beam.ml.inference.base import ModelHandler
 from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.base import _convert_to_result
 from apache_beam.utils.annotations import experimental
 
 LOGGER = logging.getLogger("TensorRTEngineHandlerNumPy")
@@ -225,7 +226,8 @@ class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray,
       self,
       batch: Sequence[np.ndarray],
       engine: TensorRTEngine,
-      inference_args: Optional[Dict[str, Any]] = None
+      inference_args: Optional[Dict[str, Any]] = None,
+      drop_example: Optional[bool] = False,
   ) -> Iterable[PredictionResult]:
     """
     Runs inferences on a batch of Tensors and returns an Iterable of
@@ -270,11 +272,11 @@ class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray,
                 stream))
       _assign_or_fail(cuda.cuStreamSynchronize(stream))
 
-      return [
-          PredictionResult(
-              x, [prediction[idx] for prediction in cpu_allocations]) for idx,
-          x in enumerate(batch)
-      ]
+      predictions = []
+      for idx in range(len(batch)):
+        predictions.append([prediction[idx] for prediction in cpu_allocations])
+
+      return _convert_to_result(batch, predictions, drop_example=drop_example)
 
   def get_num_bytes(self, batch: Sequence[np.ndarray]) -> int:
     """
@@ -287,4 +289,4 @@ class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray,
     """
     Returns a namespace for metrics collected by the RunInference transform.
     """
-    return 'RunInferenceTensorRT'
+    return 'BeamML_TensorRT'