You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ro...@apache.org on 2022/06/10 21:31:33 UTC
[beam] branch master updated: Make keying of examples explicit. (#21777)
This is an automated email from the ASF dual-hosted git repository.
robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new b8e2e85ab1f Make keying of examples explicit. (#21777)
b8e2e85ab1f is described below
commit b8e2e85ab1fb37a2f89ed20d88730e591ea3bf7e
Author: Robert Bradshaw <ro...@gmail.com>
AuthorDate: Fri Jun 10 14:31:28 2022 -0700
Make keying of examples explicit. (#21777)
This decouples the keying logic from the DoFn and helps with type inference.
There is both a KeyedModelHandler that expects keys and a MaybeKeyedModelHandler that preserves the old behavior.
---
sdks/python/apache_beam/ml/inference/base.py | 134 ++++++++++++++++++---
sdks/python/apache_beam/ml/inference/base_test.py | 24 +++-
.../apache_beam/ml/inference/pytorch_inference.py | 6 +-
.../apache_beam/ml/inference/sklearn_inference.py | 12 +-
.../ml/inference/sklearn_inference_test.py | 2 +-
5 files changed, 145 insertions(+), 33 deletions(-)
diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py
index 534512a44f3..ae07ac0531e 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -35,9 +35,11 @@ import time
from typing import Any
from typing import Generic
from typing import Iterable
-from typing import List
from typing import Mapping
+from typing import Sequence
+from typing import Tuple
from typing import TypeVar
+from typing import Union
import apache_beam as beam
from apache_beam.utils import shared
@@ -54,6 +56,7 @@ _NANOSECOND_TO_MICROSECOND = 1_000
ModelT = TypeVar('ModelT')
ExampleT = TypeVar('ExampleT')
PredictionT = TypeVar('PredictionT')
+KeyT = TypeVar('KeyT')
def _to_milliseconds(time_ns: int) -> int:
@@ -70,13 +73,13 @@ class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
"""Loads and initializes a model for processing."""
raise NotImplementedError(type(self))
- def run_inference(self, batch: List[ExampleT], model: ModelT,
+ def run_inference(self, batch: Sequence[ExampleT], model: ModelT,
**kwargs) -> Iterable[PredictionT]:
"""Runs inferences on a batch of examples and
returns an Iterable of Predictions."""
raise NotImplementedError(type(self))
- def get_num_bytes(self, batch: List[ExampleT]) -> int:
+ def get_num_bytes(self, batch: Sequence[ExampleT]) -> int:
"""Returns the number of bytes of data for a batch."""
return len(pickle.dumps(batch))
@@ -93,6 +96,111 @@ class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
return {}
+class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
+ ModelHandler[Tuple[KeyT, ExampleT],
+ Tuple[KeyT, PredictionT],
+ ModelT]):
+ """A ModelHandler that takes keyed examples and returns keyed predictions.
+
+ For example, if the original model was used with RunInference to take a
+ PCollection[E] to a PCollection[P], this would take a
+ PCollection[Tuple[K, E]] to a PCollection[Tuple[K, P]], allowing one to
+ associate the outputs with the inputs based on the key.
+ """
+ def __init__(self, unkeyed: ModelHandler[ExampleT, PredictionT, ModelT]):
+ self._unkeyed = unkeyed
+
+ def load_model(self) -> ModelT:
+ return self._unkeyed.load_model()
+
+ def run_inference(
+ self, batch: Sequence[Tuple[KeyT, ExampleT]], model: ModelT,
+ **kwargs) -> Iterable[Tuple[KeyT, PredictionT]]:
+ keys, unkeyed_batch = zip(*batch)
+ return zip(
+ keys, self._unkeyed.run_inference(unkeyed_batch, model, **kwargs))
+
+ def get_num_bytes(self, batch: Sequence[Tuple[KeyT, ExampleT]]) -> int:
+ keys, unkeyed_batch = zip(*batch)
+ return len(pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch)
+
+ def get_metrics_namespace(self) -> str:
+ return self._unkeyed.get_metrics_namespace()
+
+ def get_resource_hints(self):
+ return self._unkeyed.get_resource_hints()
+
+ def batch_elements_kwargs(self):
+ return self._unkeyed.batch_elements_kwargs()
+ return {}
+
+
+class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
+ ModelHandler[Union[ExampleT, Tuple[KeyT,
+ ExampleT]],
+ Union[PredictionT,
+ Tuple[KeyT, PredictionT]],
+ ModelT]):
+ """A ModelHandler that takes possibly keyed examples and returns possibly
+ keyed predictions.
+
+ For example, if the original model was used with RunInference to take a
+ PCollection[E] to a PCollection[P], this would take either PCollection[E] to a
+ PCollection[P] or PCollection[Tuple[K, E]] to a PCollection[Tuple[K, P]],
+ depending on the whether the elements happen to be tuples, allowing one to
+ associate the outputs with the inputs based on the key.
+
+ Note that this cannot be used if E happens to be a tuple type. In addition,
+ either all examples should be keyed, or none of them.
+ """
+ def __init__(self, unkeyed: ModelHandler[ExampleT, PredictionT, ModelT]):
+ self._unkeyed = unkeyed
+
+ def load_model(self) -> ModelT:
+ return self._unkeyed.load_model()
+
+ def run_inference(
+ self,
+ batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]],
+ model: ModelT,
+ **kwargs
+ ) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]:
+ # Really the input should be
+ # Union[Sequence[ExampleT], Sequence[Tuple[KeyT, ExampleT]]]
+ # but there's not a good way to express (or check) that.
+ if isinstance(batch[0], tuple):
+ is_keyed = True
+ keys, unkeyed_batch = zip(*batch) # type: ignore[arg-type]
+ else:
+ is_keyed = False
+ unkeyed_batch = batch # type: ignore[assignment]
+ unkeyed_results = self._unkeyed.run_inference(
+ unkeyed_batch, model, **kwargs)
+ if is_keyed:
+ return zip(keys, unkeyed_results)
+ else:
+ return unkeyed_results
+
+ def get_num_bytes(
+ self, batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]]) -> int:
+ # MyPy can't follow the branching logic.
+ if isinstance(batch[0], tuple):
+ keys, unkeyed_batch = zip(*batch) # type: ignore[arg-type]
+ return len(
+ pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch)
+ else:
+ return self._unkeyed.get_num_bytes(batch) # type: ignore[arg-type]
+
+ def get_metrics_namespace(self) -> str:
+ return self._unkeyed.get_metrics_namespace()
+
+ def get_resource_hints(self):
+ return self._unkeyed.get_resource_hints()
+
+ def batch_elements_kwargs(self):
+ return self._unkeyed.batch_elements_kwargs()
+
+
class RunInference(beam.PTransform[beam.PCollection[ExampleT],
beam.PCollection[PredictionT]]):
"""An extensible transform for running inferences.
@@ -205,32 +313,18 @@ class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]):
self._model = self._load_model()
def process(self, batch, **kwargs):
- # Process supports both keyed data, and example only data.
- # First keys and samples are separated (if there are keys)
- has_keys = isinstance(batch[0], tuple)
- if has_keys:
- examples = [example for _, example in batch]
- keys = [key for key, _ in batch]
- else:
- examples = batch
- keys = None
-
start_time = _to_microseconds(self._clock.time_ns())
result_generator = self._model_handler.run_inference(
- examples, self._model, **kwargs)
+ batch, self._model, **kwargs)
predictions = list(result_generator)
end_time = _to_microseconds(self._clock.time_ns())
inference_latency = end_time - start_time
- num_bytes = self._model_handler.get_num_bytes(examples)
+ num_bytes = self._model_handler.get_num_bytes(batch)
num_elements = len(batch)
self._metrics_collector.update(num_elements, num_bytes, inference_latency)
- # Keys are recombined with predictions in the RunInference PTransform.
- if has_keys:
- yield from zip(keys, predictions)
- else:
- yield from predictions
+ return predictions
def finish_bundle(self):
# TODO(BEAM-13970): Figure out why there is a cache.
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py
index 3ea2a9db12b..52f8f883203 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -20,7 +20,7 @@
import pickle
import unittest
from typing import Iterable
-from typing import List
+from typing import Sequence
import apache_beam as beam
from apache_beam.metrics.metric import MetricsFilter
@@ -44,7 +44,7 @@ class FakeModelHandler(base.ModelHandler[int, int, FakeModel]):
self._fake_clock.current_time_ns += 500_000_000 # 500ms
return FakeModel()
- def run_inference(self, batch: List[int], model: FakeModel,
+ def run_inference(self, batch: Sequence[int], model: FakeModel,
**kwargs) -> Iterable[int]:
if self._fake_clock:
self._fake_clock.current_time_ns += 3_000_000 # 3 milliseconds
@@ -98,9 +98,27 @@ class RunInferenceBaseTest(unittest.TestCase):
keyed_examples = [(i, example) for i, example in enumerate(examples)]
expected = [(i, example + 1) for i, example in enumerate(examples)]
pcoll = pipeline | 'start' >> beam.Create(keyed_examples)
- actual = pcoll | base.RunInference(FakeModelHandler())
+ actual = pcoll | base.RunInference(
+ base.KeyedModelHandler(FakeModelHandler()))
assert_that(actual, equal_to(expected), label='assert:inferences')
+ def test_run_inference_impl_with_maybe_keyed_examples(self):
+ with TestPipeline() as pipeline:
+ examples = [1, 5, 3, 10]
+ keyed_examples = [(i, example) for i, example in enumerate(examples)]
+ expected = [example + 1 for example in examples]
+ keyed_expected = [(i, example + 1) for i, example in enumerate(examples)]
+ model_handler = base.MaybeKeyedModelHandler(FakeModelHandler())
+
+ pcoll = pipeline | 'Unkeyed' >> beam.Create(examples)
+ actual = pcoll | 'RunUnkeyed' >> base.RunInference(model_handler)
+ assert_that(actual, equal_to(expected), label='CheckUnkeyed')
+
+ keyed_pcoll = pipeline | 'Keyed' >> beam.Create(keyed_examples)
+ keyed_actual = keyed_pcoll | 'RunKeyed' >> base.RunInference(
+ model_handler)
+ assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed')
+
def test_run_inference_impl_kwargs(self):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
index 3a4fb2926f8..1a1afaaf1c8 100644
--- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py
+++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
@@ -22,7 +22,7 @@ from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
-from typing import List
+from typing import Sequence
from typing import Union
import torch
@@ -87,7 +87,7 @@ class PytorchModelHandler(ModelHandler[torch.Tensor,
def run_inference(
self,
- batch: List[Union[torch.Tensor, Dict[str, torch.Tensor]]],
+ batch: Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]],
model: torch.nn.Module,
**kwargs) -> Iterable[PredictionResult]:
"""
@@ -119,7 +119,7 @@ class PytorchModelHandler(ModelHandler[torch.Tensor,
predictions = model(batched_tensors, **prediction_params)
return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
- def get_num_bytes(self, batch: List[torch.Tensor]) -> int:
+ def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int:
"""Returns the number of bytes of data for a batch of Tensors."""
# If elements in `batch` are provided as a dictionaries from key to Tensors
if isinstance(batch[0], dict):
diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
index 3c8eddfd7d3..5ca6a18b5d1 100644
--- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py
+++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
@@ -20,7 +20,7 @@ import pickle
import sys
from typing import Any
from typing import Iterable
-from typing import List
+from typing import Sequence
from typing import Union
import numpy
@@ -75,7 +75,7 @@ class SklearnModelHandler(ModelHandler[Union[numpy.ndarray, pandas.DataFrame],
def run_inference(
self,
- batch: List[Union[numpy.ndarray, pandas.DataFrame]],
+ batch: Sequence[Union[numpy.ndarray, pandas.DataFrame]],
model: BaseEstimator,
**kwargs) -> Iterable[PredictionResult]:
# TODO(github.com/apache/beam/issues/21769): Use supplied input type hint.
@@ -86,7 +86,7 @@ class SklearnModelHandler(ModelHandler[Union[numpy.ndarray, pandas.DataFrame],
raise ValueError('Unsupported data type.')
@staticmethod
- def _predict_np_array(batch: List[numpy.ndarray],
+ def _predict_np_array(batch: Sequence[numpy.ndarray],
model: Any) -> Iterable[PredictionResult]:
# vectorize data for better performance
vectorized_batch = numpy.stack(batch, axis=0)
@@ -94,7 +94,7 @@ class SklearnModelHandler(ModelHandler[Union[numpy.ndarray, pandas.DataFrame],
return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
@staticmethod
- def _predict_pandas_dataframe(batch: List[pandas.DataFrame],
+ def _predict_pandas_dataframe(batch: Sequence[pandas.DataFrame],
model: Any) -> Iterable[PredictionResult]:
# sklearn_inference currently only supports single rowed dataframes.
for dataframe in batch:
@@ -113,11 +113,11 @@ class SklearnModelHandler(ModelHandler[Union[numpy.ndarray, pandas.DataFrame],
]
def get_num_bytes(
- self, batch: List[Union[numpy.ndarray, pandas.DataFrame]]) -> int:
+ self, batch: Sequence[Union[numpy.ndarray, pandas.DataFrame]]) -> int:
"""Returns the number of bytes of data for a batch."""
if isinstance(batch[0], numpy.ndarray):
return sum(sys.getsizeof(element) for element in batch)
elif isinstance(batch[0], pandas.DataFrame):
- data_frames: List[pandas.DataFrame] = batch
+ data_frames: Sequence[pandas.DataFrame] = batch
return sum(df.memory_usage(deep=True).sum() for df in data_frames)
raise ValueError('Unsupported data type.')
diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py
index 91eb86e2de4..0d7294eb406 100644
--- a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py
+++ b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py
@@ -265,7 +265,7 @@ class SkLearnRunInferenceTest(unittest.TestCase):
pcoll = pipeline | 'start' >> beam.Create(keyed_rows)
actual = pcoll | api.RunInference(
- SklearnModelHandler(model_uri=temp_file_name))
+ base.KeyedModelHandler(SklearnModelHandler(model_uri=temp_file_name)))
expected = [
('0', api.PredictionResult(splits[0], 5)),
('1', api.PredictionResult(splits[1], 8)),