You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ro...@apache.org on 2022/05/31 23:49:49 UTC

[beam] branch master updated: Add typing information to RunInferrence. (#17762)

This is an automated email from the ASF dual-hosted git repository.

robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new ca33943808c Add typing information to RunInferrence. (#17762)
ca33943808c is described below

commit ca33943808c56ce634c92eb85f865285c71ee048
Author: Robert Bradshaw <ro...@gmail.com>
AuthorDate: Tue May 31 16:49:42 2022 -0700

    Add typing information to RunInferrence. (#17762)
---
 sdks/python/apache_beam/ml/inference/base.py       | 38 +++++++++++++---------
 sdks/python/apache_beam/ml/inference/base_test.py  |  7 ++--
 sdks/python/apache_beam/ml/inference/pytorch.py    | 17 ++++++----
 .../apache_beam/ml/inference/sklearn_inference.py  | 16 +++++----
 .../transforms/periodicsequence_test.py            |  2 +-
 sdks/python/apache_beam/transforms/ptransform.py   |  7 ++--
 6 files changed, 53 insertions(+), 34 deletions(-)

diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py
index ad8ea5868f7..15f94451b8f 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -51,7 +51,9 @@ except ImportError:
 _NANOSECOND_TO_MILLISECOND = 1_000_000
 _NANOSECOND_TO_MICROSECOND = 1_000
 
-T = TypeVar('T')
+ModelT = TypeVar('ModelT')
+ExampleT = TypeVar('ExampleT')
+PredictionT = TypeVar('PredictionT')
 
 
 def _to_milliseconds(time_ns: int) -> int:
@@ -62,14 +64,15 @@ def _to_microseconds(time_ns: int) -> int:
   return int(time_ns / _NANOSECOND_TO_MICROSECOND)
 
 
-class InferenceRunner:
+class InferenceRunner(Generic[ExampleT, PredictionT, ModelT]):
   """Implements running inferences for a framework."""
-  def run_inference(self, batch: List[Any], model: Any) -> Iterable[Any]:
+  def run_inference(self, batch: List[ExampleT],
+                    model: ModelT) -> Iterable[PredictionT]:
     """Runs inferences on a batch of examples and
     returns an Iterable of Predictions."""
     raise NotImplementedError(type(self))
 
-  def get_num_bytes(self, batch: Any) -> int:
+  def get_num_bytes(self, batch: List[ExampleT]) -> int:
     """Returns the number of bytes of data for a batch."""
     return len(pickle.dumps(batch))
 
@@ -78,13 +81,14 @@ class InferenceRunner:
     return 'RunInference'
 
 
-class ModelLoader(Generic[T]):
+class ModelLoader(Generic[ExampleT, PredictionT, ModelT]):
   """Has the ability to load an ML model."""
-  def load_model(self) -> T:
+  def load_model(self) -> ModelT:
     """Loads and initializes a model for processing."""
     raise NotImplementedError(type(self))
 
-  def get_inference_runner(self) -> InferenceRunner:
+  def get_inference_runner(
+      self) -> InferenceRunner[ExampleT, PredictionT, ModelT]:
     """Returns an implementation of InferenceRunner for this model."""
     raise NotImplementedError(type(self))
 
@@ -97,19 +101,22 @@ class ModelLoader(Generic[T]):
     return {}
 
 
-class RunInference(beam.PTransform):
+class RunInference(beam.PTransform[beam.PCollection[ExampleT],
+                                   beam.PCollection[PredictionT]]):
   """An extensible transform for running inferences.
   Args:
       model_loader: An implementation of ModelLoader.
       clock: A clock implementing get_current_time_in_microseconds.
   """
-  def __init__(self, model_loader: ModelLoader, clock=time):
+  def __init__(
+      self, model_loader: ModelLoader[ExampleT, PredictionT, Any], clock=time):
     self._model_loader = model_loader
     self._clock = clock
 
   # TODO(BEAM-14208): Add batch_size back off in the case there
   # are functional reasons large batch sizes cannot be handled.
-  def expand(self, pcoll: beam.PCollection) -> beam.PCollection:
+  def expand(
+      self, pcoll: beam.PCollection[ExampleT]) -> beam.PCollection[PredictionT]:
     resource_hints = self._model_loader.get_resource_hints()
     return (
         pcoll
@@ -170,14 +177,12 @@ class _MetricsCollector:
     self._inference_request_batch_byte_size.update(examples_byte_size)
 
 
-class _RunInferenceDoFn(beam.DoFn):
+class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]):
   """A DoFn implementation generic to frameworks."""
-  def __init__(self, model_loader: ModelLoader, clock):
+  def __init__(
+      self, model_loader: ModelLoader[ExampleT, PredictionT, Any], clock):
     self._model_loader = model_loader
-    self._inference_runner = model_loader.get_inference_runner()
     self._shared_model_handle = shared.Shared()
-    self._metrics_collector = _MetricsCollector(
-        self._inference_runner.get_metrics_namespace())
     self._clock = clock
     self._model = None
 
@@ -199,6 +204,9 @@ class _RunInferenceDoFn(beam.DoFn):
     return self._shared_model_handle.acquire(load)
 
   def setup(self):
+    self._inference_runner = self._model_loader.get_inference_runner()
+    self._metrics_collector = _MetricsCollector(
+        self._inference_runner.get_metrics_namespace())
     self._model = self._load_model()
 
   def process(self, batch):
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py
index d4bf6518bda..41b166ba78d 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -21,6 +21,7 @@ import pickle
 import unittest
 from typing import Any
 from typing import Iterable
+from typing import List
 
 import apache_beam as beam
 from apache_beam.metrics.metric import MetricsFilter
@@ -35,18 +36,18 @@ class FakeModel:
     return example + 1
 
 
-class FakeInferenceRunner(base.InferenceRunner):
+class FakeInferenceRunner(base.InferenceRunner[int, int, FakeModel]):
   def __init__(self, clock=None):
     self._fake_clock = clock
 
-  def run_inference(self, batch: Any, model: Any) -> Iterable[Any]:
+  def run_inference(self, batch: List[int], model: FakeModel) -> Iterable[int]:
     if self._fake_clock:
       self._fake_clock.current_time_ns += 3_000_000  # 3 milliseconds
     for example in batch:
       yield model.predict(example)
 
 
-class FakeModelLoader(base.ModelLoader):
+class FakeModelLoader(base.ModelLoader[int, int, FakeModel]):
   def __init__(self, clock=None):
     self._fake_clock = clock
 
diff --git a/sdks/python/apache_beam/ml/inference/pytorch.py b/sdks/python/apache_beam/ml/inference/pytorch.py
index d7e24d61823..d591c6867d8 100644
--- a/sdks/python/apache_beam/ml/inference/pytorch.py
+++ b/sdks/python/apache_beam/ml/inference/pytorch.py
@@ -30,7 +30,9 @@ from apache_beam.ml.inference.base import InferenceRunner
 from apache_beam.ml.inference.base import ModelLoader
 
 
-class PytorchInferenceRunner(InferenceRunner):
+class PytorchInferenceRunner(InferenceRunner[torch.Tensor,
+                                             PredictionResult,
+                                             torch.nn.Module]):
   """
   This class runs Pytorch inferences with the run_inference method. It also has
   other methods to get the bytes of a batch of Tensors as well as the namespace
@@ -66,7 +68,9 @@ class PytorchInferenceRunner(InferenceRunner):
     return 'RunInferencePytorch'
 
 
-class PytorchModelLoader(ModelLoader):
+class PytorchModelLoader(ModelLoader[torch.Tensor,
+                                     PredictionResult,
+                                     torch.nn.Module]):
   """ Implementation of the ModelLoader interface for PyTorch.
 
       NOTE: This API and its implementation are under development and
@@ -96,18 +100,17 @@ class PytorchModelLoader(ModelLoader):
     else:
       self._device = torch.device('cpu')
     self._model_class = model_class
-    self.model_params = model_params
-    self._inference_runner = PytorchInferenceRunner(device=self._device)
+    self._model_params = model_params
 
   def load_model(self) -> torch.nn.Module:
     """Loads and initializes a Pytorch model for processing."""
-    model = self._model_class(**self.model_params)
+    model = self._model_class(**self._model_params)
     model.to(self._device)
     file = FileSystems.open(self._state_dict_path, 'rb')
     model.load_state_dict(torch.load(file))
     model.eval()
     return model
 
-  def get_inference_runner(self) -> InferenceRunner:
+  def get_inference_runner(self) -> PytorchInferenceRunner:
     """Returns a Pytorch implementation of InferenceRunner."""
-    return self._inference_runner
+    return PytorchInferenceRunner(device=self._device)
diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
index 80530146c2c..7f91169ea43 100644
--- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py
+++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
@@ -23,6 +23,7 @@ from typing import Iterable
 from typing import List
 
 import numpy
+from sklearn.base import BaseEstimator
 
 from apache_beam.io.filesystems import FileSystems
 from apache_beam.ml.inference.api import PredictionResult
@@ -41,9 +42,11 @@ class ModelFileType(enum.Enum):
   JOBLIB = 2
 
 
-class SklearnInferenceRunner(InferenceRunner):
+class SklearnInferenceRunner(InferenceRunner[numpy.ndarray,
+                                             PredictionResult,
+                                             BaseEstimator]):
   def run_inference(self, batch: List[numpy.ndarray],
-                    model: Any) -> Iterable[PredictionResult]:
+                    model: BaseEstimator) -> Iterable[PredictionResult]:
     # vectorize data for better performance
     vectorized_batch = numpy.stack(batch, axis=0)
     predictions = model.predict(vectorized_batch)
@@ -54,7 +57,9 @@ class SklearnInferenceRunner(InferenceRunner):
     return sum(sys.getsizeof(element) for element in batch)
 
 
-class SklearnModelLoader(ModelLoader):
+class SklearnModelLoader(ModelLoader[numpy.ndarray,
+                                     PredictionResult,
+                                     BaseEstimator]):
   """ Implementation of the ModelLoader interface for scikit learn.
 
       NOTE: This API and its implementation are under development and
@@ -66,9 +71,8 @@ class SklearnModelLoader(ModelLoader):
       model_uri: str = ''):
     self._model_file_type = model_file_type
     self._model_uri = model_uri
-    self._inference_runner = SklearnInferenceRunner()
 
-  def load_model(self):
+  def load_model(self) -> BaseEstimator:
     """Loads and initializes a model for processing."""
     file = FileSystems.open(self._model_uri, 'rb')
     if self._model_file_type == ModelFileType.PICKLE:
@@ -84,4 +88,4 @@ class SklearnModelLoader(ModelLoader):
     raise AssertionError('Unsupported serialization type.')
 
   def get_inference_runner(self) -> SklearnInferenceRunner:
-    return self._inference_runner
+    return SklearnInferenceRunner()
diff --git a/sdks/python/apache_beam/transforms/periodicsequence_test.py b/sdks/python/apache_beam/transforms/periodicsequence_test.py
index b18bf75d070..e2fe264fbce 100644
--- a/sdks/python/apache_beam/transforms/periodicsequence_test.py
+++ b/sdks/python/apache_beam/transforms/periodicsequence_test.py
@@ -76,7 +76,7 @@ class PeriodicSequenceTest(unittest.TestCase):
       assert_that(actual, equal_to(k))
 
   def test_periodicimpulse_default_start(self):
-    default_parameters = inspect.signature(PeriodicImpulse).parameters
+    default_parameters = inspect.signature(PeriodicImpulse.__init__).parameters
     it = default_parameters["start_timestamp"].default
     duration = 1
     et = it + duration
diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py
index 86421ed1676..f3e57951e37 100644
--- a/sdks/python/apache_beam/transforms/ptransform.py
+++ b/sdks/python/apache_beam/transforms/ptransform.py
@@ -49,6 +49,7 @@ from typing import TYPE_CHECKING
 from typing import Any
 from typing import Callable
 from typing import Dict
+from typing import Generic
 from typing import List
 from typing import Mapping
 from typing import Optional
@@ -99,6 +100,8 @@ __all__ = [
 _LOGGER = logging.getLogger(__name__)
 
 T = TypeVar('T')
+InputT = TypeVar('InputT')
+OutputT = TypeVar('OutputT')
 PTransformT = TypeVar('PTransformT', bound='PTransform')
 ConstructorFn = Callable[
     ['beam_runner_api_pb2.PTransform', Optional[Any], 'PipelineContext'], Any]
@@ -328,7 +331,7 @@ class _ZipPValues(object):
         self.visit(p, sibling, pairs, context)
 
 
-class PTransform(WithTypeHints, HasDisplayData):
+class PTransform(WithTypeHints, HasDisplayData, Generic[InputT, OutputT]):
   """A transform object used to modify one or more PCollections.
 
   Subclasses must define an expand() method that will be used when the transform
@@ -522,7 +525,7 @@ class PTransform(WithTypeHints, HasDisplayData):
     transform.label = new_label
     return transform
 
-  def expand(self, input_or_inputs):
+  def expand(self, input_or_inputs: InputT) -> OutputT:
     raise NotImplementedError
 
   def __str__(self):