You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ro...@apache.org on 2022/05/31 23:49:49 UTC
[beam] branch master updated: Add typing information to RunInferrence. (#17762)
This is an automated email from the ASF dual-hosted git repository.
robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new ca33943808c Add typing information to RunInferrence. (#17762)
ca33943808c is described below
commit ca33943808c56ce634c92eb85f865285c71ee048
Author: Robert Bradshaw <ro...@gmail.com>
AuthorDate: Tue May 31 16:49:42 2022 -0700
Add typing information to RunInferrence. (#17762)
---
sdks/python/apache_beam/ml/inference/base.py | 38 +++++++++++++---------
sdks/python/apache_beam/ml/inference/base_test.py | 7 ++--
sdks/python/apache_beam/ml/inference/pytorch.py | 17 ++++++----
.../apache_beam/ml/inference/sklearn_inference.py | 16 +++++----
.../transforms/periodicsequence_test.py | 2 +-
sdks/python/apache_beam/transforms/ptransform.py | 7 ++--
6 files changed, 53 insertions(+), 34 deletions(-)
diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py
index ad8ea5868f7..15f94451b8f 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -51,7 +51,9 @@ except ImportError:
_NANOSECOND_TO_MILLISECOND = 1_000_000
_NANOSECOND_TO_MICROSECOND = 1_000
-T = TypeVar('T')
+ModelT = TypeVar('ModelT')
+ExampleT = TypeVar('ExampleT')
+PredictionT = TypeVar('PredictionT')
def _to_milliseconds(time_ns: int) -> int:
@@ -62,14 +64,15 @@ def _to_microseconds(time_ns: int) -> int:
return int(time_ns / _NANOSECOND_TO_MICROSECOND)
-class InferenceRunner:
+class InferenceRunner(Generic[ExampleT, PredictionT, ModelT]):
"""Implements running inferences for a framework."""
- def run_inference(self, batch: List[Any], model: Any) -> Iterable[Any]:
+ def run_inference(self, batch: List[ExampleT],
+ model: ModelT) -> Iterable[PredictionT]:
"""Runs inferences on a batch of examples and
returns an Iterable of Predictions."""
raise NotImplementedError(type(self))
- def get_num_bytes(self, batch: Any) -> int:
+ def get_num_bytes(self, batch: List[ExampleT]) -> int:
"""Returns the number of bytes of data for a batch."""
return len(pickle.dumps(batch))
@@ -78,13 +81,14 @@ class InferenceRunner:
return 'RunInference'
-class ModelLoader(Generic[T]):
+class ModelLoader(Generic[ExampleT, PredictionT, ModelT]):
"""Has the ability to load an ML model."""
- def load_model(self) -> T:
+ def load_model(self) -> ModelT:
"""Loads and initializes a model for processing."""
raise NotImplementedError(type(self))
- def get_inference_runner(self) -> InferenceRunner:
+ def get_inference_runner(
+ self) -> InferenceRunner[ExampleT, PredictionT, ModelT]:
"""Returns an implementation of InferenceRunner for this model."""
raise NotImplementedError(type(self))
@@ -97,19 +101,22 @@ class ModelLoader(Generic[T]):
return {}
-class RunInference(beam.PTransform):
+class RunInference(beam.PTransform[beam.PCollection[ExampleT],
+ beam.PCollection[PredictionT]]):
"""An extensible transform for running inferences.
Args:
model_loader: An implementation of ModelLoader.
clock: A clock implementing get_current_time_in_microseconds.
"""
- def __init__(self, model_loader: ModelLoader, clock=time):
+ def __init__(
+ self, model_loader: ModelLoader[ExampleT, PredictionT, Any], clock=time):
self._model_loader = model_loader
self._clock = clock
# TODO(BEAM-14208): Add batch_size back off in the case there
# are functional reasons large batch sizes cannot be handled.
- def expand(self, pcoll: beam.PCollection) -> beam.PCollection:
+ def expand(
+ self, pcoll: beam.PCollection[ExampleT]) -> beam.PCollection[PredictionT]:
resource_hints = self._model_loader.get_resource_hints()
return (
pcoll
@@ -170,14 +177,12 @@ class _MetricsCollector:
self._inference_request_batch_byte_size.update(examples_byte_size)
-class _RunInferenceDoFn(beam.DoFn):
+class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]):
"""A DoFn implementation generic to frameworks."""
- def __init__(self, model_loader: ModelLoader, clock):
+ def __init__(
+ self, model_loader: ModelLoader[ExampleT, PredictionT, Any], clock):
self._model_loader = model_loader
- self._inference_runner = model_loader.get_inference_runner()
self._shared_model_handle = shared.Shared()
- self._metrics_collector = _MetricsCollector(
- self._inference_runner.get_metrics_namespace())
self._clock = clock
self._model = None
@@ -199,6 +204,9 @@ class _RunInferenceDoFn(beam.DoFn):
return self._shared_model_handle.acquire(load)
def setup(self):
+ self._inference_runner = self._model_loader.get_inference_runner()
+ self._metrics_collector = _MetricsCollector(
+ self._inference_runner.get_metrics_namespace())
self._model = self._load_model()
def process(self, batch):
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py
index d4bf6518bda..41b166ba78d 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -21,6 +21,7 @@ import pickle
import unittest
from typing import Any
from typing import Iterable
+from typing import List
import apache_beam as beam
from apache_beam.metrics.metric import MetricsFilter
@@ -35,18 +36,18 @@ class FakeModel:
return example + 1
-class FakeInferenceRunner(base.InferenceRunner):
+class FakeInferenceRunner(base.InferenceRunner[int, int, FakeModel]):
def __init__(self, clock=None):
self._fake_clock = clock
- def run_inference(self, batch: Any, model: Any) -> Iterable[Any]:
+ def run_inference(self, batch: List[int], model: FakeModel) -> Iterable[int]:
if self._fake_clock:
self._fake_clock.current_time_ns += 3_000_000 # 3 milliseconds
for example in batch:
yield model.predict(example)
-class FakeModelLoader(base.ModelLoader):
+class FakeModelLoader(base.ModelLoader[int, int, FakeModel]):
def __init__(self, clock=None):
self._fake_clock = clock
diff --git a/sdks/python/apache_beam/ml/inference/pytorch.py b/sdks/python/apache_beam/ml/inference/pytorch.py
index d7e24d61823..d591c6867d8 100644
--- a/sdks/python/apache_beam/ml/inference/pytorch.py
+++ b/sdks/python/apache_beam/ml/inference/pytorch.py
@@ -30,7 +30,9 @@ from apache_beam.ml.inference.base import InferenceRunner
from apache_beam.ml.inference.base import ModelLoader
-class PytorchInferenceRunner(InferenceRunner):
+class PytorchInferenceRunner(InferenceRunner[torch.Tensor,
+ PredictionResult,
+ torch.nn.Module]):
"""
This class runs Pytorch inferences with the run_inference method. It also has
other methods to get the bytes of a batch of Tensors as well as the namespace
@@ -66,7 +68,9 @@ class PytorchInferenceRunner(InferenceRunner):
return 'RunInferencePytorch'
-class PytorchModelLoader(ModelLoader):
+class PytorchModelLoader(ModelLoader[torch.Tensor,
+ PredictionResult,
+ torch.nn.Module]):
""" Implementation of the ModelLoader interface for PyTorch.
NOTE: This API and its implementation are under development and
@@ -96,18 +100,17 @@ class PytorchModelLoader(ModelLoader):
else:
self._device = torch.device('cpu')
self._model_class = model_class
- self.model_params = model_params
- self._inference_runner = PytorchInferenceRunner(device=self._device)
+ self._model_params = model_params
def load_model(self) -> torch.nn.Module:
"""Loads and initializes a Pytorch model for processing."""
- model = self._model_class(**self.model_params)
+ model = self._model_class(**self._model_params)
model.to(self._device)
file = FileSystems.open(self._state_dict_path, 'rb')
model.load_state_dict(torch.load(file))
model.eval()
return model
- def get_inference_runner(self) -> InferenceRunner:
+ def get_inference_runner(self) -> PytorchInferenceRunner:
"""Returns a Pytorch implementation of InferenceRunner."""
- return self._inference_runner
+ return PytorchInferenceRunner(device=self._device)
diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
index 80530146c2c..7f91169ea43 100644
--- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py
+++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
@@ -23,6 +23,7 @@ from typing import Iterable
from typing import List
import numpy
+from sklearn.base import BaseEstimator
from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.inference.api import PredictionResult
@@ -41,9 +42,11 @@ class ModelFileType(enum.Enum):
JOBLIB = 2
-class SklearnInferenceRunner(InferenceRunner):
+class SklearnInferenceRunner(InferenceRunner[numpy.ndarray,
+ PredictionResult,
+ BaseEstimator]):
def run_inference(self, batch: List[numpy.ndarray],
- model: Any) -> Iterable[PredictionResult]:
+ model: BaseEstimator) -> Iterable[PredictionResult]:
# vectorize data for better performance
vectorized_batch = numpy.stack(batch, axis=0)
predictions = model.predict(vectorized_batch)
@@ -54,7 +57,9 @@ class SklearnInferenceRunner(InferenceRunner):
return sum(sys.getsizeof(element) for element in batch)
-class SklearnModelLoader(ModelLoader):
+class SklearnModelLoader(ModelLoader[numpy.ndarray,
+ PredictionResult,
+ BaseEstimator]):
""" Implementation of the ModelLoader interface for scikit learn.
NOTE: This API and its implementation are under development and
@@ -66,9 +71,8 @@ class SklearnModelLoader(ModelLoader):
model_uri: str = ''):
self._model_file_type = model_file_type
self._model_uri = model_uri
- self._inference_runner = SklearnInferenceRunner()
- def load_model(self):
+ def load_model(self) -> BaseEstimator:
"""Loads and initializes a model for processing."""
file = FileSystems.open(self._model_uri, 'rb')
if self._model_file_type == ModelFileType.PICKLE:
@@ -84,4 +88,4 @@ class SklearnModelLoader(ModelLoader):
raise AssertionError('Unsupported serialization type.')
def get_inference_runner(self) -> SklearnInferenceRunner:
- return self._inference_runner
+ return SklearnInferenceRunner()
diff --git a/sdks/python/apache_beam/transforms/periodicsequence_test.py b/sdks/python/apache_beam/transforms/periodicsequence_test.py
index b18bf75d070..e2fe264fbce 100644
--- a/sdks/python/apache_beam/transforms/periodicsequence_test.py
+++ b/sdks/python/apache_beam/transforms/periodicsequence_test.py
@@ -76,7 +76,7 @@ class PeriodicSequenceTest(unittest.TestCase):
assert_that(actual, equal_to(k))
def test_periodicimpulse_default_start(self):
- default_parameters = inspect.signature(PeriodicImpulse).parameters
+ default_parameters = inspect.signature(PeriodicImpulse.__init__).parameters
it = default_parameters["start_timestamp"].default
duration = 1
et = it + duration
diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py
index 86421ed1676..f3e57951e37 100644
--- a/sdks/python/apache_beam/transforms/ptransform.py
+++ b/sdks/python/apache_beam/transforms/ptransform.py
@@ -49,6 +49,7 @@ from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Dict
+from typing import Generic
from typing import List
from typing import Mapping
from typing import Optional
@@ -99,6 +100,8 @@ __all__ = [
_LOGGER = logging.getLogger(__name__)
T = TypeVar('T')
+InputT = TypeVar('InputT')
+OutputT = TypeVar('OutputT')
PTransformT = TypeVar('PTransformT', bound='PTransform')
ConstructorFn = Callable[
['beam_runner_api_pb2.PTransform', Optional[Any], 'PipelineContext'], Any]
@@ -328,7 +331,7 @@ class _ZipPValues(object):
self.visit(p, sibling, pairs, context)
-class PTransform(WithTypeHints, HasDisplayData):
+class PTransform(WithTypeHints, HasDisplayData, Generic[InputT, OutputT]):
"""A transform object used to modify one or more PCollections.
Subclasses must define an expand() method that will be used when the transform
@@ -522,7 +525,7 @@ class PTransform(WithTypeHints, HasDisplayData):
transform.label = new_label
return transform
- def expand(self, input_or_inputs):
+ def expand(self, input_or_inputs: InputT) -> OutputT:
raise NotImplementedError
def __str__(self):